roomkit 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roomkit/__init__.py +45 -0
- roomkit/_version.py +1 -1
- roomkit/channels/voice.py +728 -0
- roomkit/core/_channel_ops.py +7 -0
- roomkit/core/_inbound.py +4 -0
- roomkit/core/framework.py +177 -1
- roomkit/core/hooks.py +32 -6
- roomkit/models/enums.py +12 -0
- roomkit/sources/__init__.py +4 -4
- roomkit/sources/sse.py +226 -0
- roomkit/voice/__init__.py +99 -0
- roomkit/voice/backends/__init__.py +1 -0
- roomkit/voice/backends/base.py +264 -0
- roomkit/voice/backends/fastrtc.py +467 -0
- roomkit/voice/backends/mock.py +302 -0
- roomkit/voice/base.py +115 -0
- roomkit/voice/events.py +140 -0
- roomkit/voice/stt/__init__.py +1 -0
- roomkit/voice/stt/base.py +58 -0
- roomkit/voice/stt/deepgram.py +214 -0
- roomkit/voice/stt/mock.py +40 -0
- roomkit/voice/tts/__init__.py +1 -0
- roomkit/voice/tts/base.py +58 -0
- roomkit/voice/tts/elevenlabs.py +329 -0
- roomkit/voice/tts/mock.py +51 -0
- {roomkit-0.1.0.dist-info → roomkit-0.2.0.dist-info}/METADATA +11 -2
- {roomkit-0.1.0.dist-info → roomkit-0.2.0.dist-info}/RECORD +29 -12
- {roomkit-0.1.0.dist-info → roomkit-0.2.0.dist-info}/WHEEL +1 -1
- {roomkit-0.1.0.dist-info → roomkit-0.2.0.dist-info}/licenses/LICENSE +0 -0
roomkit/core/_channel_ops.py
CHANGED
|
@@ -16,6 +16,7 @@ from roomkit.models.enums import (
|
|
|
16
16
|
|
|
17
17
|
if TYPE_CHECKING:
|
|
18
18
|
from roomkit.channels.base import Channel
|
|
19
|
+
from roomkit.channels.voice import VoiceChannel
|
|
19
20
|
from roomkit.core.event_router import EventRouter
|
|
20
21
|
from roomkit.core.locks import RoomLockManager
|
|
21
22
|
from roomkit.store.base import ConversationStore
|
|
@@ -31,9 +32,15 @@ class ChannelOpsMixin(HelpersMixin):
|
|
|
31
32
|
|
|
32
33
|
def register_channel(self, channel: Channel) -> None:
|
|
33
34
|
"""Register a channel implementation by its ID."""
|
|
35
|
+
from roomkit.channels.voice import VoiceChannel
|
|
36
|
+
|
|
34
37
|
self._channels[channel.channel_id] = channel
|
|
35
38
|
self._event_router = None # Reset router cache
|
|
36
39
|
|
|
40
|
+
# Set framework reference on VoiceChannel for inbound routing
|
|
41
|
+
if isinstance(channel, VoiceChannel):
|
|
42
|
+
channel.set_framework(self) # type: ignore[arg-type]
|
|
43
|
+
|
|
37
44
|
async def attach_channel(
|
|
38
45
|
self,
|
|
39
46
|
room_id: str,
|
roomkit/core/_inbound.py
CHANGED
|
@@ -369,6 +369,10 @@ class InboundMixin(HelpersMixin):
|
|
|
369
369
|
await self._store.add_event(blocked)
|
|
370
370
|
# Queue nested reentry events for further broadcasting
|
|
371
371
|
pending_reentries.extend(reentry_result.reentry_events)
|
|
372
|
+
# Run AFTER_BROADCAST hooks for reentry events (e.g., AI responses)
|
|
373
|
+
await self._hook_engine.run_async_hooks(
|
|
374
|
+
room_id, HookTrigger.AFTER_BROADCAST, reentry, reentry_ctx
|
|
375
|
+
)
|
|
372
376
|
|
|
373
377
|
# Persist side effects from hooks and broadcast
|
|
374
378
|
all_tasks = sync_result.tasks + broadcast_result.tasks
|
roomkit/core/framework.py
CHANGED
|
@@ -11,9 +11,15 @@ from typing import TYPE_CHECKING, Any
|
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from roomkit.models.delivery import InboundMessage, InboundResult
|
|
14
|
+
from roomkit.models.event import AudioContent
|
|
14
15
|
from roomkit.providers.sms.meta import WebhookMeta
|
|
16
|
+
from roomkit.voice.backends.base import VoiceBackend
|
|
17
|
+
from roomkit.voice.base import VoiceSession
|
|
18
|
+
from roomkit.voice.stt.base import STTProvider
|
|
19
|
+
from roomkit.voice.tts.base import TTSProvider
|
|
15
20
|
|
|
16
21
|
from roomkit.channels.base import Channel
|
|
22
|
+
from roomkit.channels.voice import VoiceChannel
|
|
17
23
|
from roomkit.channels.websocket import SendFn, WebSocketChannel
|
|
18
24
|
from roomkit.core._channel_ops import ChannelOpsMixin
|
|
19
25
|
from roomkit.core._helpers import FrameworkEventHandler, HelpersMixin, IdentityHookFn
|
|
@@ -69,6 +75,8 @@ __all__ = [
|
|
|
69
75
|
"RoomNotFoundError",
|
|
70
76
|
"SourceAlreadyAttachedError",
|
|
71
77
|
"SourceNotFoundError",
|
|
78
|
+
"VoiceBackendNotConfiguredError",
|
|
79
|
+
"VoiceNotConfiguredError",
|
|
72
80
|
]
|
|
73
81
|
|
|
74
82
|
|
|
@@ -104,6 +112,14 @@ class SourceNotFoundError(RoomKitError):
|
|
|
104
112
|
"""Source not found for channel."""
|
|
105
113
|
|
|
106
114
|
|
|
115
|
+
class VoiceNotConfiguredError(RoomKitError):
|
|
116
|
+
"""Raised when voice operation attempted without configured provider."""
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class VoiceBackendNotConfiguredError(RoomKitError):
|
|
120
|
+
"""Raised when voice backend operation attempted without configured backend."""
|
|
121
|
+
|
|
122
|
+
|
|
107
123
|
class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
108
124
|
"""Central orchestrator tying rooms, channels, hooks, and storage."""
|
|
109
125
|
|
|
@@ -118,6 +134,9 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
118
134
|
max_chain_depth: int = 5,
|
|
119
135
|
identity_timeout: float = 10.0,
|
|
120
136
|
process_timeout: float = 30.0,
|
|
137
|
+
stt: STTProvider | None = None,
|
|
138
|
+
tts: TTSProvider | None = None,
|
|
139
|
+
voice: VoiceBackend | None = None,
|
|
121
140
|
) -> None:
|
|
122
141
|
"""Initialise the RoomKit orchestrator.
|
|
123
142
|
|
|
@@ -138,6 +157,9 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
138
157
|
max_chain_depth: Maximum reentry chain depth to prevent infinite loops.
|
|
139
158
|
identity_timeout: Timeout in seconds for identity resolution calls.
|
|
140
159
|
process_timeout: Timeout in seconds for the locked processing phase.
|
|
160
|
+
stt: Optional speech-to-text provider for transcription.
|
|
161
|
+
tts: Optional text-to-speech provider for synthesis.
|
|
162
|
+
voice: Optional voice backend for real-time audio transport.
|
|
141
163
|
"""
|
|
142
164
|
self._store = store or InMemoryStore()
|
|
143
165
|
self._identity_resolver = identity_resolver
|
|
@@ -158,6 +180,10 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
158
180
|
# Event-driven sources
|
|
159
181
|
self._sources: dict[str, SourceProvider] = {}
|
|
160
182
|
self._source_tasks: dict[str, asyncio.Task[None]] = {}
|
|
183
|
+
# Voice support
|
|
184
|
+
self._stt = stt
|
|
185
|
+
self._tts = tts
|
|
186
|
+
self._voice = voice
|
|
161
187
|
|
|
162
188
|
@property
|
|
163
189
|
def store(self) -> ConversationStore:
|
|
@@ -174,6 +200,148 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
174
200
|
"""The realtime backend for ephemeral events."""
|
|
175
201
|
return self._realtime
|
|
176
202
|
|
|
203
|
+
@property
|
|
204
|
+
def stt(self) -> STTProvider | None:
|
|
205
|
+
"""Speech-to-text provider (optional)."""
|
|
206
|
+
return self._stt
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def tts(self) -> TTSProvider | None:
|
|
210
|
+
"""Text-to-speech provider (optional)."""
|
|
211
|
+
return self._tts
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def voice(self) -> VoiceBackend | None:
|
|
215
|
+
"""Voice backend for real-time audio (optional)."""
|
|
216
|
+
return self._voice
|
|
217
|
+
|
|
218
|
+
async def connect_voice(
|
|
219
|
+
self,
|
|
220
|
+
room_id: str,
|
|
221
|
+
participant_id: str,
|
|
222
|
+
channel_id: str,
|
|
223
|
+
*,
|
|
224
|
+
metadata: dict[str, Any] | None = None,
|
|
225
|
+
) -> VoiceSession:
|
|
226
|
+
"""Connect a participant to a voice session.
|
|
227
|
+
|
|
228
|
+
Creates a voice session via the configured VoiceBackend and binds it
|
|
229
|
+
to the specified room and voice channel for message routing.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
room_id: The room to join.
|
|
233
|
+
participant_id: The participant's ID.
|
|
234
|
+
channel_id: The voice channel ID.
|
|
235
|
+
metadata: Optional session metadata.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
A VoiceSession representing the connection.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
VoiceBackendNotConfiguredError: If no voice backend is configured.
|
|
242
|
+
ChannelNotRegisteredError: If the channel is not a VoiceChannel.
|
|
243
|
+
RoomNotFoundError: If the room doesn't exist.
|
|
244
|
+
"""
|
|
245
|
+
if self._voice is None:
|
|
246
|
+
raise VoiceBackendNotConfiguredError("No voice backend configured")
|
|
247
|
+
|
|
248
|
+
# Verify room exists
|
|
249
|
+
await self.get_room(room_id)
|
|
250
|
+
|
|
251
|
+
# Get the voice channel
|
|
252
|
+
channel = self._channels.get(channel_id)
|
|
253
|
+
if not isinstance(channel, VoiceChannel):
|
|
254
|
+
raise ChannelNotRegisteredError(
|
|
255
|
+
f"Channel {channel_id} is not a registered VoiceChannel"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Get the binding
|
|
259
|
+
binding = await self._store.get_binding(room_id, channel_id)
|
|
260
|
+
if binding is None:
|
|
261
|
+
raise ChannelNotFoundError(f"Channel {channel_id} not attached to room {room_id}")
|
|
262
|
+
|
|
263
|
+
# Create the session
|
|
264
|
+
session = await self._voice.connect(
|
|
265
|
+
room_id, participant_id, channel_id, metadata=metadata
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Bind session to channel for routing
|
|
269
|
+
channel.bind_session(session, room_id, binding)
|
|
270
|
+
|
|
271
|
+
await self._emit_framework_event(
|
|
272
|
+
"voice_connected",
|
|
273
|
+
room_id=room_id,
|
|
274
|
+
channel_id=channel_id,
|
|
275
|
+
data={
|
|
276
|
+
"session_id": session.id,
|
|
277
|
+
"participant_id": participant_id,
|
|
278
|
+
},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return session
|
|
282
|
+
|
|
283
|
+
async def disconnect_voice(self, session: VoiceSession) -> None:
|
|
284
|
+
"""Disconnect a voice session.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
session: The session to disconnect.
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
VoiceBackendNotConfiguredError: If no voice backend is configured.
|
|
291
|
+
"""
|
|
292
|
+
if self._voice is None:
|
|
293
|
+
raise VoiceBackendNotConfiguredError("No voice backend configured")
|
|
294
|
+
|
|
295
|
+
# Get the voice channel and unbind
|
|
296
|
+
channel = self._channels.get(session.channel_id)
|
|
297
|
+
if isinstance(channel, VoiceChannel):
|
|
298
|
+
channel.unbind_session(session)
|
|
299
|
+
|
|
300
|
+
await self._voice.disconnect(session)
|
|
301
|
+
|
|
302
|
+
await self._emit_framework_event(
|
|
303
|
+
"voice_disconnected",
|
|
304
|
+
room_id=session.room_id,
|
|
305
|
+
channel_id=session.channel_id,
|
|
306
|
+
data={
|
|
307
|
+
"session_id": session.id,
|
|
308
|
+
"participant_id": session.participant_id,
|
|
309
|
+
},
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
async def transcribe(self, audio: AudioContent) -> str:
|
|
313
|
+
"""Transcribe audio to text using configured STT provider.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
audio: AudioContent with URL to audio file.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
Transcribed text.
|
|
320
|
+
|
|
321
|
+
Raises:
|
|
322
|
+
VoiceNotConfiguredError: If no STT provider is configured.
|
|
323
|
+
"""
|
|
324
|
+
if self._stt is None:
|
|
325
|
+
raise VoiceNotConfiguredError("No STT provider configured")
|
|
326
|
+
return await self._stt.transcribe(audio)
|
|
327
|
+
|
|
328
|
+
async def synthesize(self, text: str, *, voice: str | None = None) -> AudioContent:
|
|
329
|
+
"""Synthesize text to audio using configured TTS provider.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
text: Text to synthesize.
|
|
333
|
+
voice: Optional voice ID (uses provider default if not specified).
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
AudioContent with URL to generated audio.
|
|
337
|
+
|
|
338
|
+
Raises:
|
|
339
|
+
VoiceNotConfiguredError: If no TTS provider is configured.
|
|
340
|
+
"""
|
|
341
|
+
if self._tts is None:
|
|
342
|
+
raise VoiceNotConfiguredError("No TTS provider configured")
|
|
343
|
+
return await self._tts.synthesize(text, voice=voice)
|
|
344
|
+
|
|
177
345
|
def _get_router(self) -> EventRouter:
|
|
178
346
|
if self._event_router is None:
|
|
179
347
|
self._event_router = EventRouter(
|
|
@@ -184,13 +352,16 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
184
352
|
return self._event_router
|
|
185
353
|
|
|
186
354
|
async def close(self) -> None:
|
|
187
|
-
"""Close all sources, channels, and the realtime backend."""
|
|
355
|
+
"""Close all sources, channels, voice backend, and the realtime backend."""
|
|
188
356
|
# Stop all event sources first
|
|
189
357
|
for channel_id in list(self._sources.keys()):
|
|
190
358
|
await self.detach_source(channel_id)
|
|
191
359
|
# Then close channels
|
|
192
360
|
for channel in self._channels.values():
|
|
193
361
|
await channel.close()
|
|
362
|
+
# Close voice backend
|
|
363
|
+
if self._voice:
|
|
364
|
+
await self._voice.close()
|
|
194
365
|
await self._realtime.close()
|
|
195
366
|
|
|
196
367
|
async def __aenter__(self) -> RoomKit:
|
|
@@ -277,6 +448,11 @@ class RoomKit(InboundMixin, ChannelOpsMixin, RoomLifecycleMixin, HelpersMixin):
|
|
|
277
448
|
router = self._get_router()
|
|
278
449
|
await router.broadcast(event, binding, context)
|
|
279
450
|
|
|
451
|
+
# Run AFTER_BROADCAST hooks for observability and fan-out
|
|
452
|
+
await self._hook_engine.run_async_hooks(
|
|
453
|
+
room_id, HookTrigger.AFTER_BROADCAST, event, context
|
|
454
|
+
)
|
|
455
|
+
|
|
280
456
|
return event
|
|
281
457
|
|
|
282
458
|
# -- WebSocket lifecycle --
|
roomkit/core/hooks.py
CHANGED
|
@@ -152,11 +152,24 @@ class HookEngine:
|
|
|
152
152
|
self,
|
|
153
153
|
room_id: str,
|
|
154
154
|
trigger: HookTrigger,
|
|
155
|
-
event: RoomEvent,
|
|
155
|
+
event: RoomEvent | Any,
|
|
156
156
|
context: RoomContext,
|
|
157
|
+
*,
|
|
158
|
+
skip_event_filter: bool = False,
|
|
157
159
|
) -> SyncPipelineResult:
|
|
158
|
-
"""Run sync hooks sequentially. Stops on block, passes modified events.
|
|
159
|
-
|
|
160
|
+
"""Run sync hooks sequentially. Stops on block, passes modified events.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
room_id: The room ID to run hooks for.
|
|
164
|
+
trigger: The hook trigger type.
|
|
165
|
+
event: The event to pass to hooks. For voice hooks, this may be
|
|
166
|
+
a VoiceSession or str instead of RoomEvent.
|
|
167
|
+
context: The room context.
|
|
168
|
+
skip_event_filter: If True, skip channel-based event filtering.
|
|
169
|
+
Use this for voice hooks where event is not a RoomEvent.
|
|
170
|
+
"""
|
|
171
|
+
filter_event = None if skip_event_filter else event
|
|
172
|
+
hooks = self._get_hooks(room_id, trigger, HookExecution.SYNC, event=filter_event)
|
|
160
173
|
result = SyncPipelineResult(event=event)
|
|
161
174
|
|
|
162
175
|
for hook in hooks:
|
|
@@ -201,11 +214,24 @@ class HookEngine:
|
|
|
201
214
|
self,
|
|
202
215
|
room_id: str,
|
|
203
216
|
trigger: HookTrigger,
|
|
204
|
-
event: RoomEvent,
|
|
217
|
+
event: RoomEvent | Any,
|
|
205
218
|
context: RoomContext,
|
|
219
|
+
*,
|
|
220
|
+
skip_event_filter: bool = False,
|
|
206
221
|
) -> None:
|
|
207
|
-
"""Run async hooks concurrently. Errors are logged, never raised.
|
|
208
|
-
|
|
222
|
+
"""Run async hooks concurrently. Errors are logged, never raised.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
room_id: The room ID to run hooks for.
|
|
226
|
+
trigger: The hook trigger type.
|
|
227
|
+
event: The event to pass to hooks. For voice hooks, this may be
|
|
228
|
+
a VoiceSession or str instead of RoomEvent.
|
|
229
|
+
context: The room context.
|
|
230
|
+
skip_event_filter: If True, skip channel-based event filtering.
|
|
231
|
+
Use this for voice hooks where event is not a RoomEvent.
|
|
232
|
+
"""
|
|
233
|
+
filter_event = None if skip_event_filter else event
|
|
234
|
+
hooks = self._get_hooks(room_id, trigger, HookExecution.ASYNC, event=filter_event)
|
|
209
235
|
if not hooks:
|
|
210
236
|
return
|
|
211
237
|
|
roomkit/models/enums.py
CHANGED
|
@@ -162,6 +162,18 @@ class HookTrigger(StrEnum):
|
|
|
162
162
|
ON_ERROR = "on_error"
|
|
163
163
|
# Delivery status (outbound message tracking)
|
|
164
164
|
ON_DELIVERY_STATUS = "on_delivery_status"
|
|
165
|
+
# Voice (RFC §18)
|
|
166
|
+
ON_SPEECH_START = "on_speech_start"
|
|
167
|
+
ON_SPEECH_END = "on_speech_end"
|
|
168
|
+
ON_TRANSCRIPTION = "on_transcription"
|
|
169
|
+
BEFORE_TTS = "before_tts"
|
|
170
|
+
AFTER_TTS = "after_tts"
|
|
171
|
+
# Voice - Enhanced (RFC §19)
|
|
172
|
+
ON_BARGE_IN = "on_barge_in"
|
|
173
|
+
ON_TTS_CANCELLED = "on_tts_cancelled"
|
|
174
|
+
ON_PARTIAL_TRANSCRIPTION = "on_partial_transcription"
|
|
175
|
+
ON_VAD_SILENCE = "on_vad_silence"
|
|
176
|
+
ON_VAD_AUDIO_LEVEL = "on_vad_audio_level"
|
|
165
177
|
|
|
166
178
|
|
|
167
179
|
@unique
|
roomkit/sources/__init__.py
CHANGED
|
@@ -18,7 +18,7 @@ __all__ = [
|
|
|
18
18
|
"SourceStatus",
|
|
19
19
|
# Lazy imports for optional sources
|
|
20
20
|
"WebSocketSource",
|
|
21
|
-
"
|
|
21
|
+
"SSESource",
|
|
22
22
|
]
|
|
23
23
|
|
|
24
24
|
|
|
@@ -28,8 +28,8 @@ def __getattr__(name: str) -> Any:
|
|
|
28
28
|
from roomkit.sources.websocket import WebSocketSource
|
|
29
29
|
|
|
30
30
|
return WebSocketSource
|
|
31
|
-
if name == "
|
|
32
|
-
from roomkit.sources.
|
|
31
|
+
if name == "SSESource":
|
|
32
|
+
from roomkit.sources.sse import SSESource
|
|
33
33
|
|
|
34
|
-
return
|
|
34
|
+
return SSESource
|
|
35
35
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
roomkit/sources/sse.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""Server-Sent Events (SSE) source for RoomKit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from roomkit.models.delivery import InboundMessage
|
|
13
|
+
from roomkit.models.event import TextContent
|
|
14
|
+
from roomkit.sources.base import BaseSourceProvider, EmitCallback, SourceStatus
|
|
15
|
+
|
|
16
|
+
# Optional dependency - import for availability check
|
|
17
|
+
try:
|
|
18
|
+
import httpx
|
|
19
|
+
from httpx_sse import aconnect_sse
|
|
20
|
+
|
|
21
|
+
HAS_SSE = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
httpx = None # type: ignore[assignment]
|
|
24
|
+
aconnect_sse = None # type: ignore[assignment]
|
|
25
|
+
HAS_SSE = False
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("roomkit.sources.sse")
|
|
28
|
+
|
|
29
|
+
# Type alias for event parser
|
|
30
|
+
SSEEventParser = Callable[[str, str, str | None], InboundMessage | None]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def default_json_parser(channel_id: str) -> SSEEventParser:
|
|
34
|
+
"""Create a default JSON event parser.
|
|
35
|
+
|
|
36
|
+
Expects SSE data field to contain JSON:
|
|
37
|
+
{
|
|
38
|
+
"sender_id": "user123",
|
|
39
|
+
"text": "Hello world",
|
|
40
|
+
"external_id": "msg-456", # optional
|
|
41
|
+
"metadata": {} # optional
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
channel_id: Channel ID to use for parsed messages.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A parser function that converts SSE events to InboundMessage.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def parser(event: str, data: str, event_id: str | None) -> InboundMessage | None:
|
|
52
|
+
# Skip non-message events (e.g., heartbeats, pings)
|
|
53
|
+
if event not in ("message", "msg", "chat", ""):
|
|
54
|
+
logger.debug("Skipping event type: %s", event)
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
payload = json.loads(data)
|
|
59
|
+
|
|
60
|
+
if not isinstance(payload, dict):
|
|
61
|
+
return None
|
|
62
|
+
if "sender_id" not in payload:
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
return InboundMessage(
|
|
66
|
+
channel_id=channel_id,
|
|
67
|
+
sender_id=payload["sender_id"],
|
|
68
|
+
content=TextContent(body=payload.get("text", "")),
|
|
69
|
+
external_id=payload.get("external_id") or event_id,
|
|
70
|
+
metadata=payload.get("metadata", {}),
|
|
71
|
+
)
|
|
72
|
+
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
73
|
+
logger.debug("Failed to parse SSE data: %s", e)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
return parser
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class SSESource(BaseSourceProvider):
|
|
80
|
+
"""Server-Sent Events (SSE) source for receiving messages.
|
|
81
|
+
|
|
82
|
+
Connects to an SSE endpoint and emits parsed events into RoomKit.
|
|
83
|
+
Handles reconnection automatically when the connection drops.
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
from roomkit import RoomKit
|
|
87
|
+
from roomkit.sources import SSESource
|
|
88
|
+
|
|
89
|
+
# Simple usage with default JSON parser
|
|
90
|
+
source = SSESource(
|
|
91
|
+
url="https://api.example.com/events",
|
|
92
|
+
channel_id="sse-events",
|
|
93
|
+
)
|
|
94
|
+
await kit.attach_source("sse-events", source)
|
|
95
|
+
|
|
96
|
+
# With authentication
|
|
97
|
+
source = SSESource(
|
|
98
|
+
url="https://api.example.com/events",
|
|
99
|
+
channel_id="sse-events",
|
|
100
|
+
headers={"Authorization": "Bearer token123"},
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Custom parser for non-JSON events
|
|
104
|
+
def my_parser(event: str, data: str, event_id: str | None) -> InboundMessage | None:
|
|
105
|
+
if event != "chat":
|
|
106
|
+
return None
|
|
107
|
+
return InboundMessage(
|
|
108
|
+
channel_id="custom",
|
|
109
|
+
sender_id="system",
|
|
110
|
+
content=TextContent(body=data),
|
|
111
|
+
external_id=event_id,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
source = SSESource(
|
|
115
|
+
url="https://stream.example.com/chat",
|
|
116
|
+
channel_id="custom",
|
|
117
|
+
parser=my_parser,
|
|
118
|
+
)
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
url: str,
|
|
124
|
+
channel_id: str,
|
|
125
|
+
*,
|
|
126
|
+
parser: SSEEventParser | None = None,
|
|
127
|
+
headers: dict[str, str] | None = None,
|
|
128
|
+
params: dict[str, str] | None = None,
|
|
129
|
+
timeout: float = 30.0,
|
|
130
|
+
last_event_id: str | None = None,
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Initialize SSE source.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
url: SSE endpoint URL.
|
|
136
|
+
channel_id: Channel ID for emitted messages.
|
|
137
|
+
parser: Function to parse SSE events into InboundMessage.
|
|
138
|
+
Receives (event_type, data, event_id) and returns InboundMessage or None.
|
|
139
|
+
If None, uses default JSON parser.
|
|
140
|
+
headers: HTTP headers for the request (e.g., Authorization).
|
|
141
|
+
params: Query parameters for the URL.
|
|
142
|
+
timeout: Connection timeout in seconds.
|
|
143
|
+
last_event_id: Resume from this event ID (sent as Last-Event-ID header).
|
|
144
|
+
"""
|
|
145
|
+
super().__init__()
|
|
146
|
+
self._url = url
|
|
147
|
+
self._channel_id = channel_id
|
|
148
|
+
self._parser = parser or default_json_parser(channel_id)
|
|
149
|
+
self._headers = headers or {}
|
|
150
|
+
self._params = params or {}
|
|
151
|
+
self._timeout = timeout
|
|
152
|
+
self._last_event_id = last_event_id
|
|
153
|
+
self._client: Any = None
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def name(self) -> str:
|
|
157
|
+
return f"sse:{self._url}"
|
|
158
|
+
|
|
159
|
+
async def start(self, emit: EmitCallback) -> None:
|
|
160
|
+
"""Connect and start receiving SSE events."""
|
|
161
|
+
if not HAS_SSE:
|
|
162
|
+
raise ImportError(
|
|
163
|
+
"httpx and httpx-sse are required for SSESource. "
|
|
164
|
+
"Install with: pip install roomkit[sse]"
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
self._reset_stop()
|
|
168
|
+
self._set_status(SourceStatus.CONNECTING)
|
|
169
|
+
|
|
170
|
+
# Build headers with Last-Event-ID if resuming
|
|
171
|
+
headers = dict(self._headers)
|
|
172
|
+
if self._last_event_id:
|
|
173
|
+
headers["Last-Event-ID"] = self._last_event_id
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
177
|
+
self._client = client
|
|
178
|
+
|
|
179
|
+
async with aconnect_sse(
|
|
180
|
+
client,
|
|
181
|
+
"GET",
|
|
182
|
+
self._url,
|
|
183
|
+
headers=headers,
|
|
184
|
+
params=self._params,
|
|
185
|
+
) as event_source:
|
|
186
|
+
self._set_status(SourceStatus.CONNECTED)
|
|
187
|
+
logger.info("Connected to SSE endpoint: %s", self._url)
|
|
188
|
+
|
|
189
|
+
await self._receive_loop(event_source, emit)
|
|
190
|
+
|
|
191
|
+
except asyncio.CancelledError:
|
|
192
|
+
raise
|
|
193
|
+
except Exception as e:
|
|
194
|
+
self._set_status(SourceStatus.ERROR, str(e))
|
|
195
|
+
raise
|
|
196
|
+
finally:
|
|
197
|
+
self._client = None
|
|
198
|
+
|
|
199
|
+
async def _receive_loop(self, event_source: Any, emit: EmitCallback) -> None:
|
|
200
|
+
"""Main receive loop - reads SSE events and emits them."""
|
|
201
|
+
async for sse in event_source.aiter_sse():
|
|
202
|
+
if self._should_stop():
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
# Track last event ID for potential reconnection
|
|
206
|
+
if sse.id:
|
|
207
|
+
self._last_event_id = sse.id
|
|
208
|
+
|
|
209
|
+
# Parse the event
|
|
210
|
+
message = self._parser(sse.event, sse.data, sse.id)
|
|
211
|
+
if message is not None:
|
|
212
|
+
result = await emit(message)
|
|
213
|
+
self._record_message()
|
|
214
|
+
|
|
215
|
+
if result.blocked:
|
|
216
|
+
logger.debug("Message blocked: %s", result.reason)
|
|
217
|
+
|
|
218
|
+
async def stop(self) -> None:
|
|
219
|
+
"""Stop receiving and close the connection."""
|
|
220
|
+
await super().stop()
|
|
221
|
+
logger.info("SSE source stopped")
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def last_event_id(self) -> str | None:
|
|
225
|
+
"""Get the last received event ID for resumption."""
|
|
226
|
+
return self._last_event_id
|