lt-tensor 0.0.1a35__py3-none-any.whl → 0.0.1a36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lt_tensor/__init__.py +1 -1
- lt_tensor/model_zoo/audio_models/bigvgan/__init__.py +10 -10
- lt_tensor/model_zoo/audio_models/hifigan/__init__.py +6 -10
- lt_tensor/model_zoo/losses/CQT/__init__.py +0 -0
- lt_tensor/model_zoo/losses/CQT/transforms.py +336 -0
- lt_tensor/model_zoo/losses/CQT/utils.py +519 -0
- lt_tensor/model_zoo/losses/discriminators.py +232 -0
- lt_tensor/processors/audio.py +67 -57
- {lt_tensor-0.0.1a35.dist-info → lt_tensor-0.0.1a36.dist-info}/METADATA +1 -1
- {lt_tensor-0.0.1a35.dist-info → lt_tensor-0.0.1a36.dist-info}/RECORD +13 -10
- {lt_tensor-0.0.1a35.dist-info → lt_tensor-0.0.1a36.dist-info}/WHEEL +0 -0
- {lt_tensor-0.0.1a35.dist-info → lt_tensor-0.0.1a36.dist-info}/licenses/LICENSE +0 -0
- {lt_tensor-0.0.1a35.dist-info → lt_tensor-0.0.1a36.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,519 @@
|
|
1
|
+
"""
|
2
|
+
Module containing helper functions such as overlap sum and Fourier kernels generators
|
3
|
+
"""
|
4
|
+
import sys
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.nn.functional import conv1d, fold
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from time import time
|
11
|
+
import math
|
12
|
+
from scipy.signal import get_window
|
13
|
+
from scipy import signal
|
14
|
+
from scipy.fftpack import fft
|
15
|
+
import warnings
|
16
|
+
|
17
|
+
|
18
|
+
sz_float = 4 # size of a float
|
19
|
+
epsilon = 1e-8 # fudge factor for normalization
|
20
|
+
|
21
|
+
# Acquires and parses the PyTorch version
|
22
|
+
__TORCH_GTE_1_7 = False
|
23
|
+
split_version = torch.__version__.split(".")
|
24
|
+
major_version = int(split_version[0])
|
25
|
+
minor_version = int(split_version[1])
|
26
|
+
if major_version > 1 or (major_version == 1 and minor_version >= 7):
|
27
|
+
__TORCH_GTE_1_7 = True
|
28
|
+
import torch.fft
|
29
|
+
|
30
|
+
if "torch.fft" not in sys.modules:
|
31
|
+
raise RuntimeError("torch.fft module available but not imported")
|
32
|
+
|
33
|
+
|
34
|
+
def rfft_fn(x, n=None, onesided=False):
|
35
|
+
if __TORCH_GTE_1_7:
|
36
|
+
y = torch.fft.fft(x)
|
37
|
+
return torch.view_as_real(y)
|
38
|
+
else:
|
39
|
+
return torch.rfft(x, n, onesided=onesided)
|
40
|
+
|
41
|
+
## --------------------------- Filter Design ---------------------------##
|
42
|
+
def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
|
43
|
+
w_stacks = w.unsqueeze(-1).repeat((1, n_frames)).unsqueeze(0)
|
44
|
+
# Window length + stride*(frames-1)
|
45
|
+
output_len = w_stacks.shape[1] + stride * (w_stacks.shape[2] - 1)
|
46
|
+
return fold(
|
47
|
+
w_stacks ** power, (1, output_len), kernel_size=(1, n_fft), stride=stride
|
48
|
+
)
|
49
|
+
|
50
|
+
|
51
|
+
def overlap_add(X, stride):
|
52
|
+
n_fft = X.shape[1]
|
53
|
+
output_len = n_fft + stride * (X.shape[2] - 1)
|
54
|
+
|
55
|
+
return fold(X, (1, output_len), kernel_size=(1, n_fft), stride=stride).flatten(1)
|
56
|
+
|
57
|
+
|
58
|
+
def uniform_distribution(r1, r2, *size, device):
|
59
|
+
return (r1 - r2) * torch.rand(*size, device=device) + r2
|
60
|
+
|
61
|
+
|
62
|
+
def extend_fbins(X):
|
63
|
+
"""Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by
|
64
|
+
reversing all bins except DC and Nyquist and append it on top of existing spectrogram"""
|
65
|
+
X_upper = X[:, 1:-1].flip(1)
|
66
|
+
X_upper[:, :, :, 1] = -X_upper[
|
67
|
+
:, :, :, 1
|
68
|
+
] # For the imaganinry part, it is an odd function
|
69
|
+
return torch.cat((X[:, :, :], X_upper), 1)
|
70
|
+
|
71
|
+
|
72
|
+
def downsampling_by_n(x, filterKernel, n):
|
73
|
+
"""A helper function that downsamples the audio by a arbitary factor n.
|
74
|
+
It is used in CQT2010 and CQT2010v2.
|
75
|
+
|
76
|
+
Parameters
|
77
|
+
----------
|
78
|
+
x : torch.Tensor
|
79
|
+
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
|
80
|
+
|
81
|
+
filterKernel : str
|
82
|
+
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
|
83
|
+
|
84
|
+
n : int
|
85
|
+
The downsampling factor
|
86
|
+
|
87
|
+
Returns
|
88
|
+
-------
|
89
|
+
torch.Tensor
|
90
|
+
The downsampled waveform
|
91
|
+
|
92
|
+
Examples
|
93
|
+
--------
|
94
|
+
>>> x_down = downsampling_by_n(x, filterKernel)
|
95
|
+
"""
|
96
|
+
|
97
|
+
padding = int((filterKernel.shape[-1] - 1) // 2)
|
98
|
+
x = conv1d(x, filterKernel, stride=(n,), padding=(padding,))
|
99
|
+
return x
|
100
|
+
|
101
|
+
|
102
|
+
def downsampling_by_2(x, filterKernel):
|
103
|
+
"""A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
x : torch.Tensor
|
108
|
+
The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
|
109
|
+
|
110
|
+
filterKernel : str
|
111
|
+
Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
|
112
|
+
|
113
|
+
Returns
|
114
|
+
-------
|
115
|
+
torch.Tensor
|
116
|
+
The downsampled waveform
|
117
|
+
|
118
|
+
Examples
|
119
|
+
--------
|
120
|
+
>>> x_down = downsampling_by_2(x, filterKernel)
|
121
|
+
"""
|
122
|
+
|
123
|
+
return downsampling_by_n(x, filterKernel, 2)
|
124
|
+
|
125
|
+
|
126
|
+
## Basic tools for computation ##
|
127
|
+
def nextpow2(A):
|
128
|
+
"""A helper function to calculate the next nearest number to the power of 2.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
A : float
|
133
|
+
A float number that is going to be rounded up to the nearest power of 2
|
134
|
+
|
135
|
+
Returns
|
136
|
+
-------
|
137
|
+
int
|
138
|
+
The nearest power of 2 to the input number ``A``
|
139
|
+
|
140
|
+
Examples
|
141
|
+
--------
|
142
|
+
|
143
|
+
>>> nextpow2(6)
|
144
|
+
3
|
145
|
+
"""
|
146
|
+
|
147
|
+
return int(np.ceil(np.log2(A)))
|
148
|
+
|
149
|
+
|
150
|
+
## Basic tools for computation ##
|
151
|
+
def prepow2(A):
|
152
|
+
"""A helper function to calculate the next nearest number to the power of 2.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
A : float
|
157
|
+
A float number that is going to be rounded up to the nearest power of 2
|
158
|
+
|
159
|
+
Returns
|
160
|
+
-------
|
161
|
+
int
|
162
|
+
The nearest power of 2 to the input number ``A``
|
163
|
+
|
164
|
+
Examples
|
165
|
+
--------
|
166
|
+
|
167
|
+
>>> nextpow2(6)
|
168
|
+
3
|
169
|
+
"""
|
170
|
+
|
171
|
+
return int(np.floor(np.log2(A)))
|
172
|
+
|
173
|
+
|
174
|
+
def complex_mul(cqt_filter, stft):
|
175
|
+
"""Since PyTorch does not support complex numbers and its operation.
|
176
|
+
We need to write our own complex multiplication function. This one is specially
|
177
|
+
designed for CQT usage.
|
178
|
+
|
179
|
+
Parameters
|
180
|
+
----------
|
181
|
+
cqt_filter : tuple of torch.Tensor
|
182
|
+
The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)``
|
183
|
+
|
184
|
+
Returns
|
185
|
+
-------
|
186
|
+
tuple of torch.Tensor
|
187
|
+
The output is in the format of ``(real_torch_tensor, imag_torch_tensor)``
|
188
|
+
"""
|
189
|
+
|
190
|
+
cqt_filter_real = cqt_filter[0]
|
191
|
+
cqt_filter_imag = cqt_filter[1]
|
192
|
+
fourier_real = stft[0]
|
193
|
+
fourier_imag = stft[1]
|
194
|
+
|
195
|
+
CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(
|
196
|
+
cqt_filter_imag, fourier_imag
|
197
|
+
)
|
198
|
+
CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(
|
199
|
+
cqt_filter_imag, fourier_real
|
200
|
+
)
|
201
|
+
|
202
|
+
return CQT_real, CQT_imag
|
203
|
+
|
204
|
+
|
205
|
+
def broadcast_dim(x):
|
206
|
+
"""
|
207
|
+
Auto broadcast input so that it can fits into a Conv1d
|
208
|
+
"""
|
209
|
+
|
210
|
+
if x.dim() == 2:
|
211
|
+
x = x[:, None, :]
|
212
|
+
elif x.dim() == 1:
|
213
|
+
# If nn.DataParallel is used, this broadcast doesn't work
|
214
|
+
x = x[None, None, :]
|
215
|
+
elif x.dim() == 3:
|
216
|
+
pass
|
217
|
+
else:
|
218
|
+
raise ValueError(
|
219
|
+
"Only support input with shape = (batch, len) or shape = (len)"
|
220
|
+
)
|
221
|
+
return x
|
222
|
+
|
223
|
+
|
224
|
+
def broadcast_dim_conv2d(x):
|
225
|
+
"""
|
226
|
+
Auto broadcast input so that it can fits into a Conv2d
|
227
|
+
"""
|
228
|
+
|
229
|
+
if x.dim() == 3:
|
230
|
+
x = x[:, None, :, :]
|
231
|
+
|
232
|
+
else:
|
233
|
+
raise ValueError(
|
234
|
+
"Only support input with shape = (batch, len) or shape = (len)"
|
235
|
+
)
|
236
|
+
return x
|
237
|
+
|
238
|
+
# Tools for CQT
|
239
|
+
|
240
|
+
|
241
|
+
def create_cqt_kernels(
|
242
|
+
Q,
|
243
|
+
fs,
|
244
|
+
fmin,
|
245
|
+
n_bins=84,
|
246
|
+
bins_per_octave=12,
|
247
|
+
norm=1,
|
248
|
+
window="hann",
|
249
|
+
fmax=None,
|
250
|
+
topbin_check=True,
|
251
|
+
gamma=0,
|
252
|
+
pad_fft=True
|
253
|
+
):
|
254
|
+
"""
|
255
|
+
Automatically create CQT kernels in time domain
|
256
|
+
"""
|
257
|
+
|
258
|
+
fftLen = 2 ** nextpow2(np.ceil(Q * fs / fmin))
|
259
|
+
# minWin = 2**nextpow2(np.ceil(Q * fs / fmax))
|
260
|
+
|
261
|
+
if (fmax != None) and (n_bins == None):
|
262
|
+
n_bins = np.ceil(
|
263
|
+
bins_per_octave * np.log2(fmax / fmin)
|
264
|
+
) # Calculate the number of bins
|
265
|
+
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
|
266
|
+
|
267
|
+
elif (fmax == None) and (n_bins != None):
|
268
|
+
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
|
269
|
+
|
270
|
+
else:
|
271
|
+
warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning)
|
272
|
+
n_bins = np.ceil(
|
273
|
+
bins_per_octave * np.log2(fmax / fmin)
|
274
|
+
) # Calculate the number of bins
|
275
|
+
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
|
276
|
+
|
277
|
+
if np.max(freqs) > fs / 2 and topbin_check == True:
|
278
|
+
raise ValueError(
|
279
|
+
"The top bin {}Hz has exceeded the Nyquist frequency, \
|
280
|
+
please reduce the n_bins".format(
|
281
|
+
np.max(freqs)
|
282
|
+
)
|
283
|
+
)
|
284
|
+
|
285
|
+
alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
|
286
|
+
lengths = np.ceil(Q * fs / (freqs + gamma / alpha))
|
287
|
+
|
288
|
+
# get max window length depending on gamma value
|
289
|
+
max_len = int(max(lengths))
|
290
|
+
fftLen = int(2 ** (np.ceil(np.log2(max_len))))
|
291
|
+
|
292
|
+
tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
|
293
|
+
specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
|
294
|
+
|
295
|
+
for k in range(0, int(n_bins)):
|
296
|
+
freq = freqs[k]
|
297
|
+
l = lengths[k]
|
298
|
+
|
299
|
+
# Centering the kernels
|
300
|
+
if l % 2 == 1: # pad more zeros on RHS
|
301
|
+
start = int(np.ceil(fftLen / 2.0 - l / 2.0)) - 1
|
302
|
+
else:
|
303
|
+
start = int(np.ceil(fftLen / 2.0 - l / 2.0))
|
304
|
+
|
305
|
+
window_dispatch = get_window_dispatch(window, int(l), fftbins=True)
|
306
|
+
sig = window_dispatch * np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs) / l
|
307
|
+
|
308
|
+
if norm: # Normalizing the filter # Trying to normalize like librosa
|
309
|
+
tempKernel[k, start : start + int(l)] = sig / np.linalg.norm(sig, norm)
|
310
|
+
else:
|
311
|
+
tempKernel[k, start : start + int(l)] = sig
|
312
|
+
# specKernel[k, :] = fft(tempKernel[k])
|
313
|
+
|
314
|
+
# return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float()
|
315
|
+
return tempKernel, fftLen, torch.tensor(lengths).float(), freqs
|
316
|
+
|
317
|
+
|
318
|
+
def get_window_dispatch(window, N, fftbins=True):
|
319
|
+
if isinstance(window, str):
|
320
|
+
return get_window(window, N, fftbins=fftbins)
|
321
|
+
elif isinstance(window, tuple):
|
322
|
+
if window[0] == "gaussian":
|
323
|
+
assert window[1] >= 0
|
324
|
+
sigma = np.floor(-N / 2 / np.sqrt(-2 * np.log(10 ** (-window[1] / 20))))
|
325
|
+
return get_window(("gaussian", sigma), N, fftbins=fftbins)
|
326
|
+
else:
|
327
|
+
Warning("Tuple windows may have undesired behaviour regarding Q factor")
|
328
|
+
elif isinstance(window, float):
|
329
|
+
Warning(
|
330
|
+
"You are using Kaiser window with beta factor "
|
331
|
+
+ str(window)
|
332
|
+
+ ". Correct behaviour not checked."
|
333
|
+
)
|
334
|
+
else:
|
335
|
+
raise Exception(
|
336
|
+
"The function get_window from scipy only supports strings, tuples and floats."
|
337
|
+
)
|
338
|
+
|
339
|
+
|
340
|
+
def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding):
|
341
|
+
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
|
342
|
+
for how to multiple the STFT result with the CQT kernel
|
343
|
+
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
|
344
|
+
a constant Q transform.” (1992)."""
|
345
|
+
|
346
|
+
# STFT, converting the audio input from time domain to frequency domain
|
347
|
+
try:
|
348
|
+
x = padding(
|
349
|
+
x
|
350
|
+
) # When center == True, we need padding at the beginning and ending
|
351
|
+
except:
|
352
|
+
warnings.warn(
|
353
|
+
f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
|
354
|
+
"padding with reflection mode might not be the best choice, try using constant padding",
|
355
|
+
UserWarning,
|
356
|
+
)
|
357
|
+
x = torch.nn.functional.pad(
|
358
|
+
x, (cqt_kernels_real.shape[-1] // 2, cqt_kernels_real.shape[-1] // 2)
|
359
|
+
)
|
360
|
+
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
|
361
|
+
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
|
362
|
+
|
363
|
+
return torch.stack((CQT_real, CQT_imag), -1)
|
364
|
+
|
365
|
+
|
366
|
+
def get_cqt_complex2(
|
367
|
+
x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None
|
368
|
+
):
|
369
|
+
"""Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
|
370
|
+
for how to multiple the STFT result with the CQT kernel
|
371
|
+
[2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
|
372
|
+
a constant Q transform.” (1992)."""
|
373
|
+
|
374
|
+
# STFT, converting the audio input from time domain to frequency domain
|
375
|
+
try:
|
376
|
+
x = padding(
|
377
|
+
x
|
378
|
+
) # When center == True, we need padding at the beginning and ending
|
379
|
+
except:
|
380
|
+
warnings.warn(
|
381
|
+
f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
|
382
|
+
"padding with reflection mode might not be the best choice, try using constant padding",
|
383
|
+
UserWarning,
|
384
|
+
)
|
385
|
+
x = torch.nn.functional.pad(
|
386
|
+
x, (cqt_kernels_real.shape[-1] // 2, cqt_kernels_real.shape[-1] // 2)
|
387
|
+
)
|
388
|
+
|
389
|
+
if wcos == None or wsin == None:
|
390
|
+
CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
|
391
|
+
CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
|
392
|
+
|
393
|
+
else:
|
394
|
+
fourier_real = conv1d(x, wcos, stride=hop_length)
|
395
|
+
fourier_imag = conv1d(x, wsin, stride=hop_length)
|
396
|
+
# Multiplying input with the CQT kernel in freq domain
|
397
|
+
CQT_real, CQT_imag = complex_mul(
|
398
|
+
(cqt_kernels_real, cqt_kernels_imag), (fourier_real, fourier_imag)
|
399
|
+
)
|
400
|
+
|
401
|
+
return torch.stack((CQT_real, CQT_imag), -1)
|
402
|
+
|
403
|
+
|
404
|
+
def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03):
|
405
|
+
"""
|
406
|
+
Calculate the highest frequency we need to preserve and the lowest frequency we allow
|
407
|
+
to pass through.
|
408
|
+
Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of
|
409
|
+
the signal BEFORE downsampling.
|
410
|
+
"""
|
411
|
+
|
412
|
+
# transitionBandwidth = 0.03
|
413
|
+
passbandMax = band_center / (1 + transitionBandwidth)
|
414
|
+
stopbandMin = band_center * (1 + transitionBandwidth)
|
415
|
+
|
416
|
+
# Unlike the filter tool we used online yesterday, this tool does
|
417
|
+
# not allow us to specify how closely the filter matches our
|
418
|
+
# specifications. Instead, we specify the length of the kernel.
|
419
|
+
# The longer the kernel is, the more precisely it will match.
|
420
|
+
# kernelLength = 256
|
421
|
+
|
422
|
+
# We specify a list of key frequencies for which we will require
|
423
|
+
# that the filter match a specific output gain.
|
424
|
+
# From [0.0 to passbandMax] is the frequency range we want to keep
|
425
|
+
# untouched and [stopbandMin, 1.0] is the range we want to remove
|
426
|
+
keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0]
|
427
|
+
|
428
|
+
# We specify a list of output gains to correspond to the key
|
429
|
+
# frequencies listed above.
|
430
|
+
# The first two gains are 1.0 because they correspond to the first
|
431
|
+
# two key frequencies. the second two are 0.0 because they
|
432
|
+
# correspond to the stopband frequencies
|
433
|
+
gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0]
|
434
|
+
|
435
|
+
# This command produces the filter kernel coefficients
|
436
|
+
filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies)
|
437
|
+
|
438
|
+
return filterKernel.astype(np.float32)
|
439
|
+
|
440
|
+
|
441
|
+
def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose):
|
442
|
+
"""Used in CQT2010 and CQT2010v2"""
|
443
|
+
|
444
|
+
window_bandwidth = 1.5 # for hann window
|
445
|
+
filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
|
446
|
+
sr, hop_length, downsample_factor = early_downsample(
|
447
|
+
sr, hop_length, n_octaves, sr // 2, filter_cutoff
|
448
|
+
)
|
449
|
+
if downsample_factor != 1:
|
450
|
+
if verbose:
|
451
|
+
print("Can do early downsample, factor = ", downsample_factor)
|
452
|
+
earlydownsample = True
|
453
|
+
# print("new sr = ", sr)
|
454
|
+
# print("new hop_length = ", hop_length)
|
455
|
+
early_downsample_filter = create_lowpass_filter(
|
456
|
+
band_center=1 / downsample_factor,
|
457
|
+
kernelLength=256,
|
458
|
+
transitionBandwidth=0.03,
|
459
|
+
)
|
460
|
+
early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :]
|
461
|
+
|
462
|
+
else:
|
463
|
+
if verbose:
|
464
|
+
print(
|
465
|
+
"No early downsampling is required, downsample_factor = ",
|
466
|
+
downsample_factor,
|
467
|
+
)
|
468
|
+
early_downsample_filter = None
|
469
|
+
earlydownsample = False
|
470
|
+
|
471
|
+
return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample
|
472
|
+
|
473
|
+
|
474
|
+
def early_downsample(sr, hop_length, n_octaves, nyquist, filter_cutoff):
|
475
|
+
"""Return new sampling rate and hop length after early dowansampling"""
|
476
|
+
downsample_count = early_downsample_count(
|
477
|
+
nyquist, filter_cutoff, hop_length, n_octaves
|
478
|
+
)
|
479
|
+
# print("downsample_count = ", downsample_count)
|
480
|
+
downsample_factor = 2 ** (downsample_count)
|
481
|
+
|
482
|
+
hop_length //= downsample_factor # Getting new hop_length
|
483
|
+
new_sr = sr / float(downsample_factor) # Getting new sampling rate
|
484
|
+
sr = new_sr
|
485
|
+
|
486
|
+
return sr, hop_length, downsample_factor
|
487
|
+
|
488
|
+
|
489
|
+
# The following two downsampling count functions are obtained from librosa CQT
|
490
|
+
# They are used to determine the number of pre resamplings if the starting and ending frequency
|
491
|
+
# are both in low frequency regions.
|
492
|
+
def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves):
|
493
|
+
"""Compute the number of early downsampling operations"""
|
494
|
+
|
495
|
+
downsample_count1 = max(
|
496
|
+
0, int(np.ceil(np.log2(0.85 * nyquist / filter_cutoff)) - 1) - 1
|
497
|
+
)
|
498
|
+
# print("downsample_count1 = ", downsample_count1)
|
499
|
+
num_twos = nextpow2(hop_length)
|
500
|
+
downsample_count2 = max(0, num_twos - n_octaves + 1)
|
501
|
+
# print("downsample_count2 = ",downsample_count2)
|
502
|
+
|
503
|
+
return min(downsample_count1, downsample_count2)
|
504
|
+
|
505
|
+
|
506
|
+
def early_downsample(sr, hop_length, n_octaves, nyquist, filter_cutoff):
|
507
|
+
"""Return new sampling rate and hop length after early dowansampling"""
|
508
|
+
downsample_count = early_downsample_count(
|
509
|
+
nyquist, filter_cutoff, hop_length, n_octaves
|
510
|
+
)
|
511
|
+
# print("downsample_count = ", downsample_count)
|
512
|
+
downsample_factor = 2 ** (downsample_count)
|
513
|
+
|
514
|
+
hop_length //= downsample_factor # Getting new hop_length
|
515
|
+
new_sr = sr / float(downsample_factor) # Getting new sampling rate
|
516
|
+
|
517
|
+
sr = new_sr
|
518
|
+
|
519
|
+
return sr, hop_length, downsample_factor
|