apache-airflow-providers-cncf-kubernetes 7.13.0rc1__py3-none-any.whl → 7.14.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of apache-airflow-providers-cncf-kubernetes might be problematic. Click here for more details.

Files changed (18) hide show
  1. airflow/providers/cncf/kubernetes/__init__.py +1 -1
  2. airflow/providers/cncf/kubernetes/callbacks.py +111 -0
  3. airflow/providers/cncf/kubernetes/get_provider_info.py +4 -2
  4. airflow/providers/cncf/kubernetes/hooks/kubernetes.py +4 -4
  5. airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py +367 -0
  6. airflow/providers/cncf/kubernetes/operators/pod.py +74 -13
  7. airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +221 -136
  8. airflow/providers/cncf/kubernetes/pod_generator.py +13 -6
  9. airflow/providers/cncf/kubernetes/pod_launcher_deprecated.py +3 -3
  10. airflow/providers/cncf/kubernetes/resource_convert/__init__.py +16 -0
  11. airflow/providers/cncf/kubernetes/resource_convert/configmap.py +52 -0
  12. airflow/providers/cncf/kubernetes/resource_convert/env_variable.py +39 -0
  13. airflow/providers/cncf/kubernetes/resource_convert/secret.py +40 -0
  14. airflow/providers/cncf/kubernetes/utils/pod_manager.py +18 -4
  15. {apache_airflow_providers_cncf_kubernetes-7.13.0rc1.dist-info → apache_airflow_providers_cncf_kubernetes-7.14.0rc1.dist-info}/METADATA +6 -6
  16. {apache_airflow_providers_cncf_kubernetes-7.13.0rc1.dist-info → apache_airflow_providers_cncf_kubernetes-7.14.0rc1.dist-info}/RECORD +18 -12
  17. {apache_airflow_providers_cncf_kubernetes-7.13.0rc1.dist-info → apache_airflow_providers_cncf_kubernetes-7.14.0rc1.dist-info}/WHEEL +0 -0
  18. {apache_airflow_providers_cncf_kubernetes-7.13.0rc1.dist-info → apache_airflow_providers_cncf_kubernetes-7.14.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,6 +29,7 @@ from contextlib import AbstractContextManager
29
29
  from functools import cached_property
30
30
  from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
31
31
 
32
+ import kubernetes
32
33
  from kubernetes.client import CoreV1Api, V1Pod, models as k8s
33
34
  from kubernetes.stream import stream
34
35
  from urllib3.exceptions import HTTPError
@@ -48,6 +49,7 @@ from airflow.providers.cncf.kubernetes.backcompat.backwards_compat_converters im
48
49
  convert_volume,
49
50
  convert_volume_mount,
50
51
  )
52
+ from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback
51
53
  from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
52
54
  from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
53
55
  POD_NAME_MAX_LENGTH,
@@ -195,9 +197,12 @@ class KubernetesPodOperator(BaseOperator):
195
197
  Deprecated - use `on_finish_action` instead.
196
198
  :param termination_message_policy: The termination message policy of the base container.
197
199
  Default value is "File"
198
- :param active_deadline_seconds: The active_deadline_seconds which matches to active_deadline_seconds
200
+ :param active_deadline_seconds: The active_deadline_seconds which translates to active_deadline_seconds
199
201
  in V1PodSpec.
202
+ :param callbacks: KubernetesPodOperatorCallback instance contains the callbacks methods on different step
203
+ of KubernetesPodOperator.
200
204
  :param progress_callback: Callback function for receiving k8s container logs.
205
+ `progress_callback` is deprecated, please use :param `callbacks` instead.
201
206
  """
202
207
 
203
208
  # !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!!
@@ -225,6 +230,7 @@ class KubernetesPodOperator(BaseOperator):
225
230
  "volumes",
226
231
  "volume_mounts",
227
232
  "cluster_context",
233
+ "configmaps",
228
234
  )
229
235
  template_fields_renderers = {"env_vars": "py"}
230
236
 
@@ -289,6 +295,7 @@ class KubernetesPodOperator(BaseOperator):
289
295
  is_delete_operator_pod: None | bool = None,
290
296
  termination_message_policy: str = "File",
291
297
  active_deadline_seconds: int | None = None,
298
+ callbacks: type[KubernetesPodOperatorCallback] | None = None,
292
299
  progress_callback: Callable[[str], None] | None = None,
293
300
  **kwargs,
294
301
  ) -> None:
@@ -306,8 +313,9 @@ class KubernetesPodOperator(BaseOperator):
306
313
  if pod_runtime_info_envs:
307
314
  self.env_vars.extend([convert_pod_runtime_info_env(p) for p in pod_runtime_info_envs])
308
315
  self.env_from = env_from or []
309
- if configmaps:
310
- self.env_from.extend([convert_configmap(c) for c in configmaps])
316
+ self.configmaps = configmaps
317
+ if self.configmaps:
318
+ self.env_from.extend([convert_configmap(c) for c in self.configmaps])
311
319
  self.ports = [convert_port(p) for p in ports] if ports else []
312
320
  self.volume_mounts = [convert_volume_mount(v) for v in volume_mounts] if volume_mounts else []
313
321
  self.volumes = [convert_volume(volume) for volume in volumes] if volumes else []
@@ -380,6 +388,8 @@ class KubernetesPodOperator(BaseOperator):
380
388
 
381
389
  self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
382
390
  self._progress_callback = progress_callback
391
+ self.callbacks = callbacks
392
+ self._killed: bool = False
383
393
 
384
394
  @cached_property
385
395
  def _incluster_namespace(self):
@@ -403,11 +413,13 @@ class KubernetesPodOperator(BaseOperator):
403
413
  elif isinstance(content, k8s.V1ResourceRequirements):
404
414
  template_fields = ("limits", "requests")
405
415
  elif isinstance(content, k8s.V1Volume):
406
- template_fields = ("name", "persistent_volume_claim")
416
+ template_fields = ("name", "persistent_volume_claim", "config_map")
407
417
  elif isinstance(content, k8s.V1VolumeMount):
408
418
  template_fields = ("name", "sub_path")
409
419
  elif isinstance(content, k8s.V1PersistentVolumeClaimVolumeSource):
410
420
  template_fields = ("claim_name",)
421
+ elif isinstance(content, k8s.V1ConfigMapVolumeSource):
422
+ template_fields = ("name",)
411
423
  else:
412
424
  template_fields = None
413
425
 
@@ -457,7 +469,9 @@ class KubernetesPodOperator(BaseOperator):
457
469
 
458
470
  @cached_property
459
471
  def pod_manager(self) -> PodManager:
460
- return PodManager(kube_client=self.client, progress_callback=self._progress_callback)
472
+ return PodManager(
473
+ kube_client=self.client, callbacks=self.callbacks, progress_callback=self._progress_callback
474
+ )
461
475
 
462
476
  @cached_property
463
477
  def hook(self) -> PodOperatorHookProtocol:
@@ -471,7 +485,10 @@ class KubernetesPodOperator(BaseOperator):
471
485
 
472
486
  @cached_property
473
487
  def client(self) -> CoreV1Api:
474
- return self.hook.core_v1_client
488
+ client = self.hook.core_v1_client
489
+ if self.callbacks:
490
+ self.callbacks.on_sync_client_creation(client=client)
491
+ return client
475
492
 
476
493
  def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
477
494
  """Return an already-running pod for this task instance if one exists."""
@@ -536,11 +553,13 @@ class KubernetesPodOperator(BaseOperator):
536
553
  def execute_sync(self, context: Context):
537
554
  result = None
538
555
  try:
539
- self.pod_request_obj = self.build_pod_request_obj(context)
540
- self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
541
- pod_request_obj=self.pod_request_obj,
542
- context=context,
543
- )
556
+ if self.pod_request_obj is None:
557
+ self.pod_request_obj = self.build_pod_request_obj(context)
558
+ if self.pod is None:
559
+ self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
560
+ pod_request_obj=self.pod_request_obj,
561
+ context=context,
562
+ )
544
563
  # push to xcom now so that if there is an error we still have the values
545
564
  ti = context["ti"]
546
565
  ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
@@ -548,7 +567,17 @@ class KubernetesPodOperator(BaseOperator):
548
567
 
549
568
  # get remote pod for use in cleanup methods
550
569
  self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
570
+ if self.callbacks:
571
+ self.callbacks.on_pod_creation(
572
+ pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
573
+ )
551
574
  self.await_pod_start(pod=self.pod)
575
+ if self.callbacks:
576
+ self.callbacks.on_pod_starting(
577
+ pod=self.find_pod(self.pod.metadata.namespace, context=context),
578
+ client=self.client,
579
+ mode=ExecutionMode.SYNC,
580
+ )
552
581
 
553
582
  if self.get_logs:
554
583
  self.pod_manager.fetch_requested_container_logs(
@@ -562,6 +591,12 @@ class KubernetesPodOperator(BaseOperator):
562
591
  self.pod_manager.await_container_completion(
563
592
  pod=self.pod, container_name=self.base_container_name
564
593
  )
594
+ if self.callbacks:
595
+ self.callbacks.on_pod_completion(
596
+ pod=self.find_pod(self.pod.metadata.namespace, context=context),
597
+ client=self.client,
598
+ mode=ExecutionMode.SYNC,
599
+ )
565
600
 
566
601
  if self.do_xcom_push:
567
602
  self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
@@ -571,10 +606,14 @@ class KubernetesPodOperator(BaseOperator):
571
606
  self.pod, istio_enabled, self.base_container_name
572
607
  )
573
608
  finally:
609
+ pod_to_clean = self.pod or self.pod_request_obj
574
610
  self.cleanup(
575
- pod=self.pod or self.pod_request_obj,
611
+ pod=pod_to_clean,
576
612
  remote_pod=self.remote_pod,
577
613
  )
614
+ if self.callbacks:
615
+ self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)
616
+
578
617
  if self.do_xcom_push:
579
618
  return result
580
619
 
@@ -584,6 +623,12 @@ class KubernetesPodOperator(BaseOperator):
584
623
  pod_request_obj=self.pod_request_obj,
585
624
  context=context,
586
625
  )
626
+ if self.callbacks:
627
+ self.callbacks.on_pod_creation(
628
+ pod=self.find_pod(self.pod.metadata.namespace, context=context),
629
+ client=self.client,
630
+ mode=ExecutionMode.SYNC,
631
+ )
587
632
  ti = context["ti"]
588
633
  ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
589
634
  ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
@@ -620,6 +665,10 @@ class KubernetesPodOperator(BaseOperator):
620
665
  event["name"],
621
666
  event["namespace"],
622
667
  )
668
+ if self.callbacks:
669
+ self.callbacks.on_operator_resuming(
670
+ pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC
671
+ )
623
672
  if event["status"] in ("error", "failed", "timeout"):
624
673
  # fetch some logs when pod is failed
625
674
  if self.get_logs:
@@ -672,8 +721,15 @@ class KubernetesPodOperator(BaseOperator):
672
721
  pod=pod,
673
722
  remote_pod=remote_pod,
674
723
  )
724
+ if self.callbacks:
725
+ self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)
675
726
 
676
727
  def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
728
+ # If a task got marked as failed, "on_kill" method would be called and the pod will be cleaned up
729
+ # there. Cleaning it up again will raise an exception (which might cause retry).
730
+ if self._killed:
731
+ return
732
+
677
733
  istio_enabled = self.is_istio_enabled(remote_pod)
678
734
  pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None
679
735
 
@@ -816,6 +872,7 @@ class KubernetesPodOperator(BaseOperator):
816
872
  )
817
873
 
818
874
  def on_kill(self) -> None:
875
+ self._killed = True
819
876
  if self.pod:
820
877
  pod = self.pod
821
878
  kwargs = {
@@ -824,7 +881,11 @@ class KubernetesPodOperator(BaseOperator):
824
881
  }
825
882
  if self.termination_grace_period is not None:
826
883
  kwargs.update(grace_period_seconds=self.termination_grace_period)
827
- self.client.delete_namespaced_pod(**kwargs)
884
+
885
+ try:
886
+ self.client.delete_namespaced_pod(**kwargs)
887
+ except kubernetes.client.exceptions.ApiException:
888
+ self.log.exception("Unable to delete pod %s", self.pod.metadata.name)
828
889
 
829
890
  def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod:
830
891
  """
@@ -17,179 +17,264 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- import datetime
20
+ import re
21
21
  from functools import cached_property
22
- from typing import TYPE_CHECKING, Sequence
22
+ from typing import TYPE_CHECKING, Any
23
23
 
24
- from kubernetes.client import ApiException
25
- from kubernetes.watch import Watch
24
+ from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
26
25
 
27
26
  from airflow.exceptions import AirflowException
28
- from airflow.models import BaseOperator
27
+ from airflow.providers.cncf.kubernetes import pod_generator
29
28
  from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict
29
+ from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import CustomObjectLauncher
30
+ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
31
+ from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator
32
+ from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
33
+ from airflow.utils.helpers import prune_dict
30
34
 
31
35
  if TYPE_CHECKING:
32
- from kubernetes.client.models import CoreV1EventList
36
+ import jinja2
33
37
 
34
38
  from airflow.utils.context import Context
35
39
 
36
40
 
37
- class SparkKubernetesOperator(BaseOperator):
41
+ class SparkKubernetesOperator(KubernetesPodOperator):
38
42
  """
39
43
  Creates sparkApplication object in kubernetes cluster.
40
44
 
41
45
  .. seealso::
42
46
  For more detail about Spark Application Object have a look at the reference:
43
- https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.1.0-2.4.5/docs/api-docs.md#sparkapplication
47
+ https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.3.3-3.1.1/docs/api-docs.md#sparkapplication
44
48
 
45
- :param application_file: Defines Kubernetes 'custom_resource_definition' of 'sparkApplication' as either a
46
- path to a '.yaml' file, '.json' file, YAML string or python dictionary.
49
+ :param application_file: filepath to kubernetes custom_resource_definition of sparkApplication
50
+ :param kubernetes_conn_id: the connection to Kubernetes cluster
51
+ :param image: Docker image you wish to launch. Defaults to hub.docker.com,
52
+ :param code_path: path to the spark code in image,
47
53
  :param namespace: kubernetes namespace to put sparkApplication
48
- :param kubernetes_conn_id: The :ref:`kubernetes connection id <howto/connection:kubernetes>`
49
- for the to Kubernetes cluster.
50
- :param api_group: kubernetes api group of sparkApplication
51
- :param api_version: kubernetes api version of sparkApplication
52
- :param watch: whether to watch the job status and logs or not
54
+ :param cluster_context: context of the cluster
55
+ :param application_file: yaml file if passed
56
+ :param get_logs: get the stdout of the container as logs of the tasks.
57
+ :param do_xcom_push: If True, the content of the file
58
+ /airflow/xcom/return.json in the container will also be pushed to an
59
+ XCom when the container completes.
60
+ :param success_run_history_limit: Number of past successful runs of the application to keep.
61
+ :param delete_on_termination: What to do when the pod reaches its final
62
+ state, or the execution is interrupted. If True (default), delete the
63
+ pod; if False, leave the pod.
64
+ :param startup_timeout_seconds: timeout in seconds to startup the pod.
65
+ :param log_events_on_failure: Log the pod's events if a failure occurs
66
+ :param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor
53
67
  """
54
68
 
55
- template_fields: Sequence[str] = ("application_file", "namespace")
56
- template_ext: Sequence[str] = (".yaml", ".yml", ".json")
69
+ template_fields = ["application_file", "namespace", "template_spec"]
70
+ template_fields_renderers = {"template_spec": "py"}
71
+ template_ext = ("yaml", "yml", "json")
57
72
  ui_color = "#f4a460"
58
73
 
59
74
  def __init__(
60
75
  self,
61
76
  *,
62
- application_file: str | dict,
63
- namespace: str | None = None,
77
+ image: str | None = None,
78
+ code_path: str | None = None,
79
+ namespace: str = "default",
80
+ name: str = "default",
81
+ application_file: str | None = None,
82
+ template_spec=None,
83
+ get_logs: bool = True,
84
+ do_xcom_push: bool = False,
85
+ success_run_history_limit: int = 1,
86
+ startup_timeout_seconds=600,
87
+ log_events_on_failure: bool = False,
88
+ reattach_on_restart: bool = True,
89
+ delete_on_termination: bool = True,
64
90
  kubernetes_conn_id: str = "kubernetes_default",
65
- api_group: str = "sparkoperator.k8s.io",
66
- api_version: str = "v1beta2",
67
- in_cluster: bool | None = None,
68
- cluster_context: str | None = None,
69
- config_file: str | None = None,
70
- watch: bool = False,
71
91
  **kwargs,
72
92
  ) -> None:
73
- super().__init__(**kwargs)
74
- self.namespace = namespace
75
- self.kubernetes_conn_id = kubernetes_conn_id
76
- self.api_group = api_group
77
- self.api_version = api_version
78
- self.plural = "sparkapplications"
93
+ if kwargs.get("xcom_push") is not None:
94
+ raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
95
+ super().__init__(name=name, **kwargs)
96
+ self.image = image
97
+ self.code_path = code_path
79
98
  self.application_file = application_file
80
- self.in_cluster = in_cluster
81
- self.cluster_context = cluster_context
82
- self.config_file = config_file
83
- self.watch = watch
99
+ self.template_spec = template_spec
100
+ self.name = self.create_job_name()
101
+ self.kubernetes_conn_id = kubernetes_conn_id
102
+ self.startup_timeout_seconds = startup_timeout_seconds
103
+ self.reattach_on_restart = reattach_on_restart
104
+ self.delete_on_termination = delete_on_termination
105
+ self.do_xcom_push = do_xcom_push
106
+ self.namespace = namespace
107
+ self.get_logs = get_logs
108
+ self.log_events_on_failure = log_events_on_failure
109
+ self.success_run_history_limit = success_run_history_limit
110
+ self.template_body = self.manage_template_specs()
111
+
112
+ def _render_nested_template_fields(
113
+ self,
114
+ content: Any,
115
+ context: Context,
116
+ jinja_env: jinja2.Environment,
117
+ seen_oids: set,
118
+ ) -> None:
119
+ if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
120
+ seen_oids.add(id(content))
121
+ self._do_render_template_fields(content, ("value", "name"), context, jinja_env, seen_oids)
122
+ return
123
+
124
+ super()._render_nested_template_fields(content, context, jinja_env, seen_oids)
125
+
126
+ def manage_template_specs(self):
127
+ if self.application_file:
128
+ template_body = _load_body_to_dict(open(self.application_file))
129
+ elif self.template_spec:
130
+ template_body = self.template_spec
131
+ else:
132
+ raise AirflowException("either application_file or template_spec should be passed")
133
+ if "spark" not in template_body:
134
+ template_body = {"spark": template_body}
135
+ return template_body
136
+
137
+ def create_job_name(self):
138
+ initial_name = PodGenerator.make_unique_pod_id(self.task_id)[:MAX_LABEL_LEN]
139
+ return re.sub(r"[^a-z0-9-]+", "-", initial_name.lower())
140
+
141
+ @staticmethod
142
+ def _get_pod_identifying_label_string(labels) -> str:
143
+ filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"}
144
+ return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())])
145
+
146
+ @staticmethod
147
+ def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict:
148
+ """
149
+ Generate labels for the pod to track the pod in case of Operator crash.
150
+
151
+ :param include_try_number: add try number to labels
152
+ :param context: task context provided by airflow DAG
153
+ :return: dict.
154
+ """
155
+ if not context:
156
+ return {}
157
+
158
+ ti = context["ti"]
159
+ run_id = context["run_id"]
160
+
161
+ labels = {
162
+ "dag_id": ti.dag_id,
163
+ "task_id": ti.task_id,
164
+ "run_id": run_id,
165
+ "spark_kubernetes_operator": "True",
166
+ # 'execution_date': context['ts'],
167
+ # 'try_number': context['ti'].try_number,
168
+ }
169
+
170
+ # If running on Airflow 2.3+:
171
+ map_index = getattr(ti, "map_index", -1)
172
+ if map_index >= 0:
173
+ labels["map_index"] = map_index
174
+
175
+ if include_try_number:
176
+ labels.update(try_number=ti.try_number)
177
+
178
+ # In the case of sub dags this is just useful
179
+ if context["dag"].is_subdag:
180
+ labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
181
+ # Ensure that label is valid for Kube,
182
+ # and if not truncate/remove invalid chars and replace with short hash.
183
+ for label_id, label in labels.items():
184
+ safe_label = pod_generator.make_safe_label_value(str(label))
185
+ labels[label_id] = safe_label
186
+ return labels
187
+
188
+ @cached_property
189
+ def pod_manager(self) -> PodManager:
190
+ return PodManager(kube_client=self.client)
191
+
192
+ @staticmethod
193
+ def _try_numbers_match(context, pod) -> bool:
194
+ return pod.metadata.labels["try_number"] == context["ti"].try_number
195
+
196
+ def find_spark_job(self, context):
197
+ labels = self.create_labels_for_pod(context, include_try_number=False)
198
+ label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
199
+ pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items
200
+
201
+ pod = None
202
+ if len(pod_list) > 1: # and self.reattach_on_restart:
203
+ raise AirflowException(f"More than one pod running with labels: {label_selector}")
204
+ elif len(pod_list) == 1:
205
+ pod = pod_list[0]
206
+ self.log.info(
207
+ "Found matching driver pod %s with labels %s", pod.metadata.name, pod.metadata.labels
208
+ )
209
+ self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
210
+ self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
211
+ return pod
212
+
213
+ def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
214
+ if self.reattach_on_restart:
215
+ driver_pod = self.find_spark_job(context)
216
+ if driver_pod:
217
+ return driver_pod
218
+
219
+ driver_pod, spark_obj_spec = launcher.start_spark_job(
220
+ image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds
221
+ )
222
+ return driver_pod
223
+
224
+ def process_pod_deletion(self, pod, *, reraise=True):
225
+ if pod is not None:
226
+ if self.delete_on_termination:
227
+ self.log.info("Deleting spark job: %s", pod.metadata.name.replace("-driver", ""))
228
+ self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
229
+ else:
230
+ self.log.info("skipping deleting spark job: %s", pod.metadata.name)
84
231
 
85
232
  @cached_property
86
233
  def hook(self) -> KubernetesHook:
87
- return KubernetesHook(
234
+ hook = KubernetesHook(
88
235
  conn_id=self.kubernetes_conn_id,
89
- in_cluster=self.in_cluster,
90
- config_file=self.config_file,
91
- cluster_context=self.cluster_context,
236
+ in_cluster=self.in_cluster or self.template_body.get("kubernetes", {}).get("in_cluster", False),
237
+ config_file=self.config_file
238
+ or self.template_body.get("kubernetes", {}).get("kube_config_file", None),
239
+ cluster_context=self.cluster_context
240
+ or self.template_body.get("kubernetes", {}).get("cluster_context", None),
92
241
  )
242
+ return hook
93
243
 
94
- def _get_namespace_event_stream(self, namespace, query_kwargs=None):
95
- try:
96
- return Watch().stream(
97
- self.hook.core_v1_client.list_namespaced_event,
98
- namespace=namespace,
99
- watch=True,
100
- **(query_kwargs or {}),
101
- )
102
- except ApiException as e:
103
- if e.status == 410: # Resource version is too old
104
- events: CoreV1EventList = self.hook.core_v1_client.list_namespaced_event(
105
- namespace=namespace, watch=False
106
- )
107
- resource_version = events.metadata.resource_version
108
- query_kwargs["resource_version"] = resource_version
109
- return self._get_namespace_event_stream(namespace, query_kwargs)
110
- else:
111
- raise
244
+ @cached_property
245
+ def client(self) -> CoreV1Api:
246
+ return self.hook.core_v1_client
112
247
 
113
- def execute(self, context: Context):
114
- if isinstance(self.application_file, str):
115
- body = _load_body_to_dict(self.application_file)
116
- else:
117
- body = self.application_file
118
- name = body["metadata"]["name"]
119
- namespace = self.namespace or self.hook.get_namespace()
120
-
121
- response = None
122
- is_job_created = False
123
- if self.watch:
124
- try:
125
- namespace_event_stream = self._get_namespace_event_stream(
126
- namespace=namespace,
127
- query_kwargs={
128
- "field_selector": f"involvedObject.kind=SparkApplication,involvedObject.name={name}"
129
- },
130
- )
131
-
132
- response = self.hook.create_custom_object(
133
- group=self.api_group,
134
- version=self.api_version,
135
- plural=self.plural,
136
- body=body,
137
- namespace=namespace,
138
- )
139
- is_job_created = True
140
- for event in namespace_event_stream:
141
- obj = event["object"]
142
- if event["object"].last_timestamp >= datetime.datetime.strptime(
143
- response["metadata"]["creationTimestamp"], "%Y-%m-%dT%H:%M:%S%z"
144
- ):
145
- self.log.info(obj.message)
146
- if obj.reason == "SparkDriverRunning":
147
- pod_log_stream = Watch().stream(
148
- self.hook.core_v1_client.read_namespaced_pod_log,
149
- name=f"{name}-driver",
150
- namespace=namespace,
151
- timestamps=True,
152
- )
153
- for line in pod_log_stream:
154
- self.log.info(line)
155
- elif obj.reason in [
156
- "SparkApplicationSubmissionFailed",
157
- "SparkApplicationFailed",
158
- "SparkApplicationDeleted",
159
- ]:
160
- is_job_created = False
161
- raise AirflowException(obj.message)
162
- elif obj.reason == "SparkApplicationCompleted":
163
- break
164
- else:
165
- continue
166
- except Exception:
167
- if is_job_created:
168
- self.on_kill()
169
- raise
248
+ @cached_property
249
+ def custom_obj_api(self) -> CustomObjectsApi:
250
+ return CustomObjectsApi()
170
251
 
171
- else:
172
- response = self.hook.create_custom_object(
173
- group=self.api_group,
174
- version=self.api_version,
175
- plural=self.plural,
176
- body=body,
177
- namespace=namespace,
178
- )
252
+ def execute(self, context: Context):
253
+ self.log.info("Creating sparkApplication.")
254
+ self.launcher = CustomObjectLauncher(
255
+ name=self.name,
256
+ namespace=self.namespace,
257
+ kube_client=self.client,
258
+ custom_obj_api=self.custom_obj_api,
259
+ template_body=self.template_body,
260
+ )
261
+ self.pod = self.get_or_create_spark_crd(self.launcher, context)
262
+ self.BASE_CONTAINER_NAME = "spark-kubernetes-driver"
263
+ self.pod_request_obj = self.launcher.pod_spec
179
264
 
180
- return response
265
+ return super().execute(context=context)
181
266
 
182
267
  def on_kill(self) -> None:
183
- if isinstance(self.application_file, str):
184
- body = _load_body_to_dict(self.application_file)
185
- else:
186
- body = self.application_file
187
- name = body["metadata"]["name"]
188
- namespace = self.namespace or self.hook.get_namespace()
189
- self.hook.delete_custom_object(
190
- group=self.api_group,
191
- version=self.api_version,
192
- plural=self.plural,
193
- namespace=namespace,
194
- name=name,
195
- )
268
+ if self.launcher:
269
+ self.log.debug("Deleting spark job for task %s", self.task_id)
270
+ self.launcher.delete_spark_job()
271
+
272
+ def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
273
+ """Add an "already checked" annotation to ensure we don't reattach on retries."""
274
+ pod.metadata.labels["already_checked"] = "True"
275
+ body = PodGenerator.serialize_pod(pod)
276
+ self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body)
277
+
278
+ def dry_run(self) -> None:
279
+ """Prints out the spark job that would be created by this operator."""
280
+ print(prune_dict(self.launcher.body, mode="strict"))