[tensor] add matmul
authorPaul Brossier <piem@piem.org>
Mon, 7 Jan 2019 22:26:09 +0000 (23:26 +0100)
committerPaul Brossier <piem@piem.org>
Mon, 7 Jan 2019 22:26:09 +0000 (23:26 +0100)
src/ai/tensor.c
src/ai/tensor.h

index 7486a75..b2d3206 100644 (file)
@@ -203,3 +203,28 @@ void aubio_tensor_print(aubio_tensor_t *t)
   aubio_tensor_print_subtensor(t, 0);
   AUBIO_MSG("\n");
 }
+
+void aubio_tensor_matmul(aubio_tensor_t *a, aubio_tensor_t *b,
+    aubio_tensor_t *c)
+{
+  AUBIO_ASSERT (a->shape[0] == c->shape[0]);
+  AUBIO_ASSERT (a->shape[1] == b->shape[0]);
+  AUBIO_ASSERT (b->shape[1] == c->shape[1]);
+#if !defined(HAVE_BLAS)
+  uint_t i, j, k;
+  for (i = 0; i < c->shape[0]; i++) {
+    for (j = 0; j < c->shape[1]; j++) {
+      smpl_t sum = 0.;
+      for (k = 0; k < a->shape[1]; k++) {
+          sum += a->buffer[i * a->shape[1] + k]
+            * b->buffer[k * b->shape[1] + j];
+      }
+      c->buffer[i * c->shape[1] + j] = sum;
+    }
+  }
+#else
+  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, a->shape[0],
+      b->shape[1], b->shape[0], 1.F, a->buffer, a->shape[1], b->buffer,
+      b->shape[1], 0.F, c->buffer, b->shape[1]);
+#endif
+}
index 7ab7651..44a14c2 100644 (file)
@@ -148,6 +148,9 @@ void aubio_tensor_print(aubio_tensor_t *t);
 */
 const char_t *aubio_tensor_get_shape_string(aubio_tensor_t *t);
 
+void aubio_tensor_matmul(aubio_tensor_t *a, aubio_tensor_t *b,
+    aubio_tensor_t *c);
+
 #ifdef __cplusplus
 }
 #endif