qena-shared-lib 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,74 @@
1
+ from asyncio import Lock
2
+ from uuid import UUID
3
+
4
+ from pika.adapters.asyncio_connection import AsyncioConnection
5
+
6
+ from ..utils import AsyncEventLoopMixin
7
+ from ._channel import BaseChannel
8
+
9
+ __all__ = ["ChannelPool"]
10
+
11
+
12
+ class ChannelPool(AsyncEventLoopMixin):
13
+ def __init__(
14
+ self,
15
+ initial_pool_size: int = 50,
16
+ ):
17
+ self._initial_pool_size = initial_pool_size
18
+ self._pool: dict[UUID, BaseChannel] = {}
19
+ self._pool_lock = Lock()
20
+
21
+ async def fill(
22
+ self,
23
+ connection: AsyncioConnection,
24
+ ):
25
+ self._connection = connection
26
+
27
+ async with self._pool_lock:
28
+ for _ in range(self._initial_pool_size):
29
+ await self._add_new_channel()
30
+
31
+ async def dain(self):
32
+ async with self._pool_lock:
33
+ for channel_base in self._pool.values():
34
+ with channel_base as channel:
35
+ if channel.is_closing or channel.is_closed:
36
+ continue
37
+
38
+ channel.close()
39
+
40
+ self._pool.clear()
41
+
42
+ async def _add_new_channel(self) -> BaseChannel:
43
+ channel = BaseChannel(
44
+ connection=self._connection,
45
+ )
46
+ channel_id = await channel.open()
47
+ self._pool[channel_id] = channel
48
+
49
+ return channel
50
+
51
+ async def get(self) -> BaseChannel:
52
+ async with self._pool_lock:
53
+ _channel_base = None
54
+
55
+ for channel_base in self._pool.values():
56
+ if not channel_base.healthy:
57
+ _ = self._pool.pop(channel_base.channel_id)
58
+
59
+ continue
60
+
61
+ if not channel_base.reserved:
62
+ _channel_base = channel_base
63
+
64
+ break
65
+
66
+ if _channel_base is None:
67
+ _channel_base = await self._add_new_channel()
68
+
69
+ _channel_base.reserve()
70
+
71
+ return _channel_base
72
+
73
+ def __len__(self) -> int:
74
+ return len(self._pool)
@@ -0,0 +1,73 @@
1
+ from typing import Any, Callable
2
+
3
+ from pika import BasicProperties
4
+ from prometheus_client import Counter
5
+ from pydantic_core import to_json
6
+
7
+ from ..logging import LoggerProvider
8
+ from ._pool import ChannelPool
9
+
10
+ __all__ = ["Publisher"]
11
+
12
+
13
+ class Publisher:
14
+ PUBLISHED_MESSAGES = Counter(
15
+ name="published_messages",
16
+ documentation="Published messages",
17
+ labelnames=["routing_key", "target"],
18
+ )
19
+
20
+ def __init__(
21
+ self,
22
+ routing_key: str,
23
+ channel_pool: ChannelPool,
24
+ blocked_connection_check_callback: Callable[[], bool],
25
+ exchange: str | None = None,
26
+ target: str | None = None,
27
+ headers: dict[str, str] | None = None,
28
+ ):
29
+ self._routing_key = routing_key
30
+ self._exchange = exchange or ""
31
+ self._headers = headers or {}
32
+ self._target = target if target is not None else "__default__"
33
+
34
+ self._headers.update({"target": self._target})
35
+
36
+ self._channel_pool = channel_pool
37
+ self._blocked_connection_check_callback = (
38
+ blocked_connection_check_callback
39
+ )
40
+ self._logger = LoggerProvider.default().get_logger("rabbitmq.publisher")
41
+
42
+ async def publish_as_arguments(self, *args, **kwargs):
43
+ await self._get_channel_and_publish({"args": args, "kwargs": kwargs})
44
+
45
+ async def publish(self, message: Any | None = None):
46
+ await self._get_channel_and_publish(message)
47
+
48
+ async def _get_channel_and_publish(self, message: Any):
49
+ if self._blocked_connection_check_callback():
50
+ raise RuntimeError(
51
+ "rabbitmq broker is not able to accept message right now"
52
+ )
53
+
54
+ with await self._channel_pool.get() as channel:
55
+ channel.basic_publish(
56
+ exchange=self._exchange,
57
+ routing_key=self._routing_key,
58
+ body=to_json(message),
59
+ properties=BasicProperties(
60
+ content_type="application/json",
61
+ headers=self._headers,
62
+ ),
63
+ )
64
+
65
+ self._logger.debug(
66
+ "message published to exchange `%s`, routing key `%s` and target `%s`",
67
+ self._exchange,
68
+ self._routing_key,
69
+ self._target,
70
+ )
71
+ self.PUBLISHED_MESSAGES.labels(
72
+ routing_key=self._routing_key, target=self._target
73
+ ).inc()
@@ -0,0 +1,286 @@
1
+ from asyncio import Future, Lock
2
+ from functools import partial
3
+ from importlib import import_module
4
+ from time import time
5
+ from typing import Any, Callable
6
+ from uuid import uuid4
7
+
8
+ from pika import BasicProperties
9
+ from pika.channel import Channel
10
+ from pika.frame import Method
11
+ from pika.spec import Basic
12
+ from prometheus_client import Counter, Summary
13
+ from pydantic_core import from_json, to_json
14
+
15
+ from ..logging import LoggerProvider
16
+ from ..utils import AsyncEventLoopMixin
17
+ from ._exceptions import RabbitMQException
18
+ from ._pool import ChannelPool
19
+ from ._utils import TypeAdapterCache
20
+
21
+ __all__ = ["RpcClient"]
22
+
23
+
24
+ class ExitHandler:
25
+ _exiting = False
26
+ _rpc_futures = []
27
+ _original_exit_handler: Callable
28
+
29
+ @classmethod
30
+ def is_exising(cls):
31
+ return cls._exiting
32
+
33
+ @classmethod
34
+ def add_rpc_future(cls, rpc_future: Future):
35
+ cls._rpc_futures.append(rpc_future)
36
+
37
+ @classmethod
38
+ def remove_rpc_future(cls, rpc_future: Future):
39
+ try:
40
+ cls._rpc_futures.remove(rpc_future)
41
+ except:
42
+ pass
43
+
44
+ @classmethod
45
+ def cancel_futures(cls):
46
+ cls._exiting = True
47
+
48
+ for rpc_future in cls._rpc_futures:
49
+ if not rpc_future.done():
50
+ rpc_future.cancel()
51
+
52
+ @staticmethod
53
+ def patch_exit_handler():
54
+ try:
55
+ Server = import_module("uvicorn.server").Server
56
+ except ModuleNotFoundError:
57
+ return
58
+
59
+ ExitHandler._original_exit_handler = Server.handle_exit
60
+ Server.handle_exit = ExitHandler.handle_exit
61
+
62
+ @staticmethod
63
+ def notify_clients():
64
+ ExitHandler.cancel_futures()
65
+
66
+ @staticmethod
67
+ def handle_exit(*args, **kwargs):
68
+ ExitHandler.notify_clients()
69
+ ExitHandler._original_exit_handler(*args, **kwargs)
70
+
71
+
72
+ ExitHandler.patch_exit_handler()
73
+
74
+
75
+ class RpcClient(AsyncEventLoopMixin):
76
+ SUCCEEDED_RPC_CALLS = Counter(
77
+ name="succeeded_rpc_calls",
78
+ documentation="RPC calls made",
79
+ labelnames=["routing_key", "procedure"],
80
+ )
81
+ FAILED_RPC_CALL = Counter(
82
+ name="failed_rpc_call",
83
+ documentation="Failed RPC calls",
84
+ labelnames=["routing_key", "procedure", "exception"],
85
+ )
86
+ RPC_CALL_LATENCY = Summary(
87
+ name="rpc_call_latency",
88
+ documentation="Time it took for RPC calls",
89
+ labelnames=["routing_key", "procedure"],
90
+ )
91
+
92
+ def __init__(
93
+ self,
94
+ routing_key: str,
95
+ channel_pool: ChannelPool,
96
+ blocked_connection_check_callback: Callable[[], bool],
97
+ exchange: str | None = None,
98
+ procedure: str | None = None,
99
+ headers: dict[str, str] | None = None,
100
+ return_type: type | None = None,
101
+ timeout: float = 0,
102
+ ):
103
+ self._routing_key = routing_key
104
+ self._exchange = exchange or ""
105
+ self._headers = headers or {}
106
+ self._procedure = procedure if procedure is not None else "__default__"
107
+
108
+ self._headers.update({"procedure": self._procedure})
109
+
110
+ self._return_type = return_type
111
+ self._timeout = timeout
112
+ self._channel_pool = channel_pool
113
+ self._blocked_connection_check_callback = (
114
+ blocked_connection_check_callback
115
+ )
116
+ self._rpc_future = None
117
+ self._rpc_call_start_time = None
118
+ self._rpc_call_lock = Lock()
119
+ self._rpc_call_pending = False
120
+ self._logger = LoggerProvider.default().get_logger(
121
+ "rabbitmq.rpc_client"
122
+ )
123
+
124
+ async def call_with_arguments(self, *args, **kwargs) -> Any:
125
+ return await self._get_channel_and_call(
126
+ ({"args": args, "kwargs": kwargs})
127
+ )
128
+
129
+ async def call(self, message: Any | None = None) -> Any:
130
+ return await self._get_channel_and_call(message)
131
+
132
+ async def _get_channel_and_call(self, message: Any) -> Any:
133
+ if self._blocked_connection_check_callback():
134
+ raise RuntimeError(
135
+ "rabbitmq broker is not able to accept message right now"
136
+ )
137
+
138
+ async with self._rpc_call_lock:
139
+ if self._rpc_call_pending:
140
+ raise RuntimeError("previous rpc request not done yet")
141
+
142
+ self._rpc_call_pending = True
143
+
144
+ self._rpc_call_start_time = time()
145
+ self._channel_base = await self._channel_pool.get()
146
+ self._channel = self._channel_base.channel
147
+ self._rpc_future = self.loop.create_future()
148
+
149
+ ExitHandler.add_rpc_future(self._rpc_future)
150
+ self._channel.queue_declare(
151
+ queue="",
152
+ exclusive=True,
153
+ auto_delete=True,
154
+ callback=partial(self._on_queue_declared, message),
155
+ )
156
+
157
+ return await self._rpc_future
158
+
159
+ def _on_queue_declared(self, message: Any, method: Method):
160
+ try:
161
+ self._rpc_reply_consumer_tag = self._channel.basic_consume(
162
+ queue=method.method.queue,
163
+ on_message_callback=self._on_reply_message,
164
+ auto_ack=True,
165
+ )
166
+ except Exception as e:
167
+ self._finalize_call(exception=e)
168
+
169
+ return
170
+
171
+ self._correlation_id = str(uuid4())
172
+
173
+ try:
174
+ self._channel.basic_publish(
175
+ exchange=self._exchange,
176
+ routing_key=self._routing_key,
177
+ properties=BasicProperties(
178
+ content_type="application/json",
179
+ reply_to=method.method.queue,
180
+ correlation_id=self._correlation_id,
181
+ headers=self._headers,
182
+ ),
183
+ body=to_json(message),
184
+ )
185
+ except Exception as e:
186
+ self._finalize_call(exception=e)
187
+
188
+ return
189
+
190
+ if self._timeout > 0:
191
+ _ = self.loop.call_later(
192
+ delay=self._timeout, callback=self._on_timeout
193
+ )
194
+
195
+ self._logger.debug(
196
+ "rpc request sent to exchange `%s`, routing key `%s` and procedure `%s`",
197
+ self._exchange,
198
+ self._routing_key,
199
+ self._procedure,
200
+ )
201
+
202
+ def _on_timeout(self):
203
+ self._finalize_call(
204
+ exception=TimeoutError(
205
+ f"rpc worker didn't responed in a timely manner within `{self._timeout}` seconds"
206
+ )
207
+ )
208
+
209
+ def _on_reply_message(
210
+ self,
211
+ channel: Channel,
212
+ method: Basic.Deliver,
213
+ properties: BasicProperties,
214
+ body: bytes,
215
+ ):
216
+ del channel, method
217
+
218
+ if properties.correlation_id != self._correlation_id:
219
+ self._finalize_call(
220
+ exception=ValueError(
221
+ f"correlation id {properties.correlation_id} from rpc worker doesn't match {self._correlation_id}"
222
+ )
223
+ )
224
+
225
+ return
226
+
227
+ try:
228
+ response = from_json(body)
229
+ except Exception as e:
230
+ self._finalize_call(exception=e)
231
+
232
+ return
233
+
234
+ if isinstance(response, dict) and "exception" in response:
235
+ self._finalize_call(
236
+ exception=RabbitMQException(
237
+ code=response.get("code") or 0,
238
+ message=response.get("message")
239
+ or "unknown error occured from the rpc worker side",
240
+ )
241
+ )
242
+
243
+ if self._return_type is not None:
244
+ type_adapter = TypeAdapterCache.get_type_adapter(self._return_type)
245
+
246
+ try:
247
+ response = type_adapter.validate_python(response)
248
+ except Exception as e:
249
+ self._finalize_call(exception=e)
250
+
251
+ return
252
+
253
+ self._finalize_call(response=response)
254
+
255
+ def _finalize_call(
256
+ self,
257
+ response: Any | None = None,
258
+ exception: BaseException | None = None,
259
+ ):
260
+ self._rpc_call_pending = False
261
+ self._channel.basic_cancel(self._rpc_reply_consumer_tag)
262
+ self._channel_base.release()
263
+
264
+ if self._rpc_future is None:
265
+ return
266
+
267
+ ExitHandler.remove_rpc_future(self._rpc_future)
268
+
269
+ if self._rpc_future.done():
270
+ return
271
+ elif exception is not None:
272
+ self._rpc_future.set_exception(exception)
273
+ self.FAILED_RPC_CALL.labels(
274
+ routing_key=self._routing_key,
275
+ procedure=self._procedure,
276
+ exception=exception.__class__.__name__,
277
+ ).inc()
278
+ else:
279
+ self._rpc_future.set_result(response)
280
+ self.SUCCEEDED_RPC_CALLS.labels(
281
+ routing_key=self._routing_key, procedure=self._procedure
282
+ ).inc()
283
+
284
+ self.RPC_CALL_LATENCY.labels(
285
+ routing_key=self._routing_key, procedure=self._procedure
286
+ ).observe((self._rpc_call_start_time or time()) - time())
@@ -0,0 +1,18 @@
1
+ from pydantic import TypeAdapter
2
+
3
+ __all__ = ["TypeAdapterCache"]
4
+
5
+
6
+ class TypeAdapterCache:
7
+ _cache = {}
8
+
9
+ @classmethod
10
+ def cache_annotation(cls, annotation: type):
11
+ if annotation not in cls._cache:
12
+ cls._cache[annotation] = TypeAdapter(annotation)
13
+
14
+ @classmethod
15
+ def get_type_adapter(cls, annotation: type) -> TypeAdapter:
16
+ cls.cache_annotation(annotation)
17
+
18
+ return cls._cache[annotation]