From: LeonardoBizzoni Date: Tue, 21 Oct 2025 08:24:26 +0000 (+0200) Subject: vector-matrix/matrix-vector mul + identity matrix X-Git-Url: http://git.leonardobizzoni.com/?a=commitdiff_plain;h=ef6a9738e23e19d32e28a645fe2810f9b4db3b96;p=CBuild vector-matrix/matrix-vector mul + identity matrix --- diff --git a/extra/linear_algebra.h b/extra/linear_algebra.h index 5f9bb39..c00390b 100644 --- a/extra/linear_algebra.h +++ b/extra/linear_algebra.h @@ -57,6 +57,9 @@ static void cb_linagen_defun_vecn_dot(CB_Generator *gen, char *type, int32_t n, bool implementation); static void +cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type, + int32_t n, bool implementation); +static void cb_linagen_defun_vecn_magnitude(CB_Generator *gen, char *type, int32_t n, bool implementation); static void @@ -72,6 +75,12 @@ cb_linagen_defun_vecn_distance2(CB_Generator *gen, char *type, static void cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type, int32_t n, bool implementation); +static void +cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type, + int32_t n, bool implementation); +static void +cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type, + int32_t n, bool implementation); // ====================================================================== // Implementations @@ -404,6 +413,32 @@ cb_linagen_defun_vecn_distance2(CB_Generator *gen, char *type, "\n}\n\n"); } +static void +cb_linagen_defun_vecn_mulm(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + char *suffix = strdup(type); + *suffix = char_toupper(*type); + cb_gen_push_func_begin(gen, cb_format("linagen_fn void vec%d%s_mulm", n, type)); + cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix)); + cb_gen_push_func_end(gen, implementation); + if (!implementation) { return; } + + cb_gen_push(gen, " {"); + for (int32_t i = 0; i < n; ++i) { + char *prefix = ""; + cb_gen_push(gen, cb_format("\n res->values[%d] = ", i)); + for (int32_t j = 0; j < n; ++j) { + cb_gen_push(gen, cb_format("%s(v->values[%d] * m->values[%d][%d])", + prefix, j, i, j)); + prefix = " + "; + } + cb_gen_push(gen, ";"); + } + cb_gen_push(gen, "\n}\n\n"); +} + static void cb_linagen_defun_matnn_scale(CB_Generator *gen, char *type, int32_t n, bool implementation) { @@ -480,6 +515,51 @@ cb_linagen_defun_matnn_dot(CB_Generator *gen, char *type, cb_gen_push(gen, "\n}\n\n"); } +static void +cb_linagen_defun_matnn_mulv(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + char *suffix = strdup(type); + *suffix = char_toupper(*type); + cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_mulv", n, type)); + cb_gen_push_func_arg(gen, cb_format("Vec%d%s *restrict res", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Mat%d%s *restrict m", n, suffix)); + cb_gen_push_func_arg(gen, cb_format("const Vec%d%s *restrict v", n, suffix)); + cb_gen_push_func_end(gen, implementation); + if (!implementation) { return; } + + cb_gen_push(gen, " {"); + for (int32_t i = 0; i < n; ++i) { + char *prefix = ""; + cb_gen_push(gen, cb_format("\n res->values[%d] = ", i)); + for (int32_t j = 0; j < n; ++j) { + cb_gen_push(gen, cb_format("%s(m->values[%d][%d] * v->values[%d])", + prefix, j, i, j)); + prefix = " + "; + } + cb_gen_push(gen, ";"); + } + cb_gen_push(gen, "\n}\n\n"); +} + +static void +cb_linagen_defun_matnn_identity(CB_Generator *gen, char *type, + int32_t n, bool implementation) { + char *suffix = strdup(type); + *suffix = char_toupper(*type); + cb_gen_push_func_begin(gen, cb_format("linagen_fn void mat%d%s_identity", + n, type)); + cb_gen_push_func_arg(gen, cb_format("Mat%d%s *res", n, suffix)); + cb_gen_push_func_end(gen, implementation); + if (!implementation) { return; } + + cb_gen_push(gen, " {"); + cb_gen_push(gen, "\n memset(res, 0, sizeof *res);"); + for (int32_t i = 0; i < n; ++i) { + cb_gen_push(gen, cb_format("\n res->values[%d][%d] = 1;", i, i)); + } + cb_gen_push(gen, "\n}\n\n"); +} + internal void cb_linagen_defun_vecn_element_wise(CB_Generator *gen, char *type,