zenml-nightly 0.84.1.dev20250804__py3-none-any.whl → 0.84.1.dev20250806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,367 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """DAG runner."""
15
+
16
+ import queue
17
+ import threading
18
+ import time
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ from typing import Any, Callable, Dict, List, Optional
21
+
22
+ from pydantic import BaseModel
23
+
24
+ from zenml.logger import get_logger
25
+ from zenml.utils.enum_utils import StrEnum
26
+
27
+ logger = get_logger(__name__)
28
+
29
+
30
+ class NodeStatus(StrEnum):
31
+ """Status of a DAG node."""
32
+
33
+ NOT_READY = "not_ready" # Can not be started yet
34
+ READY = "ready" # Can be started but is still waiting in the queue
35
+ STARTING = "starting" # Is being started, but not yet running
36
+ RUNNING = "running"
37
+ COMPLETED = "completed"
38
+ FAILED = "failed"
39
+ SKIPPED = "skipped"
40
+ CANCELLED = "cancelled"
41
+
42
+
43
+ class InterruptMode(StrEnum):
44
+ """Interrupt mode."""
45
+
46
+ GRACEFUL = "graceful"
47
+ FORCE = "force"
48
+
49
+
50
+ class Node(BaseModel):
51
+ """DAG node."""
52
+
53
+ id: str
54
+ status: NodeStatus = NodeStatus.NOT_READY
55
+ upstream_nodes: List[str] = []
56
+ metadata: Dict[str, Any] = {}
57
+
58
+ @property
59
+ def is_finished(self) -> bool:
60
+ """Whether the node is finished.
61
+
62
+ Returns:
63
+ Whether the node is finished.
64
+ """
65
+ return self.status in {
66
+ NodeStatus.COMPLETED,
67
+ NodeStatus.FAILED,
68
+ NodeStatus.SKIPPED,
69
+ NodeStatus.CANCELLED,
70
+ }
71
+
72
+
73
+ class DagRunner:
74
+ """DAG runner.
75
+
76
+ This class does the orchestration of running the nodes of a DAG. It is
77
+ running two loops in separate threads:
78
+ The main thread
79
+ - checks if any nodes should be skipped or are ready to
80
+ run, in which case the node will be added to the startup queue
81
+ - creates a worker thread to start the node and executes it in a thread
82
+ pool if there are nodes in the startup queue and the maximum
83
+ parallelism is not reached
84
+ - periodically checks if the DAG should be interrupted
85
+ The monitoring thread
86
+ - monitors the running nodes and updates their status
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ nodes: List[Node],
92
+ node_startup_function: Callable[[Node], NodeStatus],
93
+ node_monitoring_function: Callable[[Node], NodeStatus],
94
+ node_stop_function: Optional[Callable[[Node], None]] = None,
95
+ interrupt_function: Optional[
96
+ Callable[[], Optional[InterruptMode]]
97
+ ] = None,
98
+ monitoring_interval: float = 1.0,
99
+ monitoring_delay: float = 0.0,
100
+ interrupt_check_interval: float = 1.0,
101
+ max_parallelism: Optional[int] = None,
102
+ ) -> None:
103
+ """Initialize the DAG runner.
104
+
105
+ Args:
106
+ nodes: The nodes of the DAG.
107
+ node_startup_function: The function to start a node.
108
+ node_monitoring_function: The function to monitor a node.
109
+ node_stop_function: The function to stop a node.
110
+ interrupt_function: Will be periodically called to check if the
111
+ DAG should be interrupted.
112
+ monitoring_interval: The interval in which the nodes are monitored.
113
+ monitoring_delay: The delay in seconds to wait between monitoring
114
+ different nodes.
115
+ interrupt_check_interval: The interval in which the interrupt
116
+ function is called.
117
+ max_parallelism: The maximum number of nodes to run in parallel.
118
+ """
119
+ self.nodes = {node.id: node for node in nodes}
120
+ self.startup_queue: queue.Queue[Node] = queue.Queue()
121
+ self.node_startup_function = node_startup_function
122
+ self.node_monitoring_function = node_monitoring_function
123
+ self.node_stop_function = node_stop_function
124
+ self.interrupt_function = interrupt_function
125
+ self.monitoring_thread = threading.Thread(
126
+ name="DagRunner-Monitoring-Loop",
127
+ target=self._monitoring_loop,
128
+ daemon=True,
129
+ )
130
+ self.monitoring_interval = monitoring_interval
131
+ self.monitoring_delay = monitoring_delay
132
+ self.interrupt_check_interval = interrupt_check_interval
133
+ self.max_parallelism = max_parallelism
134
+ self.shutdown_event = threading.Event()
135
+ self.startup_executor = ThreadPoolExecutor(
136
+ max_workers=10, thread_name_prefix="DagRunner-Startup"
137
+ )
138
+
139
+ @property
140
+ def running_nodes(self) -> List[Node]:
141
+ """Running nodes.
142
+
143
+ Returns:
144
+ Running nodes.
145
+ """
146
+ return [
147
+ node
148
+ for node in self.nodes.values()
149
+ if node.status == NodeStatus.RUNNING
150
+ ]
151
+
152
+ @property
153
+ def active_nodes(self) -> List[Node]:
154
+ """Active nodes.
155
+
156
+ Active nodes are nodes that are either running or starting.
157
+
158
+ Returns:
159
+ Active nodes.
160
+ """
161
+ return [
162
+ node
163
+ for node in self.nodes.values()
164
+ if node.status in {NodeStatus.RUNNING, NodeStatus.STARTING}
165
+ ]
166
+
167
+ def _initialize_startup_queue(self) -> None:
168
+ """Initialize the startup queue.
169
+
170
+ The startup queue contains all nodes that are ready to be started.
171
+ """
172
+ for node in self.nodes.values():
173
+ if node.status in {NodeStatus.READY, NodeStatus.STARTING}:
174
+ self.startup_queue.put(node)
175
+
176
+ def _can_start_node(self, node: Node) -> bool:
177
+ """Check if a node can be started.
178
+
179
+ Args:
180
+ node: The node to check.
181
+
182
+ Returns:
183
+ Whether the node can be started.
184
+ """
185
+ return all(
186
+ self.nodes[upstream_node_id].status == NodeStatus.COMPLETED
187
+ for upstream_node_id in node.upstream_nodes
188
+ )
189
+
190
+ def _should_skip_node(self, node: Node) -> bool:
191
+ """Check if a node should be skipped.
192
+
193
+ Args:
194
+ node: The node to check.
195
+
196
+ Returns:
197
+ Whether the node should be skipped.
198
+ """
199
+ return any(
200
+ self.nodes[upstream_node_id].status
201
+ in {NodeStatus.FAILED, NodeStatus.SKIPPED, NodeStatus.CANCELLED}
202
+ for upstream_node_id in node.upstream_nodes
203
+ )
204
+
205
+ def _start_node(self, node: Node) -> None:
206
+ """Start a node.
207
+
208
+ This will start of a thread that will run the startup function.
209
+
210
+ Args:
211
+ node: The node to start.
212
+ """
213
+ node.status = NodeStatus.STARTING
214
+
215
+ def _start_node_task() -> None:
216
+ if self.shutdown_event.is_set():
217
+ logger.debug(
218
+ "Cancelling startup of node `%s` because shutdown was "
219
+ "requested.",
220
+ node.id,
221
+ )
222
+ return
223
+
224
+ try:
225
+ node.status = self.node_startup_function(node)
226
+ except Exception:
227
+ node.status = NodeStatus.FAILED
228
+ logger.exception("Node `%s` failed to start.", node.id)
229
+ else:
230
+ logger.info(
231
+ "Node `%s` started (status: %s)", node.id, node.status
232
+ )
233
+
234
+ self.startup_executor.submit(_start_node_task)
235
+
236
+ def _stop_node(self, node: Node) -> None:
237
+ """Stop a node.
238
+
239
+ Args:
240
+ node: The node to stop.
241
+
242
+ Raises:
243
+ RuntimeError: If the node stop function is not set.
244
+ """
245
+ if not self.node_stop_function:
246
+ raise RuntimeError("Node stop function is not set.")
247
+
248
+ self.node_stop_function(node)
249
+
250
+ def _stop_all_nodes(self) -> None:
251
+ """Stop all running nodes."""
252
+ for node in self.running_nodes:
253
+ self._stop_node(node)
254
+ node.status = NodeStatus.CANCELLED
255
+
256
+ def _process_nodes(self) -> bool:
257
+ """Process the nodes.
258
+
259
+ This method will check if any nodes should be skipped or are ready to
260
+ run, in which case the node will be added to the startup queue.
261
+
262
+ Returns:
263
+ Whether the DAG is finished.
264
+ """
265
+ finished = True
266
+
267
+ for node in self.nodes.values():
268
+ if node.status == NodeStatus.NOT_READY:
269
+ if self._should_skip_node(node):
270
+ node.status = NodeStatus.SKIPPED
271
+ logger.warning(
272
+ "Skipping node `%s` because upstream node failed.",
273
+ node.id,
274
+ )
275
+ elif self._can_start_node(node):
276
+ node.status = NodeStatus.READY
277
+ self.startup_queue.put(node)
278
+
279
+ if not node.is_finished:
280
+ finished = False
281
+
282
+ # Start nodes until we reach the maximum configured parallelism
283
+ max_parallelism = self.max_parallelism or len(self.nodes)
284
+ while len(self.active_nodes) < max_parallelism:
285
+ try:
286
+ node = self.startup_queue.get_nowait()
287
+ except queue.Empty:
288
+ break
289
+ else:
290
+ self.startup_queue.task_done()
291
+ self._start_node(node)
292
+
293
+ return finished
294
+
295
+ def _monitoring_loop(self) -> None:
296
+ """Monitoring loop.
297
+
298
+ This should run in a separate thread and monitors the running nodes.
299
+ """
300
+ while not self.shutdown_event.is_set():
301
+ start_time = time.time()
302
+ for node in self.running_nodes:
303
+ try:
304
+ node.status = self.node_monitoring_function(node)
305
+ except Exception:
306
+ node.status = NodeStatus.FAILED
307
+ logger.exception("Node `%s` failed.", node.id)
308
+ else:
309
+ logger.debug(
310
+ "Node `%s` status updated to `%s`",
311
+ node.id,
312
+ node.status,
313
+ )
314
+ if node.status == NodeStatus.FAILED:
315
+ logger.error("Node `%s` failed.", node.id)
316
+ elif node.status == NodeStatus.COMPLETED:
317
+ logger.info("Node `%s` completed.", node.id)
318
+
319
+ time.sleep(self.monitoring_delay)
320
+
321
+ duration = time.time() - start_time
322
+ time_to_sleep = max(0, self.monitoring_interval - duration)
323
+ self.shutdown_event.wait(timeout=time_to_sleep)
324
+
325
+ def run(self) -> Dict[str, NodeStatus]:
326
+ """Run the DAG.
327
+
328
+ Returns:
329
+ The final node states.
330
+ """
331
+ self._initialize_startup_queue()
332
+
333
+ self.monitoring_thread.start()
334
+
335
+ interrupt_mode = None
336
+ last_interrupt_check = time.time()
337
+
338
+ while True:
339
+ if self.interrupt_function is not None:
340
+ if (
341
+ time.time() - last_interrupt_check
342
+ >= self.interrupt_check_interval
343
+ ):
344
+ if interrupt_mode := self.interrupt_function():
345
+ logger.warning("DAG execution interrupted.")
346
+ break
347
+ last_interrupt_check = time.time()
348
+
349
+ is_finished = self._process_nodes()
350
+ if is_finished:
351
+ break
352
+
353
+ time.sleep(0.5)
354
+
355
+ self.shutdown_event.set()
356
+ if interrupt_mode == InterruptMode.FORCE:
357
+ # If a force interrupt was requested, we stop all running nodes.
358
+ self._stop_all_nodes()
359
+
360
+ self.monitoring_thread.join()
361
+
362
+ node_statuses = {
363
+ node_id: node.status for node_id, node in self.nodes.items()
364
+ }
365
+ logger.debug("Finished with node statuses: %s", node_statuses)
366
+
367
+ return node_statuses