cb_linagen_defun_vecn_dot(CB_Generator *gen, char *type,
int32_t n, bool implementation);
static void
+cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type,
+ int32_t n, bool implementation);
+static void
cb_linagen_defun_vecn_magnitude(CB_Generator *gen, char *type,
int32_t n, bool implementation);
static void
static void
cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type,
int32_t n, bool implementation);
+static void
+cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type,
+ int32_t n, bool implementation);
+static void
+cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type,
+ int32_t n, bool implementation);
// ======================================================================
// Implementations
"\n}\n\n");
}
+static void
+cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ char *suffix = strdup(type);
+ *suffix = char_toupper(*type);
+ cb_gen_push_func_begin(gen, cb_format("linagen_fn void vec%d%s_mulm", n, type));
+ cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix));
+ cb_gen_push_func_end(gen, implementation);
+ if (!implementation) { return; }
+
+ cb_gen_push(gen, " {");
+ for (int32_t i = 0; i < n; ++i) {
+ char *prefix = "";
+ cb_gen_push(gen, cb_format("\n res->values[%d] = ", i));
+ for (int32_t j = 0; j < n; ++j) {
+ cb_gen_push(gen, cb_format("%s(v->values[%d] * m->values[%d][%d])",
+ prefix, j, i, j));
+ prefix = " + ";
+ }
+ cb_gen_push(gen, ";");
+ }
+ cb_gen_push(gen, "\n}\n\n");
+}
+
static void
cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type,
int32_t n, bool implementation) {
cb_gen_push(gen, "\n}\n\n");
}
+static void
+cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ char *suffix = strdup(type);
+ *suffix = char_toupper(*type);
+ cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_mulv", n, type));
+ cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix));
+ cb_gen_push_func_end(gen, implementation);
+ if (!implementation) { return; }
+
+ cb_gen_push(gen, " {");
+ for (int32_t i = 0; i < n; ++i) {
+ char *prefix = "";
+ cb_gen_push(gen, cb_format("\n res->values[%d] = ", i));
+ for (int32_t j = 0; j < n; ++j) {
+ cb_gen_push(gen, cb_format("%s(m->values[%d][%d] * v->values[%d])",
+ prefix, j, i, j));
+ prefix = " + ";
+ }
+ cb_gen_push(gen, ";");
+ }
+ cb_gen_push(gen, "\n}\n\n");
+}
+
+static void
+cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ char *suffix = strdup(type);
+ *suffix = char_toupper(*type);
+ cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_identity",
+ n, type));
+ cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix));
+ cb_gen_push_func_end(gen, implementation);
+ if (!implementation) { return; }
+
+ cb_gen_push(gen, " {");
+ cb_gen_push(gen, "\n memset(res, 0, sizeof *res);");
+ for (int32_t i = 0; i < n; ++i) {
+ cb_gen_push(gen, cb_format("\n res->values[%d][%d] = 1;", i, i));
+ }
+ cb_gen_push(gen, "\n}\n\n");
+}
+
internal void
cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,