]> git.leonardobizzoni.com Git - mnist/commitdiff
NN query
authorLeonardoBizzoni <leo2002714@gmail.com>
Mon, 27 May 2024 12:45:30 +0000 (14:45 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Mon, 27 May 2024 12:45:30 +0000 (14:45 +0200)
src/neural_network.rs
src/nn_error.rs

index b4887c1ea655572723af09cc6e1dd50dba9f385b..5342da0e80d23ec63a2c12981ec78e0b11fe4bfa 100644 (file)
@@ -1,3 +1,5 @@
+use std::slice::Iter;
+
 use nalgebra::DMatrix;
 use rand_distr::{Distribution, Normal};
 
@@ -62,30 +64,37 @@ impl NeuralNetwork {
     //     todo!()
     // }
 
-    // pub fn compute(&self, inputs: &[f64]) -> Result<f64, NNError> {
-    //     todo!()
-    // }
+    pub fn query(&mut self, inputs: Vec<f64>) -> Result<Vec<f64>, NeuralNetworkError> {
+        if inputs.len() != self.node_count_per_layer[0] {
+            return Err(NeuralNetworkError::InsufficientInputData {
+                expected: self.node_count_per_layer[0],
+                found: inputs.len(),
+            });
+        }
 
-    // fn activation(&self, sum: f64) -> f64 {
-    //     todo!()
-    // }
+        Ok(self.compute_query(
+            DMatrix::from_vec(self.node_count_per_layer[0], 1, inputs),
+            self.weights.iter(),
+        ))
+    }
 
-    fn generate_random_weights(n: usize) -> Result<Vec<f64>, NeuralNetworkError> {
-        let random = match Normal::new(0.0, (n as f64).powf(-0.5)) {
-            Ok(value) => value,
-            Err(_) => {
-                return Err(NeuralNetworkError::WeightError(
-                    WeightError::InvalidVariance(n as f64),
-                ))
+    fn compute_query(&self, input: DMatrix<f64>, mut weight_iter: Iter<DMatrix<f64>>) -> Vec<f64> {
+        match weight_iter.next() {
+            Some(weight) => {
+                let next_input: DMatrix<f64> = NeuralNetwork::activation_function(weight * input);
+                self.compute_query(next_input, weight_iter)
             }
-        };
-        let mut weights: Vec<f64> = vec![0.0; n];
+            None => input.column(0).as_slice().into(),
+        }
+    }
 
-        for weight in weights.iter_mut() {
-            *weight = random.sample(&mut rand::thread_rng());
+    fn activation_function(mut x: DMatrix<f64>) -> DMatrix<f64> {
+        // Sigmoid function applied to each element in the matrix (column vector)
+        for xi in x.iter_mut() {
+            *xi = 1.0 / (1.0 + std::f64::consts::E.powf(-*xi));
         }
 
-        Ok(weights)
+        x
     }
 }
 
index 65d3bb04bffec81b39201cc8e5af7eee5d3d4c41..d3e4f5b9df6973ee6bef4397b91645ca855ac34c 100644 (file)
@@ -5,7 +5,7 @@ pub enum NeuralNetworkError {
     InvalidNumberOfNodes,
 
     WeightError(WeightError),
-    InsufficientInputData,
+    InsufficientInputData { expected: usize, found: usize },
 }
 
 #[derive(Debug, Clone, Copy, PartialEq)]