roomkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roomkit/AGENTS.md +362 -0
- roomkit/__init__.py +372 -0
- roomkit/_version.py +1 -0
- roomkit/ai_docs.py +93 -0
- roomkit/channels/__init__.py +194 -0
- roomkit/channels/ai.py +238 -0
- roomkit/channels/base.py +66 -0
- roomkit/channels/transport.py +115 -0
- roomkit/channels/websocket.py +85 -0
- roomkit/core/__init__.py +0 -0
- roomkit/core/_channel_ops.py +252 -0
- roomkit/core/_helpers.py +296 -0
- roomkit/core/_inbound.py +435 -0
- roomkit/core/_room_lifecycle.py +275 -0
- roomkit/core/circuit_breaker.py +84 -0
- roomkit/core/event_router.py +401 -0
- roomkit/core/framework.py +793 -0
- roomkit/core/hooks.py +232 -0
- roomkit/core/inbound_router.py +57 -0
- roomkit/core/locks.py +66 -0
- roomkit/core/rate_limiter.py +67 -0
- roomkit/core/retry.py +49 -0
- roomkit/core/router.py +24 -0
- roomkit/core/transcoder.py +85 -0
- roomkit/identity/__init__.py +0 -0
- roomkit/identity/base.py +27 -0
- roomkit/identity/mock.py +49 -0
- roomkit/llms.txt +52 -0
- roomkit/models/__init__.py +104 -0
- roomkit/models/channel.py +99 -0
- roomkit/models/context.py +35 -0
- roomkit/models/delivery.py +76 -0
- roomkit/models/enums.py +170 -0
- roomkit/models/event.py +203 -0
- roomkit/models/framework_event.py +19 -0
- roomkit/models/hook.py +68 -0
- roomkit/models/identity.py +81 -0
- roomkit/models/participant.py +34 -0
- roomkit/models/room.py +33 -0
- roomkit/models/task.py +36 -0
- roomkit/providers/__init__.py +0 -0
- roomkit/providers/ai/__init__.py +0 -0
- roomkit/providers/ai/base.py +140 -0
- roomkit/providers/ai/mock.py +33 -0
- roomkit/providers/anthropic/__init__.py +6 -0
- roomkit/providers/anthropic/ai.py +145 -0
- roomkit/providers/anthropic/config.py +14 -0
- roomkit/providers/elasticemail/__init__.py +6 -0
- roomkit/providers/elasticemail/config.py +16 -0
- roomkit/providers/elasticemail/email.py +97 -0
- roomkit/providers/email/__init__.py +0 -0
- roomkit/providers/email/base.py +46 -0
- roomkit/providers/email/mock.py +34 -0
- roomkit/providers/gemini/__init__.py +6 -0
- roomkit/providers/gemini/ai.py +153 -0
- roomkit/providers/gemini/config.py +14 -0
- roomkit/providers/http/__init__.py +15 -0
- roomkit/providers/http/base.py +33 -0
- roomkit/providers/http/config.py +14 -0
- roomkit/providers/http/mock.py +21 -0
- roomkit/providers/http/provider.py +105 -0
- roomkit/providers/http/webhook.py +33 -0
- roomkit/providers/messenger/__init__.py +15 -0
- roomkit/providers/messenger/base.py +33 -0
- roomkit/providers/messenger/config.py +17 -0
- roomkit/providers/messenger/facebook.py +95 -0
- roomkit/providers/messenger/mock.py +21 -0
- roomkit/providers/messenger/webhook.py +42 -0
- roomkit/providers/openai/__init__.py +6 -0
- roomkit/providers/openai/ai.py +155 -0
- roomkit/providers/openai/config.py +24 -0
- roomkit/providers/pydantic_ai/__init__.py +5 -0
- roomkit/providers/pydantic_ai/config.py +14 -0
- roomkit/providers/rcs/__init__.py +9 -0
- roomkit/providers/rcs/base.py +95 -0
- roomkit/providers/rcs/mock.py +78 -0
- roomkit/providers/sendgrid/__init__.py +5 -0
- roomkit/providers/sendgrid/config.py +13 -0
- roomkit/providers/sinch/__init__.py +6 -0
- roomkit/providers/sinch/config.py +22 -0
- roomkit/providers/sinch/sms.py +192 -0
- roomkit/providers/sms/__init__.py +15 -0
- roomkit/providers/sms/base.py +67 -0
- roomkit/providers/sms/meta.py +401 -0
- roomkit/providers/sms/mock.py +24 -0
- roomkit/providers/sms/phone.py +77 -0
- roomkit/providers/telnyx/__init__.py +21 -0
- roomkit/providers/telnyx/config.py +14 -0
- roomkit/providers/telnyx/rcs.py +352 -0
- roomkit/providers/telnyx/sms.py +231 -0
- roomkit/providers/twilio/__init__.py +18 -0
- roomkit/providers/twilio/config.py +19 -0
- roomkit/providers/twilio/rcs.py +183 -0
- roomkit/providers/twilio/sms.py +200 -0
- roomkit/providers/voicemeup/__init__.py +15 -0
- roomkit/providers/voicemeup/config.py +21 -0
- roomkit/providers/voicemeup/sms.py +374 -0
- roomkit/providers/whatsapp/__init__.py +0 -0
- roomkit/providers/whatsapp/base.py +44 -0
- roomkit/providers/whatsapp/mock.py +21 -0
- roomkit/py.typed +0 -0
- roomkit/realtime/__init__.py +17 -0
- roomkit/realtime/base.py +111 -0
- roomkit/realtime/memory.py +158 -0
- roomkit/sources/__init__.py +35 -0
- roomkit/sources/base.py +207 -0
- roomkit/sources/websocket.py +260 -0
- roomkit/store/__init__.py +0 -0
- roomkit/store/base.py +230 -0
- roomkit/store/memory.py +293 -0
- roomkit-0.1.0.dist-info/METADATA +567 -0
- roomkit-0.1.0.dist-info/RECORD +114 -0
- roomkit-0.1.0.dist-info/WHEEL +4 -0
- roomkit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
"""In-memory realtime backend using asyncio queues."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import logging
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
|
|
11
|
+
from roomkit.realtime.base import EphemeralCallback, EphemeralEvent, RealtimeBackend
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("roomkit.realtime")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class InMemoryRealtime(RealtimeBackend):
|
|
17
|
+
"""In-process realtime backend using asyncio queues.
|
|
18
|
+
|
|
19
|
+
Suitable for single-process deployments. For multi-process or
|
|
20
|
+
distributed setups, provide a custom ``RealtimeBackend`` backed by
|
|
21
|
+
Redis pub/sub, NATS, or similar.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, max_queue_size: int = 100) -> None:
|
|
25
|
+
"""Initialize the in-memory realtime backend.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
max_queue_size: Maximum number of events to queue per subscription.
|
|
29
|
+
Older events are dropped when the queue is full (LRU-style).
|
|
30
|
+
"""
|
|
31
|
+
self._max_queue_size = max_queue_size
|
|
32
|
+
self._subscriptions: dict[str, _Subscription] = {}
|
|
33
|
+
self._channels: dict[str, set[str]] = {} # channel -> subscription_ids
|
|
34
|
+
self._closed = False
|
|
35
|
+
|
|
36
|
+
async def publish(self, channel: str, event: EphemeralEvent) -> None:
|
|
37
|
+
"""Publish an event to all subscribers on a channel."""
|
|
38
|
+
if self._closed:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
sub_ids = self._channels.get(channel, set())
|
|
42
|
+
for sub_id in sub_ids:
|
|
43
|
+
sub = self._subscriptions.get(sub_id)
|
|
44
|
+
if sub is not None:
|
|
45
|
+
await sub.enqueue(event)
|
|
46
|
+
|
|
47
|
+
async def subscribe(self, channel: str, callback: EphemeralCallback) -> str:
|
|
48
|
+
"""Subscribe to a channel with a callback.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A subscription ID that can be used to unsubscribe.
|
|
52
|
+
"""
|
|
53
|
+
sub_id = uuid4().hex
|
|
54
|
+
sub = _Subscription(
|
|
55
|
+
sub_id=sub_id,
|
|
56
|
+
channel=channel,
|
|
57
|
+
callback=callback,
|
|
58
|
+
max_queue_size=self._max_queue_size,
|
|
59
|
+
)
|
|
60
|
+
self._subscriptions[sub_id] = sub
|
|
61
|
+
|
|
62
|
+
if channel not in self._channels:
|
|
63
|
+
self._channels[channel] = set()
|
|
64
|
+
self._channels[channel].add(sub_id)
|
|
65
|
+
|
|
66
|
+
sub.start()
|
|
67
|
+
return sub_id
|
|
68
|
+
|
|
69
|
+
async def unsubscribe(self, subscription_id: str) -> bool:
|
|
70
|
+
"""Unsubscribe and stop the subscription task.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if the subscription existed and was removed.
|
|
74
|
+
"""
|
|
75
|
+
sub = self._subscriptions.pop(subscription_id, None)
|
|
76
|
+
if sub is None:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
channel_subs = self._channels.get(sub.channel)
|
|
80
|
+
if channel_subs:
|
|
81
|
+
channel_subs.discard(subscription_id)
|
|
82
|
+
if not channel_subs:
|
|
83
|
+
del self._channels[sub.channel]
|
|
84
|
+
|
|
85
|
+
await sub.stop()
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
async def close(self) -> None:
|
|
89
|
+
"""Stop all subscriptions and clean up."""
|
|
90
|
+
self._closed = True
|
|
91
|
+
for sub in list(self._subscriptions.values()):
|
|
92
|
+
await sub.stop()
|
|
93
|
+
self._subscriptions.clear()
|
|
94
|
+
self._channels.clear()
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def subscription_count(self) -> int:
|
|
98
|
+
"""Return the number of active subscriptions."""
|
|
99
|
+
return len(self._subscriptions)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class _Subscription:
|
|
103
|
+
"""Internal subscription handler with queue and background task."""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
sub_id: str,
|
|
108
|
+
channel: str,
|
|
109
|
+
callback: EphemeralCallback,
|
|
110
|
+
max_queue_size: int,
|
|
111
|
+
) -> None:
|
|
112
|
+
self.sub_id = sub_id
|
|
113
|
+
self.channel = channel
|
|
114
|
+
self.callback = callback
|
|
115
|
+
self._queue: OrderedDict[str, EphemeralEvent] = OrderedDict()
|
|
116
|
+
self._max_queue_size = max_queue_size
|
|
117
|
+
self._event = asyncio.Event()
|
|
118
|
+
self._task: asyncio.Task[None] | None = None
|
|
119
|
+
self._stopped = False
|
|
120
|
+
|
|
121
|
+
async def enqueue(self, event: EphemeralEvent) -> None:
|
|
122
|
+
"""Add an event to the queue, dropping oldest if full."""
|
|
123
|
+
if self._stopped:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
# Drop oldest if at capacity (LRU-style)
|
|
127
|
+
while len(self._queue) >= self._max_queue_size:
|
|
128
|
+
self._queue.popitem(last=False)
|
|
129
|
+
|
|
130
|
+
self._queue[event.id] = event
|
|
131
|
+
self._event.set()
|
|
132
|
+
|
|
133
|
+
def start(self) -> None:
|
|
134
|
+
"""Start the background task that drains the queue."""
|
|
135
|
+
self._task = asyncio.create_task(self._run())
|
|
136
|
+
|
|
137
|
+
async def stop(self) -> None:
|
|
138
|
+
"""Stop the background task."""
|
|
139
|
+
self._stopped = True
|
|
140
|
+
self._event.set()
|
|
141
|
+
if self._task is not None:
|
|
142
|
+
self._task.cancel()
|
|
143
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
144
|
+
await self._task
|
|
145
|
+
self._task = None
|
|
146
|
+
|
|
147
|
+
async def _run(self) -> None:
|
|
148
|
+
"""Background task that drains the queue and invokes callbacks."""
|
|
149
|
+
while not self._stopped:
|
|
150
|
+
await self._event.wait()
|
|
151
|
+
self._event.clear()
|
|
152
|
+
|
|
153
|
+
while self._queue and not self._stopped:
|
|
154
|
+
_, event = self._queue.popitem(last=False)
|
|
155
|
+
try:
|
|
156
|
+
await self.callback(event)
|
|
157
|
+
except Exception:
|
|
158
|
+
logger.exception("Error in realtime callback for subscription %s", self.sub_id)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Event-driven message sources for RoomKit."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from roomkit.sources.base import (
|
|
6
|
+
BaseSourceProvider,
|
|
7
|
+
EmitCallback,
|
|
8
|
+
SourceHealth,
|
|
9
|
+
SourceProvider,
|
|
10
|
+
SourceStatus,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BaseSourceProvider",
|
|
15
|
+
"EmitCallback",
|
|
16
|
+
"SourceHealth",
|
|
17
|
+
"SourceProvider",
|
|
18
|
+
"SourceStatus",
|
|
19
|
+
# Lazy imports for optional sources
|
|
20
|
+
"WebSocketSource",
|
|
21
|
+
"default_json_parser",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def __getattr__(name: str) -> Any:
|
|
26
|
+
"""Lazy import for optional source providers."""
|
|
27
|
+
if name == "WebSocketSource":
|
|
28
|
+
from roomkit.sources.websocket import WebSocketSource
|
|
29
|
+
|
|
30
|
+
return WebSocketSource
|
|
31
|
+
if name == "default_json_parser":
|
|
32
|
+
from roomkit.sources.websocket import default_json_parser
|
|
33
|
+
|
|
34
|
+
return default_json_parser
|
|
35
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
roomkit/sources/base.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Base abstraction for event-driven message sources."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from collections.abc import Awaitable, Callable
|
|
9
|
+
from datetime import UTC, datetime
|
|
10
|
+
from enum import StrEnum, unique
|
|
11
|
+
from typing import TYPE_CHECKING
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from roomkit.models.delivery import InboundMessage, InboundResult
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("roomkit.sources")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@unique
|
|
22
|
+
class SourceStatus(StrEnum):
|
|
23
|
+
"""Connection status for an event source."""
|
|
24
|
+
|
|
25
|
+
STOPPED = "stopped"
|
|
26
|
+
CONNECTING = "connecting"
|
|
27
|
+
CONNECTED = "connected"
|
|
28
|
+
RECONNECTING = "reconnecting"
|
|
29
|
+
ERROR = "error"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SourceHealth(BaseModel):
|
|
33
|
+
"""Health information for an event source."""
|
|
34
|
+
|
|
35
|
+
status: SourceStatus = SourceStatus.STOPPED
|
|
36
|
+
connected_at: datetime | None = None
|
|
37
|
+
last_message_at: datetime | None = None
|
|
38
|
+
messages_received: int = 0
|
|
39
|
+
error: str | None = None
|
|
40
|
+
metadata: dict[str, str] = Field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Type alias for the emit callback
|
|
44
|
+
EmitCallback = Callable[["InboundMessage"], Awaitable["InboundResult"]]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class SourceProvider(ABC):
|
|
48
|
+
"""Base class for event-driven message sources.
|
|
49
|
+
|
|
50
|
+
A SourceProvider actively listens for inbound messages from an external
|
|
51
|
+
system (WebSocket, NATS, SSE, WhatsApp via neonize, etc.) and emits them
|
|
52
|
+
into RoomKit's inbound pipeline.
|
|
53
|
+
|
|
54
|
+
Unlike webhook-based providers that receive HTTP POST requests, source
|
|
55
|
+
providers maintain persistent connections and push messages as they arrive.
|
|
56
|
+
|
|
57
|
+
Lifecycle:
|
|
58
|
+
1. Create source instance with configuration
|
|
59
|
+
2. Call `start(emit)` - source connects and begins listening
|
|
60
|
+
3. Source calls `emit(message)` for each inbound message
|
|
61
|
+
4. Call `stop()` to disconnect and cleanup
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
class MyWebSocketSource(SourceProvider):
|
|
65
|
+
def __init__(self, url: str):
|
|
66
|
+
self._url = url
|
|
67
|
+
self._ws = None
|
|
68
|
+
self._status = SourceStatus.STOPPED
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def name(self) -> str:
|
|
72
|
+
return f"websocket:{self._url}"
|
|
73
|
+
|
|
74
|
+
async def start(self, emit: EmitCallback) -> None:
|
|
75
|
+
self._status = SourceStatus.CONNECTING
|
|
76
|
+
async with websockets.connect(self._url) as ws:
|
|
77
|
+
self._ws = ws
|
|
78
|
+
self._status = SourceStatus.CONNECTED
|
|
79
|
+
async for message in ws:
|
|
80
|
+
inbound = self._parse(message)
|
|
81
|
+
await emit(inbound)
|
|
82
|
+
|
|
83
|
+
async def stop(self) -> None:
|
|
84
|
+
self._status = SourceStatus.STOPPED
|
|
85
|
+
if self._ws:
|
|
86
|
+
await self._ws.close()
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def name(self) -> str:
|
|
92
|
+
"""Unique identifier for this source instance.
|
|
93
|
+
|
|
94
|
+
Used for logging and framework events. Should be descriptive,
|
|
95
|
+
e.g. "neonize:session.db" or "nats:events.inbound".
|
|
96
|
+
"""
|
|
97
|
+
...
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
async def start(self, emit: EmitCallback) -> None:
|
|
101
|
+
"""Start receiving messages and emit them via the callback.
|
|
102
|
+
|
|
103
|
+
This method should:
|
|
104
|
+
1. Establish connection to the external system
|
|
105
|
+
2. Listen for incoming messages in a loop
|
|
106
|
+
3. Parse each message into an InboundMessage
|
|
107
|
+
4. Call `await emit(message)` for each message
|
|
108
|
+
5. Handle reconnection internally if the connection drops
|
|
109
|
+
|
|
110
|
+
The method should run until `stop()` is called or an unrecoverable
|
|
111
|
+
error occurs. For recoverable errors (network issues), implement
|
|
112
|
+
reconnection with backoff.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
emit: Callback to emit messages into RoomKit. Returns InboundResult
|
|
116
|
+
indicating whether the message was processed or blocked.
|
|
117
|
+
"""
|
|
118
|
+
...
|
|
119
|
+
|
|
120
|
+
@abstractmethod
|
|
121
|
+
async def stop(self) -> None:
|
|
122
|
+
"""Stop receiving messages and release resources.
|
|
123
|
+
|
|
124
|
+
This method should:
|
|
125
|
+
1. Signal the start() loop to exit
|
|
126
|
+
2. Close any open connections
|
|
127
|
+
3. Cancel any pending tasks
|
|
128
|
+
4. Release any held resources
|
|
129
|
+
|
|
130
|
+
After stop() returns, start() should be safe to call again.
|
|
131
|
+
"""
|
|
132
|
+
...
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def status(self) -> SourceStatus:
|
|
136
|
+
"""Current connection status.
|
|
137
|
+
|
|
138
|
+
Subclasses should override this to return the actual status.
|
|
139
|
+
Default implementation returns STOPPED.
|
|
140
|
+
"""
|
|
141
|
+
return SourceStatus.STOPPED
|
|
142
|
+
|
|
143
|
+
async def healthcheck(self) -> SourceHealth:
|
|
144
|
+
"""Return health information for monitoring.
|
|
145
|
+
|
|
146
|
+
Subclasses should override this to provide detailed health info
|
|
147
|
+
including message counts, timestamps, and any error details.
|
|
148
|
+
"""
|
|
149
|
+
return SourceHealth(status=self.status)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class BaseSourceProvider(SourceProvider):
|
|
153
|
+
"""Convenience base class with common source functionality.
|
|
154
|
+
|
|
155
|
+
Provides:
|
|
156
|
+
- Status tracking
|
|
157
|
+
- Message counting
|
|
158
|
+
- Timestamp tracking
|
|
159
|
+
- Stop signal via asyncio.Event
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
def __init__(self) -> None:
|
|
163
|
+
self._status = SourceStatus.STOPPED
|
|
164
|
+
self._connected_at: datetime | None = None
|
|
165
|
+
self._last_message_at: datetime | None = None
|
|
166
|
+
self._messages_received: int = 0
|
|
167
|
+
self._error: str | None = None
|
|
168
|
+
self._stop_event = asyncio.Event()
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def status(self) -> SourceStatus:
|
|
172
|
+
return self._status
|
|
173
|
+
|
|
174
|
+
async def healthcheck(self) -> SourceHealth:
|
|
175
|
+
return SourceHealth(
|
|
176
|
+
status=self._status,
|
|
177
|
+
connected_at=self._connected_at,
|
|
178
|
+
last_message_at=self._last_message_at,
|
|
179
|
+
messages_received=self._messages_received,
|
|
180
|
+
error=self._error,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def _set_status(self, status: SourceStatus, error: str | None = None) -> None:
|
|
184
|
+
"""Update status and optionally set error message."""
|
|
185
|
+
self._status = status
|
|
186
|
+
self._error = error
|
|
187
|
+
if status == SourceStatus.CONNECTED:
|
|
188
|
+
self._connected_at = datetime.now(UTC)
|
|
189
|
+
self._error = None
|
|
190
|
+
|
|
191
|
+
def _record_message(self) -> None:
|
|
192
|
+
"""Record that a message was received."""
|
|
193
|
+
self._messages_received += 1
|
|
194
|
+
self._last_message_at = datetime.now(UTC)
|
|
195
|
+
|
|
196
|
+
async def stop(self) -> None:
|
|
197
|
+
"""Signal the source to stop."""
|
|
198
|
+
self._stop_event.set()
|
|
199
|
+
self._status = SourceStatus.STOPPED
|
|
200
|
+
|
|
201
|
+
def _should_stop(self) -> bool:
|
|
202
|
+
"""Check if stop has been requested."""
|
|
203
|
+
return self._stop_event.is_set()
|
|
204
|
+
|
|
205
|
+
def _reset_stop(self) -> None:
|
|
206
|
+
"""Reset stop signal for restart."""
|
|
207
|
+
self._stop_event.clear()
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""WebSocket event source for RoomKit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
from collections.abc import Callable
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from roomkit.models.delivery import InboundMessage
|
|
13
|
+
from roomkit.models.event import TextContent
|
|
14
|
+
from roomkit.sources.base import BaseSourceProvider, EmitCallback, SourceStatus
|
|
15
|
+
|
|
16
|
+
# Optional dependency - import for type checking and availability check
|
|
17
|
+
try:
|
|
18
|
+
import websockets
|
|
19
|
+
from websockets import ClientConnection
|
|
20
|
+
|
|
21
|
+
HAS_WEBSOCKETS = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
websockets = None # type: ignore[assignment]
|
|
24
|
+
ClientConnection = None # type: ignore[assignment, misc]
|
|
25
|
+
HAS_WEBSOCKETS = False
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger("roomkit.sources.websocket")
|
|
28
|
+
|
|
29
|
+
# Type alias for message parser
|
|
30
|
+
MessageParser = Callable[[str | bytes], InboundMessage | None]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def default_json_parser(channel_id: str) -> MessageParser:
|
|
34
|
+
"""Create a default JSON message parser.
|
|
35
|
+
|
|
36
|
+
Expects messages in format:
|
|
37
|
+
{
|
|
38
|
+
"sender_id": "user123",
|
|
39
|
+
"text": "Hello world",
|
|
40
|
+
"external_id": "msg-456", # optional
|
|
41
|
+
"metadata": {} # optional
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
channel_id: Channel ID to use for parsed messages.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
A parser function that converts JSON to InboundMessage.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def parser(raw: str | bytes) -> InboundMessage | None:
|
|
52
|
+
try:
|
|
53
|
+
if isinstance(raw, bytes):
|
|
54
|
+
raw = raw.decode("utf-8")
|
|
55
|
+
|
|
56
|
+
data = json.loads(raw)
|
|
57
|
+
|
|
58
|
+
# Skip non-message events (e.g., pings, acks)
|
|
59
|
+
if not isinstance(data, dict):
|
|
60
|
+
return None
|
|
61
|
+
if "sender_id" not in data:
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
return InboundMessage(
|
|
65
|
+
channel_id=channel_id,
|
|
66
|
+
sender_id=data["sender_id"],
|
|
67
|
+
content=TextContent(body=data.get("text", "")),
|
|
68
|
+
external_id=data.get("external_id"),
|
|
69
|
+
metadata=data.get("metadata", {}),
|
|
70
|
+
)
|
|
71
|
+
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
72
|
+
logger.debug("Failed to parse message: %s", e)
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
return parser
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class WebSocketSource(BaseSourceProvider):
|
|
79
|
+
"""WebSocket client source for receiving messages.
|
|
80
|
+
|
|
81
|
+
Connects to a WebSocket server and emits parsed messages into RoomKit.
|
|
82
|
+
Handles reconnection automatically when the connection drops.
|
|
83
|
+
|
|
84
|
+
Example:
|
|
85
|
+
from roomkit import RoomKit
|
|
86
|
+
from roomkit.sources.websocket import WebSocketSource
|
|
87
|
+
|
|
88
|
+
# Simple usage with default JSON parser
|
|
89
|
+
source = WebSocketSource(
|
|
90
|
+
url="wss://chat.example.com/events",
|
|
91
|
+
channel_id="websocket-chat",
|
|
92
|
+
)
|
|
93
|
+
await kit.attach_source("websocket-chat", source)
|
|
94
|
+
|
|
95
|
+
# Custom parser for non-JSON messages
|
|
96
|
+
def my_parser(raw: str) -> InboundMessage | None:
|
|
97
|
+
parts = raw.split("|")
|
|
98
|
+
if len(parts) < 2:
|
|
99
|
+
return None
|
|
100
|
+
return InboundMessage(
|
|
101
|
+
channel_id="custom",
|
|
102
|
+
sender_id=parts[0],
|
|
103
|
+
content=TextContent(body=parts[1]),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
source = WebSocketSource(
|
|
107
|
+
url="wss://custom.example.com/stream",
|
|
108
|
+
channel_id="custom",
|
|
109
|
+
parser=my_parser,
|
|
110
|
+
)
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
url: str,
|
|
116
|
+
channel_id: str,
|
|
117
|
+
*,
|
|
118
|
+
parser: MessageParser | None = None,
|
|
119
|
+
headers: dict[str, str] | None = None,
|
|
120
|
+
subprotocols: list[str] | None = None,
|
|
121
|
+
ping_interval: float | None = 20.0,
|
|
122
|
+
ping_timeout: float | None = 20.0,
|
|
123
|
+
close_timeout: float = 10.0,
|
|
124
|
+
max_size: int = 2**20, # 1 MB
|
|
125
|
+
origin: str | None = None,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Initialize WebSocket source.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
url: WebSocket URL to connect to (ws:// or wss://).
|
|
131
|
+
channel_id: Channel ID for emitted messages.
|
|
132
|
+
parser: Function to parse raw messages into InboundMessage.
|
|
133
|
+
If None, uses default JSON parser.
|
|
134
|
+
headers: Additional HTTP headers for the connection.
|
|
135
|
+
subprotocols: WebSocket subprotocols to request.
|
|
136
|
+
ping_interval: Interval between ping frames in seconds.
|
|
137
|
+
Set to None to disable pings.
|
|
138
|
+
ping_timeout: Timeout for pong response in seconds.
|
|
139
|
+
close_timeout: Timeout for close handshake in seconds.
|
|
140
|
+
max_size: Maximum message size in bytes.
|
|
141
|
+
origin: Origin header value.
|
|
142
|
+
"""
|
|
143
|
+
super().__init__()
|
|
144
|
+
self._url = url
|
|
145
|
+
self._channel_id = channel_id
|
|
146
|
+
self._parser = parser or default_json_parser(channel_id)
|
|
147
|
+
self._headers = headers or {}
|
|
148
|
+
self._subprotocols = subprotocols
|
|
149
|
+
self._ping_interval = ping_interval
|
|
150
|
+
self._ping_timeout = ping_timeout
|
|
151
|
+
self._close_timeout = close_timeout
|
|
152
|
+
self._max_size = max_size
|
|
153
|
+
self._origin = origin
|
|
154
|
+
self._ws: ClientConnection | None = None
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def name(self) -> str:
|
|
158
|
+
return f"websocket:{self._url}"
|
|
159
|
+
|
|
160
|
+
async def start(self, emit: EmitCallback) -> None:
|
|
161
|
+
"""Connect and start receiving messages.
|
|
162
|
+
|
|
163
|
+
This method handles reconnection automatically using the websockets
|
|
164
|
+
library's built-in reconnection support.
|
|
165
|
+
"""
|
|
166
|
+
if not HAS_WEBSOCKETS:
|
|
167
|
+
raise ImportError(
|
|
168
|
+
"websockets is required for WebSocketSource. "
|
|
169
|
+
"Install it with: pip install roomkit[websocket]"
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
self._reset_stop()
|
|
173
|
+
self._set_status(SourceStatus.CONNECTING)
|
|
174
|
+
|
|
175
|
+
# Build connection kwargs
|
|
176
|
+
connect_kwargs: dict[str, Any] = {
|
|
177
|
+
"uri": self._url,
|
|
178
|
+
"additional_headers": self._headers if self._headers else None,
|
|
179
|
+
"subprotocols": self._subprotocols,
|
|
180
|
+
"ping_interval": self._ping_interval,
|
|
181
|
+
"ping_timeout": self._ping_timeout,
|
|
182
|
+
"close_timeout": self._close_timeout,
|
|
183
|
+
"max_size": self._max_size,
|
|
184
|
+
"origin": self._origin,
|
|
185
|
+
}
|
|
186
|
+
# Remove None values
|
|
187
|
+
connect_kwargs = {k: v for k, v in connect_kwargs.items() if v is not None}
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
async with websockets.connect(**connect_kwargs) as ws:
|
|
191
|
+
self._ws = ws
|
|
192
|
+
self._set_status(SourceStatus.CONNECTED)
|
|
193
|
+
logger.info("Connected to %s", self._url)
|
|
194
|
+
|
|
195
|
+
await self._receive_loop(ws, emit)
|
|
196
|
+
|
|
197
|
+
except asyncio.CancelledError:
|
|
198
|
+
raise
|
|
199
|
+
except Exception as e:
|
|
200
|
+
self._set_status(SourceStatus.ERROR, str(e))
|
|
201
|
+
raise
|
|
202
|
+
|
|
203
|
+
async def _receive_loop(
|
|
204
|
+
self,
|
|
205
|
+
ws: ClientConnection,
|
|
206
|
+
emit: EmitCallback,
|
|
207
|
+
) -> None:
|
|
208
|
+
"""Main receive loop - reads messages and emits them."""
|
|
209
|
+
while not self._should_stop():
|
|
210
|
+
try:
|
|
211
|
+
# Use wait_for to allow checking stop flag periodically
|
|
212
|
+
raw = await asyncio.wait_for(ws.recv(), timeout=1.0)
|
|
213
|
+
|
|
214
|
+
# Parse the message
|
|
215
|
+
message = self._parser(raw)
|
|
216
|
+
if message is not None:
|
|
217
|
+
result = await emit(message)
|
|
218
|
+
self._record_message()
|
|
219
|
+
|
|
220
|
+
if result.blocked:
|
|
221
|
+
logger.debug(
|
|
222
|
+
"Message blocked: %s",
|
|
223
|
+
result.reason,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
except TimeoutError:
|
|
227
|
+
# Normal timeout - check stop flag and continue
|
|
228
|
+
continue
|
|
229
|
+
except Exception:
|
|
230
|
+
# Connection error or other issue
|
|
231
|
+
if self._should_stop():
|
|
232
|
+
break
|
|
233
|
+
raise
|
|
234
|
+
|
|
235
|
+
async def stop(self) -> None:
|
|
236
|
+
"""Stop receiving and close the connection."""
|
|
237
|
+
await super().stop()
|
|
238
|
+
|
|
239
|
+
if self._ws is not None:
|
|
240
|
+
with contextlib.suppress(Exception):
|
|
241
|
+
await self._ws.close()
|
|
242
|
+
self._ws = None
|
|
243
|
+
|
|
244
|
+
logger.info("WebSocket source stopped")
|
|
245
|
+
|
|
246
|
+
async def send(self, message: str | bytes) -> None:
|
|
247
|
+
"""Send a message through the WebSocket connection.
|
|
248
|
+
|
|
249
|
+
This allows bidirectional communication if needed.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
message: Message to send (string or bytes).
|
|
253
|
+
|
|
254
|
+
Raises:
|
|
255
|
+
RuntimeError: If not connected.
|
|
256
|
+
"""
|
|
257
|
+
if self._ws is None or self._status != SourceStatus.CONNECTED:
|
|
258
|
+
raise RuntimeError("WebSocket not connected")
|
|
259
|
+
|
|
260
|
+
await self._ws.send(message)
|
|
File without changes
|