From: Paul Brossier Date: Fri, 18 Jan 2019 09:47:14 +0000 (+0100) Subject: [batchnorm] accepts any input size, allocate weights in get_output_shape X-Git-Url: https://git.aubio.org/?a=commitdiff_plain;h=72f450af5b21bcd94676ddb0f619143cc88b78cc;p=aubio.git [batchnorm] accepts any input size, allocate weights in get_output_shape --- diff --git a/src/ai/batchnorm.c b/src/ai/batchnorm.c index 4184275a..490da8ba 100644 --- a/src/ai/batchnorm.c +++ b/src/ai/batchnorm.c @@ -34,27 +34,18 @@ struct _aubio_batchnorm_t { static void aubio_batchnorm_debug(aubio_batchnorm_t *c, aubio_tensor_t *input_tensor); -aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs) +aubio_batchnorm_t *new_aubio_batchnorm(void) { aubio_batchnorm_t *c = AUBIO_NEW(aubio_batchnorm_t); - - AUBIO_GOTO_FAILURE((sint_t)n_outputs > 0); - - c->n_outputs = n_outputs; - - c->gamma = new_fvec(n_outputs); - c->beta = new_fvec(n_outputs); - c->moving_mean = new_fvec(n_outputs); - c->moving_variance = new_fvec(n_outputs); - return c; - +#if 0 // no argument so no other possible failure failure: del_aubio_batchnorm(c); return NULL; +#endif } -void del_aubio_batchnorm(aubio_batchnorm_t* c) { +static void aubio_batchnorm_reset(aubio_batchnorm_t *c) { AUBIO_ASSERT(c); if (c->gamma) del_fvec(c->gamma); @@ -64,6 +55,10 @@ void del_aubio_batchnorm(aubio_batchnorm_t* c) { del_fvec(c->moving_mean); if (c->moving_variance) del_fvec(c->moving_variance); +} + +void del_aubio_batchnorm(aubio_batchnorm_t* c) { + aubio_batchnorm_reset(c); AUBIO_FREE(c); } @@ -81,12 +76,26 @@ uint_t aubio_batchnorm_get_output_shape(aubio_batchnorm_t *c, uint_t i; AUBIO_ASSERT(c && input && shape); - AUBIO_ASSERT(c->n_outputs == input->shape[input->ndim - 1]); for (i = 0; i < input->ndim; i++) { shape[i] = input->shape[i]; } + aubio_batchnorm_reset(c); + + c->n_outputs = input->shape[input->ndim - 1]; + + c->gamma = new_fvec(c->n_outputs); + c->beta = new_fvec(c->n_outputs); + c->moving_mean = new_fvec(c->n_outputs); + c->moving_variance = new_fvec(c->n_outputs); + + if (!c->gamma || !c->beta || !c->moving_mean || !c->moving_variance) + { + aubio_batchnorm_reset(c); + return AUBIO_FAIL; + } + aubio_batchnorm_debug(c, input); return AUBIO_OK; diff --git a/src/ai/batchnorm.h b/src/ai/batchnorm.h index fb89174f..6c0e44f4 100644 --- a/src/ai/batchnorm.h +++ b/src/ai/batchnorm.h @@ -40,7 +40,7 @@ extern "C" { typedef struct _aubio_batchnorm_t aubio_batchnorm_t; -aubio_batchnorm_t *new_aubio_batchnorm(uint_t n_outputs); +aubio_batchnorm_t *new_aubio_batchnorm(void); void aubio_batchnorm_do(aubio_batchnorm_t *t, aubio_tensor_t *input_tensor,