zenml-nightly 0.84.1.dev20250804__py3-none-any.whl → 0.84.1.dev20250806__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/config/compiler.py +3 -3
- zenml/integrations/kubernetes/constants.py +27 -0
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +79 -36
- zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +55 -24
- zenml/integrations/kubernetes/orchestrators/dag_runner.py +367 -0
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +368 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +144 -262
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +392 -244
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +53 -85
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +74 -32
- zenml/logging/step_logging.py +33 -30
- zenml/steps/base_step.py +6 -6
- zenml/steps/step_decorator.py +4 -4
- zenml/zen_stores/sql_zen_store.py +8 -0
- {zenml_nightly-0.84.1.dev20250804.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/METADATA +1 -1
- {zenml_nightly-0.84.1.dev20250804.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/RECORD +20 -18
- {zenml_nightly-0.84.1.dev20250804.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.84.1.dev20250804.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.84.1.dev20250804.dist-info → zenml_nightly-0.84.1.dev20250806.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,367 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""DAG runner."""
|
15
|
+
|
16
|
+
import queue
|
17
|
+
import threading
|
18
|
+
import time
|
19
|
+
from concurrent.futures import ThreadPoolExecutor
|
20
|
+
from typing import Any, Callable, Dict, List, Optional
|
21
|
+
|
22
|
+
from pydantic import BaseModel
|
23
|
+
|
24
|
+
from zenml.logger import get_logger
|
25
|
+
from zenml.utils.enum_utils import StrEnum
|
26
|
+
|
27
|
+
logger = get_logger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
class NodeStatus(StrEnum):
|
31
|
+
"""Status of a DAG node."""
|
32
|
+
|
33
|
+
NOT_READY = "not_ready" # Can not be started yet
|
34
|
+
READY = "ready" # Can be started but is still waiting in the queue
|
35
|
+
STARTING = "starting" # Is being started, but not yet running
|
36
|
+
RUNNING = "running"
|
37
|
+
COMPLETED = "completed"
|
38
|
+
FAILED = "failed"
|
39
|
+
SKIPPED = "skipped"
|
40
|
+
CANCELLED = "cancelled"
|
41
|
+
|
42
|
+
|
43
|
+
class InterruptMode(StrEnum):
|
44
|
+
"""Interrupt mode."""
|
45
|
+
|
46
|
+
GRACEFUL = "graceful"
|
47
|
+
FORCE = "force"
|
48
|
+
|
49
|
+
|
50
|
+
class Node(BaseModel):
|
51
|
+
"""DAG node."""
|
52
|
+
|
53
|
+
id: str
|
54
|
+
status: NodeStatus = NodeStatus.NOT_READY
|
55
|
+
upstream_nodes: List[str] = []
|
56
|
+
metadata: Dict[str, Any] = {}
|
57
|
+
|
58
|
+
@property
|
59
|
+
def is_finished(self) -> bool:
|
60
|
+
"""Whether the node is finished.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
Whether the node is finished.
|
64
|
+
"""
|
65
|
+
return self.status in {
|
66
|
+
NodeStatus.COMPLETED,
|
67
|
+
NodeStatus.FAILED,
|
68
|
+
NodeStatus.SKIPPED,
|
69
|
+
NodeStatus.CANCELLED,
|
70
|
+
}
|
71
|
+
|
72
|
+
|
73
|
+
class DagRunner:
|
74
|
+
"""DAG runner.
|
75
|
+
|
76
|
+
This class does the orchestration of running the nodes of a DAG. It is
|
77
|
+
running two loops in separate threads:
|
78
|
+
The main thread
|
79
|
+
- checks if any nodes should be skipped or are ready to
|
80
|
+
run, in which case the node will be added to the startup queue
|
81
|
+
- creates a worker thread to start the node and executes it in a thread
|
82
|
+
pool if there are nodes in the startup queue and the maximum
|
83
|
+
parallelism is not reached
|
84
|
+
- periodically checks if the DAG should be interrupted
|
85
|
+
The monitoring thread
|
86
|
+
- monitors the running nodes and updates their status
|
87
|
+
"""
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
nodes: List[Node],
|
92
|
+
node_startup_function: Callable[[Node], NodeStatus],
|
93
|
+
node_monitoring_function: Callable[[Node], NodeStatus],
|
94
|
+
node_stop_function: Optional[Callable[[Node], None]] = None,
|
95
|
+
interrupt_function: Optional[
|
96
|
+
Callable[[], Optional[InterruptMode]]
|
97
|
+
] = None,
|
98
|
+
monitoring_interval: float = 1.0,
|
99
|
+
monitoring_delay: float = 0.0,
|
100
|
+
interrupt_check_interval: float = 1.0,
|
101
|
+
max_parallelism: Optional[int] = None,
|
102
|
+
) -> None:
|
103
|
+
"""Initialize the DAG runner.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
nodes: The nodes of the DAG.
|
107
|
+
node_startup_function: The function to start a node.
|
108
|
+
node_monitoring_function: The function to monitor a node.
|
109
|
+
node_stop_function: The function to stop a node.
|
110
|
+
interrupt_function: Will be periodically called to check if the
|
111
|
+
DAG should be interrupted.
|
112
|
+
monitoring_interval: The interval in which the nodes are monitored.
|
113
|
+
monitoring_delay: The delay in seconds to wait between monitoring
|
114
|
+
different nodes.
|
115
|
+
interrupt_check_interval: The interval in which the interrupt
|
116
|
+
function is called.
|
117
|
+
max_parallelism: The maximum number of nodes to run in parallel.
|
118
|
+
"""
|
119
|
+
self.nodes = {node.id: node for node in nodes}
|
120
|
+
self.startup_queue: queue.Queue[Node] = queue.Queue()
|
121
|
+
self.node_startup_function = node_startup_function
|
122
|
+
self.node_monitoring_function = node_monitoring_function
|
123
|
+
self.node_stop_function = node_stop_function
|
124
|
+
self.interrupt_function = interrupt_function
|
125
|
+
self.monitoring_thread = threading.Thread(
|
126
|
+
name="DagRunner-Monitoring-Loop",
|
127
|
+
target=self._monitoring_loop,
|
128
|
+
daemon=True,
|
129
|
+
)
|
130
|
+
self.monitoring_interval = monitoring_interval
|
131
|
+
self.monitoring_delay = monitoring_delay
|
132
|
+
self.interrupt_check_interval = interrupt_check_interval
|
133
|
+
self.max_parallelism = max_parallelism
|
134
|
+
self.shutdown_event = threading.Event()
|
135
|
+
self.startup_executor = ThreadPoolExecutor(
|
136
|
+
max_workers=10, thread_name_prefix="DagRunner-Startup"
|
137
|
+
)
|
138
|
+
|
139
|
+
@property
|
140
|
+
def running_nodes(self) -> List[Node]:
|
141
|
+
"""Running nodes.
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Running nodes.
|
145
|
+
"""
|
146
|
+
return [
|
147
|
+
node
|
148
|
+
for node in self.nodes.values()
|
149
|
+
if node.status == NodeStatus.RUNNING
|
150
|
+
]
|
151
|
+
|
152
|
+
@property
|
153
|
+
def active_nodes(self) -> List[Node]:
|
154
|
+
"""Active nodes.
|
155
|
+
|
156
|
+
Active nodes are nodes that are either running or starting.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Active nodes.
|
160
|
+
"""
|
161
|
+
return [
|
162
|
+
node
|
163
|
+
for node in self.nodes.values()
|
164
|
+
if node.status in {NodeStatus.RUNNING, NodeStatus.STARTING}
|
165
|
+
]
|
166
|
+
|
167
|
+
def _initialize_startup_queue(self) -> None:
|
168
|
+
"""Initialize the startup queue.
|
169
|
+
|
170
|
+
The startup queue contains all nodes that are ready to be started.
|
171
|
+
"""
|
172
|
+
for node in self.nodes.values():
|
173
|
+
if node.status in {NodeStatus.READY, NodeStatus.STARTING}:
|
174
|
+
self.startup_queue.put(node)
|
175
|
+
|
176
|
+
def _can_start_node(self, node: Node) -> bool:
|
177
|
+
"""Check if a node can be started.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
node: The node to check.
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
Whether the node can be started.
|
184
|
+
"""
|
185
|
+
return all(
|
186
|
+
self.nodes[upstream_node_id].status == NodeStatus.COMPLETED
|
187
|
+
for upstream_node_id in node.upstream_nodes
|
188
|
+
)
|
189
|
+
|
190
|
+
def _should_skip_node(self, node: Node) -> bool:
|
191
|
+
"""Check if a node should be skipped.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
node: The node to check.
|
195
|
+
|
196
|
+
Returns:
|
197
|
+
Whether the node should be skipped.
|
198
|
+
"""
|
199
|
+
return any(
|
200
|
+
self.nodes[upstream_node_id].status
|
201
|
+
in {NodeStatus.FAILED, NodeStatus.SKIPPED, NodeStatus.CANCELLED}
|
202
|
+
for upstream_node_id in node.upstream_nodes
|
203
|
+
)
|
204
|
+
|
205
|
+
def _start_node(self, node: Node) -> None:
|
206
|
+
"""Start a node.
|
207
|
+
|
208
|
+
This will start of a thread that will run the startup function.
|
209
|
+
|
210
|
+
Args:
|
211
|
+
node: The node to start.
|
212
|
+
"""
|
213
|
+
node.status = NodeStatus.STARTING
|
214
|
+
|
215
|
+
def _start_node_task() -> None:
|
216
|
+
if self.shutdown_event.is_set():
|
217
|
+
logger.debug(
|
218
|
+
"Cancelling startup of node `%s` because shutdown was "
|
219
|
+
"requested.",
|
220
|
+
node.id,
|
221
|
+
)
|
222
|
+
return
|
223
|
+
|
224
|
+
try:
|
225
|
+
node.status = self.node_startup_function(node)
|
226
|
+
except Exception:
|
227
|
+
node.status = NodeStatus.FAILED
|
228
|
+
logger.exception("Node `%s` failed to start.", node.id)
|
229
|
+
else:
|
230
|
+
logger.info(
|
231
|
+
"Node `%s` started (status: %s)", node.id, node.status
|
232
|
+
)
|
233
|
+
|
234
|
+
self.startup_executor.submit(_start_node_task)
|
235
|
+
|
236
|
+
def _stop_node(self, node: Node) -> None:
|
237
|
+
"""Stop a node.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
node: The node to stop.
|
241
|
+
|
242
|
+
Raises:
|
243
|
+
RuntimeError: If the node stop function is not set.
|
244
|
+
"""
|
245
|
+
if not self.node_stop_function:
|
246
|
+
raise RuntimeError("Node stop function is not set.")
|
247
|
+
|
248
|
+
self.node_stop_function(node)
|
249
|
+
|
250
|
+
def _stop_all_nodes(self) -> None:
|
251
|
+
"""Stop all running nodes."""
|
252
|
+
for node in self.running_nodes:
|
253
|
+
self._stop_node(node)
|
254
|
+
node.status = NodeStatus.CANCELLED
|
255
|
+
|
256
|
+
def _process_nodes(self) -> bool:
|
257
|
+
"""Process the nodes.
|
258
|
+
|
259
|
+
This method will check if any nodes should be skipped or are ready to
|
260
|
+
run, in which case the node will be added to the startup queue.
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Whether the DAG is finished.
|
264
|
+
"""
|
265
|
+
finished = True
|
266
|
+
|
267
|
+
for node in self.nodes.values():
|
268
|
+
if node.status == NodeStatus.NOT_READY:
|
269
|
+
if self._should_skip_node(node):
|
270
|
+
node.status = NodeStatus.SKIPPED
|
271
|
+
logger.warning(
|
272
|
+
"Skipping node `%s` because upstream node failed.",
|
273
|
+
node.id,
|
274
|
+
)
|
275
|
+
elif self._can_start_node(node):
|
276
|
+
node.status = NodeStatus.READY
|
277
|
+
self.startup_queue.put(node)
|
278
|
+
|
279
|
+
if not node.is_finished:
|
280
|
+
finished = False
|
281
|
+
|
282
|
+
# Start nodes until we reach the maximum configured parallelism
|
283
|
+
max_parallelism = self.max_parallelism or len(self.nodes)
|
284
|
+
while len(self.active_nodes) < max_parallelism:
|
285
|
+
try:
|
286
|
+
node = self.startup_queue.get_nowait()
|
287
|
+
except queue.Empty:
|
288
|
+
break
|
289
|
+
else:
|
290
|
+
self.startup_queue.task_done()
|
291
|
+
self._start_node(node)
|
292
|
+
|
293
|
+
return finished
|
294
|
+
|
295
|
+
def _monitoring_loop(self) -> None:
|
296
|
+
"""Monitoring loop.
|
297
|
+
|
298
|
+
This should run in a separate thread and monitors the running nodes.
|
299
|
+
"""
|
300
|
+
while not self.shutdown_event.is_set():
|
301
|
+
start_time = time.time()
|
302
|
+
for node in self.running_nodes:
|
303
|
+
try:
|
304
|
+
node.status = self.node_monitoring_function(node)
|
305
|
+
except Exception:
|
306
|
+
node.status = NodeStatus.FAILED
|
307
|
+
logger.exception("Node `%s` failed.", node.id)
|
308
|
+
else:
|
309
|
+
logger.debug(
|
310
|
+
"Node `%s` status updated to `%s`",
|
311
|
+
node.id,
|
312
|
+
node.status,
|
313
|
+
)
|
314
|
+
if node.status == NodeStatus.FAILED:
|
315
|
+
logger.error("Node `%s` failed.", node.id)
|
316
|
+
elif node.status == NodeStatus.COMPLETED:
|
317
|
+
logger.info("Node `%s` completed.", node.id)
|
318
|
+
|
319
|
+
time.sleep(self.monitoring_delay)
|
320
|
+
|
321
|
+
duration = time.time() - start_time
|
322
|
+
time_to_sleep = max(0, self.monitoring_interval - duration)
|
323
|
+
self.shutdown_event.wait(timeout=time_to_sleep)
|
324
|
+
|
325
|
+
def run(self) -> Dict[str, NodeStatus]:
|
326
|
+
"""Run the DAG.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
The final node states.
|
330
|
+
"""
|
331
|
+
self._initialize_startup_queue()
|
332
|
+
|
333
|
+
self.monitoring_thread.start()
|
334
|
+
|
335
|
+
interrupt_mode = None
|
336
|
+
last_interrupt_check = time.time()
|
337
|
+
|
338
|
+
while True:
|
339
|
+
if self.interrupt_function is not None:
|
340
|
+
if (
|
341
|
+
time.time() - last_interrupt_check
|
342
|
+
>= self.interrupt_check_interval
|
343
|
+
):
|
344
|
+
if interrupt_mode := self.interrupt_function():
|
345
|
+
logger.warning("DAG execution interrupted.")
|
346
|
+
break
|
347
|
+
last_interrupt_check = time.time()
|
348
|
+
|
349
|
+
is_finished = self._process_nodes()
|
350
|
+
if is_finished:
|
351
|
+
break
|
352
|
+
|
353
|
+
time.sleep(0.5)
|
354
|
+
|
355
|
+
self.shutdown_event.set()
|
356
|
+
if interrupt_mode == InterruptMode.FORCE:
|
357
|
+
# If a force interrupt was requested, we stop all running nodes.
|
358
|
+
self._stop_all_nodes()
|
359
|
+
|
360
|
+
self.monitoring_thread.join()
|
361
|
+
|
362
|
+
node_statuses = {
|
363
|
+
node_id: node.status for node_id, node in self.nodes.items()
|
364
|
+
}
|
365
|
+
logger.debug("Finished with node statuses: %s", node_statuses)
|
366
|
+
|
367
|
+
return node_statuses
|