From a33c395a66d8e77e81e18964b8b596a909a55aab Mon Sep 17 00:00:00 2001 From: Paul Brossier Date: Mon, 7 Jan 2019 23:26:09 +0100 Subject: [PATCH] [tensor] add matmul --- src/ai/tensor.c | 25 +++++++++++++++++++++++++ src/ai/tensor.h | 3 +++ 2 files changed, 28 insertions(+) diff --git a/src/ai/tensor.c b/src/ai/tensor.c index 7486a75a..b2d3206d 100644 --- a/src/ai/tensor.c +++ b/src/ai/tensor.c @@ -203,3 +203,28 @@ void aubio_tensor_print(aubio_tensor_t *t) aubio_tensor_print_subtensor(t, 0); AUBIO_MSG("\n"); } + +void aubio_tensor_matmul(aubio_tensor_t *a, aubio_tensor_t *b, + aubio_tensor_t *c) +{ + AUBIO_ASSERT (a->shape[0] == c->shape[0]); + AUBIO_ASSERT (a->shape[1] == b->shape[0]); + AUBIO_ASSERT (b->shape[1] == c->shape[1]); +#if !defined(HAVE_BLAS) + uint_t i, j, k; + for (i = 0; i < c->shape[0]; i++) { + for (j = 0; j < c->shape[1]; j++) { + smpl_t sum = 0.; + for (k = 0; k < a->shape[1]; k++) { + sum += a->buffer[i * a->shape[1] + k] + * b->buffer[k * b->shape[1] + j]; + } + c->buffer[i * c->shape[1] + j] = sum; + } + } +#else + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, a->shape[0], + b->shape[1], b->shape[0], 1.F, a->buffer, a->shape[1], b->buffer, + b->shape[1], 0.F, c->buffer, b->shape[1]); +#endif +} diff --git a/src/ai/tensor.h b/src/ai/tensor.h index 7ab76512..44a14c23 100644 --- a/src/ai/tensor.h +++ b/src/ai/tensor.h @@ -148,6 +148,9 @@ void aubio_tensor_print(aubio_tensor_t *t); */ const char_t *aubio_tensor_get_shape_string(aubio_tensor_t *t); +void aubio_tensor_matmul(aubio_tensor_t *a, aubio_tensor_t *b, + aubio_tensor_t *c); + #ifdef __cplusplus } #endif -- 2.11.0