Merge branch 'master' into feature/pytest
[aubio.git] / python / tests / test_filterbank_mel.py
1 #! /usr/bin/env python
2
3 import numpy as np
4 from numpy.testing import TestCase, assert_equal, assert_almost_equal
5 from _tools import assert_warns
6
7 from aubio import fvec, cvec, filterbank, float_type
8
9 class aubio_filterbank_mel_test_case(TestCase):
10
11     def test_slaney(self):
12         f = filterbank(40, 512)
13         f.set_mel_coeffs_slaney(16000)
14         a = f.get_coeffs()
15         assert_equal(np.shape (a), (40, 512/2 + 1) )
16
17     def test_other_slaney(self):
18         f = filterbank(40, 512*2)
19         f.set_mel_coeffs_slaney(44100)
20         self.assertIsInstance(f.get_coeffs(), np.ndarray)
21         #print "sum is", sum(sum(a))
22         for win_s in [256, 512, 1024, 2048, 4096]:
23             f = filterbank(40, win_s)
24             f.set_mel_coeffs_slaney(32000)
25             #print "sum is", sum(sum(a))
26             self.assertIsInstance(f.get_coeffs(), np.ndarray)
27
28     def test_triangle_freqs_zeros(self):
29         f = filterbank(9, 1024)
30         freq_list = [40, 80, 200, 400, 800, 1600, 3200, 6400, 12800, 15000, 24000]
31         freqs = np.array(freq_list, dtype = float_type)
32         f.set_triangle_bands(freqs, 48000)
33         assert_equal ( f(cvec(1024)), 0)
34         self.assertIsInstance(f.get_coeffs(), np.ndarray)
35
36     def test_triangle_freqs_ones(self):
37         f = filterbank(9, 1024)
38         freq_list = [40, 80, 200, 400, 800, 1600, 3200, 6400, 12800, 15000, 24000]
39         freqs = np.array(freq_list, dtype = float_type)
40         f.set_triangle_bands(freqs, 48000)
41         self.assertIsInstance(f.get_coeffs(), np.ndarray)
42         spec = cvec(1024)
43         spec.norm[:] = 1
44         assert_almost_equal ( f(spec),
45                 [ 0.02070313, 0.02138672, 0.02127604, 0.02135417,
46                     0.02133301, 0.02133301, 0.02133311, 0.02133334, 0.02133345])
47
48     def test_triangle_freqs_with_zeros(self):
49         """make sure set_triangle_bands works when list starts with 0"""
50         freq_list = [0, 40, 80]
51         freqs = np.array(freq_list, dtype = float_type)
52         f = filterbank(len(freqs)-2, 1024)
53         f.set_triangle_bands(freqs, 48000)
54         assert_equal ( f(cvec(1024)), 0)
55         self.assertIsInstance(f.get_coeffs(), np.ndarray)
56
57     def test_triangle_freqs_with_wrong_negative(self):
58         """make sure set_triangle_bands fails when list contains a negative"""
59         freq_list = [-10, 0, 80]
60         f = filterbank(len(freq_list)-2, 1024)
61         with self.assertRaises(ValueError):
62             f.set_triangle_bands(fvec(freq_list), 48000)
63
64     def test_triangle_freqs_with_wrong_ordering(self):
65         """make sure set_triangle_bands fails when list not ordered"""
66         freq_list = [0, 80, 40]
67         f = filterbank(len(freq_list)-2, 1024)
68         with self.assertRaises(ValueError):
69             f.set_triangle_bands(fvec(freq_list), 48000)
70
71     def test_triangle_freqs_with_large_freq(self):
72         """make sure set_triangle_bands warns when freq > nyquist"""
73         samplerate = 22050
74         freq_list = [0, samplerate//4, samplerate // 2 + 1]
75         f = filterbank(len(freq_list)-2, 1024)
76         with assert_warns(UserWarning):
77             f.set_triangle_bands(fvec(freq_list), samplerate)
78
79     def test_triangle_freqs_with_not_enough_filters(self):
80         """make sure set_triangle_bands warns when not enough filters"""
81         samplerate = 22050
82         freq_list = [0, 100, 1000, 4000, 8000, 10000]
83         f = filterbank(len(freq_list)-3, 1024)
84         with assert_warns(UserWarning):
85             f.set_triangle_bands(fvec(freq_list), samplerate)
86
87     def test_triangle_freqs_with_too_many_filters(self):
88         """make sure set_triangle_bands warns when too many filters"""
89         samplerate = 22050
90         freq_list = [0, 100, 1000, 4000, 8000, 10000]
91         f = filterbank(len(freq_list)-1, 1024)
92         with assert_warns(UserWarning):
93             f.set_triangle_bands(fvec(freq_list), samplerate)
94
95     def test_triangle_freqs_with_double_value(self):
96         """make sure set_triangle_bands works with 2 duplicate freqs"""
97         samplerate = 22050
98         freq_list = [0, 100, 1000, 4000, 4000, 4000, 10000]
99         f = filterbank(len(freq_list)-2, 1024)
100         with assert_warns(UserWarning):
101             f.set_triangle_bands(fvec(freq_list), samplerate)
102
103     def test_triangle_freqs_with_triple(self):
104         """make sure set_triangle_bands works with 3 duplicate freqs"""
105         samplerate = 22050
106         freq_list = [0, 100, 1000, 4000, 4000, 4000, 10000]
107         f = filterbank(len(freq_list)-2, 1024)
108         with assert_warns(UserWarning):
109             f.set_triangle_bands(fvec(freq_list), samplerate)
110
111
112     def test_triangle_freqs_without_norm(self):
113         """make sure set_triangle_bands works without """
114         samplerate = 22050
115         freq_list = fvec([0, 100, 1000, 10000])
116         f = filterbank(len(freq_list) - 2, 1024)
117         f.set_norm(0)
118         f.set_triangle_bands(freq_list, samplerate)
119         expected = f.get_coeffs()
120         f.set_norm(1)
121         f.set_triangle_bands(fvec(freq_list), samplerate)
122         assert_almost_equal(f.get_coeffs().T,
123                 expected.T * 2. / (freq_list[2:] - freq_list[:-2]))
124
125     def test_triangle_freqs_wrong_norm(self):
126         f = filterbank(10, 1024)
127         with self.assertRaises(ValueError):
128             f.set_norm(-1)
129
130     def test_triangle_freqs_with_power(self):
131         f = filterbank(9, 1024)
132         freqs = fvec([40, 80, 200, 400, 800, 1600, 3200, 6400, 12800, 15000,
133             24000])
134         f.set_power(2)
135         f.set_triangle_bands(freqs, 48000)
136         spec = cvec(1024)
137         spec.norm[:] = .1
138         expected = fvec([0.02070313, 0.02138672, 0.02127604, 0.02135417,
139             0.02133301, 0.02133301, 0.02133311, 0.02133334, 0.02133345])
140         expected /= 100.
141         assert_almost_equal(f(spec), expected)
142
143     def test_mel_coeffs(self):
144         f = filterbank(40, 1024)
145         f.set_mel_coeffs(44100, 0, 44100 / 2)
146
147     def test_zero_fmax(self):
148         f = filterbank(40, 1024)
149         f.set_mel_coeffs(44100, 0, 0)
150
151     def test_wrong_mel_coeffs(self):
152         f = filterbank(40, 1024)
153         with self.assertRaises(ValueError):
154             f.set_mel_coeffs_slaney(0)
155         with self.assertRaises(ValueError):
156             f.set_mel_coeffs(44100, 0, -44100 / 2)
157         with self.assertRaises(ValueError):
158             f.set_mel_coeffs(44100, -0.1, 44100 / 2)
159         with self.assertRaises(ValueError):
160             f.set_mel_coeffs(-44100, 0.1, 44100 / 2)
161         with self.assertRaises(ValueError):
162             f.set_mel_coeffs_htk(-1, 0, 0)
163
164     def test_mel_coeffs_htk(self):
165         f = filterbank(40, 1024)
166         f.set_mel_coeffs_htk(44100, 0, 44100 / 2)
167
168
169 if __name__ == '__main__':
170     from unittest import main
171     main()