simple-module-background-tasks 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- background_tasks/__init__.py +1 -0
- background_tasks/_signal_support.py +105 -0
- background_tasks/celery_app.py +95 -0
- background_tasks/constants.py +75 -0
- background_tasks/contracts/__init__.py +1 -0
- background_tasks/contracts/events.py +26 -0
- background_tasks/contracts/schemas.py +66 -0
- background_tasks/deps.py +19 -0
- background_tasks/endpoints/__init__.py +0 -0
- background_tasks/endpoints/api_admin.py +62 -0
- background_tasks/endpoints/views.py +71 -0
- background_tasks/locales/en.json +57 -0
- background_tasks/models.py +86 -0
- background_tasks/module.py +121 -0
- background_tasks/package.json +16 -0
- background_tasks/pages/Detail.tsx +180 -0
- background_tasks/pages/Index.tsx +181 -0
- background_tasks/pages/components/ExecutionRow.tsx +79 -0
- background_tasks/pages/components/RetryConfirmDialog.tsx +38 -0
- background_tasks/pages/constants.ts +49 -0
- background_tasks/pages/retry.ts +42 -0
- background_tasks/py.typed +0 -0
- background_tasks/service.py +138 -0
- background_tasks/services.py +25 -0
- background_tasks/settings.py +83 -0
- background_tasks/signals.py +286 -0
- background_tasks/sync_db.py +82 -0
- background_tasks/tasks.py +105 -0
- simple_module_background_tasks-0.0.1.dist-info/METADATA +92 -0
- simple_module_background_tasks-0.0.1.dist-info/RECORD +33 -0
- simple_module_background_tasks-0.0.1.dist-info/WHEEL +4 -0
- simple_module_background_tasks-0.0.1.dist-info/entry_points.txt +2 -0
- simple_module_background_tasks-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
// Shared retry plumbing for the Index and Detail pages. Both POST to the
|
|
2
|
+
// same endpoint; they differ only in what they do with the result (reload
|
|
3
|
+
// the list vs. navigate to the new row), so the fetch + toast live here.
|
|
4
|
+
|
|
5
|
+
import { toast } from 'sonner';
|
|
6
|
+
import { API_BASE, type TaskStatus } from './constants';
|
|
7
|
+
|
|
8
|
+
export interface Execution {
|
|
9
|
+
id: string;
|
|
10
|
+
celery_task_id: string | null;
|
|
11
|
+
task_name: string;
|
|
12
|
+
status: TaskStatus;
|
|
13
|
+
queue: string;
|
|
14
|
+
retries: number;
|
|
15
|
+
worker: string | null;
|
|
16
|
+
queued_at: string | null;
|
|
17
|
+
started_at: string | null;
|
|
18
|
+
finished_at: string | null;
|
|
19
|
+
exception_type: string | null;
|
|
20
|
+
retried_from_id: string | null;
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
export async function retryExecution(execution: {
|
|
24
|
+
id: string;
|
|
25
|
+
task_name: string;
|
|
26
|
+
}): Promise<Execution | null> {
|
|
27
|
+
try {
|
|
28
|
+
const res = await fetch(`${API_BASE}/executions/${execution.id}/retry`, {
|
|
29
|
+
method: 'POST',
|
|
30
|
+
});
|
|
31
|
+
if (!res.ok) {
|
|
32
|
+
const body = await res.json().catch(() => ({}));
|
|
33
|
+
throw new Error(body.detail || `HTTP ${res.status}`);
|
|
34
|
+
}
|
|
35
|
+
const created = (await res.json()) as Execution;
|
|
36
|
+
toast.success(`Task "${execution.task_name}" re-enqueued`);
|
|
37
|
+
return created;
|
|
38
|
+
} catch (err) {
|
|
39
|
+
toast.error(err instanceof Error ? err.message : 'Failed to retry task');
|
|
40
|
+
return null;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""BackgroundTaskService — admin listing, detail, and retry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
from fastapi import HTTPException, status
|
|
9
|
+
from simple_module_core.events import EventBus
|
|
10
|
+
from sqlalchemy import func, select
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
|
+
|
|
13
|
+
from background_tasks.constants import RETRYABLE_STATUSES, TaskStatus
|
|
14
|
+
from background_tasks.contracts.events import TaskRetried
|
|
15
|
+
from background_tasks.contracts.schemas import (
|
|
16
|
+
TaskExecutionDetail,
|
|
17
|
+
TaskExecutionListItem,
|
|
18
|
+
TaskExecutionListResponse,
|
|
19
|
+
)
|
|
20
|
+
from background_tasks.models import TaskExecution
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from celery import Celery
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BackgroundTaskService:
|
|
27
|
+
"""List, fetch, and retry task executions.
|
|
28
|
+
|
|
29
|
+
The service treats the DB as the system of record and the Celery app
|
|
30
|
+
as pure transport — ``retry`` loads the original row, enqueues a new
|
|
31
|
+
Celery task with the same args/kwargs, and inserts a fresh
|
|
32
|
+
``TaskExecution`` linked to the original via ``retried_from_id``.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
db: AsyncSession,
|
|
38
|
+
celery: Celery,
|
|
39
|
+
event_bus: EventBus,
|
|
40
|
+
) -> None:
|
|
41
|
+
self.db = db
|
|
42
|
+
self.celery = celery
|
|
43
|
+
self.event_bus = event_bus
|
|
44
|
+
|
|
45
|
+
async def list(
|
|
46
|
+
self,
|
|
47
|
+
*,
|
|
48
|
+
status: TaskStatus | None = None,
|
|
49
|
+
task_name: str | None = None,
|
|
50
|
+
page: int = 1,
|
|
51
|
+
per_page: int = 20,
|
|
52
|
+
) -> TaskExecutionListResponse:
|
|
53
|
+
"""Return a paginated listing, newest-first, with optional filters."""
|
|
54
|
+
page = max(page, 1)
|
|
55
|
+
per_page = max(1, min(per_page, 200))
|
|
56
|
+
|
|
57
|
+
# A window-function total lets us fetch the page and the count in a
|
|
58
|
+
# single round trip. SQLAlchemy's ``AsyncSession`` serialises calls on
|
|
59
|
+
# one connection, so ``asyncio.gather`` wouldn't help here.
|
|
60
|
+
total_col = func.count().over().label("_total")
|
|
61
|
+
query = select(TaskExecution, total_col)
|
|
62
|
+
if status is not None:
|
|
63
|
+
query = query.where(TaskExecution.status == status)
|
|
64
|
+
if task_name:
|
|
65
|
+
query = query.where(TaskExecution.task_name.ilike(f"%{task_name}%"))
|
|
66
|
+
query = (
|
|
67
|
+
query.order_by(TaskExecution.queued_at.desc().nulls_last())
|
|
68
|
+
.offset((page - 1) * per_page)
|
|
69
|
+
.limit(per_page)
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
result = (await self.db.execute(query)).all()
|
|
73
|
+
items = [TaskExecutionListItem.model_validate(row[0]) for row in result]
|
|
74
|
+
total = int(result[0][1]) if result else 0
|
|
75
|
+
|
|
76
|
+
return TaskExecutionListResponse(
|
|
77
|
+
items=items,
|
|
78
|
+
total=total,
|
|
79
|
+
page=page,
|
|
80
|
+
per_page=per_page,
|
|
81
|
+
status=status,
|
|
82
|
+
task_name=task_name,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def get(self, execution_id: uuid.UUID) -> TaskExecutionDetail | None:
|
|
86
|
+
row = await self.db.get(TaskExecution, execution_id)
|
|
87
|
+
if row is None:
|
|
88
|
+
return None
|
|
89
|
+
return TaskExecutionDetail.model_validate(row)
|
|
90
|
+
|
|
91
|
+
async def retry(self, execution_id: uuid.UUID) -> TaskExecutionDetail:
|
|
92
|
+
"""Re-enqueue a failed or stuck task.
|
|
93
|
+
|
|
94
|
+
The original row is immutable; we insert a new ``TaskExecution`` row
|
|
95
|
+
carrying ``retried_from_id`` so the detail page can show the chain.
|
|
96
|
+
"""
|
|
97
|
+
row = await self.db.get(TaskExecution, execution_id)
|
|
98
|
+
if row is None:
|
|
99
|
+
raise HTTPException(status_code=404, detail="Task execution not found")
|
|
100
|
+
|
|
101
|
+
if row.status not in RETRYABLE_STATUSES:
|
|
102
|
+
raise HTTPException(
|
|
103
|
+
status_code=status.HTTP_409_CONFLICT,
|
|
104
|
+
detail=(
|
|
105
|
+
f"Task execution status is {str(row.status)!r}; "
|
|
106
|
+
"only failed or stuck tasks can be retried."
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
async_result = self.celery.send_task(
|
|
111
|
+
row.task_name,
|
|
112
|
+
args=list(row.args or []),
|
|
113
|
+
kwargs=dict(row.kwargs or {}),
|
|
114
|
+
queue=row.queue,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
new_row = TaskExecution(
|
|
118
|
+
celery_task_id=async_result.id,
|
|
119
|
+
task_name=row.task_name,
|
|
120
|
+
status=TaskStatus.PENDING,
|
|
121
|
+
queue=row.queue,
|
|
122
|
+
args=list(row.args or []),
|
|
123
|
+
kwargs=dict(row.kwargs or {}),
|
|
124
|
+
retried_from_id=row.id,
|
|
125
|
+
)
|
|
126
|
+
self.db.add(new_row)
|
|
127
|
+
await self.db.flush()
|
|
128
|
+
await self.db.refresh(new_row)
|
|
129
|
+
|
|
130
|
+
await self.event_bus.publish(
|
|
131
|
+
TaskRetried(
|
|
132
|
+
original_id=row.id,
|
|
133
|
+
new_id=new_row.id,
|
|
134
|
+
task_name=row.task_name,
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return TaskExecutionDetail.model_validate(new_row)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Module-scoped state container for background_tasks.
|
|
2
|
+
|
|
3
|
+
Stored as ``app.state.background_tasks`` by
|
|
4
|
+
:meth:`BackgroundTasksModule.register_settings`. Not frozen — ``on_startup``
|
|
5
|
+
populates :attr:`celery` once the Celery app is built. Convention: set once
|
|
6
|
+
during boot, treat as read-only after.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from celery import Celery
|
|
16
|
+
|
|
17
|
+
from background_tasks.settings import BackgroundTasksSettings
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class BackgroundTasksServices:
|
|
22
|
+
"""BackgroundTasks singletons. Single slot at ``app.state.background_tasks``."""
|
|
23
|
+
|
|
24
|
+
settings: BackgroundTasksSettings
|
|
25
|
+
celery: Celery | None = None
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""BackgroundTasks module settings (DB-backed).
|
|
2
|
+
|
|
3
|
+
Construction no longer reads ``SM_BG_TASKS_*`` environment variables. Values
|
|
4
|
+
come from pydantic defaults at boot, then get hydrated from the DB by the
|
|
5
|
+
hosting lifespan before module ``on_startup`` runs. Runtime changes go
|
|
6
|
+
through ``settings.reload.apply_changes_and_reload``.
|
|
7
|
+
|
|
8
|
+
The one remaining env read is ``SM_ENVIRONMENT``, consulted by the
|
|
9
|
+
``@model_validator`` to refuse a localhost broker in production — that's a
|
|
10
|
+
host-level setting, not a background_tasks-module field.
|
|
11
|
+
|
|
12
|
+
The Celery-critical fields (``broker_url``, ``result_backend``,
|
|
13
|
+
``task_default_queue``) are marked ``requires_restart=True`` via
|
|
14
|
+
``json_schema_extra`` because workers read these once at process start, so
|
|
15
|
+
DB changes can't be hot-reloaded: the admin UI should surface that bumping
|
|
16
|
+
these values requires a worker restart.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from pydantic import Field, model_validator
|
|
24
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
25
|
+
from simple_module_core.environments import NON_PROD_ENVIRONMENTS
|
|
26
|
+
|
|
27
|
+
from background_tasks.constants import (
|
|
28
|
+
DEFAULT_BROKER_URL,
|
|
29
|
+
DEFAULT_MAX_RETRIES,
|
|
30
|
+
DEFAULT_PURGE_INTERVAL_SECONDS,
|
|
31
|
+
DEFAULT_QUEUE,
|
|
32
|
+
DEFAULT_RESULT_BACKEND,
|
|
33
|
+
DEFAULT_RETENTION_DAYS,
|
|
34
|
+
DEFAULT_STUCK_AFTER_SECONDS,
|
|
35
|
+
DEFAULT_STUCK_SWEEP_INTERVAL_SECONDS,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
_CELERY_RESTART = {"requires_restart": True, "group": "Celery"}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BackgroundTasksSettings(BaseSettings):
|
|
42
|
+
"""Configuration for the Celery + Redis task runner."""
|
|
43
|
+
|
|
44
|
+
model_config = SettingsConfigDict(extra="ignore")
|
|
45
|
+
|
|
46
|
+
broker_url: str = Field(default=DEFAULT_BROKER_URL, json_schema_extra=_CELERY_RESTART)
|
|
47
|
+
result_backend: str = Field(default=DEFAULT_RESULT_BACKEND, json_schema_extra=_CELERY_RESTART)
|
|
48
|
+
task_default_queue: str = Field(default=DEFAULT_QUEUE, json_schema_extra=_CELERY_RESTART)
|
|
49
|
+
|
|
50
|
+
# A task that has been ``running`` longer than this without a heartbeat is
|
|
51
|
+
# flipped to ``stuck`` by the beat sweep. 5 min is long enough to cover
|
|
52
|
+
# normal slow jobs while still surfacing wedged workers within one UI
|
|
53
|
+
# refresh.
|
|
54
|
+
stuck_after_seconds: int = DEFAULT_STUCK_AFTER_SECONDS
|
|
55
|
+
stuck_sweep_interval_seconds: int = DEFAULT_STUCK_SWEEP_INTERVAL_SECONDS
|
|
56
|
+
purge_interval_seconds: int = DEFAULT_PURGE_INTERVAL_SECONDS
|
|
57
|
+
|
|
58
|
+
retention_days: int = DEFAULT_RETENTION_DAYS
|
|
59
|
+
max_retries: int = DEFAULT_MAX_RETRIES
|
|
60
|
+
|
|
61
|
+
@model_validator(mode="after")
|
|
62
|
+
def _forbid_localhost_broker_in_production(self) -> BackgroundTasksSettings:
|
|
63
|
+
"""Fail boot if production is still pointed at the dev default broker.
|
|
64
|
+
|
|
65
|
+
A localhost broker in prod means the web container is talking to its
|
|
66
|
+
own 6379 instead of the shared Redis service — tasks would silently
|
|
67
|
+
queue to a broker no worker reads.
|
|
68
|
+
"""
|
|
69
|
+
env = os.environ.get("SM_ENVIRONMENT", "development")
|
|
70
|
+
if env in NON_PROD_ENVIRONMENTS:
|
|
71
|
+
return self
|
|
72
|
+
bad = []
|
|
73
|
+
if "localhost" in self.broker_url or "127.0.0.1" in self.broker_url:
|
|
74
|
+
bad.append("broker_url")
|
|
75
|
+
if "localhost" in self.result_backend or "127.0.0.1" in self.result_backend:
|
|
76
|
+
bad.append("result_backend")
|
|
77
|
+
if bad:
|
|
78
|
+
names = ", ".join(bad)
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"{names} must not point at localhost when SM_ENVIRONMENT={env!r}. "
|
|
81
|
+
"Set these to the Redis service host (e.g. redis://redis:6379/0)."
|
|
82
|
+
)
|
|
83
|
+
return self
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""Celery signal handlers that keep the ``TaskExecution`` table in sync.
|
|
2
|
+
|
|
3
|
+
Each signal writes one row: publish creates ``pending``, prerun flips to
|
|
4
|
+
``running`` with a heartbeat, postrun/success/failure/retry/revoked update
|
|
5
|
+
the terminal columns. We update an existing row (matched by
|
|
6
|
+
``celery_task_id``) so a single task's lifecycle stays on one row instead
|
|
7
|
+
of spawning a new row per signal.
|
|
8
|
+
|
|
9
|
+
Signals are sync — see :mod:`.sync_db` for why we maintain a separate sync
|
|
10
|
+
engine, and :mod:`._signal_support` for the shared helpers.
|
|
11
|
+
|
|
12
|
+
``TaskFailed`` is also dispatched from :func:`on_task_failure` when an event
|
|
13
|
+
bus has been bound via :func:`bind_event_bus` — typically from the web
|
|
14
|
+
process's ``on_startup``. In a standalone Celery worker (separate process)
|
|
15
|
+
no bus is bound and the publish becomes a no-op; subscribers that need
|
|
16
|
+
cross-process notification should consume Celery events or poll the DB
|
|
17
|
+
directly.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import asyncio
|
|
23
|
+
import logging
|
|
24
|
+
from collections.abc import Callable
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
from celery import signals
|
|
28
|
+
from simple_module_core.events import EventBus
|
|
29
|
+
|
|
30
|
+
from background_tasks._signal_support import (
|
|
31
|
+
coerce_args_kwargs,
|
|
32
|
+
jsonable_result,
|
|
33
|
+
now_utc,
|
|
34
|
+
render_traceback,
|
|
35
|
+
task_id_from,
|
|
36
|
+
task_name_of,
|
|
37
|
+
upsert_by_celery_id,
|
|
38
|
+
)
|
|
39
|
+
from background_tasks.constants import DEFAULT_QUEUE, TaskStatus
|
|
40
|
+
from background_tasks.contracts.events import TaskFailed
|
|
41
|
+
from background_tasks.models import TaskExecution
|
|
42
|
+
from background_tasks.sync_db import sync_session
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
_bus: EventBus | None = None
|
|
48
|
+
_loop: asyncio.AbstractEventLoop | None = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def bind_event_bus(bus: EventBus, loop: asyncio.AbstractEventLoop) -> None:
|
|
52
|
+
"""Bind an event bus + its running loop so signals can publish events.
|
|
53
|
+
|
|
54
|
+
Signals fire on the Celery sync thread; `run_coroutine_threadsafe`
|
|
55
|
+
bridges back to ``loop`` so handlers run on the API event loop
|
|
56
|
+
regardless of which thread triggered the signal.
|
|
57
|
+
"""
|
|
58
|
+
global _bus, _loop
|
|
59
|
+
_bus = bus
|
|
60
|
+
_loop = loop
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def unbind_event_bus() -> None:
|
|
64
|
+
"""Drop the bound bus — called from ``on_shutdown`` so tests stay isolated."""
|
|
65
|
+
global _bus, _loop
|
|
66
|
+
_bus = None
|
|
67
|
+
_loop = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _publish_from_signal(event: Any) -> None:
|
|
71
|
+
"""Dispatch ``event`` onto the bound bus without blocking the signal thread."""
|
|
72
|
+
if _bus is None or _loop is None:
|
|
73
|
+
return
|
|
74
|
+
try:
|
|
75
|
+
future = asyncio.run_coroutine_threadsafe(_bus.publish(event), _loop)
|
|
76
|
+
except RuntimeError:
|
|
77
|
+
# Loop has stopped (shutdown race). The DB row is already written.
|
|
78
|
+
logger.debug("Event bus loop is not running; skipping %s", type(event).__name__)
|
|
79
|
+
return
|
|
80
|
+
# Surface subscriber exceptions — run_coroutine_threadsafe otherwise only
|
|
81
|
+
# logs them when the Future is GC'd, which happens far from the failure.
|
|
82
|
+
future.add_done_callback(_log_publish_failure)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _log_publish_failure(future: asyncio.Future[Any]) -> None:
|
|
86
|
+
if future.cancelled():
|
|
87
|
+
return
|
|
88
|
+
exc = future.exception()
|
|
89
|
+
if exc is not None:
|
|
90
|
+
logger.error("Event publish raised: %s", exc, exc_info=exc)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _apply(
|
|
94
|
+
handler: str,
|
|
95
|
+
*,
|
|
96
|
+
celery_task_id: str | None,
|
|
97
|
+
defaults: dict[str, Any],
|
|
98
|
+
after: Callable[[TaskExecution], None] | None = None,
|
|
99
|
+
) -> TaskExecution | None:
|
|
100
|
+
"""Open a sync session, upsert by celery_task_id, log on failure.
|
|
101
|
+
|
|
102
|
+
Returns the upserted row (or ``None`` if the session raised) so callers
|
|
103
|
+
can read DB-assigned fields like ``id`` without a second round trip.
|
|
104
|
+
"""
|
|
105
|
+
try:
|
|
106
|
+
with sync_session() as session:
|
|
107
|
+
row = upsert_by_celery_id(session, celery_task_id=celery_task_id, defaults=defaults)
|
|
108
|
+
if after is not None:
|
|
109
|
+
after(row)
|
|
110
|
+
return row
|
|
111
|
+
except Exception:
|
|
112
|
+
logger.exception("%s failed for task_id=%s", handler, celery_task_id)
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# ── Enqueue ─────────────────────────────────────────────────────
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@signals.before_task_publish.connect
|
|
120
|
+
def on_task_publish(
|
|
121
|
+
sender: str | None = None,
|
|
122
|
+
headers: dict[str, Any] | None = None,
|
|
123
|
+
body: Any = None,
|
|
124
|
+
routing_key: str | None = None,
|
|
125
|
+
**_kwargs: Any,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Record a row the moment a task is pushed onto the broker."""
|
|
128
|
+
task_id = (headers or {}).get("id")
|
|
129
|
+
task_name = sender or (headers or {}).get("task") or "unknown"
|
|
130
|
+
|
|
131
|
+
# ``body`` on the publish signal is ``(args, kwargs, options)`` for the
|
|
132
|
+
# standard Celery protocol; guard against edge shapes.
|
|
133
|
+
args_in, kwargs_in = [], {}
|
|
134
|
+
if isinstance(body, list | tuple) and len(body) >= 2:
|
|
135
|
+
args_in, kwargs_in = body[0], body[1]
|
|
136
|
+
args, kwargs = coerce_args_kwargs(args_in, kwargs_in)
|
|
137
|
+
|
|
138
|
+
_apply(
|
|
139
|
+
"on_task_publish",
|
|
140
|
+
celery_task_id=task_id,
|
|
141
|
+
defaults={
|
|
142
|
+
"task_name": task_name,
|
|
143
|
+
"status": TaskStatus.PENDING,
|
|
144
|
+
"queue": routing_key or DEFAULT_QUEUE,
|
|
145
|
+
"args": args,
|
|
146
|
+
"kwargs": kwargs,
|
|
147
|
+
"queued_at": now_utc(),
|
|
148
|
+
},
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ── Execution lifecycle ─────────────────────────────────────────
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@signals.task_prerun.connect
|
|
156
|
+
def on_task_prerun(
|
|
157
|
+
sender: Any = None,
|
|
158
|
+
task_id: str | None = None,
|
|
159
|
+
task: Any = None,
|
|
160
|
+
args: Any = None,
|
|
161
|
+
kwargs: Any = None,
|
|
162
|
+
**_k: Any,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Flip the row to ``running`` and start the heartbeat."""
|
|
165
|
+
args_n, kwargs_n = coerce_args_kwargs(args, kwargs)
|
|
166
|
+
now = now_utc()
|
|
167
|
+
_apply(
|
|
168
|
+
"on_task_prerun",
|
|
169
|
+
celery_task_id=task_id,
|
|
170
|
+
defaults={
|
|
171
|
+
"task_name": task_name_of(sender, task),
|
|
172
|
+
"status": TaskStatus.RUNNING,
|
|
173
|
+
"args": args_n,
|
|
174
|
+
"kwargs": kwargs_n,
|
|
175
|
+
"started_at": now,
|
|
176
|
+
"heartbeat_at": now,
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@signals.task_postrun.connect
|
|
182
|
+
def on_task_postrun(
|
|
183
|
+
sender: Any = None,
|
|
184
|
+
task_id: str | None = None,
|
|
185
|
+
task: Any = None,
|
|
186
|
+
**_k: Any,
|
|
187
|
+
) -> None:
|
|
188
|
+
"""Refresh the heartbeat on normal completion.
|
|
189
|
+
|
|
190
|
+
Terminal status is written by ``task_success`` / ``task_failure`` /
|
|
191
|
+
``task_retry`` which fire *before* postrun; postrun only refreshes the
|
|
192
|
+
heartbeat so the sweep doesn't immediately flip a just-finished row.
|
|
193
|
+
"""
|
|
194
|
+
_apply(
|
|
195
|
+
"on_task_postrun",
|
|
196
|
+
celery_task_id=task_id,
|
|
197
|
+
defaults={"task_name": task_name_of(sender, task), "heartbeat_at": now_utc()},
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@signals.task_success.connect
|
|
202
|
+
def on_task_success(sender: Any = None, result: Any = None, **_k: Any) -> None:
|
|
203
|
+
_apply(
|
|
204
|
+
"on_task_success",
|
|
205
|
+
celery_task_id=task_id_from(sender=sender),
|
|
206
|
+
defaults={
|
|
207
|
+
"task_name": task_name_of(sender),
|
|
208
|
+
"status": TaskStatus.SUCCESS,
|
|
209
|
+
"result": jsonable_result(result),
|
|
210
|
+
"finished_at": now_utc(),
|
|
211
|
+
"traceback": None,
|
|
212
|
+
"exception_type": None,
|
|
213
|
+
},
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@signals.task_failure.connect
|
|
218
|
+
def on_task_failure(
|
|
219
|
+
sender: Any = None,
|
|
220
|
+
task_id: str | None = None,
|
|
221
|
+
exception: BaseException | None = None,
|
|
222
|
+
einfo: Any = None,
|
|
223
|
+
**_k: Any,
|
|
224
|
+
) -> None:
|
|
225
|
+
task_name = task_name_of(sender)
|
|
226
|
+
exception_type = type(exception).__name__ if exception is not None else None
|
|
227
|
+
|
|
228
|
+
row = _apply(
|
|
229
|
+
"on_task_failure",
|
|
230
|
+
celery_task_id=task_id,
|
|
231
|
+
defaults={
|
|
232
|
+
"task_name": task_name,
|
|
233
|
+
"status": TaskStatus.FAILED,
|
|
234
|
+
"traceback": render_traceback(einfo, exception),
|
|
235
|
+
"exception_type": exception_type,
|
|
236
|
+
"finished_at": now_utc(),
|
|
237
|
+
},
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if row is not None:
|
|
241
|
+
_publish_from_signal(
|
|
242
|
+
TaskFailed(
|
|
243
|
+
task_execution_id=row.id,
|
|
244
|
+
task_name=task_name,
|
|
245
|
+
exception_type=exception_type,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@signals.task_retry.connect
|
|
251
|
+
def on_task_retry(
|
|
252
|
+
sender: Any = None,
|
|
253
|
+
request: Any = None,
|
|
254
|
+
reason: Any = None,
|
|
255
|
+
**_k: Any,
|
|
256
|
+
) -> None:
|
|
257
|
+
def _stamp_reason(row: TaskExecution) -> None:
|
|
258
|
+
# Preserve the latest reason for the UI without nulling a
|
|
259
|
+
# previously-captured traceback.
|
|
260
|
+
if reason is not None:
|
|
261
|
+
row.traceback = f"retry: {reason!r}"
|
|
262
|
+
|
|
263
|
+
_apply(
|
|
264
|
+
"on_task_retry",
|
|
265
|
+
celery_task_id=task_id_from(request=request),
|
|
266
|
+
defaults={
|
|
267
|
+
"task_name": task_name_of(sender),
|
|
268
|
+
"status": TaskStatus.RETRYING,
|
|
269
|
+
"retries": int(getattr(request, "retries", 0) or 0),
|
|
270
|
+
"heartbeat_at": now_utc(),
|
|
271
|
+
},
|
|
272
|
+
after=_stamp_reason,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@signals.task_revoked.connect
|
|
277
|
+
def on_task_revoked(sender: Any = None, request: Any = None, **_k: Any) -> None:
|
|
278
|
+
_apply(
|
|
279
|
+
"on_task_revoked",
|
|
280
|
+
celery_task_id=task_id_from(request=request),
|
|
281
|
+
defaults={
|
|
282
|
+
"task_name": task_name_of(sender),
|
|
283
|
+
"status": TaskStatus.REVOKED,
|
|
284
|
+
"finished_at": now_utc(),
|
|
285
|
+
},
|
|
286
|
+
)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Sync SQLAlchemy session factory for Celery signal handlers.
|
|
2
|
+
|
|
3
|
+
Celery signals are called synchronously from inside the worker / web-process
|
|
4
|
+
hot path. Building an event-loop just to await a query is both ugly and
|
|
5
|
+
deadlock-prone (web-process signals can fire from inside an already-running
|
|
6
|
+
loop). We instead maintain a second, sync engine pointed at the same DB URL
|
|
7
|
+
and borrow a short-lived session from it whenever a signal fires.
|
|
8
|
+
|
|
9
|
+
The engine is lazily built on first use and process-global. It's cheap —
|
|
10
|
+
SQLAlchemy's connection pool is per-process so the signal path reuses
|
|
11
|
+
connections after the first signal.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from collections.abc import Iterator
|
|
19
|
+
from contextlib import contextmanager
|
|
20
|
+
|
|
21
|
+
from sqlalchemy import create_engine
|
|
22
|
+
from sqlalchemy.engine import Engine
|
|
23
|
+
from sqlalchemy.orm import Session, sessionmaker
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
_engine: Engine | None = None
|
|
28
|
+
_session_factory: sessionmaker[Session] | None = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _sync_url(async_url: str) -> str:
|
|
32
|
+
"""Convert an async SQLAlchemy URL to its sync driver equivalent.
|
|
33
|
+
|
|
34
|
+
Mirrors ``host/migrations/env.py`` so signals use the same URL shape
|
|
35
|
+
Alembic does.
|
|
36
|
+
"""
|
|
37
|
+
return async_url.replace("+aiosqlite", "").replace("+asyncpg", "+psycopg2")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _build_engine() -> Engine:
|
|
41
|
+
url = os.environ.get("SM_DATABASE_URL", "sqlite:///./app.db")
|
|
42
|
+
sync_url = _sync_url(url)
|
|
43
|
+
# Small pool — signals fire sequentially per worker process.
|
|
44
|
+
return create_engine(sync_url, pool_pre_ping=True, pool_size=2, max_overflow=3)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_sync_session_factory() -> sessionmaker[Session]:
|
|
48
|
+
"""Return the process-global sync session factory, building it once."""
|
|
49
|
+
global _engine, _session_factory
|
|
50
|
+
if _session_factory is None:
|
|
51
|
+
_engine = _build_engine()
|
|
52
|
+
_session_factory = sessionmaker(bind=_engine, expire_on_commit=False)
|
|
53
|
+
return _session_factory
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def dispose_sync_engine() -> None:
|
|
57
|
+
"""Release pooled connections and drop the cached engine.
|
|
58
|
+
|
|
59
|
+
Called from :meth:`BackgroundTasksModule.on_shutdown` so lifespan
|
|
60
|
+
restarts within one process (test runners, uvicorn dev reload) don't
|
|
61
|
+
accumulate engines against the old DB URL.
|
|
62
|
+
"""
|
|
63
|
+
global _engine, _session_factory
|
|
64
|
+
if _engine is not None:
|
|
65
|
+
_engine.dispose()
|
|
66
|
+
_engine = None
|
|
67
|
+
_session_factory = None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@contextmanager
|
|
71
|
+
def sync_session() -> Iterator[Session]:
|
|
72
|
+
"""Open a short-lived sync session; commit on success, rollback on error."""
|
|
73
|
+
factory = get_sync_session_factory()
|
|
74
|
+
session = factory()
|
|
75
|
+
try:
|
|
76
|
+
yield session
|
|
77
|
+
session.commit()
|
|
78
|
+
except Exception:
|
|
79
|
+
session.rollback()
|
|
80
|
+
raise
|
|
81
|
+
finally:
|
|
82
|
+
session.close()
|