ac098639ffbfd01e599e20f0dcb48f26994f2f48
[aubio.git] / python / demos / demo_mel-energy.py
1 #! /usr/bin/env python
2
3 import sys
4 from aubio import fvec, source, pvoc, filterbank
5 from numpy import vstack, zeros
6
7 win_s = 512                 # fft size
8 hop_s = win_s / 4           # hop size
9
10 if len(sys.argv) < 2:
11     print "Usage: %s <filename> [samplerate]" % sys.argv[0]
12     sys.exit(1)
13
14 filename = sys.argv[1]
15
16 samplerate = 0
17 if len( sys.argv ) > 2: samplerate = int(sys.argv[2])
18
19 s = source(filename, samplerate, hop_s)
20 samplerate = s.samplerate
21
22 pv = pvoc(win_s, hop_s)
23
24 f = filterbank(40, win_s)
25 f.set_mel_coeffs_slaney(samplerate)
26
27 energies = zeros((40,))
28 o = {}
29
30 total_frames = 0
31 downsample = 2
32
33 while True:
34     samples, read = s()
35     fftgrain = pv(samples)
36     new_energies = f(fftgrain)
37     print '%f' % (total_frames / float(samplerate) ),
38     print ' '.join(['%f' % b for b in new_energies])
39     energies = vstack( [energies, new_energies] )
40     total_frames += read
41     if read < hop_s: break
42
43 if 1:
44     print "done computing, now plotting"
45     import matplotlib.pyplot as plt
46     from demo_waveform_plot import get_waveform_plot
47     from demo_waveform_plot import set_xlabels_sample2time
48     fig = plt.figure()
49     plt.rc('lines',linewidth='.8')
50     wave = plt.axes([0.1, 0.75, 0.8, 0.19])
51     get_waveform_plot(filename, samplerate, block_size = hop_s, ax = wave )
52     wave.yaxis.set_visible(False)
53     wave.xaxis.set_visible(False)
54
55     n_plots = len(energies.T)
56     all_desc_times = [ x * hop_s  for x in range(len(energies)) ]
57     for i, band in enumerate(energies.T):
58         ax = plt.axes ( [0.1, 0.75 - ((i+1) * 0.65 / n_plots),  0.8, 0.65 / n_plots], sharex = wave )
59         ax.plot(all_desc_times, band, '-', label = 'band %d' % i)
60         #ax.set_ylabel(method, rotation = 0)
61         ax.xaxis.set_visible(False)
62         ax.yaxis.set_visible(False)
63         ax.axis(xmax = all_desc_times[-1], xmin = all_desc_times[0])
64         ax.annotate('band %d' % i, xy=(-10, 0),  xycoords='axes points',
65                 horizontalalignment='right', verticalalignment='bottom',
66                 size = 'xx-small',
67                 )
68     set_xlabels_sample2time( ax, all_desc_times[-1], samplerate) 
69     #plt.ylabel('spectral descriptor value')
70     ax.xaxis.set_visible(True)
71     plt.show()