horsies 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- horsies/__init__.py +115 -0
- horsies/core/__init__.py +0 -0
- horsies/core/app.py +552 -0
- horsies/core/banner.py +144 -0
- horsies/core/brokers/__init__.py +5 -0
- horsies/core/brokers/listener.py +444 -0
- horsies/core/brokers/postgres.py +864 -0
- horsies/core/cli.py +624 -0
- horsies/core/codec/serde.py +575 -0
- horsies/core/errors.py +535 -0
- horsies/core/logging.py +90 -0
- horsies/core/models/__init__.py +0 -0
- horsies/core/models/app.py +268 -0
- horsies/core/models/broker.py +79 -0
- horsies/core/models/queues.py +23 -0
- horsies/core/models/recovery.py +101 -0
- horsies/core/models/schedule.py +229 -0
- horsies/core/models/task_pg.py +307 -0
- horsies/core/models/tasks.py +332 -0
- horsies/core/models/workflow.py +1988 -0
- horsies/core/models/workflow_pg.py +245 -0
- horsies/core/registry/tasks.py +101 -0
- horsies/core/scheduler/__init__.py +26 -0
- horsies/core/scheduler/calculator.py +267 -0
- horsies/core/scheduler/service.py +569 -0
- horsies/core/scheduler/state.py +260 -0
- horsies/core/task_decorator.py +615 -0
- horsies/core/types/status.py +38 -0
- horsies/core/utils/imports.py +203 -0
- horsies/core/utils/loop_runner.py +44 -0
- horsies/core/worker/current.py +17 -0
- horsies/core/worker/worker.py +1967 -0
- horsies/core/workflows/__init__.py +23 -0
- horsies/core/workflows/engine.py +2344 -0
- horsies/core/workflows/recovery.py +501 -0
- horsies/core/workflows/registry.py +97 -0
- horsies/py.typed +0 -0
- horsies-0.1.0a1.dist-info/METADATA +31 -0
- horsies-0.1.0a1.dist-info/RECORD +42 -0
- horsies-0.1.0a1.dist-info/WHEEL +5 -0
- horsies-0.1.0a1.dist-info/entry_points.txt +2 -0
- horsies-0.1.0a1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1988 @@
|
|
|
1
|
+
"""Core workflow models for DAG-based task orchestration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
import re
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Any,
|
|
13
|
+
TypeVar,
|
|
14
|
+
Generic,
|
|
15
|
+
cast,
|
|
16
|
+
Callable,
|
|
17
|
+
Literal,
|
|
18
|
+
ClassVar,
|
|
19
|
+
)
|
|
20
|
+
import inspect
|
|
21
|
+
import time
|
|
22
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
23
|
+
from horsies.core.errors import ErrorCode, SourceLocation, ValidationReport, raise_collected
|
|
24
|
+
from pydantic import BaseModel
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from horsies.core.task_decorator import TaskFunction
|
|
28
|
+
from horsies.core.brokers.postgres import PostgresBroker
|
|
29
|
+
from horsies.core.models.tasks import TaskResult, TaskError
|
|
30
|
+
from horsies.core.app import Horsies
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# TypeVar for TaskNode generic parameter (the "ok" type of TaskResult)
|
|
34
|
+
OkT = TypeVar('OkT')
|
|
35
|
+
OkT_co = TypeVar('OkT_co', covariant=True)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# =============================================================================
|
|
39
|
+
# Enums
|
|
40
|
+
# =============================================================================
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class WorkflowStatus(str, Enum):
|
|
44
|
+
"""
|
|
45
|
+
Status of a workflow instance.
|
|
46
|
+
|
|
47
|
+
State machine:
|
|
48
|
+
PENDING → RUNNING → COMPLETED
|
|
49
|
+
→ FAILED (on task failure with on_error=FAIL)
|
|
50
|
+
→ PAUSED (on task failure with on_error=PAUSE)
|
|
51
|
+
→ CANCELLED (user requested)
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
PENDING = 'PENDING'
|
|
55
|
+
"""Created but not yet started"""
|
|
56
|
+
|
|
57
|
+
RUNNING = 'RUNNING'
|
|
58
|
+
"""At least one task executing or ready"""
|
|
59
|
+
|
|
60
|
+
COMPLETED = 'COMPLETED'
|
|
61
|
+
"""All tasks terminal and success criteria met"""
|
|
62
|
+
|
|
63
|
+
FAILED = 'FAILED'
|
|
64
|
+
"""A task failed and on_error=FAIL (or no success case satisfied)"""
|
|
65
|
+
|
|
66
|
+
PAUSED = 'PAUSED'
|
|
67
|
+
"""A task failed with on_error=PAUSE; awaiting resume() or cancel()"""
|
|
68
|
+
|
|
69
|
+
CANCELLED = 'CANCELLED'
|
|
70
|
+
"""User cancelled via WorkflowHandle.cancel()"""
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def is_terminal(self) -> bool:
|
|
74
|
+
"""Whether this status represents a final state (no further transitions)."""
|
|
75
|
+
return self in WORKFLOW_TERMINAL_STATES
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
WORKFLOW_TERMINAL_STATES: frozenset[WorkflowStatus] = frozenset({
|
|
79
|
+
WorkflowStatus.COMPLETED,
|
|
80
|
+
WorkflowStatus.FAILED,
|
|
81
|
+
WorkflowStatus.CANCELLED,
|
|
82
|
+
})
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class WorkflowTaskStatus(str, Enum):
|
|
86
|
+
"""
|
|
87
|
+
Status of a single task/node within a workflow.
|
|
88
|
+
|
|
89
|
+
State machine:
|
|
90
|
+
PENDING → READY → ENQUEUED → RUNNING → COMPLETED
|
|
91
|
+
→ FAILED
|
|
92
|
+
→ SKIPPED (deps failed and allow_failed_deps=False)
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
PENDING = 'PENDING'
|
|
96
|
+
"""Waiting for dependencies to become terminal"""
|
|
97
|
+
|
|
98
|
+
READY = 'READY'
|
|
99
|
+
"""Dependencies satisfied, waiting to be enqueued"""
|
|
100
|
+
|
|
101
|
+
ENQUEUED = 'ENQUEUED'
|
|
102
|
+
"""Task created in tasks table, waiting for worker"""
|
|
103
|
+
|
|
104
|
+
RUNNING = 'RUNNING'
|
|
105
|
+
"""Worker is executing the task (or child workflow is running)"""
|
|
106
|
+
|
|
107
|
+
COMPLETED = 'COMPLETED'
|
|
108
|
+
"""Task succeeded (TaskResult.is_ok())"""
|
|
109
|
+
|
|
110
|
+
FAILED = 'FAILED'
|
|
111
|
+
"""Task failed (TaskResult.is_err())"""
|
|
112
|
+
|
|
113
|
+
SKIPPED = 'SKIPPED'
|
|
114
|
+
"""Skipped due to: upstream failure, condition, or quorum impossible"""
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def is_terminal(self) -> bool:
|
|
118
|
+
"""Whether this status represents a final state (no further transitions)."""
|
|
119
|
+
return self in WORKFLOW_TASK_TERMINAL_STATES
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
WORKFLOW_TASK_TERMINAL_STATES: frozenset[WorkflowTaskStatus] = frozenset({
|
|
123
|
+
WorkflowTaskStatus.COMPLETED,
|
|
124
|
+
WorkflowTaskStatus.FAILED,
|
|
125
|
+
WorkflowTaskStatus.SKIPPED,
|
|
126
|
+
})
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class OnError(str, Enum):
|
|
130
|
+
"""
|
|
131
|
+
Error handling policy for workflows when a task fails.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
FAIL = 'fail'
|
|
135
|
+
"""Continue DAG resolution but mark workflow as will-fail. Skip tasks without allow_failed_deps."""
|
|
136
|
+
|
|
137
|
+
PAUSE = 'pause'
|
|
138
|
+
"""Pause workflow immediately. No new tasks enqueued until resume()."""
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class SubWorkflowRetryMode(str, Enum):
|
|
142
|
+
"""
|
|
143
|
+
Retry behavior for subworkflows (only RERUN_FAILED_ONLY currently supported).
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
RERUN_FAILED_ONLY = 'rerun_failed_only'
|
|
147
|
+
"""Re-run only failed/cancelled child tasks (default, only supported mode)"""
|
|
148
|
+
|
|
149
|
+
RERUN_ALL = 'rerun_all'
|
|
150
|
+
"""Re-run entire child workflow from scratch (not yet implemented)"""
|
|
151
|
+
|
|
152
|
+
NO_RERUN = 'no_rerun'
|
|
153
|
+
"""Re-evaluate success policy without re-running (not yet implemented)"""
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# =============================================================================
|
|
157
|
+
# SubWorkflowSummary
|
|
158
|
+
# =============================================================================
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@dataclass
|
|
162
|
+
class SubWorkflowSummary(Generic[OkT_co]):
|
|
163
|
+
"""
|
|
164
|
+
Summary of a child workflow's execution, available via WorkflowContext.summary_for().
|
|
165
|
+
|
|
166
|
+
Provides visibility into child workflow health without exposing internal DAG.
|
|
167
|
+
Useful for conditional logic based on partial success/failure.
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
def make_report(workflow_ctx: WorkflowContext | None) -> TaskResult[str, TaskError]:
|
|
171
|
+
if workflow_ctx is None:
|
|
172
|
+
return TaskResult(err=TaskError(error_code="NO_CTX"))
|
|
173
|
+
summary = workflow_ctx.summary_for(data_pipeline_node)
|
|
174
|
+
if summary.failed_tasks > 0:
|
|
175
|
+
return TaskResult(ok=f"Partial: {summary.completed_tasks}/{summary.total_tasks}")
|
|
176
|
+
return TaskResult(ok=f"Full: {summary.output}")
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
status: WorkflowStatus
|
|
180
|
+
"""Child workflow's final status (COMPLETED, FAILED, etc.)"""
|
|
181
|
+
|
|
182
|
+
success_case: str | None
|
|
183
|
+
"""Which SuccessCase was satisfied (if success_policy used)"""
|
|
184
|
+
|
|
185
|
+
output: OkT_co | None
|
|
186
|
+
"""Child's output value (typed via generic parameter)"""
|
|
187
|
+
|
|
188
|
+
total_tasks: int
|
|
189
|
+
"""Total number of tasks in child workflow"""
|
|
190
|
+
|
|
191
|
+
completed_tasks: int
|
|
192
|
+
"""Number of COMPLETED tasks"""
|
|
193
|
+
|
|
194
|
+
failed_tasks: int
|
|
195
|
+
"""Number of FAILED tasks"""
|
|
196
|
+
|
|
197
|
+
skipped_tasks: int
|
|
198
|
+
"""Number of SKIPPED tasks"""
|
|
199
|
+
|
|
200
|
+
error_summary: str | None = None
|
|
201
|
+
"""Brief description of failure (if child failed)"""
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def from_json(cls, data: dict[str, Any]) -> 'SubWorkflowSummary[Any]':
|
|
205
|
+
"""Build a SubWorkflowSummary from a JSON-like dict safely."""
|
|
206
|
+
status_val = data.get('status')
|
|
207
|
+
try:
|
|
208
|
+
status = WorkflowStatus(str(status_val))
|
|
209
|
+
except Exception:
|
|
210
|
+
status = WorkflowStatus.FAILED
|
|
211
|
+
|
|
212
|
+
success_case_val = data.get('success_case')
|
|
213
|
+
total_val = data.get('total_tasks', 0)
|
|
214
|
+
completed_val = data.get('completed_tasks', 0)
|
|
215
|
+
failed_val = data.get('failed_tasks', 0)
|
|
216
|
+
skipped_val = data.get('skipped_tasks', 0)
|
|
217
|
+
error_val = data.get('error_summary')
|
|
218
|
+
|
|
219
|
+
return cls(
|
|
220
|
+
status=status,
|
|
221
|
+
success_case=str(success_case_val) if success_case_val else None,
|
|
222
|
+
output=data.get('output'),
|
|
223
|
+
total_tasks=int(total_val) if isinstance(total_val, (int, float)) else 0,
|
|
224
|
+
completed_tasks=int(completed_val)
|
|
225
|
+
if isinstance(completed_val, (int, float))
|
|
226
|
+
else 0,
|
|
227
|
+
failed_tasks=int(failed_val) if isinstance(failed_val, (int, float)) else 0,
|
|
228
|
+
skipped_tasks=int(skipped_val)
|
|
229
|
+
if isinstance(skipped_val, (int, float))
|
|
230
|
+
else 0,
|
|
231
|
+
error_summary=str(error_val) if error_val else None,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# =============================================================================
|
|
236
|
+
# Exceptions
|
|
237
|
+
# =============================================================================
|
|
238
|
+
|
|
239
|
+
# Re-export from errors module for backward compatibility
|
|
240
|
+
from horsies.core.errors import WorkflowValidationError
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _task_accepts_workflow_ctx(fn: Callable[..., Any]) -> bool:
|
|
244
|
+
inspect_target: Callable[..., Any] = fn
|
|
245
|
+
original = getattr(fn, '_original_fn', None)
|
|
246
|
+
if callable(original):
|
|
247
|
+
inspect_target = original
|
|
248
|
+
try:
|
|
249
|
+
sig = inspect.signature(inspect_target)
|
|
250
|
+
except (TypeError, ValueError):
|
|
251
|
+
return False
|
|
252
|
+
return 'workflow_ctx' in sig.parameters
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
NODE_ID_PATTERN = re.compile(r'^[A-Za-z0-9_\-:.]+$')
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def slugify(value: str) -> str:
|
|
259
|
+
"""
|
|
260
|
+
Convert a string to a valid node_id by replacing invalid characters.
|
|
261
|
+
|
|
262
|
+
Spaces become underscores, other invalid characters are removed.
|
|
263
|
+
Result matches NODE_ID_PATTERN: [A-Za-z0-9_\\-:.]+
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
slugify("My Workflow Name") # "My_Workflow_Name"
|
|
267
|
+
slugify("task:v2.0") # "task:v2.0" (unchanged)
|
|
268
|
+
"""
|
|
269
|
+
result = value.replace(' ', '_')
|
|
270
|
+
result = re.sub(r'[^A-Za-z0-9_\-:.]', '', result)
|
|
271
|
+
return result or '_'
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# =============================================================================
|
|
275
|
+
# TaskNode
|
|
276
|
+
# =============================================================================
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@dataclass
|
|
280
|
+
class TaskNode(Generic[OkT_co]):
|
|
281
|
+
"""
|
|
282
|
+
A node in the workflow DAG representing a single task execution.
|
|
283
|
+
|
|
284
|
+
Generic parameter OkT represents the success type of the task's TaskResult.
|
|
285
|
+
This enables type-safe access to results via WorkflowContext.result_for(node).
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
```python
|
|
289
|
+
fetch = TaskNode(fn=fetch_data, args=("url",))
|
|
290
|
+
process = TaskNode(
|
|
291
|
+
fn=process_data,
|
|
292
|
+
waits_for=[fetch], # wait for fetch to be terminal
|
|
293
|
+
args_from={"raw": fetch}, # inject fetch's TaskResult as 'raw' kwarg
|
|
294
|
+
allow_failed_deps=True, # run even if fetch failed
|
|
295
|
+
)
|
|
296
|
+
```
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
fn: TaskFunction[Any, OkT_co]
|
|
300
|
+
args: tuple[Any, ...] = ()
|
|
301
|
+
kwargs: dict[str, Any] = field(default_factory=lambda: {})
|
|
302
|
+
waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] = field(
|
|
303
|
+
default_factory=lambda: [],
|
|
304
|
+
)
|
|
305
|
+
"""
|
|
306
|
+
- List of TaskNodes or SubWorkflowNodes that this task waits for
|
|
307
|
+
- The node with the dependencies will wait for all dependencies to be terminal (COMPLETED/FAILED/SKIPPED)
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] = field(
|
|
311
|
+
default_factory=lambda: {},
|
|
312
|
+
)
|
|
313
|
+
"""
|
|
314
|
+
- Data flow: inject dependency TaskResults as kwargs (keyword-only)
|
|
315
|
+
- Example: args_from={"validated": validate_node, "transformed": transform_node}
|
|
316
|
+
- Task receives: def my_task(validated: TaskResult[A, E], transformed: TaskResult[B, E])
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None
|
|
320
|
+
"""
|
|
321
|
+
- Optional context: subset of dependencies to include in WorkflowContext
|
|
322
|
+
- Only injected if task declares `workflow_ctx: WorkflowContext | None` parameter
|
|
323
|
+
|
|
324
|
+
"""
|
|
325
|
+
# Queue/priority overrides (if None, use task decorator defaults)
|
|
326
|
+
queue: str | None = None
|
|
327
|
+
"""
|
|
328
|
+
- Queue overrides (if None, use task decorator defaults)
|
|
329
|
+
"""
|
|
330
|
+
priority: int | None = None
|
|
331
|
+
"""
|
|
332
|
+
- Priority overrides (if None, use task decorator defaults)
|
|
333
|
+
"""
|
|
334
|
+
|
|
335
|
+
allow_failed_deps: bool = False
|
|
336
|
+
"""
|
|
337
|
+
- If True, this task runs even if dependencies failed (receives failed TaskResults)
|
|
338
|
+
- If False (default), task is SKIPPED when any dependency fails
|
|
339
|
+
"""
|
|
340
|
+
|
|
341
|
+
run_when: Callable[['WorkflowContext'], bool] | None = field(
|
|
342
|
+
default=None, repr=False
|
|
343
|
+
)
|
|
344
|
+
"""
|
|
345
|
+
- Conditional execution: evaluated after deps are terminal, before enqueue
|
|
346
|
+
- skip_when has priority over run_when
|
|
347
|
+
- Callables receive WorkflowContext built from workflow_ctx_from
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
skip_when: Callable[['WorkflowContext'], bool] | None = field(
|
|
351
|
+
default=None, repr=False
|
|
352
|
+
)
|
|
353
|
+
"""
|
|
354
|
+
- Conditional execution: evaluated after deps are terminal, before enqueue
|
|
355
|
+
- skip_when has priority over run_when
|
|
356
|
+
- Callables receive WorkflowContext built from workflow_ctx_from
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
join: Literal['all', 'any', 'quorum'] = 'all'
|
|
360
|
+
"""
|
|
361
|
+
- Dependency join semantics
|
|
362
|
+
- "all": task runs when ALL dependencies are terminal (default)
|
|
363
|
+
- "any": task runs when ANY dependency succeeds (COMPLETED)
|
|
364
|
+
- "quorum": task runs when at least min_success dependencies succeed
|
|
365
|
+
"""
|
|
366
|
+
min_success: int | None = None
|
|
367
|
+
"""
|
|
368
|
+
- Required for join="quorum"
|
|
369
|
+
- Minimum number of dependencies that must succeed
|
|
370
|
+
"""
|
|
371
|
+
|
|
372
|
+
good_until: datetime | None = None
|
|
373
|
+
"""
|
|
374
|
+
- Task expiry deadline (task skipped if not claimed by this time)
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
# Assigned during WorkflowSpec construction
|
|
378
|
+
index: int | None = field(default=None, repr=False)
|
|
379
|
+
"""
|
|
380
|
+
- Assigned during WorkflowSpec construction
|
|
381
|
+
- Index of the task in the workflow
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
node_id: str | None = field(default=None, repr=False)
|
|
385
|
+
"""
|
|
386
|
+
Optional stable identifier for this task within the workflow.
|
|
387
|
+
If None, auto-assigned as '{slugify(workflow_name)}:{task_index}'.
|
|
388
|
+
Must be unique within the workflow.
|
|
389
|
+
Must match pattern: [A-Za-z0-9_\\-:.]+
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def name(self) -> str:
|
|
394
|
+
"""Get the task name from the wrapped function."""
|
|
395
|
+
return self.fn.task_name
|
|
396
|
+
|
|
397
|
+
def key(self) -> 'NodeKey[OkT_co]':
|
|
398
|
+
"""Return a typed NodeKey for this task node."""
|
|
399
|
+
if self.node_id is None:
|
|
400
|
+
raise WorkflowValidationError(
|
|
401
|
+
'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
|
|
402
|
+
'or provide an explicit node_id.'
|
|
403
|
+
)
|
|
404
|
+
return NodeKey(self.node_id)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
# =============================================================================
|
|
408
|
+
# SubWorkflowNode
|
|
409
|
+
# =============================================================================
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
@dataclass
|
|
413
|
+
class SubWorkflowNode(Generic[OkT_co]):
|
|
414
|
+
"""
|
|
415
|
+
A node that runs a child workflow as a composite task.
|
|
416
|
+
|
|
417
|
+
The generic parameter OkT represents the child workflow's output type,
|
|
418
|
+
derived from WorkflowDefinition[OkT].
|
|
419
|
+
|
|
420
|
+
When the child workflow completes:
|
|
421
|
+
- COMPLETED: parent node receives TaskResult[OkT, TaskError] with child's output
|
|
422
|
+
- FAILED: parent node receives TaskResult with SubWorkflowError containing SubWorkflowSummary
|
|
423
|
+
|
|
424
|
+
Example:
|
|
425
|
+
class DataPipeline(WorkflowDefinition[ProcessedData]):
|
|
426
|
+
...
|
|
427
|
+
|
|
428
|
+
pipeline: SubWorkflowNode[ProcessedData] = SubWorkflowNode(
|
|
429
|
+
workflow_def=DataPipeline,
|
|
430
|
+
kwargs={"source_url": "https://..."}, # passed to build_with()
|
|
431
|
+
)
|
|
432
|
+
# Downstream: args_from={"data": pipeline} → TaskResult[ProcessedData, TaskError]
|
|
433
|
+
"""
|
|
434
|
+
|
|
435
|
+
workflow_def: type['WorkflowDefinition[OkT_co]']
|
|
436
|
+
"""
|
|
437
|
+
- The WorkflowDefinition subclass to run as a child workflow
|
|
438
|
+
- Must implement build_with() for parameterization
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
args: tuple[Any, ...] = ()
|
|
442
|
+
"""
|
|
443
|
+
- Positional arguments passed to workflow_def.build_with(app, *args, **kwargs)
|
|
444
|
+
"""
|
|
445
|
+
|
|
446
|
+
kwargs: dict[str, Any] = field(default_factory=lambda: {})
|
|
447
|
+
"""
|
|
448
|
+
- Keyword arguments passed to workflow_def.build_with(app, *args, **kwargs)
|
|
449
|
+
- Use with args_from to inject upstream results as parameters
|
|
450
|
+
"""
|
|
451
|
+
|
|
452
|
+
waits_for: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] = field(
|
|
453
|
+
default_factory=lambda: [],
|
|
454
|
+
)
|
|
455
|
+
"""
|
|
456
|
+
- Nodes this subworkflow waits for before starting
|
|
457
|
+
- Waits for all to be terminal (COMPLETED/FAILED/SKIPPED)
|
|
458
|
+
- Same semantics as TaskNode.waits_for
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
args_from: dict[str, 'TaskNode[Any] | SubWorkflowNode[Any]'] = field(
|
|
462
|
+
default_factory=lambda: {},
|
|
463
|
+
)
|
|
464
|
+
"""
|
|
465
|
+
- Maps kwarg names to upstream nodes for data injection
|
|
466
|
+
- Injected as TaskResult into build_with() kwargs
|
|
467
|
+
- Example: args_from={"input_data": fetch_node} → kwargs["input_data"] = TaskResult[T, E]
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
workflow_ctx_from: Sequence['TaskNode[Any] | SubWorkflowNode[Any]'] | None = None
|
|
471
|
+
"""
|
|
472
|
+
- Nodes whose results to include in WorkflowContext for run_when/skip_when
|
|
473
|
+
- For SubWorkflowNodes: access via ctx.summary_for(node) → SubWorkflowSummary
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
join: Literal['all', 'any', 'quorum'] = 'all'
|
|
477
|
+
"""
|
|
478
|
+
- "all": start when ALL dependencies are terminal (default)
|
|
479
|
+
- "any": start when ANY dependency succeeds (COMPLETED)
|
|
480
|
+
- "quorum": start when min_success dependencies succeed
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
min_success: int | None = None
|
|
484
|
+
"""
|
|
485
|
+
- Required for join="quorum": minimum dependencies that must succeed
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
allow_failed_deps: bool = False
|
|
489
|
+
"""
|
|
490
|
+
- False (default): SKIPPED if any dependency failed
|
|
491
|
+
- True: starts regardless, failed deps passed as TaskResult(err=...) via args_from
|
|
492
|
+
"""
|
|
493
|
+
|
|
494
|
+
run_when: Callable[['WorkflowContext'], bool] | None = field(
|
|
495
|
+
default=None, repr=False
|
|
496
|
+
)
|
|
497
|
+
"""
|
|
498
|
+
- Condition evaluated after deps terminal, before starting child workflow
|
|
499
|
+
- If returns False: node is SKIPPED
|
|
500
|
+
- skip_when takes priority over run_when
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
skip_when: Callable[['WorkflowContext'], bool] | None = field(
|
|
504
|
+
default=None, repr=False
|
|
505
|
+
)
|
|
506
|
+
"""
|
|
507
|
+
- Condition evaluated after deps terminal, before starting child workflow
|
|
508
|
+
- If returns True: node is SKIPPED
|
|
509
|
+
- skip_when takes priority over run_when
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
retry_mode: SubWorkflowRetryMode = SubWorkflowRetryMode.RERUN_FAILED_ONLY
|
|
513
|
+
"""
|
|
514
|
+
- How to retry if child workflow fails (only RERUN_FAILED_ONLY supported)
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
index: int | None = field(default=None, repr=False)
|
|
518
|
+
"""
|
|
519
|
+
- Auto-assigned during WorkflowSpec construction
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
node_id: str | None = field(default=None, repr=False)
|
|
523
|
+
"""
|
|
524
|
+
Optional stable identifier for this subworkflow within the parent workflow.
|
|
525
|
+
If None, auto-assigned as '{slugify(workflow_name)}:{task_index}'.
|
|
526
|
+
Must be unique within the workflow.
|
|
527
|
+
Must match pattern: [A-Za-z0-9_\\-:.]+
|
|
528
|
+
"""
|
|
529
|
+
|
|
530
|
+
@property
|
|
531
|
+
def name(self) -> str:
|
|
532
|
+
"""Get the subworkflow name."""
|
|
533
|
+
return self.workflow_def.name
|
|
534
|
+
|
|
535
|
+
def key(self) -> 'NodeKey[OkT_co]':
|
|
536
|
+
"""Return a typed NodeKey for this subworkflow node."""
|
|
537
|
+
if self.node_id is None:
|
|
538
|
+
raise WorkflowValidationError(
|
|
539
|
+
'SubWorkflowNode node_id is not set. Ensure WorkflowSpec assigns node_id '
|
|
540
|
+
'or provide an explicit node_id.'
|
|
541
|
+
)
|
|
542
|
+
return NodeKey(self.node_id)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
AnyNode = TaskNode[Any] | SubWorkflowNode[Any]
|
|
547
|
+
'''
|
|
548
|
+
Type alias for any node type
|
|
549
|
+
'''
|
|
550
|
+
|
|
551
|
+
# =============================================================================
|
|
552
|
+
# NodeKey (typed, stable id)
|
|
553
|
+
# =============================================================================
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
@dataclass(frozen=True)
|
|
557
|
+
class NodeKey(Generic[OkT_co]):
|
|
558
|
+
"""Typed stable identifier for a TaskNode."""
|
|
559
|
+
|
|
560
|
+
node_id: str
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
# =============================================================================
|
|
564
|
+
# Success Policy
|
|
565
|
+
# =============================================================================
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
@dataclass
|
|
569
|
+
class SuccessCase:
|
|
570
|
+
"""
|
|
571
|
+
A single success scenario for a workflow.
|
|
572
|
+
|
|
573
|
+
The case is satisfied when ALL required tasks are COMPLETED.
|
|
574
|
+
|
|
575
|
+
Example:
|
|
576
|
+
# Workflow succeeds if either (A and B) or (C) completes
|
|
577
|
+
SuccessPolicy(cases=[
|
|
578
|
+
SuccessCase(required=[task_a, task_b]),
|
|
579
|
+
SuccessCase(required=[task_c]),
|
|
580
|
+
])
|
|
581
|
+
"""
|
|
582
|
+
|
|
583
|
+
required: list[TaskNode[Any]]
|
|
584
|
+
"""
|
|
585
|
+
- All tasks in this list must be COMPLETED for the case to be satisfied
|
|
586
|
+
"""
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@dataclass
|
|
590
|
+
class SuccessPolicy:
|
|
591
|
+
"""
|
|
592
|
+
Custom success criteria: workflow COMPLETED if ANY SuccessCase is satisfied.
|
|
593
|
+
|
|
594
|
+
Without a success_policy, default behavior is: any task failure → workflow FAILED.
|
|
595
|
+
With a success_policy, workflow is COMPLETED if at least one case has all
|
|
596
|
+
its required tasks COMPLETED, regardless of other task failures.
|
|
597
|
+
|
|
598
|
+
Example:
|
|
599
|
+
# "Succeed if primary path completes, even if fallback fails"
|
|
600
|
+
success_policy = SuccessPolicy(
|
|
601
|
+
cases=[SuccessCase(required=[primary_task])],
|
|
602
|
+
optional=[fallback_task], # can fail without affecting success
|
|
603
|
+
)
|
|
604
|
+
"""
|
|
605
|
+
|
|
606
|
+
cases: list[SuccessCase]
|
|
607
|
+
"""
|
|
608
|
+
- List of success scenarios
|
|
609
|
+
- Workflow succeeds if ANY case is fully satisfied (all required COMPLETED)
|
|
610
|
+
"""
|
|
611
|
+
|
|
612
|
+
optional: list[TaskNode[Any]] | None = None
|
|
613
|
+
"""
|
|
614
|
+
- Tasks that may fail without affecting success evaluation
|
|
615
|
+
- These failures don't block success cases from being satisfied
|
|
616
|
+
"""
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
# =============================================================================
|
|
620
|
+
# WorkflowSpec
|
|
621
|
+
# =============================================================================
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
@dataclass
|
|
625
|
+
class WorkflowSpec:
|
|
626
|
+
"""
|
|
627
|
+
Specification for a workflow DAG. Created via app.workflow() or WorkflowDefinition.build().
|
|
628
|
+
|
|
629
|
+
Validates the DAG on construction (cycles, dependency refs, args_from, etc.)
|
|
630
|
+
to catch configuration errors early, before execution.
|
|
631
|
+
|
|
632
|
+
Example:
|
|
633
|
+
spec = app.workflow(
|
|
634
|
+
name="my_pipeline",
|
|
635
|
+
tasks=[fetch, process, persist],
|
|
636
|
+
output=persist,
|
|
637
|
+
on_error=OnError.PAUSE,
|
|
638
|
+
)
|
|
639
|
+
handle = await spec.start_async()
|
|
640
|
+
"""
|
|
641
|
+
|
|
642
|
+
name: str
|
|
643
|
+
"""
|
|
644
|
+
Human-readable workflow name (used in logs, DB, registry).
|
|
645
|
+
Can contain any characters including spaces. When auto-generating
|
|
646
|
+
node_ids, the name is passed through slugify() to ensure validity.
|
|
647
|
+
"""
|
|
648
|
+
|
|
649
|
+
tasks: list[TaskNode[Any] | SubWorkflowNode[Any]]
|
|
650
|
+
"""
|
|
651
|
+
- All nodes in the DAG (order determines index assignment)
|
|
652
|
+
- Root nodes (empty waits_for) start immediately
|
|
653
|
+
"""
|
|
654
|
+
|
|
655
|
+
on_error: OnError = OnError.FAIL
|
|
656
|
+
"""
|
|
657
|
+
- FAIL (default): on task failure, mark workflow as will-fail, skip dependent tasks
|
|
658
|
+
- PAUSE: on task failure, pause workflow for manual intervention (resume/cancel)
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
output: TaskNode[Any] | SubWorkflowNode[Any] | None = None
|
|
662
|
+
"""
|
|
663
|
+
- Explicit output node: WorkflowHandle.get() returns this node's result
|
|
664
|
+
- If None: get() returns dict of all terminal node results keyed by node_id
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
success_policy: SuccessPolicy | None = None
|
|
668
|
+
"""
|
|
669
|
+
- Custom success criteria: workflow COMPLETED if any SuccessCase is satisfied
|
|
670
|
+
- If None (default): any task failure → workflow FAILED
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
workflow_def_module: str | None = None
|
|
674
|
+
"""
|
|
675
|
+
- Module path of WorkflowDefinition class (for import fallback in workers)
|
|
676
|
+
"""
|
|
677
|
+
|
|
678
|
+
workflow_def_qualname: str | None = None
|
|
679
|
+
"""
|
|
680
|
+
- Qualified name of WorkflowDefinition class (for import fallback in workers)
|
|
681
|
+
"""
|
|
682
|
+
|
|
683
|
+
broker: PostgresBroker | None = field(default=None, repr=False)
|
|
684
|
+
"""
|
|
685
|
+
- Database broker for start()/start_async()
|
|
686
|
+
- Set automatically by app.workflow()
|
|
687
|
+
"""
|
|
688
|
+
|
|
689
|
+
def __post_init__(self) -> None:
|
|
690
|
+
"""Validate DAG structure on construction.
|
|
691
|
+
|
|
692
|
+
Phase-gated: node_id errors gate the rest (downstream needs valid IDs).
|
|
693
|
+
All other validations run and collect errors together.
|
|
694
|
+
"""
|
|
695
|
+
self._assign_indices()
|
|
696
|
+
|
|
697
|
+
# Gate 1: node_id assignment — if any ID errors, skip remaining validation
|
|
698
|
+
node_id_errors = self._collect_node_id_errors()
|
|
699
|
+
if node_id_errors:
|
|
700
|
+
report = ValidationReport('workflow')
|
|
701
|
+
for error in node_id_errors:
|
|
702
|
+
report.add(error)
|
|
703
|
+
raise_collected(report)
|
|
704
|
+
return
|
|
705
|
+
|
|
706
|
+
# Gate 2: collect all remaining validation errors
|
|
707
|
+
report = ValidationReport('workflow')
|
|
708
|
+
for error in self._collect_dag_errors():
|
|
709
|
+
report.add(error)
|
|
710
|
+
for error in self._collect_args_from_errors():
|
|
711
|
+
report.add(error)
|
|
712
|
+
for error in self._collect_workflow_ctx_from_errors():
|
|
713
|
+
report.add(error)
|
|
714
|
+
for error in self._collect_output_errors():
|
|
715
|
+
report.add(error)
|
|
716
|
+
for error in self._collect_success_policy_errors():
|
|
717
|
+
report.add(error)
|
|
718
|
+
for error in self._collect_join_semantics_errors():
|
|
719
|
+
report.add(error)
|
|
720
|
+
for error in self._collect_subworkflow_retry_mode_errors():
|
|
721
|
+
report.add(error)
|
|
722
|
+
for error in self._collect_subworkflow_cycle_errors():
|
|
723
|
+
report.add(error)
|
|
724
|
+
|
|
725
|
+
raise_collected(report)
|
|
726
|
+
|
|
727
|
+
# Only if clean: conditions + registration
|
|
728
|
+
self._validate_conditions()
|
|
729
|
+
self._register_for_conditions()
|
|
730
|
+
|
|
731
|
+
def _assign_indices(self) -> None:
|
|
732
|
+
"""Assign index to each TaskNode based on list position."""
|
|
733
|
+
for i, task in enumerate(self.tasks):
|
|
734
|
+
task.index = i
|
|
735
|
+
|
|
736
|
+
def _collect_node_id_errors(self) -> list[WorkflowValidationError]:
|
|
737
|
+
"""Assign node_id to each TaskNode if missing and validate uniqueness.
|
|
738
|
+
|
|
739
|
+
Returns all node_id errors instead of raising on first.
|
|
740
|
+
|
|
741
|
+
node_id source:
|
|
742
|
+
- User provides workflow NAME (e.g., "My Pipeline")
|
|
743
|
+
- node_id is either:
|
|
744
|
+
a) Derived from workflow name: slugify(name) + ":" + index
|
|
745
|
+
b) Explicitly provided by user on TaskNode
|
|
746
|
+
- Errors must distinguish between (a) and (b) so users know what to fix
|
|
747
|
+
"""
|
|
748
|
+
errors: list[WorkflowValidationError] = []
|
|
749
|
+
seen_ids: set[str] = set()
|
|
750
|
+
for task in self.tasks:
|
|
751
|
+
# Track whether node_id comes from workflow name or was explicitly set
|
|
752
|
+
node_id_from_workflow_name = task.node_id is None
|
|
753
|
+
if node_id_from_workflow_name:
|
|
754
|
+
if task.index is None:
|
|
755
|
+
errors.append(WorkflowValidationError(
|
|
756
|
+
message='TaskNode index is not set before assigning node_id',
|
|
757
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
758
|
+
))
|
|
759
|
+
continue
|
|
760
|
+
task.node_id = f'{slugify(self.name)}:{task.index}'
|
|
761
|
+
|
|
762
|
+
node_id = task.node_id
|
|
763
|
+
if node_id is None or not node_id.strip():
|
|
764
|
+
errors.append(WorkflowValidationError(
|
|
765
|
+
message='TaskNode node_id must be a non-empty string',
|
|
766
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
767
|
+
))
|
|
768
|
+
continue
|
|
769
|
+
if len(node_id) > 128:
|
|
770
|
+
if node_id_from_workflow_name:
|
|
771
|
+
errors.append(WorkflowValidationError(
|
|
772
|
+
message='workflow name too long',
|
|
773
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
774
|
+
notes=[
|
|
775
|
+
f"workflow name: '{self.name}'",
|
|
776
|
+
f"derived node_id would be {len(node_id)} characters (max 128)",
|
|
777
|
+
],
|
|
778
|
+
help_text='use a shorter workflow name',
|
|
779
|
+
))
|
|
780
|
+
else:
|
|
781
|
+
errors.append(WorkflowValidationError(
|
|
782
|
+
message='TaskNode node_id exceeds 128 characters',
|
|
783
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
784
|
+
notes=[f"node_id '{node_id}' has {len(node_id)} characters"],
|
|
785
|
+
help_text='use a shorter node_id (max 128 characters)',
|
|
786
|
+
))
|
|
787
|
+
continue
|
|
788
|
+
if NODE_ID_PATTERN.match(node_id) is None:
|
|
789
|
+
if node_id_from_workflow_name:
|
|
790
|
+
# This should never happen since slugify() sanitizes the name
|
|
791
|
+
errors.append(WorkflowValidationError(
|
|
792
|
+
message='workflow name produced invalid characters (internal error)',
|
|
793
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
794
|
+
notes=[
|
|
795
|
+
f"workflow name: '{self.name}'",
|
|
796
|
+
f"derived node_id: '{node_id}'",
|
|
797
|
+
'slugify() failed to sanitize the name',
|
|
798
|
+
],
|
|
799
|
+
help_text='please report this bug',
|
|
800
|
+
))
|
|
801
|
+
else:
|
|
802
|
+
errors.append(WorkflowValidationError(
|
|
803
|
+
message='TaskNode node_id contains invalid characters',
|
|
804
|
+
code=ErrorCode.WORKFLOW_INVALID_NODE_ID,
|
|
805
|
+
notes=[f"node_id '{node_id}'"],
|
|
806
|
+
help_text='node_id must match pattern: [A-Za-z0-9_\\-:.]+',
|
|
807
|
+
))
|
|
808
|
+
continue
|
|
809
|
+
if node_id in seen_ids:
|
|
810
|
+
errors.append(WorkflowValidationError(
|
|
811
|
+
message=f"duplicate node_id '{node_id}'",
|
|
812
|
+
code=ErrorCode.WORKFLOW_DUPLICATE_NODE_ID,
|
|
813
|
+
help_text='each TaskNode must have a unique node_id within the workflow',
|
|
814
|
+
))
|
|
815
|
+
seen_ids.add(node_id)
|
|
816
|
+
return errors
|
|
817
|
+
|
|
818
|
+
def _collect_dag_errors(self) -> list[WorkflowValidationError]:
|
|
819
|
+
"""Validate DAG structure. Returns all errors found."""
|
|
820
|
+
errors: list[WorkflowValidationError] = []
|
|
821
|
+
|
|
822
|
+
# 1. Check for roots (tasks with no dependencies)
|
|
823
|
+
roots = [t for t in self.tasks if not t.waits_for]
|
|
824
|
+
if not roots:
|
|
825
|
+
errors.append(WorkflowValidationError(
|
|
826
|
+
message='no root tasks found',
|
|
827
|
+
code=ErrorCode.WORKFLOW_NO_ROOT_TASKS,
|
|
828
|
+
notes=[
|
|
829
|
+
'all tasks have dependencies, creating an impossible start condition',
|
|
830
|
+
],
|
|
831
|
+
help_text='at least one task must have empty waits_for list',
|
|
832
|
+
))
|
|
833
|
+
|
|
834
|
+
# 2. Validate dependency references exist in workflow
|
|
835
|
+
task_ids = set(id(t) for t in self.tasks)
|
|
836
|
+
for task in self.tasks:
|
|
837
|
+
for dep in task.waits_for:
|
|
838
|
+
if id(dep) not in task_ids:
|
|
839
|
+
errors.append(WorkflowValidationError(
|
|
840
|
+
message='dependency references task not in workflow',
|
|
841
|
+
code=ErrorCode.WORKFLOW_INVALID_DEPENDENCY,
|
|
842
|
+
notes=[
|
|
843
|
+
f"task '{task.name}' waits for a TaskNode not in this workflow",
|
|
844
|
+
],
|
|
845
|
+
help_text='ensure all dependencies are included in the workflow tasks list',
|
|
846
|
+
))
|
|
847
|
+
|
|
848
|
+
# 3. Cycle detection (Kahn's algorithm) over valid dependencies only
|
|
849
|
+
in_degree: dict[int, int] = {}
|
|
850
|
+
for task in self.tasks:
|
|
851
|
+
idx = task.index
|
|
852
|
+
if idx is None:
|
|
853
|
+
continue
|
|
854
|
+
in_degree[idx] = 0
|
|
855
|
+
for dep in task.waits_for:
|
|
856
|
+
if id(dep) in task_ids and dep.index is not None:
|
|
857
|
+
in_degree[idx] += 1
|
|
858
|
+
|
|
859
|
+
queue = [
|
|
860
|
+
t.index
|
|
861
|
+
for t in self.tasks
|
|
862
|
+
if t.index is not None and in_degree.get(t.index, 0) == 0
|
|
863
|
+
]
|
|
864
|
+
visited = 0
|
|
865
|
+
|
|
866
|
+
while queue:
|
|
867
|
+
node_idx = queue.pop(0)
|
|
868
|
+
visited += 1
|
|
869
|
+
for task in self.tasks:
|
|
870
|
+
dep_indices = [
|
|
871
|
+
d.index
|
|
872
|
+
for d in task.waits_for
|
|
873
|
+
if id(d) in task_ids and d.index is not None
|
|
874
|
+
]
|
|
875
|
+
if node_idx in dep_indices:
|
|
876
|
+
task_idx = task.index
|
|
877
|
+
if task_idx is not None:
|
|
878
|
+
in_degree[task_idx] -= 1
|
|
879
|
+
if in_degree[task_idx] == 0:
|
|
880
|
+
queue.append(task_idx)
|
|
881
|
+
|
|
882
|
+
if visited != len(self.tasks):
|
|
883
|
+
errors.append(WorkflowValidationError(
|
|
884
|
+
message='cycle detected in workflow DAG',
|
|
885
|
+
code=ErrorCode.WORKFLOW_CYCLE_DETECTED,
|
|
886
|
+
notes=['workflows must be acyclic directed graphs (DAG)'],
|
|
887
|
+
help_text='remove circular dependencies between tasks',
|
|
888
|
+
))
|
|
889
|
+
|
|
890
|
+
return errors
|
|
891
|
+
|
|
892
|
+
def _collect_args_from_errors(self) -> list[WorkflowValidationError]:
|
|
893
|
+
"""Validate args_from references are valid dependencies. Returns all errors."""
|
|
894
|
+
errors: list[WorkflowValidationError] = []
|
|
895
|
+
for task in self.tasks:
|
|
896
|
+
deps_ids = set(id(d) for d in task.waits_for)
|
|
897
|
+
for kwarg_name, source_node in task.args_from.items():
|
|
898
|
+
if id(source_node) not in deps_ids:
|
|
899
|
+
errors.append(WorkflowValidationError(
|
|
900
|
+
message='args_from references task not in waits_for',
|
|
901
|
+
code=ErrorCode.WORKFLOW_INVALID_ARGS_FROM,
|
|
902
|
+
notes=[
|
|
903
|
+
f"task '{task.name}' args_from['{kwarg_name}'] references '{source_node.name}'",
|
|
904
|
+
f"'{source_node.name}' must be in waits_for to inject its result",
|
|
905
|
+
],
|
|
906
|
+
help_text=f"add '{source_node.name}' to waits_for list",
|
|
907
|
+
))
|
|
908
|
+
return errors
|
|
909
|
+
|
|
910
|
+
def _collect_workflow_ctx_from_errors(self) -> list[WorkflowValidationError]:
|
|
911
|
+
"""Validate workflow_ctx_from references are valid dependencies. Returns all errors."""
|
|
912
|
+
errors: list[WorkflowValidationError] = []
|
|
913
|
+
for node in self.tasks:
|
|
914
|
+
if node.workflow_ctx_from is None:
|
|
915
|
+
continue
|
|
916
|
+
deps_ids = set(id(d) for d in node.waits_for)
|
|
917
|
+
for ctx_node in node.workflow_ctx_from:
|
|
918
|
+
if id(ctx_node) not in deps_ids:
|
|
919
|
+
errors.append(WorkflowValidationError(
|
|
920
|
+
message='workflow_ctx_from references task not in waits_for',
|
|
921
|
+
code=ErrorCode.WORKFLOW_INVALID_CTX_FROM,
|
|
922
|
+
notes=[
|
|
923
|
+
f"node '{node.name}' references '{ctx_node.name}'",
|
|
924
|
+
f"'{ctx_node.name}' must be in waits_for to use in workflow_ctx_from",
|
|
925
|
+
],
|
|
926
|
+
help_text=f"add '{ctx_node.name}' to waits_for list",
|
|
927
|
+
))
|
|
928
|
+
|
|
929
|
+
# Only check function parameter for TaskNode (SubWorkflowNode has no fn)
|
|
930
|
+
if isinstance(node, SubWorkflowNode):
|
|
931
|
+
continue
|
|
932
|
+
|
|
933
|
+
task = node
|
|
934
|
+
if not _task_accepts_workflow_ctx(task.fn):
|
|
935
|
+
# Get the original function for accurate source location
|
|
936
|
+
original_fn = getattr(task.fn, '_original_fn', task.fn)
|
|
937
|
+
fn_location = SourceLocation.from_function(original_fn)
|
|
938
|
+
|
|
939
|
+
errors.append(WorkflowValidationError(
|
|
940
|
+
message='workflow_ctx_from declared but function missing workflow_ctx param',
|
|
941
|
+
code=ErrorCode.WORKFLOW_CTX_PARAM_MISSING,
|
|
942
|
+
location=fn_location, # May be None for non-function callables
|
|
943
|
+
notes=[
|
|
944
|
+
f"workflow '{self.name}'\n"
|
|
945
|
+
f"TaskNode '{task.name}' declares workflow_ctx_from=[...]\n"
|
|
946
|
+
f"but function '{task.name}' has no workflow_ctx parameter",
|
|
947
|
+
],
|
|
948
|
+
help_text=(
|
|
949
|
+
'either:\n'
|
|
950
|
+
' 1. add `workflow_ctx: WorkflowContext | None` param to the function above if needs context\n'
|
|
951
|
+
' 2. remove `workflow_ctx_from` from the TaskNode definition if this was a mistake'
|
|
952
|
+
),
|
|
953
|
+
))
|
|
954
|
+
return errors
|
|
955
|
+
|
|
956
|
+
def _collect_output_errors(self) -> list[WorkflowValidationError]:
|
|
957
|
+
"""Validate output task is in the workflow. Returns all errors."""
|
|
958
|
+
errors: list[WorkflowValidationError] = []
|
|
959
|
+
if self.output is None:
|
|
960
|
+
return errors
|
|
961
|
+
task_ids = set(id(t) for t in self.tasks)
|
|
962
|
+
if id(self.output) not in task_ids:
|
|
963
|
+
errors.append(WorkflowValidationError(
|
|
964
|
+
f"Output task '{self.output.name}' is not in workflow",
|
|
965
|
+
))
|
|
966
|
+
return errors
|
|
967
|
+
|
|
968
|
+
def _collect_success_policy_errors(self) -> list[WorkflowValidationError]:
|
|
969
|
+
"""Validate success policy references are valid workflow tasks. Returns all errors."""
|
|
970
|
+
errors: list[WorkflowValidationError] = []
|
|
971
|
+
if self.success_policy is None:
|
|
972
|
+
return errors
|
|
973
|
+
|
|
974
|
+
# Validate cases list is not empty
|
|
975
|
+
if not self.success_policy.cases:
|
|
976
|
+
errors.append(WorkflowValidationError(
|
|
977
|
+
'SuccessPolicy must have at least one SuccessCase',
|
|
978
|
+
))
|
|
979
|
+
return errors
|
|
980
|
+
|
|
981
|
+
task_ids = set(id(t) for t in self.tasks)
|
|
982
|
+
|
|
983
|
+
# Validate each success case
|
|
984
|
+
for i, case in enumerate(self.success_policy.cases):
|
|
985
|
+
if not case.required:
|
|
986
|
+
errors.append(WorkflowValidationError(
|
|
987
|
+
f'SuccessCase[{i}] has no required tasks',
|
|
988
|
+
))
|
|
989
|
+
for task in case.required:
|
|
990
|
+
if id(task) not in task_ids:
|
|
991
|
+
errors.append(WorkflowValidationError(
|
|
992
|
+
f"SuccessCase[{i}] required task '{task.name}' is not in workflow",
|
|
993
|
+
))
|
|
994
|
+
|
|
995
|
+
# Validate optional tasks
|
|
996
|
+
if self.success_policy.optional:
|
|
997
|
+
for task in self.success_policy.optional:
|
|
998
|
+
if id(task) not in task_ids:
|
|
999
|
+
errors.append(WorkflowValidationError(
|
|
1000
|
+
f"SuccessPolicy optional task '{task.name}' is not in workflow",
|
|
1001
|
+
))
|
|
1002
|
+
|
|
1003
|
+
return errors
|
|
1004
|
+
|
|
1005
|
+
def _collect_join_semantics_errors(self) -> list[WorkflowValidationError]:
|
|
1006
|
+
"""Validate join and min_success settings. Returns all errors."""
|
|
1007
|
+
errors: list[WorkflowValidationError] = []
|
|
1008
|
+
for task in self.tasks:
|
|
1009
|
+
if task.join == 'quorum':
|
|
1010
|
+
if task.min_success is None:
|
|
1011
|
+
errors.append(WorkflowValidationError(
|
|
1012
|
+
f"Task '{task.name}' has join='quorum' but min_success is not set",
|
|
1013
|
+
))
|
|
1014
|
+
elif task.min_success < 1:
|
|
1015
|
+
errors.append(WorkflowValidationError(
|
|
1016
|
+
f"Task '{task.name}' min_success must be >= 1, got {task.min_success}",
|
|
1017
|
+
))
|
|
1018
|
+
else:
|
|
1019
|
+
dep_count = len(task.waits_for)
|
|
1020
|
+
if task.min_success > dep_count:
|
|
1021
|
+
errors.append(WorkflowValidationError(
|
|
1022
|
+
f"Task '{task.name}' min_success ({task.min_success}) exceeds "
|
|
1023
|
+
f'dependency count ({dep_count})',
|
|
1024
|
+
))
|
|
1025
|
+
elif task.join in ('all', 'any'):
|
|
1026
|
+
if task.min_success is not None:
|
|
1027
|
+
errors.append(WorkflowValidationError(
|
|
1028
|
+
f"Task '{task.name}' has min_success set but join='{task.join}' "
|
|
1029
|
+
"(min_success is only used with join='quorum')",
|
|
1030
|
+
))
|
|
1031
|
+
return errors
|
|
1032
|
+
|
|
1033
|
+
def _validate_conditions(self) -> None:
|
|
1034
|
+
"""Validate condition callables have required context dependencies."""
|
|
1035
|
+
for task in self.tasks:
|
|
1036
|
+
has_condition = task.run_when is not None or task.skip_when is not None
|
|
1037
|
+
if has_condition and not task.workflow_ctx_from:
|
|
1038
|
+
# Conditions require context, but no context sources specified
|
|
1039
|
+
# This is allowed (empty context), but may cause KeyError if
|
|
1040
|
+
# condition tries to access dependency results
|
|
1041
|
+
pass # Allow - user may have conditions that don't use context
|
|
1042
|
+
|
|
1043
|
+
def _collect_subworkflow_cycle_errors(self) -> list[WorkflowValidationError]:
|
|
1044
|
+
"""Detect cycles in nested workflow definitions. Returns all errors.
|
|
1045
|
+
|
|
1046
|
+
Prevents circular references like:
|
|
1047
|
+
- WorkflowA contains SubWorkflowNode(WorkflowB)
|
|
1048
|
+
and WorkflowB contains SubWorkflowNode(WorkflowA)
|
|
1049
|
+
|
|
1050
|
+
Uses DFS with a recursion stack to detect back-edges.
|
|
1051
|
+
"""
|
|
1052
|
+
errors: list[WorkflowValidationError] = []
|
|
1053
|
+
visited: set[str] = set()
|
|
1054
|
+
stack: set[str] = set()
|
|
1055
|
+
|
|
1056
|
+
def visit(workflow_name: str, workflow_class: type[Any]) -> None:
|
|
1057
|
+
"""DFS visit with cycle detection via recursion stack."""
|
|
1058
|
+
if workflow_name in stack:
|
|
1059
|
+
# Found a back-edge - this is a cycle
|
|
1060
|
+
errors.append(WorkflowValidationError(
|
|
1061
|
+
message='cycle detected in nested workflows',
|
|
1062
|
+
code=ErrorCode.WORKFLOW_CYCLE_DETECTED,
|
|
1063
|
+
notes=[
|
|
1064
|
+
f"workflow '{workflow_name}' creates a circular reference",
|
|
1065
|
+
'cycles in nested workflows are not allowed',
|
|
1066
|
+
],
|
|
1067
|
+
help_text='remove the circular SubWorkflowNode reference',
|
|
1068
|
+
))
|
|
1069
|
+
return
|
|
1070
|
+
|
|
1071
|
+
if workflow_name in visited:
|
|
1072
|
+
# Already fully explored this workflow, no cycle through here
|
|
1073
|
+
return
|
|
1074
|
+
|
|
1075
|
+
visited.add(workflow_name)
|
|
1076
|
+
stack.add(workflow_name)
|
|
1077
|
+
|
|
1078
|
+
# Check all SubWorkflowNodes in this workflow's definition
|
|
1079
|
+
nodes = workflow_class.get_workflow_nodes()
|
|
1080
|
+
if nodes:
|
|
1081
|
+
for _, wf_node in nodes:
|
|
1082
|
+
if isinstance(wf_node, SubWorkflowNode):
|
|
1083
|
+
wf_node_any = cast(SubWorkflowNode[Any], wf_node)
|
|
1084
|
+
workflow_def = wf_node_any.workflow_def
|
|
1085
|
+
child_name: str = workflow_def.name
|
|
1086
|
+
visit(child_name, workflow_def)
|
|
1087
|
+
|
|
1088
|
+
# Done exploring this workflow
|
|
1089
|
+
stack.remove(workflow_name)
|
|
1090
|
+
|
|
1091
|
+
# Start DFS from each SubWorkflowNode in this workflow's tasks
|
|
1092
|
+
for node in self.tasks:
|
|
1093
|
+
if isinstance(node, SubWorkflowNode):
|
|
1094
|
+
visit(node.workflow_def.name, node.workflow_def)
|
|
1095
|
+
|
|
1096
|
+
return errors
|
|
1097
|
+
|
|
1098
|
+
def _collect_subworkflow_retry_mode_errors(self) -> list[WorkflowValidationError]:
|
|
1099
|
+
"""Reject unsupported subworkflow retry modes. Returns all errors."""
|
|
1100
|
+
errors: list[WorkflowValidationError] = []
|
|
1101
|
+
for node in self.tasks:
|
|
1102
|
+
if not isinstance(node, SubWorkflowNode):
|
|
1103
|
+
continue
|
|
1104
|
+
if node.retry_mode != SubWorkflowRetryMode.RERUN_FAILED_ONLY:
|
|
1105
|
+
errors.append(WorkflowValidationError(
|
|
1106
|
+
message='unsupported SubWorkflowRetryMode',
|
|
1107
|
+
code=ErrorCode.WORKFLOW_INVALID_SUBWORKFLOW_RETRY_MODE,
|
|
1108
|
+
notes=[
|
|
1109
|
+
f"node '{node.name}' uses retry_mode='{node.retry_mode.value}'",
|
|
1110
|
+
"only 'rerun_failed_only' is supported in this release",
|
|
1111
|
+
],
|
|
1112
|
+
help_text='use SubWorkflowRetryMode.RERUN_FAILED_ONLY',
|
|
1113
|
+
))
|
|
1114
|
+
return errors
|
|
1115
|
+
|
|
1116
|
+
def _register_for_conditions(self) -> None:
|
|
1117
|
+
"""Register this spec for condition evaluation at runtime."""
|
|
1118
|
+
# Register if any task has conditions OR any SubWorkflowNode exists
|
|
1119
|
+
has_conditions = any(
|
|
1120
|
+
t.run_when is not None or t.skip_when is not None for t in self.tasks
|
|
1121
|
+
)
|
|
1122
|
+
has_subworkflow = any(isinstance(t, SubWorkflowNode) for t in self.tasks)
|
|
1123
|
+
if has_conditions or has_subworkflow:
|
|
1124
|
+
from horsies.core.workflows.registry import register_workflow_spec
|
|
1125
|
+
|
|
1126
|
+
register_workflow_spec(self)
|
|
1127
|
+
|
|
1128
|
+
def start(self, workflow_id: str | None = None) -> 'WorkflowHandle':
|
|
1129
|
+
"""
|
|
1130
|
+
Start workflow execution.
|
|
1131
|
+
|
|
1132
|
+
Args:
|
|
1133
|
+
workflow_id: Optional custom workflow ID. Auto-generated if not provided.
|
|
1134
|
+
|
|
1135
|
+
Returns:
|
|
1136
|
+
WorkflowHandle for tracking and retrieving results.
|
|
1137
|
+
|
|
1138
|
+
Raises:
|
|
1139
|
+
RuntimeError: If broker is not configured.
|
|
1140
|
+
"""
|
|
1141
|
+
if self.broker is None:
|
|
1142
|
+
raise RuntimeError(
|
|
1143
|
+
'WorkflowSpec requires a broker. Use app.workflow() or set broker.'
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
# Import here to avoid circular imports
|
|
1147
|
+
from horsies.core.workflows.engine import start_workflow
|
|
1148
|
+
|
|
1149
|
+
return start_workflow(self, self.broker, workflow_id)
|
|
1150
|
+
|
|
1151
|
+
async def start_async(self, workflow_id: str | None = None) -> 'WorkflowHandle':
|
|
1152
|
+
"""
|
|
1153
|
+
Start workflow execution (async).
|
|
1154
|
+
|
|
1155
|
+
Args:
|
|
1156
|
+
workflow_id: Optional custom workflow ID. Auto-generated if not provided.
|
|
1157
|
+
|
|
1158
|
+
Returns:
|
|
1159
|
+
WorkflowHandle for tracking and retrieving results.
|
|
1160
|
+
|
|
1161
|
+
Raises:
|
|
1162
|
+
RuntimeError: If broker is not configured.
|
|
1163
|
+
"""
|
|
1164
|
+
if self.broker is None:
|
|
1165
|
+
raise RuntimeError(
|
|
1166
|
+
'WorkflowSpec requires a broker. Use app.workflow() or set broker.'
|
|
1167
|
+
)
|
|
1168
|
+
|
|
1169
|
+
# Import here to avoid circular imports
|
|
1170
|
+
from horsies.core.workflows.engine import start_workflow_async
|
|
1171
|
+
|
|
1172
|
+
return await start_workflow_async(self, self.broker, workflow_id)
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
# =============================================================================
|
|
1176
|
+
# WorkflowMeta (metadata only, no result access)
|
|
1177
|
+
# =============================================================================
|
|
1178
|
+
|
|
1179
|
+
|
|
1180
|
+
@dataclass
|
|
1181
|
+
class WorkflowMeta:
|
|
1182
|
+
"""
|
|
1183
|
+
Workflow execution metadata.
|
|
1184
|
+
|
|
1185
|
+
Auto-injected if task declares `workflow_meta: WorkflowMeta | None` parameter.
|
|
1186
|
+
Contains only metadata, no result access.
|
|
1187
|
+
|
|
1188
|
+
Attributes:
|
|
1189
|
+
workflow_id: UUID of the workflow instance
|
|
1190
|
+
task_index: Index of the current task in the workflow
|
|
1191
|
+
task_name: Name of the current task
|
|
1192
|
+
"""
|
|
1193
|
+
|
|
1194
|
+
workflow_id: str
|
|
1195
|
+
task_index: int
|
|
1196
|
+
task_name: str
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
# =============================================================================
|
|
1200
|
+
# WorkflowContext (type-safe result access via result_for)
|
|
1201
|
+
# =============================================================================
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
class WorkflowContextMissingIdError(RuntimeError):
|
|
1205
|
+
"""Raised when TaskNode node_id is missing for WorkflowContext.result_for()."""
|
|
1206
|
+
|
|
1207
|
+
|
|
1208
|
+
class WorkflowHandleMissingIdError(RuntimeError):
|
|
1209
|
+
"""Raised when TaskNode node_id is missing for WorkflowHandle.result_for()."""
|
|
1210
|
+
|
|
1211
|
+
|
|
1212
|
+
class WorkflowContext(BaseModel):
|
|
1213
|
+
"""
|
|
1214
|
+
Context passed to workflow tasks with type-safe access to dependency results.
|
|
1215
|
+
|
|
1216
|
+
Only injected if:
|
|
1217
|
+
1. TaskNode has workflow_ctx_from set, AND
|
|
1218
|
+
2. Task function declares `workflow_ctx: WorkflowContext | None` parameter
|
|
1219
|
+
|
|
1220
|
+
Use result_for(node) to access results in a type-safe manner.
|
|
1221
|
+
Use summary_for(node) to access SubWorkflowSummary for SubWorkflowNodes.
|
|
1222
|
+
|
|
1223
|
+
Attributes:
|
|
1224
|
+
workflow_id: UUID of the workflow instance
|
|
1225
|
+
task_index: Index of the current task in the workflow
|
|
1226
|
+
task_name: Name of the current task
|
|
1227
|
+
"""
|
|
1228
|
+
|
|
1229
|
+
model_config = {'arbitrary_types_allowed': True}
|
|
1230
|
+
|
|
1231
|
+
workflow_id: str # UUID as string for JSON serialization
|
|
1232
|
+
task_index: int
|
|
1233
|
+
task_name: str
|
|
1234
|
+
|
|
1235
|
+
# Internal storage: results keyed by node_id
|
|
1236
|
+
_results_by_id: dict[str, Any] = {}
|
|
1237
|
+
# Internal storage: subworkflow summaries keyed by node_id
|
|
1238
|
+
_summaries_by_id: dict[str, Any] = {}
|
|
1239
|
+
|
|
1240
|
+
def __init__(
|
|
1241
|
+
self,
|
|
1242
|
+
workflow_id: str,
|
|
1243
|
+
task_index: int,
|
|
1244
|
+
task_name: str,
|
|
1245
|
+
results_by_id: dict[str, Any] | None = None,
|
|
1246
|
+
summaries_by_id: dict[str, Any] | None = None,
|
|
1247
|
+
**kwargs: Any,
|
|
1248
|
+
) -> None:
|
|
1249
|
+
super().__init__(
|
|
1250
|
+
workflow_id=workflow_id,
|
|
1251
|
+
task_index=task_index,
|
|
1252
|
+
task_name=task_name,
|
|
1253
|
+
**kwargs,
|
|
1254
|
+
)
|
|
1255
|
+
# Store results internally (not exposed as Pydantic field)
|
|
1256
|
+
object.__setattr__(self, '_results_by_id', results_by_id or {})
|
|
1257
|
+
object.__setattr__(self, '_summaries_by_id', summaries_by_id or {})
|
|
1258
|
+
|
|
1259
|
+
def result_for(
|
|
1260
|
+
self,
|
|
1261
|
+
node: TaskNode[OkT] | NodeKey[OkT],
|
|
1262
|
+
) -> 'TaskResult[OkT, TaskError]':
|
|
1263
|
+
"""
|
|
1264
|
+
Get the result for a specific TaskNode.
|
|
1265
|
+
|
|
1266
|
+
Type-safe: returns TaskResult[T, TaskError] where T matches the node's type.
|
|
1267
|
+
|
|
1268
|
+
Args:
|
|
1269
|
+
node: The TaskNode or NodeKey whose result to retrieve. Must have been
|
|
1270
|
+
included in workflow_ctx_from and have a node_id assigned.
|
|
1271
|
+
|
|
1272
|
+
Returns:
|
|
1273
|
+
The TaskResult from the completed task.
|
|
1274
|
+
|
|
1275
|
+
Raises:
|
|
1276
|
+
KeyError: If the node's result is not in this context.
|
|
1277
|
+
RuntimeError: If the node has no node_id assigned.
|
|
1278
|
+
"""
|
|
1279
|
+
|
|
1280
|
+
node_id: str | None
|
|
1281
|
+
if isinstance(node, NodeKey):
|
|
1282
|
+
node_id = node.node_id
|
|
1283
|
+
else:
|
|
1284
|
+
node_id = node.node_id
|
|
1285
|
+
|
|
1286
|
+
if node_id is None:
|
|
1287
|
+
raise WorkflowContextMissingIdError(
|
|
1288
|
+
'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
|
|
1289
|
+
'or provide an explicit node_id.'
|
|
1290
|
+
)
|
|
1291
|
+
|
|
1292
|
+
if node_id not in self._results_by_id:
|
|
1293
|
+
raise KeyError(
|
|
1294
|
+
f"TaskNode id '{node_id}' not in workflow context. "
|
|
1295
|
+
'Ensure the node is included in workflow_ctx_from.'
|
|
1296
|
+
)
|
|
1297
|
+
|
|
1298
|
+
# Cast is safe because the generic parameter ensures type correctness
|
|
1299
|
+
return cast('TaskResult[OkT, TaskError]', self._results_by_id[node_id])
|
|
1300
|
+
|
|
1301
|
+
def has_result(self, node: TaskNode[Any] | SubWorkflowNode[Any]) -> bool:
|
|
1302
|
+
"""Check if a result exists for the given node."""
|
|
1303
|
+
if node.node_id is None:
|
|
1304
|
+
return False
|
|
1305
|
+
return node.node_id in self._results_by_id
|
|
1306
|
+
|
|
1307
|
+
def summary_for(
|
|
1308
|
+
self,
|
|
1309
|
+
node: 'SubWorkflowNode[OkT]',
|
|
1310
|
+
) -> 'SubWorkflowSummary[OkT]':
|
|
1311
|
+
"""
|
|
1312
|
+
Get the SubWorkflowSummary for a completed SubWorkflowNode.
|
|
1313
|
+
|
|
1314
|
+
Type-safe: returns SubWorkflowSummary[T] where T matches the node's output type.
|
|
1315
|
+
|
|
1316
|
+
Args:
|
|
1317
|
+
node: The SubWorkflowNode whose summary to retrieve. Must have been
|
|
1318
|
+
included in workflow_ctx_from and have a node_id assigned.
|
|
1319
|
+
|
|
1320
|
+
Returns:
|
|
1321
|
+
The SubWorkflowSummary from the completed subworkflow.
|
|
1322
|
+
|
|
1323
|
+
Raises:
|
|
1324
|
+
KeyError: If the node's summary is not in this context.
|
|
1325
|
+
RuntimeError: If the node has no node_id assigned.
|
|
1326
|
+
"""
|
|
1327
|
+
node_id = node.node_id
|
|
1328
|
+
|
|
1329
|
+
if node_id is None:
|
|
1330
|
+
raise WorkflowContextMissingIdError(
|
|
1331
|
+
'SubWorkflowNode node_id is not set. Ensure WorkflowSpec assigns node_id.'
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
if node_id not in self._summaries_by_id:
|
|
1335
|
+
raise KeyError(
|
|
1336
|
+
f"SubWorkflowNode id '{node_id}' not in workflow context summaries. "
|
|
1337
|
+
'Ensure the node is included in workflow_ctx_from.'
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
# Cast is safe because the generic parameter ensures type correctness
|
|
1341
|
+
return cast('SubWorkflowSummary[OkT]', self._summaries_by_id[node_id])
|
|
1342
|
+
|
|
1343
|
+
def has_summary(self, node: 'SubWorkflowNode[Any]') -> bool:
|
|
1344
|
+
"""Check if a summary exists for the given SubWorkflowNode."""
|
|
1345
|
+
if node.node_id is None:
|
|
1346
|
+
return False
|
|
1347
|
+
return node.node_id in self._summaries_by_id
|
|
1348
|
+
|
|
1349
|
+
@classmethod
|
|
1350
|
+
def from_serialized(
|
|
1351
|
+
cls,
|
|
1352
|
+
workflow_id: str,
|
|
1353
|
+
task_index: int,
|
|
1354
|
+
task_name: str,
|
|
1355
|
+
results_by_id: dict[str, Any],
|
|
1356
|
+
summaries_by_id: dict[str, Any] | None = None,
|
|
1357
|
+
) -> 'WorkflowContext':
|
|
1358
|
+
"""Reconstruct WorkflowContext from serialized data."""
|
|
1359
|
+
return cls(
|
|
1360
|
+
workflow_id=workflow_id,
|
|
1361
|
+
task_index=task_index,
|
|
1362
|
+
task_name=task_name,
|
|
1363
|
+
results_by_id=results_by_id,
|
|
1364
|
+
summaries_by_id=summaries_by_id,
|
|
1365
|
+
)
|
|
1366
|
+
|
|
1367
|
+
|
|
1368
|
+
# =============================================================================
|
|
1369
|
+
# WorkflowHandle
|
|
1370
|
+
# =============================================================================
|
|
1371
|
+
|
|
1372
|
+
|
|
1373
|
+
@dataclass
|
|
1374
|
+
class WorkflowTaskInfo:
|
|
1375
|
+
"""Information about a task within a workflow."""
|
|
1376
|
+
|
|
1377
|
+
index: int
|
|
1378
|
+
name: str
|
|
1379
|
+
status: WorkflowTaskStatus
|
|
1380
|
+
result: TaskResult[Any, TaskError] | None
|
|
1381
|
+
started_at: datetime | None
|
|
1382
|
+
completed_at: datetime | None
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
@dataclass
|
|
1386
|
+
class WorkflowHandle:
|
|
1387
|
+
"""
|
|
1388
|
+
Handle for tracking and retrieving workflow results.
|
|
1389
|
+
|
|
1390
|
+
Provides methods to:
|
|
1391
|
+
- Check workflow status
|
|
1392
|
+
- Wait for and retrieve results
|
|
1393
|
+
- Inspect individual task states
|
|
1394
|
+
- Cancel the workflow
|
|
1395
|
+
"""
|
|
1396
|
+
|
|
1397
|
+
workflow_id: str
|
|
1398
|
+
broker: PostgresBroker
|
|
1399
|
+
|
|
1400
|
+
def status(self) -> WorkflowStatus:
|
|
1401
|
+
"""Get current workflow status."""
|
|
1402
|
+
|
|
1403
|
+
runner = LoopRunner()
|
|
1404
|
+
try:
|
|
1405
|
+
return runner.call(self.status_async)
|
|
1406
|
+
finally:
|
|
1407
|
+
runner.stop()
|
|
1408
|
+
|
|
1409
|
+
async def status_async(self) -> WorkflowStatus:
|
|
1410
|
+
"""Async version of status()."""
|
|
1411
|
+
from sqlalchemy import text
|
|
1412
|
+
|
|
1413
|
+
async with self.broker.session_factory() as session:
|
|
1414
|
+
result = await session.execute(
|
|
1415
|
+
text('SELECT status FROM horsies_workflows WHERE id = :wf_id'),
|
|
1416
|
+
{'wf_id': self.workflow_id},
|
|
1417
|
+
)
|
|
1418
|
+
row = result.fetchone()
|
|
1419
|
+
if row is None:
|
|
1420
|
+
raise ValueError(f'Workflow {self.workflow_id} not found')
|
|
1421
|
+
return WorkflowStatus(row[0])
|
|
1422
|
+
|
|
1423
|
+
def get(self, timeout_ms: int | None = None) -> TaskResult[Any, TaskError]:
|
|
1424
|
+
"""
|
|
1425
|
+
Block until workflow completes or timeout.
|
|
1426
|
+
|
|
1427
|
+
Returns:
|
|
1428
|
+
If output task specified: that task's TaskResult
|
|
1429
|
+
Otherwise: TaskResult containing dict of terminal task results
|
|
1430
|
+
"""
|
|
1431
|
+
|
|
1432
|
+
runner = LoopRunner()
|
|
1433
|
+
try:
|
|
1434
|
+
return runner.call(self.get_async, timeout_ms)
|
|
1435
|
+
finally:
|
|
1436
|
+
runner.stop()
|
|
1437
|
+
|
|
1438
|
+
async def get_async(
|
|
1439
|
+
self, timeout_ms: int | None = None
|
|
1440
|
+
) -> TaskResult[Any, TaskError]:
|
|
1441
|
+
"""Async version of get()."""
|
|
1442
|
+
|
|
1443
|
+
from horsies.core.models.tasks import TaskResult, TaskError, LibraryErrorCode
|
|
1444
|
+
|
|
1445
|
+
start = time.monotonic()
|
|
1446
|
+
timeout_sec = timeout_ms / 1000 if timeout_ms else None
|
|
1447
|
+
|
|
1448
|
+
while True:
|
|
1449
|
+
# Check current status
|
|
1450
|
+
status = await self.status_async()
|
|
1451
|
+
|
|
1452
|
+
if status == WorkflowStatus.COMPLETED:
|
|
1453
|
+
return await self._get_result()
|
|
1454
|
+
|
|
1455
|
+
if status in (WorkflowStatus.FAILED, WorkflowStatus.CANCELLED):
|
|
1456
|
+
return await self._get_error()
|
|
1457
|
+
|
|
1458
|
+
if status == WorkflowStatus.PAUSED:
|
|
1459
|
+
return TaskResult(
|
|
1460
|
+
err=TaskError(
|
|
1461
|
+
error_code='WORKFLOW_PAUSED',
|
|
1462
|
+
message='Workflow is paused awaiting intervention',
|
|
1463
|
+
)
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# Check timeout
|
|
1467
|
+
elapsed = time.monotonic() - start
|
|
1468
|
+
if timeout_sec and elapsed >= timeout_sec:
|
|
1469
|
+
return TaskResult(
|
|
1470
|
+
err=TaskError(
|
|
1471
|
+
error_code=LibraryErrorCode.WAIT_TIMEOUT,
|
|
1472
|
+
message=f'Workflow did not complete within {timeout_ms}ms',
|
|
1473
|
+
)
|
|
1474
|
+
)
|
|
1475
|
+
|
|
1476
|
+
# Wait for notification or poll
|
|
1477
|
+
remaining = (timeout_sec - elapsed) if timeout_sec else 5.0
|
|
1478
|
+
await self._wait_for_completion(min(remaining, 5.0))
|
|
1479
|
+
|
|
1480
|
+
async def _get_result(self) -> TaskResult[Any, TaskError]:
|
|
1481
|
+
"""Fetch completed workflow result."""
|
|
1482
|
+
from sqlalchemy import text
|
|
1483
|
+
|
|
1484
|
+
from horsies.core.models.tasks import TaskResult
|
|
1485
|
+
from horsies.core.codec.serde import loads_json, task_result_from_json
|
|
1486
|
+
|
|
1487
|
+
async with self.broker.session_factory() as session:
|
|
1488
|
+
result = await session.execute(
|
|
1489
|
+
text('SELECT result FROM horsies_workflows WHERE id = :wf_id'),
|
|
1490
|
+
{'wf_id': self.workflow_id},
|
|
1491
|
+
)
|
|
1492
|
+
row = result.fetchone()
|
|
1493
|
+
if row and row[0]:
|
|
1494
|
+
return task_result_from_json(loads_json(row[0]))
|
|
1495
|
+
return TaskResult(ok=None)
|
|
1496
|
+
|
|
1497
|
+
async def _get_error(self) -> TaskResult[Any, TaskError]:
|
|
1498
|
+
"""Fetch failed workflow error."""
|
|
1499
|
+
from sqlalchemy import text
|
|
1500
|
+
|
|
1501
|
+
from horsies.core.models.tasks import TaskResult, TaskError
|
|
1502
|
+
from horsies.core.codec.serde import loads_json
|
|
1503
|
+
|
|
1504
|
+
async with self.broker.session_factory() as session:
|
|
1505
|
+
result = await session.execute(
|
|
1506
|
+
text('SELECT error, status FROM horsies_workflows WHERE id = :wf_id'),
|
|
1507
|
+
{'wf_id': self.workflow_id},
|
|
1508
|
+
)
|
|
1509
|
+
row = result.fetchone()
|
|
1510
|
+
if row and row[0]:
|
|
1511
|
+
error_data = loads_json(row[0])
|
|
1512
|
+
if isinstance(error_data, dict):
|
|
1513
|
+
# Safely extract known TaskError fields with type narrowing
|
|
1514
|
+
raw_code = error_data.get('error_code')
|
|
1515
|
+
raw_msg = error_data.get('message')
|
|
1516
|
+
return TaskResult(
|
|
1517
|
+
err=TaskError(
|
|
1518
|
+
error_code=str(raw_code) if raw_code is not None else None,
|
|
1519
|
+
message=str(raw_msg) if raw_msg is not None else None,
|
|
1520
|
+
data=error_data.get('data'),
|
|
1521
|
+
)
|
|
1522
|
+
)
|
|
1523
|
+
status_str = row[1] if row else 'FAILED'
|
|
1524
|
+
return TaskResult(
|
|
1525
|
+
err=TaskError(
|
|
1526
|
+
error_code=f'WORKFLOW_{status_str}',
|
|
1527
|
+
message=f'Workflow {status_str.lower()}',
|
|
1528
|
+
)
|
|
1529
|
+
)
|
|
1530
|
+
|
|
1531
|
+
async def _wait_for_completion(self, timeout_sec: float) -> None:
|
|
1532
|
+
"""Wait for workflow_done notification or poll interval."""
|
|
1533
|
+
import asyncio
|
|
1534
|
+
|
|
1535
|
+
try:
|
|
1536
|
+
q = await self.broker.listener.listen('workflow_done')
|
|
1537
|
+
try:
|
|
1538
|
+
|
|
1539
|
+
async def _wait_for_workflow() -> None:
|
|
1540
|
+
while True:
|
|
1541
|
+
note = await q.get()
|
|
1542
|
+
if note.payload == self.workflow_id:
|
|
1543
|
+
return
|
|
1544
|
+
|
|
1545
|
+
await asyncio.wait_for(_wait_for_workflow(), timeout=timeout_sec)
|
|
1546
|
+
finally:
|
|
1547
|
+
await self.broker.listener.unsubscribe('workflow_done', q)
|
|
1548
|
+
except asyncio.TimeoutError:
|
|
1549
|
+
pass # Polling fallback
|
|
1550
|
+
|
|
1551
|
+
def results(self) -> dict[str, TaskResult[Any, TaskError]]:
|
|
1552
|
+
"""
|
|
1553
|
+
Get all task results keyed by unique identifier.
|
|
1554
|
+
|
|
1555
|
+
Keys are `node_id` values. If a TaskNode did not specify a node_id,
|
|
1556
|
+
WorkflowSpec auto-assigns one as "{workflow_name}:{task_index}".
|
|
1557
|
+
"""
|
|
1558
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1559
|
+
|
|
1560
|
+
runner = LoopRunner()
|
|
1561
|
+
try:
|
|
1562
|
+
return runner.call(self.results_async)
|
|
1563
|
+
finally:
|
|
1564
|
+
runner.stop()
|
|
1565
|
+
|
|
1566
|
+
async def results_async(self) -> dict[str, TaskResult[Any, TaskError]]:
|
|
1567
|
+
"""
|
|
1568
|
+
Async version of results().
|
|
1569
|
+
|
|
1570
|
+
Keys are `node_id` values. If a TaskNode did not specify a node_id,
|
|
1571
|
+
WorkflowSpec auto-assigns one as "{workflow_name}:{task_index}".
|
|
1572
|
+
"""
|
|
1573
|
+
from sqlalchemy import text
|
|
1574
|
+
|
|
1575
|
+
from horsies.core.codec.serde import loads_json, task_result_from_json
|
|
1576
|
+
|
|
1577
|
+
async with self.broker.session_factory() as session:
|
|
1578
|
+
result = await session.execute(
|
|
1579
|
+
text("""
|
|
1580
|
+
SELECT node_id, result
|
|
1581
|
+
FROM horsies_workflow_tasks
|
|
1582
|
+
WHERE workflow_id = :wf_id
|
|
1583
|
+
AND result IS NOT NULL
|
|
1584
|
+
"""),
|
|
1585
|
+
{'wf_id': self.workflow_id},
|
|
1586
|
+
)
|
|
1587
|
+
|
|
1588
|
+
return {
|
|
1589
|
+
row[0]: task_result_from_json(loads_json(row[1]))
|
|
1590
|
+
for row in result.fetchall()
|
|
1591
|
+
}
|
|
1592
|
+
|
|
1593
|
+
def result_for(
|
|
1594
|
+
self, node: TaskNode[OkT] | NodeKey[OkT]
|
|
1595
|
+
) -> 'TaskResult[OkT, TaskError]':
|
|
1596
|
+
"""
|
|
1597
|
+
Get the result for a specific TaskNode or NodeKey.
|
|
1598
|
+
|
|
1599
|
+
Non-blocking: queries the database once and returns immediately.
|
|
1600
|
+
|
|
1601
|
+
Args:
|
|
1602
|
+
node: The TaskNode or NodeKey whose result to retrieve.
|
|
1603
|
+
|
|
1604
|
+
Returns:
|
|
1605
|
+
TaskResult[T, TaskError] where T matches the node's type.
|
|
1606
|
+
- If task completed: returns the task's result (success or error)
|
|
1607
|
+
- If task not completed: returns TaskResult with
|
|
1608
|
+
error_code=LibraryErrorCode.RESULT_NOT_READY
|
|
1609
|
+
|
|
1610
|
+
Raises:
|
|
1611
|
+
WorkflowHandleMissingIdError: If node has no node_id assigned.
|
|
1612
|
+
|
|
1613
|
+
Example:
|
|
1614
|
+
result = handle.result_for(node)
|
|
1615
|
+
if result.is_err() and result.err.error_code == LibraryErrorCode.RESULT_NOT_READY:
|
|
1616
|
+
# Task hasn't completed yet - wait or check later
|
|
1617
|
+
pass
|
|
1618
|
+
"""
|
|
1619
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1620
|
+
|
|
1621
|
+
runner = LoopRunner()
|
|
1622
|
+
try:
|
|
1623
|
+
return runner.call(self.result_for_async, node)
|
|
1624
|
+
finally:
|
|
1625
|
+
runner.stop()
|
|
1626
|
+
|
|
1627
|
+
async def result_for_async(
|
|
1628
|
+
self, node: TaskNode[OkT] | NodeKey[OkT]
|
|
1629
|
+
) -> 'TaskResult[OkT, TaskError]':
|
|
1630
|
+
"""Async version of result_for(). See result_for() for full documentation."""
|
|
1631
|
+
from sqlalchemy import text
|
|
1632
|
+
|
|
1633
|
+
from horsies.core.codec.serde import loads_json, task_result_from_json
|
|
1634
|
+
|
|
1635
|
+
node_id: str | None
|
|
1636
|
+
if isinstance(node, NodeKey):
|
|
1637
|
+
node_id = node.node_id
|
|
1638
|
+
else:
|
|
1639
|
+
node_id = node.node_id
|
|
1640
|
+
|
|
1641
|
+
if node_id is None:
|
|
1642
|
+
raise WorkflowHandleMissingIdError(
|
|
1643
|
+
'TaskNode node_id is not set. Ensure WorkflowSpec assigns node_id '
|
|
1644
|
+
'or provide an explicit node_id.'
|
|
1645
|
+
)
|
|
1646
|
+
|
|
1647
|
+
async with self.broker.session_factory() as session:
|
|
1648
|
+
result = await session.execute(
|
|
1649
|
+
text("""
|
|
1650
|
+
SELECT result
|
|
1651
|
+
FROM horsies_workflow_tasks
|
|
1652
|
+
WHERE workflow_id = :wf_id
|
|
1653
|
+
AND node_id = :node_id
|
|
1654
|
+
AND result IS NOT NULL
|
|
1655
|
+
"""),
|
|
1656
|
+
{'wf_id': self.workflow_id, 'node_id': node_id},
|
|
1657
|
+
)
|
|
1658
|
+
row = result.fetchone()
|
|
1659
|
+
if row is None or row[0] is None:
|
|
1660
|
+
from horsies.core.models.tasks import (
|
|
1661
|
+
TaskResult,
|
|
1662
|
+
TaskError,
|
|
1663
|
+
LibraryErrorCode,
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
return cast(
|
|
1667
|
+
'TaskResult[OkT, TaskError]',
|
|
1668
|
+
TaskResult(
|
|
1669
|
+
err=TaskError(
|
|
1670
|
+
error_code=LibraryErrorCode.RESULT_NOT_READY,
|
|
1671
|
+
message=(
|
|
1672
|
+
f"Task '{node_id}' has not completed yet "
|
|
1673
|
+
f"in workflow '{self.workflow_id}'"
|
|
1674
|
+
),
|
|
1675
|
+
)
|
|
1676
|
+
),
|
|
1677
|
+
)
|
|
1678
|
+
|
|
1679
|
+
return cast(
|
|
1680
|
+
'TaskResult[OkT, TaskError]',
|
|
1681
|
+
task_result_from_json(loads_json(row[0])),
|
|
1682
|
+
)
|
|
1683
|
+
|
|
1684
|
+
def tasks(self) -> list[WorkflowTaskInfo]:
|
|
1685
|
+
"""Get status of all tasks in workflow."""
|
|
1686
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1687
|
+
|
|
1688
|
+
runner = LoopRunner()
|
|
1689
|
+
try:
|
|
1690
|
+
return runner.call(self.tasks_async)
|
|
1691
|
+
finally:
|
|
1692
|
+
runner.stop()
|
|
1693
|
+
|
|
1694
|
+
async def tasks_async(self) -> list[WorkflowTaskInfo]:
|
|
1695
|
+
"""Async version of tasks()."""
|
|
1696
|
+
from sqlalchemy import text
|
|
1697
|
+
|
|
1698
|
+
from horsies.core.codec.serde import loads_json, task_result_from_json
|
|
1699
|
+
|
|
1700
|
+
async with self.broker.session_factory() as session:
|
|
1701
|
+
result = await session.execute(
|
|
1702
|
+
text("""
|
|
1703
|
+
SELECT task_index, task_name, status, result, started_at, completed_at
|
|
1704
|
+
FROM horsies_workflow_tasks
|
|
1705
|
+
WHERE workflow_id = :wf_id
|
|
1706
|
+
ORDER BY task_index
|
|
1707
|
+
"""),
|
|
1708
|
+
{'wf_id': self.workflow_id},
|
|
1709
|
+
)
|
|
1710
|
+
|
|
1711
|
+
return [
|
|
1712
|
+
WorkflowTaskInfo(
|
|
1713
|
+
index=row[0],
|
|
1714
|
+
name=row[1],
|
|
1715
|
+
status=WorkflowTaskStatus(row[2]),
|
|
1716
|
+
result=task_result_from_json(loads_json(row[3]))
|
|
1717
|
+
if row[3]
|
|
1718
|
+
else None,
|
|
1719
|
+
started_at=row[4],
|
|
1720
|
+
completed_at=row[5],
|
|
1721
|
+
)
|
|
1722
|
+
for row in result.fetchall()
|
|
1723
|
+
]
|
|
1724
|
+
|
|
1725
|
+
def cancel(self) -> None:
|
|
1726
|
+
"""Request workflow cancellation."""
|
|
1727
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1728
|
+
|
|
1729
|
+
runner = LoopRunner()
|
|
1730
|
+
try:
|
|
1731
|
+
runner.call(self.cancel_async)
|
|
1732
|
+
finally:
|
|
1733
|
+
runner.stop()
|
|
1734
|
+
|
|
1735
|
+
async def cancel_async(self) -> None:
|
|
1736
|
+
"""Async version of cancel()."""
|
|
1737
|
+
from sqlalchemy import text
|
|
1738
|
+
|
|
1739
|
+
async with self.broker.session_factory() as session:
|
|
1740
|
+
# Cancel workflow
|
|
1741
|
+
await session.execute(
|
|
1742
|
+
text("""
|
|
1743
|
+
UPDATE horsies_workflows
|
|
1744
|
+
SET status = 'CANCELLED', updated_at = NOW()
|
|
1745
|
+
WHERE id = :wf_id AND status IN ('PENDING', 'RUNNING', 'PAUSED')
|
|
1746
|
+
"""),
|
|
1747
|
+
{'wf_id': self.workflow_id},
|
|
1748
|
+
)
|
|
1749
|
+
|
|
1750
|
+
# Skip pending/ready tasks
|
|
1751
|
+
await session.execute(
|
|
1752
|
+
text("""
|
|
1753
|
+
UPDATE horsies_workflow_tasks
|
|
1754
|
+
SET status = 'SKIPPED'
|
|
1755
|
+
WHERE workflow_id = :wf_id AND status IN ('PENDING', 'READY')
|
|
1756
|
+
"""),
|
|
1757
|
+
{'wf_id': self.workflow_id},
|
|
1758
|
+
)
|
|
1759
|
+
|
|
1760
|
+
await session.commit()
|
|
1761
|
+
|
|
1762
|
+
def pause(self) -> bool:
|
|
1763
|
+
"""
|
|
1764
|
+
Pause a running workflow.
|
|
1765
|
+
|
|
1766
|
+
Transitions workflow from RUNNING to PAUSED state. Already-running tasks
|
|
1767
|
+
will continue to completion, but no new tasks will be enqueued.
|
|
1768
|
+
|
|
1769
|
+
Use resume() to continue execution.
|
|
1770
|
+
|
|
1771
|
+
Returns:
|
|
1772
|
+
True if workflow was paused, False if not RUNNING (no-op)
|
|
1773
|
+
"""
|
|
1774
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1775
|
+
|
|
1776
|
+
runner = LoopRunner()
|
|
1777
|
+
try:
|
|
1778
|
+
return runner.call(self.pause_async)
|
|
1779
|
+
finally:
|
|
1780
|
+
runner.stop()
|
|
1781
|
+
|
|
1782
|
+
async def pause_async(self) -> bool:
|
|
1783
|
+
"""
|
|
1784
|
+
Async version of pause().
|
|
1785
|
+
|
|
1786
|
+
Returns:
|
|
1787
|
+
True if workflow was paused, False if not RUNNING (no-op)
|
|
1788
|
+
"""
|
|
1789
|
+
from horsies.core.workflows.engine import pause_workflow
|
|
1790
|
+
|
|
1791
|
+
return await pause_workflow(self.broker, self.workflow_id)
|
|
1792
|
+
|
|
1793
|
+
def resume(self) -> bool:
|
|
1794
|
+
"""
|
|
1795
|
+
Resume a paused workflow.
|
|
1796
|
+
|
|
1797
|
+
Re-evaluates all PENDING tasks (marks READY if deps are terminal) and
|
|
1798
|
+
enqueues all READY tasks. Only works if workflow is currently PAUSED.
|
|
1799
|
+
|
|
1800
|
+
Returns:
|
|
1801
|
+
True if workflow was resumed, False if not PAUSED (no-op)
|
|
1802
|
+
"""
|
|
1803
|
+
from horsies.core.utils.loop_runner import LoopRunner
|
|
1804
|
+
|
|
1805
|
+
runner = LoopRunner()
|
|
1806
|
+
try:
|
|
1807
|
+
return runner.call(self.resume_async)
|
|
1808
|
+
finally:
|
|
1809
|
+
runner.stop()
|
|
1810
|
+
|
|
1811
|
+
async def resume_async(self) -> bool:
|
|
1812
|
+
"""
|
|
1813
|
+
Async version of resume().
|
|
1814
|
+
|
|
1815
|
+
Returns:
|
|
1816
|
+
True if workflow was resumed, False if not PAUSED (no-op)
|
|
1817
|
+
"""
|
|
1818
|
+
from horsies.core.workflows.engine import resume_workflow
|
|
1819
|
+
|
|
1820
|
+
return await resume_workflow(self.broker, self.workflow_id)
|
|
1821
|
+
|
|
1822
|
+
|
|
1823
|
+
# =============================================================================
|
|
1824
|
+
# WorkflowDefinition (class-based workflow definition)
|
|
1825
|
+
# =============================================================================
|
|
1826
|
+
|
|
1827
|
+
|
|
1828
|
+
class WorkflowDefinitionMeta(type):
|
|
1829
|
+
"""
|
|
1830
|
+
Metaclass for WorkflowDefinition that preserves attribute order.
|
|
1831
|
+
|
|
1832
|
+
Collects TaskNode and SubWorkflowNode instances from class attributes
|
|
1833
|
+
in definition order.
|
|
1834
|
+
"""
|
|
1835
|
+
|
|
1836
|
+
def __new__(
|
|
1837
|
+
mcs,
|
|
1838
|
+
name: str,
|
|
1839
|
+
bases: tuple[type, ...],
|
|
1840
|
+
namespace: dict[str, Any],
|
|
1841
|
+
) -> 'WorkflowDefinitionMeta':
|
|
1842
|
+
cls = super().__new__(mcs, name, bases, namespace)
|
|
1843
|
+
|
|
1844
|
+
# Skip processing for the base class itself
|
|
1845
|
+
if name == 'WorkflowDefinition':
|
|
1846
|
+
return cls
|
|
1847
|
+
|
|
1848
|
+
# Collect TaskNode and SubWorkflowNode instances in definition order
|
|
1849
|
+
nodes: list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]] = []
|
|
1850
|
+
for attr_name, attr_value in namespace.items():
|
|
1851
|
+
if isinstance(attr_value, (TaskNode, SubWorkflowNode)):
|
|
1852
|
+
nodes.append((attr_name, attr_value))
|
|
1853
|
+
|
|
1854
|
+
# Store the collected nodes on the class
|
|
1855
|
+
cls._workflow_nodes = nodes # type: ignore[attr-defined]
|
|
1856
|
+
|
|
1857
|
+
return cls
|
|
1858
|
+
|
|
1859
|
+
|
|
1860
|
+
class WorkflowDefinition(Generic[OkT_co], metaclass=WorkflowDefinitionMeta):
|
|
1861
|
+
"""
|
|
1862
|
+
Base class for declarative workflow definitions.
|
|
1863
|
+
|
|
1864
|
+
Generic parameter OkT represents the workflow's output type, derived from
|
|
1865
|
+
Meta.output task's return type.
|
|
1866
|
+
|
|
1867
|
+
Provides a class-based alternative to app.workflow() for defining workflows.
|
|
1868
|
+
TaskNode and SubWorkflowNode instances defined as class attributes are
|
|
1869
|
+
automatically collected and used to build a WorkflowSpec.
|
|
1870
|
+
|
|
1871
|
+
Example:
|
|
1872
|
+
class ScrapeWorkflow(WorkflowDefinition[PersistResult]):
|
|
1873
|
+
name = "scrape_pipeline"
|
|
1874
|
+
|
|
1875
|
+
fetch = TaskNode(fn=fetch_listing, args=("url",))
|
|
1876
|
+
parse = TaskNode(fn=parse_listing, waits_for=[fetch], args_from={"raw": fetch})
|
|
1877
|
+
persist = TaskNode(fn=persist_listing, waits_for=[parse], args_from={"data": parse})
|
|
1878
|
+
|
|
1879
|
+
class Meta:
|
|
1880
|
+
output = persist # Output type is PersistResult
|
|
1881
|
+
on_error = OnError.FAIL
|
|
1882
|
+
|
|
1883
|
+
spec = ScrapeWorkflow.build(app)
|
|
1884
|
+
|
|
1885
|
+
Attributes:
|
|
1886
|
+
name: Required workflow name (class attribute).
|
|
1887
|
+
Meta: Optional inner class for workflow configuration.
|
|
1888
|
+
- output: TaskNode/SubWorkflowNode to use as workflow output (default: None)
|
|
1889
|
+
- on_error: Error handling policy (default: OnError.FAIL)
|
|
1890
|
+
- success_policy: Custom success policy (default: None)
|
|
1891
|
+
"""
|
|
1892
|
+
|
|
1893
|
+
# Class attributes to be defined by subclasses
|
|
1894
|
+
name: ClassVar[str]
|
|
1895
|
+
|
|
1896
|
+
# Populated by metaclass
|
|
1897
|
+
_workflow_nodes: ClassVar[list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]]]
|
|
1898
|
+
|
|
1899
|
+
@classmethod
|
|
1900
|
+
def get_workflow_nodes(
|
|
1901
|
+
cls,
|
|
1902
|
+
) -> list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]]:
|
|
1903
|
+
"""Return collected workflow nodes or an empty list if none were defined."""
|
|
1904
|
+
nodes = getattr(cls, '_workflow_nodes', None)
|
|
1905
|
+
if not isinstance(nodes, list):
|
|
1906
|
+
return []
|
|
1907
|
+
return cast(list[tuple[str, TaskNode[Any] | SubWorkflowNode[Any]]], nodes)
|
|
1908
|
+
|
|
1909
|
+
@classmethod
|
|
1910
|
+
def build(cls, app: 'Horsies') -> WorkflowSpec:
|
|
1911
|
+
"""
|
|
1912
|
+
Build a WorkflowSpec from this workflow definition.
|
|
1913
|
+
|
|
1914
|
+
Collects all TaskNode class attributes, assigns node_ids from attribute
|
|
1915
|
+
names, and creates a WorkflowSpec with the configured options.
|
|
1916
|
+
|
|
1917
|
+
Args:
|
|
1918
|
+
app: Horsies application instance (provides broker).
|
|
1919
|
+
|
|
1920
|
+
Returns:
|
|
1921
|
+
WorkflowSpec ready for execution.
|
|
1922
|
+
|
|
1923
|
+
Raises:
|
|
1924
|
+
WorkflowValidationError: If workflow definition is invalid.
|
|
1925
|
+
"""
|
|
1926
|
+
# Validate name is defined
|
|
1927
|
+
if not hasattr(cls, 'name') or not cls.name:
|
|
1928
|
+
raise WorkflowValidationError(
|
|
1929
|
+
f"WorkflowDefinition '{cls.__name__}' must define a 'name' class attribute"
|
|
1930
|
+
)
|
|
1931
|
+
|
|
1932
|
+
# Get collected nodes from metaclass
|
|
1933
|
+
nodes = cls.get_workflow_nodes()
|
|
1934
|
+
if not nodes:
|
|
1935
|
+
raise WorkflowValidationError(
|
|
1936
|
+
f"WorkflowDefinition '{cls.__name__}' has no TaskNode attributes"
|
|
1937
|
+
)
|
|
1938
|
+
|
|
1939
|
+
# Assign node_id from attribute name (if not already set)
|
|
1940
|
+
for attr_name, node in nodes:
|
|
1941
|
+
if node.node_id is None:
|
|
1942
|
+
node.node_id = attr_name
|
|
1943
|
+
|
|
1944
|
+
# Extract task list (preserving definition order)
|
|
1945
|
+
tasks = [node for _, node in nodes]
|
|
1946
|
+
|
|
1947
|
+
# Get Meta configuration
|
|
1948
|
+
output: TaskNode[Any] | SubWorkflowNode[Any] | None = None
|
|
1949
|
+
on_error: OnError = OnError.FAIL
|
|
1950
|
+
success_policy: SuccessPolicy | None = None
|
|
1951
|
+
|
|
1952
|
+
meta: type[Any] | None = getattr(cls, 'Meta', None)
|
|
1953
|
+
if meta is not None:
|
|
1954
|
+
output = getattr(meta, 'output', None)
|
|
1955
|
+
on_error = getattr(meta, 'on_error', OnError.FAIL)
|
|
1956
|
+
success_policy = getattr(meta, 'success_policy', None)
|
|
1957
|
+
|
|
1958
|
+
# Build WorkflowSpec
|
|
1959
|
+
spec = app.workflow(
|
|
1960
|
+
name=cls.name,
|
|
1961
|
+
tasks=tasks,
|
|
1962
|
+
output=output,
|
|
1963
|
+
on_error=on_error,
|
|
1964
|
+
success_policy=success_policy,
|
|
1965
|
+
)
|
|
1966
|
+
spec.workflow_def_module = cls.__module__
|
|
1967
|
+
spec.workflow_def_qualname = cls.__qualname__
|
|
1968
|
+
return spec
|
|
1969
|
+
|
|
1970
|
+
@classmethod
|
|
1971
|
+
def build_with(
|
|
1972
|
+
cls,
|
|
1973
|
+
app: 'Horsies',
|
|
1974
|
+
*args: Any,
|
|
1975
|
+
**params: Any,
|
|
1976
|
+
) -> WorkflowSpec:
|
|
1977
|
+
"""
|
|
1978
|
+
Build a WorkflowSpec with runtime parameters.
|
|
1979
|
+
|
|
1980
|
+
Subclasses can override this to apply params to TaskNodes.
|
|
1981
|
+
Default implementation forwards to build().
|
|
1982
|
+
"""
|
|
1983
|
+
_ = args
|
|
1984
|
+
_ = params
|
|
1985
|
+
spec = cls.build(app)
|
|
1986
|
+
spec.workflow_def_module = cls.__module__
|
|
1987
|
+
spec.workflow_def_qualname = cls.__qualname__
|
|
1988
|
+
return spec
|