AudioMlSpecTools 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- AudioMlSpecTools/__init__.py +6 -0
- AudioMlSpecTools/_common.py +502 -0
- AudioMlSpecTools/efficient_features.py +176 -0
- AudioMlSpecTools/flexible_features.py +218 -0
- AudioMlSpecTools/preproc.py +50 -0
- AudioMlSpecTools/stft_features.py +62 -0
- AudioMlSpecTools/wav.py +108 -0
- audiomlspectools-0.5.0.dist-info/METADATA +74 -0
- audiomlspectools-0.5.0.dist-info/RECORD +12 -0
- audiomlspectools-0.5.0.dist-info/WHEEL +5 -0
- audiomlspectools-0.5.0.dist-info/licenses/LICENSE +7 -0
- audiomlspectools-0.5.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from ._common import MelType, ScalingType, SpecType, AmplitudeToDB, ExportableSTFT # noqa
|
|
2
|
+
from .efficient_features import ChannelConfig, EfficientFeatureSource # noqa
|
|
3
|
+
from .flexible_features import FeatureSource, FeatureChannel # noqa
|
|
4
|
+
from .preproc import AudioPreprocessor, HighPassFilter # noqa
|
|
5
|
+
from .stft_features import FullRangeStftFeatureSource # noqa
|
|
6
|
+
from .wav import load_wav, set_audio_length, list_audio_files, WavReader # noqa
|
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from enum import Enum
|
|
6
|
+
import json
|
|
7
|
+
import math
|
|
8
|
+
from typing import Optional, Sequence
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
###############################################################################
|
|
12
|
+
# 3PP Imports
|
|
13
|
+
###############################################################################
|
|
14
|
+
import numpy as np
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
###############################################################################
|
|
19
|
+
# Enumerated Types
|
|
20
|
+
###############################################################################
|
|
21
|
+
class SpecType(Enum):
|
|
22
|
+
'''
|
|
23
|
+
Possible spectrum types for feature extraction.
|
|
24
|
+
|
|
25
|
+
The literature is conflicted on which, if any, provides the best results.
|
|
26
|
+
'''
|
|
27
|
+
STFT = "STFT"
|
|
28
|
+
LOG_STFT = "LOG_STFT"
|
|
29
|
+
LFCC = "LFCC"
|
|
30
|
+
MEL = "MEL"
|
|
31
|
+
LOG_MEL = "LOG_MEL"
|
|
32
|
+
MFCC = "MFCC"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MelType(Enum):
|
|
36
|
+
OSHAUGHNESSY = "O'Shaughnessy"
|
|
37
|
+
FANT = "Fant"
|
|
38
|
+
LINDSAY_NORMAN = "Lindsay & Norman"
|
|
39
|
+
SLANEY = "Slaney"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ScalingType(Enum):
|
|
43
|
+
POWER = "power"
|
|
44
|
+
MAGNITUDE = "magnitude"
|
|
45
|
+
LOG = "log"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
###############################################################################
|
|
49
|
+
# New Classes
|
|
50
|
+
###############################################################################
|
|
51
|
+
class ExportableSTFT(torch.nn.Module):
|
|
52
|
+
"""
|
|
53
|
+
Exportable to Executorch. Created by Claude with supervision.
|
|
54
|
+
|
|
55
|
+
Results are very slightly different than Torch but show no noticeable difference in model accuracy or training.
|
|
56
|
+
"""
|
|
57
|
+
def __init__(self,
|
|
58
|
+
n_fft: int,
|
|
59
|
+
hop_length: int,
|
|
60
|
+
*,
|
|
61
|
+
win_length: Optional[int] = None,
|
|
62
|
+
window: Optional[torch.Tensor] = None
|
|
63
|
+
):
|
|
64
|
+
super().__init__()
|
|
65
|
+
|
|
66
|
+
_w = window if window is not None else torch.hann_window(n_fft)
|
|
67
|
+
_window_len = win_length or n_fft
|
|
68
|
+
if _window_len < n_fft:
|
|
69
|
+
pad_left = (n_fft - _window_len) // 2
|
|
70
|
+
pad_right = n_fft - _window_len - pad_left
|
|
71
|
+
_w = torch.nn.functional.pad(_w, (pad_left, pad_right))
|
|
72
|
+
|
|
73
|
+
# Only compute onesided bins: n_fft//2+1
|
|
74
|
+
k = torch.arange(n_fft // 2 + 1).unsqueeze(1) # (freq, 1)
|
|
75
|
+
n = torch.arange(n_fft).unsqueeze(0) # (1, n_fft)
|
|
76
|
+
angles = -2 * torch.pi * k * n / n_fft # (freq, n_fft)
|
|
77
|
+
|
|
78
|
+
self.hop_length = hop_length
|
|
79
|
+
self.n_fft = n_fft
|
|
80
|
+
self.pad = n_fft // 2
|
|
81
|
+
|
|
82
|
+
# Shape: (n_fft//2+1, n_fft)
|
|
83
|
+
self.register_buffer("dft_real", torch.cos(angles) * _w)
|
|
84
|
+
self.register_buffer("dft_imag", torch.sin(angles) * _w)
|
|
85
|
+
|
|
86
|
+
def forward(self, x: torch.Tensor):
|
|
87
|
+
if x.dim() == 1:
|
|
88
|
+
x = x.unsqueeze(0).unsqueeze(0) # (1, 1, T)
|
|
89
|
+
elif x.dim() == 2:
|
|
90
|
+
x = x.unsqueeze(1) # (B, 1, T)
|
|
91
|
+
|
|
92
|
+
x = torch.nn.functional.pad(x, (self.pad, self.pad), mode="reflect")
|
|
93
|
+
x = x.squeeze(1) # (B, T_padded)
|
|
94
|
+
|
|
95
|
+
frames = x.unfold(-1, self.n_fft, self.hop_length) # (B, frames, n_fft)
|
|
96
|
+
|
|
97
|
+
real = torch.matmul(frames, self.dft_real.T) # (B, frames, freq)
|
|
98
|
+
imag = torch.matmul(frames, self.dft_imag.T)
|
|
99
|
+
|
|
100
|
+
power = real ** 2 + imag ** 2 # (B, frames, freq)
|
|
101
|
+
|
|
102
|
+
# Match torch.stft output layout: (B, freq, frames)
|
|
103
|
+
return power.permute(0, 2, 1)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class AudioPreprocessor(torch.nn.Module, ABC):
|
|
107
|
+
@abstractmethod
|
|
108
|
+
def __call__(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
|
|
109
|
+
...
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class BaseFeatureSource(torch.nn.Module, ABC):
|
|
113
|
+
def __init__(self, preprocessors: Sequence[AudioPreprocessor]):
|
|
114
|
+
super(BaseFeatureSource, self).__init__()
|
|
115
|
+
self.preprocessors = preprocessors
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
###############################################################################
|
|
119
|
+
# New Functions
|
|
120
|
+
###############################################################################
|
|
121
|
+
def load_params(params: str | list | dict):
|
|
122
|
+
if not isinstance(params, str):
|
|
123
|
+
return params
|
|
124
|
+
else:
|
|
125
|
+
with open(params, "r") as infile:
|
|
126
|
+
return json.loads(infile.read())
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def write_params(filename: str, params: list | dict):
|
|
130
|
+
with open(filename, "r") as outfile:
|
|
131
|
+
return outfile.write(json.dumps(params, indent=2))
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def power_of_two(n: int):
|
|
135
|
+
return (n & (n - 1) == 0) and n != 0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def generate_filters(n_fft: int, n_filters: int, sample_rate: int, mel_type: Optional[MelType]):
|
|
139
|
+
n_freqs = n_fft // 2 + 1
|
|
140
|
+
f_min = 20.0
|
|
141
|
+
f_max = float(sample_rate // 2)
|
|
142
|
+
|
|
143
|
+
if mel_type:
|
|
144
|
+
return melscale_fbanks(
|
|
145
|
+
mel_type,
|
|
146
|
+
f_min,
|
|
147
|
+
f_max,
|
|
148
|
+
n_freqs,
|
|
149
|
+
n_filters,
|
|
150
|
+
sample_rate,
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
return linear_fbanks(
|
|
154
|
+
n_freqs,
|
|
155
|
+
f_min,
|
|
156
|
+
f_max,
|
|
157
|
+
n_filters,
|
|
158
|
+
sample_rate,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def create_scaler(scaling_type: ScalingType):
|
|
163
|
+
if scaling_type == ScalingType.LOG:
|
|
164
|
+
return log_scale
|
|
165
|
+
else:
|
|
166
|
+
return AmplitudeToDB(stype=scaling_type.value, top_db=80.0)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def scale_spec(spec: torch.Tensor) -> torch.Tensor:
|
|
170
|
+
min_in_val = torch.min(spec)
|
|
171
|
+
max_in_val = torch.max(spec)
|
|
172
|
+
in_span = max_in_val - min_in_val
|
|
173
|
+
|
|
174
|
+
min_out_val = torch.zeros(1)
|
|
175
|
+
max_out_val = torch.ones(1)
|
|
176
|
+
out_span = max_out_val - min_out_val
|
|
177
|
+
|
|
178
|
+
scale_factor = out_span / in_span
|
|
179
|
+
return (spec - min_in_val) * scale_factor
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def determine_spec_type(calc_mels: bool, calc_logs: bool, calc_cepstrum: bool):
|
|
183
|
+
if calc_mels:
|
|
184
|
+
if calc_cepstrum:
|
|
185
|
+
return SpecType.MFCC
|
|
186
|
+
elif calc_logs:
|
|
187
|
+
return SpecType.LOG_MEL
|
|
188
|
+
else:
|
|
189
|
+
return SpecType.MEL
|
|
190
|
+
else:
|
|
191
|
+
if calc_cepstrum:
|
|
192
|
+
return SpecType.LFCC
|
|
193
|
+
elif calc_logs:
|
|
194
|
+
return SpecType.LOG_STFT
|
|
195
|
+
else:
|
|
196
|
+
return SpecType.STFT
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
###############################################################################
|
|
200
|
+
# Imported Legacy Code
|
|
201
|
+
# This code was extracted from v2.8.0 of [torchaudio](https://github.com/pytorch/audio)
|
|
202
|
+
# torchaudio is licensed under the BSD 2-Clause License, reprinted below
|
|
203
|
+
###############################################################################
|
|
204
|
+
#
|
|
205
|
+
# Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
|
|
206
|
+
# All rights reserved.
|
|
207
|
+
#
|
|
208
|
+
# Redistribution and use in source and binary forms, with or without
|
|
209
|
+
# modification, are permitted provided that the following conditions are met:
|
|
210
|
+
#
|
|
211
|
+
# * Redistributions of source code must retain the above copyright notice, this
|
|
212
|
+
# list of conditions and the following disclaimer.
|
|
213
|
+
#
|
|
214
|
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
215
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
216
|
+
# and/or other materials provided with the distribution.
|
|
217
|
+
#
|
|
218
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
219
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
220
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
221
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
222
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
223
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
224
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
225
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
226
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
227
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
228
|
+
###############################################################################
|
|
229
|
+
|
|
230
|
+
###############################################################################
|
|
231
|
+
# Generate features
|
|
232
|
+
###############################################################################
|
|
233
|
+
|
|
234
|
+
class AmplitudeToDB(torch.nn.Module):
|
|
235
|
+
def __init__(self, stype: str = "power", top_db: Optional[float] = None) -> None:
|
|
236
|
+
super(AmplitudeToDB, self).__init__()
|
|
237
|
+
self.stype = stype
|
|
238
|
+
if top_db is not None and top_db < 0:
|
|
239
|
+
raise ValueError("top_db must be positive value")
|
|
240
|
+
self.top_db = top_db
|
|
241
|
+
self.multiplier = 10.0 if stype == "power" else 20.0
|
|
242
|
+
self.amin = 1e-10
|
|
243
|
+
self.ref_value = 1.0
|
|
244
|
+
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
|
|
245
|
+
|
|
246
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
247
|
+
x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin))
|
|
248
|
+
x_db -= self.multiplier * self.db_multiplier
|
|
249
|
+
|
|
250
|
+
if self.top_db:
|
|
251
|
+
max_ref = (x_db.max() - self.top_db)
|
|
252
|
+
x_db = torch.max(x_db, max_ref)
|
|
253
|
+
|
|
254
|
+
return x_db
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def log_scale(waveform: torch.Tensor) -> torch.Tensor:
|
|
258
|
+
log_offset = 1e-6
|
|
259
|
+
return torch.log(waveform + log_offset)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def hz_to_mel(freq: float, mel_type: MelType) -> float:
|
|
263
|
+
if mel_type == MelType.OSHAUGHNESSY:
|
|
264
|
+
return 2595.0 * math.log10(1.0 + (freq / 700.0))
|
|
265
|
+
elif mel_type == MelType.FANT:
|
|
266
|
+
return (1000 / math.log10(2)) * math.log10(1.0 + (freq / 1000.0))
|
|
267
|
+
elif mel_type == MelType.LINDSAY_NORMAN:
|
|
268
|
+
return 2410.0 * math.log10(1.0 + (freq / 625.0))
|
|
269
|
+
else: # MelType.SLANEY
|
|
270
|
+
min_log_hz = 1000.0
|
|
271
|
+
f_sp = 200.0 / 3
|
|
272
|
+
|
|
273
|
+
if freq < min_log_hz:
|
|
274
|
+
return freq / f_sp
|
|
275
|
+
else:
|
|
276
|
+
min_log_mel = min_log_hz / f_sp
|
|
277
|
+
logstep = math.log(6.4) / 27.0
|
|
278
|
+
return min_log_mel + math.log(freq / min_log_hz) / logstep
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def mel_to_hz(mels: torch.Tensor, mel_type: MelType) -> torch.Tensor:
|
|
282
|
+
if mel_type == MelType.OSHAUGHNESSY:
|
|
283
|
+
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
|
284
|
+
elif mel_type == MelType.FANT:
|
|
285
|
+
mul = math.log10(2) / 1000
|
|
286
|
+
return 1000.0 * (10 ** (mels * mul) - 1.0)
|
|
287
|
+
elif mel_type == MelType.LINDSAY_NORMAN:
|
|
288
|
+
return 625.0 * (10.0 ** (mels / 2410.0) - 1.0)
|
|
289
|
+
else:
|
|
290
|
+
min_log_hz = 1000.0
|
|
291
|
+
f_sp = 200.0 / 3
|
|
292
|
+
|
|
293
|
+
freqs = f_sp * mels
|
|
294
|
+
min_log_mel = min_log_hz / f_sp
|
|
295
|
+
|
|
296
|
+
logstep = math.log(6.4) / 27.0
|
|
297
|
+
|
|
298
|
+
log_t = mels >= min_log_mel
|
|
299
|
+
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
|
|
300
|
+
|
|
301
|
+
return freqs
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def create_triangular_filterbank(
|
|
305
|
+
all_freqs: torch.Tensor,
|
|
306
|
+
f_pts: torch.Tensor,
|
|
307
|
+
) -> torch.Tensor:
|
|
308
|
+
# Adopted from Librosa
|
|
309
|
+
# calculate the difference between each filter mid point and each stft freq point in hertz
|
|
310
|
+
f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
|
|
311
|
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
|
|
312
|
+
# create overlapping triangles
|
|
313
|
+
zero = torch.zeros(1)
|
|
314
|
+
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
|
|
315
|
+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
|
|
316
|
+
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
|
|
317
|
+
|
|
318
|
+
return fb
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def melscale_fbanks(
|
|
322
|
+
mel_type: MelType,
|
|
323
|
+
f_min: float,
|
|
324
|
+
f_max: float,
|
|
325
|
+
n_freqs: int,
|
|
326
|
+
n_mels: int,
|
|
327
|
+
sample_rate: int,
|
|
328
|
+
):
|
|
329
|
+
# freq bins
|
|
330
|
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
|
331
|
+
|
|
332
|
+
# calculate mel freq bins
|
|
333
|
+
m_min = hz_to_mel(f_min, mel_type=mel_type)
|
|
334
|
+
m_max = hz_to_mel(f_max, mel_type=mel_type)
|
|
335
|
+
|
|
336
|
+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
|
|
337
|
+
f_pts = mel_to_hz(m_pts, mel_type=mel_type)
|
|
338
|
+
|
|
339
|
+
# create filterbank
|
|
340
|
+
fb = create_triangular_filterbank(all_freqs, f_pts)
|
|
341
|
+
|
|
342
|
+
if (fb.max(dim=0).values == 0.0).any():
|
|
343
|
+
warnings.warn(
|
|
344
|
+
"At least one mel filterbank has all zero values. "
|
|
345
|
+
f"The value for `n_mels` ({n_mels}) may be set too high. "
|
|
346
|
+
f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
return fb
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def linear_fbanks(
|
|
353
|
+
n_freqs: int,
|
|
354
|
+
f_min: float,
|
|
355
|
+
f_max: float,
|
|
356
|
+
n_filter: int,
|
|
357
|
+
sample_rate: int,
|
|
358
|
+
) -> torch.Tensor:
|
|
359
|
+
# freq bins
|
|
360
|
+
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
|
|
361
|
+
|
|
362
|
+
# filter mid-points
|
|
363
|
+
f_pts = torch.linspace(f_min, f_max, n_filter + 2)
|
|
364
|
+
|
|
365
|
+
# create filterbank
|
|
366
|
+
fb = create_triangular_filterbank(all_freqs, f_pts)
|
|
367
|
+
|
|
368
|
+
return fb
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def create_dct(n_cepstrum: int, n_filters: int) -> torch.Tensor:
|
|
372
|
+
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
|
|
373
|
+
n = torch.arange(float(n_filters))
|
|
374
|
+
k = torch.arange(float(n_cepstrum)).unsqueeze(1)
|
|
375
|
+
dct = torch.cos(math.pi / float(n_filters) * (n + 0.5) * k) # size (n_mfcc, n_mels)
|
|
376
|
+
|
|
377
|
+
dct[0] *= 1.0 / math.sqrt(2.0)
|
|
378
|
+
dct *= math.sqrt(2.0 / float(n_filters))
|
|
379
|
+
return dct.t()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
###############################################################################
|
|
383
|
+
# Load audio
|
|
384
|
+
###############################################################################
|
|
385
|
+
|
|
386
|
+
def _get_sinc_resample_kernel(
|
|
387
|
+
orig_freq: int,
|
|
388
|
+
new_freq: int,
|
|
389
|
+
gcd: int,
|
|
390
|
+
lowpass_filter_width: int = 6,
|
|
391
|
+
rolloff: float = 0.99,
|
|
392
|
+
device: torch.device = torch.device("cpu"),
|
|
393
|
+
dtype: Optional[torch.dtype] = None,
|
|
394
|
+
):
|
|
395
|
+
orig_freq = int(orig_freq) // gcd
|
|
396
|
+
new_freq = int(new_freq) // gcd
|
|
397
|
+
|
|
398
|
+
if lowpass_filter_width <= 0:
|
|
399
|
+
raise ValueError("Low pass filter width should be positive.")
|
|
400
|
+
base_freq = min(orig_freq, new_freq)
|
|
401
|
+
# This will perform antialiasing filtering by removing the highest frequencies.
|
|
402
|
+
# At first I thought I only needed this when downsampling, but when upsampling
|
|
403
|
+
# you will get edge artifacts without this, as the edge is equivalent to zero padding,
|
|
404
|
+
# which will add high freq artifacts.
|
|
405
|
+
base_freq *= rolloff
|
|
406
|
+
|
|
407
|
+
# The key idea of the algorithm is that x(t) can be exactly reconstructed from x[i] (tensor)
|
|
408
|
+
# using the sinc interpolation formula:
|
|
409
|
+
# x(t) = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - t))
|
|
410
|
+
# We can then sample the function x(t) with a different sample rate:
|
|
411
|
+
# y[j] = x(j / new_freq)
|
|
412
|
+
# or,
|
|
413
|
+
# y[j] = sum_i x[i] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
|
|
414
|
+
|
|
415
|
+
# We see here that y[j] is the convolution of x[i] with a specific filter, for which
|
|
416
|
+
# we take an FIR approximation, stopping when we see at least `lowpass_filter_width` zeros crossing.
|
|
417
|
+
# But y[j+1] is going to have a different set of weights and so on, until y[j + new_freq].
|
|
418
|
+
# Indeed:
|
|
419
|
+
# y[j + new_freq] = sum_i x[i] sinc(pi * orig_freq * ((i / orig_freq - (j + new_freq) / new_freq))
|
|
420
|
+
# = sum_i x[i] sinc(pi * orig_freq * ((i - orig_freq) / orig_freq - j / new_freq))
|
|
421
|
+
# = sum_i x[i + orig_freq] sinc(pi * orig_freq * (i / orig_freq - j / new_freq))
|
|
422
|
+
# so y[j+new_freq] uses the same filter as y[j], but on a shifted version of x by `orig_freq`.
|
|
423
|
+
# This will explain the F.conv1d after, with a stride of orig_freq.
|
|
424
|
+
width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
|
|
425
|
+
# If orig_freq is still big after GCD reduction, most filters will be very unbalanced, i.e.,
|
|
426
|
+
# they will have a lot of almost zero values to the left or to the right...
|
|
427
|
+
# There is probably a way to evaluate those filters more efficiently, but this is kept for
|
|
428
|
+
# future work.
|
|
429
|
+
idx_dtype = dtype if dtype is not None else torch.float64
|
|
430
|
+
|
|
431
|
+
idx = torch.arange(-width, width + orig_freq, dtype=idx_dtype, device=device)[None, None] / orig_freq
|
|
432
|
+
|
|
433
|
+
t = torch.arange(0, -new_freq, -1, dtype=dtype, device=device)[:, None, None] / new_freq + idx
|
|
434
|
+
t *= base_freq
|
|
435
|
+
t = t.clamp_(-lowpass_filter_width, lowpass_filter_width)
|
|
436
|
+
|
|
437
|
+
# we do not use built in torch windows here as we need to evaluate the window
|
|
438
|
+
# at specific positions, not over a regular grid.
|
|
439
|
+
window = torch.cos(t * math.pi / lowpass_filter_width / 2) ** 2
|
|
440
|
+
|
|
441
|
+
t *= math.pi
|
|
442
|
+
|
|
443
|
+
scale = base_freq / orig_freq
|
|
444
|
+
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
|
|
445
|
+
kernels *= window * scale
|
|
446
|
+
|
|
447
|
+
if dtype is None:
|
|
448
|
+
kernels = kernels.to(dtype=torch.float32)
|
|
449
|
+
|
|
450
|
+
return kernels, width
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _apply_sinc_resample_kernel(
|
|
454
|
+
waveform: torch.Tensor,
|
|
455
|
+
orig_freq: int,
|
|
456
|
+
new_freq: int,
|
|
457
|
+
gcd: int,
|
|
458
|
+
kernel: torch.Tensor,
|
|
459
|
+
width: int,
|
|
460
|
+
):
|
|
461
|
+
orig_freq = int(orig_freq) // gcd
|
|
462
|
+
new_freq = int(new_freq) // gcd
|
|
463
|
+
|
|
464
|
+
# pack batch
|
|
465
|
+
shape = waveform.size()
|
|
466
|
+
waveform = waveform.view(-1, shape[-1])
|
|
467
|
+
|
|
468
|
+
num_wavs, length = waveform.shape
|
|
469
|
+
waveform = torch.nn.functional.pad(waveform, (width, width + orig_freq))
|
|
470
|
+
resampled = torch.nn.functional.conv1d(waveform[:, None], kernel, stride=orig_freq)
|
|
471
|
+
resampled = resampled.transpose(1, 2).reshape(num_wavs, -1)
|
|
472
|
+
target_length = torch.ceil(torch.as_tensor(new_freq * length / orig_freq)).long()
|
|
473
|
+
resampled = resampled[..., :target_length]
|
|
474
|
+
|
|
475
|
+
# unpack batch
|
|
476
|
+
resampled = resampled.view(shape[:-1] + resampled.shape[-1:])
|
|
477
|
+
return resampled
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def resample(
|
|
481
|
+
waveform: torch.Tensor,
|
|
482
|
+
orig_freq: int,
|
|
483
|
+
new_freq: int,
|
|
484
|
+
lowpass_filter_width: int = 6,
|
|
485
|
+
rolloff: float = 0.99,
|
|
486
|
+
) -> torch.Tensor:
|
|
487
|
+
if orig_freq == new_freq:
|
|
488
|
+
return waveform
|
|
489
|
+
|
|
490
|
+
gcd = math.gcd(int(orig_freq), int(new_freq))
|
|
491
|
+
|
|
492
|
+
kernel, width = _get_sinc_resample_kernel(
|
|
493
|
+
orig_freq,
|
|
494
|
+
new_freq,
|
|
495
|
+
gcd,
|
|
496
|
+
lowpass_filter_width,
|
|
497
|
+
rolloff,
|
|
498
|
+
waveform.device,
|
|
499
|
+
waveform.dtype,
|
|
500
|
+
)
|
|
501
|
+
resampled = _apply_sinc_resample_kernel(waveform, orig_freq, new_freq, gcd, kernel, width)
|
|
502
|
+
return resampled
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
###############################################################################
|
|
8
|
+
# 3PP Imports
|
|
9
|
+
###############################################################################
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
###############################################################################
|
|
13
|
+
# Local Imports
|
|
14
|
+
###############################################################################
|
|
15
|
+
from ._common import MelType, ScalingType, ExportableSTFT, AudioPreprocessor, BaseFeatureSource
|
|
16
|
+
from ._common import power_of_two, create_dct, create_scaler, generate_filters, scale_spec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
###############################################################################
|
|
20
|
+
# Helper Classes
|
|
21
|
+
###############################################################################
|
|
22
|
+
class ChannelConfig:
|
|
23
|
+
def __init__(self,
|
|
24
|
+
*,
|
|
25
|
+
is_mel: bool,
|
|
26
|
+
is_log: bool,
|
|
27
|
+
is_cepstrum: bool,
|
|
28
|
+
):
|
|
29
|
+
self.is_mel = is_mel
|
|
30
|
+
self.is_log = is_log
|
|
31
|
+
self.is_cepstrum = is_cepstrum
|
|
32
|
+
|
|
33
|
+
self.key = self._make_channel_key()
|
|
34
|
+
|
|
35
|
+
def get_key(self):
|
|
36
|
+
return self.key
|
|
37
|
+
|
|
38
|
+
def _make_channel_key(self):
|
|
39
|
+
if self.is_cepstrum:
|
|
40
|
+
return "mfcc" if self.is_mel else "lfcc"
|
|
41
|
+
elif self.is_log:
|
|
42
|
+
return "mel_log_spec" if self.is_mel else "lin_log_spec"
|
|
43
|
+
else:
|
|
44
|
+
return "mel_freq_spec" if self.is_mel else "lin_freq_spec"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
###############################################################################
|
|
48
|
+
# Classes
|
|
49
|
+
###############################################################################
|
|
50
|
+
class EfficientFeatureSource(BaseFeatureSource):
|
|
51
|
+
def __init__(self,
|
|
52
|
+
sample_rate: int,
|
|
53
|
+
channels: list[ChannelConfig],
|
|
54
|
+
preprocessors: Sequence[AudioPreprocessor] = [],
|
|
55
|
+
*,
|
|
56
|
+
# For all spectra
|
|
57
|
+
n_fft: Optional[int] = None,
|
|
58
|
+
hop_length: Optional[int] = None,
|
|
59
|
+
|
|
60
|
+
# Shared
|
|
61
|
+
n_filters: Optional[int] = None,
|
|
62
|
+
|
|
63
|
+
# For all cepstra
|
|
64
|
+
cepstral_coefficients: Optional[int] = None,
|
|
65
|
+
):
|
|
66
|
+
super(EfficientFeatureSource, self).__init__(preprocessors)
|
|
67
|
+
|
|
68
|
+
self.channels = channels
|
|
69
|
+
|
|
70
|
+
self.has_lin_freq = any([not c.is_mel for c in channels])
|
|
71
|
+
self.has_mel_freq = any([c.is_mel for c in channels])
|
|
72
|
+
|
|
73
|
+
self.has_lfcc = self.has_lin_freq and any([c.is_cepstrum for c in channels])
|
|
74
|
+
self.has_mfcc = self.has_mel_freq and any([c.is_cepstrum for c in channels])
|
|
75
|
+
self.has_cepstrum = self.has_lfcc or self.has_mfcc
|
|
76
|
+
|
|
77
|
+
self.has_lin_log = any([not c.is_mel and c.is_log for c in channels])
|
|
78
|
+
self.has_mel_log = any([c.is_mel and c.is_log for c in channels])
|
|
79
|
+
self.has_log_scale = any([c.is_log for c in channels]) or self.has_cepstrum
|
|
80
|
+
|
|
81
|
+
# Required configs
|
|
82
|
+
self.sample_rate = sample_rate
|
|
83
|
+
|
|
84
|
+
# Universal configs
|
|
85
|
+
if n_fft is not None and not power_of_two(n_fft):
|
|
86
|
+
raise ValueError("n_fft must be a power of 2")
|
|
87
|
+
self.n_fft = n_fft or 1024
|
|
88
|
+
|
|
89
|
+
if hop_length is not None and hop_length > (self.n_fft // 2):
|
|
90
|
+
warnings.warn(f"hop_length should be set to no more than 1/2 the FFT window size, or {self.n_fft // 2} mels for n_fft = {self.n_fft} (currently {hop_length})")
|
|
91
|
+
self.hop_length = hop_length or self.n_fft // 4
|
|
92
|
+
|
|
93
|
+
# Shared configs (mels, MFCC, LFCC)
|
|
94
|
+
if n_filters is not None and n_filters > (self.n_fft // 8):
|
|
95
|
+
warnings.warn(f"n_filters should be set to no more than 1/8 the FFT window size, or {self.n_fft // 8} filters for n_fft = {self.n_fft} (currently {n_filters})")
|
|
96
|
+
self.n_filters = n_filters or self.n_fft // 8
|
|
97
|
+
|
|
98
|
+
# Cepstral configs
|
|
99
|
+
if cepstral_coefficients is not None and cepstral_coefficients > self.n_filters:
|
|
100
|
+
raise ValueError(f"cepstral_coefficients must be no greater than n_mels (currently {cepstral_coefficients}/{self.n_filters})")
|
|
101
|
+
self.cepstral_coefficients = cepstral_coefficients or self.n_filters
|
|
102
|
+
|
|
103
|
+
###################
|
|
104
|
+
# Spec gen code
|
|
105
|
+
###################
|
|
106
|
+
|
|
107
|
+
# Basic spectrogram
|
|
108
|
+
self._stft = ExportableSTFT(self.n_fft, self.hop_length)
|
|
109
|
+
|
|
110
|
+
if self.has_lin_freq:
|
|
111
|
+
self.register_buffer("lin_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, None))
|
|
112
|
+
else:
|
|
113
|
+
self.lin_filt = None
|
|
114
|
+
|
|
115
|
+
if self.has_mel_freq:
|
|
116
|
+
self.register_buffer("mel_filt", generate_filters(self.n_fft, self.n_filters, self.sample_rate, MelType.OSHAUGHNESSY))
|
|
117
|
+
else:
|
|
118
|
+
self.mel_filt = None
|
|
119
|
+
|
|
120
|
+
# DB scaling, if necessary
|
|
121
|
+
self.amplitude_to_DB = create_scaler(ScalingType.POWER) if self.has_log_scale else None
|
|
122
|
+
|
|
123
|
+
# Cepstrum, if necessary
|
|
124
|
+
if self.has_cepstrum:
|
|
125
|
+
self.register_buffer("dct", create_dct(self.cepstral_coefficients, self.n_filters))
|
|
126
|
+
else:
|
|
127
|
+
self.dct = None
|
|
128
|
+
|
|
129
|
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
130
|
+
for preproc in self.preprocessors:
|
|
131
|
+
wav = preproc(wav)
|
|
132
|
+
|
|
133
|
+
spec = self._stft(wav)
|
|
134
|
+
|
|
135
|
+
if self.lin_filt is not None:
|
|
136
|
+
lin_freq_spec = torch.matmul(spec.transpose(-1, -2), self.lin_filt).transpose(-1, -2)
|
|
137
|
+
else:
|
|
138
|
+
lin_freq_spec = None
|
|
139
|
+
|
|
140
|
+
if self.mel_filt is not None:
|
|
141
|
+
mel_freq_spec = torch.matmul(spec.transpose(-1, -2), self.mel_filt).transpose(-1, -2)
|
|
142
|
+
else:
|
|
143
|
+
mel_freq_spec = None
|
|
144
|
+
|
|
145
|
+
if self.has_lin_log and (lin_freq_spec is not None) and self.amplitude_to_DB:
|
|
146
|
+
lin_log_spec = self.amplitude_to_DB(lin_freq_spec)
|
|
147
|
+
else:
|
|
148
|
+
lin_log_spec = None
|
|
149
|
+
|
|
150
|
+
if self.has_mel_log and (mel_freq_spec is not None) and self.amplitude_to_DB:
|
|
151
|
+
mel_log_spec = self.amplitude_to_DB(mel_freq_spec)
|
|
152
|
+
else:
|
|
153
|
+
mel_log_spec = None
|
|
154
|
+
|
|
155
|
+
if self.has_lfcc and (lin_log_spec is not None) and (self.dct is not None):
|
|
156
|
+
lfcc = torch.matmul(lin_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
|
|
157
|
+
else:
|
|
158
|
+
lfcc = None
|
|
159
|
+
|
|
160
|
+
if self.has_mfcc and (mel_log_spec is not None) and (self.dct is not None):
|
|
161
|
+
mfcc = torch.matmul(mel_log_spec.transpose(-1, -2), self.dct).transpose(-1, -2)
|
|
162
|
+
else:
|
|
163
|
+
mfcc = None
|
|
164
|
+
|
|
165
|
+
return self._collate(**{
|
|
166
|
+
"lin_freq_spec": lin_freq_spec,
|
|
167
|
+
"mel_freq_spec": lin_freq_spec,
|
|
168
|
+
"lin_log_spec": lin_log_spec,
|
|
169
|
+
"mel_log_spec": mel_log_spec,
|
|
170
|
+
"lfcc": lfcc,
|
|
171
|
+
"mfcc": mfcc,
|
|
172
|
+
})
|
|
173
|
+
|
|
174
|
+
def _collate(self, **kwargs):
|
|
175
|
+
specs = [scale_spec(kwargs[c.key]) for c in self.channels]
|
|
176
|
+
return torch.stack(specs, dim=1)
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
###############################################################################
|
|
8
|
+
# 3PP Imports
|
|
9
|
+
###############################################################################
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
###############################################################################
|
|
13
|
+
# Local Imports
|
|
14
|
+
###############################################################################
|
|
15
|
+
from ._common import MelType, ScalingType, ExportableSTFT, AudioPreprocessor, BaseFeatureSource
|
|
16
|
+
from ._common import power_of_two, create_dct, create_scaler, generate_filters, determine_spec_type, scale_spec, load_params, write_params
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
###############################################################################
|
|
20
|
+
# Export Classes
|
|
21
|
+
###############################################################################
|
|
22
|
+
class FeatureChannel(torch.nn.Module):
|
|
23
|
+
def __init__(self,
|
|
24
|
+
sample_rate: int,
|
|
25
|
+
*,
|
|
26
|
+
# For all spectra
|
|
27
|
+
n_fft: Optional[int] = None,
|
|
28
|
+
hop_length: Optional[int] = None,
|
|
29
|
+
scale_spec: Optional[bool] = None,
|
|
30
|
+
|
|
31
|
+
# Shared
|
|
32
|
+
n_filters: Optional[int] = None,
|
|
33
|
+
|
|
34
|
+
# For all mel spectra
|
|
35
|
+
is_mel: Optional[bool] = None,
|
|
36
|
+
mel_type: Optional[MelType] = None,
|
|
37
|
+
|
|
38
|
+
# For all log spectra
|
|
39
|
+
is_logarithmic: Optional[bool] = None,
|
|
40
|
+
scaling_type: ScalingType = ScalingType.POWER,
|
|
41
|
+
|
|
42
|
+
# For all cepstra
|
|
43
|
+
is_cepstrum: Optional[bool] = None,
|
|
44
|
+
cepstral_coefficients: Optional[int] = None,
|
|
45
|
+
):
|
|
46
|
+
super(FeatureChannel, self).__init__()
|
|
47
|
+
|
|
48
|
+
# Required configs
|
|
49
|
+
self.sample_rate = sample_rate
|
|
50
|
+
|
|
51
|
+
# Universal configs
|
|
52
|
+
if n_fft is not None and not power_of_two(n_fft):
|
|
53
|
+
raise ValueError("n_fft must be a power of 2")
|
|
54
|
+
self.n_fft = n_fft or 1024
|
|
55
|
+
|
|
56
|
+
if hop_length is not None and hop_length > (self.n_fft // 2):
|
|
57
|
+
warnings.warn(f"hop_length should be set to no more than 1/2 the FFT window size, or {self.n_fft // 2} mels for n_fft = {self.n_fft} (currently {hop_length})")
|
|
58
|
+
self.hop_length = hop_length or self.n_fft // 4
|
|
59
|
+
|
|
60
|
+
self.scale_spec = scale_spec if scale_spec is not None else True
|
|
61
|
+
|
|
62
|
+
# Shared configs (mels, MFCC, LFCC)
|
|
63
|
+
if n_filters is not None and n_filters > (self.n_fft // 8):
|
|
64
|
+
warnings.warn(f"n_filters should be set to no more than 1/8 the FFT window size, or {self.n_fft // 8} filters for n_fft = {self.n_fft} (currently {n_filters})")
|
|
65
|
+
self.n_filters = n_filters or self.n_fft // 8
|
|
66
|
+
|
|
67
|
+
# Mel configs
|
|
68
|
+
calc_mels = is_mel or (mel_type is not None)
|
|
69
|
+
self.mel_type = mel_type or (MelType.OSHAUGHNESSY if calc_mels else None)
|
|
70
|
+
|
|
71
|
+
# Cepstral configs
|
|
72
|
+
calc_cepstrum = is_cepstrum or (cepstral_coefficients is not None)
|
|
73
|
+
|
|
74
|
+
if cepstral_coefficients is not None and cepstral_coefficients > self.n_filters:
|
|
75
|
+
raise ValueError(f"cepstral_coefficients must be no greater than n_mels (currently {cepstral_coefficients}/{self.n_filters})")
|
|
76
|
+
self.cepstral_coefficients = cepstral_coefficients or self.n_filters
|
|
77
|
+
|
|
78
|
+
# Log configs
|
|
79
|
+
calc_logs = is_logarithmic or calc_cepstrum
|
|
80
|
+
|
|
81
|
+
###################
|
|
82
|
+
# Spec gen code
|
|
83
|
+
###################
|
|
84
|
+
|
|
85
|
+
# Basic spectrogram
|
|
86
|
+
self._stft = ExportableSTFT(self.n_fft, self.hop_length)
|
|
87
|
+
|
|
88
|
+
fb = generate_filters(self.n_fft, self.n_filters, self.sample_rate, self.mel_type)
|
|
89
|
+
self.register_buffer("fb", fb)
|
|
90
|
+
|
|
91
|
+
# DB scaling, if necessary
|
|
92
|
+
self.scaling_type = scaling_type
|
|
93
|
+
self.amplitude_to_DB = create_scaler(scaling_type) if calc_logs else None
|
|
94
|
+
|
|
95
|
+
# Cepstrum, if necessary
|
|
96
|
+
if calc_cepstrum:
|
|
97
|
+
dct_mat = create_dct(self.cepstral_coefficients, self.n_filters)
|
|
98
|
+
self.register_buffer("dct_mat", dct_mat)
|
|
99
|
+
else:
|
|
100
|
+
self.dct_mat = None
|
|
101
|
+
|
|
102
|
+
self.spec_type = determine_spec_type(calc_mels, calc_logs, calc_cepstrum)
|
|
103
|
+
|
|
104
|
+
def get_spec_type(self):
|
|
105
|
+
return self.spec_type
|
|
106
|
+
|
|
107
|
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
108
|
+
spec = self._stft(wav)
|
|
109
|
+
|
|
110
|
+
if self.fb is not None:
|
|
111
|
+
spec = torch.matmul(spec.transpose(-1, -2), self.fb).transpose(-1, -2)
|
|
112
|
+
|
|
113
|
+
if self.amplitude_to_DB is not None:
|
|
114
|
+
spec = self.amplitude_to_DB(spec)
|
|
115
|
+
|
|
116
|
+
if self.dct_mat is not None:
|
|
117
|
+
spec = torch.matmul(spec.transpose(-1, -2), self.dct_mat).transpose(-1, -2)
|
|
118
|
+
|
|
119
|
+
if self.scale_spec:
|
|
120
|
+
spec = scale_spec(spec)
|
|
121
|
+
|
|
122
|
+
return spec
|
|
123
|
+
|
|
124
|
+
@staticmethod
|
|
125
|
+
def from_json(params: str | dict):
|
|
126
|
+
loaded_params = load_params(params)
|
|
127
|
+
if not isinstance(loaded_params, dict):
|
|
128
|
+
raise ValueError(f"Invalid {FeatureChannel.__name__} parameters")
|
|
129
|
+
|
|
130
|
+
mel_type_raw = loaded_params.get("mel_type", None)
|
|
131
|
+
scaling_type_raw = loaded_params.get("scaling_type", None)
|
|
132
|
+
|
|
133
|
+
return FeatureChannel(sample_rate=loaded_params["sample_rate"],
|
|
134
|
+
# For all spectra
|
|
135
|
+
n_fft=loaded_params.get("n_fft", None),
|
|
136
|
+
hop_length=loaded_params.get("hop_length", None),
|
|
137
|
+
scale_spec=loaded_params.get("scale_spec", None),
|
|
138
|
+
|
|
139
|
+
# Shared
|
|
140
|
+
n_filters=loaded_params.get("n_filters", None),
|
|
141
|
+
|
|
142
|
+
# For all mel spectra
|
|
143
|
+
is_mel=loaded_params.get("is_mel", None),
|
|
144
|
+
mel_type=MelType(mel_type_raw) if mel_type_raw else None,
|
|
145
|
+
|
|
146
|
+
# For all log spectra
|
|
147
|
+
is_logarithmic=loaded_params.get("is_logarithmic", None),
|
|
148
|
+
scaling_type=ScalingType(scaling_type_raw) if scaling_type_raw else ScalingType.POWER,
|
|
149
|
+
|
|
150
|
+
# For all cepstra
|
|
151
|
+
is_cepstrum=loaded_params.get("is_cepstrum", None),
|
|
152
|
+
cepstral_coefficients=loaded_params.get("cepstral_coefficients", None),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def get_params(self):
|
|
156
|
+
return {
|
|
157
|
+
"sample_rate": self.sample_rate,
|
|
158
|
+
"n_fft": self.n_fft,
|
|
159
|
+
"hop_length": self.hop_length,
|
|
160
|
+
"scale_spec": self.scale_spec,
|
|
161
|
+
|
|
162
|
+
# Shared
|
|
163
|
+
"n_filters": self.n_filters,
|
|
164
|
+
|
|
165
|
+
# For all mel spectra
|
|
166
|
+
"is_mel": self.mel_type is not None,
|
|
167
|
+
"mel_type": self.mel_type,
|
|
168
|
+
|
|
169
|
+
# For all log spectra
|
|
170
|
+
"is_logarithmic": self.amplitude_to_DB is not None,
|
|
171
|
+
"scaling_type": self.scaling_type,
|
|
172
|
+
|
|
173
|
+
# For all cepstra
|
|
174
|
+
"is_cepstrum": self.dct_mat is not None,
|
|
175
|
+
"cepstral_coefficients": self.cepstral_coefficients if self.dct_mat is not None else None,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
def to_json(self, filename: str):
|
|
179
|
+
params = self.get_params()
|
|
180
|
+
write_params(filename, params)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class FeatureSource(BaseFeatureSource):
|
|
184
|
+
def __init__(self,
|
|
185
|
+
feature_channels: list[FeatureChannel],
|
|
186
|
+
preprocessors: Sequence[AudioPreprocessor] = [],
|
|
187
|
+
):
|
|
188
|
+
super(FeatureSource, self).__init__(preprocessors)
|
|
189
|
+
|
|
190
|
+
if len(feature_channels) == 0:
|
|
191
|
+
raise ValueError("Must include at least one spec type")
|
|
192
|
+
|
|
193
|
+
self.fc = feature_channels
|
|
194
|
+
self.feature_channels = torch.nn.ModuleList(feature_channels)
|
|
195
|
+
self.preprocessors = torch.nn.ModuleList(preprocessors)
|
|
196
|
+
|
|
197
|
+
def num_channels(self):
|
|
198
|
+
return len(self.fc)
|
|
199
|
+
|
|
200
|
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
for preproc in self.preprocessors:
|
|
202
|
+
wav = preproc(wav)
|
|
203
|
+
|
|
204
|
+
spectra = [chan(wav) for chan in self.feature_channels]
|
|
205
|
+
return torch.stack(spectra, dim=1)
|
|
206
|
+
|
|
207
|
+
@staticmethod
|
|
208
|
+
def from_json(params: str | list):
|
|
209
|
+
loaded_params = load_params(params)
|
|
210
|
+
if not isinstance(loaded_params, list):
|
|
211
|
+
raise ValueError(f"Invalid {FeatureSource.__name__} parameters")
|
|
212
|
+
|
|
213
|
+
feature_channels = [FeatureChannel.from_json(p) for p in loaded_params]
|
|
214
|
+
return FeatureSource(feature_channels)
|
|
215
|
+
|
|
216
|
+
def to_json(self, filename: str):
|
|
217
|
+
params = [fc.get_params() for fc in self.feature_channels.children() if isinstance(fc, FeatureChannel)]
|
|
218
|
+
write_params(filename, params)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# 3PP Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
import numpy as np
|
|
5
|
+
from scipy.signal import butter, sosfilt, sosfilt_zi
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
###############################################################################
|
|
9
|
+
# Local Imports
|
|
10
|
+
###############################################################################
|
|
11
|
+
from ._common import AudioPreprocessor
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
###############################################################################
|
|
15
|
+
# Classes
|
|
16
|
+
###############################################################################
|
|
17
|
+
class HighPassFilter(AudioPreprocessor):
|
|
18
|
+
def __init__(self,
|
|
19
|
+
*,
|
|
20
|
+
sample_rate: int,
|
|
21
|
+
cutoff_freq: int,
|
|
22
|
+
rolloff_db: int,
|
|
23
|
+
):
|
|
24
|
+
nyquist = sample_rate // 2
|
|
25
|
+
if cutoff_freq < 0 or cutoff_freq > nyquist:
|
|
26
|
+
raise ValueError(f"Cutoff freq must be >0 Hz and less than the Nyquist rate ({nyquist})")
|
|
27
|
+
elif rolloff_db < 6 or rolloff_db % 6 != 0:
|
|
28
|
+
raise ValueError("Rolloff dB must be >0 and a multiple of 6")
|
|
29
|
+
|
|
30
|
+
self.cutoff_freq = cutoff_freq
|
|
31
|
+
self.rolloff_db = rolloff_db
|
|
32
|
+
|
|
33
|
+
self.sos = butter(rolloff_db // 6,
|
|
34
|
+
cutoff_freq,
|
|
35
|
+
btype="highpass",
|
|
36
|
+
analog=False,
|
|
37
|
+
fs=sample_rate,
|
|
38
|
+
output="sos",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.zi = sosfilt_zi(self.sos)
|
|
42
|
+
|
|
43
|
+
def __call__(self, wav: torch.Tensor | np.ndarray) -> torch.Tensor:
|
|
44
|
+
_wav = wav.numpy() if isinstance(wav, torch.Tensor) else wav
|
|
45
|
+
if _wav.ndim != 1:
|
|
46
|
+
raise ValueError(f"Improper input dim; must be 1 but is {_wav.ndim}")
|
|
47
|
+
|
|
48
|
+
processed_samples, _ = sosfilt(self.sos, _wav, zi=self.zi * _wav[0])
|
|
49
|
+
processed_samples = processed_samples.astype(np.float32)
|
|
50
|
+
return torch.from_numpy(processed_samples)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
import warnings
|
|
6
|
+
|
|
7
|
+
###############################################################################
|
|
8
|
+
# 3PP Imports
|
|
9
|
+
###############################################################################
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
###############################################################################
|
|
13
|
+
# Local Imports
|
|
14
|
+
###############################################################################
|
|
15
|
+
from ._common import ScalingType, ExportableSTFT, AudioPreprocessor, BaseFeatureSource
|
|
16
|
+
from ._common import power_of_two, create_scaler, scale_spec
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
###############################################################################
|
|
20
|
+
# Export Classes
|
|
21
|
+
###############################################################################
|
|
22
|
+
class FullRangeStftFeatureSource(BaseFeatureSource):
|
|
23
|
+
def __init__(self,
|
|
24
|
+
sample_rate: int,
|
|
25
|
+
preprocessors: Sequence[AudioPreprocessor] = [],
|
|
26
|
+
*,
|
|
27
|
+
# For all spectra
|
|
28
|
+
n_fft: Optional[int] = None,
|
|
29
|
+
hop_length: Optional[int] = None,
|
|
30
|
+
|
|
31
|
+
# For all log spectra
|
|
32
|
+
is_logarithmic: bool = True,
|
|
33
|
+
):
|
|
34
|
+
super(FullRangeStftFeatureSource, self).__init__(preprocessors)
|
|
35
|
+
|
|
36
|
+
# Internal configs
|
|
37
|
+
self.sample_rate = sample_rate
|
|
38
|
+
|
|
39
|
+
if n_fft is not None and not power_of_two(n_fft):
|
|
40
|
+
raise ValueError("n_fft must be a power of 2")
|
|
41
|
+
self.n_fft = n_fft or 1024
|
|
42
|
+
|
|
43
|
+
if hop_length is not None and hop_length > (self.n_fft // 2):
|
|
44
|
+
warnings.warn(f"hop_length should be set to no more than 1/2 the FFT window size, or {self.n_fft // 2} mels for n_fft = {self.n_fft} (currently {hop_length})")
|
|
45
|
+
self.hop_length = hop_length or self.n_fft // 4
|
|
46
|
+
|
|
47
|
+
# Basic spectrogram
|
|
48
|
+
self._stft = ExportableSTFT(self.n_fft, self.hop_length)
|
|
49
|
+
|
|
50
|
+
# DB scaling, if necessary
|
|
51
|
+
self.amplitude_to_DB = create_scaler(ScalingType.POWER) if is_logarithmic else None
|
|
52
|
+
|
|
53
|
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
54
|
+
for preproc in self.preprocessors:
|
|
55
|
+
wav = preproc(wav)
|
|
56
|
+
|
|
57
|
+
spec = self._stft(wav)
|
|
58
|
+
|
|
59
|
+
if self.amplitude_to_DB is not None:
|
|
60
|
+
spec = self.amplitude_to_DB(spec)
|
|
61
|
+
|
|
62
|
+
return scale_spec(spec).unsqueeze(0)
|
AudioMlSpecTools/wav.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
###############################################################################
|
|
2
|
+
# Global Imports
|
|
3
|
+
###############################################################################
|
|
4
|
+
import os
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
###############################################################################
|
|
8
|
+
# 3PP Imports
|
|
9
|
+
###############################################################################
|
|
10
|
+
import torch
|
|
11
|
+
import torchcodec
|
|
12
|
+
|
|
13
|
+
###############################################################################
|
|
14
|
+
# Local Imports
|
|
15
|
+
###############################################################################
|
|
16
|
+
from ._common import resample
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
###############################################################################
|
|
20
|
+
# Helpers
|
|
21
|
+
###############################################################################
|
|
22
|
+
def _is_multichannel(wave: torch.Tensor) -> bool:
|
|
23
|
+
return wave.size(0) > 1
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
###############################################################################
|
|
27
|
+
# Functions
|
|
28
|
+
###############################################################################
|
|
29
|
+
def set_audio_length(wave: torch.Tensor, sr: int, duration_secs: Optional[int]):
|
|
30
|
+
if duration_secs is None:
|
|
31
|
+
return wave
|
|
32
|
+
else:
|
|
33
|
+
# Truncate as needed
|
|
34
|
+
start_secs = 0
|
|
35
|
+
end_secs = duration_secs
|
|
36
|
+
st_idx, end_idx = int(start_secs * sr), int(end_secs * sr)
|
|
37
|
+
wave = wave[:, st_idx:end_idx]
|
|
38
|
+
|
|
39
|
+
# Zero Padding
|
|
40
|
+
num_samples = int(sr * duration_secs)
|
|
41
|
+
padding_size = max(num_samples - wave.size(1), 0)
|
|
42
|
+
if padding_size > 0:
|
|
43
|
+
wave = torch.cat([wave, torch.zeros(1, padding_size)], dim=1)
|
|
44
|
+
|
|
45
|
+
# Trim
|
|
46
|
+
wave = wave[:, :num_samples]
|
|
47
|
+
|
|
48
|
+
return wave
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_wav(path: str,
|
|
52
|
+
*,
|
|
53
|
+
target_sr: Optional[int] = None,
|
|
54
|
+
duration_secs: Optional[int] = None,
|
|
55
|
+
) -> torch.Tensor:
|
|
56
|
+
'''
|
|
57
|
+
Load a WAV file into memory for processing.
|
|
58
|
+
|
|
59
|
+
Other file types may work but have not been tested.
|
|
60
|
+
|
|
61
|
+
Positional arguments:
|
|
62
|
+
path -- Location of the audio file
|
|
63
|
+
|
|
64
|
+
Keyword arguments:
|
|
65
|
+
target_sr -- Sample rate to which audio should be resampled
|
|
66
|
+
duration_secs -- Consistent duration of output audio, either by truncation or zero-padding as needed
|
|
67
|
+
'''
|
|
68
|
+
|
|
69
|
+
audio = torchcodec.decoders.AudioDecoder(path).get_all_samples()
|
|
70
|
+
wave = audio.data
|
|
71
|
+
|
|
72
|
+
# Resample
|
|
73
|
+
final_sr = target_sr or audio.sample_rate
|
|
74
|
+
|
|
75
|
+
if target_sr is not None:
|
|
76
|
+
wave = resample(wave, orig_freq=audio.sample_rate, new_freq=target_sr)
|
|
77
|
+
|
|
78
|
+
# Stereo -> Mono
|
|
79
|
+
if _is_multichannel(wave):
|
|
80
|
+
wave = torch.mean(wave, dim=0, keepdim=True)
|
|
81
|
+
|
|
82
|
+
wave = set_audio_length(wave, final_sr, duration_secs)
|
|
83
|
+
|
|
84
|
+
return wave
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def list_audio_files(dir: str) -> list[str]:
|
|
88
|
+
return sorted([f for f in os.listdir(dir) if f.endswith(".wav")])
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
###############################################################################
|
|
92
|
+
# Classes
|
|
93
|
+
###############################################################################
|
|
94
|
+
class WavReader:
|
|
95
|
+
def __init__(self, target_sr: int, duration_secs: Optional[int] = None):
|
|
96
|
+
self.target_sr = target_sr
|
|
97
|
+
self.duration_secs = duration_secs
|
|
98
|
+
|
|
99
|
+
def load(self, path: str) -> torch.Tensor:
|
|
100
|
+
return load_wav(path, target_sr=self.target_sr, duration_secs=self.duration_secs)
|
|
101
|
+
|
|
102
|
+
def __call__(self, path: str) -> torch.Tensor:
|
|
103
|
+
return self.load(path)
|
|
104
|
+
|
|
105
|
+
def clip(self, wav: torch.Tensor, start_sec: int, end_sec: int) -> torch.Tensor:
|
|
106
|
+
start_frame = start_sec * self.target_sr
|
|
107
|
+
end_frame = end_sec * self.target_sr
|
|
108
|
+
return wav[:, start_frame:end_frame]
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: AudioMlSpecTools
|
|
3
|
+
Version: 0.5.0
|
|
4
|
+
Summary: Convenience functions for generating ML features from audio data
|
|
5
|
+
Author-email: Ryan Quinn <ryan.quinn@certusinnovations.com>
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/Stonewall-Defense/team-ml-audio-features
|
|
8
|
+
Classifier: Development Status :: 4 - Beta
|
|
9
|
+
Classifier: Environment :: Console
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Natural Language :: English
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Requires-Python: >=3.12
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: torch==2.10.0
|
|
18
|
+
Requires-Dist: torchcodec==0.10.0
|
|
19
|
+
Dynamic: license-file
|
|
20
|
+
|
|
21
|
+
# Audio ML Spec Tools
|
|
22
|
+
|
|
23
|
+
Convenience functions for generating ML features from audio data. Breaks audio ML dependencies on `torchaudio`. Unlike `pytorch` features, these functions can be exported to ExecuTorch and ONNX with no issues.
|
|
24
|
+
|
|
25
|
+
## Motivation
|
|
26
|
+
|
|
27
|
+
Except in specific circumstances like `wav2vec`, raw audio has proven to be a much worse input for ML models than spectrogram-based features across a wide variety of problem domains, including environmental sound classificarion ([Guzhov et al. (2021)](https://arxiv.org/pdf/2104.11587)), singing technique classification ([Yamamoto et al. (2021)](https://www.slis.tsukuba.ac.jp/lspc/0000890.pdf)), and ship classification ([Xie, Ren, and Xu (2024)](https://arxiv.org/pdf/2306.01002)).
|
|
28
|
+
|
|
29
|
+
There is no scientific consensus on the relative benefits of mel-scale spectrograms, linear spectrograms, and MFCCs. Different researchers have shown good results with each type of spectrogram; see respectively [Raponi, Oligeri, and Ali (2021)](https://arxiv.org/pdf/2004.07948), [Jung at al. (2021)](https://www.mdpi.com/2075-4418/11/4/732), and [Razani et al (2017)](https://www.ece.mcgill.ca/~bchamp/Papers/Conference/ISSPIT2017.pdf).
|
|
30
|
+
|
|
31
|
+
With this library, you can easily try as many feature extraction methods as you want to see what works for your use case.
|
|
32
|
+
|
|
33
|
+
## Prerequisites
|
|
34
|
+
|
|
35
|
+
- Python 3.12 runtime
|
|
36
|
+
- `pip` for package installation
|
|
37
|
+
- Note that `torchcodec` depends on a system installation of FFmpeg
|
|
38
|
+
|
|
39
|
+
## Installation
|
|
40
|
+
|
|
41
|
+
Install the dependencies into the environment with [pip](https://pypi.org/project/pip/):
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install -r requirements.txt
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Then install the package itself locally:
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install .
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
## Usage
|
|
54
|
+
|
|
55
|
+
See `examples/features.py`.
|
|
56
|
+
|
|
57
|
+
## Testing
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
python3 -m coverage run -m unittest discover -s test -p "*_test.py" && python -m coverage report --skip-covered
|
|
61
|
+
python -m coverage html
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Versioning
|
|
65
|
+
|
|
66
|
+
We use [SemVer](http://semver.org/) for versioning. For the versions available, see the [tags on this repository](https://github.com/Stonewall-Defense/team-ml-audio-features/tags).
|
|
67
|
+
|
|
68
|
+
## Authors
|
|
69
|
+
|
|
70
|
+
- **Ryan Quinn** - *Initial work*
|
|
71
|
+
|
|
72
|
+
## License
|
|
73
|
+
|
|
74
|
+
MIT.
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
AudioMlSpecTools/__init__.py,sha256=kQxQCIXzMWWCAu8tXNTggOyrFGdvkFBrgNprjQobnG8,438
|
|
2
|
+
AudioMlSpecTools/_common.py,sha256=GzHnubmVsMN9IW-OM-GkzWobxqrY_z2EEu7N8pBZ74s,17367
|
|
3
|
+
AudioMlSpecTools/efficient_features.py,sha256=F_wfXWWMQtdWudKUPc9VBnjslhz6WelQpa_cH0SMH-w,6966
|
|
4
|
+
AudioMlSpecTools/flexible_features.py,sha256=BSOucAPF687lecOzgDoO1fsk5sys9HnOa9IyyxaJg3Q,8662
|
|
5
|
+
AudioMlSpecTools/preproc.py,sha256=HyIzkz2_NoImvY0_brjtLGkuGofmnB-lNKzOJMbw5Cw,2010
|
|
6
|
+
AudioMlSpecTools/stft_features.py,sha256=UZXSZHjDFDUjmPzYBBq4DSmEz1x_e-YWXAITS30xMwo,2474
|
|
7
|
+
AudioMlSpecTools/wav.py,sha256=3kxu336FFYYnSH949VECPOKKguXhFSaia9kkh4__m2o,3605
|
|
8
|
+
audiomlspectools-0.5.0.dist-info/licenses/LICENSE,sha256=4MM0VuiftnP2VCMsqBWR9luSTrr9kX06xVn6VPisv5I,1069
|
|
9
|
+
audiomlspectools-0.5.0.dist-info/METADATA,sha256=ikkDfP1EzDK_198YV2Y1SRUZhyBwrCMiEhF9CZ-N4MQ,2851
|
|
10
|
+
audiomlspectools-0.5.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
11
|
+
audiomlspectools-0.5.0.dist-info/top_level.txt,sha256=DIiCB2FGOgLnCb0oekQqlbrV9mX6k6EoAdZCnK8cncw,17
|
|
12
|
+
audiomlspectools-0.5.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
Copyright © 2025 Certus Innovations
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
|
4
|
+
|
|
5
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
|
6
|
+
|
|
7
|
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
AudioMlSpecTools
|