horsies 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. horsies/__init__.py +117 -0
  2. horsies/core/__init__.py +0 -0
  3. horsies/core/app.py +552 -0
  4. horsies/core/banner.py +144 -0
  5. horsies/core/brokers/__init__.py +5 -0
  6. horsies/core/brokers/listener.py +444 -0
  7. horsies/core/brokers/postgres.py +993 -0
  8. horsies/core/cli.py +624 -0
  9. horsies/core/codec/serde.py +596 -0
  10. horsies/core/errors.py +535 -0
  11. horsies/core/logging.py +90 -0
  12. horsies/core/models/__init__.py +0 -0
  13. horsies/core/models/app.py +268 -0
  14. horsies/core/models/broker.py +79 -0
  15. horsies/core/models/queues.py +23 -0
  16. horsies/core/models/recovery.py +101 -0
  17. horsies/core/models/schedule.py +229 -0
  18. horsies/core/models/task_pg.py +307 -0
  19. horsies/core/models/tasks.py +358 -0
  20. horsies/core/models/workflow.py +1990 -0
  21. horsies/core/models/workflow_pg.py +245 -0
  22. horsies/core/registry/tasks.py +101 -0
  23. horsies/core/scheduler/__init__.py +26 -0
  24. horsies/core/scheduler/calculator.py +267 -0
  25. horsies/core/scheduler/service.py +569 -0
  26. horsies/core/scheduler/state.py +260 -0
  27. horsies/core/task_decorator.py +656 -0
  28. horsies/core/types/status.py +38 -0
  29. horsies/core/utils/imports.py +203 -0
  30. horsies/core/utils/loop_runner.py +44 -0
  31. horsies/core/worker/current.py +17 -0
  32. horsies/core/worker/worker.py +1967 -0
  33. horsies/core/workflows/__init__.py +23 -0
  34. horsies/core/workflows/engine.py +2344 -0
  35. horsies/core/workflows/recovery.py +501 -0
  36. horsies/core/workflows/registry.py +97 -0
  37. horsies/py.typed +0 -0
  38. horsies-0.1.0a4.dist-info/METADATA +35 -0
  39. horsies-0.1.0a4.dist-info/RECORD +42 -0
  40. horsies-0.1.0a4.dist-info/WHEEL +5 -0
  41. horsies-0.1.0a4.dist-info/entry_points.txt +2 -0
  42. horsies-0.1.0a4.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1990 @@
1
+ """Core workflow models for DAG-based task orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from dataclasses import dataclass, field
7
+ import re
8
+ from datetime import datetime
9
+ from enum import Enum
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Any,
13
+ TypeVar,
14
+ Generic,
15
+ cast,
16
+ Callable,
17
+ Literal,
18
+ ClassVar,
19
+ )
20
+ import inspect
21
+ import time
22
+ from horsies.core.utils.loop_runner import LoopRunner
23
+ from horsies.core.errors import ErrorCode, SourceLocation, ValidationReport, raise_collected
24
+ from pydantic import BaseModel
25
+
26
+ if TYPE_CHECKING:
27
+ from horsies.core.task_decorator import TaskFunction
28
+ from horsies.core.brokers.postgres import PostgresBroker
29
+ from horsies.core.models.tasks import TaskResult, TaskError
30
+ from horsies.core.app import Horsies
31
+
32
+
33
+ # TypeVar for TaskNode generic parameter (the "ok" type of TaskResult)
34
+ OkT = TypeVar('OkT')
35
+ OkT_co = TypeVar('OkT_co', covariant=True)
36
+
37
+
38
+ # =============================================================================
39
+ # Enums
40
+ # =============================================================================
41
+
42
+
43
+ class WorkflowStatus(str, Enum):
44
+ """
45
+ Status of a workflow instance.
46
+
47
+ State machine:
48
+ PENDING → RUNNING → COMPLETED
49
+ → FAILED (on task failure with on_error=FAIL)
50
+ → PAUSED (on task failure with on_error=PAUSE)
51
+ → CANCELLED (user requested)
52
+ """
53
+
54
+ PENDING = 'PENDING'
55
+ """Created but not yet started"""
56
+
57
+ RUNNING = 'RUNNING'
58
+ """At least one task executing or ready"""
59
+
60
+ COMPLETED = 'COMPLETED'
61
+ """All tasks terminal and success criteria met"""
62
+
63
+ FAILED = 'FAILED'
64
+ """A task failed and on_error=FAIL (or no success case satisfied)"""
65
+
66
+ PAUSED = 'PAUSED'
67
+ """A task failed with on_error=PAUSE; awaiting resume() or cancel()"""
68
+
69
+ CANCELLED = 'CANCELLED'
70
+ """User cancelled via WorkflowHandle.cancel()"""
71
+
72
+ @property
73
+ def is_terminal(self) -> bool:
74
+ """Whether this status represents a final state (no further transitions)."""
75
+ return self in WORKFLOW_TERMINAL_STATES
76
+
77
+
78
+ WORKFLOW_TERMINAL_STATES: frozenset[WorkflowStatus] = frozenset({
79
+ WorkflowStatus.COMPLETED,
80
+ WorkflowStatus.FAILED,
81
+ WorkflowStatus.CANCELLED,
82
+ })
83
+
84
+
85
+ class WorkflowTaskStatus(str, Enum):
86
+ """
87
+ Status of a single task/node within a workflow.
88
+
89
+ State machine:
90
+ PENDING → READY → ENQUEUED → RUNNING → COMPLETED
91
+ → FAILED
92
+ → SKIPPED (deps failed and allow_failed_deps=False)
93
+ """
94
+
95
+ PENDING = 'PENDING'
96
+ """Waiting for dependencies to become terminal"""
97
+
98
+ READY = 'READY'
99
+ """Dependencies satisfied, waiting to be enqueued"""
100
+
101
+ ENQUEUED = 'ENQUEUED'
102
+ """Task created in tasks table, waiting for worker"""
103
+
104
+ RUNNING = 'RUNNING'
105
+ """Worker is executing the task (or child workflow is running)"""
106
+
107
+ COMPLETED = 'COMPLETED'
108
+ """Task succeeded (TaskResult.is_ok())"""
109
+
110
+ FAILED = 'FAILED'
111
+ """Task failed (TaskResult.is_err())"""
112
+
113
+ SKIPPED = 'SKIPPED'
114
+ """Skipped due to: upstream failure, condition, or quorum impossible"""
115
+
116
+ @property
117
+ def is_terminal(self) -> bool:
118
+ """Whether this status represents a final state (no further transitions)."""
119
+ return self in WORKFLOW_TASK_TERMINAL_STATES
120
+
121
+
122
+ WORKFLOW_TASK_TERMINAL_STATES: frozenset[WorkflowTaskStatus] = frozenset({
123
+ WorkflowTaskStatus.COMPLETED,
124
+ WorkflowTaskStatus.FAILED,
125
+ WorkflowTaskStatus.SKIPPED,
126
+ })
127
+
128
+
129
+ class OnError(str, Enum):
130
+ """
131
+ Error handling policy for workflows when a task fails.
132
+ """
133
+
134
+ FAIL = 'fail'
135
+ """Continue DAG resolution but mark workflow as will-fail. Skip tasks without allow_failed_deps."""
136
+
137
+ PAUSE = 'pause'
138
+ """Pause workflow immediately. No new tasks enqueued until resume()."""
139
+
140
+
141
+ class SubWorkflowRetryMode(str, Enum):
142
+ """
143
+ Retry behavior for subworkflows (only RERUN_FAILED_ONLY currently supported).
144
+ """
145
+
146
+ RERUN_FAILED_ONLY = 'rerun_failed_only'
147
+ """Re-run only failed/cancelled child tasks (default, only supported mode)"""
148
+
149
+ RERUN_ALL = 'rerun_all'
150
+ """Re-run entire child workflow from scratch (not yet implemented)"""
151
+
152
+ NO_RERUN = 'no_rerun'
153
+ """Re-evaluate success policy without re-running (not yet implemented)"""
154
+
155
+
156
+ # =============================================================================
157
+ # SubWorkflowSummary
158
+ # =============================================================================
159
+
160
+
161
+ @dataclass
162
+ class SubWorkflowSummary(Generic[OkT_co]):
163
+ """
164
+ Summary of a child workflow's execution, available via WorkflowContext.summary_for().
165
+
166
+ Provides visibility into child workflow health without exposing internal DAG.
167
+ Useful for conditional logic based on partial success/failure.
168
+
169
+ Example:
170
+ def make_report(workflow_ctx: WorkflowContext | None) -> TaskResult[str, TaskError]:
171
+ if workflow_ctx is None:
172
+ return TaskResult(err=TaskError(error_code="NO_CTX"))
173
+ summary = workflow_ctx.summary_for(data_pipeline_node)
174
+ if summary.failed_tasks > 0:
175
+ return TaskResult(ok=f"Partial: {summary.completed_tasks}/{summary.total_tasks}")
176
+ return TaskResult(ok=f"Full: {summary.output}")
177
+ """
178
+
179
+ status: WorkflowStatus
180
+ """Child workflow's final status (COMPLETED, FAILED, etc.)"""
181
+
182
+ success_case: str | None
183
+ """Which SuccessCase was satisfied (if success_policy used)"""
184
+
185
+ output: OkT_co | None
186
+ """Child's output value (typed via generic parameter)"""
187
+
188
+ total_tasks: int
189
+ """Total number of tasks in child workflow"""
190
+
191
+ completed_tasks: int
192
+ """Number of COMPLETED tasks"""
193
+
194
+ failed_tasks: int
195
+ """Number of FAILED tasks"""
196
+
197
+ skipped_tasks: int
198
+ """Number of SKIPPED tasks"""
199
+
200
+ error_summary: str | None = None
201
+ """Brief description of failure (if child failed)"""
202
+
203
+ @classmethod
204
+ def from_json(cls, data: dict[str, Any]) -> 'SubWorkflowSummary[Any]':
205
+ """Build a SubWorkflowSummary from a JSON-like dict safely."""
206
+ status_val = data.get('status')
207
+ try:
208
+ status = WorkflowStatus(str(status_val))
209
+ except Exception:
210
+ status = WorkflowStatus.FAILED
211
+
212
+ success_case_val = data.get('success_case')
213
+ total_val = data.get('total_tasks', 0)
214
+ completed_val = data.get('completed_tasks', 0)
215
+ failed_val = data.get('failed_tasks', 0)
216
+ skipped_val = data.get('skipped_tasks', 0)
217
+ error_val = data.get('error_summary')
218
+
219
+ return cls(
220
+ status=status,
221
+ success_case=str(success_case_val) if success_case_val else None,
222
+ output=data.get('output'),
223
+ total_tasks=int(total_val) if isinstance(total_val, (int, float)) else 0,
224
+ completed_tasks=int(completed_val)
225
+ if isinstance(completed_val, (int, float))
226
+ else 0,
227
+ failed_tasks=int(failed_val) if isinstance(failed_val, (int, float)) else 0,
228
+ skipped_tasks=int(skipped_val)
229
+ if isinstance(skipped_val, (int, float))
230
+ else 0,
231
+ error_summary=str(error_val) if error_val else None,
232
+ )
233
+
234
+
235
+ # =============================================================================
236
+ # Exceptions
237
+ # =============================================================================
238
+
239
+ # Re-export from errors module for backward compatibility
240
+ from horsies.core.errors import WorkflowValidationError
241
+
242
+
243
+ def _task_accepts_workflow_ctx(fn: Callable[..., Any]) -> bool:
244
+ inspect_target: Callable[..., Any] = fn
245
+ original = getattr(fn, '_original_fn', None)
246
+ if callable(original):
247
+ inspect_target = original
248
+ try:
249
+ sig = inspect.signature(inspect_target)
250
+ except (TypeError, ValueError):
251
+ return False
252
+ return 'workflow_ctx' in sig.parameters
253
+
254
+
255
+ NODE_ID_PATTERN = re.compile(r'^[A-Za-z0-9_\-:.]+$')
256
+
257
+
258
+ def slugify(value: str) -> str:
259
+ """
260
+ Convert a string to a valid node_id by replacing invalid characters.
261
+
262
+ Spaces become underscores, other invalid characters are removed.
263
+ Result matches NODE_ID_PATTERN: [A-Za-z0-9_\\-:.]+
264
+
265
+ Example:
266
+ slugify("My Workflow Name") # "My_Workflow_Name"
267
+ slugify("task:v2.0") # "task:v2.0" (unchanged)
268
+ """
269
+ result = value.replace(' ', '_')
270
+ result = re.sub(r'[^A-Za-z0-9_\-:.]', '', result)
271
+ return result or '_'
272
+
273
+
274
+ # =============================================================================
275
+ # TaskNode
276
+ # =============================================================================
277
+
278
+
279
+ @dataclass
280
+ class TaskNode(Generic[OkT_co]):
281
+ """
282
+ A node in the workflow DAG representing a single task execution.
283
+
284
+ Generic parameter OkT represents the success type of the task's TaskResult.
285
+ This enables type-safe access to results via WorkflowContext.result_for(node).
286
+
287
+ Example:
288
+ ```python
289
+ fetch = TaskNode(fn=fetch_data, args=("url",))
290
+ process = TaskNode(
291
+ fn=process_data,
292
+ waits_for=[fetch], # wait for fetch to be terminal
293
+ args_from={"raw": fetch}, # inject fetch's TaskResult as 'raw' kwarg
294
+ allow_failed_deps=True, # run even if fetch failed
295
+ )
296
+ ```
297
+ """
298
+
299
+ fn: TaskFunction[Any, OkT_co]
300
+ args: tuple[Any, ...] = ()
301
+ kwargs: dict[str, Any] = field(default_factory=lambda: {})
302
+ waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] = field(
303
+ default_factory=lambda: [],
304
+ )
305
+ """
306
+ - List of TaskNodes or SubWorkflowNodes that this task waits for
307
+ - The node with the dependencies will wait for all dependencies to be terminal (COMPLETED/FAILED/SKIPPED)
308
+ """
309
+
310
+ args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] = field(
311
+ default_factory=lambda: {},
312
+ )
313
+ """
314
+ - Data flow: inject dependency TaskResults as kwargs (keyword-only)
315
+ - Example: args_from={"validated": validate_node, "transformed": transform_node}
316
+ - Task receives: def my_task(validated: TaskResult[A, E], transformed: TaskResult[B, E])
317
+ """
318
+
319
+ workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None
320
+ """
321
+ - Optional context: subset of dependencies to include in WorkflowContext
322
+ - Only injected if task declares `workflow_ctx: WorkflowContext | None` parameter
323
+
324
+ """
325
+ # Queue/priority overrides (if None, use task decorator defaults)
326
+ queue: str | None = None
327
+ """
328
+ - Queue overrides (if None, use task decorator defaults)
329
+ """
330
+ priority: int | None = None
331
+ """
332
+ - Priority overrides (if None, use task decorator defaults)
333
+ """
334
+
335
+ allow_failed_deps: bool = False
336
+ """
337
+ - If True, this task runs even if dependencies failed (receives failed TaskResults)
338
+ - If False (default), task is SKIPPED when any dependency fails
339
+ """
340
+
341
+ run_when: Callable[['WorkflowContext'], bool] | None = field(
342
+ default=None, repr=False
343
+ )
344
+ """
345
+ - Conditional execution: evaluated after deps are terminal, before enqueue
346
+ - skip_when has priority over run_when
347
+ - Callables receive WorkflowContext built from workflow_ctx_from
348
+ """
349
+
350
+ skip_when: Callable[['WorkflowContext'], bool] | None = field(
351
+ default=None, repr=False
352
+ )
353
+ """
354
+ - Conditional execution: evaluated after deps are terminal, before enqueue
355
+ - skip_when has priority over run_when
356
+ - Callables receive WorkflowContext built from workflow_ctx_from
357
+ """
358
+
359
+ join: Literal['all', 'any', 'quorum'] = 'all'
360
+ """
361
+ - Dependency join semantics
362
+ - "all": task runs when ALL dependencies are terminal (default)
363
+ - "any": task runs when ANY dependency succeeds (COMPLETED)
364
+ - "quorum": task runs when at least min_success dependencies succeed
365
+ """
366
+ min_success: int | None = None
367
+ """
368
+ - Required for join="quorum"
369
+ - Minimum number of dependencies that must succeed
370
+ """
371
+
372
+ good_until: datetime | None = None
373
+ """
374
+ - Task expiry deadline (task skipped if not claimed by this time)
375
+ """
376
+
377
+ # Assigned during WorkflowSpec construction
378
+ index: int | None = field(default=None, repr=False)
379
+ """
380
+ - Assigned during WorkflowSpec construction
381
+ - Index of the task in the workflow
382
+ """
383
+
384
+ node_id: str | None = field(default=None, repr=False)
385
+ """
386
+ Optional stable identifier for this task within the workflow.
387
+ If None, auto-assigned as '{slugify(workflow_name)}:{task_index}'.
388
+ Must be unique within the workflow.
389
+ Must match pattern: [A-Za-z0-9_\\-:.]+
390
+ """
391
+
392
+ @property
393
+ def name(self) -> str:
394
+ """Get the task name from the wrapped function."""
395
+ return self.fn.task_name
396
+
397
+ def key(self) -> 'NodeKey[OkT_co]':
398
+ """Return a typed NodeKey for this task node."""
399
+ if self.node_id is None:
400
+ raise WorkflowValidationError(
401
+ 'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
402
+ 'or provide an explicit node_id.'
403
+ )
404
+ return NodeKey(self.node_id)
405
+
406
+
407
+ # =============================================================================
408
+ # SubWorkflowNode
409
+ # =============================================================================
410
+
411
+
412
+ @dataclass
413
+ class SubWorkflowNode(Generic[OkT_co]):
414
+ """
415
+ A node that runs a child workflow as a composite task.
416
+
417
+ The generic parameter OkT represents the child workflow's output type,
418
+ derived from WorkflowDefinition[OkT].
419
+
420
+ When the child workflow completes:
421
+ - COMPLETED: parent node receives TaskResult[OkT, TaskError] with child's output
422
+ - FAILED: parent node receives TaskResult with SubWorkflowError containing SubWorkflowSummary
423
+
424
+ Example:
425
+ class DataPipeline(WorkflowDefinition[ProcessedData]):
426
+ ...
427
+
428
+ pipeline: SubWorkflowNode[ProcessedData] = SubWorkflowNode(
429
+ workflow_def=DataPipeline,
430
+ kwargs={"source_url": "https://..."}, # passed to build_with()
431
+ )
432
+ # Downstream: args_from={"data": pipeline} → TaskResult[ProcessedData, TaskError]
433
+ """
434
+
435
+ workflow_def: type['WorkflowDefinition[OkT_co]']
436
+ """
437
+ - The WorkflowDefinition subclass to run as a child workflow
438
+ - Must implement build_with() for parameterization
439
+ """
440
+
441
+ args: tuple[Any, ...] = ()
442
+ """
443
+ - Positional arguments passed to workflow_def.build_with(app, *args, **kwargs)
444
+ """
445
+
446
+ kwargs: dict[str, Any] = field(default_factory=lambda: {})
447
+ """
448
+ - Keyword arguments passed to workflow_def.build_with(app, *args, **kwargs)
449
+ - Use with args_from to inject upstream results as parameters
450
+ """
451
+
452
+ waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] = field(
453
+ default_factory=lambda: [],
454
+ )
455
+ """
456
+ - Nodes this subworkflow waits for before starting
457
+ - Waits for all to be terminal (COMPLETED/FAILED/SKIPPED)
458
+ - Same semantics as TaskNode.waits_for
459
+ """
460
+
461
+ args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] = field(
462
+ default_factory=lambda: {},
463
+ )
464
+ """
465
+ - Maps kwarg names to upstream nodes for data injection
466
+ - Injected as TaskResult into build_with() kwargs
467
+ - Example: args_from={"input_data": fetch_node} → kwargs["input_data"] = TaskResult[T, E]
468
+ """
469
+
470
+ workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None
471
+ """
472
+ - Nodes whose results to include in WorkflowContext for run_when/skip_when
473
+ - For SubWorkflowNodes: access via ctx.summary_for(node) → SubWorkflowSummary
474
+ """
475
+
476
+ join: Literal['all', 'any', 'quorum'] = 'all'
477
+ """
478
+ - "all": start when ALL dependencies are terminal (default)
479
+ - "any": start when ANY dependency succeeds (COMPLETED)
480
+ - "quorum": start when min_success dependencies succeed
481
+ """
482
+
483
+ min_success: int | None = None
484
+ """
485
+ - Required for join="quorum": minimum dependencies that must succeed
486
+ """
487
+
488
+ allow_failed_deps: bool = False
489
+ """
490
+ - False (default): SKIPPED if any dependency failed
491
+ - True: starts regardless, failed deps passed as TaskResult(err=...) via args_from
492
+ """
493
+
494
+ run_when: Callable[['WorkflowContext'], bool] | None = field(
495
+ default=None, repr=False
496
+ )
497
+ """
498
+ - Condition evaluated after deps terminal, before starting child workflow
499
+ - If returns False: node is SKIPPED
500
+ - skip_when takes priority over run_when
501
+ """
502
+
503
+ skip_when: Callable[['WorkflowContext'], bool] | None = field(
504
+ default=None, repr=False
505
+ )
506
+ """
507
+ - Condition evaluated after deps terminal, before starting child workflow
508
+ - If returns True: node is SKIPPED
509
+ - skip_when takes priority over run_when
510
+ """
511
+
512
+ retry_mode: SubWorkflowRetryMode = SubWorkflowRetryMode.RERUN_FAILED_ONLY
513
+ """
514
+ - How to retry if child workflow fails (only RERUN_FAILED_ONLY supported)
515
+ """
516
+
517
+ index: int | None = field(default=None, repr=False)
518
+ """
519
+ - Auto-assigned during WorkflowSpec construction
520
+ """
521
+
522
+ node_id: str | None = field(default=None, repr=False)
523
+ """
524
+ Optional stable identifier for this subworkflow within the parent workflow.
525
+ If None, auto-assigned as '{slugify(workflow_name)}:{task_index}'.
526
+ Must be unique within the workflow.
527
+ Must match pattern: [A-Za-z0-9_\\-:.]+
528
+ """
529
+
530
+ @property
531
+ def name(self) -> str:
532
+ """Get the subworkflow name."""
533
+ return self.workflow_def.name
534
+
535
+ def key(self) -> 'NodeKey[OkT_co]':
536
+ """Return a typed NodeKey for this subworkflow node."""
537
+ if self.node_id is None:
538
+ raise WorkflowValidationError(
539
+ 'SubWorkflowNode node_id is not set. Ensure WorkflowSpec assigns node_id '
540
+ 'or provide an explicit node_id.'
541
+ )
542
+ return NodeKey(self.node_id)
543
+
544
+
545
+
546
+ AnyNode = TaskNode[Any] | SubWorkflowNode[Any]
547
+ '''
548
+ Type alias for any node type
549
+ '''
550
+
551
+ # =============================================================================
552
+ # NodeKey (typed, stable id)
553
+ # =============================================================================
554
+
555
+
556
+ @dataclass(frozen=True)
557
+ class NodeKey(Generic[OkT_co]):
558
+ """Typed stable identifier for a TaskNode."""
559
+
560
+ node_id: str
561
+
562
+
563
+ # =============================================================================
564
+ # Success Policy
565
+ # =============================================================================
566
+
567
+
568
+ @dataclass
569
+ class SuccessCase:
570
+ """
571
+ A single success scenario for a workflow.
572
+
573
+ The case is satisfied when ALL required tasks are COMPLETED.
574
+
575
+ Example:
576
+ # Workflow succeeds if either (A and B) or (C) completes
577
+ SuccessPolicy(cases=[
578
+ SuccessCase(required=[task_a, task_b]),
579
+ SuccessCase(required=[task_c]),
580
+ ])
581
+ """
582
+
583
+ required: list[TaskNode[Any]]
584
+ """
585
+ - All tasks in this list must be COMPLETED for the case to be satisfied
586
+ """
587
+
588
+
589
+ @dataclass
590
+ class SuccessPolicy:
591
+ """
592
+ Custom success criteria: workflow COMPLETED if ANY SuccessCase is satisfied.
593
+
594
+ Without a success_policy, default behavior is: any task failure → workflow FAILED.
595
+ With a success_policy, workflow is COMPLETED if at least one case has all
596
+ its required tasks COMPLETED, regardless of other task failures.
597
+
598
+ Example:
599
+ # "Succeed if primary path completes, even if fallback fails"
600
+ success_policy = SuccessPolicy(
601
+ cases=[SuccessCase(required=[primary_task])],
602
+ optional=[fallback_task], # can fail without affecting success
603
+ )
604
+ """
605
+
606
+ cases: list[SuccessCase]
607
+ """
608
+ - List of success scenarios
609
+ - Workflow succeeds if ANY case is fully satisfied (all required COMPLETED)
610
+ """
611
+
612
+ optional: list[TaskNode[Any]] | None = None
613
+ """
614
+ - Tasks that may fail without affecting success evaluation
615
+ - These failures don't block success cases from being satisfied
616
+ """
617
+
618
+
619
+ # =============================================================================
620
+ # WorkflowSpec
621
+ # =============================================================================
622
+
623
+
624
+ @dataclass
625
+ class WorkflowSpec:
626
+ """
627
+ Specification for a workflow DAG. Created via app.workflow() or WorkflowDefinition.build().
628
+
629
+ Validates the DAG on construction (cycles, dependency refs, args_from, etc.)
630
+ to catch configuration errors early, before execution.
631
+
632
+ Example:
633
+ spec = app.workflow(
634
+ name="my_pipeline",
635
+ tasks=[fetch, process, persist],
636
+ output=persist,
637
+ on_error=OnError.PAUSE,
638
+ )
639
+ handle = await spec.start_async()
640
+ """
641
+
642
+ name: str
643
+ """
644
+ Human-readable workflow name (used in logs, DB, registry).
645
+ Can contain any characters including spaces. When auto-generating
646
+ node_ids, the name is passed through slugify() to ensure validity.
647
+ """
648
+
649
+ tasks: list[TaskNode[Any] | SubWorkflowNode[Any]]
650
+ """
651
+ - All nodes in the DAG (order determines index assignment)
652
+ - Root nodes (empty waits_for) start immediately
653
+ """
654
+
655
+ on_error: OnError = OnError.FAIL
656
+ """
657
+ - FAIL (default): on task failure, mark workflow as will-fail, skip dependent tasks
658
+ - PAUSE: on task failure, pause workflow for manual intervention (resume/cancel)
659
+ """
660
+
661
+ output: TaskNode[Any] | SubWorkflowNode[Any] | None = None
662
+ """
663
+ - Explicit output node: WorkflowHandle.get() returns this node's result
664
+ - If None: get() returns dict of all terminal node results keyed by node_id
665
+ """
666
+
667
+ success_policy: SuccessPolicy | None = None
668
+ """
669
+ - Custom success criteria: workflow COMPLETED if any SuccessCase is satisfied
670
+ - If None (default): any task failure → workflow FAILED
671
+ """
672
+
673
+ workflow_def_module: str | None = None
674
+ """
675
+ - Module path of WorkflowDefinition class (for import fallback in workers)
676
+ """
677
+
678
+ workflow_def_qualname: str | None = None
679
+ """
680
+ - Qualified name of WorkflowDefinition class (for import fallback in workers)
681
+ """
682
+
683
+ broker: PostgresBroker | None = field(default=None, repr=False)
684
+ """
685
+ - Database broker for start()/start_async()
686
+ - Set automatically by app.workflow()
687
+ """
688
+
689
+ def __post_init__(self) -> None:
690
+ """Validate DAG structure on construction.
691
+
692
+ Phase-gated: node_id errors gate the rest (downstream needs valid IDs).
693
+ All other validations run and collect errors together.
694
+ """
695
+ self._assign_indices()
696
+
697
+ # Gate 1: node_id assignment — if any ID errors, skip remaining validation
698
+ node_id_errors = self._collect_node_id_errors()
699
+ if node_id_errors:
700
+ report = ValidationReport('workflow')
701
+ for error in node_id_errors:
702
+ report.add(error)
703
+ raise_collected(report)
704
+ return
705
+
706
+ # Gate 2: collect all remaining validation errors
707
+ report = ValidationReport('workflow')
708
+ for error in self._collect_dag_errors():
709
+ report.add(error)
710
+ for error in self._collect_args_from_errors():
711
+ report.add(error)
712
+ for error in self._collect_workflow_ctx_from_errors():
713
+ report.add(error)
714
+ for error in self._collect_output_errors():
715
+ report.add(error)
716
+ for error in self._collect_success_policy_errors():
717
+ report.add(error)
718
+ for error in self._collect_join_semantics_errors():
719
+ report.add(error)
720
+ for error in self._collect_subworkflow_retry_mode_errors():
721
+ report.add(error)
722
+ for error in self._collect_subworkflow_cycle_errors():
723
+ report.add(error)
724
+
725
+ raise_collected(report)
726
+
727
+ # Only if clean: conditions + registration
728
+ self._validate_conditions()
729
+ self._register_for_conditions()
730
+
731
+ def _assign_indices(self) -> None:
732
+ """Assign index to each TaskNode based on list position."""
733
+ for i, task in enumerate(self.tasks):
734
+ task.index = i
735
+
736
+ def _collect_node_id_errors(self) -> list[WorkflowValidationError]:
737
+ """Assign node_id to each TaskNode if missing and validate uniqueness.
738
+
739
+ Returns all node_id errors instead of raising on first.
740
+
741
+ node_id source:
742
+ - User provides workflow NAME (e.g., "My Pipeline")
743
+ - node_id is either:
744
+ a) Derived from workflow name: slugify(name) + ":" + index
745
+ b) Explicitly provided by user on TaskNode
746
+ - Errors must distinguish between (a) and (b) so users know what to fix
747
+ """
748
+ errors: list[WorkflowValidationError] = []
749
+ seen_ids: set[str] = set()
750
+ for task in self.tasks:
751
+ # Track whether node_id comes from workflow name or was explicitly set
752
+ node_id_from_workflow_name = task.node_id is None
753
+ if node_id_from_workflow_name:
754
+ if task.index is None:
755
+ errors.append(WorkflowValidationError(
756
+ message='TaskNode index is not set before assigning node_id',
757
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
758
+ ))
759
+ continue
760
+ task.node_id = f'{slugify(self.name)}:{task.index}'
761
+
762
+ node_id = task.node_id
763
+ if node_id is None or not node_id.strip():
764
+ errors.append(WorkflowValidationError(
765
+ message='TaskNode node_id must be a non-empty string',
766
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
767
+ ))
768
+ continue
769
+ if len(node_id) > 128:
770
+ if node_id_from_workflow_name:
771
+ errors.append(WorkflowValidationError(
772
+ message='workflow name too long',
773
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
774
+ notes=[
775
+ f"workflow name: '{self.name}'",
776
+ f"derived node_id would be {len(node_id)} characters (max 128)",
777
+ ],
778
+ help_text='use a shorter workflow name',
779
+ ))
780
+ else:
781
+ errors.append(WorkflowValidationError(
782
+ message='TaskNode node_id exceeds 128 characters',
783
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
784
+ notes=[f"node_id '{node_id}' has {len(node_id)} characters"],
785
+ help_text='use a shorter node_id (max 128 characters)',
786
+ ))
787
+ continue
788
+ if NODE_ID_PATTERN.match(node_id) is None:
789
+ if node_id_from_workflow_name:
790
+ # This should never happen since slugify() sanitizes the name
791
+ errors.append(WorkflowValidationError(
792
+ message='workflow name produced invalid characters (internal error)',
793
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
794
+ notes=[
795
+ f"workflow name: '{self.name}'",
796
+ f"derived node_id: '{node_id}'",
797
+ 'slugify() failed to sanitize the name',
798
+ ],
799
+ help_text='please report this bug',
800
+ ))
801
+ else:
802
+ errors.append(WorkflowValidationError(
803
+ message='TaskNode node_id contains invalid characters',
804
+ code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
805
+ notes=[f"node_id '{node_id}'"],
806
+ help_text='node_id must match pattern: [A-Za-z0-9_\\-:.]+',
807
+ ))
808
+ continue
809
+ if node_id in seen_ids:
810
+ errors.append(WorkflowValidationError(
811
+ message=f"duplicate node_id '{node_id}'",
812
+ code=ErrorCode.WORKFLOW_DUPLICATE_NODE_ID,
813
+ help_text='each TaskNode must have a unique node_id within the workflow',
814
+ ))
815
+ seen_ids.add(node_id)
816
+ return errors
817
+
818
+ def _collect_dag_errors(self) -> list[WorkflowValidationError]:
819
+ """Validate DAG structure. Returns all errors found."""
820
+ errors: list[WorkflowValidationError] = []
821
+
822
+ # 1. Check for roots (tasks with no dependencies)
823
+ roots = [t for t in self.tasks if not t.waits_for]
824
+ if not roots:
825
+ errors.append(WorkflowValidationError(
826
+ message='no root tasks found',
827
+ code=ErrorCode.WORKFLOW_NO_ROOT_TASKS,
828
+ notes=[
829
+ 'all tasks have dependencies, creating an impossible start condition',
830
+ ],
831
+ help_text='at least one task must have empty waits_for list',
832
+ ))
833
+
834
+ # 2. Validate dependency references exist in workflow
835
+ task_ids = set(id(t) for t in self.tasks)
836
+ for task in self.tasks:
837
+ for dep in task.waits_for:
838
+ if id(dep) not in task_ids:
839
+ errors.append(WorkflowValidationError(
840
+ message='dependency references task not in workflow',
841
+ code=ErrorCode.WORKFLOW_INVALID_DEPENDENCY,
842
+ notes=[
843
+ f"task '{task.name}' waits for a TaskNode not in this workflow",
844
+ ],
845
+ help_text='ensure all dependencies are included in the workflow tasks list',
846
+ ))
847
+
848
+ # 3. Cycle detection (Kahn's algorithm) over valid dependencies only
849
+ in_degree: dict[int, int] = {}
850
+ for task in self.tasks:
851
+ idx = task.index
852
+ if idx is None:
853
+ continue
854
+ in_degree[idx] = 0
855
+ for dep in task.waits_for:
856
+ if id(dep) in task_ids and dep.index is not None:
857
+ in_degree[idx] += 1
858
+
859
+ queue = [
860
+ t.index
861
+ for t in self.tasks
862
+ if t.index is not None and in_degree.get(t.index, 0) == 0
863
+ ]
864
+ visited = 0
865
+
866
+ while queue:
867
+ node_idx = queue.pop(0)
868
+ visited += 1
869
+ for task in self.tasks:
870
+ dep_indices = [
871
+ d.index
872
+ for d in task.waits_for
873
+ if id(d) in task_ids and d.index is not None
874
+ ]
875
+ if node_idx in dep_indices:
876
+ task_idx = task.index
877
+ if task_idx is not None:
878
+ in_degree[task_idx] -= 1
879
+ if in_degree[task_idx] == 0:
880
+ queue.append(task_idx)
881
+
882
+ if visited != len(self.tasks):
883
+ errors.append(WorkflowValidationError(
884
+ message='cycle detected in workflow DAG',
885
+ code=ErrorCode.WORKFLOW_CYCLE_DETECTED,
886
+ notes=['workflows must be acyclic directed graphs (DAG)'],
887
+ help_text='remove circular dependencies between tasks',
888
+ ))
889
+
890
+ return errors
891
+
892
+ def _collect_args_from_errors(self) -> list[WorkflowValidationError]:
893
+ """Validate args_from references are valid dependencies. Returns all errors."""
894
+ errors: list[WorkflowValidationError] = []
895
+ for task in self.tasks:
896
+ deps_ids = set(id(d) for d in task.waits_for)
897
+ for kwarg_name, source_node in task.args_from.items():
898
+ if id(source_node) not in deps_ids:
899
+ errors.append(WorkflowValidationError(
900
+ message='args_from references task not in waits_for',
901
+ code=ErrorCode.WORKFLOW_INVALID_ARGS_FROM,
902
+ notes=[
903
+ f"task '{task.name}' args_from['{kwarg_name}'] references '{source_node.name}'",
904
+ f"'{source_node.name}' must be in waits_for to inject its result",
905
+ ],
906
+ help_text=f"add '{source_node.name}' to waits_for list",
907
+ ))
908
+ return errors
909
+
910
+ def _collect_workflow_ctx_from_errors(self) -> list[WorkflowValidationError]:
911
+ """Validate workflow_ctx_from references are valid dependencies. Returns all errors."""
912
+ errors: list[WorkflowValidationError] = []
913
+ for node in self.tasks:
914
+ if node.workflow_ctx_from is None:
915
+ continue
916
+ deps_ids = set(id(d) for d in node.waits_for)
917
+ for ctx_node in node.workflow_ctx_from:
918
+ if id(ctx_node) not in deps_ids:
919
+ errors.append(WorkflowValidationError(
920
+ message='workflow_ctx_from references task not in waits_for',
921
+ code=ErrorCode.WORKFLOW_INVALID_CTX_FROM,
922
+ notes=[
923
+ f"node '{node.name}' references '{ctx_node.name}'",
924
+ f"'{ctx_node.name}' must be in waits_for to use in workflow_ctx_from",
925
+ ],
926
+ help_text=f"add '{ctx_node.name}' to waits_for list",
927
+ ))
928
+
929
+ # Only check function parameter for TaskNode (SubWorkflowNode has no fn)
930
+ if isinstance(node, SubWorkflowNode):
931
+ continue
932
+
933
+ task = node
934
+ if not _task_accepts_workflow_ctx(task.fn):
935
+ # Get the original function for accurate source location
936
+ original_fn = getattr(task.fn, '_original_fn', task.fn)
937
+ fn_location = SourceLocation.from_function(original_fn)
938
+
939
+ errors.append(WorkflowValidationError(
940
+ message='workflow_ctx_from declared but function missing workflow_ctx param',
941
+ code=ErrorCode.WORKFLOW_CTX_PARAM_MISSING,
942
+ location=fn_location, # May be None for non-function callables
943
+ notes=[
944
+ f"workflow '{self.name}'\n"
945
+ f"TaskNode '{task.name}' declares workflow_ctx_from=[...]\n"
946
+ f"but function '{task.name}' has no workflow_ctx parameter",
947
+ ],
948
+ help_text=(
949
+ 'either:\n'
950
+ ' 1. add `workflow_ctx: WorkflowContext | None` param to the function above if needs context\n'
951
+ ' 2. remove `workflow_ctx_from` from the TaskNode definition if this was a mistake'
952
+ ),
953
+ ))
954
+ return errors
955
+
956
+ def _collect_output_errors(self) -> list[WorkflowValidationError]:
957
+ """Validate output task is in the workflow. Returns all errors."""
958
+ errors: list[WorkflowValidationError] = []
959
+ if self.output is None:
960
+ return errors
961
+ task_ids = set(id(t) for t in self.tasks)
962
+ if id(self.output) not in task_ids:
963
+ errors.append(WorkflowValidationError(
964
+ f"Output task '{self.output.name}' is not in workflow",
965
+ ))
966
+ return errors
967
+
968
+ def _collect_success_policy_errors(self) -> list[WorkflowValidationError]:
969
+ """Validate success policy references are valid workflow tasks. Returns all errors."""
970
+ errors: list[WorkflowValidationError] = []
971
+ if self.success_policy is None:
972
+ return errors
973
+
974
+ # Validate cases list is not empty
975
+ if not self.success_policy.cases:
976
+ errors.append(WorkflowValidationError(
977
+ 'SuccessPolicy must have at least one SuccessCase',
978
+ ))
979
+ return errors
980
+
981
+ task_ids = set(id(t) for t in self.tasks)
982
+
983
+ # Validate each success case
984
+ for i, case in enumerate(self.success_policy.cases):
985
+ if not case.required:
986
+ errors.append(WorkflowValidationError(
987
+ f'SuccessCase[{i}] has no required tasks',
988
+ ))
989
+ for task in case.required:
990
+ if id(task) not in task_ids:
991
+ errors.append(WorkflowValidationError(
992
+ f"SuccessCase[{i}] required task '{task.name}' is not in workflow",
993
+ ))
994
+
995
+ # Validate optional tasks
996
+ if self.success_policy.optional:
997
+ for task in self.success_policy.optional:
998
+ if id(task) not in task_ids:
999
+ errors.append(WorkflowValidationError(
1000
+ f"SuccessPolicy optional task '{task.name}' is not in workflow",
1001
+ ))
1002
+
1003
+ return errors
1004
+
1005
+ def _collect_join_semantics_errors(self) -> list[WorkflowValidationError]:
1006
+ """Validate join and min_success settings. Returns all errors."""
1007
+ errors: list[WorkflowValidationError] = []
1008
+ for task in self.tasks:
1009
+ if task.join == 'quorum':
1010
+ if task.min_success is None:
1011
+ errors.append(WorkflowValidationError(
1012
+ f"Task '{task.name}' has join='quorum' but min_success is not set",
1013
+ ))
1014
+ elif task.min_success < 1:
1015
+ errors.append(WorkflowValidationError(
1016
+ f"Task '{task.name}' min_success must be >= 1, got {task.min_success}",
1017
+ ))
1018
+ else:
1019
+ dep_count = len(task.waits_for)
1020
+ if task.min_success > dep_count:
1021
+ errors.append(WorkflowValidationError(
1022
+ f"Task '{task.name}' min_success ({task.min_success}) exceeds "
1023
+ f'dependency count ({dep_count})',
1024
+ ))
1025
+ elif task.join in ('all', 'any'):
1026
+ if task.min_success is not None:
1027
+ errors.append(WorkflowValidationError(
1028
+ f"Task '{task.name}' has min_success set but join='{task.join}' "
1029
+ "(min_success is only used with join='quorum')",
1030
+ ))
1031
+ return errors
1032
+
1033
+ def _validate_conditions(self) -> None:
1034
+ """Validate condition callables have required context dependencies."""
1035
+ for task in self.tasks:
1036
+ has_condition = task.run_when is not None or task.skip_when is not None
1037
+ if has_condition and not task.workflow_ctx_from:
1038
+ # Conditions require context, but no context sources specified
1039
+ # This is allowed (empty context), but may cause KeyError if
1040
+ # condition tries to access dependency results
1041
+ pass # Allow - user may have conditions that don't use context
1042
+
1043
+ def _collect_subworkflow_cycle_errors(self) -> list[WorkflowValidationError]:
1044
+ """Detect cycles in nested workflow definitions. Returns all errors.
1045
+
1046
+ Prevents circular references like:
1047
+ - WorkflowA contains SubWorkflowNode(WorkflowB)
1048
+ and WorkflowB contains SubWorkflowNode(WorkflowA)
1049
+
1050
+ Uses DFS with a recursion stack to detect back-edges.
1051
+ """
1052
+ errors: list[WorkflowValidationError] = []
1053
+ visited: set[str] = set()
1054
+ stack: set[str] = set()
1055
+
1056
+ def visit(workflow_name: str, workflow_class: type[Any]) -> None:
1057
+ """DFS visit with cycle detection via recursion stack."""
1058
+ if workflow_name in stack:
1059
+ # Found a back-edge - this is a cycle
1060
+ errors.append(WorkflowValidationError(
1061
+ message='cycle detected in nested workflows',
1062
+ code=ErrorCode.WORKFLOW_CYCLE_DETECTED,
1063
+ notes=[
1064
+ f"workflow '{workflow_name}' creates a circular reference",
1065
+ 'cycles in nested workflows are not allowed',
1066
+ ],
1067
+ help_text='remove the circular SubWorkflowNode reference',
1068
+ ))
1069
+ return
1070
+
1071
+ if workflow_name in visited:
1072
+ # Already fully explored this workflow, no cycle through here
1073
+ return
1074
+
1075
+ visited.add(workflow_name)
1076
+ stack.add(workflow_name)
1077
+
1078
+ # Check all SubWorkflowNodes in this workflow's definition
1079
+ nodes = workflow_class.get_workflow_nodes()
1080
+ if nodes:
1081
+ for _, wf_node in nodes:
1082
+ if isinstance(wf_node, SubWorkflowNode):
1083
+ wf_node_any = cast(SubWorkflowNode[Any], wf_node)
1084
+ workflow_def = wf_node_any.workflow_def
1085
+ child_name: str = workflow_def.name
1086
+ visit(child_name, workflow_def)
1087
+
1088
+ # Done exploring this workflow
1089
+ stack.remove(workflow_name)
1090
+
1091
+ # Start DFS from each SubWorkflowNode in this workflow's tasks
1092
+ for node in self.tasks:
1093
+ if isinstance(node, SubWorkflowNode):
1094
+ visit(node.workflow_def.name, node.workflow_def)
1095
+
1096
+ return errors
1097
+
1098
+ def _collect_subworkflow_retry_mode_errors(self) -> list[WorkflowValidationError]:
1099
+ """Reject unsupported subworkflow retry modes. Returns all errors."""
1100
+ errors: list[WorkflowValidationError] = []
1101
+ for node in self.tasks:
1102
+ if not isinstance(node, SubWorkflowNode):
1103
+ continue
1104
+ if node.retry_mode != SubWorkflowRetryMode.RERUN_FAILED_ONLY:
1105
+ errors.append(WorkflowValidationError(
1106
+ message='unsupported SubWorkflowRetryMode',
1107
+ code=ErrorCode.WORKFLOW_INVALID_SUBWORKFLOW_RETRY_MODE,
1108
+ notes=[
1109
+ f"node '{node.name}' uses retry_mode='{node.retry_mode.value}'",
1110
+ "only 'rerun_failed_only' is supported in this release",
1111
+ ],
1112
+ help_text='use SubWorkflowRetryMode.RERUN_FAILED_ONLY',
1113
+ ))
1114
+ return errors
1115
+
1116
+ def _register_for_conditions(self) -> None:
1117
+ """Register this spec for condition evaluation at runtime."""
1118
+ # Register if any task has conditions OR any SubWorkflowNode exists
1119
+ has_conditions = any(
1120
+ t.run_when is not None or t.skip_when is not None for t in self.tasks
1121
+ )
1122
+ has_subworkflow = any(isinstance(t, SubWorkflowNode) for t in self.tasks)
1123
+ if has_conditions or has_subworkflow:
1124
+ from horsies.core.workflows.registry import register_workflow_spec
1125
+
1126
+ register_workflow_spec(self)
1127
+
1128
+ def start(self, workflow_id: str | None = None) -> 'WorkflowHandle':
1129
+ """
1130
+ Start workflow execution.
1131
+
1132
+ Args:
1133
+ workflow_id: Optional custom workflow ID. Auto-generated if not provided.
1134
+
1135
+ Returns:
1136
+ WorkflowHandle for tracking and retrieving results.
1137
+
1138
+ Raises:
1139
+ RuntimeError: If broker is not configured.
1140
+ """
1141
+ if self.broker is None:
1142
+ raise RuntimeError(
1143
+ 'WorkflowSpec requires a broker. Use app.workflow() or set broker.'
1144
+ )
1145
+
1146
+ # Import here to avoid circular imports
1147
+ from horsies.core.workflows.engine import start_workflow
1148
+
1149
+ return start_workflow(self, self.broker, workflow_id)
1150
+
1151
+ async def start_async(self, workflow_id: str | None = None) -> 'WorkflowHandle':
1152
+ """
1153
+ Start workflow execution (async).
1154
+
1155
+ Args:
1156
+ workflow_id: Optional custom workflow ID. Auto-generated if not provided.
1157
+
1158
+ Returns:
1159
+ WorkflowHandle for tracking and retrieving results.
1160
+
1161
+ Raises:
1162
+ RuntimeError: If broker is not configured.
1163
+ """
1164
+ if self.broker is None:
1165
+ raise RuntimeError(
1166
+ 'WorkflowSpec requires a broker. Use app.workflow() or set broker.'
1167
+ )
1168
+
1169
+ # Import here to avoid circular imports
1170
+ from horsies.core.workflows.engine import start_workflow_async
1171
+
1172
+ return await start_workflow_async(self, self.broker, workflow_id)
1173
+
1174
+
1175
+ # =============================================================================
1176
+ # WorkflowMeta (metadata only, no result access)
1177
+ # =============================================================================
1178
+
1179
+
1180
+ @dataclass
1181
+ class WorkflowMeta:
1182
+ """
1183
+ Workflow execution metadata.
1184
+
1185
+ Auto-injected if task declares `workflow_meta: WorkflowMeta | None` parameter.
1186
+ Contains only metadata, no result access.
1187
+
1188
+ Attributes:
1189
+ workflow_id: UUID of the workflow instance
1190
+ task_index: Index of the current task in the workflow
1191
+ task_name: Name of the current task
1192
+ """
1193
+
1194
+ workflow_id: str
1195
+ task_index: int
1196
+ task_name: str
1197
+
1198
+
1199
+ # =============================================================================
1200
+ # WorkflowContext (type-safe result access via result_for)
1201
+ # =============================================================================
1202
+
1203
+
1204
+ class WorkflowContextMissingIdError(RuntimeError):
1205
+ """Raised when TaskNode node_id is missing for WorkflowContext.result_for()."""
1206
+
1207
+
1208
+ class WorkflowHandleMissingIdError(RuntimeError):
1209
+ """Raised when TaskNode node_id is missing for WorkflowHandle.result_for()."""
1210
+
1211
+
1212
+ class WorkflowContext(BaseModel):
1213
+ """
1214
+ Context passed to workflow tasks with type-safe access to dependency results.
1215
+
1216
+ Only injected if:
1217
+ 1. TaskNode has workflow_ctx_from set, AND
1218
+ 2. Task function declares `workflow_ctx: WorkflowContext | None` parameter
1219
+
1220
+ Use result_for(node) to access results in a type-safe manner.
1221
+ Use summary_for(node) to access SubWorkflowSummary for SubWorkflowNodes.
1222
+
1223
+ Attributes:
1224
+ workflow_id: UUID of the workflow instance
1225
+ task_index: Index of the current task in the workflow
1226
+ task_name: Name of the current task
1227
+ """
1228
+
1229
+ model_config = {'arbitrary_types_allowed': True}
1230
+
1231
+ workflow_id: str # UUID as string for JSON serialization
1232
+ task_index: int
1233
+ task_name: str
1234
+
1235
+ # Internal storage: results keyed by node_id
1236
+ _results_by_id: dict[str, Any] = {}
1237
+ # Internal storage: subworkflow summaries keyed by node_id
1238
+ _summaries_by_id: dict[str, Any] = {}
1239
+
1240
+ def __init__(
1241
+ self,
1242
+ workflow_id: str,
1243
+ task_index: int,
1244
+ task_name: str,
1245
+ results_by_id: dict[str, Any] | None = None,
1246
+ summaries_by_id: dict[str, Any] | None = None,
1247
+ **kwargs: Any,
1248
+ ) -> None:
1249
+ super().__init__(
1250
+ workflow_id=workflow_id,
1251
+ task_index=task_index,
1252
+ task_name=task_name,
1253
+ **kwargs,
1254
+ )
1255
+ # Store results internally (not exposed as Pydantic field)
1256
+ object.__setattr__(self, '_results_by_id', results_by_id or {})
1257
+ object.__setattr__(self, '_summaries_by_id', summaries_by_id or {})
1258
+
1259
+ def result_for(
1260
+ self,
1261
+ node: TaskNode[OkT] | NodeKey[OkT],
1262
+ ) -> 'TaskResult[OkT, TaskError]':
1263
+ """
1264
+ Get the result for a specific TaskNode.
1265
+
1266
+ Type-safe: returns TaskResult[T, TaskError] where T matches the node's type.
1267
+
1268
+ Args:
1269
+ node: The TaskNode or NodeKey whose result to retrieve. Must have been
1270
+ included in workflow_ctx_from and have a node_id assigned.
1271
+
1272
+ Returns:
1273
+ The TaskResult from the completed task.
1274
+
1275
+ Raises:
1276
+ KeyError: If the node's result is not in this context.
1277
+ RuntimeError: If the node has no node_id assigned.
1278
+ """
1279
+
1280
+ node_id: str | None
1281
+ if isinstance(node, NodeKey):
1282
+ node_id = node.node_id
1283
+ else:
1284
+ node_id = node.node_id
1285
+
1286
+ if node_id is None:
1287
+ raise WorkflowContextMissingIdError(
1288
+ 'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
1289
+ 'or provide an explicit node_id.'
1290
+ )
1291
+
1292
+ if node_id not in self._results_by_id:
1293
+ raise KeyError(
1294
+ f"TaskNode id '{node_id}' not in workflow context. "
1295
+ 'Ensure the node is included in workflow_ctx_from.'
1296
+ )
1297
+
1298
+ # Cast is safe because the generic parameter ensures type correctness
1299
+ return cast('TaskResult[OkT, TaskError]', self._results_by_id[node_id])
1300
+
1301
+ def has_result(self, node: TaskNode[Any] | SubWorkflowNode[Any]) -> bool:
1302
+ """Check if a result exists for the given node."""
1303
+ if node.node_id is None:
1304
+ return False
1305
+ return node.node_id in self._results_by_id
1306
+
1307
+ def summary_for(
1308
+ self,
1309
+ node: 'SubWorkflowNode[OkT]',
1310
+ ) -> 'SubWorkflowSummary[OkT]':
1311
+ """
1312
+ Get the SubWorkflowSummary for a completed SubWorkflowNode.
1313
+
1314
+ Type-safe: returns SubWorkflowSummary[T] where T matches the node's output type.
1315
+
1316
+ Args:
1317
+ node: The SubWorkflowNode whose summary to retrieve. Must have been
1318
+ included in workflow_ctx_from and have a node_id assigned.
1319
+
1320
+ Returns:
1321
+ The SubWorkflowSummary from the completed subworkflow.
1322
+
1323
+ Raises:
1324
+ KeyError: If the node's summary is not in this context.
1325
+ RuntimeError: If the node has no node_id assigned.
1326
+ """
1327
+ node_id = node.node_id
1328
+
1329
+ if node_id is None:
1330
+ raise WorkflowContextMissingIdError(
1331
+ 'SubWorkflowNode node_id is not set. Ensure WorkflowSpec assigns node_id.'
1332
+ )
1333
+
1334
+ if node_id not in self._summaries_by_id:
1335
+ raise KeyError(
1336
+ f"SubWorkflowNode id '{node_id}' not in workflow context summaries. "
1337
+ 'Ensure the node is included in workflow_ctx_from.'
1338
+ )
1339
+
1340
+ # Cast is safe because the generic parameter ensures type correctness
1341
+ return cast('SubWorkflowSummary[OkT]', self._summaries_by_id[node_id])
1342
+
1343
+ def has_summary(self, node: 'SubWorkflowNode[Any]') -> bool:
1344
+ """Check if a summary exists for the given SubWorkflowNode."""
1345
+ if node.node_id is None:
1346
+ return False
1347
+ return node.node_id in self._summaries_by_id
1348
+
1349
+ @classmethod
1350
+ def from_serialized(
1351
+ cls,
1352
+ workflow_id: str,
1353
+ task_index: int,
1354
+ task_name: str,
1355
+ results_by_id: dict[str, Any],
1356
+ summaries_by_id: dict[str, Any] | None = None,
1357
+ ) -> 'WorkflowContext':
1358
+ """Reconstruct WorkflowContext from serialized data."""
1359
+ return cls(
1360
+ workflow_id=workflow_id,
1361
+ task_index=task_index,
1362
+ task_name=task_name,
1363
+ results_by_id=results_by_id,
1364
+ summaries_by_id=summaries_by_id,
1365
+ )
1366
+
1367
+
1368
+ # =============================================================================
1369
+ # WorkflowHandle
1370
+ # =============================================================================
1371
+
1372
+
1373
+ @dataclass
1374
+ class WorkflowTaskInfo:
1375
+ """Information about a task within a workflow."""
1376
+
1377
+ node_id: str | None
1378
+ index: int
1379
+ name: str
1380
+ status: WorkflowTaskStatus
1381
+ result: TaskResult[Any, TaskError] | None
1382
+ started_at: datetime | None
1383
+ completed_at: datetime | None
1384
+
1385
+
1386
+ @dataclass
1387
+ class WorkflowHandle:
1388
+ """
1389
+ Handle for tracking and retrieving workflow results.
1390
+
1391
+ Provides methods to:
1392
+ - Check workflow status
1393
+ - Wait for and retrieve results
1394
+ - Inspect individual task states
1395
+ - Cancel the workflow
1396
+ """
1397
+
1398
+ workflow_id: str
1399
+ broker: PostgresBroker
1400
+
1401
+ def status(self) -> WorkflowStatus:
1402
+ """Get current workflow status."""
1403
+
1404
+ runner = LoopRunner()
1405
+ try:
1406
+ return runner.call(self.status_async)
1407
+ finally:
1408
+ runner.stop()
1409
+
1410
+ async def status_async(self) -> WorkflowStatus:
1411
+ """Async version of status()."""
1412
+ from sqlalchemy import text
1413
+
1414
+ async with self.broker.session_factory() as session:
1415
+ result = await session.execute(
1416
+ text('SELECT status FROM horsies_workflows WHERE id = :wf_id'),
1417
+ {'wf_id': self.workflow_id},
1418
+ )
1419
+ row = result.fetchone()
1420
+ if row is None:
1421
+ raise ValueError(f'Workflow {self.workflow_id} not found')
1422
+ return WorkflowStatus(row[0])
1423
+
1424
+ def get(self, timeout_ms: int | None = None) -> TaskResult[Any, TaskError]:
1425
+ """
1426
+ Block until workflow completes or timeout.
1427
+
1428
+ Returns:
1429
+ If output task specified: that task's TaskResult
1430
+ Otherwise: TaskResult containing dict of terminal task results
1431
+ """
1432
+
1433
+ runner = LoopRunner()
1434
+ try:
1435
+ return runner.call(self.get_async, timeout_ms)
1436
+ finally:
1437
+ runner.stop()
1438
+
1439
+ async def get_async(
1440
+ self, timeout_ms: int | None = None
1441
+ ) -> TaskResult[Any, TaskError]:
1442
+ """Async version of get()."""
1443
+
1444
+ from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
1445
+
1446
+ start = time.monotonic()
1447
+ timeout_sec = timeout_ms / 1000 if timeout_ms else None
1448
+
1449
+ while True:
1450
+ # Check current status
1451
+ status = await self.status_async()
1452
+
1453
+ if status == WorkflowStatus.COMPLETED:
1454
+ return await self._get_result()
1455
+
1456
+ if status in (WorkflowStatus.FAILED, WorkflowStatus.CANCELLED):
1457
+ return await self._get_error()
1458
+
1459
+ if status == WorkflowStatus.PAUSED:
1460
+ return TaskResult(
1461
+ err=TaskError(
1462
+ error_code='WORKFLOW_PAUSED',
1463
+ message='Workflow is paused awaiting intervention',
1464
+ )
1465
+ )
1466
+
1467
+ # Check timeout
1468
+ elapsed = time.monotonic() - start
1469
+ if timeout_sec and elapsed >= timeout_sec:
1470
+ return TaskResult(
1471
+ err=TaskError(
1472
+ error_code=LibraryErrorCode.WAIT_TIMEOUT,
1473
+ message=f'Workflow did not complete within {timeout_ms}ms',
1474
+ )
1475
+ )
1476
+
1477
+ # Wait for notification or poll
1478
+ remaining = (timeout_sec - elapsed) if timeout_sec else 5.0
1479
+ await self._wait_for_completion(min(remaining, 5.0))
1480
+
1481
+ async def _get_result(self) -> TaskResult[Any, TaskError]:
1482
+ """Fetch completed workflow result."""
1483
+ from sqlalchemy import text
1484
+
1485
+ from horsies.core.models.tasks import TaskResult
1486
+ from horsies.core.codec.serde import loads_json, task_result_from_json
1487
+
1488
+ async with self.broker.session_factory() as session:
1489
+ result = await session.execute(
1490
+ text('SELECT result FROM horsies_workflows WHERE id = :wf_id'),
1491
+ {'wf_id': self.workflow_id},
1492
+ )
1493
+ row = result.fetchone()
1494
+ if row and row[0]:
1495
+ return task_result_from_json(loads_json(row[0]))
1496
+ return TaskResult(ok=None)
1497
+
1498
+ async def _get_error(self) -> TaskResult[Any, TaskError]:
1499
+ """Fetch failed workflow error."""
1500
+ from sqlalchemy import text
1501
+
1502
+ from horsies.core.models.tasks import TaskResult, TaskError
1503
+ from horsies.core.codec.serde import loads_json
1504
+
1505
+ async with self.broker.session_factory() as session:
1506
+ result = await session.execute(
1507
+ text('SELECT error, status FROM horsies_workflows WHERE id = :wf_id'),
1508
+ {'wf_id': self.workflow_id},
1509
+ )
1510
+ row = result.fetchone()
1511
+ if row and row[0]:
1512
+ error_data = loads_json(row[0])
1513
+ if isinstance(error_data, dict):
1514
+ # Safely extract known TaskError fields with type narrowing
1515
+ raw_code = error_data.get('error_code')
1516
+ raw_msg = error_data.get('message')
1517
+ return TaskResult(
1518
+ err=TaskError(
1519
+ error_code=str(raw_code) if raw_code is not None else None,
1520
+ message=str(raw_msg) if raw_msg is not None else None,
1521
+ data=error_data.get('data'),
1522
+ )
1523
+ )
1524
+ status_str = row[1] if row else 'FAILED'
1525
+ return TaskResult(
1526
+ err=TaskError(
1527
+ error_code=f'WORKFLOW_{status_str}',
1528
+ message=f'Workflow {status_str.lower()}',
1529
+ )
1530
+ )
1531
+
1532
+ async def _wait_for_completion(self, timeout_sec: float) -> None:
1533
+ """Wait for workflow_done notification or poll interval."""
1534
+ import asyncio
1535
+
1536
+ try:
1537
+ q = await self.broker.listener.listen('workflow_done')
1538
+ try:
1539
+
1540
+ async def _wait_for_workflow() -> None:
1541
+ while True:
1542
+ note = await q.get()
1543
+ if note.payload == self.workflow_id:
1544
+ return
1545
+
1546
+ await asyncio.wait_for(_wait_for_workflow(), timeout=timeout_sec)
1547
+ finally:
1548
+ await self.broker.listener.unsubscribe('workflow_done', q)
1549
+ except asyncio.TimeoutError:
1550
+ pass # Polling fallback
1551
+
1552
+ def results(self) -> dict[str, TaskResult[Any, TaskError]]:
1553
+ """
1554
+ Get all task results keyed by unique identifier.
1555
+
1556
+ Keys are `node_id` values. If a TaskNode did not specify a node_id,
1557
+ WorkflowSpec auto-assigns one as "{workflow_name}:{task_index}".
1558
+ """
1559
+ from horsies.core.utils.loop_runner import LoopRunner
1560
+
1561
+ runner = LoopRunner()
1562
+ try:
1563
+ return runner.call(self.results_async)
1564
+ finally:
1565
+ runner.stop()
1566
+
1567
+ async def results_async(self) -> dict[str, TaskResult[Any, TaskError]]:
1568
+ """
1569
+ Async version of results().
1570
+
1571
+ Keys are `node_id` values. If a TaskNode did not specify a node_id,
1572
+ WorkflowSpec auto-assigns one as "{workflow_name}:{task_index}".
1573
+ """
1574
+ from sqlalchemy import text
1575
+
1576
+ from horsies.core.codec.serde import loads_json, task_result_from_json
1577
+
1578
+ async with self.broker.session_factory() as session:
1579
+ result = await session.execute(
1580
+ text("""
1581
+ SELECT node_id, result
1582
+ FROM horsies_workflow_tasks
1583
+ WHERE workflow_id = :wf_id
1584
+ AND result IS NOT NULL
1585
+ """),
1586
+ {'wf_id': self.workflow_id},
1587
+ )
1588
+
1589
+ return {
1590
+ row[0]: task_result_from_json(loads_json(row[1]))
1591
+ for row in result.fetchall()
1592
+ }
1593
+
1594
+ def result_for(
1595
+ self, node: TaskNode[OkT] | NodeKey[OkT]
1596
+ ) -> 'TaskResult[OkT, TaskError]':
1597
+ """
1598
+ Get the result for a specific TaskNode or NodeKey.
1599
+
1600
+ Non-blocking: queries the database once and returns immediately.
1601
+
1602
+ Args:
1603
+ node: The TaskNode or NodeKey whose result to retrieve.
1604
+
1605
+ Returns:
1606
+ TaskResult[T, TaskError] where T matches the node's type.
1607
+ - If task completed: returns the task's result (success or error)
1608
+ - If task not completed: returns TaskResult with
1609
+ error_code=LibraryErrorCode.RESULT_NOT_READY
1610
+
1611
+ Raises:
1612
+ WorkflowHandleMissingIdError: If node has no node_id assigned.
1613
+
1614
+ Example:
1615
+ result = handle.result_for(node)
1616
+ if result.is_err() and result.err.error_code == LibraryErrorCode.RESULT_NOT_READY:
1617
+ # Task hasn't completed yet - wait or check later
1618
+ pass
1619
+ """
1620
+ from horsies.core.utils.loop_runner import LoopRunner
1621
+
1622
+ runner = LoopRunner()
1623
+ try:
1624
+ return runner.call(self.result_for_async, node)
1625
+ finally:
1626
+ runner.stop()
1627
+
1628
+ async def result_for_async(
1629
+ self, node: TaskNode[OkT] | NodeKey[OkT]
1630
+ ) -> 'TaskResult[OkT, TaskError]':
1631
+ """Async version of result_for(). See result_for() for full documentation."""
1632
+ from sqlalchemy import text
1633
+
1634
+ from horsies.core.codec.serde import loads_json, task_result_from_json
1635
+
1636
+ node_id: str | None
1637
+ if isinstance(node, NodeKey):
1638
+ node_id = node.node_id
1639
+ else:
1640
+ node_id = node.node_id
1641
+
1642
+ if node_id is None:
1643
+ raise WorkflowHandleMissingIdError(
1644
+ 'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
1645
+ 'or provide an explicit node_id.'
1646
+ )
1647
+
1648
+ async with self.broker.session_factory() as session:
1649
+ result = await session.execute(
1650
+ text("""
1651
+ SELECT result
1652
+ FROM horsies_workflow_tasks
1653
+ WHERE workflow_id = :wf_id
1654
+ AND node_id = :node_id
1655
+ AND result IS NOT NULL
1656
+ """),
1657
+ {'wf_id': self.workflow_id, 'node_id': node_id},
1658
+ )
1659
+ row = result.fetchone()
1660
+ if row is None or row[0] is None:
1661
+ from horsies.core.models.tasks import (
1662
+ TaskResult,
1663
+ TaskError,
1664
+ LibraryErrorCode,
1665
+ )
1666
+
1667
+ return cast(
1668
+ 'TaskResult[OkT, TaskError]',
1669
+ TaskResult(
1670
+ err=TaskError(
1671
+ error_code=LibraryErrorCode.RESULT_NOT_READY,
1672
+ message=(
1673
+ f"Task '{node_id}' has not completed yet "
1674
+ f"in workflow '{self.workflow_id}'"
1675
+ ),
1676
+ )
1677
+ ),
1678
+ )
1679
+
1680
+ return cast(
1681
+ 'TaskResult[OkT, TaskError]',
1682
+ task_result_from_json(loads_json(row[0])),
1683
+ )
1684
+
1685
+ def tasks(self) -> list[WorkflowTaskInfo]:
1686
+ """Get status of all tasks in workflow."""
1687
+ from horsies.core.utils.loop_runner import LoopRunner
1688
+
1689
+ runner = LoopRunner()
1690
+ try:
1691
+ return runner.call(self.tasks_async)
1692
+ finally:
1693
+ runner.stop()
1694
+
1695
+ async def tasks_async(self) -> list[WorkflowTaskInfo]:
1696
+ """Async version of tasks()."""
1697
+ from sqlalchemy import text
1698
+
1699
+ from horsies.core.codec.serde import loads_json, task_result_from_json
1700
+
1701
+ async with self.broker.session_factory() as session:
1702
+ result = await session.execute(
1703
+ text("""
1704
+ SELECT node_id, task_index, task_name, status, result, started_at, completed_at
1705
+ FROM horsies_workflow_tasks
1706
+ WHERE workflow_id = :wf_id
1707
+ ORDER BY task_index
1708
+ """),
1709
+ {'wf_id': self.workflow_id},
1710
+ )
1711
+
1712
+ return [
1713
+ WorkflowTaskInfo(
1714
+ node_id=row[0],
1715
+ index=row[1],
1716
+ name=row[2],
1717
+ status=WorkflowTaskStatus(row[3]),
1718
+ result=task_result_from_json(loads_json(row[4]))
1719
+ if row[4]
1720
+ else None,
1721
+ started_at=row[5],
1722
+ completed_at=row[6],
1723
+ )
1724
+ for row in result.fetchall()
1725
+ ]
1726
+
1727
+ def cancel(self) -> None:
1728
+ """Request workflow cancellation."""
1729
+ from horsies.core.utils.loop_runner import LoopRunner
1730
+
1731
+ runner = LoopRunner()
1732
+ try:
1733
+ runner.call(self.cancel_async)
1734
+ finally:
1735
+ runner.stop()
1736
+
1737
+ async def cancel_async(self) -> None:
1738
+ """Async version of cancel()."""
1739
+ from sqlalchemy import text
1740
+
1741
+ async with self.broker.session_factory() as session:
1742
+ # Cancel workflow
1743
+ await session.execute(
1744
+ text("""
1745
+ UPDATE horsies_workflows
1746
+ SET status = 'CANCELLED', updated_at = NOW()
1747
+ WHERE id = :wf_id AND status IN ('PENDING', 'RUNNING', 'PAUSED')
1748
+ """),
1749
+ {'wf_id': self.workflow_id},
1750
+ )
1751
+
1752
+ # Skip pending/ready tasks
1753
+ await session.execute(
1754
+ text("""
1755
+ UPDATE horsies_workflow_tasks
1756
+ SET status = 'SKIPPED'
1757
+ WHERE workflow_id = :wf_id AND status IN ('PENDING', 'READY')
1758
+ """),
1759
+ {'wf_id': self.workflow_id},
1760
+ )
1761
+
1762
+ await session.commit()
1763
+
1764
+ def pause(self) -> bool:
1765
+ """
1766
+ Pause a running workflow.
1767
+
1768
+ Transitions workflow from RUNNING to PAUSED state. Already-running tasks
1769
+ will continue to completion, but no new tasks will be enqueued.
1770
+
1771
+ Use resume() to continue execution.
1772
+
1773
+ Returns:
1774
+ True if workflow was paused, False if not RUNNING (no-op)
1775
+ """
1776
+ from horsies.core.utils.loop_runner import LoopRunner
1777
+
1778
+ runner = LoopRunner()
1779
+ try:
1780
+ return runner.call(self.pause_async)
1781
+ finally:
1782
+ runner.stop()
1783
+
1784
+ async def pause_async(self) -> bool:
1785
+ """
1786
+ Async version of pause().
1787
+
1788
+ Returns:
1789
+ True if workflow was paused, False if not RUNNING (no-op)
1790
+ """
1791
+ from horsies.core.workflows.engine import pause_workflow
1792
+
1793
+ return await pause_workflow(self.broker, self.workflow_id)
1794
+
1795
+ def resume(self) -> bool:
1796
+ """
1797
+ Resume a paused workflow.
1798
+
1799
+ Re-evaluates all PENDING tasks (marks READY if deps are terminal) and
1800
+ enqueues all READY tasks. Only works if workflow is currently PAUSED.
1801
+
1802
+ Returns:
1803
+ True if workflow was resumed, False if not PAUSED (no-op)
1804
+ """
1805
+ from horsies.core.utils.loop_runner import LoopRunner
1806
+
1807
+ runner = LoopRunner()
1808
+ try:
1809
+ return runner.call(self.resume_async)
1810
+ finally:
1811
+ runner.stop()
1812
+
1813
+ async def resume_async(self) -> bool:
1814
+ """
1815
+ Async version of resume().
1816
+
1817
+ Returns:
1818
+ True if workflow was resumed, False if not PAUSED (no-op)
1819
+ """
1820
+ from horsies.core.workflows.engine import resume_workflow
1821
+
1822
+ return await resume_workflow(self.broker, self.workflow_id)
1823
+
1824
+
1825
+ # =============================================================================
1826
+ # WorkflowDefinition (class-based workflow definition)
1827
+ # =============================================================================
1828
+
1829
+
1830
+ class WorkflowDefinitionMeta(type):
1831
+ """
1832
+ Metaclass for WorkflowDefinition that preserves attribute order.
1833
+
1834
+ Collects TaskNode and SubWorkflowNode instances from class attributes
1835
+ in definition order.
1836
+ """
1837
+
1838
+ def __new__(
1839
+ mcs,
1840
+ name: str,
1841
+ bases: tuple[type, ...],
1842
+ namespace: dict[str, Any],
1843
+ ) -> 'WorkflowDefinitionMeta':
1844
+ cls = super().__new__(mcs, name, bases, namespace)
1845
+
1846
+ # Skip processing for the base class itself
1847
+ if name == 'WorkflowDefinition':
1848
+ return cls
1849
+
1850
+ # Collect TaskNode and SubWorkflowNode instances in definition order
1851
+ nodes: list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]] = []
1852
+ for attr_name, attr_value in namespace.items():
1853
+ if isinstance(attr_value, (TaskNode, SubWorkflowNode)):
1854
+ nodes.append((attr_name, attr_value))
1855
+
1856
+ # Store the collected nodes on the class
1857
+ cls._workflow_nodes = nodes # type: ignore[attr-defined]
1858
+
1859
+ return cls
1860
+
1861
+
1862
+ class WorkflowDefinition(Generic[OkT_co], metaclass=WorkflowDefinitionMeta):
1863
+ """
1864
+ Base class for declarative workflow definitions.
1865
+
1866
+ Generic parameter OkT represents the workflow's output type, derived from
1867
+ Meta.output task's return type.
1868
+
1869
+ Provides a class-based alternative to app.workflow() for defining workflows.
1870
+ TaskNode and SubWorkflowNode instances defined as class attributes are
1871
+ automatically collected and used to build a WorkflowSpec.
1872
+
1873
+ Example:
1874
+ class ScrapeWorkflow(WorkflowDefinition[PersistResult]):
1875
+ name = "scrape_pipeline"
1876
+
1877
+ fetch = TaskNode(fn=fetch_listing, args=("url",))
1878
+ parse = TaskNode(fn=parse_listing, waits_for=[fetch], args_from={"raw": fetch})
1879
+ persist = TaskNode(fn=persist_listing, waits_for=[parse], args_from={"data": parse})
1880
+
1881
+ class Meta:
1882
+ output = persist # Output type is PersistResult
1883
+ on_error = OnError.FAIL
1884
+
1885
+ spec = ScrapeWorkflow.build(app)
1886
+
1887
+ Attributes:
1888
+ name: Required workflow name (class attribute).
1889
+ Meta: Optional inner class for workflow configuration.
1890
+ - output: TaskNode/SubWorkflowNode to use as workflow output (default: None)
1891
+ - on_error: Error handling policy (default: OnError.FAIL)
1892
+ - success_policy: Custom success policy (default: None)
1893
+ """
1894
+
1895
+ # Class attributes to be defined by subclasses
1896
+ name: ClassVar[str]
1897
+
1898
+ # Populated by metaclass
1899
+ _workflow_nodes: ClassVar[list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]]]
1900
+
1901
+ @classmethod
1902
+ def get_workflow_nodes(
1903
+ cls,
1904
+ ) -> list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]]:
1905
+ """Return collected workflow nodes or an empty list if none were defined."""
1906
+ nodes = getattr(cls, '_workflow_nodes', None)
1907
+ if not isinstance(nodes, list):
1908
+ return []
1909
+ return cast(list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]], nodes)
1910
+
1911
+ @classmethod
1912
+ def build(cls, app: 'Horsies') -> WorkflowSpec:
1913
+ """
1914
+ Build a WorkflowSpec from this workflow definition.
1915
+
1916
+ Collects all TaskNode class attributes, assigns node_ids from attribute
1917
+ names, and creates a WorkflowSpec with the configured options.
1918
+
1919
+ Args:
1920
+ app: Horsies application instance (provides broker).
1921
+
1922
+ Returns:
1923
+ WorkflowSpec ready for execution.
1924
+
1925
+ Raises:
1926
+ WorkflowValidationError: If workflow definition is invalid.
1927
+ """
1928
+ # Validate name is defined
1929
+ if not hasattr(cls, 'name') or not cls.name:
1930
+ raise WorkflowValidationError(
1931
+ f"WorkflowDefinition '{cls.__name__}' must define a 'name' class attribute"
1932
+ )
1933
+
1934
+ # Get collected nodes from metaclass
1935
+ nodes = cls.get_workflow_nodes()
1936
+ if not nodes:
1937
+ raise WorkflowValidationError(
1938
+ f"WorkflowDefinition '{cls.__name__}' has no TaskNode attributes"
1939
+ )
1940
+
1941
+ # Assign node_id from attribute name (if not already set)
1942
+ for attr_name, node in nodes:
1943
+ if node.node_id is None:
1944
+ node.node_id = attr_name
1945
+
1946
+ # Extract task list (preserving definition order)
1947
+ tasks = [node for _, node in nodes]
1948
+
1949
+ # Get Meta configuration
1950
+ output: TaskNode[Any] | SubWorkflowNode[Any] | None = None
1951
+ on_error: OnError = OnError.FAIL
1952
+ success_policy: SuccessPolicy | None = None
1953
+
1954
+ meta: type[Any] | None = getattr(cls, 'Meta', None)
1955
+ if meta is not None:
1956
+ output = getattr(meta, 'output', None)
1957
+ on_error = getattr(meta, 'on_error', OnError.FAIL)
1958
+ success_policy = getattr(meta, 'success_policy', None)
1959
+
1960
+ # Build WorkflowSpec
1961
+ spec = app.workflow(
1962
+ name=cls.name,
1963
+ tasks=tasks,
1964
+ output=output,
1965
+ on_error=on_error,
1966
+ success_policy=success_policy,
1967
+ )
1968
+ spec.workflow_def_module = cls.__module__
1969
+ spec.workflow_def_qualname = cls.__qualname__
1970
+ return spec
1971
+
1972
+ @classmethod
1973
+ def build_with(
1974
+ cls,
1975
+ app: 'Horsies',
1976
+ *args: Any,
1977
+ **params: Any,
1978
+ ) -> WorkflowSpec:
1979
+ """
1980
+ Build a WorkflowSpec with runtime parameters.
1981
+
1982
+ Subclasses can override this to apply params to TaskNodes.
1983
+ Default implementation forwards to build().
1984
+ """
1985
+ _ = args
1986
+ _ = params
1987
+ spec = cls.build(app)
1988
+ spec.workflow_def_module = cls.__module__
1989
+ spec.workflow_def_qualname = cls.__qualname__
1990
+ return spec