Merge branch 'master' into feature/pytest
[aubio.git] / python / tests / test_fft.py
1 #! /usr/bin/env python
2
3 from numpy.testing import TestCase
4 from numpy.testing import assert_equal, assert_almost_equal
5 import numpy as np
6 from aubio import fvec, fft, cvec
7 from math import pi, floor
8 from random import random
9
10 class aubio_fft_test_case(TestCase):
11
12     def test_members(self):
13         """ check members are set correctly """
14         win_s = 2048
15         f = fft(win_s)
16         assert_equal (f.win_s, win_s)
17
18     def test_output_dimensions(self):
19         """ check the dimensions of output """
20         win_s = 1024
21         timegrain = fvec(win_s)
22         f = fft (win_s)
23         fftgrain = f (timegrain)
24         del f
25         assert_equal (fftgrain.norm.shape, (win_s/2+1,))
26         assert_equal (fftgrain.phas.shape, (win_s/2+1,))
27
28     def test_zeros(self):
29         """ check the transform of zeros is all zeros """
30         win_s = 512
31         timegrain = fvec(win_s)
32         f = fft (win_s)
33         fftgrain = f (timegrain)
34         assert_equal ( fftgrain.norm, 0 )
35         try:
36             assert_equal ( fftgrain.phas, 0 )
37         except AssertionError:
38             assert_equal (fftgrain.phas[fftgrain.phas > 0], +pi)
39             assert_equal (fftgrain.phas[fftgrain.phas < 0], -pi)
40             assert_equal (np.abs(fftgrain.phas[np.abs(fftgrain.phas) != pi]), 0)
41             self.skipTest('fft(fvec(%d)).phas != +0, ' % win_s \
42                     + 'This is expected when using fftw3 on powerpc.')
43
44     def test_impulse(self):
45         """ check the transform of one impulse at a random place """
46         win_s = 256
47         i = int(floor(random()*win_s))
48         impulse = pi * random()
49         f = fft(win_s)
50         timegrain = fvec(win_s)
51         timegrain[i] = impulse
52         fftgrain = f ( timegrain )
53         #self.plot_this ( fftgrain.phas )
54         assert_almost_equal ( fftgrain.norm, impulse, decimal = 6 )
55         assert_equal ( fftgrain.phas <= pi, True)
56         assert_equal ( fftgrain.phas >= -pi, True)
57
58     def test_impulse_negative(self):
59         """ check the transform of a negative impulse at a random place """
60         win_s = 256
61         i = int(floor(random()*win_s))
62         impulse = -.1
63         f = fft(win_s)
64         timegrain = fvec(win_s)
65         timegrain[0] = 0
66         timegrain[i] = impulse
67         fftgrain = f ( timegrain )
68         #self.plot_this ( fftgrain.phas )
69         assert_almost_equal ( fftgrain.norm, abs(impulse), decimal = 5 )
70         if impulse < 0:
71             # phase can be pi or -pi, as it is not unwrapped
72             #assert_almost_equal ( abs(fftgrain.phas[1:-1]) , pi, decimal = 6 )
73             assert_almost_equal ( fftgrain.phas[0], pi, decimal = 6)
74             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
75         else:
76             #assert_equal ( fftgrain.phas[1:-1] == 0, True)
77             assert_equal ( fftgrain.phas[0], 0)
78             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
79         # now check the resynthesis
80         synthgrain = f.rdo ( fftgrain )
81         #self.plot_this ( fftgrain.phas.T )
82         assert_equal ( fftgrain.phas <= pi, True)
83         assert_equal ( fftgrain.phas >= -pi, True)
84         #self.plot_this ( synthgrain - timegrain )
85         assert_almost_equal ( synthgrain, timegrain, decimal = 6 )
86
87     def test_impulse_at_zero(self):
88         """ check the transform of one impulse at a index 0 """
89         win_s = 1024
90         impulse = pi
91         f = fft(win_s)
92         timegrain = fvec(win_s)
93         timegrain[0] = impulse
94         fftgrain = f ( timegrain )
95         #self.plot_this ( fftgrain.phas )
96         assert_equal ( fftgrain.phas[0], 0)
97         # could be 0 or -0 depending on fft implementation (0 for fftw3, -0 for ooura)
98         assert_almost_equal ( fftgrain.phas[1], 0)
99         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
100
101     def test_rdo_before_do(self):
102         """ check running fft.rdo before fft.do works """
103         win_s = 1024
104         f = fft(win_s)
105         fftgrain = cvec(win_s)
106         t = f.rdo( fftgrain )
107         assert_equal ( t, 0 )
108
109     def plot_this(self, this):
110         from pylab import plot, show
111         plot ( this )
112         show ()
113
114     def test_local_fftgrain(self):
115         """ check aubio.fft() result can be accessed after deletion """
116         def compute_grain(impulse):
117             win_s = 1024
118             timegrain = fvec(win_s)
119             timegrain[0] = impulse
120             f = fft(win_s)
121             fftgrain = f ( timegrain )
122             return fftgrain
123         impulse = pi
124         fftgrain = compute_grain(impulse)
125         assert_equal ( fftgrain.phas[0], 0)
126         assert_almost_equal ( fftgrain.phas[1], 0)
127         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
128
129     def test_local_reconstruct(self):
130         """ check aubio.fft.rdo() result can be accessed after deletion """
131         def compute_grain(impulse):
132             win_s = 1024
133             timegrain = fvec(win_s)
134             timegrain[0] = impulse
135             f = fft(win_s)
136             fftgrain = f ( timegrain )
137             r = f.rdo(fftgrain)
138             return r
139         impulse = pi
140         r = compute_grain(impulse)
141         assert_almost_equal ( r[0], impulse, decimal = 6)
142         assert_almost_equal ( r[1:], 0)
143
144 class aubio_fft_odd_sizes(TestCase):
145
146     def test_reconstruct_with_odd_size(self):
147         win_s = 29
148         self.recontruct(win_s, 'odd sizes not supported')
149
150     def test_reconstruct_with_radix15(self):
151         win_s = 2 ** 4 * 15
152         self.recontruct(win_s, 'radix 15 supported')
153
154     def test_reconstruct_with_radix5(self):
155         win_s = 2 ** 4 * 5
156         self.recontruct(win_s, 'radix 5 supported')
157
158     def test_reconstruct_with_radix3(self):
159         win_s = 2 ** 4 * 3
160         self.recontruct(win_s, 'radix 3 supported')
161
162     def recontruct(self, win_s, skipMessage):
163         try:
164             f = fft(win_s)
165         except RuntimeError:
166             self.skipTest(skipMessage)
167         input_signal = fvec(win_s)
168         input_signal[win_s//2] = 1
169         c = f(input_signal)
170         output_signal = f.rdo(c)
171         assert_almost_equal(input_signal, output_signal)
172
173 class aubio_fft_wrong_params(TestCase):
174
175     def test_large_input_timegrain(self):
176         win_s = 1024
177         f = fft(win_s)
178         t = fvec(win_s + 1)
179         with self.assertRaises(ValueError):
180             f(t)
181
182     def test_small_input_timegrain(self):
183         win_s = 1024
184         f = fft(win_s)
185         t = fvec(1)
186         with self.assertRaises(ValueError):
187             f(t)
188
189     def test_large_input_fftgrain(self):
190         win_s = 1024
191         f = fft(win_s)
192         s = cvec(win_s + 5)
193         with self.assertRaises(ValueError):
194             f.rdo(s)
195
196     def test_small_input_fftgrain(self):
197         win_s = 1024
198         f = fft(win_s)
199         s = cvec(16)
200         with self.assertRaises(ValueError):
201             f.rdo(s)
202
203     def test_wrong_buf_size(self):
204         win_s = -1
205         with self.assertRaises(ValueError):
206             fft(win_s)
207
208     def test_buf_size_too_small(self):
209         win_s = 1
210         with self.assertRaises(RuntimeError):
211             fft(win_s)
212
213 if __name__ == '__main__':
214     from unittest import main
215     main()