apache-airflow-providers-cncf-kubernetes 10.1.0rc2__py3-none-any.whl → 10.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. airflow/providers/cncf/kubernetes/LICENSE +0 -52
  2. airflow/providers/cncf/kubernetes/__init__.py +1 -1
  3. airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py +2 -3
  4. airflow/providers/cncf/kubernetes/callbacks.py +90 -8
  5. airflow/providers/cncf/kubernetes/cli/kubernetes_command.py +3 -4
  6. airflow/providers/cncf/kubernetes/decorators/kubernetes.py +10 -5
  7. airflow/providers/cncf/kubernetes/exceptions.py +29 -0
  8. airflow/providers/cncf/kubernetes/executors/kubernetes_executor.py +36 -113
  9. airflow/providers/cncf/kubernetes/executors/kubernetes_executor_utils.py +27 -15
  10. airflow/providers/cncf/kubernetes/get_provider_info.py +14 -21
  11. airflow/providers/cncf/kubernetes/hooks/kubernetes.py +20 -10
  12. airflow/providers/cncf/kubernetes/kube_config.py +0 -4
  13. airflow/providers/cncf/kubernetes/kubernetes_helper_functions.py +1 -1
  14. airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py +3 -3
  15. airflow/providers/cncf/kubernetes/operators/job.py +4 -4
  16. airflow/providers/cncf/kubernetes/operators/kueue.py +2 -2
  17. airflow/providers/cncf/kubernetes/operators/pod.py +102 -44
  18. airflow/providers/cncf/kubernetes/operators/resource.py +1 -1
  19. airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +23 -19
  20. airflow/providers/cncf/kubernetes/pod_generator.py +51 -21
  21. airflow/providers/cncf/kubernetes/resource_convert/env_variable.py +1 -2
  22. airflow/providers/cncf/kubernetes/secret.py +1 -2
  23. airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py +1 -2
  24. airflow/providers/cncf/kubernetes/template_rendering.py +10 -2
  25. airflow/providers/cncf/kubernetes/utils/k8s_resource_iterator.py +1 -2
  26. airflow/providers/cncf/kubernetes/utils/pod_manager.py +12 -11
  27. {apache_airflow_providers_cncf_kubernetes-10.1.0rc2.dist-info → apache_airflow_providers_cncf_kubernetes-10.3.0.dist-info}/METADATA +10 -27
  28. {apache_airflow_providers_cncf_kubernetes-10.1.0rc2.dist-info → apache_airflow_providers_cncf_kubernetes-10.3.0.dist-info}/RECORD +30 -30
  29. airflow/providers/cncf/kubernetes/pod_generator_deprecated.py +0 -309
  30. {apache_airflow_providers_cncf_kubernetes-10.1.0rc2.dist-info → apache_airflow_providers_cncf_kubernetes-10.3.0.dist-info}/WHEEL +0 -0
  31. {apache_airflow_providers_cncf_kubernetes-10.1.0rc2.dist-info → apache_airflow_providers_cncf_kubernetes-10.3.0.dist-info}/entry_points.txt +0 -0
@@ -15,8 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
 
18
- # NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE
19
- # OVERWRITTEN WHEN PREPARING PACKAGES.
18
+ # NOTE! THIS FILE IS AUTOMATICALLY GENERATED AND WILL BE OVERWRITTEN!
20
19
  #
21
20
  # IF YOU WANT TO MODIFY THIS FILE, YOU SHOULD MODIFY THE TEMPLATE
22
21
  # `get_provider_info_TEMPLATE.py.jinja2` IN the `dev/breeze/src/airflow_breeze/templates` DIRECTORY
@@ -28,8 +27,9 @@ def get_provider_info():
28
27
  "name": "Kubernetes",
29
28
  "description": "`Kubernetes <https://kubernetes.io/>`__\n",
30
29
  "state": "ready",
31
- "source-date-epoch": 1734537609,
30
+ "source-date-epoch": 1739959070,
32
31
  "versions": [
32
+ "10.3.0",
33
33
  "10.1.0",
34
34
  "10.0.1",
35
35
  "10.0.0",
@@ -100,27 +100,18 @@ def get_provider_info():
100
100
  "1.0.1",
101
101
  "1.0.0",
102
102
  ],
103
- "dependencies": [
104
- "aiofiles>=23.2.0",
105
- "apache-airflow>=2.9.0",
106
- "asgiref>=3.5.2",
107
- "cryptography>=41.0.0",
108
- "kubernetes>=29.0.0,<=31.0.0",
109
- "kubernetes_asyncio>=29.0.0,<=31.0.0",
110
- "google-re2>=1.0",
111
- ],
112
103
  "integrations": [
113
104
  {
114
105
  "integration-name": "Kubernetes",
115
106
  "external-doc-url": "https://kubernetes.io/",
116
107
  "how-to-guide": ["/docs/apache-airflow-providers-cncf-kubernetes/operators.rst"],
117
- "logo": "/integration-logos/kubernetes/Kubernetes.png",
108
+ "logo": "/docs/integration-logos/Kubernetes.png",
118
109
  "tags": ["software"],
119
110
  },
120
111
  {
121
112
  "integration-name": "Spark on Kubernetes",
122
113
  "external-doc-url": "https://github.com/GoogleCloudPlatform/spark-on-k8s-operator",
123
- "logo": "/integration-logos/kubernetes/Spark-On-Kubernetes.png",
114
+ "logo": "/docs/integration-logos/Spark-On-Kubernetes.png",
124
115
  "tags": ["software"],
125
116
  },
126
117
  ],
@@ -341,13 +332,6 @@ def get_provider_info():
341
332
  "example": None,
342
333
  "default": "True",
343
334
  },
344
- "worker_pods_queued_check_interval": {
345
- "description": 'How often in seconds to check for task instances stuck in "queued" status without a pod\n',
346
- "version_added": None,
347
- "type": "integer",
348
- "example": None,
349
- "default": "60",
350
- },
351
335
  "ssl_ca_cert": {
352
336
  "description": "Path to a CA certificate to be used by the Kubernetes client to verify the server's SSL certificate.\n",
353
337
  "version_added": None,
@@ -366,4 +350,13 @@ def get_provider_info():
366
350
  },
367
351
  },
368
352
  "executors": ["airflow.providers.cncf.kubernetes.kubernetes_executor.KubernetesExecutor"],
353
+ "dependencies": [
354
+ "aiofiles>=23.2.0",
355
+ "apache-airflow>=2.9.0",
356
+ "asgiref>=3.5.2",
357
+ "cryptography>=41.0.0",
358
+ "kubernetes>=29.0.0,<=31.0.0",
359
+ "kubernetes_asyncio>=29.0.0,<=31.0.0",
360
+ "google-re2>=1.0",
361
+ ],
369
362
  }
@@ -29,9 +29,6 @@ import aiofiles
29
29
  import requests
30
30
  import tenacity
31
31
  from asgiref.sync import sync_to_async
32
- from kubernetes import client, config, utils, watch
33
- from kubernetes.client.models import V1Deployment
34
- from kubernetes.config import ConfigException
35
32
  from kubernetes_asyncio import client as async_client, config as async_config
36
33
  from urllib3.exceptions import HTTPError
37
34
 
@@ -46,6 +43,9 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import (
46
43
  container_is_running,
47
44
  )
48
45
  from airflow.utils import yaml
46
+ from kubernetes import client, config, utils, watch
47
+ from kubernetes.client.models import V1Deployment
48
+ from kubernetes.config import ConfigException
49
49
 
50
50
  if TYPE_CHECKING:
51
51
  from kubernetes.client import V1JobList
@@ -734,9 +734,19 @@ class AsyncKubernetesHook(KubernetesHook):
734
734
  """Return Kubernetes API session for use with requests."""
735
735
  in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
736
736
  cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
737
+ kubeconfig_path = await self._get_field("kube_config_path")
737
738
  kubeconfig = await self._get_field("kube_config")
739
+ num_selected_configuration = sum(
740
+ 1 for o in [in_cluster, kubeconfig, kubeconfig_path, self.config_dict] if o
741
+ )
738
742
 
739
- num_selected_configuration = sum(1 for o in [in_cluster, kubeconfig, self.config_dict] if o)
743
+ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None):
744
+ await async_config.load_kube_config(
745
+ config_file=_kubeconfig_path,
746
+ client_configuration=self.client_configuration,
747
+ context=cluster_context,
748
+ )
749
+ return async_client.ApiClient()
740
750
 
741
751
  if num_selected_configuration > 1:
742
752
  raise AirflowException(
@@ -757,6 +767,11 @@ class AsyncKubernetesHook(KubernetesHook):
757
767
  await async_config.load_kube_config_from_dict(self.config_dict)
758
768
  return async_client.ApiClient()
759
769
 
770
+ if kubeconfig_path is not None:
771
+ self.log.debug("loading kube_config from: %s", kubeconfig_path)
772
+ self._is_in_cluster = False
773
+ return await api_client_from_kubeconfig_file(kubeconfig_path)
774
+
760
775
  if kubeconfig is not None:
761
776
  async with aiofiles.tempfile.NamedTemporaryFile() as temp_config:
762
777
  self.log.debug(
@@ -766,12 +781,7 @@ class AsyncKubernetesHook(KubernetesHook):
766
781
  await temp_config.write(kubeconfig.encode())
767
782
  await temp_config.flush()
768
783
  self._is_in_cluster = False
769
- await async_config.load_kube_config(
770
- config_file=temp_config.name,
771
- client_configuration=self.client_configuration,
772
- context=cluster_context,
773
- )
774
- return async_client.ApiClient()
784
+ return await api_client_from_kubeconfig_file(temp_config.name)
775
785
  self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("default configuration file"))
776
786
  await async_config.load_kube_config(
777
787
  client_configuration=self.client_configuration,
@@ -76,10 +76,6 @@ class KubeConfig:
76
76
  # interact with cluster components.
77
77
  self.executor_namespace = conf.get(self.kubernetes_section, "namespace")
78
78
 
79
- self.worker_pods_queued_check_interval = conf.getint(
80
- self.kubernetes_section, "worker_pods_queued_check_interval"
81
- )
82
-
83
79
  self.kube_client_request_args = conf.getjson(
84
80
  self.kubernetes_section, "kube_client_request_args", fallback={}
85
81
  )
@@ -23,11 +23,11 @@ from functools import cache
23
23
  from typing import TYPE_CHECKING
24
24
 
25
25
  import pendulum
26
- from kubernetes.client.rest import ApiException
27
26
  from slugify import slugify
28
27
 
29
28
  from airflow.configuration import conf
30
29
  from airflow.providers.cncf.kubernetes.backcompat import get_logical_date_key
30
+ from kubernetes.client.rest import ApiException
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from airflow.models.taskinstancekey import TaskInstanceKey
@@ -24,8 +24,6 @@ from datetime import datetime as dt
24
24
  from functools import cached_property
25
25
 
26
26
  import tenacity
27
- from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
28
- from kubernetes.client.rest import ApiException
29
27
 
30
28
  from airflow.exceptions import AirflowException
31
29
  from airflow.providers.cncf.kubernetes.resource_convert.configmap import (
@@ -39,6 +37,8 @@ from airflow.providers.cncf.kubernetes.resource_convert.secret import (
39
37
  )
40
38
  from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
41
39
  from airflow.utils.log.logging_mixin import LoggingMixin
40
+ from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
41
+ from kubernetes.client.rest import ApiException
42
42
 
43
43
 
44
44
  def should_retry_start_spark_job(exception: BaseException) -> bool:
@@ -291,7 +291,7 @@ class CustomObjectLauncher(LoggingMixin):
291
291
  # Wait for the driver pod to come alive
292
292
  self.pod_spec = k8s.V1Pod(
293
293
  metadata=k8s.V1ObjectMeta(
294
- labels=self.spark_obj_spec["spec"]["driver"]["labels"],
294
+ labels=self.spark_obj_spec["spec"]["driver"].get("labels"),
295
295
  name=self.spark_obj_spec["metadata"]["name"] + "-driver",
296
296
  namespace=self.namespace,
297
297
  )
@@ -26,10 +26,6 @@ from collections.abc import Sequence
26
26
  from functools import cached_property
27
27
  from typing import TYPE_CHECKING
28
28
 
29
- from kubernetes.client import BatchV1Api, models as k8s
30
- from kubernetes.client.api_client import ApiClient
31
- from kubernetes.client.rest import ApiException
32
-
33
29
  from airflow.configuration import conf
34
30
  from airflow.exceptions import AirflowException
35
31
  from airflow.models import BaseOperator
@@ -44,6 +40,9 @@ from airflow.providers.cncf.kubernetes.triggers.job import KubernetesJobTrigger
44
40
  from airflow.providers.cncf.kubernetes.utils.pod_manager import EMPTY_XCOM_RESULT, PodNotFoundException
45
41
  from airflow.utils import yaml
46
42
  from airflow.utils.context import Context
43
+ from kubernetes.client import BatchV1Api, models as k8s
44
+ from kubernetes.client.api_client import ApiClient
45
+ from kubernetes.client.rest import ApiException
47
46
 
48
47
  if TYPE_CHECKING:
49
48
  from airflow.utils.context import Context
@@ -167,6 +166,7 @@ class KubernetesJobOperator(KubernetesPodOperator):
167
166
  ti.xcom_push(key="job_name", value=self.job.metadata.name)
168
167
  ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)
169
168
 
169
+ self.pod: k8s.V1Pod | None
170
170
  if self.pod is None:
171
171
  self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
172
172
  pod_request_obj=self.pod_request_obj,
@@ -22,12 +22,11 @@ import json
22
22
  from collections.abc import Sequence
23
23
  from functools import cached_property
24
24
 
25
- from kubernetes.utils import FailToCreateError
26
-
27
25
  from airflow.exceptions import AirflowException
28
26
  from airflow.models import BaseOperator
29
27
  from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
30
28
  from airflow.providers.cncf.kubernetes.operators.job import KubernetesJobOperator
29
+ from kubernetes.utils import FailToCreateError
31
30
 
32
31
 
33
32
  class KubernetesInstallKueueOperator(BaseOperator):
@@ -95,6 +94,7 @@ class KubernetesStartKueueJobOperator(KubernetesJobOperator):
95
94
  super().__init__(*args, **kwargs)
96
95
  self.queue_name = queue_name
97
96
 
97
+ self.suspend: bool
98
98
  if self.suspend is False:
99
99
  raise AirflowException(
100
100
  "The `suspend` parameter can't be False. If you want to use Kueue for running Job"
@@ -26,19 +26,16 @@ import os
26
26
  import re
27
27
  import shlex
28
28
  import string
29
- from collections.abc import Container, Iterable, Mapping, Sequence
29
+ from collections.abc import Container, Iterable, Sequence
30
30
  from contextlib import AbstractContextManager
31
31
  from enum import Enum
32
32
  from functools import cached_property
33
33
  from typing import TYPE_CHECKING, Any, Callable, Literal
34
34
 
35
- import kubernetes
36
35
  import tenacity
37
- from kubernetes.client import CoreV1Api, V1Pod, models as k8s
38
- from kubernetes.client.exceptions import ApiException
39
- from kubernetes.stream import stream
40
36
  from urllib3.exceptions import HTTPError
41
37
 
38
+ import kubernetes
42
39
  from airflow.configuration import conf
43
40
  from airflow.exceptions import (
44
41
  AirflowException,
@@ -84,13 +81,21 @@ from airflow.settings import pod_mutation_hook
84
81
  from airflow.utils import yaml
85
82
  from airflow.utils.helpers import prune_dict, validate_key
86
83
  from airflow.version import version as airflow_version
84
+ from kubernetes.client import CoreV1Api, V1Pod, models as k8s
85
+ from kubernetes.client.exceptions import ApiException
86
+ from kubernetes.stream import stream
87
87
 
88
88
  if TYPE_CHECKING:
89
89
  import jinja2
90
90
  from pendulum import DateTime
91
91
 
92
92
  from airflow.providers.cncf.kubernetes.secret import Secret
93
- from airflow.utils.context import Context
93
+
94
+ try:
95
+ from airflow.sdk.definitions.context import Context
96
+ except ImportError:
97
+ # TODO: Remove once provider drops support for Airflow 2
98
+ from airflow.utils.context import Context
94
99
 
95
100
  alphanum_lower = string.ascii_lowercase + string.digits
96
101
 
@@ -238,6 +243,8 @@ class KubernetesPodOperator(BaseOperator):
238
243
 
239
244
  template_fields: Sequence[str] = (
240
245
  "image",
246
+ "name",
247
+ "hostname",
241
248
  "cmds",
242
249
  "annotations",
243
250
  "arguments",
@@ -319,7 +326,9 @@ class KubernetesPodOperator(BaseOperator):
319
326
  is_delete_operator_pod: None | bool = None,
320
327
  termination_message_policy: str = "File",
321
328
  active_deadline_seconds: int | None = None,
322
- callbacks: type[KubernetesPodOperatorCallback] | None = None,
329
+ callbacks: (
330
+ list[type[KubernetesPodOperatorCallback]] | type[KubernetesPodOperatorCallback] | None
331
+ ) = None,
323
332
  progress_callback: Callable[[str], None] | None = None,
324
333
  logging_interval: int | None = None,
325
334
  **kwargs,
@@ -384,7 +393,7 @@ class KubernetesPodOperator(BaseOperator):
384
393
  self.priority_class_name = priority_class_name
385
394
  self.pod_template_file = pod_template_file
386
395
  self.pod_template_dict = pod_template_dict
387
- self.name = self._set_name(name)
396
+ self.name = name
388
397
  self.random_name_suffix = random_name_suffix
389
398
  self.termination_grace_period = termination_grace_period
390
399
  self.pod_request_obj: k8s.V1Pod | None = None
@@ -410,7 +419,7 @@ class KubernetesPodOperator(BaseOperator):
410
419
 
411
420
  self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict
412
421
  self._progress_callback = progress_callback
413
- self.callbacks = callbacks
422
+ self.callbacks = [] if not callbacks else callbacks if isinstance(callbacks, list) else [callbacks]
414
423
  self._killed: bool = False
415
424
 
416
425
  @cached_property
@@ -423,7 +432,7 @@ class KubernetesPodOperator(BaseOperator):
423
432
  def _render_nested_template_fields(
424
433
  self,
425
434
  content: Any,
426
- context: Mapping[str, Any],
435
+ context: Context,
427
436
  jinja_env: jinja2.Environment,
428
437
  seen_oids: set,
429
438
  ) -> None:
@@ -480,11 +489,11 @@ class KubernetesPodOperator(BaseOperator):
480
489
  }
481
490
 
482
491
  map_index = ti.map_index
483
- if map_index >= 0:
484
- labels["map_index"] = map_index
492
+ if map_index is not None and map_index >= 0:
493
+ labels["map_index"] = str(map_index)
485
494
 
486
495
  if include_try_number:
487
- labels.update(try_number=ti.try_number)
496
+ labels.update(try_number=str(ti.try_number))
488
497
  # In the case of sub dags this is just useful
489
498
  # TODO: Remove this when the minimum version of Airflow is bumped to 3.0
490
499
  if getattr(context["dag"], "parent_dag", False):
@@ -514,8 +523,9 @@ class KubernetesPodOperator(BaseOperator):
514
523
  @cached_property
515
524
  def client(self) -> CoreV1Api:
516
525
  client = self.hook.core_v1_client
517
- if self.callbacks:
518
- self.callbacks.on_sync_client_creation(client=client)
526
+
527
+ for callback in self.callbacks:
528
+ callback.on_sync_client_creation(client=client, operator=self)
519
529
  return client
520
530
 
521
531
  def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
@@ -579,6 +589,7 @@ class KubernetesPodOperator(BaseOperator):
579
589
 
580
590
  def execute(self, context: Context):
581
591
  """Based on the deferrable parameter runs the pod asynchronously or synchronously."""
592
+ self.name = self._set_name(self.name)
582
593
  if not self.deferrable:
583
594
  return self.execute_sync(context)
584
595
 
@@ -589,6 +600,14 @@ class KubernetesPodOperator(BaseOperator):
589
600
  try:
590
601
  if self.pod_request_obj is None:
591
602
  self.pod_request_obj = self.build_pod_request_obj(context)
603
+ for callback in self.callbacks:
604
+ callback.on_pod_manifest_created(
605
+ pod_request=self.pod_request_obj,
606
+ client=self.client,
607
+ mode=ExecutionMode.SYNC,
608
+ context=context,
609
+ operator=self,
610
+ )
592
611
  if self.pod is None:
593
612
  self.pod = self.get_or_create_pod( # must set `self.pod` for `on_kill`
594
613
  pod_request_obj=self.pod_request_obj,
@@ -601,28 +620,48 @@ class KubernetesPodOperator(BaseOperator):
601
620
 
602
621
  # get remote pod for use in cleanup methods
603
622
  self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
604
- if self.callbacks:
605
- self.callbacks.on_pod_creation(
606
- pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC
623
+ for callback in self.callbacks:
624
+ callback.on_pod_creation(
625
+ pod=self.remote_pod,
626
+ client=self.client,
627
+ mode=ExecutionMode.SYNC,
628
+ context=context,
629
+ operator=self,
607
630
  )
608
631
 
609
632
  self.await_init_containers_completion(pod=self.pod)
610
633
 
611
634
  self.await_pod_start(pod=self.pod)
612
635
  if self.callbacks:
613
- self.callbacks.on_pod_starting(
614
- pod=self.find_pod(self.pod.metadata.namespace, context=context),
615
- client=self.client,
616
- mode=ExecutionMode.SYNC,
617
- )
636
+ pod = self.find_pod(self.pod.metadata.namespace, context=context)
637
+ for callback in self.callbacks:
638
+ callback.on_pod_starting(
639
+ pod=pod,
640
+ client=self.client,
641
+ mode=ExecutionMode.SYNC,
642
+ context=context,
643
+ operator=self,
644
+ )
618
645
 
619
646
  self.await_pod_completion(pod=self.pod)
620
647
  if self.callbacks:
621
- self.callbacks.on_pod_completion(
622
- pod=self.find_pod(self.pod.metadata.namespace, context=context),
623
- client=self.client,
624
- mode=ExecutionMode.SYNC,
625
- )
648
+ pod = self.find_pod(self.pod.metadata.namespace, context=context)
649
+ for callback in self.callbacks:
650
+ callback.on_pod_completion(
651
+ pod=pod,
652
+ client=self.client,
653
+ mode=ExecutionMode.SYNC,
654
+ context=context,
655
+ operator=self,
656
+ )
657
+ for callback in self.callbacks:
658
+ callback.on_pod_teardown(
659
+ pod=pod,
660
+ client=self.client,
661
+ mode=ExecutionMode.SYNC,
662
+ context=context,
663
+ operator=self,
664
+ )
626
665
 
627
666
  if self.do_xcom_push:
628
667
  self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
@@ -637,8 +676,14 @@ class KubernetesPodOperator(BaseOperator):
637
676
  pod=pod_to_clean,
638
677
  remote_pod=self.remote_pod,
639
678
  )
640
- if self.callbacks:
641
- self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC)
679
+ for callback in self.callbacks:
680
+ callback.on_pod_cleanup(
681
+ pod=pod_to_clean,
682
+ client=self.client,
683
+ mode=ExecutionMode.SYNC,
684
+ context=context,
685
+ operator=self,
686
+ )
642
687
 
643
688
  if self.do_xcom_push:
644
689
  return result
@@ -705,11 +750,15 @@ class KubernetesPodOperator(BaseOperator):
705
750
  context=context,
706
751
  )
707
752
  if self.callbacks:
708
- self.callbacks.on_pod_creation(
709
- pod=self.find_pod(self.pod.metadata.namespace, context=context),
710
- client=self.client,
711
- mode=ExecutionMode.SYNC,
712
- )
753
+ pod = self.find_pod(self.pod.metadata.namespace, context=context)
754
+ for callback in self.callbacks:
755
+ callback.on_pod_creation(
756
+ pod=pod,
757
+ client=self.client,
758
+ mode=ExecutionMode.SYNC,
759
+ context=context,
760
+ operator=self,
761
+ )
713
762
  ti = context["ti"]
714
763
  ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
715
764
  ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
@@ -770,10 +819,16 @@ class KubernetesPodOperator(BaseOperator):
770
819
  if not self.pod:
771
820
  raise PodNotFoundException("Could not find pod after resuming from deferral")
772
821
 
773
- if self.callbacks and event["status"] != "running":
774
- self.callbacks.on_operator_resuming(
775
- pod=self.pod, event=event, client=self.client, mode=ExecutionMode.SYNC
776
- )
822
+ if event["status"] != "running":
823
+ for callback in self.callbacks:
824
+ callback.on_operator_resuming(
825
+ pod=self.pod,
826
+ event=event,
827
+ client=self.client,
828
+ mode=ExecutionMode.SYNC,
829
+ context=context,
830
+ operator=self,
831
+ )
777
832
 
778
833
  follow = self.logging_interval is None
779
834
  last_log_time = event.get("last_log_time")
@@ -816,9 +871,9 @@ class KubernetesPodOperator(BaseOperator):
816
871
  except TaskDeferred:
817
872
  raise
818
873
  finally:
819
- self._clean(event)
874
+ self._clean(event, context)
820
875
 
821
- def _clean(self, event: dict[str, Any]) -> None:
876
+ def _clean(self, event: dict[str, Any], context: Context) -> None:
822
877
  if event["status"] == "running":
823
878
  return
824
879
  istio_enabled = self.is_istio_enabled(self.pod)
@@ -841,6 +896,7 @@ class KubernetesPodOperator(BaseOperator):
841
896
  self.post_complete_action(
842
897
  pod=self.pod,
843
898
  remote_pod=self.pod,
899
+ context=context,
844
900
  )
845
901
 
846
902
  def _write_logs(self, pod: k8s.V1Pod, follow: bool = False, since_time: DateTime | None = None) -> None:
@@ -870,14 +926,16 @@ class KubernetesPodOperator(BaseOperator):
870
926
  e if not isinstance(e, ApiException) else e.reason,
871
927
  )
872
928
 
873
- def post_complete_action(self, *, pod, remote_pod, **kwargs) -> None:
929
+ def post_complete_action(self, *, pod, remote_pod, context: Context, **kwargs) -> None:
874
930
  """Actions that must be done after operator finishes logic of the deferrable_execution."""
875
931
  self.cleanup(
876
932
  pod=pod,
877
933
  remote_pod=remote_pod,
878
934
  )
879
- if self.callbacks:
880
- self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC)
935
+ for callback in self.callbacks:
936
+ callback.on_pod_cleanup(
937
+ pod=pod, client=self.client, mode=ExecutionMode.SYNC, operator=self, context=context
938
+ )
881
939
 
882
940
  def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
883
941
  # Skip cleaning the pod in the following scenarios.
@@ -25,7 +25,6 @@ from typing import TYPE_CHECKING
25
25
 
26
26
  import tenacity
27
27
  import yaml
28
- from kubernetes.utils import create_from_yaml
29
28
 
30
29
  from airflow.exceptions import AirflowException
31
30
  from airflow.models import BaseOperator
@@ -33,6 +32,7 @@ from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
33
32
  from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import should_retry_creation
34
33
  from airflow.providers.cncf.kubernetes.utils.delete_from import delete_from_yaml
35
34
  from airflow.providers.cncf.kubernetes.utils.k8s_resource_iterator import k8s_resource_iterator
35
+ from kubernetes.utils import create_from_yaml
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from kubernetes.client import ApiClient, CustomObjectsApi
@@ -17,12 +17,9 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from collections.abc import Mapping
21
20
  from functools import cached_property
22
21
  from pathlib import Path
23
- from typing import TYPE_CHECKING, Any
24
-
25
- from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
22
+ from typing import TYPE_CHECKING, Any, cast
26
23
 
27
24
  from airflow.exceptions import AirflowException
28
25
  from airflow.providers.cncf.kubernetes import pod_generator
@@ -33,11 +30,16 @@ from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperato
33
30
  from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator
34
31
  from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
35
32
  from airflow.utils.helpers import prune_dict
33
+ from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
36
34
 
37
35
  if TYPE_CHECKING:
38
36
  import jinja2
39
37
 
40
- from airflow.utils.context import Context
38
+ try:
39
+ from airflow.sdk.definitions.context import Context
40
+ except ImportError:
41
+ # TODO: Remove once provider drops support for Airflow 2
42
+ from airflow.utils.context import Context
41
43
 
42
44
 
43
45
  class SparkKubernetesOperator(KubernetesPodOperator):
@@ -114,6 +116,10 @@ class SparkKubernetesOperator(KubernetesPodOperator):
114
116
  self.success_run_history_limit = success_run_history_limit
115
117
  self.random_name_suffix = random_name_suffix
116
118
 
119
+ # fix mypy typing
120
+ self.base_container_name: str
121
+ self.container_logs: list[str]
122
+
117
123
  if self.base_container_name != self.BASE_CONTAINER_NAME:
118
124
  self.log.warning(
119
125
  "base_container_name is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME
@@ -129,7 +135,7 @@ class SparkKubernetesOperator(KubernetesPodOperator):
129
135
  def _render_nested_template_fields(
130
136
  self,
131
137
  content: Any,
132
- context: Mapping[str, Any],
138
+ context: Context,
133
139
  jinja_env: jinja2.Environment,
134
140
  seen_oids: set,
135
141
  ) -> None:
@@ -174,12 +180,7 @@ class SparkKubernetesOperator(KubernetesPodOperator):
174
180
  return self._set_name(updated_name)
175
181
 
176
182
  @staticmethod
177
- def _get_pod_identifying_label_string(labels) -> str:
178
- filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"}
179
- return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())])
180
-
181
- @staticmethod
182
- def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict:
183
+ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]:
183
184
  """
184
185
  Generate labels for the pod to track the pod in case of Operator crash.
185
186
 
@@ -190,8 +191,9 @@ class SparkKubernetesOperator(KubernetesPodOperator):
190
191
  if not context:
191
192
  return {}
192
193
 
193
- ti = context["ti"]
194
- run_id = context["run_id"]
194
+ context_dict = cast(dict, context)
195
+ ti = context_dict["ti"]
196
+ run_id = context_dict["run_id"]
195
197
 
196
198
  labels = {
197
199
  "dag_id": ti.dag_id,
@@ -210,8 +212,8 @@ class SparkKubernetesOperator(KubernetesPodOperator):
210
212
 
211
213
  # In the case of sub dags this is just useful
212
214
  # TODO: Remove this when the minimum version of Airflow is bumped to 3.0
213
- if getattr(context["dag"], "is_subdag", False):
214
- labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
215
+ if getattr(context_dict["dag"], "is_subdag", False):
216
+ labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id
215
217
  # Ensure that label is valid for Kube,
216
218
  # and if not truncate/remove invalid chars and replace with short hash.
217
219
  for label_id, label in labels.items():
@@ -232,9 +234,11 @@ class SparkKubernetesOperator(KubernetesPodOperator):
232
234
  """Templated body for CustomObjectLauncher."""
233
235
  return self.manage_template_specs()
234
236
 
235
- def find_spark_job(self, context):
236
- labels = self.create_labels_for_pod(context, include_try_number=False)
237
- label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
237
+ def find_spark_job(self, context, exclude_checked: bool = True):
238
+ label_selector = (
239
+ self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
240
+ + ",spark-role=driver"
241
+ )
238
242
  pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items
239
243
 
240
244
  pod = None