converse-framework 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. converse_framework/__init__.py +108 -0
  2. converse_framework/audio_utils.py +412 -0
  3. converse_framework/cuda_utils.py +176 -0
  4. converse_framework/events.py +94 -0
  5. converse_framework/examples/__init__.py +20 -0
  6. converse_framework/examples/subprocess_provider.py +439 -0
  7. converse_framework/examples/text_chat.py +308 -0
  8. converse_framework/examples/voice_chat.py +223 -0
  9. converse_framework/examples/websocket_voice_chat.py +174 -0
  10. converse_framework/js/browser-voice-client.js +248 -0
  11. converse_framework/js/mic-frame-sender.js +445 -0
  12. converse_framework/js/speaker-echo-guard.js +308 -0
  13. converse_framework/js/tts-audio-player.js +237 -0
  14. converse_framework/pipeline.py +620 -0
  15. converse_framework/protocols.py +382 -0
  16. converse_framework/provider_events.py +159 -0
  17. converse_framework/providers/__init__.py +28 -0
  18. converse_framework/providers/faster_whisper.py +290 -0
  19. converse_framework/providers/kokoro_onnx.py +391 -0
  20. converse_framework/providers/llamacpp.py +264 -0
  21. converse_framework/providers/mock.py +171 -0
  22. converse_framework/providers/pocket_tts.py +409 -0
  23. converse_framework/providers/silero.py +161 -0
  24. converse_framework/providers/unavailable.py +137 -0
  25. converse_framework/providers/whisper_cpp.py +322 -0
  26. converse_framework/registry.py +397 -0
  27. converse_framework/session.py +315 -0
  28. converse_framework/transport.py +54 -0
  29. converse_framework/utterance_collector.py +336 -0
  30. converse_framework-0.2.0.dist-info/METADATA +992 -0
  31. converse_framework-0.2.0.dist-info/RECORD +33 -0
  32. converse_framework-0.2.0.dist-info/WHEEL +4 -0
  33. converse_framework-0.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,315 @@
1
+ """Reusable WebSocket session helper built on the framework transport protocol.
2
+
3
+ Provides the runtime state machine and message routing that the
4
+ ``websocket_voice_chat`` recipe previously owned, packaged as a
5
+ framework-level component so apps do not need to reimplement the
6
+ message-dispatch loop.
7
+
8
+ The session module depends only on framework protocols and dataclasses.
9
+ It does **not** import FastAPI or any HTTP/Wire server library.
10
+ Apps serve the actual WebSocket endpoint and pass events to
11
+ :meth:`WebSocketSession.handle_message`.
12
+
13
+ Example usage in a FastAPI endpoint::
14
+
15
+ from fastapi import FastAPI, WebSocket
16
+ from converse_framework.transport import Transport
17
+ from converse_framework.session import WebSocketSession, WebSocketSessionConfig
18
+
19
+ app = FastAPI()
20
+
21
+ @app.websocket("/ws")
22
+ async def ws(websocket: WebSocket) -> None:
23
+ await websocket.accept()
24
+ transport = _as_transport(websocket)
25
+ session = WebSocketSession(transport)
26
+ async for message in websocket.iter_json():
27
+ await session.handle_message(message)
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ from collections.abc import Awaitable, Callable
33
+ from dataclasses import dataclass
34
+ from typing import Any
35
+
36
+ from converse_framework.audio_utils import AudioFrameStats, parse_audio_frame
37
+ from converse_framework.events import FrameworkEvent, TransportEventSink
38
+ from converse_framework.pipeline import PipelineConfig, SpeechPipeline
39
+ from converse_framework.registry import ProviderBundle, build_provider_bundle
40
+ from converse_framework.transport import Transport
41
+ from converse_framework.utterance_collector import (
42
+ AudioUtteranceCollector,
43
+ UtteranceCollectorConfig,
44
+ )
45
+
46
+ # Forward-reference so hook signatures can use the class name.
47
+ HookFn = Callable[["WebSocketSession", dict[str, Any]], Awaitable[None]] # noqa: F821
48
+
49
+
50
+ @dataclass
51
+ class WebSocketSessionConfig:
52
+ """Configuration for building a :class:`WebSocketSession`.
53
+
54
+ Attributes:
55
+ provider_config: Provider configuration dict passed to
56
+ :func:`build_provider_bundle`. Defaults to mock-only.
57
+ collector_config: VAD frame collector configuration.
58
+ pipeline_config: Pipeline configuration.
59
+ default_mode: Default conversation mode when the client
60
+ does not specify one (``"chat"``, ``"custom"``, ...).
61
+ auto_probe_status: If True, run ``probe_statuses()`` on
62
+ each ``status.request`` message. If False, run the
63
+ heavier ``check_statuses()`` call instead.
64
+ """
65
+
66
+ provider_config: dict[str, dict[str, Any]] | None = None
67
+ collector_config: UtteranceCollectorConfig | None = None
68
+ pipeline_config: PipelineConfig | None = None
69
+ default_mode: str = "chat"
70
+ auto_probe_status: bool = True
71
+
72
+
73
+ @dataclass
74
+ class WebSocketSessionHooks:
75
+ """Optional hooks injected into :class:`WebSocketSession`.
76
+
77
+ Each hook is an async callable ``(session, payload) -> None``.
78
+ If a hook is not provided, the session falls back to default
79
+ behaviour (emit ``turn.error`` for unknown messages, ignore
80
+ settings updates, etc.).
81
+
82
+ Attributes:
83
+ on_unknown_message: Called when no built-in handler matches
84
+ the message type. Payload is the full message dict.
85
+ on_settings_update: Called for ``settings.update`` messages.
86
+ on_status_request: Called after status request is handled.
87
+ Payload includes ``kind`` and ``statuses``.
88
+ on_before_provider_reload: Called before a provider reload
89
+ with the old bundle and new provider config.
90
+ on_after_provider_reload: Called after a provider reload
91
+ with the old and new bundles.
92
+ on_event: Called for every :class:`FrameworkEvent` the
93
+ session emits, before the transport sends it. Apps can
94
+ use this for logging or filtering.
95
+ """
96
+
97
+ on_unknown_message: HookFn | None = None
98
+ on_settings_update: HookFn | None = None
99
+ on_status_request: HookFn | None = None
100
+ on_before_provider_reload: Callable[..., Awaitable[None]] | None = None
101
+ on_after_provider_reload: Callable[..., Awaitable[None]] | None = None
102
+ on_event: Callable[..., Awaitable[None]] | None = None
103
+
104
+
105
+ class WebSocketSession:
106
+ """Reusable WebSocket voice-chat session.
107
+
108
+ Owns the full runtime (bundle, pipeline, collector, transport,
109
+ sink) and exposes :meth:`handle_message` for each inbound
110
+ WebSocket event.
111
+
112
+ Example::
113
+
114
+ session = WebSocketSession(transport)
115
+
116
+ async for message in websocket.iter_json():
117
+ await session.handle_message(message)
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ transport: Transport,
123
+ *,
124
+ config: WebSocketSessionConfig | None = None,
125
+ hooks: WebSocketSessionHooks | None = None,
126
+ ) -> None:
127
+ self.config = config or WebSocketSessionConfig()
128
+ self.hooks = hooks or WebSocketSessionHooks()
129
+
130
+ self.transport = transport
131
+ self.sink = TransportEventSink(transport)
132
+
133
+ provider_config = self.config.provider_config or {
134
+ "vad": {"provider": "mock"},
135
+ "asr": {"provider": "mock"},
136
+ "llm": {"provider": "mock"},
137
+ "tts": {"provider": "mock"},
138
+ }
139
+ collector_config = self.config.collector_config or UtteranceCollectorConfig()
140
+ self.bundle: ProviderBundle = build_provider_bundle(provider_config)
141
+ self.pipeline = SpeechPipeline(
142
+ self.bundle,
143
+ self.sink,
144
+ self.config.pipeline_config or PipelineConfig(),
145
+ )
146
+
147
+ async def utterance_callback(pcm: bytes, sample_rate: int, mode: str) -> None:
148
+ await self.pipeline.handle_audio_turn(pcm, sample_rate, mode=mode)
149
+
150
+ async def cancel_callback(reason: str) -> None:
151
+ await self.pipeline.cancel_tts(reason)
152
+
153
+ self.collector = AudioUtteranceCollector(
154
+ vad_provider=self.bundle.vad,
155
+ event_sink=self.sink,
156
+ utterance_callback=utterance_callback,
157
+ config=collector_config,
158
+ cancel_callback=cancel_callback,
159
+ )
160
+ self.frame_stats = AudioFrameStats(
161
+ expected_sample_rate=collector_config.sample_rate,
162
+ expected_channels=collector_config.channels,
163
+ expected_frame_ms=collector_config.frame_ms,
164
+ )
165
+
166
+ # ------------------------------------------------------------------
167
+ # Message routing
168
+ # ------------------------------------------------------------------
169
+
170
+ async def handle_message(self, message: dict[str, Any]) -> None:
171
+ """Route one inbound message to the appropriate handler.
172
+
173
+ Built-in message types:
174
+
175
+ * ``audio.frame`` — audio data from the client mic.
176
+ * ``text.turn`` — text input (non-audio path).
177
+ * ``conversation.clear`` — reset conversation history.
178
+ * ``tts.cancel`` — interrupt active TTS playback.
179
+ * ``status.request`` — request current provider statuses.
180
+ * ``settings.update`` — routed to ``on_settings_update`` hook.
181
+ * ``providers.reload`` — reload providers from ``payload.config``.
182
+ """
183
+ message_type = str(message.get("type", ""))
184
+ payload: dict[str, Any] = dict(message.get("payload", {}) or {})
185
+ mode = str(payload.pop("mode", self.config.default_mode))
186
+
187
+ if message_type == "audio.frame":
188
+ await self._handle_audio_frame(payload, mode=mode)
189
+ return
190
+
191
+ if message_type == "text.turn":
192
+ text = str(payload.get("text", ""))
193
+ if text:
194
+ await self.pipeline.handle_text_turn(text, mode=mode)
195
+ return
196
+
197
+ if message_type == "conversation.clear":
198
+ await self.pipeline.clear_conversation(mode=mode)
199
+ return
200
+
201
+ if message_type == "tts.cancel":
202
+ reason = str(payload.get("reason", "client_cancelled"))
203
+ await self.pipeline.cancel_tts(reason)
204
+ return
205
+
206
+ if message_type == "status.request":
207
+ await self._handle_status_request(payload)
208
+ return
209
+
210
+ if message_type == "settings.update":
211
+ if self.hooks.on_settings_update:
212
+ await self.hooks.on_settings_update(self, payload)
213
+ return
214
+
215
+ if message_type == "providers.reload":
216
+ new_config = dict(payload.get("config", {}))
217
+ load = bool(payload.get("load", False))
218
+ await self._reload_providers(new_config, load=load)
219
+ return
220
+
221
+ # Unknown message type — try hook, then emit error.
222
+ if self.hooks.on_unknown_message:
223
+ await self.hooks.on_unknown_message(self, message)
224
+ else:
225
+ await self._send_event(
226
+ "turn.error", message=f"unknown message type: {message_type}"
227
+ )
228
+
229
+ # ------------------------------------------------------------------
230
+ # Status
231
+ # ------------------------------------------------------------------
232
+
233
+ async def emit_status(self, kind: str = "probe") -> None:
234
+ """Probe and emit current provider statuses."""
235
+ await self._handle_status_request({"kind": kind})
236
+
237
+ async def _handle_status_request(self, payload: dict[str, Any]) -> None:
238
+ kind = str(
239
+ payload.get("kind", "probe" if self.config.auto_probe_status else "check")
240
+ )
241
+
242
+ if kind == "probe":
243
+ statuses = await self.bundle.probe_statuses()
244
+ elif kind == "check":
245
+ statuses = await self.bundle.check_statuses()
246
+ elif kind == "load":
247
+ statuses = await self.bundle.load_statuses()
248
+ else:
249
+ statuses = await self.bundle.probe_statuses()
250
+
251
+ event = FrameworkEvent("providers.status", {"statuses": statuses})
252
+
253
+ if self.hooks.on_status_request:
254
+ await self.hooks.on_status_request(
255
+ self, {"kind": kind, "statuses": statuses}
256
+ )
257
+
258
+ await self._send_event(event.type, **event.payload)
259
+
260
+ # ------------------------------------------------------------------
261
+ # Provider reload
262
+ # ------------------------------------------------------------------
263
+
264
+ async def reload_providers(
265
+ self, config: dict[str, dict[str, Any]], *, load: bool = False
266
+ ) -> None:
267
+ """Rebuild provider bundle and swap into pipeline/collector.
268
+
269
+ Args:
270
+ config: Provider configuration dict for
271
+ :func:`build_provider_bundle`.
272
+ load: If True, run ``load_statuses()`` on the new bundle
273
+ after swapping (loads heavy models).
274
+ """
275
+ await self._reload_providers(config, load=load)
276
+
277
+ async def _reload_providers(
278
+ self, config: dict[str, dict[str, Any]], *, load: bool = False
279
+ ) -> None:
280
+ old_bundle = self.bundle
281
+
282
+ if self.hooks.on_before_provider_reload:
283
+ await self.hooks.on_before_provider_reload(self, old_bundle, config)
284
+
285
+ new_bundle = build_provider_bundle(config)
286
+ await self.pipeline.update_providers(new_bundle, reason="provider_reload")
287
+ self.bundle = new_bundle
288
+
289
+ # Swap the VAD provider in the collector
290
+ self.collector.update_vad_provider(new_bundle.vad)
291
+
292
+ if load:
293
+ await new_bundle.load_statuses()
294
+
295
+ if self.hooks.on_after_provider_reload:
296
+ await self.hooks.on_after_provider_reload(self, old_bundle, new_bundle)
297
+
298
+ # ------------------------------------------------------------------
299
+ # Internal helpers
300
+ # ------------------------------------------------------------------
301
+
302
+ async def _handle_audio_frame(self, payload: dict[str, Any], mode: str) -> None:
303
+ try:
304
+ frame = parse_audio_frame(payload, self.frame_stats)
305
+ except ValueError as exc:
306
+ await self._send_event("audio.frame_error", message=str(exc))
307
+ return
308
+ await self.collector.ingest_frame(frame, mode=mode)
309
+
310
+ async def _send_event(self, event_type: str, **payload: Any) -> None:
311
+ """Emit a framework event, routing through ``on_event`` hook if set."""
312
+ event = FrameworkEvent(event_type, payload)
313
+ if self.hooks.on_event:
314
+ await self.hooks.on_event(self, event)
315
+ await self.transport.send_event(event)
@@ -0,0 +1,54 @@
1
+ """Transport protocol and in-memory queue transport for testing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import Protocol, runtime_checkable
7
+
8
+ from converse_framework.events import FrameworkEvent
9
+
10
+
11
+ @runtime_checkable
12
+ class Transport(Protocol):
13
+ """Protocol for sending and receiving framework events over a transport.
14
+
15
+ A :class:`Transport` is the boundary between the framework and
16
+ whatever wire the host application uses (WebSocket, in-process
17
+ queue, log file, ...). The framework itself never reaches
18
+ across this boundary -- it produces :class:`FrameworkEvent`
19
+ instances and the transport decides how to serialise and
20
+ deliver them.
21
+
22
+ Implementations must be :func:`asyncio` compatible. Tests use
23
+ :class:`QueueTransport` to capture events without involving a
24
+ real I/O stack.
25
+ """
26
+
27
+ async def send_event(self, event: FrameworkEvent) -> None: ...
28
+
29
+ async def receive_event(self) -> FrameworkEvent: ...
30
+
31
+
32
+ class QueueTransport(Transport):
33
+ """In-memory dual-queue transport for testing consumers without I/O.
34
+
35
+ Maintains two independent ``asyncio.Queue`` instances: one
36
+ queue collects events the pipeline / sinks push via
37
+ :meth:`send_event` (the "outbound" stream a fake client would
38
+ read from), the other feeds events a fake client pushes back
39
+ into the framework via :meth:`receive_event`.
40
+
41
+ The queues are unbounded by default, which matches the
42
+ semantics expected by tests: every emitted event must be
43
+ observable, and the test controls when the consumer drains.
44
+ """
45
+
46
+ def __init__(self) -> None:
47
+ self._send_queue: asyncio.Queue[FrameworkEvent] = asyncio.Queue()
48
+ self._recv_queue: asyncio.Queue[FrameworkEvent] = asyncio.Queue()
49
+
50
+ async def send_event(self, event: FrameworkEvent) -> None:
51
+ await self._send_queue.put(event)
52
+
53
+ async def receive_event(self) -> FrameworkEvent:
54
+ return await self._recv_queue.get()