apache-airflow-providers-standard 1.9.1rc1__py3-none-any.whl → 1.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/standard/__init__.py +3 -3
- airflow/providers/standard/decorators/bash.py +1 -2
- airflow/providers/standard/example_dags/example_bash_decorator.py +1 -1
- airflow/providers/standard/exceptions.py +1 -1
- airflow/providers/standard/hooks/subprocess.py +2 -9
- airflow/providers/standard/operators/bash.py +7 -3
- airflow/providers/standard/operators/datetime.py +1 -2
- airflow/providers/standard/operators/hitl.py +20 -10
- airflow/providers/standard/operators/latest_only.py +19 -10
- airflow/providers/standard/operators/python.py +39 -6
- airflow/providers/standard/operators/trigger_dagrun.py +82 -27
- airflow/providers/standard/sensors/bash.py +2 -4
- airflow/providers/standard/sensors/date_time.py +1 -16
- airflow/providers/standard/sensors/external_task.py +91 -51
- airflow/providers/standard/sensors/filesystem.py +2 -19
- airflow/providers/standard/sensors/time.py +2 -18
- airflow/providers/standard/sensors/time_delta.py +7 -6
- airflow/providers/standard/triggers/external_task.py +43 -40
- airflow/providers/standard/triggers/file.py +1 -1
- airflow/providers/standard/triggers/hitl.py +136 -87
- airflow/providers/standard/utils/openlineage.py +185 -0
- airflow/providers/standard/utils/python_virtualenv.py +38 -4
- airflow/providers/standard/utils/python_virtualenv_script.jinja2 +18 -3
- airflow/providers/standard/utils/sensor_helper.py +19 -8
- airflow/providers/standard/utils/skipmixin.py +2 -2
- airflow/providers/standard/version_compat.py +1 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/METADATA +25 -11
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/RECORD +32 -30
- apache_airflow_providers_standard-1.10.3.dist-info/licenses/NOTICE +5 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_standard-1.9.1rc1.dist-info → apache_airflow_providers_standard-1.10.3.dist-info}/entry_points.txt +0 -0
- {airflow/providers/standard → apache_airflow_providers_standard-1.10.3.dist-info/licenses}/LICENSE +0 -0
|
@@ -16,7 +16,7 @@
|
|
|
16
16
|
# under the License.
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
-
from airflow.
|
|
19
|
+
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
|
|
20
20
|
from airflow.providers.standard.version_compat import AIRFLOW_V_3_1_PLUS
|
|
21
21
|
|
|
22
22
|
if not AIRFLOW_V_3_1_PLUS:
|
|
@@ -30,6 +30,9 @@ from uuid import UUID
|
|
|
30
30
|
|
|
31
31
|
from asgiref.sync import sync_to_async
|
|
32
32
|
|
|
33
|
+
from airflow.providers.common.compat.sdk import ParamValidationError
|
|
34
|
+
from airflow.sdk import Param
|
|
35
|
+
from airflow.sdk.definitions.param import ParamsDict
|
|
33
36
|
from airflow.sdk.execution_time.hitl import (
|
|
34
37
|
HITLUser,
|
|
35
38
|
get_hitl_detail_content_detail,
|
|
@@ -43,7 +46,7 @@ class HITLTriggerEventSuccessPayload(TypedDict, total=False):
|
|
|
43
46
|
"""Minimum required keys for a success Human-in-the-loop TriggerEvent."""
|
|
44
47
|
|
|
45
48
|
chosen_options: list[str]
|
|
46
|
-
params_input: dict[str, Any]
|
|
49
|
+
params_input: dict[str, dict[str, Any]]
|
|
47
50
|
responded_by_user: HITLUser | None
|
|
48
51
|
responded_at: datetime
|
|
49
52
|
timedout: bool
|
|
@@ -53,7 +56,7 @@ class HITLTriggerEventFailurePayload(TypedDict):
|
|
|
53
56
|
"""Minimum required keys for a failed Human-in-the-loop TriggerEvent."""
|
|
54
57
|
|
|
55
58
|
error: str
|
|
56
|
-
error_type: Literal["timeout", "unknown"]
|
|
59
|
+
error_type: Literal["timeout", "unknown", "validation"]
|
|
57
60
|
|
|
58
61
|
|
|
59
62
|
class HITLTrigger(BaseTrigger):
|
|
@@ -64,7 +67,7 @@ class HITLTrigger(BaseTrigger):
|
|
|
64
67
|
*,
|
|
65
68
|
ti_id: UUID,
|
|
66
69
|
options: list[str],
|
|
67
|
-
params: dict[str, Any],
|
|
70
|
+
params: dict[str, dict[str, Any]],
|
|
68
71
|
defaults: list[str] | None = None,
|
|
69
72
|
multiple: bool = False,
|
|
70
73
|
timeout_datetime: datetime | None,
|
|
@@ -80,7 +83,21 @@ class HITLTrigger(BaseTrigger):
|
|
|
80
83
|
self.defaults = defaults
|
|
81
84
|
self.timeout_datetime = timeout_datetime
|
|
82
85
|
|
|
83
|
-
self.params =
|
|
86
|
+
self.params = ParamsDict(
|
|
87
|
+
{
|
|
88
|
+
k: Param(
|
|
89
|
+
v.pop("value"),
|
|
90
|
+
**v,
|
|
91
|
+
)
|
|
92
|
+
if HITLTrigger._is_param(v)
|
|
93
|
+
else Param(v)
|
|
94
|
+
for k, v in params.items()
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _is_param(value: Any) -> bool:
|
|
100
|
+
return isinstance(value, dict) and all(key in value for key in ("description", "schema", "value"))
|
|
84
101
|
|
|
85
102
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
|
86
103
|
"""Serialize HITLTrigger arguments and classpath."""
|
|
@@ -90,99 +107,131 @@ class HITLTrigger(BaseTrigger):
|
|
|
90
107
|
"ti_id": self.ti_id,
|
|
91
108
|
"options": self.options,
|
|
92
109
|
"defaults": self.defaults,
|
|
93
|
-
"params": self.params,
|
|
110
|
+
"params": {k: self.params.get_param(k).serialize() for k in self.params},
|
|
94
111
|
"multiple": self.multiple,
|
|
95
112
|
"timeout_datetime": self.timeout_datetime,
|
|
96
113
|
"poke_interval": self.poke_interval,
|
|
97
114
|
},
|
|
98
115
|
)
|
|
99
116
|
|
|
100
|
-
async def
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
timedout=False,
|
|
129
|
-
)
|
|
130
|
-
)
|
|
131
|
-
return
|
|
132
|
-
|
|
133
|
-
if self.defaults is None:
|
|
134
|
-
yield TriggerEvent(
|
|
135
|
-
HITLTriggerEventFailurePayload(
|
|
136
|
-
error="The timeout has passed, and the response has not yet been received.",
|
|
137
|
-
error_type="timeout",
|
|
138
|
-
)
|
|
139
|
-
)
|
|
140
|
-
return
|
|
141
|
-
|
|
142
|
-
resp = await sync_to_async(update_hitl_detail_response)(
|
|
143
|
-
ti_id=self.ti_id,
|
|
144
|
-
chosen_options=self.defaults,
|
|
145
|
-
params_input=self.params,
|
|
117
|
+
async def _handle_timeout(self) -> TriggerEvent:
|
|
118
|
+
"""Handle HITL timeout logic and yield appropriate event."""
|
|
119
|
+
resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
|
|
120
|
+
|
|
121
|
+
# Case 1: Response arrived just before timeout
|
|
122
|
+
if resp.response_received and resp.chosen_options:
|
|
123
|
+
if TYPE_CHECKING:
|
|
124
|
+
assert resp.responded_by_user is not None
|
|
125
|
+
assert resp.responded_at is not None
|
|
126
|
+
|
|
127
|
+
chosen_options_list = list(resp.chosen_options or [])
|
|
128
|
+
self.log.info(
|
|
129
|
+
"[HITL] responded_by=%s (id=%s) options=%s at %s (timeout fallback skipped)",
|
|
130
|
+
resp.responded_by_user.name,
|
|
131
|
+
resp.responded_by_user.id,
|
|
132
|
+
chosen_options_list,
|
|
133
|
+
resp.responded_at,
|
|
134
|
+
)
|
|
135
|
+
return TriggerEvent(
|
|
136
|
+
HITLTriggerEventSuccessPayload(
|
|
137
|
+
chosen_options=chosen_options_list,
|
|
138
|
+
params_input=resp.params_input or {},
|
|
139
|
+
responded_at=resp.responded_at,
|
|
140
|
+
responded_by_user=HITLUser(
|
|
141
|
+
id=resp.responded_by_user.id,
|
|
142
|
+
name=resp.responded_by_user.name,
|
|
143
|
+
),
|
|
144
|
+
timedout=False,
|
|
146
145
|
)
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Case 2: No defaults defined → failure
|
|
149
|
+
if self.defaults is None:
|
|
150
|
+
return TriggerEvent(
|
|
151
|
+
HITLTriggerEventFailurePayload(
|
|
152
|
+
error="The timeout has passed, and the response has not yet been received.",
|
|
153
|
+
error_type="timeout",
|
|
151
154
|
)
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Case 3: Timeout fallback to default
|
|
158
|
+
resp = await sync_to_async(update_hitl_detail_response)(
|
|
159
|
+
ti_id=self.ti_id,
|
|
160
|
+
chosen_options=self.defaults,
|
|
161
|
+
params_input=self.params.dump(),
|
|
162
|
+
)
|
|
163
|
+
if TYPE_CHECKING:
|
|
164
|
+
assert resp.responded_at is not None
|
|
165
|
+
|
|
166
|
+
self.log.info(
|
|
167
|
+
"[HITL] timeout reached before receiving response, fallback to default %s",
|
|
168
|
+
self.defaults,
|
|
169
|
+
)
|
|
170
|
+
return TriggerEvent(
|
|
171
|
+
HITLTriggerEventSuccessPayload(
|
|
172
|
+
chosen_options=self.defaults,
|
|
173
|
+
params_input=self.params.dump(),
|
|
174
|
+
responded_by_user=None,
|
|
175
|
+
responded_at=resp.responded_at,
|
|
176
|
+
timedout=True,
|
|
177
|
+
)
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
async def _handle_response(self):
|
|
181
|
+
"""Check if HITL response is ready and yield success if so."""
|
|
182
|
+
resp = await sync_to_async(get_hitl_detail_content_detail)(ti_id=self.ti_id)
|
|
183
|
+
if TYPE_CHECKING:
|
|
184
|
+
assert resp.responded_by_user is not None
|
|
185
|
+
assert resp.responded_at is not None
|
|
186
|
+
|
|
187
|
+
if not (resp.response_received and resp.chosen_options):
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
# validate input
|
|
191
|
+
if params_input := resp.params_input:
|
|
192
|
+
try:
|
|
193
|
+
for key, value in params_input.items():
|
|
194
|
+
self.params[key] = value
|
|
195
|
+
except ParamValidationError as err:
|
|
196
|
+
return TriggerEvent(
|
|
197
|
+
HITLTriggerEventFailurePayload(
|
|
198
|
+
error=str(err),
|
|
199
|
+
error_type="validation",
|
|
159
200
|
)
|
|
160
201
|
)
|
|
202
|
+
|
|
203
|
+
chosen_options_list = list(resp.chosen_options or [])
|
|
204
|
+
self.log.info(
|
|
205
|
+
"[HITL] responded_by=%s (id=%s) options=%s at %s",
|
|
206
|
+
resp.responded_by_user.name,
|
|
207
|
+
resp.responded_by_user.id,
|
|
208
|
+
chosen_options_list,
|
|
209
|
+
resp.responded_at,
|
|
210
|
+
)
|
|
211
|
+
return TriggerEvent(
|
|
212
|
+
HITLTriggerEventSuccessPayload(
|
|
213
|
+
chosen_options=chosen_options_list,
|
|
214
|
+
params_input=params_input or {},
|
|
215
|
+
responded_at=resp.responded_at,
|
|
216
|
+
responded_by_user=HITLUser(
|
|
217
|
+
id=resp.responded_by_user.id,
|
|
218
|
+
name=resp.responded_by_user.name,
|
|
219
|
+
),
|
|
220
|
+
timedout=False,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
|
225
|
+
"""Loop until the Human-in-the-loop response received or timeout reached."""
|
|
226
|
+
while True:
|
|
227
|
+
if self.timeout_datetime and self.timeout_datetime < utcnow():
|
|
228
|
+
event = await self._handle_timeout()
|
|
229
|
+
yield event
|
|
161
230
|
return
|
|
162
231
|
|
|
163
|
-
|
|
164
|
-
if
|
|
165
|
-
|
|
166
|
-
assert resp.responded_by_user is not None
|
|
167
|
-
assert resp.responded_at is not None
|
|
168
|
-
self.log.info(
|
|
169
|
-
"[HITL] responded_by=%s (id=%s) options=%s at %s",
|
|
170
|
-
resp.responded_by_user.name,
|
|
171
|
-
resp.responded_by_user.id,
|
|
172
|
-
resp.chosen_options,
|
|
173
|
-
resp.responded_at,
|
|
174
|
-
)
|
|
175
|
-
yield TriggerEvent(
|
|
176
|
-
HITLTriggerEventSuccessPayload(
|
|
177
|
-
chosen_options=resp.chosen_options,
|
|
178
|
-
params_input=resp.params_input or {},
|
|
179
|
-
responded_at=resp.responded_at,
|
|
180
|
-
responded_by_user=HITLUser(
|
|
181
|
-
id=resp.responded_by_user.id,
|
|
182
|
-
name=resp.responded_by_user.name,
|
|
183
|
-
),
|
|
184
|
-
timedout=False,
|
|
185
|
-
)
|
|
186
|
-
)
|
|
232
|
+
event = await self._handle_response()
|
|
233
|
+
if event:
|
|
234
|
+
yield event
|
|
187
235
|
return
|
|
236
|
+
|
|
188
237
|
await asyncio.sleep(self.poke_interval)
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
|
3
|
+
# distributed with this work for additional information
|
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
|
6
|
+
# "License"); you may not use this file except in compliance
|
|
7
|
+
# with the License. You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
|
12
|
+
# software distributed under the License is distributed on an
|
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
14
|
+
# KIND, either express or implied. See the License for the
|
|
15
|
+
# specific language governing permissions and limitations
|
|
16
|
+
# under the License.
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
from typing import TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
from airflow.providers.common.compat.openlineage.check import require_openlineage_version
|
|
23
|
+
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from airflow.models import TaskInstance
|
|
27
|
+
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
|
|
28
|
+
|
|
29
|
+
log = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
OPENLINEAGE_PROVIDER_MIN_VERSION = "2.8.0"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _is_openlineage_provider_accessible() -> bool:
|
|
35
|
+
"""
|
|
36
|
+
Check if the OpenLineage provider is accessible.
|
|
37
|
+
|
|
38
|
+
This function attempts to import the necessary OpenLineage modules and checks if the provider
|
|
39
|
+
is enabled and the listener is available.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
bool: True if the OpenLineage provider is accessible, False otherwise.
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
from airflow.providers.openlineage.conf import is_disabled
|
|
46
|
+
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener
|
|
47
|
+
except (ImportError, AttributeError):
|
|
48
|
+
log.debug("OpenLineage provider could not be imported.")
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
if is_disabled():
|
|
52
|
+
log.debug("OpenLineage provider is disabled.")
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
if not get_openlineage_listener():
|
|
56
|
+
log.debug("OpenLineage listener could not be found.")
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@require_openlineage_version(provider_min_version=OPENLINEAGE_PROVIDER_MIN_VERSION)
|
|
63
|
+
def _get_openlineage_parent_info(ti: TaskInstance | RuntimeTI) -> dict[str, str]:
|
|
64
|
+
"""Get OpenLineage metadata about the parent task."""
|
|
65
|
+
from airflow.providers.openlineage.plugins.macros import (
|
|
66
|
+
lineage_job_name,
|
|
67
|
+
lineage_job_namespace,
|
|
68
|
+
lineage_root_job_name,
|
|
69
|
+
lineage_root_job_namespace,
|
|
70
|
+
lineage_root_run_id,
|
|
71
|
+
lineage_run_id,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return {
|
|
75
|
+
"parentRunId": lineage_run_id(ti),
|
|
76
|
+
"parentJobName": lineage_job_name(ti),
|
|
77
|
+
"parentJobNamespace": lineage_job_namespace(),
|
|
78
|
+
"rootParentRunId": lineage_root_run_id(ti),
|
|
79
|
+
"rootParentJobName": lineage_root_job_name(ti),
|
|
80
|
+
"rootParentJobNamespace": lineage_root_job_namespace(ti),
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _inject_openlineage_parent_info_to_dagrun_conf(
|
|
85
|
+
dr_conf: dict | None, ol_parent_info: dict[str, str]
|
|
86
|
+
) -> dict:
|
|
87
|
+
"""
|
|
88
|
+
Safely inject OpenLineage parent and root run metadata into a DAG run configuration.
|
|
89
|
+
|
|
90
|
+
This function adds parent and root job/run identifiers derived from the given TaskInstance into the
|
|
91
|
+
`openlineage` section of the DAG run configuration. If an `openlineage` key already exists, it is
|
|
92
|
+
preserved and extended, but no existing parent or root identifiers are overwritten.
|
|
93
|
+
|
|
94
|
+
The function performs several safety checks:
|
|
95
|
+
- If conf is not a dictionary or contains a non-dict `openlineage` section, conf is returned unmodified.
|
|
96
|
+
- If `openlineage` section contains any parent/root lineage identifiers, conf is returned unmodified.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
dr_conf: The original DAG run configuration dictionary or None.
|
|
100
|
+
ol_parent_info: OpenLineage metadata about the parent task
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A modified DAG run conf with injected OpenLineage parent and root metadata,
|
|
104
|
+
or the original conf if injection is not possible.
|
|
105
|
+
"""
|
|
106
|
+
current_ol_dr_conf = {}
|
|
107
|
+
if isinstance(dr_conf, dict) and dr_conf.get("openlineage"):
|
|
108
|
+
current_ol_dr_conf = dr_conf["openlineage"]
|
|
109
|
+
if not isinstance(current_ol_dr_conf, dict):
|
|
110
|
+
log.warning(
|
|
111
|
+
"Existing 'openlineage' section of DagRun conf is not a dictionary; "
|
|
112
|
+
"skipping injection of parent metadata."
|
|
113
|
+
)
|
|
114
|
+
return dr_conf
|
|
115
|
+
forbidden_keys = (
|
|
116
|
+
"parentRunId",
|
|
117
|
+
"parentJobName",
|
|
118
|
+
"parentJobNamespace",
|
|
119
|
+
"rootParentRunId",
|
|
120
|
+
"rootJobName",
|
|
121
|
+
"rootJobNamespace",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if existing := [k for k in forbidden_keys if k in current_ol_dr_conf]:
|
|
125
|
+
log.warning(
|
|
126
|
+
"'openlineage' section of DagRun conf already contains parent or root "
|
|
127
|
+
"identifiers: `%s`; skipping injection to avoid overwriting existing values.",
|
|
128
|
+
", ".join(existing),
|
|
129
|
+
)
|
|
130
|
+
return dr_conf
|
|
131
|
+
|
|
132
|
+
return {**(dr_conf or {}), **{"openlineage": {**ol_parent_info, **current_ol_dr_conf}}}
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def safe_inject_openlineage_properties_into_dagrun_conf(
|
|
136
|
+
dr_conf: dict | None, ti: TaskInstance | RuntimeTI | None
|
|
137
|
+
) -> dict | None:
|
|
138
|
+
"""
|
|
139
|
+
Safely inject OpenLineage parent task metadata into a DAG run conf.
|
|
140
|
+
|
|
141
|
+
This function checks whether the OpenLineage provider is accessible and supports parent information
|
|
142
|
+
injection. If so, it enriches the DAG run conf with OpenLineage metadata about the parent task
|
|
143
|
+
to improve lineage tracking. The function does not modify other conf fields, will not overwrite
|
|
144
|
+
any existing content, and safely returns the original configuration if OpenLineage is unavailable,
|
|
145
|
+
unsupported, or an error occurs during injection.
|
|
146
|
+
|
|
147
|
+
:param dr_conf: The original DAG run configuration dictionary.
|
|
148
|
+
:param ti: The TaskInstance whose metadata may be injected.
|
|
149
|
+
|
|
150
|
+
:return: A potentially enriched DAG run conf with OpenLineage parent information,
|
|
151
|
+
or the original conf if injection was skipped or failed.
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
if ti is None:
|
|
155
|
+
log.debug("Task instance not provided - dagrun conf not modified.")
|
|
156
|
+
return dr_conf
|
|
157
|
+
|
|
158
|
+
if not _is_openlineage_provider_accessible():
|
|
159
|
+
log.debug("OpenLineage provider not accessible - dagrun conf not modified.")
|
|
160
|
+
return dr_conf
|
|
161
|
+
|
|
162
|
+
ol_parent_info = _get_openlineage_parent_info(ti=ti)
|
|
163
|
+
|
|
164
|
+
log.info("Injecting openlineage parent task information into dagrun conf.")
|
|
165
|
+
new_conf = _inject_openlineage_parent_info_to_dagrun_conf(
|
|
166
|
+
dr_conf=dr_conf.copy() if isinstance(dr_conf, dict) else dr_conf,
|
|
167
|
+
ol_parent_info=ol_parent_info,
|
|
168
|
+
)
|
|
169
|
+
return new_conf
|
|
170
|
+
except AirflowOptionalProviderFeatureException:
|
|
171
|
+
log.info(
|
|
172
|
+
"Current OpenLineage provider version doesn't support parent information in "
|
|
173
|
+
"the DagRun conf. Upgrade `apache-airflow-providers-openlineage>=%s` to use this feature. "
|
|
174
|
+
"DagRun conf has not been modified by OpenLineage.",
|
|
175
|
+
OPENLINEAGE_PROVIDER_MIN_VERSION,
|
|
176
|
+
)
|
|
177
|
+
return dr_conf
|
|
178
|
+
except Exception as e:
|
|
179
|
+
log.warning(
|
|
180
|
+
"An error occurred while trying to inject OpenLineage information into dagrun conf. "
|
|
181
|
+
"DagRun conf has not been modified by OpenLineage. Error: %s",
|
|
182
|
+
str(e),
|
|
183
|
+
)
|
|
184
|
+
log.debug("Error details: ", exc_info=e)
|
|
185
|
+
return dr_conf
|
|
@@ -19,16 +19,18 @@
|
|
|
19
19
|
|
|
20
20
|
from __future__ import annotations
|
|
21
21
|
|
|
22
|
+
import logging
|
|
22
23
|
import os
|
|
24
|
+
import shlex
|
|
23
25
|
import shutil
|
|
26
|
+
import subprocess
|
|
24
27
|
import warnings
|
|
25
28
|
from pathlib import Path
|
|
26
29
|
|
|
27
30
|
import jinja2
|
|
28
31
|
from jinja2 import select_autoescape
|
|
29
32
|
|
|
30
|
-
from airflow.
|
|
31
|
-
from airflow.utils.process_utils import execute_in_subprocess
|
|
33
|
+
from airflow.providers.common.compat.sdk import conf
|
|
32
34
|
|
|
33
35
|
|
|
34
36
|
def _is_uv_installed() -> bool:
|
|
@@ -132,6 +134,37 @@ def _index_urls_to_uv_env_vars(index_urls: list[str] | None = None) -> dict[str,
|
|
|
132
134
|
return uv_index_env_vars
|
|
133
135
|
|
|
134
136
|
|
|
137
|
+
def _execute_in_subprocess(cmd: list[str], cwd: str | None = None, env: dict[str, str] | None = None) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Execute a process and stream output to logger.
|
|
140
|
+
|
|
141
|
+
:param cmd: command and arguments to run
|
|
142
|
+
:param cwd: Current working directory passed to the Popen constructor
|
|
143
|
+
:param env: Additional environment variables to set for the subprocess.
|
|
144
|
+
"""
|
|
145
|
+
log = logging.getLogger(__name__)
|
|
146
|
+
|
|
147
|
+
log.info("Executing cmd: %s", " ".join(shlex.quote(c) for c in cmd))
|
|
148
|
+
with subprocess.Popen(
|
|
149
|
+
cmd,
|
|
150
|
+
stdout=subprocess.PIPE,
|
|
151
|
+
stderr=subprocess.STDOUT,
|
|
152
|
+
bufsize=0,
|
|
153
|
+
close_fds=False,
|
|
154
|
+
cwd=cwd,
|
|
155
|
+
env=env,
|
|
156
|
+
) as proc:
|
|
157
|
+
log.info("Output:")
|
|
158
|
+
if proc.stdout:
|
|
159
|
+
with proc.stdout:
|
|
160
|
+
for line in iter(proc.stdout.readline, b""):
|
|
161
|
+
log.info("%s", line.decode().rstrip())
|
|
162
|
+
|
|
163
|
+
exit_code = proc.wait()
|
|
164
|
+
if exit_code != 0:
|
|
165
|
+
raise subprocess.CalledProcessError(exit_code, cmd)
|
|
166
|
+
|
|
167
|
+
|
|
135
168
|
def prepare_virtualenv(
|
|
136
169
|
venv_directory: str,
|
|
137
170
|
python_bin: str,
|
|
@@ -167,9 +200,10 @@ def prepare_virtualenv(
|
|
|
167
200
|
|
|
168
201
|
if _use_uv():
|
|
169
202
|
venv_cmd = _generate_uv_cmd(venv_directory, python_bin, system_site_packages)
|
|
203
|
+
_execute_in_subprocess(venv_cmd, env={**os.environ, **_index_urls_to_uv_env_vars(index_urls)})
|
|
170
204
|
else:
|
|
171
205
|
venv_cmd = _generate_venv_cmd(venv_directory, python_bin, system_site_packages)
|
|
172
|
-
|
|
206
|
+
_execute_in_subprocess(venv_cmd)
|
|
173
207
|
|
|
174
208
|
pip_cmd = None
|
|
175
209
|
if requirements is not None and len(requirements) != 0:
|
|
@@ -188,7 +222,7 @@ def prepare_virtualenv(
|
|
|
188
222
|
)
|
|
189
223
|
|
|
190
224
|
if pip_cmd:
|
|
191
|
-
|
|
225
|
+
_execute_in_subprocess(pip_cmd, env={**os.environ, **_index_urls_to_uv_env_vars(index_urls)})
|
|
192
226
|
|
|
193
227
|
return f"{venv_directory}/bin/python"
|
|
194
228
|
|
|
@@ -40,6 +40,23 @@ if sys.version_info >= (3,6):
|
|
|
40
40
|
pass
|
|
41
41
|
{% endif %}
|
|
42
42
|
|
|
43
|
+
try:
|
|
44
|
+
from airflow.sdk.execution_time import task_runner
|
|
45
|
+
except ModuleNotFoundError:
|
|
46
|
+
pass
|
|
47
|
+
else:
|
|
48
|
+
{#-
|
|
49
|
+
We are in an Airflow 3.x environment, try and set up supervisor comms so
|
|
50
|
+
virtualenv can access Vars/Conn/XCom/etc that normal tasks can
|
|
51
|
+
|
|
52
|
+
We don't use the walrus operator (`:=`) below as it is possible people can
|
|
53
|
+
be using this on pre-3.8 versions of python, and while Airflow doesn't
|
|
54
|
+
support them, it's easy to not break it not using that operator here.
|
|
55
|
+
#}
|
|
56
|
+
reinit_supervisor_comms = getattr(task_runner, "reinit_supervisor_comms", None)
|
|
57
|
+
if reinit_supervisor_comms:
|
|
58
|
+
reinit_supervisor_comms()
|
|
59
|
+
|
|
43
60
|
# Script
|
|
44
61
|
{{ python_callable_source }}
|
|
45
62
|
|
|
@@ -49,12 +66,10 @@ if sys.version_info >= (3,6):
|
|
|
49
66
|
import types
|
|
50
67
|
|
|
51
68
|
{{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")
|
|
52
|
-
|
|
53
69
|
{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}
|
|
54
|
-
|
|
55
70
|
sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}
|
|
56
71
|
|
|
57
|
-
{
|
|
72
|
+
{%- endif -%}
|
|
58
73
|
|
|
59
74
|
{% if op_args or op_kwargs %}
|
|
60
75
|
with open(sys.argv[1], "rb") as file:
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
# under the License.
|
|
17
17
|
from __future__ import annotations
|
|
18
18
|
|
|
19
|
+
from collections.abc import Collection
|
|
19
20
|
from typing import TYPE_CHECKING, Any, cast
|
|
20
21
|
|
|
21
22
|
from sqlalchemy import func, select, tuple_
|
|
@@ -27,7 +28,7 @@ from airflow.utils.session import NEW_SESSION, provide_session
|
|
|
27
28
|
|
|
28
29
|
if TYPE_CHECKING:
|
|
29
30
|
from sqlalchemy.orm import Session
|
|
30
|
-
from sqlalchemy.sql import
|
|
31
|
+
from sqlalchemy.sql import Select
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
@provide_session
|
|
@@ -59,6 +60,7 @@ def _get_count(
|
|
|
59
60
|
session.scalar(
|
|
60
61
|
_count_stmt(TI, states, dttm_filter, external_dag_id).where(TI.task_id.in_(external_task_ids))
|
|
61
62
|
)
|
|
63
|
+
or 0
|
|
62
64
|
) / len(external_task_ids)
|
|
63
65
|
elif external_task_group_id:
|
|
64
66
|
external_task_group_task_ids = _get_external_task_group_task_ids(
|
|
@@ -68,20 +70,25 @@ def _get_count(
|
|
|
68
70
|
count = 0
|
|
69
71
|
else:
|
|
70
72
|
count = (
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
73
|
+
(
|
|
74
|
+
session.scalar(
|
|
75
|
+
_count_stmt(TI, states, dttm_filter, external_dag_id).where(
|
|
76
|
+
tuple_(TI.task_id, TI.map_index).in_(external_task_group_task_ids)
|
|
77
|
+
)
|
|
74
78
|
)
|
|
79
|
+
or 0
|
|
75
80
|
)
|
|
76
81
|
/ len(external_task_group_task_ids)
|
|
77
82
|
* len(dttm_filter)
|
|
78
83
|
)
|
|
79
84
|
else:
|
|
80
|
-
count = session.scalar(_count_stmt(DR, states, dttm_filter, external_dag_id))
|
|
85
|
+
count = session.scalar(_count_stmt(DR, states, dttm_filter, external_dag_id)) or 0
|
|
81
86
|
return cast("int", count)
|
|
82
87
|
|
|
83
88
|
|
|
84
|
-
def _count_stmt(
|
|
89
|
+
def _count_stmt(
|
|
90
|
+
model: type[DagRun] | type[TaskInstance], states: list[str], dttm_filter: list[Any], external_dag_id: str
|
|
91
|
+
) -> Select[tuple[int]]:
|
|
85
92
|
"""
|
|
86
93
|
Get the count of records against dttm filter and states.
|
|
87
94
|
|
|
@@ -97,7 +104,9 @@ def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
|
|
|
97
104
|
)
|
|
98
105
|
|
|
99
106
|
|
|
100
|
-
def _get_external_task_group_task_ids(
|
|
107
|
+
def _get_external_task_group_task_ids(
|
|
108
|
+
dttm_filter: list[Any], external_task_group_id: str, external_dag_id: str, session: Session
|
|
109
|
+
) -> list[tuple[str, int]]:
|
|
101
110
|
"""
|
|
102
111
|
Get the count of records against dttm filter and states.
|
|
103
112
|
|
|
@@ -107,6 +116,8 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter
|
|
|
107
116
|
:param session: airflow session object
|
|
108
117
|
"""
|
|
109
118
|
refreshed_dag_info = SerializedDagModel.get_dag(external_dag_id, session=session)
|
|
119
|
+
if not refreshed_dag_info:
|
|
120
|
+
return [(external_task_group_id, -1)]
|
|
110
121
|
task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)
|
|
111
122
|
|
|
112
123
|
if task_group:
|
|
@@ -129,7 +140,7 @@ def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, exter
|
|
|
129
140
|
|
|
130
141
|
def _get_count_by_matched_states(
|
|
131
142
|
run_id_task_state_map: dict[str, dict[str, Any]],
|
|
132
|
-
states:
|
|
143
|
+
states: Collection[str],
|
|
133
144
|
):
|
|
134
145
|
count = 0
|
|
135
146
|
for _, task_states in run_id_task_state_map.items():
|
|
@@ -21,7 +21,7 @@ from collections.abc import Iterable, Sequence
|
|
|
21
21
|
from types import GeneratorType
|
|
22
22
|
from typing import TYPE_CHECKING
|
|
23
23
|
|
|
24
|
-
from airflow.
|
|
24
|
+
from airflow.providers.common.compat.sdk import AirflowException
|
|
25
25
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
|
26
26
|
|
|
27
27
|
if TYPE_CHECKING:
|
|
@@ -63,7 +63,7 @@ class SkipMixin(LoggingMixin):
|
|
|
63
63
|
"""
|
|
64
64
|
# Import is internal for backward compatibility when importing PythonOperator
|
|
65
65
|
# from airflow.providers.common.compat.standard.operators
|
|
66
|
-
from airflow.
|
|
66
|
+
from airflow.providers.common.compat.sdk import DownstreamTasksSkipped
|
|
67
67
|
|
|
68
68
|
# The following could be applied only for non-mapped tasks,
|
|
69
69
|
# as future mapped tasks have not been expanded yet. Such tasks
|