roomkit 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,467 @@
1
+ """FastRTC VoiceBackend implementation for RoomKit.
2
+
3
+ This module provides a VoiceBackend that uses FastRTC for WebSocket audio transport
4
+ with built-in VAD (Voice Activity Detection).
5
+
6
+ The backend:
7
+ - Handles WebSocket connections from clients
8
+ - Uses FastRTC's ReplyOnPause for VAD
9
+ - Calls VoiceBackend callbacks (on_speech_start, on_speech_end)
10
+ - Sends TTS audio back to clients via FastRTC
11
+
12
+ Requires the ``fastrtc`` optional dependency::
13
+
14
+ pip install roomkit[fastrtc]
15
+
16
+ Usage::
17
+
18
+ from roomkit.voice.backends.fastrtc import FastRTCVoiceBackend, mount_fastrtc_voice
19
+
20
+ backend = FastRTCVoiceBackend()
21
+ voice_channel = VoiceChannel("voice", stt=stt, tts=tts, backend=backend)
22
+ kit.register_channel(voice_channel)
23
+
24
+ # Mount FastRTC endpoints on FastAPI app (in lifespan)
25
+ mount_fastrtc_voice(app, backend, path="/fastrtc")
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import asyncio
31
+ import base64
32
+ import logging
33
+ import struct
34
+ import uuid
35
+ from collections.abc import AsyncIterator
36
+ from typing import TYPE_CHECKING, Any
37
+
38
+ from roomkit.voice.backends.base import VoiceBackend
39
+ from roomkit.voice.base import (
40
+ AudioChunk,
41
+ SpeechEndCallback,
42
+ SpeechStartCallback,
43
+ VoiceCapability,
44
+ VoiceSession,
45
+ VoiceSessionState,
46
+ )
47
+
48
+ if TYPE_CHECKING:
49
+ import numpy as np
50
+ from fastapi import FastAPI
51
+
52
+ logger = logging.getLogger("roomkit.voice.fastrtc")
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Pure-Python mu-law encoder (replaces audioop.lin2ulaw removed in Python 3.13)
56
+ # ---------------------------------------------------------------------------
57
+
58
+ # ITU-T G.711 mu-law compression bias and clip level
59
+ _MULAW_BIAS = 0x84
60
+ _MULAW_CLIP = 32635
61
+
62
+ # Precomputed lookup table: PCM-16 sample (unsigned magnitude) → mu-law byte.
63
+ # Built once at import time for O(1) encoding per sample.
64
+ _MULAW_TABLE: bytes | None = None
65
+
66
+
67
+ def _build_mulaw_table() -> bytes:
68
+ """Build a 16384-entry lookup table mapping 14-bit magnitude to mu-law 7-bit value.
69
+
70
+ Returns the lower 7 bits (exponent + mantissa) with bits inverted.
71
+ The caller must OR in the sign bit (0x80) separately.
72
+ """
73
+ table = bytearray(16384)
74
+ for i in range(16384):
75
+ sample = min(i, _MULAW_CLIP) + _MULAW_BIAS
76
+ exponent = 7
77
+ mask = 0x4000
78
+ while exponent > 0 and not (sample & mask):
79
+ exponent -= 1
80
+ mask >>= 1
81
+ mantissa = (sample >> (exponent + 3)) & 0x0F
82
+ table[i] = ~((exponent << 4) | mantissa) & 0x7F
83
+ return bytes(table)
84
+
85
+
86
+ def _pcm16_to_mulaw(pcm_data: bytes) -> bytes:
87
+ """Convert PCM-16 LE bytes to mu-law bytes (pure Python).
88
+
89
+ Each pair of bytes in *pcm_data* is interpreted as a signed 16-bit
90
+ little-endian sample and encoded to one mu-law byte per the ITU-T
91
+ G.711 standard.
92
+ """
93
+ global _MULAW_TABLE # noqa: PLW0603
94
+ if _MULAW_TABLE is None:
95
+ _MULAW_TABLE = _build_mulaw_table()
96
+
97
+ n_samples = len(pcm_data) // 2
98
+ samples = struct.unpack(f"<{n_samples}h", pcm_data[: n_samples * 2])
99
+ out = bytearray(n_samples)
100
+ table = _MULAW_TABLE
101
+ for i, s in enumerate(samples):
102
+ sign = 0x80 if s >= 0 else 0x00
103
+ magnitude = -s if s < 0 else s
104
+ magnitude = min(magnitude, _MULAW_CLIP)
105
+ # Shift right once to get a 14-bit index (15-bit magnitude → 14-bit)
106
+ out[i] = table[magnitude >> 2] | sign
107
+ return bytes(out)
108
+
109
+
110
+ class FastRTCVoiceBackend(VoiceBackend):
111
+ """VoiceBackend implementation using FastRTC for WebSocket audio transport.
112
+
113
+ This backend uses FastRTC's ReplyOnPause for Voice Activity Detection.
114
+ When speech is detected:
115
+ 1. on_speech_start callback is fired
116
+ 2. Audio is accumulated until pause is detected
117
+ 3. on_speech_end callback is fired with the audio bytes
118
+
119
+ The backend handles session management and audio streaming back to clients.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ *,
125
+ input_sample_rate: int = 48000,
126
+ output_sample_rate: int = 24000,
127
+ ) -> None:
128
+ """Initialize the FastRTC backend.
129
+
130
+ Args:
131
+ input_sample_rate: Expected sample rate of incoming audio (default 48kHz).
132
+ output_sample_rate: Sample rate for outgoing TTS audio (default 24kHz).
133
+ """
134
+ self._input_sample_rate = input_sample_rate
135
+ self._output_sample_rate = output_sample_rate
136
+
137
+ # Callbacks
138
+ self._speech_start_callback: SpeechStartCallback | None = None
139
+ self._speech_end_callback: SpeechEndCallback | None = None
140
+
141
+ # Session tracking: session_id -> VoiceSession
142
+ self._sessions: dict[str, VoiceSession] = {}
143
+
144
+ # FastRTC stream (set by mount_fastrtc_voice)
145
+ self._stream: Any = None
146
+
147
+ # Pending audio to send: session_id -> asyncio.Queue of audio chunks
148
+ self._audio_queues: dict[str, asyncio.Queue[AudioChunk | None]] = {}
149
+
150
+ # WebSocket references: session_id -> websocket
151
+ self._websockets: dict[str, Any] = {}
152
+
153
+ @property
154
+ def name(self) -> str:
155
+ return "FastRTC"
156
+
157
+ @property
158
+ def capabilities(self) -> VoiceCapability:
159
+ # FastRTC provides VAD via ReplyOnPause
160
+ return VoiceCapability.NONE
161
+
162
+ async def connect(
163
+ self,
164
+ room_id: str,
165
+ participant_id: str,
166
+ channel_id: str,
167
+ *,
168
+ metadata: dict[str, Any] | None = None,
169
+ ) -> VoiceSession:
170
+ """Create a new voice session.
171
+
172
+ Note: For FastRTC, sessions are created when clients connect via WebSocket.
173
+ This method is called by the application layer after the WebSocket handshake.
174
+ """
175
+ session_id = str(uuid.uuid4())
176
+ # Include audio parameters in metadata for STT to use
177
+ session_metadata = {
178
+ "input_sample_rate": self._input_sample_rate,
179
+ "output_sample_rate": self._output_sample_rate,
180
+ **(metadata or {}),
181
+ }
182
+ session = VoiceSession(
183
+ id=session_id,
184
+ room_id=room_id,
185
+ participant_id=participant_id,
186
+ channel_id=channel_id,
187
+ state=VoiceSessionState.ACTIVE,
188
+ metadata=session_metadata,
189
+ )
190
+ self._sessions[session_id] = session
191
+ self._audio_queues[session_id] = asyncio.Queue()
192
+ logger.info(
193
+ "Voice session created: session=%s, room=%s, participant=%s",
194
+ session_id,
195
+ room_id,
196
+ participant_id,
197
+ )
198
+ return session
199
+
200
+ async def disconnect(self, session: VoiceSession) -> None:
201
+ """End a voice session."""
202
+ session.state = VoiceSessionState.ENDED
203
+ self._sessions.pop(session.id, None)
204
+ self._audio_queues.pop(session.id, None)
205
+ self._websockets.pop(session.id, None)
206
+ logger.info("Voice session ended: session=%s", session.id)
207
+
208
+ def on_speech_start(self, callback: SpeechStartCallback) -> None:
209
+ """Register callback for speech start events."""
210
+ self._speech_start_callback = callback
211
+
212
+ def on_speech_end(self, callback: SpeechEndCallback) -> None:
213
+ """Register callback for speech end events."""
214
+ self._speech_end_callback = callback
215
+
216
+ def _resolve_websocket(self, session: VoiceSession) -> Any | None:
217
+ """Resolve WebSocket for a session.
218
+
219
+ First checks the explicit registry (populated by _register_websocket).
220
+ Falls back to looking up the websocket from the FastRTC Stream's
221
+ connection registry using the ``websocket_id`` stored in session
222
+ metadata. This fallback allows TTS to work even when the user is
223
+ muted (no speech event has fired yet to trigger registration).
224
+ """
225
+ ws = self._websockets.get(session.id)
226
+ if ws is not None:
227
+ return ws
228
+
229
+ # Fallback: look up via FastRTC Stream connections
230
+ ws_id = session.metadata.get("websocket_id")
231
+ if ws_id and self._stream and hasattr(self._stream, "connections"):
232
+ handlers = self._stream.connections.get(ws_id)
233
+ if handlers and hasattr(handlers[0], "websocket") and handlers[0].websocket:
234
+ # Auto-register so future lookups are O(1)
235
+ ws = handlers[0].websocket
236
+ self._websockets[session.id] = ws
237
+ logger.info(
238
+ "Websocket resolved from Stream connections: ws=%s session=%s",
239
+ ws_id,
240
+ session.id,
241
+ )
242
+ return ws
243
+
244
+ return None
245
+
246
+ async def send_audio(
247
+ self,
248
+ session: VoiceSession,
249
+ audio: bytes | AsyncIterator[AudioChunk],
250
+ ) -> None:
251
+ """Send audio to a voice session.
252
+
253
+ For FastRTC, audio is converted to mu-law and sent via WebSocket.
254
+ """
255
+ websocket = self._resolve_websocket(session)
256
+ if not websocket:
257
+ logger.warning("No WebSocket for session %s", session.id)
258
+ return
259
+
260
+ try:
261
+ if isinstance(audio, bytes):
262
+ # Single chunk
263
+ await self._send_mulaw_audio(websocket, audio)
264
+ else:
265
+ # Streaming
266
+ async for chunk in audio:
267
+ if chunk.data:
268
+ await self._send_mulaw_audio(websocket, chunk.data)
269
+ except Exception:
270
+ logger.exception("Error sending audio to session %s", session.id)
271
+
272
+ async def _send_mulaw_audio(self, websocket: Any, pcm_data: bytes) -> None:
273
+ """Convert PCM to mu-law and send via WebSocket."""
274
+ mulaw_data = _pcm16_to_mulaw(pcm_data)
275
+
276
+ # Send as base64 JSON
277
+ payload = base64.b64encode(mulaw_data).decode("utf-8")
278
+ await websocket.send_json({
279
+ "event": "media",
280
+ "media": {"payload": payload},
281
+ })
282
+
283
+ async def send_transcription(
284
+ self, session: VoiceSession, text: str, role: str = "user"
285
+ ) -> None:
286
+ """Send transcription text to the UI via WebSocket."""
287
+ websocket = self._resolve_websocket(session)
288
+ logger.info(
289
+ "send_transcription: session=%s, role=%s, has_websocket=%s, text=%s",
290
+ session.id,
291
+ role,
292
+ websocket is not None,
293
+ text[:50] if text else "",
294
+ )
295
+ if websocket:
296
+ try:
297
+ await websocket.send_json({
298
+ "type": "transcription",
299
+ "data": {"text": text, "role": role},
300
+ })
301
+ logger.info("Transcription sent to client")
302
+ except Exception:
303
+ logger.exception("Error sending transcription")
304
+ else:
305
+ logger.warning(
306
+ "No websocket for session %s, registered sessions: %s",
307
+ session.id,
308
+ list(self._websockets.keys()),
309
+ )
310
+
311
+ def get_session(self, session_id: str) -> VoiceSession | None:
312
+ """Get a session by ID."""
313
+ return self._sessions.get(session_id)
314
+
315
+ def list_sessions(self, room_id: str) -> list[VoiceSession]:
316
+ """List all active sessions in a room."""
317
+ return [s for s in self._sessions.values() if s.room_id == room_id]
318
+
319
+ async def close(self) -> None:
320
+ """Release resources."""
321
+ for session in list(self._sessions.values()):
322
+ await self.disconnect(session)
323
+
324
+ # -------------------------------------------------------------------------
325
+ # FastRTC integration methods (called by mount_fastrtc_voice)
326
+ # -------------------------------------------------------------------------
327
+
328
+ def _handle_speech_start(self, websocket_id: str) -> None:
329
+ """Called by FastRTC when VAD detects speech start."""
330
+ session = self._find_session_by_websocket_id(websocket_id)
331
+ if session and self._speech_start_callback:
332
+ self._speech_start_callback(session)
333
+
334
+ def _handle_speech_end(
335
+ self, websocket_id: str, audio_data: np.ndarray, sample_rate: int
336
+ ) -> None:
337
+ """Called by FastRTC when VAD detects speech end with audio."""
338
+ import numpy as _np
339
+
340
+ session = self._find_session_by_websocket_id(websocket_id)
341
+ if session and self._speech_end_callback:
342
+ # Convert numpy array to bytes
343
+ if audio_data.ndim > 1:
344
+ audio_data = audio_data.flatten()
345
+ if audio_data.dtype != _np.int16:
346
+ audio_data = (audio_data * 32767).astype(_np.int16)
347
+ audio_bytes = audio_data.tobytes()
348
+ self._speech_end_callback(session, audio_bytes)
349
+
350
+ def _register_websocket(
351
+ self, websocket_id: str, session_id: str, websocket: Any
352
+ ) -> None:
353
+ """Register a WebSocket connection for a session."""
354
+ self._websockets[session_id] = websocket
355
+ # Store websocket_id -> session_id mapping in session metadata
356
+ session = self._sessions.get(session_id)
357
+ if session:
358
+ session.metadata["websocket_id"] = websocket_id
359
+
360
+ def _find_session_by_websocket_id(self, websocket_id: str) -> VoiceSession | None:
361
+ """Find session by FastRTC websocket_id."""
362
+ for session in self._sessions.values():
363
+ if session.metadata.get("websocket_id") == websocket_id:
364
+ return session
365
+ return None
366
+
367
+
368
+ def mount_fastrtc_voice(
369
+ app: FastAPI,
370
+ backend: FastRTCVoiceBackend,
371
+ *,
372
+ path: str = "/fastrtc",
373
+ session_factory: Any = None,
374
+ ) -> None:
375
+ """Mount FastRTC voice endpoints on a FastAPI app.
376
+
377
+ This creates the WebSocket endpoint that FastRTC clients connect to.
378
+ The endpoint handles:
379
+ - WebSocket connection/disconnection
380
+ - Audio streaming with mu-law encoding
381
+ - VAD via FastRTC's ReplyOnPause
382
+
383
+ Args:
384
+ app: FastAPI application.
385
+ backend: The FastRTCVoiceBackend instance.
386
+ path: Base path for voice endpoints (default: /fastrtc).
387
+ session_factory: Async callable(websocket_id) -> VoiceSession that creates
388
+ sessions when clients connect. If not provided, sessions must be
389
+ created manually before clients connect.
390
+ """
391
+ import numpy as np
392
+ from fastrtc import ReplyOnPause, Stream
393
+
394
+ backend._session_factory = session_factory # type: ignore[attr-defined]
395
+
396
+ async def voice_handler(audio: tuple[int, np.ndarray]):
397
+ """FastRTC handler that bridges to VoiceBackend callbacks.
398
+
399
+ This is called by ReplyOnPause when speech is detected and pause occurs.
400
+ Instead of processing here, we call the backend's callbacks which
401
+ trigger VoiceChannel's pipeline (STT -> hooks -> AI -> TTS).
402
+ """
403
+ from fastrtc.utils import current_context
404
+
405
+ sample_rate, audio_data = audio
406
+
407
+ # Get the websocket_id and websocket from FastRTC context
408
+ ctx = current_context.get()
409
+ websocket_id = ctx.webrtc_id if ctx else None
410
+ websocket = ctx.websocket if ctx else None
411
+
412
+ if not websocket_id:
413
+ logger.warning("No websocket_id in context")
414
+ yield (sample_rate, np.zeros(sample_rate // 10, dtype=np.int16))
415
+ return
416
+
417
+ # Create session if not exists and we have a factory
418
+ session = backend._find_session_by_websocket_id(websocket_id)
419
+ if not session and backend._session_factory: # type: ignore[attr-defined]
420
+ try:
421
+ session = await backend._session_factory(websocket_id) # type: ignore[attr-defined]
422
+ if session and websocket:
423
+ backend._register_websocket(websocket_id, session.id, websocket)
424
+ except Exception:
425
+ logger.exception("Error creating session")
426
+
427
+ if not session:
428
+ logger.warning("No session for websocket_id=%s", websocket_id)
429
+ yield (sample_rate, np.zeros(sample_rate // 10, dtype=np.int16))
430
+ return
431
+
432
+ # Register websocket if not already registered
433
+ if websocket and session.id not in backend._websockets:
434
+ backend._register_websocket(websocket_id, session.id, websocket)
435
+
436
+ logger.info(
437
+ "Speech ended: session=%s, websocket_id=%s, samples=%d",
438
+ session.id,
439
+ websocket_id,
440
+ audio_data.size,
441
+ )
442
+
443
+ # Call the backend's speech end handler
444
+ # This triggers VoiceChannel._on_speech_end -> STT -> hooks -> AI
445
+ backend._handle_speech_end(websocket_id, audio_data, sample_rate)
446
+
447
+ # The response audio will be sent back via backend.send_audio()
448
+ # which is called by VoiceChannel._deliver_voice()
449
+ # Yield silence - actual response comes via direct WebSocket send
450
+ yield (sample_rate, np.zeros(sample_rate // 10, dtype=np.int16))
451
+
452
+ # Create FastRTC stream
453
+ stream = Stream(
454
+ handler=ReplyOnPause(
455
+ voice_handler,
456
+ input_sample_rate=backend._input_sample_rate,
457
+ output_sample_rate=backend._output_sample_rate,
458
+ ),
459
+ modality="audio",
460
+ mode="send-receive",
461
+ )
462
+
463
+ backend._stream = stream
464
+
465
+ # Mount the stream
466
+ stream.mount(app, path=path)
467
+ logger.info("FastRTC voice backend mounted at %s", path)