]> git.leonardobizzoni.com Git - CBuild/commitdiff
vector-matrix/matrix-vector mul + identity matrix
authorLeonardoBizzoni <leo2002714@gmail.com>
Tue, 21 Oct 2025 08:24:26 +0000 (10:24 +0200)
committerLeonardoBizzoni <leo2002714@gmail.com>
Tue, 21 Oct 2025 08:24:26 +0000 (10:24 +0200)
extra/linear_algebra.h

index 5f9bb39f2bb3992108401f490deee2b7b1f337fb..c00390b5c72eb09884ba9b14e0ccdfac8f14a8c1 100644 (file)
@@ -57,6 +57,9 @@ static void
 cb_linagen_defun_vecn_dot(CB_Generator *gen, char *type,
                           int32_t n, bool implementation);
 static void
+cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type,
+                           int32_t n, bool implementation);
+static void
 cb_linagen_defun_vecn_magnitude(CB_Generator *gen, char *type,
                                 int32_t n, bool implementation);
 static void
@@ -72,6 +75,12 @@ cb_linagen_defun_vecn_distance2(CB_Generator *gen, char *type,
 static void
 cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type,
                              int32_t n, bool implementation);
+static void
+cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type,
+                            int32_t n, bool implementation);
+static void
+cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type,
+                                int32_t n, bool implementation);
 
 // ======================================================================
 // Implementations
@@ -404,6 +413,32 @@ cb_linagen_defun_vecn_distance2(CB_Generator *gen, char *type,
                    "\n}\n\n");
 }
 
+static void
+cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type,
+                           int32_t n, bool implementation) {
+  char *suffix = strdup(type);
+  *suffix = char_toupper(*type);
+  cb_gen_push_func_begin(gen, cb_format("linagen_fn void vec%d%s_mulm", n, type));
+    cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix));
+  cb_gen_push_func_end(gen, implementation);
+  if (!implementation) { return; }
+
+  cb_gen_push(gen, " {");
+  for (int32_t i = 0; i < n; ++i) {
+    char *prefix = "";
+    cb_gen_push(gen, cb_format("\n  res->values[%d] = ", i));
+    for (int32_t j = 0; j < n; ++j) {
+      cb_gen_push(gen, cb_format("%s(v->values[%d] * m->values[%d][%d])",
+                                 prefix, j, i, j));
+      prefix = " + ";
+    }
+    cb_gen_push(gen, ";");
+  }
+  cb_gen_push(gen, "\n}\n\n");
+}
+
 static void
 cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type,
                              int32_t n, bool implementation) {
@@ -480,6 +515,51 @@ cb_linagen_defun_matnn_dot(CB_Generator *gen, char *type,
   cb_gen_push(gen, "\n}\n\n");
 }
 
+static void
+cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type,
+                            int32_t n, bool implementation) {
+  char *suffix = strdup(type);
+  *suffix = char_toupper(*type);
+  cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_mulv", n, type));
+    cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix));
+    cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix));
+  cb_gen_push_func_end(gen, implementation);
+  if (!implementation) { return; }
+
+  cb_gen_push(gen, " {");
+  for (int32_t i = 0; i < n; ++i) {
+    char *prefix = "";
+    cb_gen_push(gen, cb_format("\n  res->values[%d] = ", i));
+    for (int32_t j = 0; j < n; ++j) {
+      cb_gen_push(gen, cb_format("%s(m->values[%d][%d] * v->values[%d])",
+                                 prefix, j, i, j));
+      prefix = " + ";
+    }
+    cb_gen_push(gen, ";");
+  }
+  cb_gen_push(gen, "\n}\n\n");
+}
+
+static void
+cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type,
+                                int32_t n, bool implementation) {
+  char *suffix = strdup(type);
+  *suffix = char_toupper(*type);
+  cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_identity",
+                                        n, type));
+    cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix));
+  cb_gen_push_func_end(gen, implementation);
+  if (!implementation) { return; }
+
+  cb_gen_push(gen, " {");
+  cb_gen_push(gen, "\n  memset(res, 0, sizeof *res);");
+  for (int32_t i = 0; i < n; ++i) {
+    cb_gen_push(gen, cb_format("\n  res->values[%d][%d] = 1;", i, i));
+  }
+  cb_gen_push(gen, "\n}\n\n");
+}
+
 
 internal void
 cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,