jararaca 0.3.10__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jararaca might be problematic. Click here for more details.
- jararaca/__init__.py +76 -5
- jararaca/cli.py +460 -116
- jararaca/core/uow.py +17 -12
- jararaca/messagebus/decorators.py +33 -30
- jararaca/messagebus/interceptors/aiopika_publisher_interceptor.py +30 -2
- jararaca/messagebus/interceptors/publisher_interceptor.py +7 -3
- jararaca/messagebus/publisher.py +14 -6
- jararaca/messagebus/worker.py +1102 -88
- jararaca/microservice.py +137 -34
- jararaca/observability/decorators.py +7 -3
- jararaca/observability/interceptor.py +4 -2
- jararaca/observability/providers/otel.py +14 -10
- jararaca/persistence/base.py +2 -1
- jararaca/persistence/interceptors/aiosqa_interceptor.py +167 -16
- jararaca/presentation/decorators.py +96 -10
- jararaca/presentation/server.py +31 -4
- jararaca/presentation/websocket/context.py +30 -4
- jararaca/presentation/websocket/types.py +2 -2
- jararaca/presentation/websocket/websocket_interceptor.py +28 -4
- jararaca/reflect/__init__.py +0 -0
- jararaca/reflect/controller_inspect.py +75 -0
- jararaca/{tools → reflect}/metadata.py +25 -5
- jararaca/scheduler/{scheduler_v2.py → beat_worker.py} +49 -53
- jararaca/scheduler/decorators.py +55 -20
- jararaca/tools/app_config/interceptor.py +4 -2
- jararaca/utils/rabbitmq_utils.py +259 -5
- jararaca/utils/retry.py +141 -0
- {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/METADATA +2 -1
- {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/RECORD +32 -31
- {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/WHEEL +1 -1
- jararaca/messagebus/worker_v2.py +0 -617
- jararaca/scheduler/scheduler.py +0 -161
- {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/LICENSE +0 -0
- {jararaca-0.3.10.dist-info → jararaca-0.3.11.dist-info}/entry_points.txt +0 -0
jararaca/microservice.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import logging
|
|
2
3
|
from contextlib import contextmanager, suppress
|
|
3
4
|
from contextvars import ContextVar
|
|
4
5
|
from dataclasses import dataclass, field
|
|
@@ -23,64 +24,105 @@ from fastapi import Request, WebSocket
|
|
|
23
24
|
from jararaca.core.providers import ProviderSpec, T, Token
|
|
24
25
|
from jararaca.messagebus import MessageOf
|
|
25
26
|
from jararaca.messagebus.message import Message
|
|
27
|
+
from jararaca.reflect.controller_inspect import ControllerMemberReflect
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
26
30
|
|
|
27
31
|
if TYPE_CHECKING:
|
|
28
32
|
from typing_extensions import TypeIs
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
@dataclass
|
|
32
|
-
class
|
|
36
|
+
class SchedulerTransactionData:
|
|
33
37
|
triggered_at: datetime
|
|
34
38
|
scheduled_to: datetime
|
|
35
39
|
cron_expression: str
|
|
36
|
-
action: Callable[..., Any]
|
|
37
40
|
context_type: Literal["scheduler"] = "scheduler"
|
|
38
41
|
|
|
39
42
|
|
|
40
43
|
@dataclass
|
|
41
|
-
class
|
|
44
|
+
class HttpTransactionData:
|
|
42
45
|
request: Request
|
|
43
46
|
context_type: Literal["http"] = "http"
|
|
44
47
|
|
|
45
48
|
|
|
46
49
|
@dataclass
|
|
47
|
-
class
|
|
50
|
+
class MessageBusTransactionData:
|
|
48
51
|
topic: str
|
|
49
52
|
message: MessageOf[Message]
|
|
50
53
|
context_type: Literal["message_bus"] = "message_bus"
|
|
51
54
|
|
|
52
55
|
|
|
53
56
|
@dataclass
|
|
54
|
-
class
|
|
57
|
+
class WebSocketTransactionData:
|
|
55
58
|
websocket: WebSocket
|
|
56
59
|
context_type: Literal["websocket"] = "websocket"
|
|
57
60
|
|
|
58
61
|
|
|
59
|
-
|
|
60
|
-
|
|
62
|
+
TransactionData = (
|
|
63
|
+
MessageBusTransactionData
|
|
64
|
+
| HttpTransactionData
|
|
65
|
+
| SchedulerTransactionData
|
|
66
|
+
| WebSocketTransactionData
|
|
61
67
|
)
|
|
62
68
|
|
|
63
|
-
app_context_ctxvar = ContextVar[AppContext]("app_context")
|
|
64
69
|
|
|
70
|
+
@dataclass
|
|
71
|
+
class AppTransactionContext:
|
|
72
|
+
transaction_data: TransactionData
|
|
73
|
+
controller_member_reflect: ControllerMemberReflect
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
AppContext = AppTransactionContext
|
|
77
|
+
"""
|
|
78
|
+
Alias for AppTransactionContext, used for compatibility with existing code.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
app_transaction_context_var = ContextVar[AppTransactionContext]("app_context")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def use_app_transaction_context() -> AppTransactionContext:
|
|
86
|
+
"""
|
|
87
|
+
Returns the current application transaction context.
|
|
88
|
+
This function is used to access the application transaction context in the context of an application transaction.
|
|
89
|
+
If no context is set, it raises a LookupError.
|
|
90
|
+
"""
|
|
65
91
|
|
|
66
|
-
|
|
67
|
-
|
|
92
|
+
return app_transaction_context_var.get()
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def use_app_tx_ctx_data() -> TransactionData:
|
|
96
|
+
"""
|
|
97
|
+
Returns the transaction data from the current app transaction context.
|
|
98
|
+
This function is used to access the transaction data in the context of an application transaction.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
return use_app_transaction_context().transaction_data
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
use_app_context = use_app_tx_ctx_data
|
|
105
|
+
"""Alias for use_app_tx_ctx_data, used for compatibility with existing code."""
|
|
68
106
|
|
|
69
107
|
|
|
70
108
|
@contextmanager
|
|
71
|
-
def provide_app_context(
|
|
72
|
-
|
|
109
|
+
def provide_app_context(
|
|
110
|
+
app_context: AppTransactionContext,
|
|
111
|
+
) -> Generator[None, None, None]:
|
|
112
|
+
token = app_transaction_context_var.set(app_context)
|
|
73
113
|
try:
|
|
74
114
|
yield
|
|
75
115
|
finally:
|
|
76
116
|
with suppress(ValueError):
|
|
77
|
-
|
|
117
|
+
app_transaction_context_var.reset(token)
|
|
78
118
|
|
|
79
119
|
|
|
80
120
|
@runtime_checkable
|
|
81
121
|
class AppInterceptor(Protocol):
|
|
82
122
|
|
|
83
|
-
def intercept(
|
|
123
|
+
def intercept(
|
|
124
|
+
self, app_context: AppTransactionContext
|
|
125
|
+
) -> AsyncContextManager[None]: ...
|
|
84
126
|
|
|
85
127
|
|
|
86
128
|
class AppInterceptorWithLifecycle(Protocol):
|
|
@@ -106,6 +148,49 @@ class Microservice:
|
|
|
106
148
|
)
|
|
107
149
|
|
|
108
150
|
|
|
151
|
+
@dataclass
|
|
152
|
+
class InstantiationNode:
|
|
153
|
+
property_name: str
|
|
154
|
+
parent: "InstantiationNode | None" = None
|
|
155
|
+
source_type: Any | None = None
|
|
156
|
+
target_type: Any | None = None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
instantiation_vector_ctxvar = ContextVar[list[InstantiationNode]](
|
|
160
|
+
"instantiation_vector", default=[]
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def print_instantiation_vector(
|
|
165
|
+
instantiation_vector: list[InstantiationNode],
|
|
166
|
+
) -> None:
|
|
167
|
+
"""
|
|
168
|
+
Prints the instantiation vector for debugging purposes.
|
|
169
|
+
"""
|
|
170
|
+
for node in instantiation_vector:
|
|
171
|
+
print(
|
|
172
|
+
f"Property: {node.property_name}, Source: {node.source_type}, Target: {node.target_type}"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@contextmanager
|
|
177
|
+
def span_instantiation_vector(
|
|
178
|
+
instantiation_node: InstantiationNode,
|
|
179
|
+
) -> Generator[None, None, None]:
|
|
180
|
+
"""
|
|
181
|
+
Context manager to track instantiation nodes in a vector.
|
|
182
|
+
This is useful for debugging and tracing instantiation paths.
|
|
183
|
+
"""
|
|
184
|
+
current_vector = list(instantiation_vector_ctxvar.get())
|
|
185
|
+
current_vector.append(instantiation_node)
|
|
186
|
+
token = instantiation_vector_ctxvar.set(current_vector)
|
|
187
|
+
try:
|
|
188
|
+
yield
|
|
189
|
+
finally:
|
|
190
|
+
with suppress(ValueError):
|
|
191
|
+
instantiation_vector_ctxvar.reset(token)
|
|
192
|
+
|
|
193
|
+
|
|
109
194
|
class Container:
|
|
110
195
|
|
|
111
196
|
def __init__(self, app: Microservice) -> None:
|
|
@@ -122,40 +207,54 @@ class Container:
|
|
|
122
207
|
if provider.use_value:
|
|
123
208
|
self.instances_map[provider.provide] = provider.use_value
|
|
124
209
|
elif provider.use_class:
|
|
125
|
-
self.
|
|
210
|
+
self._get_and_register(provider.use_class, provider.provide)
|
|
126
211
|
elif provider.use_factory:
|
|
127
|
-
self.
|
|
212
|
+
self._get_and_register(provider.use_factory, provider.provide)
|
|
128
213
|
else:
|
|
129
|
-
self.
|
|
214
|
+
self._get_and_register(provider, provider)
|
|
130
215
|
|
|
131
|
-
def
|
|
216
|
+
def _instantiate(self, type_: type[Any] | Callable[..., Any]) -> Any:
|
|
132
217
|
|
|
133
|
-
dependencies = self.
|
|
218
|
+
dependencies = self._parse_dependencies(type_)
|
|
134
219
|
|
|
135
|
-
evaluated_dependencies = {
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
220
|
+
evaluated_dependencies: dict[str, Any] = {}
|
|
221
|
+
for name, dependency in dependencies.items():
|
|
222
|
+
with span_instantiation_vector(
|
|
223
|
+
InstantiationNode(
|
|
224
|
+
property_name=name,
|
|
225
|
+
source_type=type_,
|
|
226
|
+
target_type=dependency,
|
|
227
|
+
)
|
|
228
|
+
):
|
|
229
|
+
evaluated_dependencies[name] = self.get_or_register_token_or_type(
|
|
230
|
+
dependency
|
|
231
|
+
)
|
|
139
232
|
|
|
140
233
|
instance = type_(**evaluated_dependencies)
|
|
141
234
|
|
|
142
235
|
return instance
|
|
143
236
|
|
|
144
|
-
def
|
|
237
|
+
def _parse_dependencies(
|
|
145
238
|
self, provider: type[Any] | Callable[..., Any]
|
|
146
239
|
) -> dict[str, type[Any]]:
|
|
147
240
|
|
|
148
|
-
|
|
241
|
+
vector = instantiation_vector_ctxvar.get()
|
|
242
|
+
try:
|
|
243
|
+
signature = inspect.signature(provider)
|
|
244
|
+
except ValueError:
|
|
245
|
+
print("VECTOR:", vector)
|
|
246
|
+
print_instantiation_vector(vector)
|
|
247
|
+
raise
|
|
149
248
|
|
|
150
249
|
parameters = signature.parameters
|
|
151
250
|
|
|
152
251
|
return {
|
|
153
|
-
name: self.
|
|
252
|
+
name: self._lookup_parameter_type(parameter)
|
|
154
253
|
for name, parameter in parameters.items()
|
|
155
254
|
if parameter.annotation != inspect.Parameter.empty
|
|
156
255
|
}
|
|
157
256
|
|
|
158
|
-
def
|
|
257
|
+
def _lookup_parameter_type(self, parameter: inspect.Parameter) -> Any:
|
|
159
258
|
if parameter.annotation == inspect.Parameter.empty:
|
|
160
259
|
raise Exception(f"Parameter {parameter.name} has no type annotation")
|
|
161
260
|
|
|
@@ -188,14 +287,14 @@ class Container:
|
|
|
188
287
|
item_type = bind_to = token_or_type
|
|
189
288
|
|
|
190
289
|
if token_or_type not in self.instances_map:
|
|
191
|
-
return self.
|
|
290
|
+
return self._get_and_register(item_type, bind_to)
|
|
192
291
|
|
|
193
292
|
return cast(T, self.instances_map[bind_to])
|
|
194
293
|
|
|
195
|
-
def
|
|
294
|
+
def _get_and_register(
|
|
196
295
|
self, item_type: Type[T] | Callable[..., T], bind_to: Any
|
|
197
296
|
) -> T:
|
|
198
|
-
instance = self.
|
|
297
|
+
instance = self._instantiate(item_type)
|
|
199
298
|
self.register(instance, bind_to)
|
|
200
299
|
return cast(T, instance)
|
|
201
300
|
|
|
@@ -227,17 +326,21 @@ def provide_container(container: Container) -> Generator[None, None, None]:
|
|
|
227
326
|
|
|
228
327
|
|
|
229
328
|
__all__ = [
|
|
230
|
-
"
|
|
329
|
+
"AppTransactionContext",
|
|
231
330
|
"AppInterceptor",
|
|
232
331
|
"AppInterceptorWithLifecycle",
|
|
233
332
|
"Container",
|
|
234
333
|
"Microservice",
|
|
235
|
-
"
|
|
236
|
-
"
|
|
237
|
-
"
|
|
334
|
+
"SchedulerTransactionData",
|
|
335
|
+
"WebSocketTransactionData",
|
|
336
|
+
"app_transaction_context_var",
|
|
238
337
|
"current_container_ctx",
|
|
239
338
|
"provide_app_context",
|
|
240
339
|
"provide_container",
|
|
241
340
|
"use_app_context",
|
|
242
341
|
"use_current_container",
|
|
342
|
+
"HttpTransactionData",
|
|
343
|
+
"MessageBusTransactionData",
|
|
344
|
+
"is_interceptor_with_lifecycle",
|
|
345
|
+
"AppContext",
|
|
243
346
|
]
|
|
@@ -13,7 +13,7 @@ from typing import (
|
|
|
13
13
|
TypeVar,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
-
from jararaca.microservice import
|
|
16
|
+
from jararaca.microservice import AppTransactionContext
|
|
17
17
|
|
|
18
18
|
P = ParamSpec("P")
|
|
19
19
|
R = TypeVar("R")
|
|
@@ -28,9 +28,13 @@ class TracingContextProvider(Protocol):
|
|
|
28
28
|
|
|
29
29
|
class TracingContextProviderFactory(Protocol):
|
|
30
30
|
|
|
31
|
-
def root_setup(
|
|
31
|
+
def root_setup(
|
|
32
|
+
self, app_context: AppTransactionContext
|
|
33
|
+
) -> AsyncContextManager[None]: ...
|
|
32
34
|
|
|
33
|
-
def provide_provider(
|
|
35
|
+
def provide_provider(
|
|
36
|
+
self, app_context: AppTransactionContext
|
|
37
|
+
) -> TracingContextProvider: ...
|
|
34
38
|
|
|
35
39
|
|
|
36
40
|
tracing_ctx_provider_ctxv = ContextVar[TracingContextProvider]("tracing_ctx_provider")
|
|
@@ -2,9 +2,9 @@ from contextlib import asynccontextmanager
|
|
|
2
2
|
from typing import AsyncContextManager, AsyncGenerator, Protocol
|
|
3
3
|
|
|
4
4
|
from jararaca.microservice import (
|
|
5
|
-
AppContext,
|
|
6
5
|
AppInterceptor,
|
|
7
6
|
AppInterceptorWithLifecycle,
|
|
7
|
+
AppTransactionContext,
|
|
8
8
|
Container,
|
|
9
9
|
Microservice,
|
|
10
10
|
)
|
|
@@ -32,7 +32,9 @@ class ObservabilityInterceptor(AppInterceptor, AppInterceptorWithLifecycle):
|
|
|
32
32
|
self.observability_provider = observability_provider
|
|
33
33
|
|
|
34
34
|
@asynccontextmanager
|
|
35
|
-
async def intercept(
|
|
35
|
+
async def intercept(
|
|
36
|
+
self, app_context: AppTransactionContext
|
|
37
|
+
) -> AsyncGenerator[None, None]:
|
|
36
38
|
|
|
37
39
|
async with self.observability_provider.tracing_provider.root_setup(app_context):
|
|
38
40
|
|
|
@@ -23,7 +23,7 @@ from opentelemetry.sdk.trace import TracerProvider
|
|
|
23
23
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
24
24
|
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
|
25
25
|
|
|
26
|
-
from jararaca.microservice import
|
|
26
|
+
from jararaca.microservice import AppTransactionContext, Container, Microservice
|
|
27
27
|
from jararaca.observability.decorators import (
|
|
28
28
|
TracingContextProvider,
|
|
29
29
|
TracingContextProviderFactory,
|
|
@@ -36,7 +36,7 @@ tracer: trace.Tracer = trace.get_tracer(__name__)
|
|
|
36
36
|
|
|
37
37
|
class OtelTracingContextProvider(TracingContextProvider):
|
|
38
38
|
|
|
39
|
-
def __init__(self, app_context:
|
|
39
|
+
def __init__(self, app_context: AppTransactionContext) -> None:
|
|
40
40
|
self.app_context = app_context
|
|
41
41
|
|
|
42
42
|
@contextmanager
|
|
@@ -52,22 +52,26 @@ class OtelTracingContextProvider(TracingContextProvider):
|
|
|
52
52
|
|
|
53
53
|
class OtelTracingContextProviderFactory(TracingContextProviderFactory):
|
|
54
54
|
|
|
55
|
-
def provide_provider(
|
|
55
|
+
def provide_provider(
|
|
56
|
+
self, app_context: AppTransactionContext
|
|
57
|
+
) -> TracingContextProvider:
|
|
56
58
|
return OtelTracingContextProvider(app_context)
|
|
57
59
|
|
|
58
60
|
@asynccontextmanager
|
|
59
|
-
async def root_setup(
|
|
61
|
+
async def root_setup(
|
|
62
|
+
self, app_tx_ctx: AppTransactionContext
|
|
63
|
+
) -> AsyncGenerator[None, None]:
|
|
60
64
|
|
|
61
65
|
title: str = "Unmapped App Context Execution"
|
|
62
66
|
headers = {}
|
|
67
|
+
tx_data = app_tx_ctx.transaction_data
|
|
68
|
+
if tx_data.context_type == "http":
|
|
63
69
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
headers = dict(app_context.request.headers)
|
|
67
|
-
title = f"HTTP {app_context.request.method} {app_context.request.url}"
|
|
70
|
+
headers = dict(tx_data.request.headers)
|
|
71
|
+
title = f"HTTP {tx_data.request.method} {tx_data.request.url}"
|
|
68
72
|
|
|
69
|
-
elif
|
|
70
|
-
title = f"Message Bus {
|
|
73
|
+
elif tx_data.context_type == "message_bus":
|
|
74
|
+
title = f"Message Bus {tx_data.topic}"
|
|
71
75
|
|
|
72
76
|
carrier = {
|
|
73
77
|
key: value
|
jararaca/persistence/base.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Any, Self, Type, TypeVar
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
|
+
from sqlalchemy.ext.asyncio import AsyncAttrs
|
|
4
5
|
from sqlalchemy.orm import DeclarativeBase
|
|
5
6
|
|
|
6
7
|
IDENTIFIABLE_SCHEMA_T = TypeVar("IDENTIFIABLE_SCHEMA_T")
|
|
@@ -20,7 +21,7 @@ def recursive_get_dict(obj: Any) -> Any:
|
|
|
20
21
|
return obj
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
class BaseEntity(DeclarativeBase):
|
|
24
|
+
class BaseEntity(AsyncAttrs, DeclarativeBase):
|
|
24
25
|
|
|
25
26
|
@classmethod
|
|
26
27
|
def from_basemodel(cls, mutation: T_BASEMODEL) -> "Self":
|
|
@@ -3,21 +3,53 @@ from contextvars import ContextVar
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Any, AsyncGenerator, Generator
|
|
5
5
|
|
|
6
|
-
from sqlalchemy.ext.asyncio import
|
|
6
|
+
from sqlalchemy.ext.asyncio import (
|
|
7
|
+
AsyncSession,
|
|
8
|
+
AsyncSessionTransaction,
|
|
9
|
+
async_sessionmaker,
|
|
10
|
+
create_async_engine,
|
|
11
|
+
)
|
|
7
12
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
8
13
|
|
|
9
|
-
from jararaca.microservice import
|
|
14
|
+
from jararaca.microservice import AppInterceptor, AppTransactionContext
|
|
15
|
+
from jararaca.reflect.metadata import SetMetadata, get_metadata_value
|
|
10
16
|
|
|
11
|
-
|
|
17
|
+
DEFAULT_CONNECTION_NAME = "default"
|
|
18
|
+
|
|
19
|
+
ctx_default_connection_name: ContextVar[str] = ContextVar(
|
|
20
|
+
"ctx_default_connection_name", default=DEFAULT_CONNECTION_NAME
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def ensure_name(name: str | None) -> str:
|
|
25
|
+
return ctx_default_connection_name.get()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class PersistenceCtx:
|
|
30
|
+
session: AsyncSession
|
|
31
|
+
tx: AsyncSessionTransaction
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
ctx_session_map = ContextVar[dict[str, PersistenceCtx]]("ctx_session_map", default={})
|
|
12
35
|
|
|
13
36
|
|
|
14
37
|
@contextmanager
|
|
15
|
-
def
|
|
16
|
-
|
|
38
|
+
def providing_session(
|
|
39
|
+
session: AsyncSession,
|
|
40
|
+
tx: AsyncSessionTransaction,
|
|
41
|
+
connection_name: str | None = None,
|
|
17
42
|
) -> Generator[None, Any, None]:
|
|
43
|
+
"""
|
|
44
|
+
Context manager to provide a session and transaction for a specific connection name.
|
|
45
|
+
If no connection name is provided, it uses the default connection name from the context variable.
|
|
46
|
+
"""
|
|
47
|
+
connection_name = ensure_name(connection_name)
|
|
18
48
|
current_map = ctx_session_map.get({})
|
|
19
49
|
|
|
20
|
-
token = ctx_session_map.set(
|
|
50
|
+
token = ctx_session_map.set(
|
|
51
|
+
{**current_map, connection_name: PersistenceCtx(session, tx)}
|
|
52
|
+
)
|
|
21
53
|
|
|
22
54
|
try:
|
|
23
55
|
yield
|
|
@@ -26,18 +58,118 @@ def provide_session(
|
|
|
26
58
|
ctx_session_map.reset(token)
|
|
27
59
|
|
|
28
60
|
|
|
29
|
-
|
|
61
|
+
provide_session = providing_session
|
|
62
|
+
"""
|
|
63
|
+
Alias for `providing_session` to maintain backward compatibility.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@asynccontextmanager
|
|
68
|
+
async def providing_new_session(
|
|
69
|
+
connection_name: str | None = None,
|
|
70
|
+
) -> AsyncGenerator[AsyncSession, None]:
|
|
71
|
+
|
|
72
|
+
current_session = use_session(connection_name)
|
|
73
|
+
|
|
74
|
+
async with AsyncSession(
|
|
75
|
+
current_session.bind,
|
|
76
|
+
) as new_session, new_session.begin() as new_tx:
|
|
77
|
+
with providing_session(new_session, new_tx, connection_name):
|
|
78
|
+
yield new_session
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def use_session(connection_name: str | None = None) -> AsyncSession:
|
|
82
|
+
connection_name = ensure_name(connection_name)
|
|
30
83
|
current_map = ctx_session_map.get({})
|
|
31
84
|
if connection_name not in current_map:
|
|
32
|
-
raise ValueError(
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f'Session not found for connection "{connection_name}" in context. Check if your interceptor is correctly set up.'
|
|
87
|
+
)
|
|
33
88
|
|
|
34
|
-
return current_map[connection_name]
|
|
89
|
+
return current_map[connection_name].session
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@contextmanager
|
|
93
|
+
def providing_transaction(
|
|
94
|
+
tx: AsyncSessionTransaction,
|
|
95
|
+
connection_name: str | None = None,
|
|
96
|
+
) -> Generator[None, Any, None]:
|
|
97
|
+
connection_name = ensure_name(connection_name)
|
|
98
|
+
|
|
99
|
+
current_map = ctx_session_map.get({})
|
|
100
|
+
|
|
101
|
+
if connection_name not in current_map:
|
|
102
|
+
raise ValueError(f"No session found for connection {connection_name}")
|
|
103
|
+
|
|
104
|
+
with providing_session(current_map[connection_name].session, tx, connection_name):
|
|
105
|
+
yield
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def use_transaction(connection_name: str | None = None) -> AsyncSessionTransaction:
|
|
109
|
+
current_map = ctx_session_map.get({})
|
|
110
|
+
if connection_name not in current_map:
|
|
111
|
+
raise ValueError(f"Transaction not found for connection {connection_name}")
|
|
112
|
+
|
|
113
|
+
return current_map[connection_name].tx
|
|
35
114
|
|
|
36
115
|
|
|
37
|
-
@dataclass
|
|
38
116
|
class AIOSQAConfig:
|
|
39
|
-
connection_name: str
|
|
40
117
|
url: str | AsyncEngine
|
|
118
|
+
connection_name: str
|
|
119
|
+
inject_default: bool
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
url: str | AsyncEngine,
|
|
124
|
+
connection_name: str = DEFAULT_CONNECTION_NAME,
|
|
125
|
+
inject_default: bool = True,
|
|
126
|
+
):
|
|
127
|
+
self.url = url
|
|
128
|
+
self.connection_name = connection_name
|
|
129
|
+
self.inject_default = inject_default
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
INJECT_CONNECTION_METADATA = "inject_connection_metadata_{connection_name}"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def set_inject_connection(
|
|
136
|
+
inject: bool, connection_name: str = DEFAULT_CONNECTION_NAME
|
|
137
|
+
) -> SetMetadata:
|
|
138
|
+
"""
|
|
139
|
+
Set whether to inject the connection metadata for the given connection name.
|
|
140
|
+
This is useful when you want to control whether the connection metadata
|
|
141
|
+
should be injected into the context or not.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
return SetMetadata(
|
|
145
|
+
INJECT_CONNECTION_METADATA.format(connection_name=connection_name), inject
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def uses_connection(
|
|
150
|
+
connection_name: str = DEFAULT_CONNECTION_NAME,
|
|
151
|
+
) -> SetMetadata:
|
|
152
|
+
"""
|
|
153
|
+
Use connection metadata for the given connection name.
|
|
154
|
+
This is useful when you want to inject the connection metadata into the context,
|
|
155
|
+
for example, when you are using a specific connection for a specific operation.
|
|
156
|
+
"""
|
|
157
|
+
return SetMetadata(
|
|
158
|
+
INJECT_CONNECTION_METADATA.format(connection_name=connection_name), True
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def dnt_uses_connection(
|
|
163
|
+
connection_name: str = DEFAULT_CONNECTION_NAME,
|
|
164
|
+
) -> SetMetadata:
|
|
165
|
+
"""
|
|
166
|
+
Do not use connection metadata for the given connection name.
|
|
167
|
+
This is useful when you want to ensure that the connection metadata is not injected
|
|
168
|
+
into the context, for example, when you are using a different connection for a specific operation.
|
|
169
|
+
"""
|
|
170
|
+
return SetMetadata(
|
|
171
|
+
INJECT_CONNECTION_METADATA.format(connection_name=connection_name), False
|
|
172
|
+
)
|
|
41
173
|
|
|
42
174
|
|
|
43
175
|
class AIOSqlAlchemySessionInterceptor(AppInterceptor):
|
|
@@ -53,12 +185,31 @@ class AIOSqlAlchemySessionInterceptor(AppInterceptor):
|
|
|
53
185
|
self.sessionmaker = async_sessionmaker(self.engine)
|
|
54
186
|
|
|
55
187
|
@asynccontextmanager
|
|
56
|
-
async def intercept(
|
|
57
|
-
|
|
58
|
-
|
|
188
|
+
async def intercept(
|
|
189
|
+
self, app_context: AppTransactionContext
|
|
190
|
+
) -> AsyncGenerator[None, None]:
|
|
191
|
+
|
|
192
|
+
uses_connection_metadata = get_metadata_value(
|
|
193
|
+
INJECT_CONNECTION_METADATA.format(
|
|
194
|
+
connection_name=self.config.connection_name
|
|
195
|
+
),
|
|
196
|
+
self.config.inject_default,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if not uses_connection_metadata:
|
|
200
|
+
yield
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
async with self.sessionmaker() as session, session.begin() as tx:
|
|
204
|
+
token = ctx_default_connection_name.set(self.config.connection_name)
|
|
205
|
+
with providing_session(session, tx, self.config.connection_name):
|
|
59
206
|
try:
|
|
60
207
|
yield
|
|
61
|
-
|
|
208
|
+
if tx.is_active:
|
|
209
|
+
await tx.commit()
|
|
62
210
|
except Exception as e:
|
|
63
|
-
await
|
|
211
|
+
await tx.rollback()
|
|
64
212
|
raise e
|
|
213
|
+
finally:
|
|
214
|
+
with suppress(ValueError):
|
|
215
|
+
ctx_default_connection_name.reset(token)
|