livekit-plugins-hush 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- livekit/plugins/hush/__init__.py +59 -0
- livekit/plugins/hush/_hush_model.py +407 -0
- livekit/plugins/hush/_libdf/__init__.py +361 -0
- livekit/plugins/hush/models/config.ini +18 -0
- livekit/plugins/hush/models/df_dec.onnx +0 -0
- livekit/plugins/hush/models/enc.onnx +0 -0
- livekit/plugins/hush/models/erb_dec.onnx +0 -0
- livekit/plugins/hush/noise_suppressor.py +314 -0
- livekit_plugins_hush-0.3.3.dist-info/METADATA +196 -0
- livekit_plugins_hush-0.3.3.dist-info/RECORD +13 -0
- livekit_plugins_hush-0.3.3.dist-info/WHEEL +5 -0
- livekit_plugins_hush-0.3.3.dist-info/licenses/LICENSE +201 -0
- livekit_plugins_hush-0.3.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from livekit.agents import Plugin
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
from .noise_suppressor import HushNoiseSuppressor
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class HushPlugin(Plugin):
|
|
12
|
+
def __init__(self):
|
|
13
|
+
super().__init__(
|
|
14
|
+
title="Hush",
|
|
15
|
+
version="0.3.3",
|
|
16
|
+
package="livekit-plugins-hush",
|
|
17
|
+
logger=logger,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def noise_suppression(
|
|
22
|
+
model_path: Optional[str] = None,
|
|
23
|
+
atten_lim_db: float = 100.0,
|
|
24
|
+
strength: float = 0.5,
|
|
25
|
+
debug_logging: bool = False,
|
|
26
|
+
) -> HushNoiseSuppressor:
|
|
27
|
+
"""Create a HushNoiseSuppressor instance.
|
|
28
|
+
|
|
29
|
+
Pass to ``AudioInputOptions(noise_cancellation=hush.noise_suppression())``.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
model_path : str, optional
|
|
34
|
+
Path to directory containing enc.onnx, erb_dec.onnx, df_dec.onnx.
|
|
35
|
+
atten_lim_db : float
|
|
36
|
+
Maximum attenuation in dB (default 100.0).
|
|
37
|
+
strength : float
|
|
38
|
+
Wet/dry blend factor (default 0.5). 0.0 = bypass, 1.0 = full suppression.
|
|
39
|
+
debug_logging : bool
|
|
40
|
+
Log diagnostics every 10 chunks at DEBUG level.
|
|
41
|
+
"""
|
|
42
|
+
if atten_lim_db < 0:
|
|
43
|
+
logger.warning(
|
|
44
|
+
"atten_lim_db=%g is negative; clamping to 0 (no attenuation limit). "
|
|
45
|
+
"Negative values boost gain instead of limiting attenuation.",
|
|
46
|
+
atten_lim_db,
|
|
47
|
+
)
|
|
48
|
+
atten_lim_db = 0.0
|
|
49
|
+
return HushNoiseSuppressor(
|
|
50
|
+
model_path=model_path,
|
|
51
|
+
atten_lim_db=atten_lim_db,
|
|
52
|
+
strength=strength,
|
|
53
|
+
debug_logging=debug_logging,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
Plugin.register_plugin(HushPlugin())
|
|
58
|
+
|
|
59
|
+
__all__ = ["HushNoiseSuppressor", "noise_suppression"]
|
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
"""Hush model inference using DeepFilterLib + ONNX Runtime, per-frame streaming.
|
|
2
|
+
|
|
3
|
+
Matches the API shape of the upstream ``weya_nc`` C library: one 10 ms frame
|
|
4
|
+
in, one 10 ms frame out, with continuous GRU hidden state across calls.
|
|
5
|
+
|
|
6
|
+
The encoder, ERB decoder, and DF decoder each carry a SqueezedGRU whose
|
|
7
|
+
hidden state is exposed as an ONNX I/O. ``HushSession`` holds those three
|
|
8
|
+
states (and a 4-frame DF filter history) as plain numpy arrays, threading
|
|
9
|
+
them through every ``process_frame`` call.
|
|
10
|
+
|
|
11
|
+
Feature extraction uses the ``libdf`` C library, with ``reset=False`` so its
|
|
12
|
+
analysis and synthesis filter state is carried across frames. No PyTorch
|
|
13
|
+
required.
|
|
14
|
+
|
|
15
|
+
Re-exporting the ONNX sub-models with GRU state I/O: see
|
|
16
|
+
``scripts/export_onnx_stateful.py``.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import threading
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import onnxruntime as ort
|
|
25
|
+
|
|
26
|
+
from ._libdf import DF, erb, erb_norm, unit_norm
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
# Model constants
|
|
32
|
+
# ---------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
_SAMPLE_RATE = 16_000
|
|
35
|
+
_FFT_SIZE = 320
|
|
36
|
+
_HOP_SIZE = 160
|
|
37
|
+
_FRAME_SAMPLES = _HOP_SIZE # 160 samples = 10 ms at 16 kHz
|
|
38
|
+
_NB_ERB = 32
|
|
39
|
+
_NB_DF = 64
|
|
40
|
+
_NORM_TAU = 1.0
|
|
41
|
+
_DF_ORDER = 5
|
|
42
|
+
|
|
43
|
+
# GRU hidden state dimensions (must match the ONNX export)
|
|
44
|
+
_EMB_HIDDEN = 256
|
|
45
|
+
_DF_HIDDEN = 256
|
|
46
|
+
_ENC_NUM_LAYERS = 1
|
|
47
|
+
_ERB_DEC_NUM_LAYERS = 1
|
|
48
|
+
_DF_DEC_NUM_LAYERS = 3
|
|
49
|
+
|
|
50
|
+
_DEFAULT_MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _compute_alpha(sr, hop, tau):
|
|
54
|
+
return float(np.exp(-hop / (tau * sr)))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _build_erb_inv_fb():
|
|
58
|
+
n_freqs = _FFT_SIZE // 2 + 1
|
|
59
|
+
df_state = DF(
|
|
60
|
+
sr=_SAMPLE_RATE,
|
|
61
|
+
fft_size=_FFT_SIZE,
|
|
62
|
+
hop_size=_HOP_SIZE,
|
|
63
|
+
nb_bands=_NB_ERB,
|
|
64
|
+
min_nb_erb_freqs=2,
|
|
65
|
+
)
|
|
66
|
+
widths = np.asarray(df_state.erb_widths(), dtype=np.int64)
|
|
67
|
+
if widths.sum() != n_freqs:
|
|
68
|
+
raise RuntimeError(
|
|
69
|
+
f"libdf ERB widths sum to {widths.sum()}, expected {n_freqs}"
|
|
70
|
+
)
|
|
71
|
+
b_pts = np.cumsum(np.concatenate([[0], widths])).astype(int)[:-1]
|
|
72
|
+
freqs = np.arange(n_freqs)
|
|
73
|
+
fb = ((freqs[:, None] >= b_pts) & (freqs[:, None] < b_pts + widths)).astype(
|
|
74
|
+
np.float32
|
|
75
|
+
)
|
|
76
|
+
return fb.T.copy()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# ---------------------------------------------------------------------------
|
|
80
|
+
# Shared model (one per process)
|
|
81
|
+
# ---------------------------------------------------------------------------
|
|
82
|
+
|
|
83
|
+
_shared_model = None
|
|
84
|
+
_shared_model_lock = threading.Lock()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _get_shared_model(model_path=None):
|
|
88
|
+
global _shared_model
|
|
89
|
+
with _shared_model_lock:
|
|
90
|
+
if _shared_model is None:
|
|
91
|
+
_shared_model = HushModel(model_path)
|
|
92
|
+
return _shared_model
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _make_low_latency_session_options():
|
|
96
|
+
"""ORTT session options tuned for low-latency single-stream inference.
|
|
97
|
+
|
|
98
|
+
Mirrors the silero VAD plugin's config: single thread per op, no
|
|
99
|
+
inter-op parallelism, no spinning waits, sequential execution mode.
|
|
100
|
+
Yields a ~2x per-frame speedup over ORTT defaults on the Hush
|
|
101
|
+
sub-models because it avoids the per-op thread-pool overhead that
|
|
102
|
+
onnxruntime enables by default for parallel ops.
|
|
103
|
+
|
|
104
|
+
Graph optimization is enabled to fuse constant subgraphs and
|
|
105
|
+
eliminate redundant transposes; this is a free 10-15% speedup
|
|
106
|
+
over the default (which is ORT_ENABLE_BASIC).
|
|
107
|
+
"""
|
|
108
|
+
opts = ort.SessionOptions()
|
|
109
|
+
opts.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
|
110
|
+
opts.add_session_config_entry("session.inter_op.allow_spinning", "0")
|
|
111
|
+
opts.inter_op_num_threads = 1
|
|
112
|
+
opts.intra_op_num_threads = 1
|
|
113
|
+
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
114
|
+
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
115
|
+
return opts
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class HushModel:
|
|
119
|
+
"""Shared ONNX model sessions — loaded once per worker process."""
|
|
120
|
+
|
|
121
|
+
def __init__(self, model_path=None):
|
|
122
|
+
model_dir = model_path or _DEFAULT_MODEL_DIR
|
|
123
|
+
|
|
124
|
+
enc_path = os.path.join(model_dir, "enc.onnx")
|
|
125
|
+
erb_dec_path = os.path.join(model_dir, "erb_dec.onnx")
|
|
126
|
+
df_dec_path = os.path.join(model_dir, "df_dec.onnx")
|
|
127
|
+
|
|
128
|
+
for p in [enc_path, erb_dec_path, df_dec_path]:
|
|
129
|
+
if not os.path.exists(p):
|
|
130
|
+
raise FileNotFoundError(
|
|
131
|
+
f"ONNX model not found: {p}\n"
|
|
132
|
+
"Please ensure the sub-model files are present. "
|
|
133
|
+
"Re-export with scripts/export_onnx_stateful.py if needed."
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
sess_opts = _make_low_latency_session_options()
|
|
137
|
+
self.enc_sess = ort.InferenceSession(
|
|
138
|
+
enc_path, providers=["CPUExecutionProvider"], sess_options=sess_opts
|
|
139
|
+
)
|
|
140
|
+
self.erb_dec_sess = ort.InferenceSession(
|
|
141
|
+
erb_dec_path, providers=["CPUExecutionProvider"], sess_options=sess_opts
|
|
142
|
+
)
|
|
143
|
+
self.df_dec_sess = ort.InferenceSession(
|
|
144
|
+
df_dec_path, providers=["CPUExecutionProvider"], sess_options=sess_opts
|
|
145
|
+
)
|
|
146
|
+
self.erb_inv_fb = _build_erb_inv_fb()
|
|
147
|
+
|
|
148
|
+
# Warm-up: trigger ONNX Runtime JIT compilation once. Use a
|
|
149
|
+
# 1-frame input with zero hidden state.
|
|
150
|
+
s = 1
|
|
151
|
+
self.enc_sess.run(
|
|
152
|
+
None,
|
|
153
|
+
{
|
|
154
|
+
"feat_erb": np.zeros((1, 1, s, _NB_ERB), dtype=np.float32),
|
|
155
|
+
"feat_spec": np.zeros((1, 2, s, _NB_DF), dtype=np.float32),
|
|
156
|
+
"h_enc_in": np.zeros(
|
|
157
|
+
(_ENC_NUM_LAYERS, 1, _EMB_HIDDEN), dtype=np.float32
|
|
158
|
+
),
|
|
159
|
+
},
|
|
160
|
+
)
|
|
161
|
+
self.erb_dec_sess.run(
|
|
162
|
+
None,
|
|
163
|
+
{
|
|
164
|
+
"emb": np.zeros((1, s, 128), dtype=np.float32),
|
|
165
|
+
"e3": np.zeros((1, 16, s, 8), dtype=np.float32),
|
|
166
|
+
"e2": np.zeros((1, 16, s, 8), dtype=np.float32),
|
|
167
|
+
"e1": np.zeros((1, 16, s, 16), dtype=np.float32),
|
|
168
|
+
"e0": np.zeros((1, 16, s, 32), dtype=np.float32),
|
|
169
|
+
"h_erb_dec_in": np.zeros(
|
|
170
|
+
(_ERB_DEC_NUM_LAYERS, 1, _EMB_HIDDEN), dtype=np.float32
|
|
171
|
+
),
|
|
172
|
+
},
|
|
173
|
+
)
|
|
174
|
+
self.df_dec_sess.run(
|
|
175
|
+
None,
|
|
176
|
+
{
|
|
177
|
+
"emb": np.zeros((1, s, 128), dtype=np.float32),
|
|
178
|
+
"c0": np.zeros((1, 16, s, _NB_DF), dtype=np.float32),
|
|
179
|
+
"h_df_dec_in": np.zeros(
|
|
180
|
+
(_DF_DEC_NUM_LAYERS, 1, _DF_HIDDEN), dtype=np.float32
|
|
181
|
+
),
|
|
182
|
+
},
|
|
183
|
+
)
|
|
184
|
+
logger.debug("Hush ONNX models warm-up complete")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# ---------------------------------------------------------------------------
|
|
188
|
+
# Per-session state
|
|
189
|
+
# ---------------------------------------------------------------------------
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class HushSession:
|
|
193
|
+
"""Per-stream denoising session, matching the C library's API shape.
|
|
194
|
+
|
|
195
|
+
One frame of 160 samples (10 ms) in, one frame out, with continuous
|
|
196
|
+
GRU hidden state and DF filter history across calls. ``reset_state()``
|
|
197
|
+
clears all state.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def __init__(self, model, atten_lim_db: float = 100.0):
|
|
201
|
+
self._enc_sess = model.enc_sess
|
|
202
|
+
self._erb_dec_sess = model.erb_dec_sess
|
|
203
|
+
self._df_dec_sess = model.df_dec_sess
|
|
204
|
+
self._erb_inv_fb = model.erb_inv_fb
|
|
205
|
+
|
|
206
|
+
# libdf state — runs in streaming mode (reset=False) so the
|
|
207
|
+
# analysis/synthesis filter state is carried across frames.
|
|
208
|
+
self._df = DF(
|
|
209
|
+
sr=_SAMPLE_RATE,
|
|
210
|
+
fft_size=_FFT_SIZE,
|
|
211
|
+
hop_size=_HOP_SIZE,
|
|
212
|
+
nb_bands=_NB_ERB,
|
|
213
|
+
min_nb_erb_freqs=2,
|
|
214
|
+
)
|
|
215
|
+
self._alpha = _compute_alpha(_SAMPLE_RATE, _HOP_SIZE, _NORM_TAU)
|
|
216
|
+
|
|
217
|
+
# Precompute the linear-blend attenuation coefficient. The
|
|
218
|
+
# upstream reference (`scripts/infer_single.py` in pulp-vision/Hush)
|
|
219
|
+
# does: spec_out = spec_in * lim + spec_enh * (1.0 - lim)
|
|
220
|
+
# where lim = 10**(-atten_lim_db / 20). lim=1.0 → passthrough,
|
|
221
|
+
# lim=0.0 → full model output. Default 100.0 dB → lim ≈ 1e-5,
|
|
222
|
+
# effectively a passthrough of the model output.
|
|
223
|
+
if atten_lim_db < 100.0:
|
|
224
|
+
self._lim = 10.0 ** (-atten_lim_db / 20.0)
|
|
225
|
+
else:
|
|
226
|
+
self._lim = 0.0
|
|
227
|
+
|
|
228
|
+
# State: all zeroed on init / reset. Shapes:
|
|
229
|
+
# _h_enc: [ENC_NUM_LAYERS, 1, EMB_HIDDEN]
|
|
230
|
+
# _h_erb_dec: [ERB_DEC_NUM_LAYERS, 1, EMB_HIDDEN]
|
|
231
|
+
# _h_df_dec: [DF_DEC_NUM_LAYERS, 1, DF_HIDDEN]
|
|
232
|
+
# _prev_df: [_DF_ORDER-1, _NB_DF, 2] float32 (DF filter history)
|
|
233
|
+
self._reset_state()
|
|
234
|
+
|
|
235
|
+
def _reset_state(self):
|
|
236
|
+
self._h_enc = np.zeros((_ENC_NUM_LAYERS, 1, _EMB_HIDDEN), dtype=np.float32)
|
|
237
|
+
self._h_erb_dec = np.zeros(
|
|
238
|
+
(_ERB_DEC_NUM_LAYERS, 1, _EMB_HIDDEN), dtype=np.float32
|
|
239
|
+
)
|
|
240
|
+
self._h_df_dec = np.zeros((_DF_DEC_NUM_LAYERS, 1, _DF_HIDDEN), dtype=np.float32)
|
|
241
|
+
self._prev_df = np.zeros((_DF_ORDER - 1, _NB_DF, 2), dtype=np.float32)
|
|
242
|
+
# Pre-allocated per-frame scratch buffers. Avoids the per-frame
|
|
243
|
+
# np.zeros + np.concatenate + np.copyto allocations on the hot path.
|
|
244
|
+
self._spec_df_new = np.empty((1, _NB_DF, 2), dtype=np.float32)
|
|
245
|
+
self._spec_df_p = np.empty((_DF_ORDER, _NB_DF, 2), dtype=np.float32)
|
|
246
|
+
self._feat_spec = np.empty((1, 2, 1, _NB_DF), dtype=np.float32)
|
|
247
|
+
# Reset libdf's analysis/synthesis filter state so the next
|
|
248
|
+
# audio stream starts from a clean STFT. The first 10 ms
|
|
249
|
+
# after reset will be the STFT warmup (output near zero).
|
|
250
|
+
if self._df is not None:
|
|
251
|
+
self._df.reset()
|
|
252
|
+
|
|
253
|
+
def reset_state(self) -> None:
|
|
254
|
+
"""Reset all per-stream state for a new audio source.
|
|
255
|
+
|
|
256
|
+
Clears the encoder/decoder GRU hidden states, the DF filter
|
|
257
|
+
history, and the libdf STFT filter state. The first 10 ms of
|
|
258
|
+
audio after a reset will be the STFT warmup (output near zero);
|
|
259
|
+
this matches the C library's ``weya_nc_reset`` behavior.
|
|
260
|
+
"""
|
|
261
|
+
self._reset_state()
|
|
262
|
+
|
|
263
|
+
def close(self) -> None:
|
|
264
|
+
self._df = None
|
|
265
|
+
self._enc_sess = None
|
|
266
|
+
self._erb_dec_sess = None
|
|
267
|
+
self._df_dec_sess = None
|
|
268
|
+
self._h_enc = None
|
|
269
|
+
self._h_erb_dec = None
|
|
270
|
+
self._h_df_dec = None
|
|
271
|
+
self._prev_df = None
|
|
272
|
+
|
|
273
|
+
# ------------------------------------------------------------------
|
|
274
|
+
# Per-frame processing
|
|
275
|
+
# ------------------------------------------------------------------
|
|
276
|
+
|
|
277
|
+
def process_frame(self, audio: np.ndarray) -> np.ndarray:
|
|
278
|
+
"""Denoise a single 160-sample (10 ms) frame at 16 kHz.
|
|
279
|
+
|
|
280
|
+
The libdf analysis state, the encoder/decoder GRU hidden states, and
|
|
281
|
+
the DF filter polynomial history are all carried across calls.
|
|
282
|
+
"""
|
|
283
|
+
if audio.ndim == 1:
|
|
284
|
+
audio = audio[np.newaxis, :]
|
|
285
|
+
squeezed = True
|
|
286
|
+
else:
|
|
287
|
+
squeezed = False
|
|
288
|
+
|
|
289
|
+
if audio.shape[1] != _FRAME_SAMPLES:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"process_frame requires {_FRAME_SAMPLES} samples, got {audio.shape[1]}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# ---- STFT (streaming) -----------------------------------------
|
|
295
|
+
# libdf.analysis(audio, reset=False) keeps the analysis filter
|
|
296
|
+
# state across calls. For 160 samples with FFT=320, it produces
|
|
297
|
+
# 1 frame.
|
|
298
|
+
if audio.dtype is np.float32:
|
|
299
|
+
spec_new = self._df.analysis(audio, reset=False)
|
|
300
|
+
else:
|
|
301
|
+
spec_new = self._df.analysis(audio.astype(np.float32), reset=False)
|
|
302
|
+
# spec_new: (1, 1, 161) complex64
|
|
303
|
+
|
|
304
|
+
# ---- Feature extraction (per frame) ----------------------------
|
|
305
|
+
# The DF class owns the EMA state for both normalizations
|
|
306
|
+
# (matches libdf's in-place semantics). On a fresh session
|
|
307
|
+
# the state is None and gets initialized to the libdf defaults.
|
|
308
|
+
erb_feat, self._df._erb_norm_state = erb_norm(
|
|
309
|
+
erb(spec_new, self._df.erb_widths()),
|
|
310
|
+
self._alpha,
|
|
311
|
+
self._df._erb_norm_state,
|
|
312
|
+
) # (1, 1, 32)
|
|
313
|
+
sf_feat, self._df._unit_norm_state = unit_norm(
|
|
314
|
+
spec_new[..., :_NB_DF].copy(), self._alpha, self._df._unit_norm_state
|
|
315
|
+
) # (1, 1, 64) complex
|
|
316
|
+
|
|
317
|
+
# ---- Encoder ----------------------------------------------------
|
|
318
|
+
# Single-frame input. The GRU hidden state carries context.
|
|
319
|
+
# Use pre-allocated feat_spec buffer (real, imag) instead of
|
|
320
|
+
# allocating a fresh np.stack every frame.
|
|
321
|
+
self._feat_spec[0, 0, 0, :] = sf_feat.real[0, 0, :]
|
|
322
|
+
self._feat_spec[0, 1, 0, :] = sf_feat.imag[0, 0, :]
|
|
323
|
+
enc_out = self._enc_sess.run(
|
|
324
|
+
None,
|
|
325
|
+
{
|
|
326
|
+
"feat_erb": erb_feat[:, np.newaxis, :, :],
|
|
327
|
+
"feat_spec": self._feat_spec,
|
|
328
|
+
"h_enc_in": self._h_enc,
|
|
329
|
+
},
|
|
330
|
+
)
|
|
331
|
+
e0, e1, e2, e3, emb, c0, _lsnr, self._h_enc = enc_out
|
|
332
|
+
# All outputs shape (1, 1, ...) for the single time step.
|
|
333
|
+
|
|
334
|
+
# ---- ERB decoder ------------------------------------------------
|
|
335
|
+
m, self._h_erb_dec = self._erb_dec_sess.run(
|
|
336
|
+
None,
|
|
337
|
+
{
|
|
338
|
+
"emb": emb,
|
|
339
|
+
"e3": e3,
|
|
340
|
+
"e2": e2,
|
|
341
|
+
"e1": e1,
|
|
342
|
+
"e0": e0,
|
|
343
|
+
"h_erb_dec_in": self._h_erb_dec,
|
|
344
|
+
},
|
|
345
|
+
)
|
|
346
|
+
# m: (1, 1, 1, 32) — gain mask per ERB band
|
|
347
|
+
|
|
348
|
+
# ---- DF decoder -------------------------------------------------
|
|
349
|
+
coefs, self._h_df_dec = self._df_dec_sess.run(
|
|
350
|
+
None,
|
|
351
|
+
{
|
|
352
|
+
"emb": emb,
|
|
353
|
+
"c0": c0,
|
|
354
|
+
"h_df_dec_in": self._h_df_dec,
|
|
355
|
+
},
|
|
356
|
+
)
|
|
357
|
+
# coefs: (1, 1, 64, 10) — DF filter per freq bin
|
|
358
|
+
|
|
359
|
+
# ---- Post-process spectrum --------------------------------------
|
|
360
|
+
spec_in = spec_new[0, 0] # (161,) complex64
|
|
361
|
+
mask = m[0, 0, 0] # (32,) float32
|
|
362
|
+
coef = coefs[0, 0] # (64, 10) float32
|
|
363
|
+
|
|
364
|
+
# Project ERB mask to full spectrum.
|
|
365
|
+
spec_masked = spec_in * (mask @ self._erb_inv_fb) # (161,) complex
|
|
366
|
+
|
|
367
|
+
# Build DF filter window. The 5-tap polynomial prediction needs
|
|
368
|
+
# 4 frames of "previous filter history" + the new frame.
|
|
369
|
+
# _prev_df holds the 4 frames of (real, imag) saved from the
|
|
370
|
+
# previous call's spec_df_p (or zeros on the first call).
|
|
371
|
+
# Use pre-allocated scratch buffers to avoid per-frame allocations.
|
|
372
|
+
self._spec_df_new[0, :, 0] = spec_in[:_NB_DF].real
|
|
373
|
+
self._spec_df_new[0, :, 1] = spec_in[:_NB_DF].imag
|
|
374
|
+
# Roll the prev frames down and append the new frame
|
|
375
|
+
self._spec_df_p[:-1] = self._prev_df
|
|
376
|
+
self._spec_df_p[-1:] = self._spec_df_new
|
|
377
|
+
# shape: (5, 64, 2)
|
|
378
|
+
|
|
379
|
+
# Save the last 4 frames for the next call.
|
|
380
|
+
np.copyto(self._prev_df, self._spec_df_p[1:])
|
|
381
|
+
|
|
382
|
+
# Apply the 5-tap complex FIR.
|
|
383
|
+
# coef from ONNX: (64, 10) = (F, O*2). PyTorch reference does
|
|
384
|
+
# coefs.permute(0, 2, 1, 3, 4) to put the order axis first →
|
|
385
|
+
# (B, T, O, F, 2). We need c with shape (O, F, 2).
|
|
386
|
+
c = coef.reshape(_NB_DF, _DF_ORDER, 2).transpose(1, 0, 2)
|
|
387
|
+
# spec_df_p: (5, 64, 2). y[f] = sum_t c[t, f, 0]*x[t, f, 0] - c[t, f, 1]*x[t, f, 1]
|
|
388
|
+
re = (c[..., 0] * self._spec_df_p[..., 0] - c[..., 1] * self._spec_df_p[..., 1]).sum(axis=0)
|
|
389
|
+
im = (c[..., 1] * self._spec_df_p[..., 0] + c[..., 0] * self._spec_df_p[..., 1]).sum(axis=0)
|
|
390
|
+
|
|
391
|
+
# Write the DF output directly into spec_masked (no extra copy).
|
|
392
|
+
spec_masked[:_NB_DF] = re + 1j * im
|
|
393
|
+
enhanced = spec_masked
|
|
394
|
+
|
|
395
|
+
# ---- Attenuation limit (linear blend, per reference) -----------
|
|
396
|
+
# Skip the multiply when atten_lim_db is 100 (no blending needed).
|
|
397
|
+
if self._lim > 0.0:
|
|
398
|
+
enhanced = spec_in * self._lim + enhanced * (1.0 - self._lim)
|
|
399
|
+
|
|
400
|
+
# ---- STFT synthesis (streaming) --------------------------------
|
|
401
|
+
# libdf.synthesis with 1 frame gives 160 samples with reset=False.
|
|
402
|
+
audio_out = self._df.synthesis(enhanced[np.newaxis, np.newaxis, :], reset=False)
|
|
403
|
+
# audio_out: (1, 160) float32
|
|
404
|
+
|
|
405
|
+
if squeezed:
|
|
406
|
+
return audio_out[0]
|
|
407
|
+
return audio_out.reshape(audio.shape)
|