]> git.leonardobizzoni.com Git - mnist/commitdiff
NN training
authorLeonardoBizzoni <leo2002714@gmail.com>
Wed, 29 May 2024 11:57:37 +0000 (13:57 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Wed, 29 May 2024 11:57:37 +0000 (13:57 +0200)
src/neural_network.rs
src/nn_error.rs

index 5342da0e80d23ec63a2c12981ec78e0b11fe4bfa..471736ce6a4b205e3faeabbe356864731a7efded 100644 (file)
@@ -1,4 +1,9 @@
-use std::slice::Iter;
+use std::{
+    cell::RefCell,
+    iter::{Peekable, Rev},
+    rc::Rc,
+    slice::{Iter, IterMut},
+};
 
 use nalgebra::DMatrix;
 use rand_distr::{Distribution, Normal};
@@ -8,7 +13,7 @@ use crate::nn_error::{NeuralNetworkError, WeightError};
 #[derive(Debug)]
 #[allow(dead_code)]
 pub struct NeuralNetwork {
-    weights: Vec<DMatrix<f64>>,
+    weights: Rc<RefCell<Vec<DMatrix<f64>>>>,
 
     learning_rate: f64,
     node_count_per_layer: Vec<usize>,
@@ -26,9 +31,7 @@ impl NeuralNetwork {
             return Err(NeuralNetworkError::InvalidNumberOfOutputNodes);
         }
 
-        // Weights from input to first hidden (or output) layer
-        let mut weights: Vec<DMatrix<f64>> = vec![];
-
+        let weights: Rc<RefCell<Vec<DMatrix<f64>>>> = Rc::new(RefCell::new(vec![]));
         for (i, node_count) in node_count_per_layer[..node_count_per_layer.len() - 1]
             .iter()
             .enumerate()
@@ -46,13 +49,18 @@ impl NeuralNetwork {
                 }
             };
 
-            weights.push(DMatrix::from_fn(
+            weights.borrow_mut().push(DMatrix::from_fn(
                 node_count_per_layer[i + 1], // nrows = node_count in the next layer
                 *node_count,                 // ncols = node_count in the current layer
                 |_, _| -> f64 { random.sample(&mut rand::thread_rng()) },
             ));
         }
 
+        #[cfg(debug_assertions)]
+        for (i, weight) in weights.borrow().iter().enumerate() {
+            println!("Weight {i}-{}: {weight}", i + 1);
+        }
+
         Ok(Self {
             weights,
             learning_rate,
@@ -60,11 +68,51 @@ impl NeuralNetwork {
         })
     }
 
-    // pub fn train(&mut self, inputs: &[f64], target: f64) {
-    //     todo!()
-    // }
+    pub fn train(&mut self, input: &[f64], target: &[f64]) -> Result<(), NeuralNetworkError> {
+        if target.len() != *self.node_count_per_layer.last().unwrap() {
+            return Err(NeuralNetworkError::InsufficientTargetsData {
+                expected: self.node_count_per_layer[0],
+                found: input.len(),
+            });
+        }
+
+        // Creates 2 column matrices
+        let input = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, input);
+        let target = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, target);
 
-    pub fn query(&mut self, inputs: Vec<f64>) -> Result<Vec<f64>, NeuralNetworkError> {
+        let outputs = self.compute_query(self.weights.borrow().iter(), vec![input]);
+        let output_error = target - outputs.last().unwrap();
+
+        {
+            let weights = Rc::clone(&self.weights);
+            self.adjust_weights(
+                output_error,
+                weights.borrow_mut().iter_mut().rev(),
+                outputs.iter().rev().peekable(),
+            );
+        }
+
+        Ok(())
+    }
+
+    fn adjust_weights(
+        &mut self,
+        next_error: DMatrix<f64>,
+        mut weights: Rev<IterMut<DMatrix<f64>>>,
+        mut outputs: Peekable<Rev<Iter<DMatrix<f64>>>>,
+    ) {
+        if let (Some(weight), Some(next_output)) = (weights.next(), outputs.next()) {
+            *weight += self.learning_rate
+                * next_error
+                    .zip_map(&next_output, |a, b| a * b)
+                    .zip_map(&next_output.map(|x| 1.0 - x), |a, b| a * b)
+                * (outputs.peek().unwrap().transpose());
+
+            self.adjust_weights(weight.transpose() * next_error, weights, outputs);
+        }
+    }
+
+    pub fn query(&mut self, inputs: &[f64]) -> Result<Vec<f64>, NeuralNetworkError> {
         if inputs.len() != self.node_count_per_layer[0] {
             return Err(NeuralNetworkError::InsufficientInputData {
                 expected: self.node_count_per_layer[0],
@@ -72,19 +120,36 @@ impl NeuralNetwork {
             });
         }
 
-        Ok(self.compute_query(
-            DMatrix::from_vec(self.node_count_per_layer[0], 1, inputs),
-            self.weights.iter(),
-        ))
+        Ok(self
+            .compute_query(
+                self.weights.borrow().iter(),
+                vec![DMatrix::from_column_slice(
+                    self.node_count_per_layer[0],
+                    1,
+                    inputs,
+                )],
+            )
+            .last()
+            .unwrap()
+            .column(0)
+            .as_slice()
+            .into())
     }
 
-    fn compute_query(&self, input: DMatrix<f64>, mut weight_iter: Iter<DMatrix<f64>>) -> Vec<f64> {
+    fn compute_query(
+        &self,
+        mut weight_iter: Iter<DMatrix<f64>>,
+        mut per_layer_output: Vec<DMatrix<f64>>,
+    ) -> Vec<DMatrix<f64>> {
         match weight_iter.next() {
             Some(weight) => {
-                let next_input: DMatrix<f64> = NeuralNetwork::activation_function(weight * input);
-                self.compute_query(next_input, weight_iter)
+                let next_input: DMatrix<f64> =
+                    NeuralNetwork::activation_function(weight * per_layer_output.last().unwrap());
+                per_layer_output.push(next_input);
+
+                self.compute_query(weight_iter, per_layer_output)
             }
-            None => input.column(0).as_slice().into(),
+            None => per_layer_output,
         }
     }
 
@@ -132,6 +197,6 @@ mod tests {
         let nn = NeuralNetwork::new(42.0, vec![1, 3, 3, 3, 3, 1]);
 
         assert!(nn.is_ok());
-        assert_eq!(nn.unwrap().weights.len(), 5);
+        assert_eq!(nn.unwrap().weights.borrow().len(), 5);
     }
 }
index d3e4f5b9df6973ee6bef4397b91645ca855ac34c..63510f243a9efaed6ba23c549fe0f68999ab61d3 100644 (file)
@@ -6,6 +6,7 @@ pub enum NeuralNetworkError {
 
     WeightError(WeightError),
     InsufficientInputData { expected: usize, found: usize },
+    InsufficientTargetsData { expected: usize, found: usize },
 }
 
 #[derive(Debug, Clone, Copy, PartialEq)]