]> git.leonardobizzoni.com Git - mnist/commitdiff
NeuralNetwork creation
authorLeonardoBizzoni <leo2002714@gmail.com>
Sun, 26 May 2024 17:09:51 +0000 (19:09 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Sun, 26 May 2024 17:09:51 +0000 (19:09 +0200)
.envrc [new file with mode: 0644]
.gitignore [new file with mode: 0644]
Cargo.lock [new file with mode: 0644]
Cargo.toml [new file with mode: 0644]
shell.nix [new file with mode: 0644]
src/lib.rs [new file with mode: 0644]
src/neural_network.rs [new file with mode: 0644]

diff --git a/.envrc b/.envrc
new file mode 100644 (file)
index 0000000..1d953f4
--- /dev/null
+++ b/.envrc
@@ -0,0 +1 @@
+use nix
diff --git a/.gitignore b/.gitignore
new file mode 100644 (file)
index 0000000..ea8c4bf
--- /dev/null
@@ -0,0 +1 @@
+/target
diff --git a/Cargo.lock b/Cargo.lock
new file mode 100644 (file)
index 0000000..1cf900f
--- /dev/null
@@ -0,0 +1,274 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "approx"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "autocfg"
+version = "1.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
+
+[[package]]
+name = "bytemuck"
+version = "1.16.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5"
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "getrandom"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi",
+]
+
+[[package]]
+name = "libc"
+version = "0.2.155"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c"
+
+[[package]]
+name = "libm"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
+
+[[package]]
+name = "matrixmultiply"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
+dependencies = [
+ "autocfg",
+ "rawpointer",
+]
+
+[[package]]
+name = "nalgebra"
+version = "0.32.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3ea4908d4f23254adda3daa60ffef0f1ac7b8c3e9a864cf3cc154b251908a2ef"
+dependencies = [
+ "approx",
+ "matrixmultiply",
+ "nalgebra-macros",
+ "num-complex",
+ "num-rational",
+ "num-traits",
+ "simba",
+ "typenum",
+]
+
+[[package]]
+name = "nalgebra-macros"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "91761aed67d03ad966ef783ae962ef9bbaca728d2dd7ceb7939ec110fffad998"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "neural_network"
+version = "0.1.0"
+dependencies = [
+ "nalgebra",
+ "rand",
+ "rand_distr",
+]
+
+[[package]]
+name = "num-complex"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-integer"
+version = "0.1.46"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-rational"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824"
+dependencies = [
+ "num-integer",
+ "num-traits",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
+dependencies = [
+ "autocfg",
+ "libm",
+]
+
+[[package]]
+name = "paste"
+version = "1.0.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
+
+[[package]]
+name = "ppv-lite86"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.84"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.36"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "rand"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
+dependencies = [
+ "getrandom",
+]
+
+[[package]]
+name = "rand_distr"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
+dependencies = [
+ "num-traits",
+ "rand",
+]
+
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
+name = "safe_arch"
+version = "0.7.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354"
+dependencies = [
+ "bytemuck",
+]
+
+[[package]]
+name = "simba"
+version = "0.8.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae"
+dependencies = [
+ "approx",
+ "num-complex",
+ "num-traits",
+ "paste",
+ "wide",
+]
+
+[[package]]
+name = "syn"
+version = "1.0.109"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "typenum"
+version = "1.17.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
+
+[[package]]
+name = "unicode-ident"
+version = "1.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
+
+[[package]]
+name = "wasi"
+version = "0.11.0+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+
+[[package]]
+name = "wide"
+version = "0.7.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cd8dc749a1b03f3c255a3064a4f5c0ee5ed09b7c6bc6d4525d31f779cd74d7fc"
+dependencies = [
+ "bytemuck",
+ "safe_arch",
+]
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644 (file)
index 0000000..5098bbe
--- /dev/null
@@ -0,0 +1,11 @@
+[package]
+name = "neural_network"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+nalgebra = "0.32.5"
+rand = "0.8.5"
+rand_distr = "0.4.3"
diff --git a/shell.nix b/shell.nix
new file mode 100644 (file)
index 0000000..c83b5fb
--- /dev/null
+++ b/shell.nix
@@ -0,0 +1,15 @@
+{ pkgs ? import <nixpkgs> { } }:
+
+pkgs.mkShell {
+  nativeBuildInputs = with pkgs; [
+    cargo
+    rustc
+    rust-analyzer
+    rustfmt
+    tree-sitter-grammars.tree-sitter-rust
+  ];
+
+  shellHook = ''
+    ${pkgs.onefetch}/bin/onefetch
+  '';
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644 (file)
index 0000000..c56cfd0
--- /dev/null
@@ -0,0 +1 @@
+mod neural_network;
diff --git a/src/neural_network.rs b/src/neural_network.rs
new file mode 100644 (file)
index 0000000..0c2ec03
--- /dev/null
@@ -0,0 +1,136 @@
+use nalgebra::DMatrix;
+use rand_distr::{Distribution, Normal};
+
+#[derive(Debug)]
+pub enum NeuralNetworkError {
+    InvalidNumberOfOutputNodes,
+    InvalidNumberOfInputNodes,
+
+    WeightError(WeightError),
+}
+
+#[derive(Debug)]
+pub enum WeightError {
+    InvalidVariance(f64),
+}
+
+#[derive(Debug)]
+#[allow(dead_code)]
+pub struct NeuralNetwork {
+    weights: Vec<DMatrix<f64>>,
+
+    learning_rate: f64,
+
+    number_of_input_nodes: usize,
+    number_of_output_nodes: usize,
+
+    hlayer_node_count: Vec<usize>,
+}
+
+#[allow(dead_code)]
+impl NeuralNetwork {
+    pub fn new(
+        number_of_input_nodes: usize,
+        number_of_output_nodes: usize,
+        learning_rate: f64,
+        hlayer_node_count: Vec<usize>,
+    ) -> Result<Self, NeuralNetworkError> {
+        if number_of_input_nodes == 0 {
+            return Err(NeuralNetworkError::InvalidNumberOfInputNodes);
+        } else if number_of_output_nodes == 0 {
+            return Err(NeuralNetworkError::InvalidNumberOfOutputNodes);
+        }
+
+        // Weights from input to first hidden (or output) layer
+        let mut weights: Vec<DMatrix<f64>> = vec![DMatrix::from_vec(
+            number_of_input_nodes,
+            1,
+            NeuralNetwork::generate_random_weights(number_of_input_nodes)?,
+        )];
+
+        // Weights from hidden to next hidden (or output)
+        for node_count in hlayer_node_count.iter() {
+            weights.push(DMatrix::from_vec(
+                *node_count,
+                1,
+                NeuralNetwork::generate_random_weights(*node_count)?,
+            ));
+        }
+
+        Ok(Self {
+            weights,
+            number_of_input_nodes,
+            number_of_output_nodes,
+            learning_rate,
+            hlayer_node_count,
+        })
+    }
+
+    // pub fn train(&mut self, inputs: &[f64], target: f64) {
+    //     todo!()
+    // }
+
+    // pub fn compute(&self, inputs: &[f64]) -> Result<f64, NNError> {
+    //     todo!()
+    // }
+
+    // fn activation(&self, sum: f64) -> f64 {
+    //     todo!()
+    // }
+
+    fn generate_random_weights(n: usize) -> Result<Vec<f64>, NeuralNetworkError> {
+        let random = match Normal::new(0.0, (n as f64).powf(-0.5)) {
+            Ok(value) => value,
+            Err(_) => {
+                return Err(NeuralNetworkError::WeightError(
+                    WeightError::InvalidVariance(n as f64),
+                ))
+            }
+        };
+        let mut weights: Vec<f64> = vec![0.0; n];
+
+        for weight in weights.iter_mut() {
+            *weight = random.sample(&mut rand::thread_rng());
+        }
+
+        Ok(weights)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn nn_with_invalid_number_of_inputs() {
+        let nn = NeuralNetwork::new(0, 10, 42.0, vec![]);
+        assert!(nn.is_err())
+    }
+
+    #[test]
+    fn nn_without_outputs() {
+        let nn = NeuralNetwork::new(1, 0, 42.0, vec![]);
+        assert!(nn.is_err())
+    }
+
+    #[test]
+    fn nn_invalid_hidden_layer_node_count() {
+        let nn = NeuralNetwork::new(1, 1, 42.0, vec![0]);
+        assert!(nn.is_err())
+    }
+
+    #[test]
+    fn valid_nn_definition() {
+        // NN with 1 input, no hidden and 1 output
+        let nn = NeuralNetwork::new(1, 1, 42.0, vec![]);
+        assert!(nn.is_ok())
+    }
+    #[test]
+    fn valid_nn_with_hidden_layers_definition() {
+        // NN with 1 input, 4 hidden layers each with 3 nodes, and 1 output
+        let nn = NeuralNetwork::new(1, 1, 42.0, vec![3, 3, 3, 3]);
+
+        assert!(nn.is_ok());
+        assert_eq!(nn.unwrap().weights.len(), 5);
+    }
+}