From: LeonardoBizzoni Date: Wed, 29 May 2024 11:57:37 +0000 (+0200) Subject: NN training X-Git-Url: http://git.leonardobizzoni.com/?a=commitdiff_plain;h=00343aa714129ceb7e0acb5be8ca73ec36f711ab;p=mnist NN training --- diff --git a/src/neural_network.rs b/src/neural_network.rs index 5342da0..471736c 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -1,4 +1,9 @@ -use std::slice::Iter; +use std::{ + cell::RefCell, + iter::{Peekable, Rev}, + rc::Rc, + slice::{Iter, IterMut}, +}; use nalgebra::DMatrix; use rand_distr::{Distribution, Normal}; @@ -8,7 +13,7 @@ use crate::nn_error::{NeuralNetworkError, WeightError}; #[derive(Debug)] #[allow(dead_code)] pub struct NeuralNetwork { - weights: Vec>, + weights: Rc>>>, learning_rate: f64, node_count_per_layer: Vec, @@ -26,9 +31,7 @@ impl NeuralNetwork { return Err(NeuralNetworkError::InvalidNumberOfOutputNodes); } - // Weights from input to first hidden (or output) layer - let mut weights: Vec> = vec![]; - + let weights: Rc>>> = Rc::new(RefCell::new(vec![])); for (i, node_count) in node_count_per_layer[..node_count_per_layer.len() - 1] .iter() .enumerate() @@ -46,13 +49,18 @@ impl NeuralNetwork { } }; - weights.push(DMatrix::from_fn( + weights.borrow_mut().push(DMatrix::from_fn( node_count_per_layer[i + 1], // nrows = node_count in the next layer *node_count, // ncols = node_count in the current layer |_, _| -> f64 { random.sample(&mut rand::thread_rng()) }, )); } + #[cfg(debug_assertions)] + for (i, weight) in weights.borrow().iter().enumerate() { + println!("Weight {i}-{}: {weight}", i + 1); + } + Ok(Self { weights, learning_rate, @@ -60,11 +68,51 @@ impl NeuralNetwork { }) } - // pub fn train(&mut self, inputs: &[f64], target: f64) { - // todo!() - // } + pub fn train(&mut self, input: &[f64], target: &[f64]) -> Result<(), NeuralNetworkError> { + if target.len() != *self.node_count_per_layer.last().unwrap() { + return Err(NeuralNetworkError::InsufficientTargetsData { + expected: self.node_count_per_layer[0], + found: input.len(), + }); + } + + // Creates 2 column matrices + let input = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, input); + let target = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, target); - pub fn query(&mut self, inputs: Vec) -> Result, NeuralNetworkError> { + let outputs = self.compute_query(self.weights.borrow().iter(), vec![input]); + let output_error = target - outputs.last().unwrap(); + + { + let weights = Rc::clone(&self.weights); + self.adjust_weights( + output_error, + weights.borrow_mut().iter_mut().rev(), + outputs.iter().rev().peekable(), + ); + } + + Ok(()) + } + + fn adjust_weights( + &mut self, + next_error: DMatrix, + mut weights: Rev>>, + mut outputs: Peekable>>>, + ) { + if let (Some(weight), Some(next_output)) = (weights.next(), outputs.next()) { + *weight += self.learning_rate + * next_error + .zip_map(&next_output, |a, b| a * b) + .zip_map(&next_output.map(|x| 1.0 - x), |a, b| a * b) + * (outputs.peek().unwrap().transpose()); + + self.adjust_weights(weight.transpose() * next_error, weights, outputs); + } + } + + pub fn query(&mut self, inputs: &[f64]) -> Result, NeuralNetworkError> { if inputs.len() != self.node_count_per_layer[0] { return Err(NeuralNetworkError::InsufficientInputData { expected: self.node_count_per_layer[0], @@ -72,19 +120,36 @@ impl NeuralNetwork { }); } - Ok(self.compute_query( - DMatrix::from_vec(self.node_count_per_layer[0], 1, inputs), - self.weights.iter(), - )) + Ok(self + .compute_query( + self.weights.borrow().iter(), + vec![DMatrix::from_column_slice( + self.node_count_per_layer[0], + 1, + inputs, + )], + ) + .last() + .unwrap() + .column(0) + .as_slice() + .into()) } - fn compute_query(&self, input: DMatrix, mut weight_iter: Iter>) -> Vec { + fn compute_query( + &self, + mut weight_iter: Iter>, + mut per_layer_output: Vec>, + ) -> Vec> { match weight_iter.next() { Some(weight) => { - let next_input: DMatrix = NeuralNetwork::activation_function(weight * input); - self.compute_query(next_input, weight_iter) + let next_input: DMatrix = + NeuralNetwork::activation_function(weight * per_layer_output.last().unwrap()); + per_layer_output.push(next_input); + + self.compute_query(weight_iter, per_layer_output) } - None => input.column(0).as_slice().into(), + None => per_layer_output, } } @@ -132,6 +197,6 @@ mod tests { let nn = NeuralNetwork::new(42.0, vec![1, 3, 3, 3, 3, 1]); assert!(nn.is_ok()); - assert_eq!(nn.unwrap().weights.len(), 5); + assert_eq!(nn.unwrap().weights.borrow().len(), 5); } } diff --git a/src/nn_error.rs b/src/nn_error.rs index d3e4f5b..63510f2 100644 --- a/src/nn_error.rs +++ b/src/nn_error.rs @@ -6,6 +6,7 @@ pub enum NeuralNetworkError { WeightError(WeightError), InsufficientInputData { expected: usize, found: usize }, + InsufficientTargetsData { expected: usize, found: usize }, } #[derive(Debug, Clone, Copy, PartialEq)]