From 1768cdb5f5694da62ee91ea277323d570464c39d Mon Sep 17 00:00:00 2001 From: Paul Brossier Date: Tue, 8 Jan 2019 00:00:16 +0100 Subject: [PATCH] [tests] add tensor_matmul test --- tests/src/ai/test-tensor.c | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/src/ai/test-tensor.c b/tests/src/ai/test-tensor.c index 4626f853..b8b30b46 100644 --- a/tests/src/ai/test-tensor.c +++ b/tests/src/ai/test-tensor.c @@ -270,6 +270,47 @@ int test_get_shape_string(void) return 0; } +int test_matmul(void) +{ + uint_t m = 3, n = 2, k = 4; + uint_t input_shape[2] = {m, k}; + uint_t kernel_shape[2] = {k, n}; + uint_t output_shape[2] = {m, n}; + + aubio_tensor_t *input_tensor = new_aubio_tensor(2, input_shape); + aubio_tensor_t *kernel_tensor = new_aubio_tensor(2, kernel_shape); + aubio_tensor_t *output_tensor = new_aubio_tensor(2, output_shape); + + input_tensor->data[0][0] = 1; + input_tensor->data[1][1] = 1; + input_tensor->data[2][0] = -1; + input_tensor->data[2][1] = 1; + uint_t i; + for (i = 0; i < kernel_tensor->size; i++) { + kernel_tensor->buffer[i] = (smpl_t)i + 1.; + } + + aubio_tensor_matmul(input_tensor, kernel_tensor, output_tensor); + + PRINT_MSG("input: "); + aubio_tensor_print(input_tensor); + PRINT_MSG("kernel: "); + aubio_tensor_print(kernel_tensor); + PRINT_MSG("output: "); + aubio_tensor_print(output_tensor); + + assert (output_tensor->data[0][0] == kernel_tensor->data[0][0]); + assert (output_tensor->data[0][1] == kernel_tensor->data[0][1]); + assert (output_tensor->data[1][0] == kernel_tensor->data[1][0]); + assert (output_tensor->data[1][1] == kernel_tensor->data[1][1]); + assert (output_tensor->data[2][0] == 2); + assert (output_tensor->data[2][1] == 2); + + del_aubio_tensor(output_tensor); + del_aubio_tensor(kernel_tensor); + del_aubio_tensor(input_tensor); + return 0; +} int main(void) { PRINT_MSG("testing 1d tensors\n"); assert (test_1d() == 0); @@ -291,5 +332,7 @@ int main(void) { assert (test_maxtensor() == 0); PRINT_MSG("testing get_shape_string\n"); assert (test_get_shape_string() == 0); + PRINT_MSG("testing matmul\n"); + assert (test_matmul() == 0); return 0; } -- 2.11.0