python/demos/demo_yin_compare.py: fix indentation
[aubio.git] / python / demos / demo_yin_compare.py
1 #! /usr/bin/env python
2 # -*- coding: utf8 -*-
3
4 """ Pure python implementation of the sum of squared difference
5
6     sqd_yin: original sum of squared difference [0]
7         d_t(tau) = x ⊗ kernel
8     sqd_yinfast: sum of squared diff using complex domain [0]
9     sqd_yinfftslow: tappered squared diff [1]
10     sqd_yinfft: modified squared diff using complex domain [1]
11
12 [0]:http://audition.ens.fr/adc/pdf/2002_JASA_YIN.pdf
13 [1]:https://aubio.org/phd/
14 """
15
16 import sys
17 import numpy as np
18 import matplotlib.pyplot as plt
19
20 def sqd_yin(samples):
21     """ compute original sum of squared difference
22
23     Brute-force computation (cost o(N**2), slow)."""
24     B = len(samples)
25     W = B//2
26     yin = np.zeros(W)
27     for j in range(W):
28         for tau in range(1, W):
29             yin[tau] += (samples[j] - samples[j+tau])**2
30     return yin
31
32 def sqd_yinfast(samples):
33     """ compute approximate sum of squared difference
34
35     Using complex convolution (fast, cost o(n*log(n)) )"""
36     # yin_t(tau) = (r_t(0) + r_(t+tau)(0)) - 2r_t(tau)
37     B = len(samples)
38     W = B//2
39     yin = np.zeros(W)
40     sqdiff = np.zeros(W)
41     kernel = np.zeros(B)
42     # compute r_(t+tau)(0)
43     squares = samples**2
44     for tau in range(W):
45         sqdiff[tau] = squares[tau:tau+W].sum()
46     # add r_t(0)
47     sqdiff += sqdiff[0]
48     # compute r_t(tau) using kernel convolution in complex domain
49     samples_fft = np.fft.fft(samples)
50     kernel[1:W+1] = samples[W-1::-1] # first half, reversed
51     kernel_fft = np.fft.fft(kernel)
52     r_t_tau = np.fft.ifft(samples_fft * kernel_fft).real[W:]
53     # compute yin_t(tau)
54     yin = sqdiff - 2 * r_t_tau
55     return yin
56
57 def sqd_yintapered(samples):
58     """ compute tappered sum of squared difference
59
60     Brute-force computation (cost o(N**2), slow)."""
61     B = len(samples)
62     W = B//2
63     yin = np.zeros(W)
64     for tau in range(1, W):
65         for j in range(W - tau):
66             yin[tau] += (samples[j] - samples[j+tau])**2
67     return yin
68
69 def sqd_yinfft(samples):
70     """ compute yinfft modified sum of squared differences
71
72     Very fast, improved performance in transients.
73
74     FIXME: biased."""
75     B = len(samples)
76     W = B//2
77     yin = np.zeros(W)
78     def hanningz(W):
79         return .5 * (1. - np.cos(2. * np.pi * np.arange(W) / W))
80     #win = np.ones(B)
81     win = hanningz(B)
82     sqrmag = np.zeros(B)
83     fftout = np.fft.fft(win*samples)
84     sqrmag[0] = fftout[0].real**2
85     for l in range(1, W):
86         sqrmag[l] = fftout[l].real**2 + fftout[l].imag**2
87         sqrmag[B-l] = sqrmag[l]
88     sqrmag[W] = fftout[W].real**2
89     fftout = np.fft.fft(sqrmag)
90     sqrsum = 2.*sqrmag[:W + 1].sum()
91     yin[0] = 0
92     yin[1:] = sqrsum - fftout.real[1:W]
93     return yin / B
94
95 def cumdiff(yin):
96     """ compute the cumulative mean normalized difference """
97     W = len(yin)
98     yin[0] = 1.
99     cumsum = 0.
100     for tau in range(1, W):
101         cumsum += yin[tau]
102         if cumsum != 0:
103             yin[tau] *= tau/cumsum
104         else:
105             yin[tau] = 1
106     return yin
107
108 def compute_all(x):
109     import time
110     now = time.time()
111
112     yin     = sqd_yin(x)
113     t1 = time.time()
114     print ("yin took %.2fms" % ((t1-now) * 1000.))
115
116     yinfast = sqd_yinfast(x)
117     t2 = time.time()
118     print ("yinfast took: %.2fms" % ((t2-t1) * 1000.))
119
120     yintapered = sqd_yintapered(x)
121     t3 = time.time()
122     print ("yintapered took: %.2fms" % ((t3-t2) * 1000.))
123
124     yinfft  = sqd_yinfft(x)
125     t4 = time.time()
126     print ("yinfft took: %.2fms" % ((t4-t3) * 1000.))
127
128     return yin, yinfast, yintapered, yinfft
129
130 def plot_all(yin, yinfast, yintapered, yinfft):
131     fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey='col')
132
133     axes[0, 0].plot(yin, label='yin')
134     axes[0, 0].plot(yintapered, label='yintapered')
135     axes[0, 0].set_ylim(bottom=0)
136     axes[0, 0].legend()
137     axes[1, 0].plot(yinfast, '-', label='yinfast')
138     axes[1, 0].plot(yinfft, label='yinfft')
139     axes[1, 0].legend()
140
141     axes[0, 1].plot(cumdiff(yin), label='yin')
142     axes[0, 1].plot(cumdiff(yintapered), label='yin tapered')
143     axes[0, 1].set_ylim(bottom=0)
144     axes[0, 1].legend()
145     axes[1, 1].plot(cumdiff(yinfast), '-', label='yinfast')
146     axes[1, 1].plot(cumdiff(yinfft), label='yinfft')
147     axes[1, 1].legend()
148
149     fig.tight_layout()
150
151 testfreqs = [441., 800., 10000., 40.]
152
153 if len(sys.argv) > 1:
154     testfreqs = map(float,sys.argv[1:])
155
156 for f in testfreqs:
157     print ("Comparing yin implementations for sine wave at %.fHz" % f)
158     samplerate = 44100.
159     win_s = 4096
160
161     x = np.cos(2.*np.pi * np.arange(win_s) * f / samplerate)
162
163     n_times = 1#00
164     for n in range(n_times):
165         yin, yinfast, yinfftslow, yinfft = compute_all(x)
166     if 0: # plot difference
167         plt.plot(yin-yinfast)
168         plt.tight_layout()
169         plt.show()
170     if 1:
171         plt.plot(yinfftslow-yinfft)
172         plt.tight_layout()
173         plt.show()
174     plot_all(yin, yinfast, yinfftslow, yinfft)
175 plt.show()