qena-shared-lib 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,599 @@
1
+ from contextlib import asynccontextmanager
2
+ from datetime import datetime
3
+ from typing import (
4
+ Annotated,
5
+ Any,
6
+ AsyncGenerator,
7
+ Generic,
8
+ TypeAlias,
9
+ TypeVar,
10
+ cast,
11
+ get_args,
12
+ overload,
13
+ )
14
+
15
+ from bson.objectid import ObjectId
16
+ from pydantic import BeforeValidator, Field, field_serializer
17
+ from pymongo import (
18
+ ASCENDING,
19
+ DESCENDING,
20
+ GEO2D,
21
+ GEOSPHERE,
22
+ HASHED,
23
+ TEXT,
24
+ AsyncMongoClient,
25
+ IndexModel,
26
+ )
27
+ from pymongo.asynchronous.client_session import AsyncClientSession
28
+ from pymongo.asynchronous.collection import AsyncCollection
29
+ from pymongo.asynchronous.database import AsyncDatabase
30
+ from typing_extensions import Self
31
+
32
+ from .alias import CamelCaseAliasedBaseModel
33
+ from .logging import LoggerFactory
34
+
35
+ __all__ = [
36
+ "AggregatedDocument",
37
+ "ASCENDING",
38
+ "AsyncClientSession",
39
+ "DESCENDING",
40
+ "Document",
41
+ "EmbeddedDocument",
42
+ "Field",
43
+ "GEO2D",
44
+ "GEOSPHERE",
45
+ "HASHED",
46
+ "IndexManager",
47
+ "IndexModel",
48
+ "MongoDBManager",
49
+ "MongoDBObjectId",
50
+ "ProjectedDocument",
51
+ "RepositoryBase",
52
+ "TEXT",
53
+ "validate_object_id",
54
+ ]
55
+
56
+
57
+ class MongoDBManager:
58
+ def __init__(self, connection_string: str, db: str | None = None):
59
+ self._client = AsyncMongoClient(connection_string)
60
+ self._db = self._client.get_database(db)
61
+ self._logger = LoggerFactory.get_logger("mongodb_manager")
62
+
63
+ async def connect(self) -> None:
64
+ await self._client.aconnect()
65
+
66
+ host = "localhost"
67
+ port = 27017
68
+ address = await self._client.address
69
+
70
+ if address is not None:
71
+ host, port = address
72
+
73
+ self._logger.info("connected to mongodb server `%s:%s`", host, port)
74
+
75
+ async def disconnect(self) -> None:
76
+ await self._client.aclose()
77
+ self._logger.info("disconnected from mongodb")
78
+
79
+ @property
80
+ def client(self) -> AsyncMongoClient:
81
+ return self._client
82
+
83
+ @property
84
+ def db(self) -> AsyncDatabase:
85
+ return self._db
86
+
87
+ @asynccontextmanager
88
+ async def transactional(self) -> AsyncGenerator[AsyncClientSession, None]:
89
+ async with self.client.start_session() as session:
90
+ async with await session.start_transaction():
91
+ yield session
92
+
93
+ def __getitem__(self, document: type["Document"]) -> AsyncCollection:
94
+ return self._db.get_collection(document.get_collection_name())
95
+
96
+
97
+ class TimeStampMixin(CamelCaseAliasedBaseModel):
98
+ created_at: datetime = Field(default_factory=datetime.now)
99
+ updated_at: datetime = Field(default_factory=datetime.now)
100
+
101
+ @field_serializer("updated_at", when_used="always")
102
+ def serialize_updated_at(self, _: datetime) -> datetime:
103
+ return datetime.now()
104
+
105
+
106
+ def validate_object_id(value: Any) -> ObjectId:
107
+ if not ObjectId.is_valid(value):
108
+ raise ValueError(f"{value} is not valid objectid")
109
+
110
+ return ObjectId(value)
111
+
112
+
113
+ MongoDBObjectId: TypeAlias = Annotated[
114
+ ObjectId, BeforeValidator(validate_object_id)
115
+ ]
116
+
117
+
118
+ class Document(CamelCaseAliasedBaseModel):
119
+ id: MongoDBObjectId = Field(alias="_id", default_factory=ObjectId)
120
+
121
+ @classmethod
122
+ def get_collection_name(cls) -> str:
123
+ collection_name = getattr(cls, "__collection_name__", None)
124
+
125
+ if collection_name is None:
126
+ collection_name = cls.__name__
127
+
128
+ return collection_name
129
+
130
+ @classmethod
131
+ def get_indexes(cls) -> list[IndexModel] | None:
132
+ return getattr(cls, "__indexes__", None)
133
+
134
+ @classmethod
135
+ def from_raw_document(cls, document: Any, **kwargs: Any) -> Self:
136
+ return cast(Self, cls.model_validate(document, **kwargs))
137
+
138
+
139
+ class EmbeddedDocument(CamelCaseAliasedBaseModel):
140
+ @classmethod
141
+ def from_raw_embedded_document(
142
+ cls, embedded_document: Any, **kwargs: Any
143
+ ) -> Self:
144
+ return cast(Self, cls.model_validate(embedded_document, **kwargs))
145
+
146
+
147
+ class ProjectedDocument(CamelCaseAliasedBaseModel):
148
+ @classmethod
149
+ def from_raw_projected_document(
150
+ cls, projected_document: Any, **kwargs: Any
151
+ ) -> Self:
152
+ return cast(Self, cls.model_validate(projected_document, **kwargs))
153
+
154
+ @classmethod
155
+ def get_projection(cls) -> list[str] | dict[str, Any]:
156
+ projection = getattr(cls, "__projection__", None)
157
+
158
+ if projection is None:
159
+ projection = cls._projection_from_field_info()
160
+
161
+ return cast(list[str] | dict[str, Any], projection)
162
+
163
+ @classmethod
164
+ def _projection_from_field_info(cls) -> list[str]:
165
+ cls.__projection__ = [
166
+ field_info.alias or field_name
167
+ for field_name, field_info in cls.model_fields.items()
168
+ ]
169
+
170
+ return cls.__projection__
171
+
172
+
173
+ class AggregatedDocument(CamelCaseAliasedBaseModel):
174
+ @classmethod
175
+ def from_raw_aggregated_document(cls, obj: Any, **kwargs: Any) -> Self:
176
+ return cast(Self, cls.model_validate(obj, **kwargs))
177
+
178
+ @classmethod
179
+ def get_pipeline(cls) -> list[Any]:
180
+ projection = getattr(cls, "__pipeline__", None)
181
+
182
+ if projection is None:
183
+ raise ValueError(
184
+ f"__pipeline__ is not defined for aggregated document {cls.__name__}"
185
+ )
186
+
187
+ return cast(list[Any], projection)
188
+
189
+
190
+ class IndexManager:
191
+ def __init__(
192
+ self, db: MongoDBManager, documents: list[type[Document]]
193
+ ) -> None:
194
+ self._db = db
195
+ self._documents = documents
196
+
197
+ async def create_indexes(self) -> None:
198
+ for document in self._documents:
199
+ indexes = document.get_indexes()
200
+
201
+ if indexes is None:
202
+ continue
203
+
204
+ await self._db[document].create_indexes(indexes)
205
+
206
+ async def get_indexes(self, collection_name: str) -> list[str]:
207
+ document = self._get_document(collection_name)
208
+ indexes = []
209
+
210
+ async with await self._db[document].list_indexes() as cursor:
211
+ async for index in cursor:
212
+ indexes.append(index.get("name"))
213
+
214
+ return indexes
215
+
216
+ async def drop_indexes(
217
+ self, collection_names: list[str] | None = None
218
+ ) -> None:
219
+ for document in self._documents:
220
+ if (
221
+ collection_names is not None
222
+ and document.get_collection_name() not in collection_names
223
+ ):
224
+ continue
225
+
226
+ await self._db[document].drop_indexes()
227
+
228
+ async def drop_index(self, collection_name: str, index_name: str) -> None:
229
+ document = self._get_document(collection_name)
230
+
231
+ await self._db[document].drop_index(index_name)
232
+
233
+ def _get_document(self, collection_name: str) -> type[Document]:
234
+ document = None
235
+
236
+ for document in self._documents:
237
+ if document.get_collection_name() == collection_name:
238
+ break
239
+
240
+ if document is None:
241
+ raise ValueError(
242
+ f"collection with name {collection_name} not found"
243
+ )
244
+
245
+ return document
246
+
247
+
248
+ T = TypeVar("T", bound=Document)
249
+ P = TypeVar("P", bound=ProjectedDocument)
250
+ A = TypeVar("A", bound=AggregatedDocument)
251
+ S = TypeVar("S")
252
+
253
+
254
+ class RepositoryBase(Generic[T]):
255
+ def __init__(self, db: MongoDBManager) -> None:
256
+ self._db = db
257
+ self._session = None
258
+ self._document_type = None
259
+
260
+ @property
261
+ def db(self) -> MongoDBManager:
262
+ return self._db
263
+
264
+ @property
265
+ def collection(self) -> AsyncCollection:
266
+ return self._db[self.document_type]
267
+
268
+ @property
269
+ def session(self) -> AsyncClientSession | None:
270
+ return self._session
271
+
272
+ @session.setter
273
+ def session(self, session: AsyncClientSession) -> None:
274
+ if session.has_ended:
275
+ raise RuntimeError(
276
+ f"session with id {session.session_id} has already ended"
277
+ )
278
+
279
+ self._session = session
280
+
281
+ @property
282
+ def document_type(self) -> type[T]:
283
+ document_type = self._document_type
284
+
285
+ if document_type is None:
286
+ orig_bases = getattr(self, "__orig_bases__", None)
287
+
288
+ if not orig_bases:
289
+ raise RuntimeError("generic variable T is not specified")
290
+
291
+ *_, orig_class = orig_bases
292
+ *_, self._document_type = get_args(orig_class)
293
+
294
+ return cast(type[T], self._document_type)
295
+
296
+ async def insert(
297
+ self, document: T, session: AsyncClientSession | None = None
298
+ ) -> ObjectId | str:
299
+ inserted_one_result = await self.collection.insert_one(
300
+ document=document.model_dump(by_alias=True),
301
+ session=session or self.session,
302
+ )
303
+
304
+ return inserted_one_result.inserted_id
305
+
306
+ async def insert_many(
307
+ self, documents: list[T], session: AsyncClientSession | None = None
308
+ ) -> list[ObjectId] | list[str]:
309
+ insert_many_result = await self.collection.insert_many(
310
+ documents=[
311
+ document.model_dump(by_alias=True) for document in documents
312
+ ],
313
+ session=session or self.session,
314
+ )
315
+
316
+ return cast(list[ObjectId] | list[str], insert_many_result.inserted_ids)
317
+
318
+ @overload
319
+ async def find_by_id(
320
+ self,
321
+ *,
322
+ id: Any,
323
+ skip: int = 0,
324
+ sort: dict[str, int] | None = None,
325
+ session: AsyncClientSession | None = None,
326
+ ) -> T | None:
327
+ pass
328
+
329
+ @overload
330
+ async def find_by_id(
331
+ self,
332
+ *,
333
+ id: Any,
334
+ projection: type[P],
335
+ skip: int = 0,
336
+ sort: dict[str, int] | None = None,
337
+ session: AsyncClientSession | None = None,
338
+ ) -> P | None:
339
+ pass
340
+
341
+ async def find_by_id(self, *_: Any, **kwargs: Any) -> Any:
342
+ return await self._find_one(
343
+ filter={"_id": kwargs["id"]},
344
+ projection=kwargs.get("projection"),
345
+ skip=kwargs.get("skip", 0),
346
+ sort=kwargs.get("sort"),
347
+ session=kwargs.get("session"),
348
+ )
349
+
350
+ @overload
351
+ async def find_by_filter(
352
+ self,
353
+ *,
354
+ filter: dict[str, Any],
355
+ skip: int = 0,
356
+ sort: dict[str, int] | None = None,
357
+ session: AsyncClientSession | None = None,
358
+ ) -> T | None:
359
+ pass
360
+
361
+ @overload
362
+ async def find_by_filter(
363
+ self,
364
+ *,
365
+ filter: dict[str, Any],
366
+ projection: type[P],
367
+ skip: int = 0,
368
+ sort: dict[str, int] | None = None,
369
+ session: AsyncClientSession | None = None,
370
+ ) -> P | None:
371
+ pass
372
+
373
+ async def find_by_filter(self, *_: Any, **kwargs: Any) -> Any:
374
+ return await self._find_one(
375
+ filter=kwargs["filter"],
376
+ projection=kwargs.get("projection"),
377
+ skip=kwargs.get("skip", 0),
378
+ sort=kwargs.get("sort"),
379
+ session=kwargs.get("session"),
380
+ )
381
+
382
+ async def _find_one(
383
+ self,
384
+ filter: dict[str, Any],
385
+ projection: type[P] | None = None,
386
+ skip: int = 0,
387
+ sort: dict[str, int] | None = None,
388
+ session: AsyncClientSession | None = None,
389
+ ) -> T | P | None:
390
+ if projection is not None:
391
+ return projection.from_raw_projected_document(
392
+ await self.collection.find_one(
393
+ filter=filter,
394
+ projection=projection.get_projection(),
395
+ skip=skip,
396
+ sort=sort,
397
+ session=session or self.session,
398
+ )
399
+ )
400
+
401
+ return self.document_type.from_raw_document(
402
+ await self.collection.find_one(
403
+ filter=filter,
404
+ projection=projection,
405
+ skip=skip,
406
+ sort=sort,
407
+ session=session or self.session,
408
+ )
409
+ )
410
+
411
+ async def replace(
412
+ self, replacement: T, session: AsyncClientSession | None = None
413
+ ) -> None:
414
+ await self.collection.replace_one(
415
+ filter={"_id": replacement.id},
416
+ replacement=replacement.model_dump(by_alias=True),
417
+ session=session or self.session,
418
+ )
419
+
420
+ @overload
421
+ def find_all(
422
+ self,
423
+ *,
424
+ skip: int = 0,
425
+ limit: int = 0,
426
+ sort: dict[str, int] | None = None,
427
+ session: AsyncClientSession | None = None,
428
+ ) -> AsyncGenerator[T, None]:
429
+ pass
430
+
431
+ @overload
432
+ def find_all(
433
+ self,
434
+ *,
435
+ projection: type[P],
436
+ skip: int = 0,
437
+ limit: int = 0,
438
+ sort: dict[str, int] | None = None,
439
+ session: AsyncClientSession | None = None,
440
+ ) -> AsyncGenerator[P, None]:
441
+ pass
442
+
443
+ async def find_all(
444
+ self, *_: Any, **kwargs: Any
445
+ ) -> AsyncGenerator[Any, None]:
446
+ async for document in self._find(
447
+ projection=kwargs.get("projection"),
448
+ skip=kwargs.get("skip", 0),
449
+ limit=kwargs.get("limit", 0),
450
+ sort=kwargs.get("sort"),
451
+ session=kwargs.get("session"),
452
+ ):
453
+ yield document
454
+
455
+ @overload
456
+ def find_all_by_filter(
457
+ self,
458
+ *,
459
+ filter: dict[str, Any],
460
+ skip: int = 0,
461
+ limit: int = 0,
462
+ sort: dict[str, int] | None = None,
463
+ session: AsyncClientSession | None = None,
464
+ ) -> AsyncGenerator[T, None]:
465
+ pass
466
+
467
+ @overload
468
+ def find_all_by_filter(
469
+ self,
470
+ *,
471
+ filter: dict[str, Any],
472
+ projection: type[P] | None = None,
473
+ skip: int = 0,
474
+ limit: int = 0,
475
+ sort: dict[str, int] | None = None,
476
+ session: AsyncClientSession | None = None,
477
+ ) -> AsyncGenerator[T, None]:
478
+ pass
479
+
480
+ async def find_all_by_filter(
481
+ self, *_: Any, **kwargs: Any
482
+ ) -> AsyncGenerator[Any, None]:
483
+ async for document in self._find(
484
+ filter=kwargs.get("filter"),
485
+ projection=kwargs.get("projection"),
486
+ skip=kwargs.get("skip", 0),
487
+ limit=kwargs.get("limit", 0),
488
+ sort=kwargs.get("sort"),
489
+ session=kwargs.get("session"),
490
+ ):
491
+ yield document
492
+
493
+ async def _find(
494
+ self,
495
+ filter: dict[str, Any] | None = None,
496
+ projection: type[P] | None = None,
497
+ skip: int = 0,
498
+ limit: int = 0,
499
+ sort: dict[str, int] | None = None,
500
+ session: AsyncClientSession | None = None,
501
+ ) -> AsyncGenerator[T | P, None]:
502
+ if projection:
503
+ async with self.collection.find(
504
+ filter=filter,
505
+ projection=projection.get_projection(),
506
+ skip=skip,
507
+ limit=limit,
508
+ sort=sort,
509
+ session=session or self.session,
510
+ ) as cursor:
511
+ async for document in cursor:
512
+ yield projection.from_raw_projected_document(document)
513
+
514
+ return
515
+
516
+ async with self.collection.find(
517
+ filter=filter,
518
+ skip=skip,
519
+ limit=limit,
520
+ sort=sort,
521
+ session=session or self.session,
522
+ ) as cursor:
523
+ async for document in cursor:
524
+ yield self.document_type.from_raw_document(document)
525
+
526
+ @overload
527
+ async def exists(
528
+ self,
529
+ *,
530
+ id: Any,
531
+ session: AsyncClientSession | None = None,
532
+ ) -> bool:
533
+ pass
534
+
535
+ @overload
536
+ async def exists(
537
+ self,
538
+ *,
539
+ filter: dict[str, Any],
540
+ session: AsyncClientSession | None = None,
541
+ ) -> bool:
542
+ pass
543
+
544
+ async def exists(self, *_: Any, **kwargs: Any) -> bool:
545
+ id = kwargs.get("id")
546
+ filter = kwargs.get("filter")
547
+
548
+ if id is not None:
549
+ filter = {"_id": id}
550
+
551
+ return (
552
+ await self.collection.find_one(
553
+ filter=filter,
554
+ projection={"_id": True},
555
+ session=kwargs.get("session") or self.session,
556
+ )
557
+ is not None
558
+ )
559
+
560
+ async def count(
561
+ self,
562
+ filter: dict[str, Any] | None = None,
563
+ skip: int | None = None,
564
+ limit: int | None = None,
565
+ session: AsyncClientSession | None = None,
566
+ ) -> int:
567
+ if filter is not None or skip is not None or limit is not None:
568
+ options = {}
569
+
570
+ if skip is not None:
571
+ options["skip"] = skip
572
+
573
+ if limit is not None and limit > 0:
574
+ options["limit"] = limit
575
+
576
+ return cast(
577
+ int,
578
+ await self.collection.count_documents(
579
+ filter=filter or {},
580
+ **options,
581
+ session=session or self.session,
582
+ ),
583
+ )
584
+
585
+ return cast(int, await self.collection.estimated_document_count())
586
+
587
+ async def aggregate(
588
+ self,
589
+ aggregation: type[A],
590
+ let: dict[str, Any] | None = None,
591
+ session: AsyncClientSession | None = None,
592
+ ) -> AsyncGenerator[A, None]:
593
+ async with await self.collection.aggregate(
594
+ pipeline=aggregation.get_pipeline(),
595
+ let=let,
596
+ session=session or self.session,
597
+ ) as cursor:
598
+ async for document in cursor:
599
+ yield aggregation.from_raw_aggregated_document(document)
@@ -1,3 +1,4 @@
1
+ from . import message
1
2
  from ._base import AbstractRabbitMQService, RabbitMqManager
2
3
  from ._channel import BaseChannel
3
4
  from ._exception_handlers import (
@@ -40,6 +41,7 @@ __all__ = [
40
41
  "LISTENER_ATTRIBUTE",
41
42
  "ListenerBase",
42
43
  "ListenerContext",
44
+ "message",
43
45
  "Publisher",
44
46
  "RabbitMqGeneralExceptionHandler",
45
47
  "RabbitMqManager",
@@ -519,41 +519,17 @@ class Listener(AsyncEventLoopMixin):
519
519
 
520
520
  return
521
521
 
522
- listener_message_metadata = ListenerMessageMetadata(
523
- body=body,
524
- method=method,
525
- properties=properties,
526
- listener_name=listener_name,
527
- listener_method_container=listener_method_container,
528
- listener_start_time=time(),
529
- )
530
-
531
- self.loop.run_in_executor(
532
- executor=None,
533
- func=partial(self._parse_and_execute, listener_message_metadata),
534
- ).add_done_callback(
535
- partial(
536
- self._on_submitted_listener_error, listener_message_metadata
522
+ self._parse_and_execute(
523
+ ListenerMessageMetadata(
524
+ body=body,
525
+ method=method,
526
+ properties=properties,
527
+ listener_name=listener_name,
528
+ listener_method_container=listener_method_container,
529
+ listener_start_time=time(),
537
530
  )
538
531
  )
539
532
 
540
- def _on_submitted_listener_error(
541
- self,
542
- listener_message_metadata: ListenerMessageMetadata,
543
- future: Future[None],
544
- ) -> None:
545
- if future.cancelled():
546
- return
547
-
548
- exception = future.exception()
549
-
550
- if exception is not None:
551
- self._call_exception_callback(
552
- exception=exception,
553
- listener_message_metadata=listener_message_metadata,
554
- message=f"error occured while submitting listener callback on listener `{listener_message_metadata.listener_name}` and queue `{self._queue}`",
555
- )
556
-
557
533
  def _parse_and_execute(
558
534
  self, listener_message_metadata: ListenerMessageMetadata
559
535
  ) -> None:
@@ -0,0 +1,22 @@
1
+ from pydantic import Field
2
+
3
+ from ._inbound import (
4
+ CamelCaseInboundMessage,
5
+ InboundMessage,
6
+ SnakeCaseInboundMessage,
7
+ )
8
+ from ._outbound import (
9
+ CamelCaseOutboundMessage,
10
+ OutboundMessage,
11
+ SnakeCaseOutboundMessage,
12
+ )
13
+
14
+ __all__ = [
15
+ "CamelCaseInboundMessage",
16
+ "CamelCaseOutboundMessage",
17
+ "Field",
18
+ "InboundMessage",
19
+ "OutboundMessage",
20
+ "SnakeCaseInboundMessage",
21
+ "SnakeCaseOutboundMessage",
22
+ ]
@@ -0,0 +1,13 @@
1
+ from ...alias import CamelCaseAliasedBaseModel, SnakeCaseAliasedBaseModel
2
+
3
+
4
+ class SnakeCaseInboundMessage(SnakeCaseAliasedBaseModel):
5
+ pass
6
+
7
+
8
+ class CamelCaseInboundMessage(CamelCaseAliasedBaseModel):
9
+ pass
10
+
11
+
12
+ class InboundMessage(CamelCaseAliasedBaseModel):
13
+ pass
@@ -0,0 +1,13 @@
1
+ from ...alias import CamelCaseAliasedBaseModel, SnakeCaseAliasedBaseModel
2
+
3
+
4
+ class SnakeCaseOutboundMessage(SnakeCaseAliasedBaseModel):
5
+ pass
6
+
7
+
8
+ class CamelCaseOutboundMessage(CamelCaseAliasedBaseModel):
9
+ pass
10
+
11
+
12
+ class OutboundMessage(CamelCaseOutboundMessage):
13
+ pass