converse-framework 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- converse_framework/__init__.py +108 -0
- converse_framework/audio_utils.py +412 -0
- converse_framework/cuda_utils.py +176 -0
- converse_framework/events.py +94 -0
- converse_framework/examples/__init__.py +20 -0
- converse_framework/examples/subprocess_provider.py +439 -0
- converse_framework/examples/text_chat.py +308 -0
- converse_framework/examples/voice_chat.py +223 -0
- converse_framework/examples/websocket_voice_chat.py +174 -0
- converse_framework/js/browser-voice-client.js +248 -0
- converse_framework/js/mic-frame-sender.js +445 -0
- converse_framework/js/speaker-echo-guard.js +308 -0
- converse_framework/js/tts-audio-player.js +237 -0
- converse_framework/pipeline.py +620 -0
- converse_framework/protocols.py +382 -0
- converse_framework/provider_events.py +159 -0
- converse_framework/providers/__init__.py +28 -0
- converse_framework/providers/faster_whisper.py +290 -0
- converse_framework/providers/kokoro_onnx.py +391 -0
- converse_framework/providers/llamacpp.py +264 -0
- converse_framework/providers/mock.py +171 -0
- converse_framework/providers/pocket_tts.py +409 -0
- converse_framework/providers/silero.py +161 -0
- converse_framework/providers/unavailable.py +137 -0
- converse_framework/providers/whisper_cpp.py +322 -0
- converse_framework/registry.py +397 -0
- converse_framework/session.py +315 -0
- converse_framework/transport.py +54 -0
- converse_framework/utterance_collector.py +336 -0
- converse_framework-0.2.0.dist-info/METADATA +992 -0
- converse_framework-0.2.0.dist-info/RECORD +33 -0
- converse_framework-0.2.0.dist-info/WHEEL +4 -0
- converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
"""Reusable WebSocket session helper built on the framework transport protocol.
|
|
2
|
+
|
|
3
|
+
Provides the runtime state machine and message routing that the
|
|
4
|
+
``websocket_voice_chat`` recipe previously owned, packaged as a
|
|
5
|
+
framework-level component so apps do not need to reimplement the
|
|
6
|
+
message-dispatch loop.
|
|
7
|
+
|
|
8
|
+
The session module depends only on framework protocols and dataclasses.
|
|
9
|
+
It does **not** import FastAPI or any HTTP/Wire server library.
|
|
10
|
+
Apps serve the actual WebSocket endpoint and pass events to
|
|
11
|
+
:meth:`WebSocketSession.handle_message`.
|
|
12
|
+
|
|
13
|
+
Example usage in a FastAPI endpoint::
|
|
14
|
+
|
|
15
|
+
from fastapi import FastAPI, WebSocket
|
|
16
|
+
from converse_framework.transport import Transport
|
|
17
|
+
from converse_framework.session import WebSocketSession, WebSocketSessionConfig
|
|
18
|
+
|
|
19
|
+
app = FastAPI()
|
|
20
|
+
|
|
21
|
+
@app.websocket("/ws")
|
|
22
|
+
async def ws(websocket: WebSocket) -> None:
|
|
23
|
+
await websocket.accept()
|
|
24
|
+
transport = _as_transport(websocket)
|
|
25
|
+
session = WebSocketSession(transport)
|
|
26
|
+
async for message in websocket.iter_json():
|
|
27
|
+
await session.handle_message(message)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
from collections.abc import Awaitable, Callable
|
|
33
|
+
from dataclasses import dataclass
|
|
34
|
+
from typing import Any
|
|
35
|
+
|
|
36
|
+
from converse_framework.audio_utils import AudioFrameStats, parse_audio_frame
|
|
37
|
+
from converse_framework.events import FrameworkEvent, TransportEventSink
|
|
38
|
+
from converse_framework.pipeline import PipelineConfig, SpeechPipeline
|
|
39
|
+
from converse_framework.registry import ProviderBundle, build_provider_bundle
|
|
40
|
+
from converse_framework.transport import Transport
|
|
41
|
+
from converse_framework.utterance_collector import (
|
|
42
|
+
AudioUtteranceCollector,
|
|
43
|
+
UtteranceCollectorConfig,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Forward-reference so hook signatures can use the class name.
|
|
47
|
+
HookFn = Callable[["WebSocketSession", dict[str, Any]], Awaitable[None]] # noqa: F821
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class WebSocketSessionConfig:
|
|
52
|
+
"""Configuration for building a :class:`WebSocketSession`.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
provider_config: Provider configuration dict passed to
|
|
56
|
+
:func:`build_provider_bundle`. Defaults to mock-only.
|
|
57
|
+
collector_config: VAD frame collector configuration.
|
|
58
|
+
pipeline_config: Pipeline configuration.
|
|
59
|
+
default_mode: Default conversation mode when the client
|
|
60
|
+
does not specify one (``"chat"``, ``"custom"``, ...).
|
|
61
|
+
auto_probe_status: If True, run ``probe_statuses()`` on
|
|
62
|
+
each ``status.request`` message. If False, run the
|
|
63
|
+
heavier ``check_statuses()`` call instead.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
provider_config: dict[str, dict[str, Any]] | None = None
|
|
67
|
+
collector_config: UtteranceCollectorConfig | None = None
|
|
68
|
+
pipeline_config: PipelineConfig | None = None
|
|
69
|
+
default_mode: str = "chat"
|
|
70
|
+
auto_probe_status: bool = True
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class WebSocketSessionHooks:
|
|
75
|
+
"""Optional hooks injected into :class:`WebSocketSession`.
|
|
76
|
+
|
|
77
|
+
Each hook is an async callable ``(session, payload) -> None``.
|
|
78
|
+
If a hook is not provided, the session falls back to default
|
|
79
|
+
behaviour (emit ``turn.error`` for unknown messages, ignore
|
|
80
|
+
settings updates, etc.).
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
on_unknown_message: Called when no built-in handler matches
|
|
84
|
+
the message type. Payload is the full message dict.
|
|
85
|
+
on_settings_update: Called for ``settings.update`` messages.
|
|
86
|
+
on_status_request: Called after status request is handled.
|
|
87
|
+
Payload includes ``kind`` and ``statuses``.
|
|
88
|
+
on_before_provider_reload: Called before a provider reload
|
|
89
|
+
with the old bundle and new provider config.
|
|
90
|
+
on_after_provider_reload: Called after a provider reload
|
|
91
|
+
with the old and new bundles.
|
|
92
|
+
on_event: Called for every :class:`FrameworkEvent` the
|
|
93
|
+
session emits, before the transport sends it. Apps can
|
|
94
|
+
use this for logging or filtering.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
on_unknown_message: HookFn | None = None
|
|
98
|
+
on_settings_update: HookFn | None = None
|
|
99
|
+
on_status_request: HookFn | None = None
|
|
100
|
+
on_before_provider_reload: Callable[..., Awaitable[None]] | None = None
|
|
101
|
+
on_after_provider_reload: Callable[..., Awaitable[None]] | None = None
|
|
102
|
+
on_event: Callable[..., Awaitable[None]] | None = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class WebSocketSession:
|
|
106
|
+
"""Reusable WebSocket voice-chat session.
|
|
107
|
+
|
|
108
|
+
Owns the full runtime (bundle, pipeline, collector, transport,
|
|
109
|
+
sink) and exposes :meth:`handle_message` for each inbound
|
|
110
|
+
WebSocket event.
|
|
111
|
+
|
|
112
|
+
Example::
|
|
113
|
+
|
|
114
|
+
session = WebSocketSession(transport)
|
|
115
|
+
|
|
116
|
+
async for message in websocket.iter_json():
|
|
117
|
+
await session.handle_message(message)
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
transport: Transport,
|
|
123
|
+
*,
|
|
124
|
+
config: WebSocketSessionConfig | None = None,
|
|
125
|
+
hooks: WebSocketSessionHooks | None = None,
|
|
126
|
+
) -> None:
|
|
127
|
+
self.config = config or WebSocketSessionConfig()
|
|
128
|
+
self.hooks = hooks or WebSocketSessionHooks()
|
|
129
|
+
|
|
130
|
+
self.transport = transport
|
|
131
|
+
self.sink = TransportEventSink(transport)
|
|
132
|
+
|
|
133
|
+
provider_config = self.config.provider_config or {
|
|
134
|
+
"vad": {"provider": "mock"},
|
|
135
|
+
"asr": {"provider": "mock"},
|
|
136
|
+
"llm": {"provider": "mock"},
|
|
137
|
+
"tts": {"provider": "mock"},
|
|
138
|
+
}
|
|
139
|
+
collector_config = self.config.collector_config or UtteranceCollectorConfig()
|
|
140
|
+
self.bundle: ProviderBundle = build_provider_bundle(provider_config)
|
|
141
|
+
self.pipeline = SpeechPipeline(
|
|
142
|
+
self.bundle,
|
|
143
|
+
self.sink,
|
|
144
|
+
self.config.pipeline_config or PipelineConfig(),
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
async def utterance_callback(pcm: bytes, sample_rate: int, mode: str) -> None:
|
|
148
|
+
await self.pipeline.handle_audio_turn(pcm, sample_rate, mode=mode)
|
|
149
|
+
|
|
150
|
+
async def cancel_callback(reason: str) -> None:
|
|
151
|
+
await self.pipeline.cancel_tts(reason)
|
|
152
|
+
|
|
153
|
+
self.collector = AudioUtteranceCollector(
|
|
154
|
+
vad_provider=self.bundle.vad,
|
|
155
|
+
event_sink=self.sink,
|
|
156
|
+
utterance_callback=utterance_callback,
|
|
157
|
+
config=collector_config,
|
|
158
|
+
cancel_callback=cancel_callback,
|
|
159
|
+
)
|
|
160
|
+
self.frame_stats = AudioFrameStats(
|
|
161
|
+
expected_sample_rate=collector_config.sample_rate,
|
|
162
|
+
expected_channels=collector_config.channels,
|
|
163
|
+
expected_frame_ms=collector_config.frame_ms,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# ------------------------------------------------------------------
|
|
167
|
+
# Message routing
|
|
168
|
+
# ------------------------------------------------------------------
|
|
169
|
+
|
|
170
|
+
async def handle_message(self, message: dict[str, Any]) -> None:
|
|
171
|
+
"""Route one inbound message to the appropriate handler.
|
|
172
|
+
|
|
173
|
+
Built-in message types:
|
|
174
|
+
|
|
175
|
+
* ``audio.frame`` — audio data from the client mic.
|
|
176
|
+
* ``text.turn`` — text input (non-audio path).
|
|
177
|
+
* ``conversation.clear`` — reset conversation history.
|
|
178
|
+
* ``tts.cancel`` — interrupt active TTS playback.
|
|
179
|
+
* ``status.request`` — request current provider statuses.
|
|
180
|
+
* ``settings.update`` — routed to ``on_settings_update`` hook.
|
|
181
|
+
* ``providers.reload`` — reload providers from ``payload.config``.
|
|
182
|
+
"""
|
|
183
|
+
message_type = str(message.get("type", ""))
|
|
184
|
+
payload: dict[str, Any] = dict(message.get("payload", {}) or {})
|
|
185
|
+
mode = str(payload.pop("mode", self.config.default_mode))
|
|
186
|
+
|
|
187
|
+
if message_type == "audio.frame":
|
|
188
|
+
await self._handle_audio_frame(payload, mode=mode)
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
if message_type == "text.turn":
|
|
192
|
+
text = str(payload.get("text", ""))
|
|
193
|
+
if text:
|
|
194
|
+
await self.pipeline.handle_text_turn(text, mode=mode)
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
if message_type == "conversation.clear":
|
|
198
|
+
await self.pipeline.clear_conversation(mode=mode)
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
if message_type == "tts.cancel":
|
|
202
|
+
reason = str(payload.get("reason", "client_cancelled"))
|
|
203
|
+
await self.pipeline.cancel_tts(reason)
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
if message_type == "status.request":
|
|
207
|
+
await self._handle_status_request(payload)
|
|
208
|
+
return
|
|
209
|
+
|
|
210
|
+
if message_type == "settings.update":
|
|
211
|
+
if self.hooks.on_settings_update:
|
|
212
|
+
await self.hooks.on_settings_update(self, payload)
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
if message_type == "providers.reload":
|
|
216
|
+
new_config = dict(payload.get("config", {}))
|
|
217
|
+
load = bool(payload.get("load", False))
|
|
218
|
+
await self._reload_providers(new_config, load=load)
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
# Unknown message type — try hook, then emit error.
|
|
222
|
+
if self.hooks.on_unknown_message:
|
|
223
|
+
await self.hooks.on_unknown_message(self, message)
|
|
224
|
+
else:
|
|
225
|
+
await self._send_event(
|
|
226
|
+
"turn.error", message=f"unknown message type: {message_type}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# ------------------------------------------------------------------
|
|
230
|
+
# Status
|
|
231
|
+
# ------------------------------------------------------------------
|
|
232
|
+
|
|
233
|
+
async def emit_status(self, kind: str = "probe") -> None:
|
|
234
|
+
"""Probe and emit current provider statuses."""
|
|
235
|
+
await self._handle_status_request({"kind": kind})
|
|
236
|
+
|
|
237
|
+
async def _handle_status_request(self, payload: dict[str, Any]) -> None:
|
|
238
|
+
kind = str(
|
|
239
|
+
payload.get("kind", "probe" if self.config.auto_probe_status else "check")
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if kind == "probe":
|
|
243
|
+
statuses = await self.bundle.probe_statuses()
|
|
244
|
+
elif kind == "check":
|
|
245
|
+
statuses = await self.bundle.check_statuses()
|
|
246
|
+
elif kind == "load":
|
|
247
|
+
statuses = await self.bundle.load_statuses()
|
|
248
|
+
else:
|
|
249
|
+
statuses = await self.bundle.probe_statuses()
|
|
250
|
+
|
|
251
|
+
event = FrameworkEvent("providers.status", {"statuses": statuses})
|
|
252
|
+
|
|
253
|
+
if self.hooks.on_status_request:
|
|
254
|
+
await self.hooks.on_status_request(
|
|
255
|
+
self, {"kind": kind, "statuses": statuses}
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
await self._send_event(event.type, **event.payload)
|
|
259
|
+
|
|
260
|
+
# ------------------------------------------------------------------
|
|
261
|
+
# Provider reload
|
|
262
|
+
# ------------------------------------------------------------------
|
|
263
|
+
|
|
264
|
+
async def reload_providers(
|
|
265
|
+
self, config: dict[str, dict[str, Any]], *, load: bool = False
|
|
266
|
+
) -> None:
|
|
267
|
+
"""Rebuild provider bundle and swap into pipeline/collector.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
config: Provider configuration dict for
|
|
271
|
+
:func:`build_provider_bundle`.
|
|
272
|
+
load: If True, run ``load_statuses()`` on the new bundle
|
|
273
|
+
after swapping (loads heavy models).
|
|
274
|
+
"""
|
|
275
|
+
await self._reload_providers(config, load=load)
|
|
276
|
+
|
|
277
|
+
async def _reload_providers(
|
|
278
|
+
self, config: dict[str, dict[str, Any]], *, load: bool = False
|
|
279
|
+
) -> None:
|
|
280
|
+
old_bundle = self.bundle
|
|
281
|
+
|
|
282
|
+
if self.hooks.on_before_provider_reload:
|
|
283
|
+
await self.hooks.on_before_provider_reload(self, old_bundle, config)
|
|
284
|
+
|
|
285
|
+
new_bundle = build_provider_bundle(config)
|
|
286
|
+
await self.pipeline.update_providers(new_bundle, reason="provider_reload")
|
|
287
|
+
self.bundle = new_bundle
|
|
288
|
+
|
|
289
|
+
# Swap the VAD provider in the collector
|
|
290
|
+
self.collector.update_vad_provider(new_bundle.vad)
|
|
291
|
+
|
|
292
|
+
if load:
|
|
293
|
+
await new_bundle.load_statuses()
|
|
294
|
+
|
|
295
|
+
if self.hooks.on_after_provider_reload:
|
|
296
|
+
await self.hooks.on_after_provider_reload(self, old_bundle, new_bundle)
|
|
297
|
+
|
|
298
|
+
# ------------------------------------------------------------------
|
|
299
|
+
# Internal helpers
|
|
300
|
+
# ------------------------------------------------------------------
|
|
301
|
+
|
|
302
|
+
async def _handle_audio_frame(self, payload: dict[str, Any], mode: str) -> None:
|
|
303
|
+
try:
|
|
304
|
+
frame = parse_audio_frame(payload, self.frame_stats)
|
|
305
|
+
except ValueError as exc:
|
|
306
|
+
await self._send_event("audio.frame_error", message=str(exc))
|
|
307
|
+
return
|
|
308
|
+
await self.collector.ingest_frame(frame, mode=mode)
|
|
309
|
+
|
|
310
|
+
async def _send_event(self, event_type: str, **payload: Any) -> None:
|
|
311
|
+
"""Emit a framework event, routing through ``on_event`` hook if set."""
|
|
312
|
+
event = FrameworkEvent(event_type, payload)
|
|
313
|
+
if self.hooks.on_event:
|
|
314
|
+
await self.hooks.on_event(self, event)
|
|
315
|
+
await self.transport.send_event(event)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Transport protocol and in-memory queue transport for testing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import Protocol, runtime_checkable
|
|
7
|
+
|
|
8
|
+
from converse_framework.events import FrameworkEvent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@runtime_checkable
|
|
12
|
+
class Transport(Protocol):
|
|
13
|
+
"""Protocol for sending and receiving framework events over a transport.
|
|
14
|
+
|
|
15
|
+
A :class:`Transport` is the boundary between the framework and
|
|
16
|
+
whatever wire the host application uses (WebSocket, in-process
|
|
17
|
+
queue, log file, ...). The framework itself never reaches
|
|
18
|
+
across this boundary -- it produces :class:`FrameworkEvent`
|
|
19
|
+
instances and the transport decides how to serialise and
|
|
20
|
+
deliver them.
|
|
21
|
+
|
|
22
|
+
Implementations must be :func:`asyncio` compatible. Tests use
|
|
23
|
+
:class:`QueueTransport` to capture events without involving a
|
|
24
|
+
real I/O stack.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
async def send_event(self, event: FrameworkEvent) -> None: ...
|
|
28
|
+
|
|
29
|
+
async def receive_event(self) -> FrameworkEvent: ...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class QueueTransport(Transport):
|
|
33
|
+
"""In-memory dual-queue transport for testing consumers without I/O.
|
|
34
|
+
|
|
35
|
+
Maintains two independent ``asyncio.Queue`` instances: one
|
|
36
|
+
queue collects events the pipeline / sinks push via
|
|
37
|
+
:meth:`send_event` (the "outbound" stream a fake client would
|
|
38
|
+
read from), the other feeds events a fake client pushes back
|
|
39
|
+
into the framework via :meth:`receive_event`.
|
|
40
|
+
|
|
41
|
+
The queues are unbounded by default, which matches the
|
|
42
|
+
semantics expected by tests: every emitted event must be
|
|
43
|
+
observable, and the test controls when the consumer drains.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
self._send_queue: asyncio.Queue[FrameworkEvent] = asyncio.Queue()
|
|
48
|
+
self._recv_queue: asyncio.Queue[FrameworkEvent] = asyncio.Queue()
|
|
49
|
+
|
|
50
|
+
async def send_event(self, event: FrameworkEvent) -> None:
|
|
51
|
+
await self._send_queue.put(event)
|
|
52
|
+
|
|
53
|
+
async def receive_event(self) -> FrameworkEvent:
|
|
54
|
+
return await self._recv_queue.get()
|