horsies 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- horsies/__init__.py +115 -0
- horsies/core/__init__.py +0 -0
- horsies/core/app.py +552 -0
- horsies/core/banner.py +144 -0
- horsies/core/brokers/__init__.py +5 -0
- horsies/core/brokers/listener.py +444 -0
- horsies/core/brokers/postgres.py +864 -0
- horsies/core/cli.py +624 -0
- horsies/core/codec/serde.py +575 -0
- horsies/core/errors.py +535 -0
- horsies/core/logging.py +90 -0
- horsies/core/models/__init__.py +0 -0
- horsies/core/models/app.py +268 -0
- horsies/core/models/broker.py +79 -0
- horsies/core/models/queues.py +23 -0
- horsies/core/models/recovery.py +101 -0
- horsies/core/models/schedule.py +229 -0
- horsies/core/models/task_pg.py +307 -0
- horsies/core/models/tasks.py +332 -0
- horsies/core/models/workflow.py +1988 -0
- horsies/core/models/workflow_pg.py +245 -0
- horsies/core/registry/tasks.py +101 -0
- horsies/core/scheduler/__init__.py +26 -0
- horsies/core/scheduler/calculator.py +267 -0
- horsies/core/scheduler/service.py +569 -0
- horsies/core/scheduler/state.py +260 -0
- horsies/core/task_decorator.py +615 -0
- horsies/core/types/status.py +38 -0
- horsies/core/utils/imports.py +203 -0
- horsies/core/utils/loop_runner.py +44 -0
- horsies/core/worker/current.py +17 -0
- horsies/core/worker/worker.py +1967 -0
- horsies/core/workflows/__init__.py +23 -0
- horsies/core/workflows/engine.py +2344 -0
- horsies/core/workflows/recovery.py +501 -0
- horsies/core/workflows/registry.py +97 -0
- horsies/py.typed +0 -0
- horsies-0.1.0a1.dist-info/METADATA +31 -0
- horsies-0.1.0a1.dist-info/RECORD +42 -0
- horsies-0.1.0a1.dist-info/WHEEL +5 -0
- horsies-0.1.0a1.dist-info/entry_points.txt +2 -0
- horsies-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,864 @@
|
|
|
1
|
+
# app/core/brokers/postgres.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import uuid, asyncio, hashlib
|
|
4
|
+
from typing import Any, Optional, TYPE_CHECKING
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
7
|
+
from sqlalchemy import text
|
|
8
|
+
from horsies.core.brokers.listener import PostgresListener
|
|
9
|
+
from horsies.core.models.broker import PostgresConfig
|
|
10
|
+
from horsies.core.models.task_pg import TaskModel, Base
|
|
11
|
+
from horsies.core.models.workflow_pg import WorkflowModel, WorkflowTaskModel # noqa: F401
|
|
12
|
+
from horsies.core.types.status import TaskStatus
|
|
13
|
+
from horsies.core.codec.serde import (
|
|
14
|
+
args_to_json,
|
|
15
|
+
kwargs_to_json,
|
|
16
|
+
loads_json,
|
|
17
|
+
task_result_from_json,
|
|
18
|
+
)
|
|
19
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
20
|
+
from horsies.core.logging import get_logger
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from horsies.core.models.tasks import TaskResult, TaskError
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class PostgresBroker:
|
|
27
|
+
"""
|
|
28
|
+
PostgreSQL-based task broker with LISTEN/NOTIFY for real-time updates.
|
|
29
|
+
|
|
30
|
+
Provides both async and sync APIs:
|
|
31
|
+
- Async: enqueue_async(), get_result_async()
|
|
32
|
+
- Sync: enqueue(), get_result() (run in background event loop)
|
|
33
|
+
|
|
34
|
+
Features:
|
|
35
|
+
- Real-time notifications via PostgreSQL triggers
|
|
36
|
+
- Automatic task status tracking
|
|
37
|
+
- Connection pooling and health monitoring
|
|
38
|
+
- Operational monitoring (stale tasks, worker stats)
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: PostgresConfig):
|
|
42
|
+
self.config = config
|
|
43
|
+
self.logger = get_logger('broker')
|
|
44
|
+
self._app: Any = None # Set by Horsies.get_broker()
|
|
45
|
+
|
|
46
|
+
engine_cfg = self.config.model_dump(exclude={'database_url'}, exclude_none=True)
|
|
47
|
+
self.async_engine = create_async_engine(self.config.database_url, **engine_cfg)
|
|
48
|
+
self.session_factory = async_sessionmaker(
|
|
49
|
+
self.async_engine, expire_on_commit=False
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
psycopg_url = self.config.database_url.replace('+asyncpg', '').replace(
|
|
53
|
+
'+psycopg', ''
|
|
54
|
+
)
|
|
55
|
+
self.listener = PostgresListener(psycopg_url)
|
|
56
|
+
|
|
57
|
+
self._initialized = False
|
|
58
|
+
self._loop_runner = LoopRunner() # for sync facades
|
|
59
|
+
|
|
60
|
+
self.logger.info('PostgresBroker initialized')
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def app(self) -> Any:
|
|
64
|
+
"""Get the attached Horsies app instance (if any)."""
|
|
65
|
+
return self._app
|
|
66
|
+
|
|
67
|
+
@app.setter
|
|
68
|
+
def app(self, value: Any) -> None:
|
|
69
|
+
"""Set the Horsies app instance."""
|
|
70
|
+
self._app = value
|
|
71
|
+
|
|
72
|
+
def _schema_advisory_key(self) -> int:
|
|
73
|
+
"""
|
|
74
|
+
Compute a stable 64-bit advisory lock key for schema initialization.
|
|
75
|
+
|
|
76
|
+
Uses the database URL as a basis so that different clusters do not
|
|
77
|
+
contend on the same advisory lock key.
|
|
78
|
+
"""
|
|
79
|
+
basis = self.config.database_url.encode('utf-8', errors='ignore')
|
|
80
|
+
h = hashlib.sha256(b'horsies-schema:' + basis).digest()
|
|
81
|
+
return int.from_bytes(h[:8], byteorder='big', signed=True)
|
|
82
|
+
|
|
83
|
+
async def _create_triggers(self) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Set up PostgreSQL triggers for automatic task notifications.
|
|
86
|
+
|
|
87
|
+
Creates triggers that send NOTIFY messages on:
|
|
88
|
+
- INSERT: Sends task_new + task_queue_{queue_name} notifications
|
|
89
|
+
- UPDATE to COMPLETED/FAILED: Sends task_done notification
|
|
90
|
+
|
|
91
|
+
This enables real-time task processing without polling.
|
|
92
|
+
"""
|
|
93
|
+
async with self.async_engine.begin() as conn:
|
|
94
|
+
# Create trigger function
|
|
95
|
+
await conn.execute(
|
|
96
|
+
text("""
|
|
97
|
+
CREATE OR REPLACE FUNCTION horsies_notify_task_changes()
|
|
98
|
+
RETURNS trigger AS $$
|
|
99
|
+
BEGIN
|
|
100
|
+
IF TG_OP = 'INSERT' AND NEW.status = 'PENDING' THEN
|
|
101
|
+
-- New task notifications: wake up workers
|
|
102
|
+
PERFORM pg_notify('task_new', NEW.id); -- Global worker notification
|
|
103
|
+
PERFORM pg_notify('task_queue_' || NEW.queue_name, NEW.id); -- Queue-specific notification
|
|
104
|
+
ELSIF TG_OP = 'UPDATE' AND OLD.status != NEW.status THEN
|
|
105
|
+
-- Task completion notifications: wake up result waiters
|
|
106
|
+
IF NEW.status IN ('COMPLETED', 'FAILED') THEN
|
|
107
|
+
PERFORM pg_notify('task_done', NEW.id); -- Send task_id as payload
|
|
108
|
+
END IF;
|
|
109
|
+
END IF;
|
|
110
|
+
RETURN NEW;
|
|
111
|
+
END;
|
|
112
|
+
$$ LANGUAGE plpgsql;
|
|
113
|
+
""")
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Create trigger
|
|
117
|
+
await conn.execute(
|
|
118
|
+
text("""
|
|
119
|
+
DROP TRIGGER IF EXISTS horsies_task_notify_trigger ON horsies_tasks;
|
|
120
|
+
CREATE TRIGGER horsies_task_notify_trigger
|
|
121
|
+
AFTER INSERT OR UPDATE ON horsies_tasks
|
|
122
|
+
FOR EACH ROW
|
|
123
|
+
EXECUTE FUNCTION horsies_notify_task_changes();
|
|
124
|
+
""")
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
async def _create_workflow_schema(self) -> None:
|
|
128
|
+
"""
|
|
129
|
+
Set up workflow-specific schema elements.
|
|
130
|
+
|
|
131
|
+
Creates:
|
|
132
|
+
- GIN index on workflow_tasks.dependencies for efficient dependency lookups
|
|
133
|
+
- Trigger for workflow completion notifications
|
|
134
|
+
- Migration: adds task_options column if missing (for existing installs)
|
|
135
|
+
"""
|
|
136
|
+
async with self.async_engine.begin() as conn:
|
|
137
|
+
# GIN index for efficient dependency array lookups
|
|
138
|
+
await conn.execute(
|
|
139
|
+
text("""
|
|
140
|
+
CREATE INDEX IF NOT EXISTS idx_horsies_workflow_tasks_deps
|
|
141
|
+
ON horsies_workflow_tasks USING GIN(dependencies);
|
|
142
|
+
""")
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Migration: add task_options column for existing installs
|
|
146
|
+
await conn.execute(
|
|
147
|
+
text("""
|
|
148
|
+
ALTER TABLE horsies_workflow_tasks
|
|
149
|
+
ADD COLUMN IF NOT EXISTS task_options TEXT;
|
|
150
|
+
""")
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Migration: add success_policy column for existing installs
|
|
154
|
+
await conn.execute(
|
|
155
|
+
text("""
|
|
156
|
+
ALTER TABLE horsies_workflows
|
|
157
|
+
ADD COLUMN IF NOT EXISTS success_policy JSONB;
|
|
158
|
+
""")
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# Migration: add join_type and min_success columns for existing installs
|
|
162
|
+
await conn.execute(
|
|
163
|
+
text("""
|
|
164
|
+
ALTER TABLE horsies_workflow_tasks
|
|
165
|
+
ADD COLUMN IF NOT EXISTS join_type VARCHAR(10) NOT NULL DEFAULT 'all';
|
|
166
|
+
""")
|
|
167
|
+
)
|
|
168
|
+
await conn.execute(
|
|
169
|
+
text("""
|
|
170
|
+
ALTER TABLE horsies_workflow_tasks
|
|
171
|
+
ADD COLUMN IF NOT EXISTS min_success INTEGER;
|
|
172
|
+
""")
|
|
173
|
+
)
|
|
174
|
+
await conn.execute(
|
|
175
|
+
text("""
|
|
176
|
+
ALTER TABLE horsies_workflow_tasks
|
|
177
|
+
ADD COLUMN IF NOT EXISTS node_id VARCHAR(128);
|
|
178
|
+
""")
|
|
179
|
+
)
|
|
180
|
+
await conn.execute(
|
|
181
|
+
text("""
|
|
182
|
+
ALTER TABLE horsies_workflow_tasks
|
|
183
|
+
ALTER COLUMN workflow_ctx_from
|
|
184
|
+
TYPE VARCHAR(128)[]
|
|
185
|
+
USING workflow_ctx_from::VARCHAR(128)[];
|
|
186
|
+
""")
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Subworkflow support columns
|
|
190
|
+
await conn.execute(
|
|
191
|
+
text("""
|
|
192
|
+
ALTER TABLE horsies_workflows
|
|
193
|
+
ADD COLUMN IF NOT EXISTS parent_workflow_id VARCHAR(36);
|
|
194
|
+
""")
|
|
195
|
+
)
|
|
196
|
+
await conn.execute(
|
|
197
|
+
text("""
|
|
198
|
+
ALTER TABLE horsies_workflows
|
|
199
|
+
ADD COLUMN IF NOT EXISTS parent_task_index INTEGER;
|
|
200
|
+
""")
|
|
201
|
+
)
|
|
202
|
+
await conn.execute(
|
|
203
|
+
text("""
|
|
204
|
+
ALTER TABLE horsies_workflows
|
|
205
|
+
ADD COLUMN IF NOT EXISTS depth INTEGER NOT NULL DEFAULT 0;
|
|
206
|
+
""")
|
|
207
|
+
)
|
|
208
|
+
await conn.execute(
|
|
209
|
+
text("""
|
|
210
|
+
ALTER TABLE horsies_workflows
|
|
211
|
+
ADD COLUMN IF NOT EXISTS root_workflow_id VARCHAR(36);
|
|
212
|
+
""")
|
|
213
|
+
)
|
|
214
|
+
await conn.execute(
|
|
215
|
+
text("""
|
|
216
|
+
ALTER TABLE horsies_workflows
|
|
217
|
+
ADD COLUMN IF NOT EXISTS workflow_def_module VARCHAR(512);
|
|
218
|
+
""")
|
|
219
|
+
)
|
|
220
|
+
await conn.execute(
|
|
221
|
+
text("""
|
|
222
|
+
ALTER TABLE horsies_workflows
|
|
223
|
+
ADD COLUMN IF NOT EXISTS workflow_def_qualname VARCHAR(512);
|
|
224
|
+
""")
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
await conn.execute(
|
|
228
|
+
text("""
|
|
229
|
+
ALTER TABLE horsies_workflow_tasks
|
|
230
|
+
ADD COLUMN IF NOT EXISTS is_subworkflow BOOLEAN NOT NULL DEFAULT FALSE;
|
|
231
|
+
""")
|
|
232
|
+
)
|
|
233
|
+
await conn.execute(
|
|
234
|
+
text("""
|
|
235
|
+
ALTER TABLE horsies_workflow_tasks
|
|
236
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_id VARCHAR(36);
|
|
237
|
+
""")
|
|
238
|
+
)
|
|
239
|
+
await conn.execute(
|
|
240
|
+
text("""
|
|
241
|
+
ALTER TABLE horsies_workflow_tasks
|
|
242
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_name VARCHAR(255);
|
|
243
|
+
""")
|
|
244
|
+
)
|
|
245
|
+
await conn.execute(
|
|
246
|
+
text("""
|
|
247
|
+
ALTER TABLE horsies_workflow_tasks
|
|
248
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_retry_mode VARCHAR(50);
|
|
249
|
+
""")
|
|
250
|
+
)
|
|
251
|
+
await conn.execute(
|
|
252
|
+
text("""
|
|
253
|
+
ALTER TABLE horsies_workflow_tasks
|
|
254
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_summary TEXT;
|
|
255
|
+
""")
|
|
256
|
+
)
|
|
257
|
+
await conn.execute(
|
|
258
|
+
text("""
|
|
259
|
+
ALTER TABLE horsies_workflow_tasks
|
|
260
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_module VARCHAR(512);
|
|
261
|
+
""")
|
|
262
|
+
)
|
|
263
|
+
await conn.execute(
|
|
264
|
+
text("""
|
|
265
|
+
ALTER TABLE horsies_workflow_tasks
|
|
266
|
+
ADD COLUMN IF NOT EXISTS sub_workflow_qualname VARCHAR(512);
|
|
267
|
+
""")
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Workflow notification trigger function
|
|
271
|
+
await conn.execute(
|
|
272
|
+
text("""
|
|
273
|
+
CREATE OR REPLACE FUNCTION horsies_notify_workflow_changes()
|
|
274
|
+
RETURNS trigger AS $$
|
|
275
|
+
BEGIN
|
|
276
|
+
IF TG_OP = 'UPDATE' AND OLD.status != NEW.status THEN
|
|
277
|
+
-- Workflow completion notifications
|
|
278
|
+
IF NEW.status IN ('COMPLETED', 'FAILED', 'CANCELLED', 'PAUSED') THEN
|
|
279
|
+
PERFORM pg_notify('workflow_done', NEW.id);
|
|
280
|
+
END IF;
|
|
281
|
+
END IF;
|
|
282
|
+
RETURN NEW;
|
|
283
|
+
END;
|
|
284
|
+
$$ LANGUAGE plpgsql;
|
|
285
|
+
""")
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Create workflow trigger
|
|
289
|
+
await conn.execute(
|
|
290
|
+
text("""
|
|
291
|
+
DROP TRIGGER IF EXISTS horsies_workflow_notify_trigger ON horsies_workflows;
|
|
292
|
+
CREATE TRIGGER horsies_workflow_notify_trigger
|
|
293
|
+
AFTER UPDATE ON horsies_workflows
|
|
294
|
+
FOR EACH ROW
|
|
295
|
+
EXECUTE FUNCTION horsies_notify_workflow_changes();
|
|
296
|
+
""")
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
async def _ensure_initialized(self) -> None:
|
|
300
|
+
if self._initialized:
|
|
301
|
+
return
|
|
302
|
+
async with self.async_engine.begin() as conn:
|
|
303
|
+
# Take a short-lived, cluster-wide advisory lock to serialize
|
|
304
|
+
# schema creation across workers and producers.
|
|
305
|
+
await conn.execute(
|
|
306
|
+
text('SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))'),
|
|
307
|
+
{'key': self._schema_advisory_key()},
|
|
308
|
+
)
|
|
309
|
+
await conn.run_sync(Base.metadata.create_all)
|
|
310
|
+
|
|
311
|
+
await self._create_triggers()
|
|
312
|
+
await self._create_workflow_schema()
|
|
313
|
+
await self.listener.start()
|
|
314
|
+
self._initialized = True
|
|
315
|
+
|
|
316
|
+
async def ensure_schema_initialized(self) -> None:
|
|
317
|
+
"""
|
|
318
|
+
Public entry point to ensure tables and triggers exist.
|
|
319
|
+
|
|
320
|
+
Safe to call multiple times and from multiple processes; internally
|
|
321
|
+
guarded by a PostgreSQL advisory lock to avoid DDL races.
|
|
322
|
+
"""
|
|
323
|
+
await self._ensure_initialized()
|
|
324
|
+
|
|
325
|
+
# ----------------- Async API -----------------
|
|
326
|
+
|
|
327
|
+
async def enqueue_async(
|
|
328
|
+
self,
|
|
329
|
+
task_name: str,
|
|
330
|
+
args: tuple[Any, ...],
|
|
331
|
+
kwargs: dict[str, Any],
|
|
332
|
+
queue_name: str = 'default',
|
|
333
|
+
*,
|
|
334
|
+
priority: int = 100,
|
|
335
|
+
sent_at: Optional[datetime] = None,
|
|
336
|
+
good_until: Optional[datetime] = None,
|
|
337
|
+
task_options: Optional[str] = None,
|
|
338
|
+
) -> str:
|
|
339
|
+
await self._ensure_initialized()
|
|
340
|
+
|
|
341
|
+
task_id = str(uuid.uuid4())
|
|
342
|
+
now = datetime.now(timezone.utc)
|
|
343
|
+
sent = sent_at or now
|
|
344
|
+
|
|
345
|
+
# Parse retry configuration from task_options
|
|
346
|
+
max_retries = 0
|
|
347
|
+
if task_options:
|
|
348
|
+
try:
|
|
349
|
+
options_data = loads_json(task_options)
|
|
350
|
+
if isinstance(options_data, dict):
|
|
351
|
+
retry_policy = options_data.get('retry_policy')
|
|
352
|
+
if isinstance(retry_policy, dict):
|
|
353
|
+
max_retries = retry_policy.get('max_retries', 3)
|
|
354
|
+
except Exception:
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
async with self.session_factory() as session:
|
|
358
|
+
task = TaskModel(
|
|
359
|
+
id=task_id,
|
|
360
|
+
task_name=task_name,
|
|
361
|
+
queue_name=queue_name,
|
|
362
|
+
priority=priority,
|
|
363
|
+
args=args_to_json(args) if args else None,
|
|
364
|
+
kwargs=kwargs_to_json(kwargs) if kwargs else None,
|
|
365
|
+
status=TaskStatus.PENDING,
|
|
366
|
+
sent_at=sent,
|
|
367
|
+
good_until=good_until,
|
|
368
|
+
max_retries=max_retries,
|
|
369
|
+
task_options=task_options,
|
|
370
|
+
created_at=now,
|
|
371
|
+
updated_at=now,
|
|
372
|
+
)
|
|
373
|
+
session.add(task)
|
|
374
|
+
await session.commit()
|
|
375
|
+
|
|
376
|
+
# PostgreSQL trigger automatically sends task_new + task_queue_{queue_name} notifications
|
|
377
|
+
return task_id
|
|
378
|
+
|
|
379
|
+
async def get_result_async(
|
|
380
|
+
self, task_id: str, timeout_ms: Optional[int] = None
|
|
381
|
+
) -> 'TaskResult[Any, TaskError]':
|
|
382
|
+
"""
|
|
383
|
+
Get task result, waiting if necessary.
|
|
384
|
+
|
|
385
|
+
Returns TaskResult for task completion and retrieval outcomes:
|
|
386
|
+
- Success: TaskResult(ok=value) from task execution
|
|
387
|
+
- Task error: TaskResult(err=TaskError) from task execution
|
|
388
|
+
- Retrieval error: TaskResult(err=TaskError) with WAIT_TIMEOUT, TASK_NOT_FOUND, or TASK_CANCELLED
|
|
389
|
+
|
|
390
|
+
Broker failures (e.g., database errors) are returned as BROKER_ERROR.
|
|
391
|
+
"""
|
|
392
|
+
from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
await self._ensure_initialized()
|
|
396
|
+
|
|
397
|
+
start_time = asyncio.get_event_loop().time()
|
|
398
|
+
|
|
399
|
+
# Convert milliseconds to seconds for internal use
|
|
400
|
+
timeout_seconds: Optional[float] = None
|
|
401
|
+
if timeout_ms is not None:
|
|
402
|
+
timeout_seconds = timeout_ms / 1000.0
|
|
403
|
+
|
|
404
|
+
# Quick path - check if task is already completed
|
|
405
|
+
async with self.session_factory() as session:
|
|
406
|
+
row = await session.get(TaskModel, task_id)
|
|
407
|
+
if row is None:
|
|
408
|
+
self.logger.error(f'Task {task_id} not found')
|
|
409
|
+
return TaskResult(
|
|
410
|
+
err=TaskError(
|
|
411
|
+
error_code=LibraryErrorCode.TASK_NOT_FOUND,
|
|
412
|
+
message=f'Task {task_id} not found in database',
|
|
413
|
+
data={'task_id': task_id},
|
|
414
|
+
)
|
|
415
|
+
)
|
|
416
|
+
if row.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
|
|
417
|
+
self.logger.info(f'Task {task_id} already completed')
|
|
418
|
+
return task_result_from_json(loads_json(row.result))
|
|
419
|
+
if row.status == TaskStatus.CANCELLED:
|
|
420
|
+
self.logger.info(f'Task {task_id} was cancelled')
|
|
421
|
+
return TaskResult(
|
|
422
|
+
err=TaskError(
|
|
423
|
+
error_code=LibraryErrorCode.TASK_CANCELLED,
|
|
424
|
+
message=f'Task {task_id} was cancelled before completion',
|
|
425
|
+
data={'task_id': task_id},
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
# Listen for task completion notifications with polling fallback
|
|
430
|
+
# Multiple clients can listen to task_done; each filters for their specific task_id
|
|
431
|
+
q = await self.listener.listen('task_done')
|
|
432
|
+
try:
|
|
433
|
+
poll_interval = (
|
|
434
|
+
5.0 # Fallback polling interval (handles lost notifications)
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
while True:
|
|
438
|
+
# Calculate remaining timeout
|
|
439
|
+
remaining_timeout = None
|
|
440
|
+
if timeout_seconds:
|
|
441
|
+
elapsed = asyncio.get_event_loop().time() - start_time
|
|
442
|
+
remaining_timeout = timeout_seconds - elapsed
|
|
443
|
+
if remaining_timeout <= 0:
|
|
444
|
+
return TaskResult(
|
|
445
|
+
err=TaskError(
|
|
446
|
+
error_code=LibraryErrorCode.WAIT_TIMEOUT,
|
|
447
|
+
message=f'Timed out waiting for task {task_id} after {timeout_ms}ms. Task may still be running.',
|
|
448
|
+
data={'task_id': task_id, 'timeout_ms': timeout_ms},
|
|
449
|
+
)
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# Wait for NOTIFY or timeout (whichever comes first)
|
|
453
|
+
wait_time = (
|
|
454
|
+
min(poll_interval, remaining_timeout)
|
|
455
|
+
if remaining_timeout
|
|
456
|
+
else poll_interval
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
try:
|
|
460
|
+
|
|
461
|
+
async def _wait_for_task() -> None:
|
|
462
|
+
# Filter notifications: only process our specific task_id
|
|
463
|
+
while True:
|
|
464
|
+
note = (
|
|
465
|
+
await q.get()
|
|
466
|
+
) # Blocks until any task_done notification
|
|
467
|
+
if (
|
|
468
|
+
note.payload == task_id
|
|
469
|
+
): # Check if it's for our task
|
|
470
|
+
return # Found our task completion!
|
|
471
|
+
|
|
472
|
+
# Wait for our specific task notification with timeout
|
|
473
|
+
await asyncio.wait_for(_wait_for_task(), timeout=wait_time)
|
|
474
|
+
break # Got our task completion notification
|
|
475
|
+
|
|
476
|
+
except asyncio.TimeoutError:
|
|
477
|
+
# Fallback polling: handles lost notifications or crashed workers
|
|
478
|
+
# Ensures eventual consistency even if NOTIFY system fails
|
|
479
|
+
async with self.session_factory() as session:
|
|
480
|
+
row = await session.get(TaskModel, task_id)
|
|
481
|
+
if row is None:
|
|
482
|
+
return TaskResult(
|
|
483
|
+
err=TaskError(
|
|
484
|
+
error_code=LibraryErrorCode.TASK_NOT_FOUND,
|
|
485
|
+
message=f'Task {task_id} not found in database',
|
|
486
|
+
data={'task_id': task_id},
|
|
487
|
+
)
|
|
488
|
+
)
|
|
489
|
+
if row.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
|
|
490
|
+
self.logger.debug(
|
|
491
|
+
f'Task {task_id} completed, polling database'
|
|
492
|
+
)
|
|
493
|
+
return task_result_from_json(loads_json(row.result))
|
|
494
|
+
if row.status == TaskStatus.CANCELLED:
|
|
495
|
+
self.logger.error(f'Task {task_id} was cancelled')
|
|
496
|
+
return TaskResult(
|
|
497
|
+
err=TaskError(
|
|
498
|
+
error_code=LibraryErrorCode.TASK_CANCELLED,
|
|
499
|
+
message=f'Task {task_id} was cancelled before completion',
|
|
500
|
+
data={'task_id': task_id},
|
|
501
|
+
)
|
|
502
|
+
)
|
|
503
|
+
# Task still not done, continue waiting...
|
|
504
|
+
|
|
505
|
+
finally:
|
|
506
|
+
# Clean up subscription (keeps server-side LISTEN active for other waiters)
|
|
507
|
+
await self.listener.unsubscribe('task_done', q)
|
|
508
|
+
|
|
509
|
+
# Final database check to get actual task result after notification
|
|
510
|
+
async with self.session_factory() as session:
|
|
511
|
+
row = await session.get(TaskModel, task_id)
|
|
512
|
+
if row is None:
|
|
513
|
+
self.logger.error(f'Task {task_id} not found')
|
|
514
|
+
return TaskResult(
|
|
515
|
+
err=TaskError(
|
|
516
|
+
error_code=LibraryErrorCode.TASK_NOT_FOUND,
|
|
517
|
+
message=f'Task {task_id} not found in database',
|
|
518
|
+
data={'task_id': task_id},
|
|
519
|
+
)
|
|
520
|
+
)
|
|
521
|
+
if row.status == TaskStatus.CANCELLED:
|
|
522
|
+
self.logger.error(f'Task {task_id} was cancelled')
|
|
523
|
+
return TaskResult(
|
|
524
|
+
err=TaskError(
|
|
525
|
+
error_code=LibraryErrorCode.TASK_CANCELLED,
|
|
526
|
+
message=f'Task {task_id} was cancelled before completion',
|
|
527
|
+
data={'task_id': task_id},
|
|
528
|
+
)
|
|
529
|
+
)
|
|
530
|
+
if row.status not in (TaskStatus.COMPLETED, TaskStatus.FAILED):
|
|
531
|
+
return TaskResult(
|
|
532
|
+
err=TaskError(
|
|
533
|
+
error_code=LibraryErrorCode.WAIT_TIMEOUT,
|
|
534
|
+
message=f'Task {task_id} not completed after notification (status: {row.status}). Task may still be running.',
|
|
535
|
+
data={'task_id': task_id, 'status': str(row.status)},
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
return task_result_from_json(loads_json(row.result))
|
|
539
|
+
except asyncio.CancelledError:
|
|
540
|
+
raise
|
|
541
|
+
except Exception as exc:
|
|
542
|
+
self.logger.exception('Broker error while retrieving task result')
|
|
543
|
+
return TaskResult(
|
|
544
|
+
err=TaskError(
|
|
545
|
+
error_code=LibraryErrorCode.BROKER_ERROR,
|
|
546
|
+
message='Broker error while retrieving task result',
|
|
547
|
+
data={'task_id': task_id, 'timeout_ms': timeout_ms},
|
|
548
|
+
exception=exc,
|
|
549
|
+
)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
async def close_async(self) -> None:
|
|
553
|
+
await self.listener.close()
|
|
554
|
+
await self.async_engine.dispose()
|
|
555
|
+
|
|
556
|
+
# ------------- Operational & Monitoring Methods -------------
|
|
557
|
+
|
|
558
|
+
async def get_stale_tasks(
|
|
559
|
+
self, stale_threshold_minutes: int = 2
|
|
560
|
+
) -> list[dict[str, Any]]:
|
|
561
|
+
"""
|
|
562
|
+
Identify potentially crashed tasks based on heartbeat absence.
|
|
563
|
+
|
|
564
|
+
Finds RUNNING tasks whose workers haven't sent heartbeats within the threshold,
|
|
565
|
+
indicating the worker process may have crashed or become unresponsive.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
stale_threshold_minutes: Minutes without heartbeat to consider stale
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
List of task info dicts: id, worker_hostname, worker_pid, last_heartbeat
|
|
572
|
+
"""
|
|
573
|
+
async with self.session_factory() as session:
|
|
574
|
+
result = await session.execute(
|
|
575
|
+
text("""
|
|
576
|
+
SELECT
|
|
577
|
+
t.id,
|
|
578
|
+
t.worker_hostname,
|
|
579
|
+
t.worker_pid,
|
|
580
|
+
t.worker_process_name,
|
|
581
|
+
hb.last_heartbeat,
|
|
582
|
+
t.started_at,
|
|
583
|
+
t.task_name
|
|
584
|
+
FROM horsies_tasks t
|
|
585
|
+
LEFT JOIN LATERAL (
|
|
586
|
+
SELECT sent_at AS last_heartbeat
|
|
587
|
+
FROM horsies_heartbeats h
|
|
588
|
+
WHERE h.task_id = t.id AND h.role = 'runner'
|
|
589
|
+
ORDER BY sent_at DESC
|
|
590
|
+
LIMIT 1
|
|
591
|
+
) hb ON TRUE
|
|
592
|
+
WHERE t.status = 'RUNNING'
|
|
593
|
+
AND t.started_at IS NOT NULL
|
|
594
|
+
AND COALESCE(hb.last_heartbeat, t.started_at) < NOW() - CAST(:stale_threshold || ' minutes' AS INTERVAL)
|
|
595
|
+
ORDER BY hb.last_heartbeat NULLS FIRST
|
|
596
|
+
"""),
|
|
597
|
+
{'stale_threshold': stale_threshold_minutes},
|
|
598
|
+
)
|
|
599
|
+
columns = result.keys()
|
|
600
|
+
return [dict(zip(columns, row)) for row in result.fetchall()]
|
|
601
|
+
|
|
602
|
+
async def get_worker_stats(self) -> list[dict[str, Any]]:
|
|
603
|
+
"""
|
|
604
|
+
Gather statistics about active worker processes.
|
|
605
|
+
|
|
606
|
+
Groups RUNNING tasks by worker to show load distribution and health.
|
|
607
|
+
Useful for monitoring worker performance and identifying bottlenecks.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
List of worker stats: worker_hostname, worker_pid, active_tasks, oldest_task_start
|
|
611
|
+
"""
|
|
612
|
+
async with self.session_factory() as session:
|
|
613
|
+
result = await session.execute(
|
|
614
|
+
text("""
|
|
615
|
+
SELECT
|
|
616
|
+
t.worker_hostname,
|
|
617
|
+
t.worker_pid,
|
|
618
|
+
t.worker_process_name,
|
|
619
|
+
COUNT(*) AS active_tasks,
|
|
620
|
+
MIN(t.started_at) AS oldest_task_start,
|
|
621
|
+
MAX(hb.last_heartbeat) AS latest_heartbeat
|
|
622
|
+
FROM horsies_tasks t
|
|
623
|
+
LEFT JOIN LATERAL (
|
|
624
|
+
SELECT sent_at AS last_heartbeat
|
|
625
|
+
FROM horsies_heartbeats h
|
|
626
|
+
WHERE h.task_id = t.id AND h.role = 'runner'
|
|
627
|
+
ORDER BY sent_at DESC
|
|
628
|
+
LIMIT 1
|
|
629
|
+
) hb ON TRUE
|
|
630
|
+
WHERE t.status = 'RUNNING'
|
|
631
|
+
AND t.worker_hostname IS NOT NULL
|
|
632
|
+
GROUP BY t.worker_hostname, t.worker_pid, t.worker_process_name
|
|
633
|
+
ORDER BY active_tasks DESC
|
|
634
|
+
""")
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
columns = result.keys()
|
|
638
|
+
return [dict(zip(columns, row)) for row in result.fetchall()]
|
|
639
|
+
|
|
640
|
+
async def get_expired_tasks(self) -> list[dict[str, Any]]:
|
|
641
|
+
"""
|
|
642
|
+
Find tasks that expired before worker processing.
|
|
643
|
+
|
|
644
|
+
Identifies PENDING tasks that exceeded their good_until deadline,
|
|
645
|
+
indicating potential worker capacity issues or scheduling problems.
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
List of expired task info: id, task_name, queue_name, good_until, expired_for
|
|
649
|
+
"""
|
|
650
|
+
async with self.session_factory() as session:
|
|
651
|
+
result = await session.execute(
|
|
652
|
+
text("""
|
|
653
|
+
SELECT
|
|
654
|
+
id,
|
|
655
|
+
task_name,
|
|
656
|
+
queue_name,
|
|
657
|
+
priority,
|
|
658
|
+
sent_at,
|
|
659
|
+
good_until,
|
|
660
|
+
NOW() - good_until as expired_for
|
|
661
|
+
FROM horsies_tasks
|
|
662
|
+
WHERE status = 'PENDING'
|
|
663
|
+
AND good_until < NOW()
|
|
664
|
+
ORDER BY good_until ASC
|
|
665
|
+
""")
|
|
666
|
+
)
|
|
667
|
+
|
|
668
|
+
columns = result.keys()
|
|
669
|
+
return [dict(zip(columns, row)) for row in result.fetchall()]
|
|
670
|
+
|
|
671
|
+
async def mark_stale_tasks_as_failed(
|
|
672
|
+
self, stale_threshold_ms: int = 300_000
|
|
673
|
+
) -> int:
|
|
674
|
+
"""
|
|
675
|
+
Clean up crashed worker tasks by marking them as FAILED.
|
|
676
|
+
|
|
677
|
+
Updates RUNNING tasks that haven't received heartbeats within the threshold.
|
|
678
|
+
This is typically called by a cleanup process to handle worker crashes.
|
|
679
|
+
Creates a proper TaskResult with WORKER_CRASHED error code.
|
|
680
|
+
|
|
681
|
+
Args:
|
|
682
|
+
stale_threshold_ms: Milliseconds without heartbeat to consider crashed (default: 300000 = 5 minutes)
|
|
683
|
+
|
|
684
|
+
Returns:
|
|
685
|
+
Number of tasks marked as failed
|
|
686
|
+
"""
|
|
687
|
+
from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
|
|
688
|
+
from horsies.core.codec.serde import dumps_json
|
|
689
|
+
|
|
690
|
+
# Convert milliseconds to seconds for PostgreSQL INTERVAL
|
|
691
|
+
stale_threshold_seconds = stale_threshold_ms / 1000.0
|
|
692
|
+
|
|
693
|
+
async with self.session_factory() as session:
|
|
694
|
+
# First, find stale tasks and get their metadata
|
|
695
|
+
stale_tasks_result = await session.execute(
|
|
696
|
+
text("""
|
|
697
|
+
SELECT t2.id, t2.worker_pid, t2.worker_hostname, t2.claimed_by_worker_id,
|
|
698
|
+
t2.started_at, hb.last_heartbeat
|
|
699
|
+
FROM horsies_tasks t2
|
|
700
|
+
LEFT JOIN LATERAL (
|
|
701
|
+
SELECT sent_at AS last_heartbeat
|
|
702
|
+
FROM horsies_heartbeats h
|
|
703
|
+
WHERE h.task_id = t2.id AND h.role = 'runner'
|
|
704
|
+
ORDER BY sent_at DESC
|
|
705
|
+
LIMIT 1
|
|
706
|
+
) hb ON TRUE
|
|
707
|
+
WHERE t2.status = 'RUNNING'
|
|
708
|
+
AND t2.started_at IS NOT NULL
|
|
709
|
+
AND COALESCE(hb.last_heartbeat, t2.started_at) < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL)
|
|
710
|
+
"""),
|
|
711
|
+
{'stale_threshold': stale_threshold_seconds},
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
stale_tasks = stale_tasks_result.fetchall()
|
|
715
|
+
if not stale_tasks:
|
|
716
|
+
return 0
|
|
717
|
+
|
|
718
|
+
# Mark each stale task as failed with proper TaskResult
|
|
719
|
+
for task_row in stale_tasks:
|
|
720
|
+
task_id = task_row[0]
|
|
721
|
+
worker_pid = task_row[1]
|
|
722
|
+
worker_hostname = task_row[2]
|
|
723
|
+
worker_id = task_row[3]
|
|
724
|
+
started_at = task_row[4]
|
|
725
|
+
last_heartbeat = task_row[5]
|
|
726
|
+
|
|
727
|
+
# Create TaskResult with WORKER_CRASHED error
|
|
728
|
+
task_error = TaskError(
|
|
729
|
+
error_code=LibraryErrorCode.WORKER_CRASHED,
|
|
730
|
+
message=f'Worker process crashed (no runner heartbeat for {stale_threshold_ms}ms = {stale_threshold_ms/1000:.1f}s)',
|
|
731
|
+
data={
|
|
732
|
+
'stale_threshold_ms': stale_threshold_ms,
|
|
733
|
+
'stale_threshold_seconds': stale_threshold_seconds,
|
|
734
|
+
'worker_pid': worker_pid,
|
|
735
|
+
'worker_hostname': worker_hostname,
|
|
736
|
+
'worker_id': worker_id,
|
|
737
|
+
'started_at': started_at.isoformat() if started_at else None,
|
|
738
|
+
'last_heartbeat': last_heartbeat.isoformat()
|
|
739
|
+
if last_heartbeat
|
|
740
|
+
else None,
|
|
741
|
+
'detected_at': datetime.now(timezone.utc).isoformat(),
|
|
742
|
+
},
|
|
743
|
+
)
|
|
744
|
+
task_result: TaskResult[None, TaskError] = TaskResult(err=task_error)
|
|
745
|
+
result_json = dumps_json(task_result)
|
|
746
|
+
|
|
747
|
+
# Update task with proper result
|
|
748
|
+
await session.execute(
|
|
749
|
+
text("""
|
|
750
|
+
UPDATE horsies_tasks
|
|
751
|
+
SET status = 'FAILED',
|
|
752
|
+
failed_at = NOW(),
|
|
753
|
+
failed_reason = :failed_reason,
|
|
754
|
+
result = :result,
|
|
755
|
+
updated_at = NOW()
|
|
756
|
+
WHERE id = :task_id
|
|
757
|
+
"""),
|
|
758
|
+
{
|
|
759
|
+
'task_id': task_id,
|
|
760
|
+
'failed_reason': f'Worker process crashed (no runner heartbeat for {stale_threshold_ms}ms = {stale_threshold_ms/1000:.1f}s)',
|
|
761
|
+
'result': result_json,
|
|
762
|
+
},
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
await session.commit()
|
|
766
|
+
return len(stale_tasks)
|
|
767
|
+
|
|
768
|
+
async def requeue_stale_claimed(self, stale_threshold_ms: int = 120_000) -> int:
|
|
769
|
+
"""
|
|
770
|
+
Requeue tasks stuck in CLAIMED without recent claimer heartbeat.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
stale_threshold_ms: Milliseconds without heartbeat to consider stale (default: 120000 = 2 minutes)
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
Number of tasks requeued
|
|
777
|
+
"""
|
|
778
|
+
# Convert milliseconds to seconds for PostgreSQL INTERVAL
|
|
779
|
+
stale_threshold_seconds = stale_threshold_ms / 1000.0
|
|
780
|
+
|
|
781
|
+
async with self.session_factory() as session:
|
|
782
|
+
result = await session.execute(
|
|
783
|
+
text("""
|
|
784
|
+
UPDATE horsies_tasks AS t
|
|
785
|
+
SET status = 'PENDING',
|
|
786
|
+
claimed = FALSE,
|
|
787
|
+
claimed_at = NULL,
|
|
788
|
+
claimed_by_worker_id = NULL,
|
|
789
|
+
updated_at = NOW()
|
|
790
|
+
FROM (
|
|
791
|
+
SELECT t2.id, hb.last_heartbeat, t2.claimed_at
|
|
792
|
+
FROM horsies_tasks t2
|
|
793
|
+
LEFT JOIN LATERAL (
|
|
794
|
+
SELECT sent_at AS last_heartbeat
|
|
795
|
+
FROM horsies_heartbeats h
|
|
796
|
+
WHERE h.task_id = t2.id AND h.role = 'claimer'
|
|
797
|
+
ORDER BY sent_at DESC
|
|
798
|
+
LIMIT 1
|
|
799
|
+
) hb ON TRUE
|
|
800
|
+
WHERE t2.status = 'CLAIMED'
|
|
801
|
+
) s
|
|
802
|
+
WHERE t.id = s.id
|
|
803
|
+
AND (
|
|
804
|
+
(s.last_heartbeat IS NULL AND s.claimed_at IS NOT NULL AND s.claimed_at < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL))
|
|
805
|
+
OR (s.last_heartbeat IS NOT NULL AND s.last_heartbeat < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL))
|
|
806
|
+
)
|
|
807
|
+
"""),
|
|
808
|
+
{'stale_threshold': stale_threshold_seconds},
|
|
809
|
+
)
|
|
810
|
+
await session.commit()
|
|
811
|
+
return getattr(result, 'rowcount', 0)
|
|
812
|
+
|
|
813
|
+
# ----------------- Sync API Facades -----------------
|
|
814
|
+
|
|
815
|
+
def enqueue(
|
|
816
|
+
self,
|
|
817
|
+
task_name: str,
|
|
818
|
+
args: tuple[Any, ...],
|
|
819
|
+
kwargs: dict[str, Any],
|
|
820
|
+
queue_name: str = 'default',
|
|
821
|
+
*,
|
|
822
|
+
priority: int = 100,
|
|
823
|
+
sent_at: Optional[datetime] = None,
|
|
824
|
+
good_until: Optional[datetime] = None,
|
|
825
|
+
task_options: Optional[str] = None,
|
|
826
|
+
) -> str:
|
|
827
|
+
"""
|
|
828
|
+
Synchronous task submission (runs enqueue_async in background loop).
|
|
829
|
+
"""
|
|
830
|
+
return self._loop_runner.call(
|
|
831
|
+
self.enqueue_async,
|
|
832
|
+
task_name,
|
|
833
|
+
args,
|
|
834
|
+
kwargs,
|
|
835
|
+
queue_name,
|
|
836
|
+
priority=priority,
|
|
837
|
+
sent_at=sent_at,
|
|
838
|
+
good_until=good_until,
|
|
839
|
+
task_options=task_options,
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
def get_result(
|
|
843
|
+
self, task_id: str, timeout_ms: Optional[int] = None
|
|
844
|
+
) -> 'TaskResult[Any, TaskError]':
|
|
845
|
+
"""
|
|
846
|
+
Synchronous result retrieval (runs get_result_async in background loop).
|
|
847
|
+
|
|
848
|
+
Args:
|
|
849
|
+
task_id: The task ID to retrieve result for
|
|
850
|
+
timeout_ms: Maximum time to wait for result (milliseconds)
|
|
851
|
+
|
|
852
|
+
Returns:
|
|
853
|
+
TaskResult - success, task error, retrieval error, or broker error
|
|
854
|
+
"""
|
|
855
|
+
return self._loop_runner.call(self.get_result_async, task_id, timeout_ms)
|
|
856
|
+
|
|
857
|
+
def close(self) -> None:
|
|
858
|
+
"""
|
|
859
|
+
Synchronous cleanup (runs close_async in background loop).
|
|
860
|
+
"""
|
|
861
|
+
try:
|
|
862
|
+
self._loop_runner.call(self.close_async)
|
|
863
|
+
finally:
|
|
864
|
+
self._loop_runner.stop()
|