zenml-nightly 0.83.1.dev20250702__py3-none-any.whl → 0.83.1.dev20250704__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/pipeline.py +54 -1
- zenml/cli/utils.py +2 -0
- zenml/config/compiler.py +19 -3
- zenml/config/step_configurations.py +34 -2
- zenml/constants.py +1 -0
- zenml/enums.py +6 -3
- zenml/exceptions.py +8 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +8 -4
- zenml/integrations/aws/step_operators/sagemaker_step_operator.py +1 -1
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +5 -3
- zenml/integrations/azure/step_operators/azureml_step_operator.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +7 -8
- zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -1
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +6 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +109 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +36 -1
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +11 -3
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +1 -1
- zenml/integrations/modal/step_operators/modal_step_operator.py +1 -1
- zenml/integrations/spark/step_operators/kubernetes_step_operator.py +1 -1
- zenml/models/v2/core/pipeline_run.py +2 -2
- zenml/orchestrators/base_orchestrator.py +70 -0
- zenml/orchestrators/containerized_orchestrator.py +22 -0
- zenml/orchestrators/dag_runner.py +27 -8
- zenml/orchestrators/local_docker/local_docker_orchestrator.py +9 -0
- zenml/orchestrators/publish_utils.py +100 -13
- zenml/orchestrators/step_launcher.py +94 -8
- zenml/stack/stack.py +2 -2
- zenml/utils/run_utils.py +74 -0
- zenml/zen_server/routers/runs_endpoints.py +27 -23
- zenml/zen_stores/sql_zen_store.py +23 -3
- {zenml_nightly-0.83.1.dev20250702.dist-info → zenml_nightly-0.83.1.dev20250704.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.1.dev20250702.dist-info → zenml_nightly-0.83.1.dev20250704.dist-info}/RECORD +37 -36
- {zenml_nightly-0.83.1.dev20250702.dist-info → zenml_nightly-0.83.1.dev20250704.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.1.dev20250702.dist-info → zenml_nightly-0.83.1.dev20250704.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.1.dev20250702.dist-info → zenml_nightly-0.83.1.dev20250704.dist-info}/entry_points.txt +0 -0
@@ -343,7 +343,7 @@ class PipelineRunResponse(
|
|
343
343
|
if self.stack is None:
|
344
344
|
raise ValueError(
|
345
345
|
"The stack that this pipeline run response was executed on"
|
346
|
-
"has been deleted."
|
346
|
+
"is either not accessible or has been deleted."
|
347
347
|
)
|
348
348
|
|
349
349
|
# Create the orchestrator instance
|
@@ -358,7 +358,7 @@ class PipelineRunResponse(
|
|
358
358
|
if len(orchestrator_list) == 0:
|
359
359
|
raise ValueError(
|
360
360
|
"The orchestrator that this pipeline run response was "
|
361
|
-
"executed with has been deleted."
|
361
|
+
"executed with is either not accessible or has been deleted."
|
362
362
|
)
|
363
363
|
|
364
364
|
orchestrator = cast(
|
@@ -38,6 +38,7 @@ from zenml.logger import get_logger
|
|
38
38
|
from zenml.metadata.metadata_types import MetadataType
|
39
39
|
from zenml.orchestrators.publish_utils import (
|
40
40
|
publish_pipeline_run_metadata,
|
41
|
+
publish_pipeline_run_status_update,
|
41
42
|
publish_schedule_metadata,
|
42
43
|
)
|
43
44
|
from zenml.orchestrators.step_launcher import StepLauncher
|
@@ -210,6 +211,8 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
210
211
|
This will be deleted in case the pipeline deployment failed.
|
211
212
|
|
212
213
|
Raises:
|
214
|
+
KeyboardInterrupt: If the orchestrator is synchronous and the
|
215
|
+
pipeline run is keyboard interrupted.
|
213
216
|
RunMonitoringError: If a failure happened while monitoring the
|
214
217
|
pipeline run.
|
215
218
|
"""
|
@@ -324,8 +327,17 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
324
327
|
if submission_result.wait_for_completion:
|
325
328
|
try:
|
326
329
|
submission_result.wait_for_completion()
|
330
|
+
except KeyboardInterrupt:
|
331
|
+
error_message = "Received KeyboardInterrupt. Note that the run is still executing. "
|
332
|
+
if placeholder_run:
|
333
|
+
error_message += (
|
334
|
+
"If you want to stop the pipeline run, please use: "
|
335
|
+
f"`zenml pipeline runs stop {placeholder_run.id}`"
|
336
|
+
)
|
337
|
+
raise KeyboardInterrupt(error_message)
|
327
338
|
except BaseException as e:
|
328
339
|
raise RunMonitoringError(original_exception=e)
|
340
|
+
|
329
341
|
finally:
|
330
342
|
self._cleanup_run()
|
331
343
|
|
@@ -391,6 +403,64 @@ class BaseOrchestrator(StackComponent, ABC):
|
|
391
403
|
f"'{self.__class__.__name__}' orchestrator."
|
392
404
|
)
|
393
405
|
|
406
|
+
def stop_run(
|
407
|
+
self, run: "PipelineRunResponse", graceful: bool = False
|
408
|
+
) -> None:
|
409
|
+
"""Stops a specific pipeline run.
|
410
|
+
|
411
|
+
This method should only be called if the orchestrator's
|
412
|
+
supports_cancellation property is True.
|
413
|
+
|
414
|
+
Args:
|
415
|
+
run: A pipeline run response to stop.
|
416
|
+
graceful: If True, allows for graceful shutdown where possible.
|
417
|
+
If False, forces immediate termination. Default is False.
|
418
|
+
|
419
|
+
Raises:
|
420
|
+
NotImplementedError: If any orchestrator inheriting from the base
|
421
|
+
class does not implement this logic.
|
422
|
+
"""
|
423
|
+
# Check if the orchestrator supports cancellation
|
424
|
+
if (
|
425
|
+
getattr(self._stop_run, "__func__", None)
|
426
|
+
is BaseOrchestrator._stop_run
|
427
|
+
):
|
428
|
+
raise NotImplementedError(
|
429
|
+
f"The '{self.__class__.__name__}' orchestrator does not "
|
430
|
+
"support stopping pipeline runs."
|
431
|
+
)
|
432
|
+
|
433
|
+
# Update pipeline status to STOPPING before calling concrete implementation
|
434
|
+
publish_pipeline_run_status_update(
|
435
|
+
pipeline_run_id=run.id,
|
436
|
+
status=ExecutionStatus.STOPPING,
|
437
|
+
)
|
438
|
+
|
439
|
+
# Now call the concrete implementation
|
440
|
+
self._stop_run(run=run, graceful=graceful)
|
441
|
+
|
442
|
+
def _stop_run(
|
443
|
+
self, run: "PipelineRunResponse", graceful: bool = False
|
444
|
+
) -> None:
|
445
|
+
"""Concrete implementation of pipeline stopping logic.
|
446
|
+
|
447
|
+
This method should be implemented by concrete orchestrator classes
|
448
|
+
instead of stop_run to ensure proper status management.
|
449
|
+
|
450
|
+
Args:
|
451
|
+
run: A pipeline run response to stop (already updated to STOPPING status).
|
452
|
+
graceful: If True, allows for graceful shutdown where possible.
|
453
|
+
If False, forces immediate termination. Default is True.
|
454
|
+
|
455
|
+
Raises:
|
456
|
+
NotImplementedError: If any orchestrator inheriting from the base
|
457
|
+
class does not implement this logic.
|
458
|
+
"""
|
459
|
+
raise NotImplementedError(
|
460
|
+
"The stop run functionality is not implemented for the "
|
461
|
+
f"'{self.__class__.__name__}' orchestrator."
|
462
|
+
)
|
463
|
+
|
394
464
|
|
395
465
|
class BaseOrchestratorFlavor(Flavor):
|
396
466
|
"""Base orchestrator flavor class."""
|
@@ -53,6 +53,19 @@ class ContainerizedOrchestrator(BaseOrchestrator, ABC):
|
|
53
53
|
component_key=ORCHESTRATOR_DOCKER_IMAGE_KEY, step=step_name
|
54
54
|
)
|
55
55
|
|
56
|
+
def should_build_pipeline_image(
|
57
|
+
self, deployment: "PipelineDeploymentBase"
|
58
|
+
) -> bool:
|
59
|
+
"""Whether to build the pipeline image.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
deployment: The pipeline deployment.
|
63
|
+
|
64
|
+
Returns:
|
65
|
+
Whether to build the pipeline image.
|
66
|
+
"""
|
67
|
+
return False
|
68
|
+
|
56
69
|
def get_docker_builds(
|
57
70
|
self, deployment: "PipelineDeploymentBase"
|
58
71
|
) -> List["BuildConfiguration"]:
|
@@ -87,4 +100,13 @@ class ContainerizedOrchestrator(BaseOrchestrator, ABC):
|
|
87
100
|
builds.append(pipeline_build)
|
88
101
|
included_pipeline_build = True
|
89
102
|
|
103
|
+
if not included_pipeline_build and self.should_build_pipeline_image(
|
104
|
+
deployment
|
105
|
+
):
|
106
|
+
pipeline_build = BuildConfiguration(
|
107
|
+
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
|
108
|
+
settings=pipeline_settings,
|
109
|
+
)
|
110
|
+
builds.append(pipeline_build)
|
111
|
+
|
90
112
|
return builds
|
@@ -56,6 +56,7 @@ class NodeStatus(Enum):
|
|
56
56
|
RUNNING = "running"
|
57
57
|
COMPLETED = "completed"
|
58
58
|
FAILED = "failed"
|
59
|
+
CANCELLED = "cancelled"
|
59
60
|
|
60
61
|
|
61
62
|
class ThreadedDagRunner:
|
@@ -76,6 +77,7 @@ class ThreadedDagRunner:
|
|
76
77
|
finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None,
|
77
78
|
parallel_node_startup_waiting_period: float = 0.0,
|
78
79
|
max_parallelism: Optional[int] = None,
|
80
|
+
continue_fn: Optional[Callable[[], bool]] = None,
|
79
81
|
) -> None:
|
80
82
|
"""Define attributes and initialize all nodes in waiting state.
|
81
83
|
|
@@ -92,6 +94,9 @@ class ThreadedDagRunner:
|
|
92
94
|
parallel_node_startup_waiting_period: Delay in seconds to wait in
|
93
95
|
between starting parallel nodes.
|
94
96
|
max_parallelism: Maximum number of nodes to run in parallel
|
97
|
+
continue_fn: A function that returns True if the run should continue
|
98
|
+
after each step execution, False if it should stop (e.g., due
|
99
|
+
to cancellation). If None, execution continues normally.
|
95
100
|
|
96
101
|
Raises:
|
97
102
|
ValueError: If max_parallelism is not greater than 0.
|
@@ -108,12 +113,15 @@ class ThreadedDagRunner:
|
|
108
113
|
self.run_fn = run_fn
|
109
114
|
self.preparation_fn = preparation_fn
|
110
115
|
self.finalize_fn = finalize_fn
|
116
|
+
self.continue_fn = continue_fn
|
111
117
|
self.nodes = dag.keys()
|
112
118
|
self.node_states = {
|
113
119
|
node: NodeStatus.NOT_STARTED for node in self.nodes
|
114
120
|
}
|
115
121
|
self._lock = threading.Lock()
|
116
122
|
|
123
|
+
self._stop_requested = False
|
124
|
+
|
117
125
|
def _can_run(self, node: str) -> bool:
|
118
126
|
"""Determine whether a node is ready to be run.
|
119
127
|
|
@@ -173,6 +181,15 @@ class ThreadedDagRunner:
|
|
173
181
|
"""
|
174
182
|
self._prepare_node_run(node)
|
175
183
|
|
184
|
+
# Check if execution should continue (e.g., check for cancellation)
|
185
|
+
if self.continue_fn:
|
186
|
+
self._stop_requested = (
|
187
|
+
self._stop_requested or not self.continue_fn()
|
188
|
+
)
|
189
|
+
if self._stop_requested:
|
190
|
+
self._finish_node(node, cancelled=True)
|
191
|
+
return
|
192
|
+
|
176
193
|
if self.preparation_fn:
|
177
194
|
run_required = self.preparation_fn(node)
|
178
195
|
if not run_required:
|
@@ -204,24 +221,26 @@ class ThreadedDagRunner:
|
|
204
221
|
thread.start()
|
205
222
|
return thread
|
206
223
|
|
207
|
-
def _finish_node(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
Then starts all other nodes that can now be run and waits for them.
|
224
|
+
def _finish_node(
|
225
|
+
self, node: str, failed: bool = False, cancelled: bool = False
|
226
|
+
) -> None:
|
227
|
+
"""Mark a node as finished and potentially start new nodes.
|
212
228
|
|
213
229
|
Args:
|
214
|
-
node: The node.
|
230
|
+
node: The node to mark as finished.
|
215
231
|
failed: Whether the node failed.
|
232
|
+
cancelled: Whether the node was cancelled.
|
216
233
|
"""
|
217
234
|
with self._lock:
|
218
235
|
if failed:
|
219
236
|
self.node_states[node] = NodeStatus.FAILED
|
237
|
+
elif cancelled:
|
238
|
+
self.node_states[node] = NodeStatus.CANCELLED
|
220
239
|
else:
|
221
240
|
self.node_states[node] = NodeStatus.COMPLETED
|
222
241
|
|
223
|
-
if failed:
|
224
|
-
# If the node failed, we don't need to run any downstream nodes.
|
242
|
+
if failed or cancelled:
|
243
|
+
# If the node failed or was cancelled, we don't need to run any downstream nodes.
|
225
244
|
return
|
226
245
|
|
227
246
|
# Run downstream nodes.
|
@@ -63,6 +63,15 @@ class LocalDockerOrchestrator(ContainerizedOrchestrator):
|
|
63
63
|
"""
|
64
64
|
return LocalDockerOrchestratorSettings
|
65
65
|
|
66
|
+
@property
|
67
|
+
def config(self) -> "LocalDockerOrchestratorConfig":
|
68
|
+
"""Returns the `LocalDockerOrchestratorConfig` config.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The configuration.
|
72
|
+
"""
|
73
|
+
return cast(LocalDockerOrchestratorConfig, self._config)
|
74
|
+
|
66
75
|
@property
|
67
76
|
def validator(self) -> Optional[StackValidator]:
|
68
77
|
"""Ensures there is an image builder in the stack.
|
@@ -13,7 +13,8 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Utilities to publish pipeline and step runs."""
|
15
15
|
|
16
|
-
from
|
16
|
+
from datetime import datetime
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
17
18
|
|
18
19
|
from zenml.client import Client
|
19
20
|
from zenml.enums import ExecutionStatus, MetadataResourceTypes
|
@@ -54,6 +55,40 @@ def publish_successful_step_run(
|
|
54
55
|
)
|
55
56
|
|
56
57
|
|
58
|
+
def publish_step_run_status_update(
|
59
|
+
step_run_id: "UUID",
|
60
|
+
status: "ExecutionStatus",
|
61
|
+
end_time: Optional[datetime] = None,
|
62
|
+
) -> "StepRunResponse":
|
63
|
+
"""Publishes a step run update.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
step_run_id: ID of the step run.
|
67
|
+
status: New status of the step run.
|
68
|
+
end_time: New end time of the step run.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The updated step run.
|
72
|
+
|
73
|
+
Raises:
|
74
|
+
ValueError: If the end time is set for a non-finished step run.
|
75
|
+
"""
|
76
|
+
from zenml.client import Client
|
77
|
+
|
78
|
+
if end_time is not None and not status.is_finished:
|
79
|
+
raise ValueError("End time cannot be set for a non-finished step run.")
|
80
|
+
|
81
|
+
step_run = Client().zen_store.update_run_step(
|
82
|
+
step_run_id=step_run_id,
|
83
|
+
step_run_update=StepRunUpdate(
|
84
|
+
status=status,
|
85
|
+
end_time=end_time,
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
return step_run
|
90
|
+
|
91
|
+
|
57
92
|
def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
|
58
93
|
"""Publishes a failed step run.
|
59
94
|
|
@@ -63,12 +98,10 @@ def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
|
|
63
98
|
Returns:
|
64
99
|
The updated step run.
|
65
100
|
"""
|
66
|
-
return
|
101
|
+
return publish_step_run_status_update(
|
67
102
|
step_run_id=step_run_id,
|
68
|
-
|
69
|
-
|
70
|
-
end_time=utc_now(),
|
71
|
-
),
|
103
|
+
status=ExecutionStatus.FAILED,
|
104
|
+
end_time=utc_now(),
|
72
105
|
)
|
73
106
|
|
74
107
|
|
@@ -92,27 +125,81 @@ def publish_failed_pipeline_run(
|
|
92
125
|
)
|
93
126
|
|
94
127
|
|
128
|
+
def publish_pipeline_run_status_update(
|
129
|
+
pipeline_run_id: "UUID",
|
130
|
+
status: ExecutionStatus,
|
131
|
+
end_time: Optional[datetime] = None,
|
132
|
+
) -> "PipelineRunResponse":
|
133
|
+
"""Publishes a pipeline run status update.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
pipeline_run_id: The ID of the pipeline run to update.
|
137
|
+
status: The new status for the pipeline run.
|
138
|
+
end_time: The end time for the pipeline run. If None, will be set to current time
|
139
|
+
for finished statuses.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
The updated pipeline run.
|
143
|
+
"""
|
144
|
+
if end_time is None and status.is_finished:
|
145
|
+
end_time = utc_now()
|
146
|
+
|
147
|
+
return Client().zen_store.update_run(
|
148
|
+
run_id=pipeline_run_id,
|
149
|
+
run_update=PipelineRunUpdate(
|
150
|
+
status=status,
|
151
|
+
end_time=end_time,
|
152
|
+
),
|
153
|
+
)
|
154
|
+
|
155
|
+
|
95
156
|
def get_pipeline_run_status(
|
96
|
-
|
157
|
+
run_status: ExecutionStatus,
|
158
|
+
step_statuses: List[ExecutionStatus],
|
159
|
+
num_steps: int,
|
97
160
|
) -> ExecutionStatus:
|
98
161
|
"""Gets the pipeline run status for the given step statuses.
|
99
162
|
|
100
163
|
Args:
|
164
|
+
run_status: The status of the run.
|
101
165
|
step_statuses: The status of steps in this run.
|
102
166
|
num_steps: The total amount of steps in this run.
|
103
167
|
|
104
168
|
Returns:
|
105
169
|
The run status.
|
106
170
|
"""
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
171
|
+
# STOPPING state
|
172
|
+
if run_status == ExecutionStatus.STOPPING:
|
173
|
+
if all(status.is_finished for status in step_statuses):
|
174
|
+
return ExecutionStatus.STOPPED
|
175
|
+
else:
|
176
|
+
return ExecutionStatus.STOPPING
|
177
|
+
|
178
|
+
# If there is a stopped step, the run is stopped or stopping
|
179
|
+
if ExecutionStatus.STOPPED in step_statuses:
|
180
|
+
if all(status.is_finished for status in step_statuses):
|
181
|
+
return ExecutionStatus.STOPPED
|
182
|
+
else:
|
183
|
+
return ExecutionStatus.STOPPING
|
184
|
+
|
185
|
+
# Otherwise, if there is a failed step, the run is failed
|
186
|
+
elif (
|
187
|
+
ExecutionStatus.FAILED in step_statuses
|
188
|
+
or run_status == ExecutionStatus.FAILED
|
112
189
|
):
|
190
|
+
return ExecutionStatus.FAILED
|
191
|
+
|
192
|
+
# If there is a running step, the run is running
|
193
|
+
elif ExecutionStatus.RUNNING in step_statuses:
|
194
|
+
return ExecutionStatus.RUNNING
|
195
|
+
|
196
|
+
# If there are less steps than the total number of steps, it is running
|
197
|
+
elif len(step_statuses) < num_steps:
|
113
198
|
return ExecutionStatus.RUNNING
|
114
199
|
|
115
|
-
|
200
|
+
# Any other state is completed
|
201
|
+
else:
|
202
|
+
return ExecutionStatus.COMPLETED
|
116
203
|
|
117
204
|
|
118
205
|
def publish_pipeline_run_metadata(
|
@@ -14,10 +14,11 @@
|
|
14
14
|
"""Class to launch (run directly or using a step operator) steps."""
|
15
15
|
|
16
16
|
import os
|
17
|
+
import signal
|
17
18
|
import time
|
18
19
|
from contextlib import nullcontext
|
19
20
|
from functools import partial
|
20
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
|
21
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
21
22
|
|
22
23
|
from zenml.client import Client
|
23
24
|
from zenml.config.step_configurations import Step
|
@@ -29,6 +30,7 @@ from zenml.constants import (
|
|
29
30
|
)
|
30
31
|
from zenml.enums import ExecutionStatus
|
31
32
|
from zenml.environment import get_run_environment_dict
|
33
|
+
from zenml.exceptions import RunInterruptedException, RunStoppedException
|
32
34
|
from zenml.logger import get_logger
|
33
35
|
from zenml.logging import step_logging
|
34
36
|
from zenml.models import (
|
@@ -53,7 +55,7 @@ logger = get_logger(__name__)
|
|
53
55
|
|
54
56
|
|
55
57
|
def _get_step_operator(
|
56
|
-
stack: "Stack", step_operator_name: str
|
58
|
+
stack: "Stack", step_operator_name: Optional[str]
|
57
59
|
) -> "BaseStepOperator":
|
58
60
|
"""Fetches the step operator from the stack.
|
59
61
|
|
@@ -76,7 +78,7 @@ def _get_step_operator(
|
|
76
78
|
f"No step operator specified for active stack '{stack.name}'."
|
77
79
|
)
|
78
80
|
|
79
|
-
if step_operator_name != step_operator.name:
|
81
|
+
if step_operator_name and step_operator_name != step_operator.name:
|
80
82
|
raise RuntimeError(
|
81
83
|
f"No step operator named '{step_operator_name}' in active "
|
82
84
|
f"stack '{stack.name}'."
|
@@ -131,11 +133,86 @@ class StepLauncher:
|
|
131
133
|
self._stack = Stack.from_model(deployment.stack)
|
132
134
|
self._step_name = step.spec.pipeline_parameter_name
|
133
135
|
|
136
|
+
# Internal properties and methods
|
137
|
+
self._step_run: Optional[StepRunResponse] = None
|
138
|
+
self._setup_signal_handlers()
|
139
|
+
|
140
|
+
def _setup_signal_handlers(self) -> None:
|
141
|
+
"""Set up signal handlers for graceful shutdown, chaining previous handlers."""
|
142
|
+
# Save previous handlers
|
143
|
+
self._prev_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
144
|
+
self._prev_sigint_handler = signal.getsignal(signal.SIGINT)
|
145
|
+
|
146
|
+
def signal_handler(signum: int, frame: Any) -> None:
|
147
|
+
"""Handle shutdown signals gracefully.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
signum: The signal number.
|
151
|
+
frame: The frame of the signal handler.
|
152
|
+
|
153
|
+
Raises:
|
154
|
+
RunStoppedException: If the pipeline run is stopped by the user.
|
155
|
+
RunInterruptedException: If the execution is interrupted for any
|
156
|
+
other reason.
|
157
|
+
"""
|
158
|
+
logger.info(
|
159
|
+
f"Received signal shutdown {signum}. Requesting shutdown "
|
160
|
+
f"for step '{self._step_name}'..."
|
161
|
+
)
|
162
|
+
|
163
|
+
try:
|
164
|
+
client = Client()
|
165
|
+
pipeline_run = None
|
166
|
+
|
167
|
+
if self._step_run:
|
168
|
+
pipeline_run = client.get_pipeline_run(
|
169
|
+
self._step_run.pipeline_run_id
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
raise RunInterruptedException(
|
173
|
+
"The execution was interrupted and the step does not "
|
174
|
+
"exist yet."
|
175
|
+
)
|
176
|
+
|
177
|
+
if pipeline_run and pipeline_run.status in [
|
178
|
+
ExecutionStatus.STOPPING,
|
179
|
+
ExecutionStatus.STOPPED,
|
180
|
+
]:
|
181
|
+
if self._step_run:
|
182
|
+
publish_utils.publish_step_run_status_update(
|
183
|
+
step_run_id=self._step_run.id,
|
184
|
+
status=ExecutionStatus.STOPPED,
|
185
|
+
end_time=utc_now(),
|
186
|
+
)
|
187
|
+
raise RunStoppedException("Pipeline run in stopped.")
|
188
|
+
else:
|
189
|
+
raise RunInterruptedException(
|
190
|
+
"The execution was interrupted."
|
191
|
+
)
|
192
|
+
except (RunStoppedException, RunInterruptedException):
|
193
|
+
raise
|
194
|
+
except Exception as e:
|
195
|
+
raise RunInterruptedException(str(e))
|
196
|
+
finally:
|
197
|
+
# Chain to previous handler if it exists and is not default/ignore
|
198
|
+
if signum == signal.SIGTERM and callable(
|
199
|
+
self._prev_sigterm_handler
|
200
|
+
):
|
201
|
+
self._prev_sigterm_handler(signum, frame)
|
202
|
+
elif signum == signal.SIGINT and callable(
|
203
|
+
self._prev_sigint_handler
|
204
|
+
):
|
205
|
+
self._prev_sigint_handler(signum, frame)
|
206
|
+
|
207
|
+
# Register handlers for common termination signals
|
208
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
209
|
+
signal.signal(signal.SIGINT, signal_handler)
|
210
|
+
|
134
211
|
def launch(self) -> None:
|
135
212
|
"""Launches the step.
|
136
213
|
|
137
214
|
Raises:
|
138
|
-
|
215
|
+
RunStoppedException: If the pipeline run is stopped by the user.
|
139
216
|
"""
|
140
217
|
pipeline_run, run_was_created = self._create_or_reuse_run()
|
141
218
|
|
@@ -207,6 +284,8 @@ class StepLauncher:
|
|
207
284
|
step_run = Client().zen_store.create_run_step(
|
208
285
|
step_run_request
|
209
286
|
)
|
287
|
+
# Store step run ID for signal handler
|
288
|
+
self._step_run = step_run
|
210
289
|
if model_version := step_run.model_version:
|
211
290
|
step_run_utils.log_model_version_dashboard_url(
|
212
291
|
model_version=model_version
|
@@ -259,6 +338,8 @@ class StepLauncher:
|
|
259
338
|
force_write_logs=force_write_logs,
|
260
339
|
)
|
261
340
|
break
|
341
|
+
except RunStoppedException as e:
|
342
|
+
raise e
|
262
343
|
except BaseException as e: # noqa: E722
|
263
344
|
retries += 1
|
264
345
|
if retries < max_retries:
|
@@ -292,10 +373,11 @@ class StepLauncher:
|
|
292
373
|
artifacts=step_run.outputs,
|
293
374
|
model_version=model_version,
|
294
375
|
)
|
295
|
-
|
376
|
+
except RunStoppedException:
|
377
|
+
logger.info(f"Pipeline run `{pipeline_run.name}` stopped.")
|
378
|
+
raise
|
296
379
|
except: # noqa: E722
|
297
380
|
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
|
298
|
-
publish_utils.publish_failed_pipeline_run(pipeline_run.id)
|
299
381
|
raise
|
300
382
|
|
301
383
|
def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
|
@@ -367,8 +449,12 @@ class StepLauncher:
|
|
367
449
|
start_time = time.time()
|
368
450
|
try:
|
369
451
|
if self._step.config.step_operator:
|
452
|
+
step_operator_name = None
|
453
|
+
if isinstance(self._step.config.step_operator, str):
|
454
|
+
step_operator_name = self._step.config.step_operator
|
455
|
+
|
370
456
|
self._run_step_with_step_operator(
|
371
|
-
step_operator_name=
|
457
|
+
step_operator_name=step_operator_name,
|
372
458
|
step_run_info=step_run_info,
|
373
459
|
last_retry=last_retry,
|
374
460
|
)
|
@@ -395,7 +481,7 @@ class StepLauncher:
|
|
395
481
|
|
396
482
|
def _run_step_with_step_operator(
|
397
483
|
self,
|
398
|
-
step_operator_name: str,
|
484
|
+
step_operator_name: Optional[str],
|
399
485
|
step_run_info: StepRunInfo,
|
400
486
|
last_retry: bool,
|
401
487
|
) -> None:
|
zenml/stack/stack.py
CHANGED
@@ -849,10 +849,10 @@ class Stack:
|
|
849
849
|
If the component is used in this step.
|
850
850
|
"""
|
851
851
|
if component.type == StackComponentType.STEP_OPERATOR:
|
852
|
-
return component.name
|
852
|
+
return step_config.uses_step_operator(component.name)
|
853
853
|
|
854
854
|
if component.type == StackComponentType.EXPERIMENT_TRACKER:
|
855
|
-
return component.name
|
855
|
+
return step_config.uses_experiment_tracker(component.name)
|
856
856
|
|
857
857
|
return True
|
858
858
|
|
zenml/utils/run_utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Utility functions for runs."""
|
15
|
+
|
16
|
+
from typing import cast
|
17
|
+
|
18
|
+
from zenml.enums import ExecutionStatus
|
19
|
+
from zenml.exceptions import IllegalOperationError
|
20
|
+
from zenml.models import PipelineRunResponse
|
21
|
+
|
22
|
+
|
23
|
+
def stop_run(run: PipelineRunResponse, graceful: bool = False) -> None:
|
24
|
+
"""Stop a pipeline run.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
run: The pipeline run to stop.
|
28
|
+
graceful: Whether to stop the run gracefully.
|
29
|
+
|
30
|
+
Raises:
|
31
|
+
IllegalOperationError: If the run is already stopped or being stopped.
|
32
|
+
ValueError: If the stack is not accessible.
|
33
|
+
"""
|
34
|
+
# Check if the stack is still accessible
|
35
|
+
if run.stack is None:
|
36
|
+
raise ValueError(
|
37
|
+
"The stack that this pipeline run response was executed on "
|
38
|
+
"is either not accessible or has been deleted."
|
39
|
+
)
|
40
|
+
|
41
|
+
# Check if pipeline can be stopped
|
42
|
+
if run.status == ExecutionStatus.COMPLETED:
|
43
|
+
raise IllegalOperationError(
|
44
|
+
"Cannot stop a run that is already completed."
|
45
|
+
)
|
46
|
+
|
47
|
+
if run.status == ExecutionStatus.STOPPED:
|
48
|
+
raise IllegalOperationError("Run is already stopped.")
|
49
|
+
|
50
|
+
if run.status == ExecutionStatus.STOPPING:
|
51
|
+
raise IllegalOperationError("Run is already being stopped.")
|
52
|
+
|
53
|
+
# Create the orchestrator instance
|
54
|
+
from zenml.enums import StackComponentType
|
55
|
+
from zenml.orchestrators.base_orchestrator import BaseOrchestrator
|
56
|
+
from zenml.stack.stack_component import StackComponent
|
57
|
+
|
58
|
+
# Check if the stack is still accessible
|
59
|
+
orchestrator_list = run.stack.components.get(
|
60
|
+
StackComponentType.ORCHESTRATOR, []
|
61
|
+
)
|
62
|
+
if len(orchestrator_list) == 0:
|
63
|
+
raise ValueError(
|
64
|
+
"The orchestrator that this pipeline run response was "
|
65
|
+
"executed with is either not accessible or has been deleted."
|
66
|
+
)
|
67
|
+
|
68
|
+
orchestrator = cast(
|
69
|
+
BaseOrchestrator,
|
70
|
+
StackComponent.from_model(component_model=orchestrator_list[0]),
|
71
|
+
)
|
72
|
+
|
73
|
+
# Stop the run
|
74
|
+
orchestrator.stop_run(run=run, graceful=graceful)
|