converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,620 @@
|
|
|
1
|
+
"""Turn orchestration for speech-to-speech applications."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import base64
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from converse_framework.events import EventSink
|
|
14
|
+
from converse_framework.provider_events import (
|
|
15
|
+
provider_loaded_event,
|
|
16
|
+
provider_loading_event,
|
|
17
|
+
)
|
|
18
|
+
from converse_framework.registry import ProviderBundle
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
SystemPromptBuilder = Callable[[str, str, list[dict[str, str]]], str]
|
|
23
|
+
SamplerBuilder = Callable[[str], dict[str, Any]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass(frozen=True)
|
|
27
|
+
class PipelineConfig:
|
|
28
|
+
"""Top-level tunables for :class:`SpeechPipeline`.
|
|
29
|
+
|
|
30
|
+
The pipeline never reads from a config file or environment --
|
|
31
|
+
callers construct a :class:`ProviderBundle`, build a
|
|
32
|
+
:class:`PipelineConfig`, and hand them to the pipeline
|
|
33
|
+
directly. ``tts_chunk_chars`` and ``min_tts_chars`` can also be
|
|
34
|
+
changed at runtime via
|
|
35
|
+
:meth:`SpeechPipeline.update_turn_config`.
|
|
36
|
+
|
|
37
|
+
Attributes:
|
|
38
|
+
tts_chunk_chars: Soft character limit at which a buffered
|
|
39
|
+
LLM response is flushed to TTS. The chunker also
|
|
40
|
+
flushes on sentence-ending punctuation, so most
|
|
41
|
+
flushes will be smaller.
|
|
42
|
+
min_tts_chars: Hard lower bound for a TTS flush. Shorter
|
|
43
|
+
buffers are held back until a sentence boundary is
|
|
44
|
+
seen or ``tts_chunk_chars`` is reached. ``0`` disables
|
|
45
|
+
the lower bound.
|
|
46
|
+
default_mode: Conversation mode used when callers do not
|
|
47
|
+
pass an explicit ``mode=`` argument. The framework
|
|
48
|
+
treats modes as opaque string keys; ``"chat"`` is the
|
|
49
|
+
conventional default.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
tts_chunk_chars: int = 120
|
|
53
|
+
min_tts_chars: int = 0
|
|
54
|
+
default_mode: str = "chat"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class _TurnState:
|
|
59
|
+
messages: list[dict[str, str]] = field(default_factory=list)
|
|
60
|
+
active_tts_tasks: set[asyncio.Task] = field(default_factory=set)
|
|
61
|
+
system_prompt: str = ""
|
|
62
|
+
turn_id: int = 0
|
|
63
|
+
tts_tail: asyncio.Task | None = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SpeechPipeline:
|
|
67
|
+
"""Turn orchestrator for a speech-to-speech conversation.
|
|
68
|
+
|
|
69
|
+
The pipeline is a single async object that owns the active
|
|
70
|
+
provider bundle, an :class:`EventSink` for outbound events, and
|
|
71
|
+
the per-mode conversation state (message history, active TTS
|
|
72
|
+
tasks, system prompt, turn id). It exposes three entry
|
|
73
|
+
points -- :meth:`handle_text_turn`, :meth:`handle_audio_turn`
|
|
74
|
+
and :meth:`handle_continue` -- and drives the ASR -> LLM -> TTS
|
|
75
|
+
flow internally.
|
|
76
|
+
|
|
77
|
+
The pipeline is the only place that knows about mode
|
|
78
|
+
switching, TTS cancellation, barge-in coordination, and the
|
|
79
|
+
chunking heuristics that decide when the LLM token stream is
|
|
80
|
+
handed off to TTS. App policy (UI, profile loading, memory,
|
|
81
|
+
sampler configuration) is supplied through the optional
|
|
82
|
+
``system_prompt_builder`` and the registered
|
|
83
|
+
:class:`ProviderBundle` -- the framework never imports app
|
|
84
|
+
code.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
providers: Active provider bundle (VAD, ASR, LLM, TTS).
|
|
88
|
+
sink: Event sink that receives every turn-related event.
|
|
89
|
+
config: Optional :class:`PipelineConfig`; defaults are
|
|
90
|
+
used if omitted.
|
|
91
|
+
system_prompt_builder: Optional callable with the signature
|
|
92
|
+
``(mode, manual_prompt, messages) -> str`` that the
|
|
93
|
+
pipeline calls to compute the effective system prompt
|
|
94
|
+
for each turn. Apps use this to inject character / mode
|
|
95
|
+
/ memory policy without leaking that policy into the
|
|
96
|
+
framework.
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
providers: ProviderBundle,
|
|
102
|
+
sink: EventSink,
|
|
103
|
+
config: PipelineConfig | None = None,
|
|
104
|
+
system_prompt_builder: SystemPromptBuilder | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
self.providers = providers
|
|
107
|
+
self.sink = sink
|
|
108
|
+
self.config = config or PipelineConfig()
|
|
109
|
+
self.tts_chunk_chars = self.config.tts_chunk_chars
|
|
110
|
+
self.min_tts_chars = self.config.min_tts_chars
|
|
111
|
+
self._default_mode = self.config.default_mode
|
|
112
|
+
self._system_prompt_builder = system_prompt_builder
|
|
113
|
+
self._states: dict[str, _TurnState] = {self._default_mode: _TurnState()}
|
|
114
|
+
self.state = self._states[self._default_mode]
|
|
115
|
+
|
|
116
|
+
def update_turn_config(self, *, tts_chunk_chars: int, min_tts_chars: int) -> None:
|
|
117
|
+
self.tts_chunk_chars = tts_chunk_chars
|
|
118
|
+
self.min_tts_chars = min_tts_chars
|
|
119
|
+
|
|
120
|
+
async def update_providers(
|
|
121
|
+
self,
|
|
122
|
+
providers: ProviderBundle,
|
|
123
|
+
*,
|
|
124
|
+
cancel_active_tts: bool = True,
|
|
125
|
+
reason: str = "provider_reload",
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Swap the active provider bundle at runtime.
|
|
128
|
+
|
|
129
|
+
Cancels any active TTS synthesis by default so the next turn
|
|
130
|
+
picks up the new TTS provider. Does **not** clear
|
|
131
|
+
conversation history -- callers that want a fresh slate
|
|
132
|
+
should call :meth:`clear_conversation` separately.
|
|
133
|
+
|
|
134
|
+
Emits ``providers.updated`` with the serialized statuses of
|
|
135
|
+
the new provider bundle so downstream consumers (UI layers,
|
|
136
|
+
session helpers) can react.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
providers: The new provider bundle to activate.
|
|
140
|
+
cancel_active_tts: If True (default), cancel any
|
|
141
|
+
in-flight TTS synthesis before swapping.
|
|
142
|
+
reason: Label emitted in the event payload for
|
|
143
|
+
diagnostic/debug use.
|
|
144
|
+
"""
|
|
145
|
+
if cancel_active_tts:
|
|
146
|
+
# Clear TTS for all known modes without cancelling
|
|
147
|
+
# active recording -- that is the collector's job.
|
|
148
|
+
for state in self._states.values():
|
|
149
|
+
active = [t for t in state.active_tts_tasks if not t.done()]
|
|
150
|
+
for task in active:
|
|
151
|
+
task.cancel()
|
|
152
|
+
if active:
|
|
153
|
+
await asyncio.gather(*active, return_exceptions=True)
|
|
154
|
+
state.tts_tail = None
|
|
155
|
+
|
|
156
|
+
old_providers = self.providers
|
|
157
|
+
self.providers = providers
|
|
158
|
+
|
|
159
|
+
await self.sink.emit(
|
|
160
|
+
"providers.updated",
|
|
161
|
+
reason=reason,
|
|
162
|
+
statuses=self.providers.statuses(),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Unload replaced providers in the background so the
|
|
166
|
+
# caller does not have to wait for heavyweight cleanup.
|
|
167
|
+
asyncio.ensure_future(ProviderBundle.unload_replaced(old_providers, providers))
|
|
168
|
+
|
|
169
|
+
async def clear_conversation(self, mode: str = "chat") -> None:
|
|
170
|
+
self._select_mode(mode)
|
|
171
|
+
await self.cancel_tts("conversation_clear")
|
|
172
|
+
self.state.messages.clear()
|
|
173
|
+
await self.sink.emit("conversation.cleared", mode=self._mode)
|
|
174
|
+
|
|
175
|
+
def set_system_prompt(self, prompt: str, mode: str = "chat") -> None:
|
|
176
|
+
self._select_mode(mode)
|
|
177
|
+
self.state.system_prompt = prompt.strip()
|
|
178
|
+
|
|
179
|
+
async def cancel_tts(self, reason: str) -> None:
|
|
180
|
+
active = [task for task in self.state.active_tts_tasks if not task.done()]
|
|
181
|
+
for task in active:
|
|
182
|
+
task.cancel()
|
|
183
|
+
if active:
|
|
184
|
+
await asyncio.gather(*active, return_exceptions=True)
|
|
185
|
+
self.state.tts_tail = None
|
|
186
|
+
if active:
|
|
187
|
+
await self.sink.emit("tts.cancelled", reason=reason)
|
|
188
|
+
|
|
189
|
+
async def handle_text_turn(self, text: str, mode: str = "chat") -> None:
|
|
190
|
+
self._select_mode(mode)
|
|
191
|
+
turn_state = self.state
|
|
192
|
+
turn_mode = self._mode
|
|
193
|
+
started = time.perf_counter()
|
|
194
|
+
turn_id = self._next_turn_id(turn_state)
|
|
195
|
+
await self.cancel_tts("new_user_turn")
|
|
196
|
+
await self.sink.emit("turn.started", mode=turn_mode, turn_id=turn_id)
|
|
197
|
+
await self.sink.emit(
|
|
198
|
+
"vad.speech_start", mode=turn_mode, source="text", text_only=True
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
final_transcript = ""
|
|
202
|
+
async for transcript in self.providers.asr.transcribe_text_input(text):
|
|
203
|
+
await self.sink.emit(
|
|
204
|
+
"asr.transcript",
|
|
205
|
+
mode=turn_mode,
|
|
206
|
+
text=transcript.text,
|
|
207
|
+
final=transcript.final,
|
|
208
|
+
latency_ms=elapsed_ms(started),
|
|
209
|
+
)
|
|
210
|
+
if transcript.final:
|
|
211
|
+
final_transcript = transcript.text
|
|
212
|
+
|
|
213
|
+
await self.sink.emit(
|
|
214
|
+
"vad.speech_end",
|
|
215
|
+
mode=turn_mode,
|
|
216
|
+
source="text",
|
|
217
|
+
text_only=True,
|
|
218
|
+
latency_ms=elapsed_ms(started),
|
|
219
|
+
)
|
|
220
|
+
if not final_transcript:
|
|
221
|
+
await self.sink.emit(
|
|
222
|
+
"turn.finished", mode=turn_mode, reason="empty_transcript"
|
|
223
|
+
)
|
|
224
|
+
return
|
|
225
|
+
|
|
226
|
+
await self._respond_to_transcript(
|
|
227
|
+
final_transcript, started, turn_id, turn_state, turn_mode
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
async def handle_audio_turn(
|
|
231
|
+
self, pcm_s16le: bytes, sample_rate: int, mode: str = "chat"
|
|
232
|
+
) -> None:
|
|
233
|
+
self._select_mode(mode)
|
|
234
|
+
turn_state = self.state
|
|
235
|
+
turn_mode = self._mode
|
|
236
|
+
started = time.perf_counter()
|
|
237
|
+
turn_id = self._next_turn_id(turn_state)
|
|
238
|
+
await self.cancel_tts("new_audio_turn")
|
|
239
|
+
await self.sink.emit(
|
|
240
|
+
"turn.started", mode=turn_mode, source="audio", turn_id=turn_id
|
|
241
|
+
)
|
|
242
|
+
await self.sink.emit(
|
|
243
|
+
"asr.started", mode=turn_mode, sample_rate=sample_rate, bytes=len(pcm_s16le)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
final_transcript = ""
|
|
247
|
+
try:
|
|
248
|
+
|
|
249
|
+
async def progress(event_type: str, payload: dict) -> None:
|
|
250
|
+
lat = elapsed_ms(started)
|
|
251
|
+
await self.sink.emit(event_type, **payload, latency_ms=lat)
|
|
252
|
+
# Emit provider lifecycle events alongside progress.
|
|
253
|
+
stage = payload.get("stage", "")
|
|
254
|
+
if event_type in ("asr.progress", "tts.progress"):
|
|
255
|
+
kind = "asr" if event_type == "asr.progress" else "tts"
|
|
256
|
+
provider_name = (
|
|
257
|
+
self.providers.asr.status.name
|
|
258
|
+
if event_type == "asr.progress"
|
|
259
|
+
else self.providers.tts.status.name
|
|
260
|
+
)
|
|
261
|
+
msg = payload.get("message", "")
|
|
262
|
+
if stage == "loading":
|
|
263
|
+
await self.sink.emit(
|
|
264
|
+
**provider_loading_event(
|
|
265
|
+
kind=kind,
|
|
266
|
+
provider=provider_name,
|
|
267
|
+
message=msg,
|
|
268
|
+
),
|
|
269
|
+
latency_ms=lat,
|
|
270
|
+
)
|
|
271
|
+
elif stage == "loaded":
|
|
272
|
+
await self.sink.emit(
|
|
273
|
+
**provider_loaded_event(
|
|
274
|
+
kind=kind,
|
|
275
|
+
provider=provider_name,
|
|
276
|
+
message=msg,
|
|
277
|
+
latency_ms=lat,
|
|
278
|
+
),
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
async for transcript in self.providers.asr.transcribe_audio(
|
|
282
|
+
pcm_s16le, sample_rate, progress
|
|
283
|
+
):
|
|
284
|
+
await self.sink.emit(
|
|
285
|
+
"asr.transcript",
|
|
286
|
+
mode=turn_mode,
|
|
287
|
+
text=transcript.text,
|
|
288
|
+
final=transcript.final,
|
|
289
|
+
latency_ms=elapsed_ms(started),
|
|
290
|
+
)
|
|
291
|
+
if transcript.final:
|
|
292
|
+
final_transcript = transcript.text
|
|
293
|
+
except Exception as exc:
|
|
294
|
+
lat = elapsed_ms(started)
|
|
295
|
+
payload = exception_payload(
|
|
296
|
+
exc, fallback=f"ASR provider failed with {type(exc).__name__}."
|
|
297
|
+
)
|
|
298
|
+
await self.sink.emit("asr.error", latency_ms=lat, **payload)
|
|
299
|
+
await self.sink.emit(
|
|
300
|
+
"provider.error",
|
|
301
|
+
kind="asr",
|
|
302
|
+
provider=self.providers.asr.status.name,
|
|
303
|
+
**payload,
|
|
304
|
+
latency_ms=lat,
|
|
305
|
+
)
|
|
306
|
+
await self.sink.emit(
|
|
307
|
+
"turn.finished",
|
|
308
|
+
mode=turn_mode,
|
|
309
|
+
reason="asr_error",
|
|
310
|
+
latency_ms=lat,
|
|
311
|
+
)
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
if not final_transcript:
|
|
315
|
+
await self.sink.emit(
|
|
316
|
+
"turn.finished",
|
|
317
|
+
mode=turn_mode,
|
|
318
|
+
reason="empty_transcript",
|
|
319
|
+
latency_ms=elapsed_ms(started),
|
|
320
|
+
)
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
await self._respond_to_transcript(
|
|
324
|
+
final_transcript, started, turn_id, turn_state, turn_mode
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
async def handle_continue(self, mode: str = "chat") -> None:
|
|
328
|
+
self._select_mode(mode)
|
|
329
|
+
turn_state = self.state
|
|
330
|
+
turn_mode = self._mode
|
|
331
|
+
if not turn_state.messages or turn_state.messages[-1]["role"] != "assistant":
|
|
332
|
+
await self.sink.emit(
|
|
333
|
+
"turn.error",
|
|
334
|
+
mode=turn_mode,
|
|
335
|
+
message="No previous assistant message to continue.",
|
|
336
|
+
)
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
started = time.perf_counter()
|
|
340
|
+
turn_id = self._next_turn_id(turn_state)
|
|
341
|
+
await self.cancel_tts("continue_turn")
|
|
342
|
+
await self.sink.emit(
|
|
343
|
+
"turn.started", mode=turn_mode, source="continue", turn_id=turn_id
|
|
344
|
+
)
|
|
345
|
+
prefix = turn_state.messages[-1]["content"]
|
|
346
|
+
turn_state.messages.pop()
|
|
347
|
+
turn_state.messages.append({"role": "assistant", "content": prefix})
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
response_text = await self._stream_llm_and_tts(
|
|
351
|
+
prefix, started, turn_id, turn_state, turn_mode
|
|
352
|
+
)
|
|
353
|
+
turn_state.messages[-1] = {
|
|
354
|
+
"role": "assistant",
|
|
355
|
+
"content": response_text.strip(),
|
|
356
|
+
}
|
|
357
|
+
await self.sink.emit(
|
|
358
|
+
"turn.finished", mode=turn_mode, latency_ms=elapsed_ms(started)
|
|
359
|
+
)
|
|
360
|
+
except Exception as exc:
|
|
361
|
+
lat = elapsed_ms(started)
|
|
362
|
+
payload = exception_payload(
|
|
363
|
+
exc, fallback=f"LLM provider failed with {type(exc).__name__}."
|
|
364
|
+
)
|
|
365
|
+
await self.sink.emit(
|
|
366
|
+
"turn.error",
|
|
367
|
+
mode=turn_mode,
|
|
368
|
+
latency_ms=lat,
|
|
369
|
+
**payload,
|
|
370
|
+
)
|
|
371
|
+
await self.sink.emit(
|
|
372
|
+
"provider.error",
|
|
373
|
+
kind="llm",
|
|
374
|
+
provider=self.providers.llm.status.name,
|
|
375
|
+
**payload,
|
|
376
|
+
latency_ms=lat,
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def messages_for_mode(self, mode: str) -> list[dict[str, str]]:
|
|
380
|
+
return list(self._state_for_mode(mode).messages)
|
|
381
|
+
|
|
382
|
+
async def _respond_to_transcript(
|
|
383
|
+
self,
|
|
384
|
+
final_transcript: str,
|
|
385
|
+
started: float,
|
|
386
|
+
turn_id: int,
|
|
387
|
+
turn_state: _TurnState,
|
|
388
|
+
turn_mode: str,
|
|
389
|
+
) -> None:
|
|
390
|
+
turn_state.messages.append({"role": "user", "content": final_transcript})
|
|
391
|
+
try:
|
|
392
|
+
response_text = await self._stream_llm_and_tts(
|
|
393
|
+
"", started, turn_id, turn_state, turn_mode
|
|
394
|
+
)
|
|
395
|
+
turn_state.messages.append(
|
|
396
|
+
{"role": "assistant", "content": response_text.strip()}
|
|
397
|
+
)
|
|
398
|
+
await self.sink.emit(
|
|
399
|
+
"turn.finished", mode=turn_mode, latency_ms=elapsed_ms(started)
|
|
400
|
+
)
|
|
401
|
+
except Exception as exc:
|
|
402
|
+
lat = elapsed_ms(started)
|
|
403
|
+
payload = exception_payload(
|
|
404
|
+
exc, fallback=f"LLM provider failed with {type(exc).__name__}."
|
|
405
|
+
)
|
|
406
|
+
await self.sink.emit(
|
|
407
|
+
"turn.error",
|
|
408
|
+
mode=turn_mode,
|
|
409
|
+
latency_ms=lat,
|
|
410
|
+
**payload,
|
|
411
|
+
)
|
|
412
|
+
await self.sink.emit(
|
|
413
|
+
"provider.error",
|
|
414
|
+
kind="llm",
|
|
415
|
+
provider=self.providers.llm.status.name,
|
|
416
|
+
**payload,
|
|
417
|
+
latency_ms=lat,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
async def _stream_llm_and_tts(
|
|
421
|
+
self,
|
|
422
|
+
response_text: str,
|
|
423
|
+
started: float,
|
|
424
|
+
turn_id: int,
|
|
425
|
+
turn_state: _TurnState,
|
|
426
|
+
turn_mode: str,
|
|
427
|
+
) -> str:
|
|
428
|
+
first_token_seen = False
|
|
429
|
+
sentence_buffer = ""
|
|
430
|
+
async for token in self.providers.llm.stream_response(
|
|
431
|
+
self._llm_messages(turn_state, turn_mode)
|
|
432
|
+
):
|
|
433
|
+
if not first_token_seen:
|
|
434
|
+
first_token_seen = True
|
|
435
|
+
await self.sink.emit(
|
|
436
|
+
"llm.first_token", mode=turn_mode, latency_ms=elapsed_ms(started)
|
|
437
|
+
)
|
|
438
|
+
response_text += token
|
|
439
|
+
sentence_buffer += token
|
|
440
|
+
await self.sink.emit(
|
|
441
|
+
"llm.token", mode=turn_mode, text=token, accumulated=response_text
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
if should_flush_tts(
|
|
445
|
+
sentence_buffer, self.tts_chunk_chars, self.min_tts_chars
|
|
446
|
+
):
|
|
447
|
+
await self._start_tts_chunk(
|
|
448
|
+
sentence_buffer.strip(), started, turn_id, turn_state, turn_mode
|
|
449
|
+
)
|
|
450
|
+
sentence_buffer = ""
|
|
451
|
+
|
|
452
|
+
if sentence_buffer.strip():
|
|
453
|
+
await self._start_tts_chunk(
|
|
454
|
+
sentence_buffer.strip(), started, turn_id, turn_state, turn_mode
|
|
455
|
+
)
|
|
456
|
+
return response_text
|
|
457
|
+
|
|
458
|
+
async def _start_tts_chunk(
|
|
459
|
+
self,
|
|
460
|
+
text: str,
|
|
461
|
+
turn_started: float,
|
|
462
|
+
turn_id: int,
|
|
463
|
+
turn_state: _TurnState,
|
|
464
|
+
turn_mode: str,
|
|
465
|
+
) -> None:
|
|
466
|
+
previous = turn_state.tts_tail
|
|
467
|
+
task = asyncio.create_task(
|
|
468
|
+
self._stream_tts_after(previous, text, turn_started, turn_id, turn_mode)
|
|
469
|
+
)
|
|
470
|
+
turn_state.tts_tail = task
|
|
471
|
+
turn_state.active_tts_tasks.add(task)
|
|
472
|
+
task.add_done_callback(turn_state.active_tts_tasks.discard)
|
|
473
|
+
|
|
474
|
+
async def _stream_tts_after(
|
|
475
|
+
self,
|
|
476
|
+
previous: asyncio.Task | None,
|
|
477
|
+
text: str,
|
|
478
|
+
turn_started: float,
|
|
479
|
+
turn_id: int,
|
|
480
|
+
turn_mode: str,
|
|
481
|
+
) -> None:
|
|
482
|
+
if previous is not None:
|
|
483
|
+
try:
|
|
484
|
+
await previous
|
|
485
|
+
except asyncio.CancelledError:
|
|
486
|
+
raise
|
|
487
|
+
except Exception as exc:
|
|
488
|
+
logger.warning("Previous TTS task failed: %s", exc)
|
|
489
|
+
await self._stream_tts(text, turn_started, turn_id, turn_mode)
|
|
490
|
+
|
|
491
|
+
async def _stream_tts(
|
|
492
|
+
self, text: str, turn_started: float, turn_id: int, turn_mode: str
|
|
493
|
+
) -> None:
|
|
494
|
+
first_chunk_seen = False
|
|
495
|
+
chunk_index = 0
|
|
496
|
+
try:
|
|
497
|
+
|
|
498
|
+
async def progress(event_type: str, payload: dict) -> None:
|
|
499
|
+
await self.sink.emit(
|
|
500
|
+
event_type, **payload, latency_ms=elapsed_ms(turn_started)
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
async for chunk in self.providers.tts.stream_audio_with_progress(
|
|
504
|
+
text, progress
|
|
505
|
+
):
|
|
506
|
+
chunk_index += 1
|
|
507
|
+
if not first_chunk_seen:
|
|
508
|
+
first_chunk_seen = True
|
|
509
|
+
await self.sink.emit(
|
|
510
|
+
"tts.first_chunk",
|
|
511
|
+
mode=turn_mode,
|
|
512
|
+
latency_ms=elapsed_ms(turn_started),
|
|
513
|
+
text=text,
|
|
514
|
+
turn_id=turn_id,
|
|
515
|
+
)
|
|
516
|
+
encoded = base64.b64encode(chunk.data).decode("ascii")
|
|
517
|
+
await self.sink.emit(
|
|
518
|
+
"tts.audio",
|
|
519
|
+
mode=turn_mode,
|
|
520
|
+
mime_type=chunk.mime_type,
|
|
521
|
+
sample_rate=chunk.sample_rate,
|
|
522
|
+
channels=chunk.channels,
|
|
523
|
+
encoding=chunk.encoding,
|
|
524
|
+
duration_ms=chunk.duration_ms,
|
|
525
|
+
data=encoded,
|
|
526
|
+
final=chunk.final,
|
|
527
|
+
text=text,
|
|
528
|
+
turn_id=turn_id,
|
|
529
|
+
chunk_index=chunk_index,
|
|
530
|
+
text_chars=len(text),
|
|
531
|
+
byte_length=len(chunk.data),
|
|
532
|
+
latency_ms=elapsed_ms(turn_started),
|
|
533
|
+
)
|
|
534
|
+
except asyncio.CancelledError:
|
|
535
|
+
raise
|
|
536
|
+
except Exception as exc:
|
|
537
|
+
lat = elapsed_ms(turn_started)
|
|
538
|
+
payload = exception_payload(
|
|
539
|
+
exc, fallback=f"TTS provider failed with {type(exc).__name__}."
|
|
540
|
+
)
|
|
541
|
+
await self.sink.emit(
|
|
542
|
+
"tts.error",
|
|
543
|
+
mode=turn_mode,
|
|
544
|
+
latency_ms=lat,
|
|
545
|
+
text=text,
|
|
546
|
+
**payload,
|
|
547
|
+
)
|
|
548
|
+
await self.sink.emit(
|
|
549
|
+
"provider.error",
|
|
550
|
+
kind="tts",
|
|
551
|
+
provider=self.providers.tts.status.name,
|
|
552
|
+
**payload,
|
|
553
|
+
latency_ms=lat,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def _llm_messages(
|
|
557
|
+
self, turn_state: _TurnState, turn_mode: str
|
|
558
|
+
) -> list[dict[str, str]]:
|
|
559
|
+
prompt = self._effective_system_prompt(turn_state, turn_mode)
|
|
560
|
+
if not prompt:
|
|
561
|
+
return list(turn_state.messages)
|
|
562
|
+
return [{"role": "system", "content": prompt}, *turn_state.messages]
|
|
563
|
+
|
|
564
|
+
def _effective_system_prompt(self, turn_state: _TurnState, turn_mode: str) -> str:
|
|
565
|
+
if self._system_prompt_builder is not None:
|
|
566
|
+
return self._system_prompt_builder(
|
|
567
|
+
turn_mode, turn_state.system_prompt, list(turn_state.messages)
|
|
568
|
+
).strip()
|
|
569
|
+
return turn_state.system_prompt
|
|
570
|
+
|
|
571
|
+
def _next_turn_id(self, turn_state: _TurnState | None = None) -> int:
|
|
572
|
+
selected = turn_state or self.state
|
|
573
|
+
selected.turn_id += 1
|
|
574
|
+
return selected.turn_id
|
|
575
|
+
|
|
576
|
+
@property
|
|
577
|
+
def _mode(self) -> str:
|
|
578
|
+
for mode, state in self._states.items():
|
|
579
|
+
if state is self.state:
|
|
580
|
+
return mode
|
|
581
|
+
return self._default_mode
|
|
582
|
+
|
|
583
|
+
def _select_mode(self, mode: str) -> None:
|
|
584
|
+
self.state = self._state_for_mode(mode)
|
|
585
|
+
|
|
586
|
+
def _state_for_mode(self, mode: str) -> _TurnState:
|
|
587
|
+
selected = mode or self._default_mode
|
|
588
|
+
if selected not in self._states:
|
|
589
|
+
self._states[selected] = _TurnState()
|
|
590
|
+
return self._states[selected]
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def elapsed_ms(started: float) -> int:
|
|
594
|
+
return int((time.perf_counter() - started) * 1000)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def exception_payload(exc: Exception, *, fallback: str) -> dict[str, str]:
|
|
598
|
+
"""Build a structured error dict with a guaranteed non-empty message.
|
|
599
|
+
|
|
600
|
+
Falls back to *fallback* when ``str(exc)`` is empty.
|
|
601
|
+
"""
|
|
602
|
+
message = str(exc)
|
|
603
|
+
if not message:
|
|
604
|
+
message = fallback
|
|
605
|
+
return {
|
|
606
|
+
"message": message,
|
|
607
|
+
"error_type": type(exc).__name__,
|
|
608
|
+
"repr": repr(exc),
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def should_flush_tts(text: str, limit: int, minimum: int = 0) -> bool:
|
|
613
|
+
stripped = text.strip()
|
|
614
|
+
if not stripped:
|
|
615
|
+
return False
|
|
616
|
+
if len(stripped) >= limit:
|
|
617
|
+
return True
|
|
618
|
+
if len(stripped) < minimum:
|
|
619
|
+
return False
|
|
620
|
+
return stripped.endswith((".", "!", "?", ";", ":"))
|