Merge branch 'master' into awhitening
[aubio.git] / python / tests / test_fft.py
1 #! /usr/bin/env python
2
3 from unittest import main
4 from numpy.testing import TestCase
5 from numpy.testing import assert_equal, assert_almost_equal
6 import numpy as np
7 from aubio import fvec, fft, cvec
8 from math import pi, floor
9 from random import random
10
11 class aubio_fft_test_case(TestCase):
12
13     def test_members(self):
14         """ check members are set correctly """
15         win_s = 2048
16         f = fft(win_s)
17         assert_equal (f.win_s, win_s)
18
19     def test_output_dimensions(self):
20         """ check the dimensions of output """
21         win_s = 1024
22         timegrain = fvec(win_s)
23         f = fft (win_s)
24         fftgrain = f (timegrain)
25         del f
26         assert_equal (fftgrain.norm.shape, (win_s/2+1,))
27         assert_equal (fftgrain.phas.shape, (win_s/2+1,))
28
29     def test_zeros(self):
30         """ check the transform of zeros is all zeros """
31         win_s = 512
32         timegrain = fvec(win_s)
33         f = fft (win_s)
34         fftgrain = f (timegrain)
35         assert_equal ( fftgrain.norm, 0 )
36         try:
37             assert_equal ( fftgrain.phas, 0 )
38         except AssertionError:
39             assert_equal (fftgrain.phas[fftgrain.phas > 0], +pi)
40             assert_equal (fftgrain.phas[fftgrain.phas < 0], -pi)
41             assert_equal (np.abs(fftgrain.phas[np.abs(fftgrain.phas) != pi]), 0)
42             self.skipTest('fft(fvec(%d)).phas != +0, ' % win_s \
43                     + 'This is expected when using fftw3 on powerpc.')
44
45     def test_impulse(self):
46         """ check the transform of one impulse at a random place """
47         win_s = 256
48         i = int(floor(random()*win_s))
49         impulse = pi * random()
50         f = fft(win_s)
51         timegrain = fvec(win_s)
52         timegrain[i] = impulse
53         fftgrain = f ( timegrain )
54         #self.plot_this ( fftgrain.phas )
55         assert_almost_equal ( fftgrain.norm, impulse, decimal = 6 )
56         assert_equal ( fftgrain.phas <= pi, True)
57         assert_equal ( fftgrain.phas >= -pi, True)
58
59     def test_impulse_negative(self):
60         """ check the transform of a negative impulse at a random place """
61         win_s = 256
62         i = int(floor(random()*win_s))
63         impulse = -.1
64         f = fft(win_s)
65         timegrain = fvec(win_s)
66         timegrain[0] = 0
67         timegrain[i] = impulse
68         fftgrain = f ( timegrain )
69         #self.plot_this ( fftgrain.phas )
70         assert_almost_equal ( fftgrain.norm, abs(impulse), decimal = 5 )
71         if impulse < 0:
72             # phase can be pi or -pi, as it is not unwrapped
73             #assert_almost_equal ( abs(fftgrain.phas[1:-1]) , pi, decimal = 6 )
74             assert_almost_equal ( fftgrain.phas[0], pi, decimal = 6)
75             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
76         else:
77             #assert_equal ( fftgrain.phas[1:-1] == 0, True)
78             assert_equal ( fftgrain.phas[0], 0)
79             assert_almost_equal ( np.fmod(fftgrain.phas[-1], pi), 0, decimal = 6)
80         # now check the resynthesis
81         synthgrain = f.rdo ( fftgrain )
82         #self.plot_this ( fftgrain.phas.T )
83         assert_equal ( fftgrain.phas <= pi, True)
84         assert_equal ( fftgrain.phas >= -pi, True)
85         #self.plot_this ( synthgrain - timegrain )
86         assert_almost_equal ( synthgrain, timegrain, decimal = 6 )
87
88     def test_impulse_at_zero(self):
89         """ check the transform of one impulse at a index 0 """
90         win_s = 1024
91         impulse = pi
92         f = fft(win_s)
93         timegrain = fvec(win_s)
94         timegrain[0] = impulse
95         fftgrain = f ( timegrain )
96         #self.plot_this ( fftgrain.phas )
97         assert_equal ( fftgrain.phas[0], 0)
98         # could be 0 or -0 depending on fft implementation (0 for fftw3, -0 for ooura)
99         assert_almost_equal ( fftgrain.phas[1], 0)
100         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
101
102     def test_rdo_before_do(self):
103         """ check running fft.rdo before fft.do works """
104         win_s = 1024
105         f = fft(win_s)
106         fftgrain = cvec(win_s)
107         t = f.rdo( fftgrain )
108         assert_equal ( t, 0 )
109
110     def plot_this(self, this):
111         from pylab import plot, show
112         plot ( this )
113         show ()
114
115     def test_local_fftgrain(self):
116         """ check aubio.fft() result can be accessed after deletion """
117         def compute_grain(impulse):
118             win_s = 1024
119             timegrain = fvec(win_s)
120             timegrain[0] = impulse
121             f = fft(win_s)
122             fftgrain = f ( timegrain )
123             return fftgrain
124         impulse = pi
125         fftgrain = compute_grain(impulse)
126         assert_equal ( fftgrain.phas[0], 0)
127         assert_almost_equal ( fftgrain.phas[1], 0)
128         assert_almost_equal ( fftgrain.norm[0], impulse, decimal = 6 )
129
130     def test_local_reconstruct(self):
131         """ check aubio.fft.rdo() result can be accessed after deletion """
132         def compute_grain(impulse):
133             win_s = 1024
134             timegrain = fvec(win_s)
135             timegrain[0] = impulse
136             f = fft(win_s)
137             fftgrain = f ( timegrain )
138             r = f.rdo(fftgrain)
139             return r
140         impulse = pi
141         r = compute_grain(impulse)
142         assert_almost_equal ( r[0], impulse, decimal = 6)
143         assert_almost_equal ( r[1:], 0)
144
145     def test_large_input_timegrain(self):
146         win_s = 1024
147         f = fft(win_s)
148         t = fvec(win_s + 1)
149         with self.assertRaises(ValueError):
150             f(t)
151
152     def test_small_input_timegrain(self):
153         win_s = 1024
154         f = fft(win_s)
155         t = fvec(1)
156         with self.assertRaises(ValueError):
157             f(t)
158
159     def test_large_input_fftgrain(self):
160         win_s = 1024
161         f = fft(win_s)
162         s = cvec(win_s + 5)
163         with self.assertRaises(ValueError):
164             f.rdo(s)
165
166     def test_small_input_fftgrain(self):
167         win_s = 1024
168         f = fft(win_s)
169         s = cvec(16)
170         with self.assertRaises(ValueError):
171             f.rdo(s)
172
173 class aubio_fft_wrong_params(TestCase):
174
175     def test_wrong_buf_size(self):
176         win_s = -1
177         with self.assertRaises(ValueError):
178             fft(win_s)
179
180     def test_buf_size_not_power_of_two(self):
181         # when compiled with fftw3, aubio supports non power of two fft sizes
182         win_s = 320
183         try:
184             with self.assertRaises(RuntimeError):
185                 fft(win_s)
186         except AssertionError:
187             self.skipTest('creating aubio.fft with size %d did not fail' % win_s)
188
189     def test_buf_size_too_small(self):
190         win_s = 1
191         with self.assertRaises(RuntimeError):
192             fft(win_s)
193
194 if __name__ == '__main__':
195     main()