delphai-rpc 4.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.1
2
+ Name: delphai-rpc
3
+ Version: 4.1.6
4
+ Summary: Queue-based RPC client and server
5
+ Author: Anton Ryzhov
6
+ Author-email: anton@delphai.com
7
+ Requires-Python: >=3.8,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.8
10
+ Classifier: Programming Language :: Python :: 3.9
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Requires-Dist: aio-pika (>=9.1.4,<10.0.0)
15
+ Requires-Dist: msgpack (>=1.0.5,<2.0.0)
16
+ Requires-Dist: prometheus-client (>=0.20.0,<0.21.0)
17
+ Requires-Dist: pydantic (>=2.2,<3.0)
18
+ Description-Content-Type: text/markdown
19
+
20
+ # delphai-rpc
21
+ Queue-based RPC client and server
22
+
@@ -0,0 +1,2 @@
1
+ # delphai-rpc
2
+ Queue-based RPC client and server
@@ -0,0 +1,5 @@
1
+ from .client import Options, RpcClient
2
+ from .server import RpcServer
3
+
4
+
5
+ __all__ = ["Options", "RpcClient", "RpcServer"]
@@ -0,0 +1,331 @@
1
+ import aio_pika
2
+ import aio_pika.exceptions
3
+ import asyncio
4
+ import functools
5
+ import logging
6
+ import socket
7
+ import time
8
+ import uuid
9
+ import weakref
10
+
11
+ from aio_pika.abc import (
12
+ AbstractChannel,
13
+ AbstractExchange,
14
+ AbstractRobustConnection,
15
+ AbstractQueue,
16
+ )
17
+ from typing import Any, Dict, Optional
18
+
19
+ from . import errors
20
+ from . import metrics
21
+ from .connection_manager import get_connection
22
+ from .models import Request, Response
23
+ from .server import request_context
24
+ from .types import AbstractOptions, IncomingMessage, Message, Priority
25
+ from .utils import clean_service_name, fix_message_timestamp
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class Options(AbstractOptions):
32
+ timeout: Optional[float] = 60
33
+ priority: Optional[Priority] = None
34
+ no_wait: bool = False
35
+
36
+
37
+ class RpcClient:
38
+ def __init__(
39
+ self, client_name: str, connection_string: str, *args: Options, **kwargs: Any
40
+ ) -> None:
41
+ self._client_name = clean_service_name(client_name)
42
+ self._connection_string = connection_string
43
+ self._app_id = f"{self._client_name}@{socket.gethostname()}"
44
+
45
+ self._options = Options(*args, **kwargs)
46
+ self._client_lock = asyncio.Lock()
47
+ self._declare_exchange_lock = asyncio.Lock()
48
+
49
+ self._reset()
50
+
51
+ def _reset(self) -> None:
52
+ self._connection: Optional[AbstractRobustConnection] = None
53
+ self._channel: Optional[AbstractChannel] = None
54
+ self._reply_queue: Optional[AbstractQueue] = None
55
+
56
+ self._exchanges: Dict[str, AbstractExchange] = {}
57
+ self._futures: weakref.WeakValueDictionary[str, asyncio.Future] = (
58
+ weakref.WeakValueDictionary()
59
+ )
60
+
61
+ def get_service(
62
+ self, service_name: str, *args: Options, **kwargs: Dict[str, Any]
63
+ ) -> "RpcService":
64
+ options = self._options.update(*args, **kwargs)
65
+ return RpcService(self, clean_service_name(service_name), options)
66
+
67
+ __getitem__ = __getattr__ = get_service
68
+
69
+ async def _ensure_connection(self) -> None:
70
+ async with self._client_lock:
71
+ if self._connection and self._channel:
72
+ return
73
+
74
+ self._connection = await get_connection(
75
+ self._connection_string, self._client_name
76
+ )
77
+ self._channel = await self._connection.channel(on_return_raises=True)
78
+
79
+ async def _ensure_reply_queue(self) -> None:
80
+ await self._ensure_connection()
81
+
82
+ assert self._channel
83
+
84
+ async with self._client_lock:
85
+ if self._reply_queue:
86
+ return
87
+
88
+ queue_name = f"client.{self._client_name}.{uuid.uuid4().hex}"
89
+ queue = await self._channel.declare_queue(
90
+ name=queue_name,
91
+ exclusive=True,
92
+ auto_delete=True,
93
+ )
94
+ await queue.consume(fix_message_timestamp(self._on_message))
95
+
96
+ self._reply_queue = queue
97
+
98
+ async def stop(self) -> None:
99
+ channel = self._channel
100
+ if channel:
101
+ self._reset()
102
+ await channel.close()
103
+
104
+ async def call(
105
+ self,
106
+ service_name: str,
107
+ method_name: str,
108
+ arguments: Optional[Dict[str, Any]] = None,
109
+ options: Optional[Options] = None,
110
+ ) -> Any:
111
+ options = options or Options()
112
+
113
+ current_request_context = request_context.get()
114
+ if current_request_context:
115
+ if options.priority is None:
116
+ options = options.update(priority=current_request_context.priority)
117
+
118
+ if current_request_context.deadline is not None:
119
+ timeout = current_request_context.deadline - time.time()
120
+ if options.timeout is not None:
121
+ timeout = min(options.timeout, timeout)
122
+
123
+ options = options.update(timeout=timeout)
124
+
125
+ if options.priority is None:
126
+ options = options.update(priority=Priority.DEFAULT)
127
+
128
+ if options.timeout:
129
+ deadline = time.monotonic() + options.timeout
130
+ waiter = lambda coro: asyncio.wait_for( # noqa: E731
131
+ coro, timeout=(deadline - time.monotonic())
132
+ )
133
+ else:
134
+ waiter = lambda coro: coro # noqa: E731
135
+
136
+ request = Request(method_name=method_name, arguments=arguments or {})
137
+
138
+ if options.no_wait:
139
+ correlation_id = None
140
+ future = None
141
+ else:
142
+ correlation_id = str(uuid.uuid1())
143
+ future = asyncio.get_running_loop().create_future()
144
+ self._futures[correlation_id] = future
145
+
146
+ with metrics.client_requests_in_progress.labels(
147
+ service=service_name,
148
+ method=method_name,
149
+ ).track_inprogress():
150
+ elapsed = -time.perf_counter()
151
+ await waiter(
152
+ self._send_request(service_name, request, options, correlation_id)
153
+ )
154
+
155
+ try:
156
+ if future:
157
+ return await waiter(future)
158
+ except asyncio.CancelledError:
159
+ logger.warning(
160
+ "Wait was cancelled but not the request itself. "
161
+ "Pass `timeout` option instead of using `asyncio.wait_for` or similar"
162
+ )
163
+ raise
164
+ finally:
165
+ elapsed += time.perf_counter()
166
+
167
+ metrics.client_request_processed(
168
+ priority=options.priority or 0,
169
+ service=service_name,
170
+ method=method_name,
171
+ elapsed=elapsed,
172
+ )
173
+
174
+ async def _send_request(
175
+ self,
176
+ service_name: str,
177
+ request: Request,
178
+ options: Options,
179
+ correlation_id: Optional[str] = None,
180
+ ):
181
+ await self._ensure_connection()
182
+
183
+ assert self._channel
184
+
185
+ request_message = Message(
186
+ body=request,
187
+ app_id=self._app_id,
188
+ priority=options.priority,
189
+ expiration=options.timeout or None,
190
+ type="rpc.request",
191
+ )
192
+
193
+ if correlation_id:
194
+ await self._ensure_reply_queue()
195
+ assert self._reply_queue
196
+ request_message.correlation_id = correlation_id
197
+ request_message.reply_to = self._reply_queue.name
198
+
199
+ async with self._declare_exchange_lock:
200
+ if service_name not in self._exchanges:
201
+ try:
202
+ self._exchanges[service_name] = await self._channel.get_exchange(
203
+ f"service.{service_name}"
204
+ )
205
+ except aio_pika.exceptions.ChannelNotFoundEntity:
206
+ raise errors.UnknownServiceError("Exchange was not found")
207
+
208
+ try:
209
+ routing_key = f"method.{request.method_name}"
210
+ await self._exchanges[service_name].publish(
211
+ message=request_message,
212
+ routing_key=routing_key,
213
+ )
214
+ metrics.message_published(
215
+ exchange=self._exchanges[service_name].name,
216
+ routing_key=routing_key,
217
+ type=request_message.type or "",
218
+ priority=request_message.priority or 0,
219
+ payload_size=request_message.body_size,
220
+ )
221
+
222
+ except aio_pika.exceptions.PublishError:
223
+ raise errors.UnknownServiceError("Request was not delivered to a queue")
224
+
225
+ async def _on_message(self, message: IncomingMessage) -> None:
226
+ message_consumed = functools.partial(
227
+ metrics.message_consumed,
228
+ exchange=message.exchange,
229
+ routing_key="", # random, causes high cardinality metric
230
+ type=message.type,
231
+ priority=message.priority,
232
+ redelivered=message.redelivered,
233
+ payload_size=message.body_size,
234
+ )
235
+
236
+ if message.type != "rpc.response":
237
+ logger.warning(
238
+ "[MID:%s] Unexpected message type: `%s`",
239
+ message.message_id,
240
+ message.type,
241
+ )
242
+ await message.reject()
243
+ message_consumed(error="WRONG_MESSAGE_TYPE")
244
+ return
245
+
246
+ if not message.correlation_id:
247
+ logger.warning("[MID:%s] `correlation_id` is not set", message.message_id)
248
+ await message.reject()
249
+ message_consumed(error="NO_CORRELATION_ID")
250
+ return
251
+
252
+ future = self._futures.pop(message.correlation_id, None)
253
+ if not future or future.done():
254
+ logger.warning(
255
+ "[MID:%s] [CID:%s] Response is not awaited (too late or duplicate)",
256
+ message.message_id,
257
+ message.correlation_id,
258
+ )
259
+ await message.reject()
260
+ message_consumed(error="UNKNOWN_CORRELATION_ID")
261
+ return
262
+
263
+ try:
264
+ future.set_result(await self._process_message(message))
265
+ except Exception as error:
266
+ future.set_exception(error)
267
+
268
+ await message.ack()
269
+ message_consumed()
270
+
271
+ logger.debug(
272
+ "[MID:%s] [CID:%s] Got `%s` from `%s` service",
273
+ message.message_id,
274
+ message.correlation_id,
275
+ message.type,
276
+ message.app_id or "unknown",
277
+ )
278
+
279
+ async def _process_message(self, message: IncomingMessage) -> Any:
280
+ response = Response.model_validate_message(message)
281
+
282
+ if response.error:
283
+ error_class = getattr(errors, response.error.type, errors.UnknownError)
284
+ raise error_class(response.error.message)
285
+
286
+ return response.result
287
+
288
+
289
+ class RpcService:
290
+ def __init__(self, client: RpcClient, service_name: str, options: Options):
291
+ self._client = client
292
+ self._service_name = service_name
293
+ self._options = options
294
+
295
+ def get_method(
296
+ self, method_name: str, *args: Options, **kwargs: Dict[str, Any]
297
+ ) -> "RpcMethod":
298
+ options = self._options.update(*args, **kwargs)
299
+ return RpcMethod(self._client, self._service_name, method_name, options)
300
+
301
+ __getitem__ = __getattr__ = get_method
302
+
303
+ def __repr__(self) -> str:
304
+ class_ = self.__class__
305
+ return f"<{class_.__qualname__} `{self._service_name}`>"
306
+
307
+
308
+ class RpcMethod:
309
+ def __init__(
310
+ self, client: RpcClient, service_name: str, method_name: str, options: Options
311
+ ) -> None:
312
+ self._client = client
313
+ self._service_name = service_name
314
+ self._method_name = method_name
315
+ self._options = options
316
+
317
+ def __repr__(self) -> str:
318
+ class_ = self.__class__
319
+ return (
320
+ f"<{class_.__qualname__} `{self._method_name}` "
321
+ "of service `{self._service_name}`>"
322
+ )
323
+
324
+ def __call__(self, *args: Options, **kwargs: Dict[str, Any]) -> Any:
325
+ options = self._options.update(*args)
326
+ return self._client.call(
327
+ service_name=self._service_name,
328
+ method_name=self._method_name,
329
+ arguments=kwargs,
330
+ options=options,
331
+ )
@@ -0,0 +1,38 @@
1
+ import aio_pika.exceptions
2
+ import asyncio
3
+ import importlib.metadata
4
+ import logging
5
+ import weakref
6
+
7
+ from aio_pika.connection import URL
8
+ from aio_pika.robust_connection import AbstractRobustConnection
9
+ from typing import Union
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ _connect_lock = asyncio.Lock()
15
+ _connections: weakref.WeakValueDictionary[URL, AbstractRobustConnection] = (
16
+ weakref.WeakValueDictionary()
17
+ )
18
+
19
+
20
+ async def get_connection(
21
+ connection_string: Union[str, URL, None], service_name: str = "?"
22
+ ) -> AbstractRobustConnection:
23
+ connection_url = aio_pika.connection.make_url(connection_string)
24
+ if not connection_url.query.get("name"):
25
+ package_name = __package__.split(".")[0]
26
+ package_name_version = importlib.metadata.version(package_name)
27
+
28
+ connection_url %= {
29
+ "name": f"{service_name} ({package_name} v{package_name_version})"
30
+ }
31
+
32
+ async with _connect_lock:
33
+ connection = _connections.get(connection_url)
34
+ if connection is None:
35
+ connection = await aio_pika.connect_robust(connection_url)
36
+ _connections[connection_url] = connection
37
+
38
+ return connection
@@ -0,0 +1,34 @@
1
+ class RpcError(Exception):
2
+ pass
3
+
4
+
5
+ class TemporaryError(RpcError):
6
+ pass
7
+
8
+
9
+ class FinalError(RpcError):
10
+ pass
11
+
12
+
13
+ class UnknownError(FinalError):
14
+ pass
15
+
16
+
17
+ class UnhandledError(FinalError):
18
+ pass
19
+
20
+
21
+ class ParsingError(FinalError):
22
+ pass
23
+
24
+
25
+ class UnknownServiceError(FinalError):
26
+ pass
27
+
28
+
29
+ class UnknownMethodError(FinalError):
30
+ pass
31
+
32
+
33
+ class ExecutionError(FinalError):
34
+ pass
@@ -0,0 +1,223 @@
1
+ import importlib.metadata
2
+
3
+ from .models import ResponseError
4
+
5
+ from prometheus_client import Counter, Gauge, Histogram, Info, Summary
6
+ from typing import Optional
7
+
8
+
9
+ application_info = Info("application", "Application info")
10
+
11
+
12
+ def set_application_info(**kwargs) -> None:
13
+ mertics_package = __package__.split(".")[0]
14
+ application_info.info(
15
+ {
16
+ "mertics_package": mertics_package,
17
+ "mertics_package_version": importlib.metadata.version(mertics_package),
18
+ **kwargs,
19
+ }
20
+ )
21
+
22
+
23
+ set_application_info()
24
+
25
+
26
+ messages_published_count = Counter(
27
+ name="queue_rpc_messages_published_total",
28
+ documentation="Total number of published messages",
29
+ labelnames=["exchange", "routing_key", "type", "priority"],
30
+ )
31
+
32
+ messages_published_payload_size = Summary(
33
+ name="queue_rpc_messages_published_payload_size_bytes",
34
+ documentation="Payload size of published messages",
35
+ labelnames=["exchange", "routing_key", "type"],
36
+ )
37
+
38
+
39
+ def message_published(
40
+ *, exchange: str, routing_key: str, type: str, priority: int, payload_size: int
41
+ ) -> None:
42
+ labels = dict(
43
+ exchange=exchange,
44
+ routing_key=routing_key,
45
+ type=type,
46
+ )
47
+ messages_published_count.labels(priority=priority, **labels).inc()
48
+ messages_published_payload_size.labels(**labels).observe(payload_size)
49
+
50
+
51
+ messages_consumed_count = Counter(
52
+ name="queue_rpc_messages_consumed_total",
53
+ documentation="Total number of consumed messages",
54
+ labelnames=["exchange", "routing_key", "redelivered", "type", "priority", "error"],
55
+ )
56
+
57
+ messages_consumed_payload_size = Summary(
58
+ name="queue_rpc_messages_consumed_payload_size_bytes",
59
+ documentation="Payload size of consumed messages",
60
+ labelnames=["exchange", "routing_key", "redelivered", "type"],
61
+ )
62
+
63
+
64
+ def message_consumed(
65
+ *,
66
+ exchange: str,
67
+ routing_key: str,
68
+ redelivered: bool,
69
+ type: str,
70
+ priority: int,
71
+ payload_size: int,
72
+ error: Optional[str] = None,
73
+ ) -> None:
74
+ labels = dict(
75
+ exchange=exchange,
76
+ routing_key=routing_key,
77
+ redelivered=redelivered,
78
+ type=type,
79
+ )
80
+ messages_consumed_count.labels(priority=priority, error=error or "", **labels).inc()
81
+ messages_consumed_payload_size.labels(**labels).observe(payload_size)
82
+
83
+
84
+ server_requests_count = Counter(
85
+ name="queue_rpc_server_requests_total",
86
+ documentation="Total number of requests",
87
+ labelnames=["priority", "method", "error"],
88
+ )
89
+
90
+ server_requests_in_progress = Gauge(
91
+ name="queue_rpc_server_requests_in_progress",
92
+ documentation="Number of requests in progress",
93
+ labelnames=["method"],
94
+ )
95
+
96
+ server_request_waiting_time = Histogram(
97
+ name="queue_rpc_server_request_waiting_seconds",
98
+ documentation="Time request spent in queue",
99
+ labelnames=["priority", "method"],
100
+ buckets=(
101
+ 0.1,
102
+ 0.3,
103
+ 0.5,
104
+ 1,
105
+ 3,
106
+ 5,
107
+ 10,
108
+ 30,
109
+ 1 * 60,
110
+ 3 * 60,
111
+ 5 * 60,
112
+ 10 * 60,
113
+ 30 * 60,
114
+ 1 * 3600,
115
+ 3 * 3600,
116
+ 5 * 3600,
117
+ 10 * 3600,
118
+ 30 * 3600,
119
+ ),
120
+ )
121
+
122
+ server_request_processing_time = Histogram(
123
+ name="queue_rpc_server_request_processing_seconds",
124
+ documentation="Time spent processing request",
125
+ labelnames=["method"],
126
+ buckets=(
127
+ 0.1,
128
+ 0.25,
129
+ 0.5,
130
+ 0.75,
131
+ 1.0,
132
+ 2.5,
133
+ 5.0,
134
+ 7.5,
135
+ 10,
136
+ 25,
137
+ 50,
138
+ 75,
139
+ 100,
140
+ 250,
141
+ 500,
142
+ 750,
143
+ 1000,
144
+ 2500,
145
+ ),
146
+ )
147
+
148
+
149
+ def server_request_processed(
150
+ priority: int,
151
+ method: str,
152
+ error: Optional[ResponseError],
153
+ queued_for: float,
154
+ elapsed: float,
155
+ ) -> None:
156
+ server_requests_count.labels(
157
+ priority=priority,
158
+ method=method,
159
+ error=error and error.type or "",
160
+ ).inc()
161
+
162
+ if queued_for is not None:
163
+ server_request_waiting_time.labels(
164
+ priority=priority,
165
+ method=method,
166
+ ).observe(queued_for)
167
+
168
+ server_request_processing_time.labels(
169
+ method=method,
170
+ ).observe(elapsed)
171
+
172
+
173
+ client_requests_count = Counter(
174
+ name="queue_rpc_client_requests_total",
175
+ documentation="Total number of requests",
176
+ labelnames=["priority", "service", "method"],
177
+ )
178
+
179
+ client_requests_in_progress = Gauge(
180
+ name="queue_rpc_client_requests_in_progress",
181
+ documentation="Number of requests in progress",
182
+ labelnames=["service", "method"],
183
+ )
184
+
185
+ client_request_time = Histogram(
186
+ name="queue_rpc_client_request_seconds",
187
+ documentation="Time request took",
188
+ labelnames=["priority", "service", "method"],
189
+ buckets=(
190
+ 0.1,
191
+ 0.3,
192
+ 0.5,
193
+ 1,
194
+ 3,
195
+ 5,
196
+ 10,
197
+ 30,
198
+ 1 * 60,
199
+ 3 * 60,
200
+ 5 * 60,
201
+ 10 * 60,
202
+ 30 * 60,
203
+ 1 * 3600,
204
+ 3 * 3600,
205
+ 5 * 3600,
206
+ 10 * 3600,
207
+ 30 * 3600,
208
+ ),
209
+ )
210
+
211
+
212
+ def client_request_processed(
213
+ priority: int, service: str, method: str, elapsed: float
214
+ ) -> None:
215
+ labels = dict(
216
+ priority=int(priority or 0),
217
+ service=service,
218
+ method=method,
219
+ )
220
+
221
+ client_requests_count.labels(**labels).inc()
222
+
223
+ client_request_time.labels(**labels).observe(elapsed)
@@ -0,0 +1,89 @@
1
+ import aio_pika.message
2
+ import functools
3
+ import logging
4
+ import msgpack
5
+ import pydantic
6
+ import zlib
7
+
8
+ from typing import Callable, Dict, Any, List, Optional, Tuple, Type, TypeVar
9
+
10
+ from . import errors
11
+
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ TBaseModel = TypeVar("TBaseModel", bound="BaseModel")
17
+
18
+
19
+ class BaseModel(pydantic.BaseModel):
20
+ @classmethod
21
+ def model_validate_message(
22
+ cls: Type[TBaseModel], message: aio_pika.message.Message
23
+ ) -> "TBaseModel":
24
+ body = message.body
25
+ if message.content_encoding == "deflate":
26
+ try:
27
+ body = zlib.decompress(body)
28
+ except Exception as error:
29
+ raise errors.ParsingError(f"Message decompression failed: `{error!r}`")
30
+
31
+ elif message.content_encoding:
32
+ raise errors.ParsingError(
33
+ f"Unknown content_encoding: `{message.content_encoding}`"
34
+ )
35
+
36
+ if message.content_type != "application/msgpack":
37
+ raise errors.ParsingError(
38
+ f"Got a message with unknown content type: {message.content_type}"
39
+ )
40
+
41
+ try:
42
+ return cls.model_validate(msgpack.loads(body))
43
+ except ValueError as error:
44
+ raise errors.ParsingError(f"Message deserialization failed: `{error!r}`")
45
+
46
+ def model_dump_msgpack(self, **kwargs) -> bytes:
47
+ kwargs.setdefault("exclude_defaults", True)
48
+ return msgpack.dumps(self.model_dump(**kwargs))
49
+
50
+
51
+ class Request(BaseModel):
52
+ method_name: str
53
+ arguments: Dict[str, Any] = {}
54
+ context: Optional[Any] = None
55
+ timings: List[Tuple[str, float]] = []
56
+
57
+
58
+ class ResponseError(BaseModel):
59
+ type: str
60
+ message: Optional[str] = None
61
+
62
+
63
+ class Response(BaseModel):
64
+ result: Optional[Any] = None
65
+ error: Optional[ResponseError] = None
66
+ context: Optional[Any] = None
67
+ timings: List[Tuple[str, float]] = []
68
+
69
+ @classmethod
70
+ def wrap_errors(cls, func: Callable) -> Callable:
71
+ @functools.wraps(func)
72
+ async def inner(*args, **kwargs):
73
+ try:
74
+ response = await func(*args, **kwargs)
75
+ if not isinstance(response, cls):
76
+ raise TypeError(f"Incorrect response type, got: {type(response)}")
77
+
78
+ except errors.RpcError as error:
79
+ response = cls(
80
+ error={"type": type(error).__name__, "message": error.args[0]}
81
+ )
82
+
83
+ except Exception as error:
84
+ logger.exception("Unhandled error")
85
+ response = cls(error={"type": "UnhandledError", "message": repr(error)})
86
+
87
+ return response
88
+
89
+ return inner
@@ -0,0 +1,364 @@
1
+ import aio_pika
2
+ import aio_pika.connection
3
+ import asyncio
4
+ import contextvars
5
+ import functools
6
+ import inspect
7
+ import logging
8
+ import pydantic
9
+ import socket
10
+ import time
11
+
12
+ from aio_pika.abc import (
13
+ AbstractChannel,
14
+ AbstractExchange,
15
+ AbstractRobustConnection,
16
+ AbstractQueue,
17
+ )
18
+ from typing import Any, Callable, Dict, Optional, cast
19
+
20
+ from . import errors
21
+ from . import metrics
22
+ from .connection_manager import get_connection
23
+ from .models import Request, Response
24
+ from .types import IncomingMessage, Message, Priority, RequestContext
25
+ from .utils import clean_service_name, fix_message_timestamp
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ request_context: contextvars.ContextVar[Optional[RequestContext]] = (
31
+ contextvars.ContextVar("request_context", default=None)
32
+ )
33
+
34
+
35
+ class RpcServer:
36
+ def __init__(self, service_name) -> None:
37
+ self._service_name = clean_service_name(service_name)
38
+ self._app_id = f"{self._service_name}@{socket.gethostname()}"
39
+
40
+ self._handlers: Dict[str, Callable] = {}
41
+
42
+ self._reset()
43
+
44
+ self.bind(self._ping)
45
+ self.bind(self._help)
46
+
47
+ def _reset(self) -> None:
48
+ self._connection: Optional[AbstractRobustConnection] = None
49
+ self._channel: Optional[AbstractChannel] = None
50
+ self._exchange: Optional[AbstractExchange] = None
51
+ self._queue: Optional[AbstractQueue] = None
52
+
53
+ def bind(
54
+ self, handler: Optional[Callable] = None, *, name: Optional[str] = None
55
+ ) -> Callable:
56
+ """
57
+ Binds to be exposed handlers (functions) to RPC server instance:
58
+
59
+ @server.bind
60
+ def add(*, a: float, b: float) -> float:
61
+ ...
62
+
63
+ # or
64
+
65
+ def sub(*, a: float, b: float) -> float:
66
+ ...
67
+
68
+ server.bind(sub)
69
+
70
+ # or
71
+
72
+ @server.bind(name="mul")
73
+ def multiply(*, a: float, b: float) -> float:
74
+ ...
75
+
76
+ """
77
+
78
+ def decorator(handler):
79
+ self._bind_handler(handler=handler, name=name)
80
+ return handler
81
+
82
+ return decorator(handler) if handler else decorator
83
+
84
+ def _bind_handler(self, *, handler: Callable, name: Optional[str] = None) -> None:
85
+ handler_name = name or handler.__name__
86
+ if handler_name in self._handlers:
87
+ raise KeyError(f"Handler {handler_name} already defined")
88
+
89
+ if hasattr(handler, "raw_function"):
90
+ # Unwrap `pydantic.validate_call` decorator
91
+ handler = handler.raw_function
92
+
93
+ self._validate_handler(handler)
94
+
95
+ self._handlers[handler_name] = pydantic.validate_call(validate_return=True)(
96
+ handler
97
+ )
98
+
99
+ def _validate_handler(self, handler: Callable) -> None:
100
+ if not inspect.iscoroutinefunction(handler):
101
+ raise TypeError("Handlers must be coroutine functions")
102
+
103
+ positional_only = []
104
+ positional_or_keyword = []
105
+
106
+ for parameter_name, parameter in inspect.signature(handler).parameters.items():
107
+ if parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
108
+ positional_or_keyword.append(parameter_name)
109
+
110
+ elif parameter.kind in [
111
+ inspect.Parameter.POSITIONAL_ONLY,
112
+ inspect.Parameter.VAR_POSITIONAL,
113
+ ]:
114
+ positional_only.append(parameter_name)
115
+
116
+ if positional_only:
117
+ raise TypeError(
118
+ "{} has positional-only parameters {} that are not supported in {}".format(
119
+ handler,
120
+ positional_only,
121
+ self.__class__,
122
+ )
123
+ )
124
+
125
+ if positional_or_keyword:
126
+ logger.warning(
127
+ "%s has positional parameters %s, only keyword parameters are supported in %s",
128
+ handler,
129
+ positional_or_keyword,
130
+ self.__class__,
131
+ )
132
+
133
+ async def start(self, connection_string: str, prefetch_count: int = 1) -> None:
134
+ if self._connection:
135
+ raise RuntimeError("Already started")
136
+
137
+ connection = self._connection = await get_connection(
138
+ connection_string, self._service_name
139
+ )
140
+ channel = self._channel = await connection.channel()
141
+ await channel.set_qos(prefetch_count=prefetch_count)
142
+
143
+ exchange = self._exchange = await channel.declare_exchange(
144
+ name=f"service.{self._service_name}",
145
+ type=aio_pika.ExchangeType.TOPIC,
146
+ durable=True,
147
+ )
148
+
149
+ queue = self._queue = await channel.declare_queue(
150
+ name=f"service.{self._service_name}",
151
+ durable=True,
152
+ arguments={
153
+ "x-max-priority": max(Priority),
154
+ "x-dead-letter-exchange": f"service.{self._service_name}.dlx",
155
+ },
156
+ )
157
+
158
+ await queue.bind(exchange, "#")
159
+ await queue.consume(fix_message_timestamp(self._on_message))
160
+
161
+ logger.info("RPC server is consuming messages from `%s`", queue)
162
+
163
+ async def stop(self) -> None:
164
+ channel = self._channel
165
+ if channel:
166
+ self._reset()
167
+ await channel.close()
168
+
169
+ async def serve_forever(
170
+ self, connection_string: str, prefetch_count: int = 1
171
+ ) -> None:
172
+ await self.start(connection_string, prefetch_count=prefetch_count)
173
+ try:
174
+ await asyncio.Future()
175
+ finally:
176
+ await self.stop()
177
+
178
+ async def _on_message(self, message: IncomingMessage) -> None:
179
+ consumed_timestamp = time.time()
180
+
181
+ logger.debug(
182
+ "[MID:%s] Got `%s` message from `%s` service",
183
+ message.message_id,
184
+ message.type or "[untyped]",
185
+ message.app_id or "unknown",
186
+ )
187
+
188
+ message_consumed = functools.partial(
189
+ metrics.message_consumed,
190
+ exchange=message.exchange,
191
+ routing_key=message.routing_key,
192
+ type=message.type,
193
+ priority=message.priority,
194
+ redelivered=message.redelivered,
195
+ payload_size=message.body_size,
196
+ )
197
+
198
+ if message.type != "rpc.request":
199
+ logger.warning(
200
+ "[MID:%s] Unexpected message type: `%s`",
201
+ message.message_id,
202
+ message.type,
203
+ )
204
+ await message.reject()
205
+ message_consumed(error="WRONG_MESSAGE_TYPE")
206
+ return
207
+
208
+ if message.reply_to and not message.correlation_id:
209
+ logger.warning("[MID:%s] `correlation_id` is not set", message.message_id)
210
+ message_consumed(error="NO_CORRELATION_ID")
211
+ await message.reject()
212
+ return
213
+
214
+ published_timestamp = None
215
+ if message.timestamp:
216
+ published_timestamp = message.timestamp.timestamp()
217
+
218
+ deadline = None
219
+ if published_timestamp and message.expiration:
220
+ deadline = published_timestamp + cast(float, message.expiration)
221
+
222
+ request_context.set(
223
+ RequestContext(
224
+ deadline=deadline,
225
+ priority=Priority(message.priority or 0),
226
+ )
227
+ )
228
+
229
+ try:
230
+ response = await asyncio.wait_for(
231
+ self._process_message(message, consumed_timestamp, published_timestamp),
232
+ timeout=(deadline - time.time()) if deadline else None,
233
+ )
234
+ except asyncio.TimeoutError:
235
+ logger.info(
236
+ "[MID:%s] [CID:%s] Execution of `%s` from `%s` service timed out", # noqa: E501
237
+ message.message_id,
238
+ message.correlation_id,
239
+ message.type,
240
+ message.app_id or "unknown",
241
+ )
242
+
243
+ await message.ack()
244
+ message_consumed(error="TIMEOUT")
245
+ return
246
+
247
+ if message.reply_to:
248
+ response_message = Message(
249
+ body=response,
250
+ app_id=self._app_id,
251
+ priority=message.priority,
252
+ correlation_id=message.correlation_id,
253
+ expiration=(deadline - time.time()) if deadline else None,
254
+ type="rpc.response",
255
+ )
256
+
257
+ await message.channel.basic_publish(
258
+ body=response_message.body,
259
+ routing_key=message.reply_to,
260
+ properties=response_message.properties,
261
+ )
262
+ metrics.message_published(
263
+ exchange="",
264
+ routing_key="", # random, causes high cardinality metric
265
+ type=response_message.type or "",
266
+ priority=response_message.priority or 0,
267
+ payload_size=response_message.body_size,
268
+ )
269
+
270
+ await message.ack()
271
+ message_consumed()
272
+
273
+ logger.debug(
274
+ "[MID:%s] [CID:%s] Handled `%s` from `%s` service, success: %s", # noqa: E501
275
+ message.message_id,
276
+ message.correlation_id,
277
+ message.type,
278
+ message.app_id or "unknown",
279
+ response.error is None,
280
+ )
281
+
282
+ @Response.wrap_errors
283
+ async def _process_message(
284
+ self,
285
+ message: IncomingMessage,
286
+ consumed_timestamp: float,
287
+ published_timestamp: Optional[float],
288
+ ) -> Response:
289
+ request = Request.model_validate_message(message)
290
+
291
+ timings = request.timings
292
+ queued_for = None
293
+ if published_timestamp is not None:
294
+ timings.append(("queue.published", published_timestamp))
295
+ queued_for = consumed_timestamp - published_timestamp
296
+
297
+ timings.append((f"queue.consumed by {self._app_id}", consumed_timestamp))
298
+
299
+ with metrics.server_requests_in_progress.labels(
300
+ method=request.method_name
301
+ ).track_inprogress():
302
+ elapsed = -time.perf_counter()
303
+ response = await self._process_request(request)
304
+ elapsed += time.perf_counter()
305
+ timings.append(("execution.completed", consumed_timestamp + elapsed))
306
+
307
+ response.context = request.context
308
+ response.timings = timings
309
+
310
+ metrics.server_request_processed(
311
+ priority=message.priority or 0,
312
+ method=request.method_name,
313
+ error=response.error,
314
+ queued_for=queued_for or 0,
315
+ elapsed=elapsed,
316
+ )
317
+
318
+ logger.info(
319
+ "[MID:%s] Processed `%s` from `%s` service to method `%s`. In queue: %ims, execution: %ims, success: %s%s",
320
+ message.message_id,
321
+ message.type,
322
+ message.app_id or "unknown",
323
+ request.method_name,
324
+ (None if queued_for is None else max(queued_for * 1000, 0)),
325
+ elapsed * 1000,
326
+ response.error is None,
327
+ (f", error: {response.error.message}" if response.error else ""),
328
+ )
329
+
330
+ return response
331
+
332
+ @Response.wrap_errors
333
+ async def _process_request(self, request: Request) -> Response:
334
+ handler = self._handlers.get(request.method_name)
335
+ if handler is None:
336
+ raise errors.UnknownMethodError(request.method_name)
337
+
338
+ try:
339
+ result = await handler(**request.arguments)
340
+ except Exception as error:
341
+ raise errors.ExecutionError(repr(error))
342
+
343
+ if isinstance(result, pydantic.BaseModel):
344
+ result = result.model_dump()
345
+
346
+ return Response(result=result)
347
+
348
+ async def _ping(self) -> None:
349
+ return None
350
+
351
+ async def _help(self) -> Dict[str, Any]:
352
+ """
353
+ Returns methods list
354
+ """
355
+ return {
356
+ "methods": [
357
+ {
358
+ "method_name": method_name,
359
+ "signature": f"{method_name}{inspect.signature(handler)}",
360
+ "description": handler.__doc__,
361
+ }
362
+ for method_name, handler in self._handlers.items()
363
+ ]
364
+ }
@@ -0,0 +1,112 @@
1
+ import aio_pika.message
2
+ import pydantic
3
+ import time
4
+ import uuid
5
+
6
+ from aio_pika.message import IncomingMessage
7
+ from dataclasses import dataclass
8
+ from enum import IntEnum
9
+ from typing import Dict, Any, Optional, TypeVar, Union
10
+
11
+ from .models import BaseModel
12
+
13
+
14
+ __all__ = [
15
+ "AbstractOptions",
16
+ "IncomingMessage",
17
+ "Message",
18
+ "Priority",
19
+ "RequestContext",
20
+ ]
21
+
22
+
23
+ TAbstractOptions = TypeVar("TAbstractOptions", bound="AbstractOptions")
24
+
25
+
26
+ class AbstractOptions(pydantic.BaseModel):
27
+ model_config = pydantic.ConfigDict(extra="forbid", frozen=True)
28
+
29
+ def __init__(self, *args: "AbstractOptions", **kwargs: Dict[str, Any]) -> None:
30
+ if args:
31
+ merged = {}
32
+ self_class = type(self)
33
+ for options in args:
34
+ if not isinstance(options, self_class):
35
+ raise TypeError(
36
+ f"Positional arguments must be {self_class} instances"
37
+ )
38
+
39
+ merged.update(options.model_dump(exclude_unset=True))
40
+ merged.update(**kwargs)
41
+
42
+ kwargs = merged
43
+
44
+ return super().__init__(**kwargs)
45
+
46
+ def update(
47
+ self: TAbstractOptions, *args: "TAbstractOptions", **kwargs: Any
48
+ ) -> "TAbstractOptions":
49
+ if not args and not kwargs:
50
+ return self
51
+
52
+ return self.__class__(self, *args, **kwargs)
53
+
54
+
55
+ class Priority(IntEnum):
56
+ LOW = 0
57
+ NORMAL = 1
58
+ DEFAULT = 1
59
+ HIGH = 2
60
+ INTERACTIVE = 3
61
+ SYSTEM = 4
62
+
63
+
64
+ class Message(aio_pika.message.Message):
65
+ """AMQP message abstraction"""
66
+
67
+ __slots__ = ()
68
+
69
+ def __init__(
70
+ self,
71
+ body: Union[bytes, BaseModel],
72
+ *,
73
+ headers: Optional[aio_pika.message.HeadersType] = None,
74
+ content_type: Optional[str] = None,
75
+ content_encoding: Optional[str] = None,
76
+ delivery_mode: Union[aio_pika.message.DeliveryMode, int, None] = None,
77
+ priority: Optional[int] = None,
78
+ correlation_id: Optional[str] = None,
79
+ reply_to: Optional[str] = None,
80
+ expiration: Optional[aio_pika.message.DateType] = None,
81
+ message_id: Optional[str] = None,
82
+ timestamp: Optional[aio_pika.message.DateType] = None,
83
+ type: Optional[str] = None,
84
+ user_id: Optional[str] = None,
85
+ app_id: Optional[str] = None,
86
+ ) -> None:
87
+ if isinstance(body, BaseModel):
88
+ body = body.model_dump_msgpack()
89
+ content_type = None
90
+
91
+ super().__init__(
92
+ body=body,
93
+ headers=headers,
94
+ content_type=content_type or "application/msgpack",
95
+ content_encoding=content_encoding,
96
+ delivery_mode=delivery_mode or aio_pika.DeliveryMode.PERSISTENT,
97
+ priority=priority,
98
+ correlation_id=correlation_id,
99
+ reply_to=reply_to,
100
+ expiration=expiration,
101
+ message_id=message_id or str(uuid.uuid1()),
102
+ timestamp=timestamp or time.time(),
103
+ type=type,
104
+ user_id=user_id,
105
+ app_id=app_id,
106
+ )
107
+
108
+
109
+ @dataclass(frozen=True)
110
+ class RequestContext:
111
+ deadline: Optional[float]
112
+ priority: Priority
@@ -0,0 +1,21 @@
1
+ import datetime
2
+ import functools
3
+ import re
4
+
5
+ from typing import Callable
6
+
7
+
8
+ def fix_message_timestamp(func: Callable) -> Callable:
9
+ @functools.wraps(func)
10
+ def inner(message):
11
+ # Fix `pamqp` naive timestamp
12
+ if message.timestamp:
13
+ message.timestamp = message.timestamp.replace(tzinfo=datetime.timezone.utc)
14
+
15
+ return func(message)
16
+
17
+ return inner
18
+
19
+
20
+ def clean_service_name(service_name: str) -> str:
21
+ return re.sub("[^a-z0-9-]+", "-", service_name.strip().lower())
@@ -0,0 +1,21 @@
1
+ [tool.poetry]
2
+ name = "delphai-rpc"
3
+ version = "4.1.6"
4
+ description = "Queue-based RPC client and server"
5
+ authors = ["Anton Ryzhov <anton@delphai.com>"]
6
+ readme = "README.md"
7
+ packages = [{include = "delphai_rpc"}]
8
+
9
+ [tool.poetry.dependencies]
10
+ python = "^3.8"
11
+ aio-pika = { version = "^9.1.4" }
12
+ pydantic = { version = "^2.2" }
13
+ msgpack = { version = "^1.0.5" }
14
+ prometheus-client = "^0.20.0"
15
+
16
+ [tool.poetry.group.dev.dependencies]
17
+ ruff = "^0.5.1"
18
+
19
+ [build-system]
20
+ requires = ["poetry-core"]
21
+ build-backend = "poetry.core.masonry.api"