zenml-nightly 0.83.1.dev20250709__py3-none-any.whl → 0.83.1.dev20250710__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/login.py +141 -18
- zenml/cli/project.py +8 -6
- zenml/cli/utils.py +63 -16
- zenml/client.py +4 -1
- zenml/config/compiler.py +1 -0
- zenml/config/retry_config.py +5 -3
- zenml/config/step_configurations.py +7 -1
- zenml/console.py +4 -1
- zenml/constants.py +0 -1
- zenml/enums.py +13 -4
- zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +58 -4
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +172 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +37 -23
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +92 -22
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +59 -0
- zenml/logger.py +6 -4
- zenml/login/web_login.py +13 -6
- zenml/models/v2/core/model_version.py +9 -1
- zenml/models/v2/core/pipeline_run.py +1 -0
- zenml/models/v2/core/step_run.py +35 -1
- zenml/orchestrators/base_orchestrator.py +63 -8
- zenml/orchestrators/dag_runner.py +3 -1
- zenml/orchestrators/publish_utils.py +4 -1
- zenml/orchestrators/step_launcher.py +77 -139
- zenml/orchestrators/step_run_utils.py +16 -0
- zenml/orchestrators/step_runner.py +1 -4
- zenml/pipelines/pipeline_decorator.py +6 -1
- zenml/pipelines/pipeline_definition.py +7 -0
- zenml/zen_server/auth.py +0 -1
- zenml/zen_stores/migrations/versions/360fa84718bf_step_run_versioning.py +64 -0
- zenml/zen_stores/migrations/versions/85289fea86ff_adding_source_to_logs.py +1 -1
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +21 -0
- zenml/zen_stores/schemas/pipeline_run_schemas.py +31 -2
- zenml/zen_stores/schemas/step_run_schemas.py +41 -17
- zenml/zen_stores/sql_zen_store.py +152 -32
- zenml/zen_stores/template_utils.py +29 -9
- zenml_nightly-0.83.1.dev20250710.dist-info/METADATA +499 -0
- {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/RECORD +42 -41
- zenml_nightly-0.83.1.dev20250709.dist-info/METADATA +0 -538
- {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.1.dev20250709.dist-info → zenml_nightly-0.83.1.dev20250710.dist-info}/entry_points.txt +0 -0
@@ -13,9 +13,9 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Kubernetes orchestrator flavor."""
|
15
15
|
|
16
|
-
from typing import TYPE_CHECKING, Optional, Type
|
16
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
17
17
|
|
18
|
-
from pydantic import NonNegativeInt, PositiveInt
|
18
|
+
from pydantic import NonNegativeInt, PositiveInt, field_validator
|
19
19
|
|
20
20
|
from zenml.config.base_settings import BaseSettings
|
21
21
|
from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE
|
@@ -40,6 +40,9 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
40
40
|
asynchronously. Defaults to `True`.
|
41
41
|
timeout: How many seconds to wait for synchronous runs. `0` means
|
42
42
|
to wait for an unlimited duration.
|
43
|
+
stream_step_logs: If `True`, the orchestrator pod will stream the logs
|
44
|
+
of the step pods. This only has an effect if specified on the
|
45
|
+
pipeline, not on individual steps.
|
43
46
|
service_account_name: Name of the service account to use for the
|
44
47
|
orchestrator pod. If not provided, a new service account with "edit"
|
45
48
|
permissions will be created.
|
@@ -65,8 +68,26 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
65
68
|
failed_jobs_history_limit: The number of failed jobs to retain.
|
66
69
|
This only applies to jobs created when scheduling a pipeline.
|
67
70
|
ttl_seconds_after_finished: The amount of seconds to keep finished jobs
|
68
|
-
before deleting them. This
|
69
|
-
|
71
|
+
before deleting them. **Note**: This does not clean up the
|
72
|
+
orchestrator pod for non-scheduled runs.
|
73
|
+
active_deadline_seconds: The active deadline seconds for the job that is
|
74
|
+
executing the step.
|
75
|
+
backoff_limit_margin: The value to add to the backoff limit in addition
|
76
|
+
to the step retries. The retry configuration defined on the step
|
77
|
+
defines the maximum number of retries that the server will accept
|
78
|
+
for a step. For this orchestrator, this controls how often the
|
79
|
+
job running the step will try to start the step pod. There are some
|
80
|
+
circumstances however where the job will start the pod, but the pod
|
81
|
+
doesn't actually get to the point of running the step. That means
|
82
|
+
the server will not receive the maximum amount of retry requests,
|
83
|
+
which in turn causes other inconsistencies like wrong step statuses.
|
84
|
+
To mitigate this, this attribute allows to add a margin to the
|
85
|
+
backoff limit. This means that the job will retry the pod startup
|
86
|
+
for the configured amount of times plus the margin, which increases
|
87
|
+
the chance of the server receiving the maximum amount of retry
|
88
|
+
requests.
|
89
|
+
pod_failure_policy: The pod failure policy to use for the job that is
|
90
|
+
executing the step.
|
70
91
|
prevent_orchestrator_pod_caching: If `True`, the orchestrator pod will
|
71
92
|
not try to compute cached steps before starting the step pods.
|
72
93
|
always_build_pipeline_image: If `True`, the orchestrator will always
|
@@ -77,6 +98,7 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
77
98
|
|
78
99
|
synchronous: bool = True
|
79
100
|
timeout: int = 0
|
101
|
+
stream_step_logs: bool = True
|
80
102
|
service_account_name: Optional[str] = None
|
81
103
|
step_pod_service_account_name: Optional[str] = None
|
82
104
|
privileged: bool = False
|
@@ -91,10 +113,33 @@ class KubernetesOrchestratorSettings(BaseSettings):
|
|
91
113
|
successful_jobs_history_limit: Optional[NonNegativeInt] = None
|
92
114
|
failed_jobs_history_limit: Optional[NonNegativeInt] = None
|
93
115
|
ttl_seconds_after_finished: Optional[NonNegativeInt] = None
|
116
|
+
active_deadline_seconds: Optional[NonNegativeInt] = None
|
117
|
+
backoff_limit_margin: NonNegativeInt = 0
|
118
|
+
pod_failure_policy: Optional[Dict[str, Any]] = None
|
94
119
|
prevent_orchestrator_pod_caching: bool = False
|
95
120
|
always_build_pipeline_image: bool = False
|
96
121
|
pod_stop_grace_period: PositiveInt = 30
|
97
122
|
|
123
|
+
@field_validator("pod_failure_policy", mode="before")
|
124
|
+
@classmethod
|
125
|
+
def _convert_pod_failure_policy(cls, value: Any) -> Any:
|
126
|
+
"""Converts Kubernetes pod failure policy to a dict.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
value: The pod failure policy value.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
The converted value.
|
133
|
+
"""
|
134
|
+
from kubernetes.client.models import V1PodFailurePolicy
|
135
|
+
|
136
|
+
from zenml.integrations.kubernetes import serialization_utils
|
137
|
+
|
138
|
+
if isinstance(value, V1PodFailurePolicy):
|
139
|
+
return serialization_utils.serialize_kubernetes_model(value)
|
140
|
+
else:
|
141
|
+
return value
|
142
|
+
|
98
143
|
|
99
144
|
class KubernetesOrchestratorConfig(
|
100
145
|
BaseOrchestratorConfig, KubernetesOrchestratorSettings
|
@@ -187,6 +232,15 @@ class KubernetesOrchestratorConfig(
|
|
187
232
|
# This is currently not supported when using client-side caching.
|
188
233
|
return False
|
189
234
|
|
235
|
+
@property
|
236
|
+
def handles_step_retries(self) -> bool:
|
237
|
+
"""Whether the orchestrator handles step retries.
|
238
|
+
|
239
|
+
Returns:
|
240
|
+
Whether the orchestrator handles step retries.
|
241
|
+
"""
|
242
|
+
return True
|
243
|
+
|
190
244
|
|
191
245
|
class KubernetesOrchestratorFlavor(BaseOrchestratorFlavor):
|
192
246
|
"""Kubernetes orchestrator flavor."""
|
@@ -32,8 +32,10 @@ Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils
|
|
32
32
|
"""
|
33
33
|
|
34
34
|
import enum
|
35
|
+
import functools
|
35
36
|
import re
|
36
37
|
import time
|
38
|
+
from collections import defaultdict
|
37
39
|
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
|
38
40
|
|
39
41
|
from kubernetes import client as k8s_client
|
@@ -51,6 +53,8 @@ from zenml.utils.time_utils import utc_now
|
|
51
53
|
|
52
54
|
logger = get_logger(__name__)
|
53
55
|
|
56
|
+
R = TypeVar("R")
|
57
|
+
|
54
58
|
|
55
59
|
class PodPhase(enum.Enum):
|
56
60
|
"""Phase of the Kubernetes pod.
|
@@ -581,3 +585,171 @@ def get_pod_owner_references(
|
|
581
585
|
return cast(
|
582
586
|
List[k8s_client.V1OwnerReference], pod.metadata.owner_references
|
583
587
|
)
|
588
|
+
|
589
|
+
|
590
|
+
def retry_on_api_exception(
|
591
|
+
func: Callable[..., R],
|
592
|
+
max_retries: int = 3,
|
593
|
+
delay: float = 1,
|
594
|
+
backoff: float = 1,
|
595
|
+
) -> Callable[..., R]:
|
596
|
+
"""Retry a function on API exceptions.
|
597
|
+
|
598
|
+
Args:
|
599
|
+
func: The function to retry.
|
600
|
+
max_retries: The maximum number of retries.
|
601
|
+
delay: The delay between retries.
|
602
|
+
backoff: The backoff factor.
|
603
|
+
|
604
|
+
Returns:
|
605
|
+
The wrapped function with retry logic.
|
606
|
+
"""
|
607
|
+
|
608
|
+
@functools.wraps(func)
|
609
|
+
def wrapper(*args: Any, **kwargs: Any) -> R:
|
610
|
+
_delay = delay
|
611
|
+
retries = 0
|
612
|
+
while retries <= max_retries:
|
613
|
+
try:
|
614
|
+
return func(*args, **kwargs)
|
615
|
+
except ApiException as e:
|
616
|
+
retries += 1
|
617
|
+
if retries <= max_retries:
|
618
|
+
logger.warning("Error calling %s: %s.", func.__name__, e)
|
619
|
+
time.sleep(_delay)
|
620
|
+
_delay *= backoff
|
621
|
+
else:
|
622
|
+
raise
|
623
|
+
|
624
|
+
raise RuntimeError(
|
625
|
+
f"Failed to call {func.__name__} after {max_retries} retries."
|
626
|
+
)
|
627
|
+
|
628
|
+
return wrapper
|
629
|
+
|
630
|
+
|
631
|
+
def create_job(
|
632
|
+
batch_api: k8s_client.BatchV1Api,
|
633
|
+
namespace: str,
|
634
|
+
job_manifest: k8s_client.V1Job,
|
635
|
+
) -> None:
|
636
|
+
"""Create a Kubernetes job.
|
637
|
+
|
638
|
+
Args:
|
639
|
+
batch_api: Kubernetes batch api.
|
640
|
+
namespace: Kubernetes namespace.
|
641
|
+
job_manifest: The manifest of the job to create.
|
642
|
+
"""
|
643
|
+
retry_on_api_exception(batch_api.create_namespaced_job)(
|
644
|
+
namespace=namespace,
|
645
|
+
body=job_manifest,
|
646
|
+
)
|
647
|
+
|
648
|
+
|
649
|
+
def wait_for_job_to_finish(
|
650
|
+
batch_api: k8s_client.BatchV1Api,
|
651
|
+
core_api: k8s_client.CoreV1Api,
|
652
|
+
namespace: str,
|
653
|
+
job_name: str,
|
654
|
+
backoff_interval: float = 1,
|
655
|
+
maximum_backoff: float = 32,
|
656
|
+
exponential_backoff: bool = False,
|
657
|
+
container_name: Optional[str] = None,
|
658
|
+
stream_logs: bool = True,
|
659
|
+
) -> None:
|
660
|
+
"""Wait for a job to finish.
|
661
|
+
|
662
|
+
Args:
|
663
|
+
batch_api: Kubernetes BatchV1Api client.
|
664
|
+
core_api: Kubernetes CoreV1Api client.
|
665
|
+
namespace: Kubernetes namespace.
|
666
|
+
job_name: Name of the job for which to wait.
|
667
|
+
backoff_interval: The interval to wait between polling the job status.
|
668
|
+
maximum_backoff: The maximum interval to wait between polling the job
|
669
|
+
status.
|
670
|
+
exponential_backoff: Whether to use exponential backoff.
|
671
|
+
stream_logs: Whether to stream the job logs.
|
672
|
+
container_name: Name of the container to stream logs from.
|
673
|
+
|
674
|
+
Raises:
|
675
|
+
RuntimeError: If the job failed or timed out.
|
676
|
+
"""
|
677
|
+
logged_lines_per_pod: Dict[str, int] = defaultdict(int)
|
678
|
+
finished_pods = set()
|
679
|
+
|
680
|
+
while True:
|
681
|
+
job: k8s_client.V1Job = retry_on_api_exception(
|
682
|
+
batch_api.read_namespaced_job
|
683
|
+
)(name=job_name, namespace=namespace)
|
684
|
+
|
685
|
+
if job.status.conditions:
|
686
|
+
for condition in job.status.conditions:
|
687
|
+
if condition.type == "Complete" and condition.status == "True":
|
688
|
+
return
|
689
|
+
if condition.type == "Failed" and condition.status == "True":
|
690
|
+
raise RuntimeError(
|
691
|
+
f"Job `{namespace}:{job_name}` failed: "
|
692
|
+
f"{condition.message}"
|
693
|
+
)
|
694
|
+
|
695
|
+
if stream_logs:
|
696
|
+
try:
|
697
|
+
pod_list: k8s_client.V1PodList = core_api.list_namespaced_pod(
|
698
|
+
namespace=namespace,
|
699
|
+
label_selector=f"job-name={job_name}",
|
700
|
+
)
|
701
|
+
except ApiException as e:
|
702
|
+
logger.error("Error fetching pods: %s.", e)
|
703
|
+
pod_list = []
|
704
|
+
else:
|
705
|
+
# Sort pods by creation timestamp, oldest first
|
706
|
+
pod_list.items.sort(
|
707
|
+
key=lambda pod: pod.metadata.creation_timestamp,
|
708
|
+
)
|
709
|
+
|
710
|
+
for pod in pod_list.items:
|
711
|
+
pod_name = pod.metadata.name
|
712
|
+
pod_status = pod.status.phase
|
713
|
+
|
714
|
+
if pod_name in finished_pods:
|
715
|
+
# We've already streamed all logs for this pod, so we can
|
716
|
+
# skip it.
|
717
|
+
continue
|
718
|
+
|
719
|
+
if pod_status == PodPhase.PENDING.value:
|
720
|
+
# The pod is still pending, so we can't stream logs for it
|
721
|
+
# yet.
|
722
|
+
continue
|
723
|
+
|
724
|
+
if pod_status in [
|
725
|
+
PodPhase.SUCCEEDED.value,
|
726
|
+
PodPhase.FAILED.value,
|
727
|
+
]:
|
728
|
+
finished_pods.add(pod_name)
|
729
|
+
|
730
|
+
containers = pod.spec.containers
|
731
|
+
if not container_name:
|
732
|
+
container_name = containers[0].name
|
733
|
+
|
734
|
+
try:
|
735
|
+
response = core_api.read_namespaced_pod_log(
|
736
|
+
name=pod_name,
|
737
|
+
namespace=namespace,
|
738
|
+
container=container_name,
|
739
|
+
_preload_content=False,
|
740
|
+
)
|
741
|
+
except ApiException as e:
|
742
|
+
logger.error("Error reading pod logs: %s.", e)
|
743
|
+
else:
|
744
|
+
raw_data = response.data
|
745
|
+
decoded_log = raw_data.decode("utf-8", errors="replace")
|
746
|
+
logs = decoded_log.splitlines()
|
747
|
+
logged_lines = logged_lines_per_pod[pod_name]
|
748
|
+
if len(logs) > logged_lines:
|
749
|
+
for line in logs[logged_lines:]:
|
750
|
+
logger.info(line)
|
751
|
+
logged_lines_per_pod[pod_name] = len(logs)
|
752
|
+
|
753
|
+
time.sleep(backoff_interval)
|
754
|
+
if exponential_backoff and backoff_interval < maximum_backoff:
|
755
|
+
backoff_interval *= 2
|
@@ -447,6 +447,13 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
|
|
447
447
|
step_name,
|
448
448
|
)
|
449
449
|
|
450
|
+
if retry_config := step.config.retry:
|
451
|
+
if retry_config.delay or retry_config.backoff:
|
452
|
+
logger.warning(
|
453
|
+
"Specifying retry delay or backoff is not supported "
|
454
|
+
"for the Kubernetes orchestrator."
|
455
|
+
)
|
456
|
+
|
450
457
|
pipeline_name = deployment.pipeline_configuration.name
|
451
458
|
settings = cast(
|
452
459
|
KubernetesOrchestratorSettings, self.get_settings(deployment)
|
@@ -693,7 +700,7 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
|
|
693
700
|
Args:
|
694
701
|
run: The run that was executed by this orchestrator.
|
695
702
|
graceful: If True, does nothing (lets the orchestrator and steps finish naturally).
|
696
|
-
If False, stops all running step
|
703
|
+
If False, stops all running step jobs.
|
697
704
|
|
698
705
|
Raises:
|
699
706
|
RuntimeError: If we fail to stop the run.
|
@@ -706,55 +713,63 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
|
|
706
713
|
)
|
707
714
|
return
|
708
715
|
|
709
|
-
|
716
|
+
jobs_stopped = []
|
710
717
|
errors = []
|
711
718
|
|
712
|
-
# Find all
|
719
|
+
# Find all jobs running steps of the pipeline
|
713
720
|
label_selector = f"run_id={kube_utils.sanitize_label(str(run.id))}"
|
714
721
|
try:
|
715
|
-
|
722
|
+
jobs = self._k8s_batch_api.list_namespaced_job(
|
716
723
|
namespace=self.config.kubernetes_namespace,
|
717
724
|
label_selector=label_selector,
|
718
725
|
)
|
719
726
|
except Exception as e:
|
720
727
|
raise RuntimeError(
|
721
|
-
f"Failed to list step
|
728
|
+
f"Failed to list step jobs with run ID {run.id}: {e}"
|
722
729
|
)
|
723
730
|
|
724
|
-
|
725
|
-
|
726
|
-
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
+
for job in jobs.items:
|
732
|
+
if job.status.conditions:
|
733
|
+
# Don't delete completed/failed jobs
|
734
|
+
for condition in job.status.conditions:
|
735
|
+
if (
|
736
|
+
condition.type == "Complete"
|
737
|
+
and condition.status == "True"
|
738
|
+
):
|
739
|
+
continue
|
740
|
+
if (
|
741
|
+
condition.type == "Failed"
|
742
|
+
and condition.status == "True"
|
743
|
+
):
|
744
|
+
continue
|
731
745
|
|
732
746
|
try:
|
733
|
-
self.
|
734
|
-
name=
|
747
|
+
self._k8s_batch_api.delete_namespaced_job(
|
748
|
+
name=job.metadata.name,
|
735
749
|
namespace=self.config.kubernetes_namespace,
|
750
|
+
propagation_policy="Foreground",
|
736
751
|
)
|
737
|
-
|
752
|
+
jobs_stopped.append(f"step job: {job.metadata.name}")
|
738
753
|
logger.debug(
|
739
|
-
f"Successfully initiated graceful stop of step
|
754
|
+
f"Successfully initiated graceful stop of step job: {job.metadata.name}"
|
740
755
|
)
|
741
756
|
except Exception as e:
|
742
|
-
error_msg = f"Failed to stop step
|
757
|
+
error_msg = f"Failed to stop step job {job.metadata.name}: {e}"
|
743
758
|
logger.warning(error_msg)
|
744
759
|
errors.append(error_msg)
|
745
760
|
|
746
761
|
# Summary logging
|
747
762
|
settings = cast(KubernetesOrchestratorSettings, self.get_settings(run))
|
748
763
|
grace_period_seconds = settings.pod_stop_grace_period
|
749
|
-
if
|
764
|
+
if jobs_stopped:
|
750
765
|
logger.debug(
|
751
|
-
f"Successfully initiated graceful termination of: {', '.join(
|
766
|
+
f"Successfully initiated graceful termination of: {', '.join(jobs_stopped)}. "
|
752
767
|
f"Pods will terminate within {grace_period_seconds} seconds."
|
753
768
|
)
|
754
769
|
|
755
770
|
if errors:
|
756
771
|
error_summary = "; ".join(errors)
|
757
|
-
if not
|
772
|
+
if not jobs_stopped:
|
758
773
|
# If nothing was stopped successfully, raise an error
|
759
774
|
raise RuntimeError(
|
760
775
|
f"Failed to stop pipeline run: {error_summary}"
|
@@ -765,10 +780,9 @@ class KubernetesOrchestrator(ContainerizedOrchestrator):
|
|
765
780
|
f"Partial stop operation completed with errors: {error_summary}"
|
766
781
|
)
|
767
782
|
|
768
|
-
|
769
|
-
if not pods_stopped and not errors:
|
783
|
+
if not jobs_stopped and not errors:
|
770
784
|
logger.info(
|
771
|
-
f"No running step
|
785
|
+
f"No running step jobs found for pipeline run with ID: {run.id}"
|
772
786
|
)
|
773
787
|
|
774
788
|
def get_pipeline_run_metadata(
|
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Entrypoint of the Kubernetes master/orchestrator pod."""
|
15
15
|
|
16
16
|
import argparse
|
17
|
+
import random
|
17
18
|
import socket
|
18
19
|
from typing import Callable, Dict, Optional, cast
|
19
20
|
|
@@ -36,7 +37,9 @@ from zenml.integrations.kubernetes.orchestrators.kubernetes_orchestrator import
|
|
36
37
|
KubernetesOrchestrator,
|
37
38
|
)
|
38
39
|
from zenml.integrations.kubernetes.orchestrators.manifest_utils import (
|
40
|
+
build_job_manifest,
|
39
41
|
build_pod_manifest,
|
42
|
+
pod_template_manifest_from_pod,
|
40
43
|
)
|
41
44
|
from zenml.logger import get_logger
|
42
45
|
from zenml.logging.step_logging import setup_orchestrator_logging
|
@@ -110,8 +113,16 @@ def main() -> None:
|
|
110
113
|
# Get a Kubernetes client from the active Kubernetes orchestrator, but
|
111
114
|
# override the `incluster` setting to `True` since we are running inside
|
112
115
|
# the Kubernetes cluster.
|
113
|
-
|
116
|
+
|
117
|
+
api_client_config = orchestrator.get_kube_client(
|
118
|
+
incluster=True
|
119
|
+
).configuration
|
120
|
+
api_client_config.connection_pool_maxsize = (
|
121
|
+
pipeline_settings.max_parallelism
|
122
|
+
)
|
123
|
+
kube_client = k8s_client.ApiClient(api_client_config)
|
114
124
|
core_api = k8s_client.CoreV1Api(kube_client)
|
125
|
+
batch_api = k8s_client.BatchV1Api(kube_client)
|
115
126
|
|
116
127
|
env = get_config_environment_vars()
|
117
128
|
env[ENV_ZENML_KUBERNETES_RUN_ID] = orchestrator_pod_name
|
@@ -150,6 +161,9 @@ def main() -> None:
|
|
150
161
|
Returns:
|
151
162
|
Whether the step node needs to be run.
|
152
163
|
"""
|
164
|
+
if not step_run_request_factory.has_caching_enabled(step_name):
|
165
|
+
return True
|
166
|
+
|
153
167
|
step_run_request = step_run_request_factory.create_request(
|
154
168
|
step_name
|
155
169
|
)
|
@@ -266,39 +280,95 @@ def main() -> None:
|
|
266
280
|
service_account_name=settings.step_pod_service_account_name
|
267
281
|
or settings.service_account_name,
|
268
282
|
mount_local_stores=mount_local_stores,
|
269
|
-
owner_references=owner_references,
|
270
283
|
termination_grace_period_seconds=settings.pod_stop_grace_period,
|
271
284
|
labels=step_pod_labels,
|
272
285
|
)
|
273
286
|
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
287
|
+
retry_config = step_config.retry
|
288
|
+
backoff_limit = (
|
289
|
+
retry_config.max_retries if retry_config else 0
|
290
|
+
) + settings.backoff_limit_margin
|
291
|
+
|
292
|
+
# This is to fix a bug in the kubernetes client which has some wrong
|
293
|
+
# client-side validations that means the `on_exit_codes` field is
|
294
|
+
# unusable. See https://github.com/kubernetes-client/python/issues/2056
|
295
|
+
class PatchedFailurePolicyRule(k8s_client.V1PodFailurePolicyRule): # type: ignore[misc]
|
296
|
+
@property
|
297
|
+
def on_pod_conditions(self): # type: ignore[no-untyped-def]
|
298
|
+
return self._on_pod_conditions
|
299
|
+
|
300
|
+
@on_pod_conditions.setter
|
301
|
+
def on_pod_conditions(self, on_pod_conditions): # type: ignore[no-untyped-def]
|
302
|
+
self._on_pod_conditions = on_pod_conditions
|
303
|
+
|
304
|
+
k8s_client.V1PodFailurePolicyRule = PatchedFailurePolicyRule
|
305
|
+
k8s_client.models.V1PodFailurePolicyRule = PatchedFailurePolicyRule
|
306
|
+
|
307
|
+
pod_failure_policy = settings.pod_failure_policy or {
|
308
|
+
# These rules are applied sequentially. This means any failure in
|
309
|
+
# the main container will count towards the max retries. Any other
|
310
|
+
# disruption will not count towards the max retries.
|
311
|
+
"rules": [
|
312
|
+
# If the main container fails, we count it towards the max
|
313
|
+
# retries.
|
314
|
+
{
|
315
|
+
"action": "Count",
|
316
|
+
"onExitCodes": {
|
317
|
+
"containerName": "main",
|
318
|
+
"operator": "NotIn",
|
319
|
+
"values": [0],
|
320
|
+
},
|
321
|
+
},
|
322
|
+
# If the pod is interrupted at any other time, we don't count
|
323
|
+
# it as a retry
|
324
|
+
{
|
325
|
+
"action": "Ignore",
|
326
|
+
"onPodConditions": [
|
327
|
+
{
|
328
|
+
"type": "DisruptionTarget",
|
329
|
+
}
|
330
|
+
],
|
331
|
+
},
|
332
|
+
]
|
333
|
+
}
|
334
|
+
|
335
|
+
job_name = settings.pod_name_prefix or ""
|
336
|
+
random_prefix = "".join(random.choices("0123456789abcdef", k=8))
|
337
|
+
job_name += f"-{random_prefix}-{step_name}-{deployment.pipeline_configuration.name}"
|
338
|
+
# The job name will be used as a label on the pods, so we need to make
|
339
|
+
# sure it doesn't exceed the label length limit
|
340
|
+
job_name = kube_utils.sanitize_label(job_name)
|
341
|
+
|
342
|
+
job_manifest = build_job_manifest(
|
343
|
+
job_name=job_name,
|
344
|
+
pod_template=pod_template_manifest_from_pod(pod_manifest),
|
345
|
+
backoff_limit=backoff_limit,
|
346
|
+
ttl_seconds_after_finished=settings.ttl_seconds_after_finished,
|
347
|
+
active_deadline_seconds=settings.active_deadline_seconds,
|
348
|
+
pod_failure_policy=pod_failure_policy,
|
349
|
+
owner_references=owner_references,
|
350
|
+
labels=step_pod_labels,
|
351
|
+
)
|
352
|
+
|
353
|
+
kube_utils.create_job(
|
354
|
+
batch_api=batch_api,
|
279
355
|
namespace=namespace,
|
280
|
-
|
281
|
-
startup_failure_delay=settings.pod_failure_retry_delay,
|
282
|
-
startup_failure_backoff=settings.pod_failure_backoff,
|
283
|
-
startup_timeout=settings.pod_startup_timeout,
|
356
|
+
job_manifest=job_manifest,
|
284
357
|
)
|
285
358
|
|
286
|
-
|
287
|
-
logger.info(f"Waiting for pod of step `{step_name}` to finish...")
|
359
|
+
logger.info(f"Waiting for job of step `{step_name}` to finish...")
|
288
360
|
try:
|
289
|
-
kube_utils.
|
290
|
-
|
291
|
-
|
292
|
-
),
|
293
|
-
pod_name=pod_name,
|
361
|
+
kube_utils.wait_for_job_to_finish(
|
362
|
+
batch_api=batch_api,
|
363
|
+
core_api=core_api,
|
294
364
|
namespace=namespace,
|
295
|
-
|
296
|
-
stream_logs=
|
365
|
+
job_name=job_name,
|
366
|
+
stream_logs=pipeline_settings.stream_step_logs,
|
297
367
|
)
|
298
368
|
|
299
|
-
logger.info(f"
|
369
|
+
logger.info(f"Job for step `{step_name}` completed.")
|
300
370
|
except Exception:
|
301
|
-
logger.error(f"
|
371
|
+
logger.error(f"Job for step `{step_name}` failed.")
|
302
372
|
|
303
373
|
raise
|
304
374
|
|
@@ -450,3 +450,62 @@ def build_secret_manifest(
|
|
450
450
|
"type": secret_type,
|
451
451
|
"data": encoded_data,
|
452
452
|
}
|
453
|
+
|
454
|
+
|
455
|
+
def pod_template_manifest_from_pod(
|
456
|
+
pod: k8s_client.V1Pod,
|
457
|
+
) -> k8s_client.V1PodTemplateSpec:
|
458
|
+
"""Build a Kubernetes pod template manifest from a pod.
|
459
|
+
|
460
|
+
Args:
|
461
|
+
pod: The pod manifest to build the template from.
|
462
|
+
|
463
|
+
Returns:
|
464
|
+
The pod template manifest.
|
465
|
+
"""
|
466
|
+
return k8s_client.V1PodTemplateSpec(
|
467
|
+
metadata=pod.metadata,
|
468
|
+
spec=pod.spec,
|
469
|
+
)
|
470
|
+
|
471
|
+
|
472
|
+
def build_job_manifest(
|
473
|
+
job_name: str,
|
474
|
+
pod_template: k8s_client.V1PodTemplateSpec,
|
475
|
+
backoff_limit: Optional[int] = None,
|
476
|
+
ttl_seconds_after_finished: Optional[int] = None,
|
477
|
+
labels: Optional[Dict[str, str]] = None,
|
478
|
+
active_deadline_seconds: Optional[int] = None,
|
479
|
+
pod_failure_policy: Optional[Dict[str, Any]] = None,
|
480
|
+
owner_references: Optional[List[k8s_client.V1OwnerReference]] = None,
|
481
|
+
) -> k8s_client.V1Job:
|
482
|
+
"""Build a Kubernetes job manifest.
|
483
|
+
|
484
|
+
Args:
|
485
|
+
job_name: Name of the job.
|
486
|
+
pod_template: The pod template to use for the job.
|
487
|
+
backoff_limit: The backoff limit for the job.
|
488
|
+
ttl_seconds_after_finished: The TTL seconds after finished for the job.
|
489
|
+
labels: The labels to use for the job.
|
490
|
+
active_deadline_seconds: The active deadline seconds for the job.
|
491
|
+
pod_failure_policy: The pod failure policy for the job.
|
492
|
+
owner_references: The owner references for the job.
|
493
|
+
|
494
|
+
Returns:
|
495
|
+
The Kubernetes job manifest.
|
496
|
+
"""
|
497
|
+
job_spec = k8s_client.V1JobSpec(
|
498
|
+
template=pod_template,
|
499
|
+
backoff_limit=backoff_limit,
|
500
|
+
parallelism=1,
|
501
|
+
ttl_seconds_after_finished=ttl_seconds_after_finished,
|
502
|
+
active_deadline_seconds=active_deadline_seconds,
|
503
|
+
pod_failure_policy=pod_failure_policy,
|
504
|
+
)
|
505
|
+
job_metadata = k8s_client.V1ObjectMeta(
|
506
|
+
name=job_name,
|
507
|
+
labels=labels,
|
508
|
+
owner_references=owner_references,
|
509
|
+
)
|
510
|
+
|
511
|
+
return k8s_client.V1Job(spec=job_spec, metadata=job_metadata)
|
zenml/logger.py
CHANGED
@@ -39,14 +39,15 @@ ZENML_LOGGING_COLORS_DISABLED = handle_bool_env_var(
|
|
39
39
|
class CustomFormatter(logging.Formatter):
|
40
40
|
"""Formats logs according to custom specifications."""
|
41
41
|
|
42
|
-
grey: str = "\x1b[
|
42
|
+
grey: str = "\x1b[90m"
|
43
|
+
white: str = "\x1b[37m"
|
43
44
|
pink: str = "\x1b[35m"
|
44
45
|
green: str = "\x1b[32m"
|
45
46
|
yellow: str = "\x1b[33m"
|
46
47
|
red: str = "\x1b[31m"
|
47
48
|
cyan: str = "\x1b[1;36m"
|
48
49
|
bold_red: str = "\x1b[31;1m"
|
49
|
-
purple: str = "\x1b[
|
50
|
+
purple: str = "\x1b[38;5;105m"
|
50
51
|
blue: str = "\x1b[34m"
|
51
52
|
reset: str = "\x1b[0m"
|
52
53
|
|
@@ -59,7 +60,7 @@ class CustomFormatter(logging.Formatter):
|
|
59
60
|
|
60
61
|
COLORS: Dict[LoggingLevels, str] = {
|
61
62
|
LoggingLevels.DEBUG: grey,
|
62
|
-
LoggingLevels.INFO:
|
63
|
+
LoggingLevels.INFO: white,
|
63
64
|
LoggingLevels.WARN: yellow,
|
64
65
|
LoggingLevels.ERROR: red,
|
65
66
|
LoggingLevels.CRITICAL: bold_red,
|
@@ -87,12 +88,13 @@ class CustomFormatter(logging.Formatter):
|
|
87
88
|
)
|
88
89
|
formatter = logging.Formatter(log_fmt)
|
89
90
|
formatted_message = formatter.format(record)
|
91
|
+
|
90
92
|
quoted_groups = re.findall("`([^`]*)`", formatted_message)
|
91
93
|
for quoted in quoted_groups:
|
92
94
|
formatted_message = formatted_message.replace(
|
93
95
|
"`" + quoted + "`",
|
94
96
|
self.reset
|
95
|
-
+ self.
|
97
|
+
+ self.purple
|
96
98
|
+ quoted
|
97
99
|
+ self.COLORS.get(LoggingLevels(record.levelno)),
|
98
100
|
)
|