de06d001efab19e86c857468d2fcd40149638672
[aubio.git] / python / demos / demo_tempo_plot.py
1 #! /usr/bin/env python
2
3 import sys
4 from aubio import tempo, source
5
6 win_s = 512                 # fft size
7 hop_s = win_s / 2           # hop size
8
9 if len(sys.argv) < 2:
10     print "Usage: %s <filename> [samplerate]" % sys.argv[0]
11     sys.exit(1)
12
13 filename = sys.argv[1]
14
15 samplerate = 0
16 if len( sys.argv ) > 2: samplerate = int(sys.argv[2])
17
18 s = source(filename, samplerate, hop_s)
19 samplerate = s.samplerate
20 o = tempo("default", win_s, hop_s, samplerate)
21
22 # tempo detection delay, in samples
23 # default to 4 blocks delay to catch up with
24 delay = 4. * hop_s
25
26 # list of beats, in samples
27 beats = []
28
29 # total number of frames read
30 total_frames = 0
31 while True:
32     samples, read = s()
33     is_beat = o(samples)
34     if is_beat:
35         this_beat = o.get_last_s()
36         beats.append(this_beat)
37     total_frames += read
38     if read < hop_s: break
39
40 if len(beats) > 1:
41     # do plotting
42     from numpy import array, arange, mean, median, diff
43     import matplotlib.pyplot as plt
44     bpms = 60./ diff(beats)
45     print 'mean period:', "%.2f" % mean(bpms), 'bpm', 'median', "%.2f" % median(bpms), 'bpm'
46     print 'plotting', filename
47     plt1 = plt.axes([0.1, 0.75, 0.8, 0.19])
48     plt2 = plt.axes([0.1, 0.1, 0.8, 0.65], sharex = plt1)
49     plt.rc('lines',linewidth='.8')
50     for stamp in beats: plt1.plot([stamp, stamp], [-1., 1.], '-r')
51     plt1.axis(xmin = 0., xmax = total_frames / float(samplerate) )
52     plt1.xaxis.set_visible(False)
53     plt1.yaxis.set_visible(False)
54
55     # plot actual periods
56     plt2.plot(beats[1:], bpms, '-', label = 'raw')
57
58     # plot moving median of 5 last periods
59     median_win_s = 5
60     bpms_median = [ median(bpms[i:i + median_win_s:1]) for i in range(len(bpms) - median_win_s ) ]
61     plt2.plot(beats[median_win_s+1:], bpms_median, '-', label = 'median of %d' % median_win_s)
62     # plot moving median of 10 last periods
63     median_win_s = 20
64     bpms_median = [ median(bpms[i:i + median_win_s:1]) for i in range(len(bpms) - median_win_s ) ]
65     plt2.plot(beats[median_win_s+1:], bpms_median, '-', label = 'median of %d' % median_win_s)
66
67     plt2.axis(ymin = min(bpms), ymax = max(bpms))
68     #plt2.axis(ymin = 40, ymax = 240)
69     plt.xlabel('time (mm:ss)')
70     plt.ylabel('beats per minute (bpm)')
71     plt2.set_xticklabels([ "%02d:%02d" % (t/60, t%60) for t in plt2.get_xticks()[:-1]], rotation = 50)
72
73     #plt.savefig('/tmp/t.png', dpi=200)
74     plt2.legend()
75     plt.show()
76
77 else:
78     print 'mean period:', "%.2f" % 0, 'bpm', 'median', "%.2f" % 0, 'bpm',
79     print 'nothing to plot, file too short?'