From 46a4b5ba0882f6925715aa28d7ba8aeb1c363c6a Mon Sep 17 00:00:00 2001 From: LeonardoBizzoni Date: Mon, 27 May 2024 14:32:51 +0200 Subject: [PATCH] Fixed creation of weight matrix A weight was supposed to be a DMatrix with: - nrows = node_count in the next layer - ncols = node_count in the current layer so that the resulting matrix could be multiplied with the input matrix (column vector) --- src/neural_network.rs | 63 ++++++++++++++++++++++--------------------- src/nn_error.rs | 6 +++-- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/neural_network.rs b/src/neural_network.rs index ca8966c..b4887c1 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -9,49 +9,52 @@ pub struct NeuralNetwork { weights: Vec>, learning_rate: f64, - - number_of_input_nodes: usize, - number_of_output_nodes: usize, - - hlayer_node_count: Vec, + node_count_per_layer: Vec, } #[allow(dead_code)] impl NeuralNetwork { pub fn new( - number_of_input_nodes: usize, - number_of_output_nodes: usize, learning_rate: f64, - hlayer_node_count: Vec, + node_count_per_layer: Vec, ) -> Result { - if number_of_input_nodes == 0 { + if node_count_per_layer.len() <= 0 { return Err(NeuralNetworkError::InvalidNumberOfInputNodes); - } else if number_of_output_nodes == 0 { + } else if node_count_per_layer.len() <= 1 { return Err(NeuralNetworkError::InvalidNumberOfOutputNodes); } // Weights from input to first hidden (or output) layer - let mut weights: Vec> = vec![DMatrix::from_vec( - number_of_input_nodes, - 1, - NeuralNetwork::generate_random_weights(number_of_input_nodes)?, - )]; - - // Weights from hidden to next hidden (or output) - for node_count in hlayer_node_count.iter() { - weights.push(DMatrix::from_vec( - *node_count, - 1, - NeuralNetwork::generate_random_weights(*node_count)?, + let mut weights: Vec> = vec![]; + + for (i, node_count) in node_count_per_layer[..node_count_per_layer.len() - 1] + .iter() + .enumerate() + { + if *node_count <= 0 || node_count_per_layer[i + 1] <= 0 { + return Err(NeuralNetworkError::InvalidNumberOfNodes); + } + + let random = match Normal::new(0.0, (*node_count as f64).powf(-0.5)) { + Ok(value) => value, + Err(_) => { + return Err(NeuralNetworkError::WeightError( + WeightError::InvalidVariance(*node_count as f64), + )) + } + }; + + weights.push(DMatrix::from_fn( + node_count_per_layer[i + 1], // nrows = node_count in the next layer + *node_count, // ncols = node_count in the current layer + |_, _| -> f64 { random.sample(&mut rand::thread_rng()) }, )); } Ok(Self { weights, - number_of_input_nodes, - number_of_output_nodes, learning_rate, - hlayer_node_count, + node_count_per_layer, }) } @@ -92,32 +95,32 @@ mod tests { #[test] fn nn_with_invalid_number_of_inputs() { - let nn = NeuralNetwork::new(0, 10, 42.0, vec![]); + let nn = NeuralNetwork::new(42.0, vec![0, 10]); assert!(nn.is_err()) } #[test] fn nn_without_outputs() { - let nn = NeuralNetwork::new(1, 0, 42.0, vec![]); + let nn = NeuralNetwork::new(42.0, vec![1, 0]); assert!(nn.is_err()) } #[test] fn nn_invalid_hidden_layer_node_count() { - let nn = NeuralNetwork::new(1, 1, 42.0, vec![0]); + let nn = NeuralNetwork::new(42.0, vec![1, 0, 1]); assert!(nn.is_err()) } #[test] fn valid_nn_definition() { // NN with 1 input, no hidden and 1 output - let nn = NeuralNetwork::new(1, 1, 42.0, vec![]); + let nn = NeuralNetwork::new(42.0, vec![1, 1]); assert!(nn.is_ok()) } #[test] fn valid_nn_with_hidden_layers_definition() { // NN with 1 input, 4 hidden layers each with 3 nodes, and 1 output - let nn = NeuralNetwork::new(1, 1, 42.0, vec![3, 3, 3, 3]); + let nn = NeuralNetwork::new(42.0, vec![1, 3, 3, 3, 3, 1]); assert!(nn.is_ok()); assert_eq!(nn.unwrap().weights.len(), 5); diff --git a/src/nn_error.rs b/src/nn_error.rs index d8a962f..65d3bb0 100644 --- a/src/nn_error.rs +++ b/src/nn_error.rs @@ -1,12 +1,14 @@ -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum NeuralNetworkError { InvalidNumberOfOutputNodes, InvalidNumberOfInputNodes, + InvalidNumberOfNodes, WeightError(WeightError), + InsufficientInputData, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq)] pub enum WeightError { InvalidVariance(f64), } -- 2.52.0