simple-module-background-tasks 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,42 @@
1
+ // Shared retry plumbing for the Index and Detail pages. Both POST to the
2
+ // same endpoint; they differ only in what they do with the result (reload
3
+ // the list vs. navigate to the new row), so the fetch + toast live here.
4
+
5
+ import { toast } from 'sonner';
6
+ import { API_BASE, type TaskStatus } from './constants';
7
+
8
+ export interface Execution {
9
+ id: string;
10
+ celery_task_id: string | null;
11
+ task_name: string;
12
+ status: TaskStatus;
13
+ queue: string;
14
+ retries: number;
15
+ worker: string | null;
16
+ queued_at: string | null;
17
+ started_at: string | null;
18
+ finished_at: string | null;
19
+ exception_type: string | null;
20
+ retried_from_id: string | null;
21
+ }
22
+
23
+ export async function retryExecution(execution: {
24
+ id: string;
25
+ task_name: string;
26
+ }): Promise<Execution | null> {
27
+ try {
28
+ const res = await fetch(`${API_BASE}/executions/${execution.id}/retry`, {
29
+ method: 'POST',
30
+ });
31
+ if (!res.ok) {
32
+ const body = await res.json().catch(() => ({}));
33
+ throw new Error(body.detail || `HTTP ${res.status}`);
34
+ }
35
+ const created = (await res.json()) as Execution;
36
+ toast.success(`Task "${execution.task_name}" re-enqueued`);
37
+ return created;
38
+ } catch (err) {
39
+ toast.error(err instanceof Error ? err.message : 'Failed to retry task');
40
+ return null;
41
+ }
42
+ }
File without changes
@@ -0,0 +1,138 @@
1
+ """BackgroundTaskService — admin listing, detail, and retry."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import uuid
6
+ from typing import TYPE_CHECKING
7
+
8
+ from fastapi import HTTPException, status
9
+ from simple_module_core.events import EventBus
10
+ from sqlalchemy import func, select
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+
13
+ from background_tasks.constants import RETRYABLE_STATUSES, TaskStatus
14
+ from background_tasks.contracts.events import TaskRetried
15
+ from background_tasks.contracts.schemas import (
16
+ TaskExecutionDetail,
17
+ TaskExecutionListItem,
18
+ TaskExecutionListResponse,
19
+ )
20
+ from background_tasks.models import TaskExecution
21
+
22
+ if TYPE_CHECKING:
23
+ from celery import Celery
24
+
25
+
26
+ class BackgroundTaskService:
27
+ """List, fetch, and retry task executions.
28
+
29
+ The service treats the DB as the system of record and the Celery app
30
+ as pure transport — ``retry`` loads the original row, enqueues a new
31
+ Celery task with the same args/kwargs, and inserts a fresh
32
+ ``TaskExecution`` linked to the original via ``retried_from_id``.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ db: AsyncSession,
38
+ celery: Celery,
39
+ event_bus: EventBus,
40
+ ) -> None:
41
+ self.db = db
42
+ self.celery = celery
43
+ self.event_bus = event_bus
44
+
45
+ async def list(
46
+ self,
47
+ *,
48
+ status: TaskStatus | None = None,
49
+ task_name: str | None = None,
50
+ page: int = 1,
51
+ per_page: int = 20,
52
+ ) -> TaskExecutionListResponse:
53
+ """Return a paginated listing, newest-first, with optional filters."""
54
+ page = max(page, 1)
55
+ per_page = max(1, min(per_page, 200))
56
+
57
+ # A window-function total lets us fetch the page and the count in a
58
+ # single round trip. SQLAlchemy's ``AsyncSession`` serialises calls on
59
+ # one connection, so ``asyncio.gather`` wouldn't help here.
60
+ total_col = func.count().over().label("_total")
61
+ query = select(TaskExecution, total_col)
62
+ if status is not None:
63
+ query = query.where(TaskExecution.status == status)
64
+ if task_name:
65
+ query = query.where(TaskExecution.task_name.ilike(f"%{task_name}%"))
66
+ query = (
67
+ query.order_by(TaskExecution.queued_at.desc().nulls_last())
68
+ .offset((page - 1) * per_page)
69
+ .limit(per_page)
70
+ )
71
+
72
+ result = (await self.db.execute(query)).all()
73
+ items = [TaskExecutionListItem.model_validate(row[0]) for row in result]
74
+ total = int(result[0][1]) if result else 0
75
+
76
+ return TaskExecutionListResponse(
77
+ items=items,
78
+ total=total,
79
+ page=page,
80
+ per_page=per_page,
81
+ status=status,
82
+ task_name=task_name,
83
+ )
84
+
85
+ async def get(self, execution_id: uuid.UUID) -> TaskExecutionDetail | None:
86
+ row = await self.db.get(TaskExecution, execution_id)
87
+ if row is None:
88
+ return None
89
+ return TaskExecutionDetail.model_validate(row)
90
+
91
+ async def retry(self, execution_id: uuid.UUID) -> TaskExecutionDetail:
92
+ """Re-enqueue a failed or stuck task.
93
+
94
+ The original row is immutable; we insert a new ``TaskExecution`` row
95
+ carrying ``retried_from_id`` so the detail page can show the chain.
96
+ """
97
+ row = await self.db.get(TaskExecution, execution_id)
98
+ if row is None:
99
+ raise HTTPException(status_code=404, detail="Task execution not found")
100
+
101
+ if row.status not in RETRYABLE_STATUSES:
102
+ raise HTTPException(
103
+ status_code=status.HTTP_409_CONFLICT,
104
+ detail=(
105
+ f"Task execution status is {str(row.status)!r}; "
106
+ "only failed or stuck tasks can be retried."
107
+ ),
108
+ )
109
+
110
+ async_result = self.celery.send_task(
111
+ row.task_name,
112
+ args=list(row.args or []),
113
+ kwargs=dict(row.kwargs or {}),
114
+ queue=row.queue,
115
+ )
116
+
117
+ new_row = TaskExecution(
118
+ celery_task_id=async_result.id,
119
+ task_name=row.task_name,
120
+ status=TaskStatus.PENDING,
121
+ queue=row.queue,
122
+ args=list(row.args or []),
123
+ kwargs=dict(row.kwargs or {}),
124
+ retried_from_id=row.id,
125
+ )
126
+ self.db.add(new_row)
127
+ await self.db.flush()
128
+ await self.db.refresh(new_row)
129
+
130
+ await self.event_bus.publish(
131
+ TaskRetried(
132
+ original_id=row.id,
133
+ new_id=new_row.id,
134
+ task_name=row.task_name,
135
+ )
136
+ )
137
+
138
+ return TaskExecutionDetail.model_validate(new_row)
@@ -0,0 +1,25 @@
1
+ """Module-scoped state container for background_tasks.
2
+
3
+ Stored as ``app.state.background_tasks`` by
4
+ :meth:`BackgroundTasksModule.register_settings`. Not frozen — ``on_startup``
5
+ populates :attr:`celery` once the Celery app is built. Convention: set once
6
+ during boot, treat as read-only after.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING
13
+
14
+ if TYPE_CHECKING:
15
+ from celery import Celery
16
+
17
+ from background_tasks.settings import BackgroundTasksSettings
18
+
19
+
20
+ @dataclass
21
+ class BackgroundTasksServices:
22
+ """BackgroundTasks singletons. Single slot at ``app.state.background_tasks``."""
23
+
24
+ settings: BackgroundTasksSettings
25
+ celery: Celery | None = None
@@ -0,0 +1,83 @@
1
+ """BackgroundTasks module settings (DB-backed).
2
+
3
+ Construction no longer reads ``SM_BG_TASKS_*`` environment variables. Values
4
+ come from pydantic defaults at boot, then get hydrated from the DB by the
5
+ hosting lifespan before module ``on_startup`` runs. Runtime changes go
6
+ through ``settings.reload.apply_changes_and_reload``.
7
+
8
+ The one remaining env read is ``SM_ENVIRONMENT``, consulted by the
9
+ ``@model_validator`` to refuse a localhost broker in production — that's a
10
+ host-level setting, not a background_tasks-module field.
11
+
12
+ The Celery-critical fields (``broker_url``, ``result_backend``,
13
+ ``task_default_queue``) are marked ``requires_restart=True`` via
14
+ ``json_schema_extra`` because workers read these once at process start, so
15
+ DB changes can't be hot-reloaded: the admin UI should surface that bumping
16
+ these values requires a worker restart.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+
23
+ from pydantic import Field, model_validator
24
+ from pydantic_settings import BaseSettings, SettingsConfigDict
25
+ from simple_module_core.environments import NON_PROD_ENVIRONMENTS
26
+
27
+ from background_tasks.constants import (
28
+ DEFAULT_BROKER_URL,
29
+ DEFAULT_MAX_RETRIES,
30
+ DEFAULT_PURGE_INTERVAL_SECONDS,
31
+ DEFAULT_QUEUE,
32
+ DEFAULT_RESULT_BACKEND,
33
+ DEFAULT_RETENTION_DAYS,
34
+ DEFAULT_STUCK_AFTER_SECONDS,
35
+ DEFAULT_STUCK_SWEEP_INTERVAL_SECONDS,
36
+ )
37
+
38
+ _CELERY_RESTART = {"requires_restart": True, "group": "Celery"}
39
+
40
+
41
+ class BackgroundTasksSettings(BaseSettings):
42
+ """Configuration for the Celery + Redis task runner."""
43
+
44
+ model_config = SettingsConfigDict(extra="ignore")
45
+
46
+ broker_url: str = Field(default=DEFAULT_BROKER_URL, json_schema_extra=_CELERY_RESTART)
47
+ result_backend: str = Field(default=DEFAULT_RESULT_BACKEND, json_schema_extra=_CELERY_RESTART)
48
+ task_default_queue: str = Field(default=DEFAULT_QUEUE, json_schema_extra=_CELERY_RESTART)
49
+
50
+ # A task that has been ``running`` longer than this without a heartbeat is
51
+ # flipped to ``stuck`` by the beat sweep. 5 min is long enough to cover
52
+ # normal slow jobs while still surfacing wedged workers within one UI
53
+ # refresh.
54
+ stuck_after_seconds: int = DEFAULT_STUCK_AFTER_SECONDS
55
+ stuck_sweep_interval_seconds: int = DEFAULT_STUCK_SWEEP_INTERVAL_SECONDS
56
+ purge_interval_seconds: int = DEFAULT_PURGE_INTERVAL_SECONDS
57
+
58
+ retention_days: int = DEFAULT_RETENTION_DAYS
59
+ max_retries: int = DEFAULT_MAX_RETRIES
60
+
61
+ @model_validator(mode="after")
62
+ def _forbid_localhost_broker_in_production(self) -> BackgroundTasksSettings:
63
+ """Fail boot if production is still pointed at the dev default broker.
64
+
65
+ A localhost broker in prod means the web container is talking to its
66
+ own 6379 instead of the shared Redis service — tasks would silently
67
+ queue to a broker no worker reads.
68
+ """
69
+ env = os.environ.get("SM_ENVIRONMENT", "development")
70
+ if env in NON_PROD_ENVIRONMENTS:
71
+ return self
72
+ bad = []
73
+ if "localhost" in self.broker_url or "127.0.0.1" in self.broker_url:
74
+ bad.append("broker_url")
75
+ if "localhost" in self.result_backend or "127.0.0.1" in self.result_backend:
76
+ bad.append("result_backend")
77
+ if bad:
78
+ names = ", ".join(bad)
79
+ raise ValueError(
80
+ f"{names} must not point at localhost when SM_ENVIRONMENT={env!r}. "
81
+ "Set these to the Redis service host (e.g. redis://redis:6379/0)."
82
+ )
83
+ return self
@@ -0,0 +1,286 @@
1
+ """Celery signal handlers that keep the ``TaskExecution`` table in sync.
2
+
3
+ Each signal writes one row: publish creates ``pending``, prerun flips to
4
+ ``running`` with a heartbeat, postrun/success/failure/retry/revoked update
5
+ the terminal columns. We update an existing row (matched by
6
+ ``celery_task_id``) so a single task's lifecycle stays on one row instead
7
+ of spawning a new row per signal.
8
+
9
+ Signals are sync — see :mod:`.sync_db` for why we maintain a separate sync
10
+ engine, and :mod:`._signal_support` for the shared helpers.
11
+
12
+ ``TaskFailed`` is also dispatched from :func:`on_task_failure` when an event
13
+ bus has been bound via :func:`bind_event_bus` — typically from the web
14
+ process's ``on_startup``. In a standalone Celery worker (separate process)
15
+ no bus is bound and the publish becomes a no-op; subscribers that need
16
+ cross-process notification should consume Celery events or poll the DB
17
+ directly.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import asyncio
23
+ import logging
24
+ from collections.abc import Callable
25
+ from typing import Any
26
+
27
+ from celery import signals
28
+ from simple_module_core.events import EventBus
29
+
30
+ from background_tasks._signal_support import (
31
+ coerce_args_kwargs,
32
+ jsonable_result,
33
+ now_utc,
34
+ render_traceback,
35
+ task_id_from,
36
+ task_name_of,
37
+ upsert_by_celery_id,
38
+ )
39
+ from background_tasks.constants import DEFAULT_QUEUE, TaskStatus
40
+ from background_tasks.contracts.events import TaskFailed
41
+ from background_tasks.models import TaskExecution
42
+ from background_tasks.sync_db import sync_session
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ _bus: EventBus | None = None
48
+ _loop: asyncio.AbstractEventLoop | None = None
49
+
50
+
51
+ def bind_event_bus(bus: EventBus, loop: asyncio.AbstractEventLoop) -> None:
52
+ """Bind an event bus + its running loop so signals can publish events.
53
+
54
+ Signals fire on the Celery sync thread; `run_coroutine_threadsafe`
55
+ bridges back to ``loop`` so handlers run on the API event loop
56
+ regardless of which thread triggered the signal.
57
+ """
58
+ global _bus, _loop
59
+ _bus = bus
60
+ _loop = loop
61
+
62
+
63
+ def unbind_event_bus() -> None:
64
+ """Drop the bound bus — called from ``on_shutdown`` so tests stay isolated."""
65
+ global _bus, _loop
66
+ _bus = None
67
+ _loop = None
68
+
69
+
70
+ def _publish_from_signal(event: Any) -> None:
71
+ """Dispatch ``event`` onto the bound bus without blocking the signal thread."""
72
+ if _bus is None or _loop is None:
73
+ return
74
+ try:
75
+ future = asyncio.run_coroutine_threadsafe(_bus.publish(event), _loop)
76
+ except RuntimeError:
77
+ # Loop has stopped (shutdown race). The DB row is already written.
78
+ logger.debug("Event bus loop is not running; skipping %s", type(event).__name__)
79
+ return
80
+ # Surface subscriber exceptions — run_coroutine_threadsafe otherwise only
81
+ # logs them when the Future is GC'd, which happens far from the failure.
82
+ future.add_done_callback(_log_publish_failure)
83
+
84
+
85
+ def _log_publish_failure(future: asyncio.Future[Any]) -> None:
86
+ if future.cancelled():
87
+ return
88
+ exc = future.exception()
89
+ if exc is not None:
90
+ logger.error("Event publish raised: %s", exc, exc_info=exc)
91
+
92
+
93
+ def _apply(
94
+ handler: str,
95
+ *,
96
+ celery_task_id: str | None,
97
+ defaults: dict[str, Any],
98
+ after: Callable[[TaskExecution], None] | None = None,
99
+ ) -> TaskExecution | None:
100
+ """Open a sync session, upsert by celery_task_id, log on failure.
101
+
102
+ Returns the upserted row (or ``None`` if the session raised) so callers
103
+ can read DB-assigned fields like ``id`` without a second round trip.
104
+ """
105
+ try:
106
+ with sync_session() as session:
107
+ row = upsert_by_celery_id(session, celery_task_id=celery_task_id, defaults=defaults)
108
+ if after is not None:
109
+ after(row)
110
+ return row
111
+ except Exception:
112
+ logger.exception("%s failed for task_id=%s", handler, celery_task_id)
113
+ return None
114
+
115
+
116
+ # ── Enqueue ─────────────────────────────────────────────────────
117
+
118
+
119
+ @signals.before_task_publish.connect
120
+ def on_task_publish(
121
+ sender: str | None = None,
122
+ headers: dict[str, Any] | None = None,
123
+ body: Any = None,
124
+ routing_key: str | None = None,
125
+ **_kwargs: Any,
126
+ ) -> None:
127
+ """Record a row the moment a task is pushed onto the broker."""
128
+ task_id = (headers or {}).get("id")
129
+ task_name = sender or (headers or {}).get("task") or "unknown"
130
+
131
+ # ``body`` on the publish signal is ``(args, kwargs, options)`` for the
132
+ # standard Celery protocol; guard against edge shapes.
133
+ args_in, kwargs_in = [], {}
134
+ if isinstance(body, list | tuple) and len(body) >= 2:
135
+ args_in, kwargs_in = body[0], body[1]
136
+ args, kwargs = coerce_args_kwargs(args_in, kwargs_in)
137
+
138
+ _apply(
139
+ "on_task_publish",
140
+ celery_task_id=task_id,
141
+ defaults={
142
+ "task_name": task_name,
143
+ "status": TaskStatus.PENDING,
144
+ "queue": routing_key or DEFAULT_QUEUE,
145
+ "args": args,
146
+ "kwargs": kwargs,
147
+ "queued_at": now_utc(),
148
+ },
149
+ )
150
+
151
+
152
+ # ── Execution lifecycle ─────────────────────────────────────────
153
+
154
+
155
+ @signals.task_prerun.connect
156
+ def on_task_prerun(
157
+ sender: Any = None,
158
+ task_id: str | None = None,
159
+ task: Any = None,
160
+ args: Any = None,
161
+ kwargs: Any = None,
162
+ **_k: Any,
163
+ ) -> None:
164
+ """Flip the row to ``running`` and start the heartbeat."""
165
+ args_n, kwargs_n = coerce_args_kwargs(args, kwargs)
166
+ now = now_utc()
167
+ _apply(
168
+ "on_task_prerun",
169
+ celery_task_id=task_id,
170
+ defaults={
171
+ "task_name": task_name_of(sender, task),
172
+ "status": TaskStatus.RUNNING,
173
+ "args": args_n,
174
+ "kwargs": kwargs_n,
175
+ "started_at": now,
176
+ "heartbeat_at": now,
177
+ },
178
+ )
179
+
180
+
181
+ @signals.task_postrun.connect
182
+ def on_task_postrun(
183
+ sender: Any = None,
184
+ task_id: str | None = None,
185
+ task: Any = None,
186
+ **_k: Any,
187
+ ) -> None:
188
+ """Refresh the heartbeat on normal completion.
189
+
190
+ Terminal status is written by ``task_success`` / ``task_failure`` /
191
+ ``task_retry`` which fire *before* postrun; postrun only refreshes the
192
+ heartbeat so the sweep doesn't immediately flip a just-finished row.
193
+ """
194
+ _apply(
195
+ "on_task_postrun",
196
+ celery_task_id=task_id,
197
+ defaults={"task_name": task_name_of(sender, task), "heartbeat_at": now_utc()},
198
+ )
199
+
200
+
201
+ @signals.task_success.connect
202
+ def on_task_success(sender: Any = None, result: Any = None, **_k: Any) -> None:
203
+ _apply(
204
+ "on_task_success",
205
+ celery_task_id=task_id_from(sender=sender),
206
+ defaults={
207
+ "task_name": task_name_of(sender),
208
+ "status": TaskStatus.SUCCESS,
209
+ "result": jsonable_result(result),
210
+ "finished_at": now_utc(),
211
+ "traceback": None,
212
+ "exception_type": None,
213
+ },
214
+ )
215
+
216
+
217
+ @signals.task_failure.connect
218
+ def on_task_failure(
219
+ sender: Any = None,
220
+ task_id: str | None = None,
221
+ exception: BaseException | None = None,
222
+ einfo: Any = None,
223
+ **_k: Any,
224
+ ) -> None:
225
+ task_name = task_name_of(sender)
226
+ exception_type = type(exception).__name__ if exception is not None else None
227
+
228
+ row = _apply(
229
+ "on_task_failure",
230
+ celery_task_id=task_id,
231
+ defaults={
232
+ "task_name": task_name,
233
+ "status": TaskStatus.FAILED,
234
+ "traceback": render_traceback(einfo, exception),
235
+ "exception_type": exception_type,
236
+ "finished_at": now_utc(),
237
+ },
238
+ )
239
+
240
+ if row is not None:
241
+ _publish_from_signal(
242
+ TaskFailed(
243
+ task_execution_id=row.id,
244
+ task_name=task_name,
245
+ exception_type=exception_type,
246
+ )
247
+ )
248
+
249
+
250
+ @signals.task_retry.connect
251
+ def on_task_retry(
252
+ sender: Any = None,
253
+ request: Any = None,
254
+ reason: Any = None,
255
+ **_k: Any,
256
+ ) -> None:
257
+ def _stamp_reason(row: TaskExecution) -> None:
258
+ # Preserve the latest reason for the UI without nulling a
259
+ # previously-captured traceback.
260
+ if reason is not None:
261
+ row.traceback = f"retry: {reason!r}"
262
+
263
+ _apply(
264
+ "on_task_retry",
265
+ celery_task_id=task_id_from(request=request),
266
+ defaults={
267
+ "task_name": task_name_of(sender),
268
+ "status": TaskStatus.RETRYING,
269
+ "retries": int(getattr(request, "retries", 0) or 0),
270
+ "heartbeat_at": now_utc(),
271
+ },
272
+ after=_stamp_reason,
273
+ )
274
+
275
+
276
+ @signals.task_revoked.connect
277
+ def on_task_revoked(sender: Any = None, request: Any = None, **_k: Any) -> None:
278
+ _apply(
279
+ "on_task_revoked",
280
+ celery_task_id=task_id_from(request=request),
281
+ defaults={
282
+ "task_name": task_name_of(sender),
283
+ "status": TaskStatus.REVOKED,
284
+ "finished_at": now_utc(),
285
+ },
286
+ )
@@ -0,0 +1,82 @@
1
+ """Sync SQLAlchemy session factory for Celery signal handlers.
2
+
3
+ Celery signals are called synchronously from inside the worker / web-process
4
+ hot path. Building an event-loop just to await a query is both ugly and
5
+ deadlock-prone (web-process signals can fire from inside an already-running
6
+ loop). We instead maintain a second, sync engine pointed at the same DB URL
7
+ and borrow a short-lived session from it whenever a signal fires.
8
+
9
+ The engine is lazily built on first use and process-global. It's cheap —
10
+ SQLAlchemy's connection pool is per-process so the signal path reuses
11
+ connections after the first signal.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ import os
18
+ from collections.abc import Iterator
19
+ from contextlib import contextmanager
20
+
21
+ from sqlalchemy import create_engine
22
+ from sqlalchemy.engine import Engine
23
+ from sqlalchemy.orm import Session, sessionmaker
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ _engine: Engine | None = None
28
+ _session_factory: sessionmaker[Session] | None = None
29
+
30
+
31
+ def _sync_url(async_url: str) -> str:
32
+ """Convert an async SQLAlchemy URL to its sync driver equivalent.
33
+
34
+ Mirrors ``host/migrations/env.py`` so signals use the same URL shape
35
+ Alembic does.
36
+ """
37
+ return async_url.replace("+aiosqlite", "").replace("+asyncpg", "+psycopg2")
38
+
39
+
40
+ def _build_engine() -> Engine:
41
+ url = os.environ.get("SM_DATABASE_URL", "sqlite:///./app.db")
42
+ sync_url = _sync_url(url)
43
+ # Small pool — signals fire sequentially per worker process.
44
+ return create_engine(sync_url, pool_pre_ping=True, pool_size=2, max_overflow=3)
45
+
46
+
47
+ def get_sync_session_factory() -> sessionmaker[Session]:
48
+ """Return the process-global sync session factory, building it once."""
49
+ global _engine, _session_factory
50
+ if _session_factory is None:
51
+ _engine = _build_engine()
52
+ _session_factory = sessionmaker(bind=_engine, expire_on_commit=False)
53
+ return _session_factory
54
+
55
+
56
+ def dispose_sync_engine() -> None:
57
+ """Release pooled connections and drop the cached engine.
58
+
59
+ Called from :meth:`BackgroundTasksModule.on_shutdown` so lifespan
60
+ restarts within one process (test runners, uvicorn dev reload) don't
61
+ accumulate engines against the old DB URL.
62
+ """
63
+ global _engine, _session_factory
64
+ if _engine is not None:
65
+ _engine.dispose()
66
+ _engine = None
67
+ _session_factory = None
68
+
69
+
70
+ @contextmanager
71
+ def sync_session() -> Iterator[Session]:
72
+ """Open a short-lived sync session; commit on success, rollback on error."""
73
+ factory = get_sync_session_factory()
74
+ session = factory()
75
+ try:
76
+ yield session
77
+ session.commit()
78
+ except Exception:
79
+ session.rollback()
80
+ raise
81
+ finally:
82
+ session.close()