heed-wakeword 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
heed/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ __version__ = "0.1.0"
2
+
3
+ SAMPLE_RATE = 16000
4
+ WINDOW_SECONDS = 1.0
5
+ HOP_SAMPLES = 160 # 10 ms at 16 kHz
6
+ WIN_SAMPLES = 400 # 25 ms at 16 kHz - STFT analysis window (Hann), zero-padded to N_FFT
7
+ N_FFT = 512 # FFT size: nearest power of two >= WIN_SAMPLES. A power-of-two
8
+ # transform is a fast radix-2 FFT in every deployment runtime
9
+ # (JS/Swift/Kotlin/C) instead of a slow Bluestein/DFT for N=400.
10
+ # The 25 ms analysis window is preserved (win_length=400); the
11
+ # window is just zero-padded to 512 before the FFT (librosa's
12
+ # standard convention). Mel features come out 40 x 101 either way.
13
+ N_MELS = 40
14
+ WINDOW_FRAMES = int(WINDOW_SECONDS * SAMPLE_RATE / HOP_SAMPLES) # ~100 frames
heed/audio.py ADDED
@@ -0,0 +1,457 @@
1
+ """Audio I/O, mel features, normalization, trimming.
2
+
3
+ Pure torch + scipy + soundfile. No torchaudio C-extension required - this
4
+ was a portability issue on Windows where torchaudio's binary often mismatches
5
+ the installed torch.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import numpy as np
15
+ import scipy.signal
16
+ import soundfile as sf
17
+ import torch
18
+
19
+ from . import HOP_SAMPLES, N_FFT, N_MELS, SAMPLE_RATE, WIN_SAMPLES, WINDOW_FRAMES
20
+
21
+
22
+ # ----- pure-torch mel feature extractor (matches torchaudio at <1e-3 max-abs) ---
23
+
24
+
25
+ def _hz_to_mel(hz: float) -> float:
26
+ return 2595.0 * math.log10(1.0 + hz / 700.0)
27
+
28
+
29
+ def _mel_to_hz(m: float) -> float:
30
+ return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
31
+
32
+
33
+ def _build_mel_filterbank(
34
+ n_mels: int = N_MELS,
35
+ n_fft: int = N_FFT,
36
+ sample_rate: int = SAMPLE_RATE,
37
+ fmin: float = 0.0,
38
+ fmax: Optional[float] = None,
39
+ ) -> torch.Tensor:
40
+ if fmax is None:
41
+ fmax = sample_rate / 2
42
+ mel_min = _hz_to_mel(fmin)
43
+ mel_max = _hz_to_mel(fmax)
44
+ mel_pts = torch.linspace(mel_min, mel_max, n_mels + 2)
45
+ hz_pts = torch.tensor([_mel_to_hz(m.item()) for m in mel_pts])
46
+ freqs = torch.linspace(0.0, sample_rate / 2, n_fft // 2 + 1)
47
+ fb = torch.zeros(n_mels, n_fft // 2 + 1)
48
+ for m in range(n_mels):
49
+ left, center, right = hz_pts[m], hz_pts[m + 1], hz_pts[m + 2]
50
+ rising = (freqs - left) / (center - left + 1e-9)
51
+ falling = (right - freqs) / (right - center + 1e-9)
52
+ fb[m] = torch.clamp(torch.minimum(rising, falling), min=0.0)
53
+ return fb
54
+
55
+
56
+ _MEL_FB: Optional[torch.Tensor] = None
57
+ _HANN: Optional[torch.Tensor] = None
58
+
59
+
60
+ def _mel_fb() -> torch.Tensor:
61
+ global _MEL_FB
62
+ if _MEL_FB is None:
63
+ _MEL_FB = _build_mel_filterbank()
64
+ return _MEL_FB
65
+
66
+
67
+ def _hann() -> torch.Tensor:
68
+ global _HANN
69
+ if _HANN is None:
70
+ _HANN = torch.hann_window(WIN_SAMPLES)
71
+ return _HANN
72
+
73
+
74
+ def log_mel(audio: torch.Tensor, apply_cmn: bool = True) -> torch.Tensor:
75
+ """Compute log-mel spectrogram. Input (T,) or (B, T); output (B, n_mels, F).
76
+
77
+ When `apply_cmn` is True (default), subtracts the per-clip mean across
78
+ time from each mel bin. This is **cepstral mean normalization** - a
79
+ standard ASR trick that makes the representation invariant to mic
80
+ frequency response and channel gain (these become additive constants
81
+ in log-mel, eliminated by mean subtraction). Critical for cross-mic
82
+ robustness; the model trained without it is locked to the trainer's
83
+ mic spectrum.
84
+
85
+ Must be applied **consistently** at train and inference: a model trained
86
+ with CMN won't work without it and vice versa. Hence default-on
87
+ everywhere in this codebase.
88
+ """
89
+ if audio.ndim == 1:
90
+ audio = audio.unsqueeze(0)
91
+ spec = torch.stft(
92
+ audio,
93
+ n_fft=N_FFT,
94
+ hop_length=HOP_SAMPLES,
95
+ win_length=WIN_SAMPLES,
96
+ window=_hann().to(audio.device),
97
+ center=True,
98
+ pad_mode="reflect",
99
+ return_complex=True,
100
+ )
101
+ power = spec.real.pow(2) + spec.imag.pow(2)
102
+ fb = _mel_fb().to(audio.device)
103
+ mel = fb @ power # (B, n_mels, T)
104
+ out = torch.log(mel.clamp(min=1e-9))
105
+ if apply_cmn:
106
+ # Per-clip mean across time, per mel bin. Subtraction removes the
107
+ # multiplicative mic-spectrum contribution that became additive
108
+ # after log.
109
+ out = out - out.mean(dim=-1, keepdim=True)
110
+ return out
111
+
112
+
113
+ def expected_frames(window_samples: int = SAMPLE_RATE) -> int:
114
+ return window_samples // HOP_SAMPLES + 1
115
+
116
+
117
+ # ----- WAV I/O with scipy-based resample ------------------------------------
118
+
119
+
120
+ def _resample(audio: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
121
+ """scipy polyphase resample. Good quality, no C-ext dependency."""
122
+ if src_sr == dst_sr:
123
+ return audio
124
+ g = math.gcd(src_sr, dst_sr)
125
+ up = dst_sr // g
126
+ down = src_sr // g
127
+ return scipy.signal.resample_poly(audio, up, down).astype(np.float32)
128
+
129
+
130
+ _LOW_SR_WARNED: set[str] = set()
131
+
132
+
133
+ def load_wav(path: str | Path) -> torch.Tensor:
134
+ """Load any WAV file → mono float32 tensor at SAMPLE_RATE.
135
+
136
+ Auto-resamples non-16 kHz inputs via scipy polyphase, auto-downmixes
137
+ stereo to mono, auto-converts integer PCM to float32 (handled by
138
+ soundfile). The only sample-rate footgun is *very low* rates: telephony
139
+ (8 kHz) drops content above 4 kHz, which removes most fricatives
140
+ (s, sh, f). We warn once per file in that case.
141
+ """
142
+ audio, sr = sf.read(str(path), dtype="float32", always_2d=False)
143
+ if audio.ndim > 1:
144
+ audio = audio.mean(axis=1)
145
+ if sr < 14000 and str(path) not in _LOW_SR_WARNED:
146
+ _LOW_SR_WARNED.add(str(path))
147
+ print(
148
+ f"[warn] {path} is only {sr} Hz - high-frequency consonants "
149
+ f"(s/sh/f) will be missing. Re-record at 16 kHz or higher if "
150
+ f"your wake word contains them."
151
+ )
152
+ if sr != SAMPLE_RATE:
153
+ audio = _resample(audio, sr, SAMPLE_RATE)
154
+ return torch.from_numpy(np.ascontiguousarray(audio))
155
+
156
+
157
+ def save_wav(path: str | Path, audio: torch.Tensor | np.ndarray) -> None:
158
+ """Save mono audio to WAV at SAMPLE_RATE."""
159
+ if isinstance(audio, torch.Tensor):
160
+ audio = audio.detach().cpu().numpy()
161
+ audio = np.asarray(audio, dtype=np.float32).squeeze()
162
+ sf.write(str(path), audio, SAMPLE_RATE, subtype="PCM_16")
163
+
164
+
165
+ # ----- normalization + windowing -------------------------------------------
166
+
167
+
168
+ def peak_normalize(audio: torch.Tensor, target_dbfs: float = -3.0) -> torch.Tensor:
169
+ """Scale audio so peak is at target_dbfs."""
170
+ peak = audio.abs().max().clamp(min=1e-9)
171
+ target = 10 ** (target_dbfs / 20.0)
172
+ if peak < 1e-6:
173
+ return audio
174
+ return audio * (target / peak)
175
+
176
+
177
+ # Cached filter coefficients so we don't recompute butter() on every clip.
178
+ _HPF_CACHE: dict[tuple[float, int, int], np.ndarray] = {}
179
+ _NOTCH_CACHE: dict[tuple[float, float, int], np.ndarray] = {}
180
+
181
+
182
+ def _hpf_coefficients(cutoff_hz: float, sample_rate: int, order: int) -> np.ndarray:
183
+ key = (float(cutoff_hz), int(sample_rate), int(order))
184
+ if key not in _HPF_CACHE:
185
+ sos = scipy.signal.butter(order, cutoff_hz, btype="highpass",
186
+ fs=sample_rate, output="sos")
187
+ _HPF_CACHE[key] = sos
188
+ return _HPF_CACHE[key]
189
+
190
+
191
+ def _notch_coefficients(freq_hz: float, q: float, sample_rate: int) -> np.ndarray:
192
+ """IIR notch (biquad) coefficients. q higher = narrower notch."""
193
+ key = (float(freq_hz), float(q), int(sample_rate))
194
+ if key not in _NOTCH_CACHE:
195
+ b, a = scipy.signal.iirnotch(freq_hz, q, fs=sample_rate)
196
+ # convert to SOS so the causal sosfilt cascade is applied uniformly
197
+ sos = scipy.signal.tf2sos(b, a)
198
+ _NOTCH_CACHE[key] = sos
199
+ return _NOTCH_CACHE[key]
200
+
201
+
202
+ def highpass_filter(
203
+ audio: torch.Tensor,
204
+ cutoff_hz: float = 100.0,
205
+ sample_rate: int = SAMPLE_RATE,
206
+ order: int = 8,
207
+ apply_mains_notch: bool = True,
208
+ ) -> torch.Tensor:
209
+ """Aggressively remove sub-cutoff_hz content + mains-hum notches.
210
+
211
+ Default 8th-order Butterworth at 100 Hz: -48 dB/octave rolloff, so 60 Hz
212
+ is at ~-32 dB and 30 Hz at ~-56 dB. Loses the very bottom of male voice
213
+ fundamentals (typical adult male 85-180 Hz; we sacrifice 85-100 Hz),
214
+ but speech intelligibility lives in the formants (300-3500 Hz), so the
215
+ model still has plenty to work with - and the gain in mic-noise
216
+ rejection is large.
217
+
218
+ Plus optional notch filters at 50 Hz and 60 Hz, killing mains hum
219
+ that bleeds into low-end harmonics. Notches are narrow (Q=30) so they
220
+ don't audibly affect anything but the offending frequencies.
221
+
222
+ CAUSAL single-pass (``scipy.signal.sosfilt``), initialized with the
223
+ steady-state ``sosfilt_zi * x[0]`` so there is no startup DC step. This
224
+ is the key property that makes real-time streaming possible: a causal
225
+ filter's output for a given sample never changes once later samples
226
+ arrive, so a streaming deployment can filter each new chunk once (with
227
+ retained state via :class:`StreamingHighpass`) and never recompute old
228
+ audio. Filtering a stream chunk-by-chunk with retained state is
229
+ bit-identical (it is an LTI system) to one call of this function over
230
+ the concatenated signal - so a model trained on clips filtered here
231
+ sees consistent features when deployed with the streaming filter.
232
+
233
+ (The previous implementation used zero-phase ``sosfiltfilt``, which runs
234
+ the filter forwards *and* backwards; the backward pass makes every output
235
+ sample depend on all future samples, which is impossible to stream and
236
+ forced the JS preprocessor to recompute the entire 1-s window every hop.)
237
+ """
238
+ if audio.numel() < 1:
239
+ return audio
240
+ flat = audio.flatten().detach().cpu().numpy().astype(np.float64)
241
+ try:
242
+ sos_hpf = _hpf_coefficients(cutoff_hz, sample_rate, order)
243
+ zi = scipy.signal.sosfilt_zi(sos_hpf)
244
+ flat, _ = scipy.signal.sosfilt(sos_hpf, flat, zi=zi * flat[0])
245
+ if apply_mains_notch:
246
+ # Both 50 Hz (Europe) and 60 Hz (Americas) - we don't know which.
247
+ # Applying both is cheap and harmless: each is a narrow notch.
248
+ for f in (50.0, 60.0):
249
+ sos_notch = _notch_coefficients(f, q=30.0, sample_rate=sample_rate)
250
+ zin = scipy.signal.sosfilt_zi(sos_notch)
251
+ flat, _ = scipy.signal.sosfilt(sos_notch, flat, zi=zin * flat[0])
252
+ except ValueError:
253
+ return audio # falls through on any unexpected scipy edge case
254
+ flat = np.ascontiguousarray(flat.astype(np.float32))
255
+ out = torch.from_numpy(flat)
256
+ return out if audio.ndim == 1 else out.view_as(audio)
257
+
258
+
259
+ class StreamingHighpass:
260
+ """Causal, stateful counterpart to :func:`highpass_filter` for real-time
261
+ streaming inference.
262
+
263
+ Holds the per-section biquad state across calls, so each incoming audio
264
+ chunk is filtered exactly once and old audio is never re-touched. The
265
+ output is bit-identical (an LTI cascade) to calling ``highpass_filter``
266
+ on the concatenation of all chunks seen so far - so it stays consistent
267
+ with the offline filtering used to prepare training clips.
268
+
269
+ This is the exact algorithm mirrored in the JS demos
270
+ (``examples/*/preprocessing.js``); keeping the two in lock-step is what
271
+ lets the browser / React-Native runtimes match Python bit-for-bit.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ cutoff_hz: float = 100.0,
277
+ sample_rate: int = SAMPLE_RATE,
278
+ order: int = 8,
279
+ apply_mains_notch: bool = True,
280
+ ) -> None:
281
+ self._sos = [_hpf_coefficients(cutoff_hz, sample_rate, order)]
282
+ if apply_mains_notch:
283
+ self._sos += [
284
+ _notch_coefficients(f, q=30.0, sample_rate=sample_rate)
285
+ for f in (50.0, 60.0)
286
+ ]
287
+ # Unit steady-state initial conditions per stage; scaled by each
288
+ # stage's first input sample on the first chunk (matches scipy's
289
+ # sosfilt_zi * x[0] convention used in highpass_filter above).
290
+ self._zi_unit = [scipy.signal.sosfilt_zi(sos) for sos in self._sos]
291
+ self._zi: list[Optional[np.ndarray]] = [None] * len(self._sos)
292
+
293
+ def reset(self) -> None:
294
+ self._zi = [None] * len(self._sos)
295
+
296
+ def __call__(self, chunk: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray:
297
+ is_tensor = isinstance(chunk, torch.Tensor)
298
+ x = (chunk.detach().cpu().numpy() if is_tensor else np.asarray(chunk))
299
+ x = x.astype(np.float64).flatten()
300
+ if x.size == 0:
301
+ return chunk
302
+ for i, sos in enumerate(self._sos):
303
+ if self._zi[i] is None:
304
+ self._zi[i] = self._zi_unit[i] * x[0]
305
+ x, self._zi[i] = scipy.signal.sosfilt(sos, x, zi=self._zi[i])
306
+ x = np.ascontiguousarray(x.astype(np.float32))
307
+ return torch.from_numpy(x) if is_tensor else x
308
+
309
+
310
+ def trim_silence(
311
+ audio: torch.Tensor,
312
+ frame_ms: float = 20.0,
313
+ threshold_db: float = -40.0,
314
+ ) -> torch.Tensor:
315
+ """Trim leading/trailing silence below threshold (relative to peak)."""
316
+ if audio.numel() == 0:
317
+ return audio
318
+ frame_len = int(SAMPLE_RATE * frame_ms / 1000)
319
+ if frame_len <= 0 or audio.numel() <= frame_len:
320
+ return audio
321
+ pad = (-audio.numel()) % frame_len
322
+ padded = torch.cat([audio, audio.new_zeros(pad)])
323
+ frames = padded.unfold(0, frame_len, frame_len)
324
+ rms = frames.pow(2).mean(dim=-1).sqrt()
325
+ peak_rms = rms.max().clamp(min=1e-9)
326
+ threshold = peak_rms * (10 ** (threshold_db / 20.0))
327
+ voiced = rms > threshold
328
+ if not voiced.any():
329
+ return audio
330
+ idx = voiced.nonzero(as_tuple=False).squeeze(-1)
331
+ start = idx.min().item() * frame_len
332
+ end = (idx.max().item() + 1) * frame_len
333
+ return audio[start:end]
334
+
335
+
336
+ def center_in_window(audio: torch.Tensor, window_samples: int = SAMPLE_RATE) -> torch.Tensor:
337
+ """Place audio in a fixed-length window, centered. Pad with zeros or truncate."""
338
+ n = audio.numel()
339
+ if n == window_samples:
340
+ return audio
341
+ if n < window_samples:
342
+ pad_total = window_samples - n
343
+ left = pad_total // 2
344
+ right = pad_total - left
345
+ return torch.cat([audio.new_zeros(left), audio, audio.new_zeros(right)])
346
+ excess = n - window_samples
347
+ start = excess // 2
348
+ return audio[start : start + window_samples]
349
+
350
+
351
+ def prepare_clip(audio: torch.Tensor, window_samples: int = SAMPLE_RATE) -> torch.Tensor:
352
+ """Normalize, high-pass filter, trim silence, and center in a fixed window."""
353
+ audio = highpass_filter(audio) # remove sub-80Hz rumble before anything else
354
+ audio = peak_normalize(audio, target_dbfs=-3.0)
355
+ audio = trim_silence(audio)
356
+ if audio.numel() == 0:
357
+ return torch.zeros(window_samples)
358
+ audio = peak_normalize(audio, target_dbfs=-3.0)
359
+ return center_in_window(audio, window_samples)
360
+
361
+
362
+ def end_align_in_window(
363
+ audio: torch.Tensor,
364
+ window_samples: int = SAMPLE_RATE,
365
+ trailing_silence_ms: float = 50.0,
366
+ ) -> torch.Tensor:
367
+ """Place audio at the right edge of a fixed window with optional trailing
368
+ silence padding to the right (so we don't end exactly at the buffer's last
369
+ sample, which is unnatural). Used to create end-aligned positive variants
370
+ that train the model to fire on phrase-completion at the buffer's right
371
+ edge - matching how the sliding inference buffer captures fresh audio.
372
+ """
373
+ n = audio.numel()
374
+ if n >= window_samples:
375
+ return audio[-window_samples:]
376
+ trailing = min(int(SAMPLE_RATE * trailing_silence_ms / 1000),
377
+ window_samples - n)
378
+ leading = window_samples - n - trailing
379
+ return torch.cat([audio.new_zeros(leading), audio, audio.new_zeros(trailing)])
380
+
381
+
382
+ def prepare_clip_end_aligned(
383
+ audio: torch.Tensor,
384
+ window_samples: int = SAMPLE_RATE,
385
+ trailing_silence_ms: float = 50.0,
386
+ ) -> torch.Tensor:
387
+ """prepare_clip variant: HPF + normalize + trim silence + end-align (right edge)."""
388
+ audio = highpass_filter(audio)
389
+ audio = peak_normalize(audio, target_dbfs=-3.0)
390
+ audio = trim_silence(audio)
391
+ if audio.numel() == 0:
392
+ return torch.zeros(window_samples)
393
+ audio = peak_normalize(audio, target_dbfs=-3.0)
394
+ return end_align_in_window(audio, window_samples, trailing_silence_ms)
395
+
396
+
397
+ def random_align_in_window(
398
+ audio: torch.Tensor,
399
+ window_samples: int = SAMPLE_RATE,
400
+ rng: "random.Random | None" = None,
401
+ ) -> torch.Tensor:
402
+ """Place audio at a RANDOM position within a fixed-length window.
403
+
404
+ Streaming inference sees the wake phrase at every alignment as the
405
+ rolling buffer scrolls. Centered + end-aligned training cover the two
406
+ extremes; random-aligned variants fill the in-between space so the
407
+ model isn't pinned to specific positions.
408
+ """
409
+ import random as _r
410
+ rng = rng or _r
411
+ n = audio.numel()
412
+ if n >= window_samples:
413
+ # Random crop instead of forced alignment
414
+ start = rng.randint(0, n - window_samples)
415
+ return audio[start : start + window_samples]
416
+ # Otherwise: pick a random left-pad amount and place the audio
417
+ leading = rng.randint(0, window_samples - n)
418
+ trailing = window_samples - n - leading
419
+ return torch.cat([audio.new_zeros(leading), audio, audio.new_zeros(trailing)])
420
+
421
+
422
+ def prepare_clip_random_aligned(
423
+ audio: torch.Tensor,
424
+ window_samples: int = SAMPLE_RATE,
425
+ rng: "random.Random | None" = None,
426
+ ) -> torch.Tensor:
427
+ """prepare_clip variant: HPF + normalize + trim silence + random-align.
428
+
429
+ Adds variance to where speech appears within the 1-s window. Combined
430
+ with the existing centered + end-aligned variants, the model learns to
431
+ detect the wake phrase regardless of its alignment in the buffer -
432
+ closing the train/infer mismatch from streaming buffers scrolling
433
+ audio through every possible position.
434
+ """
435
+ audio = highpass_filter(audio)
436
+ audio = peak_normalize(audio, target_dbfs=-3.0)
437
+ audio = trim_silence(audio)
438
+ if audio.numel() == 0:
439
+ return torch.zeros(window_samples)
440
+ audio = peak_normalize(audio, target_dbfs=-3.0)
441
+ return random_align_in_window(audio, window_samples, rng=rng)
442
+
443
+
444
+ def load_dir_clips(directory: str | Path, window_samples: int = SAMPLE_RATE) -> list[torch.Tensor]:
445
+ """Load every .wav in a directory, return prepared fixed-length clips."""
446
+ directory = Path(directory)
447
+ if not directory.is_dir():
448
+ return []
449
+ clips: list[torch.Tensor] = []
450
+ for p in sorted(directory.glob("*.wav")):
451
+ try:
452
+ audio = load_wav(p)
453
+ except Exception as exc: # noqa: BLE001
454
+ print(f"[warn] skipping {p}: {exc}")
455
+ continue
456
+ clips.append(prepare_clip(audio, window_samples))
457
+ return clips