AudioMlSpecTools 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- audiomlspectools-0.5.0/AudioMlSpecTools/__init__.py +6 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/_common.py +502 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/efficient_features.py +176 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/flexible_features.py +218 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/preproc.py +50 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/stft_features.py +62 -0
- audiomlspectools-0.5.0/AudioMlSpecTools/wav.py +108 -0
- audiomlspectools-0.5.0/AudioMlSpecTools.egg-info/PKG-INFO +74 -0
- audiomlspectools-0.5.0/AudioMlSpecTools.egg-info/SOURCES.txt +15 -0
- audiomlspectools-0.5.0/AudioMlSpecTools.egg-info/dependency_links.txt +1 -0
- audiomlspectools-0.5.0/AudioMlSpecTools.egg-info/requires.txt +2 -0
- audiomlspectools-0.5.0/AudioMlSpecTools.egg-info/top_level.txt +1 -0
- audiomlspectools-0.5.0/LICENSE +7 -0
- audiomlspectools-0.5.0/PKG-INFO +74 -0
- audiomlspectools-0.5.0/README.md +54 -0
- audiomlspectools-0.5.0/pyproject.toml +35 -0
- audiomlspectools-0.5.0/setup.cfg +4 -0
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from ._common import MelType, ScalingType, SpecType, AmplitudeToDB, ExportableSTFT # noqa
|
|
2
|
+
from .efficient_features import ChannelConfig, EfficientFeatureSource # noqa
|
|
3
|
+
from .flexible_features import FeatureSource, FeatureChannel # noqa
|
|
4
|
+
from .preproc import AudioPreprocessor, HighPassFilter # noqa
|
|
5
|
+
from .stft_features import FullRangeStftFeatureSource # noqa
|
|
6
|
+
from .wav import load_wav, set_audio_length, list_audio_files, WavReader # noqa
|
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from enum import Enum
|
|
6
|
+
import json
|
|
7
|
+
import math
|
|
8
|
+
from typing import Optional, Sequence
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
###############################################################################
|
|
12
|
+
# 3PP Imports
|
|
13
|
+
###############################################################################
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
###############################################################################
|
|
19
|
+
# Enumerated Types
|
|
20
|
+
###############################################################################
|
|
21
|
+
class SpecType(Enum):
|
|
22
|
+
'''
|
|
23
|
+
Possible spectrum types for feature extraction.
|
|
24
|
+
|
|
25
|
+
The literature is conflicted on which, if any, provides the best results.
|
|
26
|
+
'''
|
|
27
|
+
STFT = "STFT"
|
|
28
|
+
LOG_STFT = "LOG_STFT"
|
|
29
|
+
LFCC = "LFCC"
|
|
30
|
+
MEL = "MEL"
|
|
31
|
+
LOG_MEL = "LOG_MEL"
|
|
32
|
+
MFCC = "MFCC"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MelType(Enum):
|
|
36
|
+
OSHAUGHNESSY = "O'Shaughnessy"
|
|
37
|
+
FANT = "Fant"
|
|
38
|
+
LINDSAY_NORMAN = "Lindsay & Norman"
|
|
39
|
+
SLANEY = "Slaney"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ScalingType(Enum):
|
|
43
|
+
POWER = "power"
|
|
44
|
+
MAGNITUDE = "magnitude"
|
|
45
|
+
LOG = "log"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
###############################################################################
|
|
49
|
+
# New Classes
|
|
50
|
+
###############################################################################
|
|
51
|
+
class ExportableSTFT(torch.nn.Module):
|
|
52
|
+
"""
|
|
53
|
+
Exportable to Executorch. Created by Claude with supervision.
|
|
54
|
+
|
|
55
|
+
Results are very slightly different than Torch but show no noticeable difference in model accuracy or training.
|
|
56
|
+
"""
|
|
57
|
+
def __init__(self,
|
|
58
|
+
n_fft: int,
|
|
59
|
+
hop_length: int,
|
|
60
|
+
*,
|
|
61
|
+
win_length: Optional[int] = None,
|
|
62
|
+
window: Optional[torch.Tensor] = None
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
_w = window if window is not None else torch.hann_window(n_fft)
|
|
67
|
+
_window_len = win_length or n_fft
|
|
68
|
+
if _window_len < n_fft:
|
|
69
|
+
pad_left = (n_fft - _window_len) // 2
|
|
70
|
+
pad_right = n_fft - _window_len - pad_left
|
|
71
|
+
_w = torch.nn.functional.pad(_w, (pad_left, pad_right))
|
|
72
|
+
|
|
73
|
+
# Only compute onesided bins: n_fft//2+1
|
|
74
|
+
k = torch.arange(n_fft // 2 + 1).unsqueeze(1) # (freq, 1)
|
|
75
|
+
n = torch.arange(n_fft).unsqueeze(0) # (1, n_fft)
|
|
76
|
+
angles = -2 * torch.pi * k * n / n_fft # (freq, n_fft)
|
|
77
|
+
|
|
78
|
+
self.hop_length = hop_length
|
|
79
|
+
self.n_fft = n_fft
|
|
80
|
+
self.pad = n_fft // 2
|
|
81
|
+
|
|
82
|
+
# Shape: (n_fft//2+1, n_fft)
|
|
83
|
+
self.register_buffer("dft_real", torch.cos(angles) * _w)
|
|
84
|
+
self.register_buffer("dft_imag", torch.sin(angles) * _w)
|
|
85
|
+
|
|
86
|
+
def forward(self, x: torch.Tensor):
|
|
87
|
+
if x.dim() == 1:
|
|
88
|
+
x = x.unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
|
89
|
+
elif x.dim() == 2:
|
|
90
|
+
x = x.unsqueeze(1) # (B, 1, T)
|
|
91
|
+
|
|
92
|
+
x = torch.nn.functional.pad(x, (self.pad, self.pad), mode="reflect")
|
|
93
|
+
x = x.squeeze(1) # (B, T_padded)
|
|
94
|
+
|
|
95
|
+
frames = x.unfold(-1, self.n_fft, self.hop_length) # (B, frames, n_fft)
|
|
96
|
+
|
|
97
|
+
real = torch.matmul(frames, self.dft_real.T) # (B, frames, freq)
|
|
98
|
+
imag = torch.matmul(frames, self.dft_imag.T)
|
|
99
|
+
|
|
100
|
+
power = real ** 2 + imag ** 2 # (B, frames, freq)
|
|
101
|
+
|
|
102
|
+
# Match torch.stft output layout: (B, freq, frames)
|
|
103
|
+
return power.permute(0, 2, 1)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AudioPreprocessor(torch.nn.Module, ABC):
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def __call__(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class BaseFeatureSource(torch.nn.Module, ABC):
|
|
113
|
+
def __init__(self, preprocessors: Sequence[AudioPreprocessor]):
|
|
114
|
+
super(BaseFeatureSource, self).__init__()
|
|
115
|
+
self.preprocessors = preprocessors
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
###############################################################################
|
|
119
|
+
# New Functions
|
|
120
|
+
###############################################################################
|
|
121
|
+
def load_params(params: str | list | dict):
|
|
122
|
+
if not isinstance(params, str):
|
|
123
|
+
return params
|
|
124
|
+
else:
|
|
125
|
+
with open(params, "r") as infile:
|
|
126
|
+
return json.loads(infile.read())
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def write_params(filename: str, params: list | dict):
|
|
130
|
+
with open(filename, "r") as outfile:
|
|
131
|
+
return outfile.write(json.dumps(params, indent=2))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def power_of_two(n: int):
|
|
135
|
+
return (n & (n - 1) == 0) and n != 0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def generate_filters(n_fft: int, n_filters: int, sample_rate: int, mel_type: Optional[MelType]):
|
|
139
|
+
n_freqs = n_fft // 2 + 1
|
|
140
|
+
f_min = 20.0
|
|
141
|
+
f_max = float(sample_rate // 2)
|
|
142
|
+
|
|
143
|
+
if mel_type:
|
|
144
|
+
return melscale_fbanks(
|
|
145
|
+
mel_type,
|
|
146
|
+
f_min,
|
|
147
|
+
f_max,
|
|
148
|
+
n_freqs,
|
|
149
|
+
n_filters,
|
|
150
|
+
sample_rate,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
return linear_fbanks(
|
|
154
|
+
n_freqs,
|
|
155
|
+
f_min,
|
|
156
|
+
f_max,
|
|
157
|
+
n_filters,
|
|
158
|
+
sample_rate,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def create_scaler(scaling_type: ScalingType):
|
|
163
|
+
if scaling_type == ScalingType.LOG:
|
|
164
|
+
return log_scale
|
|
165
|
+
else:
|
|
166
|
+
return AmplitudeToDB(stype=scaling_type.value, top_db=80.0)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def scale_spec(spec: torch.Tensor) -> torch.Tensor:
|
|
170
|
+
min_in_val = torch.min(spec)
|
|
171
|
+
max_in_val = torch.max(spec)
|
|
172
|
+
in_span = max_in_val - min_in_val
|
|
173
|
+
|
|
174
|
+
min_out_val = torch.zeros(1)
|
|
175
|
+
max_out_val = torch.ones(1)
|
|
176
|
+
out_span = max_out_val - min_out_val
|
|
177
|
+
|
|
178
|
+
scale_factor = out_span / in_span
|
|
179
|
+
return (spec - min_in_val) * scale_factor
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def determine_spec_type(calc_mels: bool, calc_logs: bool, calc_cepstrum: bool):
|
|
183
|
+
if calc_mels:
|
|
184
|
+
if calc_cepstrum:
|
|
185
|
+
return SpecType.MFCC
|
|
186
|
+
elif calc_logs:
|
|
187
|
+
return SpecType.LOG_MEL
|
|
188
|
+
else:
|
|
189
|
+
return SpecType.MEL
|
|
190
|
+
else:
|
|
191
|
+
if calc_cepstrum:
|
|
192
|
+
return SpecType.LFCC
|
|
193
|
+
elif calc_logs:
|
|
194
|
+
return SpecType.LOG_STFT
|
|
195
|
+
else:
|
|
196
|
+
return SpecType.STFT
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
###############################################################################
|
|
200
|
+
# Imported Legacy Code
|
|
201
|
+
# This code was extracted from v2.8.0 of [torchaudio](https://github.com/pytorch/audio)
|
|
202
|
+
# torchaudio is licensed under the BSD 2-Clause License, reprinted below
|
|
203
|
+
###############################################################################
|
|
204
|
+
#
|
|
205
|
+
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
|
206
|
+
# All rights reserved.
|
|
207
|
+
#
|
|
208
|
+
# Redistribution and use in source and binary forms, with or without
|
|
209
|
+
# modification, are permitted provided that the following conditions are met:
|
|
210
|
+
#
|
|
211
|
+
# * Redistributions of source code must retain the above copyright notice, this
|
|
212
|
+
# list of conditions and the following disclaimer.
|
|
213
|
+
#
|
|
214
|
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
215
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
216
|
+
# and/or other materials provided with the distribution.
|
|
217
|
+
#
|
|
218
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
219
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
220
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
221
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
222
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
223
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
224
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
225
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
226
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
227
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
228
|
+
###############################################################################
|
|
229
|
+
|
|
230
|
+
###############################################################################
|
|
231
|
+
# Generate features
|
|
232
|
+
###############################################################################
|
|
233
|
+
|
|
234
|
+
class AmplitudeToDB(torch.nn.Module):
|
|
235
|
+
def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
|
|
236
|
+
super(AmplitudeToDB, self).__init__()
|
|
237
|
+
self.stype = stype
|
|
238
|
+
if top_db is not None and top_db < 0:
|
|
239
|
+
raise ValueError("top_db must be positive value")
|
|
240
|
+
self.top_db = top_db
|
|
241
|
+
self.multiplier = 10.0 if stype == "power" else 20.0
|
|
242
|
+
self.amin = 1e-10
|
|
243
|
+
self.ref_value = 1.0
|
|
244
|
+
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
|
|
245
|
+
|
|
246
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
247
|
+
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
|
|
248
|
+
x_db -= self.multiplier * self.db_multiplier
|
|
249
|
+
|
|
250
|
+
if self.top_db:
|
|
251
|
+
max_ref = (x_db.max() - self.top_db)
|
|
252
|
+
x_db = torch.max(x_db, max_ref)
|
|
253
|
+
|
|
254
|
+
return x_db
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def log_scale(waveform: torch.Tensor) -> torch.Tensor:
|
|
258
|
+
log_offset = 1e-6
|
|
259
|
+
return torch.log(waveform + log_offset)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def hz_to_mel(freq: float, mel_type: MelType) -> float:
|
|
263
|
+
if mel_type == MelType.OSHAUGHNESSY:
|
|
264
|
+
return 2595.0 * math.log10(1.0 + (freq / 700.0))
|
|
265
|
+
elif mel_type == MelType.FANT:
|
|
266
|
+
return (1000 / math.log10(2)) * math.log10(1.0 + (freq / 1000.0))
|
|
267
|
+
elif mel_type == MelType.LINDSAY_NORMAN:
|
|
268
|
+
return 2410.0 * math.log10(1.0 + (freq / 625.0))
|
|
269
|
+
else: # MelType.SLANEY
|
|
270
|
+
min_log_hz = 1000.0
|
|
271
|
+
f_sp = 200.0 / 3
|
|
272
|
+
|
|
273
|
+
if freq < min_log_hz:
|
|
274
|
+
return freq / f_sp
|
|
275
|
+
else:
|
|
276
|
+
min_log_mel = min_log_hz / f_sp
|
|
277
|
+
logstep = math.log(6.4) / 27.0
|
|
278
|
+
return min_log_mel + math.log(freq / min_log_hz) / logstep
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def mel_to_hz(mels: torch.Tensor, mel_type: MelType) -> torch.Tensor:
|
|
282
|
+
if mel_type == MelType.OSHAUGHNESSY:
|
|
283
|
+
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
|
284
|
+
elif mel_type == MelType.FANT:
|
|
285
|
+
mul = math.log10(2) / 1000
|
|
286
|
+
return 1000.0 * (10 ** (mels * mul) - 1.0)
|
|
287
|
+
elif mel_type == MelType.LINDSAY_NORMAN:
|
|
288
|
+
return 625.0 * (10.0 ** (mels / 2410.0) - 1.0)
|
|
289
|
+
else:
|
|
290
|
+
min_log_hz = 1000.0
|
|
291
|
+
f_sp = 200.0 / 3
|
|
292
|
+
|
|
293
|
+
freqs = f_sp * mels
|
|
294
|
+
min_log_mel = min_log_hz / f_sp
|
|
295
|
+
|
|
296
|
+
logstep = math.log(6.4) / 27.0
|
|
297
|
+
|
|
298
|
+
log_t = mels >= min_log_mel
|
|
299
|
+
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
|
300
|
+
|
|
301
|
+
return freqs
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def create_triangular_filterbank(
|
|
305
|
+
all_freqs: torch.Tensor,
|
|
306
|
+
f_pts: torch.Tensor,
|
|
307
|
+
) -> torch.Tensor:
|
|
308
|
+
# Adopted from Librosa
|
|
309
|
+
# calculate the difference between each filter mid point and each stft freq point in hertz
|
|
310
|
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
|
311
|
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
|
|
312
|
+
# create overlapping triangles
|
|
313
|
+
zero = torch.zeros(1)
|
|
314
|
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
|
|
315
|
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
|
|
316
|
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
|
317
|
+
|
|
318
|
+
return fb
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def melscale_fbanks(
|
|
322
|
+
mel_type: MelType,
|
|
323
|
+
f_min: float,
|
|
324
|
+
f_max: float,
|
|
325
|
+
n_freqs: int,
|
|
326
|
+
n_mels: int,
|
|
327
|
+
sample_rate: int,
|
|
328
|
+
):
|
|
329
|
+
# freq bins
|
|
330
|
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
|
331
|
+
|
|
332
|
+
# calculate mel freq bins
|
|
333
|
+
m_min = hz_to_mel(f_min, mel_type=mel_type)
|
|
334
|
+
m_max = hz_to_mel(f_max, mel_type=mel_type)
|
|
335
|
+
|
|
336
|
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
|
337
|
+
f_pts = mel_to_hz(m_pts, mel_type=mel_type)
|
|
338
|
+
|
|
339
|
+
# create filterbank
|
|
340
|
+
fb = create_triangular_filterbank(all_freqs, f_pts)
|
|
341
|
+
|
|
342
|
+
if (fb.max(dim=0).values == 0.0).any():
|
|
343
|
+
warnings.warn(
|
|
344
|
+
"At least one mel filterbank has all zero values. "
|
|
345
|
+
f"The value for `n_mels` ({n_mels}) may be set too high. "
|
|
346
|
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
return fb
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def linear_fbanks(
|
|
353
|
+
n_freqs: int,
|
|
354
|
+
f_min: float,
|
|
355
|
+
f_max: float,
|
|
356
|
+
n_filter: int,
|
|
357
|
+
sample_rate: int,
|
|
358
|
+
) -> torch.Tensor:
|
|
359
|
+
# freq bins
|
|
360
|
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
|
361
|
+
|
|
362
|
+
# filter mid-points
|
|
363
|
+
f_pts = torch.linspace(f_min, f_max, n_filter + 2)
|
|
364
|
+
|
|
365
|
+
# create filterbank
|
|
366
|
+
fb = create_triangular_filterbank(all_freqs, f_pts)
|
|
367
|
+
|
|
368
|
+
return fb
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def create_dct(n_cepstrum: int, n_filters: int) -> torch.Tensor:
|
|
372
|
+
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
|
|
373
|
+
n = torch.arange(float(n_filters))
|
|
374
|
+
k = torch.arange(float(n_cepstrum)).unsqueeze(1)
|
|
375
|
+
dct = torch.cos(math.pi / float(n_filters) * (n + 0.5) * k) # size (n_mfcc, n_mels)
|
|
376
|
+
|
|
377
|
+
dct[0] *= 1.0 / math.sqrt(2.0)
|
|
378
|
+
dct *= math.sqrt(2.0 / float(n_filters))
|
|
379
|
+
return dct.t()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
###############################################################################
|
|
383
|
+
# Load audio
|
|
384
|
+
###############################################################################
|
|
385
|
+
|
|
386
|
+
def _get_sinc_resample_kernel(
|
|
387
|
+
orig_freq: int,
|
|
388
|
+
new_freq: int,
|
|
389
|
+
gcd: int,
|
|
390
|
+
lowpass_filter_width: int = 6,
|
|
391
|
+
rolloff: float = 0.99,
|
|
392
|
+
device: torch.device = torch.device("cpu"),
|
|
393
|
+
dtype: Optional[torch.dtype] = None,
|
|
394
|
+
):
|
|
395
|
+
orig_freq = int(orig_freq) // gcd
|
|
396
|
+
new_freq = int(new_freq) // gcd
|
|
397
|
+
|
|
398
|
+
if lowpass_filter_width <= 0:
|
|
399
|
+
raise ValueError("Low pass filter width should be positive.")
|
|
400
|
+
base_freq = min(orig_freq, new_freq)
|
|
401
|
+
# This will perform antialiasing filtering by removing the highest frequencies.
|
|
402
|
+
# At first I thought I only needed this when downsampling, but when upsampling
|
|
403
|
+
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
|
|
404
|
+
# which will add high freq artifacts.
|
|
405
|
+
base_freq *= rolloff
|
|
406
|
+
|
|
407
|
+
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
|
|
408
|
+
# using the sinc interpolation formula:
|
|
409
|
+
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
|
|
410
|
+
# We can then sample the function x(t) with a different sample rate:
|
|
411
|
+
# y[j] = x(j / new_freq)
|
|
412
|
+
# or,
|
|
413
|
+
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
|
|
414
|
+
|
|
415
|
+
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
|
|
416
|
+
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
|
|
417
|
+
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
|
|
418
|
+
# Indeed:
|
|
419
|
+
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
|
|
420
|
+
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
|
|
421
|
+
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
|
|
422
|
+
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
|
|
423
|
+
# This will explain the F.conv1d after, with a stride of orig_freq.
|
|
424
|
+
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
|
|
425
|
+
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
|
|
426
|
+
# they will have a lot of almost zero values to the left or to the right...
|
|
427
|
+
# There is probably a way to evaluate those filters more efficiently, but this is kept for
|
|
428
|
+
# future work.
|
|
429
|
+
idx_dtype = dtype if dtype is not None else torch.float64
|
|
430
|
+
|
|
431
|
+
idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
|
|
432
|
+
|
|
433
|
+
t = torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] / new_freq + idx
|
|
434
|
+
t *= base_freq
|
|
435
|
+
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
|
|
436
|
+
|
|
437
|
+
# we do not use built in torch windows here as we need to evaluate the window
|
|
438
|
+
# at specific positions, not over a regular grid.
|
|
439
|
+
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
|
|
440
|
+
|
|
441
|
+
t *= math.pi
|
|
442
|
+
|
|
443
|
+
scale = base_freq / orig_freq
|
|
444
|
+
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
|
|
445
|
+
kernels *= window * scale
|
|
446
|
+
|
|
447
|
+
if dtype is None:
|
|
448
|
+
kernels = kernels.to(dtype=torch.float32)
|
|
449
|
+
|
|
450
|
+
return kernels, width
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _apply_sinc_resample_kernel(
|
|
454
|
+
waveform: torch.Tensor,
|
|
455
|
+
orig_freq: int,
|
|
456
|
+
new_freq: int,
|
|
457
|
+
gcd: int,
|
|
458
|
+
kernel: torch.Tensor,
|
|
459
|
+
width: int,
|
|
460
|
+
):
|
|
461
|
+
orig_freq = int(orig_freq) // gcd
|
|
462
|
+
new_freq = int(new_freq) // gcd
|
|
463
|
+
|
|
464
|
+
# pack batch
|
|
465
|
+
shape = waveform.size()
|
|
466
|
+
waveform = waveform.view(-1, shape[-1])
|
|
467
|
+
|
|
468
|
+
num_wavs, length = waveform.shape
|
|
469
|
+
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
|
|
470
|
+
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
|
|
471
|
+
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
|
|
472
|
+
target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
|
|
473
|
+
resampled = resampled[..., :target_length]
|
|
474
|
+
|
|
475
|
+
# unpack batch
|
|
476
|
+
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
|
|
477
|
+
return resampled
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def resample(
|
|
481
|
+
waveform: torch.Tensor,
|
|
482
|
+
orig_freq: int,
|
|
483
|
+
new_freq: int,
|
|
484
|
+
lowpass_filter_width: int = 6,
|
|
485
|
+
rolloff: float = 0.99,
|
|
486
|
+
) -> torch.Tensor:
|
|
487
|
+
if orig_freq == new_freq:
|
|
488
|
+
return waveform
|
|
489
|
+
|
|
490
|
+
gcd = math.gcd(int(orig_freq), int(new_freq))
|
|
491
|
+
|
|
492
|
+
kernel, width = _get_sinc_resample_kernel(
|
|
493
|
+
orig_freq,
|
|
494
|
+
new_freq,
|
|
495
|
+
gcd,
|
|
496
|
+
lowpass_filter_width,
|
|
497
|
+
rolloff,
|
|
498
|
+
waveform.device,
|
|
499
|
+
waveform.dtype,
|
|
500
|
+
)
|
|
501
|
+
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
|
|
502
|
+
return resampled
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
###############################################################################
|
|
8
|
+
# 3PP Imports
|
|
9
|
+
###############################################################################
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
###############################################################################
|
|
13
|
+
# Local Imports
|
|
14
|
+
###############################################################################
|
|
15
|
+
from ._common import MelType, ScalingType, ExportableSTFT, AudioPreprocessor, BaseFeatureSource
|
|
16
|
+
from ._common import power_of_two, create_dct, create_scaler, generate_filters, scale_spec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
###############################################################################
|
|
20
|
+
# Helper Classes
|
|
21
|
+
###############################################################################
|
|
22
|
+
class ChannelConfig:
|
|
23
|
+
def __init__(self,
|
|
24
|
+
*,
|
|
25
|
+
is_mel: bool,
|
|
26
|
+
is_log: bool,
|
|
27
|
+
is_cepstrum: bool,
|
|
28
|
+
):
|
|
29
|
+
self.is_mel = is_mel
|
|
30
|
+
self.is_log = is_log
|
|
31
|
+
self.is_cepstrum = is_cepstrum
|
|
32
|
+
|
|
33
|
+
self.key = self._make_channel_key()
|
|
34
|
+
|
|
35
|
+
def get_key(self):
|
|
36
|
+
return self.key
|
|
37
|
+
|
|
38
|
+
def _make_channel_key(self):
|
|
39
|
+
if self.is_cepstrum:
|
|
40
|
+
return "mfcc" if self.is_mel else "lfcc"
|
|
41
|
+
elif self.is_log:
|
|
42
|
+
return "mel_log_spec" if self.is_mel else "lin_log_spec"
|
|
43
|
+
else:
|
|
44
|
+
return "mel_freq_spec" if self.is_mel else "lin_freq_spec"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
###############################################################################
|
|
48
|
+
# Classes
|
|
49
|
+
###############################################################################
|
|
50
|
+
class EfficientFeatureSource(BaseFeatureSource):
|
|
51
|
+
def __init__(self,
|
|
52
|
+
sample_rate: int,
|
|
53
|
+
channels: list[ChannelConfig],
|
|
54
|
+
preprocessors: Sequence[AudioPreprocessor] = [],
|
|
55
|
+
*,
|
|
56
|
+
# For all spectra
|
|
57
|
+
n_fft: Optional[int] = None,
|
|
58
|
+
hop_length: Optional[int] = None,
|
|
59
|
+
|
|
60
|
+
# Shared
|
|
61
|
+
n_filters: Optional[int] = None,
|
|
62
|
+
|
|
63
|
+
# For all cepstra
|
|
64
|
+
cepstral_coefficients: Optional[int] = None,
|
|
65
|
+
):
|
|
66
|
+
super(EfficientFeatureSource, self).__init__(preprocessors)
|
|
67
|
+
|
|
68
|
+
self.channels = channels
|
|
69
|
+
|
|
70
|
+
self.has_lin_freq = any([not c.is_mel for c in channels])
|
|
71
|
+
self.has_mel_freq = any([c.is_mel for c in channels])
|
|
72
|
+
|
|
73
|
+
self.has_lfcc = self.has_lin_freq and any([c.is_cepstrum for c in channels])
|
|
74
|
+
self.has_mfcc = self.has_mel_freq and any([c.is_cepstrum for c in channels])
|
|
75
|
+
self.has_cepstrum = self.has_lfcc or self.has_mfcc
|
|
76
|
+
|
|
77
|
+
self.has_lin_log = any([not c.is_mel and c.is_log for c in channels])
|
|
78
|
+
self.has_mel_log = any([c.is_mel and c.is_log for c in channels])
|
|
79
|
+
self.has_log_scale = any([c.is_log for c in channels]) or self.has_cepstrum
|
|
80
|
+
|
|
81
|
+
# Required configs
|
|
82
|
+
self.sample_rate = sample_rate
|
|
83
|
+
|
|
84
|
+
# Universal configs
|
|
85
|
+
if n_fft is not None and not power_of_two(n_fft):
|
|
86
|
+
raise ValueError("n_fft must be a power of 2")
|
|
87
|
+
self.n_fft = n_fft or 1024
|
|
88
|
+
|
|
89
|
+
if hop_length is not None and hop_length > (self.n_fft // 2):
|
|
90
|
+
warnings.warn(f"hop_length should be set to no more than 1/2 the FFT window size, or {self.n_fft // 2} mels for n_fft = {self.n_fft} (currently {hop_length})")
|
|
91
|
+
self.hop_length = hop_length or self.n_fft // 4
|
|
92
|
+
|
|
93
|
+
# Shared configs (mels, MFCC, LFCC)
|
|
94
|
+
if n_filters is not None and n_filters > (self.n_fft // 8):
|
|
95
|
+
warnings.warn(f"n_filters should be set to no more than 1/8 the FFT window size, or {self.n_fft // 8} filters for n_fft = {self.n_fft} (currently {n_filters})")
|
|
96
|
+
self.n_filters = n_filters or self.n_fft // 8
|
|
97
|
+
|
|
98
|
+
# Cepstral configs
|
|
99
|
+
if cepstral_coefficients is not None and cepstral_coefficients > self.n_filters:
|
|
100
|
+
raise ValueError(f"cepstral_coefficients must be no greater than n_mels (currently {cepstral_coefficients}/{self.n_filters})")
|
|
101
|
+
self.cepstral_coefficients = cepstral_coefficients or self.n_filters
|
|
102
|
+
|
|
103
|
+
###################
|
|
104
|
+
# Spec gen code
|
|
105
|
+
###################
|
|
106
|
+
|
|
107
|
+
# Basic spectrogram
|
|
108
|
+
self._stft = ExportableSTFT(self.n_fft, self.hop_length)
|
|
109
|
+
|
|
110
|
+
if self.has_lin_freq:
|
|
111
|
+
self.register_buffer("lin_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, None))
|
|
112
|
+
else:
|
|
113
|
+
self.lin_filt = None
|
|
114
|
+
|
|
115
|
+
if self.has_mel_freq:
|
|
116
|
+
self.register_buffer("mel_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, MelType.OSHAUGHNESSY))
|
|
117
|
+
else:
|
|
118
|
+
self.mel_filt = None
|
|
119
|
+
|
|
120
|
+
# DB scaling, if necessary
|
|
121
|
+
self.amplitude_to_DB = create_scaler(ScalingType.POWER) if self.has_log_scale else None
|
|
122
|
+
|
|
123
|
+
# Cepstrum, if necessary
|
|
124
|
+
if self.has_cepstrum:
|
|
125
|
+
self.register_buffer("dct", create_dct(self.cepstral_coefficients, self.n_filters))
|
|
126
|
+
else:
|
|
127
|
+
self.dct = None
|
|
128
|
+
|
|
129
|
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
for preproc in self.preprocessors:
|
|
131
|
+
wav = preproc(wav)
|
|
132
|
+
|
|
133
|
+
spec = self._stft(wav)
|
|
134
|
+
|
|
135
|
+
if self.lin_filt is not None:
|
|
136
|
+
lin_freq_spec = torch.matmul(spec.transpose(-1, -2), self.lin_filt).transpose(-1, -2)
|
|
137
|
+
else:
|
|
138
|
+
lin_freq_spec = None
|
|
139
|
+
|
|
140
|
+
if self.mel_filt is not None:
|
|
141
|
+
mel_freq_spec = torch.matmul(spec.transpose(-1, -2), self.mel_filt).transpose(-1, -2)
|
|
142
|
+
else:
|
|
143
|
+
mel_freq_spec = None
|
|
144
|
+
|
|
145
|
+
if self.has_lin_log and (lin_freq_spec is not None) and self.amplitude_to_DB:
|
|
146
|
+
lin_log_spec = self.amplitude_to_DB(lin_freq_spec)
|
|
147
|
+
else:
|
|
148
|
+
lin_log_spec = None
|
|
149
|
+
|
|
150
|
+
if self.has_mel_log and (mel_freq_spec is not None) and self.amplitude_to_DB:
|
|
151
|
+
mel_log_spec = self.amplitude_to_DB(mel_freq_spec)
|
|
152
|
+
else:
|
|
153
|
+
mel_log_spec = None
|
|
154
|
+
|
|
155
|
+
if self.has_lfcc and (lin_log_spec is not None) and (self.dct is not None):
|
|
156
|
+
lfcc = torch.matmul(lin_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
|
|
157
|
+
else:
|
|
158
|
+
lfcc = None
|
|
159
|
+
|
|
160
|
+
if self.has_mfcc and (mel_log_spec is not None) and (self.dct is not None):
|
|
161
|
+
mfcc = torch.matmul(mel_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
|
|
162
|
+
else:
|
|
163
|
+
mfcc = None
|
|
164
|
+
|
|
165
|
+
return self._collate(**{
|
|
166
|
+
"lin_freq_spec": lin_freq_spec,
|
|
167
|
+
"mel_freq_spec": lin_freq_spec,
|
|
168
|
+
"lin_log_spec": lin_log_spec,
|
|
169
|
+
"mel_log_spec": mel_log_spec,
|
|
170
|
+
"lfcc": lfcc,
|
|
171
|
+
"mfcc": mfcc,
|
|
172
|
+
})
|
|
173
|
+
|
|
174
|
+
def _collate(self, **kwargs):
|
|
175
|
+
specs = [scale_spec(kwargs[c.key]) for c in self.channels]
|
|
176
|
+
return torch.stack(specs, dim=1)
|