zenml-nightly 0.84.1.dev20250804__py3-none-any.whl → 0.84.1.dev20250806__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,7 +16,10 @@
16
16
  import argparse
17
17
  import random
18
18
  import socket
19
- from typing import Callable, Dict, Optional, cast
19
+ import threading
20
+ import time
21
+ from typing import List, Optional, Tuple, cast
22
+ from uuid import UUID
20
23
 
21
24
  from kubernetes import client as k8s_client
22
25
  from kubernetes.client.rest import ApiException
@@ -27,13 +30,24 @@ from zenml.entrypoints.step_entrypoint_configuration import (
27
30
  )
28
31
  from zenml.enums import ExecutionStatus
29
32
  from zenml.exceptions import AuthorizationException
33
+ from zenml.integrations.kubernetes.constants import (
34
+ ENV_ZENML_KUBERNETES_RUN_ID,
35
+ KUBERNETES_SECRET_TOKEN_KEY_NAME,
36
+ ORCHESTRATOR_RUN_ID_ANNOTATION_KEY,
37
+ RUN_ID_ANNOTATION_KEY,
38
+ STEP_NAME_ANNOTATION_KEY,
39
+ )
30
40
  from zenml.integrations.kubernetes.flavors.kubernetes_orchestrator_flavor import (
31
41
  KubernetesOrchestratorSettings,
32
42
  )
33
43
  from zenml.integrations.kubernetes.orchestrators import kube_utils
44
+ from zenml.integrations.kubernetes.orchestrators.dag_runner import (
45
+ DagRunner,
46
+ InterruptMode,
47
+ Node,
48
+ NodeStatus,
49
+ )
34
50
  from zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator import (
35
- ENV_ZENML_KUBERNETES_RUN_ID,
36
- KUBERNETES_SECRET_TOKEN_KEY_NAME,
37
51
  KubernetesOrchestrator,
38
52
  )
39
53
  from zenml.integrations.kubernetes.orchestrators.manifest_utils import (
@@ -43,8 +57,8 @@ from zenml.integrations.kubernetes.orchestrators.manifest_utils import (
43
57
  )
44
58
  from zenml.logger import get_logger
45
59
  from zenml.logging.step_logging import setup_orchestrator_logging
60
+ from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
46
61
  from zenml.orchestrators import publish_utils
47
- from zenml.orchestrators.dag_runner import NodeStatus, ThreadedDagRunner
48
62
  from zenml.orchestrators.step_run_utils import (
49
63
  StepRunRequestFactory,
50
64
  fetch_step_runs_by_names,
@@ -52,7 +66,6 @@ from zenml.orchestrators.step_run_utils import (
52
66
  )
53
67
  from zenml.orchestrators.utils import (
54
68
  get_config_environment_vars,
55
- get_orchestrator_run_name,
56
69
  )
57
70
  from zenml.pipelines.run_utils import create_placeholder_run
58
71
 
@@ -71,10 +84,115 @@ def parse_args() -> argparse.Namespace:
71
84
  return parser.parse_args()
72
85
 
73
86
 
87
+ def _get_orchestrator_job_state(
88
+ batch_api: k8s_client.BatchV1Api, namespace: str, job_name: str
89
+ ) -> Tuple[Optional[UUID], Optional[str]]:
90
+ """Get the existing status of the orchestrator job.
91
+
92
+ Args:
93
+ batch_api: The batch api.
94
+ namespace: The namespace.
95
+ job_name: The name of the orchestrator job.
96
+
97
+ Returns:
98
+ The run id and orchestrator run id.
99
+ """
100
+ run_id = None
101
+ orchestrator_run_id = None
102
+
103
+ job = kube_utils.get_job(
104
+ batch_api=batch_api,
105
+ namespace=namespace,
106
+ job_name=job_name,
107
+ )
108
+
109
+ if job.metadata and job.metadata.annotations:
110
+ annotations = job.metadata.annotations
111
+
112
+ run_id = annotations.get(RUN_ID_ANNOTATION_KEY, None)
113
+ orchestrator_run_id = annotations.get(
114
+ ORCHESTRATOR_RUN_ID_ANNOTATION_KEY, None
115
+ )
116
+
117
+ return UUID(run_id) if run_id else None, orchestrator_run_id
118
+
119
+
120
+ def _reconstruct_nodes(
121
+ deployment: PipelineDeploymentResponse,
122
+ pipeline_run: PipelineRunResponse,
123
+ namespace: str,
124
+ batch_api: k8s_client.BatchV1Api,
125
+ ) -> List[Node]:
126
+ """Reconstruct the nodes from the pipeline run.
127
+
128
+ Args:
129
+ deployment: The deployment.
130
+ pipeline_run: The pipeline run.
131
+ namespace: The namespace.
132
+ batch_api: The batch api.
133
+
134
+ Returns:
135
+ The reconstructed nodes.
136
+ """
137
+ nodes = {
138
+ step_name: Node(id=step_name, upstream_nodes=step.spec.upstream_steps)
139
+ for step_name, step in deployment.step_configurations.items()
140
+ }
141
+
142
+ for step_name, existing_step_run in pipeline_run.steps.items():
143
+ node = nodes[step_name]
144
+ if existing_step_run.status.is_successful:
145
+ node.status = NodeStatus.COMPLETED
146
+ elif existing_step_run.status.is_finished:
147
+ node.status = NodeStatus.FAILED
148
+
149
+ job_list = kube_utils.list_jobs(
150
+ batch_api=batch_api,
151
+ namespace=namespace,
152
+ label_selector=f"run_id={pipeline_run.id}",
153
+ )
154
+ for job in job_list.items:
155
+ annotations = job.metadata.annotations or {}
156
+ if step_name := annotations.get(STEP_NAME_ANNOTATION_KEY, None):
157
+ node = nodes[step_name]
158
+ node.metadata["job_name"] = job.metadata.name
159
+
160
+ if node.status == NodeStatus.NOT_READY:
161
+ # The step is not finished in the ZenML database, so we base it
162
+ # on the job status.
163
+ node_status = NodeStatus.RUNNING
164
+ if job.status.conditions:
165
+ for condition in job.status.conditions:
166
+ if (
167
+ condition.type == "Complete"
168
+ and condition.status == "True"
169
+ ):
170
+ node_status = NodeStatus.COMPLETED
171
+ break
172
+ elif (
173
+ condition.type == "Failed"
174
+ and condition.status == "True"
175
+ ):
176
+ node_status = NodeStatus.FAILED
177
+ break
178
+
179
+ node.status = node_status
180
+ logger.debug(
181
+ "Existing job for step `%s` status: %s.",
182
+ step_name,
183
+ node_status,
184
+ )
185
+
186
+ return list(nodes.values())
187
+
188
+
74
189
  def main() -> None:
75
- """Entrypoint of the k8s master/orchestrator pod."""
76
- # Log to the container's stdout so it can be streamed by the client.
77
- logger.info("Kubernetes orchestrator pod started.")
190
+ """Entrypoint of the k8s master/orchestrator pod.
191
+
192
+ Raises:
193
+ RuntimeError: If the orchestrator pod is not associated with a job.
194
+ """
195
+ logger.info("Orchestrator pod started.")
78
196
 
79
197
  args = parse_args()
80
198
 
@@ -82,50 +200,98 @@ def main() -> None:
82
200
 
83
201
  client = Client()
84
202
  deployment = client.get_deployment(args.deployment_id)
203
+ active_stack = client.active_stack
204
+ orchestrator = active_stack.orchestrator
205
+ assert isinstance(orchestrator, KubernetesOrchestrator)
206
+ namespace = orchestrator.config.kubernetes_namespace
207
+
208
+ pipeline_settings = cast(
209
+ KubernetesOrchestratorSettings,
210
+ orchestrator.get_settings(deployment),
211
+ )
85
212
 
86
- if args.run_id:
87
- pipeline_run = client.get_pipeline_run(args.run_id)
88
- else:
89
- pipeline_run = create_placeholder_run(
213
+ # Get a Kubernetes client from the active Kubernetes orchestrator, but
214
+ # override the `incluster` setting to `True` since we are running inside
215
+ # the Kubernetes cluster.
216
+ api_client_config = orchestrator.get_kube_client(
217
+ incluster=True
218
+ ).configuration
219
+ api_client_config.connection_pool_maxsize = (
220
+ pipeline_settings.max_parallelism
221
+ )
222
+ kube_client = k8s_client.ApiClient(api_client_config)
223
+ core_api = k8s_client.CoreV1Api(kube_client)
224
+ batch_api = k8s_client.BatchV1Api(kube_client)
225
+
226
+ job_name = kube_utils.get_parent_job_name(
227
+ core_api=core_api,
228
+ pod_name=orchestrator_pod_name,
229
+ namespace=namespace,
230
+ )
231
+ if not job_name:
232
+ raise RuntimeError("Failed to fetch job name for orchestrator pod.")
233
+
234
+ run_id, orchestrator_run_id = _get_orchestrator_job_state(
235
+ batch_api=batch_api,
236
+ namespace=namespace,
237
+ job_name=job_name,
238
+ )
239
+ existing_logs_response = None
240
+
241
+ if run_id and orchestrator_run_id:
242
+ logger.info("Continuing existing run `%s`.", run_id)
243
+ pipeline_run = client.get_pipeline_run(run_id)
244
+ nodes = _reconstruct_nodes(
90
245
  deployment=deployment,
91
- orchestrator_run_id=orchestrator_pod_name,
246
+ pipeline_run=pipeline_run,
247
+ namespace=namespace,
248
+ batch_api=batch_api,
249
+ )
250
+ logger.debug("Reconstructed nodes: %s", nodes)
251
+
252
+ # Continue logging to the same log file if it exists
253
+ for log_response in pipeline_run.log_collection or []:
254
+ if log_response.source == "orchestrator":
255
+ existing_logs_response = log_response
256
+ break
257
+ else:
258
+ orchestrator_run_id = orchestrator_pod_name
259
+ if args.run_id:
260
+ pipeline_run = client.get_pipeline_run(args.run_id)
261
+ else:
262
+ pipeline_run = create_placeholder_run(
263
+ deployment=deployment,
264
+ orchestrator_run_id=orchestrator_run_id,
265
+ )
266
+
267
+ # Store in the job annotations so we can continue the run if the pod
268
+ # is restarted
269
+ kube_utils.update_job(
270
+ batch_api=batch_api,
271
+ namespace=namespace,
272
+ job_name=job_name,
273
+ annotations={
274
+ RUN_ID_ANNOTATION_KEY: str(pipeline_run.id),
275
+ ORCHESTRATOR_RUN_ID_ANNOTATION_KEY: orchestrator_run_id,
276
+ },
92
277
  )
278
+ nodes = [
279
+ Node(id=step_name, upstream_nodes=step.spec.upstream_steps)
280
+ for step_name, step in deployment.step_configurations.items()
281
+ ]
93
282
 
94
283
  logs_context = setup_orchestrator_logging(
95
- run_id=str(pipeline_run.id), deployment=deployment
284
+ run_id=pipeline_run.id,
285
+ deployment=deployment,
286
+ logs_response=existing_logs_response,
96
287
  )
97
288
 
98
289
  with logs_context:
99
- active_stack = client.active_stack
100
- orchestrator = active_stack.orchestrator
101
- assert isinstance(orchestrator, KubernetesOrchestrator)
102
- namespace = orchestrator.config.kubernetes_namespace
103
-
104
- pipeline_settings = cast(
105
- KubernetesOrchestratorSettings,
106
- orchestrator.get_settings(deployment),
107
- )
108
-
109
290
  step_command = StepEntrypointConfiguration.get_entrypoint_command()
110
-
111
291
  mount_local_stores = active_stack.orchestrator.config.is_local
112
292
 
113
- # Get a Kubernetes client from the active Kubernetes orchestrator, but
114
- # override the `incluster` setting to `True` since we are running inside
115
- # the Kubernetes cluster.
116
-
117
- api_client_config = orchestrator.get_kube_client(
118
- incluster=True
119
- ).configuration
120
- api_client_config.connection_pool_maxsize = (
121
- pipeline_settings.max_parallelism
122
- )
123
- kube_client = k8s_client.ApiClient(api_client_config)
124
- core_api = k8s_client.CoreV1Api(kube_client)
125
- batch_api = k8s_client.BatchV1Api(kube_client)
126
-
127
293
  env = get_config_environment_vars()
128
- env[ENV_ZENML_KUBERNETES_RUN_ID] = orchestrator_pod_name
294
+ env[ENV_ZENML_KUBERNETES_RUN_ID] = orchestrator_run_id
129
295
 
130
296
  try:
131
297
  owner_references = kube_utils.get_pod_owner_references(
@@ -142,50 +308,12 @@ def main() -> None:
142
308
  for owner_reference in owner_references:
143
309
  owner_reference.controller = False
144
310
 
145
- pre_step_run: Optional[Callable[[str], bool]] = None
146
-
147
- if not pipeline_settings.prevent_orchestrator_pod_caching:
148
- step_run_request_factory = StepRunRequestFactory(
149
- deployment=deployment,
150
- pipeline_run=pipeline_run,
151
- stack=active_stack,
152
- )
153
- step_runs = {}
154
-
155
- def pre_step_run(step_name: str) -> bool:
156
- """Pre-step run.
157
-
158
- Args:
159
- step_name: Name of the step.
160
-
161
- Returns:
162
- Whether the step node needs to be run.
163
- """
164
- if not step_run_request_factory.has_caching_enabled(step_name):
165
- return True
166
-
167
- step_run_request = step_run_request_factory.create_request(
168
- step_name
169
- )
170
- try:
171
- step_run_request_factory.populate_request(step_run_request)
172
- except Exception as e:
173
- logger.error(
174
- f"Failed to populate step run request for step {step_name}: {e}"
175
- )
176
- return True
177
-
178
- if step_run_request.status == ExecutionStatus.CACHED:
179
- step_run = publish_cached_step_run(
180
- step_run_request, pipeline_run
181
- )
182
- step_runs[step_name] = step_run
183
- logger.info(
184
- "Using cached version of step `%s`.", step_name
185
- )
186
- return False
187
-
188
- return True
311
+ step_run_request_factory = StepRunRequestFactory(
312
+ deployment=deployment,
313
+ pipeline_run=pipeline_run,
314
+ stack=active_stack,
315
+ )
316
+ step_runs = {}
189
317
 
190
318
  base_labels = {
191
319
  "run_id": kube_utils.sanitize_label(str(pipeline_run.id)),
@@ -195,15 +323,44 @@ def main() -> None:
195
323
  ),
196
324
  }
197
325
 
198
- def run_step_on_kubernetes(step_name: str) -> None:
326
+ def _cache_step_run_if_possible(step_name: str) -> bool:
327
+ if not step_run_request_factory.has_caching_enabled(step_name):
328
+ return False
329
+
330
+ step_run_request = step_run_request_factory.create_request(
331
+ step_name
332
+ )
333
+ try:
334
+ step_run_request_factory.populate_request(step_run_request)
335
+ except Exception as e:
336
+ logger.error(
337
+ f"Failed to populate step run request for step {step_name}: {e}"
338
+ )
339
+ return False
340
+
341
+ if step_run_request.status == ExecutionStatus.CACHED:
342
+ step_run = publish_cached_step_run(
343
+ step_run_request, pipeline_run
344
+ )
345
+ step_runs[step_name] = step_run
346
+ logger.info("Using cached version of step `%s`.", step_name)
347
+ return True
348
+
349
+ return False
350
+
351
+ startup_lock = threading.Lock()
352
+ last_startup_time: float = 0.0
353
+
354
+ def start_step_job(node: Node) -> NodeStatus:
199
355
  """Run a pipeline step in a separate Kubernetes pod.
200
356
 
201
357
  Args:
202
- step_name: Name of the step.
358
+ node: The node to start.
203
359
 
204
- Raises:
205
- Exception: If the pod fails to start.
360
+ Returns:
361
+ The status of the node.
206
362
  """
363
+ step_name = node.id
207
364
  step_config = deployment.step_configurations[step_name].config
208
365
  settings = step_config.settings.get(
209
366
  "orchestrator.kubernetes", None
@@ -211,32 +368,15 @@ def main() -> None:
211
368
  settings = KubernetesOrchestratorSettings.model_validate(
212
369
  settings.model_dump() if settings else {}
213
370
  )
371
+ if not pipeline_settings.prevent_orchestrator_pod_caching:
372
+ if _cache_step_run_if_possible(step_name):
373
+ return NodeStatus.COMPLETED
214
374
 
215
- if (
216
- settings.pod_name_prefix
217
- and not orchestrator_pod_name.startswith(
218
- settings.pod_name_prefix
219
- )
220
- ):
221
- max_length = (
222
- kube_utils.calculate_max_pod_name_length_for_namespace(
223
- namespace=namespace
224
- )
225
- )
226
- pod_name_prefix = get_orchestrator_run_name(
227
- settings.pod_name_prefix, max_length=max_length
228
- )
229
- pod_name = f"{pod_name_prefix}-{step_name}"
230
- else:
231
- pod_name = f"{orchestrator_pod_name}-{step_name}"
232
-
233
- pod_name = kube_utils.sanitize_pod_name(
234
- pod_name, namespace=namespace
235
- )
236
-
237
- # Add step name to labels so both pod and job have consistent labeling
238
375
  step_labels = base_labels.copy()
239
376
  step_labels["step_name"] = kube_utils.sanitize_label(step_name)
377
+ step_annotations = {
378
+ STEP_NAME_ANNOTATION_KEY: step_name,
379
+ }
240
380
 
241
381
  image = KubernetesOrchestrator.get_image(
242
382
  deployment=deployment, step_name=step_name
@@ -250,11 +390,9 @@ def main() -> None:
250
390
  # some memory resources itself and, if not specified, the pod will be
251
391
  # scheduled on any node regardless of available memory and risk
252
392
  # negatively impacting or even crashing the node due to memory pressure.
253
- pod_settings = (
254
- KubernetesOrchestrator.apply_default_resource_requests(
255
- memory="400Mi",
256
- pod_settings=settings.pod_settings,
257
- )
393
+ pod_settings = kube_utils.apply_default_resource_requests(
394
+ memory="400Mi",
395
+ pod_settings=settings.pod_settings,
258
396
  )
259
397
 
260
398
  if orchestrator.config.pass_zenml_token_as_secret:
@@ -272,9 +410,8 @@ def main() -> None:
272
410
  }
273
411
  )
274
412
 
275
- # Define Kubernetes pod manifest.
276
413
  pod_manifest = build_pod_manifest(
277
- pod_name=pod_name,
414
+ pod_name=None,
278
415
  image_name=image,
279
416
  command=step_command,
280
417
  args=step_args,
@@ -293,21 +430,6 @@ def main() -> None:
293
430
  retry_config.max_retries if retry_config else 0
294
431
  ) + settings.backoff_limit_margin
295
432
 
296
- # This is to fix a bug in the kubernetes client which has some wrong
297
- # client-side validations that means the `on_exit_codes` field is
298
- # unusable. See https://github.com/kubernetes-client/python/issues/2056
299
- class PatchedFailurePolicyRule(k8s_client.V1PodFailurePolicyRule): # type: ignore[misc]
300
- @property
301
- def on_pod_conditions(self): # type: ignore[no-untyped-def]
302
- return self._on_pod_conditions
303
-
304
- @on_pod_conditions.setter
305
- def on_pod_conditions(self, on_pod_conditions): # type: ignore[no-untyped-def]
306
- self._on_pod_conditions = on_pod_conditions
307
-
308
- k8s_client.V1PodFailurePolicyRule = PatchedFailurePolicyRule
309
- k8s_client.models.V1PodFailurePolicyRule = PatchedFailurePolicyRule
310
-
311
433
  pod_failure_policy = settings.pod_failure_policy or {
312
434
  # These rules are applied sequentially. This means any failure in
313
435
  # the main container will count towards the max retries. Any other
@@ -336,7 +458,7 @@ def main() -> None:
336
458
  ]
337
459
  }
338
460
 
339
- job_name = settings.pod_name_prefix or ""
461
+ job_name = settings.job_name_prefix or ""
340
462
  random_prefix = "".join(random.choices("0123456789abcdef", k=8))
341
463
  job_name += f"-{random_prefix}-{step_name}-{deployment.pipeline_configuration.name}"
342
464
  # The job name will be used as a label on the pods, so we need to make
@@ -352,109 +474,85 @@ def main() -> None:
352
474
  pod_failure_policy=pod_failure_policy,
353
475
  owner_references=owner_references,
354
476
  labels=step_labels,
477
+ annotations=step_annotations,
355
478
  )
356
479
 
480
+ if (
481
+ startup_interval
482
+ := orchestrator.config.parallel_step_startup_waiting_period
483
+ ):
484
+ nonlocal last_startup_time
485
+
486
+ with startup_lock:
487
+ now = time.time()
488
+ time_since_last_startup = now - last_startup_time
489
+ sleep_time = startup_interval - time_since_last_startup
490
+ if sleep_time > 0:
491
+ logger.debug(
492
+ f"Sleeping for {sleep_time} seconds before "
493
+ f"starting job for step {step_name}."
494
+ )
495
+ time.sleep(sleep_time)
496
+ last_startup_time = now
497
+
357
498
  kube_utils.create_job(
358
499
  batch_api=batch_api,
359
500
  namespace=namespace,
360
501
  job_manifest=job_manifest,
361
502
  )
362
503
 
363
- logger.info(f"Waiting for job of step `{step_name}` to finish...")
364
- try:
365
- kube_utils.wait_for_job_to_finish(
366
- batch_api=batch_api,
367
- core_api=core_api,
368
- namespace=namespace,
369
- job_name=job_name,
370
- fail_on_container_waiting_reasons=settings.fail_on_container_waiting_reasons,
371
- stream_logs=pipeline_settings.stream_step_logs,
372
- backoff_interval=settings.job_monitoring_interval,
373
- )
374
-
375
- logger.info(f"Job for step `{step_name}` completed.")
376
- except Exception:
377
- reason = "Unknown"
378
- try:
379
- pods = core_api.list_namespaced_pod(
380
- label_selector=f"job-name={job_name}",
381
- namespace=namespace,
382
- ).items
383
- # Sort pods by creation timestamp, oldest first
384
- pods.sort(
385
- key=lambda pod: pod.metadata.creation_timestamp,
386
- )
387
- if pods:
388
- if (
389
- termination_reason
390
- := kube_utils.get_container_termination_reason(
391
- pods[-1], "main"
392
- )
393
- ):
394
- exit_code, reason = termination_reason
395
- if exit_code != 0:
396
- reason = f"{reason} (exit_code={exit_code})"
397
- except Exception:
398
- pass
399
- logger.error(
400
- f"Job for step `{step_name}` failed. Reason: {reason}"
401
- )
504
+ node.metadata["job_name"] = job_name
402
505
 
403
- raise
506
+ return NodeStatus.RUNNING
404
507
 
405
- def finalize_run(node_states: Dict[str, NodeStatus]) -> None:
406
- """Finalize the run.
508
+ def check_job_status(node: Node) -> NodeStatus:
509
+ """Check the status of a job.
407
510
 
408
511
  Args:
409
- node_states: The states of the nodes.
512
+ node: The node to check.
513
+
514
+ Returns:
515
+ The status of the node.
410
516
  """
411
- try:
412
- # Some steps may have failed because the pods could not be created.
413
- # We need to check for this and mark the step run as failed if so.
414
- pipeline_failed = False
415
- failed_step_names = [
416
- step_name
417
- for step_name, node_state in node_states.items()
418
- if node_state == NodeStatus.FAILED
419
- ]
420
- step_runs = fetch_step_runs_by_names(
421
- step_run_names=failed_step_names, pipeline_run=pipeline_run
517
+ step_name = node.id
518
+ job_name = node.metadata.get("job_name", None)
519
+ if not job_name:
520
+ logger.error(
521
+ "Missing job name to monitor step `%s`.", step_name
422
522
  )
523
+ return NodeStatus.FAILED
423
524
 
424
- for step_name, node_state in node_states.items():
425
- if node_state != NodeStatus.FAILED:
426
- continue
427
-
428
- pipeline_failed = True
429
-
430
- if step_run := step_runs.get(step_name, None):
431
- # Try to update the step run status, if it exists and is in
432
- # a transient state.
433
- if step_run and step_run.status in {
434
- ExecutionStatus.INITIALIZING,
435
- ExecutionStatus.RUNNING,
436
- }:
437
- publish_utils.publish_failed_step_run(step_run.id)
438
-
439
- # If any steps failed and the pipeline run is still in a transient
440
- # state, we need to mark it as failed.
441
- if pipeline_failed and pipeline_run.status in {
442
- ExecutionStatus.INITIALIZING,
443
- ExecutionStatus.RUNNING,
444
- }:
445
- publish_utils.publish_failed_pipeline_run(pipeline_run.id)
446
- except AuthorizationException:
447
- # If a step of the pipeline failed or all of them completed
448
- # successfully, the pipeline run will be finished and the API token
449
- # will be invalidated. We catch this exception and do nothing here,
450
- # as the pipeline run status will already have been published.
451
- pass
452
-
453
- def check_pipeline_cancellation() -> bool:
454
- """Check if the pipeline should continue execution.
525
+ step_config = deployment.step_configurations[step_name].config
526
+ settings = step_config.settings.get(
527
+ "orchestrator.kubernetes", None
528
+ )
529
+ settings = KubernetesOrchestratorSettings.model_validate(
530
+ settings.model_dump() if settings else {}
531
+ )
532
+ status, error_message = kube_utils.check_job_status(
533
+ batch_api=batch_api,
534
+ core_api=core_api,
535
+ namespace=namespace,
536
+ job_name=job_name,
537
+ fail_on_container_waiting_reasons=settings.fail_on_container_waiting_reasons,
538
+ )
539
+ if status == kube_utils.JobStatus.SUCCEEDED:
540
+ return NodeStatus.COMPLETED
541
+ elif status == kube_utils.JobStatus.FAILED:
542
+ logger.error(
543
+ "Job for step `%s` failed: %s",
544
+ step_name,
545
+ error_message,
546
+ )
547
+ return NodeStatus.FAILED
548
+ else:
549
+ return NodeStatus.RUNNING
550
+
551
+ def should_interrupt_execution() -> Optional[InterruptMode]:
552
+ """Check if the DAG execution should be interrupted.
455
553
 
456
554
  Returns:
457
- True if execution should continue, False if it should stop.
555
+ If the DAG execution should be interrupted.
458
556
  """
459
557
  try:
460
558
  run = client.get_pipeline_run(
@@ -463,44 +561,34 @@ def main() -> None:
463
561
  hydrate=False, # We only need status, not full hydration
464
562
  )
465
563
 
466
- # If the run is STOPPING or STOPPED, we should stop the execution
467
564
  if run.status in [
468
565
  ExecutionStatus.STOPPING,
469
566
  ExecutionStatus.STOPPED,
470
567
  ]:
471
568
  logger.info(
472
- f"Pipeline run is in {run.status} state, stopping execution"
569
+ "Stopping DAG execution because pipeline run is in "
570
+ "`%s` state.",
571
+ run.status,
473
572
  )
474
- return False
475
-
476
- return True
477
-
573
+ return InterruptMode.GRACEFUL
478
574
  except Exception as e:
479
- # If we can't check the status, assume we should continue
480
575
  logger.warning(
481
- f"Failed to check pipeline cancellation status: {e}"
576
+ "Failed to check pipeline cancellation status: %s", e
482
577
  )
483
- return True
484
578
 
485
- parallel_node_startup_waiting_period = (
486
- orchestrator.config.parallel_step_startup_waiting_period or 0.0
487
- )
579
+ return None
488
580
 
489
- pipeline_dag = {
490
- step_name: step.spec.upstream_steps
491
- for step_name, step in deployment.step_configurations.items()
492
- }
493
581
  try:
494
- ThreadedDagRunner(
495
- dag=pipeline_dag,
496
- run_fn=run_step_on_kubernetes,
497
- preparation_fn=pre_step_run,
498
- finalize_fn=finalize_run,
499
- continue_fn=check_pipeline_cancellation,
500
- parallel_node_startup_waiting_period=parallel_node_startup_waiting_period,
582
+ nodes_statuses = DagRunner(
583
+ nodes=nodes,
584
+ node_startup_function=start_step_job,
585
+ node_monitoring_function=check_job_status,
586
+ interrupt_function=should_interrupt_execution,
587
+ monitoring_interval=pipeline_settings.job_monitoring_interval,
588
+ monitoring_delay=pipeline_settings.job_monitoring_delay,
589
+ interrupt_check_interval=pipeline_settings.interrupt_check_interval,
501
590
  max_parallelism=pipeline_settings.max_parallelism,
502
591
  ).run()
503
- logger.info("Orchestration pod completed.")
504
592
  finally:
505
593
  if (
506
594
  orchestrator.config.pass_zenml_token_as_secret
@@ -518,6 +606,66 @@ def main() -> None:
518
606
  f"Error cleaning up secret {secret_name}: {e}"
519
607
  )
520
608
 
609
+ try:
610
+ pipeline_failed = False
611
+ failed_step_names = [
612
+ step_name
613
+ for step_name, node_state in nodes_statuses.items()
614
+ if node_state == NodeStatus.FAILED
615
+ ]
616
+ skipped_step_names = [
617
+ step_name
618
+ for step_name, node_state in nodes_statuses.items()
619
+ if node_state == NodeStatus.SKIPPED
620
+ ]
621
+
622
+ if failed_step_names:
623
+ logger.error(
624
+ "The following steps failed: %s",
625
+ ", ".join(failed_step_names),
626
+ )
627
+ if skipped_step_names:
628
+ logger.error(
629
+ "The following steps were skipped because some of their "
630
+ "upstream steps failed: %s",
631
+ ", ".join(skipped_step_names),
632
+ )
633
+
634
+ step_runs = fetch_step_runs_by_names(
635
+ step_run_names=failed_step_names, pipeline_run=pipeline_run
636
+ )
637
+
638
+ for step_name, node_state in nodes_statuses.items():
639
+ if node_state != NodeStatus.FAILED:
640
+ continue
641
+
642
+ pipeline_failed = True
643
+
644
+ if step_run := step_runs.get(step_name, None):
645
+ # Try to update the step run status, if it exists and is in
646
+ # a transient state.
647
+ if step_run and step_run.status in {
648
+ ExecutionStatus.INITIALIZING,
649
+ ExecutionStatus.RUNNING,
650
+ }:
651
+ publish_utils.publish_failed_step_run(step_run.id)
652
+
653
+ # If any steps failed and the pipeline run is still in a transient
654
+ # state, we need to mark it as failed.
655
+ if pipeline_failed and pipeline_run.status in {
656
+ ExecutionStatus.INITIALIZING,
657
+ ExecutionStatus.RUNNING,
658
+ }:
659
+ publish_utils.publish_failed_pipeline_run(pipeline_run.id)
660
+ except AuthorizationException:
661
+ # If a step of the pipeline failed or all of them completed
662
+ # successfully, the pipeline run will be finished and the API token
663
+ # will be invalidated. We catch this exception and do nothing here,
664
+ # as the pipeline run status will already have been published.
665
+ pass
666
+
667
+ logger.info("Orchestrator pod finished.")
668
+
521
669
 
522
670
  if __name__ == "__main__":
523
671
  main()