From e18c30efe717668a4ac020a0b3b0f2a8ec3ee576 Mon Sep 17 00:00:00 2001 From: Paul Brossier Date: Tue, 29 Jan 2019 03:36:35 +0100 Subject: [PATCH] [conv2d] copy params in set_kernel, set_bias --- src/ai/conv2d.c | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ai/conv2d.c b/src/ai/conv2d.c index 976c28bc..586dad71 100644 --- a/src/ai/conv2d.c +++ b/src/ai/conv2d.c @@ -478,12 +478,12 @@ uint_t aubio_conv2d_set_padding_mode(aubio_conv2d_t *c, uint_t aubio_conv2d_set_kernel(aubio_conv2d_t *c, aubio_tensor_t *kernel) { - uint_t i; AUBIO_ASSERT(c && kernel); - for (i = 0; i < c->kernel->ndim; i++) { - AUBIO_ASSERT(c->kernel->shape[i] == kernel->shape[i]); + if (aubio_tensor_have_same_shape(kernel, c->kernel)) { + aubio_tensor_copy(kernel, c->kernel); + return AUBIO_OK; } - return AUBIO_OK; + return AUBIO_FAIL; } aubio_tensor_t *aubio_conv2d_get_kernel(aubio_conv2d_t* c) @@ -495,7 +495,10 @@ aubio_tensor_t *aubio_conv2d_get_kernel(aubio_conv2d_t* c) uint_t aubio_conv2d_set_bias(aubio_conv2d_t *c, fvec_t *bias) { AUBIO_ASSERT(c && bias); - AUBIO_ASSERT(c->kernel_shape[1] == bias->length); + if (bias->length == c->bias->length) { + fvec_copy(bias, c->bias); + return AUBIO_OK; + } return AUBIO_OK; } -- 2.11.0