From: LeonardoBizzoni Date: Sun, 26 May 2024 17:09:51 +0000 (+0200) Subject: NeuralNetwork creation X-Git-Url: http://git.leonardobizzoni.com/?a=commitdiff_plain;h=df833b75e196cc51875434da673760c7705c2add;p=mnist NeuralNetwork creation --- df833b75e196cc51875434da673760c7705c2add diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..1d953f4 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use nix diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..1cf900f --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,274 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bytemuck" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "libc" +version = "0.2.155" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "nalgebra" +version = "0.32.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ea4908d4f23254adda3daa60ffef0f1ac7b8c3e9a864cf3cc154b251908a2ef" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "neural_network" +version = "0.1.0" +dependencies = [ + "nalgebra", + "rand", + "rand_distr", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.84" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "safe_arch" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wide" +version = "0.7.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd8dc749a1b03f3c255a3064a4f5c0ee5ed09b7c6bc6d4525d31f779cd74d7fc" +dependencies = [ + "bytemuck", + "safe_arch", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..5098bbe --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "neural_network" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +nalgebra = "0.32.5" +rand = "0.8.5" +rand_distr = "0.4.3" diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..c83b5fb --- /dev/null +++ b/shell.nix @@ -0,0 +1,15 @@ +{ pkgs ? import { } }: + +pkgs.mkShell { + nativeBuildInputs = with pkgs; [ + cargo + rustc + rust-analyzer + rustfmt + tree-sitter-grammars.tree-sitter-rust + ]; + + shellHook = '' + ${pkgs.onefetch}/bin/onefetch + ''; +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..c56cfd0 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1 @@ +mod neural_network; diff --git a/src/neural_network.rs b/src/neural_network.rs new file mode 100644 index 0000000..0c2ec03 --- /dev/null +++ b/src/neural_network.rs @@ -0,0 +1,136 @@ +use nalgebra::DMatrix; +use rand_distr::{Distribution, Normal}; + +#[derive(Debug)] +pub enum NeuralNetworkError { + InvalidNumberOfOutputNodes, + InvalidNumberOfInputNodes, + + WeightError(WeightError), +} + +#[derive(Debug)] +pub enum WeightError { + InvalidVariance(f64), +} + +#[derive(Debug)] +#[allow(dead_code)] +pub struct NeuralNetwork { + weights: Vec>, + + learning_rate: f64, + + number_of_input_nodes: usize, + number_of_output_nodes: usize, + + hlayer_node_count: Vec, +} + +#[allow(dead_code)] +impl NeuralNetwork { + pub fn new( + number_of_input_nodes: usize, + number_of_output_nodes: usize, + learning_rate: f64, + hlayer_node_count: Vec, + ) -> Result { + if number_of_input_nodes == 0 { + return Err(NeuralNetworkError::InvalidNumberOfInputNodes); + } else if number_of_output_nodes == 0 { + return Err(NeuralNetworkError::InvalidNumberOfOutputNodes); + } + + // Weights from input to first hidden (or output) layer + let mut weights: Vec> = vec![DMatrix::from_vec( + number_of_input_nodes, + 1, + NeuralNetwork::generate_random_weights(number_of_input_nodes)?, + )]; + + // Weights from hidden to next hidden (or output) + for node_count in hlayer_node_count.iter() { + weights.push(DMatrix::from_vec( + *node_count, + 1, + NeuralNetwork::generate_random_weights(*node_count)?, + )); + } + + Ok(Self { + weights, + number_of_input_nodes, + number_of_output_nodes, + learning_rate, + hlayer_node_count, + }) + } + + // pub fn train(&mut self, inputs: &[f64], target: f64) { + // todo!() + // } + + // pub fn compute(&self, inputs: &[f64]) -> Result { + // todo!() + // } + + // fn activation(&self, sum: f64) -> f64 { + // todo!() + // } + + fn generate_random_weights(n: usize) -> Result, NeuralNetworkError> { + let random = match Normal::new(0.0, (n as f64).powf(-0.5)) { + Ok(value) => value, + Err(_) => { + return Err(NeuralNetworkError::WeightError( + WeightError::InvalidVariance(n as f64), + )) + } + }; + let mut weights: Vec = vec![0.0; n]; + + for weight in weights.iter_mut() { + *weight = random.sample(&mut rand::thread_rng()); + } + + Ok(weights) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn nn_with_invalid_number_of_inputs() { + let nn = NeuralNetwork::new(0, 10, 42.0, vec![]); + assert!(nn.is_err()) + } + + #[test] + fn nn_without_outputs() { + let nn = NeuralNetwork::new(1, 0, 42.0, vec![]); + assert!(nn.is_err()) + } + + #[test] + fn nn_invalid_hidden_layer_node_count() { + let nn = NeuralNetwork::new(1, 1, 42.0, vec![0]); + assert!(nn.is_err()) + } + + #[test] + fn valid_nn_definition() { + // NN with 1 input, no hidden and 1 output + let nn = NeuralNetwork::new(1, 1, 42.0, vec![]); + assert!(nn.is_ok()) + } + #[test] + fn valid_nn_with_hidden_layers_definition() { + // NN with 1 input, 4 hidden layers each with 3 nodes, and 1 output + let nn = NeuralNetwork::new(1, 1, 42.0, vec![3, 3, 3, 3]); + + assert!(nn.is_ok()); + assert_eq!(nn.unwrap().weights.len(), 5); + } +}