use crate::nn_error::{NeuralNetworkError, WeightError};
#[derive(Debug)]
-#[allow(dead_code)]
pub struct NeuralNetwork {
weights: Rc<RefCell<Vec<DMatrix<f64>>>>,
}
pub fn train(&mut self, input: &[f64], target: &[f64]) -> Result<(), NeuralNetworkError> {
- if target.len() != *self.node_count_per_layer.last().unwrap() {
+ if input.len() != *self.node_count_per_layer.first().unwrap() {
return Err(NeuralNetworkError::InsufficientTargetsData {
expected: self.node_count_per_layer[0],
found: input.len(),
});
}
+ if target.len() != *self.node_count_per_layer.last().unwrap() {
+ return Err(NeuralNetworkError::InsufficientTargetsData {
+ expected: *self.node_count_per_layer.last().unwrap(),
+ found: target.len(),
+ });
+ }
// Creates 2 column matrices
let input = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, input);
- let target = DMatrix::from_column_slice(self.node_count_per_layer[0], 1, target);
+ let target = DMatrix::from_column_slice(*self.node_count_per_layer.last().unwrap(), 1, target);
let outputs = self.compute_query(self.weights.borrow().iter(), vec![input]);
let output_error = target - outputs.last().unwrap();
+
{
let weights = Rc::clone(&self.weights);
self.adjust_weights(