zenml-nightly 0.84.1.dev20250805__py3-none-any.whl → 0.84.2.dev20250807__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (21) hide show
  1. zenml/VERSION +1 -1
  2. zenml/integrations/huggingface/__init__.py +1 -1
  3. zenml/integrations/hyperai/__init__.py +1 -1
  4. zenml/integrations/kubernetes/constants.py +27 -0
  5. zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +79 -36
  6. zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +55 -24
  7. zenml/integrations/kubernetes/orchestrators/dag_runner.py +367 -0
  8. zenml/integrations/kubernetes/orchestrators/kube_utils.py +368 -1
  9. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +144 -262
  10. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +392 -244
  11. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +53 -85
  12. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +74 -32
  13. zenml/logging/step_logging.py +33 -30
  14. zenml/steps/base_step.py +6 -6
  15. zenml/steps/step_decorator.py +4 -4
  16. zenml/zen_stores/migrations/versions/0.84.2_release.py +23 -0
  17. {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.2.dev20250807.dist-info}/METADATA +3 -3
  18. {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.2.dev20250807.dist-info}/RECORD +21 -18
  19. {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.2.dev20250807.dist-info}/LICENSE +0 -0
  20. {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.2.dev20250807.dist-info}/WHEEL +0 -0
  21. {zenml_nightly-0.84.1.dev20250805.dist-info → zenml_nightly-0.84.2.dev20250807.dist-info}/entry_points.txt +0 -0
@@ -43,12 +43,16 @@ from kubernetes import client as k8s_client
43
43
  from kubernetes import config as k8s_config
44
44
  from kubernetes.client.rest import ApiException
45
45
 
46
+ from zenml.integrations.kubernetes.constants import (
47
+ STEP_NAME_ANNOTATION_KEY,
48
+ )
46
49
  from zenml.integrations.kubernetes.orchestrators.manifest_utils import (
47
50
  build_namespace_manifest,
48
51
  build_role_binding_manifest_for_service_account,
49
52
  build_secret_manifest,
50
53
  build_service_account_manifest,
51
54
  )
55
+ from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings
52
56
  from zenml.logger import get_logger
53
57
  from zenml.utils.time_utils import utc_now
54
58
 
@@ -57,6 +61,35 @@ logger = get_logger(__name__)
57
61
  R = TypeVar("R")
58
62
 
59
63
 
64
+ # This is to fix a bug in the kubernetes client which has some wrong
65
+ # client-side validations that means the `on_exit_codes` field is
66
+ # unusable. See https://github.com/kubernetes-client/python/issues/2056
67
+ class PatchedFailurePolicyRule(k8s_client.V1PodFailurePolicyRule): # type: ignore[misc]
68
+ """Patched failure policy rule."""
69
+
70
+ @property
71
+ def on_pod_conditions(self): # type: ignore[no-untyped-def]
72
+ """On pod conditions.
73
+
74
+ Returns:
75
+ On pod conditions.
76
+ """
77
+ return self._on_pod_conditions
78
+
79
+ @on_pod_conditions.setter
80
+ def on_pod_conditions(self, on_pod_conditions): # type: ignore[no-untyped-def]
81
+ """On pod conditions.
82
+
83
+ Args:
84
+ on_pod_conditions: On pod conditions.
85
+ """
86
+ self._on_pod_conditions = on_pod_conditions
87
+
88
+
89
+ k8s_client.V1PodFailurePolicyRule = PatchedFailurePolicyRule
90
+ k8s_client.models.V1PodFailurePolicyRule = PatchedFailurePolicyRule
91
+
92
+
60
93
  class PodPhase(enum.Enum):
61
94
  """Phase of the Kubernetes pod.
62
95
 
@@ -71,6 +104,14 @@ class PodPhase(enum.Enum):
71
104
  UNKNOWN = "Unknown"
72
105
 
73
106
 
107
+ class JobStatus(enum.Enum):
108
+ """Status of a Kubernetes job."""
109
+
110
+ RUNNING = "Running"
111
+ SUCCEEDED = "Succeeded"
112
+ FAILED = "Failed"
113
+
114
+
74
115
  def is_inside_kubernetes() -> bool:
75
116
  """Check whether we are inside a Kubernetes cluster or on a remote host.
76
117
 
@@ -211,7 +252,9 @@ def get_pod(
211
252
  The found pod object. None if it's not found.
212
253
  """
213
254
  try:
214
- return core_api.read_namespaced_pod(name=pod_name, namespace=namespace)
255
+ return retry_on_api_exception(core_api.read_namespaced_pod)(
256
+ name=pod_name, namespace=namespace
257
+ )
215
258
  except k8s_client.rest.ApiException as e:
216
259
  if e.status == 404:
217
260
  return None
@@ -603,6 +646,7 @@ def retry_on_api_exception(
603
646
  max_retries: int = 3,
604
647
  delay: float = 1,
605
648
  backoff: float = 1,
649
+ fail_on_status_codes: Tuple[int, ...] = (404,),
606
650
  ) -> Callable[..., R]:
607
651
  """Retry a function on API exceptions.
608
652
 
@@ -611,6 +655,7 @@ def retry_on_api_exception(
611
655
  max_retries: The maximum number of retries.
612
656
  delay: The delay between retries.
613
657
  backoff: The backoff factor.
658
+ fail_on_status_codes: The status codes to fail on immediately.
614
659
 
615
660
  Returns:
616
661
  The wrapped function with retry logic.
@@ -624,6 +669,9 @@ def retry_on_api_exception(
624
669
  try:
625
670
  return func(*args, **kwargs)
626
671
  except ApiException as e:
672
+ if e.status in fail_on_status_codes:
673
+ raise
674
+
627
675
  retries += 1
628
676
  if retries <= max_retries:
629
677
  logger.warning("Error calling %s: %s.", func.__name__, e)
@@ -657,6 +705,86 @@ def create_job(
657
705
  )
658
706
 
659
707
 
708
+ def get_job(
709
+ batch_api: k8s_client.BatchV1Api,
710
+ namespace: str,
711
+ job_name: str,
712
+ ) -> k8s_client.V1Job:
713
+ """Get a job by name.
714
+
715
+ Args:
716
+ batch_api: Kubernetes batch api.
717
+ namespace: Kubernetes namespace.
718
+ job_name: The name of the job to get.
719
+
720
+ Returns:
721
+ The job.
722
+ """
723
+ return retry_on_api_exception(batch_api.read_namespaced_job)(
724
+ name=job_name, namespace=namespace
725
+ )
726
+
727
+
728
+ def list_jobs(
729
+ batch_api: k8s_client.BatchV1Api,
730
+ namespace: str,
731
+ label_selector: Optional[str] = None,
732
+ ) -> k8s_client.V1JobList:
733
+ """List jobs in a namespace.
734
+
735
+ Args:
736
+ batch_api: Kubernetes batch api.
737
+ namespace: Kubernetes namespace.
738
+ label_selector: The label selector to use.
739
+
740
+ Returns:
741
+ The job list.
742
+ """
743
+ return retry_on_api_exception(batch_api.list_namespaced_job)(
744
+ namespace=namespace,
745
+ label_selector=label_selector,
746
+ )
747
+
748
+
749
+ def update_job(
750
+ batch_api: k8s_client.BatchV1Api,
751
+ namespace: str,
752
+ job_name: str,
753
+ annotations: Dict[str, str],
754
+ ) -> k8s_client.V1Job:
755
+ """Update a job.
756
+
757
+ Args:
758
+ batch_api: Kubernetes batch api.
759
+ namespace: Kubernetes namespace.
760
+ job_name: The name of the job to update.
761
+ annotations: The annotations to update.
762
+
763
+ Returns:
764
+ The updated job.
765
+ """
766
+ return retry_on_api_exception(batch_api.patch_namespaced_job)(
767
+ name=job_name,
768
+ namespace=namespace,
769
+ body={"metadata": {"annotations": annotations}},
770
+ )
771
+
772
+
773
+ def is_step_job(job: k8s_client.V1Job) -> bool:
774
+ """Check if a job is a step job.
775
+
776
+ Args:
777
+ job: The job to check.
778
+
779
+ Returns:
780
+ Whether the job is a step job.
781
+ """
782
+ if not job.metadata or not job.metadata.annotations:
783
+ return False
784
+
785
+ return STEP_NAME_ANNOTATION_KEY in job.metadata.annotations
786
+
787
+
660
788
  def get_container_status(
661
789
  pod: k8s_client.V1Pod, container_name: str
662
790
  ) -> Optional[k8s_client.V1ContainerState]:
@@ -841,3 +969,242 @@ def wait_for_job_to_finish(
841
969
  time.sleep(backoff_interval)
842
970
  if exponential_backoff and backoff_interval < maximum_backoff:
843
971
  backoff_interval *= 2
972
+
973
+
974
+ def check_job_status(
975
+ batch_api: k8s_client.BatchV1Api,
976
+ core_api: k8s_client.CoreV1Api,
977
+ namespace: str,
978
+ job_name: str,
979
+ fail_on_container_waiting_reasons: Optional[List[str]] = None,
980
+ container_name: Optional[str] = None,
981
+ ) -> Tuple[JobStatus, Optional[str]]:
982
+ """Check the status of a job.
983
+
984
+ Args:
985
+ batch_api: Kubernetes BatchV1Api client.
986
+ core_api: Kubernetes CoreV1Api client.
987
+ namespace: Kubernetes namespace.
988
+ job_name: Name of the job for which to wait.
989
+ fail_on_container_waiting_reasons: List of container waiting reasons
990
+ that will cause the job to fail.
991
+ container_name: Name of the container to check for failure.
992
+
993
+ Returns:
994
+ The status of the job and an error message if the job failed.
995
+ """
996
+ job: k8s_client.V1Job = retry_on_api_exception(
997
+ batch_api.read_namespaced_job
998
+ )(name=job_name, namespace=namespace)
999
+
1000
+ if job.status.conditions:
1001
+ for condition in job.status.conditions:
1002
+ if condition.type == "Complete" and condition.status == "True":
1003
+ return JobStatus.SUCCEEDED, None
1004
+ if condition.type == "Failed" and condition.status == "True":
1005
+ error_message = condition.message or "Unknown"
1006
+ container_failure_reason = None
1007
+ try:
1008
+ pods = core_api.list_namespaced_pod(
1009
+ label_selector=f"job-name={job_name}",
1010
+ namespace=namespace,
1011
+ ).items
1012
+ # Sort pods by creation timestamp, oldest first
1013
+ pods.sort(
1014
+ key=lambda pod: pod.metadata.creation_timestamp,
1015
+ )
1016
+ if pods:
1017
+ if (
1018
+ termination_reason
1019
+ := get_container_termination_reason(
1020
+ pods[-1], container_name or "main"
1021
+ )
1022
+ ):
1023
+ exit_code, reason = termination_reason
1024
+ if exit_code != 0:
1025
+ container_failure_reason = (
1026
+ f"{reason}, exit_code={exit_code}"
1027
+ )
1028
+ except Exception:
1029
+ pass
1030
+
1031
+ if container_failure_reason:
1032
+ error_message += f" (container failure reason: {container_failure_reason})"
1033
+
1034
+ return JobStatus.FAILED, error_message
1035
+
1036
+ if fail_on_container_waiting_reasons:
1037
+ pod_list: k8s_client.V1PodList = retry_on_api_exception(
1038
+ core_api.list_namespaced_pod
1039
+ )(
1040
+ namespace=namespace,
1041
+ label_selector=f"job-name={job_name}",
1042
+ field_selector="status.phase=Pending",
1043
+ )
1044
+ for pod in pod_list.items:
1045
+ container_state = get_container_status(
1046
+ pod, container_name or "main"
1047
+ )
1048
+
1049
+ if (
1050
+ container_state
1051
+ and (waiting_state := container_state.waiting)
1052
+ and waiting_state.reason in fail_on_container_waiting_reasons
1053
+ ):
1054
+ retry_on_api_exception(batch_api.delete_namespaced_job)(
1055
+ name=job_name,
1056
+ namespace=namespace,
1057
+ propagation_policy="Foreground",
1058
+ )
1059
+ error_message = (
1060
+ f"Detected container in state `{waiting_state.reason}`"
1061
+ )
1062
+ return JobStatus.FAILED, error_message
1063
+
1064
+ return JobStatus.RUNNING, None
1065
+
1066
+
1067
+ def create_config_map(
1068
+ core_api: k8s_client.CoreV1Api,
1069
+ namespace: str,
1070
+ name: str,
1071
+ data: Dict[str, str],
1072
+ ) -> None:
1073
+ """Create a Kubernetes config map.
1074
+
1075
+ Args:
1076
+ core_api: Kubernetes CoreV1Api client.
1077
+ namespace: Kubernetes namespace.
1078
+ name: Name of the config map to create.
1079
+ data: Data to store in the config map.
1080
+ """
1081
+ retry_on_api_exception(core_api.create_namespaced_config_map)(
1082
+ namespace=namespace,
1083
+ body=k8s_client.V1ConfigMap(metadata={"name": name}, data=data),
1084
+ )
1085
+
1086
+
1087
+ def update_config_map(
1088
+ core_api: k8s_client.CoreV1Api,
1089
+ namespace: str,
1090
+ name: str,
1091
+ data: Dict[str, str],
1092
+ ) -> None:
1093
+ """Update a Kubernetes config map.
1094
+
1095
+ Args:
1096
+ core_api: Kubernetes CoreV1Api client.
1097
+ namespace: Kubernetes namespace.
1098
+ name: Name of the config map to update.
1099
+ data: Data to store in the config map.
1100
+ """
1101
+ retry_on_api_exception(core_api.patch_namespaced_config_map)(
1102
+ namespace=namespace,
1103
+ name=name,
1104
+ body=k8s_client.V1ConfigMap(data=data),
1105
+ )
1106
+
1107
+
1108
+ def get_config_map(
1109
+ core_api: k8s_client.CoreV1Api,
1110
+ namespace: str,
1111
+ name: str,
1112
+ ) -> k8s_client.V1ConfigMap:
1113
+ """Get a Kubernetes config map.
1114
+
1115
+ Args:
1116
+ core_api: Kubernetes CoreV1Api client.
1117
+ namespace: Kubernetes namespace.
1118
+ name: Name of the config map to get.
1119
+
1120
+ Returns:
1121
+ The config map.
1122
+ """
1123
+ return retry_on_api_exception(core_api.read_namespaced_config_map)(
1124
+ namespace=namespace,
1125
+ name=name,
1126
+ )
1127
+
1128
+
1129
+ def delete_config_map(
1130
+ core_api: k8s_client.CoreV1Api,
1131
+ namespace: str,
1132
+ name: str,
1133
+ ) -> None:
1134
+ """Delete a Kubernetes config map.
1135
+
1136
+ Args:
1137
+ core_api: Kubernetes CoreV1Api client.
1138
+ namespace: Kubernetes namespace.
1139
+ name: Name of the config map to delete.
1140
+ """
1141
+ retry_on_api_exception(core_api.delete_namespaced_config_map)(
1142
+ namespace=namespace,
1143
+ name=name,
1144
+ )
1145
+
1146
+
1147
+ def get_parent_job_name(
1148
+ core_api: k8s_client.CoreV1Api,
1149
+ pod_name: str,
1150
+ namespace: str,
1151
+ ) -> Optional[str]:
1152
+ """Get the name of the job that created a pod.
1153
+
1154
+ Args:
1155
+ core_api: Kubernetes CoreV1Api client.
1156
+ pod_name: Name of the pod.
1157
+ namespace: Kubernetes namespace.
1158
+
1159
+ Returns:
1160
+ The name of the job that created the pod, or None if the pod is not
1161
+ associated with a job.
1162
+ """
1163
+ pod = get_pod(core_api, pod_name=pod_name, namespace=namespace)
1164
+ if (
1165
+ pod
1166
+ and pod.metadata
1167
+ and pod.metadata.labels
1168
+ and (job_name := pod.metadata.labels.get("job-name", None))
1169
+ ):
1170
+ return cast(str, job_name)
1171
+
1172
+ return None
1173
+
1174
+
1175
+ def apply_default_resource_requests(
1176
+ memory: str,
1177
+ cpu: Optional[str] = None,
1178
+ pod_settings: Optional[KubernetesPodSettings] = None,
1179
+ ) -> KubernetesPodSettings:
1180
+ """Applies default resource requests to a pod settings object.
1181
+
1182
+ Args:
1183
+ memory: The memory resource request.
1184
+ cpu: The CPU resource request.
1185
+ pod_settings: The pod settings to update. A new one will be created
1186
+ if not provided.
1187
+
1188
+ Returns:
1189
+ The new or updated pod settings.
1190
+ """
1191
+ resources = {
1192
+ "requests": {"memory": memory},
1193
+ }
1194
+ if cpu:
1195
+ resources["requests"]["cpu"] = cpu
1196
+ if not pod_settings:
1197
+ pod_settings = KubernetesPodSettings(resources=resources)
1198
+ elif not pod_settings.resources:
1199
+ # We can't update the pod settings in place (because it's a frozen
1200
+ # pydantic model), so we have to create a new one.
1201
+ pod_settings = KubernetesPodSettings(
1202
+ **pod_settings.model_dump(exclude_unset=True),
1203
+ resources=resources,
1204
+ )
1205
+ else:
1206
+ set_requests = pod_settings.resources.get("requests", {})
1207
+ resources["requests"].update(set_requests)
1208
+ pod_settings.resources["requests"] = resources["requests"]
1209
+
1210
+ return pod_settings