AudioMlSpecTools 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ from ._common import MelType, ScalingType, SpecType, AmplitudeToDB, ExportableSTFT # noqa
2
+ from .efficient_features import ChannelConfig, EfficientFeatureSource # noqa
3
+ from .flexible_features import FeatureSource, FeatureChannel # noqa
4
+ from .preproc import AudioPreprocessor, HighPassFilter # noqa
5
+ from .stft_features import FullRangeStftFeatureSource # noqa
6
+ from .wav import load_wav, set_audio_length, list_audio_files, WavReader # noqa
@@ -0,0 +1,502 @@
1
+ ###############################################################################
2
+ # Global Imports
3
+ ###############################################################################
4
+ from abc import ABC, abstractmethod
5
+ from enum import Enum
6
+ import json
7
+ import math
8
+ from typing import Optional, Sequence
9
+ import warnings
10
+
11
+ ###############################################################################
12
+ # 3PP Imports
13
+ ###############################################################################
14
+ import numpy as np
15
+ import torch
16
+
17
+
18
+ ###############################################################################
19
+ # Enumerated Types
20
+ ###############################################################################
21
+ class SpecType(Enum):
22
+ '''
23
+ Possible spectrum types for feature extraction.
24
+
25
+ The literature is conflicted on which, if any, provides the best results.
26
+ '''
27
+ STFT = "STFT"
28
+ LOG_STFT = "LOG_STFT"
29
+ LFCC = "LFCC"
30
+ MEL = "MEL"
31
+ LOG_MEL = "LOG_MEL"
32
+ MFCC = "MFCC"
33
+
34
+
35
+ class MelType(Enum):
36
+ OSHAUGHNESSY = "O'Shaughnessy"
37
+ FANT = "Fant"
38
+ LINDSAY_NORMAN = "Lindsay & Norman"
39
+ SLANEY = "Slaney"
40
+
41
+
42
+ class ScalingType(Enum):
43
+ POWER = "power"
44
+ MAGNITUDE = "magnitude"
45
+ LOG = "log"
46
+
47
+
48
+ ###############################################################################
49
+ # New Classes
50
+ ###############################################################################
51
+ class ExportableSTFT(torch.nn.Module):
52
+ """
53
+ Exportable to Executorch. Created by Claude with supervision.
54
+
55
+ Results are very slightly different than Torch but show no noticeable difference in model accuracy or training.
56
+ """
57
+ def __init__(self,
58
+ n_fft: int,
59
+ hop_length: int,
60
+ *,
61
+ win_length: Optional[int] = None,
62
+ window: Optional[torch.Tensor] = None
63
+ ):
64
+ super().__init__()
65
+
66
+ _w = window if window is not None else torch.hann_window(n_fft)
67
+ _window_len = win_length or n_fft
68
+ if _window_len < n_fft:
69
+ pad_left = (n_fft - _window_len) // 2
70
+ pad_right = n_fft - _window_len - pad_left
71
+ _w = torch.nn.functional.pad(_w, (pad_left, pad_right))
72
+
73
+ # Only compute onesided bins: n_fft//2+1
74
+ k = torch.arange(n_fft // 2 + 1).unsqueeze(1) # (freq, 1)
75
+ n = torch.arange(n_fft).unsqueeze(0) # (1, n_fft)
76
+ angles = -2 * torch.pi * k * n / n_fft # (freq, n_fft)
77
+
78
+ self.hop_length = hop_length
79
+ self.n_fft = n_fft
80
+ self.pad = n_fft // 2
81
+
82
+ # Shape: (n_fft//2+1, n_fft)
83
+ self.register_buffer("dft_real", torch.cos(angles) * _w)
84
+ self.register_buffer("dft_imag", torch.sin(angles) * _w)
85
+
86
+ def forward(self, x: torch.Tensor):
87
+ if x.dim() == 1:
88
+ x = x.unsqueeze(0).unsqueeze(0) # (1, 1, T)
89
+ elif x.dim() == 2:
90
+ x = x.unsqueeze(1) # (B, 1, T)
91
+
92
+ x = torch.nn.functional.pad(x, (self.pad, self.pad), mode="reflect")
93
+ x = x.squeeze(1) # (B, T_padded)
94
+
95
+ frames = x.unfold(-1, self.n_fft, self.hop_length) # (B, frames, n_fft)
96
+
97
+ real = torch.matmul(frames, self.dft_real.T) # (B, frames, freq)
98
+ imag = torch.matmul(frames, self.dft_imag.T)
99
+
100
+ power = real ** 2 + imag ** 2 # (B, frames, freq)
101
+
102
+ # Match torch.stft output layout: (B, freq, frames)
103
+ return power.permute(0, 2, 1)
104
+
105
+
106
+ class AudioPreprocessor(torch.nn.Module, ABC):
107
+ @abstractmethod
108
+ def __call__(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
109
+ ...
110
+
111
+
112
+ class BaseFeatureSource(torch.nn.Module, ABC):
113
+ def __init__(self, preprocessors: Sequence[AudioPreprocessor]):
114
+ super(BaseFeatureSource, self).__init__()
115
+ self.preprocessors = preprocessors
116
+
117
+
118
+ ###############################################################################
119
+ # New Functions
120
+ ###############################################################################
121
+ def load_params(params: str | list | dict):
122
+ if not isinstance(params, str):
123
+ return params
124
+ else:
125
+ with open(params, "r") as infile:
126
+ return json.loads(infile.read())
127
+
128
+
129
+ def write_params(filename: str, params: list | dict):
130
+ with open(filename, "r") as outfile:
131
+ return outfile.write(json.dumps(params, indent=2))
132
+
133
+
134
+ def power_of_two(n: int):
135
+ return (n & (n - 1) == 0) and n != 0
136
+
137
+
138
+ def generate_filters(n_fft: int, n_filters: int, sample_rate: int, mel_type: Optional[MelType]):
139
+ n_freqs = n_fft // 2 + 1
140
+ f_min = 20.0
141
+ f_max = float(sample_rate // 2)
142
+
143
+ if mel_type:
144
+ return melscale_fbanks(
145
+ mel_type,
146
+ f_min,
147
+ f_max,
148
+ n_freqs,
149
+ n_filters,
150
+ sample_rate,
151
+ )
152
+ else:
153
+ return linear_fbanks(
154
+ n_freqs,
155
+ f_min,
156
+ f_max,
157
+ n_filters,
158
+ sample_rate,
159
+ )
160
+
161
+
162
+ def create_scaler(scaling_type: ScalingType):
163
+ if scaling_type == ScalingType.LOG:
164
+ return log_scale
165
+ else:
166
+ return AmplitudeToDB(stype=scaling_type.value, top_db=80.0)
167
+
168
+
169
+ def scale_spec(spec: torch.Tensor) -> torch.Tensor:
170
+ min_in_val = torch.min(spec)
171
+ max_in_val = torch.max(spec)
172
+ in_span = max_in_val - min_in_val
173
+
174
+ min_out_val = torch.zeros(1)
175
+ max_out_val = torch.ones(1)
176
+ out_span = max_out_val - min_out_val
177
+
178
+ scale_factor = out_span / in_span
179
+ return (spec - min_in_val) * scale_factor
180
+
181
+
182
+ def determine_spec_type(calc_mels: bool, calc_logs: bool, calc_cepstrum: bool):
183
+ if calc_mels:
184
+ if calc_cepstrum:
185
+ return SpecType.MFCC
186
+ elif calc_logs:
187
+ return SpecType.LOG_MEL
188
+ else:
189
+ return SpecType.MEL
190
+ else:
191
+ if calc_cepstrum:
192
+ return SpecType.LFCC
193
+ elif calc_logs:
194
+ return SpecType.LOG_STFT
195
+ else:
196
+ return SpecType.STFT
197
+
198
+
199
+ ###############################################################################
200
+ # Imported Legacy Code
201
+ # This code was extracted from v2.8.0 of [torchaudio](https://github.com/pytorch/audio)
202
+ # torchaudio is licensed under the BSD 2-Clause License, reprinted below
203
+ ###############################################################################
204
+ #
205
+ # Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
206
+ # All rights reserved.
207
+ #
208
+ # Redistribution and use in source and binary forms, with or without
209
+ # modification, are permitted provided that the following conditions are met:
210
+ #
211
+ # * Redistributions of source code must retain the above copyright notice, this
212
+ # list of conditions and the following disclaimer.
213
+ #
214
+ # * Redistributions in binary form must reproduce the above copyright notice,
215
+ # this list of conditions and the following disclaimer in the documentation
216
+ # and/or other materials provided with the distribution.
217
+ #
218
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
219
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
220
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
221
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
222
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
223
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
224
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
225
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
226
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
227
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
228
+ ###############################################################################
229
+
230
+ ###############################################################################
231
+ # Generate features
232
+ ###############################################################################
233
+
234
+ class AmplitudeToDB(torch.nn.Module):
235
+ def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
236
+ super(AmplitudeToDB, self).__init__()
237
+ self.stype = stype
238
+ if top_db is not None and top_db < 0:
239
+ raise ValueError("top_db must be positive value")
240
+ self.top_db = top_db
241
+ self.multiplier = 10.0 if stype == "power" else 20.0
242
+ self.amin = 1e-10
243
+ self.ref_value = 1.0
244
+ self.db_multiplier = math.log10(max(self.amin, self.ref_value))
245
+
246
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
247
+ x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
248
+ x_db -= self.multiplier * self.db_multiplier
249
+
250
+ if self.top_db:
251
+ max_ref = (x_db.max() - self.top_db)
252
+ x_db = torch.max(x_db, max_ref)
253
+
254
+ return x_db
255
+
256
+
257
+ def log_scale(waveform: torch.Tensor) -> torch.Tensor:
258
+ log_offset = 1e-6
259
+ return torch.log(waveform + log_offset)
260
+
261
+
262
+ def hz_to_mel(freq: float, mel_type: MelType) -> float:
263
+ if mel_type == MelType.OSHAUGHNESSY:
264
+ return 2595.0 * math.log10(1.0 + (freq / 700.0))
265
+ elif mel_type == MelType.FANT:
266
+ return (1000 / math.log10(2)) * math.log10(1.0 + (freq / 1000.0))
267
+ elif mel_type == MelType.LINDSAY_NORMAN:
268
+ return 2410.0 * math.log10(1.0 + (freq / 625.0))
269
+ else: # MelType.SLANEY
270
+ min_log_hz = 1000.0
271
+ f_sp = 200.0 / 3
272
+
273
+ if freq < min_log_hz:
274
+ return freq / f_sp
275
+ else:
276
+ min_log_mel = min_log_hz / f_sp
277
+ logstep = math.log(6.4) / 27.0
278
+ return min_log_mel + math.log(freq / min_log_hz) / logstep
279
+
280
+
281
+ def mel_to_hz(mels: torch.Tensor, mel_type: MelType) -> torch.Tensor:
282
+ if mel_type == MelType.OSHAUGHNESSY:
283
+ return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
284
+ elif mel_type == MelType.FANT:
285
+ mul = math.log10(2) / 1000
286
+ return 1000.0 * (10 ** (mels * mul) - 1.0)
287
+ elif mel_type == MelType.LINDSAY_NORMAN:
288
+ return 625.0 * (10.0 ** (mels / 2410.0) - 1.0)
289
+ else:
290
+ min_log_hz = 1000.0
291
+ f_sp = 200.0 / 3
292
+
293
+ freqs = f_sp * mels
294
+ min_log_mel = min_log_hz / f_sp
295
+
296
+ logstep = math.log(6.4) / 27.0
297
+
298
+ log_t = mels >= min_log_mel
299
+ freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
300
+
301
+ return freqs
302
+
303
+
304
+ def create_triangular_filterbank(
305
+ all_freqs: torch.Tensor,
306
+ f_pts: torch.Tensor,
307
+ ) -> torch.Tensor:
308
+ # Adopted from Librosa
309
+ # calculate the difference between each filter mid point and each stft freq point in hertz
310
+ f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
311
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
312
+ # create overlapping triangles
313
+ zero = torch.zeros(1)
314
+ down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
315
+ up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
316
+ fb = torch.max(zero, torch.min(down_slopes, up_slopes))
317
+
318
+ return fb
319
+
320
+
321
+ def melscale_fbanks(
322
+ mel_type: MelType,
323
+ f_min: float,
324
+ f_max: float,
325
+ n_freqs: int,
326
+ n_mels: int,
327
+ sample_rate: int,
328
+ ):
329
+ # freq bins
330
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
331
+
332
+ # calculate mel freq bins
333
+ m_min = hz_to_mel(f_min, mel_type=mel_type)
334
+ m_max = hz_to_mel(f_max, mel_type=mel_type)
335
+
336
+ m_pts = torch.linspace(m_min, m_max, n_mels + 2)
337
+ f_pts = mel_to_hz(m_pts, mel_type=mel_type)
338
+
339
+ # create filterbank
340
+ fb = create_triangular_filterbank(all_freqs, f_pts)
341
+
342
+ if (fb.max(dim=0).values == 0.0).any():
343
+ warnings.warn(
344
+ "At least one mel filterbank has all zero values. "
345
+ f"The value for `n_mels` ({n_mels}) may be set too high. "
346
+ f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
347
+ )
348
+
349
+ return fb
350
+
351
+
352
+ def linear_fbanks(
353
+ n_freqs: int,
354
+ f_min: float,
355
+ f_max: float,
356
+ n_filter: int,
357
+ sample_rate: int,
358
+ ) -> torch.Tensor:
359
+ # freq bins
360
+ all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
361
+
362
+ # filter mid-points
363
+ f_pts = torch.linspace(f_min, f_max, n_filter + 2)
364
+
365
+ # create filterbank
366
+ fb = create_triangular_filterbank(all_freqs, f_pts)
367
+
368
+ return fb
369
+
370
+
371
+ def create_dct(n_cepstrum: int, n_filters: int) -> torch.Tensor:
372
+ # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
373
+ n = torch.arange(float(n_filters))
374
+ k = torch.arange(float(n_cepstrum)).unsqueeze(1)
375
+ dct = torch.cos(math.pi / float(n_filters) * (n + 0.5) * k) # size (n_mfcc, n_mels)
376
+
377
+ dct[0] *= 1.0 / math.sqrt(2.0)
378
+ dct *= math.sqrt(2.0 / float(n_filters))
379
+ return dct.t()
380
+
381
+
382
+ ###############################################################################
383
+ # Load audio
384
+ ###############################################################################
385
+
386
+ def _get_sinc_resample_kernel(
387
+ orig_freq: int,
388
+ new_freq: int,
389
+ gcd: int,
390
+ lowpass_filter_width: int = 6,
391
+ rolloff: float = 0.99,
392
+ device: torch.device = torch.device("cpu"),
393
+ dtype: Optional[torch.dtype] = None,
394
+ ):
395
+ orig_freq = int(orig_freq) // gcd
396
+ new_freq = int(new_freq) // gcd
397
+
398
+ if lowpass_filter_width <= 0:
399
+ raise ValueError("Low pass filter width should be positive.")
400
+ base_freq = min(orig_freq, new_freq)
401
+ # This will perform antialiasing filtering by removing the highest frequencies.
402
+ # At first I thought I only needed this when downsampling, but when upsampling
403
+ # you will get edge artifacts without this, as the edge is equivalent to zero padding,
404
+ # which will add high freq artifacts.
405
+ base_freq *= rolloff
406
+
407
+ # The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
408
+ # using the sinc interpolation formula:
409
+ # x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
410
+ # We can then sample the function x(t) with a different sample rate:
411
+ # y[j] = x(j / new_freq)
412
+ # or,
413
+ # y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
414
+
415
+ # We see here that y[j] is the convolution of x[i] with a specific filter, for which
416
+ # we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
417
+ # But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
418
+ # Indeed:
419
+ # y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
420
+ # = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
421
+ # = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
422
+ # so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
423
+ # This will explain the F.conv1d after, with a stride of orig_freq.
424
+ width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
425
+ # If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
426
+ # they will have a lot of almost zero values to the left or to the right...
427
+ # There is probably a way to evaluate those filters more efficiently, but this is kept for
428
+ # future work.
429
+ idx_dtype = dtype if dtype is not None else torch.float64
430
+
431
+ idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
432
+
433
+ t = torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] / new_freq + idx
434
+ t *= base_freq
435
+ t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
436
+
437
+ # we do not use built in torch windows here as we need to evaluate the window
438
+ # at specific positions, not over a regular grid.
439
+ window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
440
+
441
+ t *= math.pi
442
+
443
+ scale = base_freq / orig_freq
444
+ kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
445
+ kernels *= window * scale
446
+
447
+ if dtype is None:
448
+ kernels = kernels.to(dtype=torch.float32)
449
+
450
+ return kernels, width
451
+
452
+
453
+ def _apply_sinc_resample_kernel(
454
+ waveform: torch.Tensor,
455
+ orig_freq: int,
456
+ new_freq: int,
457
+ gcd: int,
458
+ kernel: torch.Tensor,
459
+ width: int,
460
+ ):
461
+ orig_freq = int(orig_freq) // gcd
462
+ new_freq = int(new_freq) // gcd
463
+
464
+ # pack batch
465
+ shape = waveform.size()
466
+ waveform = waveform.view(-1, shape[-1])
467
+
468
+ num_wavs, length = waveform.shape
469
+ waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
470
+ resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
471
+ resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
472
+ target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
473
+ resampled = resampled[..., :target_length]
474
+
475
+ # unpack batch
476
+ resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
477
+ return resampled
478
+
479
+
480
+ def resample(
481
+ waveform: torch.Tensor,
482
+ orig_freq: int,
483
+ new_freq: int,
484
+ lowpass_filter_width: int = 6,
485
+ rolloff: float = 0.99,
486
+ ) -> torch.Tensor:
487
+ if orig_freq == new_freq:
488
+ return waveform
489
+
490
+ gcd = math.gcd(int(orig_freq), int(new_freq))
491
+
492
+ kernel, width = _get_sinc_resample_kernel(
493
+ orig_freq,
494
+ new_freq,
495
+ gcd,
496
+ lowpass_filter_width,
497
+ rolloff,
498
+ waveform.device,
499
+ waveform.dtype,
500
+ )
501
+ resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
502
+ return resampled
@@ -0,0 +1,176 @@
1
+ ###############################################################################
2
+ # Global Imports
3
+ ###############################################################################
4
+ from typing import Optional, Sequence
5
+ import warnings
6
+
7
+ ###############################################################################
8
+ # 3PP Imports
9
+ ###############################################################################
10
+ import torch
11
+
12
+ ###############################################################################
13
+ # Local Imports
14
+ ###############################################################################
15
+ from ._common import MelType, ScalingType, ExportableSTFT, AudioPreprocessor, BaseFeatureSource
16
+ from ._common import power_of_two, create_dct, create_scaler, generate_filters, scale_spec
17
+
18
+
19
+ ###############################################################################
20
+ # Helper Classes
21
+ ###############################################################################
22
+ class ChannelConfig:
23
+ def __init__(self,
24
+ *,
25
+ is_mel: bool,
26
+ is_log: bool,
27
+ is_cepstrum: bool,
28
+ ):
29
+ self.is_mel = is_mel
30
+ self.is_log = is_log
31
+ self.is_cepstrum = is_cepstrum
32
+
33
+ self.key = self._make_channel_key()
34
+
35
+ def get_key(self):
36
+ return self.key
37
+
38
+ def _make_channel_key(self):
39
+ if self.is_cepstrum:
40
+ return "mfcc" if self.is_mel else "lfcc"
41
+ elif self.is_log:
42
+ return "mel_log_spec" if self.is_mel else "lin_log_spec"
43
+ else:
44
+ return "mel_freq_spec" if self.is_mel else "lin_freq_spec"
45
+
46
+
47
+ ###############################################################################
48
+ # Classes
49
+ ###############################################################################
50
+ class EfficientFeatureSource(BaseFeatureSource):
51
+ def __init__(self,
52
+ sample_rate: int,
53
+ channels: list[ChannelConfig],
54
+ preprocessors: Sequence[AudioPreprocessor] = [],
55
+ *,
56
+ # For all spectra
57
+ n_fft: Optional[int] = None,
58
+ hop_length: Optional[int] = None,
59
+
60
+ # Shared
61
+ n_filters: Optional[int] = None,
62
+
63
+ # For all cepstra
64
+ cepstral_coefficients: Optional[int] = None,
65
+ ):
66
+ super(EfficientFeatureSource, self).__init__(preprocessors)
67
+
68
+ self.channels = channels
69
+
70
+ self.has_lin_freq = any([not c.is_mel for c in channels])
71
+ self.has_mel_freq = any([c.is_mel for c in channels])
72
+
73
+ self.has_lfcc = self.has_lin_freq and any([c.is_cepstrum for c in channels])
74
+ self.has_mfcc = self.has_mel_freq and any([c.is_cepstrum for c in channels])
75
+ self.has_cepstrum = self.has_lfcc or self.has_mfcc
76
+
77
+ self.has_lin_log = any([not c.is_mel and c.is_log for c in channels])
78
+ self.has_mel_log = any([c.is_mel and c.is_log for c in channels])
79
+ self.has_log_scale = any([c.is_log for c in channels]) or self.has_cepstrum
80
+
81
+ # Required configs
82
+ self.sample_rate = sample_rate
83
+
84
+ # Universal configs
85
+ if n_fft is not None and not power_of_two(n_fft):
86
+ raise ValueError("n_fft must be a power of 2")
87
+ self.n_fft = n_fft or 1024
88
+
89
+ if hop_length is not None and hop_length > (self.n_fft // 2):
90
+ warnings.warn(f"hop_length should be set to no more than 1/2 the FFT window size, or {self.n_fft // 2} mels for n_fft = {self.n_fft} (currently {hop_length})")
91
+ self.hop_length = hop_length or self.n_fft // 4
92
+
93
+ # Shared configs (mels, MFCC, LFCC)
94
+ if n_filters is not None and n_filters > (self.n_fft // 8):
95
+ warnings.warn(f"n_filters should be set to no more than 1/8 the FFT window size, or {self.n_fft // 8} filters for n_fft = {self.n_fft} (currently {n_filters})")
96
+ self.n_filters = n_filters or self.n_fft // 8
97
+
98
+ # Cepstral configs
99
+ if cepstral_coefficients is not None and cepstral_coefficients > self.n_filters:
100
+ raise ValueError(f"cepstral_coefficients must be no greater than n_mels (currently {cepstral_coefficients}/{self.n_filters})")
101
+ self.cepstral_coefficients = cepstral_coefficients or self.n_filters
102
+
103
+ ###################
104
+ # Spec gen code
105
+ ###################
106
+
107
+ # Basic spectrogram
108
+ self._stft = ExportableSTFT(self.n_fft, self.hop_length)
109
+
110
+ if self.has_lin_freq:
111
+ self.register_buffer("lin_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, None))
112
+ else:
113
+ self.lin_filt = None
114
+
115
+ if self.has_mel_freq:
116
+ self.register_buffer("mel_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, MelType.OSHAUGHNESSY))
117
+ else:
118
+ self.mel_filt = None
119
+
120
+ # DB scaling, if necessary
121
+ self.amplitude_to_DB = create_scaler(ScalingType.POWER) if self.has_log_scale else None
122
+
123
+ # Cepstrum, if necessary
124
+ if self.has_cepstrum:
125
+ self.register_buffer("dct", create_dct(self.cepstral_coefficients, self.n_filters))
126
+ else:
127
+ self.dct = None
128
+
129
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
130
+ for preproc in self.preprocessors:
131
+ wav = preproc(wav)
132
+
133
+ spec = self._stft(wav)
134
+
135
+ if self.lin_filt is not None:
136
+ lin_freq_spec = torch.matmul(spec.transpose(-1, -2), self.lin_filt).transpose(-1, -2)
137
+ else:
138
+ lin_freq_spec = None
139
+
140
+ if self.mel_filt is not None:
141
+ mel_freq_spec = torch.matmul(spec.transpose(-1, -2), self.mel_filt).transpose(-1, -2)
142
+ else:
143
+ mel_freq_spec = None
144
+
145
+ if self.has_lin_log and (lin_freq_spec is not None) and self.amplitude_to_DB:
146
+ lin_log_spec = self.amplitude_to_DB(lin_freq_spec)
147
+ else:
148
+ lin_log_spec = None
149
+
150
+ if self.has_mel_log and (mel_freq_spec is not None) and self.amplitude_to_DB:
151
+ mel_log_spec = self.amplitude_to_DB(mel_freq_spec)
152
+ else:
153
+ mel_log_spec = None
154
+
155
+ if self.has_lfcc and (lin_log_spec is not None) and (self.dct is not None):
156
+ lfcc = torch.matmul(lin_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
157
+ else:
158
+ lfcc = None
159
+
160
+ if self.has_mfcc and (mel_log_spec is not None) and (self.dct is not None):
161
+ mfcc = torch.matmul(mel_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
162
+ else:
163
+ mfcc = None
164
+
165
+ return self._collate(**{
166
+ "lin_freq_spec": lin_freq_spec,
167
+ "mel_freq_spec": lin_freq_spec,
168
+ "lin_log_spec": lin_log_spec,
169
+ "mel_log_spec": mel_log_spec,
170
+ "lfcc": lfcc,
171
+ "mfcc": mfcc,
172
+ })
173
+
174
+ def _collate(self, **kwargs):
175
+ specs = [scale_spec(kwargs[c.key]) for c in self.channels]
176
+ return torch.stack(specs, dim=1)