zenml-nightly 0.84.1.dev20250805__py3-none-any.whl → 0.84.1.dev20250806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.84.1.dev20250805
1
+ 0.84.1.dev20250806
@@ -0,0 +1,27 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ #
15
+ # Parts of the `prepare_or_run_pipeline()` method of this file are
16
+ # inspired by the Kubernetes dag runner implementation of tfx
17
+ """Kubernetes orchestrator constants."""
18
+
19
+ ENV_ZENML_KUBERNETES_RUN_ID = "ZENML_KUBERNETES_RUN_ID"
20
+ KUBERNETES_SECRET_TOKEN_KEY_NAME = "zenml_api_token"
21
+ KUBERNETES_CRON_JOB_METADATA_KEY = "cron_job_name"
22
+ # Annotation keys
23
+ ORCHESTRATOR_ANNOTATION_KEY = "zenml.io/orchestrator"
24
+ RUN_ID_ANNOTATION_KEY = "zenml.io/run-id"
25
+ ORCHESTRATOR_RUN_ID_ANNOTATION_KEY = "zenml.io/orchestrator-run-id"
26
+ STEP_NAME_ANNOTATION_KEY = "zenml.io/step-name"
27
+ STEP_OPERATOR_ANNOTATION_KEY = "zenml.io/step-operator"
@@ -15,7 +15,13 @@
15
15
 
16
16
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
17
17
 
18
- from pydantic import Field, NonNegativeInt, PositiveInt, field_validator
18
+ from pydantic import (
19
+ Field,
20
+ NonNegativeInt,
21
+ PositiveFloat,
22
+ PositiveInt,
23
+ field_validator,
24
+ )
19
25
 
20
26
  from zenml.config.base_settings import BaseSettings
21
27
  from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE
@@ -23,6 +29,7 @@ from zenml.integrations.kubernetes import KUBERNETES_ORCHESTRATOR_FLAVOR
23
29
  from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
24
30
  from zenml.models import ServiceConnectorRequirements
25
31
  from zenml.orchestrators import BaseOrchestratorConfig, BaseOrchestratorFlavor
32
+ from zenml.utils import deprecation_utils
26
33
 
27
34
  if TYPE_CHECKING:
28
35
  from zenml.integrations.kubernetes.orchestrators import (
@@ -42,16 +49,6 @@ class KubernetesOrchestratorSettings(BaseSettings):
42
49
  description="Whether to wait for all pipeline steps to complete. "
43
50
  "When `False`, the client returns immediately and execution continues asynchronously.",
44
51
  )
45
- timeout: int = Field(
46
- default=0,
47
- description="Maximum seconds to wait for synchronous runs. Set to `0` for unlimited duration.",
48
- )
49
- stream_step_logs: bool = Field(
50
- default=True,
51
- description="If `True`, the orchestrator pod will stream the logs "
52
- "of the step pods. This only has an effect if specified on the "
53
- "pipeline, not on individual steps.",
54
- )
55
52
  service_account_name: Optional[str] = Field(
56
53
  default=None,
57
54
  description="Kubernetes service account for the orchestrator pod. "
@@ -74,25 +71,9 @@ class KubernetesOrchestratorSettings(BaseSettings):
74
71
  default=None,
75
72
  description="Pod configuration for the orchestrator container that launches step pods.",
76
73
  )
77
- pod_name_prefix: Optional[str] = Field(
74
+ job_name_prefix: Optional[str] = Field(
78
75
  default=None,
79
- description="Custom prefix for generated pod names. Helps identify pods in the cluster.",
80
- )
81
- pod_startup_timeout: int = Field(
82
- default=600,
83
- description="Maximum seconds to wait for step pods to start. Default is 10 minutes.",
84
- )
85
- pod_failure_max_retries: int = Field(
86
- default=3,
87
- description="Maximum retry attempts when step pods fail to start.",
88
- )
89
- pod_failure_retry_delay: int = Field(
90
- default=10,
91
- description="Delay in seconds between pod failure retry attempts.",
92
- )
93
- pod_failure_backoff: float = Field(
94
- default=1.0,
95
- description="Exponential backoff factor for retry delays. Values > 1.0 increase delay with each retry.",
76
+ description="Prefix for the job name.",
96
77
  )
97
78
  max_parallelism: Optional[PositiveInt] = Field(
98
79
  default=None,
@@ -100,19 +81,20 @@ class KubernetesOrchestratorSettings(BaseSettings):
100
81
  )
101
82
  successful_jobs_history_limit: Optional[NonNegativeInt] = Field(
102
83
  default=None,
103
- description="Number of successful scheduled jobs to retain in cluster history.",
84
+ description="Number of successful scheduled jobs to retain in history.",
104
85
  )
105
86
  failed_jobs_history_limit: Optional[NonNegativeInt] = Field(
106
87
  default=None,
107
- description="Number of failed scheduled jobs to retain in cluster history.",
88
+ description="Number of failed scheduled jobs to retain in history.",
108
89
  )
109
90
  ttl_seconds_after_finished: Optional[NonNegativeInt] = Field(
110
91
  default=None,
111
- description="Seconds to keep finished scheduled jobs before automatic cleanup.",
92
+ description="Seconds to keep finished jobs before automatic cleanup.",
112
93
  )
113
94
  active_deadline_seconds: Optional[NonNegativeInt] = Field(
114
95
  default=None,
115
- description="Deadline in seconds for the active pod. If the pod is inactive for this many seconds, it will be terminated.",
96
+ description="Job deadline in seconds. If the job doesn't finish "
97
+ "within this time, it will be terminated.",
116
98
  )
117
99
  backoff_limit_margin: NonNegativeInt = Field(
118
100
  default=0,
@@ -131,6 +113,10 @@ class KubernetesOrchestratorSettings(BaseSettings):
131
113
  "the chance of the server receiving the maximum amount of retry "
132
114
  "requests.",
133
115
  )
116
+ orchestrator_job_backoff_limit: NonNegativeInt = Field(
117
+ default=3,
118
+ description="The backoff limit for the orchestrator job.",
119
+ )
134
120
  fail_on_container_waiting_reasons: Optional[List[str]] = Field(
135
121
  default=[
136
122
  "InvalidImageName",
@@ -144,11 +130,21 @@ class KubernetesOrchestratorSettings(BaseSettings):
144
130
  "`pod.status.containerStatuses[*].state.waiting.reason` of a job pod, "
145
131
  "should cause the job to fail immediately.",
146
132
  )
147
- job_monitoring_interval: int = Field(
133
+ job_monitoring_interval: PositiveFloat = Field(
148
134
  default=3,
149
135
  description="The interval in seconds to monitor the job. Each interval "
150
- "is used to check for container issues and streaming logs for the "
151
- "job pods.",
136
+ "is used to check for container issues for the job pods.",
137
+ )
138
+ job_monitoring_delay: PositiveFloat = Field(
139
+ default=0.0,
140
+ description="The delay in seconds to wait between monitoring active "
141
+ "step jobs. This can be used to reduce load on the Kubernetes API "
142
+ "server.",
143
+ )
144
+ interrupt_check_interval: PositiveFloat = Field(
145
+ default=1.0,
146
+ description="The interval in seconds to check for run interruptions.",
147
+ ge=0.5,
152
148
  )
153
149
  pod_failure_policy: Optional[Dict[str, Any]] = Field(
154
150
  default=None,
@@ -169,6 +165,53 @@ class KubernetesOrchestratorSettings(BaseSettings):
169
165
  description="When stopping a pipeline run, the amount of seconds to wait for a step pod to shutdown gracefully.",
170
166
  )
171
167
 
168
+ # Deprecated fields
169
+ timeout: Optional[int] = Field(
170
+ default=None,
171
+ deprecated=True,
172
+ description="DEPRECATED/UNUSED.",
173
+ )
174
+ stream_step_logs: Optional[bool] = Field(
175
+ default=None,
176
+ deprecated=True,
177
+ description="DEPRECATED/UNUSED.",
178
+ )
179
+ pod_startup_timeout: Optional[int] = Field(
180
+ default=None,
181
+ description="DEPRECATED/UNUSED.",
182
+ deprecated=True,
183
+ )
184
+ pod_failure_max_retries: Optional[int] = Field(
185
+ default=None,
186
+ description="DEPRECATED/UNUSED.",
187
+ deprecated=True,
188
+ )
189
+ pod_failure_retry_delay: Optional[int] = Field(
190
+ default=None,
191
+ description="DEPRECATED/UNUSED.",
192
+ deprecated=True,
193
+ )
194
+ pod_failure_backoff: Optional[float] = Field(
195
+ default=None,
196
+ description="DEPRECATED/UNUSED.",
197
+ deprecated=True,
198
+ )
199
+ pod_name_prefix: Optional[str] = Field(
200
+ default=None,
201
+ deprecated=True,
202
+ description="DEPRECATED/UNUSED.",
203
+ )
204
+
205
+ _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
206
+ "timeout",
207
+ "stream_step_logs",
208
+ "pod_startup_timeout",
209
+ "pod_failure_max_retries",
210
+ "pod_failure_retry_delay",
211
+ "pod_failure_backoff",
212
+ ("pod_name_prefix", "job_name_prefix"),
213
+ )
214
+
172
215
  @field_validator("pod_failure_policy", mode="before")
173
216
  @classmethod
174
217
  def _convert_pod_failure_policy(cls, value: Any) -> Any:
@@ -13,9 +13,9 @@
13
13
  # permissions and limitations under the License.
14
14
  """Kubernetes step operator flavor."""
15
15
 
16
- from typing import TYPE_CHECKING, Optional, Type
16
+ from typing import TYPE_CHECKING, List, Optional, Type
17
17
 
18
- from pydantic import Field
18
+ from pydantic import Field, NonNegativeInt
19
19
 
20
20
  from zenml.config.base_settings import BaseSettings
21
21
  from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE
@@ -23,6 +23,7 @@ from zenml.integrations.kubernetes import KUBERNETES_STEP_OPERATOR_FLAVOR
23
23
  from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
24
24
  from zenml.models import ServiceConnectorRequirements
25
25
  from zenml.step_operators import BaseStepOperatorConfig, BaseStepOperatorFlavor
26
+ from zenml.utils import deprecation_utils
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from zenml.integrations.kubernetes.step_operators import (
@@ -31,11 +32,7 @@ if TYPE_CHECKING:
31
32
 
32
33
 
33
34
  class KubernetesStepOperatorSettings(BaseSettings):
34
- """Settings for the Kubernetes step operator.
35
-
36
- Configuration options for individual step execution on Kubernetes.
37
- Field descriptions are defined inline using Field() descriptors.
38
- """
35
+ """Settings for the Kubernetes step operator."""
39
36
 
40
37
  pod_settings: Optional[KubernetesPodSettings] = Field(
41
38
  default=None,
@@ -49,32 +46,66 @@ class KubernetesStepOperatorSettings(BaseSettings):
49
46
  default=False,
50
47
  description="Whether to run step containers in privileged mode with extended permissions.",
51
48
  )
52
- pod_startup_timeout: int = Field(
53
- default=600,
54
- description="Maximum seconds to wait for step pods to start. Default is 10 minutes.",
49
+ job_name_prefix: Optional[str] = Field(
50
+ default=None,
51
+ description="Prefix for the job name.",
55
52
  )
56
- pod_failure_max_retries: int = Field(
57
- default=3,
58
- description="Maximum retry attempts when step pods fail to start.",
53
+ ttl_seconds_after_finished: Optional[NonNegativeInt] = Field(
54
+ default=None,
55
+ description="Seconds to keep finished jobs before automatic cleanup.",
59
56
  )
60
- pod_failure_retry_delay: int = Field(
61
- default=10,
62
- description="Delay in seconds between pod failure retry attempts.",
57
+ active_deadline_seconds: Optional[NonNegativeInt] = Field(
58
+ default=None,
59
+ description="Job deadline in seconds. If the job doesn't finish "
60
+ "within this time, it will be terminated.",
63
61
  )
64
- pod_failure_backoff: float = Field(
65
- default=1.0,
66
- description="Exponential backoff factor for retry delays. Values > 1.0 increase delay with each retry.",
62
+ fail_on_container_waiting_reasons: Optional[List[str]] = Field(
63
+ default=[
64
+ "InvalidImageName",
65
+ "ErrImagePull",
66
+ "ImagePullBackOff",
67
+ "CreateContainerConfigError",
68
+ ],
69
+ description="List of container waiting reasons that should cause the "
70
+ "job to fail immediately. This should be set to a list of "
71
+ "nonrecoverable reasons, which if found in any "
72
+ "`pod.status.containerStatuses[*].state.waiting.reason` of a job pod, "
73
+ "should cause the job to fail immediately.",
74
+ )
75
+
76
+ # Deprecated fields
77
+ pod_startup_timeout: Optional[int] = Field(
78
+ default=None,
79
+ deprecated=True,
80
+ description="DEPRECATED/UNUSED.",
81
+ )
82
+ pod_failure_max_retries: Optional[int] = Field(
83
+ default=None,
84
+ deprecated=True,
85
+ description="DEPRECATED/UNUSED.",
86
+ )
87
+ pod_failure_retry_delay: Optional[int] = Field(
88
+ default=None,
89
+ deprecated=True,
90
+ description="DEPRECATED/UNUSED.",
91
+ )
92
+ pod_failure_backoff: Optional[float] = Field(
93
+ default=None,
94
+ deprecated=True,
95
+ description="DEPRECATED/UNUSED.",
96
+ )
97
+ _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes(
98
+ "pod_startup_timeout",
99
+ "pod_failure_max_retries",
100
+ "pod_failure_retry_delay",
101
+ "pod_failure_backoff",
67
102
  )
68
103
 
69
104
 
70
105
  class KubernetesStepOperatorConfig(
71
106
  BaseStepOperatorConfig, KubernetesStepOperatorSettings
72
107
  ):
73
- """Configuration for the Kubernetes step operator.
74
-
75
- Defines cluster connection and execution settings.
76
- Field descriptions are defined inline using Field() descriptors.
77
- """
108
+ """Configuration for the Kubernetes step operator."""
78
109
 
79
110
  kubernetes_namespace: str = Field(
80
111
  default="zenml",
@@ -0,0 +1,367 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """DAG runner."""
15
+
16
+ import queue
17
+ import threading
18
+ import time
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from typing import Any, Callable, Dict, List, Optional
21
+
22
+ from pydantic import BaseModel
23
+
24
+ from zenml.logger import get_logger
25
+ from zenml.utils.enum_utils import StrEnum
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class NodeStatus(StrEnum):
31
+ """Status of a DAG node."""
32
+
33
+ NOT_READY = "not_ready" # Can not be started yet
34
+ READY = "ready" # Can be started but is still waiting in the queue
35
+ STARTING = "starting" # Is being started, but not yet running
36
+ RUNNING = "running"
37
+ COMPLETED = "completed"
38
+ FAILED = "failed"
39
+ SKIPPED = "skipped"
40
+ CANCELLED = "cancelled"
41
+
42
+
43
+ class InterruptMode(StrEnum):
44
+ """Interrupt mode."""
45
+
46
+ GRACEFUL = "graceful"
47
+ FORCE = "force"
48
+
49
+
50
+ class Node(BaseModel):
51
+ """DAG node."""
52
+
53
+ id: str
54
+ status: NodeStatus = NodeStatus.NOT_READY
55
+ upstream_nodes: List[str] = []
56
+ metadata: Dict[str, Any] = {}
57
+
58
+ @property
59
+ def is_finished(self) -> bool:
60
+ """Whether the node is finished.
61
+
62
+ Returns:
63
+ Whether the node is finished.
64
+ """
65
+ return self.status in {
66
+ NodeStatus.COMPLETED,
67
+ NodeStatus.FAILED,
68
+ NodeStatus.SKIPPED,
69
+ NodeStatus.CANCELLED,
70
+ }
71
+
72
+
73
+ class DagRunner:
74
+ """DAG runner.
75
+
76
+ This class does the orchestration of running the nodes of a DAG. It is
77
+ running two loops in separate threads:
78
+ The main thread
79
+ - checks if any nodes should be skipped or are ready to
80
+ run, in which case the node will be added to the startup queue
81
+ - creates a worker thread to start the node and executes it in a thread
82
+ pool if there are nodes in the startup queue and the maximum
83
+ parallelism is not reached
84
+ - periodically checks if the DAG should be interrupted
85
+ The monitoring thread
86
+ - monitors the running nodes and updates their status
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ nodes: List[Node],
92
+ node_startup_function: Callable[[Node], NodeStatus],
93
+ node_monitoring_function: Callable[[Node], NodeStatus],
94
+ node_stop_function: Optional[Callable[[Node], None]] = None,
95
+ interrupt_function: Optional[
96
+ Callable[[], Optional[InterruptMode]]
97
+ ] = None,
98
+ monitoring_interval: float = 1.0,
99
+ monitoring_delay: float = 0.0,
100
+ interrupt_check_interval: float = 1.0,
101
+ max_parallelism: Optional[int] = None,
102
+ ) -> None:
103
+ """Initialize the DAG runner.
104
+
105
+ Args:
106
+ nodes: The nodes of the DAG.
107
+ node_startup_function: The function to start a node.
108
+ node_monitoring_function: The function to monitor a node.
109
+ node_stop_function: The function to stop a node.
110
+ interrupt_function: Will be periodically called to check if the
111
+ DAG should be interrupted.
112
+ monitoring_interval: The interval in which the nodes are monitored.
113
+ monitoring_delay: The delay in seconds to wait between monitoring
114
+ different nodes.
115
+ interrupt_check_interval: The interval in which the interrupt
116
+ function is called.
117
+ max_parallelism: The maximum number of nodes to run in parallel.
118
+ """
119
+ self.nodes = {node.id: node for node in nodes}
120
+ self.startup_queue: queue.Queue[Node] = queue.Queue()
121
+ self.node_startup_function = node_startup_function
122
+ self.node_monitoring_function = node_monitoring_function
123
+ self.node_stop_function = node_stop_function
124
+ self.interrupt_function = interrupt_function
125
+ self.monitoring_thread = threading.Thread(
126
+ name="DagRunner-Monitoring-Loop",
127
+ target=self._monitoring_loop,
128
+ daemon=True,
129
+ )
130
+ self.monitoring_interval = monitoring_interval
131
+ self.monitoring_delay = monitoring_delay
132
+ self.interrupt_check_interval = interrupt_check_interval
133
+ self.max_parallelism = max_parallelism
134
+ self.shutdown_event = threading.Event()
135
+ self.startup_executor = ThreadPoolExecutor(
136
+ max_workers=10, thread_name_prefix="DagRunner-Startup"
137
+ )
138
+
139
+ @property
140
+ def running_nodes(self) -> List[Node]:
141
+ """Running nodes.
142
+
143
+ Returns:
144
+ Running nodes.
145
+ """
146
+ return [
147
+ node
148
+ for node in self.nodes.values()
149
+ if node.status == NodeStatus.RUNNING
150
+ ]
151
+
152
+ @property
153
+ def active_nodes(self) -> List[Node]:
154
+ """Active nodes.
155
+
156
+ Active nodes are nodes that are either running or starting.
157
+
158
+ Returns:
159
+ Active nodes.
160
+ """
161
+ return [
162
+ node
163
+ for node in self.nodes.values()
164
+ if node.status in {NodeStatus.RUNNING, NodeStatus.STARTING}
165
+ ]
166
+
167
+ def _initialize_startup_queue(self) -> None:
168
+ """Initialize the startup queue.
169
+
170
+ The startup queue contains all nodes that are ready to be started.
171
+ """
172
+ for node in self.nodes.values():
173
+ if node.status in {NodeStatus.READY, NodeStatus.STARTING}:
174
+ self.startup_queue.put(node)
175
+
176
+ def _can_start_node(self, node: Node) -> bool:
177
+ """Check if a node can be started.
178
+
179
+ Args:
180
+ node: The node to check.
181
+
182
+ Returns:
183
+ Whether the node can be started.
184
+ """
185
+ return all(
186
+ self.nodes[upstream_node_id].status == NodeStatus.COMPLETED
187
+ for upstream_node_id in node.upstream_nodes
188
+ )
189
+
190
+ def _should_skip_node(self, node: Node) -> bool:
191
+ """Check if a node should be skipped.
192
+
193
+ Args:
194
+ node: The node to check.
195
+
196
+ Returns:
197
+ Whether the node should be skipped.
198
+ """
199
+ return any(
200
+ self.nodes[upstream_node_id].status
201
+ in {NodeStatus.FAILED, NodeStatus.SKIPPED, NodeStatus.CANCELLED}
202
+ for upstream_node_id in node.upstream_nodes
203
+ )
204
+
205
+ def _start_node(self, node: Node) -> None:
206
+ """Start a node.
207
+
208
+ This will start of a thread that will run the startup function.
209
+
210
+ Args:
211
+ node: The node to start.
212
+ """
213
+ node.status = NodeStatus.STARTING
214
+
215
+ def _start_node_task() -> None:
216
+ if self.shutdown_event.is_set():
217
+ logger.debug(
218
+ "Cancelling startup of node `%s` because shutdown was "
219
+ "requested.",
220
+ node.id,
221
+ )
222
+ return
223
+
224
+ try:
225
+ node.status = self.node_startup_function(node)
226
+ except Exception:
227
+ node.status = NodeStatus.FAILED
228
+ logger.exception("Node `%s` failed to start.", node.id)
229
+ else:
230
+ logger.info(
231
+ "Node `%s` started (status: %s)", node.id, node.status
232
+ )
233
+
234
+ self.startup_executor.submit(_start_node_task)
235
+
236
+ def _stop_node(self, node: Node) -> None:
237
+ """Stop a node.
238
+
239
+ Args:
240
+ node: The node to stop.
241
+
242
+ Raises:
243
+ RuntimeError: If the node stop function is not set.
244
+ """
245
+ if not self.node_stop_function:
246
+ raise RuntimeError("Node stop function is not set.")
247
+
248
+ self.node_stop_function(node)
249
+
250
+ def _stop_all_nodes(self) -> None:
251
+ """Stop all running nodes."""
252
+ for node in self.running_nodes:
253
+ self._stop_node(node)
254
+ node.status = NodeStatus.CANCELLED
255
+
256
+ def _process_nodes(self) -> bool:
257
+ """Process the nodes.
258
+
259
+ This method will check if any nodes should be skipped or are ready to
260
+ run, in which case the node will be added to the startup queue.
261
+
262
+ Returns:
263
+ Whether the DAG is finished.
264
+ """
265
+ finished = True
266
+
267
+ for node in self.nodes.values():
268
+ if node.status == NodeStatus.NOT_READY:
269
+ if self._should_skip_node(node):
270
+ node.status = NodeStatus.SKIPPED
271
+ logger.warning(
272
+ "Skipping node `%s` because upstream node failed.",
273
+ node.id,
274
+ )
275
+ elif self._can_start_node(node):
276
+ node.status = NodeStatus.READY
277
+ self.startup_queue.put(node)
278
+
279
+ if not node.is_finished:
280
+ finished = False
281
+
282
+ # Start nodes until we reach the maximum configured parallelism
283
+ max_parallelism = self.max_parallelism or len(self.nodes)
284
+ while len(self.active_nodes) < max_parallelism:
285
+ try:
286
+ node = self.startup_queue.get_nowait()
287
+ except queue.Empty:
288
+ break
289
+ else:
290
+ self.startup_queue.task_done()
291
+ self._start_node(node)
292
+
293
+ return finished
294
+
295
+ def _monitoring_loop(self) -> None:
296
+ """Monitoring loop.
297
+
298
+ This should run in a separate thread and monitors the running nodes.
299
+ """
300
+ while not self.shutdown_event.is_set():
301
+ start_time = time.time()
302
+ for node in self.running_nodes:
303
+ try:
304
+ node.status = self.node_monitoring_function(node)
305
+ except Exception:
306
+ node.status = NodeStatus.FAILED
307
+ logger.exception("Node `%s` failed.", node.id)
308
+ else:
309
+ logger.debug(
310
+ "Node `%s` status updated to `%s`",
311
+ node.id,
312
+ node.status,
313
+ )
314
+ if node.status == NodeStatus.FAILED:
315
+ logger.error("Node `%s` failed.", node.id)
316
+ elif node.status == NodeStatus.COMPLETED:
317
+ logger.info("Node `%s` completed.", node.id)
318
+
319
+ time.sleep(self.monitoring_delay)
320
+
321
+ duration = time.time() - start_time
322
+ time_to_sleep = max(0, self.monitoring_interval - duration)
323
+ self.shutdown_event.wait(timeout=time_to_sleep)
324
+
325
+ def run(self) -> Dict[str, NodeStatus]:
326
+ """Run the DAG.
327
+
328
+ Returns:
329
+ The final node states.
330
+ """
331
+ self._initialize_startup_queue()
332
+
333
+ self.monitoring_thread.start()
334
+
335
+ interrupt_mode = None
336
+ last_interrupt_check = time.time()
337
+
338
+ while True:
339
+ if self.interrupt_function is not None:
340
+ if (
341
+ time.time() - last_interrupt_check
342
+ >= self.interrupt_check_interval
343
+ ):
344
+ if interrupt_mode := self.interrupt_function():
345
+ logger.warning("DAG execution interrupted.")
346
+ break
347
+ last_interrupt_check = time.time()
348
+
349
+ is_finished = self._process_nodes()
350
+ if is_finished:
351
+ break
352
+
353
+ time.sleep(0.5)
354
+
355
+ self.shutdown_event.set()
356
+ if interrupt_mode == InterruptMode.FORCE:
357
+ # If a force interrupt was requested, we stop all running nodes.
358
+ self._stop_all_nodes()
359
+
360
+ self.monitoring_thread.join()
361
+
362
+ node_statuses = {
363
+ node_id: node.status for node_id, node in self.nodes.items()
364
+ }
365
+ logger.debug("Finished with node statuses: %s", node_statuses)
366
+
367
+ return node_statuses