highway-dsl 0.0.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of highway-dsl might be problematic. Click here for more details.
- highway_dsl/__init__.py +15 -6
- highway_dsl/workflow_dsl.py +498 -72
- highway_dsl-1.2.0.dist-info/METADATA +481 -0
- highway_dsl-1.2.0.dist-info/RECORD +7 -0
- highway_dsl-0.0.2.dist-info/METADATA +0 -227
- highway_dsl-0.0.2.dist-info/RECORD +0 -7
- {highway_dsl-0.0.2.dist-info → highway_dsl-1.2.0.dist-info}/WHEEL +0 -0
- {highway_dsl-0.0.2.dist-info → highway_dsl-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {highway_dsl-0.0.2.dist-info → highway_dsl-1.2.0.dist-info}/top_level.txt +0 -0
highway_dsl/workflow_dsl.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
# workflow_dsl.py
|
|
2
|
-
from
|
|
3
|
-
from
|
|
2
|
+
from abc import ABC
|
|
3
|
+
from collections.abc import Callable
|
|
4
4
|
from datetime import datetime, timedelta
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Optional, Union
|
|
7
|
+
|
|
5
8
|
import yaml
|
|
6
|
-
import
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
8
|
-
from pydantic import BaseModel, Field, model_validator, ConfigDict
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class OperatorType(Enum):
|
|
@@ -16,6 +17,9 @@ class OperatorType(Enum):
|
|
|
16
17
|
FOREACH = "foreach"
|
|
17
18
|
SWITCH = "switch"
|
|
18
19
|
TRY_CATCH = "try_catch"
|
|
20
|
+
WHILE = "while"
|
|
21
|
+
EMIT_EVENT = "emit_event"
|
|
22
|
+
WAIT_FOR_EVENT = "wait_for_event"
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
class RetryPolicy(BaseModel):
|
|
@@ -27,38 +31,45 @@ class RetryPolicy(BaseModel):
|
|
|
27
31
|
class TimeoutPolicy(BaseModel):
|
|
28
32
|
timeout: timedelta = Field(..., description="Timeout duration")
|
|
29
33
|
kill_on_timeout: bool = Field(
|
|
30
|
-
True, description="Whether to kill the task on timeout"
|
|
34
|
+
True, description="Whether to kill the task on timeout",
|
|
31
35
|
)
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
class BaseOperator(BaseModel, ABC):
|
|
35
39
|
task_id: str
|
|
36
40
|
operator_type: OperatorType
|
|
37
|
-
dependencies:
|
|
38
|
-
retry_policy:
|
|
39
|
-
timeout_policy:
|
|
40
|
-
metadata:
|
|
41
|
+
dependencies: list[str] = Field(default_factory=list)
|
|
42
|
+
retry_policy: RetryPolicy | None = None
|
|
43
|
+
timeout_policy: TimeoutPolicy | None = None
|
|
44
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
45
|
+
description: str = Field(default="", description="Task description")
|
|
46
|
+
# Phase 3: Callback hooks
|
|
47
|
+
on_success_task_id: str | None = Field(None, description="Task to run on success")
|
|
48
|
+
on_failure_task_id: str | None = Field(None, description="Task to run on failure")
|
|
49
|
+
is_internal_loop_task: bool = Field(
|
|
50
|
+
default=False, exclude=True,
|
|
51
|
+
) # Mark if task is internal to a loop
|
|
41
52
|
|
|
42
53
|
model_config = ConfigDict(use_enum_values=True, arbitrary_types_allowed=True)
|
|
43
54
|
|
|
44
55
|
|
|
45
56
|
class TaskOperator(BaseOperator):
|
|
46
57
|
function: str
|
|
47
|
-
args:
|
|
48
|
-
kwargs:
|
|
49
|
-
result_key:
|
|
58
|
+
args: list[Any] = Field(default_factory=list)
|
|
59
|
+
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
60
|
+
result_key: str | None = None
|
|
50
61
|
operator_type: OperatorType = Field(OperatorType.TASK, frozen=True)
|
|
51
62
|
|
|
52
63
|
|
|
53
64
|
class ConditionOperator(BaseOperator):
|
|
54
65
|
condition: str
|
|
55
|
-
if_true: str
|
|
56
|
-
if_false: str
|
|
66
|
+
if_true: str | None
|
|
67
|
+
if_false: str | None
|
|
57
68
|
operator_type: OperatorType = Field(OperatorType.CONDITION, frozen=True)
|
|
58
69
|
|
|
59
70
|
|
|
60
71
|
class WaitOperator(BaseOperator):
|
|
61
|
-
wait_for:
|
|
72
|
+
wait_for: timedelta | datetime | str
|
|
62
73
|
operator_type: OperatorType = Field(OperatorType.WAIT, frozen=True)
|
|
63
74
|
|
|
64
75
|
@model_validator(mode="before")
|
|
@@ -73,7 +84,7 @@ class WaitOperator(BaseOperator):
|
|
|
73
84
|
data["wait_for"] = datetime.fromisoformat(wait_for.split(":", 1)[1])
|
|
74
85
|
return data
|
|
75
86
|
|
|
76
|
-
def model_dump(self, **kwargs) ->
|
|
87
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
77
88
|
data = super().model_dump(**kwargs)
|
|
78
89
|
wait_for = data["wait_for"]
|
|
79
90
|
if isinstance(wait_for, timedelta):
|
|
@@ -84,44 +95,142 @@ class WaitOperator(BaseOperator):
|
|
|
84
95
|
|
|
85
96
|
|
|
86
97
|
class ParallelOperator(BaseOperator):
|
|
87
|
-
branches:
|
|
98
|
+
branches: dict[str, list[str]] = Field(default_factory=dict)
|
|
99
|
+
timeout: int | None = Field(None, description="Optional timeout in seconds for branch execution")
|
|
88
100
|
operator_type: OperatorType = Field(OperatorType.PARALLEL, frozen=True)
|
|
89
101
|
|
|
90
102
|
|
|
91
103
|
class ForEachOperator(BaseOperator):
|
|
92
104
|
items: str
|
|
93
|
-
|
|
105
|
+
loop_body: list[
|
|
106
|
+
Union[
|
|
107
|
+
TaskOperator,
|
|
108
|
+
ConditionOperator,
|
|
109
|
+
WaitOperator,
|
|
110
|
+
ParallelOperator,
|
|
111
|
+
"ForEachOperator",
|
|
112
|
+
"WhileOperator",
|
|
113
|
+
"EmitEventOperator",
|
|
114
|
+
"WaitForEventOperator",
|
|
115
|
+
"SwitchOperator",
|
|
116
|
+
]
|
|
117
|
+
] = Field(default_factory=list)
|
|
94
118
|
operator_type: OperatorType = Field(OperatorType.FOREACH, frozen=True)
|
|
95
119
|
|
|
96
120
|
|
|
97
|
-
class
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
description: str = ""
|
|
101
|
-
tasks: Dict[
|
|
102
|
-
str,
|
|
121
|
+
class WhileOperator(BaseOperator):
|
|
122
|
+
condition: str
|
|
123
|
+
loop_body: list[
|
|
103
124
|
Union[
|
|
104
125
|
TaskOperator,
|
|
105
126
|
ConditionOperator,
|
|
106
127
|
WaitOperator,
|
|
107
128
|
ParallelOperator,
|
|
108
129
|
ForEachOperator,
|
|
109
|
-
|
|
130
|
+
"WhileOperator",
|
|
131
|
+
"EmitEventOperator",
|
|
132
|
+
"WaitForEventOperator",
|
|
133
|
+
"SwitchOperator",
|
|
134
|
+
]
|
|
135
|
+
] = Field(default_factory=list)
|
|
136
|
+
operator_type: OperatorType = Field(OperatorType.WHILE, frozen=True)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class EmitEventOperator(BaseOperator):
|
|
140
|
+
"""Phase 2: Emit an event that other workflows can wait for."""
|
|
141
|
+
event_name: str = Field(..., description="Name of the event to emit")
|
|
142
|
+
payload: dict[str, Any] = Field(default_factory=dict, description="Event payload data")
|
|
143
|
+
operator_type: OperatorType = Field(OperatorType.EMIT_EVENT, frozen=True)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class WaitForEventOperator(BaseOperator):
|
|
147
|
+
"""Phase 2: Wait for an external event with optional timeout."""
|
|
148
|
+
event_name: str = Field(..., description="Name of the event to wait for")
|
|
149
|
+
timeout_seconds: int | None = Field(None, description="Timeout in seconds (None = wait forever)")
|
|
150
|
+
operator_type: OperatorType = Field(OperatorType.WAIT_FOR_EVENT, frozen=True)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class SwitchOperator(BaseOperator):
|
|
154
|
+
"""Phase 4: Multi-branch switch/case operator."""
|
|
155
|
+
switch_on: str = Field(..., description="Expression to evaluate for switch")
|
|
156
|
+
cases: dict[str, str] = Field(default_factory=dict, description="Map of case values to task IDs")
|
|
157
|
+
default: str | None = Field(None, description="Default task ID if no case matches")
|
|
158
|
+
operator_type: OperatorType = Field(OperatorType.SWITCH, frozen=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Workflow(BaseModel):
|
|
162
|
+
name: str
|
|
163
|
+
version: str = "1.1.0"
|
|
164
|
+
description: str = ""
|
|
165
|
+
tasks: dict[
|
|
166
|
+
str,
|
|
167
|
+
TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
|
|
110
168
|
] = Field(default_factory=dict)
|
|
111
|
-
variables:
|
|
112
|
-
start_task:
|
|
169
|
+
variables: dict[str, Any] = Field(default_factory=dict)
|
|
170
|
+
start_task: str | None = None
|
|
171
|
+
|
|
172
|
+
# Phase 1: Scheduling metadata
|
|
173
|
+
schedule: str | None = Field(None, description="Cron expression for scheduled execution")
|
|
174
|
+
start_date: datetime | None = Field(None, description="When the schedule becomes active")
|
|
175
|
+
catchup: bool = Field(False, description="Whether to backfill missed runs")
|
|
176
|
+
is_paused: bool = Field(False, description="Whether the workflow is paused")
|
|
177
|
+
tags: list[str] = Field(default_factory=list, description="Workflow categorization tags")
|
|
178
|
+
max_active_runs: int = Field(1, description="Maximum number of concurrent runs")
|
|
179
|
+
default_retry_policy: RetryPolicy | None = Field(None, description="Default retry policy for all tasks")
|
|
180
|
+
|
|
181
|
+
@model_validator(mode="before")
|
|
182
|
+
@classmethod
|
|
183
|
+
def validate_workflow_name_and_version(cls, data: Any) -> Any:
|
|
184
|
+
"""Validate workflow name and version don't contain '__' (double underscore).
|
|
185
|
+
|
|
186
|
+
The double underscore is reserved as a separator for display purposes:
|
|
187
|
+
{workflow_name}__{version}__{step_name}
|
|
188
|
+
|
|
189
|
+
Workflow names must match: ^[a-z][a-z0-9_]*$ (lowercase, alphanumeric, single underscore)
|
|
190
|
+
Workflow versions must match: ^[a-zA-Z0-9._-]+$ (semver compatible)
|
|
191
|
+
"""
|
|
192
|
+
import re
|
|
193
|
+
|
|
194
|
+
if isinstance(data, dict):
|
|
195
|
+
name = data.get("name", "")
|
|
196
|
+
version = data.get("version", "")
|
|
197
|
+
|
|
198
|
+
# Check for double underscore (reserved separator)
|
|
199
|
+
if "__" in name:
|
|
200
|
+
msg = f"Workflow name '{name}' cannot contain '__' (double underscore) - it's reserved as a separator"
|
|
201
|
+
raise ValueError(msg)
|
|
202
|
+
|
|
203
|
+
if "__" in version:
|
|
204
|
+
msg = f"Workflow version '{version}' cannot contain '__' (double underscore) - it's reserved as a separator"
|
|
205
|
+
raise ValueError(msg)
|
|
206
|
+
|
|
207
|
+
# Validate workflow name format
|
|
208
|
+
if name and not re.match(r"^[a-z][a-z0-9_]*$", name):
|
|
209
|
+
msg = f"Workflow name '{name}' must start with lowercase letter and contain only lowercase letters, digits, and single underscores"
|
|
210
|
+
raise ValueError(msg)
|
|
211
|
+
|
|
212
|
+
# Validate workflow version format (semver compatible)
|
|
213
|
+
if version and not re.match(r"^[a-zA-Z0-9._-]+$", version):
|
|
214
|
+
msg = f"Workflow version '{version}' must contain only alphanumeric characters, dots, hyphens, and underscores (semver compatible)"
|
|
215
|
+
raise ValueError(msg)
|
|
216
|
+
|
|
217
|
+
return data
|
|
113
218
|
|
|
114
219
|
@model_validator(mode="before")
|
|
115
220
|
@classmethod
|
|
116
221
|
def validate_tasks(cls, data: Any) -> Any:
|
|
117
222
|
if isinstance(data, dict) and "tasks" in data:
|
|
118
223
|
validated_tasks = {}
|
|
119
|
-
operator_classes:
|
|
224
|
+
operator_classes: dict[str, type[BaseOperator]] = {
|
|
120
225
|
OperatorType.TASK.value: TaskOperator,
|
|
121
226
|
OperatorType.CONDITION.value: ConditionOperator,
|
|
122
227
|
OperatorType.WAIT.value: WaitOperator,
|
|
123
228
|
OperatorType.PARALLEL.value: ParallelOperator,
|
|
124
229
|
OperatorType.FOREACH.value: ForEachOperator,
|
|
230
|
+
OperatorType.WHILE.value: WhileOperator,
|
|
231
|
+
OperatorType.EMIT_EVENT.value: EmitEventOperator,
|
|
232
|
+
OperatorType.WAIT_FOR_EVENT.value: WaitForEventOperator,
|
|
233
|
+
OperatorType.SWITCH.value: SwitchOperator,
|
|
125
234
|
}
|
|
126
235
|
for task_id, task_data in data["tasks"].items():
|
|
127
236
|
operator_type = task_data.get("operator_type")
|
|
@@ -129,24 +238,19 @@ class Workflow(BaseModel):
|
|
|
129
238
|
operator_class = operator_classes[operator_type]
|
|
130
239
|
validated_tasks[task_id] = operator_class.model_validate(task_data)
|
|
131
240
|
else:
|
|
132
|
-
|
|
241
|
+
msg = f"Unknown operator type: {operator_type}"
|
|
242
|
+
raise ValueError(msg)
|
|
133
243
|
data["tasks"] = validated_tasks
|
|
134
244
|
return data
|
|
135
245
|
|
|
136
246
|
def add_task(
|
|
137
247
|
self,
|
|
138
|
-
task:
|
|
139
|
-
TaskOperator,
|
|
140
|
-
ConditionOperator,
|
|
141
|
-
WaitOperator,
|
|
142
|
-
ParallelOperator,
|
|
143
|
-
ForEachOperator,
|
|
144
|
-
],
|
|
248
|
+
task: TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
|
|
145
249
|
) -> "Workflow":
|
|
146
250
|
self.tasks[task.task_id] = task
|
|
147
251
|
return self
|
|
148
252
|
|
|
149
|
-
def set_variables(self, variables:
|
|
253
|
+
def set_variables(self, variables: dict[str, Any]) -> "Workflow":
|
|
150
254
|
self.variables.update(variables)
|
|
151
255
|
return self
|
|
152
256
|
|
|
@@ -154,6 +258,42 @@ class Workflow(BaseModel):
|
|
|
154
258
|
self.start_task = task_id
|
|
155
259
|
return self
|
|
156
260
|
|
|
261
|
+
# Phase 1: Scheduling methods
|
|
262
|
+
def set_schedule(self, cron: str) -> "Workflow":
|
|
263
|
+
"""Set the cron schedule for this workflow."""
|
|
264
|
+
self.schedule = cron
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
def set_start_date(self, start_date: datetime) -> "Workflow":
|
|
268
|
+
"""Set when the schedule becomes active."""
|
|
269
|
+
self.start_date = start_date
|
|
270
|
+
return self
|
|
271
|
+
|
|
272
|
+
def set_catchup(self, enabled: bool) -> "Workflow":
|
|
273
|
+
"""Set whether to backfill missed runs."""
|
|
274
|
+
self.catchup = enabled
|
|
275
|
+
return self
|
|
276
|
+
|
|
277
|
+
def set_paused(self, paused: bool) -> "Workflow":
|
|
278
|
+
"""Set whether the workflow is paused."""
|
|
279
|
+
self.is_paused = paused
|
|
280
|
+
return self
|
|
281
|
+
|
|
282
|
+
def add_tags(self, *tags: str) -> "Workflow":
|
|
283
|
+
"""Add tags to the workflow."""
|
|
284
|
+
self.tags.extend(tags)
|
|
285
|
+
return self
|
|
286
|
+
|
|
287
|
+
def set_max_active_runs(self, count: int) -> "Workflow":
|
|
288
|
+
"""Set maximum number of concurrent runs."""
|
|
289
|
+
self.max_active_runs = count
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
def set_default_retry_policy(self, policy: RetryPolicy) -> "Workflow":
|
|
293
|
+
"""Set default retry policy for all tasks."""
|
|
294
|
+
self.default_retry_policy = policy
|
|
295
|
+
return self
|
|
296
|
+
|
|
157
297
|
def to_yaml(self) -> str:
|
|
158
298
|
data = self.model_dump(mode="json", by_alias=True, exclude_none=True)
|
|
159
299
|
return yaml.dump(data, default_flow_style=False)
|
|
@@ -161,6 +301,68 @@ class Workflow(BaseModel):
|
|
|
161
301
|
def to_json(self) -> str:
|
|
162
302
|
return self.model_dump_json(indent=2)
|
|
163
303
|
|
|
304
|
+
def to_mermaid(self) -> str:
|
|
305
|
+
""" convert to mermaid state diagram format """
|
|
306
|
+
lines = ["stateDiagram-v2"]
|
|
307
|
+
|
|
308
|
+
all_dependencies = {dep for task in self.tasks.values() for dep in task.dependencies}
|
|
309
|
+
|
|
310
|
+
for task_id, task in self.tasks.items():
|
|
311
|
+
# Add state with description for regular tasks
|
|
312
|
+
if task.description and not isinstance(task, (ForEachOperator, WhileOperator)):
|
|
313
|
+
lines.append(f' state "{task.description}" as {task_id}')
|
|
314
|
+
|
|
315
|
+
# Add dependencies
|
|
316
|
+
if not task.dependencies:
|
|
317
|
+
if self.start_task == task_id or not self.start_task:
|
|
318
|
+
lines.append(f' [*] --> {task_id}')
|
|
319
|
+
else:
|
|
320
|
+
for dep in task.dependencies:
|
|
321
|
+
lines.append(f' {dep} --> {task_id}')
|
|
322
|
+
|
|
323
|
+
# Add transitions for conditional operator
|
|
324
|
+
if isinstance(task, ConditionOperator):
|
|
325
|
+
if task.if_true:
|
|
326
|
+
lines.append(f' {task_id} --> {task.if_true} : True')
|
|
327
|
+
if task.if_false:
|
|
328
|
+
lines.append(f' {task_id} --> {task.if_false} : False')
|
|
329
|
+
|
|
330
|
+
# Add composite state for parallel operator
|
|
331
|
+
if isinstance(task, ParallelOperator):
|
|
332
|
+
lines.append(f' state {task_id} {{')
|
|
333
|
+
for i, branch in enumerate(task.branches):
|
|
334
|
+
lines.append(f' state "Branch {i+1}" as {branch}')
|
|
335
|
+
if i < len(task.branches) - 1:
|
|
336
|
+
lines.append(' --')
|
|
337
|
+
lines.append(' }')
|
|
338
|
+
|
|
339
|
+
# Add composite state for foreach operator
|
|
340
|
+
if isinstance(task, ForEachOperator):
|
|
341
|
+
lines.append(f' state {task_id} {{')
|
|
342
|
+
for sub_task in task.loop_body:
|
|
343
|
+
if sub_task.description:
|
|
344
|
+
lines.append(f' state "{sub_task.description}" as {sub_task.task_id}')
|
|
345
|
+
else:
|
|
346
|
+
lines.append(f' {sub_task.task_id}')
|
|
347
|
+
lines.append(' }')
|
|
348
|
+
|
|
349
|
+
# Add composite state for while operator
|
|
350
|
+
if isinstance(task, WhileOperator):
|
|
351
|
+
lines.append(f' state {task_id} {{')
|
|
352
|
+
for sub_task in task.loop_body:
|
|
353
|
+
if sub_task.description:
|
|
354
|
+
lines.append(f' state "{sub_task.description}" as {sub_task.task_id}')
|
|
355
|
+
else:
|
|
356
|
+
lines.append(f' {sub_task.task_id}')
|
|
357
|
+
lines.append(' }')
|
|
358
|
+
|
|
359
|
+
# End states
|
|
360
|
+
if task_id not in all_dependencies:
|
|
361
|
+
if not (isinstance(task, ConditionOperator) and (task.if_true or task.if_false)):
|
|
362
|
+
lines.append(f' {task_id} --> [*]')
|
|
363
|
+
|
|
364
|
+
return "\n".join(lines)
|
|
365
|
+
|
|
164
366
|
@classmethod
|
|
165
367
|
def from_yaml(cls, yaml_str: str) -> "Workflow":
|
|
166
368
|
data = yaml.safe_load(yaml_str)
|
|
@@ -172,66 +374,208 @@ class Workflow(BaseModel):
|
|
|
172
374
|
|
|
173
375
|
|
|
174
376
|
class WorkflowBuilder:
|
|
175
|
-
def __init__(
|
|
377
|
+
def __init__(
|
|
378
|
+
self,
|
|
379
|
+
name: str,
|
|
380
|
+
existing_workflow: Workflow | None = None,
|
|
381
|
+
parent: Optional["WorkflowBuilder"] = None,
|
|
382
|
+
) -> None:
|
|
176
383
|
if existing_workflow:
|
|
177
384
|
self.workflow = existing_workflow
|
|
178
385
|
else:
|
|
179
|
-
self.workflow = Workflow(
|
|
180
|
-
|
|
386
|
+
self.workflow = Workflow(
|
|
387
|
+
name=name,
|
|
388
|
+
version="1.1.0",
|
|
389
|
+
description="",
|
|
390
|
+
tasks={},
|
|
391
|
+
variables={},
|
|
392
|
+
start_task=None,
|
|
393
|
+
schedule=None,
|
|
394
|
+
start_date=None,
|
|
395
|
+
catchup=False,
|
|
396
|
+
is_paused=False,
|
|
397
|
+
tags=[],
|
|
398
|
+
max_active_runs=1,
|
|
399
|
+
default_retry_policy=None,
|
|
400
|
+
)
|
|
401
|
+
self._current_task: str | None = None
|
|
402
|
+
self.parent = parent
|
|
403
|
+
|
|
404
|
+
def _add_task(
|
|
405
|
+
self,
|
|
406
|
+
task: TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
|
|
407
|
+
**kwargs: Any,
|
|
408
|
+
) -> None:
|
|
409
|
+
dependencies = kwargs.get("dependencies", [])
|
|
410
|
+
if self._current_task and not dependencies:
|
|
411
|
+
dependencies.append(self._current_task)
|
|
412
|
+
|
|
413
|
+
task.dependencies = sorted(set(dependencies))
|
|
181
414
|
|
|
182
|
-
def task(self, task_id: str, function: str, **kwargs) -> "WorkflowBuilder":
|
|
183
|
-
task = TaskOperator(task_id=task_id, function=function, **kwargs)
|
|
184
|
-
if self._current_task:
|
|
185
|
-
task.dependencies.append(self._current_task)
|
|
186
415
|
self.workflow.add_task(task)
|
|
187
|
-
self._current_task = task_id
|
|
416
|
+
self._current_task = task.task_id
|
|
417
|
+
|
|
418
|
+
def task(self, task_id: str, function: str, **kwargs: Any) -> "WorkflowBuilder":
|
|
419
|
+
task = TaskOperator(task_id=task_id, function=function, **kwargs)
|
|
420
|
+
self._add_task(task, **kwargs)
|
|
188
421
|
return self
|
|
189
422
|
|
|
190
423
|
def condition(
|
|
191
|
-
self,
|
|
424
|
+
self,
|
|
425
|
+
task_id: str,
|
|
426
|
+
condition: str,
|
|
427
|
+
if_true: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
|
|
428
|
+
if_false: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
|
|
429
|
+
**kwargs: Any,
|
|
192
430
|
) -> "WorkflowBuilder":
|
|
431
|
+
true_builder = if_true(WorkflowBuilder(f"{task_id}_true", parent=self))
|
|
432
|
+
false_builder = if_false(WorkflowBuilder(f"{task_id}_false", parent=self))
|
|
433
|
+
|
|
434
|
+
true_tasks = list(true_builder.workflow.tasks.keys())
|
|
435
|
+
false_tasks = list(false_builder.workflow.tasks.keys())
|
|
436
|
+
|
|
193
437
|
task = ConditionOperator(
|
|
194
438
|
task_id=task_id,
|
|
195
439
|
condition=condition,
|
|
196
|
-
if_true=
|
|
197
|
-
if_false=
|
|
440
|
+
if_true=true_tasks[0] if true_tasks else None,
|
|
441
|
+
if_false=false_tasks[0] if false_tasks else None,
|
|
198
442
|
**kwargs,
|
|
199
443
|
)
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
444
|
+
|
|
445
|
+
self._add_task(task, **kwargs)
|
|
446
|
+
|
|
447
|
+
for task_obj in true_builder.workflow.tasks.values():
|
|
448
|
+
# Only add the condition task as dependency, preserve original dependencies
|
|
449
|
+
if task_id not in task_obj.dependencies:
|
|
450
|
+
task_obj.dependencies.append(task_id)
|
|
451
|
+
self.workflow.add_task(task_obj)
|
|
452
|
+
for task_obj in false_builder.workflow.tasks.values():
|
|
453
|
+
# Only add the condition task as dependency, preserve original dependencies
|
|
454
|
+
if task_id not in task_obj.dependencies:
|
|
455
|
+
task_obj.dependencies.append(task_id)
|
|
456
|
+
self.workflow.add_task(task_obj)
|
|
457
|
+
|
|
203
458
|
self._current_task = task_id
|
|
204
459
|
return self
|
|
205
460
|
|
|
206
461
|
def wait(
|
|
207
|
-
self, task_id: str, wait_for:
|
|
462
|
+
self, task_id: str, wait_for: timedelta | datetime | str, **kwargs: Any,
|
|
208
463
|
) -> "WorkflowBuilder":
|
|
209
464
|
task = WaitOperator(task_id=task_id, wait_for=wait_for, **kwargs)
|
|
210
|
-
|
|
211
|
-
task.dependencies.append(self._current_task)
|
|
212
|
-
self.workflow.add_task(task)
|
|
213
|
-
self._current_task = task_id
|
|
465
|
+
self._add_task(task, **kwargs)
|
|
214
466
|
return self
|
|
215
467
|
|
|
216
468
|
def parallel(
|
|
217
|
-
self,
|
|
469
|
+
self,
|
|
470
|
+
task_id: str,
|
|
471
|
+
branches: dict[str, Callable[["WorkflowBuilder"], "WorkflowBuilder"]],
|
|
472
|
+
**kwargs: Any,
|
|
218
473
|
) -> "WorkflowBuilder":
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
474
|
+
branch_builders = {}
|
|
475
|
+
for name, branch_func in branches.items():
|
|
476
|
+
branch_builder = branch_func(
|
|
477
|
+
WorkflowBuilder(f"{task_id}_{name}", parent=self),
|
|
478
|
+
)
|
|
479
|
+
branch_builders[name] = branch_builder
|
|
480
|
+
|
|
481
|
+
branch_tasks = {
|
|
482
|
+
name: list(builder.workflow.tasks.keys())
|
|
483
|
+
for name, builder in branch_builders.items()
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
task = ParallelOperator(task_id=task_id, branches=branch_tasks, **kwargs)
|
|
487
|
+
|
|
488
|
+
self._add_task(task, **kwargs)
|
|
489
|
+
|
|
490
|
+
for builder in branch_builders.values():
|
|
491
|
+
for task_obj in builder.workflow.tasks.values():
|
|
492
|
+
# Only add the parallel task as dependency to non-internal tasks,
|
|
493
|
+
# preserve original dependencies
|
|
494
|
+
if (
|
|
495
|
+
not getattr(task_obj, "is_internal_loop_task", False)
|
|
496
|
+
and task_id not in task_obj.dependencies
|
|
497
|
+
):
|
|
498
|
+
task_obj.dependencies.append(task_id)
|
|
499
|
+
self.workflow.add_task(task_obj)
|
|
500
|
+
|
|
223
501
|
self._current_task = task_id
|
|
224
502
|
return self
|
|
225
503
|
|
|
226
504
|
def foreach(
|
|
227
|
-
self,
|
|
505
|
+
self,
|
|
506
|
+
task_id: str,
|
|
507
|
+
items: str,
|
|
508
|
+
loop_body: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
|
|
509
|
+
**kwargs: Any,
|
|
228
510
|
) -> "WorkflowBuilder":
|
|
511
|
+
# Create a temporary builder for the loop body.
|
|
512
|
+
temp_builder = WorkflowBuilder(f"{task_id}_loop", parent=self)
|
|
513
|
+
loop_builder = loop_body(temp_builder)
|
|
514
|
+
loop_tasks = list(loop_builder.workflow.tasks.values())
|
|
515
|
+
|
|
516
|
+
# Mark all loop body tasks as internal to prevent parallel dependency injection
|
|
517
|
+
for task_obj in loop_tasks:
|
|
518
|
+
task_obj.is_internal_loop_task = True
|
|
519
|
+
|
|
520
|
+
# Create the foreach operator
|
|
229
521
|
task = ForEachOperator(
|
|
230
|
-
task_id=task_id,
|
|
522
|
+
task_id=task_id,
|
|
523
|
+
items=items,
|
|
524
|
+
loop_body=loop_tasks,
|
|
525
|
+
**kwargs,
|
|
231
526
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
self.
|
|
527
|
+
|
|
528
|
+
# Add the foreach task to workflow to establish initial dependencies
|
|
529
|
+
self._add_task(task, **kwargs)
|
|
530
|
+
|
|
531
|
+
# Add the foreach task as dependency to the FIRST task in the loop body
|
|
532
|
+
# and preserve the original dependency chain within the loop
|
|
533
|
+
if loop_tasks:
|
|
534
|
+
first_task = loop_tasks[0]
|
|
535
|
+
if task_id not in first_task.dependencies:
|
|
536
|
+
first_task.dependencies.append(task_id)
|
|
537
|
+
|
|
538
|
+
# Add all loop tasks to workflow
|
|
539
|
+
for task_obj in loop_tasks:
|
|
540
|
+
self.workflow.add_task(task_obj)
|
|
541
|
+
|
|
542
|
+
self._current_task = task_id
|
|
543
|
+
return self
|
|
544
|
+
|
|
545
|
+
def while_loop(
|
|
546
|
+
self,
|
|
547
|
+
task_id: str,
|
|
548
|
+
condition: str,
|
|
549
|
+
loop_body: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
|
|
550
|
+
**kwargs: Any,
|
|
551
|
+
) -> "WorkflowBuilder":
|
|
552
|
+
loop_builder = loop_body(WorkflowBuilder(f"{task_id}_loop", parent=self))
|
|
553
|
+
loop_tasks = list(loop_builder.workflow.tasks.values())
|
|
554
|
+
|
|
555
|
+
# Mark all loop body tasks as internal to prevent parallel dependency injection
|
|
556
|
+
for task_obj in loop_tasks:
|
|
557
|
+
task_obj.is_internal_loop_task = True
|
|
558
|
+
|
|
559
|
+
task = WhileOperator(
|
|
560
|
+
task_id=task_id,
|
|
561
|
+
condition=condition,
|
|
562
|
+
loop_body=loop_tasks,
|
|
563
|
+
**kwargs,
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
self._add_task(task, **kwargs)
|
|
567
|
+
|
|
568
|
+
# Fix: Only add the while task as dependency to the FIRST task in the loop body
|
|
569
|
+
# and preserve the original dependency chain within the loop
|
|
570
|
+
if loop_tasks:
|
|
571
|
+
first_task = loop_tasks[0]
|
|
572
|
+
if task_id not in first_task.dependencies:
|
|
573
|
+
first_task.dependencies.append(task_id)
|
|
574
|
+
|
|
575
|
+
# Add all loop tasks to workflow without modifying their dependencies further
|
|
576
|
+
for task_obj in loop_tasks:
|
|
577
|
+
self.workflow.add_task(task_obj)
|
|
578
|
+
|
|
235
579
|
self._current_task = task_id
|
|
236
580
|
return self
|
|
237
581
|
|
|
@@ -242,24 +586,106 @@ class WorkflowBuilder:
|
|
|
242
586
|
backoff_factor: float = 2.0,
|
|
243
587
|
) -> "WorkflowBuilder":
|
|
244
588
|
if self._current_task and isinstance(
|
|
245
|
-
self.workflow.tasks[self._current_task], TaskOperator
|
|
589
|
+
self.workflow.tasks[self._current_task], TaskOperator,
|
|
246
590
|
):
|
|
247
591
|
self.workflow.tasks[self._current_task].retry_policy = RetryPolicy(
|
|
248
|
-
max_retries=max_retries, delay=delay, backoff_factor=backoff_factor
|
|
592
|
+
max_retries=max_retries, delay=delay, backoff_factor=backoff_factor,
|
|
249
593
|
)
|
|
250
594
|
return self
|
|
251
595
|
|
|
252
596
|
def timeout(
|
|
253
|
-
self, timeout: timedelta, kill_on_timeout: bool = True
|
|
597
|
+
self, timeout: timedelta, kill_on_timeout: bool = True,
|
|
254
598
|
) -> "WorkflowBuilder":
|
|
255
599
|
if self._current_task and isinstance(
|
|
256
|
-
self.workflow.tasks[self._current_task], TaskOperator
|
|
600
|
+
self.workflow.tasks[self._current_task], TaskOperator,
|
|
257
601
|
):
|
|
258
602
|
self.workflow.tasks[self._current_task].timeout_policy = TimeoutPolicy(
|
|
259
|
-
timeout=timeout, kill_on_timeout=kill_on_timeout
|
|
603
|
+
timeout=timeout, kill_on_timeout=kill_on_timeout,
|
|
260
604
|
)
|
|
261
605
|
return self
|
|
262
606
|
|
|
607
|
+
# Phase 2: Event-based operators
|
|
608
|
+
def emit_event(self, task_id: str, event_name: str, **kwargs: Any) -> "WorkflowBuilder":
|
|
609
|
+
"""Emit an event that other workflows can wait for."""
|
|
610
|
+
task = EmitEventOperator(task_id=task_id, event_name=event_name, **kwargs)
|
|
611
|
+
self._add_task(task, **kwargs)
|
|
612
|
+
return self
|
|
613
|
+
|
|
614
|
+
def wait_for_event(
|
|
615
|
+
self, task_id: str, event_name: str, timeout_seconds: int | None = None, **kwargs: Any,
|
|
616
|
+
) -> "WorkflowBuilder":
|
|
617
|
+
"""Wait for an external event with optional timeout."""
|
|
618
|
+
task = WaitForEventOperator(
|
|
619
|
+
task_id=task_id, event_name=event_name, timeout_seconds=timeout_seconds, **kwargs,
|
|
620
|
+
)
|
|
621
|
+
self._add_task(task, **kwargs)
|
|
622
|
+
return self
|
|
623
|
+
|
|
624
|
+
# Phase 3: Callback hooks (applies to current task)
|
|
625
|
+
def on_success(self, success_task_id: str) -> "WorkflowBuilder":
|
|
626
|
+
"""Set the task to run when the current task succeeds."""
|
|
627
|
+
if self._current_task:
|
|
628
|
+
self.workflow.tasks[self._current_task].on_success_task_id = success_task_id
|
|
629
|
+
return self
|
|
630
|
+
|
|
631
|
+
def on_failure(self, failure_task_id: str) -> "WorkflowBuilder":
|
|
632
|
+
"""Set the task to run when the current task fails."""
|
|
633
|
+
if self._current_task:
|
|
634
|
+
self.workflow.tasks[self._current_task].on_failure_task_id = failure_task_id
|
|
635
|
+
return self
|
|
636
|
+
|
|
637
|
+
# Phase 4: Switch operator
|
|
638
|
+
def switch(
|
|
639
|
+
self,
|
|
640
|
+
task_id: str,
|
|
641
|
+
switch_on: str,
|
|
642
|
+
cases: dict[str, str],
|
|
643
|
+
default: str | None = None,
|
|
644
|
+
**kwargs: Any,
|
|
645
|
+
) -> "WorkflowBuilder":
|
|
646
|
+
"""Multi-branch switch/case operator."""
|
|
647
|
+
task = SwitchOperator(
|
|
648
|
+
task_id=task_id, switch_on=switch_on, cases=cases, default=default, **kwargs,
|
|
649
|
+
)
|
|
650
|
+
self._add_task(task, **kwargs)
|
|
651
|
+
return self
|
|
652
|
+
|
|
653
|
+
# Phase 1: Scheduling methods (delegate to Workflow)
|
|
654
|
+
def set_schedule(self, cron: str) -> "WorkflowBuilder":
|
|
655
|
+
"""Set the cron schedule for this workflow."""
|
|
656
|
+
self.workflow.set_schedule(cron)
|
|
657
|
+
return self
|
|
658
|
+
|
|
659
|
+
def set_start_date(self, start_date: datetime) -> "WorkflowBuilder":
|
|
660
|
+
"""Set when the schedule becomes active."""
|
|
661
|
+
self.workflow.set_start_date(start_date)
|
|
662
|
+
return self
|
|
663
|
+
|
|
664
|
+
def set_catchup(self, enabled: bool) -> "WorkflowBuilder":
|
|
665
|
+
"""Set whether to backfill missed runs."""
|
|
666
|
+
self.workflow.set_catchup(enabled)
|
|
667
|
+
return self
|
|
668
|
+
|
|
669
|
+
def set_paused(self, paused: bool) -> "WorkflowBuilder":
|
|
670
|
+
"""Set whether the workflow is paused."""
|
|
671
|
+
self.workflow.set_paused(paused)
|
|
672
|
+
return self
|
|
673
|
+
|
|
674
|
+
def add_tags(self, *tags: str) -> "WorkflowBuilder":
|
|
675
|
+
"""Add tags to the workflow."""
|
|
676
|
+
self.workflow.add_tags(*tags)
|
|
677
|
+
return self
|
|
678
|
+
|
|
679
|
+
def set_max_active_runs(self, count: int) -> "WorkflowBuilder":
|
|
680
|
+
"""Set maximum number of concurrent runs."""
|
|
681
|
+
self.workflow.set_max_active_runs(count)
|
|
682
|
+
return self
|
|
683
|
+
|
|
684
|
+
def set_default_retry_policy(self, policy: RetryPolicy) -> "WorkflowBuilder":
|
|
685
|
+
"""Set default retry policy for all tasks."""
|
|
686
|
+
self.workflow.set_default_retry_policy(policy)
|
|
687
|
+
return self
|
|
688
|
+
|
|
263
689
|
def build(self) -> Workflow:
|
|
264
690
|
if not self.workflow.start_task and self.workflow.tasks:
|
|
265
691
|
self.workflow.start_task = next(iter(self.workflow.tasks.keys()))
|