zenml-nightly 0.83.1.dev20250701__py3-none-any.whl → 0.83.1.dev20250703__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/pipeline.py +54 -1
- zenml/cli/utils.py +2 -0
- zenml/constants.py +1 -0
- zenml/enums.py +6 -3
- zenml/exceptions.py +8 -0
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +8 -4
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +5 -3
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +7 -8
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +3 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +88 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +36 -1
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +11 -3
- zenml/models/v2/core/pipeline_run.py +2 -2
- zenml/orchestrators/base_orchestrator.py +70 -0
- zenml/orchestrators/dag_runner.py +27 -8
- zenml/orchestrators/local_docker/local_docker_orchestrator.py +9 -0
- zenml/orchestrators/publish_utils.py +100 -13
- zenml/orchestrators/step_launcher.py +86 -4
- zenml/utils/run_utils.py +74 -0
- zenml/zen_server/routers/runs_endpoints.py +27 -23
- zenml/zen_stores/sql_zen_store.py +23 -3
- {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/RECORD +27 -26
- {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/entry_points.txt +0 -0
@@ -56,6 +56,7 @@ class NodeStatus(Enum):
|
|
56
56
|
RUNNING = "running"
|
57
57
|
COMPLETED = "completed"
|
58
58
|
FAILED = "failed"
|
59
|
+
CANCELLED = "cancelled"
|
59
60
|
|
60
61
|
|
61
62
|
class ThreadedDagRunner:
|
@@ -76,6 +77,7 @@ class ThreadedDagRunner:
|
|
76
77
|
finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None,
|
77
78
|
parallel_node_startup_waiting_period: float = 0.0,
|
78
79
|
max_parallelism: Optional[int] = None,
|
80
|
+
continue_fn: Optional[Callable[[], bool]] = None,
|
79
81
|
) -> None:
|
80
82
|
"""Define attributes and initialize all nodes in waiting state.
|
81
83
|
|
@@ -92,6 +94,9 @@ class ThreadedDagRunner:
|
|
92
94
|
parallel_node_startup_waiting_period: Delay in seconds to wait in
|
93
95
|
between starting parallel nodes.
|
94
96
|
max_parallelism: Maximum number of nodes to run in parallel
|
97
|
+
continue_fn: A function that returns True if the run should continue
|
98
|
+
after each step execution, False if it should stop (e.g., due
|
99
|
+
to cancellation). If None, execution continues normally.
|
95
100
|
|
96
101
|
Raises:
|
97
102
|
ValueError: If max_parallelism is not greater than 0.
|
@@ -108,12 +113,15 @@ class ThreadedDagRunner:
|
|
108
113
|
self.run_fn = run_fn
|
109
114
|
self.preparation_fn = preparation_fn
|
110
115
|
self.finalize_fn = finalize_fn
|
116
|
+
self.continue_fn = continue_fn
|
111
117
|
self.nodes = dag.keys()
|
112
118
|
self.node_states = {
|
113
119
|
node: NodeStatus.NOT_STARTED for node in self.nodes
|
114
120
|
}
|
115
121
|
self._lock = threading.Lock()
|
116
122
|
|
123
|
+
self._stop_requested = False
|
124
|
+
|
117
125
|
def _can_run(self, node: str) -> bool:
|
118
126
|
"""Determine whether a node is ready to be run.
|
119
127
|
|
@@ -173,6 +181,15 @@ class ThreadedDagRunner:
|
|
173
181
|
"""
|
174
182
|
self._prepare_node_run(node)
|
175
183
|
|
184
|
+
# Check if execution should continue (e.g., check for cancellation)
|
185
|
+
if self.continue_fn:
|
186
|
+
self._stop_requested = (
|
187
|
+
self._stop_requested or not self.continue_fn()
|
188
|
+
)
|
189
|
+
if self._stop_requested:
|
190
|
+
self._finish_node(node, cancelled=True)
|
191
|
+
return
|
192
|
+
|
176
193
|
if self.preparation_fn:
|
177
194
|
run_required = self.preparation_fn(node)
|
178
195
|
if not run_required:
|
@@ -204,24 +221,26 @@ class ThreadedDagRunner:
|
|
204
221
|
thread.start()
|
205
222
|
return thread
|
206
223
|
|
207
|
-
def _finish_node(
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
Then starts all other nodes that can now be run and waits for them.
|
224
|
+
def _finish_node(
|
225
|
+
self, node: str, failed: bool = False, cancelled: bool = False
|
226
|
+
) -> None:
|
227
|
+
"""Mark a node as finished and potentially start new nodes.
|
212
228
|
|
213
229
|
Args:
|
214
|
-
node: The node.
|
230
|
+
node: The node to mark as finished.
|
215
231
|
failed: Whether the node failed.
|
232
|
+
cancelled: Whether the node was cancelled.
|
216
233
|
"""
|
217
234
|
with self._lock:
|
218
235
|
if failed:
|
219
236
|
self.node_states[node] = NodeStatus.FAILED
|
237
|
+
elif cancelled:
|
238
|
+
self.node_states[node] = NodeStatus.CANCELLED
|
220
239
|
else:
|
221
240
|
self.node_states[node] = NodeStatus.COMPLETED
|
222
241
|
|
223
|
-
if failed:
|
224
|
-
# If the node failed, we don't need to run any downstream nodes.
|
242
|
+
if failed or cancelled:
|
243
|
+
# If the node failed or was cancelled, we don't need to run any downstream nodes.
|
225
244
|
return
|
226
245
|
|
227
246
|
# Run downstream nodes.
|
@@ -63,6 +63,15 @@ class LocalDockerOrchestrator(ContainerizedOrchestrator):
|
|
63
63
|
"""
|
64
64
|
return LocalDockerOrchestratorSettings
|
65
65
|
|
66
|
+
@property
|
67
|
+
def config(self) -> "LocalDockerOrchestratorConfig":
|
68
|
+
"""Returns the `LocalDockerOrchestratorConfig` config.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The configuration.
|
72
|
+
"""
|
73
|
+
return cast(LocalDockerOrchestratorConfig, self._config)
|
74
|
+
|
66
75
|
@property
|
67
76
|
def validator(self) -> Optional[StackValidator]:
|
68
77
|
"""Ensures there is an image builder in the stack.
|
@@ -13,7 +13,8 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Utilities to publish pipeline and step runs."""
|
15
15
|
|
16
|
-
from
|
16
|
+
from datetime import datetime
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional
|
17
18
|
|
18
19
|
from zenml.client import Client
|
19
20
|
from zenml.enums import ExecutionStatus, MetadataResourceTypes
|
@@ -54,6 +55,40 @@ def publish_successful_step_run(
|
|
54
55
|
)
|
55
56
|
|
56
57
|
|
58
|
+
def publish_step_run_status_update(
|
59
|
+
step_run_id: "UUID",
|
60
|
+
status: "ExecutionStatus",
|
61
|
+
end_time: Optional[datetime] = None,
|
62
|
+
) -> "StepRunResponse":
|
63
|
+
"""Publishes a step run update.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
step_run_id: ID of the step run.
|
67
|
+
status: New status of the step run.
|
68
|
+
end_time: New end time of the step run.
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
The updated step run.
|
72
|
+
|
73
|
+
Raises:
|
74
|
+
ValueError: If the end time is set for a non-finished step run.
|
75
|
+
"""
|
76
|
+
from zenml.client import Client
|
77
|
+
|
78
|
+
if end_time is not None and not status.is_finished:
|
79
|
+
raise ValueError("End time cannot be set for a non-finished step run.")
|
80
|
+
|
81
|
+
step_run = Client().zen_store.update_run_step(
|
82
|
+
step_run_id=step_run_id,
|
83
|
+
step_run_update=StepRunUpdate(
|
84
|
+
status=status,
|
85
|
+
end_time=end_time,
|
86
|
+
),
|
87
|
+
)
|
88
|
+
|
89
|
+
return step_run
|
90
|
+
|
91
|
+
|
57
92
|
def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
|
58
93
|
"""Publishes a failed step run.
|
59
94
|
|
@@ -63,12 +98,10 @@ def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
|
|
63
98
|
Returns:
|
64
99
|
The updated step run.
|
65
100
|
"""
|
66
|
-
return
|
101
|
+
return publish_step_run_status_update(
|
67
102
|
step_run_id=step_run_id,
|
68
|
-
|
69
|
-
|
70
|
-
end_time=utc_now(),
|
71
|
-
),
|
103
|
+
status=ExecutionStatus.FAILED,
|
104
|
+
end_time=utc_now(),
|
72
105
|
)
|
73
106
|
|
74
107
|
|
@@ -92,27 +125,81 @@ def publish_failed_pipeline_run(
|
|
92
125
|
)
|
93
126
|
|
94
127
|
|
128
|
+
def publish_pipeline_run_status_update(
|
129
|
+
pipeline_run_id: "UUID",
|
130
|
+
status: ExecutionStatus,
|
131
|
+
end_time: Optional[datetime] = None,
|
132
|
+
) -> "PipelineRunResponse":
|
133
|
+
"""Publishes a pipeline run status update.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
pipeline_run_id: The ID of the pipeline run to update.
|
137
|
+
status: The new status for the pipeline run.
|
138
|
+
end_time: The end time for the pipeline run. If None, will be set to current time
|
139
|
+
for finished statuses.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
The updated pipeline run.
|
143
|
+
"""
|
144
|
+
if end_time is None and status.is_finished:
|
145
|
+
end_time = utc_now()
|
146
|
+
|
147
|
+
return Client().zen_store.update_run(
|
148
|
+
run_id=pipeline_run_id,
|
149
|
+
run_update=PipelineRunUpdate(
|
150
|
+
status=status,
|
151
|
+
end_time=end_time,
|
152
|
+
),
|
153
|
+
)
|
154
|
+
|
155
|
+
|
95
156
|
def get_pipeline_run_status(
|
96
|
-
|
157
|
+
run_status: ExecutionStatus,
|
158
|
+
step_statuses: List[ExecutionStatus],
|
159
|
+
num_steps: int,
|
97
160
|
) -> ExecutionStatus:
|
98
161
|
"""Gets the pipeline run status for the given step statuses.
|
99
162
|
|
100
163
|
Args:
|
164
|
+
run_status: The status of the run.
|
101
165
|
step_statuses: The status of steps in this run.
|
102
166
|
num_steps: The total amount of steps in this run.
|
103
167
|
|
104
168
|
Returns:
|
105
169
|
The run status.
|
106
170
|
"""
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
171
|
+
# STOPPING state
|
172
|
+
if run_status == ExecutionStatus.STOPPING:
|
173
|
+
if all(status.is_finished for status in step_statuses):
|
174
|
+
return ExecutionStatus.STOPPED
|
175
|
+
else:
|
176
|
+
return ExecutionStatus.STOPPING
|
177
|
+
|
178
|
+
# If there is a stopped step, the run is stopped or stopping
|
179
|
+
if ExecutionStatus.STOPPED in step_statuses:
|
180
|
+
if all(status.is_finished for status in step_statuses):
|
181
|
+
return ExecutionStatus.STOPPED
|
182
|
+
else:
|
183
|
+
return ExecutionStatus.STOPPING
|
184
|
+
|
185
|
+
# Otherwise, if there is a failed step, the run is failed
|
186
|
+
elif (
|
187
|
+
ExecutionStatus.FAILED in step_statuses
|
188
|
+
or run_status == ExecutionStatus.FAILED
|
112
189
|
):
|
190
|
+
return ExecutionStatus.FAILED
|
191
|
+
|
192
|
+
# If there is a running step, the run is running
|
193
|
+
elif ExecutionStatus.RUNNING in step_statuses:
|
194
|
+
return ExecutionStatus.RUNNING
|
195
|
+
|
196
|
+
# If there are less steps than the total number of steps, it is running
|
197
|
+
elif len(step_statuses) < num_steps:
|
113
198
|
return ExecutionStatus.RUNNING
|
114
199
|
|
115
|
-
|
200
|
+
# Any other state is completed
|
201
|
+
else:
|
202
|
+
return ExecutionStatus.COMPLETED
|
116
203
|
|
117
204
|
|
118
205
|
def publish_pipeline_run_metadata(
|
@@ -14,10 +14,11 @@
|
|
14
14
|
"""Class to launch (run directly or using a step operator) steps."""
|
15
15
|
|
16
16
|
import os
|
17
|
+
import signal
|
17
18
|
import time
|
18
19
|
from contextlib import nullcontext
|
19
20
|
from functools import partial
|
20
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
|
21
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
|
21
22
|
|
22
23
|
from zenml.client import Client
|
23
24
|
from zenml.config.step_configurations import Step
|
@@ -29,6 +30,7 @@ from zenml.constants import (
|
|
29
30
|
)
|
30
31
|
from zenml.enums import ExecutionStatus
|
31
32
|
from zenml.environment import get_run_environment_dict
|
33
|
+
from zenml.exceptions import RunInterruptedException, RunStoppedException
|
32
34
|
from zenml.logger import get_logger
|
33
35
|
from zenml.logging import step_logging
|
34
36
|
from zenml.models import (
|
@@ -131,11 +133,86 @@ class StepLauncher:
|
|
131
133
|
self._stack = Stack.from_model(deployment.stack)
|
132
134
|
self._step_name = step.spec.pipeline_parameter_name
|
133
135
|
|
136
|
+
# Internal properties and methods
|
137
|
+
self._step_run: Optional[StepRunResponse] = None
|
138
|
+
self._setup_signal_handlers()
|
139
|
+
|
140
|
+
def _setup_signal_handlers(self) -> None:
|
141
|
+
"""Set up signal handlers for graceful shutdown, chaining previous handlers."""
|
142
|
+
# Save previous handlers
|
143
|
+
self._prev_sigterm_handler = signal.getsignal(signal.SIGTERM)
|
144
|
+
self._prev_sigint_handler = signal.getsignal(signal.SIGINT)
|
145
|
+
|
146
|
+
def signal_handler(signum: int, frame: Any) -> None:
|
147
|
+
"""Handle shutdown signals gracefully.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
signum: The signal number.
|
151
|
+
frame: The frame of the signal handler.
|
152
|
+
|
153
|
+
Raises:
|
154
|
+
RunStoppedException: If the pipeline run is stopped by the user.
|
155
|
+
RunInterruptedException: If the execution is interrupted for any
|
156
|
+
other reason.
|
157
|
+
"""
|
158
|
+
logger.info(
|
159
|
+
f"Received signal shutdown {signum}. Requesting shutdown "
|
160
|
+
f"for step '{self._step_name}'..."
|
161
|
+
)
|
162
|
+
|
163
|
+
try:
|
164
|
+
client = Client()
|
165
|
+
pipeline_run = None
|
166
|
+
|
167
|
+
if self._step_run:
|
168
|
+
pipeline_run = client.get_pipeline_run(
|
169
|
+
self._step_run.pipeline_run_id
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
raise RunInterruptedException(
|
173
|
+
"The execution was interrupted and the step does not "
|
174
|
+
"exist yet."
|
175
|
+
)
|
176
|
+
|
177
|
+
if pipeline_run and pipeline_run.status in [
|
178
|
+
ExecutionStatus.STOPPING,
|
179
|
+
ExecutionStatus.STOPPED,
|
180
|
+
]:
|
181
|
+
if self._step_run:
|
182
|
+
publish_utils.publish_step_run_status_update(
|
183
|
+
step_run_id=self._step_run.id,
|
184
|
+
status=ExecutionStatus.STOPPED,
|
185
|
+
end_time=utc_now(),
|
186
|
+
)
|
187
|
+
raise RunStoppedException("Pipeline run in stopped.")
|
188
|
+
else:
|
189
|
+
raise RunInterruptedException(
|
190
|
+
"The execution was interrupted."
|
191
|
+
)
|
192
|
+
except (RunStoppedException, RunInterruptedException):
|
193
|
+
raise
|
194
|
+
except Exception as e:
|
195
|
+
raise RunInterruptedException(str(e))
|
196
|
+
finally:
|
197
|
+
# Chain to previous handler if it exists and is not default/ignore
|
198
|
+
if signum == signal.SIGTERM and callable(
|
199
|
+
self._prev_sigterm_handler
|
200
|
+
):
|
201
|
+
self._prev_sigterm_handler(signum, frame)
|
202
|
+
elif signum == signal.SIGINT and callable(
|
203
|
+
self._prev_sigint_handler
|
204
|
+
):
|
205
|
+
self._prev_sigint_handler(signum, frame)
|
206
|
+
|
207
|
+
# Register handlers for common termination signals
|
208
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
209
|
+
signal.signal(signal.SIGINT, signal_handler)
|
210
|
+
|
134
211
|
def launch(self) -> None:
|
135
212
|
"""Launches the step.
|
136
213
|
|
137
214
|
Raises:
|
138
|
-
|
215
|
+
RunStoppedException: If the pipeline run is stopped by the user.
|
139
216
|
"""
|
140
217
|
pipeline_run, run_was_created = self._create_or_reuse_run()
|
141
218
|
|
@@ -207,6 +284,8 @@ class StepLauncher:
|
|
207
284
|
step_run = Client().zen_store.create_run_step(
|
208
285
|
step_run_request
|
209
286
|
)
|
287
|
+
# Store step run ID for signal handler
|
288
|
+
self._step_run = step_run
|
210
289
|
if model_version := step_run.model_version:
|
211
290
|
step_run_utils.log_model_version_dashboard_url(
|
212
291
|
model_version=model_version
|
@@ -259,6 +338,8 @@ class StepLauncher:
|
|
259
338
|
force_write_logs=force_write_logs,
|
260
339
|
)
|
261
340
|
break
|
341
|
+
except RunStoppedException as e:
|
342
|
+
raise e
|
262
343
|
except BaseException as e: # noqa: E722
|
263
344
|
retries += 1
|
264
345
|
if retries < max_retries:
|
@@ -292,10 +373,11 @@ class StepLauncher:
|
|
292
373
|
artifacts=step_run.outputs,
|
293
374
|
model_version=model_version,
|
294
375
|
)
|
295
|
-
|
376
|
+
except RunStoppedException:
|
377
|
+
logger.info(f"Pipeline run `{pipeline_run.name}` stopped.")
|
378
|
+
raise
|
296
379
|
except: # noqa: E722
|
297
380
|
logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
|
298
|
-
publish_utils.publish_failed_pipeline_run(pipeline_run.id)
|
299
381
|
raise
|
300
382
|
|
301
383
|
def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
|
zenml/utils/run_utils.py
ADDED
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Utility functions for runs."""
|
15
|
+
|
16
|
+
from typing import cast
|
17
|
+
|
18
|
+
from zenml.enums import ExecutionStatus
|
19
|
+
from zenml.exceptions import IllegalOperationError
|
20
|
+
from zenml.models import PipelineRunResponse
|
21
|
+
|
22
|
+
|
23
|
+
def stop_run(run: PipelineRunResponse, graceful: bool = False) -> None:
|
24
|
+
"""Stop a pipeline run.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
run: The pipeline run to stop.
|
28
|
+
graceful: Whether to stop the run gracefully.
|
29
|
+
|
30
|
+
Raises:
|
31
|
+
IllegalOperationError: If the run is already stopped or being stopped.
|
32
|
+
ValueError: If the stack is not accessible.
|
33
|
+
"""
|
34
|
+
# Check if the stack is still accessible
|
35
|
+
if run.stack is None:
|
36
|
+
raise ValueError(
|
37
|
+
"The stack that this pipeline run response was executed on "
|
38
|
+
"is either not accessible or has been deleted."
|
39
|
+
)
|
40
|
+
|
41
|
+
# Check if pipeline can be stopped
|
42
|
+
if run.status == ExecutionStatus.COMPLETED:
|
43
|
+
raise IllegalOperationError(
|
44
|
+
"Cannot stop a run that is already completed."
|
45
|
+
)
|
46
|
+
|
47
|
+
if run.status == ExecutionStatus.STOPPED:
|
48
|
+
raise IllegalOperationError("Run is already stopped.")
|
49
|
+
|
50
|
+
if run.status == ExecutionStatus.STOPPING:
|
51
|
+
raise IllegalOperationError("Run is already being stopped.")
|
52
|
+
|
53
|
+
# Create the orchestrator instance
|
54
|
+
from zenml.enums import StackComponentType
|
55
|
+
from zenml.orchestrators.base_orchestrator import BaseOrchestrator
|
56
|
+
from zenml.stack.stack_component import StackComponent
|
57
|
+
|
58
|
+
# Check if the stack is still accessible
|
59
|
+
orchestrator_list = run.stack.components.get(
|
60
|
+
StackComponentType.ORCHESTRATOR, []
|
61
|
+
)
|
62
|
+
if len(orchestrator_list) == 0:
|
63
|
+
raise ValueError(
|
64
|
+
"The orchestrator that this pipeline run response was "
|
65
|
+
"executed with is either not accessible or has been deleted."
|
66
|
+
)
|
67
|
+
|
68
|
+
orchestrator = cast(
|
69
|
+
BaseOrchestrator,
|
70
|
+
StackComponent.from_model(component_model=orchestrator_list[0]),
|
71
|
+
)
|
72
|
+
|
73
|
+
# Stop the run
|
74
|
+
orchestrator.stop_run(run=run, graceful=graceful)
|
@@ -25,6 +25,7 @@ from zenml.constants import (
|
|
25
25
|
RUNS,
|
26
26
|
STATUS,
|
27
27
|
STEPS,
|
28
|
+
STOP,
|
28
29
|
VERSION_1,
|
29
30
|
)
|
30
31
|
from zenml.enums import ExecutionStatus, StackComponentType
|
@@ -40,6 +41,7 @@ from zenml.models import (
|
|
40
41
|
StepRunFilter,
|
41
42
|
StepRunResponse,
|
42
43
|
)
|
44
|
+
from zenml.utils import run_utils
|
43
45
|
from zenml.zen_server.auth import AuthContext, authorize
|
44
46
|
from zenml.zen_server.exceptions import error_response
|
45
47
|
from zenml.zen_server.rbac.endpoint_utils import (
|
@@ -51,6 +53,7 @@ from zenml.zen_server.rbac.endpoint_utils import (
|
|
51
53
|
)
|
52
54
|
from zenml.zen_server.rbac.models import Action, ResourceType
|
53
55
|
from zenml.zen_server.rbac.utils import (
|
56
|
+
dehydrate_response_model,
|
54
57
|
verify_permission_for_model,
|
55
58
|
)
|
56
59
|
from zenml.zen_server.routers.projects_endpoints import workspace_router
|
@@ -389,38 +392,39 @@ def refresh_run_status(
|
|
389
392
|
|
390
393
|
Args:
|
391
394
|
run_id: ID of the pipeline run to refresh.
|
392
|
-
|
393
|
-
Raises:
|
394
|
-
RuntimeError: If the stack or the orchestrator of the run is deleted.
|
395
395
|
"""
|
396
|
-
# Verify access to the run
|
397
396
|
run = verify_permissions_and_get_entity(
|
398
397
|
id=run_id,
|
399
398
|
get_method=zen_store().get_run,
|
400
399
|
hydrate=True,
|
401
400
|
)
|
402
|
-
|
403
|
-
# Check the stack and its orchestrator
|
404
|
-
if run.stack is not None:
|
405
|
-
orchestrators = run.stack.components.get(
|
406
|
-
StackComponentType.ORCHESTRATOR, []
|
407
|
-
)
|
408
|
-
if orchestrators:
|
409
|
-
verify_permission_for_model(
|
410
|
-
model=orchestrators[0], action=Action.READ
|
411
|
-
)
|
412
|
-
else:
|
413
|
-
raise RuntimeError(
|
414
|
-
f"The orchestrator, the run '{run.id}' was executed with, is "
|
415
|
-
"deleted."
|
416
|
-
)
|
417
|
-
else:
|
418
|
-
raise RuntimeError(
|
419
|
-
f"The stack, the run '{run.id}' was executed on, is deleted."
|
420
|
-
)
|
421
401
|
run.refresh_run_status()
|
422
402
|
|
423
403
|
|
404
|
+
@router.post(
|
405
|
+
"/{run_id}" + STOP,
|
406
|
+
responses={401: error_response, 404: error_response, 422: error_response},
|
407
|
+
)
|
408
|
+
@async_fastapi_endpoint_wrapper
|
409
|
+
def stop_run(
|
410
|
+
run_id: UUID,
|
411
|
+
graceful: bool = False,
|
412
|
+
_: AuthContext = Security(authorize),
|
413
|
+
) -> None:
|
414
|
+
"""Stops a specific pipeline run.
|
415
|
+
|
416
|
+
Args:
|
417
|
+
run_id: ID of the pipeline run to stop.
|
418
|
+
graceful: If True, allows for graceful shutdown where possible.
|
419
|
+
If False, forces immediate termination. Default is False.
|
420
|
+
"""
|
421
|
+
run = zen_store().get_run(run_id, hydrate=True)
|
422
|
+
verify_permission_for_model(run, action=Action.READ)
|
423
|
+
verify_permission_for_model(run, action=Action.UPDATE)
|
424
|
+
dehydrate_response_model(run)
|
425
|
+
run_utils.stop_run(run=run, graceful=graceful)
|
426
|
+
|
427
|
+
|
424
428
|
@router.get(
|
425
429
|
"/{run_id}/logs",
|
426
430
|
responses={
|
@@ -6119,7 +6119,14 @@ class SqlZenStore(BaseZenStore):
|
|
6119
6119
|
resources=existing_run,
|
6120
6120
|
session=session,
|
6121
6121
|
)
|
6122
|
+
|
6123
|
+
if run_update.status is not None:
|
6124
|
+
self._update_pipeline_run_status(
|
6125
|
+
pipeline_run_id=run_id,
|
6126
|
+
session=session,
|
6127
|
+
)
|
6122
6128
|
session.refresh(existing_run)
|
6129
|
+
|
6123
6130
|
return existing_run.to_model(
|
6124
6131
|
include_metadata=True, include_resources=True
|
6125
6132
|
)
|
@@ -8824,6 +8831,7 @@ class SqlZenStore(BaseZenStore):
|
|
8824
8831
|
|
8825
8832
|
Raises:
|
8826
8833
|
EntityExistsError: if the step run already exists.
|
8834
|
+
IllegalOperationError: if the pipeline run is stopped or stopping.
|
8827
8835
|
"""
|
8828
8836
|
with Session(self.engine) as session:
|
8829
8837
|
self._set_request_user_id(request_model=step_run, session=session)
|
@@ -8835,6 +8843,16 @@ class SqlZenStore(BaseZenStore):
|
|
8835
8843
|
reference_id=step_run.pipeline_run_id,
|
8836
8844
|
session=session,
|
8837
8845
|
)
|
8846
|
+
|
8847
|
+
# Validate pipeline status before creating step
|
8848
|
+
if run.status in [
|
8849
|
+
ExecutionStatus.STOPPING,
|
8850
|
+
ExecutionStatus.STOPPED,
|
8851
|
+
]:
|
8852
|
+
raise IllegalOperationError(
|
8853
|
+
f"Cannot create step '{step_run.name}' for pipeline in "
|
8854
|
+
f"{run.status} state. Pipeline run ID: {step_run.pipeline_run_id}"
|
8855
|
+
)
|
8838
8856
|
self._get_reference_schema_by_id(
|
8839
8857
|
resource=step_run,
|
8840
8858
|
reference_schema=StepRunSchema,
|
@@ -8996,6 +9014,8 @@ class SqlZenStore(BaseZenStore):
|
|
8996
9014
|
session=session,
|
8997
9015
|
)
|
8998
9016
|
|
9017
|
+
session.commit()
|
9018
|
+
|
8999
9019
|
if step_run.status != ExecutionStatus.RUNNING:
|
9000
9020
|
self._update_pipeline_run_status(
|
9001
9021
|
pipeline_run_id=step_run.pipeline_run_id, session=session
|
@@ -9130,15 +9150,14 @@ class SqlZenStore(BaseZenStore):
|
|
9130
9150
|
input_type=StepRunInputArtifactType.MANUAL,
|
9131
9151
|
session=session,
|
9132
9152
|
)
|
9153
|
+
session.commit()
|
9154
|
+
session.refresh(existing_step_run)
|
9133
9155
|
|
9134
9156
|
self._update_pipeline_run_status(
|
9135
9157
|
pipeline_run_id=existing_step_run.pipeline_run_id,
|
9136
9158
|
session=session,
|
9137
9159
|
)
|
9138
9160
|
|
9139
|
-
session.commit()
|
9140
|
-
session.refresh(existing_step_run)
|
9141
|
-
|
9142
9161
|
return existing_step_run.to_model(
|
9143
9162
|
include_metadata=True, include_resources=True
|
9144
9163
|
)
|
@@ -9375,6 +9394,7 @@ class SqlZenStore(BaseZenStore):
|
|
9375
9394
|
assert pipeline_run.deployment
|
9376
9395
|
num_steps = pipeline_run.deployment.step_count
|
9377
9396
|
new_status = get_pipeline_run_status(
|
9397
|
+
run_status=ExecutionStatus(pipeline_run.status),
|
9378
9398
|
step_statuses=[
|
9379
9399
|
ExecutionStatus(status) for status in step_run_statuses
|
9380
9400
|
],
|