converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""Pocket TTS provider.
|
|
2
|
+
|
|
3
|
+
The ``pocket_tts`` package is imported lazily inside :meth:`_ensure_model`
|
|
4
|
+
so the base :mod:`converse_framework` package stays light. Install with::
|
|
5
|
+
|
|
6
|
+
pip install 'converse-framework[pocket-tts]'
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
from collections.abc import AsyncIterator
|
|
15
|
+
|
|
16
|
+
from converse_framework.audio_utils import float_audio_to_pcm_s16le_bytes
|
|
17
|
+
from converse_framework.protocols import (
|
|
18
|
+
AudioChunk,
|
|
19
|
+
ProgressCallback,
|
|
20
|
+
ProviderCapabilities,
|
|
21
|
+
ProviderConfigResult,
|
|
22
|
+
ProviderStatus,
|
|
23
|
+
TTSProvider,
|
|
24
|
+
VoiceInfo,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PocketTTSProvider(TTSProvider):
|
|
29
|
+
def __init__(self, config: dict):
|
|
30
|
+
self.voice = str(config.get("voice", "azelma"))
|
|
31
|
+
self.language = config.get("language")
|
|
32
|
+
self.temp = float(config.get("temp", 0.7))
|
|
33
|
+
self.max_tokens = int(config.get("max_tokens", 50))
|
|
34
|
+
self.quantize = bool(config.get("quantize", False))
|
|
35
|
+
self.coalesce_ms = int(config.get("coalesce_ms", 400))
|
|
36
|
+
self._model = config.get("_model")
|
|
37
|
+
self._voice_state = config.get("_voice_state")
|
|
38
|
+
self._load_error: str | None = None
|
|
39
|
+
self._lock = threading.Lock()
|
|
40
|
+
# Known voice identifiers for pocket-tts; listed here so status
|
|
41
|
+
# can advertise them without importing the heavy backend.
|
|
42
|
+
self._known_voices = (
|
|
43
|
+
{"id": "azelma", "label": "Azelma", "language": "en"},
|
|
44
|
+
{"id": "bela", "label": "Bela", "language": "en"},
|
|
45
|
+
{"id": "conrad", "label": "Conrad", "language": "en"},
|
|
46
|
+
{"id": "demeter", "label": "Demeter", "language": "en"},
|
|
47
|
+
{"id": "ebenezer", "label": "Ebenezer", "language": "en"},
|
|
48
|
+
{"id": "ferdinand", "label": "Ferdinand", "language": "en"},
|
|
49
|
+
{"id": "gaspard", "label": "Gaspard", "language": "en"},
|
|
50
|
+
{"id": "horace", "label": "Horace", "language": "en"},
|
|
51
|
+
{"id": "ivo", "label": "Ivo", "language": "en"},
|
|
52
|
+
{"id": "jean", "label": "Jean", "language": "fr"},
|
|
53
|
+
{"id": "kimber", "label": "Kimber", "language": "en"},
|
|
54
|
+
{"id": "lobelia", "label": "Lobelia", "language": "en"},
|
|
55
|
+
{"id": "marie", "label": "Marie", "language": "fr"},
|
|
56
|
+
{"id": "nico", "label": "Nico", "language": "en"},
|
|
57
|
+
{"id": "orion", "label": "Orion", "language": "en"},
|
|
58
|
+
{"id": "pavel", "label": "Pavel", "language": "en"},
|
|
59
|
+
{"id": "quito", "label": "Quito", "language": "en"},
|
|
60
|
+
{"id": "river", "label": "River", "language": "en"},
|
|
61
|
+
{"id": "sophia", "label": "Sophia", "language": "en"},
|
|
62
|
+
{"id": "tom", "label": "Tom", "language": "en"},
|
|
63
|
+
{"id": "xavier", "label": "Xavier", "language": "en"},
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def status(self) -> ProviderStatus:
|
|
68
|
+
if self._load_error:
|
|
69
|
+
return ProviderStatus(
|
|
70
|
+
name="pocket-tts",
|
|
71
|
+
kind="tts",
|
|
72
|
+
ready=False,
|
|
73
|
+
message=f"Pocket TTS failed to load: {self._load_error}",
|
|
74
|
+
capabilities=ProviderCapabilities(supports_streaming_tts=True),
|
|
75
|
+
provider_id="pocket-tts",
|
|
76
|
+
loaded=False,
|
|
77
|
+
supports_model_management=True,
|
|
78
|
+
supports_voice_selection=True,
|
|
79
|
+
active_voice=self.voice,
|
|
80
|
+
voices=self._known_voices,
|
|
81
|
+
status_level="error",
|
|
82
|
+
)
|
|
83
|
+
mode = "int8" if self.quantize else "fp32"
|
|
84
|
+
loaded = self._model is not None and self._voice_state is not None
|
|
85
|
+
if loaded:
|
|
86
|
+
message = f"Loaded Pocket TTS voice '{self.voice}' ({mode})."
|
|
87
|
+
status_level = "ready"
|
|
88
|
+
else:
|
|
89
|
+
message = (
|
|
90
|
+
f"Configured for Pocket TTS voice '{self.voice}' ({mode}). "
|
|
91
|
+
"Model and voice load on first TTS request."
|
|
92
|
+
)
|
|
93
|
+
status_level = "configured"
|
|
94
|
+
return ProviderStatus(
|
|
95
|
+
name="pocket-tts",
|
|
96
|
+
kind="tts",
|
|
97
|
+
ready=True,
|
|
98
|
+
message=message,
|
|
99
|
+
capabilities=ProviderCapabilities(
|
|
100
|
+
supports_streaming_tts=True,
|
|
101
|
+
languages=("en", "fr", "de", "pt", "it", "es"),
|
|
102
|
+
),
|
|
103
|
+
provider_id="pocket-tts",
|
|
104
|
+
loaded=loaded,
|
|
105
|
+
supports_model_management=True,
|
|
106
|
+
supports_voice_selection=True,
|
|
107
|
+
active_voice=self.voice,
|
|
108
|
+
voices=self._known_voices,
|
|
109
|
+
status_level=status_level,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async def check_status(self) -> ProviderStatus:
|
|
113
|
+
return await self.probe_status()
|
|
114
|
+
|
|
115
|
+
async def probe_status(self) -> ProviderStatus:
|
|
116
|
+
"""Cheap probe: check import availability, no model load."""
|
|
117
|
+
try:
|
|
118
|
+
import pocket_tts # type: ignore[import-not-found] # noqa: F401
|
|
119
|
+
except Exception as exc: # pragma: no cover - import path
|
|
120
|
+
self._load_error = str(exc)
|
|
121
|
+
return self.status
|
|
122
|
+
|
|
123
|
+
async def load_status(self) -> ProviderStatus:
|
|
124
|
+
"""May load heavy resources."""
|
|
125
|
+
return await self.load()
|
|
126
|
+
|
|
127
|
+
async def load(self) -> ProviderStatus:
|
|
128
|
+
loop = asyncio.get_running_loop()
|
|
129
|
+
await loop.run_in_executor(None, self._ensure_model)
|
|
130
|
+
return self.status
|
|
131
|
+
|
|
132
|
+
async def unload(self) -> ProviderStatus:
|
|
133
|
+
def release() -> None:
|
|
134
|
+
with self._lock:
|
|
135
|
+
self._model = None
|
|
136
|
+
self._voice_state = None
|
|
137
|
+
self._load_error = None
|
|
138
|
+
|
|
139
|
+
loop = asyncio.get_running_loop()
|
|
140
|
+
await loop.run_in_executor(None, release)
|
|
141
|
+
return self.status
|
|
142
|
+
|
|
143
|
+
def set_quantize(self, quantize: bool) -> ProviderStatus:
|
|
144
|
+
"""Switch quantization mode and unload cached model state if needed.
|
|
145
|
+
|
|
146
|
+
The next :meth:`load` or synthesis request reloads Pocket TTS
|
|
147
|
+
with the updated mode. If the requested mode is already active,
|
|
148
|
+
loaded model state is kept.
|
|
149
|
+
"""
|
|
150
|
+
requested = bool(quantize)
|
|
151
|
+
with self._lock:
|
|
152
|
+
if self.quantize == requested:
|
|
153
|
+
return self.status
|
|
154
|
+
self.quantize = requested
|
|
155
|
+
self._model = None
|
|
156
|
+
self._voice_state = None
|
|
157
|
+
self._load_error = None
|
|
158
|
+
return self.status
|
|
159
|
+
|
|
160
|
+
async def stream_audio(self, text: str) -> AsyncIterator[AudioChunk]:
|
|
161
|
+
async for chunk in self.stream_audio_with_progress(text):
|
|
162
|
+
yield chunk
|
|
163
|
+
|
|
164
|
+
async def stream_audio_with_progress(
|
|
165
|
+
self,
|
|
166
|
+
text: str,
|
|
167
|
+
progress: ProgressCallback | None = None,
|
|
168
|
+
) -> AsyncIterator[AudioChunk]:
|
|
169
|
+
loop = asyncio.get_running_loop()
|
|
170
|
+
queue: asyncio.Queue[AudioChunk | Exception | None] = asyncio.Queue()
|
|
171
|
+
|
|
172
|
+
def worker() -> None:
|
|
173
|
+
try:
|
|
174
|
+
with self._lock:
|
|
175
|
+
started = time.perf_counter()
|
|
176
|
+
self._emit_progress(
|
|
177
|
+
loop,
|
|
178
|
+
progress,
|
|
179
|
+
"loading",
|
|
180
|
+
f"Loading Pocket TTS voice '{self.voice}'.",
|
|
181
|
+
)
|
|
182
|
+
self._ensure_model()
|
|
183
|
+
self._emit_progress(
|
|
184
|
+
loop,
|
|
185
|
+
progress,
|
|
186
|
+
"loaded",
|
|
187
|
+
f"Pocket TTS ready after {round(time.perf_counter() - started, 1)}s.",
|
|
188
|
+
)
|
|
189
|
+
self._emit_progress(
|
|
190
|
+
loop, progress, "generating", "Generating speech."
|
|
191
|
+
)
|
|
192
|
+
assert self._model is not None
|
|
193
|
+
target_samples = max(
|
|
194
|
+
1, int(self._model.sample_rate * self.coalesce_ms / 1000)
|
|
195
|
+
)
|
|
196
|
+
pending = bytearray()
|
|
197
|
+
pending_samples = 0
|
|
198
|
+
assert self._voice_state is not None
|
|
199
|
+
chunks = self._model.generate_audio_stream(
|
|
200
|
+
self._voice_state,
|
|
201
|
+
text,
|
|
202
|
+
max_tokens=self.max_tokens,
|
|
203
|
+
copy_state=True,
|
|
204
|
+
)
|
|
205
|
+
for index, audio in enumerate(chunks):
|
|
206
|
+
pcm_bytes = float_audio_to_pcm_s16le_bytes(audio)
|
|
207
|
+
if not pcm_bytes:
|
|
208
|
+
continue
|
|
209
|
+
pending.extend(pcm_bytes)
|
|
210
|
+
pending_samples += len(pcm_bytes) // 2
|
|
211
|
+
if pending_samples >= target_samples:
|
|
212
|
+
self._emit_progress(
|
|
213
|
+
loop,
|
|
214
|
+
progress,
|
|
215
|
+
"chunk",
|
|
216
|
+
f"Generated audio chunk {index + 1}.",
|
|
217
|
+
)
|
|
218
|
+
asyncio.run_coroutine_threadsafe(
|
|
219
|
+
queue.put(
|
|
220
|
+
AudioChunk(
|
|
221
|
+
bytes(pending),
|
|
222
|
+
sample_rate=self._model.sample_rate,
|
|
223
|
+
channels=1,
|
|
224
|
+
encoding="pcm_s16le",
|
|
225
|
+
duration_ms=int(
|
|
226
|
+
pending_samples
|
|
227
|
+
* 1000
|
|
228
|
+
/ self._model.sample_rate
|
|
229
|
+
),
|
|
230
|
+
final=False,
|
|
231
|
+
)
|
|
232
|
+
),
|
|
233
|
+
loop,
|
|
234
|
+
)
|
|
235
|
+
pending.clear()
|
|
236
|
+
pending_samples = 0
|
|
237
|
+
if pending:
|
|
238
|
+
asyncio.run_coroutine_threadsafe(
|
|
239
|
+
queue.put(
|
|
240
|
+
AudioChunk(
|
|
241
|
+
bytes(pending),
|
|
242
|
+
sample_rate=self._model.sample_rate,
|
|
243
|
+
channels=1,
|
|
244
|
+
encoding="pcm_s16le",
|
|
245
|
+
duration_ms=int(
|
|
246
|
+
pending_samples * 1000 / self._model.sample_rate
|
|
247
|
+
),
|
|
248
|
+
final=True,
|
|
249
|
+
)
|
|
250
|
+
),
|
|
251
|
+
loop,
|
|
252
|
+
)
|
|
253
|
+
self._emit_progress(loop, progress, "complete", "TTS complete.")
|
|
254
|
+
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
|
255
|
+
except Exception as exc: # pragma: no cover - threaded path
|
|
256
|
+
self._load_error = str(exc)
|
|
257
|
+
asyncio.run_coroutine_threadsafe(queue.put(exc), loop)
|
|
258
|
+
|
|
259
|
+
threading.Thread(target=worker, daemon=True).start()
|
|
260
|
+
|
|
261
|
+
while True:
|
|
262
|
+
item = await queue.get()
|
|
263
|
+
if item is None:
|
|
264
|
+
break
|
|
265
|
+
if isinstance(item, Exception):
|
|
266
|
+
raise item
|
|
267
|
+
yield item
|
|
268
|
+
|
|
269
|
+
def _ensure_model(self) -> None:
|
|
270
|
+
if self._model is not None and self._voice_state is not None:
|
|
271
|
+
return
|
|
272
|
+
from pocket_tts import TTSModel # type: ignore[import-not-found]
|
|
273
|
+
|
|
274
|
+
if self._model is None:
|
|
275
|
+
kwargs: dict = {"temp": self.temp, "quantize": self.quantize}
|
|
276
|
+
if self.language:
|
|
277
|
+
kwargs["language"] = self.language
|
|
278
|
+
self._model = TTSModel.load_model(**kwargs)
|
|
279
|
+
if self._voice_state is None:
|
|
280
|
+
assert self._model is not None
|
|
281
|
+
self._voice_state = self._model.get_state_for_audio_prompt(self.voice)
|
|
282
|
+
|
|
283
|
+
def set_voice(self, voice: str) -> ProviderStatus:
|
|
284
|
+
"""Change the active voice without reloading the full model.
|
|
285
|
+
|
|
286
|
+
Clears ``_voice_state`` (so the next synthesis reloads voice
|
|
287
|
+
state) but keeps ``_model`` when the language / temp / quantize
|
|
288
|
+
are unchanged.
|
|
289
|
+
"""
|
|
290
|
+
with self._lock:
|
|
291
|
+
if self.voice == voice:
|
|
292
|
+
return self.status
|
|
293
|
+
self.voice = voice
|
|
294
|
+
# Clear only voice state — model stays loaded
|
|
295
|
+
self._voice_state = None
|
|
296
|
+
self._load_error = None
|
|
297
|
+
return self.status
|
|
298
|
+
|
|
299
|
+
async def configure(self, **options) -> ProviderConfigResult:
|
|
300
|
+
"""Apply configuration changes.
|
|
301
|
+
|
|
302
|
+
Supported options:
|
|
303
|
+
|
|
304
|
+
* ``voice`` — changes voice, reloads voice state only.
|
|
305
|
+
* ``quantize`` — changes quantization, unloads model and voice.
|
|
306
|
+
* ``language`` — changes language, unloads model and voice.
|
|
307
|
+
* ``temp`` — changes temperature, unloads model and voice.
|
|
308
|
+
* ``max_tokens`` — changes max tokens, no unload.
|
|
309
|
+
* ``coalesce_ms`` — changes coalesce window, no unload.
|
|
310
|
+
"""
|
|
311
|
+
changed = False
|
|
312
|
+
requires_reload = False
|
|
313
|
+
parts: list[str] = []
|
|
314
|
+
|
|
315
|
+
with self._lock:
|
|
316
|
+
if "voice" in options:
|
|
317
|
+
v = str(options["voice"])
|
|
318
|
+
if v != self.voice:
|
|
319
|
+
self.voice = v
|
|
320
|
+
self._voice_state = None
|
|
321
|
+
changed = True
|
|
322
|
+
requires_reload = True
|
|
323
|
+
parts.append(f"voice={v}")
|
|
324
|
+
|
|
325
|
+
if "quantize" in options:
|
|
326
|
+
q = bool(options["quantize"])
|
|
327
|
+
if q != self.quantize:
|
|
328
|
+
self.quantize = q
|
|
329
|
+
self._model = None
|
|
330
|
+
self._voice_state = None
|
|
331
|
+
changed = True
|
|
332
|
+
requires_reload = True
|
|
333
|
+
parts.append(f"quantize={q}")
|
|
334
|
+
|
|
335
|
+
if "language" in options:
|
|
336
|
+
lang = options["language"]
|
|
337
|
+
if lang != self.language:
|
|
338
|
+
self.language = lang
|
|
339
|
+
self._model = None
|
|
340
|
+
self._voice_state = None
|
|
341
|
+
changed = True
|
|
342
|
+
requires_reload = True
|
|
343
|
+
parts.append(f"language={lang}")
|
|
344
|
+
|
|
345
|
+
if "temp" in options:
|
|
346
|
+
t = float(options["temp"])
|
|
347
|
+
if abs(t - self.temp) > 1e-6:
|
|
348
|
+
self.temp = t
|
|
349
|
+
self._model = None
|
|
350
|
+
self._voice_state = None
|
|
351
|
+
changed = True
|
|
352
|
+
requires_reload = True
|
|
353
|
+
parts.append(f"temp={t}")
|
|
354
|
+
|
|
355
|
+
if "max_tokens" in options:
|
|
356
|
+
m = int(options["max_tokens"])
|
|
357
|
+
if m != self.max_tokens:
|
|
358
|
+
self.max_tokens = m
|
|
359
|
+
changed = True
|
|
360
|
+
parts.append(f"max_tokens={m}")
|
|
361
|
+
|
|
362
|
+
if "coalesce_ms" in options:
|
|
363
|
+
c = int(options["coalesce_ms"])
|
|
364
|
+
if c != self.coalesce_ms:
|
|
365
|
+
self.coalesce_ms = c
|
|
366
|
+
changed = True
|
|
367
|
+
parts.append(f"coalesce_ms={c}")
|
|
368
|
+
|
|
369
|
+
self._load_error = None
|
|
370
|
+
message = ", ".join(parts) if parts else "no changes"
|
|
371
|
+
|
|
372
|
+
return ProviderConfigResult(
|
|
373
|
+
status=self.status,
|
|
374
|
+
changed=changed,
|
|
375
|
+
requires_reload=requires_reload,
|
|
376
|
+
message=message,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def list_voices(self) -> tuple[VoiceInfo, ...]:
|
|
380
|
+
"""Return structured voice metadata.
|
|
381
|
+
|
|
382
|
+
Returns the known voice list without importing the heavy
|
|
383
|
+
``pocket_tts`` backend.
|
|
384
|
+
"""
|
|
385
|
+
return tuple(
|
|
386
|
+
VoiceInfo(
|
|
387
|
+
id=v["id"],
|
|
388
|
+
label=v["label"],
|
|
389
|
+
language=v.get("language", "en"),
|
|
390
|
+
description=v.get("description", ""),
|
|
391
|
+
gender=v.get("gender", "neutral"),
|
|
392
|
+
)
|
|
393
|
+
for v in self._known_voices
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def _emit_progress(
|
|
397
|
+
self,
|
|
398
|
+
loop: asyncio.AbstractEventLoop,
|
|
399
|
+
progress: ProgressCallback | None,
|
|
400
|
+
stage: str,
|
|
401
|
+
message: str,
|
|
402
|
+
) -> None:
|
|
403
|
+
if not progress:
|
|
404
|
+
return
|
|
405
|
+
|
|
406
|
+
async def _fire() -> None:
|
|
407
|
+
await progress("tts.progress", {"stage": stage, "message": message})
|
|
408
|
+
|
|
409
|
+
asyncio.run_coroutine_threadsafe(_fire(), loop)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Silero VAD provider.
|
|
2
|
+
|
|
3
|
+
Heavy dependencies (``silero-vad``, ``torch``, ``onnxruntime``) are imported
|
|
4
|
+
lazily inside :meth:`_ensure_model` so the base :mod:`converse_framework`
|
|
5
|
+
package stays light. Install with::
|
|
6
|
+
|
|
7
|
+
pip install 'converse-framework[silero]'
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import struct
|
|
13
|
+
from typing import Protocol
|
|
14
|
+
|
|
15
|
+
from converse_framework.audio_utils import AudioFrame
|
|
16
|
+
from converse_framework.protocols import (
|
|
17
|
+
ProviderCapabilities,
|
|
18
|
+
ProviderStatus,
|
|
19
|
+
VADEvent,
|
|
20
|
+
VADProvider,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SileroModel(Protocol):
|
|
25
|
+
def __call__(self, chunk, sample_rate: int): ...
|
|
26
|
+
|
|
27
|
+
def reset_states(self) -> None: ...
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SileroVADProvider(VADProvider):
|
|
31
|
+
def __init__(self, config: dict):
|
|
32
|
+
self.threshold = float(config.get("speech_threshold", 0.6))
|
|
33
|
+
self.neg_threshold = float(
|
|
34
|
+
config.get("neg_threshold", max(0.15, self.threshold - 0.15))
|
|
35
|
+
)
|
|
36
|
+
self.hangover_ms = int(config.get("hangover_ms", 450))
|
|
37
|
+
self.window_samples = int(config.get("window_samples", 512))
|
|
38
|
+
self.sample_rate = int(config.get("sample_rate", 16000))
|
|
39
|
+
self._model: SileroModel | None = config.get("_model")
|
|
40
|
+
self._torch = None
|
|
41
|
+
self._buffer = bytearray()
|
|
42
|
+
self._speaking = False
|
|
43
|
+
self._silence_ms = 0
|
|
44
|
+
self._audio_ms = 0
|
|
45
|
+
self._load_error: str | None = None
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def status(self) -> ProviderStatus:
|
|
49
|
+
ready = self._model is not None and self._load_error is None
|
|
50
|
+
if self._load_error:
|
|
51
|
+
message = f"Silero VAD failed to load: {self._load_error}"
|
|
52
|
+
status_level = "error"
|
|
53
|
+
elif ready:
|
|
54
|
+
message = "Silero VAD ONNX model loaded."
|
|
55
|
+
status_level = "ready"
|
|
56
|
+
else:
|
|
57
|
+
message = "Silero VAD is configured and will load on first status check."
|
|
58
|
+
status_level = "configured"
|
|
59
|
+
return ProviderStatus(
|
|
60
|
+
name="silero-vad",
|
|
61
|
+
kind="vad",
|
|
62
|
+
ready=ready,
|
|
63
|
+
message=message,
|
|
64
|
+
capabilities=ProviderCapabilities(supports_barge_in=True),
|
|
65
|
+
provider_id="silero",
|
|
66
|
+
loaded=ready,
|
|
67
|
+
status_level=status_level,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
async def check_status(self) -> ProviderStatus:
|
|
71
|
+
"""Legacy compat: probes and loads model."""
|
|
72
|
+
self._ensure_model()
|
|
73
|
+
return self.status
|
|
74
|
+
|
|
75
|
+
async def probe_status(self) -> ProviderStatus:
|
|
76
|
+
"""Cheap probe: check import availability, no model load."""
|
|
77
|
+
if self._model is None and self._load_error is None:
|
|
78
|
+
try:
|
|
79
|
+
import silero_vad # type: ignore[import-not-found] # noqa: F401
|
|
80
|
+
except Exception as exc: # pragma: no cover - import path
|
|
81
|
+
self._load_error = str(exc)
|
|
82
|
+
return self.status
|
|
83
|
+
|
|
84
|
+
async def load_status(self) -> ProviderStatus:
|
|
85
|
+
"""May load heavy resources."""
|
|
86
|
+
self._ensure_model()
|
|
87
|
+
return self.status
|
|
88
|
+
|
|
89
|
+
async def process_frame(self, frame: AudioFrame) -> list[VADEvent]:
|
|
90
|
+
self._ensure_model()
|
|
91
|
+
if self._model is None or self._torch is None:
|
|
92
|
+
return []
|
|
93
|
+
if frame.sample_rate != self.sample_rate:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Silero VAD expected {self.sample_rate} Hz audio, "
|
|
96
|
+
f"got {frame.sample_rate}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self._buffer.extend(frame.data)
|
|
100
|
+
events: list[VADEvent] = []
|
|
101
|
+
window_bytes = self.window_samples * 2
|
|
102
|
+
while len(self._buffer) >= window_bytes:
|
|
103
|
+
chunk_bytes = bytes(self._buffer[:window_bytes])
|
|
104
|
+
del self._buffer[:window_bytes]
|
|
105
|
+
probability = self._infer_probability(chunk_bytes)
|
|
106
|
+
self._audio_ms += int(self.window_samples * 1000 / self.sample_rate)
|
|
107
|
+
transition = self._update_state(probability)
|
|
108
|
+
events.append(VADEvent("vad.probability", probability, self._audio_ms))
|
|
109
|
+
if transition:
|
|
110
|
+
events.append(VADEvent(transition, probability, self._audio_ms))
|
|
111
|
+
return events
|
|
112
|
+
|
|
113
|
+
def reset(self) -> None:
|
|
114
|
+
self._buffer.clear()
|
|
115
|
+
self._speaking = False
|
|
116
|
+
self._silence_ms = 0
|
|
117
|
+
self._audio_ms = 0
|
|
118
|
+
if self._model:
|
|
119
|
+
self._model.reset_states()
|
|
120
|
+
|
|
121
|
+
async def unload(self) -> ProviderStatus:
|
|
122
|
+
self._model = None
|
|
123
|
+
self._torch = None
|
|
124
|
+
self._buffer.clear()
|
|
125
|
+
self._load_error = None
|
|
126
|
+
return self.status
|
|
127
|
+
|
|
128
|
+
def _ensure_model(self) -> None:
|
|
129
|
+
if self._model is not None or self._load_error:
|
|
130
|
+
return
|
|
131
|
+
try:
|
|
132
|
+
import torch # type: ignore[import-not-found]
|
|
133
|
+
from silero_vad import load_silero_vad # type: ignore[import-not-found]
|
|
134
|
+
|
|
135
|
+
self._torch = torch
|
|
136
|
+
self._model = load_silero_vad(onnx=True)
|
|
137
|
+
except Exception as exc: # pragma: no cover - import path
|
|
138
|
+
self._load_error = str(exc)
|
|
139
|
+
|
|
140
|
+
def _infer_probability(self, chunk_bytes: bytes) -> float:
|
|
141
|
+
assert self._model is not None and self._torch is not None
|
|
142
|
+
samples = struct.unpack(f"<{self.window_samples}h", chunk_bytes)
|
|
143
|
+
tensor = self._torch.tensor(samples, dtype=self._torch.float32) / 32768.0
|
|
144
|
+
result = self._model(tensor, self.sample_rate)
|
|
145
|
+
return round(float(result.item()), 4)
|
|
146
|
+
|
|
147
|
+
def _update_state(self, probability: float) -> str | None:
|
|
148
|
+
if probability >= self.threshold:
|
|
149
|
+
self._silence_ms = 0
|
|
150
|
+
if not self._speaking:
|
|
151
|
+
self._speaking = True
|
|
152
|
+
return "vad.speech_start"
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
if self._speaking and probability < self.neg_threshold:
|
|
156
|
+
self._silence_ms += int(self.window_samples * 1000 / self.sample_rate)
|
|
157
|
+
if self._silence_ms >= self.hangover_ms:
|
|
158
|
+
self._speaking = False
|
|
159
|
+
self._silence_ms = 0
|
|
160
|
+
return "vad.speech_end"
|
|
161
|
+
return None
|