rmqaio 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rmqaio-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Roman Koshel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
rmqaio-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.1
2
+ Name: rmqaio
3
+ Version: 0.1.0
4
+ Summary:
5
+ License: MIT
6
+ Keywords: rabbitmq
7
+ Author: Roman Koshel
8
+ Author-email: roma.koshel@gmail.com
9
+ Requires-Python: >=3.10,<4.0
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3 :: Only
18
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
19
+ Classifier: Topic :: Utilities
20
+ Requires-Dist: aiormq (>=6.7.6)
21
+ Requires-Dist: rich
22
+ Requires-Dist: yarl
23
+ Description-Content-Type: text/markdown
24
+
25
+ ### rmqaio
26
+
rmqaio-0.1.0/README.md ADDED
@@ -0,0 +1 @@
1
+ ### rmqaio
@@ -0,0 +1,48 @@
1
+ [tool.poetry]
2
+ name = "rmqaio"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Roman Koshel <roma.koshel@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ keywords = ["rabbitmq"]
9
+ classifiers = [
10
+ "Development Status :: 3 - Alpha",
11
+ "Programming Language :: Python",
12
+ "Programming Language :: Python :: 3 :: Only",
13
+ "Topic :: Software Development :: Libraries :: Python Modules",
14
+ "Topic :: Utilities",
15
+ "License :: OSI Approved :: MIT License",
16
+ ]
17
+
18
+ [tool.poetry.dependencies]
19
+ python = "^3.10"
20
+ aiormq = ">=6.7.6"
21
+ rich = "*"
22
+ yarl = "*"
23
+
24
+ [tool.poetry.group.test.dependencies]
25
+ docker = "*"
26
+ httpx = "*"
27
+ pytest = "*"
28
+ pytest_asyncio = "*"
29
+ pytest_timeout = "*"
30
+
31
+ [tool.poetry.group.docs.dependencies]
32
+ mkdocs = "*"
33
+ mkdocs-material = "*"
34
+ mkdocstrings = {extras = ["python"], version = "*"}
35
+
36
+ [build-system]
37
+ requires = ["poetry-core"]
38
+ build-backend = "poetry.core.masonry.api"
39
+
40
+ [tool.black]
41
+ line-length = 120
42
+
43
+ [tool.isort]
44
+ indent = 4
45
+ lines_after_imports = 2
46
+ lines_between_types = 1
47
+ src_paths = ["."]
48
+ profile = "black"
rmqaio-0.1.0/rmqaio.py ADDED
@@ -0,0 +1,820 @@
1
+ import asyncio
2
+ import importlib.metadata
3
+ import itertools
4
+ import logging
5
+ import os
6
+ import ssl
7
+
8
+ from asyncio import FIRST_COMPLETED, Future, Lock, Task, create_task, current_task, get_event_loop, sleep, wait
9
+ from collections.abc import Hashable
10
+ from dataclasses import dataclass, field
11
+ from enum import StrEnum
12
+ from functools import partial, wraps
13
+ from inspect import iscoroutine, iscoroutinefunction
14
+ from ssl import SSLContext
15
+ from typing import Callable, Coroutine, Iterable
16
+ from uuid import uuid4
17
+
18
+ import aiormq
19
+ import aiormq.exceptions
20
+ import yarl
21
+
22
+
23
+ __version__ = importlib.metadata.version("rmqaio")
24
+
25
+ logger = logging.getLogger("rmqaio")
26
+
27
+ CONNECT_TIMEOUT = 15
28
+ LOG_SANITIZE = True
29
+
30
+ BasicProperties = aiormq.spec.Basic.Properties
31
+
32
+
33
+ class QueueType(StrEnum):
34
+ """RabbitMQ queue type"""
35
+
36
+ CLASSIC = "classic"
37
+ QUORUM = "quorum"
38
+
39
+
40
+ def retry(
41
+ *,
42
+ retry_timeouts: Iterable[int],
43
+ exc_filter: Callable[[Exception], bool],
44
+ msg: str | None = None,
45
+ on_error: Callable | None = None,
46
+ ):
47
+ """Retry decorator.
48
+
49
+ Args:
50
+ retry_timeouts: Retry timeout as list of int, for example: `[1, 2, 3]` or `itertools.repeat(5)`.
51
+ exc_filter: Callable to determine whether or not to retry.
52
+ msg: Message to log on retry.
53
+ on_error: Callable to call on error.
54
+ reraise: Reraise exception or not.
55
+ """
56
+
57
+ def decorator(fn):
58
+ @wraps(fn)
59
+ async def wrapper(*args, **kwds):
60
+ timeouts = iter(retry_timeouts)
61
+ attempt = 0
62
+ while True:
63
+ try:
64
+ return await fn(*args, **kwds)
65
+ except Exception as e:
66
+ if not exc_filter(e):
67
+ raise e
68
+ try:
69
+ t = next(timeouts)
70
+ attempt += 1
71
+ logger.warning(
72
+ "%s (%s %s) retry(%s) in %s second(s)",
73
+ msg or fn,
74
+ e.__class__,
75
+ e,
76
+ attempt,
77
+ t,
78
+ )
79
+ if on_error:
80
+ await on_error(e)
81
+ await sleep(t)
82
+ except StopIteration:
83
+ raise e
84
+
85
+ return wrapper
86
+
87
+ return decorator
88
+
89
+
90
+ def create_ssl_context(url: str, password: str | None = None, cwd: str | None = None) -> SSLContext | None:
91
+ """Create ssl context from url"""
92
+
93
+ if not url.startswith("amqps://"):
94
+ return
95
+
96
+ cwd = cwd or os.path.abspath(os.path.dirname(__file__))
97
+
98
+ query = yarl.URL(url).query # pylint: disable=no-member
99
+
100
+ capath = query.get("capath")
101
+ if capath and not capath.startswith("/"):
102
+ capath = os.path.join(cwd, capath)
103
+
104
+ cafile = query.get("cafile")
105
+ if cafile and not cafile.startswith("/"):
106
+ cafile = os.path.join(cwd, cafile)
107
+
108
+ context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, capath=capath, cafile=cafile)
109
+
110
+ cert = query.get("certfile")
111
+ if cert:
112
+ if not cert.startswith("/"):
113
+ cert = os.path.join(cwd, cert)
114
+ keyfile = query.get("keyfile")
115
+ if keyfile and not keyfile.startswith("/"):
116
+ keyfile = os.path.join(cwd, keyfile)
117
+ context.load_cert_chain(cert, keyfile=keyfile, password=password)
118
+
119
+ verify = query.get("no_verify_ssl", "0") == "0"
120
+ if not verify:
121
+ context.check_hostname = False
122
+ context.verify_mode = ssl.CERT_NONE
123
+
124
+ return context
125
+
126
+
127
+ class LoopIter:
128
+ """Infinity iterator class."""
129
+
130
+ __slots__ = ("_data", "_i", "_j", "_iter")
131
+
132
+ def __init__(self, data: list):
133
+ self._data = data
134
+ self._i = -1
135
+ self._j = 0
136
+ self._iter = iter(data)
137
+
138
+ def __next__(self):
139
+ if self._j == len(self._data):
140
+ self._j = 0
141
+ raise StopIteration
142
+ self._i = (self._i + 1) % len(self._data)
143
+ self._j += 1
144
+ return self._data[self._i]
145
+
146
+ def reset(self):
147
+ self._j = 1
148
+
149
+
150
+ class Connection:
151
+ """RabbitMQ connection."""
152
+
153
+ __shared: dict = {}
154
+
155
+ def __init__(
156
+ self,
157
+ url: str | list[str],
158
+ name: str | None = None,
159
+ retry_timeouts: Iterable[int] | None = None,
160
+ exc_filter: Callable[[Exception], bool] | None = None,
161
+ ):
162
+ if not isinstance(url, (list, tuple, set)):
163
+ self.urls = [url]
164
+ else:
165
+ self.urls = list(url)
166
+
167
+ self._urls_iter = LoopIter(self.urls)
168
+
169
+ self.url = next(self._urls_iter)
170
+
171
+ self.name = name or uuid4().hex[-4:]
172
+
173
+ self._ssl_contexts = {url: create_ssl_context(url, cwd=os.path.abspath("~")) for url in self.urls}
174
+
175
+ self._open_task: Task | Future = Future()
176
+ self._open_task.set_result(None)
177
+
178
+ self._watcher_task: Task | None = None
179
+
180
+ self._closed: Future = Future()
181
+
182
+ self._key: tuple = (get_event_loop(), tuple(sorted(self.urls)))
183
+
184
+ if self._key not in self.__shared:
185
+ self.__shared[self._key] = {
186
+ "refs": 0,
187
+ "objs": 0,
188
+ "conn": None,
189
+ "connect_lock": Lock(),
190
+ }
191
+
192
+ shared: dict = self.__shared[self._key]
193
+ shared["objs"] += 1
194
+ shared[self] = {
195
+ "on_open": {},
196
+ "on_lost": {},
197
+ "on_close": {},
198
+ "callback_tasks": {"on_open": {}, "on_lost": {}, "on_close": {}},
199
+ }
200
+ self._shared = shared
201
+
202
+ self._channel: aiormq.abc.AbstractChannel | None = None
203
+
204
+ self._retry_timeouts = retry_timeouts or []
205
+ self._exc_filter = exc_filter or (
206
+ lambda e: isinstance(e, (asyncio.TimeoutError, ConnectionError, aiormq.exceptions.AMQPConnectionError))
207
+ )
208
+
209
+ def __del__(self):
210
+ if getattr(self, "_key", None):
211
+ if self._conn and not self.is_closed:
212
+ logger.warning("%s unclosed", self)
213
+ shared = self._shared
214
+ shared["objs"] -= 1
215
+ if self in shared:
216
+ shared.pop(self, None)
217
+ if shared["objs"] == 0:
218
+ self.__shared.pop(self._key, None)
219
+
220
+ @property
221
+ def _conn(self) -> aiormq.abc.AbstractConnection:
222
+ return self._shared["conn"]
223
+
224
+ @_conn.setter
225
+ def _conn(self, value: aiormq.abc.AbstractConnection | None):
226
+ self._shared["conn"] = value
227
+
228
+ @property
229
+ def _refs(self) -> int:
230
+ return self._shared["refs"]
231
+
232
+ @_refs.setter
233
+ def _refs(self, value: int):
234
+ self._shared["refs"] = value
235
+
236
+ async def _execute_callbacks(self, tp: str, reraise: bool | None = None):
237
+ async def fn(name, callback):
238
+ logger.debug("%s execute callback %s[%s]", self, tp, name)
239
+
240
+ self._shared[self]["callback_tasks"][tp][name] = current_task()
241
+ try:
242
+ if iscoroutinefunction(callback):
243
+ await callback()
244
+ else:
245
+ res = callback()
246
+ if iscoroutine(res):
247
+ await res
248
+ except Exception as e:
249
+ logger.exception("%s callback %s[%s] %s", self, tp, name, callback)
250
+ if reraise:
251
+ raise e
252
+ finally:
253
+ self._shared[self]["callback_tasks"][tp].pop(name, None)
254
+
255
+ for name, callback in tuple(self._shared[self][tp].items()):
256
+ await create_task(fn(name, callback))
257
+
258
+ def set_callback(self, tp: str, name: Hashable, callback: Callable):
259
+ if shared := self._shared.get(self):
260
+ if tp not in shared:
261
+ raise ValueError("invalid callback type")
262
+ shared[tp][name] = callback
263
+
264
+ def remove_callback(self, tp: str, name: Hashable, cancel: bool | None = None):
265
+ if shared := self._shared.get(self):
266
+ if tp not in shared:
267
+ raise ValueError("invalid callback type")
268
+ if name in shared[tp]:
269
+ del shared[tp][name]
270
+ if cancel:
271
+ task = shared["callback_tasks"][tp].get(name)
272
+ if task:
273
+ task.cancel()
274
+
275
+ def remove_callbacks(self, cancel: bool | None = None):
276
+ if self in self._shared:
277
+ if cancel:
278
+ for tp in ("on_open", "on_lost", "on_close"):
279
+ for task in self._shared[self]["callback_tasks"][tp].values():
280
+ task.cancel()
281
+ self._shared[self] = {
282
+ "on_open": {},
283
+ "on_lost": {},
284
+ "on_close": {},
285
+ "callback_tasks": {"on_open": {}, "on_lost": {}, "on_close": {}},
286
+ }
287
+
288
+ def __str__(self):
289
+ return f"{self.__class__.__name__}[{yarl.URL(self.url).host}]#{self.name}"
290
+
291
+ def __repr__(self):
292
+ return self.__str__()
293
+
294
+ @property
295
+ def is_open(self) -> bool:
296
+ return self._watcher_task is not None and not (self.is_closed or self._conn is None or self._conn.is_closed)
297
+
298
+ @property
299
+ def is_closed(self) -> bool:
300
+ return self._closed.done()
301
+
302
+ async def _watcher(self):
303
+ try:
304
+ await wait([self._conn.closing, self._closed], return_when=FIRST_COMPLETED)
305
+ except Exception as e:
306
+ logger.warning("%s %s %s", self, e.__class__, e)
307
+
308
+ self._watcher_task = None
309
+
310
+ if not self._closed.done():
311
+ logger.warning("%s connection lost", self)
312
+ if self._channel:
313
+ await self._channel.close()
314
+ self._refs -= 1
315
+ await self._execute_callbacks("on_lost")
316
+
317
+ async def _connect(self):
318
+ while not self.is_closed:
319
+ connect_timeout = yarl.URL(self.url).query.get("connection_timeout")
320
+ if connect_timeout is not None:
321
+ connect_timeout = int(connect_timeout) / 1000
322
+ else:
323
+ connect_timeout = CONNECT_TIMEOUT
324
+ try:
325
+ logger.info("%s connecting[timeout=%s]...", self, connect_timeout)
326
+
327
+ async with asyncio.timeout(connect_timeout):
328
+ if self._retry_timeouts:
329
+ self._conn = await retry(
330
+ retry_timeouts=self._retry_timeouts,
331
+ exc_filter=self._exc_filter,
332
+ )(aiormq.connect)(
333
+ self.url,
334
+ context=self._ssl_contexts[self.url],
335
+ )
336
+ else:
337
+ self._conn = await aiormq.connect(self.url, context=self._ssl_contexts[self.url])
338
+ self._urls_iter.reset()
339
+ break
340
+ except (asyncio.TimeoutError, ConnectionError, aiormq.exceptions.ConnectionClosed) as e:
341
+ try:
342
+ url = next(self._urls_iter)
343
+ except StopIteration:
344
+ raise e
345
+ logger.warning("%s %s %s", self, e.__class__, e)
346
+ self.url = url
347
+
348
+ logger.info("%s connected", self)
349
+
350
+ async def open(self):
351
+ """Open connection"""
352
+
353
+ if self.is_open:
354
+ return
355
+
356
+ if self.is_closed:
357
+ self._closed = Future()
358
+
359
+ async with self._shared["connect_lock"]:
360
+ if self._conn is None or self._conn.is_closed:
361
+ self._open_task = create_task(self._connect())
362
+ await self._open_task
363
+
364
+ if self._watcher_task is None:
365
+ self._refs += 1
366
+ self._watcher_task = create_task(self._watcher())
367
+ try:
368
+ await self._execute_callbacks("on_open", reraise=True)
369
+ except Exception as e:
370
+ logger.exception(e)
371
+ await self.close()
372
+ raise e
373
+
374
+ async def close(self):
375
+ """Close connection"""
376
+
377
+ if self.is_closed:
378
+ return
379
+
380
+ if not self._open_task.done():
381
+ self._open_task.cancel()
382
+ self._open_task = Future()
383
+
384
+ if self._conn:
385
+ await self._execute_callbacks("on_close")
386
+
387
+ self._closed.set_result(None)
388
+
389
+ self._refs = max(0, self._refs - 1)
390
+ if self._refs == 0:
391
+ if self._conn:
392
+ await self._conn.close()
393
+ self._conn = None
394
+ logger.info("%s close underlying connection", self)
395
+
396
+ self.remove_callbacks(cancel=True)
397
+
398
+ if self._watcher_task:
399
+ await self._watcher_task
400
+ self._watcher_task = None
401
+
402
+ logger.info("%s closed", self)
403
+
404
+ @retry(retry_timeouts=[0], exc_filter=lambda e: isinstance(e, aiormq.exceptions.ConnectionClosed))
405
+ async def new_channel(self) -> aiormq.abc.AbstractChannel:
406
+ """Create new channel"""
407
+
408
+ await self.open()
409
+ return await self._conn.channel()
410
+
411
+ async def channel(self) -> aiormq.abc.AbstractChannel:
412
+ """Get or create channel"""
413
+
414
+ if self._channel is None or self._channel.is_closed:
415
+ await self.open()
416
+ if self._channel is None or self._channel.is_closed:
417
+ self._channel = await self.new_channel()
418
+ return self._channel
419
+
420
+
421
+ @dataclass(slots=True, frozen=True)
422
+ class SimpleExchange:
423
+ """Simple exchange. Only publish method allowed."""
424
+
425
+ name: str = ""
426
+ timeout: int | None = None
427
+ conn: Connection = None # type: ignore
428
+ conn_factory: Callable = field(default=None, repr=False) # type: ignore
429
+
430
+ def __post_init__(self):
431
+ if all((self.conn, self.conn_factory)):
432
+ raise Exception("conn and conn_factory are incompatible")
433
+ if not any((self.conn, self.conn_factory)):
434
+ raise Exception("conn or conn_factory is requried")
435
+ if self.conn_factory:
436
+ object.__setattr__(self, "conn", self.conn_factory())
437
+
438
+ async def close(self):
439
+ logger.debug("Close %s", self)
440
+ try:
441
+ if self.conn_factory:
442
+ self.conn.remove_callbacks(cancel=True)
443
+ finally:
444
+ if self.conn_factory:
445
+ await self.conn.close()
446
+
447
+ async def publish(
448
+ self,
449
+ data: bytes,
450
+ routing_key: str,
451
+ properties: dict | None = None,
452
+ timeout: int | None = None,
453
+ ):
454
+ """
455
+ Publish data to exchange with routing key
456
+ Args:
457
+ data: data to publish.
458
+ routing_key: routing key.
459
+ properties: RabbitMQ message properties.
460
+ timeout: publish operation timeout.
461
+ """
462
+
463
+ channel = await self.conn.channel()
464
+
465
+ logger.debug(
466
+ "Exchange[name='%s'] channel[%s] publish[routing_key='%s'] %s",
467
+ self.name,
468
+ channel,
469
+ routing_key,
470
+ data if not LOG_SANITIZE else "<hiden>",
471
+ )
472
+
473
+ await channel.basic_publish(
474
+ data,
475
+ exchange=self.name,
476
+ routing_key=routing_key,
477
+ properties=BasicProperties(**(properties or {})),
478
+ timeout=timeout or self.timeout,
479
+ )
480
+
481
+
482
+ @dataclass(slots=True, frozen=True)
483
+ class Exchange:
484
+ """RabbitMQ exchange class."""
485
+
486
+ name: str = ""
487
+ type: str = "direct"
488
+ durable: bool = False
489
+ auto_delete: bool = False
490
+ timeout: int | None = None
491
+ conn: Connection = None # type: ignore
492
+ conn_factory: Callable = field(default=None, repr=False) # type: ignore
493
+
494
+ def __post_init__(self):
495
+ if all((self.conn, self.conn_factory)):
496
+ raise Exception("conn and conn_factory are incompatible")
497
+ if not any((self.conn, self.conn_factory)):
498
+ raise Exception("conn or conn_factory is requried")
499
+ if self.conn_factory:
500
+ object.__setattr__(self, "conn", self.conn_factory())
501
+
502
+ async def close(self, delete: bool | None = None, timeout: int | None = None):
503
+ if self.conn.is_closed:
504
+ raise Exception("already closed")
505
+
506
+ logger.debug("Close %s delete[%s]", self, delete)
507
+
508
+ try:
509
+ if self.conn_factory:
510
+ self.conn.remove_callbacks(cancel=True)
511
+ else:
512
+ self.conn.remove_callback("on_open", f"on_open_exchange_{self.name}_declare", cancel=True)
513
+ if delete and self.name != "":
514
+ channel = await self.conn.channel()
515
+ try:
516
+ await channel.exchange_delete(self.name, timeout=timeout or self.timeout)
517
+ except aiormq.exceptions.AMQPError:
518
+ pass
519
+ finally:
520
+ if self.conn_factory:
521
+ await self.conn.close()
522
+
523
+ async def declare(
524
+ self,
525
+ timeout: int | None = None,
526
+ restore: bool | None = None,
527
+ force: bool | None = None,
528
+ ):
529
+ if self.name == "":
530
+ return
531
+
532
+ logger.debug("Declare[force=%s, restore=%s] %s", force, restore, self)
533
+
534
+ async def fn():
535
+ channel = await self.conn.channel()
536
+ await channel.exchange_declare(
537
+ self.name,
538
+ exchange_type=self.type,
539
+ durable=self.durable,
540
+ auto_delete=self.auto_delete,
541
+ timeout=timeout or self.timeout,
542
+ )
543
+
544
+ if force:
545
+
546
+ async def on_error(e):
547
+ channel = await self.conn.channel()
548
+ await channel.exchange_delete(self.name)
549
+
550
+ await retry(
551
+ retry_timeouts=[0],
552
+ exc_filter=lambda e: isinstance(e, aiormq.ChannelPreconditionFailed),
553
+ on_error=on_error,
554
+ )(fn)()
555
+
556
+ else:
557
+ await fn()
558
+
559
+ if restore:
560
+ self.conn.set_callback(
561
+ "on_open",
562
+ f"on_open_exchange_{self.name}_declare",
563
+ partial(self.declare, timeout=timeout),
564
+ )
565
+
566
+ async def publish(
567
+ self,
568
+ data: bytes,
569
+ routing_key: str,
570
+ properties: dict | None = None,
571
+ timeout: int | None = None,
572
+ ):
573
+ channel = await self.conn.channel()
574
+
575
+ logger.debug(
576
+ "Exchange[name='%s'] channel[%s] publish[routing_key='%s'] %s",
577
+ self.name,
578
+ channel,
579
+ routing_key,
580
+ data if not LOG_SANITIZE else "<hiden>",
581
+ )
582
+
583
+ await channel.basic_publish(
584
+ data,
585
+ exchange=self.name,
586
+ routing_key=routing_key,
587
+ properties=BasicProperties(**(properties or {})),
588
+ timeout=timeout or self.timeout,
589
+ )
590
+
591
+
592
+ @dataclass(slots=True, frozen=True)
593
+ class Consumer:
594
+ """Consumer class."""
595
+
596
+ channel: aiormq.abc.AbstractChannel
597
+ consumer_tag: str
598
+
599
+ async def close(self):
600
+ logger.debug("Close %s", self)
601
+ await self.channel.close()
602
+
603
+
604
+ @dataclass(slots=True, frozen=True)
605
+ class Queue:
606
+ """RabbitMQ queue class."""
607
+
608
+ name: str
609
+ type: QueueType = QueueType.CLASSIC
610
+ durable: bool = False
611
+ auto_delete: bool = False
612
+ prefetch_count: int | None = 1
613
+ max_priority: int | None = None
614
+ expires: int | None = None
615
+ msg_ttl: int | None = None
616
+ timeout: int | None = None
617
+ conn: Connection = None # type: ignore
618
+ conn_factory: Callable = field(default=None, repr=False) # type: ignore
619
+ consumer: Consumer | None = field(default=None, init=False)
620
+ bindings: list[tuple[Exchange, str]] = field(default_factory=list, init=False)
621
+
622
+ def __post_init__(self):
623
+ if all((self.conn, self.conn_factory)):
624
+ raise Exception("conn and conn_factory are incompatible")
625
+ if not any((self.conn, self.conn_factory)):
626
+ raise Exception("conn or conn_factory is requried")
627
+ if self.conn_factory:
628
+ object.__setattr__(self, "conn", self.conn_factory())
629
+ self.conn.set_callback(
630
+ "on_lost",
631
+ f"on_lost_queue_{self.name}_cleanup_consumer",
632
+ lambda: object.__setattr__(self, "consumer", None),
633
+ )
634
+ self.conn.set_callback(
635
+ "on_close",
636
+ f"on_close_queue_{self.name}_cleanup_consumer",
637
+ lambda: object.__setattr__(self, "consumer", None),
638
+ )
639
+
640
+ async def close(self, delete: bool | None = None, timeout: int | None = None):
641
+ if self.conn.is_closed:
642
+ raise Exception("already closed")
643
+
644
+ logger.debug("Close %s delete[%s]", self, delete)
645
+
646
+ try:
647
+ await self.stop_consume()
648
+ for exchange, routing_key in self.bindings:
649
+ await self.unbind(exchange, routing_key)
650
+ if self.conn_factory:
651
+ self.conn.remove_callbacks(cancel=True)
652
+ else:
653
+ self.conn.remove_callback("on_open", f"on_open_queue_{self.name}_declare", cancel=True)
654
+ if delete:
655
+ channel = await self.conn.channel()
656
+ try:
657
+ await channel.queue_delete(self.name, timeout=timeout or self.timeout)
658
+ except aiormq.exceptions.AMQPError as e:
659
+ logger.warning(e)
660
+ finally:
661
+ if self.conn_factory:
662
+ await self.conn.close()
663
+
664
+ async def declare(self, timeout: int | None = None, restore: bool | None = None, force: bool | None = None):
665
+ logger.debug("Declare[force=%s, restore=%s] %s", force, restore, self)
666
+
667
+ async def fn():
668
+ channel = await self.conn.channel()
669
+ arguments = {
670
+ "x-queue-type": self.type,
671
+ }
672
+ if self.max_priority:
673
+ arguments["x-max-priority"] = self.max_priority
674
+ if self.expires:
675
+ arguments["x-expires"] = int(self.expires) * 1000
676
+ if self.msg_ttl:
677
+ arguments["x-message-ttl"] = int(self.msg_ttl) * 1000
678
+ await channel.queue_declare(
679
+ self.name,
680
+ durable=self.durable,
681
+ auto_delete=self.auto_delete,
682
+ arguments=arguments,
683
+ timeout=timeout or self.timeout,
684
+ )
685
+
686
+ if force:
687
+
688
+ async def on_error(e):
689
+ channel = await self.conn.channel()
690
+ await channel.queue_delete(self.name)
691
+
692
+ await retry(
693
+ retry_timeouts=[0],
694
+ exc_filter=lambda e: isinstance(e, aiormq.ChannelPreconditionFailed),
695
+ on_error=on_error,
696
+ )(fn)()
697
+
698
+ else:
699
+ await fn()
700
+
701
+ if restore:
702
+ self.conn.set_callback(
703
+ "on_open",
704
+ f"on_open_queue_{self.name}_declare",
705
+ partial(self.declare, timeout=timeout),
706
+ )
707
+
708
+ async def bind(
709
+ self,
710
+ exchange: Exchange,
711
+ routing_key: str,
712
+ timeout: int | None = None,
713
+ restore: bool | None = None,
714
+ ):
715
+ logger.debug(
716
+ "Bind queue '%s' to exchange '%s' with routing_key '%s'",
717
+ self.name,
718
+ exchange.name,
719
+ routing_key,
720
+ )
721
+
722
+ channel = await self.conn.channel()
723
+ await channel.queue_bind(
724
+ self.name,
725
+ exchange.name,
726
+ routing_key=routing_key,
727
+ timeout=timeout or self.timeout,
728
+ )
729
+
730
+ if not (exchange, routing_key) in self.bindings:
731
+ self.bindings.append((exchange, routing_key))
732
+
733
+ if restore:
734
+ self.conn.set_callback(
735
+ "on_open",
736
+ f"on_open_queue_{self.name}_bind_{exchange.name}_{routing_key}",
737
+ partial(self.bind, exchange, routing_key, timeout=timeout),
738
+ )
739
+
740
+ async def unbind(self, exchange: Exchange, routing_key: str, timeout: int | None = None):
741
+ logger.debug(
742
+ "Unbind queue '%s' from exchange '%s' for routing_key '%s'",
743
+ self.name,
744
+ exchange.name,
745
+ routing_key,
746
+ )
747
+
748
+ if (exchange, routing_key) in self.bindings:
749
+ self.bindings.remove((exchange, routing_key))
750
+
751
+ channel = await self.conn.channel()
752
+ await channel.queue_unbind(
753
+ self.name,
754
+ exchange.name,
755
+ routing_key=routing_key,
756
+ timeout=timeout or self.timeout,
757
+ )
758
+
759
+ self.conn.remove_callback(
760
+ "on_open",
761
+ f"on_open_queue_{self.name}_bind_{exchange.name}_{routing_key}",
762
+ cancel=True,
763
+ )
764
+
765
+ async def consume(
766
+ self,
767
+ callback: Callable[[aiormq.Channel, aiormq.abc.DeliveredMessage], Coroutine],
768
+ prefetch_count: int | None = None,
769
+ timeout: int | None = None,
770
+ retry_timeout: int = 5,
771
+ ):
772
+ if self.consumer is None:
773
+ channel = await self.conn.new_channel()
774
+ await channel.basic_qos(
775
+ prefetch_count=prefetch_count or self.prefetch_count,
776
+ timeout=timeout or self.timeout,
777
+ )
778
+
779
+ object.__setattr__(
780
+ self,
781
+ "consumer",
782
+ Consumer(
783
+ channel=channel,
784
+ consumer_tag=(
785
+ await channel.basic_consume(
786
+ self.name,
787
+ partial(callback, channel),
788
+ timeout=timeout or self.timeout,
789
+ )
790
+ ).consumer_tag,
791
+ ),
792
+ )
793
+
794
+ logger.info("Consuming %s", self)
795
+
796
+ self.conn.set_callback(
797
+ "on_lost",
798
+ f"on_lost_queue_{self.name}_consume",
799
+ partial(
800
+ retry(
801
+ retry_timeouts=itertools.repeat(retry_timeout),
802
+ exc_filter=lambda e: True,
803
+ )(self.consume),
804
+ callback,
805
+ prefetch_count=prefetch_count,
806
+ timeout=timeout,
807
+ ),
808
+ )
809
+
810
+ return self.consumer
811
+
812
+ async def stop_consume(self, timeout: int | None = None):
813
+ logger.debug("Stop consume %s", self)
814
+
815
+ self.conn.remove_callback("on_lost", f"on_lost_queue_{self.name}_consume", cancel=True)
816
+
817
+ if self.consumer and not self.consumer.channel.is_closed:
818
+ await self.consumer.channel.basic_cancel(self.consumer.consumer_tag, timeout=timeout)
819
+ await self.consumer.close()
820
+ object.__setattr__(self, "consumer", None)