]> git.leonardobizzoni.com Git - CBuild/commitdiff
element wise op and dot on matrix
authorLeonardoBizzoni <leo2002714@gmail.com>
Mon, 20 Oct 2025 16:54:51 +0000 (18:54 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Mon, 20 Oct 2025 16:54:51 +0000 (18:54 +0200)
extra/linear_algebra.h

index ea533bffd4b6e4e912b4a19e3d3e80f7b611f06b..5f9bb39f2bb3992108401f490deee2b7b1f337fb 100644 (file)
@@ -8,6 +8,11 @@ cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,
                                    int32_t n, bool implementation,
                                    char *func_name, char op);
 
+internal void
+cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type,
+                                    int32_t n, bool implementation,
+                                    char *func_name, char op);
+
 internal char char_toupper(char ch);
 
 static void
@@ -421,6 +426,60 @@ cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type,
   cb_gen_push(gen, "\n}\n\n");
 }
 
+static inline void
+cb_linagen_defun_matnn_add(CB_Generator *gen, char *type,
+                           int32_t n, bool implementation) {
+  cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "add", '+');
+}
+
+static inline void
+cb_linagen_defun_matnn_sub(CB_Generator *gen, char *type,
+                           int32_t n, bool implementation) {
+  cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "sub", '-');
+}
+
+static inline void
+cb_linagen_defun_matnn_hadamard_prod(CB_Generator *gen, char *type,
+                                     int32_t n, bool implementation) {
+  cb_linagen_defun_matnn_element_wise(gen, type, n, implementation,
+                                      "hadamard_prod", '*');
+}
+
+static inline void
+cb_linagen_defun_matnn_hadamard_div(CB_Generator *gen, char *type,
+                                    int32_t n, bool implementation) {
+  cb_linagen_defun_matnn_element_wise(gen, type, n, implementation,
+                                      "hadamard_div", '/');
+}
+
+static void
+cb_linagen_defun_matnn_dot(CB_Generator *gen, char *type,
+                           int32_t n, bool implementation) {
+  char *suffix = strdup(type);
+  *suffix = char_toupper(*type);
+  cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_dot", n, type));
+    cb_gen_push_func_arg(gen, cb_format("Mat%d%s *restrict res", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix));
+  cb_gen_push_func_end(gen, implementation);
+  if (!implementation) { return; }
+
+  cb_gen_push(gen, " {");
+  for (int32_t col = 0; col < n; ++col) {
+    for (int32_t row = 0; row < n; ++row) {
+      cb_gen_push(gen, cb_format("\n  res->values[%d][%d] = ", col, row));
+      char *prefix = "";
+      for (int32_t k = 0; k < n; ++k) {
+        cb_gen_push(gen, cb_format("%s(m1->values[%d][%d] * m2->values[%d][%d])",
+                                   prefix, k, row, col, k));
+        prefix = " + ";
+      }
+      cb_gen_push(gen, ";");
+    }
+  }
+  cb_gen_push(gen, "\n}\n\n");
+}
+
 
 internal void
 cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,
@@ -444,6 +503,30 @@ cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,
   cb_gen_push(gen, "\n}\n\n");
 }
 
+internal void
+cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type,
+                                    int32_t n, bool implementation,
+                                    char *func_name, char op) {
+  char *suffix = strdup(type);
+  *suffix = char_toupper(*type);
+  cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_%s",
+                                        n, type, func_name));
+    cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix));
+  cb_gen_push_func_end(gen, implementation);
+  if (!implementation) { return; }
+
+  cb_gen_push(gen, " {");
+  for (int32_t i = 0; i < n; ++i) {
+    for (int32_t j = 0; j < n; ++j) {
+      cb_gen_push(gen, cb_format("\n  res->values[%d][%d] = m1->values[%d][%d] %c m2->values[%d][%d];",
+                                 i, j, i, j, op, i, j));
+    }
+  }
+  cb_gen_push(gen, "\n}\n\n");
+}
+
 internal char char_toupper(char ch) {
   if (ch >= 'a' && ch <= 'z') {
     return ch - ('a' - 'A');