[batchnorm] accepts any input size, allocate weights in get_output_shape
authorPaul Brossier <piem@piem.org>
Fri, 18 Jan 2019 09:47:14 +0000 (10:47 +0100)
committerPaul Brossier <piem@piem.org>
Wed, 29 Dec 2021 16:52:00 +0000 (11:52 -0500)
src/ai/batchnorm.c
src/ai/batchnorm.h

index 4184275..490da8b 100644 (file)
@@ -34,27 +34,18 @@ struct _aubio_batchnorm_t {
 static void aubio_batchnorm_debug(aubio_batchnorm_t *c,
     aubio_tensor_t *input_tensor);
 
-aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs)
+aubio_batchnorm_t *new_aubio_batchnorm(void)
 {
   aubio_batchnorm_t *c = AUBIO_NEW(aubio_batchnorm_t);
-
-  AUBIO_GOTO_FAILURE((sint_t)n_outputs > 0);
-
-  c->n_outputs = n_outputs;
-
-  c->gamma = new_fvec(n_outputs);
-  c->beta = new_fvec(n_outputs);
-  c->moving_mean = new_fvec(n_outputs);
-  c->moving_variance = new_fvec(n_outputs);
-
   return c;
-
+#if 0 // no argument so no other possible failure
 failure:
   del_aubio_batchnorm(c);
   return NULL;
+#endif
 }
 
-void del_aubio_batchnorm(aubio_batchnorm_t* c) {
+static void aubio_batchnorm_reset(aubio_batchnorm_t *c) {
   AUBIO_ASSERT(c);
   if (c->gamma)
     del_fvec(c->gamma);
@@ -64,6 +55,10 @@ void del_aubio_batchnorm(aubio_batchnorm_t* c) {
     del_fvec(c->moving_mean);
   if (c->moving_variance)
     del_fvec(c->moving_variance);
+}
+
+void del_aubio_batchnorm(aubio_batchnorm_t* c) {
+  aubio_batchnorm_reset(c);
   AUBIO_FREE(c);
 }
 
@@ -81,12 +76,26 @@ uint_t aubio_batchnorm_get_output_shape(aubio_batchnorm_t *c,
   uint_t i;
 
   AUBIO_ASSERT(c && input && shape);
-  AUBIO_ASSERT(c->n_outputs == input->shape[input->ndim - 1]);
 
   for (i = 0; i < input->ndim; i++) {
     shape[i] = input->shape[i];
   }
 
+  aubio_batchnorm_reset(c);
+
+  c->n_outputs = input->shape[input->ndim - 1];
+
+  c->gamma = new_fvec(c->n_outputs);
+  c->beta = new_fvec(c->n_outputs);
+  c->moving_mean = new_fvec(c->n_outputs);
+  c->moving_variance = new_fvec(c->n_outputs);
+
+  if (!c->gamma || !c->beta || !c->moving_mean || !c->moving_variance)
+  {
+    aubio_batchnorm_reset(c);
+    return AUBIO_FAIL;
+  }
+
   aubio_batchnorm_debug(c, input);
 
   return AUBIO_OK;
index fb89174..6c0e44f 100644 (file)
@@ -40,7 +40,7 @@ extern "C" {
 
 typedef struct _aubio_batchnorm_t aubio_batchnorm_t;
 
-aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs);
+aubio_batchnorm_t *new_aubio_batchnorm(void);
 
 void aubio_batchnorm_do(aubio_batchnorm_t *t,
         aubio_tensor_t *input_tensor,