From 689ba93b06ffce012e4a45ed0637e88d44e9c97d Mon Sep 17 00:00:00 2001 From: Paul Brossier Date: Fri, 18 Jan 2019 10:47:14 +0100 Subject: [PATCH] [batchnorm] accepts any input size, allocate weights in get_output_shape --- src/ai/batchnorm.c | 37 +++++++++++++++++++++++-------------- src/ai/batchnorm.h | 2 +- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/ai/batchnorm.c b/src/ai/batchnorm.c index 4184275a..490da8ba 100644 --- a/src/ai/batchnorm.c +++ b/src/ai/batchnorm.c @@ -34,27 +34,18 @@ struct _aubio_batchnorm_t { static void aubio_batchnorm_debug(aubio_batchnorm_t *c, aubio_tensor_t *input_tensor); -aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs) +aubio_batchnorm_t *new_aubio_batchnorm(void) { aubio_batchnorm_t *c = AUBIO_NEW(aubio_batchnorm_t); - - AUBIO_GOTO_FAILURE((sint_t)n_outputs > 0); - - c->n_outputs = n_outputs; - - c->gamma = new_fvec(n_outputs); - c->beta = new_fvec(n_outputs); - c->moving_mean = new_fvec(n_outputs); - c->moving_variance = new_fvec(n_outputs); - return c; - +#if 0 // no argument so no other possible failure failure: del_aubio_batchnorm(c); return NULL; +#endif } -void del_aubio_batchnorm(aubio_batchnorm_t* c) { +static void aubio_batchnorm_reset(aubio_batchnorm_t *c) { AUBIO_ASSERT(c); if (c->gamma) del_fvec(c->gamma); @@ -64,6 +55,10 @@ void del_aubio_batchnorm(aubio_batchnorm_t* c) { del_fvec(c->moving_mean); if (c->moving_variance) del_fvec(c->moving_variance); +} + +void del_aubio_batchnorm(aubio_batchnorm_t* c) { + aubio_batchnorm_reset(c); AUBIO_FREE(c); } @@ -81,12 +76,26 @@ uint_t aubio_batchnorm_get_output_shape(aubio_batchnorm_t *c, uint_t i; AUBIO_ASSERT(c && input && shape); - AUBIO_ASSERT(c->n_outputs == input->shape[input->ndim - 1]); for (i = 0; i < input->ndim; i++) { shape[i] = input->shape[i]; } + aubio_batchnorm_reset(c); + + c->n_outputs = input->shape[input->ndim - 1]; + + c->gamma = new_fvec(c->n_outputs); + c->beta = new_fvec(c->n_outputs); + c->moving_mean = new_fvec(c->n_outputs); + c->moving_variance = new_fvec(c->n_outputs); + + if (!c->gamma || !c->beta || !c->moving_mean || !c->moving_variance) + { + aubio_batchnorm_reset(c); + return AUBIO_FAIL; + } + aubio_batchnorm_debug(c, input); return AUBIO_OK; diff --git a/src/ai/batchnorm.h b/src/ai/batchnorm.h index fb89174f..6c0e44f4 100644 --- a/src/ai/batchnorm.h +++ b/src/ai/batchnorm.h @@ -40,7 +40,7 @@ extern "C" { typedef struct _aubio_batchnorm_t aubio_batchnorm_t; -aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs); +aubio_batchnorm_t *new_aubio_batchnorm(void); void aubio_batchnorm_do(aubio_batchnorm_t *t, aubio_tensor_t *input_tensor, -- 2.11.0