stackraise 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. stackraise/__init__.py +6 -0
  2. stackraise/ai/__init__.py +2 -0
  3. stackraise/ai/rpa.py +380 -0
  4. stackraise/ai/toolset.py +227 -0
  5. stackraise/app.py +23 -0
  6. stackraise/auth/__init__.py +2 -0
  7. stackraise/auth/model.py +24 -0
  8. stackraise/auth/service.py +240 -0
  9. stackraise/ctrl/__init__.py +4 -0
  10. stackraise/ctrl/change_stream.py +40 -0
  11. stackraise/ctrl/crud_controller.py +63 -0
  12. stackraise/ctrl/file_storage.py +68 -0
  13. stackraise/db/__init__.py +11 -0
  14. stackraise/db/adapter.py +60 -0
  15. stackraise/db/collection.py +292 -0
  16. stackraise/db/cursor.py +229 -0
  17. stackraise/db/document.py +282 -0
  18. stackraise/db/exceptions.py +9 -0
  19. stackraise/db/id.py +79 -0
  20. stackraise/db/index.py +84 -0
  21. stackraise/db/persistence.py +238 -0
  22. stackraise/db/pipeline.py +245 -0
  23. stackraise/db/protocols.py +141 -0
  24. stackraise/di.py +36 -0
  25. stackraise/event.py +150 -0
  26. stackraise/inflection.py +28 -0
  27. stackraise/io/__init__.py +3 -0
  28. stackraise/io/imap_client.py +400 -0
  29. stackraise/io/smtp_client.py +102 -0
  30. stackraise/logging.py +22 -0
  31. stackraise/model/__init__.py +11 -0
  32. stackraise/model/core.py +16 -0
  33. stackraise/model/dto.py +12 -0
  34. stackraise/model/email_message.py +88 -0
  35. stackraise/model/file.py +154 -0
  36. stackraise/model/name_email.py +45 -0
  37. stackraise/model/query_filters.py +231 -0
  38. stackraise/model/time_range.py +285 -0
  39. stackraise/model/validation.py +8 -0
  40. stackraise/templating/__init__.py +4 -0
  41. stackraise/templating/exceptions.py +23 -0
  42. stackraise/templating/image/__init__.py +2 -0
  43. stackraise/templating/image/model.py +51 -0
  44. stackraise/templating/image/processor.py +154 -0
  45. stackraise/templating/parser.py +156 -0
  46. stackraise/templating/pptx/__init__.py +3 -0
  47. stackraise/templating/pptx/pptx_engine.py +204 -0
  48. stackraise/templating/pptx/slide_renderer.py +181 -0
  49. stackraise/templating/tracer.py +57 -0
  50. stackraise-0.1.0.dist-info/METADATA +37 -0
  51. stackraise-0.1.0.dist-info/RECORD +52 -0
  52. stackraise-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,282 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Annotated, Any, Awaitable, Callable, ClassVar, Mapping, Optional, Self
5
+
6
+ from pydantic import Field
7
+ from pydantic_core import core_schema
8
+
9
+ from .adapter import Adapter, DocumentProtocol
10
+ from .collection import Collection
11
+ from .id import Id
12
+ from .exceptions import NotFoundError
13
+ from .protocols import QueryLike, ensure_mongo_query
14
+
15
+ import stackraise.model as model
16
+
17
+
18
+ class DocumentMeta(type(model.Dto)):
19
+ pass
20
+
21
+
22
+ class Document(model.Dto, metaclass=DocumentMeta):
23
+ """
24
+ Base class for all documents in the database.
25
+ This class provides a common interface for all documents, including
26
+ methods for storing, updating, and deleting documents.
27
+ """
28
+
29
+ class Reference[T: DocumentProtocol]:
30
+ __slots__ = ("_id",)
31
+ collection: ClassVar[Collection[T]]
32
+ _id: Id
33
+ _sync_item: ClassVar[T|None] = None
34
+
35
+ @classmethod
36
+ def __get_pydantic_core_schema__(cls, *_):
37
+
38
+ def js_parse(value: str):
39
+ assert isinstance(value, str), f"Bad value type {type(value)}"
40
+ return cls(Id(value))
41
+
42
+ def js_serial(ref: Document.Reference):
43
+ return str(ref.id)
44
+
45
+ def py_parse(value: Document.Reference | Id | str):
46
+ if isinstance(value, Document.Reference):
47
+ assert type(value) is cls, f"Bad type {type(value)} for {cls}"
48
+ return value
49
+ if not isinstance(value, Id):
50
+ value = Id(value)
51
+ return cls(value)
52
+
53
+ def py_serial(ref: Document.Reference):
54
+ return ref.id
55
+
56
+ return core_schema.json_or_python_schema(
57
+ # JSON
58
+ json_schema=core_schema.no_info_plain_validator_function(
59
+ js_parse,
60
+ serialization=core_schema.plain_serializer_function_ser_schema(
61
+ js_serial
62
+ ),
63
+ ),
64
+ # PYTHON
65
+ python_schema=core_schema.no_info_plain_validator_function(
66
+ py_parse,
67
+ serialization=core_schema.plain_serializer_function_ser_schema(
68
+ py_serial
69
+ ),
70
+ ),
71
+ )
72
+
73
+ @classmethod
74
+ def __get_pydantic_json_schema__(cls, _, handler):
75
+ return handler(core_schema.str_schema())
76
+
77
+ def __repr__(self):
78
+ return f"{type(self).__qualname__}({self._id!s})"
79
+
80
+ def __init__(self, id: Id):
81
+ assert isinstance(id, Id), f"Bad id type {type(id)}"
82
+ self._id = id
83
+
84
+ def __eq__(self, other) -> bool:
85
+ if not isinstance(other, type(self)):
86
+ return False
87
+ return self._id == other._id
88
+
89
+ def __hash__(self) -> int:
90
+ return hash(self._id)
91
+
92
+ @property
93
+ def id(self):
94
+ return self._id
95
+
96
+ @property
97
+ def created_at(self) -> datetime:
98
+ return self.id.generation_time
99
+
100
+ def to_mongo_query(self) -> Mapping[str, Any]:
101
+ """Return a MongoDB query for this reference."""
102
+ return {"_id": self._id}
103
+
104
+ def sync_fetch(self) -> T:
105
+ """Fetch the document referenced by this reference."""
106
+ return self._sync_item
107
+
108
+ async def fetch(self, not_found_error: bool=True) -> T:
109
+ """Fetch the document referenced by this reference."""
110
+ if (sync_item := self._sync_item) is not None:
111
+ return sync_item
112
+ return await self.collection.fetch_by_id(self._id, not_found_error=not_found_error)
113
+
114
+ async def delete(self):
115
+ """Delete the document referenced by this reference."""
116
+ return await self.collection.delete_by_id(self._id)
117
+
118
+ async def exists(self) -> bool:
119
+ """ Check if the document exists in the database. """
120
+ count = await self.collection.count(self)
121
+ return count == 1
122
+
123
+ async def complies(self, conditions: QueryLike) -> bool:
124
+ """
125
+ Check if the document referenced by this reference complies with the given conditions.
126
+ Args:
127
+ conditions (QueryLike): The conditions to check against the document.
128
+ Returns:
129
+ bool: True if the document complies with the conditions, False otherwise.
130
+ """
131
+ query = ensure_mongo_query(conditions)
132
+ count = await self.collection.count(query | {"_id": self._id})
133
+ return count == 1
134
+
135
+ async def assign(self, **values):
136
+ """
137
+ Assign values to the document referenced by this reference.
138
+ Args:
139
+ **values: The values to assign to the document.
140
+ """
141
+
142
+ result = await self.collection._update_one(
143
+ {"_id": self._id},
144
+ {"$set": values},
145
+ upsert=True,
146
+ )
147
+
148
+ if result.upserted_id != self.id:
149
+ raise NotFoundError(self)
150
+
151
+ async def update(self, **values):
152
+ result = await self.collection._update_one(
153
+ {"_id": self._id}, {"$set": values}, upsert=False
154
+ )
155
+
156
+ if result.matched_count != 1:
157
+ raise NotFoundError(self)
158
+
159
+ ## mutate() as doc creara un contexto de mutacion
160
+ ## doc puede ser valuado leido y escrito de diferentes maneras
161
+ ## despues del contexto de mutacion la operacion será realizada
162
+
163
+ def __init_subclass__(
164
+ cls,
165
+ abstract=False,
166
+ collection: Optional[str] = None,
167
+ **kwargs,
168
+ ):
169
+ super().__init_subclass__(**kwargs)
170
+ ##
171
+ ## el subclassing de Document tiene propieadades especiales:
172
+ ## - Se crea una clase de referencia para el nuevo documento
173
+ ## - Se instancia un repositorio para el nuevo documento
174
+ ##
175
+
176
+ # Make a repository class for the new persistent class
177
+ # Instance the repository
178
+
179
+ collection = Collection(Adapter(cls), collection)
180
+
181
+ # Make a reference class for the new persistent class
182
+
183
+ reference_cls = type(
184
+ "Reference",
185
+ tuple(
186
+ vars(base)["Reference"]
187
+ for base in cls.__mro__
188
+ if "Reference" in vars(base)
189
+ and issubclass(vars(base)["Reference"], Document.Reference)
190
+ ),
191
+ {
192
+ "__module__": cls.__module__,
193
+ "__qualname__": f"{cls.__qualname__}.Reference",
194
+ "document_class": cls,
195
+ "collection": collection,
196
+ },
197
+ )
198
+
199
+ setattr(cls, "Reference", reference_cls)
200
+ setattr(cls, "Ref", reference_cls)
201
+ setattr(cls, "collection", collection)
202
+
203
+ type Ref = Reference[Self]
204
+ collection: ClassVar[Collection[Self]]
205
+
206
+ id: Annotated[
207
+ Optional[Id],
208
+ Field(
209
+ None,
210
+ title="Unique object Id",
211
+ alias="_id",
212
+ ),
213
+ ]
214
+
215
+ # @computed_field
216
+ # def kind(self) -> str:
217
+ # """Return the kind of the document."""
218
+ # return self.__class__.__name__
219
+
220
+ @property
221
+ def ref(self) -> Ref | None:
222
+ """Return a reference to the document."""
223
+ if self.id is None:
224
+ return None
225
+ return self.Reference(self.id)
226
+
227
+ async def __prepare_for_storage__(self):
228
+ """
229
+ Hook to be called before persisting the object.
230
+ This method can be overridden in subclasses to perform custom validation
231
+ or processing before the object is stored in the database.
232
+ For example, you can check for uniqueness constraints or perform
233
+ additional validation on the object's attributes.
234
+ This method is called automatically by the `insert` and `update` methods.
235
+
236
+ TODO: esto es parte del document protocol
237
+ """
238
+
239
+ @classmethod
240
+ async def __handle_post_deletion__(cls, ref: Ref):
241
+ """
242
+ Hook to be called after the document is deleted.
243
+ This method can be overridden in subclasses to perform custom cleanup
244
+ or processing after the object is deleted from the database.
245
+ For example, you can delete related documents or perform additional
246
+ cleanup tasks.
247
+ This method is called automatically by the `delete` method.
248
+ """
249
+
250
+ async def store(self) -> Self:
251
+ """Persist the object in the database.shipment is
252
+
253
+ This method will use the insert procedure if the object does not have
254
+ an identifier defined (id is None) or the update(upsert) procedure if
255
+ the object already has an identifier.
256
+
257
+ Returns:
258
+ RefBase[Self]: A reference to the persisted object.
259
+ """
260
+ if self.id:
261
+ return await self.update()
262
+ else:
263
+ return await self.insert()
264
+
265
+ async def insert(self, *, with_id: Optional[Id] = None) -> Self:
266
+ return await self.collection.insert_item(self, with_id=with_id)
267
+
268
+ async def update(self):
269
+ return await self.collection.update_item(self)
270
+
271
+ async def delete(self):
272
+ """
273
+ Deletes the document from the database.
274
+ """
275
+ doc_id = self.id
276
+ if doc_id is None:
277
+ raise NotFoundError(self)
278
+
279
+ self.id = None
280
+ await self.collection.delete_by_id(doc_id)
281
+
282
+
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+ import stackraise.db as db
3
+
4
+ class NotFoundError(Exception):
5
+ """Exception raised when a document is not found in the database."""
6
+
7
+ def __init__(self, ref: db.Document.Ref):
8
+ super().__init__(f"Document with reference {ref} not found.")
9
+ self.ref = ref
stackraise/db/id.py ADDED
@@ -0,0 +1,79 @@
1
+ from __future__ import annotations
2
+ from typing_extensions import deprecated
3
+
4
+ from bson import ObjectId
5
+ from bson.errors import InvalidId
6
+ from pydantic_core import core_schema, SchemaSerializer
7
+ from stackraise.model.validation import validation_error
8
+
9
+ __all__ = ["Id"]
10
+
11
+ class Id(ObjectId):
12
+ _GENERATE_NEW = object() # Sentinel para generar nuevo ID
13
+
14
+ def __init__(self, value=None):
15
+ # Permitir generación explícita de nuevo ID con sentinel
16
+ if value is Id._GENERATE_NEW:
17
+ super().__init__()
18
+ return
19
+ # Evitar que None o vacío generen un ObjectId aleatorio accidentalmente
20
+ if value is None:
21
+ raise ValueError("Valor no seleccionado")
22
+ try:
23
+ super().__init__(value)
24
+ except InvalidId as e:
25
+ raise ValueError(*e.args) from InvalidId
26
+
27
+ @classmethod
28
+ def new(cls):
29
+ """Genera un nuevo ObjectId único."""
30
+ return cls(cls._GENERATE_NEW)
31
+
32
+ @classmethod
33
+ @deprecated("you must not use this method")
34
+ def from_str(cls, val: str):
35
+ assert isinstance(val, str), f"from_str receive {val}"
36
+ return cls(val)
37
+
38
+ @classmethod
39
+ @deprecated("you must not use this method")
40
+ def from_oid(cls, val: Id | ObjectId | str):
41
+ # assert isinstance(val, ObjectId), f"Id from python is {val}"
42
+ return val if isinstance(val, cls) else cls(val)
43
+
44
+ @property
45
+ def created_at(self):
46
+ return self.generation_time
47
+
48
+ def to_mongo_query(self): # query protocol
49
+ return {"_id": self}
50
+
51
+ @property
52
+ def value(self) -> str:
53
+ """Return the string representation of the Id"""
54
+ return str(self)
55
+
56
+ SCHEMA = core_schema.json_or_python_schema(
57
+ # JSON
58
+ json_schema=core_schema.no_info_plain_validator_function(
59
+ lambda s: Id.from_str(s),
60
+ ),
61
+ # PYTHON
62
+ python_schema=core_schema.no_info_plain_validator_function(
63
+ lambda v: Id.from_oid(v),
64
+ ),
65
+ serialization=core_schema.plain_serializer_function_ser_schema(
66
+ str, when_used="json" # as str
67
+ ),
68
+ )
69
+
70
+ __pydantic_serializer__ = SchemaSerializer(SCHEMA)
71
+
72
+ @classmethod
73
+ def __get_pydantic_core_schema__(cls, *_):
74
+ return cls.SCHEMA
75
+
76
+ @classmethod
77
+ def __get_pydantic_json_schema__(cls, _, handler):
78
+ # return {"type": "string"}
79
+ return handler(core_schema.str_schema())
stackraise/db/index.py ADDED
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from logging import getLogger as get_logger
5
+
6
+ #import stackraise.db as db
7
+ import stackraise.inflection as inflection
8
+ from pymongo.asynchronous.client_session import \
9
+ AsyncClientSession as MongoSession
10
+ from pymongo.asynchronous.collection import AsyncCollection as MongoCollection
11
+
12
+ from .protocols import DocumentProtocol
13
+
14
+ log = get_logger(__name__)
15
+
16
+ _indices: dict[type[DocumentProtocol], list[Index]] = {}
17
+
18
+
19
+ class Index:
20
+ def __init__(self, args: list[str], kwargs: dict[str, int] = {}, **options):
21
+ fields = {field: 1 for field in args}
22
+ fields.update(kwargs)
23
+ self.fields = {inflection.to_camelcase(field): value for field, value in fields.items()}
24
+ self.options = options
25
+
26
+ def __call__[T: type[DocumentProtocol]](self, cls: T) -> T:
27
+ _indices.setdefault(cls, []).append(self)
28
+ return cls
29
+
30
+ def index(*args, **kwargs: int) -> Index:
31
+ return Index(args, kwargs, unique=False)
32
+
33
+ def unique_index(*args, **kwargs: int) -> Index:
34
+ return Index(args, unique=True)
35
+
36
+ def text_index(*fields, **options) -> Index:
37
+ return Index([], {field: 'text' for field in fields}, **options)
38
+
39
+
40
+ async def _update_indices(
41
+ document_class: type[DocumentProtocol],
42
+ collection: MongoCollection,
43
+ session: MongoSession,
44
+ ):
45
+ """
46
+ Applies the indices to the collection.
47
+ """
48
+
49
+ desired_indices = _indices.get(document_class, [])
50
+
51
+ existing_indices = await collection.index_information()
52
+ # Map existing indices by their key tuple (excluding _id_)
53
+ existing_keys = {
54
+ tuple(idx['key']): name
55
+ for name, idx in existing_indices.items()
56
+ if name != '_id_'
57
+ }
58
+
59
+ # Map desired indices by their key tuple
60
+ desired_keys = {}
61
+ for idx in desired_indices:
62
+ # Convert fields dict to tuple of (field, direction) preserving order
63
+ key_tuple = tuple(idx.fields.items())
64
+ desired_keys[key_tuple] = idx
65
+
66
+ # Indices to drop: present in existing but not in desired
67
+ to_drop = [existing_keys[k] for k in existing_keys if k not in desired_keys]
68
+ # Indices to create: present in desired but not in existing
69
+ to_create = [desired_keys[k] for k in desired_keys if k not in existing_keys]
70
+
71
+ # Drop obsolete indices
72
+ for name in to_drop:
73
+ log.debug(f"Dropping index {name} from collection {collection.name}")
74
+ await collection.drop_index(name, session=session)
75
+
76
+ # Create new indices
77
+ for idx in to_create:
78
+ log.debug(f"Creating index {idx.fields} on collection {collection.name}")
79
+ await collection.create_index(
80
+ list(idx.fields.items()),
81
+ session=session,
82
+ **idx.options
83
+ )
84
+
@@ -0,0 +1,238 @@
1
+ from __future__ import annotations
2
+
3
+ from anyio import create_task_group
4
+ from contextlib import asynccontextmanager
5
+ from contextvars import ContextVar
6
+ from dataclasses import dataclass, field as dc_field
7
+ from functools import cached_property, wraps
8
+ from typing import Annotated, Awaitable, Callable, Optional, Tuple, TypedDict
9
+
10
+ from gridfs import AsyncGridFS
11
+
12
+ from pydantic import BaseModel, Field, MongoDsn
13
+ from pymongo import AsyncMongoClient
14
+ from pymongo.asynchronous.database import AsyncDatabase
15
+
16
+ from pymongo.asynchronous.client_session import AsyncClientSession
17
+ from pymongo.errors import OperationFailure
18
+ from pymongo.read_concern import ReadConcern
19
+ from pymongo.read_preferences import ReadPreference
20
+ from pymongo.write_concern import WriteConcern
21
+ from pymongo import AsyncMongoClient
22
+
23
+ import stackraise.event as ev
24
+ import stackraise.model as model
25
+ import stackraise.db as db
26
+
27
+ from .protocols import get_collection_instances
28
+
29
+ _current_context: ContextVar[Optional[Tuple[Persistence, AsyncClientSession]]] = (
30
+ ContextVar(
31
+ f"{__name__}._current_context",
32
+ default=None,
33
+ )
34
+ )
35
+
36
+ _startup_tasks: Callable[[], Awaitable] = []
37
+
38
+
39
+ class ChangeEvent(TypedDict):
40
+ op: str
41
+ collection: str
42
+ refs: list[str]
43
+
44
+
45
+ change_event_emitter: ev.EventEmitter[ChangeEvent] = ev.EventEmitter(
46
+ "persistence.change_event"
47
+ )
48
+
49
+
50
+ @dataclass(frozen=True)
51
+ class Persistence:
52
+ """
53
+ A class representing the persistence layer for the ctrl.
54
+ """
55
+
56
+ class Settings(BaseModel):
57
+ """Persistence layer settings"""
58
+
59
+ mongo_dsn: Annotated[
60
+ MongoDsn,
61
+ Field(
62
+ MongoDsn("mongodb://localhost/test"),
63
+ title="Mongo database DSN",
64
+ description="The DSN for the MongoDB database. ",
65
+ ),
66
+ ]
67
+
68
+ direct_connection: Annotated[
69
+ bool,
70
+ Field(
71
+ True,
72
+ title="Direct connection",
73
+ description="Whether to connect directly to the MongoDB instance.",
74
+ ),
75
+ ]
76
+
77
+ causal_consistency: Annotated[
78
+ bool,
79
+ Field(
80
+ True,
81
+ title="Causal consistency",
82
+ description="Whether to enable causal consistency.",
83
+ ),
84
+ ]
85
+
86
+ # TODO: Default transaction options
87
+
88
+ settings: Settings = dc_field(default_factory=Settings)
89
+
90
+ @cached_property
91
+ def client(self):
92
+ # print(str(self.settings.mongo_dsn))
93
+ return AsyncMongoClient(
94
+ str(self.settings.mongo_dsn),
95
+ directConnection=self.settings.direct_connection,
96
+ )
97
+ # return AsyncMongoClient()
98
+
99
+ @cached_property
100
+ def database(self) -> AsyncDatabase:
101
+ # TODO: read / write preference and concern from settings
102
+ return self.client.get_default_database()
103
+
104
+ @cached_property
105
+ def fs(self):
106
+ return AsyncGridFS(self.database)
107
+
108
+ @asynccontextmanager
109
+ async def lifespan(self):
110
+ async with self.client:
111
+ async with self.session():
112
+
113
+ async with create_task_group() as tg:
114
+ tg.start_soon(self._watch_task)
115
+
116
+ for collection in get_collection_instances():
117
+ tg.start_soon(collection._startup_task, self)
118
+
119
+ yield
120
+
121
+ tg.cancel_scope.cancel()
122
+
123
+ async def _watch_task(self):
124
+ """
125
+ Background task to watch changes in the database.
126
+ Only works with MongoDB replica sets. Gracefully disables on standalone instances.
127
+ """
128
+ try:
129
+ async with await self.database.watch() as change_stream:
130
+ async for change in change_stream:
131
+
132
+ # Emit change events
133
+ change_event = ChangeEvent(
134
+ op=change["operationType"],
135
+ collection=change["ns"]["coll"],
136
+ refs=[str(change["documentKey"]["_id"])],
137
+ )
138
+
139
+ await change_event_emitter.emit(change_event)
140
+ except OperationFailure as e:
141
+ # Change Streams require replica set (code 40573)
142
+ if e.code == 40573:
143
+ print(
144
+ "⚠️ MongoDB Change Streams disabled: running in standalone mode (replica set required)"
145
+ )
146
+ return # Exit gracefully
147
+ else:
148
+ raise
149
+
150
+ @asynccontextmanager
151
+ async def session(self):
152
+ """Enter the persistence session context."""
153
+ async with self.client.start_session(
154
+ causal_consistency=self.settings.causal_consistency,
155
+ default_transaction_options=None, # TODO: From settings
156
+ ) as session:
157
+ session_token = _current_context.set((self, session))
158
+ try:
159
+ yield session
160
+ finally:
161
+ _current_context.reset(session_token)
162
+
163
+
164
+ def current_context():
165
+ """
166
+ Get the current persistence context.
167
+ """
168
+ context = _current_context.get()
169
+ if context is None:
170
+ raise RuntimeError("No persistence context is currently set.")
171
+ return context
172
+
173
+
174
+ def current_database() -> AsyncDatabase:
175
+ """
176
+ Get the current database from the context.
177
+ """
178
+ persistence, _ = current_context()
179
+ return persistence.database
180
+
181
+
182
+ def current_fs() -> AsyncGridFS:
183
+ """
184
+ Get the current GridFS instance from the context.
185
+ """
186
+ persistence, _ = current_context()
187
+ return persistence.fs
188
+
189
+
190
+ def current_session() -> AsyncClientSession:
191
+ """
192
+ Get the current session from the context.
193
+ """
194
+ _, session = _current_context.get()
195
+ return session
196
+
197
+
198
+ def in_transaction() -> bool:
199
+ """
200
+ Check if the current session is in a transaction.
201
+ """
202
+ session = current_session()
203
+ return session is not None and session.in_transaction
204
+
205
+
206
+ def transaction(
207
+ read_concern: ReadConcern | None = None,
208
+ write_concern: WriteConcern | None = None,
209
+ read_preference: ReadPreference | None = None,
210
+ max_commit_time_ms: int | None = None,
211
+ ):
212
+ """
213
+ Decorator to run a function within a transaction.
214
+ """
215
+
216
+ def decorator(fn):
217
+
218
+ @wraps(fn)
219
+ async def wrapper(*args, **kwargs):
220
+ session = current_session()
221
+
222
+ if session.in_transaction:
223
+ return await coro(*args, **kwargs)
224
+
225
+ async def coro(_session):
226
+ return await fn(*args, **kwargs)
227
+
228
+ return await session.with_transaction(
229
+ coro,
230
+ read_concern=read_concern,
231
+ write_concern=write_concern,
232
+ read_preference=read_preference,
233
+ max_commit_time_ms=max_commit_time_ms,
234
+ )
235
+
236
+ return wrapper
237
+
238
+ return decorator