From: LeonardoBizzoni Date: Thu, 30 May 2024 08:03:02 +0000 (+0200) Subject: Fixed target length check X-Git-Url: http://git.leonardobizzoni.com/?a=commitdiff_plain;h=refs%2Fheads%2Fmaster;p=mnist Fixed target length check --- diff --git a/src/neural_network.rs b/src/neural_network.rs index 471736c..73dd9fc 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -11,7 +11,6 @@ use rand_distr::{Distribution, Normal}; use crate::nn_error::{NeuralNetworkError, WeightError}; #[derive(Debug)] -#[allow(dead_code)] pub struct NeuralNetwork { weights: Rc>>>, @@ -69,20 +68,27 @@ impl NeuralNetwork { } pub fn train(&mut self, input: &[f64], target: &[f64]) -> Result<(), NeuralNetworkError> { - if target.len() != *self.node_count_per_layer.last().unwrap() { + if input.len() != *self.node_count_per_layer.first().unwrap() { return Err(NeuralNetworkError::InsufficientTargetsData { expected: self.node_count_per_layer[0], found: input.len(), }); } + if target.len() != *self.node_count_per_layer.last().unwrap() { + return Err(NeuralNetworkError::InsufficientTargetsData { + expected: *self.node_count_per_layer.last().unwrap(), + found: target.len(), + }); + } // Creates 2 column matrices let input = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, input); - let target = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, target); + let target = DMatrix::from_column_slice(*self.node_count_per_layer.last().unwrap(), 1, target); let outputs = self.compute_query(self.weights.borrow().iter(), vec![input]); let output_error = target - outputs.last().unwrap(); + { let weights = Rc::clone(&self.weights); self.adjust_weights(