highway-dsl 0.0.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of highway-dsl might be problematic. Click here for more details.

@@ -1,11 +1,12 @@
1
1
  # workflow_dsl.py
2
- from typing import Any, Dict, List, Optional, Union, Callable, Type
3
- from enum import Enum
2
+ from abc import ABC
3
+ from collections.abc import Callable
4
4
  from datetime import datetime, timedelta
5
+ from enum import Enum
6
+ from typing import Any, Optional, Union
7
+
5
8
  import yaml
6
- import json
7
- from abc import ABC, abstractmethod
8
- from pydantic import BaseModel, Field, model_validator, ConfigDict
9
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
9
10
 
10
11
 
11
12
  class OperatorType(Enum):
@@ -16,6 +17,9 @@ class OperatorType(Enum):
16
17
  FOREACH = "foreach"
17
18
  SWITCH = "switch"
18
19
  TRY_CATCH = "try_catch"
20
+ WHILE = "while"
21
+ EMIT_EVENT = "emit_event"
22
+ WAIT_FOR_EVENT = "wait_for_event"
19
23
 
20
24
 
21
25
  class RetryPolicy(BaseModel):
@@ -27,38 +31,45 @@ class RetryPolicy(BaseModel):
27
31
  class TimeoutPolicy(BaseModel):
28
32
  timeout: timedelta = Field(..., description="Timeout duration")
29
33
  kill_on_timeout: bool = Field(
30
- True, description="Whether to kill the task on timeout"
34
+ True, description="Whether to kill the task on timeout",
31
35
  )
32
36
 
33
37
 
34
38
  class BaseOperator(BaseModel, ABC):
35
39
  task_id: str
36
40
  operator_type: OperatorType
37
- dependencies: List[str] = Field(default_factory=list)
38
- retry_policy: Optional[RetryPolicy] = None
39
- timeout_policy: Optional[TimeoutPolicy] = None
40
- metadata: Dict[str, Any] = Field(default_factory=dict)
41
+ dependencies: list[str] = Field(default_factory=list)
42
+ retry_policy: RetryPolicy | None = None
43
+ timeout_policy: TimeoutPolicy | None = None
44
+ metadata: dict[str, Any] = Field(default_factory=dict)
45
+ description: str = Field(default="", description="Task description")
46
+ # Phase 3: Callback hooks
47
+ on_success_task_id: str | None = Field(None, description="Task to run on success")
48
+ on_failure_task_id: str | None = Field(None, description="Task to run on failure")
49
+ is_internal_loop_task: bool = Field(
50
+ default=False, exclude=True,
51
+ ) # Mark if task is internal to a loop
41
52
 
42
53
  model_config = ConfigDict(use_enum_values=True, arbitrary_types_allowed=True)
43
54
 
44
55
 
45
56
  class TaskOperator(BaseOperator):
46
57
  function: str
47
- args: List[Any] = Field(default_factory=list)
48
- kwargs: Dict[str, Any] = Field(default_factory=dict)
49
- result_key: Optional[str] = None
58
+ args: list[Any] = Field(default_factory=list)
59
+ kwargs: dict[str, Any] = Field(default_factory=dict)
60
+ result_key: str | None = None
50
61
  operator_type: OperatorType = Field(OperatorType.TASK, frozen=True)
51
62
 
52
63
 
53
64
  class ConditionOperator(BaseOperator):
54
65
  condition: str
55
- if_true: str
56
- if_false: str
66
+ if_true: str | None
67
+ if_false: str | None
57
68
  operator_type: OperatorType = Field(OperatorType.CONDITION, frozen=True)
58
69
 
59
70
 
60
71
  class WaitOperator(BaseOperator):
61
- wait_for: Union[timedelta, datetime, str]
72
+ wait_for: timedelta | datetime | str
62
73
  operator_type: OperatorType = Field(OperatorType.WAIT, frozen=True)
63
74
 
64
75
  @model_validator(mode="before")
@@ -73,7 +84,7 @@ class WaitOperator(BaseOperator):
73
84
  data["wait_for"] = datetime.fromisoformat(wait_for.split(":", 1)[1])
74
85
  return data
75
86
 
76
- def model_dump(self, **kwargs) -> Dict[str, Any]:
87
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
77
88
  data = super().model_dump(**kwargs)
78
89
  wait_for = data["wait_for"]
79
90
  if isinstance(wait_for, timedelta):
@@ -84,44 +95,142 @@ class WaitOperator(BaseOperator):
84
95
 
85
96
 
86
97
  class ParallelOperator(BaseOperator):
87
- branches: Dict[str, List[str]] = Field(default_factory=dict)
98
+ branches: dict[str, list[str]] = Field(default_factory=dict)
99
+ timeout: int | None = Field(None, description="Optional timeout in seconds for branch execution")
88
100
  operator_type: OperatorType = Field(OperatorType.PARALLEL, frozen=True)
89
101
 
90
102
 
91
103
  class ForEachOperator(BaseOperator):
92
104
  items: str
93
- task_chain: List[str] = Field(default_factory=list)
105
+ loop_body: list[
106
+ Union[
107
+ TaskOperator,
108
+ ConditionOperator,
109
+ WaitOperator,
110
+ ParallelOperator,
111
+ "ForEachOperator",
112
+ "WhileOperator",
113
+ "EmitEventOperator",
114
+ "WaitForEventOperator",
115
+ "SwitchOperator",
116
+ ]
117
+ ] = Field(default_factory=list)
94
118
  operator_type: OperatorType = Field(OperatorType.FOREACH, frozen=True)
95
119
 
96
120
 
97
- class Workflow(BaseModel):
98
- name: str
99
- version: str = "1.0.0"
100
- description: str = ""
101
- tasks: Dict[
102
- str,
121
+ class WhileOperator(BaseOperator):
122
+ condition: str
123
+ loop_body: list[
103
124
  Union[
104
125
  TaskOperator,
105
126
  ConditionOperator,
106
127
  WaitOperator,
107
128
  ParallelOperator,
108
129
  ForEachOperator,
109
- ],
130
+ "WhileOperator",
131
+ "EmitEventOperator",
132
+ "WaitForEventOperator",
133
+ "SwitchOperator",
134
+ ]
135
+ ] = Field(default_factory=list)
136
+ operator_type: OperatorType = Field(OperatorType.WHILE, frozen=True)
137
+
138
+
139
+ class EmitEventOperator(BaseOperator):
140
+ """Phase 2: Emit an event that other workflows can wait for."""
141
+ event_name: str = Field(..., description="Name of the event to emit")
142
+ payload: dict[str, Any] = Field(default_factory=dict, description="Event payload data")
143
+ operator_type: OperatorType = Field(OperatorType.EMIT_EVENT, frozen=True)
144
+
145
+
146
+ class WaitForEventOperator(BaseOperator):
147
+ """Phase 2: Wait for an external event with optional timeout."""
148
+ event_name: str = Field(..., description="Name of the event to wait for")
149
+ timeout_seconds: int | None = Field(None, description="Timeout in seconds (None = wait forever)")
150
+ operator_type: OperatorType = Field(OperatorType.WAIT_FOR_EVENT, frozen=True)
151
+
152
+
153
+ class SwitchOperator(BaseOperator):
154
+ """Phase 4: Multi-branch switch/case operator."""
155
+ switch_on: str = Field(..., description="Expression to evaluate for switch")
156
+ cases: dict[str, str] = Field(default_factory=dict, description="Map of case values to task IDs")
157
+ default: str | None = Field(None, description="Default task ID if no case matches")
158
+ operator_type: OperatorType = Field(OperatorType.SWITCH, frozen=True)
159
+
160
+
161
+ class Workflow(BaseModel):
162
+ name: str
163
+ version: str = "1.1.0"
164
+ description: str = ""
165
+ tasks: dict[
166
+ str,
167
+ TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
110
168
  ] = Field(default_factory=dict)
111
- variables: Dict[str, Any] = Field(default_factory=dict)
112
- start_task: Optional[str] = None
169
+ variables: dict[str, Any] = Field(default_factory=dict)
170
+ start_task: str | None = None
171
+
172
+ # Phase 1: Scheduling metadata
173
+ schedule: str | None = Field(None, description="Cron expression for scheduled execution")
174
+ start_date: datetime | None = Field(None, description="When the schedule becomes active")
175
+ catchup: bool = Field(False, description="Whether to backfill missed runs")
176
+ is_paused: bool = Field(False, description="Whether the workflow is paused")
177
+ tags: list[str] = Field(default_factory=list, description="Workflow categorization tags")
178
+ max_active_runs: int = Field(1, description="Maximum number of concurrent runs")
179
+ default_retry_policy: RetryPolicy | None = Field(None, description="Default retry policy for all tasks")
180
+
181
+ @model_validator(mode="before")
182
+ @classmethod
183
+ def validate_workflow_name_and_version(cls, data: Any) -> Any:
184
+ """Validate workflow name and version don't contain '__' (double underscore).
185
+
186
+ The double underscore is reserved as a separator for display purposes:
187
+ {workflow_name}__{version}__{step_name}
188
+
189
+ Workflow names must match: ^[a-z][a-z0-9_]*$ (lowercase, alphanumeric, single underscore)
190
+ Workflow versions must match: ^[a-zA-Z0-9._-]+$ (semver compatible)
191
+ """
192
+ import re
193
+
194
+ if isinstance(data, dict):
195
+ name = data.get("name", "")
196
+ version = data.get("version", "")
197
+
198
+ # Check for double underscore (reserved separator)
199
+ if "__" in name:
200
+ msg = f"Workflow name '{name}' cannot contain '__' (double underscore) - it's reserved as a separator"
201
+ raise ValueError(msg)
202
+
203
+ if "__" in version:
204
+ msg = f"Workflow version '{version}' cannot contain '__' (double underscore) - it's reserved as a separator"
205
+ raise ValueError(msg)
206
+
207
+ # Validate workflow name format
208
+ if name and not re.match(r"^[a-z][a-z0-9_]*$", name):
209
+ msg = f"Workflow name '{name}' must start with lowercase letter and contain only lowercase letters, digits, and single underscores"
210
+ raise ValueError(msg)
211
+
212
+ # Validate workflow version format (semver compatible)
213
+ if version and not re.match(r"^[a-zA-Z0-9._-]+$", version):
214
+ msg = f"Workflow version '{version}' must contain only alphanumeric characters, dots, hyphens, and underscores (semver compatible)"
215
+ raise ValueError(msg)
216
+
217
+ return data
113
218
 
114
219
  @model_validator(mode="before")
115
220
  @classmethod
116
221
  def validate_tasks(cls, data: Any) -> Any:
117
222
  if isinstance(data, dict) and "tasks" in data:
118
223
  validated_tasks = {}
119
- operator_classes: Dict[str, Type[BaseOperator]] = {
224
+ operator_classes: dict[str, type[BaseOperator]] = {
120
225
  OperatorType.TASK.value: TaskOperator,
121
226
  OperatorType.CONDITION.value: ConditionOperator,
122
227
  OperatorType.WAIT.value: WaitOperator,
123
228
  OperatorType.PARALLEL.value: ParallelOperator,
124
229
  OperatorType.FOREACH.value: ForEachOperator,
230
+ OperatorType.WHILE.value: WhileOperator,
231
+ OperatorType.EMIT_EVENT.value: EmitEventOperator,
232
+ OperatorType.WAIT_FOR_EVENT.value: WaitForEventOperator,
233
+ OperatorType.SWITCH.value: SwitchOperator,
125
234
  }
126
235
  for task_id, task_data in data["tasks"].items():
127
236
  operator_type = task_data.get("operator_type")
@@ -129,24 +238,19 @@ class Workflow(BaseModel):
129
238
  operator_class = operator_classes[operator_type]
130
239
  validated_tasks[task_id] = operator_class.model_validate(task_data)
131
240
  else:
132
- raise ValueError(f"Unknown operator type: {operator_type}")
241
+ msg = f"Unknown operator type: {operator_type}"
242
+ raise ValueError(msg)
133
243
  data["tasks"] = validated_tasks
134
244
  return data
135
245
 
136
246
  def add_task(
137
247
  self,
138
- task: Union[
139
- TaskOperator,
140
- ConditionOperator,
141
- WaitOperator,
142
- ParallelOperator,
143
- ForEachOperator,
144
- ],
248
+ task: TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
145
249
  ) -> "Workflow":
146
250
  self.tasks[task.task_id] = task
147
251
  return self
148
252
 
149
- def set_variables(self, variables: Dict[str, Any]) -> "Workflow":
253
+ def set_variables(self, variables: dict[str, Any]) -> "Workflow":
150
254
  self.variables.update(variables)
151
255
  return self
152
256
 
@@ -154,6 +258,42 @@ class Workflow(BaseModel):
154
258
  self.start_task = task_id
155
259
  return self
156
260
 
261
+ # Phase 1: Scheduling methods
262
+ def set_schedule(self, cron: str) -> "Workflow":
263
+ """Set the cron schedule for this workflow."""
264
+ self.schedule = cron
265
+ return self
266
+
267
+ def set_start_date(self, start_date: datetime) -> "Workflow":
268
+ """Set when the schedule becomes active."""
269
+ self.start_date = start_date
270
+ return self
271
+
272
+ def set_catchup(self, enabled: bool) -> "Workflow":
273
+ """Set whether to backfill missed runs."""
274
+ self.catchup = enabled
275
+ return self
276
+
277
+ def set_paused(self, paused: bool) -> "Workflow":
278
+ """Set whether the workflow is paused."""
279
+ self.is_paused = paused
280
+ return self
281
+
282
+ def add_tags(self, *tags: str) -> "Workflow":
283
+ """Add tags to the workflow."""
284
+ self.tags.extend(tags)
285
+ return self
286
+
287
+ def set_max_active_runs(self, count: int) -> "Workflow":
288
+ """Set maximum number of concurrent runs."""
289
+ self.max_active_runs = count
290
+ return self
291
+
292
+ def set_default_retry_policy(self, policy: RetryPolicy) -> "Workflow":
293
+ """Set default retry policy for all tasks."""
294
+ self.default_retry_policy = policy
295
+ return self
296
+
157
297
  def to_yaml(self) -> str:
158
298
  data = self.model_dump(mode="json", by_alias=True, exclude_none=True)
159
299
  return yaml.dump(data, default_flow_style=False)
@@ -161,6 +301,68 @@ class Workflow(BaseModel):
161
301
  def to_json(self) -> str:
162
302
  return self.model_dump_json(indent=2)
163
303
 
304
+ def to_mermaid(self) -> str:
305
+ """ convert to mermaid state diagram format """
306
+ lines = ["stateDiagram-v2"]
307
+
308
+ all_dependencies = {dep for task in self.tasks.values() for dep in task.dependencies}
309
+
310
+ for task_id, task in self.tasks.items():
311
+ # Add state with description for regular tasks
312
+ if task.description and not isinstance(task, (ForEachOperator, WhileOperator)):
313
+ lines.append(f' state "{task.description}" as {task_id}')
314
+
315
+ # Add dependencies
316
+ if not task.dependencies:
317
+ if self.start_task == task_id or not self.start_task:
318
+ lines.append(f' [*] --> {task_id}')
319
+ else:
320
+ for dep in task.dependencies:
321
+ lines.append(f' {dep} --> {task_id}')
322
+
323
+ # Add transitions for conditional operator
324
+ if isinstance(task, ConditionOperator):
325
+ if task.if_true:
326
+ lines.append(f' {task_id} --> {task.if_true} : True')
327
+ if task.if_false:
328
+ lines.append(f' {task_id} --> {task.if_false} : False')
329
+
330
+ # Add composite state for parallel operator
331
+ if isinstance(task, ParallelOperator):
332
+ lines.append(f' state {task_id} {{')
333
+ for i, branch in enumerate(task.branches):
334
+ lines.append(f' state "Branch {i+1}" as {branch}')
335
+ if i < len(task.branches) - 1:
336
+ lines.append(' --')
337
+ lines.append(' }')
338
+
339
+ # Add composite state for foreach operator
340
+ if isinstance(task, ForEachOperator):
341
+ lines.append(f' state {task_id} {{')
342
+ for sub_task in task.loop_body:
343
+ if sub_task.description:
344
+ lines.append(f' state "{sub_task.description}" as {sub_task.task_id}')
345
+ else:
346
+ lines.append(f' {sub_task.task_id}')
347
+ lines.append(' }')
348
+
349
+ # Add composite state for while operator
350
+ if isinstance(task, WhileOperator):
351
+ lines.append(f' state {task_id} {{')
352
+ for sub_task in task.loop_body:
353
+ if sub_task.description:
354
+ lines.append(f' state "{sub_task.description}" as {sub_task.task_id}')
355
+ else:
356
+ lines.append(f' {sub_task.task_id}')
357
+ lines.append(' }')
358
+
359
+ # End states
360
+ if task_id not in all_dependencies:
361
+ if not (isinstance(task, ConditionOperator) and (task.if_true or task.if_false)):
362
+ lines.append(f' {task_id} --> [*]')
363
+
364
+ return "\n".join(lines)
365
+
164
366
  @classmethod
165
367
  def from_yaml(cls, yaml_str: str) -> "Workflow":
166
368
  data = yaml.safe_load(yaml_str)
@@ -172,66 +374,208 @@ class Workflow(BaseModel):
172
374
 
173
375
 
174
376
  class WorkflowBuilder:
175
- def __init__(self, name: str, existing_workflow: Optional[Workflow] = None):
377
+ def __init__(
378
+ self,
379
+ name: str,
380
+ existing_workflow: Workflow | None = None,
381
+ parent: Optional["WorkflowBuilder"] = None,
382
+ ) -> None:
176
383
  if existing_workflow:
177
384
  self.workflow = existing_workflow
178
385
  else:
179
- self.workflow = Workflow(name=name)
180
- self._current_task: Optional[str] = None
386
+ self.workflow = Workflow(
387
+ name=name,
388
+ version="1.1.0",
389
+ description="",
390
+ tasks={},
391
+ variables={},
392
+ start_task=None,
393
+ schedule=None,
394
+ start_date=None,
395
+ catchup=False,
396
+ is_paused=False,
397
+ tags=[],
398
+ max_active_runs=1,
399
+ default_retry_policy=None,
400
+ )
401
+ self._current_task: str | None = None
402
+ self.parent = parent
403
+
404
+ def _add_task(
405
+ self,
406
+ task: TaskOperator | ConditionOperator | WaitOperator | ParallelOperator | ForEachOperator | WhileOperator | EmitEventOperator | WaitForEventOperator | SwitchOperator,
407
+ **kwargs: Any,
408
+ ) -> None:
409
+ dependencies = kwargs.get("dependencies", [])
410
+ if self._current_task and not dependencies:
411
+ dependencies.append(self._current_task)
412
+
413
+ task.dependencies = sorted(set(dependencies))
181
414
 
182
- def task(self, task_id: str, function: str, **kwargs) -> "WorkflowBuilder":
183
- task = TaskOperator(task_id=task_id, function=function, **kwargs)
184
- if self._current_task:
185
- task.dependencies.append(self._current_task)
186
415
  self.workflow.add_task(task)
187
- self._current_task = task_id
416
+ self._current_task = task.task_id
417
+
418
+ def task(self, task_id: str, function: str, **kwargs: Any) -> "WorkflowBuilder":
419
+ task = TaskOperator(task_id=task_id, function=function, **kwargs)
420
+ self._add_task(task, **kwargs)
188
421
  return self
189
422
 
190
423
  def condition(
191
- self, task_id: str, condition: str, if_true: str, if_false: str, **kwargs
424
+ self,
425
+ task_id: str,
426
+ condition: str,
427
+ if_true: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
428
+ if_false: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
429
+ **kwargs: Any,
192
430
  ) -> "WorkflowBuilder":
431
+ true_builder = if_true(WorkflowBuilder(f"{task_id}_true", parent=self))
432
+ false_builder = if_false(WorkflowBuilder(f"{task_id}_false", parent=self))
433
+
434
+ true_tasks = list(true_builder.workflow.tasks.keys())
435
+ false_tasks = list(false_builder.workflow.tasks.keys())
436
+
193
437
  task = ConditionOperator(
194
438
  task_id=task_id,
195
439
  condition=condition,
196
- if_true=if_true,
197
- if_false=if_false,
440
+ if_true=true_tasks[0] if true_tasks else None,
441
+ if_false=false_tasks[0] if false_tasks else None,
198
442
  **kwargs,
199
443
  )
200
- if self._current_task:
201
- task.dependencies.append(self._current_task)
202
- self.workflow.add_task(task)
444
+
445
+ self._add_task(task, **kwargs)
446
+
447
+ for task_obj in true_builder.workflow.tasks.values():
448
+ # Only add the condition task as dependency, preserve original dependencies
449
+ if task_id not in task_obj.dependencies:
450
+ task_obj.dependencies.append(task_id)
451
+ self.workflow.add_task(task_obj)
452
+ for task_obj in false_builder.workflow.tasks.values():
453
+ # Only add the condition task as dependency, preserve original dependencies
454
+ if task_id not in task_obj.dependencies:
455
+ task_obj.dependencies.append(task_id)
456
+ self.workflow.add_task(task_obj)
457
+
203
458
  self._current_task = task_id
204
459
  return self
205
460
 
206
461
  def wait(
207
- self, task_id: str, wait_for: Union[timedelta, datetime, str], **kwargs
462
+ self, task_id: str, wait_for: timedelta | datetime | str, **kwargs: Any,
208
463
  ) -> "WorkflowBuilder":
209
464
  task = WaitOperator(task_id=task_id, wait_for=wait_for, **kwargs)
210
- if self._current_task:
211
- task.dependencies.append(self._current_task)
212
- self.workflow.add_task(task)
213
- self._current_task = task_id
465
+ self._add_task(task, **kwargs)
214
466
  return self
215
467
 
216
468
  def parallel(
217
- self, task_id: str, branches: Dict[str, List[str]], **kwargs
469
+ self,
470
+ task_id: str,
471
+ branches: dict[str, Callable[["WorkflowBuilder"], "WorkflowBuilder"]],
472
+ **kwargs: Any,
218
473
  ) -> "WorkflowBuilder":
219
- task = ParallelOperator(task_id=task_id, branches=branches, **kwargs)
220
- if self._current_task:
221
- task.dependencies.append(self._current_task)
222
- self.workflow.add_task(task)
474
+ branch_builders = {}
475
+ for name, branch_func in branches.items():
476
+ branch_builder = branch_func(
477
+ WorkflowBuilder(f"{task_id}_{name}", parent=self),
478
+ )
479
+ branch_builders[name] = branch_builder
480
+
481
+ branch_tasks = {
482
+ name: list(builder.workflow.tasks.keys())
483
+ for name, builder in branch_builders.items()
484
+ }
485
+
486
+ task = ParallelOperator(task_id=task_id, branches=branch_tasks, **kwargs)
487
+
488
+ self._add_task(task, **kwargs)
489
+
490
+ for builder in branch_builders.values():
491
+ for task_obj in builder.workflow.tasks.values():
492
+ # Only add the parallel task as dependency to non-internal tasks,
493
+ # preserve original dependencies
494
+ if (
495
+ not getattr(task_obj, "is_internal_loop_task", False)
496
+ and task_id not in task_obj.dependencies
497
+ ):
498
+ task_obj.dependencies.append(task_id)
499
+ self.workflow.add_task(task_obj)
500
+
223
501
  self._current_task = task_id
224
502
  return self
225
503
 
226
504
  def foreach(
227
- self, task_id: str, items: str, task_chain: List[str], **kwargs
505
+ self,
506
+ task_id: str,
507
+ items: str,
508
+ loop_body: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
509
+ **kwargs: Any,
228
510
  ) -> "WorkflowBuilder":
511
+ # Create a temporary builder for the loop body.
512
+ temp_builder = WorkflowBuilder(f"{task_id}_loop", parent=self)
513
+ loop_builder = loop_body(temp_builder)
514
+ loop_tasks = list(loop_builder.workflow.tasks.values())
515
+
516
+ # Mark all loop body tasks as internal to prevent parallel dependency injection
517
+ for task_obj in loop_tasks:
518
+ task_obj.is_internal_loop_task = True
519
+
520
+ # Create the foreach operator
229
521
  task = ForEachOperator(
230
- task_id=task_id, items=items, task_chain=task_chain, **kwargs
522
+ task_id=task_id,
523
+ items=items,
524
+ loop_body=loop_tasks,
525
+ **kwargs,
231
526
  )
232
- if self._current_task:
233
- task.dependencies.append(self._current_task)
234
- self.workflow.add_task(task)
527
+
528
+ # Add the foreach task to workflow to establish initial dependencies
529
+ self._add_task(task, **kwargs)
530
+
531
+ # Add the foreach task as dependency to the FIRST task in the loop body
532
+ # and preserve the original dependency chain within the loop
533
+ if loop_tasks:
534
+ first_task = loop_tasks[0]
535
+ if task_id not in first_task.dependencies:
536
+ first_task.dependencies.append(task_id)
537
+
538
+ # Add all loop tasks to workflow
539
+ for task_obj in loop_tasks:
540
+ self.workflow.add_task(task_obj)
541
+
542
+ self._current_task = task_id
543
+ return self
544
+
545
+ def while_loop(
546
+ self,
547
+ task_id: str,
548
+ condition: str,
549
+ loop_body: Callable[["WorkflowBuilder"], "WorkflowBuilder"],
550
+ **kwargs: Any,
551
+ ) -> "WorkflowBuilder":
552
+ loop_builder = loop_body(WorkflowBuilder(f"{task_id}_loop", parent=self))
553
+ loop_tasks = list(loop_builder.workflow.tasks.values())
554
+
555
+ # Mark all loop body tasks as internal to prevent parallel dependency injection
556
+ for task_obj in loop_tasks:
557
+ task_obj.is_internal_loop_task = True
558
+
559
+ task = WhileOperator(
560
+ task_id=task_id,
561
+ condition=condition,
562
+ loop_body=loop_tasks,
563
+ **kwargs,
564
+ )
565
+
566
+ self._add_task(task, **kwargs)
567
+
568
+ # Fix: Only add the while task as dependency to the FIRST task in the loop body
569
+ # and preserve the original dependency chain within the loop
570
+ if loop_tasks:
571
+ first_task = loop_tasks[0]
572
+ if task_id not in first_task.dependencies:
573
+ first_task.dependencies.append(task_id)
574
+
575
+ # Add all loop tasks to workflow without modifying their dependencies further
576
+ for task_obj in loop_tasks:
577
+ self.workflow.add_task(task_obj)
578
+
235
579
  self._current_task = task_id
236
580
  return self
237
581
 
@@ -242,24 +586,106 @@ class WorkflowBuilder:
242
586
  backoff_factor: float = 2.0,
243
587
  ) -> "WorkflowBuilder":
244
588
  if self._current_task and isinstance(
245
- self.workflow.tasks[self._current_task], TaskOperator
589
+ self.workflow.tasks[self._current_task], TaskOperator,
246
590
  ):
247
591
  self.workflow.tasks[self._current_task].retry_policy = RetryPolicy(
248
- max_retries=max_retries, delay=delay, backoff_factor=backoff_factor
592
+ max_retries=max_retries, delay=delay, backoff_factor=backoff_factor,
249
593
  )
250
594
  return self
251
595
 
252
596
  def timeout(
253
- self, timeout: timedelta, kill_on_timeout: bool = True
597
+ self, timeout: timedelta, kill_on_timeout: bool = True,
254
598
  ) -> "WorkflowBuilder":
255
599
  if self._current_task and isinstance(
256
- self.workflow.tasks[self._current_task], TaskOperator
600
+ self.workflow.tasks[self._current_task], TaskOperator,
257
601
  ):
258
602
  self.workflow.tasks[self._current_task].timeout_policy = TimeoutPolicy(
259
- timeout=timeout, kill_on_timeout=kill_on_timeout
603
+ timeout=timeout, kill_on_timeout=kill_on_timeout,
260
604
  )
261
605
  return self
262
606
 
607
+ # Phase 2: Event-based operators
608
+ def emit_event(self, task_id: str, event_name: str, **kwargs: Any) -> "WorkflowBuilder":
609
+ """Emit an event that other workflows can wait for."""
610
+ task = EmitEventOperator(task_id=task_id, event_name=event_name, **kwargs)
611
+ self._add_task(task, **kwargs)
612
+ return self
613
+
614
+ def wait_for_event(
615
+ self, task_id: str, event_name: str, timeout_seconds: int | None = None, **kwargs: Any,
616
+ ) -> "WorkflowBuilder":
617
+ """Wait for an external event with optional timeout."""
618
+ task = WaitForEventOperator(
619
+ task_id=task_id, event_name=event_name, timeout_seconds=timeout_seconds, **kwargs,
620
+ )
621
+ self._add_task(task, **kwargs)
622
+ return self
623
+
624
+ # Phase 3: Callback hooks (applies to current task)
625
+ def on_success(self, success_task_id: str) -> "WorkflowBuilder":
626
+ """Set the task to run when the current task succeeds."""
627
+ if self._current_task:
628
+ self.workflow.tasks[self._current_task].on_success_task_id = success_task_id
629
+ return self
630
+
631
+ def on_failure(self, failure_task_id: str) -> "WorkflowBuilder":
632
+ """Set the task to run when the current task fails."""
633
+ if self._current_task:
634
+ self.workflow.tasks[self._current_task].on_failure_task_id = failure_task_id
635
+ return self
636
+
637
+ # Phase 4: Switch operator
638
+ def switch(
639
+ self,
640
+ task_id: str,
641
+ switch_on: str,
642
+ cases: dict[str, str],
643
+ default: str | None = None,
644
+ **kwargs: Any,
645
+ ) -> "WorkflowBuilder":
646
+ """Multi-branch switch/case operator."""
647
+ task = SwitchOperator(
648
+ task_id=task_id, switch_on=switch_on, cases=cases, default=default, **kwargs,
649
+ )
650
+ self._add_task(task, **kwargs)
651
+ return self
652
+
653
+ # Phase 1: Scheduling methods (delegate to Workflow)
654
+ def set_schedule(self, cron: str) -> "WorkflowBuilder":
655
+ """Set the cron schedule for this workflow."""
656
+ self.workflow.set_schedule(cron)
657
+ return self
658
+
659
+ def set_start_date(self, start_date: datetime) -> "WorkflowBuilder":
660
+ """Set when the schedule becomes active."""
661
+ self.workflow.set_start_date(start_date)
662
+ return self
663
+
664
+ def set_catchup(self, enabled: bool) -> "WorkflowBuilder":
665
+ """Set whether to backfill missed runs."""
666
+ self.workflow.set_catchup(enabled)
667
+ return self
668
+
669
+ def set_paused(self, paused: bool) -> "WorkflowBuilder":
670
+ """Set whether the workflow is paused."""
671
+ self.workflow.set_paused(paused)
672
+ return self
673
+
674
+ def add_tags(self, *tags: str) -> "WorkflowBuilder":
675
+ """Add tags to the workflow."""
676
+ self.workflow.add_tags(*tags)
677
+ return self
678
+
679
+ def set_max_active_runs(self, count: int) -> "WorkflowBuilder":
680
+ """Set maximum number of concurrent runs."""
681
+ self.workflow.set_max_active_runs(count)
682
+ return self
683
+
684
+ def set_default_retry_policy(self, policy: RetryPolicy) -> "WorkflowBuilder":
685
+ """Set default retry policy for all tasks."""
686
+ self.workflow.set_default_retry_policy(policy)
687
+ return self
688
+
263
689
  def build(self) -> Workflow:
264
690
  if not self.workflow.start_task and self.workflow.tasks:
265
691
  self.workflow.start_task = next(iter(self.workflow.tasks.keys()))