zenml-nightly 0.83.1.dev20250701__py3-none-any.whl → 0.83.1.dev20250703__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/pipeline.py +54 -1
  3. zenml/cli/utils.py +2 -0
  4. zenml/constants.py +1 -0
  5. zenml/enums.py +6 -3
  6. zenml/exceptions.py +8 -0
  7. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +8 -4
  8. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +5 -3
  9. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +7 -8
  10. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +3 -0
  11. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +88 -0
  12. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +36 -1
  13. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +11 -3
  14. zenml/models/v2/core/pipeline_run.py +2 -2
  15. zenml/orchestrators/base_orchestrator.py +70 -0
  16. zenml/orchestrators/dag_runner.py +27 -8
  17. zenml/orchestrators/local_docker/local_docker_orchestrator.py +9 -0
  18. zenml/orchestrators/publish_utils.py +100 -13
  19. zenml/orchestrators/step_launcher.py +86 -4
  20. zenml/utils/run_utils.py +74 -0
  21. zenml/zen_server/routers/runs_endpoints.py +27 -23
  22. zenml/zen_stores/sql_zen_store.py +23 -3
  23. {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/METADATA +1 -1
  24. {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/RECORD +27 -26
  25. {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/LICENSE +0 -0
  26. {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/WHEEL +0 -0
  27. {zenml_nightly-0.83.1.dev20250701.dist-info → zenml_nightly-0.83.1.dev20250703.dist-info}/entry_points.txt +0 -0
@@ -56,6 +56,7 @@ class NodeStatus(Enum):
56
56
  RUNNING = "running"
57
57
  COMPLETED = "completed"
58
58
  FAILED = "failed"
59
+ CANCELLED = "cancelled"
59
60
 
60
61
 
61
62
  class ThreadedDagRunner:
@@ -76,6 +77,7 @@ class ThreadedDagRunner:
76
77
  finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None,
77
78
  parallel_node_startup_waiting_period: float = 0.0,
78
79
  max_parallelism: Optional[int] = None,
80
+ continue_fn: Optional[Callable[[], bool]] = None,
79
81
  ) -> None:
80
82
  """Define attributes and initialize all nodes in waiting state.
81
83
 
@@ -92,6 +94,9 @@ class ThreadedDagRunner:
92
94
  parallel_node_startup_waiting_period: Delay in seconds to wait in
93
95
  between starting parallel nodes.
94
96
  max_parallelism: Maximum number of nodes to run in parallel
97
+ continue_fn: A function that returns True if the run should continue
98
+ after each step execution, False if it should stop (e.g., due
99
+ to cancellation). If None, execution continues normally.
95
100
 
96
101
  Raises:
97
102
  ValueError: If max_parallelism is not greater than 0.
@@ -108,12 +113,15 @@ class ThreadedDagRunner:
108
113
  self.run_fn = run_fn
109
114
  self.preparation_fn = preparation_fn
110
115
  self.finalize_fn = finalize_fn
116
+ self.continue_fn = continue_fn
111
117
  self.nodes = dag.keys()
112
118
  self.node_states = {
113
119
  node: NodeStatus.NOT_STARTED for node in self.nodes
114
120
  }
115
121
  self._lock = threading.Lock()
116
122
 
123
+ self._stop_requested = False
124
+
117
125
  def _can_run(self, node: str) -> bool:
118
126
  """Determine whether a node is ready to be run.
119
127
 
@@ -173,6 +181,15 @@ class ThreadedDagRunner:
173
181
  """
174
182
  self._prepare_node_run(node)
175
183
 
184
+ # Check if execution should continue (e.g., check for cancellation)
185
+ if self.continue_fn:
186
+ self._stop_requested = (
187
+ self._stop_requested or not self.continue_fn()
188
+ )
189
+ if self._stop_requested:
190
+ self._finish_node(node, cancelled=True)
191
+ return
192
+
176
193
  if self.preparation_fn:
177
194
  run_required = self.preparation_fn(node)
178
195
  if not run_required:
@@ -204,24 +221,26 @@ class ThreadedDagRunner:
204
221
  thread.start()
205
222
  return thread
206
223
 
207
- def _finish_node(self, node: str, failed: bool = False) -> None:
208
- """Finish a node run.
209
-
210
- First updates the node status to completed.
211
- Then starts all other nodes that can now be run and waits for them.
224
+ def _finish_node(
225
+ self, node: str, failed: bool = False, cancelled: bool = False
226
+ ) -> None:
227
+ """Mark a node as finished and potentially start new nodes.
212
228
 
213
229
  Args:
214
- node: The node.
230
+ node: The node to mark as finished.
215
231
  failed: Whether the node failed.
232
+ cancelled: Whether the node was cancelled.
216
233
  """
217
234
  with self._lock:
218
235
  if failed:
219
236
  self.node_states[node] = NodeStatus.FAILED
237
+ elif cancelled:
238
+ self.node_states[node] = NodeStatus.CANCELLED
220
239
  else:
221
240
  self.node_states[node] = NodeStatus.COMPLETED
222
241
 
223
- if failed:
224
- # If the node failed, we don't need to run any downstream nodes.
242
+ if failed or cancelled:
243
+ # If the node failed or was cancelled, we don't need to run any downstream nodes.
225
244
  return
226
245
 
227
246
  # Run downstream nodes.
@@ -63,6 +63,15 @@ class LocalDockerOrchestrator(ContainerizedOrchestrator):
63
63
  """
64
64
  return LocalDockerOrchestratorSettings
65
65
 
66
+ @property
67
+ def config(self) -> "LocalDockerOrchestratorConfig":
68
+ """Returns the `LocalDockerOrchestratorConfig` config.
69
+
70
+ Returns:
71
+ The configuration.
72
+ """
73
+ return cast(LocalDockerOrchestratorConfig, self._config)
74
+
66
75
  @property
67
76
  def validator(self) -> Optional[StackValidator]:
68
77
  """Ensures there is an image builder in the stack.
@@ -13,7 +13,8 @@
13
13
  # permissions and limitations under the License.
14
14
  """Utilities to publish pipeline and step runs."""
15
15
 
16
- from typing import TYPE_CHECKING, Dict, List
16
+ from datetime import datetime
17
+ from typing import TYPE_CHECKING, Dict, List, Optional
17
18
 
18
19
  from zenml.client import Client
19
20
  from zenml.enums import ExecutionStatus, MetadataResourceTypes
@@ -54,6 +55,40 @@ def publish_successful_step_run(
54
55
  )
55
56
 
56
57
 
58
+ def publish_step_run_status_update(
59
+ step_run_id: "UUID",
60
+ status: "ExecutionStatus",
61
+ end_time: Optional[datetime] = None,
62
+ ) -> "StepRunResponse":
63
+ """Publishes a step run update.
64
+
65
+ Args:
66
+ step_run_id: ID of the step run.
67
+ status: New status of the step run.
68
+ end_time: New end time of the step run.
69
+
70
+ Returns:
71
+ The updated step run.
72
+
73
+ Raises:
74
+ ValueError: If the end time is set for a non-finished step run.
75
+ """
76
+ from zenml.client import Client
77
+
78
+ if end_time is not None and not status.is_finished:
79
+ raise ValueError("End time cannot be set for a non-finished step run.")
80
+
81
+ step_run = Client().zen_store.update_run_step(
82
+ step_run_id=step_run_id,
83
+ step_run_update=StepRunUpdate(
84
+ status=status,
85
+ end_time=end_time,
86
+ ),
87
+ )
88
+
89
+ return step_run
90
+
91
+
57
92
  def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
58
93
  """Publishes a failed step run.
59
94
 
@@ -63,12 +98,10 @@ def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
63
98
  Returns:
64
99
  The updated step run.
65
100
  """
66
- return Client().zen_store.update_run_step(
101
+ return publish_step_run_status_update(
67
102
  step_run_id=step_run_id,
68
- step_run_update=StepRunUpdate(
69
- status=ExecutionStatus.FAILED,
70
- end_time=utc_now(),
71
- ),
103
+ status=ExecutionStatus.FAILED,
104
+ end_time=utc_now(),
72
105
  )
73
106
 
74
107
 
@@ -92,27 +125,81 @@ def publish_failed_pipeline_run(
92
125
  )
93
126
 
94
127
 
128
+ def publish_pipeline_run_status_update(
129
+ pipeline_run_id: "UUID",
130
+ status: ExecutionStatus,
131
+ end_time: Optional[datetime] = None,
132
+ ) -> "PipelineRunResponse":
133
+ """Publishes a pipeline run status update.
134
+
135
+ Args:
136
+ pipeline_run_id: The ID of the pipeline run to update.
137
+ status: The new status for the pipeline run.
138
+ end_time: The end time for the pipeline run. If None, will be set to current time
139
+ for finished statuses.
140
+
141
+ Returns:
142
+ The updated pipeline run.
143
+ """
144
+ if end_time is None and status.is_finished:
145
+ end_time = utc_now()
146
+
147
+ return Client().zen_store.update_run(
148
+ run_id=pipeline_run_id,
149
+ run_update=PipelineRunUpdate(
150
+ status=status,
151
+ end_time=end_time,
152
+ ),
153
+ )
154
+
155
+
95
156
  def get_pipeline_run_status(
96
- step_statuses: List[ExecutionStatus], num_steps: int
157
+ run_status: ExecutionStatus,
158
+ step_statuses: List[ExecutionStatus],
159
+ num_steps: int,
97
160
  ) -> ExecutionStatus:
98
161
  """Gets the pipeline run status for the given step statuses.
99
162
 
100
163
  Args:
164
+ run_status: The status of the run.
101
165
  step_statuses: The status of steps in this run.
102
166
  num_steps: The total amount of steps in this run.
103
167
 
104
168
  Returns:
105
169
  The run status.
106
170
  """
107
- if ExecutionStatus.FAILED in step_statuses:
108
- return ExecutionStatus.FAILED
109
- if (
110
- ExecutionStatus.RUNNING in step_statuses
111
- or len(step_statuses) < num_steps
171
+ # STOPPING state
172
+ if run_status == ExecutionStatus.STOPPING:
173
+ if all(status.is_finished for status in step_statuses):
174
+ return ExecutionStatus.STOPPED
175
+ else:
176
+ return ExecutionStatus.STOPPING
177
+
178
+ # If there is a stopped step, the run is stopped or stopping
179
+ if ExecutionStatus.STOPPED in step_statuses:
180
+ if all(status.is_finished for status in step_statuses):
181
+ return ExecutionStatus.STOPPED
182
+ else:
183
+ return ExecutionStatus.STOPPING
184
+
185
+ # Otherwise, if there is a failed step, the run is failed
186
+ elif (
187
+ ExecutionStatus.FAILED in step_statuses
188
+ or run_status == ExecutionStatus.FAILED
112
189
  ):
190
+ return ExecutionStatus.FAILED
191
+
192
+ # If there is a running step, the run is running
193
+ elif ExecutionStatus.RUNNING in step_statuses:
194
+ return ExecutionStatus.RUNNING
195
+
196
+ # If there are less steps than the total number of steps, it is running
197
+ elif len(step_statuses) < num_steps:
113
198
  return ExecutionStatus.RUNNING
114
199
 
115
- return ExecutionStatus.COMPLETED
200
+ # Any other state is completed
201
+ else:
202
+ return ExecutionStatus.COMPLETED
116
203
 
117
204
 
118
205
  def publish_pipeline_run_metadata(
@@ -14,10 +14,11 @@
14
14
  """Class to launch (run directly or using a step operator) steps."""
15
15
 
16
16
  import os
17
+ import signal
17
18
  import time
18
19
  from contextlib import nullcontext
19
20
  from functools import partial
20
- from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple
21
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple
21
22
 
22
23
  from zenml.client import Client
23
24
  from zenml.config.step_configurations import Step
@@ -29,6 +30,7 @@ from zenml.constants import (
29
30
  )
30
31
  from zenml.enums import ExecutionStatus
31
32
  from zenml.environment import get_run_environment_dict
33
+ from zenml.exceptions import RunInterruptedException, RunStoppedException
32
34
  from zenml.logger import get_logger
33
35
  from zenml.logging import step_logging
34
36
  from zenml.models import (
@@ -131,11 +133,86 @@ class StepLauncher:
131
133
  self._stack = Stack.from_model(deployment.stack)
132
134
  self._step_name = step.spec.pipeline_parameter_name
133
135
 
136
+ # Internal properties and methods
137
+ self._step_run: Optional[StepRunResponse] = None
138
+ self._setup_signal_handlers()
139
+
140
+ def _setup_signal_handlers(self) -> None:
141
+ """Set up signal handlers for graceful shutdown, chaining previous handlers."""
142
+ # Save previous handlers
143
+ self._prev_sigterm_handler = signal.getsignal(signal.SIGTERM)
144
+ self._prev_sigint_handler = signal.getsignal(signal.SIGINT)
145
+
146
+ def signal_handler(signum: int, frame: Any) -> None:
147
+ """Handle shutdown signals gracefully.
148
+
149
+ Args:
150
+ signum: The signal number.
151
+ frame: The frame of the signal handler.
152
+
153
+ Raises:
154
+ RunStoppedException: If the pipeline run is stopped by the user.
155
+ RunInterruptedException: If the execution is interrupted for any
156
+ other reason.
157
+ """
158
+ logger.info(
159
+ f"Received signal shutdown {signum}. Requesting shutdown "
160
+ f"for step '{self._step_name}'..."
161
+ )
162
+
163
+ try:
164
+ client = Client()
165
+ pipeline_run = None
166
+
167
+ if self._step_run:
168
+ pipeline_run = client.get_pipeline_run(
169
+ self._step_run.pipeline_run_id
170
+ )
171
+ else:
172
+ raise RunInterruptedException(
173
+ "The execution was interrupted and the step does not "
174
+ "exist yet."
175
+ )
176
+
177
+ if pipeline_run and pipeline_run.status in [
178
+ ExecutionStatus.STOPPING,
179
+ ExecutionStatus.STOPPED,
180
+ ]:
181
+ if self._step_run:
182
+ publish_utils.publish_step_run_status_update(
183
+ step_run_id=self._step_run.id,
184
+ status=ExecutionStatus.STOPPED,
185
+ end_time=utc_now(),
186
+ )
187
+ raise RunStoppedException("Pipeline run in stopped.")
188
+ else:
189
+ raise RunInterruptedException(
190
+ "The execution was interrupted."
191
+ )
192
+ except (RunStoppedException, RunInterruptedException):
193
+ raise
194
+ except Exception as e:
195
+ raise RunInterruptedException(str(e))
196
+ finally:
197
+ # Chain to previous handler if it exists and is not default/ignore
198
+ if signum == signal.SIGTERM and callable(
199
+ self._prev_sigterm_handler
200
+ ):
201
+ self._prev_sigterm_handler(signum, frame)
202
+ elif signum == signal.SIGINT and callable(
203
+ self._prev_sigint_handler
204
+ ):
205
+ self._prev_sigint_handler(signum, frame)
206
+
207
+ # Register handlers for common termination signals
208
+ signal.signal(signal.SIGTERM, signal_handler)
209
+ signal.signal(signal.SIGINT, signal_handler)
210
+
134
211
  def launch(self) -> None:
135
212
  """Launches the step.
136
213
 
137
214
  Raises:
138
- BaseException: If the step failed to launch, run, or publish.
215
+ RunStoppedException: If the pipeline run is stopped by the user.
139
216
  """
140
217
  pipeline_run, run_was_created = self._create_or_reuse_run()
141
218
 
@@ -207,6 +284,8 @@ class StepLauncher:
207
284
  step_run = Client().zen_store.create_run_step(
208
285
  step_run_request
209
286
  )
287
+ # Store step run ID for signal handler
288
+ self._step_run = step_run
210
289
  if model_version := step_run.model_version:
211
290
  step_run_utils.log_model_version_dashboard_url(
212
291
  model_version=model_version
@@ -259,6 +338,8 @@ class StepLauncher:
259
338
  force_write_logs=force_write_logs,
260
339
  )
261
340
  break
341
+ except RunStoppedException as e:
342
+ raise e
262
343
  except BaseException as e: # noqa: E722
263
344
  retries += 1
264
345
  if retries < max_retries:
@@ -292,10 +373,11 @@ class StepLauncher:
292
373
  artifacts=step_run.outputs,
293
374
  model_version=model_version,
294
375
  )
295
-
376
+ except RunStoppedException:
377
+ logger.info(f"Pipeline run `{pipeline_run.name}` stopped.")
378
+ raise
296
379
  except: # noqa: E722
297
380
  logger.error(f"Pipeline run `{pipeline_run.name}` failed.")
298
- publish_utils.publish_failed_pipeline_run(pipeline_run.id)
299
381
  raise
300
382
 
301
383
  def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]:
@@ -0,0 +1,74 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Utility functions for runs."""
15
+
16
+ from typing import cast
17
+
18
+ from zenml.enums import ExecutionStatus
19
+ from zenml.exceptions import IllegalOperationError
20
+ from zenml.models import PipelineRunResponse
21
+
22
+
23
+ def stop_run(run: PipelineRunResponse, graceful: bool = False) -> None:
24
+ """Stop a pipeline run.
25
+
26
+ Args:
27
+ run: The pipeline run to stop.
28
+ graceful: Whether to stop the run gracefully.
29
+
30
+ Raises:
31
+ IllegalOperationError: If the run is already stopped or being stopped.
32
+ ValueError: If the stack is not accessible.
33
+ """
34
+ # Check if the stack is still accessible
35
+ if run.stack is None:
36
+ raise ValueError(
37
+ "The stack that this pipeline run response was executed on "
38
+ "is either not accessible or has been deleted."
39
+ )
40
+
41
+ # Check if pipeline can be stopped
42
+ if run.status == ExecutionStatus.COMPLETED:
43
+ raise IllegalOperationError(
44
+ "Cannot stop a run that is already completed."
45
+ )
46
+
47
+ if run.status == ExecutionStatus.STOPPED:
48
+ raise IllegalOperationError("Run is already stopped.")
49
+
50
+ if run.status == ExecutionStatus.STOPPING:
51
+ raise IllegalOperationError("Run is already being stopped.")
52
+
53
+ # Create the orchestrator instance
54
+ from zenml.enums import StackComponentType
55
+ from zenml.orchestrators.base_orchestrator import BaseOrchestrator
56
+ from zenml.stack.stack_component import StackComponent
57
+
58
+ # Check if the stack is still accessible
59
+ orchestrator_list = run.stack.components.get(
60
+ StackComponentType.ORCHESTRATOR, []
61
+ )
62
+ if len(orchestrator_list) == 0:
63
+ raise ValueError(
64
+ "The orchestrator that this pipeline run response was "
65
+ "executed with is either not accessible or has been deleted."
66
+ )
67
+
68
+ orchestrator = cast(
69
+ BaseOrchestrator,
70
+ StackComponent.from_model(component_model=orchestrator_list[0]),
71
+ )
72
+
73
+ # Stop the run
74
+ orchestrator.stop_run(run=run, graceful=graceful)
@@ -25,6 +25,7 @@ from zenml.constants import (
25
25
  RUNS,
26
26
  STATUS,
27
27
  STEPS,
28
+ STOP,
28
29
  VERSION_1,
29
30
  )
30
31
  from zenml.enums import ExecutionStatus, StackComponentType
@@ -40,6 +41,7 @@ from zenml.models import (
40
41
  StepRunFilter,
41
42
  StepRunResponse,
42
43
  )
44
+ from zenml.utils import run_utils
43
45
  from zenml.zen_server.auth import AuthContext, authorize
44
46
  from zenml.zen_server.exceptions import error_response
45
47
  from zenml.zen_server.rbac.endpoint_utils import (
@@ -51,6 +53,7 @@ from zenml.zen_server.rbac.endpoint_utils import (
51
53
  )
52
54
  from zenml.zen_server.rbac.models import Action, ResourceType
53
55
  from zenml.zen_server.rbac.utils import (
56
+ dehydrate_response_model,
54
57
  verify_permission_for_model,
55
58
  )
56
59
  from zenml.zen_server.routers.projects_endpoints import workspace_router
@@ -389,38 +392,39 @@ def refresh_run_status(
389
392
 
390
393
  Args:
391
394
  run_id: ID of the pipeline run to refresh.
392
-
393
- Raises:
394
- RuntimeError: If the stack or the orchestrator of the run is deleted.
395
395
  """
396
- # Verify access to the run
397
396
  run = verify_permissions_and_get_entity(
398
397
  id=run_id,
399
398
  get_method=zen_store().get_run,
400
399
  hydrate=True,
401
400
  )
402
-
403
- # Check the stack and its orchestrator
404
- if run.stack is not None:
405
- orchestrators = run.stack.components.get(
406
- StackComponentType.ORCHESTRATOR, []
407
- )
408
- if orchestrators:
409
- verify_permission_for_model(
410
- model=orchestrators[0], action=Action.READ
411
- )
412
- else:
413
- raise RuntimeError(
414
- f"The orchestrator, the run '{run.id}' was executed with, is "
415
- "deleted."
416
- )
417
- else:
418
- raise RuntimeError(
419
- f"The stack, the run '{run.id}' was executed on, is deleted."
420
- )
421
401
  run.refresh_run_status()
422
402
 
423
403
 
404
+ @router.post(
405
+ "/{run_id}" + STOP,
406
+ responses={401: error_response, 404: error_response, 422: error_response},
407
+ )
408
+ @async_fastapi_endpoint_wrapper
409
+ def stop_run(
410
+ run_id: UUID,
411
+ graceful: bool = False,
412
+ _: AuthContext = Security(authorize),
413
+ ) -> None:
414
+ """Stops a specific pipeline run.
415
+
416
+ Args:
417
+ run_id: ID of the pipeline run to stop.
418
+ graceful: If True, allows for graceful shutdown where possible.
419
+ If False, forces immediate termination. Default is False.
420
+ """
421
+ run = zen_store().get_run(run_id, hydrate=True)
422
+ verify_permission_for_model(run, action=Action.READ)
423
+ verify_permission_for_model(run, action=Action.UPDATE)
424
+ dehydrate_response_model(run)
425
+ run_utils.stop_run(run=run, graceful=graceful)
426
+
427
+
424
428
  @router.get(
425
429
  "/{run_id}/logs",
426
430
  responses={
@@ -6119,7 +6119,14 @@ class SqlZenStore(BaseZenStore):
6119
6119
  resources=existing_run,
6120
6120
  session=session,
6121
6121
  )
6122
+
6123
+ if run_update.status is not None:
6124
+ self._update_pipeline_run_status(
6125
+ pipeline_run_id=run_id,
6126
+ session=session,
6127
+ )
6122
6128
  session.refresh(existing_run)
6129
+
6123
6130
  return existing_run.to_model(
6124
6131
  include_metadata=True, include_resources=True
6125
6132
  )
@@ -8824,6 +8831,7 @@ class SqlZenStore(BaseZenStore):
8824
8831
 
8825
8832
  Raises:
8826
8833
  EntityExistsError: if the step run already exists.
8834
+ IllegalOperationError: if the pipeline run is stopped or stopping.
8827
8835
  """
8828
8836
  with Session(self.engine) as session:
8829
8837
  self._set_request_user_id(request_model=step_run, session=session)
@@ -8835,6 +8843,16 @@ class SqlZenStore(BaseZenStore):
8835
8843
  reference_id=step_run.pipeline_run_id,
8836
8844
  session=session,
8837
8845
  )
8846
+
8847
+ # Validate pipeline status before creating step
8848
+ if run.status in [
8849
+ ExecutionStatus.STOPPING,
8850
+ ExecutionStatus.STOPPED,
8851
+ ]:
8852
+ raise IllegalOperationError(
8853
+ f"Cannot create step '{step_run.name}' for pipeline in "
8854
+ f"{run.status} state. Pipeline run ID: {step_run.pipeline_run_id}"
8855
+ )
8838
8856
  self._get_reference_schema_by_id(
8839
8857
  resource=step_run,
8840
8858
  reference_schema=StepRunSchema,
@@ -8996,6 +9014,8 @@ class SqlZenStore(BaseZenStore):
8996
9014
  session=session,
8997
9015
  )
8998
9016
 
9017
+ session.commit()
9018
+
8999
9019
  if step_run.status != ExecutionStatus.RUNNING:
9000
9020
  self._update_pipeline_run_status(
9001
9021
  pipeline_run_id=step_run.pipeline_run_id, session=session
@@ -9130,15 +9150,14 @@ class SqlZenStore(BaseZenStore):
9130
9150
  input_type=StepRunInputArtifactType.MANUAL,
9131
9151
  session=session,
9132
9152
  )
9153
+ session.commit()
9154
+ session.refresh(existing_step_run)
9133
9155
 
9134
9156
  self._update_pipeline_run_status(
9135
9157
  pipeline_run_id=existing_step_run.pipeline_run_id,
9136
9158
  session=session,
9137
9159
  )
9138
9160
 
9139
- session.commit()
9140
- session.refresh(existing_step_run)
9141
-
9142
9161
  return existing_step_run.to_model(
9143
9162
  include_metadata=True, include_resources=True
9144
9163
  )
@@ -9375,6 +9394,7 @@ class SqlZenStore(BaseZenStore):
9375
9394
  assert pipeline_run.deployment
9376
9395
  num_steps = pipeline_run.deployment.step_count
9377
9396
  new_status = get_pipeline_run_status(
9397
+ run_status=ExecutionStatus(pipeline_run.status),
9378
9398
  step_statuses=[
9379
9399
  ExecutionStatus(status) for status in step_run_statuses
9380
9400
  ],
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: zenml-nightly
3
- Version: 0.83.1.dev20250701
3
+ Version: 0.83.1.dev20250703
4
4
  Summary: ZenML: Write production-ready ML code.
5
5
  License: Apache-2.0
6
6
  Keywords: machine learning,production,pipeline,mlops,devops