4 """ Pure python implementation of the sum of squared difference
6 sqd_yin: original sum of squared difference [0]
8 sqd_yinfast: sum of squared diff using complex domain [0]
9 sqd_yinfftslow: tappered squared diff [1]
10 sqd_yinfft: modified squared diff using complex domain [1]
12 [0]:http://audition.ens.fr/adc/pdf/2002_JASA_YIN.pdf
13 [1]:https://aubio.org/phd/
18 import matplotlib.pyplot as plt
21 """ compute original sum of squared difference
23 Brute-force computation (cost o(N**2), slow)."""
28 for tau in range(1, W):
29 yin[tau] += (samples[j] - samples[j+tau])**2
32 def sqd_yinfast(samples):
33 """ compute approximate sum of squared difference
35 Using complex convolution (fast, cost o(n*log(n)) )"""
36 # yin_t(tau) = (r_t(0) + r_(t+tau)(0)) - 2r_t(tau)
42 # compute r_(t+tau)(0)
45 sqdiff[tau] = squares[tau:tau+W].sum()
48 # compute r_t(tau) using kernel convolution in complex domain
49 samples_fft = np.fft.fft(samples)
50 kernel[1:W+1] = samples[W-1::-1] # first half, reversed
51 kernel_fft = np.fft.fft(kernel)
52 r_t_tau = np.fft.ifft(samples_fft * kernel_fft).real[W:]
54 yin = sqdiff - 2 * r_t_tau
57 def sqd_yintapered(samples):
58 """ compute tappered sum of squared difference
60 Brute-force computation (cost o(N**2), slow)."""
64 for tau in range(1, W):
65 for j in range(W - tau):
66 yin[tau] += (samples[j] - samples[j+tau])**2
69 def sqd_yinfft(samples):
70 """ compute yinfft modified sum of squared differences
72 Very fast, improved performance in transients.
79 return .5 * (1. - np.cos(2. * np.pi * np.arange(W) / W))
83 fftout = np.fft.fft(win*samples)
84 sqrmag[0] = fftout[0].real**2
86 sqrmag[l] = fftout[l].real**2 + fftout[l].imag**2
87 sqrmag[B-l] = sqrmag[l]
88 sqrmag[W] = fftout[W].real**2
89 fftout = np.fft.fft(sqrmag)
90 sqrsum = 2.*sqrmag[:W + 1].sum()
92 yin[1:] = sqrsum - fftout.real[1:W]
96 """ compute the cumulative mean normalized difference """
100 for tau in range(1, W):
103 yin[tau] *= tau/cumsum
114 print ("yin took %.2fms" % ((t1-now) * 1000.))
116 yinfast = sqd_yinfast(x)
118 print ("yinfast took: %.2fms" % ((t2-t1) * 1000.))
120 yintapered = sqd_yintapered(x)
122 print ("yintapered took: %.2fms" % ((t3-t2) * 1000.))
124 yinfft = sqd_yinfft(x)
126 print ("yinfft took: %.2fms" % ((t4-t3) * 1000.))
128 return yin, yinfast, yintapered, yinfft
130 def plot_all(yin, yinfast, yintapered, yinfft):
131 fig, axes = plt.subplots(nrows=2, ncols=2, sharex=True, sharey='col')
133 axes[0, 0].plot(yin, label='yin')
134 axes[0, 0].plot(yintapered, label='yintapered')
135 axes[0, 0].set_ylim(bottom=0)
137 axes[1, 0].plot(yinfast, '-', label='yinfast')
138 axes[1, 0].plot(yinfft, label='yinfft')
141 axes[0, 1].plot(cumdiff(yin), label='yin')
142 axes[0, 1].plot(cumdiff(yintapered), label='yin tapered')
143 axes[0, 1].set_ylim(bottom=0)
145 axes[1, 1].plot(cumdiff(yinfast), '-', label='yinfast')
146 axes[1, 1].plot(cumdiff(yinfft), label='yinfft')
151 testfreqs = [441., 800., 10000., 40.]
153 if len(sys.argv) > 1:
154 testfreqs = map(float,sys.argv[1:])
157 print ("Comparing yin implementations for sine wave at %.fHz" % f)
161 x = np.cos(2.*np.pi * np.arange(win_s) * f / samplerate)
164 for n in range(n_times):
165 yin, yinfast, yinfftslow, yinfft = compute_all(x)
166 if 0: # plot difference
167 plt.plot(yin-yinfast)
171 plt.plot(yinfftslow-yinfft)
174 plot_all(yin, yinfast, yinfftslow, yinfft)