From: Paul Brossier Date: Mon, 28 Jan 2019 21:56:27 +0000 (+0100) Subject: [tensor] rewrite and rename have_same_shape X-Git-Url: https://git.aubio.org/?a=commitdiff_plain;h=4b2d174adc612d476ea625577be491917ba95bb3;p=aubio.git [tensor] rewrite and rename have_same_shape --- diff --git a/src/ai/tensor.c b/src/ai/tensor.c index cbf2af95..ca134599 100644 --- a/src/ai/tensor.c +++ b/src/ai/tensor.c @@ -137,15 +137,15 @@ uint_t aubio_tensor_get_subtensor(aubio_tensor_t *t, uint_t i, return AUBIO_OK; } -uint_t aubio_tensor_have_same_size(aubio_tensor_t *t, aubio_tensor_t *s) +uint_t aubio_tensor_have_same_shape(aubio_tensor_t *a, aubio_tensor_t *b) { uint_t n; - if (!t || !s) return 0; - if (t->ndim != s->ndim) return 0; - if (t->size != s->size) return 0; - n = t->ndim; - while (n--) { - if (t->shape[n] != s->shape[n]) { + AUBIO_ASSERT(a && b); + if (a->ndim != b->ndim) { + return 0; + } + for (n = 0; n < a->ndim; n++) { + if (a->shape[n] != b->shape[n]) { return 0; } } diff --git a/tests/src/ai/test-tensor.c b/tests/src/ai/test-tensor.c index b8b30b46..84d69432 100644 --- a/tests/src/ai/test-tensor.c +++ b/tests/src/ai/test-tensor.c @@ -123,28 +123,28 @@ int test_sizes(void) aubio_tensor_t *a = new_aubio_tensor(4, dims); aubio_tensor_t *b = new_aubio_tensor(3, dims); - assert (!aubio_tensor_have_same_size(a, b)); + assert (!aubio_tensor_have_same_shape(a, b)); del_aubio_tensor(b); dims[2] += 1; b = new_aubio_tensor(4, dims); - assert (!aubio_tensor_have_same_size(a, b)); + assert (!aubio_tensor_have_same_shape(a, b)); del_aubio_tensor(b); dims[2] -= 1; dims[0] -= 1; dims[1] += 1; b = new_aubio_tensor(4, dims); - assert (!aubio_tensor_have_same_size(a, b)); + assert (!aubio_tensor_have_same_shape(a, b)); del_aubio_tensor(b); dims[0] += 1; dims[1] -= 1; b = new_aubio_tensor(4, dims); - assert (aubio_tensor_have_same_size(a, b)); + assert (aubio_tensor_have_same_shape(a, b)); - assert (!aubio_tensor_have_same_size(NULL, b)); - assert (!aubio_tensor_have_same_size(a, NULL)); + assert (!aubio_tensor_have_same_shape(NULL, b)); + assert (!aubio_tensor_have_same_shape(a, NULL)); del_aubio_tensor(a); del_aubio_tensor(b); @@ -320,7 +320,7 @@ int main(void) { assert (test_3d() == 0); PRINT_MSG("testing 4d tensors\n"); assert (test_4d() == 0); - PRINT_MSG("testing aubio_tensor_have_same_size\n"); + PRINT_MSG("testing aubio_tensor_have_same_shape\n"); assert (test_sizes() == 0); PRINT_MSG("testing new_aubio_tensor with wrong arguments\n"); assert (test_wrong_args() == 0);