From: LeonardoBizzoni Date: Mon, 20 Oct 2025 16:54:51 +0000 (+0200) Subject: element wise op and dot on matrix X-Git-Url: http://git.leonardobizzoni.com/?a=commitdiff_plain;h=42b3045f5629fd866c59d17277ffe120ed049bb7;p=CBuild element wise op and dot on matrix --- diff --git a/extra/linear_algebra.h b/extra/linear_algebra.h index ea533bf..5f9bb39 100644 --- a/extra/linear_algebra.h +++ b/extra/linear_algebra.h @@ -8,6 +8,11 @@ cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type, int32_t n, bool implementation, char *func_name, char op); +internal void +cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type, + int32_t n, bool implementation, + char *func_name, char op); + internal char char_toupper(char ch); static void @@ -421,6 +426,60 @@ cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type, cb_gen_push(gen, "\n}\n\n"); } +static inline void +cb_linagen_defun_matnn_add(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "add", '+'); +} + +static inline void +cb_linagen_defun_matnn_sub(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, "sub", '-'); +} + +static inline void +cb_linagen_defun_matnn_hadamard_prod(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, + "hadamard_prod", '*'); +} + +static inline void +cb_linagen_defun_matnn_hadamard_div(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + cb_linagen_defun_matnn_element_wise(gen, type, n, implementation, + "hadamard_div", '/'); +} + +static void +cb_linagen_defun_matnn_dot(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + char *suffix = strdup(type); + *suffix = char_toupper(*type); + cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_dot", n, type)); + cb_gen_push_func_arg(gen, cb_format("Mat%d%s *restrict res", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix)); + cb_gen_push_func_end(gen, implementation); + if (!implementation) { return; } + + cb_gen_push(gen, " {"); + for (int32_t col = 0; col < n; ++col) { + for (int32_t row = 0; row < n; ++row) { + cb_gen_push(gen, cb_format("\n res->values[%d][%d] = ", col, row)); + char *prefix = ""; + for (int32_t k = 0; k < n; ++k) { + cb_gen_push(gen, cb_format("%s(m1->values[%d][%d] * m2->values[%d][%d])", + prefix, k, row, col, k)); + prefix = " + "; + } + cb_gen_push(gen, ";"); + } + } + cb_gen_push(gen, "\n}\n\n"); +} + internal void cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type, @@ -444,6 +503,30 @@ cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type, cb_gen_push(gen, "\n}\n\n"); } +internal void +cb_linagen_defun_matnn_element_wise(CB_Generator *gen, char *type, + int32_t n, bool implementation, + char *func_name, char op) { + char *suffix = strdup(type); + *suffix = char_toupper(*type); + cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_%s", + n, type, func_name)); + cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m1", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *m2", n, suffix)); + cb_gen_push_func_end(gen, implementation); + if (!implementation) { return; } + + cb_gen_push(gen, " {"); + for (int32_t i = 0; i < n; ++i) { + for (int32_t j = 0; j < n; ++j) { + cb_gen_push(gen, cb_format("\n res->values[%d][%d] = m1->values[%d][%d] %c m2->values[%d][%d];", + i, j, i, j, op, i, j)); + } + } + cb_gen_push(gen, "\n}\n\n"); +} + internal char char_toupper(char ch) { if (ch >= 'a' && ch <= 'z') { return ch - ('a' - 'A');