int32_t n, bool implementation,
char *func_name, char op);
+internal void
+cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type,
+ int32_t n, bool implementation,
+ char *func_name, char op);
+
internal char char_toupper(char ch);
static void
cb_gen_push(gen, "\n}\n\n");
}
+static inline void
+cb_linagen_defun_matnn_add(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "add", '+');
+}
+
+static inline void
+cb_linagen_defun_matnn_sub(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "sub", '-');
+}
+
+static inline void
+cb_linagen_defun_matnn_hadamard_prod(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ cb_linagen_defun_matnn_element_wise(gen, type, n, implementation,
+ "hadamard_prod", '*');
+}
+
+static inline void
+cb_linagen_defun_matnn_hadamard_div(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ cb_linagen_defun_matnn_element_wise(gen, type, n, implementation,
+ "hadamard_div", '/');
+}
+
+static void
+cb_linagen_defun_matnn_dot(CB_Generator *gen, char *type,
+ int32_t n, bool implementation) {
+ char *suffix = strdup(type);
+ *suffix = char_toupper(*type);
+ cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_dot", n, type));
+ cb_gen_push_func_arg(gen, cb_format("Mat%d%s *restrict res", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix));
+ cb_gen_push_func_end(gen, implementation);
+ if (!implementation) { return; }
+
+ cb_gen_push(gen, " {");
+ for (int32_t col = 0; col < n; ++col) {
+ for (int32_t row = 0; row < n; ++row) {
+ cb_gen_push(gen, cb_format("\n res->values[%d][%d] = ", col, row));
+ char *prefix = "";
+ for (int32_t k = 0; k < n; ++k) {
+ cb_gen_push(gen, cb_format("%s(m1->values[%d][%d] * m2->values[%d][%d])",
+ prefix, k, row, col, k));
+ prefix = " + ";
+ }
+ cb_gen_push(gen, ";");
+ }
+ }
+ cb_gen_push(gen, "\n}\n\n");
+}
+
internal void
cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,
cb_gen_push(gen, "\n}\n\n");
}
+internal void
+cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type,
+ int32_t n, bool implementation,
+ char *func_name, char op) {
+ char *suffix = strdup(type);
+ *suffix = char_toupper(*type);
+ cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_%s",
+ n, type, func_name));
+ cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix));
+ cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix));
+ cb_gen_push_func_end(gen, implementation);
+ if (!implementation) { return; }
+
+ cb_gen_push(gen, " {");
+ for (int32_t i = 0; i < n; ++i) {
+ for (int32_t j = 0; j < n; ++j) {
+ cb_gen_push(gen, cb_format("\n res->values[%d][%d] = m1->values[%d][%d] %c m2->values[%d][%d];",
+ i, j, i, j, op, i, j));
+ }
+ }
+ cb_gen_push(gen, "\n}\n\n");
+}
+
internal char char_toupper(char ch) {
if (ch >= 'a' && ch <= 'z') {
return ch - ('a' - 'A');