vention-communication 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- communication/__init__.py +0 -0
- communication/app.py +88 -0
- communication/codegen.py +392 -0
- communication/connect_router.py +325 -0
- communication/decorators.py +110 -0
- communication/entries.py +42 -0
- communication/errors.py +59 -0
- communication/typing_utils.py +90 -0
- vention_communication-0.1.0.dist-info/METADATA +302 -0
- vention_communication-0.1.0.dist-info/RECORD +11 -0
- vention_communication-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any, Dict
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, Request
|
|
10
|
+
from fastapi.responses import JSONResponse, StreamingResponse
|
|
11
|
+
|
|
12
|
+
from .entries import ActionEntry, StreamEntry
|
|
13
|
+
from .errors import error_envelope, to_connect_error
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
CONTENT_TYPE = "application/connect+json"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _frame(payload: Dict[str, Any], *, trailer: bool = False) -> bytes:
|
|
21
|
+
"""Create a Connect protocol frame for streaming responses.
|
|
22
|
+
|
|
23
|
+
Connect uses a binary framing format for streaming JSON over HTTP. Each frame
|
|
24
|
+
consists of:
|
|
25
|
+
- 1 byte: Flags (0x00 for data, 0x80 for trailer/end-of-stream)
|
|
26
|
+
- 4 bytes: Body length in big-endian format
|
|
27
|
+
- N bytes: JSON-encoded payload
|
|
28
|
+
"""
|
|
29
|
+
body = json.dumps(payload, separators=(",", ":"), ensure_ascii=False).encode(
|
|
30
|
+
"utf-8"
|
|
31
|
+
)
|
|
32
|
+
flag = 0x80 if trailer else 0x00
|
|
33
|
+
header = bytes([flag]) + len(body).to_bytes(4, byteorder="big", signed=False)
|
|
34
|
+
return header + body
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---------- Subscriber ----------
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass(eq=False, unsafe_hash=True)
|
|
41
|
+
class _Subscriber:
|
|
42
|
+
queue: asyncio.Queue[Any]
|
|
43
|
+
joined_at: float = field(default_factory=lambda: time.time())
|
|
44
|
+
last_send_at: float = field(default_factory=lambda: time.time())
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class StreamManager:
|
|
48
|
+
"""Topic-oriented fan-out with a distributor task per stream.
|
|
49
|
+
|
|
50
|
+
Supports configurable delivery policies: "latest" (drops old items when queue is full)
|
|
51
|
+
or "fifo" (waits for space to ensure all items are delivered).
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self) -> None:
|
|
55
|
+
"""Initialize the StreamManager with an empty topic registry."""
|
|
56
|
+
self._topics: Dict[str, Dict[str, Any]] = {}
|
|
57
|
+
|
|
58
|
+
def ensure_topic(self, entry: StreamEntry) -> None:
|
|
59
|
+
"""Create a topic if it doesn't exist (synchronous, safe before loop)."""
|
|
60
|
+
if entry.name in self._topics:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
publish_queue: asyncio.Queue[Any] = asyncio.Queue()
|
|
64
|
+
subscribers: set[_Subscriber] = set()
|
|
65
|
+
topic = {
|
|
66
|
+
"entry": entry,
|
|
67
|
+
"publish_queue": publish_queue,
|
|
68
|
+
"subscribers": subscribers,
|
|
69
|
+
"last_value": None,
|
|
70
|
+
"task": None,
|
|
71
|
+
}
|
|
72
|
+
self._topics[entry.name] = topic
|
|
73
|
+
|
|
74
|
+
# If an event loop is running, start distributor immediately
|
|
75
|
+
try:
|
|
76
|
+
loop = asyncio.get_running_loop()
|
|
77
|
+
topic["task"] = loop.create_task(self._distributor(entry.name))
|
|
78
|
+
except RuntimeError:
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def start_distributor_if_needed(self, stream_name: str) -> None:
|
|
82
|
+
"""Start the distributor task if it doesn't exist or has completed.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
stream_name: Name of the stream topic
|
|
86
|
+
"""
|
|
87
|
+
topic = self._topics[stream_name]
|
|
88
|
+
if topic["task"] is None or topic["task"].done():
|
|
89
|
+
try:
|
|
90
|
+
loop = asyncio.get_running_loop()
|
|
91
|
+
except RuntimeError:
|
|
92
|
+
loop = asyncio.get_event_loop()
|
|
93
|
+
topic["task"] = loop.create_task(self._distributor(stream_name))
|
|
94
|
+
|
|
95
|
+
def add_stream(self, entry: StreamEntry) -> None:
|
|
96
|
+
"""Register a stream entry and ensure its topic exists.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
entry: Stream entry to register
|
|
100
|
+
"""
|
|
101
|
+
self.ensure_topic(entry)
|
|
102
|
+
|
|
103
|
+
def publish(self, stream_name: str, item: Any) -> None:
|
|
104
|
+
"""Publish an item to a stream topic.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
stream_name: Name of the stream topic
|
|
108
|
+
item: Item to publish
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
KeyError: If the stream topic doesn't exist
|
|
112
|
+
"""
|
|
113
|
+
topic = self._topics.get(stream_name)
|
|
114
|
+
if topic is None:
|
|
115
|
+
raise KeyError(f"Unknown stream: {stream_name}")
|
|
116
|
+
self.start_distributor_if_needed(stream_name)
|
|
117
|
+
topic["publish_queue"].put_nowait(item)
|
|
118
|
+
|
|
119
|
+
def subscribe(self, stream_name: str) -> _Subscriber:
|
|
120
|
+
"""Subscribe to a stream topic and return a subscriber.
|
|
121
|
+
|
|
122
|
+
If replay is enabled, the last published value will be enqueued
|
|
123
|
+
for the new subscriber.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
stream_name: Name of the stream topic
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
A subscriber instance with its own queue
|
|
130
|
+
"""
|
|
131
|
+
topic = self._topics[stream_name]
|
|
132
|
+
entry: StreamEntry = topic["entry"]
|
|
133
|
+
subscriber_queue: asyncio.Queue[Any] = asyncio.Queue(
|
|
134
|
+
maxsize=entry.queue_maxsize
|
|
135
|
+
)
|
|
136
|
+
subscriber = _Subscriber(queue=subscriber_queue)
|
|
137
|
+
topic["subscribers"].add(subscriber)
|
|
138
|
+
self.start_distributor_if_needed(stream_name)
|
|
139
|
+
|
|
140
|
+
if entry.replay and topic["last_value"] is not None:
|
|
141
|
+
try:
|
|
142
|
+
subscriber_queue.put_nowait(topic["last_value"])
|
|
143
|
+
except asyncio.QueueFull:
|
|
144
|
+
_ = subscriber_queue.get_nowait()
|
|
145
|
+
subscriber_queue.put_nowait(topic["last_value"])
|
|
146
|
+
return subscriber
|
|
147
|
+
|
|
148
|
+
def unsubscribe(self, stream_name: str, subscriber: _Subscriber) -> None:
|
|
149
|
+
"""Remove a subscriber from a stream topic.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
stream_name: Name of the stream topic
|
|
153
|
+
subscriber: Subscriber to remove
|
|
154
|
+
"""
|
|
155
|
+
topic = self._topics.get(stream_name)
|
|
156
|
+
if topic:
|
|
157
|
+
topic["subscribers"].discard(subscriber)
|
|
158
|
+
|
|
159
|
+
async def _distributor(self, stream_name: str) -> None:
|
|
160
|
+
"""Distribute published items to all subscribers of a stream.
|
|
161
|
+
|
|
162
|
+
For FIFO policy, waits for queue space. For latest-wins policy,
|
|
163
|
+
drops oldest items when queues are full.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
stream_name: Name of the stream topic to distribute
|
|
167
|
+
"""
|
|
168
|
+
topic = self._topics[stream_name]
|
|
169
|
+
publish_queue: asyncio.Queue[Any] = topic["publish_queue"]
|
|
170
|
+
subscribers: set[_Subscriber] = topic["subscribers"]
|
|
171
|
+
|
|
172
|
+
while True:
|
|
173
|
+
item = await publish_queue.get()
|
|
174
|
+
topic["last_value"] = item
|
|
175
|
+
dead_subscribers: list[_Subscriber] = []
|
|
176
|
+
entry: StreamEntry = topic["entry"]
|
|
177
|
+
|
|
178
|
+
for subscriber in list(subscribers):
|
|
179
|
+
try:
|
|
180
|
+
if entry.policy == "fifo":
|
|
181
|
+
await subscriber.queue.put(item)
|
|
182
|
+
else:
|
|
183
|
+
subscriber.queue.put_nowait(item)
|
|
184
|
+
except asyncio.QueueFull:
|
|
185
|
+
try:
|
|
186
|
+
_ = subscriber.queue.get_nowait()
|
|
187
|
+
except Exception:
|
|
188
|
+
pass
|
|
189
|
+
try:
|
|
190
|
+
subscriber.queue.put_nowait(item)
|
|
191
|
+
except Exception:
|
|
192
|
+
dead_subscribers.append(subscriber)
|
|
193
|
+
|
|
194
|
+
for dead_subscriber in dead_subscribers:
|
|
195
|
+
subscribers.discard(dead_subscriber)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class ConnectRouter:
|
|
199
|
+
"""Router for Connect-style RPC endpoints."""
|
|
200
|
+
|
|
201
|
+
def __init__(self) -> None:
|
|
202
|
+
"""Initialize the ConnectRouter with empty registries."""
|
|
203
|
+
self.router = APIRouter()
|
|
204
|
+
self._unaries: Dict[str, ActionEntry] = {}
|
|
205
|
+
self._streams: Dict[str, StreamEntry] = {}
|
|
206
|
+
self.manager = StreamManager()
|
|
207
|
+
|
|
208
|
+
def add_unary(self, entry: ActionEntry, service_fqn: str) -> None:
|
|
209
|
+
"""Register a unary RPC action endpoint.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
entry: Action entry containing the handler function and types
|
|
213
|
+
service_fqn: Fully qualified service name for the path
|
|
214
|
+
"""
|
|
215
|
+
self._unaries[entry.name] = entry
|
|
216
|
+
path = f"/{service_fqn}/{entry.name}"
|
|
217
|
+
|
|
218
|
+
@self.router.post(path)
|
|
219
|
+
async def _handler(request: Request) -> JSONResponse:
|
|
220
|
+
try:
|
|
221
|
+
request_body = await request.json()
|
|
222
|
+
except Exception:
|
|
223
|
+
request_body = {}
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
if entry.input_type is None:
|
|
227
|
+
result = await _maybe_await(entry.func())
|
|
228
|
+
else:
|
|
229
|
+
input_type = entry.input_type
|
|
230
|
+
if hasattr(input_type, "model_validate"):
|
|
231
|
+
validated_arg = input_type.model_validate(request_body)
|
|
232
|
+
else:
|
|
233
|
+
validated_arg = request_body
|
|
234
|
+
result = await _maybe_await(entry.func(validated_arg))
|
|
235
|
+
|
|
236
|
+
if hasattr(result, "model_dump"):
|
|
237
|
+
result = result.model_dump()
|
|
238
|
+
return JSONResponse(result or {})
|
|
239
|
+
except Exception as exc:
|
|
240
|
+
return JSONResponse(error_envelope(exc))
|
|
241
|
+
|
|
242
|
+
def add_stream(self, entry: StreamEntry, service_fqn: str) -> None:
|
|
243
|
+
"""Register a streaming RPC endpoint.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
entry: Stream entry containing configuration
|
|
247
|
+
service_fqn: Fully qualified service name for the path
|
|
248
|
+
"""
|
|
249
|
+
self._streams[entry.name] = entry
|
|
250
|
+
self.manager.add_stream(entry)
|
|
251
|
+
|
|
252
|
+
path = f"/{service_fqn}/{entry.name}"
|
|
253
|
+
|
|
254
|
+
async def _stream_iter(subscriber: _Subscriber) -> Any:
|
|
255
|
+
"""Async generator yielding Connect frames for each stream item."""
|
|
256
|
+
try:
|
|
257
|
+
while True:
|
|
258
|
+
stream_item = await subscriber.queue.get()
|
|
259
|
+
payload = _serialize_stream_item(stream_item)
|
|
260
|
+
yield _frame(payload)
|
|
261
|
+
await asyncio.sleep(0)
|
|
262
|
+
except asyncio.CancelledError:
|
|
263
|
+
raise
|
|
264
|
+
except Exception as exc:
|
|
265
|
+
error_trailer = error_envelope(to_connect_error(exc))
|
|
266
|
+
yield _frame(error_trailer, trailer=True)
|
|
267
|
+
finally:
|
|
268
|
+
self.manager.unsubscribe(entry.name, subscriber)
|
|
269
|
+
|
|
270
|
+
@self.router.post(path)
|
|
271
|
+
async def _handler(_: Request) -> StreamingResponse:
|
|
272
|
+
self.manager.start_distributor_if_needed(entry.name)
|
|
273
|
+
subscriber = self.manager.subscribe(entry.name)
|
|
274
|
+
logger.debug(f"Stream subscribed: {entry.name}")
|
|
275
|
+
return StreamingResponse(
|
|
276
|
+
_stream_iter(subscriber),
|
|
277
|
+
media_type=CONTENT_TYPE,
|
|
278
|
+
headers={"Transfer-Encoding": "chunked"},
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
def publish(self, stream_name: str, item: Any) -> None:
|
|
282
|
+
"""Publish an item to a stream.
|
|
283
|
+
|
|
284
|
+
Creates the stream topic if it doesn't exist yet.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
stream_name: Name of the stream
|
|
288
|
+
item: Item to publish
|
|
289
|
+
|
|
290
|
+
Raises:
|
|
291
|
+
KeyError: If the stream doesn't exist and can't be found
|
|
292
|
+
"""
|
|
293
|
+
if stream_name not in self.manager._topics:
|
|
294
|
+
entry = self._streams.get(stream_name)
|
|
295
|
+
if not entry:
|
|
296
|
+
raise KeyError(f"Unknown stream: {stream_name}")
|
|
297
|
+
self.manager.add_stream(entry)
|
|
298
|
+
self.manager.publish(stream_name, item)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _serialize_stream_item(item: Any) -> Dict[str, Any]:
|
|
302
|
+
if hasattr(item, "model_dump"):
|
|
303
|
+
dumped = item.model_dump()
|
|
304
|
+
if isinstance(dumped, dict):
|
|
305
|
+
return dumped
|
|
306
|
+
return {"value": dumped}
|
|
307
|
+
if isinstance(item, dict):
|
|
308
|
+
return item
|
|
309
|
+
if isinstance(item, list):
|
|
310
|
+
return {"value": item}
|
|
311
|
+
return {"value": item}
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
async def _maybe_await(value: Any) -> Any:
|
|
315
|
+
"""Await a value if it's a coroutine or Future, otherwise return it.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
value: Value that may or may not be awaitable
|
|
319
|
+
|
|
320
|
+
Returns:
|
|
321
|
+
The result of awaiting or the value itself
|
|
322
|
+
"""
|
|
323
|
+
if asyncio.iscoroutine(value) or isinstance(value, asyncio.Future):
|
|
324
|
+
return await value
|
|
325
|
+
return value
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Callable, Literal, Optional, Type, List, TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from .app import VentionApp
|
|
6
|
+
from .entries import RpcBundle
|
|
7
|
+
|
|
8
|
+
from .entries import ActionEntry, StreamEntry
|
|
9
|
+
from .typing_utils import infer_input_type, infer_output_type, is_pydantic_model
|
|
10
|
+
|
|
11
|
+
_actions: List[ActionEntry] = []
|
|
12
|
+
_streams: List[StreamEntry] = []
|
|
13
|
+
_GLOBAL_APP: Optional["VentionApp"] = None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def set_global_app(app: Any) -> None:
|
|
17
|
+
"""Set the global app instance for use by stream publishers.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
app: VentionApp instance to make globally available
|
|
21
|
+
"""
|
|
22
|
+
global _GLOBAL_APP
|
|
23
|
+
_GLOBAL_APP = app
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def action(
|
|
27
|
+
name: Optional[str] = None,
|
|
28
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
29
|
+
"""Decorator to register a function as an RPC action.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
name: Optional name for the action. If not provided, uses the function name.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Decorator function that registers the action
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
|
|
39
|
+
input_type = infer_input_type(function)
|
|
40
|
+
output_type = infer_output_type(function)
|
|
41
|
+
entry = ActionEntry(
|
|
42
|
+
name or function.__name__, function, input_type, output_type
|
|
43
|
+
)
|
|
44
|
+
_actions.append(entry)
|
|
45
|
+
return function
|
|
46
|
+
|
|
47
|
+
return decorator
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def stream(
|
|
51
|
+
name: str,
|
|
52
|
+
*,
|
|
53
|
+
payload: Type[Any],
|
|
54
|
+
replay: bool = True,
|
|
55
|
+
queue_maxsize: int = 1,
|
|
56
|
+
policy: Literal["latest", "fifo"] = "latest",
|
|
57
|
+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
58
|
+
"""Register a server-broadcast stream.
|
|
59
|
+
|
|
60
|
+
The decorated function becomes a publisher that publishes its return value
|
|
61
|
+
to the stream when called.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
name: Name of the stream
|
|
65
|
+
payload: Type of the payload (Pydantic model or JSON-serializable type)
|
|
66
|
+
replay: Whether to replay the last value to new subscribers
|
|
67
|
+
queue_maxsize: Maximum size of the per-subscriber queue
|
|
68
|
+
policy: Delivery policy, either "latest" or "fifo"
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Decorator function that registers the stream
|
|
72
|
+
"""
|
|
73
|
+
if not (is_pydantic_model(payload) or payload in (int, float, str, bool, dict)):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
"payload must be a pydantic BaseModel or a JSON-serializable scalar/dict"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def decorator(function: Callable[..., Any]) -> Callable[..., Any]:
|
|
79
|
+
entry = StreamEntry(
|
|
80
|
+
name=name,
|
|
81
|
+
func=None,
|
|
82
|
+
payload_type=payload,
|
|
83
|
+
replay=replay,
|
|
84
|
+
queue_maxsize=queue_maxsize,
|
|
85
|
+
policy=policy,
|
|
86
|
+
)
|
|
87
|
+
_streams.append(entry)
|
|
88
|
+
|
|
89
|
+
async def publisher_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
90
|
+
if _GLOBAL_APP is None or _GLOBAL_APP.connect_router is None:
|
|
91
|
+
raise RuntimeError("Stream publish called before app.finalize()")
|
|
92
|
+
result = await function(*args, **kwargs)
|
|
93
|
+
_GLOBAL_APP.connect_router.publish(name, result)
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
entry.func = publisher_wrapper
|
|
97
|
+
return publisher_wrapper
|
|
98
|
+
|
|
99
|
+
return decorator
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def collect_bundle() -> RpcBundle:
|
|
103
|
+
"""Collect all registered actions and streams into an RpcBundle.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
RpcBundle containing all actions and streams registered via decorators
|
|
107
|
+
"""
|
|
108
|
+
from .entries import RpcBundle
|
|
109
|
+
|
|
110
|
+
return RpcBundle(actions=list(_actions), streams=list(_streams))
|
communication/entries.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Callable, Literal, Optional, Type
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class ActionEntry:
|
|
8
|
+
"""Entry for a unary RPC action."""
|
|
9
|
+
|
|
10
|
+
name: str
|
|
11
|
+
func: Callable[..., Any]
|
|
12
|
+
input_type: Optional[Type[Any]]
|
|
13
|
+
output_type: Optional[Type[Any]]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class StreamEntry:
|
|
18
|
+
"""Entry for a streaming RPC."""
|
|
19
|
+
|
|
20
|
+
name: str
|
|
21
|
+
func: Optional[Callable[..., Any]]
|
|
22
|
+
payload_type: Type[Any]
|
|
23
|
+
replay: bool = True
|
|
24
|
+
queue_maxsize: int = 1
|
|
25
|
+
policy: Literal["latest", "fifo"] = "latest"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class RpcBundle:
|
|
30
|
+
"""Bundle of RPC actions and streams."""
|
|
31
|
+
|
|
32
|
+
actions: list[ActionEntry] = field(default_factory=list)
|
|
33
|
+
streams: list[StreamEntry] = field(default_factory=list)
|
|
34
|
+
|
|
35
|
+
def extend(self, other: "RpcBundle") -> None:
|
|
36
|
+
"""Extend this bundle with actions and streams from another bundle.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
other: RPC bundle to merge into this one
|
|
40
|
+
"""
|
|
41
|
+
self.actions.extend(other.actions)
|
|
42
|
+
self.streams.extend(other.streams)
|
communication/errors.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConnectError(Exception):
|
|
6
|
+
"""Application-level error to send over Connect transport."""
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self, code: str, message: str, *, details: Optional[List[Any]] = None
|
|
10
|
+
) -> None:
|
|
11
|
+
super().__init__(message)
|
|
12
|
+
self.code = code
|
|
13
|
+
self.message = message
|
|
14
|
+
self.details = details or []
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def to_connect_error(exception: BaseException) -> ConnectError:
|
|
18
|
+
"""Map arbitrary exceptions to a ConnectError.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
exception: Exception to convert
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
ConnectError with appropriate error code based on exception type
|
|
25
|
+
"""
|
|
26
|
+
if isinstance(exception, ConnectError):
|
|
27
|
+
return exception
|
|
28
|
+
|
|
29
|
+
import asyncio
|
|
30
|
+
|
|
31
|
+
if isinstance(exception, asyncio.TimeoutError):
|
|
32
|
+
return ConnectError("deadline_exceeded", str(exception) or "Deadline exceeded")
|
|
33
|
+
|
|
34
|
+
if isinstance(exception, (KeyError, ValueError)):
|
|
35
|
+
return ConnectError("invalid_argument", str(exception) or "Invalid argument")
|
|
36
|
+
|
|
37
|
+
if isinstance(exception, PermissionError):
|
|
38
|
+
return ConnectError("permission_denied", str(exception) or "Permission denied")
|
|
39
|
+
|
|
40
|
+
return ConnectError("internal", str(exception) or exception.__class__.__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def error_envelope(exception: BaseException) -> Dict[str, Any]:
|
|
44
|
+
"""Wrap an exception in a Connect error envelope format.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
exception: Exception to wrap
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Dictionary with error code, message, and details in Connect format
|
|
51
|
+
"""
|
|
52
|
+
connect_error = to_connect_error(exception)
|
|
53
|
+
return {
|
|
54
|
+
"error": {
|
|
55
|
+
"code": connect_error.code,
|
|
56
|
+
"message": connect_error.message,
|
|
57
|
+
"details": connect_error.details,
|
|
58
|
+
}
|
|
59
|
+
}
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import Any, Optional, Type, get_type_hints, cast
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TypingError(Exception):
|
|
9
|
+
"""Raised when type inference fails."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _strip_self(params: list[inspect.Parameter]) -> list[inspect.Parameter]:
|
|
13
|
+
if not params:
|
|
14
|
+
return params
|
|
15
|
+
first = params[0]
|
|
16
|
+
if first.name in ("self", "cls"):
|
|
17
|
+
return params[1:]
|
|
18
|
+
return params
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def infer_input_type(function: Any) -> Optional[Type[Any]]:
|
|
22
|
+
"""Infer the input type annotation from a function's first parameter.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
function: Function to inspect
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Type annotation of the first parameter, or None if no parameters or type is Any
|
|
29
|
+
|
|
30
|
+
Raises:
|
|
31
|
+
TypingError: If the first parameter lacks a type annotation
|
|
32
|
+
"""
|
|
33
|
+
signature = inspect.signature(function)
|
|
34
|
+
parameters = _strip_self(list(signature.parameters.values()))
|
|
35
|
+
if not parameters:
|
|
36
|
+
return None
|
|
37
|
+
|
|
38
|
+
first_param = parameters[0]
|
|
39
|
+
type_hints = get_type_hints(function)
|
|
40
|
+
if first_param.name not in type_hints:
|
|
41
|
+
raise TypingError(f"First parameter '{first_param.name}' must be annotated")
|
|
42
|
+
|
|
43
|
+
hint_type = type_hints[first_param.name]
|
|
44
|
+
if hint_type is Any:
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
if isinstance(hint_type, type):
|
|
48
|
+
return cast(Type[Any], hint_type)
|
|
49
|
+
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def infer_output_type(function: Any) -> Optional[Type[Any]]:
|
|
54
|
+
"""Infer the return type annotation from a function.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
function: Function to inspect
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Return type annotation, or None if not annotated or type is Any
|
|
61
|
+
"""
|
|
62
|
+
type_hints = get_type_hints(function)
|
|
63
|
+
return_type = type_hints.get("return")
|
|
64
|
+
if return_type is None or return_type is Any:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
if isinstance(return_type, type) and return_type in (type(None),):
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
if isinstance(return_type, type):
|
|
71
|
+
return cast(Type[Any], return_type)
|
|
72
|
+
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def is_pydantic_model(type_annotation: Any) -> bool:
|
|
77
|
+
"""Check if a type annotation is a Pydantic BaseModel.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
type_annotation: Type to check
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
True if the type is a Pydantic BaseModel subclass, False otherwise
|
|
84
|
+
"""
|
|
85
|
+
try:
|
|
86
|
+
return isinstance(type_annotation, type) and issubclass(
|
|
87
|
+
type_annotation, BaseModel
|
|
88
|
+
)
|
|
89
|
+
except Exception:
|
|
90
|
+
return False
|