[conv2d] copy params in set_kernel, set_bias
authorPaul Brossier <piem@piem.org>
Tue, 29 Jan 2019 02:36:35 +0000 (03:36 +0100)
committerPaul Brossier <piem@piem.org>
Tue, 29 Jan 2019 02:36:35 +0000 (03:36 +0100)
src/ai/conv2d.c

index 976c28b..586dad7 100644 (file)
@@ -478,12 +478,12 @@ uint_t aubio_conv2d_set_padding_mode(aubio_conv2d_t *c,
 
 uint_t aubio_conv2d_set_kernel(aubio_conv2d_t *c, aubio_tensor_t *kernel)
 {
-  uint_t i;
   AUBIO_ASSERT(c && kernel);
-  for (i = 0; i < c->kernel->ndim; i++) {
-    AUBIO_ASSERT(c->kernel->shape[i] == kernel->shape[i]);
+  if (aubio_tensor_have_same_shape(kernel, c->kernel)) {
+    aubio_tensor_copy(kernel, c->kernel);
+    return AUBIO_OK;
   }
-  return AUBIO_OK;
+  return AUBIO_FAIL;
 }
 
 aubio_tensor_t *aubio_conv2d_get_kernel(aubio_conv2d_t* c)
@@ -495,7 +495,10 @@ aubio_tensor_t *aubio_conv2d_get_kernel(aubio_conv2d_t* c)
 uint_t aubio_conv2d_set_bias(aubio_conv2d_t *c, fvec_t *bias)
 {
   AUBIO_ASSERT(c && bias);
-  AUBIO_ASSERT(c->kernel_shape[1] == bias->length);
+  if (bias->length == c->bias->length) {
+    fvec_copy(bias, c->bias);
+    return AUBIO_OK;
+  }
   return AUBIO_OK;
 }