lt-tensor 0.0.1a35__py3-none-any.whl → 0.0.1a36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,519 @@
1
+ """
2
+ Module containing helper functions such as overlap sum and Fourier kernels generators
3
+ """
4
+ import sys
5
+
6
+ import torch
7
+ from torch.nn.functional import conv1d, fold
8
+
9
+ import numpy as np
10
+ from time import time
11
+ import math
12
+ from scipy.signal import get_window
13
+ from scipy import signal
14
+ from scipy.fftpack import fft
15
+ import warnings
16
+
17
+
18
+ sz_float = 4 # size of a float
19
+ epsilon = 1e-8 # fudge factor for normalization
20
+
21
+ # Acquires and parses the PyTorch version
22
+ __TORCH_GTE_1_7 = False
23
+ split_version = torch.__version__.split(".")
24
+ major_version = int(split_version[0])
25
+ minor_version = int(split_version[1])
26
+ if major_version > 1 or (major_version == 1 and minor_version >= 7):
27
+ __TORCH_GTE_1_7 = True
28
+ import torch.fft
29
+
30
+ if "torch.fft" not in sys.modules:
31
+ raise RuntimeError("torch.fft module available but not imported")
32
+
33
+
34
+ def rfft_fn(x, n=None, onesided=False):
35
+ if __TORCH_GTE_1_7:
36
+ y = torch.fft.fft(x)
37
+ return torch.view_as_real(y)
38
+ else:
39
+ return torch.rfft(x, n, onesided=onesided)
40
+
41
+ ## --------------------------- Filter Design ---------------------------##
42
+ def torch_window_sumsquare(w, n_frames, stride, n_fft, power=2):
43
+ w_stacks = w.unsqueeze(-1).repeat((1, n_frames)).unsqueeze(0)
44
+ # Window length + stride*(frames-1)
45
+ output_len = w_stacks.shape[1] + stride * (w_stacks.shape[2] - 1)
46
+ return fold(
47
+ w_stacks ** power, (1, output_len), kernel_size=(1, n_fft), stride=stride
48
+ )
49
+
50
+
51
+ def overlap_add(X, stride):
52
+ n_fft = X.shape[1]
53
+ output_len = n_fft + stride * (X.shape[2] - 1)
54
+
55
+ return fold(X, (1, output_len), kernel_size=(1, n_fft), stride=stride).flatten(1)
56
+
57
+
58
+ def uniform_distribution(r1, r2, *size, device):
59
+ return (r1 - r2) * torch.rand(*size, device=device) + r2
60
+
61
+
62
+ def extend_fbins(X):
63
+ """Extending the number of frequency bins from `n_fft//2+1` back to `n_fft` by
64
+ reversing all bins except DC and Nyquist and append it on top of existing spectrogram"""
65
+ X_upper = X[:, 1:-1].flip(1)
66
+ X_upper[:, :, :, 1] = -X_upper[
67
+ :, :, :, 1
68
+ ] # For the imaganinry part, it is an odd function
69
+ return torch.cat((X[:, :, :], X_upper), 1)
70
+
71
+
72
+ def downsampling_by_n(x, filterKernel, n):
73
+ """A helper function that downsamples the audio by a arbitary factor n.
74
+ It is used in CQT2010 and CQT2010v2.
75
+
76
+ Parameters
77
+ ----------
78
+ x : torch.Tensor
79
+ The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
80
+
81
+ filterKernel : str
82
+ Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
83
+
84
+ n : int
85
+ The downsampling factor
86
+
87
+ Returns
88
+ -------
89
+ torch.Tensor
90
+ The downsampled waveform
91
+
92
+ Examples
93
+ --------
94
+ >>> x_down = downsampling_by_n(x, filterKernel)
95
+ """
96
+
97
+ padding = int((filterKernel.shape[-1] - 1) // 2)
98
+ x = conv1d(x, filterKernel, stride=(n,), padding=(padding,))
99
+ return x
100
+
101
+
102
+ def downsampling_by_2(x, filterKernel):
103
+ """A helper function that downsamples the audio by half. It is used in CQT2010 and CQT2010v2
104
+
105
+ Parameters
106
+ ----------
107
+ x : torch.Tensor
108
+ The input waveform in ``torch.Tensor`` type with shape ``(batch, 1, len_audio)``
109
+
110
+ filterKernel : str
111
+ Filter kernel in ``torch.Tensor`` type with shape ``(1, 1, len_kernel)``
112
+
113
+ Returns
114
+ -------
115
+ torch.Tensor
116
+ The downsampled waveform
117
+
118
+ Examples
119
+ --------
120
+ >>> x_down = downsampling_by_2(x, filterKernel)
121
+ """
122
+
123
+ return downsampling_by_n(x, filterKernel, 2)
124
+
125
+
126
+ ## Basic tools for computation ##
127
+ def nextpow2(A):
128
+ """A helper function to calculate the next nearest number to the power of 2.
129
+
130
+ Parameters
131
+ ----------
132
+ A : float
133
+ A float number that is going to be rounded up to the nearest power of 2
134
+
135
+ Returns
136
+ -------
137
+ int
138
+ The nearest power of 2 to the input number ``A``
139
+
140
+ Examples
141
+ --------
142
+
143
+ >>> nextpow2(6)
144
+ 3
145
+ """
146
+
147
+ return int(np.ceil(np.log2(A)))
148
+
149
+
150
+ ## Basic tools for computation ##
151
+ def prepow2(A):
152
+ """A helper function to calculate the next nearest number to the power of 2.
153
+
154
+ Parameters
155
+ ----------
156
+ A : float
157
+ A float number that is going to be rounded up to the nearest power of 2
158
+
159
+ Returns
160
+ -------
161
+ int
162
+ The nearest power of 2 to the input number ``A``
163
+
164
+ Examples
165
+ --------
166
+
167
+ >>> nextpow2(6)
168
+ 3
169
+ """
170
+
171
+ return int(np.floor(np.log2(A)))
172
+
173
+
174
+ def complex_mul(cqt_filter, stft):
175
+ """Since PyTorch does not support complex numbers and its operation.
176
+ We need to write our own complex multiplication function. This one is specially
177
+ designed for CQT usage.
178
+
179
+ Parameters
180
+ ----------
181
+ cqt_filter : tuple of torch.Tensor
182
+ The tuple is in the format of ``(real_torch_tensor, imag_torch_tensor)``
183
+
184
+ Returns
185
+ -------
186
+ tuple of torch.Tensor
187
+ The output is in the format of ``(real_torch_tensor, imag_torch_tensor)``
188
+ """
189
+
190
+ cqt_filter_real = cqt_filter[0]
191
+ cqt_filter_imag = cqt_filter[1]
192
+ fourier_real = stft[0]
193
+ fourier_imag = stft[1]
194
+
195
+ CQT_real = torch.matmul(cqt_filter_real, fourier_real) - torch.matmul(
196
+ cqt_filter_imag, fourier_imag
197
+ )
198
+ CQT_imag = torch.matmul(cqt_filter_real, fourier_imag) + torch.matmul(
199
+ cqt_filter_imag, fourier_real
200
+ )
201
+
202
+ return CQT_real, CQT_imag
203
+
204
+
205
+ def broadcast_dim(x):
206
+ """
207
+ Auto broadcast input so that it can fits into a Conv1d
208
+ """
209
+
210
+ if x.dim() == 2:
211
+ x = x[:, None, :]
212
+ elif x.dim() == 1:
213
+ # If nn.DataParallel is used, this broadcast doesn't work
214
+ x = x[None, None, :]
215
+ elif x.dim() == 3:
216
+ pass
217
+ else:
218
+ raise ValueError(
219
+ "Only support input with shape = (batch, len) or shape = (len)"
220
+ )
221
+ return x
222
+
223
+
224
+ def broadcast_dim_conv2d(x):
225
+ """
226
+ Auto broadcast input so that it can fits into a Conv2d
227
+ """
228
+
229
+ if x.dim() == 3:
230
+ x = x[:, None, :, :]
231
+
232
+ else:
233
+ raise ValueError(
234
+ "Only support input with shape = (batch, len) or shape = (len)"
235
+ )
236
+ return x
237
+
238
+ # Tools for CQT
239
+
240
+
241
+ def create_cqt_kernels(
242
+ Q,
243
+ fs,
244
+ fmin,
245
+ n_bins=84,
246
+ bins_per_octave=12,
247
+ norm=1,
248
+ window="hann",
249
+ fmax=None,
250
+ topbin_check=True,
251
+ gamma=0,
252
+ pad_fft=True
253
+ ):
254
+ """
255
+ Automatically create CQT kernels in time domain
256
+ """
257
+
258
+ fftLen = 2 ** nextpow2(np.ceil(Q * fs / fmin))
259
+ # minWin = 2**nextpow2(np.ceil(Q * fs / fmax))
260
+
261
+ if (fmax != None) and (n_bins == None):
262
+ n_bins = np.ceil(
263
+ bins_per_octave * np.log2(fmax / fmin)
264
+ ) # Calculate the number of bins
265
+ freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
266
+
267
+ elif (fmax == None) and (n_bins != None):
268
+ freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
269
+
270
+ else:
271
+ warnings.warn("If fmax is given, n_bins will be ignored", SyntaxWarning)
272
+ n_bins = np.ceil(
273
+ bins_per_octave * np.log2(fmax / fmin)
274
+ ) # Calculate the number of bins
275
+ freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.double(bins_per_octave))
276
+
277
+ if np.max(freqs) > fs / 2 and topbin_check == True:
278
+ raise ValueError(
279
+ "The top bin {}Hz has exceeded the Nyquist frequency, \
280
+ please reduce the n_bins".format(
281
+ np.max(freqs)
282
+ )
283
+ )
284
+
285
+ alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0
286
+ lengths = np.ceil(Q * fs / (freqs + gamma / alpha))
287
+
288
+ # get max window length depending on gamma value
289
+ max_len = int(max(lengths))
290
+ fftLen = int(2 ** (np.ceil(np.log2(max_len))))
291
+
292
+ tempKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
293
+ specKernel = np.zeros((int(n_bins), int(fftLen)), dtype=np.complex64)
294
+
295
+ for k in range(0, int(n_bins)):
296
+ freq = freqs[k]
297
+ l = lengths[k]
298
+
299
+ # Centering the kernels
300
+ if l % 2 == 1: # pad more zeros on RHS
301
+ start = int(np.ceil(fftLen / 2.0 - l / 2.0)) - 1
302
+ else:
303
+ start = int(np.ceil(fftLen / 2.0 - l / 2.0))
304
+
305
+ window_dispatch = get_window_dispatch(window, int(l), fftbins=True)
306
+ sig = window_dispatch * np.exp(np.r_[-l // 2 : l // 2] * 1j * 2 * np.pi * freq / fs) / l
307
+
308
+ if norm: # Normalizing the filter # Trying to normalize like librosa
309
+ tempKernel[k, start : start + int(l)] = sig / np.linalg.norm(sig, norm)
310
+ else:
311
+ tempKernel[k, start : start + int(l)] = sig
312
+ # specKernel[k, :] = fft(tempKernel[k])
313
+
314
+ # return specKernel[:,:fftLen//2+1], fftLen, torch.tensor(lenghts).float()
315
+ return tempKernel, fftLen, torch.tensor(lengths).float(), freqs
316
+
317
+
318
+ def get_window_dispatch(window, N, fftbins=True):
319
+ if isinstance(window, str):
320
+ return get_window(window, N, fftbins=fftbins)
321
+ elif isinstance(window, tuple):
322
+ if window[0] == "gaussian":
323
+ assert window[1] >= 0
324
+ sigma = np.floor(-N / 2 / np.sqrt(-2 * np.log(10 ** (-window[1] / 20))))
325
+ return get_window(("gaussian", sigma), N, fftbins=fftbins)
326
+ else:
327
+ Warning("Tuple windows may have undesired behaviour regarding Q factor")
328
+ elif isinstance(window, float):
329
+ Warning(
330
+ "You are using Kaiser window with beta factor "
331
+ + str(window)
332
+ + ". Correct behaviour not checked."
333
+ )
334
+ else:
335
+ raise Exception(
336
+ "The function get_window from scipy only supports strings, tuples and floats."
337
+ )
338
+
339
+
340
+ def get_cqt_complex(x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding):
341
+ """Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
342
+ for how to multiple the STFT result with the CQT kernel
343
+ [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
344
+ a constant Q transform.” (1992)."""
345
+
346
+ # STFT, converting the audio input from time domain to frequency domain
347
+ try:
348
+ x = padding(
349
+ x
350
+ ) # When center == True, we need padding at the beginning and ending
351
+ except:
352
+ warnings.warn(
353
+ f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
354
+ "padding with reflection mode might not be the best choice, try using constant padding",
355
+ UserWarning,
356
+ )
357
+ x = torch.nn.functional.pad(
358
+ x, (cqt_kernels_real.shape[-1] // 2, cqt_kernels_real.shape[-1] // 2)
359
+ )
360
+ CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
361
+ CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
362
+
363
+ return torch.stack((CQT_real, CQT_imag), -1)
364
+
365
+
366
+ def get_cqt_complex2(
367
+ x, cqt_kernels_real, cqt_kernels_imag, hop_length, padding, wcos=None, wsin=None
368
+ ):
369
+ """Multiplying the STFT result with the cqt_kernel, check out the 1992 CQT paper [1]
370
+ for how to multiple the STFT result with the CQT kernel
371
+ [2] Brown, Judith C.C. and Miller Puckette. “An efficient algorithm for the calculation of
372
+ a constant Q transform.” (1992)."""
373
+
374
+ # STFT, converting the audio input from time domain to frequency domain
375
+ try:
376
+ x = padding(
377
+ x
378
+ ) # When center == True, we need padding at the beginning and ending
379
+ except:
380
+ warnings.warn(
381
+ f"\ninput size = {x.shape}\tkernel size = {cqt_kernels_real.shape[-1]}\n"
382
+ "padding with reflection mode might not be the best choice, try using constant padding",
383
+ UserWarning,
384
+ )
385
+ x = torch.nn.functional.pad(
386
+ x, (cqt_kernels_real.shape[-1] // 2, cqt_kernels_real.shape[-1] // 2)
387
+ )
388
+
389
+ if wcos == None or wsin == None:
390
+ CQT_real = conv1d(x, cqt_kernels_real, stride=hop_length)
391
+ CQT_imag = -conv1d(x, cqt_kernels_imag, stride=hop_length)
392
+
393
+ else:
394
+ fourier_real = conv1d(x, wcos, stride=hop_length)
395
+ fourier_imag = conv1d(x, wsin, stride=hop_length)
396
+ # Multiplying input with the CQT kernel in freq domain
397
+ CQT_real, CQT_imag = complex_mul(
398
+ (cqt_kernels_real, cqt_kernels_imag), (fourier_real, fourier_imag)
399
+ )
400
+
401
+ return torch.stack((CQT_real, CQT_imag), -1)
402
+
403
+
404
+ def create_lowpass_filter(band_center=0.5, kernelLength=256, transitionBandwidth=0.03):
405
+ """
406
+ Calculate the highest frequency we need to preserve and the lowest frequency we allow
407
+ to pass through.
408
+ Note that frequency is on a scale from 0 to 1 where 0 is 0 and 1 is Nyquist frequency of
409
+ the signal BEFORE downsampling.
410
+ """
411
+
412
+ # transitionBandwidth = 0.03
413
+ passbandMax = band_center / (1 + transitionBandwidth)
414
+ stopbandMin = band_center * (1 + transitionBandwidth)
415
+
416
+ # Unlike the filter tool we used online yesterday, this tool does
417
+ # not allow us to specify how closely the filter matches our
418
+ # specifications. Instead, we specify the length of the kernel.
419
+ # The longer the kernel is, the more precisely it will match.
420
+ # kernelLength = 256
421
+
422
+ # We specify a list of key frequencies for which we will require
423
+ # that the filter match a specific output gain.
424
+ # From [0.0 to passbandMax] is the frequency range we want to keep
425
+ # untouched and [stopbandMin, 1.0] is the range we want to remove
426
+ keyFrequencies = [0.0, passbandMax, stopbandMin, 1.0]
427
+
428
+ # We specify a list of output gains to correspond to the key
429
+ # frequencies listed above.
430
+ # The first two gains are 1.0 because they correspond to the first
431
+ # two key frequencies. the second two are 0.0 because they
432
+ # correspond to the stopband frequencies
433
+ gainAtKeyFrequencies = [1.0, 1.0, 0.0, 0.0]
434
+
435
+ # This command produces the filter kernel coefficients
436
+ filterKernel = signal.firwin2(kernelLength, keyFrequencies, gainAtKeyFrequencies)
437
+
438
+ return filterKernel.astype(np.float32)
439
+
440
+
441
+ def get_early_downsample_params(sr, hop_length, fmax_t, Q, n_octaves, verbose):
442
+ """Used in CQT2010 and CQT2010v2"""
443
+
444
+ window_bandwidth = 1.5 # for hann window
445
+ filter_cutoff = fmax_t * (1 + 0.5 * window_bandwidth / Q)
446
+ sr, hop_length, downsample_factor = early_downsample(
447
+ sr, hop_length, n_octaves, sr // 2, filter_cutoff
448
+ )
449
+ if downsample_factor != 1:
450
+ if verbose:
451
+ print("Can do early downsample, factor = ", downsample_factor)
452
+ earlydownsample = True
453
+ # print("new sr = ", sr)
454
+ # print("new hop_length = ", hop_length)
455
+ early_downsample_filter = create_lowpass_filter(
456
+ band_center=1 / downsample_factor,
457
+ kernelLength=256,
458
+ transitionBandwidth=0.03,
459
+ )
460
+ early_downsample_filter = torch.tensor(early_downsample_filter)[None, None, :]
461
+
462
+ else:
463
+ if verbose:
464
+ print(
465
+ "No early downsampling is required, downsample_factor = ",
466
+ downsample_factor,
467
+ )
468
+ early_downsample_filter = None
469
+ earlydownsample = False
470
+
471
+ return sr, hop_length, downsample_factor, early_downsample_filter, earlydownsample
472
+
473
+
474
+ def early_downsample(sr, hop_length, n_octaves, nyquist, filter_cutoff):
475
+ """Return new sampling rate and hop length after early dowansampling"""
476
+ downsample_count = early_downsample_count(
477
+ nyquist, filter_cutoff, hop_length, n_octaves
478
+ )
479
+ # print("downsample_count = ", downsample_count)
480
+ downsample_factor = 2 ** (downsample_count)
481
+
482
+ hop_length //= downsample_factor # Getting new hop_length
483
+ new_sr = sr / float(downsample_factor) # Getting new sampling rate
484
+ sr = new_sr
485
+
486
+ return sr, hop_length, downsample_factor
487
+
488
+
489
+ # The following two downsampling count functions are obtained from librosa CQT
490
+ # They are used to determine the number of pre resamplings if the starting and ending frequency
491
+ # are both in low frequency regions.
492
+ def early_downsample_count(nyquist, filter_cutoff, hop_length, n_octaves):
493
+ """Compute the number of early downsampling operations"""
494
+
495
+ downsample_count1 = max(
496
+ 0, int(np.ceil(np.log2(0.85 * nyquist / filter_cutoff)) - 1) - 1
497
+ )
498
+ # print("downsample_count1 = ", downsample_count1)
499
+ num_twos = nextpow2(hop_length)
500
+ downsample_count2 = max(0, num_twos - n_octaves + 1)
501
+ # print("downsample_count2 = ",downsample_count2)
502
+
503
+ return min(downsample_count1, downsample_count2)
504
+
505
+
506
+ def early_downsample(sr, hop_length, n_octaves, nyquist, filter_cutoff):
507
+ """Return new sampling rate and hop length after early dowansampling"""
508
+ downsample_count = early_downsample_count(
509
+ nyquist, filter_cutoff, hop_length, n_octaves
510
+ )
511
+ # print("downsample_count = ", downsample_count)
512
+ downsample_factor = 2 ** (downsample_count)
513
+
514
+ hop_length //= downsample_factor # Getting new hop_length
515
+ new_sr = sr / float(downsample_factor) # Getting new sampling rate
516
+
517
+ sr = new_sr
518
+
519
+ return sr, hop_length, downsample_factor