converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""faster-whisper ASR provider.
|
|
2
|
+
|
|
3
|
+
The ``faster_whisper`` package is imported lazily inside
|
|
4
|
+
:meth:`_ensure_model` and :meth:`check_status` so the base
|
|
5
|
+
:mod:`converse_framework` package stays light. Install with::
|
|
6
|
+
|
|
7
|
+
pip install 'converse-framework[faster-whisper]'
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
import contextlib
|
|
14
|
+
import logging
|
|
15
|
+
import time
|
|
16
|
+
from collections.abc import AsyncIterator
|
|
17
|
+
|
|
18
|
+
from converse_framework.audio_utils import pcm_s16le_to_float32
|
|
19
|
+
from converse_framework.protocols import (
|
|
20
|
+
ASRProvider,
|
|
21
|
+
ProgressCallback,
|
|
22
|
+
ProviderCapabilities,
|
|
23
|
+
ProviderStatus,
|
|
24
|
+
TranscriptEvent,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FasterWhisperASRProvider(ASRProvider):
|
|
31
|
+
def __init__(self, config: dict):
|
|
32
|
+
self.model_name = str(config.get("model", "large-v3-turbo"))
|
|
33
|
+
self.device = str(config.get("device", "auto"))
|
|
34
|
+
self.compute_type = str(config.get("compute_type", "auto"))
|
|
35
|
+
self.language = config.get("language", "en")
|
|
36
|
+
self.beam_size = int(config.get("beam_size", 1))
|
|
37
|
+
self.vad_filter = bool(config.get("vad_filter", False))
|
|
38
|
+
self.initial_prompt = config.get("initial_prompt")
|
|
39
|
+
self.condition_on_previous_text = bool(
|
|
40
|
+
config.get("condition_on_previous_text", False)
|
|
41
|
+
)
|
|
42
|
+
self.temperature = config.get("temperature", 0)
|
|
43
|
+
self.compression_ratio_threshold = config.get(
|
|
44
|
+
"compression_ratio_threshold", 2.4
|
|
45
|
+
)
|
|
46
|
+
self.log_prob_threshold = config.get("log_prob_threshold", -0.5)
|
|
47
|
+
self.no_speech_threshold = config.get("no_speech_threshold", 0.2)
|
|
48
|
+
self.suppress_tokens = config.get("suppress_tokens")
|
|
49
|
+
self.timeout_s = float(config.get("timeout_s", 120))
|
|
50
|
+
self._model = config.get("_model")
|
|
51
|
+
self._load_error: str | None = None
|
|
52
|
+
self._auto_cuda_dll_dirs = bool(config.get("auto_cuda_dll_dirs", True))
|
|
53
|
+
self._cuda_dll_handles: list[object] = []
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def status(self) -> ProviderStatus:
|
|
57
|
+
if self._load_error:
|
|
58
|
+
return ProviderStatus(
|
|
59
|
+
name="faster-whisper",
|
|
60
|
+
kind="asr",
|
|
61
|
+
ready=False,
|
|
62
|
+
message=f"faster-whisper failed to load: {self._load_error}",
|
|
63
|
+
capabilities=ProviderCapabilities(),
|
|
64
|
+
status_level="error",
|
|
65
|
+
)
|
|
66
|
+
if self._model is not None:
|
|
67
|
+
message = f"Loaded {self.model_name} on {self.device}/{self.compute_type}."
|
|
68
|
+
status_level = "ready"
|
|
69
|
+
else:
|
|
70
|
+
message = (
|
|
71
|
+
f"Configured for {self.model_name} on {self.device}/{self.compute_type}. "
|
|
72
|
+
"Model loads on first voice transcription and may download if not cached."
|
|
73
|
+
)
|
|
74
|
+
status_level = "configured"
|
|
75
|
+
return ProviderStatus(
|
|
76
|
+
name="faster-whisper",
|
|
77
|
+
kind="asr",
|
|
78
|
+
ready=True,
|
|
79
|
+
message=message,
|
|
80
|
+
capabilities=ProviderCapabilities(
|
|
81
|
+
languages=(str(self.language),) if self.language else ("auto",)
|
|
82
|
+
),
|
|
83
|
+
provider_id="faster-whisper",
|
|
84
|
+
loaded=self._model is not None,
|
|
85
|
+
active_model=self.model_name,
|
|
86
|
+
models=({"id": self.model_name, "label": self.model_name},),
|
|
87
|
+
status_level=status_level,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
async def check_status(self) -> ProviderStatus:
|
|
91
|
+
return await self.probe_status()
|
|
92
|
+
|
|
93
|
+
async def probe_status(self) -> ProviderStatus:
|
|
94
|
+
"""Cheap probe: check import availability, no model load."""
|
|
95
|
+
if self._model is None:
|
|
96
|
+
try:
|
|
97
|
+
import faster_whisper # type: ignore[import-not-found] # noqa: F401
|
|
98
|
+
except Exception as exc: # pragma: no cover - import path
|
|
99
|
+
self._load_error = str(exc)
|
|
100
|
+
return self.status
|
|
101
|
+
|
|
102
|
+
async def load_status(self) -> ProviderStatus:
|
|
103
|
+
"""May load heavy resources."""
|
|
104
|
+
return await self.load()
|
|
105
|
+
|
|
106
|
+
async def load(self) -> ProviderStatus:
|
|
107
|
+
if self._model is not None:
|
|
108
|
+
return self.status
|
|
109
|
+
try:
|
|
110
|
+
await asyncio.wait_for(
|
|
111
|
+
asyncio.to_thread(self._ensure_model), timeout=self.timeout_s
|
|
112
|
+
)
|
|
113
|
+
except asyncio.TimeoutError:
|
|
114
|
+
self._load_error = f"Model load timed out after {self.timeout_s}s"
|
|
115
|
+
raise
|
|
116
|
+
return self.status
|
|
117
|
+
|
|
118
|
+
async def transcribe_text_input(self, text: str) -> AsyncIterator[TranscriptEvent]:
|
|
119
|
+
stripped = text.strip()
|
|
120
|
+
if stripped:
|
|
121
|
+
yield TranscriptEvent(text=stripped, final=True)
|
|
122
|
+
|
|
123
|
+
async def transcribe_audio(
|
|
124
|
+
self,
|
|
125
|
+
pcm_s16le: bytes,
|
|
126
|
+
sample_rate: int,
|
|
127
|
+
progress: ProgressCallback | None = None,
|
|
128
|
+
) -> AsyncIterator[TranscriptEvent]:
|
|
129
|
+
if sample_rate != 16000:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"faster-whisper expects 16000 Hz audio, got {sample_rate}"
|
|
132
|
+
)
|
|
133
|
+
audio = pcm_s16le_to_float32(pcm_s16le)
|
|
134
|
+
if audio.size == 0:
|
|
135
|
+
return
|
|
136
|
+
if progress:
|
|
137
|
+
await progress(
|
|
138
|
+
"asr.progress",
|
|
139
|
+
{
|
|
140
|
+
"stage": "queued",
|
|
141
|
+
"message": (
|
|
142
|
+
f"Queued {round(audio.size / sample_rate, 2)}s utterance "
|
|
143
|
+
"for faster-whisper."
|
|
144
|
+
),
|
|
145
|
+
},
|
|
146
|
+
)
|
|
147
|
+
loop = asyncio.get_running_loop()
|
|
148
|
+
try:
|
|
149
|
+
segments_text = await asyncio.wait_for(
|
|
150
|
+
asyncio.to_thread(self._transcribe_blocking, audio, progress, loop),
|
|
151
|
+
timeout=self.timeout_s,
|
|
152
|
+
)
|
|
153
|
+
except asyncio.TimeoutError:
|
|
154
|
+
if self._model is None:
|
|
155
|
+
self._load_error = (
|
|
156
|
+
f"Model load or transcription timed out after {self.timeout_s}s"
|
|
157
|
+
)
|
|
158
|
+
raise
|
|
159
|
+
text = " ".join(part for part in segments_text if part).strip()
|
|
160
|
+
if progress:
|
|
161
|
+
await progress(
|
|
162
|
+
"asr.progress",
|
|
163
|
+
{"stage": "complete", "message": "ASR transcription complete."},
|
|
164
|
+
)
|
|
165
|
+
if text:
|
|
166
|
+
yield TranscriptEvent(text=text, final=True)
|
|
167
|
+
|
|
168
|
+
def _transcribe_blocking(
|
|
169
|
+
self,
|
|
170
|
+
audio,
|
|
171
|
+
progress: ProgressCallback | None,
|
|
172
|
+
loop: asyncio.AbstractEventLoop,
|
|
173
|
+
) -> list[str]:
|
|
174
|
+
started = time.perf_counter()
|
|
175
|
+
logger.info(
|
|
176
|
+
"[ASR] transcribe_blocking called, audio length=%d samples (%.2fs)",
|
|
177
|
+
audio.size,
|
|
178
|
+
audio.size / 16000,
|
|
179
|
+
)
|
|
180
|
+
self._emit_progress_threadsafe(
|
|
181
|
+
loop,
|
|
182
|
+
progress,
|
|
183
|
+
"loading",
|
|
184
|
+
f"Loading faster-whisper model {self.model_name}.",
|
|
185
|
+
)
|
|
186
|
+
with contextlib.suppress(Exception):
|
|
187
|
+
self._ensure_model()
|
|
188
|
+
if self._model is None:
|
|
189
|
+
raise RuntimeError(
|
|
190
|
+
f"faster-whisper model did not load: {self._load_error or 'unknown error'}"
|
|
191
|
+
)
|
|
192
|
+
self._emit_progress_threadsafe(
|
|
193
|
+
loop,
|
|
194
|
+
progress,
|
|
195
|
+
"loaded",
|
|
196
|
+
f"Model ready after {round(time.perf_counter() - started, 1)}s. "
|
|
197
|
+
"Running inference.",
|
|
198
|
+
)
|
|
199
|
+
logger.info(
|
|
200
|
+
"[ASR] model loaded in %.1fs, starting inference on %s/%s",
|
|
201
|
+
time.perf_counter() - started,
|
|
202
|
+
self.device,
|
|
203
|
+
self.compute_type,
|
|
204
|
+
)
|
|
205
|
+
transcribe_options = {
|
|
206
|
+
"language": self.language,
|
|
207
|
+
"beam_size": self.beam_size,
|
|
208
|
+
"vad_filter": self.vad_filter,
|
|
209
|
+
"initial_prompt": self.initial_prompt,
|
|
210
|
+
"condition_on_previous_text": self.condition_on_previous_text,
|
|
211
|
+
"temperature": self.temperature,
|
|
212
|
+
"compression_ratio_threshold": self.compression_ratio_threshold,
|
|
213
|
+
"log_prob_threshold": self.log_prob_threshold,
|
|
214
|
+
"no_speech_threshold": self.no_speech_threshold,
|
|
215
|
+
}
|
|
216
|
+
if self.suppress_tokens is not None:
|
|
217
|
+
transcribe_options["suppress_tokens"] = self.suppress_tokens
|
|
218
|
+
segments, _info = self._model.transcribe(audio, **transcribe_options)
|
|
219
|
+
logger.info("[ASR] inference call returned, iterating segments...")
|
|
220
|
+
texts: list[str] = []
|
|
221
|
+
for segment in segments:
|
|
222
|
+
text = segment.text.strip()
|
|
223
|
+
if text:
|
|
224
|
+
texts.append(text)
|
|
225
|
+
start = getattr(segment, "start", None)
|
|
226
|
+
end = getattr(segment, "end", None)
|
|
227
|
+
prefix = ""
|
|
228
|
+
if start is not None and end is not None:
|
|
229
|
+
prefix = (
|
|
230
|
+
f"Segment {round(float(start), 2)}-{round(float(end), 2)}s: "
|
|
231
|
+
)
|
|
232
|
+
self._emit_progress_threadsafe(
|
|
233
|
+
loop, progress, "segment", f"{prefix}{text}"
|
|
234
|
+
)
|
|
235
|
+
logger.info(
|
|
236
|
+
"[ASR] all segments collected in %.1fs, %d segments with text",
|
|
237
|
+
time.perf_counter() - started,
|
|
238
|
+
len(texts),
|
|
239
|
+
)
|
|
240
|
+
return texts
|
|
241
|
+
|
|
242
|
+
def _ensure_model(self) -> None:
|
|
243
|
+
if self._model is not None:
|
|
244
|
+
return
|
|
245
|
+
# Windows CUDA DLL directory discovery (conservative, best-effort).
|
|
246
|
+
if self._auto_cuda_dll_dirs:
|
|
247
|
+
self._add_cuda_dll_dirs()
|
|
248
|
+
try:
|
|
249
|
+
from faster_whisper import WhisperModel # type: ignore[import-not-found]
|
|
250
|
+
|
|
251
|
+
self._model = WhisperModel(
|
|
252
|
+
self.model_name, device=self.device, compute_type=self.compute_type
|
|
253
|
+
)
|
|
254
|
+
except Exception as exc: # pragma: no cover - import path
|
|
255
|
+
self._load_error = str(exc)
|
|
256
|
+
raise
|
|
257
|
+
|
|
258
|
+
def _add_cuda_dll_dirs(self) -> None:
|
|
259
|
+
"""Register NVIDIA wheel DLL directories if on Windows."""
|
|
260
|
+
# Local import to keep cuda_utils optional.
|
|
261
|
+
try:
|
|
262
|
+
from converse_framework.cuda_utils import add_nvidia_dll_directories
|
|
263
|
+
|
|
264
|
+
self._cuda_dll_handles = add_nvidia_dll_directories()
|
|
265
|
+
except Exception: # pragma: no cover
|
|
266
|
+
logger.debug("CUDA DLL discovery skipped or failed.")
|
|
267
|
+
|
|
268
|
+
async def unload(self) -> ProviderStatus:
|
|
269
|
+
if self._model is not None:
|
|
270
|
+
logger.info(
|
|
271
|
+
"[ASR] unloading faster-whisper model (%s/%s)",
|
|
272
|
+
self.device,
|
|
273
|
+
self.compute_type,
|
|
274
|
+
)
|
|
275
|
+
self._model = None
|
|
276
|
+
self._load_error = None
|
|
277
|
+
self._cuda_dll_handles.clear()
|
|
278
|
+
return self.status
|
|
279
|
+
|
|
280
|
+
def _emit_progress_threadsafe(
|
|
281
|
+
self,
|
|
282
|
+
loop: asyncio.AbstractEventLoop,
|
|
283
|
+
progress: ProgressCallback | None,
|
|
284
|
+
stage: str,
|
|
285
|
+
message: str,
|
|
286
|
+
) -> None:
|
|
287
|
+
if not progress:
|
|
288
|
+
return
|
|
289
|
+
coro = progress("asr.progress", {"stage": stage, "message": message})
|
|
290
|
+
asyncio.run_coroutine_threadsafe(coro, loop) # type: ignore[arg-type]
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
"""Kokoro ONNX TTS provider.
|
|
2
|
+
|
|
3
|
+
The ``kokoro_onnx``, ``misaki``, and ``httpx`` packages are imported lazily
|
|
4
|
+
inside :meth:`_ensure_model` and :meth:`_download_asset` so the base
|
|
5
|
+
:mod:`converse_framework` package stays light. Install with::
|
|
6
|
+
|
|
7
|
+
pip install 'converse-framework[kokoro]'
|
|
8
|
+
|
|
9
|
+
The default cache directory is platform-aware and does not depend on the
|
|
10
|
+
harness ``PROJECT_ROOT``. The cache location can be overridden via the
|
|
11
|
+
``cache_dir`` config key, or by setting the ``CONVERSE_FRAMEWORK_CACHE_DIR``
|
|
12
|
+
environment variable (the provider appends ``/kokoro`` to that path).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import asyncio
|
|
18
|
+
import os
|
|
19
|
+
import threading
|
|
20
|
+
import time
|
|
21
|
+
from collections.abc import AsyncIterator
|
|
22
|
+
from dataclasses import replace
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
from converse_framework.audio_utils import float_audio_to_pcm_s16le_bytes
|
|
26
|
+
from converse_framework.protocols import (
|
|
27
|
+
AudioChunk,
|
|
28
|
+
ProgressCallback,
|
|
29
|
+
ProviderCapabilities,
|
|
30
|
+
ProviderStatus,
|
|
31
|
+
TTSProvider,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
DEFAULT_KOKORO_MODEL_URL = (
|
|
36
|
+
"https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/"
|
|
37
|
+
"kokoro-v1.0.int8.onnx"
|
|
38
|
+
)
|
|
39
|
+
DEFAULT_KOKORO_VOICES_URL = (
|
|
40
|
+
"https://github.com/thewh1teagle/kokoro-onnx/releases/download/model-files-v1.0/"
|
|
41
|
+
"voices-v1.0.bin"
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _default_cache_dir() -> Path:
|
|
46
|
+
"""Return a platform-appropriate cache directory for Kokoro assets.
|
|
47
|
+
|
|
48
|
+
Resolution order:
|
|
49
|
+
1. ``CONVERSE_FRAMEWORK_CACHE_DIR`` environment variable, with ``kokoro``
|
|
50
|
+
appended.
|
|
51
|
+
2. ``~/.cache/converse-framework/kokoro`` on POSIX-likes.
|
|
52
|
+
3. ``%LOCALAPPDATA%/converse-framework/kokoro`` (or
|
|
53
|
+
``~/.cache/converse-framework/kokoro``) on Windows.
|
|
54
|
+
"""
|
|
55
|
+
env_root = os.environ.get("CONVERSE_FRAMEWORK_CACHE_DIR")
|
|
56
|
+
if env_root:
|
|
57
|
+
return Path(env_root) / "kokoro"
|
|
58
|
+
if os.name == "nt":
|
|
59
|
+
local_app_data = os.environ.get("LOCALAPPDATA")
|
|
60
|
+
if local_app_data:
|
|
61
|
+
return Path(local_app_data) / "converse-framework" / "kokoro"
|
|
62
|
+
return Path.home() / ".cache" / "converse-framework" / "kokoro"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class KokoroOnnxProvider(TTSProvider):
|
|
66
|
+
def __init__(self, config: dict):
|
|
67
|
+
self.voice = str(config.get("voice", "af_heart"))
|
|
68
|
+
self.lang = str(config.get("lang", "en-us"))
|
|
69
|
+
self.speed = float(config.get("speed", 1.0))
|
|
70
|
+
self.trim = bool(config.get("trim", True))
|
|
71
|
+
self.timeout_s = float(config.get("timeout_s", 300))
|
|
72
|
+
configured_cache = config.get("cache_dir")
|
|
73
|
+
self.cache_dir = (
|
|
74
|
+
Path(str(configured_cache)) if configured_cache else _default_cache_dir()
|
|
75
|
+
)
|
|
76
|
+
self.model_filename = str(config.get("model_filename", "kokoro-v1.0.int8.onnx"))
|
|
77
|
+
self.voices_filename = str(config.get("voices_filename", "voices-v1.0.bin"))
|
|
78
|
+
self.model_url = str(config.get("model_url", DEFAULT_KOKORO_MODEL_URL))
|
|
79
|
+
self.voices_url = str(config.get("voices_url", DEFAULT_KOKORO_VOICES_URL))
|
|
80
|
+
self.onnx_intra_op_num_threads = int(config.get("onnx_intra_op_num_threads", 4))
|
|
81
|
+
self.onnx_inter_op_num_threads = int(config.get("onnx_inter_op_num_threads", 1))
|
|
82
|
+
self.preload_g2p = bool(config.get("preload_g2p", True))
|
|
83
|
+
self._model = config.get("_model")
|
|
84
|
+
self._g2p = config.get("_g2p")
|
|
85
|
+
self._load_error: str | None = None
|
|
86
|
+
self._g2p_error: str | None = None
|
|
87
|
+
self._lock = threading.Lock()
|
|
88
|
+
self._generation_lock = asyncio.Lock()
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def status(self) -> ProviderStatus:
|
|
92
|
+
loaded = self._model is not None
|
|
93
|
+
if self._load_error:
|
|
94
|
+
return ProviderStatus(
|
|
95
|
+
name="kokoro-onnx",
|
|
96
|
+
kind="tts",
|
|
97
|
+
ready=False,
|
|
98
|
+
message=f"Kokoro ONNX failed to load: {self._load_error}",
|
|
99
|
+
capabilities=ProviderCapabilities(
|
|
100
|
+
supports_streaming_tts=True, languages=("en",)
|
|
101
|
+
),
|
|
102
|
+
provider_id="kokoro-onnx",
|
|
103
|
+
loaded=False,
|
|
104
|
+
supports_model_management=True,
|
|
105
|
+
supports_voice_selection=True,
|
|
106
|
+
active_voice=self.voice,
|
|
107
|
+
status_level="error",
|
|
108
|
+
)
|
|
109
|
+
if self._g2p_error:
|
|
110
|
+
return ProviderStatus(
|
|
111
|
+
name="kokoro-onnx",
|
|
112
|
+
kind="tts",
|
|
113
|
+
ready=False,
|
|
114
|
+
message=f"Kokoro English G2P failed: {self._g2p_error}",
|
|
115
|
+
capabilities=ProviderCapabilities(
|
|
116
|
+
supports_streaming_tts=True, languages=("en",)
|
|
117
|
+
),
|
|
118
|
+
provider_id="kokoro-onnx",
|
|
119
|
+
loaded=loaded,
|
|
120
|
+
supports_model_management=True,
|
|
121
|
+
supports_voice_selection=True,
|
|
122
|
+
active_voice=self.voice,
|
|
123
|
+
status_level="error",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if loaded:
|
|
127
|
+
message = f"Loaded Kokoro v1.0 ONNX voice '{self.voice}' ({self.lang})."
|
|
128
|
+
status_level = "ready"
|
|
129
|
+
else:
|
|
130
|
+
message = (
|
|
131
|
+
f"Configured for Kokoro v1.0 ONNX voice '{self.voice}' ({self.lang}). "
|
|
132
|
+
"Model loads on first TTS request."
|
|
133
|
+
)
|
|
134
|
+
status_level = "configured"
|
|
135
|
+
return ProviderStatus(
|
|
136
|
+
name="kokoro-onnx",
|
|
137
|
+
kind="tts",
|
|
138
|
+
ready=True,
|
|
139
|
+
message=message,
|
|
140
|
+
capabilities=ProviderCapabilities(
|
|
141
|
+
supports_streaming_tts=True, languages=("en",)
|
|
142
|
+
),
|
|
143
|
+
provider_id="kokoro-onnx",
|
|
144
|
+
loaded=loaded,
|
|
145
|
+
supports_model_management=True,
|
|
146
|
+
supports_voice_selection=True,
|
|
147
|
+
active_voice=self.voice,
|
|
148
|
+
status_level=status_level,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def check_status(self) -> ProviderStatus:
|
|
152
|
+
return await self.probe_status()
|
|
153
|
+
|
|
154
|
+
async def probe_status(self) -> ProviderStatus:
|
|
155
|
+
"""Cheap probe: check import availability, no model load."""
|
|
156
|
+
try:
|
|
157
|
+
import kokoro_onnx # type: ignore[import-not-found] # noqa: F401
|
|
158
|
+
|
|
159
|
+
if self._should_use_misaki():
|
|
160
|
+
from misaki import en as _en # type: ignore[import-not-found] # noqa: F401
|
|
161
|
+
from misaki import espeak as _espeak # type: ignore[import-not-found] # noqa: F401
|
|
162
|
+
except Exception as exc: # pragma: no cover - import path
|
|
163
|
+
if self._should_use_misaki():
|
|
164
|
+
self._g2p_error = str(exc)
|
|
165
|
+
else:
|
|
166
|
+
self._load_error = str(exc)
|
|
167
|
+
return self.status
|
|
168
|
+
|
|
169
|
+
async def load_status(self) -> ProviderStatus:
|
|
170
|
+
"""May load heavy resources."""
|
|
171
|
+
return await self.load()
|
|
172
|
+
|
|
173
|
+
async def load(self) -> ProviderStatus:
|
|
174
|
+
loop = asyncio.get_running_loop()
|
|
175
|
+
await loop.run_in_executor(None, self._ensure_model)
|
|
176
|
+
if self.preload_g2p and self._should_use_misaki():
|
|
177
|
+
await loop.run_in_executor(None, self._ensure_g2p)
|
|
178
|
+
return self.status
|
|
179
|
+
|
|
180
|
+
async def unload(self) -> ProviderStatus:
|
|
181
|
+
def release() -> None:
|
|
182
|
+
with self._lock:
|
|
183
|
+
self._model = None
|
|
184
|
+
self._load_error = None
|
|
185
|
+
self._g2p = None
|
|
186
|
+
self._g2p_error = None
|
|
187
|
+
|
|
188
|
+
loop = asyncio.get_running_loop()
|
|
189
|
+
await loop.run_in_executor(None, release)
|
|
190
|
+
return self.status
|
|
191
|
+
|
|
192
|
+
async def stream_audio(self, text: str) -> AsyncIterator[AudioChunk]:
|
|
193
|
+
async for chunk in self.stream_audio_with_progress(text):
|
|
194
|
+
yield chunk
|
|
195
|
+
|
|
196
|
+
async def stream_audio_with_progress(
|
|
197
|
+
self,
|
|
198
|
+
text: str,
|
|
199
|
+
progress: ProgressCallback | None = None,
|
|
200
|
+
) -> AsyncIterator[AudioChunk]:
|
|
201
|
+
loop = asyncio.get_running_loop()
|
|
202
|
+
started = time.perf_counter()
|
|
203
|
+
self._emit_progress(
|
|
204
|
+
loop,
|
|
205
|
+
progress,
|
|
206
|
+
"loading",
|
|
207
|
+
f"Loading Kokoro voice '{self.voice}'.",
|
|
208
|
+
started=started,
|
|
209
|
+
)
|
|
210
|
+
await loop.run_in_executor(None, self._ensure_model)
|
|
211
|
+
self._emit_progress(loop, progress, "loaded", "Kokoro ready.", started=started)
|
|
212
|
+
|
|
213
|
+
async with self._generation_lock:
|
|
214
|
+
stream_text = text
|
|
215
|
+
is_phonemes = False
|
|
216
|
+
if self._should_use_misaki():
|
|
217
|
+
self._emit_progress(
|
|
218
|
+
loop,
|
|
219
|
+
progress,
|
|
220
|
+
"phonemizing",
|
|
221
|
+
"Preparing English phonemes with Misaki.",
|
|
222
|
+
started=started,
|
|
223
|
+
)
|
|
224
|
+
stream_text = await loop.run_in_executor(
|
|
225
|
+
None, self._phonemize_english, text
|
|
226
|
+
)
|
|
227
|
+
is_phonemes = True
|
|
228
|
+
|
|
229
|
+
self._emit_progress(
|
|
230
|
+
loop,
|
|
231
|
+
progress,
|
|
232
|
+
"generating",
|
|
233
|
+
"Generating speech.",
|
|
234
|
+
started=started,
|
|
235
|
+
)
|
|
236
|
+
index = 0
|
|
237
|
+
previous_chunk: AudioChunk | None = None
|
|
238
|
+
assert self._model is not None
|
|
239
|
+
async for audio, sample_rate in self._model.create_stream(
|
|
240
|
+
stream_text,
|
|
241
|
+
voice=self.voice,
|
|
242
|
+
speed=self.speed,
|
|
243
|
+
lang=self.lang,
|
|
244
|
+
is_phonemes=is_phonemes,
|
|
245
|
+
trim=self.trim,
|
|
246
|
+
):
|
|
247
|
+
pcm_bytes = float_audio_to_pcm_s16le_bytes(audio)
|
|
248
|
+
if not pcm_bytes:
|
|
249
|
+
continue
|
|
250
|
+
index += 1
|
|
251
|
+
current_chunk = AudioChunk(
|
|
252
|
+
pcm_bytes,
|
|
253
|
+
sample_rate=sample_rate,
|
|
254
|
+
channels=1,
|
|
255
|
+
encoding="pcm_s16le",
|
|
256
|
+
duration_ms=(
|
|
257
|
+
int((len(pcm_bytes) // 2) * 1000 / sample_rate)
|
|
258
|
+
if sample_rate
|
|
259
|
+
else None
|
|
260
|
+
),
|
|
261
|
+
final=False,
|
|
262
|
+
)
|
|
263
|
+
self._emit_progress(
|
|
264
|
+
loop,
|
|
265
|
+
progress,
|
|
266
|
+
"chunk",
|
|
267
|
+
f"Generated audio chunk {index}.",
|
|
268
|
+
started=started,
|
|
269
|
+
chunk_index=index,
|
|
270
|
+
first_chunk=index == 1,
|
|
271
|
+
duration_ms=current_chunk.duration_ms,
|
|
272
|
+
)
|
|
273
|
+
if previous_chunk is not None:
|
|
274
|
+
yield previous_chunk
|
|
275
|
+
previous_chunk = current_chunk
|
|
276
|
+
|
|
277
|
+
if previous_chunk is not None:
|
|
278
|
+
yield replace(previous_chunk, final=True)
|
|
279
|
+
self._emit_progress(
|
|
280
|
+
loop,
|
|
281
|
+
progress,
|
|
282
|
+
"complete",
|
|
283
|
+
"TTS complete.",
|
|
284
|
+
started=started,
|
|
285
|
+
chunks=index,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
def _ensure_model(self) -> None:
|
|
289
|
+
if self._model is not None:
|
|
290
|
+
return
|
|
291
|
+
model_path = self._download_asset(self.model_url, self.model_filename)
|
|
292
|
+
voices_path = self._download_asset(self.voices_url, self.voices_filename)
|
|
293
|
+
from kokoro_onnx import Kokoro # type: ignore[import-not-found]
|
|
294
|
+
|
|
295
|
+
self._model = Kokoro(str(model_path), str(voices_path))
|
|
296
|
+
self._apply_onnx_session_options(str(model_path))
|
|
297
|
+
self._load_error = None
|
|
298
|
+
self._g2p_error = None
|
|
299
|
+
|
|
300
|
+
def _apply_onnx_session_options(self, model_path: str) -> None:
|
|
301
|
+
if self.onnx_intra_op_num_threads <= 0 and self.onnx_inter_op_num_threads <= 0:
|
|
302
|
+
return
|
|
303
|
+
if self._model is None or not hasattr(self._model, "sess"):
|
|
304
|
+
return
|
|
305
|
+
import onnxruntime as ort # type: ignore[import-not-found]
|
|
306
|
+
|
|
307
|
+
options = ort.SessionOptions()
|
|
308
|
+
if self.onnx_intra_op_num_threads > 0:
|
|
309
|
+
options.intra_op_num_threads = self.onnx_intra_op_num_threads
|
|
310
|
+
if self.onnx_inter_op_num_threads > 0:
|
|
311
|
+
options.inter_op_num_threads = self.onnx_inter_op_num_threads
|
|
312
|
+
providers = ["CPUExecutionProvider"]
|
|
313
|
+
self._model.sess = ort.InferenceSession(
|
|
314
|
+
model_path, sess_options=options, providers=providers
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def _ensure_g2p(self):
|
|
318
|
+
if self._g2p is not None:
|
|
319
|
+
return self._g2p
|
|
320
|
+
british = self._use_british_english()
|
|
321
|
+
from misaki import en, espeak # type: ignore[import-not-found]
|
|
322
|
+
|
|
323
|
+
self._g2p = en.G2P(
|
|
324
|
+
trf=False,
|
|
325
|
+
british=british,
|
|
326
|
+
fallback=espeak.EspeakFallback(british=british),
|
|
327
|
+
)
|
|
328
|
+
self._g2p_error = None
|
|
329
|
+
return self._g2p
|
|
330
|
+
|
|
331
|
+
def _phonemize_english(self, text: str) -> str:
|
|
332
|
+
try:
|
|
333
|
+
g2p = self._ensure_g2p()
|
|
334
|
+
phonemes, _tokens = g2p(text)
|
|
335
|
+
return str(phonemes).strip()
|
|
336
|
+
except Exception as exc: # pragma: no cover - exercised via tests
|
|
337
|
+
self._g2p_error = str(exc)
|
|
338
|
+
raise RuntimeError(f"Misaki English phonemization failed: {exc}") from exc
|
|
339
|
+
|
|
340
|
+
def _should_use_misaki(self) -> bool:
|
|
341
|
+
return self.lang.lower().startswith("en")
|
|
342
|
+
|
|
343
|
+
def _use_british_english(self) -> bool:
|
|
344
|
+
lang = self.lang.lower()
|
|
345
|
+
return lang.startswith("en-gb") or self.voice.lower().startswith("b")
|
|
346
|
+
|
|
347
|
+
def _download_asset(self, url: str, filename: str) -> Path:
|
|
348
|
+
try:
|
|
349
|
+
import httpx # type: ignore[import-not-found]
|
|
350
|
+
except Exception as exc: # pragma: no cover - import path
|
|
351
|
+
raise RuntimeError(
|
|
352
|
+
"kokoro-onnx provider requires httpx; install with "
|
|
353
|
+
"pip install 'converse-framework[kokoro]'."
|
|
354
|
+
) from exc
|
|
355
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
356
|
+
target = self.cache_dir / filename
|
|
357
|
+
if target.exists():
|
|
358
|
+
return target
|
|
359
|
+
temp = target.with_suffix(target.suffix + ".part")
|
|
360
|
+
with (
|
|
361
|
+
httpx.Client(follow_redirects=True, timeout=self.timeout_s) as client,
|
|
362
|
+
temp.open("wb") as handle,
|
|
363
|
+
):
|
|
364
|
+
with client.stream("GET", url) as response:
|
|
365
|
+
response.raise_for_status()
|
|
366
|
+
for chunk in response.iter_bytes():
|
|
367
|
+
if chunk:
|
|
368
|
+
handle.write(chunk)
|
|
369
|
+
temp.replace(target)
|
|
370
|
+
return target
|
|
371
|
+
|
|
372
|
+
def _emit_progress(
|
|
373
|
+
self,
|
|
374
|
+
loop: asyncio.AbstractEventLoop,
|
|
375
|
+
progress: ProgressCallback | None,
|
|
376
|
+
stage: str,
|
|
377
|
+
message: str,
|
|
378
|
+
*,
|
|
379
|
+
started: float | None = None,
|
|
380
|
+
**extra,
|
|
381
|
+
) -> None:
|
|
382
|
+
if not progress:
|
|
383
|
+
return
|
|
384
|
+
payload = {"stage": stage, "message": message, **extra}
|
|
385
|
+
if started is not None:
|
|
386
|
+
payload["elapsed_ms"] = int((time.perf_counter() - started) * 1000)
|
|
387
|
+
|
|
388
|
+
async def _fire() -> None:
|
|
389
|
+
await progress("tts.progress", payload)
|
|
390
|
+
|
|
391
|
+
loop.call_soon_threadsafe(asyncio.create_task, _fire())
|