python/tests/test_source.py: test with interface (PEP 343)
[aubio.git] / python / tests / test_source.py
1 #! /usr/bin/env python
2
3 from nose2 import main
4 from nose2.tools import params
5 from numpy.testing import TestCase, assert_equal
6 from aubio import source
7 from .utils import list_all_sounds
8 import numpy as np
9
10 import warnings
11 warnings.filterwarnings('ignore', category=UserWarning, append=True)
12
13 list_of_sounds = list_all_sounds('sounds')
14 samplerates = [0, 44100, 8000, 32000]
15 hop_sizes = [512, 1024, 64]
16
17 path = None
18
19 all_params = []
20 for soundfile in list_of_sounds:
21     for hop_size in hop_sizes:
22         for samplerate in samplerates:
23             all_params.append((hop_size, samplerate, soundfile))
24
25
26 class aubio_source_test_case_base(TestCase):
27
28     def setUp(self):
29         if not len(list_of_sounds):
30             self.skipTest('add some sound files in \'python/tests/sounds\'')
31         self.default_test_sound = list_of_sounds[0]
32
33 class aubio_source_test_case(aubio_source_test_case_base):
34
35     @params(*list_of_sounds)
36     def test_close_file(self, filename):
37         samplerate = 0 # use native samplerate
38         hop_size = 256
39         f = source(filename, samplerate, hop_size)
40         f.close()
41
42     @params(*list_of_sounds)
43     def test_close_file_twice(self, filename):
44         samplerate = 0 # use native samplerate
45         hop_size = 256
46         f = source(filename, samplerate, hop_size)
47         f.close()
48         f.close()
49
50 class aubio_source_read_test_case(aubio_source_test_case_base):
51
52     def read_from_source(self, f):
53         total_frames = 0
54         while True:
55             samples , read = f()
56             total_frames += read
57             if read < f.hop_size:
58                 assert_equal(samples[read:], 0)
59                 break
60         #result_str = "read {:.2f}s ({:d} frames in {:d} blocks at {:d}Hz) from {:s}"
61         #result_params = total_frames / float(f.samplerate), total_frames, total_frames//f.hop_size, f.samplerate, f.uri
62         #print (result_str.format(*result_params))
63         return total_frames
64
65     @params(*all_params)
66     def test_samplerate_hopsize(self, hop_size, samplerate, soundfile):
67         try:
68             f = source(soundfile, samplerate, hop_size)
69         except RuntimeError as e:
70             self.skipTest('failed opening with hop_s = {:d}, samplerate = {:d} ({:s})'.format(hop_size, samplerate, str(e)))
71         assert f.samplerate != 0
72         read_frames = self.read_from_source(f)
73         if 'f_' in soundfile and samplerate == 0:
74             import re
75             f = re.compile('.*_\([0:9]*f\)_.*')
76             match_f = re.findall('([0-9]*)f_', soundfile)
77             if len(match_f) == 1:
78                 expected_frames = int(match_f[0])
79                 self.assertEqual(expected_frames, read_frames)
80
81     @params(*list_of_sounds)
82     def test_samplerate_none(self, p):
83         f = source(p)
84         assert f.samplerate != 0
85         self.read_from_source(f)
86
87     @params(*list_of_sounds)
88     def test_samplerate_0(self, p):
89         f = source(p, 0)
90         assert f.samplerate != 0
91         self.read_from_source(f)
92
93     @params(*list_of_sounds)
94     def test_zero_hop_size(self, p):
95         f = source(p, 0, 0)
96         assert f.samplerate != 0
97         assert f.hop_size != 0
98         self.read_from_source(f)
99
100     @params(*list_of_sounds)
101     def test_seek_to_half(self, p):
102         from random import randint
103         f = source(p, 0, 0)
104         assert f.samplerate != 0
105         assert f.hop_size != 0
106         a = self.read_from_source(f)
107         c = randint(0, a)
108         f.seek(c)
109         b = self.read_from_source(f)
110         assert a == b + c
111
112     @params(*list_of_sounds)
113     def test_duration(self, p):
114         total_frames = 0
115         f = source(p)
116         duration = f.duration
117         while True:
118             _, read = f()
119             total_frames += read
120             if read < f.hop_size: break
121         self.assertEqual(duration, total_frames)
122
123
124 class aubio_source_test_wrong_params(TestCase):
125
126     def test_wrong_file(self):
127         with self.assertRaises(RuntimeError):
128             source('path_to/unexisting file.mp3')
129
130 class aubio_source_test_wrong_params_with_file(aubio_source_test_case_base):
131
132     def test_wrong_samplerate(self):
133         with self.assertRaises(ValueError):
134             source(self.default_test_sound, -1)
135
136     def test_wrong_hop_size(self):
137         with self.assertRaises(ValueError):
138             source(self.default_test_sound, 0, -1)
139
140     def test_wrong_channels(self):
141         with self.assertRaises(ValueError):
142             source(self.default_test_sound, 0, 0, -1)
143
144     def test_wrong_seek(self):
145         f = source(self.default_test_sound)
146         with self.assertRaises(ValueError):
147             f.seek(-1)
148
149     def test_wrong_seek_too_large(self):
150         f = source(self.default_test_sound)
151         try:
152             with self.assertRaises(ValueError):
153                 f.seek(f.duration + f.samplerate * 10)
154         except AssertionError:
155             self.skipTest('seeking after end of stream failed raising ValueError')
156
157 class aubio_source_readmulti_test_case(aubio_source_read_test_case):
158
159     def read_from_source(self, f):
160         total_frames = 0
161         while True:
162             samples, read = f.do_multi()
163             total_frames += read
164             if read < f.hop_size:
165                 assert_equal(samples[:,read:], 0)
166                 break
167         #result_str = "read {:.2f}s ({:d} frames in {:d} channels and {:d} blocks at {:d}Hz) from {:s}"
168         #result_params = total_frames / float(f.samplerate), total_frames, f.channels, int(total_frames/f.hop_size), f.samplerate, f.uri
169         #print (result_str.format(*result_params))
170         return total_frames
171
172 class aubio_source_with(aubio_source_test_case_base):
173
174     #@params(*list_of_sounds)
175     @params(*list_of_sounds)
176     def test_read_from_mono(self, filename):
177         total_frames = 0
178         hop_size = 2048
179         with source(filename, 0, hop_size) as input_source:
180             assert_equal(input_source.hop_size, hop_size)
181             #assert_equal(input_source.samplerate, samplerate)
182             total_frames = 0
183             for frames in input_source:
184                 total_frames += frames.shape[-1]
185             # check we read as many samples as we expected
186             assert_equal(total_frames, input_source.duration)
187
188 if __name__ == '__main__':
189     main()