horsies 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. horsies/__init__.py +117 -0
  2. horsies/core/__init__.py +0 -0
  3. horsies/core/app.py +552 -0
  4. horsies/core/banner.py +144 -0
  5. horsies/core/brokers/__init__.py +5 -0
  6. horsies/core/brokers/listener.py +444 -0
  7. horsies/core/brokers/postgres.py +993 -0
  8. horsies/core/cli.py +624 -0
  9. horsies/core/codec/serde.py +596 -0
  10. horsies/core/errors.py +535 -0
  11. horsies/core/logging.py +90 -0
  12. horsies/core/models/__init__.py +0 -0
  13. horsies/core/models/app.py +268 -0
  14. horsies/core/models/broker.py +79 -0
  15. horsies/core/models/queues.py +23 -0
  16. horsies/core/models/recovery.py +101 -0
  17. horsies/core/models/schedule.py +229 -0
  18. horsies/core/models/task_pg.py +307 -0
  19. horsies/core/models/tasks.py +358 -0
  20. horsies/core/models/workflow.py +1990 -0
  21. horsies/core/models/workflow_pg.py +245 -0
  22. horsies/core/registry/tasks.py +101 -0
  23. horsies/core/scheduler/__init__.py +26 -0
  24. horsies/core/scheduler/calculator.py +267 -0
  25. horsies/core/scheduler/service.py +569 -0
  26. horsies/core/scheduler/state.py +260 -0
  27. horsies/core/task_decorator.py +656 -0
  28. horsies/core/types/status.py +38 -0
  29. horsies/core/utils/imports.py +203 -0
  30. horsies/core/utils/loop_runner.py +44 -0
  31. horsies/core/worker/current.py +17 -0
  32. horsies/core/worker/worker.py +1967 -0
  33. horsies/core/workflows/__init__.py +23 -0
  34. horsies/core/workflows/engine.py +2344 -0
  35. horsies/core/workflows/recovery.py +501 -0
  36. horsies/core/workflows/registry.py +97 -0
  37. horsies/py.typed +0 -0
  38. horsies-0.1.0a4.dist-info/METADATA +35 -0
  39. horsies-0.1.0a4.dist-info/RECORD +42 -0
  40. horsies-0.1.0a4.dist-info/WHEEL +5 -0
  41. horsies-0.1.0a4.dist-info/entry_points.txt +2 -0
  42. horsies-0.1.0a4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,993 @@
1
+ # app/core/brokers/postgres.py
2
+ from __future__ import annotations
3
+ import uuid, asyncio, hashlib
4
+ from typing import Any, Optional, TYPE_CHECKING
5
+ from datetime import datetime, timezone
6
+ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
7
+ from sqlalchemy import text
8
+ from horsies.core.brokers.listener import PostgresListener
9
+ from horsies.core.models.broker import PostgresConfig
10
+ from horsies.core.models.task_pg import TaskModel, Base
11
+ from horsies.core.models.workflow_pg import WorkflowModel, WorkflowTaskModel # noqa: F401
12
+ from horsies.core.types.status import TaskStatus
13
+ from horsies.core.codec.serde import (
14
+ args_to_json,
15
+ kwargs_to_json,
16
+ loads_json,
17
+ task_result_from_json,
18
+ )
19
+ from horsies.core.models.tasks import TaskInfo
20
+ from horsies.core.utils.loop_runner import LoopRunner
21
+ from horsies.core.logging import get_logger
22
+
23
+ if TYPE_CHECKING:
24
+ from horsies.core.models.tasks import TaskResult, TaskError
25
+
26
+
27
+ class PostgresBroker:
28
+ """
29
+ PostgreSQL-based task broker with LISTEN/NOTIFY for real-time updates.
30
+
31
+ Provides both async and sync APIs:
32
+ - Async: enqueue_async(), get_result_async()
33
+ - Sync: enqueue(), get_result() (run in background event loop)
34
+
35
+ Features:
36
+ - Real-time notifications via PostgreSQL triggers
37
+ - Automatic task status tracking
38
+ - Connection pooling and health monitoring
39
+ - Operational monitoring (stale tasks, worker stats)
40
+ """
41
+
42
+ def __init__(self, config: PostgresConfig):
43
+ self.config = config
44
+ self.logger = get_logger('broker')
45
+ self._app: Any = None # Set by Horsies.get_broker()
46
+
47
+ engine_cfg = self.config.model_dump(exclude={'database_url'}, exclude_none=True)
48
+ self.async_engine = create_async_engine(self.config.database_url, **engine_cfg)
49
+ self.session_factory = async_sessionmaker(
50
+ self.async_engine, expire_on_commit=False
51
+ )
52
+
53
+ psycopg_url = self.config.database_url.replace('+asyncpg', '').replace(
54
+ '+psycopg', ''
55
+ )
56
+ self.listener = PostgresListener(psycopg_url)
57
+
58
+ self._initialized = False
59
+ self._loop_runner = LoopRunner() # for sync facades
60
+
61
+ self.logger.info('PostgresBroker initialized')
62
+
63
+ @property
64
+ def app(self) -> Any:
65
+ """Get the attached Horsies app instance (if any)."""
66
+ return self._app
67
+
68
+ @app.setter
69
+ def app(self, value: Any) -> None:
70
+ """Set the Horsies app instance."""
71
+ self._app = value
72
+
73
+ def _schema_advisory_key(self) -> int:
74
+ """
75
+ Compute a stable 64-bit advisory lock key for schema initialization.
76
+
77
+ Uses the database URL as a basis so that different clusters do not
78
+ contend on the same advisory lock key.
79
+ """
80
+ basis = self.config.database_url.encode('utf-8', errors='ignore')
81
+ h = hashlib.sha256(b'horsies-schema:' + basis).digest()
82
+ return int.from_bytes(h[:8], byteorder='big', signed=True)
83
+
84
+ async def _create_triggers(self) -> None:
85
+ """
86
+ Set up PostgreSQL triggers for automatic task notifications.
87
+
88
+ Creates triggers that send NOTIFY messages on:
89
+ - INSERT: Sends task_new + task_queue_{queue_name} notifications
90
+ - UPDATE to COMPLETED/FAILED: Sends task_done notification
91
+
92
+ This enables real-time task processing without polling.
93
+ """
94
+ async with self.async_engine.begin() as conn:
95
+ # Create trigger function
96
+ await conn.execute(
97
+ text("""
98
+ CREATE OR REPLACE FUNCTION horsies_notify_task_changes()
99
+ RETURNS trigger AS $$
100
+ BEGIN
101
+ IF TG_OP = 'INSERT' AND NEW.status = 'PENDING' THEN
102
+ -- New task notifications: wake up workers
103
+ PERFORM pg_notify('task_new', NEW.id); -- Global worker notification
104
+ PERFORM pg_notify('task_queue_' || NEW.queue_name, NEW.id); -- Queue-specific notification
105
+ ELSIF TG_OP = 'UPDATE' AND OLD.status != NEW.status THEN
106
+ -- Task completion notifications: wake up result waiters
107
+ IF NEW.status IN ('COMPLETED', 'FAILED') THEN
108
+ PERFORM pg_notify('task_done', NEW.id); -- Send task_id as payload
109
+ END IF;
110
+ END IF;
111
+ RETURN NEW;
112
+ END;
113
+ $$ LANGUAGE plpgsql;
114
+ """)
115
+ )
116
+
117
+ # Create trigger
118
+ await conn.execute(
119
+ text("""
120
+ DROP TRIGGER IF EXISTS horsies_task_notify_trigger ON horsies_tasks;
121
+ CREATE TRIGGER horsies_task_notify_trigger
122
+ AFTER INSERT OR UPDATE ON horsies_tasks
123
+ FOR EACH ROW
124
+ EXECUTE FUNCTION horsies_notify_task_changes();
125
+ """)
126
+ )
127
+
128
+ async def _create_workflow_schema(self) -> None:
129
+ """
130
+ Set up workflow-specific schema elements.
131
+
132
+ Creates:
133
+ - GIN index on workflow_tasks.dependencies for efficient dependency lookups
134
+ - Trigger for workflow completion notifications
135
+ - Migration: adds task_options column if missing (for existing installs)
136
+ """
137
+ async with self.async_engine.begin() as conn:
138
+ # GIN index for efficient dependency array lookups
139
+ await conn.execute(
140
+ text("""
141
+ CREATE INDEX IF NOT EXISTS idx_horsies_workflow_tasks_deps
142
+ ON horsies_workflow_tasks USING GIN(dependencies);
143
+ """)
144
+ )
145
+
146
+ # Migration: add task_options column for existing installs
147
+ await conn.execute(
148
+ text("""
149
+ ALTER TABLE horsies_workflow_tasks
150
+ ADD COLUMN IF NOT EXISTS task_options TEXT;
151
+ """)
152
+ )
153
+
154
+ # Migration: add success_policy column for existing installs
155
+ await conn.execute(
156
+ text("""
157
+ ALTER TABLE horsies_workflows
158
+ ADD COLUMN IF NOT EXISTS success_policy JSONB;
159
+ """)
160
+ )
161
+
162
+ # Migration: add join_type and min_success columns for existing installs
163
+ await conn.execute(
164
+ text("""
165
+ ALTER TABLE horsies_workflow_tasks
166
+ ADD COLUMN IF NOT EXISTS join_type VARCHAR(10) NOT NULL DEFAULT 'all';
167
+ """)
168
+ )
169
+ await conn.execute(
170
+ text("""
171
+ ALTER TABLE horsies_workflow_tasks
172
+ ADD COLUMN IF NOT EXISTS min_success INTEGER;
173
+ """)
174
+ )
175
+ await conn.execute(
176
+ text("""
177
+ ALTER TABLE horsies_workflow_tasks
178
+ ADD COLUMN IF NOT EXISTS node_id VARCHAR(128);
179
+ """)
180
+ )
181
+ await conn.execute(
182
+ text("""
183
+ ALTER TABLE horsies_workflow_tasks
184
+ ALTER COLUMN workflow_ctx_from
185
+ TYPE VARCHAR(128)[]
186
+ USING workflow_ctx_from::VARCHAR(128)[];
187
+ """)
188
+ )
189
+
190
+ # Subworkflow support columns
191
+ await conn.execute(
192
+ text("""
193
+ ALTER TABLE horsies_workflows
194
+ ADD COLUMN IF NOT EXISTS parent_workflow_id VARCHAR(36);
195
+ """)
196
+ )
197
+ await conn.execute(
198
+ text("""
199
+ ALTER TABLE horsies_workflows
200
+ ADD COLUMN IF NOT EXISTS parent_task_index INTEGER;
201
+ """)
202
+ )
203
+ await conn.execute(
204
+ text("""
205
+ ALTER TABLE horsies_workflows
206
+ ADD COLUMN IF NOT EXISTS depth INTEGER NOT NULL DEFAULT 0;
207
+ """)
208
+ )
209
+ await conn.execute(
210
+ text("""
211
+ ALTER TABLE horsies_workflows
212
+ ADD COLUMN IF NOT EXISTS root_workflow_id VARCHAR(36);
213
+ """)
214
+ )
215
+ await conn.execute(
216
+ text("""
217
+ ALTER TABLE horsies_workflows
218
+ ADD COLUMN IF NOT EXISTS workflow_def_module VARCHAR(512);
219
+ """)
220
+ )
221
+ await conn.execute(
222
+ text("""
223
+ ALTER TABLE horsies_workflows
224
+ ADD COLUMN IF NOT EXISTS workflow_def_qualname VARCHAR(512);
225
+ """)
226
+ )
227
+
228
+ await conn.execute(
229
+ text("""
230
+ ALTER TABLE horsies_workflow_tasks
231
+ ADD COLUMN IF NOT EXISTS is_subworkflow BOOLEAN NOT NULL DEFAULT FALSE;
232
+ """)
233
+ )
234
+ await conn.execute(
235
+ text("""
236
+ ALTER TABLE horsies_workflow_tasks
237
+ ADD COLUMN IF NOT EXISTS sub_workflow_id VARCHAR(36);
238
+ """)
239
+ )
240
+ await conn.execute(
241
+ text("""
242
+ ALTER TABLE horsies_workflow_tasks
243
+ ADD COLUMN IF NOT EXISTS sub_workflow_name VARCHAR(255);
244
+ """)
245
+ )
246
+ await conn.execute(
247
+ text("""
248
+ ALTER TABLE horsies_workflow_tasks
249
+ ADD COLUMN IF NOT EXISTS sub_workflow_retry_mode VARCHAR(50);
250
+ """)
251
+ )
252
+ await conn.execute(
253
+ text("""
254
+ ALTER TABLE horsies_workflow_tasks
255
+ ADD COLUMN IF NOT EXISTS sub_workflow_summary TEXT;
256
+ """)
257
+ )
258
+ await conn.execute(
259
+ text("""
260
+ ALTER TABLE horsies_workflow_tasks
261
+ ADD COLUMN IF NOT EXISTS sub_workflow_module VARCHAR(512);
262
+ """)
263
+ )
264
+ await conn.execute(
265
+ text("""
266
+ ALTER TABLE horsies_workflow_tasks
267
+ ADD COLUMN IF NOT EXISTS sub_workflow_qualname VARCHAR(512);
268
+ """)
269
+ )
270
+
271
+ # Workflow notification trigger function
272
+ await conn.execute(
273
+ text("""
274
+ CREATE OR REPLACE FUNCTION horsies_notify_workflow_changes()
275
+ RETURNS trigger AS $$
276
+ BEGIN
277
+ IF TG_OP = 'UPDATE' AND OLD.status != NEW.status THEN
278
+ -- Workflow completion notifications
279
+ IF NEW.status IN ('COMPLETED', 'FAILED', 'CANCELLED', 'PAUSED') THEN
280
+ PERFORM pg_notify('workflow_done', NEW.id);
281
+ END IF;
282
+ END IF;
283
+ RETURN NEW;
284
+ END;
285
+ $$ LANGUAGE plpgsql;
286
+ """)
287
+ )
288
+
289
+ # Create workflow trigger
290
+ await conn.execute(
291
+ text("""
292
+ DROP TRIGGER IF EXISTS horsies_workflow_notify_trigger ON horsies_workflows;
293
+ CREATE TRIGGER horsies_workflow_notify_trigger
294
+ AFTER UPDATE ON horsies_workflows
295
+ FOR EACH ROW
296
+ EXECUTE FUNCTION horsies_notify_workflow_changes();
297
+ """)
298
+ )
299
+
300
+ async def _ensure_initialized(self) -> None:
301
+ if self._initialized:
302
+ return
303
+ async with self.async_engine.begin() as conn:
304
+ # Take a short-lived, cluster-wide advisory lock to serialize
305
+ # schema creation across workers and producers.
306
+ await conn.execute(
307
+ text('SELECT pg_advisory_xact_lock(CAST(:key AS BIGINT))'),
308
+ {'key': self._schema_advisory_key()},
309
+ )
310
+ await conn.run_sync(Base.metadata.create_all)
311
+
312
+ await self._create_triggers()
313
+ await self._create_workflow_schema()
314
+ await self.listener.start()
315
+ self._initialized = True
316
+
317
+ async def ensure_schema_initialized(self) -> None:
318
+ """
319
+ Public entry point to ensure tables and triggers exist.
320
+
321
+ Safe to call multiple times and from multiple processes; internally
322
+ guarded by a PostgreSQL advisory lock to avoid DDL races.
323
+ """
324
+ await self._ensure_initialized()
325
+
326
+ # ----------------- Async API -----------------
327
+
328
+ async def enqueue_async(
329
+ self,
330
+ task_name: str,
331
+ args: tuple[Any, ...],
332
+ kwargs: dict[str, Any],
333
+ queue_name: str = 'default',
334
+ *,
335
+ priority: int = 100,
336
+ sent_at: Optional[datetime] = None,
337
+ good_until: Optional[datetime] = None,
338
+ task_options: Optional[str] = None,
339
+ ) -> str:
340
+ await self._ensure_initialized()
341
+
342
+ task_id = str(uuid.uuid4())
343
+ now = datetime.now(timezone.utc)
344
+ sent = sent_at or now
345
+
346
+ # Parse retry configuration from task_options
347
+ max_retries = 0
348
+ if task_options:
349
+ try:
350
+ options_data = loads_json(task_options)
351
+ if isinstance(options_data, dict):
352
+ retry_policy = options_data.get('retry_policy')
353
+ if isinstance(retry_policy, dict):
354
+ max_retries = retry_policy.get('max_retries', 3)
355
+ except Exception:
356
+ pass
357
+
358
+ async with self.session_factory() as session:
359
+ task = TaskModel(
360
+ id=task_id,
361
+ task_name=task_name,
362
+ queue_name=queue_name,
363
+ priority=priority,
364
+ args=args_to_json(args) if args else None,
365
+ kwargs=kwargs_to_json(kwargs) if kwargs else None,
366
+ status=TaskStatus.PENDING,
367
+ sent_at=sent,
368
+ good_until=good_until,
369
+ max_retries=max_retries,
370
+ task_options=task_options,
371
+ created_at=now,
372
+ updated_at=now,
373
+ )
374
+ session.add(task)
375
+ await session.commit()
376
+
377
+ # PostgreSQL trigger automatically sends task_new + task_queue_{queue_name} notifications
378
+ return task_id
379
+
380
+ async def get_result_async(
381
+ self, task_id: str, timeout_ms: Optional[int] = None
382
+ ) -> 'TaskResult[Any, TaskError]':
383
+ """
384
+ Get task result, waiting if necessary.
385
+
386
+ Returns TaskResult for task completion and retrieval outcomes:
387
+ - Success: TaskResult(ok=value) from task execution
388
+ - Task error: TaskResult(err=TaskError) from task execution
389
+ - Retrieval error: TaskResult(err=TaskError) with WAIT_TIMEOUT, TASK_NOT_FOUND, or TASK_CANCELLED
390
+
391
+ Broker failures (e.g., database errors) are returned as BROKER_ERROR.
392
+ """
393
+ from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
394
+
395
+ try:
396
+ await self._ensure_initialized()
397
+
398
+ start_time = asyncio.get_event_loop().time()
399
+
400
+ # Convert milliseconds to seconds for internal use
401
+ timeout_seconds: Optional[float] = None
402
+ if timeout_ms is not None:
403
+ timeout_seconds = timeout_ms / 1000.0
404
+
405
+ # Quick path - check if task is already completed
406
+ async with self.session_factory() as session:
407
+ row = await session.get(TaskModel, task_id)
408
+ if row is None:
409
+ self.logger.error(f'Task {task_id} not found')
410
+ return TaskResult(
411
+ err=TaskError(
412
+ error_code=LibraryErrorCode.TASK_NOT_FOUND,
413
+ message=f'Task {task_id} not found in database',
414
+ data={'task_id': task_id},
415
+ )
416
+ )
417
+ if row.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
418
+ self.logger.info(f'Task {task_id} already completed')
419
+ return task_result_from_json(loads_json(row.result))
420
+ if row.status == TaskStatus.CANCELLED:
421
+ self.logger.info(f'Task {task_id} was cancelled')
422
+ return TaskResult(
423
+ err=TaskError(
424
+ error_code=LibraryErrorCode.TASK_CANCELLED,
425
+ message=f'Task {task_id} was cancelled before completion',
426
+ data={'task_id': task_id},
427
+ )
428
+ )
429
+
430
+ # Listen for task completion notifications with polling fallback
431
+ # Multiple clients can listen to task_done; each filters for their specific task_id
432
+ q = await self.listener.listen('task_done')
433
+ try:
434
+ poll_interval = (
435
+ 5.0 # Fallback polling interval (handles lost notifications)
436
+ )
437
+
438
+ while True:
439
+ # Calculate remaining timeout
440
+ remaining_timeout = None
441
+ if timeout_seconds:
442
+ elapsed = asyncio.get_event_loop().time() - start_time
443
+ remaining_timeout = timeout_seconds - elapsed
444
+ if remaining_timeout <= 0:
445
+ return TaskResult(
446
+ err=TaskError(
447
+ error_code=LibraryErrorCode.WAIT_TIMEOUT,
448
+ message=f'Timed out waiting for task {task_id} after {timeout_ms}ms. Task may still be running.',
449
+ data={'task_id': task_id, 'timeout_ms': timeout_ms},
450
+ )
451
+ )
452
+
453
+ # Wait for NOTIFY or timeout (whichever comes first)
454
+ wait_time = (
455
+ min(poll_interval, remaining_timeout)
456
+ if remaining_timeout
457
+ else poll_interval
458
+ )
459
+
460
+ try:
461
+
462
+ async def _wait_for_task() -> None:
463
+ # Filter notifications: only process our specific task_id
464
+ while True:
465
+ note = (
466
+ await q.get()
467
+ ) # Blocks until any task_done notification
468
+ if (
469
+ note.payload == task_id
470
+ ): # Check if it's for our task
471
+ return # Found our task completion!
472
+
473
+ # Wait for our specific task notification with timeout
474
+ await asyncio.wait_for(_wait_for_task(), timeout=wait_time)
475
+ break # Got our task completion notification
476
+
477
+ except asyncio.TimeoutError:
478
+ # Fallback polling: handles lost notifications or crashed workers
479
+ # Ensures eventual consistency even if NOTIFY system fails
480
+ async with self.session_factory() as session:
481
+ row = await session.get(TaskModel, task_id)
482
+ if row is None:
483
+ return TaskResult(
484
+ err=TaskError(
485
+ error_code=LibraryErrorCode.TASK_NOT_FOUND,
486
+ message=f'Task {task_id} not found in database',
487
+ data={'task_id': task_id},
488
+ )
489
+ )
490
+ if row.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
491
+ self.logger.debug(
492
+ f'Task {task_id} completed, polling database'
493
+ )
494
+ return task_result_from_json(loads_json(row.result))
495
+ if row.status == TaskStatus.CANCELLED:
496
+ self.logger.error(f'Task {task_id} was cancelled')
497
+ return TaskResult(
498
+ err=TaskError(
499
+ error_code=LibraryErrorCode.TASK_CANCELLED,
500
+ message=f'Task {task_id} was cancelled before completion',
501
+ data={'task_id': task_id},
502
+ )
503
+ )
504
+ # Task still not done, continue waiting...
505
+
506
+ finally:
507
+ # Clean up subscription (keeps server-side LISTEN active for other waiters)
508
+ await self.listener.unsubscribe('task_done', q)
509
+
510
+ # Final database check to get actual task result after notification
511
+ async with self.session_factory() as session:
512
+ row = await session.get(TaskModel, task_id)
513
+ if row is None:
514
+ self.logger.error(f'Task {task_id} not found')
515
+ return TaskResult(
516
+ err=TaskError(
517
+ error_code=LibraryErrorCode.TASK_NOT_FOUND,
518
+ message=f'Task {task_id} not found in database',
519
+ data={'task_id': task_id},
520
+ )
521
+ )
522
+ if row.status == TaskStatus.CANCELLED:
523
+ self.logger.error(f'Task {task_id} was cancelled')
524
+ return TaskResult(
525
+ err=TaskError(
526
+ error_code=LibraryErrorCode.TASK_CANCELLED,
527
+ message=f'Task {task_id} was cancelled before completion',
528
+ data={'task_id': task_id},
529
+ )
530
+ )
531
+ if row.status not in (TaskStatus.COMPLETED, TaskStatus.FAILED):
532
+ return TaskResult(
533
+ err=TaskError(
534
+ error_code=LibraryErrorCode.WAIT_TIMEOUT,
535
+ message=f'Task {task_id} not completed after notification (status: {row.status}). Task may still be running.',
536
+ data={'task_id': task_id, 'status': str(row.status)},
537
+ )
538
+ )
539
+ return task_result_from_json(loads_json(row.result))
540
+ except asyncio.CancelledError:
541
+ raise
542
+ except Exception as exc:
543
+ self.logger.exception('Broker error while retrieving task result')
544
+ return TaskResult(
545
+ err=TaskError(
546
+ error_code=LibraryErrorCode.BROKER_ERROR,
547
+ message='Broker error while retrieving task result',
548
+ data={'task_id': task_id, 'timeout_ms': timeout_ms},
549
+ exception=exc,
550
+ )
551
+ )
552
+
553
+ async def close_async(self) -> None:
554
+ await self.listener.close()
555
+ await self.async_engine.dispose()
556
+
557
+ # ------------- Operational & Monitoring Methods -------------
558
+
559
+ async def get_stale_tasks(
560
+ self, stale_threshold_minutes: int = 2
561
+ ) -> list[dict[str, Any]]:
562
+ """
563
+ Identify potentially crashed tasks based on heartbeat absence.
564
+
565
+ Finds RUNNING tasks whose workers haven't sent heartbeats within the threshold,
566
+ indicating the worker process may have crashed or become unresponsive.
567
+
568
+ Args:
569
+ stale_threshold_minutes: Minutes without heartbeat to consider stale
570
+
571
+ Returns:
572
+ List of task info dicts: id, worker_hostname, worker_pid, last_heartbeat
573
+ """
574
+ async with self.session_factory() as session:
575
+ result = await session.execute(
576
+ text("""
577
+ SELECT
578
+ t.id,
579
+ t.worker_hostname,
580
+ t.worker_pid,
581
+ t.worker_process_name,
582
+ hb.last_heartbeat,
583
+ t.started_at,
584
+ t.task_name
585
+ FROM horsies_tasks t
586
+ LEFT JOIN LATERAL (
587
+ SELECT sent_at AS last_heartbeat
588
+ FROM horsies_heartbeats h
589
+ WHERE h.task_id = t.id AND h.role = 'runner'
590
+ ORDER BY sent_at DESC
591
+ LIMIT 1
592
+ ) hb ON TRUE
593
+ WHERE t.status = 'RUNNING'
594
+ AND t.started_at IS NOT NULL
595
+ AND COALESCE(hb.last_heartbeat, t.started_at) < NOW() - CAST(:stale_threshold || ' minutes' AS INTERVAL)
596
+ ORDER BY hb.last_heartbeat NULLS FIRST
597
+ """),
598
+ {'stale_threshold': stale_threshold_minutes},
599
+ )
600
+ columns = result.keys()
601
+ return [dict(zip(columns, row)) for row in result.fetchall()]
602
+
603
+ async def get_worker_stats(self) -> list[dict[str, Any]]:
604
+ """
605
+ Gather statistics about active worker processes.
606
+
607
+ Groups RUNNING tasks by worker to show load distribution and health.
608
+ Useful for monitoring worker performance and identifying bottlenecks.
609
+
610
+ Returns:
611
+ List of worker stats: worker_hostname, worker_pid, active_tasks, oldest_task_start
612
+ """
613
+ async with self.session_factory() as session:
614
+ result = await session.execute(
615
+ text("""
616
+ SELECT
617
+ t.worker_hostname,
618
+ t.worker_pid,
619
+ t.worker_process_name,
620
+ COUNT(*) AS active_tasks,
621
+ MIN(t.started_at) AS oldest_task_start,
622
+ MAX(hb.last_heartbeat) AS latest_heartbeat
623
+ FROM horsies_tasks t
624
+ LEFT JOIN LATERAL (
625
+ SELECT sent_at AS last_heartbeat
626
+ FROM horsies_heartbeats h
627
+ WHERE h.task_id = t.id AND h.role = 'runner'
628
+ ORDER BY sent_at DESC
629
+ LIMIT 1
630
+ ) hb ON TRUE
631
+ WHERE t.status = 'RUNNING'
632
+ AND t.worker_hostname IS NOT NULL
633
+ GROUP BY t.worker_hostname, t.worker_pid, t.worker_process_name
634
+ ORDER BY active_tasks DESC
635
+ """)
636
+ )
637
+
638
+ columns = result.keys()
639
+ return [dict(zip(columns, row)) for row in result.fetchall()]
640
+
641
+ async def get_expired_tasks(self) -> list[dict[str, Any]]:
642
+ """
643
+ Find tasks that expired before worker processing.
644
+
645
+ Identifies PENDING tasks that exceeded their good_until deadline,
646
+ indicating potential worker capacity issues or scheduling problems.
647
+
648
+ Returns:
649
+ List of expired task info: id, task_name, queue_name, good_until, expired_for
650
+ """
651
+ async with self.session_factory() as session:
652
+ result = await session.execute(
653
+ text("""
654
+ SELECT
655
+ id,
656
+ task_name,
657
+ queue_name,
658
+ priority,
659
+ sent_at,
660
+ good_until,
661
+ NOW() - good_until as expired_for
662
+ FROM horsies_tasks
663
+ WHERE status = 'PENDING'
664
+ AND good_until < NOW()
665
+ ORDER BY good_until ASC
666
+ """)
667
+ )
668
+
669
+ columns = result.keys()
670
+ return [dict(zip(columns, row)) for row in result.fetchall()]
671
+
672
+ async def mark_stale_tasks_as_failed(
673
+ self, stale_threshold_ms: int = 300_000
674
+ ) -> int:
675
+ """
676
+ Clean up crashed worker tasks by marking them as FAILED.
677
+
678
+ Updates RUNNING tasks that haven't received heartbeats within the threshold.
679
+ This is typically called by a cleanup process to handle worker crashes.
680
+ Creates a proper TaskResult with WORKER_CRASHED error code.
681
+
682
+ Args:
683
+ stale_threshold_ms: Milliseconds without heartbeat to consider crashed (default: 300000 = 5 minutes)
684
+
685
+ Returns:
686
+ Number of tasks marked as failed
687
+ """
688
+ from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
689
+ from horsies.core.codec.serde import dumps_json
690
+
691
+ # Convert milliseconds to seconds for PostgreSQL INTERVAL
692
+ stale_threshold_seconds = stale_threshold_ms / 1000.0
693
+
694
+ async with self.session_factory() as session:
695
+ # First, find stale tasks and get their metadata
696
+ stale_tasks_result = await session.execute(
697
+ text("""
698
+ SELECT t2.id, t2.worker_pid, t2.worker_hostname, t2.claimed_by_worker_id,
699
+ t2.started_at, hb.last_heartbeat
700
+ FROM horsies_tasks t2
701
+ LEFT JOIN LATERAL (
702
+ SELECT sent_at AS last_heartbeat
703
+ FROM horsies_heartbeats h
704
+ WHERE h.task_id = t2.id AND h.role = 'runner'
705
+ ORDER BY sent_at DESC
706
+ LIMIT 1
707
+ ) hb ON TRUE
708
+ WHERE t2.status = 'RUNNING'
709
+ AND t2.started_at IS NOT NULL
710
+ AND COALESCE(hb.last_heartbeat, t2.started_at) < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL)
711
+ """),
712
+ {'stale_threshold': stale_threshold_seconds},
713
+ )
714
+
715
+ stale_tasks = stale_tasks_result.fetchall()
716
+ if not stale_tasks:
717
+ return 0
718
+
719
+ # Mark each stale task as failed with proper TaskResult
720
+ for task_row in stale_tasks:
721
+ task_id = task_row[0]
722
+ worker_pid = task_row[1]
723
+ worker_hostname = task_row[2]
724
+ worker_id = task_row[3]
725
+ started_at = task_row[4]
726
+ last_heartbeat = task_row[5]
727
+
728
+ # Create TaskResult with WORKER_CRASHED error
729
+ task_error = TaskError(
730
+ error_code=LibraryErrorCode.WORKER_CRASHED,
731
+ message=f'Worker process crashed (no runner heartbeat for {stale_threshold_ms}ms = {stale_threshold_ms/1000:.1f}s)',
732
+ data={
733
+ 'stale_threshold_ms': stale_threshold_ms,
734
+ 'stale_threshold_seconds': stale_threshold_seconds,
735
+ 'worker_pid': worker_pid,
736
+ 'worker_hostname': worker_hostname,
737
+ 'worker_id': worker_id,
738
+ 'started_at': started_at.isoformat() if started_at else None,
739
+ 'last_heartbeat': last_heartbeat.isoformat()
740
+ if last_heartbeat
741
+ else None,
742
+ 'detected_at': datetime.now(timezone.utc).isoformat(),
743
+ },
744
+ )
745
+ task_result: TaskResult[None, TaskError] = TaskResult(err=task_error)
746
+ result_json = dumps_json(task_result)
747
+
748
+ # Update task with proper result
749
+ await session.execute(
750
+ text("""
751
+ UPDATE horsies_tasks
752
+ SET status = 'FAILED',
753
+ failed_at = NOW(),
754
+ failed_reason = :failed_reason,
755
+ result = :result,
756
+ updated_at = NOW()
757
+ WHERE id = :task_id
758
+ """),
759
+ {
760
+ 'task_id': task_id,
761
+ 'failed_reason': f'Worker process crashed (no runner heartbeat for {stale_threshold_ms}ms = {stale_threshold_ms/1000:.1f}s)',
762
+ 'result': result_json,
763
+ },
764
+ )
765
+
766
+ await session.commit()
767
+ return len(stale_tasks)
768
+
769
+ async def requeue_stale_claimed(self, stale_threshold_ms: int = 120_000) -> int:
770
+ """
771
+ Requeue tasks stuck in CLAIMED without recent claimer heartbeat.
772
+
773
+ Args:
774
+ stale_threshold_ms: Milliseconds without heartbeat to consider stale (default: 120000 = 2 minutes)
775
+
776
+ Returns:
777
+ Number of tasks requeued
778
+ """
779
+ # Convert milliseconds to seconds for PostgreSQL INTERVAL
780
+ stale_threshold_seconds = stale_threshold_ms / 1000.0
781
+
782
+ async with self.session_factory() as session:
783
+ result = await session.execute(
784
+ text("""
785
+ UPDATE horsies_tasks AS t
786
+ SET status = 'PENDING',
787
+ claimed = FALSE,
788
+ claimed_at = NULL,
789
+ claimed_by_worker_id = NULL,
790
+ updated_at = NOW()
791
+ FROM (
792
+ SELECT t2.id, hb.last_heartbeat, t2.claimed_at
793
+ FROM horsies_tasks t2
794
+ LEFT JOIN LATERAL (
795
+ SELECT sent_at AS last_heartbeat
796
+ FROM horsies_heartbeats h
797
+ WHERE h.task_id = t2.id AND h.role = 'claimer'
798
+ ORDER BY sent_at DESC
799
+ LIMIT 1
800
+ ) hb ON TRUE
801
+ WHERE t2.status = 'CLAIMED'
802
+ ) s
803
+ WHERE t.id = s.id
804
+ AND (
805
+ (s.last_heartbeat IS NULL AND s.claimed_at IS NOT NULL AND s.claimed_at < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL))
806
+ OR (s.last_heartbeat IS NOT NULL AND s.last_heartbeat < NOW() - CAST(:stale_threshold || ' seconds' AS INTERVAL))
807
+ )
808
+ """),
809
+ {'stale_threshold': stale_threshold_seconds},
810
+ )
811
+ await session.commit()
812
+ return getattr(result, 'rowcount', 0)
813
+
814
+ # ----------------- Sync API Facades -----------------
815
+
816
+ def enqueue(
817
+ self,
818
+ task_name: str,
819
+ args: tuple[Any, ...],
820
+ kwargs: dict[str, Any],
821
+ queue_name: str = 'default',
822
+ *,
823
+ priority: int = 100,
824
+ sent_at: Optional[datetime] = None,
825
+ good_until: Optional[datetime] = None,
826
+ task_options: Optional[str] = None,
827
+ ) -> str:
828
+ """
829
+ Synchronous task submission (runs enqueue_async in background loop).
830
+ """
831
+ return self._loop_runner.call(
832
+ self.enqueue_async,
833
+ task_name,
834
+ args,
835
+ kwargs,
836
+ queue_name,
837
+ priority=priority,
838
+ sent_at=sent_at,
839
+ good_until=good_until,
840
+ task_options=task_options,
841
+ )
842
+
843
+ def get_result(
844
+ self, task_id: str, timeout_ms: Optional[int] = None
845
+ ) -> 'TaskResult[Any, TaskError]':
846
+ """
847
+ Synchronous result retrieval (runs get_result_async in background loop).
848
+
849
+ Args:
850
+ task_id: The task ID to retrieve result for
851
+ timeout_ms: Maximum time to wait for result (milliseconds)
852
+
853
+ Returns:
854
+ TaskResult - success, task error, retrieval error, or broker error
855
+ """
856
+ return self._loop_runner.call(self.get_result_async, task_id, timeout_ms)
857
+
858
+ async def get_task_info_async(
859
+ self,
860
+ task_id: str,
861
+ *,
862
+ include_result: bool = False,
863
+ include_failed_reason: bool = False,
864
+ ) -> 'TaskInfo | None':
865
+ """Fetch metadata for a task by ID."""
866
+ await self._ensure_initialized()
867
+
868
+ async with self.session_factory() as session:
869
+ base_columns = [
870
+ 'id',
871
+ 'task_name',
872
+ 'status',
873
+ 'queue_name',
874
+ 'priority',
875
+ 'retry_count',
876
+ 'max_retries',
877
+ 'next_retry_at',
878
+ 'sent_at',
879
+ 'claimed_at',
880
+ 'started_at',
881
+ 'completed_at',
882
+ 'failed_at',
883
+ 'worker_hostname',
884
+ 'worker_pid',
885
+ 'worker_process_name',
886
+ ]
887
+ if include_result:
888
+ base_columns.append('result')
889
+ if include_failed_reason:
890
+ base_columns.append('failed_reason')
891
+
892
+ query = text(
893
+ f"""
894
+ SELECT {', '.join(base_columns)}
895
+ FROM horsies_tasks
896
+ WHERE id = :id
897
+ """
898
+ )
899
+ result = await session.execute(query, {'id': task_id})
900
+ row = result.fetchone()
901
+ if row is None:
902
+ return None
903
+
904
+ result_value = None
905
+ failed_reason = None
906
+
907
+ idx = 0
908
+ task_id_value = row[idx]
909
+ idx += 1
910
+ task_name = row[idx]
911
+ idx += 1
912
+ status = TaskStatus(row[idx])
913
+ idx += 1
914
+ queue_name = row[idx]
915
+ idx += 1
916
+ priority = row[idx]
917
+ idx += 1
918
+ retry_count = row[idx] or 0
919
+ idx += 1
920
+ max_retries = row[idx] or 0
921
+ idx += 1
922
+ next_retry_at = row[idx]
923
+ idx += 1
924
+ sent_at = row[idx]
925
+ idx += 1
926
+ claimed_at = row[idx]
927
+ idx += 1
928
+ started_at = row[idx]
929
+ idx += 1
930
+ completed_at = row[idx]
931
+ idx += 1
932
+ failed_at = row[idx]
933
+ idx += 1
934
+ worker_hostname = row[idx]
935
+ idx += 1
936
+ worker_pid = row[idx]
937
+ idx += 1
938
+ worker_process_name = row[idx]
939
+ idx += 1
940
+
941
+ if include_result:
942
+ raw_result = row[idx]
943
+ idx += 1
944
+ if raw_result:
945
+ result_value = task_result_from_json(loads_json(raw_result))
946
+
947
+ if include_failed_reason:
948
+ failed_reason = row[idx]
949
+
950
+ return TaskInfo(
951
+ task_id=task_id_value,
952
+ task_name=task_name,
953
+ status=status,
954
+ queue_name=queue_name,
955
+ priority=priority,
956
+ retry_count=retry_count,
957
+ max_retries=max_retries,
958
+ next_retry_at=next_retry_at,
959
+ sent_at=sent_at,
960
+ claimed_at=claimed_at,
961
+ started_at=started_at,
962
+ completed_at=completed_at,
963
+ failed_at=failed_at,
964
+ worker_hostname=worker_hostname,
965
+ worker_pid=worker_pid,
966
+ worker_process_name=worker_process_name,
967
+ result=result_value,
968
+ failed_reason=failed_reason,
969
+ )
970
+
971
+ def get_task_info(
972
+ self,
973
+ task_id: str,
974
+ *,
975
+ include_result: bool = False,
976
+ include_failed_reason: bool = False,
977
+ ) -> 'TaskInfo | None':
978
+ """Synchronous wrapper for get_task_info_async()."""
979
+ return self._loop_runner.call(
980
+ self.get_task_info_async,
981
+ task_id,
982
+ include_result=include_result,
983
+ include_failed_reason=include_failed_reason,
984
+ )
985
+
986
+ def close(self) -> None:
987
+ """
988
+ Synchronous cleanup (runs close_async in background loop).
989
+ """
990
+ try:
991
+ self._loop_runner.call(self.close_async)
992
+ finally:
993
+ self._loop_runner.stop()