jararaca 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jararaca might be problematic. Click here for more details.

Files changed (34) hide show
  1. jararaca/__init__.py +76 -5
  2. jararaca/cli.py +460 -116
  3. jararaca/core/uow.py +17 -12
  4. jararaca/messagebus/decorators.py +33 -30
  5. jararaca/messagebus/interceptors/aiopika_publisher_interceptor.py +30 -2
  6. jararaca/messagebus/interceptors/publisher_interceptor.py +7 -3
  7. jararaca/messagebus/publisher.py +14 -6
  8. jararaca/messagebus/worker.py +1102 -88
  9. jararaca/microservice.py +137 -34
  10. jararaca/observability/decorators.py +7 -3
  11. jararaca/observability/interceptor.py +4 -2
  12. jararaca/observability/providers/otel.py +14 -10
  13. jararaca/persistence/base.py +2 -1
  14. jararaca/persistence/interceptors/aiosqa_interceptor.py +167 -16
  15. jararaca/presentation/decorators.py +96 -10
  16. jararaca/presentation/server.py +31 -4
  17. jararaca/presentation/websocket/context.py +30 -4
  18. jararaca/presentation/websocket/types.py +2 -2
  19. jararaca/presentation/websocket/websocket_interceptor.py +28 -4
  20. jararaca/reflect/__init__.py +0 -0
  21. jararaca/reflect/controller_inspect.py +75 -0
  22. jararaca/{tools → reflect}/metadata.py +25 -5
  23. jararaca/scheduler/{scheduler_v2.py → beat_worker.py} +49 -53
  24. jararaca/scheduler/decorators.py +55 -20
  25. jararaca/tools/app_config/interceptor.py +4 -2
  26. jararaca/utils/rabbitmq_utils.py +259 -5
  27. jararaca/utils/retry.py +141 -0
  28. {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/METADATA +2 -1
  29. {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/RECORD +32 -31
  30. {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/WHEEL +1 -1
  31. jararaca/messagebus/worker_v2.py +0 -617
  32. jararaca/scheduler/scheduler.py +0 -161
  33. {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/LICENSE +0 -0
  34. {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/entry_points.txt +0 -0
jararaca/microservice.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import logging
2
3
  from contextlib import contextmanager, suppress
3
4
  from contextvars import ContextVar
4
5
  from dataclasses import dataclass, field
@@ -23,64 +24,105 @@ from fastapi import Request, WebSocket
23
24
  from jararaca.core.providers import ProviderSpec, T, Token
24
25
  from jararaca.messagebus import MessageOf
25
26
  from jararaca.messagebus.message import Message
27
+ from jararaca.reflect.controller_inspect import ControllerMemberReflect
28
+
29
+ logger = logging.getLogger(__name__)
26
30
 
27
31
  if TYPE_CHECKING:
28
32
  from typing_extensions import TypeIs
29
33
 
30
34
 
31
35
  @dataclass
32
- class SchedulerAppContext:
36
+ class SchedulerTransactionData:
33
37
  triggered_at: datetime
34
38
  scheduled_to: datetime
35
39
  cron_expression: str
36
- action: Callable[..., Any]
37
40
  context_type: Literal["scheduler"] = "scheduler"
38
41
 
39
42
 
40
43
  @dataclass
41
- class HttpAppContext:
44
+ class HttpTransactionData:
42
45
  request: Request
43
46
  context_type: Literal["http"] = "http"
44
47
 
45
48
 
46
49
  @dataclass
47
- class MessageBusAppContext:
50
+ class MessageBusTransactionData:
48
51
  topic: str
49
52
  message: MessageOf[Message]
50
53
  context_type: Literal["message_bus"] = "message_bus"
51
54
 
52
55
 
53
56
  @dataclass
54
- class WebSocketAppContext:
57
+ class WebSocketTransactionData:
55
58
  websocket: WebSocket
56
59
  context_type: Literal["websocket"] = "websocket"
57
60
 
58
61
 
59
- AppContext = (
60
- MessageBusAppContext | HttpAppContext | SchedulerAppContext | WebSocketAppContext
62
+ TransactionData = (
63
+ MessageBusTransactionData
64
+ | HttpTransactionData
65
+ | SchedulerTransactionData
66
+ | WebSocketTransactionData
61
67
  )
62
68
 
63
- app_context_ctxvar = ContextVar[AppContext]("app_context")
64
69
 
70
+ @dataclass
71
+ class AppTransactionContext:
72
+ transaction_data: TransactionData
73
+ controller_member_reflect: ControllerMemberReflect
74
+
75
+
76
+ AppContext = AppTransactionContext
77
+ """
78
+ Alias for AppTransactionContext, used for compatibility with existing code.
79
+ """
80
+
81
+
82
+ app_transaction_context_var = ContextVar[AppTransactionContext]("app_context")
83
+
84
+
85
+ def use_app_transaction_context() -> AppTransactionContext:
86
+ """
87
+ Returns the current application transaction context.
88
+ This function is used to access the application transaction context in the context of an application transaction.
89
+ If no context is set, it raises a LookupError.
90
+ """
65
91
 
66
- def use_app_context() -> AppContext:
67
- return app_context_ctxvar.get()
92
+ return app_transaction_context_var.get()
93
+
94
+
95
+ def use_app_tx_ctx_data() -> TransactionData:
96
+ """
97
+ Returns the transaction data from the current app transaction context.
98
+ This function is used to access the transaction data in the context of an application transaction.
99
+ """
100
+
101
+ return use_app_transaction_context().transaction_data
102
+
103
+
104
+ use_app_context = use_app_tx_ctx_data
105
+ """Alias for use_app_tx_ctx_data, used for compatibility with existing code."""
68
106
 
69
107
 
70
108
  @contextmanager
71
- def provide_app_context(app_context: AppContext) -> Generator[None, None, None]:
72
- token = app_context_ctxvar.set(app_context)
109
+ def provide_app_context(
110
+ app_context: AppTransactionContext,
111
+ ) -> Generator[None, None, None]:
112
+ token = app_transaction_context_var.set(app_context)
73
113
  try:
74
114
  yield
75
115
  finally:
76
116
  with suppress(ValueError):
77
- app_context_ctxvar.reset(token)
117
+ app_transaction_context_var.reset(token)
78
118
 
79
119
 
80
120
  @runtime_checkable
81
121
  class AppInterceptor(Protocol):
82
122
 
83
- def intercept(self, app_context: AppContext) -> AsyncContextManager[None]: ...
123
+ def intercept(
124
+ self, app_context: AppTransactionContext
125
+ ) -> AsyncContextManager[None]: ...
84
126
 
85
127
 
86
128
  class AppInterceptorWithLifecycle(Protocol):
@@ -106,6 +148,49 @@ class Microservice:
106
148
  )
107
149
 
108
150
 
151
+ @dataclass
152
+ class InstantiationNode:
153
+ property_name: str
154
+ parent: "InstantiationNode | None" = None
155
+ source_type: Any | None = None
156
+ target_type: Any | None = None
157
+
158
+
159
+ instantiation_vector_ctxvar = ContextVar[list[InstantiationNode]](
160
+ "instantiation_vector", default=[]
161
+ )
162
+
163
+
164
+ def print_instantiation_vector(
165
+ instantiation_vector: list[InstantiationNode],
166
+ ) -> None:
167
+ """
168
+ Prints the instantiation vector for debugging purposes.
169
+ """
170
+ for node in instantiation_vector:
171
+ print(
172
+ f"Property: {node.property_name}, Source: {node.source_type}, Target: {node.target_type}"
173
+ )
174
+
175
+
176
+ @contextmanager
177
+ def span_instantiation_vector(
178
+ instantiation_node: InstantiationNode,
179
+ ) -> Generator[None, None, None]:
180
+ """
181
+ Context manager to track instantiation nodes in a vector.
182
+ This is useful for debugging and tracing instantiation paths.
183
+ """
184
+ current_vector = list(instantiation_vector_ctxvar.get())
185
+ current_vector.append(instantiation_node)
186
+ token = instantiation_vector_ctxvar.set(current_vector)
187
+ try:
188
+ yield
189
+ finally:
190
+ with suppress(ValueError):
191
+ instantiation_vector_ctxvar.reset(token)
192
+
193
+
109
194
  class Container:
110
195
 
111
196
  def __init__(self, app: Microservice) -> None:
@@ -122,40 +207,54 @@ class Container:
122
207
  if provider.use_value:
123
208
  self.instances_map[provider.provide] = provider.use_value
124
209
  elif provider.use_class:
125
- self.get_and_register(provider.use_class, provider.provide)
210
+ self._get_and_register(provider.use_class, provider.provide)
126
211
  elif provider.use_factory:
127
- self.get_and_register(provider.use_factory, provider.provide)
212
+ self._get_and_register(provider.use_factory, provider.provide)
128
213
  else:
129
- self.get_and_register(provider, provider)
214
+ self._get_and_register(provider, provider)
130
215
 
131
- def instantiate(self, type_: type[Any] | Callable[..., Any]) -> Any:
216
+ def _instantiate(self, type_: type[Any] | Callable[..., Any]) -> Any:
132
217
 
133
- dependencies = self.parse_dependencies(type_)
218
+ dependencies = self._parse_dependencies(type_)
134
219
 
135
- evaluated_dependencies = {
136
- name: self.get_or_register_token_or_type(dependency)
137
- for name, dependency in dependencies.items()
138
- }
220
+ evaluated_dependencies: dict[str, Any] = {}
221
+ for name, dependency in dependencies.items():
222
+ with span_instantiation_vector(
223
+ InstantiationNode(
224
+ property_name=name,
225
+ source_type=type_,
226
+ target_type=dependency,
227
+ )
228
+ ):
229
+ evaluated_dependencies[name] = self.get_or_register_token_or_type(
230
+ dependency
231
+ )
139
232
 
140
233
  instance = type_(**evaluated_dependencies)
141
234
 
142
235
  return instance
143
236
 
144
- def parse_dependencies(
237
+ def _parse_dependencies(
145
238
  self, provider: type[Any] | Callable[..., Any]
146
239
  ) -> dict[str, type[Any]]:
147
240
 
148
- signature = inspect.signature(provider)
241
+ vector = instantiation_vector_ctxvar.get()
242
+ try:
243
+ signature = inspect.signature(provider)
244
+ except ValueError:
245
+ print("VECTOR:", vector)
246
+ print_instantiation_vector(vector)
247
+ raise
149
248
 
150
249
  parameters = signature.parameters
151
250
 
152
251
  return {
153
- name: self.lookup_parameter_type(parameter)
252
+ name: self._lookup_parameter_type(parameter)
154
253
  for name, parameter in parameters.items()
155
254
  if parameter.annotation != inspect.Parameter.empty
156
255
  }
157
256
 
158
- def lookup_parameter_type(self, parameter: inspect.Parameter) -> Any:
257
+ def _lookup_parameter_type(self, parameter: inspect.Parameter) -> Any:
159
258
  if parameter.annotation == inspect.Parameter.empty:
160
259
  raise Exception(f"Parameter {parameter.name} has no type annotation")
161
260
 
@@ -188,14 +287,14 @@ class Container:
188
287
  item_type = bind_to = token_or_type
189
288
 
190
289
  if token_or_type not in self.instances_map:
191
- return self.get_and_register(item_type, bind_to)
290
+ return self._get_and_register(item_type, bind_to)
192
291
 
193
292
  return cast(T, self.instances_map[bind_to])
194
293
 
195
- def get_and_register(
294
+ def _get_and_register(
196
295
  self, item_type: Type[T] | Callable[..., T], bind_to: Any
197
296
  ) -> T:
198
- instance = self.instantiate(item_type)
297
+ instance = self._instantiate(item_type)
199
298
  self.register(instance, bind_to)
200
299
  return cast(T, instance)
201
300
 
@@ -227,17 +326,21 @@ def provide_container(container: Container) -> Generator[None, None, None]:
227
326
 
228
327
 
229
328
  __all__ = [
230
- "AppContext",
329
+ "AppTransactionContext",
231
330
  "AppInterceptor",
232
331
  "AppInterceptorWithLifecycle",
233
332
  "Container",
234
333
  "Microservice",
235
- "SchedulerAppContext",
236
- "WebSocketAppContext",
237
- "app_context_ctxvar",
334
+ "SchedulerTransactionData",
335
+ "WebSocketTransactionData",
336
+ "app_transaction_context_var",
238
337
  "current_container_ctx",
239
338
  "provide_app_context",
240
339
  "provide_container",
241
340
  "use_app_context",
242
341
  "use_current_container",
342
+ "HttpTransactionData",
343
+ "MessageBusTransactionData",
344
+ "is_interceptor_with_lifecycle",
345
+ "AppContext",
243
346
  ]
@@ -13,7 +13,7 @@ from typing import (
13
13
  TypeVar,
14
14
  )
15
15
 
16
- from jararaca.microservice import AppContext
16
+ from jararaca.microservice import AppTransactionContext
17
17
 
18
18
  P = ParamSpec("P")
19
19
  R = TypeVar("R")
@@ -28,9 +28,13 @@ class TracingContextProvider(Protocol):
28
28
 
29
29
  class TracingContextProviderFactory(Protocol):
30
30
 
31
- def root_setup(self, app_context: AppContext) -> AsyncContextManager[None]: ...
31
+ def root_setup(
32
+ self, app_context: AppTransactionContext
33
+ ) -> AsyncContextManager[None]: ...
32
34
 
33
- def provide_provider(self, app_context: AppContext) -> TracingContextProvider: ...
35
+ def provide_provider(
36
+ self, app_context: AppTransactionContext
37
+ ) -> TracingContextProvider: ...
34
38
 
35
39
 
36
40
  tracing_ctx_provider_ctxv = ContextVar[TracingContextProvider]("tracing_ctx_provider")
@@ -2,9 +2,9 @@ from contextlib import asynccontextmanager
2
2
  from typing import AsyncContextManager, AsyncGenerator, Protocol
3
3
 
4
4
  from jararaca.microservice import (
5
- AppContext,
6
5
  AppInterceptor,
7
6
  AppInterceptorWithLifecycle,
7
+ AppTransactionContext,
8
8
  Container,
9
9
  Microservice,
10
10
  )
@@ -32,7 +32,9 @@ class ObservabilityInterceptor(AppInterceptor, AppInterceptorWithLifecycle):
32
32
  self.observability_provider = observability_provider
33
33
 
34
34
  @asynccontextmanager
35
- async def intercept(self, app_context: AppContext) -> AsyncGenerator[None, None]:
35
+ async def intercept(
36
+ self, app_context: AppTransactionContext
37
+ ) -> AsyncGenerator[None, None]:
36
38
 
37
39
  async with self.observability_provider.tracing_provider.root_setup(app_context):
38
40
 
@@ -23,7 +23,7 @@ from opentelemetry.sdk.trace import TracerProvider
23
23
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
24
24
  from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
25
25
 
26
- from jararaca.microservice import AppContext, Container, Microservice
26
+ from jararaca.microservice import AppTransactionContext, Container, Microservice
27
27
  from jararaca.observability.decorators import (
28
28
  TracingContextProvider,
29
29
  TracingContextProviderFactory,
@@ -36,7 +36,7 @@ tracer: trace.Tracer = trace.get_tracer(__name__)
36
36
 
37
37
  class OtelTracingContextProvider(TracingContextProvider):
38
38
 
39
- def __init__(self, app_context: AppContext) -> None:
39
+ def __init__(self, app_context: AppTransactionContext) -> None:
40
40
  self.app_context = app_context
41
41
 
42
42
  @contextmanager
@@ -52,22 +52,26 @@ class OtelTracingContextProvider(TracingContextProvider):
52
52
 
53
53
  class OtelTracingContextProviderFactory(TracingContextProviderFactory):
54
54
 
55
- def provide_provider(self, app_context: AppContext) -> TracingContextProvider:
55
+ def provide_provider(
56
+ self, app_context: AppTransactionContext
57
+ ) -> TracingContextProvider:
56
58
  return OtelTracingContextProvider(app_context)
57
59
 
58
60
  @asynccontextmanager
59
- async def root_setup(self, app_context: AppContext) -> AsyncGenerator[None, None]:
61
+ async def root_setup(
62
+ self, app_tx_ctx: AppTransactionContext
63
+ ) -> AsyncGenerator[None, None]:
60
64
 
61
65
  title: str = "Unmapped App Context Execution"
62
66
  headers = {}
67
+ tx_data = app_tx_ctx.transaction_data
68
+ if tx_data.context_type == "http":
63
69
 
64
- if app_context.context_type == "http":
65
-
66
- headers = dict(app_context.request.headers)
67
- title = f"HTTP {app_context.request.method} {app_context.request.url}"
70
+ headers = dict(tx_data.request.headers)
71
+ title = f"HTTP {tx_data.request.method} {tx_data.request.url}"
68
72
 
69
- elif app_context.context_type == "message_bus":
70
- title = f"Message Bus {app_context.topic}"
73
+ elif tx_data.context_type == "message_bus":
74
+ title = f"Message Bus {tx_data.topic}"
71
75
 
72
76
  carrier = {
73
77
  key: value
@@ -1,6 +1,7 @@
1
1
  from typing import Any, Self, Type, TypeVar
2
2
 
3
3
  from pydantic import BaseModel
4
+ from sqlalchemy.ext.asyncio import AsyncAttrs
4
5
  from sqlalchemy.orm import DeclarativeBase
5
6
 
6
7
  IDENTIFIABLE_SCHEMA_T = TypeVar("IDENTIFIABLE_SCHEMA_T")
@@ -20,7 +21,7 @@ def recursive_get_dict(obj: Any) -> Any:
20
21
  return obj
21
22
 
22
23
 
23
- class BaseEntity(DeclarativeBase):
24
+ class BaseEntity(AsyncAttrs, DeclarativeBase):
24
25
 
25
26
  @classmethod
26
27
  def from_basemodel(cls, mutation: T_BASEMODEL) -> "Self":
@@ -3,21 +3,53 @@ from contextvars import ContextVar
3
3
  from dataclasses import dataclass
4
4
  from typing import Any, AsyncGenerator, Generator
5
5
 
6
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
6
+ from sqlalchemy.ext.asyncio import (
7
+ AsyncSession,
8
+ AsyncSessionTransaction,
9
+ async_sessionmaker,
10
+ create_async_engine,
11
+ )
7
12
  from sqlalchemy.ext.asyncio.engine import AsyncEngine
8
13
 
9
- from jararaca.microservice import AppContext, AppInterceptor
14
+ from jararaca.microservice import AppInterceptor, AppTransactionContext
15
+ from jararaca.reflect.metadata import SetMetadata, get_metadata_value
10
16
 
11
- ctx_session_map = ContextVar[dict[str, AsyncSession]]("ctx_session_map", default={})
17
+ DEFAULT_CONNECTION_NAME = "default"
18
+
19
+ ctx_default_connection_name: ContextVar[str] = ContextVar(
20
+ "ctx_default_connection_name", default=DEFAULT_CONNECTION_NAME
21
+ )
22
+
23
+
24
+ def ensure_name(name: str | None) -> str:
25
+ return ctx_default_connection_name.get()
26
+
27
+
28
+ @dataclass
29
+ class PersistenceCtx:
30
+ session: AsyncSession
31
+ tx: AsyncSessionTransaction
32
+
33
+
34
+ ctx_session_map = ContextVar[dict[str, PersistenceCtx]]("ctx_session_map", default={})
12
35
 
13
36
 
14
37
  @contextmanager
15
- def provide_session(
16
- connection_name: str, session: AsyncSession
38
+ def providing_session(
39
+ session: AsyncSession,
40
+ tx: AsyncSessionTransaction,
41
+ connection_name: str | None = None,
17
42
  ) -> Generator[None, Any, None]:
43
+ """
44
+ Context manager to provide a session and transaction for a specific connection name.
45
+ If no connection name is provided, it uses the default connection name from the context variable.
46
+ """
47
+ connection_name = ensure_name(connection_name)
18
48
  current_map = ctx_session_map.get({})
19
49
 
20
- token = ctx_session_map.set({**current_map, connection_name: session})
50
+ token = ctx_session_map.set(
51
+ {**current_map, connection_name: PersistenceCtx(session, tx)}
52
+ )
21
53
 
22
54
  try:
23
55
  yield
@@ -26,18 +58,118 @@ def provide_session(
26
58
  ctx_session_map.reset(token)
27
59
 
28
60
 
29
- def use_session(connection_name: str = "default") -> AsyncSession:
61
+ provide_session = providing_session
62
+ """
63
+ Alias for `providing_session` to maintain backward compatibility.
64
+ """
65
+
66
+
67
+ @asynccontextmanager
68
+ async def providing_new_session(
69
+ connection_name: str | None = None,
70
+ ) -> AsyncGenerator[AsyncSession, None]:
71
+
72
+ current_session = use_session(connection_name)
73
+
74
+ async with AsyncSession(
75
+ current_session.bind,
76
+ ) as new_session, new_session.begin() as new_tx:
77
+ with providing_session(new_session, new_tx, connection_name):
78
+ yield new_session
79
+
80
+
81
+ def use_session(connection_name: str | None = None) -> AsyncSession:
82
+ connection_name = ensure_name(connection_name)
30
83
  current_map = ctx_session_map.get({})
31
84
  if connection_name not in current_map:
32
- raise ValueError(f"Session not found for connection {connection_name}")
85
+ raise ValueError(
86
+ f'Session not found for connection "{connection_name}" in context. Check if your interceptor is correctly set up.'
87
+ )
33
88
 
34
- return current_map[connection_name]
89
+ return current_map[connection_name].session
90
+
91
+
92
+ @contextmanager
93
+ def providing_transaction(
94
+ tx: AsyncSessionTransaction,
95
+ connection_name: str | None = None,
96
+ ) -> Generator[None, Any, None]:
97
+ connection_name = ensure_name(connection_name)
98
+
99
+ current_map = ctx_session_map.get({})
100
+
101
+ if connection_name not in current_map:
102
+ raise ValueError(f"No session found for connection {connection_name}")
103
+
104
+ with providing_session(current_map[connection_name].session, tx, connection_name):
105
+ yield
106
+
107
+
108
+ def use_transaction(connection_name: str | None = None) -> AsyncSessionTransaction:
109
+ current_map = ctx_session_map.get({})
110
+ if connection_name not in current_map:
111
+ raise ValueError(f"Transaction not found for connection {connection_name}")
112
+
113
+ return current_map[connection_name].tx
35
114
 
36
115
 
37
- @dataclass
38
116
  class AIOSQAConfig:
39
- connection_name: str
40
117
  url: str | AsyncEngine
118
+ connection_name: str
119
+ inject_default: bool
120
+
121
+ def __init__(
122
+ self,
123
+ url: str | AsyncEngine,
124
+ connection_name: str = DEFAULT_CONNECTION_NAME,
125
+ inject_default: bool = True,
126
+ ):
127
+ self.url = url
128
+ self.connection_name = connection_name
129
+ self.inject_default = inject_default
130
+
131
+
132
+ INJECT_CONNECTION_METADATA = "inject_connection_metadata_{connection_name}"
133
+
134
+
135
+ def set_inject_connection(
136
+ inject: bool, connection_name: str = DEFAULT_CONNECTION_NAME
137
+ ) -> SetMetadata:
138
+ """
139
+ Set whether to inject the connection metadata for the given connection name.
140
+ This is useful when you want to control whether the connection metadata
141
+ should be injected into the context or not.
142
+ """
143
+
144
+ return SetMetadata(
145
+ INJECT_CONNECTION_METADATA.format(connection_name=connection_name), inject
146
+ )
147
+
148
+
149
+ def uses_connection(
150
+ connection_name: str = DEFAULT_CONNECTION_NAME,
151
+ ) -> SetMetadata:
152
+ """
153
+ Use connection metadata for the given connection name.
154
+ This is useful when you want to inject the connection metadata into the context,
155
+ for example, when you are using a specific connection for a specific operation.
156
+ """
157
+ return SetMetadata(
158
+ INJECT_CONNECTION_METADATA.format(connection_name=connection_name), True
159
+ )
160
+
161
+
162
+ def dnt_uses_connection(
163
+ connection_name: str = DEFAULT_CONNECTION_NAME,
164
+ ) -> SetMetadata:
165
+ """
166
+ Do not use connection metadata for the given connection name.
167
+ This is useful when you want to ensure that the connection metadata is not injected
168
+ into the context, for example, when you are using a different connection for a specific operation.
169
+ """
170
+ return SetMetadata(
171
+ INJECT_CONNECTION_METADATA.format(connection_name=connection_name), False
172
+ )
41
173
 
42
174
 
43
175
  class AIOSqlAlchemySessionInterceptor(AppInterceptor):
@@ -53,12 +185,31 @@ class AIOSqlAlchemySessionInterceptor(AppInterceptor):
53
185
  self.sessionmaker = async_sessionmaker(self.engine)
54
186
 
55
187
  @asynccontextmanager
56
- async def intercept(self, app_context: AppContext) -> AsyncGenerator[None, None]:
57
- async with self.sessionmaker() as session:
58
- with provide_session(self.config.connection_name, session):
188
+ async def intercept(
189
+ self, app_context: AppTransactionContext
190
+ ) -> AsyncGenerator[None, None]:
191
+
192
+ uses_connection_metadata = get_metadata_value(
193
+ INJECT_CONNECTION_METADATA.format(
194
+ connection_name=self.config.connection_name
195
+ ),
196
+ self.config.inject_default,
197
+ )
198
+
199
+ if not uses_connection_metadata:
200
+ yield
201
+ return
202
+
203
+ async with self.sessionmaker() as session, session.begin() as tx:
204
+ token = ctx_default_connection_name.set(self.config.connection_name)
205
+ with providing_session(session, tx, self.config.connection_name):
59
206
  try:
60
207
  yield
61
- await session.commit()
208
+ if tx.is_active:
209
+ await tx.commit()
62
210
  except Exception as e:
63
- await session.rollback()
211
+ await tx.rollback()
64
212
  raise e
213
+ finally:
214
+ with suppress(ValueError):
215
+ ctx_default_connection_name.reset(token)