From d9d59f5dd5a1af7c0a222c659e9abae363ad272d Mon Sep 17 00:00:00 2001 From: LeonardoBizzoni Date: Mon, 27 May 2024 14:45:30 +0200 Subject: [PATCH] NN query --- src/neural_network.rs | 45 ++++++++++++++++++++++++++----------------- src/nn_error.rs | 2 +- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/neural_network.rs b/src/neural_network.rs index b4887c1..5342da0 100644 --- a/src/neural_network.rs +++ b/src/neural_network.rs @@ -1,3 +1,5 @@ +use std::slice::Iter; + use nalgebra::DMatrix; use rand_distr::{Distribution, Normal}; @@ -62,30 +64,37 @@ impl NeuralNetwork { // todo!() // } - // pub fn compute(&self, inputs: &[f64]) -> Result { - // todo!() - // } + pub fn query(&mut self, inputs: Vec) -> Result, NeuralNetworkError> { + if inputs.len() != self.node_count_per_layer[0] { + return Err(NeuralNetworkError::InsufficientInputData { + expected: self.node_count_per_layer[0], + found: inputs.len(), + }); + } - // fn activation(&self, sum: f64) -> f64 { - // todo!() - // } + Ok(self.compute_query( + DMatrix::from_vec(self.node_count_per_layer[0], 1, inputs), + self.weights.iter(), + )) + } - fn generate_random_weights(n: usize) -> Result, NeuralNetworkError> { - let random = match Normal::new(0.0, (n as f64).powf(-0.5)) { - Ok(value) => value, - Err(_) => { - return Err(NeuralNetworkError::WeightError( - WeightError::InvalidVariance(n as f64), - )) + fn compute_query(&self, input: DMatrix, mut weight_iter: Iter>) -> Vec { + match weight_iter.next() { + Some(weight) => { + let next_input: DMatrix = NeuralNetwork::activation_function(weight * input); + self.compute_query(next_input, weight_iter) } - }; - let mut weights: Vec = vec![0.0; n]; + None => input.column(0).as_slice().into(), + } + } - for weight in weights.iter_mut() { - *weight = random.sample(&mut rand::thread_rng()); + fn activation_function(mut x: DMatrix) -> DMatrix { + // Sigmoid function applied to each element in the matrix (column vector) + for xi in x.iter_mut() { + *xi = 1.0 / (1.0 + std::f64::consts::E.powf(-*xi)); } - Ok(weights) + x } } diff --git a/src/nn_error.rs b/src/nn_error.rs index 65d3bb0..d3e4f5b 100644 --- a/src/nn_error.rs +++ b/src/nn_error.rs @@ -5,7 +5,7 @@ pub enum NeuralNetworkError { InvalidNumberOfNodes, WeightError(WeightError), - InsufficientInputData, + InsufficientInputData { expected: usize, found: usize }, } #[derive(Debug, Clone, Copy, PartialEq)] -- 2.52.0