heed-wakeword 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- heed/__init__.py +14 -0
- heed/audio.py +457 -0
- heed/augment.py +428 -0
- heed/cli.py +963 -0
- heed/eval.py +102 -0
- heed/export.py +550 -0
- heed/gate.py +77 -0
- heed/infer.py +204 -0
- heed/model.py +77 -0
- heed/trainer.py +1014 -0
- heed/tts.py +564 -0
- heed/tts_kokoro.py +430 -0
- heed/web.py +3483 -0
- heed_wakeword-0.1.0.dist-info/METADATA +251 -0
- heed_wakeword-0.1.0.dist-info/RECORD +20 -0
- heed_wakeword-0.1.0.dist-info/WHEEL +5 -0
- heed_wakeword-0.1.0.dist-info/entry_points.txt +2 -0
- heed_wakeword-0.1.0.dist-info/licenses/LICENSE +201 -0
- heed_wakeword-0.1.0.dist-info/licenses/NOTICE +8 -0
- heed_wakeword-0.1.0.dist-info/top_level.txt +1 -0
heed/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
2
|
+
|
|
3
|
+
SAMPLE_RATE = 16000
|
|
4
|
+
WINDOW_SECONDS = 1.0
|
|
5
|
+
HOP_SAMPLES = 160 # 10 ms at 16 kHz
|
|
6
|
+
WIN_SAMPLES = 400 # 25 ms at 16 kHz - STFT analysis window (Hann), zero-padded to N_FFT
|
|
7
|
+
N_FFT = 512 # FFT size: nearest power of two >= WIN_SAMPLES. A power-of-two
|
|
8
|
+
# transform is a fast radix-2 FFT in every deployment runtime
|
|
9
|
+
# (JS/Swift/Kotlin/C) instead of a slow Bluestein/DFT for N=400.
|
|
10
|
+
# The 25 ms analysis window is preserved (win_length=400); the
|
|
11
|
+
# window is just zero-padded to 512 before the FFT (librosa's
|
|
12
|
+
# standard convention). Mel features come out 40 x 101 either way.
|
|
13
|
+
N_MELS = 40
|
|
14
|
+
WINDOW_FRAMES = int(WINDOW_SECONDS * SAMPLE_RATE / HOP_SAMPLES) # ~100 frames
|
heed/audio.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
1
|
+
"""Audio I/O, mel features, normalization, trimming.
|
|
2
|
+
|
|
3
|
+
Pure torch + scipy + soundfile. No torchaudio C-extension required - this
|
|
4
|
+
was a portability issue on Windows where torchaudio's binary often mismatches
|
|
5
|
+
the installed torch.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Optional
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
import scipy.signal
|
|
16
|
+
import soundfile as sf
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
from . import HOP_SAMPLES, N_FFT, N_MELS, SAMPLE_RATE, WIN_SAMPLES, WINDOW_FRAMES
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# ----- pure-torch mel feature extractor (matches torchaudio at <1e-3 max-abs) ---
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _hz_to_mel(hz: float) -> float:
|
|
26
|
+
return 2595.0 * math.log10(1.0 + hz / 700.0)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _mel_to_hz(m: float) -> float:
|
|
30
|
+
return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _build_mel_filterbank(
|
|
34
|
+
n_mels: int = N_MELS,
|
|
35
|
+
n_fft: int = N_FFT,
|
|
36
|
+
sample_rate: int = SAMPLE_RATE,
|
|
37
|
+
fmin: float = 0.0,
|
|
38
|
+
fmax: Optional[float] = None,
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
if fmax is None:
|
|
41
|
+
fmax = sample_rate / 2
|
|
42
|
+
mel_min = _hz_to_mel(fmin)
|
|
43
|
+
mel_max = _hz_to_mel(fmax)
|
|
44
|
+
mel_pts = torch.linspace(mel_min, mel_max, n_mels + 2)
|
|
45
|
+
hz_pts = torch.tensor([_mel_to_hz(m.item()) for m in mel_pts])
|
|
46
|
+
freqs = torch.linspace(0.0, sample_rate / 2, n_fft // 2 + 1)
|
|
47
|
+
fb = torch.zeros(n_mels, n_fft // 2 + 1)
|
|
48
|
+
for m in range(n_mels):
|
|
49
|
+
left, center, right = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
|
|
50
|
+
rising = (freqs - left) / (center - left + 1e-9)
|
|
51
|
+
falling = (right - freqs) / (right - center + 1e-9)
|
|
52
|
+
fb[m] = torch.clamp(torch.minimum(rising, falling), min=0.0)
|
|
53
|
+
return fb
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_MEL_FB: Optional[torch.Tensor] = None
|
|
57
|
+
_HANN: Optional[torch.Tensor] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _mel_fb() -> torch.Tensor:
|
|
61
|
+
global _MEL_FB
|
|
62
|
+
if _MEL_FB is None:
|
|
63
|
+
_MEL_FB = _build_mel_filterbank()
|
|
64
|
+
return _MEL_FB
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _hann() -> torch.Tensor:
|
|
68
|
+
global _HANN
|
|
69
|
+
if _HANN is None:
|
|
70
|
+
_HANN = torch.hann_window(WIN_SAMPLES)
|
|
71
|
+
return _HANN
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def log_mel(audio: torch.Tensor, apply_cmn: bool = True) -> torch.Tensor:
|
|
75
|
+
"""Compute log-mel spectrogram. Input (T,) or (B, T); output (B, n_mels, F).
|
|
76
|
+
|
|
77
|
+
When `apply_cmn` is True (default), subtracts the per-clip mean across
|
|
78
|
+
time from each mel bin. This is **cepstral mean normalization** - a
|
|
79
|
+
standard ASR trick that makes the representation invariant to mic
|
|
80
|
+
frequency response and channel gain (these become additive constants
|
|
81
|
+
in log-mel, eliminated by mean subtraction). Critical for cross-mic
|
|
82
|
+
robustness; the model trained without it is locked to the trainer's
|
|
83
|
+
mic spectrum.
|
|
84
|
+
|
|
85
|
+
Must be applied **consistently** at train and inference: a model trained
|
|
86
|
+
with CMN won't work without it and vice versa. Hence default-on
|
|
87
|
+
everywhere in this codebase.
|
|
88
|
+
"""
|
|
89
|
+
if audio.ndim == 1:
|
|
90
|
+
audio = audio.unsqueeze(0)
|
|
91
|
+
spec = torch.stft(
|
|
92
|
+
audio,
|
|
93
|
+
n_fft=N_FFT,
|
|
94
|
+
hop_length=HOP_SAMPLES,
|
|
95
|
+
win_length=WIN_SAMPLES,
|
|
96
|
+
window=_hann().to(audio.device),
|
|
97
|
+
center=True,
|
|
98
|
+
pad_mode="reflect",
|
|
99
|
+
return_complex=True,
|
|
100
|
+
)
|
|
101
|
+
power = spec.real.pow(2) + spec.imag.pow(2)
|
|
102
|
+
fb = _mel_fb().to(audio.device)
|
|
103
|
+
mel = fb @ power # (B, n_mels, T)
|
|
104
|
+
out = torch.log(mel.clamp(min=1e-9))
|
|
105
|
+
if apply_cmn:
|
|
106
|
+
# Per-clip mean across time, per mel bin. Subtraction removes the
|
|
107
|
+
# multiplicative mic-spectrum contribution that became additive
|
|
108
|
+
# after log.
|
|
109
|
+
out = out - out.mean(dim=-1, keepdim=True)
|
|
110
|
+
return out
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def expected_frames(window_samples: int = SAMPLE_RATE) -> int:
|
|
114
|
+
return window_samples // HOP_SAMPLES + 1
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# ----- WAV I/O with scipy-based resample ------------------------------------
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _resample(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
|
|
121
|
+
"""scipy polyphase resample. Good quality, no C-ext dependency."""
|
|
122
|
+
if src_sr == dst_sr:
|
|
123
|
+
return audio
|
|
124
|
+
g = math.gcd(src_sr, dst_sr)
|
|
125
|
+
up = dst_sr // g
|
|
126
|
+
down = src_sr // g
|
|
127
|
+
return scipy.signal.resample_poly(audio, up, down).astype(np.float32)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
_LOW_SR_WARNED: set[str] = set()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def load_wav(path: str | Path) -> torch.Tensor:
|
|
134
|
+
"""Load any WAV file → mono float32 tensor at SAMPLE_RATE.
|
|
135
|
+
|
|
136
|
+
Auto-resamples non-16 kHz inputs via scipy polyphase, auto-downmixes
|
|
137
|
+
stereo to mono, auto-converts integer PCM to float32 (handled by
|
|
138
|
+
soundfile). The only sample-rate footgun is *very low* rates: telephony
|
|
139
|
+
(8 kHz) drops content above 4 kHz, which removes most fricatives
|
|
140
|
+
(s, sh, f). We warn once per file in that case.
|
|
141
|
+
"""
|
|
142
|
+
audio, sr = sf.read(str(path), dtype="float32", always_2d=False)
|
|
143
|
+
if audio.ndim > 1:
|
|
144
|
+
audio = audio.mean(axis=1)
|
|
145
|
+
if sr < 14000 and str(path) not in _LOW_SR_WARNED:
|
|
146
|
+
_LOW_SR_WARNED.add(str(path))
|
|
147
|
+
print(
|
|
148
|
+
f"[warn] {path} is only {sr} Hz - high-frequency consonants "
|
|
149
|
+
f"(s/sh/f) will be missing. Re-record at 16 kHz or higher if "
|
|
150
|
+
f"your wake word contains them."
|
|
151
|
+
)
|
|
152
|
+
if sr != SAMPLE_RATE:
|
|
153
|
+
audio = _resample(audio, sr, SAMPLE_RATE)
|
|
154
|
+
return torch.from_numpy(np.ascontiguousarray(audio))
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def save_wav(path: str | Path, audio: torch.Tensor | np.ndarray) -> None:
|
|
158
|
+
"""Save mono audio to WAV at SAMPLE_RATE."""
|
|
159
|
+
if isinstance(audio, torch.Tensor):
|
|
160
|
+
audio = audio.detach().cpu().numpy()
|
|
161
|
+
audio = np.asarray(audio, dtype=np.float32).squeeze()
|
|
162
|
+
sf.write(str(path), audio, SAMPLE_RATE, subtype="PCM_16")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
# ----- normalization + windowing -------------------------------------------
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def peak_normalize(audio: torch.Tensor, target_dbfs: float = -3.0) -> torch.Tensor:
|
|
169
|
+
"""Scale audio so peak is at target_dbfs."""
|
|
170
|
+
peak = audio.abs().max().clamp(min=1e-9)
|
|
171
|
+
target = 10 ** (target_dbfs / 20.0)
|
|
172
|
+
if peak < 1e-6:
|
|
173
|
+
return audio
|
|
174
|
+
return audio * (target / peak)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
# Cached filter coefficients so we don't recompute butter() on every clip.
|
|
178
|
+
_HPF_CACHE: dict[tuple[float, int, int], np.ndarray] = {}
|
|
179
|
+
_NOTCH_CACHE: dict[tuple[float, float, int], np.ndarray] = {}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _hpf_coefficients(cutoff_hz: float, sample_rate: int, order: int) -> np.ndarray:
|
|
183
|
+
key = (float(cutoff_hz), int(sample_rate), int(order))
|
|
184
|
+
if key not in _HPF_CACHE:
|
|
185
|
+
sos = scipy.signal.butter(order, cutoff_hz, btype="highpass",
|
|
186
|
+
fs=sample_rate, output="sos")
|
|
187
|
+
_HPF_CACHE[key] = sos
|
|
188
|
+
return _HPF_CACHE[key]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _notch_coefficients(freq_hz: float, q: float, sample_rate: int) -> np.ndarray:
|
|
192
|
+
"""IIR notch (biquad) coefficients. q higher = narrower notch."""
|
|
193
|
+
key = (float(freq_hz), float(q), int(sample_rate))
|
|
194
|
+
if key not in _NOTCH_CACHE:
|
|
195
|
+
b, a = scipy.signal.iirnotch(freq_hz, q, fs=sample_rate)
|
|
196
|
+
# convert to SOS so the causal sosfilt cascade is applied uniformly
|
|
197
|
+
sos = scipy.signal.tf2sos(b, a)
|
|
198
|
+
_NOTCH_CACHE[key] = sos
|
|
199
|
+
return _NOTCH_CACHE[key]
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def highpass_filter(
|
|
203
|
+
audio: torch.Tensor,
|
|
204
|
+
cutoff_hz: float = 100.0,
|
|
205
|
+
sample_rate: int = SAMPLE_RATE,
|
|
206
|
+
order: int = 8,
|
|
207
|
+
apply_mains_notch: bool = True,
|
|
208
|
+
) -> torch.Tensor:
|
|
209
|
+
"""Aggressively remove sub-cutoff_hz content + mains-hum notches.
|
|
210
|
+
|
|
211
|
+
Default 8th-order Butterworth at 100 Hz: -48 dB/octave rolloff, so 60 Hz
|
|
212
|
+
is at ~-32 dB and 30 Hz at ~-56 dB. Loses the very bottom of male voice
|
|
213
|
+
fundamentals (typical adult male 85-180 Hz; we sacrifice 85-100 Hz),
|
|
214
|
+
but speech intelligibility lives in the formants (300-3500 Hz), so the
|
|
215
|
+
model still has plenty to work with - and the gain in mic-noise
|
|
216
|
+
rejection is large.
|
|
217
|
+
|
|
218
|
+
Plus optional notch filters at 50 Hz and 60 Hz, killing mains hum
|
|
219
|
+
that bleeds into low-end harmonics. Notches are narrow (Q=30) so they
|
|
220
|
+
don't audibly affect anything but the offending frequencies.
|
|
221
|
+
|
|
222
|
+
CAUSAL single-pass (``scipy.signal.sosfilt``), initialized with the
|
|
223
|
+
steady-state ``sosfilt_zi * x[0]`` so there is no startup DC step. This
|
|
224
|
+
is the key property that makes real-time streaming possible: a causal
|
|
225
|
+
filter's output for a given sample never changes once later samples
|
|
226
|
+
arrive, so a streaming deployment can filter each new chunk once (with
|
|
227
|
+
retained state via :class:`StreamingHighpass`) and never recompute old
|
|
228
|
+
audio. Filtering a stream chunk-by-chunk with retained state is
|
|
229
|
+
bit-identical (it is an LTI system) to one call of this function over
|
|
230
|
+
the concatenated signal - so a model trained on clips filtered here
|
|
231
|
+
sees consistent features when deployed with the streaming filter.
|
|
232
|
+
|
|
233
|
+
(The previous implementation used zero-phase ``sosfiltfilt``, which runs
|
|
234
|
+
the filter forwards *and* backwards; the backward pass makes every output
|
|
235
|
+
sample depend on all future samples, which is impossible to stream and
|
|
236
|
+
forced the JS preprocessor to recompute the entire 1-s window every hop.)
|
|
237
|
+
"""
|
|
238
|
+
if audio.numel() < 1:
|
|
239
|
+
return audio
|
|
240
|
+
flat = audio.flatten().detach().cpu().numpy().astype(np.float64)
|
|
241
|
+
try:
|
|
242
|
+
sos_hpf = _hpf_coefficients(cutoff_hz, sample_rate, order)
|
|
243
|
+
zi = scipy.signal.sosfilt_zi(sos_hpf)
|
|
244
|
+
flat, _ = scipy.signal.sosfilt(sos_hpf, flat, zi=zi * flat[0])
|
|
245
|
+
if apply_mains_notch:
|
|
246
|
+
# Both 50 Hz (Europe) and 60 Hz (Americas) - we don't know which.
|
|
247
|
+
# Applying both is cheap and harmless: each is a narrow notch.
|
|
248
|
+
for f in (50.0, 60.0):
|
|
249
|
+
sos_notch = _notch_coefficients(f, q=30.0, sample_rate=sample_rate)
|
|
250
|
+
zin = scipy.signal.sosfilt_zi(sos_notch)
|
|
251
|
+
flat, _ = scipy.signal.sosfilt(sos_notch, flat, zi=zin * flat[0])
|
|
252
|
+
except ValueError:
|
|
253
|
+
return audio # falls through on any unexpected scipy edge case
|
|
254
|
+
flat = np.ascontiguousarray(flat.astype(np.float32))
|
|
255
|
+
out = torch.from_numpy(flat)
|
|
256
|
+
return out if audio.ndim == 1 else out.view_as(audio)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class StreamingHighpass:
|
|
260
|
+
"""Causal, stateful counterpart to :func:`highpass_filter` for real-time
|
|
261
|
+
streaming inference.
|
|
262
|
+
|
|
263
|
+
Holds the per-section biquad state across calls, so each incoming audio
|
|
264
|
+
chunk is filtered exactly once and old audio is never re-touched. The
|
|
265
|
+
output is bit-identical (an LTI cascade) to calling ``highpass_filter``
|
|
266
|
+
on the concatenation of all chunks seen so far - so it stays consistent
|
|
267
|
+
with the offline filtering used to prepare training clips.
|
|
268
|
+
|
|
269
|
+
This is the exact algorithm mirrored in the JS demos
|
|
270
|
+
(``examples/*/preprocessing.js``); keeping the two in lock-step is what
|
|
271
|
+
lets the browser / React-Native runtimes match Python bit-for-bit.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
cutoff_hz: float = 100.0,
|
|
277
|
+
sample_rate: int = SAMPLE_RATE,
|
|
278
|
+
order: int = 8,
|
|
279
|
+
apply_mains_notch: bool = True,
|
|
280
|
+
) -> None:
|
|
281
|
+
self._sos = [_hpf_coefficients(cutoff_hz, sample_rate, order)]
|
|
282
|
+
if apply_mains_notch:
|
|
283
|
+
self._sos += [
|
|
284
|
+
_notch_coefficients(f, q=30.0, sample_rate=sample_rate)
|
|
285
|
+
for f in (50.0, 60.0)
|
|
286
|
+
]
|
|
287
|
+
# Unit steady-state initial conditions per stage; scaled by each
|
|
288
|
+
# stage's first input sample on the first chunk (matches scipy's
|
|
289
|
+
# sosfilt_zi * x[0] convention used in highpass_filter above).
|
|
290
|
+
self._zi_unit = [scipy.signal.sosfilt_zi(sos) for sos in self._sos]
|
|
291
|
+
self._zi: list[Optional[np.ndarray]] = [None] * len(self._sos)
|
|
292
|
+
|
|
293
|
+
def reset(self) -> None:
|
|
294
|
+
self._zi = [None] * len(self._sos)
|
|
295
|
+
|
|
296
|
+
def __call__(self, chunk: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray:
|
|
297
|
+
is_tensor = isinstance(chunk, torch.Tensor)
|
|
298
|
+
x = (chunk.detach().cpu().numpy() if is_tensor else np.asarray(chunk))
|
|
299
|
+
x = x.astype(np.float64).flatten()
|
|
300
|
+
if x.size == 0:
|
|
301
|
+
return chunk
|
|
302
|
+
for i, sos in enumerate(self._sos):
|
|
303
|
+
if self._zi[i] is None:
|
|
304
|
+
self._zi[i] = self._zi_unit[i] * x[0]
|
|
305
|
+
x, self._zi[i] = scipy.signal.sosfilt(sos, x, zi=self._zi[i])
|
|
306
|
+
x = np.ascontiguousarray(x.astype(np.float32))
|
|
307
|
+
return torch.from_numpy(x) if is_tensor else x
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def trim_silence(
|
|
311
|
+
audio: torch.Tensor,
|
|
312
|
+
frame_ms: float = 20.0,
|
|
313
|
+
threshold_db: float = -40.0,
|
|
314
|
+
) -> torch.Tensor:
|
|
315
|
+
"""Trim leading/trailing silence below threshold (relative to peak)."""
|
|
316
|
+
if audio.numel() == 0:
|
|
317
|
+
return audio
|
|
318
|
+
frame_len = int(SAMPLE_RATE * frame_ms / 1000)
|
|
319
|
+
if frame_len <= 0 or audio.numel() <= frame_len:
|
|
320
|
+
return audio
|
|
321
|
+
pad = (-audio.numel()) % frame_len
|
|
322
|
+
padded = torch.cat([audio, audio.new_zeros(pad)])
|
|
323
|
+
frames = padded.unfold(0, frame_len, frame_len)
|
|
324
|
+
rms = frames.pow(2).mean(dim=-1).sqrt()
|
|
325
|
+
peak_rms = rms.max().clamp(min=1e-9)
|
|
326
|
+
threshold = peak_rms * (10 ** (threshold_db / 20.0))
|
|
327
|
+
voiced = rms > threshold
|
|
328
|
+
if not voiced.any():
|
|
329
|
+
return audio
|
|
330
|
+
idx = voiced.nonzero(as_tuple=False).squeeze(-1)
|
|
331
|
+
start = idx.min().item() * frame_len
|
|
332
|
+
end = (idx.max().item() + 1) * frame_len
|
|
333
|
+
return audio[start:end]
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def center_in_window(audio: torch.Tensor, window_samples: int = SAMPLE_RATE) -> torch.Tensor:
|
|
337
|
+
"""Place audio in a fixed-length window, centered. Pad with zeros or truncate."""
|
|
338
|
+
n = audio.numel()
|
|
339
|
+
if n == window_samples:
|
|
340
|
+
return audio
|
|
341
|
+
if n < window_samples:
|
|
342
|
+
pad_total = window_samples - n
|
|
343
|
+
left = pad_total // 2
|
|
344
|
+
right = pad_total - left
|
|
345
|
+
return torch.cat([audio.new_zeros(left), audio, audio.new_zeros(right)])
|
|
346
|
+
excess = n - window_samples
|
|
347
|
+
start = excess // 2
|
|
348
|
+
return audio[start : start + window_samples]
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def prepare_clip(audio: torch.Tensor, window_samples: int = SAMPLE_RATE) -> torch.Tensor:
|
|
352
|
+
"""Normalize, high-pass filter, trim silence, and center in a fixed window."""
|
|
353
|
+
audio = highpass_filter(audio) # remove sub-80Hz rumble before anything else
|
|
354
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
355
|
+
audio = trim_silence(audio)
|
|
356
|
+
if audio.numel() == 0:
|
|
357
|
+
return torch.zeros(window_samples)
|
|
358
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
359
|
+
return center_in_window(audio, window_samples)
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def end_align_in_window(
|
|
363
|
+
audio: torch.Tensor,
|
|
364
|
+
window_samples: int = SAMPLE_RATE,
|
|
365
|
+
trailing_silence_ms: float = 50.0,
|
|
366
|
+
) -> torch.Tensor:
|
|
367
|
+
"""Place audio at the right edge of a fixed window with optional trailing
|
|
368
|
+
silence padding to the right (so we don't end exactly at the buffer's last
|
|
369
|
+
sample, which is unnatural). Used to create end-aligned positive variants
|
|
370
|
+
that train the model to fire on phrase-completion at the buffer's right
|
|
371
|
+
edge - matching how the sliding inference buffer captures fresh audio.
|
|
372
|
+
"""
|
|
373
|
+
n = audio.numel()
|
|
374
|
+
if n >= window_samples:
|
|
375
|
+
return audio[-window_samples:]
|
|
376
|
+
trailing = min(int(SAMPLE_RATE * trailing_silence_ms / 1000),
|
|
377
|
+
window_samples - n)
|
|
378
|
+
leading = window_samples - n - trailing
|
|
379
|
+
return torch.cat([audio.new_zeros(leading), audio, audio.new_zeros(trailing)])
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def prepare_clip_end_aligned(
|
|
383
|
+
audio: torch.Tensor,
|
|
384
|
+
window_samples: int = SAMPLE_RATE,
|
|
385
|
+
trailing_silence_ms: float = 50.0,
|
|
386
|
+
) -> torch.Tensor:
|
|
387
|
+
"""prepare_clip variant: HPF + normalize + trim silence + end-align (right edge)."""
|
|
388
|
+
audio = highpass_filter(audio)
|
|
389
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
390
|
+
audio = trim_silence(audio)
|
|
391
|
+
if audio.numel() == 0:
|
|
392
|
+
return torch.zeros(window_samples)
|
|
393
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
394
|
+
return end_align_in_window(audio, window_samples, trailing_silence_ms)
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def random_align_in_window(
|
|
398
|
+
audio: torch.Tensor,
|
|
399
|
+
window_samples: int = SAMPLE_RATE,
|
|
400
|
+
rng: "random.Random | None" = None,
|
|
401
|
+
) -> torch.Tensor:
|
|
402
|
+
"""Place audio at a RANDOM position within a fixed-length window.
|
|
403
|
+
|
|
404
|
+
Streaming inference sees the wake phrase at every alignment as the
|
|
405
|
+
rolling buffer scrolls. Centered + end-aligned training cover the two
|
|
406
|
+
extremes; random-aligned variants fill the in-between space so the
|
|
407
|
+
model isn't pinned to specific positions.
|
|
408
|
+
"""
|
|
409
|
+
import random as _r
|
|
410
|
+
rng = rng or _r
|
|
411
|
+
n = audio.numel()
|
|
412
|
+
if n >= window_samples:
|
|
413
|
+
# Random crop instead of forced alignment
|
|
414
|
+
start = rng.randint(0, n - window_samples)
|
|
415
|
+
return audio[start : start + window_samples]
|
|
416
|
+
# Otherwise: pick a random left-pad amount and place the audio
|
|
417
|
+
leading = rng.randint(0, window_samples - n)
|
|
418
|
+
trailing = window_samples - n - leading
|
|
419
|
+
return torch.cat([audio.new_zeros(leading), audio, audio.new_zeros(trailing)])
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def prepare_clip_random_aligned(
|
|
423
|
+
audio: torch.Tensor,
|
|
424
|
+
window_samples: int = SAMPLE_RATE,
|
|
425
|
+
rng: "random.Random | None" = None,
|
|
426
|
+
) -> torch.Tensor:
|
|
427
|
+
"""prepare_clip variant: HPF + normalize + trim silence + random-align.
|
|
428
|
+
|
|
429
|
+
Adds variance to where speech appears within the 1-s window. Combined
|
|
430
|
+
with the existing centered + end-aligned variants, the model learns to
|
|
431
|
+
detect the wake phrase regardless of its alignment in the buffer -
|
|
432
|
+
closing the train/infer mismatch from streaming buffers scrolling
|
|
433
|
+
audio through every possible position.
|
|
434
|
+
"""
|
|
435
|
+
audio = highpass_filter(audio)
|
|
436
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
437
|
+
audio = trim_silence(audio)
|
|
438
|
+
if audio.numel() == 0:
|
|
439
|
+
return torch.zeros(window_samples)
|
|
440
|
+
audio = peak_normalize(audio, target_dbfs=-3.0)
|
|
441
|
+
return random_align_in_window(audio, window_samples, rng=rng)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def load_dir_clips(directory: str | Path, window_samples: int = SAMPLE_RATE) -> list[torch.Tensor]:
|
|
445
|
+
"""Load every .wav in a directory, return prepared fixed-length clips."""
|
|
446
|
+
directory = Path(directory)
|
|
447
|
+
if not directory.is_dir():
|
|
448
|
+
return []
|
|
449
|
+
clips: list[torch.Tensor] = []
|
|
450
|
+
for p in sorted(directory.glob("*.wav")):
|
|
451
|
+
try:
|
|
452
|
+
audio = load_wav(p)
|
|
453
|
+
except Exception as exc: # noqa: BLE001
|
|
454
|
+
print(f"[warn] skipping {p}: {exc}")
|
|
455
|
+
continue
|
|
456
|
+
clips.append(prepare_clip(audio, window_samples))
|
|
457
|
+
return clips
|