Merge branch 'feature/mfccparams'
[aubio.git] / python / tests / test_fft.py
1 #! /usr/bin/env python
2
3 from unittest import main
4 from numpy.testing import TestCase
5 from numpy.testing import assert_equal, assert_almost_equal
6 import numpy as np
7 from aubio import fvec, fft, cvec
8 from math import pi, floor
9 from random import random
10
11 class aubio_fft_test_case(TestCase):
12
13     def test_members(self):
14         """ check members are set correctly """
15         win_s = 2048
16         f = fft(win_s)
17         assert_equal (f.win_s, win_s)
18
19     def test_output_dimensions(self):
20         """ check the dimensions of output """
21         win_s = 1024
22         timegrain = fvec(win_s)
23         f = fft (win_s)
24         fftgrain = f (timegrain)
25         del f
26         assert_equal (fftgrain.norm.shape, (win_s/2+1,))
27         assert_equal (fftgrain.phas.shape, (win_s/2+1,))
28
29     def test_zeros(self):
30         """ check the transform of zeros is all zeros """
31         win_s = 512
32         timegrain = fvec(win_s)
33         f = fft (win_s)
34         fftgrain = f (timegrain)
35         assert_equal ( fftgrain.norm, 0 )
36         try:
37             assert_equal ( fftgrain.phas, 0 )
38         except AssertionError:
39             assert_equal (fftgrain.phas[fftgrain.phas > 0], +pi)
40             assert_equal (fftgrain.phas[fftgrain.phas < 0], -pi)
41             assert_equal (np.abs(fftgrain.phas[np.abs(fftgrain.phas) != pi]), 0)
42             self.skipTest('fft(fvec(%d)).phas != +0, ' % win_s \
43                     + 'This is expected when using fftw3 on powerpc.')
44
45     def test_impulse(self):
46         """ check the transform of one impulse at a random place """
47         win_s = 256
48         i = int(floor(random()*win_s))
49         impulse = pi * random()
50         f = fft(win_s)
51         timegrain = fvec(win_s)
52         timegrain[i] = impulse
53         fftgrain = f ( timegrain )
54         #self.plot_this ( fftgrain.phas )
55         assert_almost_equal ( fftgrain.norm, impulse, decimal = 6 )
56         assert_equal ( fftgrain.phas <= pi, True)
57         assert_equal ( fftgrain.phas >= -pi, True)
58
59     def test_impulse_negative(self):
60         """ check the transform of a negative impulse at a random place """
61         win_s = 256
62         i = int(floor(random()*win_s))
63         impulse = -.1
64         f = fft(win_s)
65         timegrain = fvec(win_s)
66         timegrain[0] = 0
67         timegrain[i] = impulse
68         fftgrain = f ( timegrain )
69         #self.plot_this ( fftgrain.phas )
70         assert_almost_equal ( fftgrain.norm, abs(impulse), decimal = 5 )
71         if impulse < 0:
72             # phase can be pi or -pi, as it is not unwrapped
73             #assert_almost_equal ( abs(fftgrain.phas[1:-1]) , pi, decimal = 6 )
74             assert_almost_equal ( fftgrain.phas[0], pi, decimal = 6)
75             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
76         else:
77             #assert_equal ( fftgrain.phas[1:-1] == 0, True)
78             assert_equal ( fftgrain.phas[0], 0)
79             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
80         # now check the resynthesis
81         synthgrain = f.rdo ( fftgrain )
82         #self.plot_this ( fftgrain.phas.T )
83         assert_equal ( fftgrain.phas <= pi, True)
84         assert_equal ( fftgrain.phas >= -pi, True)
85         #self.plot_this ( synthgrain - timegrain )
86         assert_almost_equal ( synthgrain, timegrain, decimal = 6 )
87
88     def test_impulse_at_zero(self):
89         """ check the transform of one impulse at a index 0 """
90         win_s = 1024
91         impulse = pi
92         f = fft(win_s)
93         timegrain = fvec(win_s)
94         timegrain[0] = impulse
95         fftgrain = f ( timegrain )
96         #self.plot_this ( fftgrain.phas )
97         assert_equal ( fftgrain.phas[0], 0)
98         # could be 0 or -0 depending on fft implementation (0 for fftw3, -0 for ooura)
99         assert_almost_equal ( fftgrain.phas[1], 0)
100         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
101
102     def test_rdo_before_do(self):
103         """ check running fft.rdo before fft.do works """
104         win_s = 1024
105         f = fft(win_s)
106         fftgrain = cvec(win_s)
107         t = f.rdo( fftgrain )
108         assert_equal ( t, 0 )
109
110     def plot_this(self, this):
111         from pylab import plot, show
112         plot ( this )
113         show ()
114
115     def test_local_fftgrain(self):
116         """ check aubio.fft() result can be accessed after deletion """
117         def compute_grain(impulse):
118             win_s = 1024
119             timegrain = fvec(win_s)
120             timegrain[0] = impulse
121             f = fft(win_s)
122             fftgrain = f ( timegrain )
123             return fftgrain
124         impulse = pi
125         fftgrain = compute_grain(impulse)
126         assert_equal ( fftgrain.phas[0], 0)
127         assert_almost_equal ( fftgrain.phas[1], 0)
128         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
129
130     def test_local_reconstruct(self):
131         """ check aubio.fft.rdo() result can be accessed after deletion """
132         def compute_grain(impulse):
133             win_s = 1024
134             timegrain = fvec(win_s)
135             timegrain[0] = impulse
136             f = fft(win_s)
137             fftgrain = f ( timegrain )
138             r = f.rdo(fftgrain)
139             return r
140         impulse = pi
141         r = compute_grain(impulse)
142         assert_almost_equal ( r[0], impulse, decimal = 6)
143         assert_almost_equal ( r[1:], 0)
144
145 class aubio_fft_odd_sizes(TestCase):
146
147     def test_reconstruct_with_odd_size(self):
148         win_s = 29
149         self.recontruct(win_s, 'odd sizes not supported')
150
151     def test_reconstruct_with_radix15(self):
152         win_s = 2 ** 4 * 15
153         self.recontruct(win_s, 'radix 15 supported')
154
155     def test_reconstruct_with_radix5(self):
156         win_s = 2 ** 4 * 5
157         self.recontruct(win_s, 'radix 5 supported')
158
159     def test_reconstruct_with_radix3(self):
160         win_s = 2 ** 4 * 3
161         self.recontruct(win_s, 'radix 3 supported')
162
163     def recontruct(self, win_s, skipMessage):
164         try:
165             f = fft(win_s)
166         except RuntimeError:
167             self.skipTest(skipMessage)
168         input_signal = fvec(win_s)
169         input_signal[win_s//2] = 1
170         c = f(input_signal)
171         output_signal = f.rdo(c)
172         assert_almost_equal(input_signal, output_signal)
173
174 class aubio_fft_wrong_params(TestCase):
175
176     def test_large_input_timegrain(self):
177         win_s = 1024
178         f = fft(win_s)
179         t = fvec(win_s + 1)
180         with self.assertRaises(ValueError):
181             f(t)
182
183     def test_small_input_timegrain(self):
184         win_s = 1024
185         f = fft(win_s)
186         t = fvec(1)
187         with self.assertRaises(ValueError):
188             f(t)
189
190     def test_large_input_fftgrain(self):
191         win_s = 1024
192         f = fft(win_s)
193         s = cvec(win_s + 5)
194         with self.assertRaises(ValueError):
195             f.rdo(s)
196
197     def test_small_input_fftgrain(self):
198         win_s = 1024
199         f = fft(win_s)
200         s = cvec(16)
201         with self.assertRaises(ValueError):
202             f.rdo(s)
203
204     def test_wrong_buf_size(self):
205         win_s = -1
206         with self.assertRaises(ValueError):
207             fft(win_s)
208
209     def test_buf_size_too_small(self):
210         win_s = 1
211         with self.assertRaises(RuntimeError):
212             fft(win_s)
213
214 if __name__ == '__main__':
215     main()