]> git.leonardobizzoni.com Git - mnist/commitdiff
Fixed target length check master
authorLeonardoBizzoni <leo2002714@gmail.com>
Thu, 30 May 2024 08:03:02 +0000 (10:03 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Thu, 30 May 2024 08:03:02 +0000 (10:03 +0200)
src/neural_network.rs

index 471736ce6a4b205e3faeabbe356864731a7efded..73dd9fc834c2c9f483c15d41a9c4f9ba51a7b437 100644 (file)
@@ -11,7 +11,6 @@ use rand_distr::{Distribution, Normal};
 use crate::nn_error::{NeuralNetworkError, WeightError};
 
 #[derive(Debug)]
-#[allow(dead_code)]
 pub struct NeuralNetwork {
     weights: Rc<RefCell<Vec<DMatrix<f64>>>>,
 
@@ -69,20 +68,27 @@ impl NeuralNetwork {
     }
 
     pub fn train(&mut self, input: &[f64], target: &[f64]) -> Result<(), NeuralNetworkError> {
-        if target.len() != *self.node_count_per_layer.last().unwrap() {
+        if input.len() != *self.node_count_per_layer.first().unwrap() {
             return Err(NeuralNetworkError::InsufficientTargetsData {
                 expected: self.node_count_per_layer[0],
                 found: input.len(),
             });
         }
+        if target.len() != *self.node_count_per_layer.last().unwrap() {
+            return Err(NeuralNetworkError::InsufficientTargetsData {
+                expected: *self.node_count_per_layer.last().unwrap(),
+                found: target.len(),
+            });
+        }
 
         // Creates 2 column matrices
         let input = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, input);
-        let target = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, target);
+        let target = DMatrix::from_column_slice(*self.node_count_per_layer.last().unwrap(), 1, target);
 
         let outputs = self.compute_query(self.weights.borrow().iter(), vec![input]);
         let output_error = target - outputs.last().unwrap();
 
+
         {
             let weights = Rc::clone(&self.weights);
             self.adjust_weights(