apache-airflow-providers-standard 1.9.1rc1__py3-none-any.whl → 1.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. airflow/providers/standard/__init__.py +3 -3
  2. airflow/providers/standard/decorators/bash.py +1 -2
  3. airflow/providers/standard/example_dags/example_bash_decorator.py +1 -1
  4. airflow/providers/standard/exceptions.py +1 -1
  5. airflow/providers/standard/hooks/subprocess.py +2 -9
  6. airflow/providers/standard/operators/bash.py +7 -3
  7. airflow/providers/standard/operators/datetime.py +1 -2
  8. airflow/providers/standard/operators/hitl.py +20 -10
  9. airflow/providers/standard/operators/latest_only.py +19 -10
  10. airflow/providers/standard/operators/python.py +39 -6
  11. airflow/providers/standard/operators/trigger_dagrun.py +82 -27
  12. airflow/providers/standard/sensors/bash.py +2 -4
  13. airflow/providers/standard/sensors/date_time.py +1 -16
  14. airflow/providers/standard/sensors/external_task.py +91 -51
  15. airflow/providers/standard/sensors/filesystem.py +2 -19
  16. airflow/providers/standard/sensors/time.py +2 -18
  17. airflow/providers/standard/sensors/time_delta.py +7 -6
  18. airflow/providers/standard/triggers/external_task.py +43 -40
  19. airflow/providers/standard/triggers/file.py +1 -1
  20. airflow/providers/standard/triggers/hitl.py +136 -87
  21. airflow/providers/standard/utils/openlineage.py +185 -0
  22. airflow/providers/standard/utils/python_virtualenv.py +38 -4
  23. airflow/providers/standard/utils/python_virtualenv_script.jinja2 +18 -3
  24. airflow/providers/standard/utils/sensor_helper.py +19 -8
  25. airflow/providers/standard/utils/skipmixin.py +2 -2
  26. airflow/providers/standard/version_compat.py +1 -0
  27. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/METADATA +25 -11
  28. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/RECORD +32 -30
  29. apache_airflow_providers_standard-1.10.3.dist-info/licenses/NOTICE +5 -0
  30. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/WHEEL +0 -0
  31. {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/entry_points.txt +0 -0
  32. {airflow/providers/standard → apache_airflow_providers_standard-1.10.3.dist-info/licenses}/LICENSE +0 -0
@@ -16,7 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
- from airflow.exceptions import AirflowOptionalProviderFeatureException
19
+ from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
20
20
  from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
21
21
 
22
22
  if not AIRFLOW_V_3_1_PLUS:
@@ -30,6 +30,9 @@ from uuid import UUID
30
30
 
31
31
  from asgiref.sync import sync_to_async
32
32
 
33
+ from airflow.providers.common.compat.sdk import ParamValidationError
34
+ from airflow.sdk import Param
35
+ from airflow.sdk.definitions.param import ParamsDict
33
36
  from airflow.sdk.execution_time.hitl import (
34
37
  HITLUser,
35
38
  get_hitl_detail_content_detail,
@@ -43,7 +46,7 @@ class HITLTriggerEventSuccessPayload(TypedDict, total=False):
43
46
  """Minimum required keys for a success Human-in-the-loop TriggerEvent."""
44
47
 
45
48
  chosen_options: list[str]
46
- params_input: dict[str, Any]
49
+ params_input: dict[str, dict[str, Any]]
47
50
  responded_by_user: HITLUser | None
48
51
  responded_at: datetime
49
52
  timedout: bool
@@ -53,7 +56,7 @@ class HITLTriggerEventFailurePayload(TypedDict):
53
56
  """Minimum required keys for a failed Human-in-the-loop TriggerEvent."""
54
57
 
55
58
  error: str
56
- error_type: Literal["timeout", "unknown"]
59
+ error_type: Literal["timeout", "unknown", "validation"]
57
60
 
58
61
 
59
62
  class HITLTrigger(BaseTrigger):
@@ -64,7 +67,7 @@ class HITLTrigger(BaseTrigger):
64
67
  *,
65
68
  ti_id: UUID,
66
69
  options: list[str],
67
- params: dict[str, Any],
70
+ params: dict[str, dict[str, Any]],
68
71
  defaults: list[str] | None = None,
69
72
  multiple: bool = False,
70
73
  timeout_datetime: datetime | None,
@@ -80,7 +83,21 @@ class HITLTrigger(BaseTrigger):
80
83
  self.defaults = defaults
81
84
  self.timeout_datetime = timeout_datetime
82
85
 
83
- self.params = params
86
+ self.params = ParamsDict(
87
+ {
88
+ k: Param(
89
+ v.pop("value"),
90
+ **v,
91
+ )
92
+ if HITLTrigger._is_param(v)
93
+ else Param(v)
94
+ for k, v in params.items()
95
+ },
96
+ )
97
+
98
+ @staticmethod
99
+ def _is_param(value: Any) -> bool:
100
+ return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value"))
84
101
 
85
102
  def serialize(self) -> tuple[str, dict[str, Any]]:
86
103
  """Serialize HITLTrigger arguments and classpath."""
@@ -90,99 +107,131 @@ class HITLTrigger(BaseTrigger):
90
107
  "ti_id": self.ti_id,
91
108
  "options": self.options,
92
109
  "defaults": self.defaults,
93
- "params": self.params,
110
+ "params": {k: self.params.get_param(k).serialize() for k in self.params},
94
111
  "multiple": self.multiple,
95
112
  "timeout_datetime": self.timeout_datetime,
96
113
  "poke_interval": self.poke_interval,
97
114
  },
98
115
  )
99
116
 
100
- async def run(self) -> AsyncIterator[TriggerEvent]:
101
- """Loop until the Human-in-the-loop response received or timeout reached."""
102
- while True:
103
- if self.timeout_datetime and self.timeout_datetime < utcnow():
104
- # Fetch latest HITL detail before fallback
105
- resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
106
- # Response already received, yield success and exit
107
- if resp.response_received and resp.chosen_options:
108
- if TYPE_CHECKING:
109
- assert resp.responded_by_user is not None
110
- assert resp.responded_at is not None
111
-
112
- self.log.info(
113
- "[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)",
114
- resp.responded_by_user.name,
115
- resp.responded_by_user.id,
116
- resp.chosen_options,
117
- resp.responded_at,
118
- )
119
- yield TriggerEvent(
120
- HITLTriggerEventSuccessPayload(
121
- chosen_options=resp.chosen_options,
122
- params_input=resp.params_input or {},
123
- responded_at=resp.responded_at,
124
- responded_by_user=HITLUser(
125
- id=resp.responded_by_user.id,
126
- name=resp.responded_by_user.name,
127
- ),
128
- timedout=False,
129
- )
130
- )
131
- return
132
-
133
- if self.defaults is None:
134
- yield TriggerEvent(
135
- HITLTriggerEventFailurePayload(
136
- error="The timeout has passed, and the response has not yet been received.",
137
- error_type="timeout",
138
- )
139
- )
140
- return
141
-
142
- resp = await sync_to_async(update_hitl_detail_response)(
143
- ti_id=self.ti_id,
144
- chosen_options=self.defaults,
145
- params_input=self.params,
117
+ async def _handle_timeout(self) -> TriggerEvent:
118
+ """Handle HITL timeout logic and yield appropriate event."""
119
+ resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
120
+
121
+ # Case 1: Response arrived just before timeout
122
+ if resp.response_received and resp.chosen_options:
123
+ if TYPE_CHECKING:
124
+ assert resp.responded_by_user is not None
125
+ assert resp.responded_at is not None
126
+
127
+ chosen_options_list = list(resp.chosen_options or [])
128
+ self.log.info(
129
+ "[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)",
130
+ resp.responded_by_user.name,
131
+ resp.responded_by_user.id,
132
+ chosen_options_list,
133
+ resp.responded_at,
134
+ )
135
+ return TriggerEvent(
136
+ HITLTriggerEventSuccessPayload(
137
+ chosen_options=chosen_options_list,
138
+ params_input=resp.params_input or {},
139
+ responded_at=resp.responded_at,
140
+ responded_by_user=HITLUser(
141
+ id=resp.responded_by_user.id,
142
+ name=resp.responded_by_user.name,
143
+ ),
144
+ timedout=False,
146
145
  )
147
- if TYPE_CHECKING:
148
- assert resp.responded_at is not None
149
- self.log.info(
150
- "[HITL] timeout reached before receiving response, fallback to default %s", self.defaults
146
+ )
147
+
148
+ # Case 2: No defaults defined → failure
149
+ if self.defaults is None:
150
+ return TriggerEvent(
151
+ HITLTriggerEventFailurePayload(
152
+ error="The timeout has passed, and the response has not yet been received.",
153
+ error_type="timeout",
151
154
  )
152
- yield TriggerEvent(
153
- HITLTriggerEventSuccessPayload(
154
- chosen_options=self.defaults,
155
- params_input=self.params,
156
- responded_by_user=None,
157
- responded_at=resp.responded_at,
158
- timedout=True,
155
+ )
156
+
157
+ # Case 3: Timeout fallback to default
158
+ resp = await sync_to_async(update_hitl_detail_response)(
159
+ ti_id=self.ti_id,
160
+ chosen_options=self.defaults,
161
+ params_input=self.params.dump(),
162
+ )
163
+ if TYPE_CHECKING:
164
+ assert resp.responded_at is not None
165
+
166
+ self.log.info(
167
+ "[HITL] timeout reached before receiving response, fallback to default %s",
168
+ self.defaults,
169
+ )
170
+ return TriggerEvent(
171
+ HITLTriggerEventSuccessPayload(
172
+ chosen_options=self.defaults,
173
+ params_input=self.params.dump(),
174
+ responded_by_user=None,
175
+ responded_at=resp.responded_at,
176
+ timedout=True,
177
+ )
178
+ )
179
+
180
+ async def _handle_response(self):
181
+ """Check if HITL response is ready and yield success if so."""
182
+ resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
183
+ if TYPE_CHECKING:
184
+ assert resp.responded_by_user is not None
185
+ assert resp.responded_at is not None
186
+
187
+ if not (resp.response_received and resp.chosen_options):
188
+ return None
189
+
190
+ # validate input
191
+ if params_input := resp.params_input:
192
+ try:
193
+ for key, value in params_input.items():
194
+ self.params[key] = value
195
+ except ParamValidationError as err:
196
+ return TriggerEvent(
197
+ HITLTriggerEventFailurePayload(
198
+ error=str(err),
199
+ error_type="validation",
159
200
  )
160
201
  )
202
+
203
+ chosen_options_list = list(resp.chosen_options or [])
204
+ self.log.info(
205
+ "[HITL] responded_by=%s (id=%s) options=%s at %s",
206
+ resp.responded_by_user.name,
207
+ resp.responded_by_user.id,
208
+ chosen_options_list,
209
+ resp.responded_at,
210
+ )
211
+ return TriggerEvent(
212
+ HITLTriggerEventSuccessPayload(
213
+ chosen_options=chosen_options_list,
214
+ params_input=params_input or {},
215
+ responded_at=resp.responded_at,
216
+ responded_by_user=HITLUser(
217
+ id=resp.responded_by_user.id,
218
+ name=resp.responded_by_user.name,
219
+ ),
220
+ timedout=False,
221
+ )
222
+ )
223
+
224
+ async def run(self) -> AsyncIterator[TriggerEvent]:
225
+ """Loop until the Human-in-the-loop response received or timeout reached."""
226
+ while True:
227
+ if self.timeout_datetime and self.timeout_datetime < utcnow():
228
+ event = await self._handle_timeout()
229
+ yield event
161
230
  return
162
231
 
163
- resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
164
- if resp.response_received and resp.chosen_options:
165
- if TYPE_CHECKING:
166
- assert resp.responded_by_user is not None
167
- assert resp.responded_at is not None
168
- self.log.info(
169
- "[HITL] responded_by=%s (id=%s) options=%s at %s",
170
- resp.responded_by_user.name,
171
- resp.responded_by_user.id,
172
- resp.chosen_options,
173
- resp.responded_at,
174
- )
175
- yield TriggerEvent(
176
- HITLTriggerEventSuccessPayload(
177
- chosen_options=resp.chosen_options,
178
- params_input=resp.params_input or {},
179
- responded_at=resp.responded_at,
180
- responded_by_user=HITLUser(
181
- id=resp.responded_by_user.id,
182
- name=resp.responded_by_user.name,
183
- ),
184
- timedout=False,
185
- )
186
- )
232
+ event = await self._handle_response()
233
+ if event:
234
+ yield event
187
235
  return
236
+
188
237
  await asyncio.sleep(self.poke_interval)
@@ -0,0 +1,185 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from typing import TYPE_CHECKING
21
+
22
+ from airflow.providers.common.compat.openlineage.check import require_openlineage_version
23
+ from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
24
+
25
+ if TYPE_CHECKING:
26
+ from airflow.models import TaskInstance
27
+ from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
28
+
29
+ log = logging.getLogger(__name__)
30
+
31
+ OPENLINEAGE_PROVIDER_MIN_VERSION = "2.8.0"
32
+
33
+
34
+ def _is_openlineage_provider_accessible() -> bool:
35
+ """
36
+ Check if the OpenLineage provider is accessible.
37
+
38
+ This function attempts to import the necessary OpenLineage modules and checks if the provider
39
+ is enabled and the listener is available.
40
+
41
+ Returns:
42
+ bool: True if the OpenLineage provider is accessible, False otherwise.
43
+ """
44
+ try:
45
+ from airflow.providers.openlineage.conf import is_disabled
46
+ from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
47
+ except (ImportError, AttributeError):
48
+ log.debug("OpenLineage provider could not be imported.")
49
+ return False
50
+
51
+ if is_disabled():
52
+ log.debug("OpenLineage provider is disabled.")
53
+ return False
54
+
55
+ if not get_openlineage_listener():
56
+ log.debug("OpenLineage listener could not be found.")
57
+ return False
58
+
59
+ return True
60
+
61
+
62
+ @require_openlineage_version(provider_min_version=OPENLINEAGE_PROVIDER_MIN_VERSION)
63
+ def _get_openlineage_parent_info(ti: TaskInstance | RuntimeTI) -> dict[str, str]:
64
+ """Get OpenLineage metadata about the parent task."""
65
+ from airflow.providers.openlineage.plugins.macros import (
66
+ lineage_job_name,
67
+ lineage_job_namespace,
68
+ lineage_root_job_name,
69
+ lineage_root_job_namespace,
70
+ lineage_root_run_id,
71
+ lineage_run_id,
72
+ )
73
+
74
+ return {
75
+ "parentRunId": lineage_run_id(ti),
76
+ "parentJobName": lineage_job_name(ti),
77
+ "parentJobNamespace": lineage_job_namespace(),
78
+ "rootParentRunId": lineage_root_run_id(ti),
79
+ "rootParentJobName": lineage_root_job_name(ti),
80
+ "rootParentJobNamespace": lineage_root_job_namespace(ti),
81
+ }
82
+
83
+
84
+ def _inject_openlineage_parent_info_to_dagrun_conf(
85
+ dr_conf: dict | None, ol_parent_info: dict[str, str]
86
+ ) -> dict:
87
+ """
88
+ Safely inject OpenLineage parent and root run metadata into a DAG run configuration.
89
+
90
+ This function adds parent and root job/run identifiers derived from the given TaskInstance into the
91
+ `openlineage` section of the DAG run configuration. If an `openlineage` key already exists, it is
92
+ preserved and extended, but no existing parent or root identifiers are overwritten.
93
+
94
+ The function performs several safety checks:
95
+ - If conf is not a dictionary or contains a non-dict `openlineage` section, conf is returned unmodified.
96
+ - If `openlineage` section contains any parent/root lineage identifiers, conf is returned unmodified.
97
+
98
+ Args:
99
+ dr_conf: The original DAG run configuration dictionary or None.
100
+ ol_parent_info: OpenLineage metadata about the parent task
101
+
102
+ Returns:
103
+ A modified DAG run conf with injected OpenLineage parent and root metadata,
104
+ or the original conf if injection is not possible.
105
+ """
106
+ current_ol_dr_conf = {}
107
+ if isinstance(dr_conf, dict) and dr_conf.get("openlineage"):
108
+ current_ol_dr_conf = dr_conf["openlineage"]
109
+ if not isinstance(current_ol_dr_conf, dict):
110
+ log.warning(
111
+ "Existing 'openlineage' section of DagRun conf is not a dictionary; "
112
+ "skipping injection of parent metadata."
113
+ )
114
+ return dr_conf
115
+ forbidden_keys = (
116
+ "parentRunId",
117
+ "parentJobName",
118
+ "parentJobNamespace",
119
+ "rootParentRunId",
120
+ "rootJobName",
121
+ "rootJobNamespace",
122
+ )
123
+
124
+ if existing := [k for k in forbidden_keys if k in current_ol_dr_conf]:
125
+ log.warning(
126
+ "'openlineage' section of DagRun conf already contains parent or root "
127
+ "identifiers: `%s`; skipping injection to avoid overwriting existing values.",
128
+ ", ".join(existing),
129
+ )
130
+ return dr_conf
131
+
132
+ return {**(dr_conf or {}), **{"openlineage": {**ol_parent_info, **current_ol_dr_conf}}}
133
+
134
+
135
+ def safe_inject_openlineage_properties_into_dagrun_conf(
136
+ dr_conf: dict | None, ti: TaskInstance | RuntimeTI | None
137
+ ) -> dict | None:
138
+ """
139
+ Safely inject OpenLineage parent task metadata into a DAG run conf.
140
+
141
+ This function checks whether the OpenLineage provider is accessible and supports parent information
142
+ injection. If so, it enriches the DAG run conf with OpenLineage metadata about the parent task
143
+ to improve lineage tracking. The function does not modify other conf fields, will not overwrite
144
+ any existing content, and safely returns the original configuration if OpenLineage is unavailable,
145
+ unsupported, or an error occurs during injection.
146
+
147
+ :param dr_conf: The original DAG run configuration dictionary.
148
+ :param ti: The TaskInstance whose metadata may be injected.
149
+
150
+ :return: A potentially enriched DAG run conf with OpenLineage parent information,
151
+ or the original conf if injection was skipped or failed.
152
+ """
153
+ try:
154
+ if ti is None:
155
+ log.debug("Task instance not provided - dagrun conf not modified.")
156
+ return dr_conf
157
+
158
+ if not _is_openlineage_provider_accessible():
159
+ log.debug("OpenLineage provider not accessible - dagrun conf not modified.")
160
+ return dr_conf
161
+
162
+ ol_parent_info = _get_openlineage_parent_info(ti=ti)
163
+
164
+ log.info("Injecting openlineage parent task information into dagrun conf.")
165
+ new_conf = _inject_openlineage_parent_info_to_dagrun_conf(
166
+ dr_conf=dr_conf.copy() if isinstance(dr_conf, dict) else dr_conf,
167
+ ol_parent_info=ol_parent_info,
168
+ )
169
+ return new_conf
170
+ except AirflowOptionalProviderFeatureException:
171
+ log.info(
172
+ "Current OpenLineage provider version doesn't support parent information in "
173
+ "the DagRun conf. Upgrade `apache-airflow-providers-openlineage>=%s` to use this feature. "
174
+ "DagRun conf has not been modified by OpenLineage.",
175
+ OPENLINEAGE_PROVIDER_MIN_VERSION,
176
+ )
177
+ return dr_conf
178
+ except Exception as e:
179
+ log.warning(
180
+ "An error occurred while trying to inject OpenLineage information into dagrun conf. "
181
+ "DagRun conf has not been modified by OpenLineage. Error: %s",
182
+ str(e),
183
+ )
184
+ log.debug("Error details: ", exc_info=e)
185
+ return dr_conf
@@ -19,16 +19,18 @@
19
19
 
20
20
  from __future__ import annotations
21
21
 
22
+ import logging
22
23
  import os
24
+ import shlex
23
25
  import shutil
26
+ import subprocess
24
27
  import warnings
25
28
  from pathlib import Path
26
29
 
27
30
  import jinja2
28
31
  from jinja2 import select_autoescape
29
32
 
30
- from airflow.configuration import conf
31
- from airflow.utils.process_utils import execute_in_subprocess
33
+ from airflow.providers.common.compat.sdk import conf
32
34
 
33
35
 
34
36
  def _is_uv_installed() -> bool:
@@ -132,6 +134,37 @@ def _index_urls_to_uv_env_vars(index_urls: list[str] | None = None) -> dict[str,
132
134
  return uv_index_env_vars
133
135
 
134
136
 
137
+ def _execute_in_subprocess(cmd: list[str], cwd: str | None = None, env: dict[str, str] | None = None) -> None:
138
+ """
139
+ Execute a process and stream output to logger.
140
+
141
+ :param cmd: command and arguments to run
142
+ :param cwd: Current working directory passed to the Popen constructor
143
+ :param env: Additional environment variables to set for the subprocess.
144
+ """
145
+ log = logging.getLogger(__name__)
146
+
147
+ log.info("Executing cmd: %s", " ".join(shlex.quote(c) for c in cmd))
148
+ with subprocess.Popen(
149
+ cmd,
150
+ stdout=subprocess.PIPE,
151
+ stderr=subprocess.STDOUT,
152
+ bufsize=0,
153
+ close_fds=False,
154
+ cwd=cwd,
155
+ env=env,
156
+ ) as proc:
157
+ log.info("Output:")
158
+ if proc.stdout:
159
+ with proc.stdout:
160
+ for line in iter(proc.stdout.readline, b""):
161
+ log.info("%s", line.decode().rstrip())
162
+
163
+ exit_code = proc.wait()
164
+ if exit_code != 0:
165
+ raise subprocess.CalledProcessError(exit_code, cmd)
166
+
167
+
135
168
  def prepare_virtualenv(
136
169
  venv_directory: str,
137
170
  python_bin: str,
@@ -167,9 +200,10 @@ def prepare_virtualenv(
167
200
 
168
201
  if _use_uv():
169
202
  venv_cmd = _generate_uv_cmd(venv_directory, python_bin, system_site_packages)
203
+ _execute_in_subprocess(venv_cmd, env={**os.environ, **_index_urls_to_uv_env_vars(index_urls)})
170
204
  else:
171
205
  venv_cmd = _generate_venv_cmd(venv_directory, python_bin, system_site_packages)
172
- execute_in_subprocess(venv_cmd)
206
+ _execute_in_subprocess(venv_cmd)
173
207
 
174
208
  pip_cmd = None
175
209
  if requirements is not None and len(requirements) != 0:
@@ -188,7 +222,7 @@ def prepare_virtualenv(
188
222
  )
189
223
 
190
224
  if pip_cmd:
191
- execute_in_subprocess(pip_cmd, env={**os.environ, **_index_urls_to_uv_env_vars(index_urls)})
225
+ _execute_in_subprocess(pip_cmd, env={**os.environ, **_index_urls_to_uv_env_vars(index_urls)})
192
226
 
193
227
  return f"{venv_directory}/bin/python"
194
228
 
@@ -40,6 +40,23 @@ if sys.version_info >= (3,6):
40
40
  pass
41
41
  {% endif %}
42
42
 
43
+ try:
44
+ from airflow.sdk.execution_time import task_runner
45
+ except ModuleNotFoundError:
46
+ pass
47
+ else:
48
+ {#-
49
+ We are in an Airflow 3.x environment, try and set up supervisor comms so
50
+ virtualenv can access Vars/Conn/XCom/etc that normal tasks can
51
+
52
+ We don't use the walrus operator (`:=`) below as it is possible people can
53
+ be using this on pre-3.8 versions of python, and while Airflow doesn't
54
+ support them, it's easy to not break it not using that operator here.
55
+ #}
56
+ reinit_supervisor_comms = getattr(task_runner, "reinit_supervisor_comms", None)
57
+ if reinit_supervisor_comms:
58
+ reinit_supervisor_comms()
59
+
43
60
  # Script
44
61
  {{ python_callable_source }}
45
62
 
@@ -49,12 +66,10 @@ if sys.version_info >= (3,6):
49
66
  import types
50
67
 
51
68
  {{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")
52
-
53
69
  {{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}
54
-
55
70
  sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}
56
71
 
57
- {% endif%}
72
+ {%- endif -%}
58
73
 
59
74
  {% if op_args or op_kwargs %}
60
75
  with open(sys.argv[1], "rb") as file:
@@ -16,6 +16,7 @@
16
16
  # under the License.
17
17
  from __future__ import annotations
18
18
 
19
+ from collections.abc import Collection
19
20
  from typing import TYPE_CHECKING, Any, cast
20
21
 
21
22
  from sqlalchemy import func, select, tuple_
@@ -27,7 +28,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from sqlalchemy.orm import Session
30
- from sqlalchemy.sql import Executable
31
+ from sqlalchemy.sql import Select
31
32
 
32
33
 
33
34
  @provide_session
@@ -59,6 +60,7 @@ def _get_count(
59
60
  session.scalar(
60
61
  _count_stmt(TI, states, dttm_filter, external_dag_id).where(TI.task_id.in_(external_task_ids))
61
62
  )
63
+ or 0
62
64
  ) / len(external_task_ids)
63
65
  elif external_task_group_id:
64
66
  external_task_group_task_ids = _get_external_task_group_task_ids(
@@ -68,20 +70,25 @@ def _get_count(
68
70
  count = 0
69
71
  else:
70
72
  count = (
71
- session.scalar(
72
- _count_stmt(TI, states, dttm_filter, external_dag_id).where(
73
- tuple_(TI.task_id, TI.map_index).in_(external_task_group_task_ids)
73
+ (
74
+ session.scalar(
75
+ _count_stmt(TI, states, dttm_filter, external_dag_id).where(
76
+ tuple_(TI.task_id, TI.map_index).in_(external_task_group_task_ids)
77
+ )
74
78
  )
79
+ or 0
75
80
  )
76
81
  / len(external_task_group_task_ids)
77
82
  * len(dttm_filter)
78
83
  )
79
84
  else:
80
- count = session.scalar(_count_stmt(DR, states, dttm_filter, external_dag_id))
85
+ count = session.scalar(_count_stmt(DR, states, dttm_filter, external_dag_id)) or 0
81
86
  return cast("int", count)
82
87
 
83
88
 
84
- def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
89
+ def _count_stmt(
90
+ model: type[DagRun] | type[TaskInstance], states: list[str], dttm_filter: list[Any], external_dag_id: str
91
+ ) -> Select[tuple[int]]:
85
92
  """
86
93
  Get the count of records against dttm filter and states.
87
94
 
@@ -97,7 +104,9 @@ def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
97
104
  )
98
105
 
99
106
 
100
- def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, external_dag_id, session):
107
+ def _get_external_task_group_task_ids(
108
+ dttm_filter: list[Any], external_task_group_id: str, external_dag_id: str, session: Session
109
+ ) -> list[tuple[str, int]]:
101
110
  """
102
111
  Get the count of records against dttm filter and states.
103
112
 
@@ -107,6 +116,8 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter
107
116
  :param session: airflow session object
108
117
  """
109
118
  refreshed_dag_info = SerializedDagModel.get_dag(external_dag_id, session=session)
119
+ if not refreshed_dag_info:
120
+ return [(external_task_group_id, -1)]
110
121
  task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
111
122
 
112
123
  if task_group:
@@ -129,7 +140,7 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter
129
140
 
130
141
  def _get_count_by_matched_states(
131
142
  run_id_task_state_map: dict[str, dict[str, Any]],
132
- states: list[str],
143
+ states: Collection[str],
133
144
  ):
134
145
  count = 0
135
146
  for _, task_states in run_id_task_state_map.items():
@@ -21,7 +21,7 @@ from collections.abc import Iterable, Sequence
21
21
  from types import GeneratorType
22
22
  from typing import TYPE_CHECKING
23
23
 
24
- from airflow.exceptions import AirflowException
24
+ from airflow.providers.common.compat.sdk import AirflowException
25
25
  from airflow.utils.log.logging_mixin import LoggingMixin
26
26
 
27
27
  if TYPE_CHECKING:
@@ -63,7 +63,7 @@ class SkipMixin(LoggingMixin):
63
63
  """
64
64
  # Import is internal for backward compatibility when importing PythonOperator
65
65
  # from airflow.providers.common.compat.standard.operators
66
- from airflow.exceptions import DownstreamTasksSkipped
66
+ from airflow.providers.common.compat.sdk import DownstreamTasksSkipped
67
67
 
68
68
  # The following could be applied only for non-mapped tasks,
69
69
  # as future mapped tasks have not been expanded yet. Such tasks