horsies 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. horsies/__init__.py +117 -0
  2. horsies/core/__init__.py +0 -0
  3. horsies/core/app.py +552 -0
  4. horsies/core/banner.py +144 -0
  5. horsies/core/brokers/__init__.py +5 -0
  6. horsies/core/brokers/listener.py +444 -0
  7. horsies/core/brokers/postgres.py +993 -0
  8. horsies/core/cli.py +624 -0
  9. horsies/core/codec/serde.py +596 -0
  10. horsies/core/errors.py +535 -0
  11. horsies/core/logging.py +90 -0
  12. horsies/core/models/__init__.py +0 -0
  13. horsies/core/models/app.py +268 -0
  14. horsies/core/models/broker.py +79 -0
  15. horsies/core/models/queues.py +23 -0
  16. horsies/core/models/recovery.py +101 -0
  17. horsies/core/models/schedule.py +229 -0
  18. horsies/core/models/task_pg.py +307 -0
  19. horsies/core/models/tasks.py +358 -0
  20. horsies/core/models/workflow.py +1990 -0
  21. horsies/core/models/workflow_pg.py +245 -0
  22. horsies/core/registry/tasks.py +101 -0
  23. horsies/core/scheduler/__init__.py +26 -0
  24. horsies/core/scheduler/calculator.py +267 -0
  25. horsies/core/scheduler/service.py +569 -0
  26. horsies/core/scheduler/state.py +260 -0
  27. horsies/core/task_decorator.py +656 -0
  28. horsies/core/types/status.py +38 -0
  29. horsies/core/utils/imports.py +203 -0
  30. horsies/core/utils/loop_runner.py +44 -0
  31. horsies/core/worker/current.py +17 -0
  32. horsies/core/worker/worker.py +1967 -0
  33. horsies/core/workflows/__init__.py +23 -0
  34. horsies/core/workflows/engine.py +2344 -0
  35. horsies/core/workflows/recovery.py +501 -0
  36. horsies/core/workflows/registry.py +97 -0
  37. horsies/py.typed +0 -0
  38. horsies-0.1.0a4.dist-info/METADATA +35 -0
  39. horsies-0.1.0a4.dist-info/RECORD +42 -0
  40. horsies-0.1.0a4.dist-info/WHEEL +5 -0
  41. horsies-0.1.0a4.dist-info/entry_points.txt +2 -0
  42. horsies-0.1.0a4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1967 @@
1
+ # app/core/worker.py
2
+ from __future__ import annotations
3
+ import asyncio
4
+ import uuid
5
+ import os
6
+ import random
7
+ import signal
8
+ import socket
9
+ import threading
10
+ import time
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from dataclasses import dataclass, field
13
+ from datetime import datetime, timezone, timedelta
14
+ from importlib import import_module
15
+ from typing import Any, Optional, Sequence, Tuple, cast
16
+ import atexit
17
+ import hashlib
18
+ from psycopg import Connection, Cursor, InterfaceError, OperationalError
19
+ from psycopg.errors import DeadlockDetected, SerializationFailure
20
+ from psycopg.types.json import Jsonb
21
+ from psycopg_pool import ConnectionPool
22
+ from sqlalchemy import text
23
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
24
+ from horsies.core.app import Horsies
25
+ from horsies.core.brokers.listener import PostgresListener
26
+ from horsies.core.codec.serde import (
27
+ loads_json,
28
+ json_to_args,
29
+ json_to_kwargs,
30
+ dumps_json,
31
+ SerializationError,
32
+ task_result_from_json,
33
+ )
34
+ from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
35
+ from horsies.core.errors import ConfigurationError, ErrorCode
36
+ from horsies.core.logging import get_logger
37
+ from horsies.core.worker.current import get_current_app, set_current_app
38
+ import sys
39
+ from horsies.core.models.recovery import RecoveryConfig
40
+ from horsies.core.utils.imports import import_file_path
41
+
42
+ logger = get_logger('worker')
43
+
44
+ # ---------- Per-process connection pool (initialized in child processes) ----------
45
+ _worker_pool: ConnectionPool | None = None
46
+
47
+
48
+ def _get_worker_pool() -> ConnectionPool:
49
+ """Get the per-process connection pool. Raises if not initialized."""
50
+ if _worker_pool is None:
51
+ raise RuntimeError(
52
+ 'Worker connection pool not initialized. '
53
+ 'This function must be called from a child worker process.'
54
+ )
55
+ return _worker_pool
56
+
57
+
58
+ def _cleanup_worker_pool() -> None:
59
+ """Clean up the connection pool on process exit."""
60
+ global _worker_pool
61
+ if _worker_pool is not None:
62
+ try:
63
+ _worker_pool.close()
64
+ except Exception:
65
+ pass
66
+ _worker_pool = None
67
+
68
+
69
+ def _initialize_worker_pool(database_url: str) -> None:
70
+ """
71
+ Initialize the per-process connection pool.
72
+
73
+ In production: Called by _child_initializer in spawned worker processes.
74
+ In tests: Can be called directly to set up the pool for direct _run_task_entry calls.
75
+ """
76
+ global _worker_pool
77
+ if _worker_pool is not None:
78
+ return # Already initialized
79
+ _worker_pool = ConnectionPool(
80
+ database_url,
81
+ min_size=1,
82
+ max_size=5,
83
+ max_lifetime=300.0,
84
+ open=True,
85
+ )
86
+ atexit.register(_cleanup_worker_pool)
87
+
88
+
89
+ def _default_str_list() -> list[str]:
90
+ return []
91
+
92
+
93
+ def _default_str_int_dict() -> dict[str, int]:
94
+ return {}
95
+
96
+
97
+ def _dedupe_paths(paths: Sequence[str]) -> list[str]:
98
+ unique: list[str] = []
99
+ seen: set[str] = set()
100
+ for path in paths:
101
+ if not path:
102
+ continue
103
+ if path in seen:
104
+ continue
105
+ seen.add(path)
106
+ unique.append(path)
107
+ return unique
108
+
109
+
110
+ def _debug_imports_enabled() -> bool:
111
+ return os.getenv('HORSIES_DEBUG_IMPORTS', '').strip() == '1'
112
+
113
+
114
+ def _debug_imports_log(message: str) -> None:
115
+ if _debug_imports_enabled():
116
+ logger.debug(message)
117
+
118
+
119
+ def _derive_sys_path_roots_from_file(file_path: str) -> list[str]:
120
+ """Derive sys.path roots from a file path.
121
+
122
+ Simple and explicit: just the file's parent directory.
123
+ No traversal for pyproject.toml to avoid monorepo collisions.
124
+ sys_path_roots from CLI are the authoritative source.
125
+ """
126
+ abs_path = os.path.realpath(file_path)
127
+ parent_dir = os.path.dirname(abs_path)
128
+ return [parent_dir] if parent_dir else []
129
+
130
+
131
+ def _build_sys_path_roots(
132
+ app_locator: str,
133
+ imports: Sequence[str],
134
+ extra_roots: Sequence[str],
135
+ ) -> list[str]:
136
+ roots: list[str] = []
137
+ for root in extra_roots:
138
+ if root:
139
+ roots.append(os.path.abspath(root))
140
+ if app_locator and ':' in app_locator:
141
+ mod_path = app_locator.split(':', 1)[0]
142
+ if mod_path.endswith('.py') or os.path.sep in mod_path:
143
+ roots.extend(_derive_sys_path_roots_from_file(mod_path))
144
+ for mod in imports:
145
+ if mod.endswith('.py') or os.path.sep in mod:
146
+ roots.extend(_derive_sys_path_roots_from_file(mod))
147
+ return _dedupe_paths(roots)
148
+
149
+
150
+ def import_by_path(path: str, module_name: str | None = None) -> Any:
151
+ """Import module from file path."""
152
+ return import_file_path(path, module_name)
153
+
154
+
155
+ # ---------- helpers for child processes ----------
156
+ def _locate_app(app_locator: str) -> Horsies:
157
+ """
158
+ app_locator examples:
159
+ - 'package.module:app' -> import module, take variable attr
160
+ - 'package.module:create_app' -> call factory
161
+ - '/abs/path/to/file.py:app' -> load from file path
162
+ """
163
+ logger.info(f'Locating app from {app_locator}')
164
+ if not app_locator or ':' not in app_locator:
165
+ raise ConfigurationError(
166
+ message='invalid app locator format',
167
+ code=ErrorCode.WORKER_INVALID_LOCATOR,
168
+ notes=[f'got: {app_locator!r}'],
169
+ help_text="use 'module.path:app' or '/path/to/file.py:app'",
170
+ )
171
+ mod_path, _, attr = app_locator.partition(':')
172
+ if mod_path.endswith('.py') or os.path.sep in mod_path:
173
+ mod = import_by_path(mod_path)
174
+ else:
175
+ mod = import_module(mod_path)
176
+ obj = getattr(mod, attr)
177
+ candidate = obj() if callable(obj) else obj
178
+ if not isinstance(candidate, Horsies):
179
+ raise ConfigurationError(
180
+ message='app locator did not resolve to Horsies instance',
181
+ code=ErrorCode.WORKER_INVALID_LOCATOR,
182
+ notes=[
183
+ f'locator: {app_locator!r}',
184
+ f'resolved to: {type(candidate).__name__}',
185
+ ],
186
+ help_text='ensure the locator points to a Horsies app instance or factory',
187
+ )
188
+ return candidate
189
+
190
+
191
+ def _child_initializer(
192
+ app_locator: str,
193
+ imports: Sequence[str],
194
+ sys_path_roots: Sequence[str],
195
+ loglevel: int,
196
+ database_url: str,
197
+ ) -> None:
198
+ global _worker_pool
199
+
200
+ # Ignore SIGINT in child processes - let the parent handle shutdown gracefully
201
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
202
+
203
+ # Set log level for this child process before any logging
204
+ from horsies.core.logging import set_default_level
205
+ set_default_level(loglevel)
206
+
207
+ # Mark child process to adjust logging behavior during module import
208
+ os.environ['HORSIES_CHILD_PROCESS'] = '1'
209
+ app_mod_path = app_locator.split(':', 1)[0]
210
+ app_mod_abs = (
211
+ os.path.abspath(app_mod_path) if app_mod_path.endswith('.py') else None
212
+ )
213
+ sys_path_roots_resolved = _build_sys_path_roots(
214
+ app_locator, imports, sys_path_roots
215
+ )
216
+ _debug_imports_log(
217
+ f'[child {os.getpid()}] app_locator={app_locator!r} sys_path_roots={sys_path_roots_resolved}'
218
+ )
219
+ for root in sys_path_roots_resolved:
220
+ if root not in sys.path:
221
+ sys.path.insert(0, root)
222
+
223
+ app = _locate_app(app_locator) # uses import_by_path -> loads as 'instance'
224
+ set_current_app(app)
225
+
226
+ # Suppress accidental sends while importing modules for discovery
227
+ try:
228
+ app.suppress_sends(True)
229
+ except Exception:
230
+ pass
231
+
232
+ try:
233
+ combined_imports = list(imports)
234
+ try:
235
+ combined_imports.extend(app.get_discovered_task_modules())
236
+ except Exception:
237
+ pass
238
+ combined_imports = _dedupe_paths(combined_imports)
239
+ _debug_imports_log(
240
+ f'[child {os.getpid()}] import_modules={combined_imports}'
241
+ )
242
+ for m in combined_imports:
243
+ if m.endswith('.py') or os.path.sep in m:
244
+ m_abs = os.path.abspath(m)
245
+ if app_mod_abs and os.path.samefile(m_abs, app_mod_abs):
246
+ continue # don't import the app file again
247
+ import_by_path(m_abs)
248
+ else:
249
+ import_module(m)
250
+ finally:
251
+ try:
252
+ app.suppress_sends(False)
253
+ except Exception:
254
+ pass
255
+
256
+ # optional: sanity log
257
+ try:
258
+ keys = app.tasks.keys_list()
259
+ except Exception:
260
+ keys = list(app.tasks.keys())
261
+ _debug_imports_log(f'[child {os.getpid()}] registered_tasks={keys}')
262
+
263
+ # Initialize per-process connection pool (after all imports complete)
264
+ _initialize_worker_pool(database_url)
265
+ logger.debug(f'[child {os.getpid()}] Connection pool initialized')
266
+
267
+
268
+ # ---------- Child-process execution ----------
269
+ def _heartbeat_worker(
270
+ task_id: str,
271
+ database_url: str,
272
+ stop_event: threading.Event,
273
+ sender_worker_id: str,
274
+ heartbeat_interval_ms: int = 30_000,
275
+ ) -> None:
276
+ """
277
+ Runs in a separate thread within the task process.
278
+ Sends heartbeats at configured interval until stopped or process dies.
279
+
280
+ Uses the per-process connection pool for efficient connection reuse.
281
+
282
+ Args:
283
+ task_id: The task ID
284
+ database_url: Database connection string (unused, kept for compatibility)
285
+ stop_event: Event to signal thread termination
286
+ sender_worker_id: Worker instance ID
287
+ heartbeat_interval_ms: Milliseconds between heartbeats (default: 30000 = 30s)
288
+ """
289
+ # Convert to seconds only for threading.Event.wait()
290
+ heartbeat_interval_seconds = heartbeat_interval_ms / 1000.0
291
+
292
+ def send_heartbeat() -> bool:
293
+ try:
294
+ pool = _get_worker_pool()
295
+ with pool.connection() as conn:
296
+ cursor = conn.cursor()
297
+ cursor.execute(
298
+ """
299
+ INSERT INTO horsies_heartbeats (task_id, sender_id, role, sent_at, hostname, pid)
300
+ VALUES (%s, %s, 'runner', NOW(), %s, %s)
301
+ """,
302
+ (
303
+ task_id,
304
+ sender_worker_id + f':{os.getpid()}',
305
+ socket.gethostname(),
306
+ os.getpid(),
307
+ ),
308
+ )
309
+ conn.commit()
310
+ cursor.close()
311
+ return True
312
+ except Exception as e:
313
+ logger.error(f'Heartbeat failed for task {task_id}: {e}')
314
+ return False
315
+
316
+ # Send an immediate heartbeat so freshly RUNNING tasks aren't considered stale
317
+ send_heartbeat()
318
+
319
+ while not stop_event.is_set():
320
+ # Wait for interval, but check stop_event periodically
321
+ if stop_event.wait(timeout=heartbeat_interval_seconds):
322
+ break # stop_event was set
323
+
324
+ # Use pooled connection for heartbeat
325
+ if not send_heartbeat():
326
+ break
327
+
328
+
329
+ def _is_retryable_db_error(exc: BaseException) -> bool:
330
+ match exc:
331
+ case OperationalError() | InterfaceError() | SerializationFailure() | DeadlockDetected():
332
+ return True
333
+ case _:
334
+ return False
335
+
336
+
337
+ def _get_workflow_status_for_task(cursor: Cursor[Any], task_id: str) -> str | None:
338
+ cursor.execute(
339
+ """
340
+ SELECT w.status FROM horsies_workflows w
341
+ JOIN horsies_workflow_tasks wt ON wt.workflow_id = w.id
342
+ WHERE wt.task_id = %s
343
+ """,
344
+ (task_id,),
345
+ )
346
+ row = cursor.fetchone()
347
+ if row is None:
348
+ return None
349
+ status = row[0]
350
+ if isinstance(status, str):
351
+ return status
352
+ return None
353
+
354
+
355
+ def _mark_task_skipped_for_workflow_stop(
356
+ cursor: Cursor[Any],
357
+ conn: Connection[Any],
358
+ task_id: str,
359
+ workflow_status: str,
360
+ ) -> Tuple[bool, str, Optional[str]]:
361
+ logger.info(f'Skipping task {task_id} - workflow is {workflow_status}')
362
+ cursor.execute(
363
+ """
364
+ UPDATE horsies_workflow_tasks
365
+ SET status = 'SKIPPED'
366
+ WHERE task_id = %s AND status IN ('ENQUEUED', 'READY')
367
+ """,
368
+ (task_id,),
369
+ )
370
+ result: TaskResult[Any, TaskError] = TaskResult(
371
+ err=TaskError(
372
+ error_code='WORKFLOW_STOPPED',
373
+ message=f'Task skipped - workflow is {workflow_status}',
374
+ )
375
+ )
376
+ result_json = dumps_json(result)
377
+ cursor.execute(
378
+ """
379
+ UPDATE horsies_tasks
380
+ SET status = 'COMPLETED',
381
+ result = %s,
382
+ completed_at = NOW(),
383
+ updated_at = NOW()
384
+ WHERE id = %s
385
+ """,
386
+ (result_json, task_id),
387
+ )
388
+ conn.commit()
389
+ return (True, result_json, f'Workflow {workflow_status}')
390
+
391
+
392
+ def _update_workflow_task_running_with_retry(task_id: str) -> None:
393
+ backoff_seconds = (0.0, 0.25, 0.75)
394
+ total_attempts = len(backoff_seconds)
395
+ for attempt_index, delay in enumerate(backoff_seconds):
396
+ if delay > 0:
397
+ time.sleep(delay)
398
+ try:
399
+ pool = _get_worker_pool()
400
+ with pool.connection() as conn:
401
+ cursor = conn.cursor()
402
+ cursor.execute(
403
+ """
404
+ UPDATE horsies_workflow_tasks wt
405
+ SET status = 'RUNNING', started_at = NOW()
406
+ FROM horsies_workflows w
407
+ WHERE wt.task_id = %s
408
+ AND wt.status = 'ENQUEUED'
409
+ AND wt.workflow_id = w.id
410
+ AND w.status = 'RUNNING'
411
+ """,
412
+ (task_id,),
413
+ )
414
+ conn.commit()
415
+ return
416
+ except Exception as exc:
417
+ retryable = _is_retryable_db_error(exc)
418
+ is_last_attempt = attempt_index == total_attempts - 1
419
+ match (retryable, is_last_attempt):
420
+ case (True, True):
421
+ logger.error(
422
+ f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
423
+ )
424
+ return
425
+ case (True, False):
426
+ logger.warning(
427
+ f'Retrying workflow_tasks RUNNING update for task {task_id}: {exc}'
428
+ )
429
+ continue
430
+ case (False, _):
431
+ logger.error(
432
+ f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
433
+ )
434
+ return
435
+ case _:
436
+ logger.error(
437
+ f'Failed to update workflow_tasks to RUNNING for task {task_id}: {exc}'
438
+ )
439
+ return
440
+
441
+
442
+ def _preflight_workflow_check(
443
+ task_id: str,
444
+ ) -> Optional[Tuple[bool, str, Optional[str]]]:
445
+ try:
446
+ pool = _get_worker_pool()
447
+ with pool.connection() as conn:
448
+ cursor = conn.cursor()
449
+ workflow_status = _get_workflow_status_for_task(cursor, task_id)
450
+ match workflow_status:
451
+ case 'PAUSED' | 'CANCELLED':
452
+ return _mark_task_skipped_for_workflow_stop(
453
+ cursor,
454
+ conn,
455
+ task_id,
456
+ workflow_status,
457
+ )
458
+ case _:
459
+ return None
460
+ except Exception as e:
461
+ logger.error(f'Failed to check workflow status for task {task_id}: {e}')
462
+ return (False, '', 'WORKFLOW_CHECK_FAILED')
463
+
464
+
465
+ def _confirm_ownership_and_set_running(
466
+ task_id: str,
467
+ worker_id: str,
468
+ ) -> Optional[Tuple[bool, str, Optional[str]]]:
469
+ try:
470
+ pool = _get_worker_pool()
471
+ with pool.connection() as conn:
472
+ cursor = conn.cursor()
473
+ cursor.execute(
474
+ """
475
+ UPDATE horsies_tasks
476
+ SET status = 'RUNNING',
477
+ claimed = FALSE,
478
+ claim_expires_at = NULL,
479
+ started_at = NOW(),
480
+ worker_pid = %s,
481
+ worker_hostname = %s,
482
+ worker_process_name = %s,
483
+ updated_at = NOW()
484
+ WHERE id = %s
485
+ AND status = 'CLAIMED'
486
+ AND claimed_by_worker_id = %s
487
+ AND (claim_expires_at IS NULL OR claim_expires_at > now())
488
+ AND NOT EXISTS (
489
+ SELECT 1
490
+ FROM horsies_workflow_tasks wt
491
+ JOIN horsies_workflows w ON w.id = wt.workflow_id
492
+ WHERE wt.task_id = %s
493
+ AND w.status IN ('PAUSED', 'CANCELLED')
494
+ )
495
+ RETURNING id
496
+ """,
497
+ (
498
+ os.getpid(),
499
+ socket.gethostname(),
500
+ f'worker-{os.getpid()}',
501
+ task_id,
502
+ worker_id,
503
+ task_id,
504
+ ),
505
+ )
506
+ updated_row = cursor.fetchone()
507
+ if updated_row is None:
508
+ workflow_status = _get_workflow_status_for_task(cursor, task_id)
509
+ match workflow_status:
510
+ case 'PAUSED' | 'CANCELLED':
511
+ return _mark_task_skipped_for_workflow_stop(
512
+ cursor,
513
+ conn,
514
+ task_id,
515
+ workflow_status,
516
+ )
517
+ case _:
518
+ conn.rollback()
519
+ logger.warning(
520
+ f'Task {task_id} ownership lost - claim was reclaimed or status changed. '
521
+ f'Aborting execution to prevent double-execution.'
522
+ )
523
+ return (False, '', 'CLAIM_LOST')
524
+
525
+ conn.commit()
526
+
527
+ _update_workflow_task_running_with_retry(task_id)
528
+ return None
529
+ except Exception as e:
530
+ logger.error(f'Failed to transition task {task_id} to RUNNING: {e}')
531
+ return (False, '', 'OWNERSHIP_UNCONFIRMED')
532
+
533
+
534
+ def _start_heartbeat_thread(
535
+ task_id: str,
536
+ database_url: str,
537
+ heartbeat_stop_event: threading.Event,
538
+ worker_id: str,
539
+ runner_heartbeat_interval_ms: int,
540
+ ) -> threading.Thread:
541
+ heartbeat_thread = threading.Thread(
542
+ target=_heartbeat_worker,
543
+ args=(
544
+ task_id,
545
+ database_url,
546
+ heartbeat_stop_event,
547
+ worker_id,
548
+ runner_heartbeat_interval_ms,
549
+ ),
550
+ daemon=True,
551
+ name=f'heartbeat-{task_id[:8]}',
552
+ )
553
+ heartbeat_thread.start()
554
+ return heartbeat_thread
555
+
556
+
557
+ def _run_task_entry(
558
+ task_name: str,
559
+ args_json: Optional[str],
560
+ kwargs_json: Optional[str],
561
+ task_id: str,
562
+ database_url: str,
563
+ master_worker_id: str,
564
+ runner_heartbeat_interval_ms: int = 30_000,
565
+ ) -> Tuple[bool, str, Optional[str]]:
566
+ """
567
+ Child-process entry.
568
+ Returns:
569
+ (ok, serialized_task_result_json, worker_failure_reason)
570
+
571
+ ok=True: we produced a valid TaskResult JSON (success *or* failure).
572
+ ok=False: worker couldn't even produce a TaskResult (infra error).
573
+
574
+ Do not mistake the 'ok' here as the task's own success/failure.
575
+ It's only a signal that the worker was able to produce a valid JSON.
576
+ The task's own success/failure is determined by the TaskResult's ok/err fields.
577
+ """
578
+ logger.info(f'Starting task execution: {task_name}')
579
+
580
+ # Pre-execute guard: check if workflow is PAUSED or CANCELLED
581
+ # If so, skip execution and mark task as SKIPPED
582
+ preflight_result = _preflight_workflow_check(task_id)
583
+ if preflight_result is not None:
584
+ return preflight_result
585
+
586
+ # Mark as RUNNING in DB at the actual start of execution (child process)
587
+ # CRITICAL: Include ownership check to prevent double-execution when claim lease expires
588
+ ownership_result = _confirm_ownership_and_set_running(task_id, master_worker_id)
589
+ if ownership_result is not None:
590
+ return ownership_result
591
+
592
+ # Start heartbeat monitoring in separate thread
593
+ heartbeat_stop_event = threading.Event()
594
+ _start_heartbeat_thread(
595
+ task_id,
596
+ database_url,
597
+ heartbeat_stop_event,
598
+ master_worker_id,
599
+ runner_heartbeat_interval_ms,
600
+ )
601
+
602
+ try:
603
+ # Resolve from TaskRegistry
604
+ try:
605
+ app = get_current_app()
606
+ task = app.tasks[task_name] # may raise NotRegistered(KeyError)
607
+ except Exception as e:
608
+ logger.error(f'Failed to resolve task {task_name}: {e}')
609
+ tr: TaskResult[Any, TaskError] = TaskResult(
610
+ err=TaskError(
611
+ error_code=LibraryErrorCode.WORKER_RESOLUTION_ERROR,
612
+ message=f'Failed to resolve {task_name}: {type(e).__name__}: {e}',
613
+ data={'task_name': task_name},
614
+ )
615
+ )
616
+ # Stop heartbeat thread since task failed to resolve
617
+ heartbeat_stop_event.set()
618
+ logger.debug(
619
+ f'Heartbeat stop signal sent for task {task_id} (resolution failed)'
620
+ )
621
+ return (True, dumps_json(tr), f'{type(e).__name__}: {e}')
622
+
623
+ args = json_to_args(loads_json(args_json))
624
+ kwargs = json_to_kwargs(loads_json(kwargs_json))
625
+
626
+ # Deserialize injected TaskResults from workflow args_from
627
+ for key, value in list(kwargs.items()):
628
+ if (
629
+ isinstance(value, dict)
630
+ and '__horsies_taskresult__' in value
631
+ and value['__horsies_taskresult__']
632
+ ):
633
+ # value is a dict with "data" key containing serialized TaskResult
634
+ task_result_dict = cast(dict[str, Any], value)
635
+ data_str = task_result_dict.get('data')
636
+ if isinstance(data_str, str):
637
+ kwargs[key] = task_result_from_json(loads_json(data_str))
638
+
639
+ # Handle WorkflowContext injection (only if task declares workflow_ctx param)
640
+ workflow_ctx_data = kwargs.pop('__horsies_workflow_ctx__', None)
641
+ if workflow_ctx_data is not None:
642
+ import inspect
643
+
644
+ # Access the underlying function: TaskFunctionImpl stores it in _original_fn
645
+ underlying_fn = getattr(task, '_original_fn', getattr(task, '_fn', task))
646
+ sig = inspect.signature(underlying_fn)
647
+ if 'workflow_ctx' in sig.parameters:
648
+ from horsies.core.models.workflow import (
649
+ WorkflowContext,
650
+ SubWorkflowSummary,
651
+ )
652
+
653
+ # Reconstruct TaskResults from serialized data (node_id-based)
654
+ results_by_id_raw = workflow_ctx_data.get('results_by_id', {})
655
+
656
+ results_by_id: dict[str, Any] = {}
657
+ for node_id, result_json in results_by_id_raw.items():
658
+ if isinstance(result_json, str):
659
+ results_by_id[node_id] = task_result_from_json(
660
+ loads_json(result_json)
661
+ )
662
+
663
+ # Reconstruct SubWorkflowSummaries from serialized data (node_id-based)
664
+ summaries_by_id_raw = workflow_ctx_data.get('summaries_by_id', {})
665
+ summaries_by_id: dict[str, Any] = {}
666
+ for node_id, summary_json in summaries_by_id_raw.items():
667
+ if isinstance(summary_json, str):
668
+ parsed = loads_json(summary_json)
669
+ if isinstance(parsed, dict):
670
+ summaries_by_id[node_id] = SubWorkflowSummary.from_json(
671
+ parsed
672
+ )
673
+
674
+ kwargs['workflow_ctx'] = WorkflowContext.from_serialized(
675
+ workflow_id=workflow_ctx_data.get('workflow_id', ''),
676
+ task_index=workflow_ctx_data.get('task_index', 0),
677
+ task_name=workflow_ctx_data.get('task_name', ''),
678
+ results_by_id=results_by_id,
679
+ summaries_by_id=summaries_by_id,
680
+ )
681
+ # If task doesn't declare workflow_ctx, silently skip injection
682
+
683
+ out = task(*args, **kwargs) # __call__ returns TaskResult
684
+ logger.info(f'Task execution completed: {[task_id]} : {[task_name]}')
685
+
686
+ if isinstance(out, TaskResult):
687
+ return (True, dumps_json(out), None)
688
+
689
+ if out is None:
690
+ tr: TaskResult[Any, TaskError] = TaskResult(
691
+ err=TaskError(
692
+ error_code=LibraryErrorCode.TASK_EXCEPTION,
693
+ message=f'Task {task_name} returned None instead of TaskResult or value',
694
+ data={'task_name': task_name},
695
+ )
696
+ )
697
+ return (True, dumps_json(tr), 'Task returned None')
698
+
699
+ # Plain value → wrap into success
700
+ return (True, dumps_json(TaskResult(ok=out)), None)
701
+
702
+ except SerializationError as se:
703
+ logger.error(f'Serialization error for task {task_name}: {se}')
704
+ tr = TaskResult(
705
+ err=TaskError(
706
+ error_code=LibraryErrorCode.WORKER_SERIALIZATION_ERROR,
707
+ message=str(se),
708
+ data={'task_name': task_name},
709
+ )
710
+ )
711
+ return (True, dumps_json(tr), f'SerializationError: {se}')
712
+
713
+ except BaseException as e:
714
+ logger.error(f'Task exception: {task_name}: {e}')
715
+ # If the task raised an exception, we wrap it in a TaskError
716
+ tr = TaskResult(
717
+ err=TaskError(
718
+ error_code=LibraryErrorCode.TASK_EXCEPTION,
719
+ message=f'{type(e).__name__}: {e}',
720
+ data={'task_name': task_name},
721
+ exception=e, # pydantic will accept and we flatten on serialization if needed elsewhere
722
+ )
723
+ )
724
+ return (True, dumps_json(tr), None)
725
+
726
+ finally:
727
+ # Always stop the heartbeat thread when task completes (success/failure/error)
728
+ heartbeat_stop_event.set()
729
+ logger.debug(f'Heartbeat stop signal sent for task {task_id}')
730
+
731
+
732
+ # ---------- Claim SQL (priority + sent_at) ----------
733
+ # Supports both hard cap mode (claim_expires_at = NULL) and soft cap mode with lease.
734
+ # In soft cap mode, also reclaims tasks with expired claim leases.
735
+
736
+ CLAIM_SQL = text("""
737
+ WITH next AS (
738
+ SELECT id
739
+ FROM horsies_tasks
740
+ WHERE queue_name = :queue
741
+ AND (
742
+ -- Fresh pending tasks
743
+ status = 'PENDING'
744
+ -- OR expired claims (soft cap mode: lease expired, reclaim for this worker)
745
+ OR (status = 'CLAIMED' AND claim_expires_at IS NOT NULL AND claim_expires_at < now())
746
+ )
747
+ AND sent_at <= now()
748
+ AND (next_retry_at IS NULL OR next_retry_at <= now())
749
+ AND (good_until IS NULL OR good_until > now())
750
+ ORDER BY priority ASC, sent_at ASC, id ASC
751
+ FOR UPDATE SKIP LOCKED
752
+ LIMIT :lim
753
+ )
754
+ UPDATE horsies_tasks t
755
+ SET status = 'CLAIMED',
756
+ claimed = TRUE,
757
+ claimed_at = now(),
758
+ claimed_by_worker_id = :worker_id,
759
+ claim_expires_at = :claim_expires_at,
760
+ updated_at = now()
761
+ FROM next
762
+ WHERE t.id = next.id
763
+ RETURNING t.id;
764
+ """)
765
+
766
+ # Fetch rows for a list of ids (to get func_path/args/etc.)
767
+ LOAD_ROWS_SQL = text("""
768
+ SELECT id, task_name, args, kwargs, retry_count, max_retries, task_options
769
+ FROM horsies_tasks
770
+ WHERE id = ANY(:ids)
771
+ """)
772
+
773
+
774
+ # ---------- WorkerConfig ----------
775
+ @dataclass
776
+ class WorkerConfig:
777
+ dsn: str # SQLAlchemy async URL (e.g. postgresql+psycopg://...)
778
+ psycopg_dsn: str # plain psycopg URL for listener
779
+ queues: list[str] # which queues to serve
780
+ processes: int = os.cpu_count() or 2
781
+ # Claiming knobs
782
+ # max_claim_batch: Top-level fairness limiter to prevent worker starvation in multi-worker setups.
783
+ # Limits claims per queue per pass, regardless of available capacity. Increase for high-concurrency workloads.
784
+ max_claim_batch: int = 2
785
+ # max_claim_per_worker: Per-worker limit on total CLAIMED tasks to prevent over-claiming.
786
+ # 0 = auto (defaults to processes). Increase for deeper prefetch if tasks start very quickly.
787
+ max_claim_per_worker: int = 0
788
+ coalesce_notifies: int = 100 # drain up to N notes after wake
789
+ app_locator: str = '' # NEW (see _locate_app)
790
+ sys_path_roots: list[str] = field(default_factory=_default_str_list)
791
+ imports: list[str] = field(
792
+ default_factory=_default_str_list
793
+ ) # modules that contain @app.task defs
794
+ # When in CUSTOM mode, provide per-queue settings {name: {priority, max_concurrency}}
795
+ queue_priorities: dict[str, int] = field(default_factory=_default_str_int_dict)
796
+ queue_max_concurrency: dict[str, int] = field(default_factory=_default_str_int_dict)
797
+ cluster_wide_cap: Optional[int] = None
798
+ # Prefetch buffer: 0 = hard cap mode (count RUNNING + CLAIMED), >0 = soft cap with lease
799
+ prefetch_buffer: int = 0
800
+ # Claim lease duration in ms. Required when prefetch_buffer > 0.
801
+ claim_lease_ms: Optional[int] = None
802
+ # Recovery configuration from AppConfig
803
+ recovery_config: Optional['RecoveryConfig'] = (
804
+ None # RecoveryConfig, avoid circular import
805
+ )
806
+ # Log level for worker processes (default: INFO)
807
+ loglevel: int = 20 # logging.INFO
808
+
809
+
810
+ class Worker:
811
+ """
812
+ Async master that:
813
+ - Subscribes to queue channels
814
+ - Claims tasks (priority + sent_at) with SKIP LOCKED
815
+ - Executes in a process pool
816
+ - On completion, writes result/failed, COMMITs, and NOTIFY task_done
817
+ """
818
+
819
+ def __init__(
820
+ self,
821
+ session_factory: async_sessionmaker[AsyncSession],
822
+ listener: PostgresListener,
823
+ cfg: WorkerConfig,
824
+ ):
825
+ self.sf = session_factory
826
+ self.listener = listener
827
+ self.cfg = cfg
828
+ self.worker_instance_id = str(uuid.uuid4())
829
+ self._started_at = datetime.now(timezone.utc)
830
+ self._app: Horsies | None = None
831
+ # Delay creation of the process pool until after preloading modules so that
832
+ # any import/validation errors surface in the main process at startup.
833
+ self._executor: Optional[ProcessPoolExecutor] = None
834
+ self._stop = asyncio.Event()
835
+
836
+ def request_stop(self) -> None:
837
+ """Request worker to stop gracefully."""
838
+ self._stop.set()
839
+
840
+ # ----- lifecycle -----
841
+
842
+ async def start(self) -> None:
843
+ logger.debug('Starting worker')
844
+ # Preload the app and task modules in the main process to fail fast
845
+ self._preload_modules_main()
846
+
847
+ # Create the process pool AFTER successful preload so initializer runs in children only
848
+ # Compute the plain psycopg database URL for child processes
849
+ child_database_url = self.cfg.dsn.replace('+asyncpg', '').replace('+psycopg', '')
850
+ self._executor = ProcessPoolExecutor(
851
+ max_workers=self.cfg.processes,
852
+ initializer=_child_initializer,
853
+ initargs=(
854
+ self.cfg.app_locator,
855
+ self.cfg.imports,
856
+ self.cfg.sys_path_roots,
857
+ self.cfg.loglevel,
858
+ child_database_url,
859
+ ),
860
+ )
861
+ await self.listener.start()
862
+ # Surface concurrency configuration clearly for operators
863
+ max_claimed_effective = (
864
+ self.cfg.max_claim_per_worker
865
+ if self.cfg.max_claim_per_worker > 0
866
+ else self.cfg.processes
867
+ )
868
+ logger.info(
869
+ 'Concurrency config: processes=%s, cluster_wide_cap=%s, max_claim_per_worker=%s, max_claim_batch=%s',
870
+ self.cfg.processes,
871
+ (
872
+ self.cfg.cluster_wide_cap
873
+ if self.cfg.cluster_wide_cap is not None
874
+ else 'unlimited'
875
+ ),
876
+ max_claimed_effective,
877
+ self.cfg.max_claim_batch,
878
+ )
879
+
880
+ # Subscribe to each queue channel (and optionally a global)
881
+ self._queues = [
882
+ await self.listener.listen(f'task_queue_{q}') for q in self.cfg.queues
883
+ ]
884
+ logger.info(f'Subscribed to queues: {self.cfg.queues}')
885
+ self._global = await self.listener.listen('task_new')
886
+ logger.info('Subscribed to global queue')
887
+ # Start claimer heartbeat loop (CLAIMED coverage)
888
+ asyncio.create_task(self._claimer_heartbeat_loop())
889
+ # Start worker state heartbeat loop for monitoring
890
+ asyncio.create_task(self._worker_state_heartbeat_loop())
891
+ logger.info('Worker state heartbeat loop started for monitoring')
892
+ # Start reaper loop for automatic stale task handling
893
+ if self.cfg.recovery_config:
894
+ asyncio.create_task(self._reaper_loop())
895
+ logger.info('Reaper loop started for automatic stale task recovery')
896
+
897
+ async def stop(self) -> None:
898
+ self._stop.set()
899
+ # Close the Postgres listener early to avoid UNLISTEN races on dispatcher connection
900
+ try:
901
+ await self.listener.close()
902
+ logger.info('Postgres listener closed')
903
+ except Exception as e:
904
+ logger.error(f'Error closing Postgres listener: {e}')
905
+ # Shutdown executor
906
+ if self._executor:
907
+ # Offload blocking shutdown to a thread to avoid freezing the event loop
908
+ loop = asyncio.get_running_loop()
909
+ executor = self._executor
910
+ self._executor = None
911
+ try:
912
+ await loop.run_in_executor(
913
+ None, lambda: executor.shutdown(wait=True, cancel_futures=True)
914
+ ) # TODO:inspect this behaviour more in depth!
915
+ except Exception as e:
916
+ logger.error(f'Error shutting down executor: {e}')
917
+ logger.info('Worker executor shutdown')
918
+ logger.info('Worker stopped')
919
+
920
+ def _preload_modules_main(self) -> None:
921
+ """Import the app and all task modules in the main process.
922
+
923
+ This ensures Pydantic validations and module-level side effects run once
924
+ and any configuration errors surface during startup rather than inside
925
+ the child process initializer.
926
+ """
927
+ try:
928
+ sys_path_roots_resolved = _build_sys_path_roots(
929
+ self.cfg.app_locator, self.cfg.imports, self.cfg.sys_path_roots
930
+ )
931
+ _debug_imports_log(
932
+ f'[preload] app_locator={self.cfg.app_locator!r} sys_path_roots={sys_path_roots_resolved}'
933
+ )
934
+ for root in sys_path_roots_resolved:
935
+ if root not in sys.path:
936
+ sys.path.insert(0, root)
937
+
938
+ # Load app object (variable or factory)
939
+ app = _locate_app(self.cfg.app_locator)
940
+ # Optionally set as current for consistency in main process
941
+ set_current_app(app)
942
+ self._app = app
943
+
944
+ # Suppress accidental sends while importing modules for discovery
945
+ try:
946
+ app.suppress_sends(True)
947
+ except Exception:
948
+ pass
949
+
950
+ # Import declared modules that contain task definitions
951
+ combined_imports = list(self.cfg.imports)
952
+ try:
953
+ combined_imports.extend(app.get_discovered_task_modules())
954
+ except Exception:
955
+ pass
956
+ combined_imports = _dedupe_paths(combined_imports)
957
+ _debug_imports_log(f'[preload] import_modules={combined_imports}')
958
+ for m in combined_imports:
959
+ if m.endswith('.py') or os.path.sep in m:
960
+ import_by_path(os.path.abspath(m))
961
+ else:
962
+ import_module(m)
963
+
964
+ # Re-enable sends after import completes
965
+ try:
966
+ app.suppress_sends(False)
967
+ except Exception:
968
+ pass
969
+ _debug_imports_log(f'[preload] registered_tasks={app.list_tasks()}')
970
+ except Exception as e:
971
+ # Surface the error clearly and re-raise to stop startup
972
+ logger.error(f'Failed during preload of task modules: {e}')
973
+ raise
974
+
975
+ # ----- main loop -----
976
+
977
+ async def run_forever(self) -> None:
978
+ """Main orchestrator loop."""
979
+ try:
980
+ # Add timeout to startup to prevent hanging
981
+ await asyncio.wait_for(self.start(), timeout=30.0)
982
+ logger.info('Worker started')
983
+ except asyncio.TimeoutError:
984
+ logger.error('Worker startup timed out after 30 seconds')
985
+ raise RuntimeError(
986
+ 'Worker startup timeout - likely database connection issue'
987
+ )
988
+ try:
989
+ while not self._stop.is_set():
990
+ # Single budgeted claim pass, then wait for new NOTIFY
991
+ await self._claim_and_dispatch_all()
992
+
993
+ # Wait for a NOTIFY from any queue (coalesce bursts).
994
+ await self._wait_for_any_notify()
995
+ await self._claim_and_dispatch_all()
996
+ finally:
997
+ await self.stop()
998
+
999
+ async def _wait_for_any_notify(self) -> None:
1000
+ """Wait on any subscribed queue channel; coalesce a burst."""
1001
+ import contextlib
1002
+
1003
+ queue_tasks = [
1004
+ asyncio.create_task(q.get()) for q in (self._queues + [self._global])
1005
+ ]
1006
+ # Add only the stop event as an additional wait condition (no periodic polling)
1007
+ stop_task = asyncio.create_task(self._stop.wait())
1008
+ all_tasks = queue_tasks + [stop_task]
1009
+
1010
+ _, pending = await asyncio.wait(all_tasks, return_when=asyncio.FIRST_COMPLETED)
1011
+
1012
+ # Check if stop was signaled
1013
+ if self._stop.is_set():
1014
+ # Cancel all pending tasks and await them to avoid warnings
1015
+ for p in pending:
1016
+ p.cancel()
1017
+ for p in pending:
1018
+ with contextlib.suppress(asyncio.CancelledError):
1019
+ await p
1020
+ return
1021
+
1022
+ # cancel the rest to avoid background tasks piling up and await them
1023
+ for p in pending:
1024
+ p.cancel()
1025
+ for p in pending:
1026
+ with contextlib.suppress(asyncio.CancelledError):
1027
+ await p
1028
+
1029
+ # drain a burst
1030
+ drained = 0
1031
+ for q in self._queues + [self._global]:
1032
+ while drained < self.cfg.coalesce_notifies and not q.empty():
1033
+ try:
1034
+ q.get_nowait()
1035
+ drained += 1
1036
+ except asyncio.QueueEmpty:
1037
+ break
1038
+
1039
+ # ----- claim & dispatch -----
1040
+
1041
+ async def _claim_and_dispatch_all(self) -> bool:
1042
+ """
1043
+ Claim tasks subject to:
1044
+ - max_claim_per_worker guard (prevents over-claiming)
1045
+ - queue priorities (CUSTOM mode)
1046
+ - per-queue max_concurrency (CUSTOM mode)
1047
+ - worker global concurrency (processes)
1048
+ Returns True if anything was claimed.
1049
+ """
1050
+ # Guard: Check if we've already claimed too many tasks
1051
+ # Default depends on mode:
1052
+ # - Hard cap (prefetch_buffer=0): default to processes
1053
+ # - Soft cap (prefetch_buffer>0): default to processes + prefetch_buffer
1054
+ if self.cfg.max_claim_per_worker > 0:
1055
+ # User explicitly set a limit - use it
1056
+ max_claimed = self.cfg.max_claim_per_worker
1057
+ elif self.cfg.prefetch_buffer > 0:
1058
+ # Soft cap mode: allow claiming up to processes + prefetch_buffer
1059
+ max_claimed = self.cfg.processes + self.cfg.prefetch_buffer
1060
+ else:
1061
+ # Hard cap mode: limit to processes
1062
+ max_claimed = self.cfg.processes
1063
+ claimed_count = await self._count_claimed_for_worker()
1064
+ if claimed_count >= max_claimed:
1065
+ return False
1066
+
1067
+ # Cluster-wide, lock-guarded claim to avoid races. One short transaction.
1068
+ # Use local capacity + small prefetch to size claims fairly across workers.
1069
+ claimed_ids: list[str] = []
1070
+
1071
+ # Queue order: if custom priorities provided, sort by priority; otherwise keep given order
1072
+ if self.cfg.queue_priorities:
1073
+ ordered_queues = sorted(
1074
+ [q for q in self.cfg.queues if q in self.cfg.queue_priorities],
1075
+ key=lambda q: self.cfg.queue_priorities.get(q, 100),
1076
+ )
1077
+ else:
1078
+ ordered_queues = list(self.cfg.queues)
1079
+
1080
+ # Open one transaction, take a global advisory xact lock
1081
+ async with self.sf() as s:
1082
+ # Take a cluster-wide transaction-scoped advisory lock to serialize claiming
1083
+ await s.execute(
1084
+ text('SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))'),
1085
+ {'key': self._advisory_key_global()},
1086
+ )
1087
+
1088
+ # Compute local budget and optional global remaining
1089
+ # Hard cap mode (prefetch_buffer=0): count RUNNING + CLAIMED for strict enforcement
1090
+ # Soft cap mode (prefetch_buffer>0): count only RUNNING, allow prefetch with lease
1091
+ hard_cap_mode = self.cfg.prefetch_buffer == 0
1092
+
1093
+ if hard_cap_mode:
1094
+ # Hard cap: count both RUNNING and CLAIMED for this worker
1095
+ local_in_flight = await self._count_in_flight_for_worker()
1096
+ max_local_capacity = self.cfg.processes
1097
+ else:
1098
+ # Soft cap: count only RUNNING to allow prefetch beyond processes
1099
+ local_in_flight = await self._count_only_running_for_worker()
1100
+ max_local_capacity = self.cfg.processes + self.cfg.prefetch_buffer
1101
+ local_available = max(0, int(max_local_capacity) - int(local_in_flight))
1102
+ budget_remaining = local_available
1103
+
1104
+ global_remaining: Optional[int] = None
1105
+ if self.cfg.cluster_wide_cap is not None:
1106
+ # Hard cap mode: count RUNNING + CLAIMED globally
1107
+ # (Note: prefetch_buffer must be 0 when cluster_wide_cap is set, enforced by config validation)
1108
+ res = await s.execute(
1109
+ text(
1110
+ "SELECT COUNT(*) FROM horsies_tasks WHERE status IN ('RUNNING', 'CLAIMED')"
1111
+ )
1112
+ )
1113
+ row = res.fetchone()
1114
+ if row:
1115
+ in_flight_global = int(row[0])
1116
+ else:
1117
+ in_flight_global = 0
1118
+ global_remaining = max(
1119
+ 0, int(self.cfg.cluster_wide_cap) - in_flight_global
1120
+ )
1121
+
1122
+ # Total claim budget for this pass: local budget capped by global remaining (if any)
1123
+ total_remaining = (
1124
+ budget_remaining
1125
+ if global_remaining is None
1126
+ else min(budget_remaining, global_remaining)
1127
+ )
1128
+ if total_remaining <= 0:
1129
+ # Nothing to claim globally or locally
1130
+ await s.commit()
1131
+ return False
1132
+
1133
+ for qname in ordered_queues:
1134
+ if total_remaining <= 0:
1135
+ break
1136
+
1137
+ # Compute queue remaining in cluster (only if custom-configured)
1138
+ q_remaining: Optional[int] = None
1139
+ if (
1140
+ self.cfg.queue_priorities
1141
+ and qname in self.cfg.queue_max_concurrency
1142
+ ):
1143
+ # Hard cap mode: count RUNNING + CLAIMED for this queue
1144
+ # Soft cap mode: count only RUNNING
1145
+ if hard_cap_mode:
1146
+ resq = await s.execute(
1147
+ text(
1148
+ "SELECT COUNT(*) FROM horsies_tasks WHERE status IN ('RUNNING', 'CLAIMED') AND queue_name = :q"
1149
+ ),
1150
+ {'q': qname},
1151
+ )
1152
+ else:
1153
+ resq = await s.execute(
1154
+ text(
1155
+ "SELECT COUNT(*) FROM horsies_tasks WHERE status = 'RUNNING' AND queue_name = :q"
1156
+ ),
1157
+ {'q': qname},
1158
+ )
1159
+ row = resq.fetchone()
1160
+ if row:
1161
+ in_flight_q = int(row[0])
1162
+ else:
1163
+ in_flight_q = 0
1164
+ max_q = int(self.cfg.queue_max_concurrency.get(qname, 0))
1165
+ q_remaining = max(0, max_q - in_flight_q)
1166
+
1167
+ # Determine how many we may claim from this queue
1168
+ # Hierarchy: max_claim_batch (fairness) -> q_remaining (queue cap) -> total_remaining (worker budget)
1169
+
1170
+ if self.cfg.queue_priorities:
1171
+ # Strict priority mode: try to fill remaining budget from this queue
1172
+ # Ignore max_claim_batch (which forces round-robin fairness)
1173
+ per_queue_cap = total_remaining
1174
+ else:
1175
+ # Default mode: use max_claim_batch to ensure fairness across queues
1176
+ per_queue_cap = self.cfg.max_claim_batch
1177
+
1178
+ if q_remaining is not None:
1179
+ per_queue_cap = min(per_queue_cap, q_remaining)
1180
+ to_claim = min(total_remaining, per_queue_cap)
1181
+ if to_claim <= 0:
1182
+ continue
1183
+
1184
+ ids = await self._claim_batch_locked(s, qname, to_claim)
1185
+ if not ids:
1186
+ continue
1187
+ claimed_ids.extend(ids)
1188
+ total_remaining -= len(ids)
1189
+
1190
+ await s.commit()
1191
+
1192
+ if not claimed_ids:
1193
+ return False
1194
+
1195
+ rows = await self._load_rows(claimed_ids)
1196
+
1197
+ # PAUSE guard: filter out tasks belonging to PAUSED workflows and unclaim them
1198
+ rows = await self._filter_paused_workflow_tasks(rows)
1199
+
1200
+ for row in rows:
1201
+ await self._dispatch_one(
1202
+ row['id'], row['task_name'], row['args'], row['kwargs']
1203
+ )
1204
+ return len(rows) > 0
1205
+
1206
+ async def _filter_paused_workflow_tasks(
1207
+ self, rows: list[dict[str, Any]]
1208
+ ) -> list[dict[str, Any]]:
1209
+ """
1210
+ Filter out tasks belonging to PAUSED workflows and unclaim them.
1211
+
1212
+ Post-claim guard: If a task belongs to a workflow that is PAUSED,
1213
+ we unclaim it (set back to PENDING) so it can be processed on resume.
1214
+
1215
+ Returns the filtered list of rows that should be dispatched.
1216
+ """
1217
+ if not rows:
1218
+ return rows
1219
+
1220
+ task_ids = [row['id'] for row in rows]
1221
+
1222
+ # Find which tasks belong to PAUSED workflows
1223
+ async with self.sf() as s:
1224
+ res = await s.execute(
1225
+ text("""
1226
+ SELECT t.id
1227
+ FROM horsies_tasks t
1228
+ JOIN horsies_workflow_tasks wt ON wt.task_id = t.id
1229
+ JOIN horsies_workflows w ON w.id = wt.workflow_id
1230
+ WHERE t.id = ANY(:ids)
1231
+ AND w.status = 'PAUSED'
1232
+ """),
1233
+ {'ids': task_ids},
1234
+ )
1235
+ paused_task_ids = {row[0] for row in res.fetchall()}
1236
+
1237
+ if paused_task_ids:
1238
+ # Unclaim these tasks: set back to PENDING so they can be picked up on resume
1239
+ await s.execute(
1240
+ text("""
1241
+ UPDATE horsies_tasks
1242
+ SET status = 'PENDING',
1243
+ claimed = FALSE,
1244
+ claimed_at = NULL,
1245
+ claimed_by_worker_id = NULL,
1246
+ updated_at = NOW()
1247
+ WHERE id = ANY(:ids)
1248
+ """),
1249
+ {'ids': list(paused_task_ids)},
1250
+ )
1251
+ # Also reset workflow_tasks back to READY for consistency
1252
+ # (they were ENQUEUED, but the task is now unclaimed)
1253
+ await s.execute(
1254
+ text("""
1255
+ UPDATE horsies_workflow_tasks
1256
+ SET status = 'READY', task_id = NULL, started_at = NULL
1257
+ WHERE task_id = ANY(:ids)
1258
+ """),
1259
+ {'ids': list(paused_task_ids)},
1260
+ )
1261
+ await s.commit()
1262
+
1263
+ # Return only tasks not belonging to PAUSED workflows
1264
+ return [row for row in rows if row['id'] not in paused_task_ids]
1265
+
1266
+ def _advisory_key_global(self) -> int:
1267
+ """Compute a stable 64-bit advisory lock key for this cluster."""
1268
+ basis = (self.cfg.psycopg_dsn or self.cfg.dsn or 'horsies').encode(
1269
+ 'utf-8', errors='ignore'
1270
+ )
1271
+ h = hashlib.sha256(b'horsies-global:' + basis).digest()
1272
+ return int.from_bytes(h[:8], byteorder='big', signed=True)
1273
+
1274
+ # Stale detection is handled via heartbeat policy for RUNNING tasks.
1275
+
1276
+ def _compute_claim_expires_at(self) -> Optional[datetime]:
1277
+ """Compute claim expiration timestamp for soft cap mode, or None for hard cap mode."""
1278
+ if self.cfg.claim_lease_ms is None:
1279
+ return None
1280
+ return datetime.now(timezone.utc) + timedelta(
1281
+ milliseconds=self.cfg.claim_lease_ms
1282
+ )
1283
+
1284
+ async def _claim_batch_locked(
1285
+ self, s: AsyncSession, queue: str, limit: int
1286
+ ) -> list[str]:
1287
+ """Claim up to limit tasks for a given queue within an open transaction/lock."""
1288
+ res = await s.execute(
1289
+ CLAIM_SQL,
1290
+ {
1291
+ 'queue': queue,
1292
+ 'lim': limit,
1293
+ 'worker_id': self.worker_instance_id,
1294
+ 'claim_expires_at': self._compute_claim_expires_at(),
1295
+ },
1296
+ )
1297
+ return [r[0] for r in res.fetchall()]
1298
+
1299
+ async def _count_claimed_for_worker(self) -> int:
1300
+ """Count only CLAIMED tasks for this worker (not yet RUNNING)."""
1301
+ async with self.sf() as s:
1302
+ res = await s.execute(
1303
+ text(
1304
+ """
1305
+ SELECT COUNT(*)
1306
+ FROM horsies_tasks
1307
+ WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
1308
+ AND status = 'CLAIMED'
1309
+ """
1310
+ ),
1311
+ {'wid': self.worker_instance_id},
1312
+ )
1313
+ row = res.fetchone()
1314
+ return int(row[0]) if row else 0
1315
+
1316
+ async def _count_only_running_for_worker(self) -> int:
1317
+ """Count only RUNNING tasks for this worker (excludes CLAIMED)."""
1318
+ async with self.sf() as s:
1319
+ res = await s.execute(
1320
+ text(
1321
+ """
1322
+ SELECT COUNT(*)
1323
+ FROM horsies_tasks
1324
+ WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
1325
+ AND status = 'RUNNING'
1326
+ """
1327
+ ),
1328
+ {'wid': self.worker_instance_id},
1329
+ )
1330
+ row = res.fetchone()
1331
+ return int(row[0]) if row else 0
1332
+
1333
+ async def _count_in_flight_for_worker(self) -> int:
1334
+ """Count RUNNING + CLAIMED tasks for this worker (hard cap mode)."""
1335
+ async with self.sf() as s:
1336
+ res = await s.execute(
1337
+ text(
1338
+ """
1339
+ SELECT COUNT(*)
1340
+ FROM horsies_tasks
1341
+ WHERE claimed_by_worker_id = CAST(:wid AS VARCHAR)
1342
+ AND status IN ('RUNNING', 'CLAIMED')
1343
+ """
1344
+ ),
1345
+ {'wid': self.worker_instance_id},
1346
+ )
1347
+ row = res.fetchone()
1348
+ return int(row[0]) if row else 0
1349
+
1350
+ async def _count_running_in_queue(self, queue_name: str) -> int:
1351
+ """Count RUNNING tasks in a given queue across the cluster."""
1352
+ async with self.sf() as s:
1353
+ res = await s.execute(
1354
+ text(
1355
+ """
1356
+ SELECT COUNT(*)
1357
+ FROM horsies_tasks
1358
+ WHERE status = 'RUNNING'
1359
+ AND queue_name = :q
1360
+ """
1361
+ ),
1362
+ {'q': queue_name},
1363
+ )
1364
+ row = res.fetchone()
1365
+ return int(row[0]) if row else 0
1366
+
1367
+ async def _claim_batch(self, queue: str, limit: int) -> list[str]:
1368
+ async with self.sf() as s:
1369
+ res = await s.execute(
1370
+ CLAIM_SQL,
1371
+ {
1372
+ 'queue': queue,
1373
+ 'lim': limit,
1374
+ 'worker_id': self.worker_instance_id,
1375
+ 'claim_expires_at': self._compute_claim_expires_at(),
1376
+ },
1377
+ )
1378
+ ids = [r[0] for r in res.fetchall()]
1379
+ # Make the CLAIMED state visible and release the row locks
1380
+ await s.commit()
1381
+ return ids
1382
+
1383
+ async def _load_rows(self, ids: Sequence[str]) -> list[dict[str, Any]]:
1384
+ if not ids:
1385
+ return []
1386
+ async with self.sf() as s:
1387
+ res = await s.execute(LOAD_ROWS_SQL, {'ids': list(ids)})
1388
+ cols = res.keys()
1389
+ return [dict(zip(cols, row)) for row in res.fetchall()]
1390
+
1391
+ async def _dispatch_one(
1392
+ self,
1393
+ task_id: str,
1394
+ task_name: str,
1395
+ args_json: Optional[str],
1396
+ kwargs_json: Optional[str],
1397
+ ) -> None:
1398
+ """Submit to process pool; attach completion handler."""
1399
+ assert self._executor is not None
1400
+ loop = asyncio.get_running_loop()
1401
+
1402
+ # Get heartbeat interval from recovery config (milliseconds)
1403
+ runner_heartbeat_interval_ms = 30_000 # default: 30 seconds
1404
+ if self.cfg.recovery_config:
1405
+ runner_heartbeat_interval_ms = (
1406
+ self.cfg.recovery_config.runner_heartbeat_interval_ms
1407
+ )
1408
+
1409
+ # Pass task_id and database_url to task process for self-heartbeat
1410
+ database_url = self.cfg.dsn.replace('+asyncpg', '').replace('+psycopg', '')
1411
+ fut = loop.run_in_executor(
1412
+ self._executor,
1413
+ _run_task_entry,
1414
+ task_name,
1415
+ args_json,
1416
+ kwargs_json,
1417
+ task_id,
1418
+ database_url,
1419
+ self.worker_instance_id,
1420
+ runner_heartbeat_interval_ms,
1421
+ )
1422
+
1423
+ # When done, record the outcome
1424
+ asyncio.create_task(self._finalize_after(fut, task_id))
1425
+
1426
+ # ----- finalize (write back to DB + notify) -----
1427
+
1428
+ async def _finalize_after(
1429
+ self, fut: 'asyncio.Future[tuple[bool, str, Optional[str]]]', task_id: str
1430
+ ) -> None:
1431
+ ok, result_json_str, failed_reason = await fut
1432
+ now = datetime.now(timezone.utc)
1433
+
1434
+ # Note: Heartbeat thread in task process automatically dies when process completes
1435
+
1436
+ async with self.sf() as s:
1437
+ if not ok:
1438
+ # CLAIM_LOST: Another worker reclaimed this task - do nothing
1439
+ # The task is not failed; it belongs to another worker now
1440
+ match failed_reason:
1441
+ case (
1442
+ 'CLAIM_LOST'
1443
+ | 'OWNERSHIP_UNCONFIRMED'
1444
+ | 'WORKFLOW_CHECK_FAILED'
1445
+ | 'WORKFLOW_STOPPED'
1446
+ ):
1447
+ logger.debug(
1448
+ f'Task {task_id} aborted with reason={failed_reason}, skipping finalization'
1449
+ )
1450
+ return
1451
+ case _:
1452
+ pass
1453
+
1454
+ # worker-level failure (rare): mark FAILED with reason
1455
+ await s.execute(
1456
+ text("""
1457
+ UPDATE horsies_tasks
1458
+ SET status='FAILED',
1459
+ failed_at = :now,
1460
+ failed_reason = :reason,
1461
+ updated_at = :now
1462
+ WHERE id = :id
1463
+ """),
1464
+ {
1465
+ 'now': now,
1466
+ 'reason': failed_reason or 'Worker failure',
1467
+ 'id': task_id,
1468
+ },
1469
+ )
1470
+ # Trigger automatically sends NOTIFY on UPDATE
1471
+ await s.commit()
1472
+ return
1473
+
1474
+ # Parse the TaskResult we produced
1475
+ tr = task_result_from_json(loads_json(result_json_str))
1476
+ if tr.is_err():
1477
+ # Check if this task should be retried
1478
+ task_error = tr.unwrap_err()
1479
+ match task_error.error_code if task_error else None:
1480
+ case 'WORKFLOW_STOPPED':
1481
+ logger.debug(
1482
+ f'Task {task_id} skipped due to workflow stop, skipping finalization'
1483
+ )
1484
+ return
1485
+ case _:
1486
+ pass
1487
+ should_retry = await self._should_retry_task(task_id, task_error)
1488
+ if should_retry:
1489
+ await self._schedule_retry(task_id, s)
1490
+ await s.commit()
1491
+ return
1492
+
1493
+ # Mark as failed if no retry
1494
+ await s.execute(
1495
+ text("""
1496
+ UPDATE horsies_tasks
1497
+ SET status='FAILED',
1498
+ failed_at = :now,
1499
+ result = :result_json, -- or result_json JSONB if you add that column
1500
+ updated_at = :now
1501
+ WHERE id = :id
1502
+ """),
1503
+ {'now': now, 'result_json': result_json_str, 'id': task_id},
1504
+ )
1505
+ else:
1506
+ await s.execute(
1507
+ text("""
1508
+ UPDATE horsies_tasks
1509
+ SET status='COMPLETED',
1510
+ completed_at = :now,
1511
+ result = :result_json, -- or result_json JSONB
1512
+ updated_at = :now
1513
+ WHERE id = :id
1514
+ """),
1515
+ {'now': now, 'result_json': result_json_str, 'id': task_id},
1516
+ )
1517
+
1518
+ # Handle workflow task completion (if this task is part of a workflow)
1519
+ await self._handle_workflow_task_if_needed(s, task_id, tr)
1520
+
1521
+ # Proactively wake workers to re-check capacity/backlog.
1522
+ try:
1523
+ # Notify workers globally and on the specific queue to wake claims
1524
+ # Fetch queue name for this task
1525
+ resq = await s.execute(
1526
+ text('SELECT queue_name FROM horsies_tasks WHERE id = :id'), {'id': task_id}
1527
+ )
1528
+ rowq = resq.fetchone()
1529
+ qname = str(rowq[0]) if rowq and rowq[0] else 'default'
1530
+ payload = f'capacity:{task_id}'
1531
+ await s.execute(
1532
+ text('SELECT pg_notify(:c1, :p)'), {'c1': 'task_new', 'p': payload}
1533
+ )
1534
+ await s.execute(
1535
+ text('SELECT pg_notify(:c2, :p)'),
1536
+ {'c2': f'task_queue_{qname}', 'p': payload},
1537
+ )
1538
+ except Exception:
1539
+ # Non-fatal if NOTIFY fails; continue
1540
+ pass
1541
+
1542
+ # Trigger automatically sends NOTIFY on UPDATE; commit to flush NOTIFYs
1543
+ await s.commit()
1544
+
1545
+ async def _handle_workflow_task_if_needed(
1546
+ self,
1547
+ session: 'AsyncSession',
1548
+ task_id: str,
1549
+ result: 'TaskResult[Any, TaskError]',
1550
+ ) -> None:
1551
+ """
1552
+ Check if task is part of a workflow and handle accordingly.
1553
+
1554
+ This method is called after a task completes (success or failure).
1555
+ It updates the workflow_task record and triggers dependency resolution.
1556
+ """
1557
+ from horsies.core.workflows.engine import on_workflow_task_complete
1558
+
1559
+ # Quick check: is this task linked to a workflow?
1560
+ check = await session.execute(
1561
+ text('SELECT 1 FROM horsies_workflow_tasks WHERE task_id = :tid LIMIT 1'),
1562
+ {'tid': task_id},
1563
+ )
1564
+
1565
+ if check.fetchone() is None:
1566
+ return # Not a workflow task
1567
+
1568
+ broker = self._app.get_broker() if self._app is not None else None
1569
+ # Handle workflow task completion
1570
+ await on_workflow_task_complete(session, task_id, result, broker)
1571
+
1572
+ async def _should_retry_task(self, task_id: str, error: TaskError) -> bool:
1573
+ """Check if a task should be retried based on its configuration and current retry count."""
1574
+ async with self.sf() as s:
1575
+ result = await s.execute(
1576
+ text(
1577
+ 'SELECT retry_count, max_retries, task_options FROM horsies_tasks WHERE id = :id'
1578
+ ),
1579
+ {'id': task_id},
1580
+ )
1581
+ row = result.fetchone()
1582
+
1583
+ if not row:
1584
+ return False
1585
+
1586
+ retry_count = row.retry_count or 0
1587
+ max_retries = row.max_retries or 0
1588
+
1589
+ if retry_count >= max_retries or max_retries == 0:
1590
+ return False
1591
+
1592
+ # Parse task options to check auto_retry_for
1593
+ try:
1594
+ task_options_data = loads_json(row.task_options) if row.task_options else {}
1595
+ if not isinstance(task_options_data, dict):
1596
+ return False
1597
+ auto_retry_for = task_options_data.get('auto_retry_for', [])
1598
+ except Exception:
1599
+ return False
1600
+
1601
+ if not auto_retry_for or not isinstance(auto_retry_for, list):
1602
+ return False
1603
+
1604
+ # Check if error matches auto_retry_for criteria (enum or string)
1605
+ code = (
1606
+ error.error_code.value
1607
+ if isinstance(error.error_code, LibraryErrorCode)
1608
+ else error.error_code
1609
+ )
1610
+ if code and code in auto_retry_for:
1611
+ return True
1612
+
1613
+ if error.exception:
1614
+ # Handle both Exception objects and serialized exception dicts
1615
+ exception_type: str | None = None
1616
+ if isinstance(error.exception, Exception):
1617
+ exception_type = type(error.exception).__name__
1618
+ elif isinstance(error.exception, dict):
1619
+ # Support both serialization formats:
1620
+ # - "type" from _exception_to_json
1621
+ # - "exception_type" from task_decorator data field
1622
+ exception_type = error.exception.get('type') or error.exception.get(
1623
+ 'exception_type'
1624
+ )
1625
+
1626
+ if exception_type and exception_type in auto_retry_for:
1627
+ return True
1628
+
1629
+ return False
1630
+
1631
+ async def _schedule_retry(self, task_id: str, session: AsyncSession) -> None:
1632
+ """Schedule a task for retry by updating its status and next retry time."""
1633
+ # Get current retry configuration
1634
+ result = await session.execute(
1635
+ text('SELECT retry_count, task_options FROM horsies_tasks WHERE id = :id'),
1636
+ {'id': task_id},
1637
+ )
1638
+ row = result.fetchone()
1639
+
1640
+ if not row:
1641
+ return
1642
+
1643
+ retry_count = (row.retry_count or 0) + 1
1644
+
1645
+ # Parse retry policy from task options
1646
+ try:
1647
+ task_options_data = loads_json(row.task_options) if row.task_options else {}
1648
+ if not isinstance(task_options_data, dict):
1649
+ retry_policy_data = {}
1650
+ else:
1651
+ retry_policy_data = task_options_data.get('retry_policy', {})
1652
+ if not isinstance(retry_policy_data, dict):
1653
+ retry_policy_data = {}
1654
+ except Exception:
1655
+ retry_policy_data = {}
1656
+
1657
+ # Calculate retry delay
1658
+ delay_seconds = self._calculate_retry_delay(retry_count, retry_policy_data)
1659
+ next_retry_at = datetime.now(timezone.utc) + timedelta(seconds=delay_seconds)
1660
+
1661
+ # Update task for retry
1662
+ await session.execute(
1663
+ text("""
1664
+ UPDATE horsies_tasks
1665
+ SET status = 'PENDING',
1666
+ retry_count = :retry_count,
1667
+ next_retry_at = :next_retry_at,
1668
+ sent_at = :next_retry_at,
1669
+ updated_at = now()
1670
+ WHERE id = :id
1671
+ """),
1672
+ {'id': task_id, 'retry_count': retry_count, 'next_retry_at': next_retry_at},
1673
+ )
1674
+
1675
+ # Schedule a delayed notification using asyncio for the task's actual queue
1676
+ queue_name = await self._get_task_queue_name(task_id)
1677
+ asyncio.create_task(
1678
+ self._schedule_delayed_notification(
1679
+ delay_seconds, f'task_queue_{queue_name}', f'retry:{task_id}'
1680
+ )
1681
+ )
1682
+
1683
+ logger.info(
1684
+ f'Scheduled task {task_id} for retry #{retry_count} at {next_retry_at}'
1685
+ )
1686
+
1687
+ def _calculate_retry_delay(
1688
+ self, retry_attempt: int, retry_policy_data: dict[str, Any]
1689
+ ) -> int:
1690
+ """Calculate the delay in seconds for a retry attempt."""
1691
+ # Default retry policy values (these match RetryPolicy defaults)
1692
+ intervals = retry_policy_data.get(
1693
+ 'intervals', [60, 300, 900]
1694
+ ) # 1min, 5min, 15min
1695
+ backoff_strategy = retry_policy_data.get('backoff_strategy', 'fixed')
1696
+ jitter = retry_policy_data.get('jitter', True)
1697
+
1698
+ # Calculate base delay based on strategy
1699
+ if backoff_strategy == 'fixed':
1700
+ # Fixed strategy: use intervals directly (validation ensures length matches max_retries)
1701
+ base_delay = intervals[retry_attempt - 1] # Safe due to validation
1702
+
1703
+ elif backoff_strategy == 'exponential':
1704
+ # Exponential strategy: use intervals[0] as base and apply exponential multiplier
1705
+ base_interval = intervals[0] # Safe due to validation (exactly 1 interval)
1706
+ base_delay = base_interval * (2 ** (retry_attempt - 1))
1707
+
1708
+ else:
1709
+ # Fallback (shouldn't happen due to Literal type validation)
1710
+ base_delay = intervals[0] if intervals else 60
1711
+
1712
+ # Apply jitter (±25% randomization)
1713
+ if jitter:
1714
+ jitter_range = base_delay * 0.25
1715
+ base_delay += random.uniform(-jitter_range, jitter_range)
1716
+
1717
+ return max(1, int(base_delay)) # Ensure at least 1 second delay
1718
+
1719
+ async def _schedule_delayed_notification(
1720
+ self, delay_seconds: int, channel: str, payload: str
1721
+ ) -> None:
1722
+ """Schedule a delayed notification to wake up the worker for retry."""
1723
+ try:
1724
+ await asyncio.sleep(delay_seconds)
1725
+
1726
+ # Send notification to trigger retry processing
1727
+ async with self.sf() as session:
1728
+ await session.execute(
1729
+ text('SELECT pg_notify(:channel, :payload)'),
1730
+ {'channel': channel, 'payload': payload},
1731
+ )
1732
+ await session.commit()
1733
+
1734
+ logger.debug(f'Sent delayed notification for retry: {payload}')
1735
+ except asyncio.CancelledError:
1736
+ logger.debug(f'Delayed notification cancelled for: {payload}')
1737
+ except Exception as e:
1738
+ logger.error(f'Error sending delayed notification for {payload}: {e}')
1739
+
1740
+ async def _get_task_queue_name(self, task_id: str) -> str:
1741
+ """Fetch the queue_name for a given task id."""
1742
+ async with self.sf() as session:
1743
+ res = await session.execute(
1744
+ text('SELECT queue_name FROM horsies_tasks WHERE id = :id'),
1745
+ {'id': task_id},
1746
+ )
1747
+ row = res.fetchone()
1748
+ return str(row[0]) if row and row[0] else 'default'
1749
+
1750
+ async def _claimer_heartbeat_loop(self) -> None:
1751
+ """Emit claimer heartbeats for tasks we've claimed but not yet started."""
1752
+ # Get heartbeat interval from recovery config (milliseconds)
1753
+ claimer_heartbeat_interval_ms = 30_000 # default: 30 seconds
1754
+ if self.cfg.recovery_config:
1755
+ claimer_heartbeat_interval_ms = (
1756
+ self.cfg.recovery_config.claimer_heartbeat_interval_ms
1757
+ )
1758
+
1759
+ try:
1760
+ while not self._stop.is_set():
1761
+ try:
1762
+ async with self.sf() as s:
1763
+ await s.execute(
1764
+ text(
1765
+ """
1766
+ INSERT INTO horsies_heartbeats (task_id, sender_id, role, sent_at, hostname, pid)
1767
+ SELECT id, CAST(:wid AS VARCHAR), 'claimer', NOW(), :host, :pid
1768
+ FROM horsies_tasks
1769
+ WHERE status = 'CLAIMED' AND claimed_by_worker_id = CAST(:wid AS VARCHAR)
1770
+ """
1771
+ ),
1772
+ {
1773
+ 'wid': self.worker_instance_id,
1774
+ 'host': socket.gethostname(),
1775
+ 'pid': os.getpid(),
1776
+ },
1777
+ )
1778
+ await s.commit()
1779
+ except Exception as e:
1780
+ logger.error(f'Claimer heartbeat error: {e}')
1781
+ # Convert to seconds only for asyncio.sleep
1782
+ await asyncio.sleep(claimer_heartbeat_interval_ms / 1000.0)
1783
+ except asyncio.CancelledError:
1784
+ return
1785
+
1786
+ async def _update_worker_state(self) -> None:
1787
+ """Update worker state snapshot in database for monitoring."""
1788
+ import psutil
1789
+
1790
+ try:
1791
+ process = psutil.Process()
1792
+ memory_info = process.memory_info()
1793
+
1794
+ # Get current task counts
1795
+ running = await self._count_only_running_for_worker()
1796
+ claimed = await self._count_claimed_for_worker()
1797
+
1798
+ # Serialize recovery config
1799
+ recovery_dict = None
1800
+ if self.cfg.recovery_config:
1801
+ recovery_dict = {
1802
+ 'auto_requeue_stale_claimed': self.cfg.recovery_config.auto_requeue_stale_claimed,
1803
+ 'claimed_stale_threshold_ms': self.cfg.recovery_config.claimed_stale_threshold_ms,
1804
+ 'auto_fail_stale_running': self.cfg.recovery_config.auto_fail_stale_running,
1805
+ 'running_stale_threshold_ms': self.cfg.recovery_config.running_stale_threshold_ms,
1806
+ 'check_interval_ms': self.cfg.recovery_config.check_interval_ms,
1807
+ 'runner_heartbeat_interval_ms': self.cfg.recovery_config.runner_heartbeat_interval_ms,
1808
+ 'claimer_heartbeat_interval_ms': self.cfg.recovery_config.claimer_heartbeat_interval_ms,
1809
+ }
1810
+
1811
+ async with self.sf() as s:
1812
+ await s.execute(
1813
+ text("""
1814
+ INSERT INTO horsies_worker_states (
1815
+ worker_id, snapshot_at, hostname, pid,
1816
+ processes, max_claim_batch, max_claim_per_worker,
1817
+ cluster_wide_cap, queues, queue_priorities, queue_max_concurrency,
1818
+ recovery_config, tasks_running, tasks_claimed,
1819
+ memory_usage_mb, memory_percent, cpu_percent,
1820
+ worker_started_at
1821
+ )
1822
+ VALUES (
1823
+ :wid, NOW(), :host, :pid, :procs, :mcb, :mcpw, :cwc,
1824
+ :queues, :qp, :qmc, :recovery, :running, :claimed,
1825
+ :mem_mb, :mem_pct, :cpu_pct, :started
1826
+ )
1827
+ """),
1828
+ {
1829
+ 'wid': self.worker_instance_id,
1830
+ 'host': socket.gethostname(),
1831
+ 'pid': os.getpid(),
1832
+ 'procs': self.cfg.processes,
1833
+ 'mcb': self.cfg.max_claim_batch,
1834
+ 'mcpw': self.cfg.max_claim_per_worker
1835
+ if self.cfg.max_claim_per_worker > 0
1836
+ else self.cfg.processes,
1837
+ 'cwc': self.cfg.cluster_wide_cap,
1838
+ 'queues': self.cfg.queues,
1839
+ 'qp': Jsonb(self.cfg.queue_priorities)
1840
+ if self.cfg.queue_priorities
1841
+ else None,
1842
+ 'qmc': Jsonb(self.cfg.queue_max_concurrency)
1843
+ if self.cfg.queue_max_concurrency
1844
+ else None,
1845
+ 'recovery': Jsonb(recovery_dict) if recovery_dict else None,
1846
+ 'running': running,
1847
+ 'claimed': claimed,
1848
+ 'mem_mb': memory_info.rss / 1024 / 1024,
1849
+ 'mem_pct': process.memory_percent(),
1850
+ 'cpu_pct': process.cpu_percent(interval=0.1),
1851
+ 'started': self._started_at,
1852
+ },
1853
+ )
1854
+ await s.commit()
1855
+ except Exception as e:
1856
+ logger.error(f'Failed to update worker state: {e}')
1857
+
1858
+ async def _worker_state_heartbeat_loop(self) -> None:
1859
+ """Periodically update worker state for monitoring (every 5 seconds)."""
1860
+ worker_state_interval_ms = 5_000 # 5 seconds
1861
+
1862
+ try:
1863
+ while not self._stop.is_set():
1864
+ try:
1865
+ await self._update_worker_state()
1866
+ except Exception as e:
1867
+ logger.error(f'Worker state heartbeat error: {e}')
1868
+
1869
+ # Wait for interval or stop signal
1870
+ await asyncio.sleep(worker_state_interval_ms / 1000.0)
1871
+ except asyncio.CancelledError:
1872
+ return
1873
+
1874
+ async def _reaper_loop(self) -> None:
1875
+ """Automatic stale task handling loop.
1876
+
1877
+ Periodically checks for and recovers stale tasks based on RecoveryConfig:
1878
+ - Requeues tasks stuck in CLAIMED (safe - user code never ran)
1879
+ - Marks stale RUNNING tasks as FAILED (not safe to requeue)
1880
+ """
1881
+ from horsies.core.brokers.postgres import PostgresBroker
1882
+
1883
+ if not self.cfg.recovery_config:
1884
+ return
1885
+
1886
+ recovery_cfg = self.cfg.recovery_config
1887
+ check_interval_ms = recovery_cfg.check_interval_ms
1888
+ temp_broker = None
1889
+
1890
+ logger.info(
1891
+ f'Reaper configuration: auto_requeue_claimed={recovery_cfg.auto_requeue_stale_claimed}, '
1892
+ f'auto_fail_running={recovery_cfg.auto_fail_stale_running}, '
1893
+ f'check_interval={check_interval_ms}ms ({check_interval_ms/1000:.1f}s)'
1894
+ )
1895
+
1896
+ try:
1897
+ from horsies.core.models.broker import PostgresConfig
1898
+
1899
+ temp_broker_config = PostgresConfig(database_url=self.cfg.dsn)
1900
+ temp_broker = PostgresBroker(temp_broker_config)
1901
+ if self._app is not None:
1902
+ temp_broker.app = self._app
1903
+
1904
+ while not self._stop.is_set():
1905
+ try:
1906
+ # Auto-requeue stale CLAIMED tasks
1907
+ if recovery_cfg.auto_requeue_stale_claimed:
1908
+ requeued = await temp_broker.requeue_stale_claimed(
1909
+ stale_threshold_ms=recovery_cfg.claimed_stale_threshold_ms
1910
+ )
1911
+ if requeued > 0:
1912
+ logger.info(
1913
+ f'Reaper requeued {requeued} stale CLAIMED task(s)'
1914
+ )
1915
+
1916
+ # Auto-fail stale RUNNING tasks
1917
+ if recovery_cfg.auto_fail_stale_running:
1918
+ failed = await temp_broker.mark_stale_tasks_as_failed(
1919
+ stale_threshold_ms=recovery_cfg.running_stale_threshold_ms
1920
+ )
1921
+ if failed > 0:
1922
+ logger.warning(
1923
+ f'Reaper marked {failed} stale RUNNING task(s) as FAILED'
1924
+ )
1925
+
1926
+ # Recover stuck workflows
1927
+ try:
1928
+ from horsies.core.workflows.recovery import (
1929
+ recover_stuck_workflows,
1930
+ )
1931
+
1932
+ async with temp_broker.session_factory() as s:
1933
+ recovered = await recover_stuck_workflows(s, temp_broker)
1934
+ if recovered > 0:
1935
+ logger.info(
1936
+ f'Reaper recovered {recovered} stuck workflow task(s)'
1937
+ )
1938
+ await s.commit()
1939
+ except Exception as wf_err:
1940
+ logger.error(f'Workflow recovery error: {wf_err}')
1941
+
1942
+ except Exception as e:
1943
+ logger.error(f'Reaper loop error: {e}')
1944
+
1945
+ # Wait for check interval or stop signal (convert ms to seconds for asyncio)
1946
+ try:
1947
+ await asyncio.wait_for(
1948
+ self._stop.wait(), timeout=check_interval_ms / 1000.0
1949
+ )
1950
+ break # Stop signal received
1951
+ except asyncio.TimeoutError:
1952
+ continue # Continue to next iteration
1953
+
1954
+ except asyncio.CancelledError:
1955
+ logger.info('Reaper loop cancelled')
1956
+ return
1957
+ finally:
1958
+ if temp_broker is not None:
1959
+ try:
1960
+ await temp_broker.close_async()
1961
+ except Exception as e:
1962
+ logger.error(f'Error closing reaper broker: {e}')
1963
+
1964
+
1965
+ """
1966
+ horsies examples/instance.py worker --loglevel=info --processes=8
1967
+ """