converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""Text-only chat example showing the framework as a second consumer.
|
|
2
|
+
|
|
3
|
+
This example runs a real conversation against :class:`SpeechPipeline`
|
|
4
|
+
using only the framework's public API. No harness modules, FastAPI,
|
|
5
|
+
WebSocket, profile files, or browser UI are involved. It is intended
|
|
6
|
+
as a quick smoke test that the extracted package can drive a complete
|
|
7
|
+
turn loop on its own.
|
|
8
|
+
|
|
9
|
+
The example exposes two surfaces:
|
|
10
|
+
|
|
11
|
+
* :func:`run_text_chat` — async driver that runs a list of scripted
|
|
12
|
+
inputs through the pipeline. Used by the test suite.
|
|
13
|
+
* ``__main__`` — a small CLI that reads lines from stdin, prints LLM
|
|
14
|
+
tokens as they stream, and summarizes each turn's events.
|
|
15
|
+
|
|
16
|
+
The CLI uses mock providers by default. To try a real provider, pass
|
|
17
|
+
``--asr``, ``--llm``, ``--tts`` with a registered provider name and
|
|
18
|
+
install the matching extra. If the extra is missing, the framework
|
|
19
|
+
falls back to :class:`UnavailableProvider` and the status message
|
|
20
|
+
will tell the user which extra to install.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import argparse
|
|
26
|
+
import asyncio
|
|
27
|
+
import base64
|
|
28
|
+
import sys
|
|
29
|
+
from collections.abc import Iterable
|
|
30
|
+
from dataclasses import dataclass, field
|
|
31
|
+
from typing import Any
|
|
32
|
+
|
|
33
|
+
from converse_framework.events import EventSink, QueueEventSink
|
|
34
|
+
from converse_framework.pipeline import PipelineConfig, SpeechPipeline
|
|
35
|
+
from converse_framework.registry import build_provider_bundle
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class TextChatExampleConfig:
|
|
40
|
+
"""User-facing configuration for the text chat example."""
|
|
41
|
+
|
|
42
|
+
providers: dict[str, str] = field(
|
|
43
|
+
default_factory=lambda: {
|
|
44
|
+
"vad": "mock",
|
|
45
|
+
"asr": "mock",
|
|
46
|
+
"llm": "mock",
|
|
47
|
+
"tts": "mock",
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
tts_chunk_chars: int = 80
|
|
51
|
+
min_tts_chars: int = 0
|
|
52
|
+
mode: str = "chat"
|
|
53
|
+
system_prompt: str = "You are a helpful assistant. Be concise."
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def build_example_bundle(config: TextChatExampleConfig) -> Any:
|
|
57
|
+
"""Build a provider bundle from the example config.
|
|
58
|
+
|
|
59
|
+
Each section is forwarded to the framework registry. Real provider
|
|
60
|
+
names are honored when the matching optional dependency is
|
|
61
|
+
installed; otherwise the registry returns an unavailable provider
|
|
62
|
+
whose status message tells the user which extra to install.
|
|
63
|
+
"""
|
|
64
|
+
nested: dict[str, dict[str, Any]] = {}
|
|
65
|
+
for kind, name in config.providers.items():
|
|
66
|
+
nested[kind] = {"provider": name}
|
|
67
|
+
return build_provider_bundle(nested)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def build_example_pipeline(
|
|
71
|
+
config: TextChatExampleConfig | None = None,
|
|
72
|
+
*,
|
|
73
|
+
sink: EventSink | None = None,
|
|
74
|
+
) -> tuple[SpeechPipeline, Any]:
|
|
75
|
+
"""Construct a :class:`SpeechPipeline` ready for the example.
|
|
76
|
+
|
|
77
|
+
Returns the pipeline and the provider bundle it was built with so
|
|
78
|
+
callers can inspect ``bundle.statuses()`` if they want to.
|
|
79
|
+
"""
|
|
80
|
+
config = config or TextChatExampleConfig()
|
|
81
|
+
bundle = build_example_bundle(config)
|
|
82
|
+
sink = sink or QueueEventSink(asyncio.Queue())
|
|
83
|
+
pipeline = SpeechPipeline(
|
|
84
|
+
providers=bundle,
|
|
85
|
+
sink=sink,
|
|
86
|
+
config=PipelineConfig(
|
|
87
|
+
tts_chunk_chars=config.tts_chunk_chars,
|
|
88
|
+
min_tts_chars=config.min_tts_chars,
|
|
89
|
+
default_mode=config.mode,
|
|
90
|
+
),
|
|
91
|
+
)
|
|
92
|
+
if config.system_prompt:
|
|
93
|
+
pipeline.set_system_prompt(config.system_prompt, mode=config.mode)
|
|
94
|
+
return pipeline, bundle
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def run_text_chat(
|
|
98
|
+
inputs: Iterable[str],
|
|
99
|
+
config: TextChatExampleConfig | None = None,
|
|
100
|
+
) -> dict[str, Any]:
|
|
101
|
+
"""Drive the example end-to-end and return a structured summary.
|
|
102
|
+
|
|
103
|
+
The summary contains the per-turn event types, the LLM text per
|
|
104
|
+
turn, and the number of TTS audio chunks produced per turn. It is
|
|
105
|
+
designed to be assertion-friendly for the framework test suite.
|
|
106
|
+
"""
|
|
107
|
+
config = config or TextChatExampleConfig()
|
|
108
|
+
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
109
|
+
pipeline, _ = build_example_pipeline(config, sink=QueueEventSink(queue))
|
|
110
|
+
|
|
111
|
+
turns: list[dict[str, Any]] = []
|
|
112
|
+
for text in inputs:
|
|
113
|
+
await pipeline.handle_text_turn(text, mode=config.mode)
|
|
114
|
+
# Drain events emitted during the turn before recording them so
|
|
115
|
+
# concurrent TTS tasks finish flushing.
|
|
116
|
+
await _drain_in_flight_tts(pipeline)
|
|
117
|
+
events = await _drain_queue(queue)
|
|
118
|
+
turns.append(
|
|
119
|
+
{
|
|
120
|
+
"input": text,
|
|
121
|
+
"events": [event["type"] for event in events],
|
|
122
|
+
"llm_text": _joined_llm_text(events),
|
|
123
|
+
"tts_audio_chunks": sum(
|
|
124
|
+
1 for event in events if event["type"] == "tts.audio"
|
|
125
|
+
),
|
|
126
|
+
}
|
|
127
|
+
)
|
|
128
|
+
return {
|
|
129
|
+
"mode": config.mode,
|
|
130
|
+
"providers": config.providers,
|
|
131
|
+
"turns": turns,
|
|
132
|
+
"messages": pipeline.messages_for_mode(config.mode),
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
async def _drain_in_flight_tts(pipeline: SpeechPipeline) -> None:
|
|
137
|
+
"""Wait for any pending TTS tasks to finish so events are settled."""
|
|
138
|
+
active = list(pipeline.state.active_tts_tasks)
|
|
139
|
+
if not active:
|
|
140
|
+
return
|
|
141
|
+
await asyncio.gather(*active, return_exceptions=True)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def _drain_queue(queue: asyncio.Queue[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
145
|
+
events: list[dict[str, Any]] = []
|
|
146
|
+
while not queue.empty():
|
|
147
|
+
events.append(queue.get_nowait())
|
|
148
|
+
return events
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _joined_llm_text(events: list[dict[str, Any]]) -> str:
|
|
152
|
+
pieces: list[str] = []
|
|
153
|
+
for event in events:
|
|
154
|
+
if event["type"] == "llm.token":
|
|
155
|
+
payload = event.get("payload", {})
|
|
156
|
+
token = payload.get("text")
|
|
157
|
+
if isinstance(token, str):
|
|
158
|
+
pieces.append(token)
|
|
159
|
+
return "".join(pieces).strip()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _format_audio_summary(audio_b64: str, mime_type: str) -> str:
|
|
163
|
+
"""Best-effort one-line summary for a TTS audio payload."""
|
|
164
|
+
try:
|
|
165
|
+
decoded = base64.b64decode(audio_b64.encode("ascii"))
|
|
166
|
+
except Exception:
|
|
167
|
+
return f"[tts.audio] {mime_type} (decode-error)"
|
|
168
|
+
return f"[tts.audio] {mime_type} {len(decoded)} bytes"
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _format_event_for_cli(event: dict[str, Any]) -> str:
|
|
172
|
+
event_type = event.get("type", "")
|
|
173
|
+
payload = event.get("payload", {}) or {}
|
|
174
|
+
if event_type == "llm.token":
|
|
175
|
+
return f" token: {payload.get('text', '')!r}"
|
|
176
|
+
if event_type == "asr.transcript":
|
|
177
|
+
marker = "final" if payload.get("final") else "partial"
|
|
178
|
+
return f" asr({marker}): {payload.get('text', '')!r}"
|
|
179
|
+
if event_type == "tts.audio":
|
|
180
|
+
return " " + _format_audio_summary(
|
|
181
|
+
str(payload.get("data", "")), str(payload.get("mime_type", "audio/wav"))
|
|
182
|
+
)
|
|
183
|
+
if event_type == "turn.finished":
|
|
184
|
+
return " turn.finished"
|
|
185
|
+
if event_type == "turn.error":
|
|
186
|
+
return f" turn.error: {payload.get('message', '')}"
|
|
187
|
+
return f" {event_type}"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
_DEFAULT_PROVIDER_NAMES: dict[str, str] = {
|
|
191
|
+
"vad": "mock",
|
|
192
|
+
"asr": "mock",
|
|
193
|
+
"llm": "mock",
|
|
194
|
+
"tts": "mock",
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _parse_provider_args(values: Iterable[str]) -> dict[str, str]:
|
|
199
|
+
"""Parse ``--provider KIND=NAME`` style CLI arguments.
|
|
200
|
+
|
|
201
|
+
Returns a complete provider name map (vad/asr/llm/tts) with each
|
|
202
|
+
kind defaulting to ``mock`` when no override is supplied. The
|
|
203
|
+
text-only example does not exercise voice activity detection, so
|
|
204
|
+
callers typically leave the VAD default alone.
|
|
205
|
+
"""
|
|
206
|
+
parsed: dict[str, str] = dict(_DEFAULT_PROVIDER_NAMES)
|
|
207
|
+
for entry in values:
|
|
208
|
+
if "=" not in entry:
|
|
209
|
+
raise SystemExit(
|
|
210
|
+
f"Expected --provider-style argument of the form KIND=NAME, got {entry!r}"
|
|
211
|
+
)
|
|
212
|
+
kind, name = entry.split("=", 1)
|
|
213
|
+
kind = kind.strip()
|
|
214
|
+
name = name.strip()
|
|
215
|
+
if kind not in {"vad", "asr", "llm", "tts"}:
|
|
216
|
+
raise SystemExit(f"Unknown provider kind: {kind}")
|
|
217
|
+
if not name:
|
|
218
|
+
raise SystemExit(f"Provider name is empty for kind {kind}")
|
|
219
|
+
parsed[kind] = name
|
|
220
|
+
return parsed
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _build_arg_parser() -> argparse.ArgumentParser:
|
|
224
|
+
parser = argparse.ArgumentParser(
|
|
225
|
+
prog="python -m converse_framework.examples.text_chat",
|
|
226
|
+
description=(
|
|
227
|
+
"Run a text conversation against Converse Framework with mock "
|
|
228
|
+
"providers by default. Use KIND=NAME overrides to select a real "
|
|
229
|
+
"provider when the matching extra is installed."
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
parser.add_argument(
|
|
233
|
+
"--provider",
|
|
234
|
+
action="append",
|
|
235
|
+
default=[],
|
|
236
|
+
metavar="KIND=NAME",
|
|
237
|
+
help="Override a provider, e.g. --provider asr=faster-whisper",
|
|
238
|
+
)
|
|
239
|
+
parser.add_argument(
|
|
240
|
+
"--tts-chunk-chars",
|
|
241
|
+
type=int,
|
|
242
|
+
default=80,
|
|
243
|
+
help="Flush TTS once the LLM has produced this many characters.",
|
|
244
|
+
)
|
|
245
|
+
parser.add_argument(
|
|
246
|
+
"--system-prompt",
|
|
247
|
+
type=str,
|
|
248
|
+
default="You are a helpful assistant. Be concise.",
|
|
249
|
+
help="Initial system prompt.",
|
|
250
|
+
)
|
|
251
|
+
parser.add_argument(
|
|
252
|
+
"--mode",
|
|
253
|
+
type=str,
|
|
254
|
+
default="chat",
|
|
255
|
+
help="Conversation mode key (default: chat).",
|
|
256
|
+
)
|
|
257
|
+
return parser
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
async def _run_cli(args: argparse.Namespace) -> int:
|
|
261
|
+
config = TextChatExampleConfig(
|
|
262
|
+
providers=_parse_provider_args(args.provider),
|
|
263
|
+
tts_chunk_chars=args.tts_chunk_chars,
|
|
264
|
+
min_tts_chars=0,
|
|
265
|
+
mode=args.mode,
|
|
266
|
+
system_prompt=args.system_prompt,
|
|
267
|
+
)
|
|
268
|
+
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
|
269
|
+
pipeline, bundle = build_example_pipeline(config, sink=QueueEventSink(queue))
|
|
270
|
+
|
|
271
|
+
print(f"providers: {config.providers}")
|
|
272
|
+
for status in bundle.statuses():
|
|
273
|
+
print(f" - {status['kind']}/{status['name']}: {status['message']}")
|
|
274
|
+
print("type 'quit' to exit, 'clear' to reset history")
|
|
275
|
+
print("-" * 60)
|
|
276
|
+
|
|
277
|
+
loop = asyncio.get_running_loop()
|
|
278
|
+
while True:
|
|
279
|
+
try:
|
|
280
|
+
line = await loop.run_in_executor(None, lambda: input("you> "))
|
|
281
|
+
except (EOFError, KeyboardInterrupt):
|
|
282
|
+
print()
|
|
283
|
+
return 0
|
|
284
|
+
text = line.strip()
|
|
285
|
+
if not text:
|
|
286
|
+
continue
|
|
287
|
+
if text in {"quit", "exit"}:
|
|
288
|
+
return 0
|
|
289
|
+
if text == "clear":
|
|
290
|
+
await pipeline.clear_conversation(mode=config.mode)
|
|
291
|
+
print("(conversation cleared)")
|
|
292
|
+
continue
|
|
293
|
+
await pipeline.handle_text_turn(text, mode=config.mode)
|
|
294
|
+
await _drain_in_flight_tts(pipeline)
|
|
295
|
+
events = await _drain_queue(queue)
|
|
296
|
+
for event in events:
|
|
297
|
+
print(_format_event_for_cli(event))
|
|
298
|
+
print("-" * 60)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def main(argv: list[str] | None = None) -> int:
|
|
302
|
+
parser = _build_arg_parser()
|
|
303
|
+
args = parser.parse_args(argv)
|
|
304
|
+
return asyncio.run(_run_cli(args))
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
if __name__ == "__main__": # pragma: no cover - exercised by the CLI smoke test
|
|
308
|
+
sys.exit(main())
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""Voice chat example (manual).
|
|
2
|
+
|
|
3
|
+
This example shows the framework's voice flow: a :class:`AudioUtteranceCollector`
|
|
4
|
+
feeds PCM frames into the :class:`SpeechPipeline` after detecting speech
|
|
5
|
+
end. It is the recommended starting point for any consumer that wants
|
|
6
|
+
to build a voice assistant on top of the framework.
|
|
7
|
+
|
|
8
|
+
The example is **manual** by design — it reads a WAV file supplied with
|
|
9
|
+
``--input`` and is not exercised by the
|
|
10
|
+
automated test suite. The text-only example in
|
|
11
|
+
:mod:`converse_framework.examples.text_chat` is the one covered by
|
|
12
|
+
tests.
|
|
13
|
+
|
|
14
|
+
Usage
|
|
15
|
+
-----
|
|
16
|
+
|
|
17
|
+
Run the voice example from the repository root after installing the
|
|
18
|
+
optional ``silero`` and ``faster-whisper`` extras::
|
|
19
|
+
|
|
20
|
+
python -m converse_framework.examples.voice_chat --input path/to/16k_mono.wav
|
|
21
|
+
|
|
22
|
+
The example will:
|
|
23
|
+
|
|
24
|
+
1. Build a provider bundle (``silero`` VAD, ``faster-whisper`` ASR,
|
|
25
|
+
``llamacpp`` LLM, ``kokoro`` TTS by default).
|
|
26
|
+
2. Read 16 kHz mono PCM frames from the WAV file passed with
|
|
27
|
+
``--input path/to/file.wav``.
|
|
28
|
+
3. Feed 30 ms PCM frames into the utterance collector.
|
|
29
|
+
4. For each completed utterance, hand the PCM bytes to
|
|
30
|
+
:meth:`SpeechPipeline.handle_audio_turn`.
|
|
31
|
+
5. Emit pipeline events, including any ``tts.audio`` chunks, through the
|
|
32
|
+
configured event sink. Consumer apps own playback.
|
|
33
|
+
|
|
34
|
+
Implementation notes
|
|
35
|
+
--------------------
|
|
36
|
+
|
|
37
|
+
* The collector's :attr:`cancel_callback` is wired to
|
|
38
|
+
:meth:`SpeechPipeline.cancel_tts` so VAD-driven speech starts cancel
|
|
39
|
+
any in-flight TTS (barge-in).
|
|
40
|
+
* The collector's ``pre_speech_start_hook`` is a no-op here, but it
|
|
41
|
+
is where a real consumer would update a per-frame system prompt
|
|
42
|
+
before the utterance is finalized.
|
|
43
|
+
* ``--mock`` swaps the VAD/ASR providers for the in-process mock
|
|
44
|
+
providers so the example can run without any heavy model
|
|
45
|
+
dependencies.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
from __future__ import annotations
|
|
49
|
+
|
|
50
|
+
import argparse
|
|
51
|
+
import asyncio
|
|
52
|
+
import base64
|
|
53
|
+
import sys
|
|
54
|
+
|
|
55
|
+
from converse_framework.audio_utils import AudioFrame, AudioFrameStats, parse_audio_frame
|
|
56
|
+
from converse_framework.events import QueueEventSink
|
|
57
|
+
from converse_framework.pipeline import PipelineConfig, SpeechPipeline
|
|
58
|
+
from converse_framework.registry import build_provider_bundle
|
|
59
|
+
from converse_framework.utterance_collector import (
|
|
60
|
+
AudioUtteranceCollector,
|
|
61
|
+
UtteranceCollectorConfig,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _parse_provider_args(values: list[str]) -> dict[str, str]:
|
|
66
|
+
parsed: dict[str, str] = {"vad": "silero", "asr": "faster-whisper", "llm": "llamacpp", "tts": "kokoro"}
|
|
67
|
+
for entry in values:
|
|
68
|
+
if "=" not in entry:
|
|
69
|
+
raise SystemExit(f"Expected KIND=NAME, got {entry!r}")
|
|
70
|
+
kind, name = entry.split("=", 1)
|
|
71
|
+
if kind not in parsed:
|
|
72
|
+
raise SystemExit(f"Unknown provider kind: {kind}")
|
|
73
|
+
parsed[kind] = name.strip()
|
|
74
|
+
return parsed
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _build_arg_parser() -> argparse.ArgumentParser:
|
|
78
|
+
parser = argparse.ArgumentParser(
|
|
79
|
+
prog="python -m converse_framework.examples.voice_chat",
|
|
80
|
+
description=(
|
|
81
|
+
"Manual voice example. Streams WAV-file frames "
|
|
82
|
+
"through the framework's utterance collector and pipeline. "
|
|
83
|
+
"Run --mock to avoid heavy provider dependencies."
|
|
84
|
+
),
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--provider",
|
|
88
|
+
action="append",
|
|
89
|
+
default=[],
|
|
90
|
+
metavar="KIND=NAME",
|
|
91
|
+
help="Override a provider, e.g. --provider tts=pocket-tts",
|
|
92
|
+
)
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--mock",
|
|
95
|
+
action="store_true",
|
|
96
|
+
help="Use mock VAD/ASR providers so the example runs without extras.",
|
|
97
|
+
)
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"--input",
|
|
100
|
+
type=str,
|
|
101
|
+
default=None,
|
|
102
|
+
help="Read PCM frames from a 16 kHz mono WAV file instead of the mic.",
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--frame-ms",
|
|
106
|
+
type=int,
|
|
107
|
+
default=30,
|
|
108
|
+
help="Frame size in milliseconds (default 30).",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--pre-speech-ms",
|
|
112
|
+
type=int,
|
|
113
|
+
default=450,
|
|
114
|
+
help="Pre-speech buffer size in milliseconds (default 450).",
|
|
115
|
+
)
|
|
116
|
+
return parser
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
async def _drive_collector(
|
|
120
|
+
collector: AudioUtteranceCollector,
|
|
121
|
+
pipeline: SpeechPipeline,
|
|
122
|
+
config: UtteranceCollectorConfig,
|
|
123
|
+
*,
|
|
124
|
+
input_path: str | None = None,
|
|
125
|
+
) -> None:
|
|
126
|
+
"""Read PCM frames from the mic or a file and feed the collector.
|
|
127
|
+
|
|
128
|
+
Real deployments should plug a microphone capture coroutine in
|
|
129
|
+
here. The framework only cares that ``parse_audio_frame`` succeeds
|
|
130
|
+
and that the resulting :class:`AudioFrame` is passed to
|
|
131
|
+
:meth:`AudioUtteranceCollector.ingest_frame`.
|
|
132
|
+
"""
|
|
133
|
+
if input_path is None:
|
|
134
|
+
raise SystemExit(
|
|
135
|
+
"No --input file provided. Pass a 16 kHz mono WAV file to drive the "
|
|
136
|
+
"collector. (Live microphone capture is intentionally out of scope "
|
|
137
|
+
"for the example; the framework is platform-agnostic.)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
stats = AudioFrameStats(
|
|
141
|
+
expected_sample_rate=config.sample_rate,
|
|
142
|
+
expected_channels=config.channels,
|
|
143
|
+
expected_frame_ms=config.frame_ms,
|
|
144
|
+
)
|
|
145
|
+
expected_frame_bytes = config.bytes_per_ms * config.frame_ms
|
|
146
|
+
|
|
147
|
+
with open(input_path, "rb") as handle:
|
|
148
|
+
# Skip the 44-byte WAV header. The example assumes a bare
|
|
149
|
+
# PCM_s16le body; production code should use ``wave`` instead.
|
|
150
|
+
handle.read(44)
|
|
151
|
+
sequence = 0
|
|
152
|
+
while True:
|
|
153
|
+
chunk = handle.read(expected_frame_bytes)
|
|
154
|
+
if not chunk:
|
|
155
|
+
break
|
|
156
|
+
frame = parse_audio_frame(
|
|
157
|
+
{
|
|
158
|
+
# The wire format is base64-encoded PCM bytes (see
|
|
159
|
+
# ``parse_audio_frame``); encode the raw chunk so
|
|
160
|
+
# the parser can decode it back to bytes.
|
|
161
|
+
"data": base64.b64encode(chunk).decode("ascii"),
|
|
162
|
+
"sample_rate": config.sample_rate,
|
|
163
|
+
"channels": config.channels,
|
|
164
|
+
"frame_ms": config.frame_ms,
|
|
165
|
+
"encoding": "pcm_s16le",
|
|
166
|
+
"sequence": sequence,
|
|
167
|
+
},
|
|
168
|
+
stats,
|
|
169
|
+
)
|
|
170
|
+
await collector.ingest_frame(frame)
|
|
171
|
+
sequence += 1
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
async def _main_async(args: argparse.Namespace) -> int:
|
|
175
|
+
providers = _parse_provider_args(args.provider)
|
|
176
|
+
if args.mock:
|
|
177
|
+
providers = {"vad": "mock", "asr": "mock", "llm": "mock", "tts": "mock"}
|
|
178
|
+
|
|
179
|
+
queue: asyncio.Queue = asyncio.Queue()
|
|
180
|
+
sink = QueueEventSink(queue)
|
|
181
|
+
bundle = build_provider_bundle({kind: {"provider": name} for kind, name in providers.items()})
|
|
182
|
+
pipeline = SpeechPipeline(
|
|
183
|
+
providers=bundle,
|
|
184
|
+
sink=sink,
|
|
185
|
+
config=PipelineConfig(tts_chunk_chars=80),
|
|
186
|
+
)
|
|
187
|
+
collector_config = UtteranceCollectorConfig(
|
|
188
|
+
sample_rate=16000,
|
|
189
|
+
channels=1,
|
|
190
|
+
frame_ms=args.frame_ms,
|
|
191
|
+
pre_speech_ms=args.pre_speech_ms,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
async def utterance_callback(pcm: bytes, sample_rate: int, mode: str) -> None:
|
|
195
|
+
await pipeline.handle_audio_turn(pcm, sample_rate, mode=mode)
|
|
196
|
+
|
|
197
|
+
async def cancel_callback(reason: str) -> None:
|
|
198
|
+
await pipeline.cancel_tts(reason)
|
|
199
|
+
|
|
200
|
+
collector = AudioUtteranceCollector(
|
|
201
|
+
vad_provider=bundle.vad,
|
|
202
|
+
event_sink=sink,
|
|
203
|
+
utterance_callback=utterance_callback,
|
|
204
|
+
config=collector_config,
|
|
205
|
+
cancel_callback=cancel_callback,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
print(f"providers: {providers}")
|
|
209
|
+
print(f"frame_ms={args.frame_ms} pre_speech_ms={args.pre_speech_ms}")
|
|
210
|
+
print("-" * 60)
|
|
211
|
+
|
|
212
|
+
await _drive_collector(collector, pipeline, collector_config, input_path=args.input)
|
|
213
|
+
return 0
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def main(argv: list[str] | None = None) -> int:
|
|
217
|
+
parser = _build_arg_parser()
|
|
218
|
+
args = parser.parse_args(argv)
|
|
219
|
+
return asyncio.run(_main_async(args))
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
if __name__ == "__main__": # pragma: no cover - manual example
|
|
223
|
+
sys.exit(main())
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""FastAPI WebSocket voice-chat recipe.
|
|
2
|
+
|
|
3
|
+
This example shows the browser-oriented wire shape: clients send JSON
|
|
4
|
+
messages containing ``audio.frame`` payloads, the framework validates
|
|
5
|
+
them with :func:`parse_audio_frame`, and all framework events are sent
|
|
6
|
+
back over the same WebSocket.
|
|
7
|
+
|
|
8
|
+
FastAPI is imported only by :func:`create_app`, so importing this module
|
|
9
|
+
does not add a dependency to the base framework package. To run it::
|
|
10
|
+
|
|
11
|
+
pip install fastapi uvicorn
|
|
12
|
+
uvicorn converse_framework.examples.websocket_voice_chat:create_app --factory
|
|
13
|
+
|
|
14
|
+
The mock providers are used by default. Pass a provider config to
|
|
15
|
+
:func:`build_websocket_voice_runtime` when embedding this recipe in a
|
|
16
|
+
real app.
|
|
17
|
+
|
|
18
|
+
.. seealso::
|
|
19
|
+
|
|
20
|
+
:class:`converse_framework.session.WebSocketSession` provides a
|
|
21
|
+
reusable message-dispatch loop that replaces the per-endpoint
|
|
22
|
+
routing in this recipe. See the WebSocket Session Helper section
|
|
23
|
+
in the README.
|
|
24
|
+
|
|
25
|
+
.. note::
|
|
26
|
+
|
|
27
|
+
Mobile browser microphone access (``getUserMedia``) requires a
|
|
28
|
+
**secure context** — HTTPS, ``localhost``, or ``127.0.0.1``.
|
|
29
|
+
Over a plain ``http://<lan-ip>`` page, ``getUserMedia`` will be
|
|
30
|
+
rejected on mobile browsers. See the "Mobile Browser Microphone
|
|
31
|
+
Testing" section in the README for tunnel and HTTPS recipes.
|
|
32
|
+
The WebSocket URL for tunneled setups changes from
|
|
33
|
+
``ws://<host>/ws`` to ``wss://<tunnel-host>/ws``.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
from __future__ import annotations
|
|
37
|
+
|
|
38
|
+
import asyncio
|
|
39
|
+
from dataclasses import dataclass
|
|
40
|
+
from typing import Any
|
|
41
|
+
|
|
42
|
+
from converse_framework.audio_utils import AudioFrameStats, parse_audio_frame
|
|
43
|
+
from converse_framework.events import FrameworkEvent, TransportEventSink
|
|
44
|
+
from converse_framework.pipeline import PipelineConfig, SpeechPipeline
|
|
45
|
+
from converse_framework.registry import build_provider_bundle
|
|
46
|
+
from converse_framework.transport import Transport
|
|
47
|
+
from converse_framework.utterance_collector import (
|
|
48
|
+
AudioUtteranceCollector,
|
|
49
|
+
UtteranceCollectorConfig,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class WebSocketVoiceRuntime:
|
|
55
|
+
"""Objects needed to drive voice frames from a WebSocket."""
|
|
56
|
+
|
|
57
|
+
pipeline: SpeechPipeline
|
|
58
|
+
collector: AudioUtteranceCollector
|
|
59
|
+
frame_stats: AudioFrameStats
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class WebSocketTransport(Transport):
|
|
63
|
+
"""Minimal transport adapter for FastAPI-compatible WebSockets."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, websocket) -> None:
|
|
66
|
+
self.websocket = websocket
|
|
67
|
+
|
|
68
|
+
async def send_event(self, event: FrameworkEvent) -> None:
|
|
69
|
+
await self.websocket.send_json(event.to_json())
|
|
70
|
+
|
|
71
|
+
async def receive_event(self) -> FrameworkEvent:
|
|
72
|
+
message = await self.websocket.receive_json()
|
|
73
|
+
return FrameworkEvent(
|
|
74
|
+
type=str(message.get("type", "")),
|
|
75
|
+
payload=dict(message.get("payload", {}) or {}),
|
|
76
|
+
ts=float(message.get("ts", 0.0)),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def build_websocket_voice_runtime(
|
|
81
|
+
transport: Transport,
|
|
82
|
+
*,
|
|
83
|
+
provider_config: dict[str, dict[str, Any]] | None = None,
|
|
84
|
+
collector_config: UtteranceCollectorConfig | None = None,
|
|
85
|
+
pipeline_config: PipelineConfig | None = None,
|
|
86
|
+
) -> WebSocketVoiceRuntime:
|
|
87
|
+
"""Build the pipeline and collector used by the WebSocket handler."""
|
|
88
|
+
provider_config = provider_config or {
|
|
89
|
+
"vad": {"provider": "mock"},
|
|
90
|
+
"asr": {"provider": "mock"},
|
|
91
|
+
"llm": {"provider": "mock"},
|
|
92
|
+
"tts": {"provider": "mock"},
|
|
93
|
+
}
|
|
94
|
+
collector_config = collector_config or UtteranceCollectorConfig()
|
|
95
|
+
sink = TransportEventSink(transport)
|
|
96
|
+
bundle = build_provider_bundle(provider_config)
|
|
97
|
+
pipeline = SpeechPipeline(bundle, sink, pipeline_config or PipelineConfig())
|
|
98
|
+
|
|
99
|
+
async def utterance_callback(pcm: bytes, sample_rate: int, mode: str) -> None:
|
|
100
|
+
await pipeline.handle_audio_turn(pcm, sample_rate, mode=mode)
|
|
101
|
+
|
|
102
|
+
async def cancel_callback(reason: str) -> None:
|
|
103
|
+
await pipeline.cancel_tts(reason)
|
|
104
|
+
|
|
105
|
+
collector = AudioUtteranceCollector(
|
|
106
|
+
vad_provider=bundle.vad,
|
|
107
|
+
event_sink=sink,
|
|
108
|
+
utterance_callback=utterance_callback,
|
|
109
|
+
config=collector_config,
|
|
110
|
+
cancel_callback=cancel_callback,
|
|
111
|
+
)
|
|
112
|
+
frame_stats = AudioFrameStats(
|
|
113
|
+
expected_sample_rate=collector_config.sample_rate,
|
|
114
|
+
expected_channels=collector_config.channels,
|
|
115
|
+
expected_frame_ms=collector_config.frame_ms,
|
|
116
|
+
)
|
|
117
|
+
return WebSocketVoiceRuntime(pipeline, collector, frame_stats)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def handle_websocket_message(
|
|
121
|
+
runtime: WebSocketVoiceRuntime,
|
|
122
|
+
transport: Transport,
|
|
123
|
+
message: dict[str, Any],
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Handle one client message from the WebSocket recipe."""
|
|
126
|
+
message_type = str(message.get("type", ""))
|
|
127
|
+
payload = dict(message.get("payload", {}) or {})
|
|
128
|
+
mode = str(payload.pop("mode", "chat"))
|
|
129
|
+
|
|
130
|
+
if message_type == "audio.frame":
|
|
131
|
+
try:
|
|
132
|
+
frame = parse_audio_frame(payload, runtime.frame_stats)
|
|
133
|
+
except ValueError as exc:
|
|
134
|
+
await transport.send_event(
|
|
135
|
+
FrameworkEvent("audio.frame_error", {"message": str(exc)})
|
|
136
|
+
)
|
|
137
|
+
return
|
|
138
|
+
await runtime.collector.ingest_frame(frame, mode=mode)
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
if message_type == "text.turn":
|
|
142
|
+
text = str(payload.get("text", ""))
|
|
143
|
+
if text:
|
|
144
|
+
await runtime.pipeline.handle_text_turn(text, mode=mode)
|
|
145
|
+
return
|
|
146
|
+
|
|
147
|
+
if message_type == "conversation.clear":
|
|
148
|
+
await runtime.pipeline.clear_conversation(mode=mode)
|
|
149
|
+
return
|
|
150
|
+
|
|
151
|
+
await transport.send_event(
|
|
152
|
+
FrameworkEvent(
|
|
153
|
+
"turn.error", {"message": f"unknown message type: {message_type}"}
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def create_app():
|
|
159
|
+
"""Create a tiny FastAPI app exposing ``/ws`` for voice chat."""
|
|
160
|
+
from fastapi import FastAPI, WebSocket
|
|
161
|
+
|
|
162
|
+
app = FastAPI()
|
|
163
|
+
|
|
164
|
+
@app.websocket("/ws")
|
|
165
|
+
async def websocket_endpoint(websocket: WebSocket) -> None:
|
|
166
|
+
await websocket.accept()
|
|
167
|
+
transport = WebSocketTransport(websocket)
|
|
168
|
+
runtime = build_websocket_voice_runtime(transport)
|
|
169
|
+
while True:
|
|
170
|
+
message = await websocket.receive_json()
|
|
171
|
+
await handle_websocket_message(runtime, transport, message)
|
|
172
|
+
await asyncio.sleep(0)
|
|
173
|
+
|
|
174
|
+
return app
|