horsies 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- horsies/__init__.py +115 -0
- horsies/core/__init__.py +0 -0
- horsies/core/app.py +552 -0
- horsies/core/banner.py +144 -0
- horsies/core/brokers/__init__.py +5 -0
- horsies/core/brokers/listener.py +444 -0
- horsies/core/brokers/postgres.py +864 -0
- horsies/core/cli.py +624 -0
- horsies/core/codec/serde.py +575 -0
- horsies/core/errors.py +535 -0
- horsies/core/logging.py +90 -0
- horsies/core/models/__init__.py +0 -0
- horsies/core/models/app.py +268 -0
- horsies/core/models/broker.py +79 -0
- horsies/core/models/queues.py +23 -0
- horsies/core/models/recovery.py +101 -0
- horsies/core/models/schedule.py +229 -0
- horsies/core/models/task_pg.py +307 -0
- horsies/core/models/tasks.py +332 -0
- horsies/core/models/workflow.py +1988 -0
- horsies/core/models/workflow_pg.py +245 -0
- horsies/core/registry/tasks.py +101 -0
- horsies/core/scheduler/__init__.py +26 -0
- horsies/core/scheduler/calculator.py +267 -0
- horsies/core/scheduler/service.py +569 -0
- horsies/core/scheduler/state.py +260 -0
- horsies/core/task_decorator.py +615 -0
- horsies/core/types/status.py +38 -0
- horsies/core/utils/imports.py +203 -0
- horsies/core/utils/loop_runner.py +44 -0
- horsies/core/worker/current.py +17 -0
- horsies/core/worker/worker.py +1967 -0
- horsies/core/workflows/__init__.py +23 -0
- horsies/core/workflows/engine.py +2344 -0
- horsies/core/workflows/recovery.py +501 -0
- horsies/core/workflows/registry.py +97 -0
- horsies/py.typed +0 -0
- horsies-0.1.0a1.dist-info/METADATA +31 -0
- horsies-0.1.0a1.dist-info/RECORD +42 -0
- horsies-0.1.0a1.dist-info/WHEEL +5 -0
- horsies-0.1.0a1.dist-info/entry_points.txt +2 -0
- horsies-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1967 @@
|
|
|
1
|
+
# app/core/worker.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
import os
|
|
6
|
+
import random
|
|
7
|
+
import signal
|
|
8
|
+
import socket
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime, timezone, timedelta
|
|
14
|
+
from importlib import import_module
|
|
15
|
+
from typing import Any, Optional, Sequence, Tuple, cast
|
|
16
|
+
import atexit
|
|
17
|
+
import hashlib
|
|
18
|
+
from psycopg import Connection, Cursor, InterfaceError, OperationalError
|
|
19
|
+
from psycopg.errors import DeadlockDetected, SerializationFailure
|
|
20
|
+
from psycopg.types.json import Jsonb
|
|
21
|
+
from psycopg_pool import ConnectionPool
|
|
22
|
+
from sqlalchemy import text
|
|
23
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
24
|
+
from horsies.core.app import Horsies
|
|
25
|
+
from horsies.core.brokers.listener import PostgresListener
|
|
26
|
+
from horsies.core.codec.serde import (
|
|
27
|
+
loads_json,
|
|
28
|
+
json_to_args,
|
|
29
|
+
json_to_kwargs,
|
|
30
|
+
dumps_json,
|
|
31
|
+
SerializationError,
|
|
32
|
+
task_result_from_json,
|
|
33
|
+
)
|
|
34
|
+
from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
|
|
35
|
+
from horsies.core.errors import ConfigurationError, ErrorCode
|
|
36
|
+
from horsies.core.logging import get_logger
|
|
37
|
+
from horsies.core.worker.current import get_current_app, set_current_app
|
|
38
|
+
import sys
|
|
39
|
+
from horsies.core.models.recovery import RecoveryConfig
|
|
40
|
+
from horsies.core.utils.imports import import_file_path
|
|
41
|
+
|
|
42
|
+
logger = get_logger('worker')
|
|
43
|
+
|
|
44
|
+
# ---------- Per-process connection pool (initialized in child processes) ----------
|
|
45
|
+
_worker_pool: ConnectionPool | None = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_worker_pool() -> ConnectionPool:
|
|
49
|
+
"""Get the per-process connection pool. Raises if not initialized."""
|
|
50
|
+
if _worker_pool is None:
|
|
51
|
+
raise RuntimeError(
|
|
52
|
+
'Worker connection pool not initialized. '
|
|
53
|
+
'This function must be called from a child worker process.'
|
|
54
|
+
)
|
|
55
|
+
return _worker_pool
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _cleanup_worker_pool() -> None:
|
|
59
|
+
"""Clean up the connection pool on process exit."""
|
|
60
|
+
global _worker_pool
|
|
61
|
+
if _worker_pool is not None:
|
|
62
|
+
try:
|
|
63
|
+
_worker_pool.close()
|
|
64
|
+
except Exception:
|
|
65
|
+
pass
|
|
66
|
+
_worker_pool = None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _initialize_worker_pool(database_url: str) -> None:
|
|
70
|
+
"""
|
|
71
|
+
Initialize the per-process connection pool.
|
|
72
|
+
|
|
73
|
+
In production: Called by _child_initializer in spawned worker processes.
|
|
74
|
+
In tests: Can be called directly to set up the pool for direct _run_task_entry calls.
|
|
75
|
+
"""
|
|
76
|
+
global _worker_pool
|
|
77
|
+
if _worker_pool is not None:
|
|
78
|
+
return # Already initialized
|
|
79
|
+
_worker_pool = ConnectionPool(
|
|
80
|
+
database_url,
|
|
81
|
+
min_size=1,
|
|
82
|
+
max_size=5,
|
|
83
|
+
max_lifetime=300.0,
|
|
84
|
+
open=True,
|
|
85
|
+
)
|
|
86
|
+
atexit.register(_cleanup_worker_pool)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _default_str_list() -> list[str]:
|
|
90
|
+
return []
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _default_str_int_dict() -> dict[str, int]:
|
|
94
|
+
return {}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _dedupe_paths(paths: Sequence[str]) -> list[str]:
|
|
98
|
+
unique: list[str] = []
|
|
99
|
+
seen: set[str] = set()
|
|
100
|
+
for path in paths:
|
|
101
|
+
if not path:
|
|
102
|
+
continue
|
|
103
|
+
if path in seen:
|
|
104
|
+
continue
|
|
105
|
+
seen.add(path)
|
|
106
|
+
unique.append(path)
|
|
107
|
+
return unique
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _debug_imports_enabled() -> bool:
|
|
111
|
+
return os.getenv('HORSIES_DEBUG_IMPORTS', '').strip() == '1'
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _debug_imports_log(message: str) -> None:
|
|
115
|
+
if _debug_imports_enabled():
|
|
116
|
+
logger.debug(message)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _derive_sys_path_roots_from_file(file_path: str) -> list[str]:
|
|
120
|
+
"""Derive sys.path roots from a file path.
|
|
121
|
+
|
|
122
|
+
Simple and explicit: just the file's parent directory.
|
|
123
|
+
No traversal for pyproject.toml to avoid monorepo collisions.
|
|
124
|
+
sys_path_roots from CLI are the authoritative source.
|
|
125
|
+
"""
|
|
126
|
+
abs_path = os.path.realpath(file_path)
|
|
127
|
+
parent_dir = os.path.dirname(abs_path)
|
|
128
|
+
return [parent_dir] if parent_dir else []
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _build_sys_path_roots(
|
|
132
|
+
app_locator: str,
|
|
133
|
+
imports: Sequence[str],
|
|
134
|
+
extra_roots: Sequence[str],
|
|
135
|
+
) -> list[str]:
|
|
136
|
+
roots: list[str] = []
|
|
137
|
+
for root in extra_roots:
|
|
138
|
+
if root:
|
|
139
|
+
roots.append(os.path.abspath(root))
|
|
140
|
+
if app_locator and ':' in app_locator:
|
|
141
|
+
mod_path = app_locator.split(':', 1)[0]
|
|
142
|
+
if mod_path.endswith('.py') or os.path.sep in mod_path:
|
|
143
|
+
roots.extend(_derive_sys_path_roots_from_file(mod_path))
|
|
144
|
+
for mod in imports:
|
|
145
|
+
if mod.endswith('.py') or os.path.sep in mod:
|
|
146
|
+
roots.extend(_derive_sys_path_roots_from_file(mod))
|
|
147
|
+
return _dedupe_paths(roots)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def import_by_path(path: str, module_name: str | None = None) -> Any:
|
|
151
|
+
"""Import module from file path."""
|
|
152
|
+
return import_file_path(path, module_name)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ---------- helpers for child processes ----------
|
|
156
|
+
def _locate_app(app_locator: str) -> Horsies:
|
|
157
|
+
"""
|
|
158
|
+
app_locator examples:
|
|
159
|
+
- 'package.module:app' -> import module, take variable attr
|
|
160
|
+
- 'package.module:create_app' -> call factory
|
|
161
|
+
- '/abs/path/to/file.py:app' -> load from file path
|
|
162
|
+
"""
|
|
163
|
+
logger.info(f'Locating app from {app_locator}')
|
|
164
|
+
if not app_locator or ':' not in app_locator:
|
|
165
|
+
raise ConfigurationError(
|
|
166
|
+
message='invalid app locator format',
|
|
167
|
+
code=ErrorCode.WORKER_INVALID_LOCATOR,
|
|
168
|
+
notes=[f'got: {app_locator!r}'],
|
|
169
|
+
help_text="use 'module.path:app' or '/path/to/file.py:app'",
|
|
170
|
+
)
|
|
171
|
+
mod_path, _, attr = app_locator.partition(':')
|
|
172
|
+
if mod_path.endswith('.py') or os.path.sep in mod_path:
|
|
173
|
+
mod = import_by_path(mod_path)
|
|
174
|
+
else:
|
|
175
|
+
mod = import_module(mod_path)
|
|
176
|
+
obj = getattr(mod, attr)
|
|
177
|
+
candidate = obj() if callable(obj) else obj
|
|
178
|
+
if not isinstance(candidate, Horsies):
|
|
179
|
+
raise ConfigurationError(
|
|
180
|
+
message='app locator did not resolve to Horsies instance',
|
|
181
|
+
code=ErrorCode.WORKER_INVALID_LOCATOR,
|
|
182
|
+
notes=[
|
|
183
|
+
f'locator: {app_locator!r}',
|
|
184
|
+
f'resolved to: {type(candidate).__name__}',
|
|
185
|
+
],
|
|
186
|
+
help_text='ensure the locator points to a Horsies app instance or factory',
|
|
187
|
+
)
|
|
188
|
+
return candidate
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _child_initializer(
|
|
192
|
+
app_locator: str,
|
|
193
|
+
imports: Sequence[str],
|
|
194
|
+
sys_path_roots: Sequence[str],
|
|
195
|
+
loglevel: int,
|
|
196
|
+
database_url: str,
|
|
197
|
+
) -> None:
|
|
198
|
+
global _worker_pool
|
|
199
|
+
|
|
200
|
+
# Ignore SIGINT in child processes - let the parent handle shutdown gracefully
|
|
201
|
+
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
|
202
|
+
|
|
203
|
+
# Set log level for this child process before any logging
|
|
204
|
+
from horsies.core.logging import set_default_level
|
|
205
|
+
set_default_level(loglevel)
|
|
206
|
+
|
|
207
|
+
# Mark child process to adjust logging behavior during module import
|
|
208
|
+
os.environ['HORSIES_CHILD_PROCESS'] = '1'
|
|
209
|
+
app_mod_path = app_locator.split(':', 1)[0]
|
|
210
|
+
app_mod_abs = (
|
|
211
|
+
os.path.abspath(app_mod_path) if app_mod_path.endswith('.py') else None
|
|
212
|
+
)
|
|
213
|
+
sys_path_roots_resolved = _build_sys_path_roots(
|
|
214
|
+
app_locator, imports, sys_path_roots
|
|
215
|
+
)
|
|
216
|
+
_debug_imports_log(
|
|
217
|
+
f'[child {os.getpid()}] app_locator={app_locator!r} sys_path_roots={sys_path_roots_resolved}'
|
|
218
|
+
)
|
|
219
|
+
for root in sys_path_roots_resolved:
|
|
220
|
+
if root not in sys.path:
|
|
221
|
+
sys.path.insert(0, root)
|
|
222
|
+
|
|
223
|
+
app = _locate_app(app_locator) # uses import_by_path -> loads as 'instance'
|
|
224
|
+
set_current_app(app)
|
|
225
|
+
|
|
226
|
+
# Suppress accidental sends while importing modules for discovery
|
|
227
|
+
try:
|
|
228
|
+
app.suppress_sends(True)
|
|
229
|
+
except Exception:
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
combined_imports = list(imports)
|
|
234
|
+
try:
|
|
235
|
+
combined_imports.extend(app.get_discovered_task_modules())
|
|
236
|
+
except Exception:
|
|
237
|
+
pass
|
|
238
|
+
combined_imports = _dedupe_paths(combined_imports)
|
|
239
|
+
_debug_imports_log(
|
|
240
|
+
f'[child {os.getpid()}] import_modules={combined_imports}'
|
|
241
|
+
)
|
|
242
|
+
for m in combined_imports:
|
|
243
|
+
if m.endswith('.py') or os.path.sep in m:
|
|
244
|
+
m_abs = os.path.abspath(m)
|
|
245
|
+
if app_mod_abs and os.path.samefile(m_abs, app_mod_abs):
|
|
246
|
+
continue # don't import the app file again
|
|
247
|
+
import_by_path(m_abs)
|
|
248
|
+
else:
|
|
249
|
+
import_module(m)
|
|
250
|
+
finally:
|
|
251
|
+
try:
|
|
252
|
+
app.suppress_sends(False)
|
|
253
|
+
except Exception:
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
# optional: sanity log
|
|
257
|
+
try:
|
|
258
|
+
keys = app.tasks.keys_list()
|
|
259
|
+
except Exception:
|
|
260
|
+
keys = list(app.tasks.keys())
|
|
261
|
+
_debug_imports_log(f'[child {os.getpid()}] registered_tasks={keys}')
|
|
262
|
+
|
|
263
|
+
# Initialize per-process connection pool (after all imports complete)
|
|
264
|
+
_initialize_worker_pool(database_url)
|
|
265
|
+
logger.debug(f'[child {os.getpid()}] Connection pool initialized')
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# ---------- Child-process execution ----------
|
|
269
|
+
def _heartbeat_worker(
|
|
270
|
+
task_id: str,
|
|
271
|
+
database_url: str,
|
|
272
|
+
stop_event: threading.Event,
|
|
273
|
+
sender_worker_id: str,
|
|
274
|
+
heartbeat_interval_ms: int = 30_000,
|
|
275
|
+
) -> None:
|
|
276
|
+
"""
|
|
277
|
+
Runs in a separate thread within the task process.
|
|
278
|
+
Sends heartbeats at configured interval until stopped or process dies.
|
|
279
|
+
|
|
280
|
+
Uses the per-process connection pool for efficient connection reuse.
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
task_id: The task ID
|
|
284
|
+
database_url: Database connection string (unused, kept for compatibility)
|
|
285
|
+
stop_event: Event to signal thread termination
|
|
286
|
+
sender_worker_id: Worker instance ID
|
|
287
|
+
heartbeat_interval_ms: Milliseconds between heartbeats (default: 30000 = 30s)
|
|
288
|
+
"""
|
|
289
|
+
# Convert to seconds only for threading.Event.wait()
|
|
290
|
+
heartbeat_interval_seconds = heartbeat_interval_ms / 1000.0
|
|
291
|
+
|
|
292
|
+
def send_heartbeat() -> bool:
|
|
293
|
+
try:
|
|
294
|
+
pool = _get_worker_pool()
|
|
295
|
+
with pool.connection() as conn:
|
|
296
|
+
cursor = conn.cursor()
|
|
297
|
+
cursor.execute(
|
|
298
|
+
"""
|
|
299
|
+
INSERT INTO horsies_heartbeats (task_id, sender_id, role, sent_at, hostname, pid)
|
|
300
|
+
VALUES (%s, %s, 'runner', NOW(), %s, %s)
|
|
301
|
+
""",
|
|
302
|
+
(
|
|
303
|
+
task_id,
|
|
304
|
+
sender_worker_id + f':{os.getpid()}',
|
|
305
|
+
socket.gethostname(),
|
|
306
|
+
os.getpid(),
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
conn.commit()
|
|
310
|
+
cursor.close()
|
|
311
|
+
return True
|
|
312
|
+
except Exception as e:
|
|
313
|
+
logger.error(f'Heartbeat failed for task {task_id}: {e}')
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
# Send an immediate heartbeat so freshly RUNNING tasks aren't considered stale
|
|
317
|
+
send_heartbeat()
|
|
318
|
+
|
|
319
|
+
while not stop_event.is_set():
|
|
320
|
+
# Wait for interval, but check stop_event periodically
|
|
321
|
+
if stop_event.wait(timeout=heartbeat_interval_seconds):
|
|
322
|
+
break # stop_event was set
|
|
323
|
+
|
|
324
|
+
# Use pooled connection for heartbeat
|
|
325
|
+
if not send_heartbeat():
|
|
326
|
+
break
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _is_retryable_db_error(exc: BaseException) -> bool:
|
|
330
|
+
match exc:
|
|
331
|
+
case OperationalError() | InterfaceError() | SerializationFailure() | DeadlockDetected():
|
|
332
|
+
return True
|
|
333
|
+
case _:
|
|
334
|
+
return False
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def _get_workflow_status_for_task(cursor: Cursor[Any], task_id: str) -> str | None:
|
|
338
|
+
cursor.execute(
|
|
339
|
+
"""
|
|
340
|
+
SELECT w.status FROM horsies_workflows w
|
|
341
|
+
JOIN horsies_workflow_tasks wt ON wt.workflow_id = w.id
|
|
342
|
+
WHERE wt.task_id = %s
|
|
343
|
+
""",
|
|
344
|
+
(task_id,),
|
|
345
|
+
)
|
|
346
|
+
row = cursor.fetchone()
|
|
347
|
+
if row is None:
|
|
348
|
+
return None
|
|
349
|
+
status = row[0]
|
|
350
|
+
if isinstance(status, str):
|
|
351
|
+
return status
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _mark_task_skipped_for_workflow_stop(
|
|
356
|
+
cursor: Cursor[Any],
|
|
357
|
+
conn: Connection[Any],
|
|
358
|
+
task_id: str,
|
|
359
|
+
workflow_status: str,
|
|
360
|
+
) -> Tuple[bool, str, Optional[str]]:
|
|
361
|
+
logger.info(f'Skipping task {task_id} - workflow is {workflow_status}')
|
|
362
|
+
cursor.execute(
|
|
363
|
+
"""
|
|
364
|
+
UPDATE horsies_workflow_tasks
|
|
365
|
+
SET status = 'SKIPPED'
|
|
366
|
+
WHERE task_id = %s AND status IN ('ENQUEUED', 'READY')
|
|
367
|
+
""",
|
|
368
|
+
(task_id,),
|
|
369
|
+
)
|
|
370
|
+
result: TaskResult[Any, TaskError] = TaskResult(
|
|
371
|
+
err=TaskError(
|
|
372
|
+
error_code='WORKFLOW_STOPPED',
|
|
373
|
+
message=f'Task skipped - workflow is {workflow_status}',
|
|
374
|
+
)
|
|
375
|
+
)
|
|
376
|
+
result_json = dumps_json(result)
|
|
377
|
+
cursor.execute(
|
|
378
|
+
"""
|
|
379
|
+
UPDATE horsies_tasks
|
|
380
|
+
SET status = 'COMPLETED',
|
|
381
|
+
result = %s,
|
|
382
|
+
completed_at = NOW(),
|
|
383
|
+
updated_at = NOW()
|
|
384
|
+
WHERE id = %s
|
|
385
|
+
""",
|
|
386
|
+
(result_json, task_id),
|
|
387
|
+
)
|
|
388
|
+
conn.commit()
|
|
389
|
+
return (True, result_json, f'Workflow {workflow_status}')
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def _update_workflow_task_running_with_retry(task_id: str) -> None:
|
|
393
|
+
backoff_seconds = (0.0, 0.25, 0.75)
|
|
394
|
+
total_attempts = len(backoff_seconds)
|
|
395
|
+
for attempt_index, delay in enumerate(backoff_seconds):
|
|
396
|
+
if delay > 0:
|
|
397
|
+
time.sleep(delay)
|
|
398
|
+
try:
|
|
399
|
+
pool = _get_worker_pool()
|
|
400
|
+
with pool.connection() as conn:
|
|
401
|
+
cursor = conn.cursor()
|
|
402
|
+
cursor.execute(
|
|
403
|
+
"""
|
|
404
|
+
UPDATE horsies_workflow_tasks wt
|
|
405
|
+
SET status = 'RUNNING', started_at = NOW()
|
|
406
|
+
FROM horsies_workflows w
|
|
407
|
+
WHERE wt.task_id = %s
|
|
408
|
+
AND wt.status = 'ENQUEUED'
|
|
409
|
+
AND wt.workflow_id = w.id
|
|
410
|
+
AND w.status = 'RUNNING'
|
|
411
|
+
""",
|
|
412
|
+
(task_id,),
|
|
413
|
+
)
|
|
414
|
+
conn.commit()
|
|
415
|
+
return
|
|
416
|
+
except Exception as exc:
|
|
417
|
+
retryable = _is_retryable_db_error(exc)
|
|
418
|
+
is_last_attempt = attempt_index == total_attempts - 1
|
|
419
|
+
match (retryable, is_last_attempt):
|
|
420
|
+
case (True, True):
|
|
421
|
+
logger.error(
|
|
422
|
+
f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
|
|
423
|
+
)
|
|
424
|
+
return
|
|
425
|
+
case (True, False):
|
|
426
|
+
logger.warning(
|
|
427
|
+
f'Retrying workflow_tasks RUNNING update for task {task_id}: {exc}'
|
|
428
|
+
)
|
|
429
|
+
continue
|
|
430
|
+
case (False, _):
|
|
431
|
+
logger.error(
|
|
432
|
+
f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
|
|
433
|
+
)
|
|
434
|
+
return
|
|
435
|
+
case _:
|
|
436
|
+
logger.error(
|
|
437
|
+
f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
|
|
438
|
+
)
|
|
439
|
+
return
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def _preflight_workflow_check(
|
|
443
|
+
task_id: str,
|
|
444
|
+
) -> Optional[Tuple[bool, str, Optional[str]]]:
|
|
445
|
+
try:
|
|
446
|
+
pool = _get_worker_pool()
|
|
447
|
+
with pool.connection() as conn:
|
|
448
|
+
cursor = conn.cursor()
|
|
449
|
+
workflow_status = _get_workflow_status_for_task(cursor, task_id)
|
|
450
|
+
match workflow_status:
|
|
451
|
+
case 'PAUSED' | 'CANCELLED':
|
|
452
|
+
return _mark_task_skipped_for_workflow_stop(
|
|
453
|
+
cursor,
|
|
454
|
+
conn,
|
|
455
|
+
task_id,
|
|
456
|
+
workflow_status,
|
|
457
|
+
)
|
|
458
|
+
case _:
|
|
459
|
+
return None
|
|
460
|
+
except Exception as e:
|
|
461
|
+
logger.error(f'Failed to check workflow status for task {task_id}: {e}')
|
|
462
|
+
return (False, '', 'WORKFLOW_CHECK_FAILED')
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def _confirm_ownership_and_set_running(
|
|
466
|
+
task_id: str,
|
|
467
|
+
worker_id: str,
|
|
468
|
+
) -> Optional[Tuple[bool, str, Optional[str]]]:
|
|
469
|
+
try:
|
|
470
|
+
pool = _get_worker_pool()
|
|
471
|
+
with pool.connection() as conn:
|
|
472
|
+
cursor = conn.cursor()
|
|
473
|
+
cursor.execute(
|
|
474
|
+
"""
|
|
475
|
+
UPDATE horsies_tasks
|
|
476
|
+
SET status = 'RUNNING',
|
|
477
|
+
claimed = FALSE,
|
|
478
|
+
claim_expires_at = NULL,
|
|
479
|
+
started_at = NOW(),
|
|
480
|
+
worker_pid = %s,
|
|
481
|
+
worker_hostname = %s,
|
|
482
|
+
worker_process_name = %s,
|
|
483
|
+
updated_at = NOW()
|
|
484
|
+
WHERE id = %s
|
|
485
|
+
AND status = 'CLAIMED'
|
|
486
|
+
AND claimed_by_worker_id = %s
|
|
487
|
+
AND (claim_expires_at IS NULL OR claim_expires_at > now())
|
|
488
|
+
AND NOT EXISTS (
|
|
489
|
+
SELECT 1
|
|
490
|
+
FROM horsies_workflow_tasks wt
|
|
491
|
+
JOIN horsies_workflows w ON w.id = wt.workflow_id
|
|
492
|
+
WHERE wt.task_id = %s
|
|
493
|
+
AND w.status IN ('PAUSED', 'CANCELLED')
|
|
494
|
+
)
|
|
495
|
+
RETURNING id
|
|
496
|
+
""",
|
|
497
|
+
(
|
|
498
|
+
os.getpid(),
|
|
499
|
+
socket.gethostname(),
|
|
500
|
+
f'worker-{os.getpid()}',
|
|
501
|
+
task_id,
|
|
502
|
+
worker_id,
|
|
503
|
+
task_id,
|
|
504
|
+
),
|
|
505
|
+
)
|
|
506
|
+
updated_row = cursor.fetchone()
|
|
507
|
+
if updated_row is None:
|
|
508
|
+
workflow_status = _get_workflow_status_for_task(cursor, task_id)
|
|
509
|
+
match workflow_status:
|
|
510
|
+
case 'PAUSED' | 'CANCELLED':
|
|
511
|
+
return _mark_task_skipped_for_workflow_stop(
|
|
512
|
+
cursor,
|
|
513
|
+
conn,
|
|
514
|
+
task_id,
|
|
515
|
+
workflow_status,
|
|
516
|
+
)
|
|
517
|
+
case _:
|
|
518
|
+
conn.rollback()
|
|
519
|
+
logger.warning(
|
|
520
|
+
f'Task {task_id} ownership lost - claim was reclaimed or status changed. '
|
|
521
|
+
f'Aborting execution to prevent double-execution.'
|
|
522
|
+
)
|
|
523
|
+
return (False, '', 'CLAIM_LOST')
|
|
524
|
+
|
|
525
|
+
conn.commit()
|
|
526
|
+
|
|
527
|
+
_update_workflow_task_running_with_retry(task_id)
|
|
528
|
+
return None
|
|
529
|
+
except Exception as e:
|
|
530
|
+
logger.error(f'Failed to transition task {task_id} to RUNNING: {e}')
|
|
531
|
+
return (False, '', 'OWNERSHIP_UNCONFIRMED')
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _start_heartbeat_thread(
|
|
535
|
+
task_id: str,
|
|
536
|
+
database_url: str,
|
|
537
|
+
heartbeat_stop_event: threading.Event,
|
|
538
|
+
worker_id: str,
|
|
539
|
+
runner_heartbeat_interval_ms: int,
|
|
540
|
+
) -> threading.Thread:
|
|
541
|
+
heartbeat_thread = threading.Thread(
|
|
542
|
+
target=_heartbeat_worker,
|
|
543
|
+
args=(
|
|
544
|
+
task_id,
|
|
545
|
+
database_url,
|
|
546
|
+
heartbeat_stop_event,
|
|
547
|
+
worker_id,
|
|
548
|
+
runner_heartbeat_interval_ms,
|
|
549
|
+
),
|
|
550
|
+
daemon=True,
|
|
551
|
+
name=f'heartbeat-{task_id[:8]}',
|
|
552
|
+
)
|
|
553
|
+
heartbeat_thread.start()
|
|
554
|
+
return heartbeat_thread
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def _run_task_entry(
|
|
558
|
+
task_name: str,
|
|
559
|
+
args_json: Optional[str],
|
|
560
|
+
kwargs_json: Optional[str],
|
|
561
|
+
task_id: str,
|
|
562
|
+
database_url: str,
|
|
563
|
+
master_worker_id: str,
|
|
564
|
+
runner_heartbeat_interval_ms: int = 30_000,
|
|
565
|
+
) -> Tuple[bool, str, Optional[str]]:
|
|
566
|
+
"""
|
|
567
|
+
Child-process entry.
|
|
568
|
+
Returns:
|
|
569
|
+
(ok, serialized_task_result_json, worker_failure_reason)
|
|
570
|
+
|
|
571
|
+
ok=True: we produced a valid TaskResult JSON (success *or* failure).
|
|
572
|
+
ok=False: worker couldn't even produce a TaskResult (infra error).
|
|
573
|
+
|
|
574
|
+
Do not mistake the 'ok' here as the task's own success/failure.
|
|
575
|
+
It's only a signal that the worker was able to produce a valid JSON.
|
|
576
|
+
The task's own success/failure is determined by the TaskResult's ok/err fields.
|
|
577
|
+
"""
|
|
578
|
+
logger.info(f'Starting task execution: {task_name}')
|
|
579
|
+
|
|
580
|
+
# Pre-execute guard: check if workflow is PAUSED or CANCELLED
|
|
581
|
+
# If so, skip execution and mark task as SKIPPED
|
|
582
|
+
preflight_result = _preflight_workflow_check(task_id)
|
|
583
|
+
if preflight_result is not None:
|
|
584
|
+
return preflight_result
|
|
585
|
+
|
|
586
|
+
# Mark as RUNNING in DB at the actual start of execution (child process)
|
|
587
|
+
# CRITICAL: Include ownership check to prevent double-execution when claim lease expires
|
|
588
|
+
ownership_result = _confirm_ownership_and_set_running(task_id, master_worker_id)
|
|
589
|
+
if ownership_result is not None:
|
|
590
|
+
return ownership_result
|
|
591
|
+
|
|
592
|
+
# Start heartbeat monitoring in separate thread
|
|
593
|
+
heartbeat_stop_event = threading.Event()
|
|
594
|
+
_start_heartbeat_thread(
|
|
595
|
+
task_id,
|
|
596
|
+
database_url,
|
|
597
|
+
heartbeat_stop_event,
|
|
598
|
+
master_worker_id,
|
|
599
|
+
runner_heartbeat_interval_ms,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
# Resolve from TaskRegistry
|
|
604
|
+
try:
|
|
605
|
+
app = get_current_app()
|
|
606
|
+
task = app.tasks[task_name] # may raise NotRegistered(KeyError)
|
|
607
|
+
except Exception as e:
|
|
608
|
+
logger.error(f'Failed to resolve task {task_name}: {e}')
|
|
609
|
+
tr: TaskResult[Any, TaskError] = TaskResult(
|
|
610
|
+
err=TaskError(
|
|
611
|
+
error_code=LibraryErrorCode.WORKER_RESOLUTION_ERROR,
|
|
612
|
+
message=f'Failed to resolve {task_name}: {type(e).__name__}: {e}',
|
|
613
|
+
data={'task_name': task_name},
|
|
614
|
+
)
|
|
615
|
+
)
|
|
616
|
+
# Stop heartbeat thread since task failed to resolve
|
|
617
|
+
heartbeat_stop_event.set()
|
|
618
|
+
logger.debug(
|
|
619
|
+
f'Heartbeat stop signal sent for task {task_id} (resolution failed)'
|
|
620
|
+
)
|
|
621
|
+
return (True, dumps_json(tr), f'{type(e).__name__}: {e}')
|
|
622
|
+
|
|
623
|
+
args = json_to_args(loads_json(args_json))
|
|
624
|
+
kwargs = json_to_kwargs(loads_json(kwargs_json))
|
|
625
|
+
|
|
626
|
+
# Deserialize injected TaskResults from workflow args_from
|
|
627
|
+
for key, value in list(kwargs.items()):
|
|
628
|
+
if (
|
|
629
|
+
isinstance(value, dict)
|
|
630
|
+
and '__horsies_taskresult__' in value
|
|
631
|
+
and value['__horsies_taskresult__']
|
|
632
|
+
):
|
|
633
|
+
# value is a dict with "data" key containing serialized TaskResult
|
|
634
|
+
task_result_dict = cast(dict[str, Any], value)
|
|
635
|
+
data_str = task_result_dict.get('data')
|
|
636
|
+
if isinstance(data_str, str):
|
|
637
|
+
kwargs[key] = task_result_from_json(loads_json(data_str))
|
|
638
|
+
|
|
639
|
+
# Handle WorkflowContext injection (only if task declares workflow_ctx param)
|
|
640
|
+
workflow_ctx_data = kwargs.pop('__horsies_workflow_ctx__', None)
|
|
641
|
+
if workflow_ctx_data is not None:
|
|
642
|
+
import inspect
|
|
643
|
+
|
|
644
|
+
# Access the underlying function: TaskFunctionImpl stores it in _original_fn
|
|
645
|
+
underlying_fn = getattr(task, '_original_fn', getattr(task, '_fn', task))
|
|
646
|
+
sig = inspect.signature(underlying_fn)
|
|
647
|
+
if 'workflow_ctx' in sig.parameters:
|
|
648
|
+
from horsies.core.models.workflow import (
|
|
649
|
+
WorkflowContext,
|
|
650
|
+
SubWorkflowSummary,
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Reconstruct TaskResults from serialized data (node_id-based)
|
|
654
|
+
results_by_id_raw = workflow_ctx_data.get('results_by_id', {})
|
|
655
|
+
|
|
656
|
+
results_by_id: dict[str, Any] = {}
|
|
657
|
+
for node_id, result_json in results_by_id_raw.items():
|
|
658
|
+
if isinstance(result_json, str):
|
|
659
|
+
results_by_id[node_id] = task_result_from_json(
|
|
660
|
+
loads_json(result_json)
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
# Reconstruct SubWorkflowSummaries from serialized data (node_id-based)
|
|
664
|
+
summaries_by_id_raw = workflow_ctx_data.get('summaries_by_id', {})
|
|
665
|
+
summaries_by_id: dict[str, Any] = {}
|
|
666
|
+
for node_id, summary_json in summaries_by_id_raw.items():
|
|
667
|
+
if isinstance(summary_json, str):
|
|
668
|
+
parsed = loads_json(summary_json)
|
|
669
|
+
if isinstance(parsed, dict):
|
|
670
|
+
summaries_by_id[node_id] = SubWorkflowSummary.from_json(
|
|
671
|
+
parsed
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
kwargs['workflow_ctx'] = WorkflowContext.from_serialized(
|
|
675
|
+
workflow_id=workflow_ctx_data.get('workflow_id', ''),
|
|
676
|
+
task_index=workflow_ctx_data.get('task_index', 0),
|
|
677
|
+
task_name=workflow_ctx_data.get('task_name', ''),
|
|
678
|
+
results_by_id=results_by_id,
|
|
679
|
+
summaries_by_id=summaries_by_id,
|
|
680
|
+
)
|
|
681
|
+
# If task doesn't declare workflow_ctx, silently skip injection
|
|
682
|
+
|
|
683
|
+
out = task(*args, **kwargs) # __call__ returns TaskResult
|
|
684
|
+
logger.info(f'Task execution completed: {[task_id]} : {[task_name]}')
|
|
685
|
+
|
|
686
|
+
if isinstance(out, TaskResult):
|
|
687
|
+
return (True, dumps_json(out), None)
|
|
688
|
+
|
|
689
|
+
if out is None:
|
|
690
|
+
tr: TaskResult[Any, TaskError] = TaskResult(
|
|
691
|
+
err=TaskError(
|
|
692
|
+
error_code=LibraryErrorCode.TASK_EXCEPTION,
|
|
693
|
+
message=f'Task {task_name} returned None instead of TaskResult or value',
|
|
694
|
+
data={'task_name': task_name},
|
|
695
|
+
)
|
|
696
|
+
)
|
|
697
|
+
return (True, dumps_json(tr), 'Task returned None')
|
|
698
|
+
|
|
699
|
+
# Plain value → wrap into success
|
|
700
|
+
return (True, dumps_json(TaskResult(ok=out)), None)
|
|
701
|
+
|
|
702
|
+
except SerializationError as se:
|
|
703
|
+
logger.error(f'Serialization error for task {task_name}: {se}')
|
|
704
|
+
tr = TaskResult(
|
|
705
|
+
err=TaskError(
|
|
706
|
+
error_code=LibraryErrorCode.WORKER_SERIALIZATION_ERROR,
|
|
707
|
+
message=str(se),
|
|
708
|
+
data={'task_name': task_name},
|
|
709
|
+
)
|
|
710
|
+
)
|
|
711
|
+
return (True, dumps_json(tr), f'SerializationError: {se}')
|
|
712
|
+
|
|
713
|
+
except BaseException as e:
|
|
714
|
+
logger.error(f'Task exception: {task_name}: {e}')
|
|
715
|
+
# If the task raised an exception, we wrap it in a TaskError
|
|
716
|
+
tr = TaskResult(
|
|
717
|
+
err=TaskError(
|
|
718
|
+
error_code=LibraryErrorCode.TASK_EXCEPTION,
|
|
719
|
+
message=f'{type(e).__name__}: {e}',
|
|
720
|
+
data={'task_name': task_name},
|
|
721
|
+
exception=e, # pydantic will accept and we flatten on serialization if needed elsewhere
|
|
722
|
+
)
|
|
723
|
+
)
|
|
724
|
+
return (True, dumps_json(tr), None)
|
|
725
|
+
|
|
726
|
+
finally:
|
|
727
|
+
# Always stop the heartbeat thread when task completes (success/failure/error)
|
|
728
|
+
heartbeat_stop_event.set()
|
|
729
|
+
logger.debug(f'Heartbeat stop signal sent for task {task_id}')
|
|
730
|
+
|
|
731
|
+
|
|
732
|
+
# ---------- Claim SQL (priority + sent_at) ----------
|
|
733
|
+
# Supports both hard cap mode (claim_expires_at = NULL) and soft cap mode with lease.
|
|
734
|
+
# In soft cap mode, also reclaims tasks with expired claim leases.
|
|
735
|
+
|
|
736
|
+
CLAIM_SQL = text("""
|
|
737
|
+
WITH next AS (
|
|
738
|
+
SELECT id
|
|
739
|
+
FROM horsies_tasks
|
|
740
|
+
WHERE queue_name = :queue
|
|
741
|
+
AND (
|
|
742
|
+
-- Fresh pending tasks
|
|
743
|
+
status = 'PENDING'
|
|
744
|
+
-- OR expired claims (soft cap mode: lease expired, reclaim for this worker)
|
|
745
|
+
OR (status = 'CLAIMED' AND claim_expires_at IS NOT NULL AND claim_expires_at < now())
|
|
746
|
+
)
|
|
747
|
+
AND sent_at <= now()
|
|
748
|
+
AND (next_retry_at IS NULL OR next_retry_at <= now())
|
|
749
|
+
AND (good_until IS NULL OR good_until > now())
|
|
750
|
+
ORDER BY priority ASC, sent_at ASC, id ASC
|
|
751
|
+
FOR UPDATE SKIP LOCKED
|
|
752
|
+
LIMIT :lim
|
|
753
|
+
)
|
|
754
|
+
UPDATE horsies_tasks t
|
|
755
|
+
SET status = 'CLAIMED',
|
|
756
|
+
claimed = TRUE,
|
|
757
|
+
claimed_at = now(),
|
|
758
|
+
claimed_by_worker_id = :worker_id,
|
|
759
|
+
claim_expires_at = :claim_expires_at,
|
|
760
|
+
updated_at = now()
|
|
761
|
+
FROM next
|
|
762
|
+
WHERE t.id = next.id
|
|
763
|
+
RETURNING t.id;
|
|
764
|
+
""")
|
|
765
|
+
|
|
766
|
+
# Fetch rows for a list of ids (to get func_path/args/etc.)
|
|
767
|
+
LOAD_ROWS_SQL = text("""
|
|
768
|
+
SELECT id, task_name, args, kwargs, retry_count, max_retries, task_options
|
|
769
|
+
FROM horsies_tasks
|
|
770
|
+
WHERE id = ANY(:ids)
|
|
771
|
+
""")
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
# ---------- WorkerConfig ----------
|
|
775
|
+
@dataclass
|
|
776
|
+
class WorkerConfig:
|
|
777
|
+
dsn: str # SQLAlchemy async URL (e.g. postgresql+psycopg://...)
|
|
778
|
+
psycopg_dsn: str # plain psycopg URL for listener
|
|
779
|
+
queues: list[str] # which queues to serve
|
|
780
|
+
processes: int = os.cpu_count() or 2
|
|
781
|
+
# Claiming knobs
|
|
782
|
+
# max_claim_batch: Top-level fairness limiter to prevent worker starvation in multi-worker setups.
|
|
783
|
+
# Limits claims per queue per pass, regardless of available capacity. Increase for high-concurrency workloads.
|
|
784
|
+
max_claim_batch: int = 2
|
|
785
|
+
# max_claim_per_worker: Per-worker limit on total CLAIMED tasks to prevent over-claiming.
|
|
786
|
+
# 0 = auto (defaults to processes). Increase for deeper prefetch if tasks start very quickly.
|
|
787
|
+
max_claim_per_worker: int = 0
|
|
788
|
+
coalesce_notifies: int = 100 # drain up to N notes after wake
|
|
789
|
+
app_locator: str = '' # NEW (see _locate_app)
|
|
790
|
+
sys_path_roots: list[str] = field(default_factory=_default_str_list)
|
|
791
|
+
imports: list[str] = field(
|
|
792
|
+
default_factory=_default_str_list
|
|
793
|
+
) # modules that contain @app.task defs
|
|
794
|
+
# When in CUSTOM mode, provide per-queue settings {name: {priority, max_concurrency}}
|
|
795
|
+
queue_priorities: dict[str, int] = field(default_factory=_default_str_int_dict)
|
|
796
|
+
queue_max_concurrency: dict[str, int] = field(default_factory=_default_str_int_dict)
|
|
797
|
+
cluster_wide_cap: Optional[int] = None
|
|
798
|
+
# Prefetch buffer: 0 = hard cap mode (count RUNNING + CLAIMED), >0 = soft cap with lease
|
|
799
|
+
prefetch_buffer: int = 0
|
|
800
|
+
# Claim lease duration in ms. Required when prefetch_buffer > 0.
|
|
801
|
+
claim_lease_ms: Optional[int] = None
|
|
802
|
+
# Recovery configuration from AppConfig
|
|
803
|
+
recovery_config: Optional['RecoveryConfig'] = (
|
|
804
|
+
None # RecoveryConfig, avoid circular import
|
|
805
|
+
)
|
|
806
|
+
# Log level for worker processes (default: INFO)
|
|
807
|
+
loglevel: int = 20 # logging.INFO
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
class Worker:
|
|
811
|
+
"""
|
|
812
|
+
Async master that:
|
|
813
|
+
- Subscribes to queue channels
|
|
814
|
+
- Claims tasks (priority + sent_at) with SKIP LOCKED
|
|
815
|
+
- Executes in a process pool
|
|
816
|
+
- On completion, writes result/failed, COMMITs, and NOTIFY task_done
|
|
817
|
+
"""
|
|
818
|
+
|
|
819
|
+
def __init__(
|
|
820
|
+
self,
|
|
821
|
+
session_factory: async_sessionmaker[AsyncSession],
|
|
822
|
+
listener: PostgresListener,
|
|
823
|
+
cfg: WorkerConfig,
|
|
824
|
+
):
|
|
825
|
+
self.sf = session_factory
|
|
826
|
+
self.listener = listener
|
|
827
|
+
self.cfg = cfg
|
|
828
|
+
self.worker_instance_id = str(uuid.uuid4())
|
|
829
|
+
self._started_at = datetime.now(timezone.utc)
|
|
830
|
+
self._app: Horsies | None = None
|
|
831
|
+
# Delay creation of the process pool until after preloading modules so that
|
|
832
|
+
# any import/validation errors surface in the main process at startup.
|
|
833
|
+
self._executor: Optional[ProcessPoolExecutor] = None
|
|
834
|
+
self._stop = asyncio.Event()
|
|
835
|
+
|
|
836
|
+
def request_stop(self) -> None:
|
|
837
|
+
"""Request worker to stop gracefully."""
|
|
838
|
+
self._stop.set()
|
|
839
|
+
|
|
840
|
+
# ----- lifecycle -----
|
|
841
|
+
|
|
842
|
+
async def start(self) -> None:
|
|
843
|
+
logger.debug('Starting worker')
|
|
844
|
+
# Preload the app and task modules in the main process to fail fast
|
|
845
|
+
self._preload_modules_main()
|
|
846
|
+
|
|
847
|
+
# Create the process pool AFTER successful preload so initializer runs in children only
|
|
848
|
+
# Compute the plain psycopg database URL for child processes
|
|
849
|
+
child_database_url = self.cfg.dsn.replace('+asyncpg', '').replace('+psycopg', '')
|
|
850
|
+
self._executor = ProcessPoolExecutor(
|
|
851
|
+
max_workers=self.cfg.processes,
|
|
852
|
+
initializer=_child_initializer,
|
|
853
|
+
initargs=(
|
|
854
|
+
self.cfg.app_locator,
|
|
855
|
+
self.cfg.imports,
|
|
856
|
+
self.cfg.sys_path_roots,
|
|
857
|
+
self.cfg.loglevel,
|
|
858
|
+
child_database_url,
|
|
859
|
+
),
|
|
860
|
+
)
|
|
861
|
+
await self.listener.start()
|
|
862
|
+
# Surface concurrency configuration clearly for operators
|
|
863
|
+
max_claimed_effective = (
|
|
864
|
+
self.cfg.max_claim_per_worker
|
|
865
|
+
if self.cfg.max_claim_per_worker > 0
|
|
866
|
+
else self.cfg.processes
|
|
867
|
+
)
|
|
868
|
+
logger.info(
|
|
869
|
+
'Concurrency config: processes=%s, cluster_wide_cap=%s, max_claim_per_worker=%s, max_claim_batch=%s',
|
|
870
|
+
self.cfg.processes,
|
|
871
|
+
(
|
|
872
|
+
self.cfg.cluster_wide_cap
|
|
873
|
+
if self.cfg.cluster_wide_cap is not None
|
|
874
|
+
else 'unlimited'
|
|
875
|
+
),
|
|
876
|
+
max_claimed_effective,
|
|
877
|
+
self.cfg.max_claim_batch,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
# Subscribe to each queue channel (and optionally a global)
|
|
881
|
+
self._queues = [
|
|
882
|
+
await self.listener.listen(f'task_queue_{q}') for q in self.cfg.queues
|
|
883
|
+
]
|
|
884
|
+
logger.info(f'Subscribed to queues: {self.cfg.queues}')
|
|
885
|
+
self._global = await self.listener.listen('task_new')
|
|
886
|
+
logger.info('Subscribed to global queue')
|
|
887
|
+
# Start claimer heartbeat loop (CLAIMED coverage)
|
|
888
|
+
asyncio.create_task(self._claimer_heartbeat_loop())
|
|
889
|
+
# Start worker state heartbeat loop for monitoring
|
|
890
|
+
asyncio.create_task(self._worker_state_heartbeat_loop())
|
|
891
|
+
logger.info('Worker state heartbeat loop started for monitoring')
|
|
892
|
+
# Start reaper loop for automatic stale task handling
|
|
893
|
+
if self.cfg.recovery_config:
|
|
894
|
+
asyncio.create_task(self._reaper_loop())
|
|
895
|
+
logger.info('Reaper loop started for automatic stale task recovery')
|
|
896
|
+
|
|
897
|
+
async def stop(self) -> None:
|
|
898
|
+
self._stop.set()
|
|
899
|
+
# Close the Postgres listener early to avoid UNLISTEN races on dispatcher connection
|
|
900
|
+
try:
|
|
901
|
+
await self.listener.close()
|
|
902
|
+
logger.info('Postgres listener closed')
|
|
903
|
+
except Exception as e:
|
|
904
|
+
logger.error(f'Error closing Postgres listener: {e}')
|
|
905
|
+
# Shutdown executor
|
|
906
|
+
if self._executor:
|
|
907
|
+
# Offload blocking shutdown to a thread to avoid freezing the event loop
|
|
908
|
+
loop = asyncio.get_running_loop()
|
|
909
|
+
executor = self._executor
|
|
910
|
+
self._executor = None
|
|
911
|
+
try:
|
|
912
|
+
await loop.run_in_executor(
|
|
913
|
+
None, lambda: executor.shutdown(wait=True, cancel_futures=True)
|
|
914
|
+
) # TODO:inspect this behaviour more in depth!
|
|
915
|
+
except Exception as e:
|
|
916
|
+
logger.error(f'Error shutting down executor: {e}')
|
|
917
|
+
logger.info('Worker executor shutdown')
|
|
918
|
+
logger.info('Worker stopped')
|
|
919
|
+
|
|
920
|
+
def _preload_modules_main(self) -> None:
|
|
921
|
+
"""Import the app and all task modules in the main process.
|
|
922
|
+
|
|
923
|
+
This ensures Pydantic validations and module-level side effects run once
|
|
924
|
+
and any configuration errors surface during startup rather than inside
|
|
925
|
+
the child process initializer.
|
|
926
|
+
"""
|
|
927
|
+
try:
|
|
928
|
+
sys_path_roots_resolved = _build_sys_path_roots(
|
|
929
|
+
self.cfg.app_locator, self.cfg.imports, self.cfg.sys_path_roots
|
|
930
|
+
)
|
|
931
|
+
_debug_imports_log(
|
|
932
|
+
f'[preload] app_locator={self.cfg.app_locator!r} sys_path_roots={sys_path_roots_resolved}'
|
|
933
|
+
)
|
|
934
|
+
for root in sys_path_roots_resolved:
|
|
935
|
+
if root not in sys.path:
|
|
936
|
+
sys.path.insert(0, root)
|
|
937
|
+
|
|
938
|
+
# Load app object (variable or factory)
|
|
939
|
+
app = _locate_app(self.cfg.app_locator)
|
|
940
|
+
# Optionally set as current for consistency in main process
|
|
941
|
+
set_current_app(app)
|
|
942
|
+
self._app = app
|
|
943
|
+
|
|
944
|
+
# Suppress accidental sends while importing modules for discovery
|
|
945
|
+
try:
|
|
946
|
+
app.suppress_sends(True)
|
|
947
|
+
except Exception:
|
|
948
|
+
pass
|
|
949
|
+
|
|
950
|
+
# Import declared modules that contain task definitions
|
|
951
|
+
combined_imports = list(self.cfg.imports)
|
|
952
|
+
try:
|
|
953
|
+
combined_imports.extend(app.get_discovered_task_modules())
|
|
954
|
+
except Exception:
|
|
955
|
+
pass
|
|
956
|
+
combined_imports = _dedupe_paths(combined_imports)
|
|
957
|
+
_debug_imports_log(f'[preload] import_modules={combined_imports}')
|
|
958
|
+
for m in combined_imports:
|
|
959
|
+
if m.endswith('.py') or os.path.sep in m:
|
|
960
|
+
import_by_path(os.path.abspath(m))
|
|
961
|
+
else:
|
|
962
|
+
import_module(m)
|
|
963
|
+
|
|
964
|
+
# Re-enable sends after import completes
|
|
965
|
+
try:
|
|
966
|
+
app.suppress_sends(False)
|
|
967
|
+
except Exception:
|
|
968
|
+
pass
|
|
969
|
+
_debug_imports_log(f'[preload] registered_tasks={app.list_tasks()}')
|
|
970
|
+
except Exception as e:
|
|
971
|
+
# Surface the error clearly and re-raise to stop startup
|
|
972
|
+
logger.error(f'Failed during preload of task modules: {e}')
|
|
973
|
+
raise
|
|
974
|
+
|
|
975
|
+
# ----- main loop -----
|
|
976
|
+
|
|
977
|
+
async def run_forever(self) -> None:
|
|
978
|
+
"""Main orchestrator loop."""
|
|
979
|
+
try:
|
|
980
|
+
# Add timeout to startup to prevent hanging
|
|
981
|
+
await asyncio.wait_for(self.start(), timeout=30.0)
|
|
982
|
+
logger.info('Worker started')
|
|
983
|
+
except asyncio.TimeoutError:
|
|
984
|
+
logger.error('Worker startup timed out after 30 seconds')
|
|
985
|
+
raise RuntimeError(
|
|
986
|
+
'Worker startup timeout - likely database connection issue'
|
|
987
|
+
)
|
|
988
|
+
try:
|
|
989
|
+
while not self._stop.is_set():
|
|
990
|
+
# Single budgeted claim pass, then wait for new NOTIFY
|
|
991
|
+
await self._claim_and_dispatch_all()
|
|
992
|
+
|
|
993
|
+
# Wait for a NOTIFY from any queue (coalesce bursts).
|
|
994
|
+
await self._wait_for_any_notify()
|
|
995
|
+
await self._claim_and_dispatch_all()
|
|
996
|
+
finally:
|
|
997
|
+
await self.stop()
|
|
998
|
+
|
|
999
|
+
async def _wait_for_any_notify(self) -> None:
|
|
1000
|
+
"""Wait on any subscribed queue channel; coalesce a burst."""
|
|
1001
|
+
import contextlib
|
|
1002
|
+
|
|
1003
|
+
queue_tasks = [
|
|
1004
|
+
asyncio.create_task(q.get()) for q in (self._queues + [self._global])
|
|
1005
|
+
]
|
|
1006
|
+
# Add only the stop event as an additional wait condition (no periodic polling)
|
|
1007
|
+
stop_task = asyncio.create_task(self._stop.wait())
|
|
1008
|
+
all_tasks = queue_tasks + [stop_task]
|
|
1009
|
+
|
|
1010
|
+
_, pending = await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
1011
|
+
|
|
1012
|
+
# Check if stop was signaled
|
|
1013
|
+
if self._stop.is_set():
|
|
1014
|
+
# Cancel all pending tasks and await them to avoid warnings
|
|
1015
|
+
for p in pending:
|
|
1016
|
+
p.cancel()
|
|
1017
|
+
for p in pending:
|
|
1018
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
1019
|
+
await p
|
|
1020
|
+
return
|
|
1021
|
+
|
|
1022
|
+
# cancel the rest to avoid background tasks piling up and await them
|
|
1023
|
+
for p in pending:
|
|
1024
|
+
p.cancel()
|
|
1025
|
+
for p in pending:
|
|
1026
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
1027
|
+
await p
|
|
1028
|
+
|
|
1029
|
+
# drain a burst
|
|
1030
|
+
drained = 0
|
|
1031
|
+
for q in self._queues + [self._global]:
|
|
1032
|
+
while drained < self.cfg.coalesce_notifies and not q.empty():
|
|
1033
|
+
try:
|
|
1034
|
+
q.get_nowait()
|
|
1035
|
+
drained += 1
|
|
1036
|
+
except asyncio.QueueEmpty:
|
|
1037
|
+
break
|
|
1038
|
+
|
|
1039
|
+
# ----- claim & dispatch -----
|
|
1040
|
+
|
|
1041
|
+
async def _claim_and_dispatch_all(self) -> bool:
|
|
1042
|
+
"""
|
|
1043
|
+
Claim tasks subject to:
|
|
1044
|
+
- max_claim_per_worker guard (prevents over-claiming)
|
|
1045
|
+
- queue priorities (CUSTOM mode)
|
|
1046
|
+
- per-queue max_concurrency (CUSTOM mode)
|
|
1047
|
+
- worker global concurrency (processes)
|
|
1048
|
+
Returns True if anything was claimed.
|
|
1049
|
+
"""
|
|
1050
|
+
# Guard: Check if we've already claimed too many tasks
|
|
1051
|
+
# Default depends on mode:
|
|
1052
|
+
# - Hard cap (prefetch_buffer=0): default to processes
|
|
1053
|
+
# - Soft cap (prefetch_buffer>0): default to processes + prefetch_buffer
|
|
1054
|
+
if self.cfg.max_claim_per_worker > 0:
|
|
1055
|
+
# User explicitly set a limit - use it
|
|
1056
|
+
max_claimed = self.cfg.max_claim_per_worker
|
|
1057
|
+
elif self.cfg.prefetch_buffer > 0:
|
|
1058
|
+
# Soft cap mode: allow claiming up to processes + prefetch_buffer
|
|
1059
|
+
max_claimed = self.cfg.processes + self.cfg.prefetch_buffer
|
|
1060
|
+
else:
|
|
1061
|
+
# Hard cap mode: limit to processes
|
|
1062
|
+
max_claimed = self.cfg.processes
|
|
1063
|
+
claimed_count = await self._count_claimed_for_worker()
|
|
1064
|
+
if claimed_count >= max_claimed:
|
|
1065
|
+
return False
|
|
1066
|
+
|
|
1067
|
+
# Cluster-wide, lock-guarded claim to avoid races. One short transaction.
|
|
1068
|
+
# Use local capacity + small prefetch to size claims fairly across workers.
|
|
1069
|
+
claimed_ids: list[str] = []
|
|
1070
|
+
|
|
1071
|
+
# Queue order: if custom priorities provided, sort by priority; otherwise keep given order
|
|
1072
|
+
if self.cfg.queue_priorities:
|
|
1073
|
+
ordered_queues = sorted(
|
|
1074
|
+
[q for q in self.cfg.queues if q in self.cfg.queue_priorities],
|
|
1075
|
+
key=lambda q: self.cfg.queue_priorities.get(q, 100),
|
|
1076
|
+
)
|
|
1077
|
+
else:
|
|
1078
|
+
ordered_queues = list(self.cfg.queues)
|
|
1079
|
+
|
|
1080
|
+
# Open one transaction, take a global advisory xact lock
|
|
1081
|
+
async with self.sf() as s:
|
|
1082
|
+
# Take a cluster-wide transaction-scoped advisory lock to serialize claiming
|
|
1083
|
+
await s.execute(
|
|
1084
|
+
text('SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))'),
|
|
1085
|
+
{'key': self._advisory_key_global()},
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
# Compute local budget and optional global remaining
|
|
1089
|
+
# Hard cap mode (prefetch_buffer=0): count RUNNING + CLAIMED for strict enforcement
|
|
1090
|
+
# Soft cap mode (prefetch_buffer>0): count only RUNNING, allow prefetch with lease
|
|
1091
|
+
hard_cap_mode = self.cfg.prefetch_buffer == 0
|
|
1092
|
+
|
|
1093
|
+
if hard_cap_mode:
|
|
1094
|
+
# Hard cap: count both RUNNING and CLAIMED for this worker
|
|
1095
|
+
local_in_flight = await self._count_in_flight_for_worker()
|
|
1096
|
+
max_local_capacity = self.cfg.processes
|
|
1097
|
+
else:
|
|
1098
|
+
# Soft cap: count only RUNNING to allow prefetch beyond processes
|
|
1099
|
+
local_in_flight = await self._count_only_running_for_worker()
|
|
1100
|
+
max_local_capacity = self.cfg.processes + self.cfg.prefetch_buffer
|
|
1101
|
+
local_available = max(0, int(max_local_capacity) - int(local_in_flight))
|
|
1102
|
+
budget_remaining = local_available
|
|
1103
|
+
|
|
1104
|
+
global_remaining: Optional[int] = None
|
|
1105
|
+
if self.cfg.cluster_wide_cap is not None:
|
|
1106
|
+
# Hard cap mode: count RUNNING + CLAIMED globally
|
|
1107
|
+
# (Note: prefetch_buffer must be 0 when cluster_wide_cap is set, enforced by config validation)
|
|
1108
|
+
res = await s.execute(
|
|
1109
|
+
text(
|
|
1110
|
+
"SELECT COUNT(*) FROM horsies_tasks WHERE status IN ('RUNNING', 'CLAIMED')"
|
|
1111
|
+
)
|
|
1112
|
+
)
|
|
1113
|
+
row = res.fetchone()
|
|
1114
|
+
if row:
|
|
1115
|
+
in_flight_global = int(row[0])
|
|
1116
|
+
else:
|
|
1117
|
+
in_flight_global = 0
|
|
1118
|
+
global_remaining = max(
|
|
1119
|
+
0, int(self.cfg.cluster_wide_cap) - in_flight_global
|
|
1120
|
+
)
|
|
1121
|
+
|
|
1122
|
+
# Total claim budget for this pass: local budget capped by global remaining (if any)
|
|
1123
|
+
total_remaining = (
|
|
1124
|
+
budget_remaining
|
|
1125
|
+
if global_remaining is None
|
|
1126
|
+
else min(budget_remaining, global_remaining)
|
|
1127
|
+
)
|
|
1128
|
+
if total_remaining <= 0:
|
|
1129
|
+
# Nothing to claim globally or locally
|
|
1130
|
+
await s.commit()
|
|
1131
|
+
return False
|
|
1132
|
+
|
|
1133
|
+
for qname in ordered_queues:
|
|
1134
|
+
if total_remaining <= 0:
|
|
1135
|
+
break
|
|
1136
|
+
|
|
1137
|
+
# Compute queue remaining in cluster (only if custom-configured)
|
|
1138
|
+
q_remaining: Optional[int] = None
|
|
1139
|
+
if (
|
|
1140
|
+
self.cfg.queue_priorities
|
|
1141
|
+
and qname in self.cfg.queue_max_concurrency
|
|
1142
|
+
):
|
|
1143
|
+
# Hard cap mode: count RUNNING + CLAIMED for this queue
|
|
1144
|
+
# Soft cap mode: count only RUNNING
|
|
1145
|
+
if hard_cap_mode:
|
|
1146
|
+
resq = await s.execute(
|
|
1147
|
+
text(
|
|
1148
|
+
"SELECT COUNT(*) FROM horsies_tasks WHERE status IN ('RUNNING', 'CLAIMED') AND queue_name = :q"
|
|
1149
|
+
),
|
|
1150
|
+
{'q': qname},
|
|
1151
|
+
)
|
|
1152
|
+
else:
|
|
1153
|
+
resq = await s.execute(
|
|
1154
|
+
text(
|
|
1155
|
+
"SELECT COUNT(*) FROM horsies_tasks WHERE status = 'RUNNING' AND queue_name = :q"
|
|
1156
|
+
),
|
|
1157
|
+
{'q': qname},
|
|
1158
|
+
)
|
|
1159
|
+
row = resq.fetchone()
|
|
1160
|
+
if row:
|
|
1161
|
+
in_flight_q = int(row[0])
|
|
1162
|
+
else:
|
|
1163
|
+
in_flight_q = 0
|
|
1164
|
+
max_q = int(self.cfg.queue_max_concurrency.get(qname, 0))
|
|
1165
|
+
q_remaining = max(0, max_q - in_flight_q)
|
|
1166
|
+
|
|
1167
|
+
# Determine how many we may claim from this queue
|
|
1168
|
+
# Hierarchy: max_claim_batch (fairness) -> q_remaining (queue cap) -> total_remaining (worker budget)
|
|
1169
|
+
|
|
1170
|
+
if self.cfg.queue_priorities:
|
|
1171
|
+
# Strict priority mode: try to fill remaining budget from this queue
|
|
1172
|
+
# Ignore max_claim_batch (which forces round-robin fairness)
|
|
1173
|
+
per_queue_cap = total_remaining
|
|
1174
|
+
else:
|
|
1175
|
+
# Default mode: use max_claim_batch to ensure fairness across queues
|
|
1176
|
+
per_queue_cap = self.cfg.max_claim_batch
|
|
1177
|
+
|
|
1178
|
+
if q_remaining is not None:
|
|
1179
|
+
per_queue_cap = min(per_queue_cap, q_remaining)
|
|
1180
|
+
to_claim = min(total_remaining, per_queue_cap)
|
|
1181
|
+
if to_claim <= 0:
|
|
1182
|
+
continue
|
|
1183
|
+
|
|
1184
|
+
ids = await self._claim_batch_locked(s, qname, to_claim)
|
|
1185
|
+
if not ids:
|
|
1186
|
+
continue
|
|
1187
|
+
claimed_ids.extend(ids)
|
|
1188
|
+
total_remaining -= len(ids)
|
|
1189
|
+
|
|
1190
|
+
await s.commit()
|
|
1191
|
+
|
|
1192
|
+
if not claimed_ids:
|
|
1193
|
+
return False
|
|
1194
|
+
|
|
1195
|
+
rows = await self._load_rows(claimed_ids)
|
|
1196
|
+
|
|
1197
|
+
# PAUSE guard: filter out tasks belonging to PAUSED workflows and unclaim them
|
|
1198
|
+
rows = await self._filter_paused_workflow_tasks(rows)
|
|
1199
|
+
|
|
1200
|
+
for row in rows:
|
|
1201
|
+
await self._dispatch_one(
|
|
1202
|
+
row['id'], row['task_name'], row['args'], row['kwargs']
|
|
1203
|
+
)
|
|
1204
|
+
return len(rows) > 0
|
|
1205
|
+
|
|
1206
|
+
async def _filter_paused_workflow_tasks(
|
|
1207
|
+
self, rows: list[dict[str, Any]]
|
|
1208
|
+
) -> list[dict[str, Any]]:
|
|
1209
|
+
"""
|
|
1210
|
+
Filter out tasks belonging to PAUSED workflows and unclaim them.
|
|
1211
|
+
|
|
1212
|
+
Post-claim guard: If a task belongs to a workflow that is PAUSED,
|
|
1213
|
+
we unclaim it (set back to PENDING) so it can be processed on resume.
|
|
1214
|
+
|
|
1215
|
+
Returns the filtered list of rows that should be dispatched.
|
|
1216
|
+
"""
|
|
1217
|
+
if not rows:
|
|
1218
|
+
return rows
|
|
1219
|
+
|
|
1220
|
+
task_ids = [row['id'] for row in rows]
|
|
1221
|
+
|
|
1222
|
+
# Find which tasks belong to PAUSED workflows
|
|
1223
|
+
async with self.sf() as s:
|
|
1224
|
+
res = await s.execute(
|
|
1225
|
+
text("""
|
|
1226
|
+
SELECT t.id
|
|
1227
|
+
FROM horsies_tasks t
|
|
1228
|
+
JOIN horsies_workflow_tasks wt ON wt.task_id = t.id
|
|
1229
|
+
JOIN horsies_workflows w ON w.id = wt.workflow_id
|
|
1230
|
+
WHERE t.id = ANY(:ids)
|
|
1231
|
+
AND w.status = 'PAUSED'
|
|
1232
|
+
"""),
|
|
1233
|
+
{'ids': task_ids},
|
|
1234
|
+
)
|
|
1235
|
+
paused_task_ids = {row[0] for row in res.fetchall()}
|
|
1236
|
+
|
|
1237
|
+
if paused_task_ids:
|
|
1238
|
+
# Unclaim these tasks: set back to PENDING so they can be picked up on resume
|
|
1239
|
+
await s.execute(
|
|
1240
|
+
text("""
|
|
1241
|
+
UPDATE horsies_tasks
|
|
1242
|
+
SET status = 'PENDING',
|
|
1243
|
+
claimed = FALSE,
|
|
1244
|
+
claimed_at = NULL,
|
|
1245
|
+
claimed_by_worker_id = NULL,
|
|
1246
|
+
updated_at = NOW()
|
|
1247
|
+
WHERE id = ANY(:ids)
|
|
1248
|
+
"""),
|
|
1249
|
+
{'ids': list(paused_task_ids)},
|
|
1250
|
+
)
|
|
1251
|
+
# Also reset workflow_tasks back to READY for consistency
|
|
1252
|
+
# (they were ENQUEUED, but the task is now unclaimed)
|
|
1253
|
+
await s.execute(
|
|
1254
|
+
text("""
|
|
1255
|
+
UPDATE horsies_workflow_tasks
|
|
1256
|
+
SET status = 'READY', task_id = NULL, started_at = NULL
|
|
1257
|
+
WHERE task_id = ANY(:ids)
|
|
1258
|
+
"""),
|
|
1259
|
+
{'ids': list(paused_task_ids)},
|
|
1260
|
+
)
|
|
1261
|
+
await s.commit()
|
|
1262
|
+
|
|
1263
|
+
# Return only tasks not belonging to PAUSED workflows
|
|
1264
|
+
return [row for row in rows if row['id'] not in paused_task_ids]
|
|
1265
|
+
|
|
1266
|
+
def _advisory_key_global(self) -> int:
|
|
1267
|
+
"""Compute a stable 64-bit advisory lock key for this cluster."""
|
|
1268
|
+
basis = (self.cfg.psycopg_dsn or self.cfg.dsn or 'horsies').encode(
|
|
1269
|
+
'utf-8', errors='ignore'
|
|
1270
|
+
)
|
|
1271
|
+
h = hashlib.sha256(b'horsies-global:' + basis).digest()
|
|
1272
|
+
return int.from_bytes(h[:8], byteorder='big', signed=True)
|
|
1273
|
+
|
|
1274
|
+
# Stale detection is handled via heartbeat policy for RUNNING tasks.
|
|
1275
|
+
|
|
1276
|
+
def _compute_claim_expires_at(self) -> Optional[datetime]:
|
|
1277
|
+
"""Compute claim expiration timestamp for soft cap mode, or None for hard cap mode."""
|
|
1278
|
+
if self.cfg.claim_lease_ms is None:
|
|
1279
|
+
return None
|
|
1280
|
+
return datetime.now(timezone.utc) + timedelta(
|
|
1281
|
+
milliseconds=self.cfg.claim_lease_ms
|
|
1282
|
+
)
|
|
1283
|
+
|
|
1284
|
+
async def _claim_batch_locked(
|
|
1285
|
+
self, s: AsyncSession, queue: str, limit: int
|
|
1286
|
+
) -> list[str]:
|
|
1287
|
+
"""Claim up to limit tasks for a given queue within an open transaction/lock."""
|
|
1288
|
+
res = await s.execute(
|
|
1289
|
+
CLAIM_SQL,
|
|
1290
|
+
{
|
|
1291
|
+
'queue': queue,
|
|
1292
|
+
'lim': limit,
|
|
1293
|
+
'worker_id': self.worker_instance_id,
|
|
1294
|
+
'claim_expires_at': self._compute_claim_expires_at(),
|
|
1295
|
+
},
|
|
1296
|
+
)
|
|
1297
|
+
return [r[0] for r in res.fetchall()]
|
|
1298
|
+
|
|
1299
|
+
async def _count_claimed_for_worker(self) -> int:
|
|
1300
|
+
"""Count only CLAIMED tasks for this worker (not yet RUNNING)."""
|
|
1301
|
+
async with self.sf() as s:
|
|
1302
|
+
res = await s.execute(
|
|
1303
|
+
text(
|
|
1304
|
+
"""
|
|
1305
|
+
SELECT COUNT(*)
|
|
1306
|
+
FROM horsies_tasks
|
|
1307
|
+
WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
|
|
1308
|
+
AND status = 'CLAIMED'
|
|
1309
|
+
"""
|
|
1310
|
+
),
|
|
1311
|
+
{'wid': self.worker_instance_id},
|
|
1312
|
+
)
|
|
1313
|
+
row = res.fetchone()
|
|
1314
|
+
return int(row[0]) if row else 0
|
|
1315
|
+
|
|
1316
|
+
async def _count_only_running_for_worker(self) -> int:
|
|
1317
|
+
"""Count only RUNNING tasks for this worker (excludes CLAIMED)."""
|
|
1318
|
+
async with self.sf() as s:
|
|
1319
|
+
res = await s.execute(
|
|
1320
|
+
text(
|
|
1321
|
+
"""
|
|
1322
|
+
SELECT COUNT(*)
|
|
1323
|
+
FROM horsies_tasks
|
|
1324
|
+
WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
|
|
1325
|
+
AND status = 'RUNNING'
|
|
1326
|
+
"""
|
|
1327
|
+
),
|
|
1328
|
+
{'wid': self.worker_instance_id},
|
|
1329
|
+
)
|
|
1330
|
+
row = res.fetchone()
|
|
1331
|
+
return int(row[0]) if row else 0
|
|
1332
|
+
|
|
1333
|
+
async def _count_in_flight_for_worker(self) -> int:
|
|
1334
|
+
"""Count RUNNING + CLAIMED tasks for this worker (hard cap mode)."""
|
|
1335
|
+
async with self.sf() as s:
|
|
1336
|
+
res = await s.execute(
|
|
1337
|
+
text(
|
|
1338
|
+
"""
|
|
1339
|
+
SELECT COUNT(*)
|
|
1340
|
+
FROM horsies_tasks
|
|
1341
|
+
WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
|
|
1342
|
+
AND status IN ('RUNNING', 'CLAIMED')
|
|
1343
|
+
"""
|
|
1344
|
+
),
|
|
1345
|
+
{'wid': self.worker_instance_id},
|
|
1346
|
+
)
|
|
1347
|
+
row = res.fetchone()
|
|
1348
|
+
return int(row[0]) if row else 0
|
|
1349
|
+
|
|
1350
|
+
async def _count_running_in_queue(self, queue_name: str) -> int:
|
|
1351
|
+
"""Count RUNNING tasks in a given queue across the cluster."""
|
|
1352
|
+
async with self.sf() as s:
|
|
1353
|
+
res = await s.execute(
|
|
1354
|
+
text(
|
|
1355
|
+
"""
|
|
1356
|
+
SELECT COUNT(*)
|
|
1357
|
+
FROM horsies_tasks
|
|
1358
|
+
WHERE status = 'RUNNING'
|
|
1359
|
+
AND queue_name = :q
|
|
1360
|
+
"""
|
|
1361
|
+
),
|
|
1362
|
+
{'q': queue_name},
|
|
1363
|
+
)
|
|
1364
|
+
row = res.fetchone()
|
|
1365
|
+
return int(row[0]) if row else 0
|
|
1366
|
+
|
|
1367
|
+
async def _claim_batch(self, queue: str, limit: int) -> list[str]:
|
|
1368
|
+
async with self.sf() as s:
|
|
1369
|
+
res = await s.execute(
|
|
1370
|
+
CLAIM_SQL,
|
|
1371
|
+
{
|
|
1372
|
+
'queue': queue,
|
|
1373
|
+
'lim': limit,
|
|
1374
|
+
'worker_id': self.worker_instance_id,
|
|
1375
|
+
'claim_expires_at': self._compute_claim_expires_at(),
|
|
1376
|
+
},
|
|
1377
|
+
)
|
|
1378
|
+
ids = [r[0] for r in res.fetchall()]
|
|
1379
|
+
# Make the CLAIMED state visible and release the row locks
|
|
1380
|
+
await s.commit()
|
|
1381
|
+
return ids
|
|
1382
|
+
|
|
1383
|
+
async def _load_rows(self, ids: Sequence[str]) -> list[dict[str, Any]]:
|
|
1384
|
+
if not ids:
|
|
1385
|
+
return []
|
|
1386
|
+
async with self.sf() as s:
|
|
1387
|
+
res = await s.execute(LOAD_ROWS_SQL, {'ids': list(ids)})
|
|
1388
|
+
cols = res.keys()
|
|
1389
|
+
return [dict(zip(cols, row)) for row in res.fetchall()]
|
|
1390
|
+
|
|
1391
|
+
async def _dispatch_one(
|
|
1392
|
+
self,
|
|
1393
|
+
task_id: str,
|
|
1394
|
+
task_name: str,
|
|
1395
|
+
args_json: Optional[str],
|
|
1396
|
+
kwargs_json: Optional[str],
|
|
1397
|
+
) -> None:
|
|
1398
|
+
"""Submit to process pool; attach completion handler."""
|
|
1399
|
+
assert self._executor is not None
|
|
1400
|
+
loop = asyncio.get_running_loop()
|
|
1401
|
+
|
|
1402
|
+
# Get heartbeat interval from recovery config (milliseconds)
|
|
1403
|
+
runner_heartbeat_interval_ms = 30_000 # default: 30 seconds
|
|
1404
|
+
if self.cfg.recovery_config:
|
|
1405
|
+
runner_heartbeat_interval_ms = (
|
|
1406
|
+
self.cfg.recovery_config.runner_heartbeat_interval_ms
|
|
1407
|
+
)
|
|
1408
|
+
|
|
1409
|
+
# Pass task_id and database_url to task process for self-heartbeat
|
|
1410
|
+
database_url = self.cfg.dsn.replace('+asyncpg', '').replace('+psycopg', '')
|
|
1411
|
+
fut = loop.run_in_executor(
|
|
1412
|
+
self._executor,
|
|
1413
|
+
_run_task_entry,
|
|
1414
|
+
task_name,
|
|
1415
|
+
args_json,
|
|
1416
|
+
kwargs_json,
|
|
1417
|
+
task_id,
|
|
1418
|
+
database_url,
|
|
1419
|
+
self.worker_instance_id,
|
|
1420
|
+
runner_heartbeat_interval_ms,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
# When done, record the outcome
|
|
1424
|
+
asyncio.create_task(self._finalize_after(fut, task_id))
|
|
1425
|
+
|
|
1426
|
+
# ----- finalize (write back to DB + notify) -----
|
|
1427
|
+
|
|
1428
|
+
async def _finalize_after(
|
|
1429
|
+
self, fut: 'asyncio.Future[tuple[bool, str, Optional[str]]]', task_id: str
|
|
1430
|
+
) -> None:
|
|
1431
|
+
ok, result_json_str, failed_reason = await fut
|
|
1432
|
+
now = datetime.now(timezone.utc)
|
|
1433
|
+
|
|
1434
|
+
# Note: Heartbeat thread in task process automatically dies when process completes
|
|
1435
|
+
|
|
1436
|
+
async with self.sf() as s:
|
|
1437
|
+
if not ok:
|
|
1438
|
+
# CLAIM_LOST: Another worker reclaimed this task - do nothing
|
|
1439
|
+
# The task is not failed; it belongs to another worker now
|
|
1440
|
+
match failed_reason:
|
|
1441
|
+
case (
|
|
1442
|
+
'CLAIM_LOST'
|
|
1443
|
+
| 'OWNERSHIP_UNCONFIRMED'
|
|
1444
|
+
| 'WORKFLOW_CHECK_FAILED'
|
|
1445
|
+
| 'WORKFLOW_STOPPED'
|
|
1446
|
+
):
|
|
1447
|
+
logger.debug(
|
|
1448
|
+
f'Task {task_id} aborted with reason={failed_reason}, skipping finalization'
|
|
1449
|
+
)
|
|
1450
|
+
return
|
|
1451
|
+
case _:
|
|
1452
|
+
pass
|
|
1453
|
+
|
|
1454
|
+
# worker-level failure (rare): mark FAILED with reason
|
|
1455
|
+
await s.execute(
|
|
1456
|
+
text("""
|
|
1457
|
+
UPDATE horsies_tasks
|
|
1458
|
+
SET status='FAILED',
|
|
1459
|
+
failed_at = :now,
|
|
1460
|
+
failed_reason = :reason,
|
|
1461
|
+
updated_at = :now
|
|
1462
|
+
WHERE id = :id
|
|
1463
|
+
"""),
|
|
1464
|
+
{
|
|
1465
|
+
'now': now,
|
|
1466
|
+
'reason': failed_reason or 'Worker failure',
|
|
1467
|
+
'id': task_id,
|
|
1468
|
+
},
|
|
1469
|
+
)
|
|
1470
|
+
# Trigger automatically sends NOTIFY on UPDATE
|
|
1471
|
+
await s.commit()
|
|
1472
|
+
return
|
|
1473
|
+
|
|
1474
|
+
# Parse the TaskResult we produced
|
|
1475
|
+
tr = task_result_from_json(loads_json(result_json_str))
|
|
1476
|
+
if tr.is_err():
|
|
1477
|
+
# Check if this task should be retried
|
|
1478
|
+
task_error = tr.unwrap_err()
|
|
1479
|
+
match task_error.error_code if task_error else None:
|
|
1480
|
+
case 'WORKFLOW_STOPPED':
|
|
1481
|
+
logger.debug(
|
|
1482
|
+
f'Task {task_id} skipped due to workflow stop, skipping finalization'
|
|
1483
|
+
)
|
|
1484
|
+
return
|
|
1485
|
+
case _:
|
|
1486
|
+
pass
|
|
1487
|
+
should_retry = await self._should_retry_task(task_id, task_error)
|
|
1488
|
+
if should_retry:
|
|
1489
|
+
await self._schedule_retry(task_id, s)
|
|
1490
|
+
await s.commit()
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
# Mark as failed if no retry
|
|
1494
|
+
await s.execute(
|
|
1495
|
+
text("""
|
|
1496
|
+
UPDATE horsies_tasks
|
|
1497
|
+
SET status='FAILED',
|
|
1498
|
+
failed_at = :now,
|
|
1499
|
+
result = :result_json, -- or result_json JSONB if you add that column
|
|
1500
|
+
updated_at = :now
|
|
1501
|
+
WHERE id = :id
|
|
1502
|
+
"""),
|
|
1503
|
+
{'now': now, 'result_json': result_json_str, 'id': task_id},
|
|
1504
|
+
)
|
|
1505
|
+
else:
|
|
1506
|
+
await s.execute(
|
|
1507
|
+
text("""
|
|
1508
|
+
UPDATE horsies_tasks
|
|
1509
|
+
SET status='COMPLETED',
|
|
1510
|
+
completed_at = :now,
|
|
1511
|
+
result = :result_json, -- or result_json JSONB
|
|
1512
|
+
updated_at = :now
|
|
1513
|
+
WHERE id = :id
|
|
1514
|
+
"""),
|
|
1515
|
+
{'now': now, 'result_json': result_json_str, 'id': task_id},
|
|
1516
|
+
)
|
|
1517
|
+
|
|
1518
|
+
# Handle workflow task completion (if this task is part of a workflow)
|
|
1519
|
+
await self._handle_workflow_task_if_needed(s, task_id, tr)
|
|
1520
|
+
|
|
1521
|
+
# Proactively wake workers to re-check capacity/backlog.
|
|
1522
|
+
try:
|
|
1523
|
+
# Notify workers globally and on the specific queue to wake claims
|
|
1524
|
+
# Fetch queue name for this task
|
|
1525
|
+
resq = await s.execute(
|
|
1526
|
+
text('SELECT queue_name FROM horsies_tasks WHERE id = :id'), {'id': task_id}
|
|
1527
|
+
)
|
|
1528
|
+
rowq = resq.fetchone()
|
|
1529
|
+
qname = str(rowq[0]) if rowq and rowq[0] else 'default'
|
|
1530
|
+
payload = f'capacity:{task_id}'
|
|
1531
|
+
await s.execute(
|
|
1532
|
+
text('SELECT pg_notify(:c1, :p)'), {'c1': 'task_new', 'p': payload}
|
|
1533
|
+
)
|
|
1534
|
+
await s.execute(
|
|
1535
|
+
text('SELECT pg_notify(:c2, :p)'),
|
|
1536
|
+
{'c2': f'task_queue_{qname}', 'p': payload},
|
|
1537
|
+
)
|
|
1538
|
+
except Exception:
|
|
1539
|
+
# Non-fatal if NOTIFY fails; continue
|
|
1540
|
+
pass
|
|
1541
|
+
|
|
1542
|
+
# Trigger automatically sends NOTIFY on UPDATE; commit to flush NOTIFYs
|
|
1543
|
+
await s.commit()
|
|
1544
|
+
|
|
1545
|
+
async def _handle_workflow_task_if_needed(
|
|
1546
|
+
self,
|
|
1547
|
+
session: 'AsyncSession',
|
|
1548
|
+
task_id: str,
|
|
1549
|
+
result: 'TaskResult[Any, TaskError]',
|
|
1550
|
+
) -> None:
|
|
1551
|
+
"""
|
|
1552
|
+
Check if task is part of a workflow and handle accordingly.
|
|
1553
|
+
|
|
1554
|
+
This method is called after a task completes (success or failure).
|
|
1555
|
+
It updates the workflow_task record and triggers dependency resolution.
|
|
1556
|
+
"""
|
|
1557
|
+
from horsies.core.workflows.engine import on_workflow_task_complete
|
|
1558
|
+
|
|
1559
|
+
# Quick check: is this task linked to a workflow?
|
|
1560
|
+
check = await session.execute(
|
|
1561
|
+
text('SELECT 1 FROM horsies_workflow_tasks WHERE task_id = :tid LIMIT 1'),
|
|
1562
|
+
{'tid': task_id},
|
|
1563
|
+
)
|
|
1564
|
+
|
|
1565
|
+
if check.fetchone() is None:
|
|
1566
|
+
return # Not a workflow task
|
|
1567
|
+
|
|
1568
|
+
broker = self._app.get_broker() if self._app is not None else None
|
|
1569
|
+
# Handle workflow task completion
|
|
1570
|
+
await on_workflow_task_complete(session, task_id, result, broker)
|
|
1571
|
+
|
|
1572
|
+
async def _should_retry_task(self, task_id: str, error: TaskError) -> bool:
|
|
1573
|
+
"""Check if a task should be retried based on its configuration and current retry count."""
|
|
1574
|
+
async with self.sf() as s:
|
|
1575
|
+
result = await s.execute(
|
|
1576
|
+
text(
|
|
1577
|
+
'SELECT retry_count, max_retries, task_options FROM horsies_tasks WHERE id = :id'
|
|
1578
|
+
),
|
|
1579
|
+
{'id': task_id},
|
|
1580
|
+
)
|
|
1581
|
+
row = result.fetchone()
|
|
1582
|
+
|
|
1583
|
+
if not row:
|
|
1584
|
+
return False
|
|
1585
|
+
|
|
1586
|
+
retry_count = row.retry_count or 0
|
|
1587
|
+
max_retries = row.max_retries or 0
|
|
1588
|
+
|
|
1589
|
+
if retry_count >= max_retries or max_retries == 0:
|
|
1590
|
+
return False
|
|
1591
|
+
|
|
1592
|
+
# Parse task options to check auto_retry_for
|
|
1593
|
+
try:
|
|
1594
|
+
task_options_data = loads_json(row.task_options) if row.task_options else {}
|
|
1595
|
+
if not isinstance(task_options_data, dict):
|
|
1596
|
+
return False
|
|
1597
|
+
auto_retry_for = task_options_data.get('auto_retry_for', [])
|
|
1598
|
+
except Exception:
|
|
1599
|
+
return False
|
|
1600
|
+
|
|
1601
|
+
if not auto_retry_for or not isinstance(auto_retry_for, list):
|
|
1602
|
+
return False
|
|
1603
|
+
|
|
1604
|
+
# Check if error matches auto_retry_for criteria (enum or string)
|
|
1605
|
+
code = (
|
|
1606
|
+
error.error_code.value
|
|
1607
|
+
if isinstance(error.error_code, LibraryErrorCode)
|
|
1608
|
+
else error.error_code
|
|
1609
|
+
)
|
|
1610
|
+
if code and code in auto_retry_for:
|
|
1611
|
+
return True
|
|
1612
|
+
|
|
1613
|
+
if error.exception:
|
|
1614
|
+
# Handle both Exception objects and serialized exception dicts
|
|
1615
|
+
exception_type: str | None = None
|
|
1616
|
+
if isinstance(error.exception, Exception):
|
|
1617
|
+
exception_type = type(error.exception).__name__
|
|
1618
|
+
elif isinstance(error.exception, dict):
|
|
1619
|
+
# Support both serialization formats:
|
|
1620
|
+
# - "type" from _exception_to_json
|
|
1621
|
+
# - "exception_type" from task_decorator data field
|
|
1622
|
+
exception_type = error.exception.get('type') or error.exception.get(
|
|
1623
|
+
'exception_type'
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
if exception_type and exception_type in auto_retry_for:
|
|
1627
|
+
return True
|
|
1628
|
+
|
|
1629
|
+
return False
|
|
1630
|
+
|
|
1631
|
+
async def _schedule_retry(self, task_id: str, session: AsyncSession) -> None:
|
|
1632
|
+
"""Schedule a task for retry by updating its status and next retry time."""
|
|
1633
|
+
# Get current retry configuration
|
|
1634
|
+
result = await session.execute(
|
|
1635
|
+
text('SELECT retry_count, task_options FROM horsies_tasks WHERE id = :id'),
|
|
1636
|
+
{'id': task_id},
|
|
1637
|
+
)
|
|
1638
|
+
row = result.fetchone()
|
|
1639
|
+
|
|
1640
|
+
if not row:
|
|
1641
|
+
return
|
|
1642
|
+
|
|
1643
|
+
retry_count = (row.retry_count or 0) + 1
|
|
1644
|
+
|
|
1645
|
+
# Parse retry policy from task options
|
|
1646
|
+
try:
|
|
1647
|
+
task_options_data = loads_json(row.task_options) if row.task_options else {}
|
|
1648
|
+
if not isinstance(task_options_data, dict):
|
|
1649
|
+
retry_policy_data = {}
|
|
1650
|
+
else:
|
|
1651
|
+
retry_policy_data = task_options_data.get('retry_policy', {})
|
|
1652
|
+
if not isinstance(retry_policy_data, dict):
|
|
1653
|
+
retry_policy_data = {}
|
|
1654
|
+
except Exception:
|
|
1655
|
+
retry_policy_data = {}
|
|
1656
|
+
|
|
1657
|
+
# Calculate retry delay
|
|
1658
|
+
delay_seconds = self._calculate_retry_delay(retry_count, retry_policy_data)
|
|
1659
|
+
next_retry_at = datetime.now(timezone.utc) + timedelta(seconds=delay_seconds)
|
|
1660
|
+
|
|
1661
|
+
# Update task for retry
|
|
1662
|
+
await session.execute(
|
|
1663
|
+
text("""
|
|
1664
|
+
UPDATE horsies_tasks
|
|
1665
|
+
SET status = 'PENDING',
|
|
1666
|
+
retry_count = :retry_count,
|
|
1667
|
+
next_retry_at = :next_retry_at,
|
|
1668
|
+
sent_at = :next_retry_at,
|
|
1669
|
+
updated_at = now()
|
|
1670
|
+
WHERE id = :id
|
|
1671
|
+
"""),
|
|
1672
|
+
{'id': task_id, 'retry_count': retry_count, 'next_retry_at': next_retry_at},
|
|
1673
|
+
)
|
|
1674
|
+
|
|
1675
|
+
# Schedule a delayed notification using asyncio for the task's actual queue
|
|
1676
|
+
queue_name = await self._get_task_queue_name(task_id)
|
|
1677
|
+
asyncio.create_task(
|
|
1678
|
+
self._schedule_delayed_notification(
|
|
1679
|
+
delay_seconds, f'task_queue_{queue_name}', f'retry:{task_id}'
|
|
1680
|
+
)
|
|
1681
|
+
)
|
|
1682
|
+
|
|
1683
|
+
logger.info(
|
|
1684
|
+
f'Scheduled task {task_id} for retry #{retry_count} at {next_retry_at}'
|
|
1685
|
+
)
|
|
1686
|
+
|
|
1687
|
+
def _calculate_retry_delay(
|
|
1688
|
+
self, retry_attempt: int, retry_policy_data: dict[str, Any]
|
|
1689
|
+
) -> int:
|
|
1690
|
+
"""Calculate the delay in seconds for a retry attempt."""
|
|
1691
|
+
# Default retry policy values (these match RetryPolicy defaults)
|
|
1692
|
+
intervals = retry_policy_data.get(
|
|
1693
|
+
'intervals', [60, 300, 900]
|
|
1694
|
+
) # 1min, 5min, 15min
|
|
1695
|
+
backoff_strategy = retry_policy_data.get('backoff_strategy', 'fixed')
|
|
1696
|
+
jitter = retry_policy_data.get('jitter', True)
|
|
1697
|
+
|
|
1698
|
+
# Calculate base delay based on strategy
|
|
1699
|
+
if backoff_strategy == 'fixed':
|
|
1700
|
+
# Fixed strategy: use intervals directly (validation ensures length matches max_retries)
|
|
1701
|
+
base_delay = intervals[retry_attempt - 1] # Safe due to validation
|
|
1702
|
+
|
|
1703
|
+
elif backoff_strategy == 'exponential':
|
|
1704
|
+
# Exponential strategy: use intervals[0] as base and apply exponential multiplier
|
|
1705
|
+
base_interval = intervals[0] # Safe due to validation (exactly 1 interval)
|
|
1706
|
+
base_delay = base_interval * (2 ** (retry_attempt - 1))
|
|
1707
|
+
|
|
1708
|
+
else:
|
|
1709
|
+
# Fallback (shouldn't happen due to Literal type validation)
|
|
1710
|
+
base_delay = intervals[0] if intervals else 60
|
|
1711
|
+
|
|
1712
|
+
# Apply jitter (±25% randomization)
|
|
1713
|
+
if jitter:
|
|
1714
|
+
jitter_range = base_delay * 0.25
|
|
1715
|
+
base_delay += random.uniform(-jitter_range, jitter_range)
|
|
1716
|
+
|
|
1717
|
+
return max(1, int(base_delay)) # Ensure at least 1 second delay
|
|
1718
|
+
|
|
1719
|
+
async def _schedule_delayed_notification(
|
|
1720
|
+
self, delay_seconds: int, channel: str, payload: str
|
|
1721
|
+
) -> None:
|
|
1722
|
+
"""Schedule a delayed notification to wake up the worker for retry."""
|
|
1723
|
+
try:
|
|
1724
|
+
await asyncio.sleep(delay_seconds)
|
|
1725
|
+
|
|
1726
|
+
# Send notification to trigger retry processing
|
|
1727
|
+
async with self.sf() as session:
|
|
1728
|
+
await session.execute(
|
|
1729
|
+
text('SELECT pg_notify(:channel, :payload)'),
|
|
1730
|
+
{'channel': channel, 'payload': payload},
|
|
1731
|
+
)
|
|
1732
|
+
await session.commit()
|
|
1733
|
+
|
|
1734
|
+
logger.debug(f'Sent delayed notification for retry: {payload}')
|
|
1735
|
+
except asyncio.CancelledError:
|
|
1736
|
+
logger.debug(f'Delayed notification cancelled for: {payload}')
|
|
1737
|
+
except Exception as e:
|
|
1738
|
+
logger.error(f'Error sending delayed notification for {payload}: {e}')
|
|
1739
|
+
|
|
1740
|
+
async def _get_task_queue_name(self, task_id: str) -> str:
|
|
1741
|
+
"""Fetch the queue_name for a given task id."""
|
|
1742
|
+
async with self.sf() as session:
|
|
1743
|
+
res = await session.execute(
|
|
1744
|
+
text('SELECT queue_name FROM horsies_tasks WHERE id = :id'),
|
|
1745
|
+
{'id': task_id},
|
|
1746
|
+
)
|
|
1747
|
+
row = res.fetchone()
|
|
1748
|
+
return str(row[0]) if row and row[0] else 'default'
|
|
1749
|
+
|
|
1750
|
+
async def _claimer_heartbeat_loop(self) -> None:
|
|
1751
|
+
"""Emit claimer heartbeats for tasks we've claimed but not yet started."""
|
|
1752
|
+
# Get heartbeat interval from recovery config (milliseconds)
|
|
1753
|
+
claimer_heartbeat_interval_ms = 30_000 # default: 30 seconds
|
|
1754
|
+
if self.cfg.recovery_config:
|
|
1755
|
+
claimer_heartbeat_interval_ms = (
|
|
1756
|
+
self.cfg.recovery_config.claimer_heartbeat_interval_ms
|
|
1757
|
+
)
|
|
1758
|
+
|
|
1759
|
+
try:
|
|
1760
|
+
while not self._stop.is_set():
|
|
1761
|
+
try:
|
|
1762
|
+
async with self.sf() as s:
|
|
1763
|
+
await s.execute(
|
|
1764
|
+
text(
|
|
1765
|
+
"""
|
|
1766
|
+
INSERT INTO horsies_heartbeats (task_id, sender_id, role, sent_at, hostname, pid)
|
|
1767
|
+
SELECT id, CAST(:wid AS VARCHAR), 'claimer', NOW(), :host, :pid
|
|
1768
|
+
FROM horsies_tasks
|
|
1769
|
+
WHERE status = 'CLAIMED' AND claimed_by_worker_id = CAST(:wid AS VARCHAR)
|
|
1770
|
+
"""
|
|
1771
|
+
),
|
|
1772
|
+
{
|
|
1773
|
+
'wid': self.worker_instance_id,
|
|
1774
|
+
'host': socket.gethostname(),
|
|
1775
|
+
'pid': os.getpid(),
|
|
1776
|
+
},
|
|
1777
|
+
)
|
|
1778
|
+
await s.commit()
|
|
1779
|
+
except Exception as e:
|
|
1780
|
+
logger.error(f'Claimer heartbeat error: {e}')
|
|
1781
|
+
# Convert to seconds only for asyncio.sleep
|
|
1782
|
+
await asyncio.sleep(claimer_heartbeat_interval_ms / 1000.0)
|
|
1783
|
+
except asyncio.CancelledError:
|
|
1784
|
+
return
|
|
1785
|
+
|
|
1786
|
+
async def _update_worker_state(self) -> None:
|
|
1787
|
+
"""Update worker state snapshot in database for monitoring."""
|
|
1788
|
+
import psutil
|
|
1789
|
+
|
|
1790
|
+
try:
|
|
1791
|
+
process = psutil.Process()
|
|
1792
|
+
memory_info = process.memory_info()
|
|
1793
|
+
|
|
1794
|
+
# Get current task counts
|
|
1795
|
+
running = await self._count_only_running_for_worker()
|
|
1796
|
+
claimed = await self._count_claimed_for_worker()
|
|
1797
|
+
|
|
1798
|
+
# Serialize recovery config
|
|
1799
|
+
recovery_dict = None
|
|
1800
|
+
if self.cfg.recovery_config:
|
|
1801
|
+
recovery_dict = {
|
|
1802
|
+
'auto_requeue_stale_claimed': self.cfg.recovery_config.auto_requeue_stale_claimed,
|
|
1803
|
+
'claimed_stale_threshold_ms': self.cfg.recovery_config.claimed_stale_threshold_ms,
|
|
1804
|
+
'auto_fail_stale_running': self.cfg.recovery_config.auto_fail_stale_running,
|
|
1805
|
+
'running_stale_threshold_ms': self.cfg.recovery_config.running_stale_threshold_ms,
|
|
1806
|
+
'check_interval_ms': self.cfg.recovery_config.check_interval_ms,
|
|
1807
|
+
'runner_heartbeat_interval_ms': self.cfg.recovery_config.runner_heartbeat_interval_ms,
|
|
1808
|
+
'claimer_heartbeat_interval_ms': self.cfg.recovery_config.claimer_heartbeat_interval_ms,
|
|
1809
|
+
}
|
|
1810
|
+
|
|
1811
|
+
async with self.sf() as s:
|
|
1812
|
+
await s.execute(
|
|
1813
|
+
text("""
|
|
1814
|
+
INSERT INTO horsies_worker_states (
|
|
1815
|
+
worker_id, snapshot_at, hostname, pid,
|
|
1816
|
+
processes, max_claim_batch, max_claim_per_worker,
|
|
1817
|
+
cluster_wide_cap, queues, queue_priorities, queue_max_concurrency,
|
|
1818
|
+
recovery_config, tasks_running, tasks_claimed,
|
|
1819
|
+
memory_usage_mb, memory_percent, cpu_percent,
|
|
1820
|
+
worker_started_at
|
|
1821
|
+
)
|
|
1822
|
+
VALUES (
|
|
1823
|
+
:wid, NOW(), :host, :pid, :procs, :mcb, :mcpw, :cwc,
|
|
1824
|
+
:queues, :qp, :qmc, :recovery, :running, :claimed,
|
|
1825
|
+
:mem_mb, :mem_pct, :cpu_pct, :started
|
|
1826
|
+
)
|
|
1827
|
+
"""),
|
|
1828
|
+
{
|
|
1829
|
+
'wid': self.worker_instance_id,
|
|
1830
|
+
'host': socket.gethostname(),
|
|
1831
|
+
'pid': os.getpid(),
|
|
1832
|
+
'procs': self.cfg.processes,
|
|
1833
|
+
'mcb': self.cfg.max_claim_batch,
|
|
1834
|
+
'mcpw': self.cfg.max_claim_per_worker
|
|
1835
|
+
if self.cfg.max_claim_per_worker > 0
|
|
1836
|
+
else self.cfg.processes,
|
|
1837
|
+
'cwc': self.cfg.cluster_wide_cap,
|
|
1838
|
+
'queues': self.cfg.queues,
|
|
1839
|
+
'qp': Jsonb(self.cfg.queue_priorities)
|
|
1840
|
+
if self.cfg.queue_priorities
|
|
1841
|
+
else None,
|
|
1842
|
+
'qmc': Jsonb(self.cfg.queue_max_concurrency)
|
|
1843
|
+
if self.cfg.queue_max_concurrency
|
|
1844
|
+
else None,
|
|
1845
|
+
'recovery': Jsonb(recovery_dict) if recovery_dict else None,
|
|
1846
|
+
'running': running,
|
|
1847
|
+
'claimed': claimed,
|
|
1848
|
+
'mem_mb': memory_info.rss / 1024 / 1024,
|
|
1849
|
+
'mem_pct': process.memory_percent(),
|
|
1850
|
+
'cpu_pct': process.cpu_percent(interval=0.1),
|
|
1851
|
+
'started': self._started_at,
|
|
1852
|
+
},
|
|
1853
|
+
)
|
|
1854
|
+
await s.commit()
|
|
1855
|
+
except Exception as e:
|
|
1856
|
+
logger.error(f'Failed to update worker state: {e}')
|
|
1857
|
+
|
|
1858
|
+
async def _worker_state_heartbeat_loop(self) -> None:
|
|
1859
|
+
"""Periodically update worker state for monitoring (every 5 seconds)."""
|
|
1860
|
+
worker_state_interval_ms = 5_000 # 5 seconds
|
|
1861
|
+
|
|
1862
|
+
try:
|
|
1863
|
+
while not self._stop.is_set():
|
|
1864
|
+
try:
|
|
1865
|
+
await self._update_worker_state()
|
|
1866
|
+
except Exception as e:
|
|
1867
|
+
logger.error(f'Worker state heartbeat error: {e}')
|
|
1868
|
+
|
|
1869
|
+
# Wait for interval or stop signal
|
|
1870
|
+
await asyncio.sleep(worker_state_interval_ms / 1000.0)
|
|
1871
|
+
except asyncio.CancelledError:
|
|
1872
|
+
return
|
|
1873
|
+
|
|
1874
|
+
async def _reaper_loop(self) -> None:
|
|
1875
|
+
"""Automatic stale task handling loop.
|
|
1876
|
+
|
|
1877
|
+
Periodically checks for and recovers stale tasks based on RecoveryConfig:
|
|
1878
|
+
- Requeues tasks stuck in CLAIMED (safe - user code never ran)
|
|
1879
|
+
- Marks stale RUNNING tasks as FAILED (not safe to requeue)
|
|
1880
|
+
"""
|
|
1881
|
+
from horsies.core.brokers.postgres import PostgresBroker
|
|
1882
|
+
|
|
1883
|
+
if not self.cfg.recovery_config:
|
|
1884
|
+
return
|
|
1885
|
+
|
|
1886
|
+
recovery_cfg = self.cfg.recovery_config
|
|
1887
|
+
check_interval_ms = recovery_cfg.check_interval_ms
|
|
1888
|
+
temp_broker = None
|
|
1889
|
+
|
|
1890
|
+
logger.info(
|
|
1891
|
+
f'Reaper configuration: auto_requeue_claimed={recovery_cfg.auto_requeue_stale_claimed}, '
|
|
1892
|
+
f'auto_fail_running={recovery_cfg.auto_fail_stale_running}, '
|
|
1893
|
+
f'check_interval={check_interval_ms}ms ({check_interval_ms/1000:.1f}s)'
|
|
1894
|
+
)
|
|
1895
|
+
|
|
1896
|
+
try:
|
|
1897
|
+
from horsies.core.models.broker import PostgresConfig
|
|
1898
|
+
|
|
1899
|
+
temp_broker_config = PostgresConfig(database_url=self.cfg.dsn)
|
|
1900
|
+
temp_broker = PostgresBroker(temp_broker_config)
|
|
1901
|
+
if self._app is not None:
|
|
1902
|
+
temp_broker.app = self._app
|
|
1903
|
+
|
|
1904
|
+
while not self._stop.is_set():
|
|
1905
|
+
try:
|
|
1906
|
+
# Auto-requeue stale CLAIMED tasks
|
|
1907
|
+
if recovery_cfg.auto_requeue_stale_claimed:
|
|
1908
|
+
requeued = await temp_broker.requeue_stale_claimed(
|
|
1909
|
+
stale_threshold_ms=recovery_cfg.claimed_stale_threshold_ms
|
|
1910
|
+
)
|
|
1911
|
+
if requeued > 0:
|
|
1912
|
+
logger.info(
|
|
1913
|
+
f'Reaper requeued {requeued} stale CLAIMED task(s)'
|
|
1914
|
+
)
|
|
1915
|
+
|
|
1916
|
+
# Auto-fail stale RUNNING tasks
|
|
1917
|
+
if recovery_cfg.auto_fail_stale_running:
|
|
1918
|
+
failed = await temp_broker.mark_stale_tasks_as_failed(
|
|
1919
|
+
stale_threshold_ms=recovery_cfg.running_stale_threshold_ms
|
|
1920
|
+
)
|
|
1921
|
+
if failed > 0:
|
|
1922
|
+
logger.warning(
|
|
1923
|
+
f'Reaper marked {failed} stale RUNNING task(s) as FAILED'
|
|
1924
|
+
)
|
|
1925
|
+
|
|
1926
|
+
# Recover stuck workflows
|
|
1927
|
+
try:
|
|
1928
|
+
from horsies.core.workflows.recovery import (
|
|
1929
|
+
recover_stuck_workflows,
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
async with temp_broker.session_factory() as s:
|
|
1933
|
+
recovered = await recover_stuck_workflows(s, temp_broker)
|
|
1934
|
+
if recovered > 0:
|
|
1935
|
+
logger.info(
|
|
1936
|
+
f'Reaper recovered {recovered} stuck workflow task(s)'
|
|
1937
|
+
)
|
|
1938
|
+
await s.commit()
|
|
1939
|
+
except Exception as wf_err:
|
|
1940
|
+
logger.error(f'Workflow recovery error: {wf_err}')
|
|
1941
|
+
|
|
1942
|
+
except Exception as e:
|
|
1943
|
+
logger.error(f'Reaper loop error: {e}')
|
|
1944
|
+
|
|
1945
|
+
# Wait for check interval or stop signal (convert ms to seconds for asyncio)
|
|
1946
|
+
try:
|
|
1947
|
+
await asyncio.wait_for(
|
|
1948
|
+
self._stop.wait(), timeout=check_interval_ms / 1000.0
|
|
1949
|
+
)
|
|
1950
|
+
break # Stop signal received
|
|
1951
|
+
except asyncio.TimeoutError:
|
|
1952
|
+
continue # Continue to next iteration
|
|
1953
|
+
|
|
1954
|
+
except asyncio.CancelledError:
|
|
1955
|
+
logger.info('Reaper loop cancelled')
|
|
1956
|
+
return
|
|
1957
|
+
finally:
|
|
1958
|
+
if temp_broker is not None:
|
|
1959
|
+
try:
|
|
1960
|
+
await temp_broker.close_async()
|
|
1961
|
+
except Exception as e:
|
|
1962
|
+
logger.error(f'Error closing reaper broker: {e}')
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
"""
|
|
1966
|
+
horsies examples/instance.py worker --loglevel=info --processes=8
|
|
1967
|
+
"""
|