jararaca 0.3.11a16__py3-none-any.whl → 0.4.0a19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. README.md +121 -0
  2. jararaca/__init__.py +189 -17
  3. jararaca/__main__.py +4 -0
  4. jararaca/broker_backend/__init__.py +4 -0
  5. jararaca/broker_backend/mapper.py +4 -0
  6. jararaca/broker_backend/redis_broker_backend.py +9 -3
  7. jararaca/cli.py +915 -51
  8. jararaca/common/__init__.py +3 -0
  9. jararaca/core/__init__.py +3 -0
  10. jararaca/core/providers.py +8 -0
  11. jararaca/core/uow.py +41 -7
  12. jararaca/di.py +4 -0
  13. jararaca/files/entity.py.mako +4 -0
  14. jararaca/helpers/__init__.py +3 -0
  15. jararaca/helpers/global_scheduler/__init__.py +3 -0
  16. jararaca/helpers/global_scheduler/config.py +21 -0
  17. jararaca/helpers/global_scheduler/controller.py +42 -0
  18. jararaca/helpers/global_scheduler/registry.py +32 -0
  19. jararaca/lifecycle.py +6 -2
  20. jararaca/messagebus/__init__.py +4 -0
  21. jararaca/messagebus/bus_message_controller.py +4 -0
  22. jararaca/messagebus/consumers/__init__.py +3 -0
  23. jararaca/messagebus/decorators.py +121 -61
  24. jararaca/messagebus/implicit_headers.py +49 -0
  25. jararaca/messagebus/interceptors/__init__.py +3 -0
  26. jararaca/messagebus/interceptors/aiopika_publisher_interceptor.py +62 -11
  27. jararaca/messagebus/interceptors/message_publisher_collector.py +62 -0
  28. jararaca/messagebus/interceptors/publisher_interceptor.py +29 -3
  29. jararaca/messagebus/message.py +4 -0
  30. jararaca/messagebus/publisher.py +6 -0
  31. jararaca/messagebus/worker.py +1002 -459
  32. jararaca/microservice.py +113 -2
  33. jararaca/observability/constants.py +7 -0
  34. jararaca/observability/decorators.py +170 -13
  35. jararaca/observability/fastapi_exception_handler.py +37 -0
  36. jararaca/observability/hooks.py +109 -0
  37. jararaca/observability/interceptor.py +4 -0
  38. jararaca/observability/providers/__init__.py +3 -0
  39. jararaca/observability/providers/otel.py +225 -16
  40. jararaca/persistence/base.py +39 -3
  41. jararaca/persistence/exports.py +4 -0
  42. jararaca/persistence/interceptors/__init__.py +3 -0
  43. jararaca/persistence/interceptors/aiosqa_interceptor.py +86 -73
  44. jararaca/persistence/interceptors/constants.py +5 -0
  45. jararaca/persistence/interceptors/decorators.py +50 -0
  46. jararaca/persistence/session.py +3 -0
  47. jararaca/persistence/sort_filter.py +4 -0
  48. jararaca/persistence/utilities.py +73 -20
  49. jararaca/presentation/__init__.py +3 -0
  50. jararaca/presentation/decorators.py +88 -86
  51. jararaca/presentation/exceptions.py +23 -0
  52. jararaca/presentation/hooks.py +4 -0
  53. jararaca/presentation/http_microservice.py +4 -0
  54. jararaca/presentation/server.py +97 -45
  55. jararaca/presentation/websocket/__init__.py +3 -0
  56. jararaca/presentation/websocket/base_types.py +4 -0
  57. jararaca/presentation/websocket/context.py +4 -0
  58. jararaca/presentation/websocket/decorators.py +8 -41
  59. jararaca/presentation/websocket/redis.py +280 -53
  60. jararaca/presentation/websocket/types.py +4 -0
  61. jararaca/presentation/websocket/websocket_interceptor.py +46 -19
  62. jararaca/reflect/__init__.py +3 -0
  63. jararaca/reflect/controller_inspect.py +16 -10
  64. jararaca/reflect/decorators.py +252 -0
  65. jararaca/reflect/helpers.py +18 -0
  66. jararaca/reflect/metadata.py +34 -25
  67. jararaca/rpc/__init__.py +3 -0
  68. jararaca/rpc/http/__init__.py +101 -0
  69. jararaca/rpc/http/backends/__init__.py +14 -0
  70. jararaca/rpc/http/backends/httpx.py +43 -9
  71. jararaca/rpc/http/backends/otel.py +4 -0
  72. jararaca/rpc/http/decorators.py +380 -115
  73. jararaca/rpc/http/httpx.py +3 -0
  74. jararaca/scheduler/__init__.py +3 -0
  75. jararaca/scheduler/beat_worker.py +521 -105
  76. jararaca/scheduler/decorators.py +15 -22
  77. jararaca/scheduler/types.py +4 -0
  78. jararaca/tools/app_config/__init__.py +3 -0
  79. jararaca/tools/app_config/decorators.py +7 -19
  80. jararaca/tools/app_config/interceptor.py +6 -2
  81. jararaca/tools/typescript/__init__.py +3 -0
  82. jararaca/tools/typescript/decorators.py +120 -0
  83. jararaca/tools/typescript/interface_parser.py +1077 -174
  84. jararaca/utils/__init__.py +3 -0
  85. jararaca/utils/env_parse_utils.py +133 -0
  86. jararaca/utils/rabbitmq_utils.py +112 -39
  87. jararaca/utils/retry.py +19 -14
  88. jararaca-0.4.0a19.dist-info/LICENSE +674 -0
  89. jararaca-0.4.0a19.dist-info/LICENSES/GPL-3.0-or-later.txt +232 -0
  90. {jararaca-0.3.11a16.dist-info → jararaca-0.4.0a19.dist-info}/METADATA +12 -7
  91. jararaca-0.4.0a19.dist-info/RECORD +96 -0
  92. {jararaca-0.3.11a16.dist-info → jararaca-0.4.0a19.dist-info}/WHEEL +1 -1
  93. pyproject.toml +132 -0
  94. jararaca-0.3.11a16.dist-info/RECORD +0 -74
  95. /jararaca-0.3.11a16.dist-info/LICENSE → /LICENSE +0 -0
  96. {jararaca-0.3.11a16.dist-info → jararaca-0.4.0a19.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,10 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
1
5
  import logging
2
6
  from contextlib import asynccontextmanager, contextmanager
3
- from typing import AsyncGenerator, Generator, Protocol
7
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Literal, Protocol
4
8
 
5
9
  from opentelemetry import metrics, trace
6
10
  from opentelemetry._logs import set_logger_provider
@@ -23,32 +27,165 @@ from opentelemetry.sdk.trace import TracerProvider
23
27
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
24
28
  from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
25
29
 
26
- from jararaca.microservice import AppTransactionContext, Container, Microservice
30
+ from jararaca.messagebus.implicit_headers import (
31
+ ImplicitHeaders,
32
+ provide_implicit_headers,
33
+ use_implicit_headers,
34
+ )
35
+ from jararaca.microservice import (
36
+ AppTransactionContext,
37
+ Container,
38
+ Microservice,
39
+ use_app_transaction_context,
40
+ )
41
+ from jararaca.observability.constants import TRACEPARENT_KEY
27
42
  from jararaca.observability.decorators import (
43
+ AttributeMap,
44
+ AttributeValue,
28
45
  TracingContextProvider,
29
46
  TracingContextProviderFactory,
30
- get_tracing_ctx_provider,
47
+ TracingSpan,
48
+ TracingSpanContext,
31
49
  )
32
50
  from jararaca.observability.interceptor import ObservabilityProvider
33
51
 
52
+ if TYPE_CHECKING:
53
+ from opentelemetry.trace import Span as _Span
54
+ from typing_extensions import TypeIs
55
+
34
56
  tracer: trace.Tracer = trace.get_tracer(__name__)
35
57
 
36
58
 
59
+ def extract_context_attributes(ctx: AppTransactionContext) -> dict[str, Any]:
60
+ tx_data = ctx.transaction_data
61
+ extra_attributes: dict[str, Any] = {}
62
+
63
+ if tx_data.context_type == "http":
64
+ extra_attributes = {
65
+ "http.method": tx_data.request.method,
66
+ "http.url": str(tx_data.request.url),
67
+ "http.path": tx_data.request.url.path,
68
+ "http.route.path": tx_data.request.scope["route"].path,
69
+ "http.route.endpoint.name": tx_data.request["route"].endpoint.__qualname__,
70
+ "http.query": tx_data.request.url.query,
71
+ **{
72
+ f"http.request.path_param.{k}": v
73
+ for k, v in tx_data.request.path_params.items()
74
+ },
75
+ **{
76
+ f"http.request.query_param.{k}": v
77
+ for k, v in tx_data.request.query_params.items()
78
+ },
79
+ **{
80
+ f"http.request.header.{k}": v
81
+ for k, v in tx_data.request.headers.items()
82
+ },
83
+ "http.request.client.host": (
84
+ tx_data.request.client.host if tx_data.request.client else ""
85
+ ),
86
+ }
87
+ elif tx_data.context_type == "message_bus":
88
+ extra_attributes = {
89
+ "bus.message.id": tx_data.message_id,
90
+ "bus.message.name": tx_data.message_type.__qualname__,
91
+ "bus.message.module": tx_data.message_type.__module__,
92
+ "bus.message.category": tx_data.message_type.MESSAGE_CATEGORY,
93
+ "bus.message.type": tx_data.message_type.MESSAGE_TYPE,
94
+ "bus.message.topic": tx_data.message_type.MESSAGE_TOPIC,
95
+ "bus.message.processing_attempt": tx_data.processing_attempt,
96
+ }
97
+ elif tx_data.context_type == "websocket":
98
+ extra_attributes = {
99
+ "ws.url": str(tx_data.websocket.url),
100
+ }
101
+ elif tx_data.context_type == "scheduler":
102
+ extra_attributes = {
103
+ "sched.task_name": tx_data.task_name,
104
+ "sched.scheduled_to": tx_data.scheduled_to.isoformat(),
105
+ "sched.cron_expression": tx_data.cron_expression,
106
+ "sched.triggered_at": tx_data.triggered_at.isoformat(),
107
+ }
108
+ return {
109
+ "app.context_type": tx_data.context_type,
110
+ "controller_member_reflect.rest_controller.class_name": ctx.controller_member_reflect.controller_reflect.controller_class.__qualname__,
111
+ "controller_member_reflect.rest_controller.module": ctx.controller_member_reflect.controller_reflect.controller_class.__module__,
112
+ "controller_member_reflect.member_function.name": ctx.controller_member_reflect.member_function.__qualname__,
113
+ "controller_member_reflect.member_function.module": ctx.controller_member_reflect.member_function.__module__,
114
+ **extra_attributes,
115
+ }
116
+
117
+
118
+ class OtelTracingSpan(TracingSpan):
119
+
120
+ def __init__(self, span: trace.Span) -> None:
121
+ self.span = span
122
+
123
+
124
+ class OtelTracingSpanContext(TracingSpanContext):
125
+
126
+ def __init__(self, span_context: trace.SpanContext) -> None:
127
+ self.span_context = span_context
128
+
129
+
37
130
  class OtelTracingContextProvider(TracingContextProvider):
38
131
 
39
132
  def __init__(self, app_context: AppTransactionContext) -> None:
40
133
  self.app_context = app_context
41
134
 
42
135
  @contextmanager
43
- def __call__(
136
+ def start_span_context(
44
137
  self,
45
138
  trace_name: str,
46
- context_attributes: dict[str, str],
139
+ context_attributes: AttributeMap | None,
47
140
  ) -> Generator[None, None, None]:
48
141
 
49
142
  with tracer.start_as_current_span(trace_name, attributes=context_attributes):
50
143
  yield
51
144
 
145
+ def add_event(
146
+ self, event_name: str, event_attributes: AttributeMap | None = None
147
+ ) -> None:
148
+ trace.get_current_span().add_event(name=event_name, attributes=event_attributes)
149
+
150
+ def set_span_status(self, status_code: Literal["OK", "ERROR", "UNSET"]) -> None:
151
+ span = trace.get_current_span()
152
+ if status_code == "OK":
153
+ span.set_status(trace.Status(trace.StatusCode.OK))
154
+ elif status_code == "ERROR":
155
+ span.set_status(trace.Status(trace.StatusCode.ERROR))
156
+ else:
157
+ span.set_status(trace.Status(trace.StatusCode.UNSET))
158
+
159
+ def record_exception(
160
+ self,
161
+ exception: Exception,
162
+ attributes: AttributeMap | None = None,
163
+ escaped: bool = False,
164
+ ) -> None:
165
+ span = trace.get_current_span()
166
+ span.record_exception(exception, attributes=attributes, escaped=escaped)
167
+
168
+ def set_span_attribute(self, key: str, value: AttributeValue) -> None:
169
+ span = trace.get_current_span()
170
+ span.set_attribute(key, value)
171
+
172
+ def update_span_name(self, new_name: str) -> None:
173
+ span = trace.get_current_span()
174
+
175
+ span.update_name(new_name)
176
+
177
+ def add_link(self, span_context: TracingSpanContext) -> None:
178
+ if not isinstance(span_context, OtelTracingSpanContext):
179
+ return
180
+ span = trace.get_current_span()
181
+ span.add_link(span_context.span_context)
182
+
183
+ def get_current_span(self) -> TracingSpan | None:
184
+ return OtelTracingSpan(trace.get_current_span())
185
+
186
+ def get_current_span_context(self) -> TracingSpanContext | None:
187
+ return OtelTracingSpanContext(trace.get_current_span().get_span_context())
188
+
52
189
 
53
190
  class OtelTracingContextProviderFactory(TracingContextProviderFactory):
54
191
 
@@ -59,19 +196,31 @@ class OtelTracingContextProviderFactory(TracingContextProviderFactory):
59
196
 
60
197
  @asynccontextmanager
61
198
  async def root_setup(
62
- self, app_tx_ctx: AppTransactionContext
199
+ self, app_context: AppTransactionContext
63
200
  ) -> AsyncGenerator[None, None]:
64
201
 
65
202
  title: str = "Unmapped App Context Execution"
66
- headers = {}
67
- tx_data = app_tx_ctx.transaction_data
68
- if tx_data.context_type == "http":
203
+ headers: dict[str, Any] = {}
204
+ tx_data = app_context.transaction_data
205
+ extra_attributes = extract_context_attributes(app_context)
69
206
 
207
+ if tx_data.context_type == "http":
70
208
  headers = dict(tx_data.request.headers)
71
209
  title = f"HTTP {tx_data.request.method} {tx_data.request.url}"
210
+ extra_attributes["http.request.body"] = (await tx_data.request.body())[
211
+ :5000
212
+ ].decode(errors="ignore")
72
213
 
73
214
  elif tx_data.context_type == "message_bus":
74
- title = f"Message Bus {tx_data.topic}"
215
+ title = f"Att#{tx_data.processing_attempt} Message Bus {tx_data.topic}"
216
+ headers = use_implicit_headers() or {}
217
+
218
+ elif tx_data.context_type == "websocket":
219
+ headers = dict(tx_data.websocket.headers)
220
+ title = f"WebSocket {tx_data.websocket.url}"
221
+
222
+ elif tx_data.context_type == "scheduler":
223
+ title = f"Scheduler Task {tx_data.task_name}"
75
224
 
76
225
  carrier = {
77
226
  key: value
@@ -90,8 +239,28 @@ class OtelTracingContextProviderFactory(TracingContextProviderFactory):
90
239
 
91
240
  ctx2 = W3CBaggagePropagator().extract(b2, context=ctx)
92
241
 
93
- with tracer.start_as_current_span(name=title, context=ctx2):
94
- yield
242
+ with tracer.start_as_current_span(
243
+ name=title,
244
+ context=ctx2,
245
+ attributes={
246
+ **extra_attributes,
247
+ },
248
+ ) as root_span:
249
+ cx = root_span.get_span_context()
250
+ span_traceparent_id = hex(cx.trace_id)[2:].rjust(32, "0")
251
+ if app_context.transaction_data.context_type == "http":
252
+ app_context.transaction_data.request.scope[TRACEPARENT_KEY] = (
253
+ span_traceparent_id
254
+ )
255
+ elif app_context.transaction_data.context_type == "websocket":
256
+ app_context.transaction_data.websocket.scope[TRACEPARENT_KEY] = (
257
+ span_traceparent_id
258
+ )
259
+ tracing_headers: ImplicitHeaders = {}
260
+ TraceContextTextMapPropagator().inject(tracing_headers)
261
+ W3CBaggagePropagator().inject(tracing_headers)
262
+ with provide_implicit_headers(tracing_headers):
263
+ yield
95
264
 
96
265
 
97
266
  class LoggerHandlerCallback(Protocol):
@@ -99,6 +268,46 @@ class LoggerHandlerCallback(Protocol):
99
268
  def __call__(self, logger_handler: logging.Handler) -> None: ...
100
269
 
101
270
 
271
+ class SpanWithName(Protocol):
272
+
273
+ @property
274
+ def name(self) -> str: ...
275
+
276
+
277
+ def is_span_with_name(span: Any) -> "TypeIs[SpanWithName]":
278
+ return hasattr(span, "name")
279
+
280
+
281
+ class CustomLoggingHandler(LoggingHandler):
282
+
283
+ def _translate(self, record: logging.LogRecord) -> dict[str, Any]:
284
+ try:
285
+ ctx = use_app_transaction_context()
286
+ data = super()._translate(record)
287
+ extra_attributes = extract_context_attributes(ctx)
288
+
289
+ current_span: "_Span" = trace.get_current_span()
290
+
291
+ data["attributes"] = {
292
+ **data.get("attributes", {}),
293
+ **extra_attributes,
294
+ **(
295
+ {
296
+ "span_name": (
297
+ current_span.name if is_span_with_name(current_span) else ""
298
+ ),
299
+ }
300
+ if hasattr(current_span, "name")
301
+ and current_span.is_recording() is False
302
+ else {}
303
+ ),
304
+ }
305
+
306
+ return data
307
+ except LookupError:
308
+ return super()._translate(record)
309
+
310
+
102
311
  class OtelObservabilityProvider(ObservabilityProvider):
103
312
 
104
313
  def __init__(
@@ -107,7 +316,7 @@ class OtelObservabilityProvider(ObservabilityProvider):
107
316
  logs_exporter: LogExporter,
108
317
  span_exporter: SpanExporter,
109
318
  meter_exporter: MeterExporter,
110
- logging_handler_callback: LoggerHandlerCallback = lambda _: None,
319
+ logging_handler_callback: LoggerHandlerCallback = lambda logger_handler: None,
111
320
  meter_export_interval: int = 5000,
112
321
  ) -> None:
113
322
  self.app_name = app_name
@@ -143,11 +352,11 @@ class OtelObservabilityProvider(ObservabilityProvider):
143
352
  BatchLogRecordProcessor(self.logs_exporter)
144
353
  )
145
354
 
146
- logging_handler = LoggingHandler(
355
+ logging_handler = CustomLoggingHandler(
147
356
  level=logging.DEBUG, logger_provider=logger_provider
148
357
  )
149
358
 
150
- logging_handler.addFilter(lambda _: get_tracing_ctx_provider() is not None)
359
+ # logging_handler.addFilter(lambda _: get_tracing_ctx_provider() is not None)
151
360
 
152
361
  self.logging_handler_callback(logging_handler)
153
362
 
@@ -165,7 +374,7 @@ class OtelObservabilityProvider(ObservabilityProvider):
165
374
  def from_url(
166
375
  app_name: str,
167
376
  url: str,
168
- logging_handler_callback: LoggerHandlerCallback = lambda _: None,
377
+ logging_handler_callback: LoggerHandlerCallback = lambda logger_handler: None,
169
378
  meter_export_interval: int = 5000,
170
379
  ) -> "OtelObservabilityProvider":
171
380
  """
@@ -1,4 +1,8 @@
1
- from typing import Any, Self, Type, TypeVar
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
5
+ from typing import Any, Callable, Protocol, Self, Type, TypeVar
2
6
 
3
7
  from pydantic import BaseModel
4
8
  from sqlalchemy.ext.asyncio import AsyncAttrs
@@ -21,14 +25,46 @@ def recursive_get_dict(obj: Any) -> Any:
21
25
  return obj
22
26
 
23
27
 
28
+ RESULT_T = TypeVar("RESULT_T", covariant=True)
29
+ ENTITY_T_CONTRA = TypeVar("ENTITY_T_CONTRA", bound="BaseEntity", contravariant=True)
30
+
31
+
32
+ class EntityParserType(Protocol[ENTITY_T_CONTRA, RESULT_T]):
33
+
34
+ @classmethod
35
+ def parse_entity(cls, model: ENTITY_T_CONTRA) -> "RESULT_T": ...
36
+
37
+
38
+ EntityParserFunc = Callable[[ENTITY_T_CONTRA], RESULT_T]
39
+
40
+ BASED_BASE_ENTITY_T = TypeVar("BASED_BASE_ENTITY_T", bound="BaseEntity")
41
+
42
+
24
43
  class BaseEntity(AsyncAttrs, DeclarativeBase):
25
44
 
26
45
  @classmethod
27
- def from_basemodel(cls, mutation: T_BASEMODEL) -> "Self":
46
+ def from_basemodel(cls, mutation: BaseModel) -> "Self":
28
47
  intersection = set(cls.__annotations__.keys()) & set(
29
- mutation.model_fields.keys()
48
+ mutation.__class__.model_fields.keys()
30
49
  )
31
50
  return cls(**{k: getattr(mutation, k) for k in intersection})
32
51
 
33
52
  def to_basemodel(self, model: Type[T_BASEMODEL]) -> T_BASEMODEL:
34
53
  return model.model_validate(recursive_get_dict(self))
54
+
55
+ def parse_entity_with_func(
56
+ self, model_cls: EntityParserFunc["Self", RESULT_T]
57
+ ) -> RESULT_T:
58
+ return model_cls(self)
59
+
60
+ def parse_entity_with_type(
61
+ self: BASED_BASE_ENTITY_T,
62
+ model_cls: Type[EntityParserType[BASED_BASE_ENTITY_T, RESULT_T]],
63
+ ) -> RESULT_T:
64
+ return model_cls.parse_entity(self)
65
+
66
+ def __rshift__(self, model: EntityParserFunc["Self", RESULT_T]) -> RESULT_T:
67
+ return self.parse_entity_with_func(model)
68
+
69
+ def __and__(self, model: Type[EntityParserType["Self", RESULT_T]]) -> RESULT_T:
70
+ return self.parse_entity_with_type(model)
@@ -1,3 +1,7 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
1
5
  from .base import BaseEntity
2
6
 
3
7
  __all__ = [
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
@@ -1,7 +1,11 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
1
5
  from contextlib import asynccontextmanager, contextmanager, suppress
2
6
  from contextvars import ContextVar
3
7
  from dataclasses import dataclass
4
- from typing import Any, AsyncGenerator, Generator
8
+ from typing import Any, AsyncGenerator, Generator, Protocol
5
9
 
6
10
  from sqlalchemy.ext.asyncio import (
7
11
  AsyncSession,
@@ -12,9 +16,47 @@ from sqlalchemy.ext.asyncio import (
12
16
  from sqlalchemy.ext.asyncio.engine import AsyncEngine
13
17
 
14
18
  from jararaca.microservice import AppInterceptor, AppTransactionContext
15
- from jararaca.reflect.metadata import SetMetadata, get_metadata_value
19
+ from jararaca.persistence.interceptors.constants import DEFAULT_CONNECTION_NAME
20
+ from jararaca.persistence.interceptors.decorators import (
21
+ INJECT_PERSISTENCE_SESSION_METADATA_TEMPLATE,
22
+ )
23
+ from jararaca.reflect.metadata import get_metadata_value
24
+
25
+
26
+ class SessionManager(Protocol):
27
+ def spawn_session(self, connection_name: str | None = None) -> AsyncSession: ...
28
+
29
+
30
+ ctx_session_manager: ContextVar[SessionManager | None] = ContextVar(
31
+ "ctx_session_manager", default=None
32
+ )
33
+
34
+
35
+ @contextmanager
36
+ def providing_session_manager(
37
+ session_manager: SessionManager,
38
+ ) -> Generator[None, Any, None]:
39
+ """
40
+ Context manager to provide a session manager for the current context.
41
+ """
42
+ token = ctx_session_manager.set(session_manager)
43
+ try:
44
+ yield
45
+ finally:
46
+ with suppress(ValueError):
47
+ ctx_session_manager.reset(token)
48
+
49
+
50
+ def use_session_manager() -> SessionManager:
51
+ """
52
+ Retrieve the current session manager from the context variable.
53
+ Raises ValueError if no session manager is set.
54
+ """
55
+ session_manager = ctx_session_manager.get()
56
+ if session_manager is None:
57
+ raise ValueError("No session manager set in the context.")
58
+ return session_manager
16
59
 
17
- DEFAULT_CONNECTION_NAME = "default"
18
60
 
19
61
  ctx_default_connection_name: ContextVar[str] = ContextVar(
20
62
  "ctx_default_connection_name", default=DEFAULT_CONNECTION_NAME
@@ -69,13 +111,21 @@ async def providing_new_session(
69
111
  connection_name: str | None = None,
70
112
  ) -> AsyncGenerator[AsyncSession, None]:
71
113
 
72
- current_session = use_session(connection_name)
114
+ session_manager = use_session_manager()
115
+ current_session = session_manager.spawn_session(connection_name)
73
116
 
74
117
  async with AsyncSession(
75
118
  current_session.bind,
76
119
  ) as new_session, new_session.begin() as new_tx:
77
120
  with providing_session(new_session, new_tx, connection_name):
78
- yield new_session
121
+ try:
122
+ yield new_session
123
+ if new_tx.is_active:
124
+ await new_tx.commit()
125
+ except Exception:
126
+ if new_tx.is_active:
127
+ await new_tx.rollback()
128
+ raise
79
129
 
80
130
 
81
131
  def use_session(connection_name: str | None = None) -> AsyncSession:
@@ -129,50 +179,7 @@ class AIOSQAConfig:
129
179
  self.inject_default = inject_default
130
180
 
131
181
 
132
- INJECT_CONNECTION_METADATA = "inject_connection_metadata_{connection_name}"
133
-
134
-
135
- def set_inject_connection(
136
- inject: bool, connection_name: str = DEFAULT_CONNECTION_NAME
137
- ) -> SetMetadata:
138
- """
139
- Set whether to inject the connection metadata for the given connection name.
140
- This is useful when you want to control whether the connection metadata
141
- should be injected into the context or not.
142
- """
143
-
144
- return SetMetadata(
145
- INJECT_CONNECTION_METADATA.format(connection_name=connection_name), inject
146
- )
147
-
148
-
149
- def uses_connection(
150
- connection_name: str = DEFAULT_CONNECTION_NAME,
151
- ) -> SetMetadata:
152
- """
153
- Use connection metadata for the given connection name.
154
- This is useful when you want to inject the connection metadata into the context,
155
- for example, when you are using a specific connection for a specific operation.
156
- """
157
- return SetMetadata(
158
- INJECT_CONNECTION_METADATA.format(connection_name=connection_name), True
159
- )
160
-
161
-
162
- def dnt_uses_connection(
163
- connection_name: str = DEFAULT_CONNECTION_NAME,
164
- ) -> SetMetadata:
165
- """
166
- Do not use connection metadata for the given connection name.
167
- This is useful when you want to ensure that the connection metadata is not injected
168
- into the context, for example, when you are using a different connection for a specific operation.
169
- """
170
- return SetMetadata(
171
- INJECT_CONNECTION_METADATA.format(connection_name=connection_name), False
172
- )
173
-
174
-
175
- class AIOSqlAlchemySessionInterceptor(AppInterceptor):
182
+ class AIOSqlAlchemySessionInterceptor(AppInterceptor, SessionManager):
176
183
 
177
184
  def __init__(self, config: AIOSQAConfig):
178
185
  self.config = config
@@ -189,27 +196,33 @@ class AIOSqlAlchemySessionInterceptor(AppInterceptor):
189
196
  self, app_context: AppTransactionContext
190
197
  ) -> AsyncGenerator[None, None]:
191
198
 
192
- uses_connection_metadata = get_metadata_value(
193
- INJECT_CONNECTION_METADATA.format(
194
- connection_name=self.config.connection_name
195
- ),
196
- self.config.inject_default,
197
- )
198
-
199
- if not uses_connection_metadata:
200
- yield
201
- return
202
-
203
- async with self.sessionmaker() as session, session.begin() as tx:
204
- token = ctx_default_connection_name.set(self.config.connection_name)
205
- with providing_session(session, tx, self.config.connection_name):
206
- try:
207
- yield
208
- if tx.is_active:
209
- await tx.commit()
210
- except Exception as e:
211
- await tx.rollback()
212
- raise e
213
- finally:
214
- with suppress(ValueError):
215
- ctx_default_connection_name.reset(token)
199
+ with providing_session_manager(self):
200
+ uses_connection_metadata = get_metadata_value(
201
+ INJECT_PERSISTENCE_SESSION_METADATA_TEMPLATE.format(
202
+ connection_name=self.config.connection_name
203
+ ),
204
+ self.config.inject_default,
205
+ )
206
+
207
+ if not uses_connection_metadata:
208
+ yield
209
+ return
210
+
211
+ async with self.sessionmaker() as session, session.begin() as tx:
212
+ token = ctx_default_connection_name.set(self.config.connection_name)
213
+ with providing_session(session, tx, self.config.connection_name):
214
+ try:
215
+ yield
216
+ if tx.is_active:
217
+ await tx.commit()
218
+ except Exception as e:
219
+ await tx.rollback()
220
+ raise e
221
+ finally:
222
+ with suppress(ValueError):
223
+ ctx_default_connection_name.reset(token)
224
+
225
+ def spawn_session(self, connection_name: str | None = None) -> AsyncSession:
226
+ connection_name = ensure_name(connection_name)
227
+ session = self.sessionmaker()
228
+ return session
@@ -0,0 +1,5 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
5
+ DEFAULT_CONNECTION_NAME = "default"
@@ -0,0 +1,50 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
5
+
6
+ from jararaca.persistence.interceptors.constants import DEFAULT_CONNECTION_NAME
7
+ from jararaca.reflect.metadata import SetMetadata
8
+
9
+ INJECT_PERSISTENCE_SESSION_METADATA_TEMPLATE = (
10
+ "inject_persistence_template_{connection_name}"
11
+ )
12
+
13
+
14
+ def set_use_persistence_session(
15
+ inject: bool, connection_name: str = DEFAULT_CONNECTION_NAME
16
+ ) -> SetMetadata:
17
+ """
18
+ Set whether to inject the connection metadata for the given connection name.
19
+ This is useful when you want to control whether the connection metadata
20
+ should be injected into the context or not.
21
+ """
22
+
23
+ return SetMetadata(
24
+ INJECT_PERSISTENCE_SESSION_METADATA_TEMPLATE.format(
25
+ connection_name=connection_name
26
+ ),
27
+ inject,
28
+ )
29
+
30
+
31
+ def uses_persistence_session(
32
+ connection_name: str = DEFAULT_CONNECTION_NAME,
33
+ ) -> SetMetadata:
34
+ """
35
+ Use connection metadata for the given connection name.
36
+ This is useful when you want to inject the connection metadata into the context,
37
+ for example, when you are using a specific connection for a specific operation.
38
+ """
39
+ return set_use_persistence_session(True, connection_name=connection_name)
40
+
41
+
42
+ def skip_persistence_session(
43
+ connection_name: str = DEFAULT_CONNECTION_NAME,
44
+ ) -> SetMetadata:
45
+ """
46
+ Decorator to skip using connection metadata for the given connection name.
47
+ This is useful when you want to ensure that the connection metadata is not injected
48
+ into the context, for example, when you are using a different connection for a specific operation.
49
+ """
50
+ return set_use_persistence_session(False, connection_name=connection_name)
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
@@ -1,3 +1,7 @@
1
+ # SPDX-FileCopyrightText: 2025 Lucas S
2
+ #
3
+ # SPDX-License-Identifier: GPL-3.0-or-later
4
+
1
5
  import re
2
6
  from datetime import date, datetime
3
7
  from functools import reduce