apache-airflow-providers-amazon 8.29.0rc1__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/{datasets → assets}/s3.py +10 -6
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +5 -11
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +2 -5
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +0 -6
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +3 -17
- airflow/providers/amazon/aws/hooks/base_aws.py +4 -162
- airflow/providers/amazon/aws/hooks/logs.py +1 -20
- airflow/providers/amazon/aws/hooks/quicksight.py +1 -17
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +6 -120
- airflow/providers/amazon/aws/hooks/redshift_data.py +52 -14
- airflow/providers/amazon/aws/hooks/s3.py +24 -27
- airflow/providers/amazon/aws/hooks/sagemaker.py +4 -48
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -6
- airflow/providers/amazon/aws/operators/appflow.py +1 -10
- airflow/providers/amazon/aws/operators/batch.py +1 -29
- airflow/providers/amazon/aws/operators/datasync.py +1 -8
- airflow/providers/amazon/aws/operators/ecs.py +1 -25
- airflow/providers/amazon/aws/operators/eks.py +7 -46
- airflow/providers/amazon/aws/operators/emr.py +16 -232
- airflow/providers/amazon/aws/operators/glue_databrew.py +1 -10
- airflow/providers/amazon/aws/operators/rds.py +3 -17
- airflow/providers/amazon/aws/operators/redshift_data.py +18 -3
- airflow/providers/amazon/aws/operators/s3.py +12 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +10 -32
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -40
- airflow/providers/amazon/aws/sensors/batch.py +1 -8
- airflow/providers/amazon/aws/sensors/dms.py +1 -8
- airflow/providers/amazon/aws/sensors/dynamodb.py +22 -8
- airflow/providers/amazon/aws/sensors/emr.py +0 -7
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +1 -8
- airflow/providers/amazon/aws/sensors/glue_crawler.py +1 -8
- airflow/providers/amazon/aws/sensors/quicksight.py +1 -29
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -8
- airflow/providers/amazon/aws/sensors/s3.py +1 -8
- airflow/providers/amazon/aws/sensors/sagemaker.py +2 -9
- airflow/providers/amazon/aws/sensors/sqs.py +1 -8
- airflow/providers/amazon/aws/sensors/step_function.py +1 -8
- airflow/providers/amazon/aws/transfers/base.py +1 -14
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +5 -33
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +15 -10
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +3 -6
- airflow/providers/amazon/aws/triggers/batch.py +1 -168
- airflow/providers/amazon/aws/triggers/eks.py +1 -20
- airflow/providers/amazon/aws/triggers/emr.py +0 -32
- airflow/providers/amazon/aws/triggers/glue_crawler.py +0 -11
- airflow/providers/amazon/aws/triggers/glue_databrew.py +0 -21
- airflow/providers/amazon/aws/triggers/rds.py +0 -79
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +5 -64
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -93
- airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py +106 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +4 -164
- airflow/providers/amazon/aws/utils/mixins.py +1 -23
- airflow/providers/amazon/aws/utils/openlineage.py +3 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +1 -1
- airflow/providers/amazon/get_provider_info.py +13 -4
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/METADATA +12 -13
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/RECORD +64 -64
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +0 -149
- /airflow/providers/amazon/aws/{datasets → assets}/__init__.py +0 -0
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/entry_points.txt +0 -0
@@ -16,95 +16,16 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from functools import cached_property
|
20
19
|
from typing import TYPE_CHECKING, Any
|
21
20
|
|
22
|
-
from deprecated import deprecated
|
23
|
-
|
24
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
25
21
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
26
22
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
27
23
|
from airflow.providers.amazon.aws.utils.rds import RdsDbType
|
28
|
-
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
29
|
-
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
30
24
|
|
31
25
|
if TYPE_CHECKING:
|
32
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
33
27
|
|
34
28
|
|
35
|
-
@deprecated(
|
36
|
-
reason=(
|
37
|
-
"This trigger is deprecated, please use the other RDS triggers "
|
38
|
-
"such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or RdsDbAvailableTrigger"
|
39
|
-
),
|
40
|
-
category=AirflowProviderDeprecationWarning,
|
41
|
-
)
|
42
|
-
class RdsDbInstanceTrigger(BaseTrigger):
|
43
|
-
"""
|
44
|
-
Deprecated Trigger for RDS operations. Do not use.
|
45
|
-
|
46
|
-
:param waiter_name: Name of the waiter to use, for instance 'db_instance_available'
|
47
|
-
or 'db_instance_deleted'.
|
48
|
-
:param db_instance_identifier: The DB instance identifier for the DB instance to be polled.
|
49
|
-
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
50
|
-
:param waiter_max_attempts: The maximum number of attempts to be made.
|
51
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
52
|
-
:param region_name: AWS region where the DB is located, if different from the default one.
|
53
|
-
:param response: The response from the RdsHook, to be passed back to the operator.
|
54
|
-
"""
|
55
|
-
|
56
|
-
def __init__(
|
57
|
-
self,
|
58
|
-
waiter_name: str,
|
59
|
-
db_instance_identifier: str,
|
60
|
-
waiter_delay: int,
|
61
|
-
waiter_max_attempts: int,
|
62
|
-
aws_conn_id: str | None,
|
63
|
-
region_name: str | None,
|
64
|
-
response: dict[str, Any],
|
65
|
-
):
|
66
|
-
self.db_instance_identifier = db_instance_identifier
|
67
|
-
self.waiter_delay = waiter_delay
|
68
|
-
self.waiter_max_attempts = waiter_max_attempts
|
69
|
-
self.aws_conn_id = aws_conn_id
|
70
|
-
self.region_name = region_name
|
71
|
-
self.waiter_name = waiter_name
|
72
|
-
self.response = response
|
73
|
-
|
74
|
-
def serialize(self) -> tuple[str, dict[str, Any]]:
|
75
|
-
return (
|
76
|
-
# dynamically generate the fully qualified name of the class
|
77
|
-
self.__class__.__module__ + "." + self.__class__.__qualname__,
|
78
|
-
{
|
79
|
-
"db_instance_identifier": self.db_instance_identifier,
|
80
|
-
"waiter_delay": str(self.waiter_delay),
|
81
|
-
"waiter_max_attempts": str(self.waiter_max_attempts),
|
82
|
-
"aws_conn_id": self.aws_conn_id,
|
83
|
-
"region_name": self.region_name,
|
84
|
-
"waiter_name": self.waiter_name,
|
85
|
-
"response": self.response,
|
86
|
-
},
|
87
|
-
)
|
88
|
-
|
89
|
-
@cached_property
|
90
|
-
def hook(self) -> RdsHook:
|
91
|
-
return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
92
|
-
|
93
|
-
async def run(self):
|
94
|
-
async with self.hook.async_conn as client:
|
95
|
-
waiter = client.get_waiter(self.waiter_name)
|
96
|
-
await async_wait(
|
97
|
-
waiter=waiter,
|
98
|
-
waiter_delay=int(self.waiter_delay),
|
99
|
-
waiter_max_attempts=int(self.waiter_max_attempts),
|
100
|
-
args={"DBInstanceIdentifier": self.db_instance_identifier},
|
101
|
-
failure_message="Error checking DB Instance status",
|
102
|
-
status_message="DB instance status is",
|
103
|
-
status_args=["DBInstances[0].DBInstanceStatus"],
|
104
|
-
)
|
105
|
-
yield TriggerEvent({"status": "success", "response": self.response})
|
106
|
-
|
107
|
-
|
108
29
|
_waiter_arg = {
|
109
30
|
RdsDbType.INSTANCE.value: "DBInstanceIdentifier",
|
110
31
|
RdsDbType.CLUSTER.value: "DBClusterIdentifier",
|
@@ -17,11 +17,9 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import asyncio
|
20
|
-
import warnings
|
21
20
|
from typing import TYPE_CHECKING, Any, AsyncIterator
|
22
21
|
|
23
|
-
from airflow.
|
24
|
-
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook
|
22
|
+
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
|
25
23
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
26
24
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
27
25
|
|
@@ -45,21 +43,10 @@ class RedshiftCreateClusterTrigger(AwsBaseWaiterTrigger):
|
|
45
43
|
def __init__(
|
46
44
|
self,
|
47
45
|
cluster_identifier: str,
|
48
|
-
poll_interval: int | None = None,
|
49
|
-
max_attempt: int | None = None,
|
50
46
|
aws_conn_id: str | None = "aws_default",
|
51
47
|
waiter_delay: int = 15,
|
52
48
|
waiter_max_attempts: int = 999999,
|
53
49
|
):
|
54
|
-
if poll_interval is not None or max_attempt is not None:
|
55
|
-
warnings.warn(
|
56
|
-
"please use waiter_delay instead of poll_interval "
|
57
|
-
"and waiter_max_attempts instead of max_attempt.",
|
58
|
-
AirflowProviderDeprecationWarning,
|
59
|
-
stacklevel=2,
|
60
|
-
)
|
61
|
-
waiter_delay = poll_interval or waiter_delay
|
62
|
-
waiter_max_attempts = max_attempt or waiter_max_attempts
|
63
50
|
super().__init__(
|
64
51
|
serialized_fields={"cluster_identifier": cluster_identifier},
|
65
52
|
waiter_name="cluster_available",
|
@@ -93,21 +80,10 @@ class RedshiftPauseClusterTrigger(AwsBaseWaiterTrigger):
|
|
93
80
|
def __init__(
|
94
81
|
self,
|
95
82
|
cluster_identifier: str,
|
96
|
-
poll_interval: int | None = None,
|
97
|
-
max_attempts: int | None = None,
|
98
83
|
aws_conn_id: str | None = "aws_default",
|
99
84
|
waiter_delay: int = 15,
|
100
85
|
waiter_max_attempts: int = 999999,
|
101
86
|
):
|
102
|
-
if poll_interval is not None or max_attempts is not None:
|
103
|
-
warnings.warn(
|
104
|
-
"please use waiter_delay instead of poll_interval "
|
105
|
-
"and waiter_max_attempts instead of max_attempt.",
|
106
|
-
AirflowProviderDeprecationWarning,
|
107
|
-
stacklevel=2,
|
108
|
-
)
|
109
|
-
waiter_delay = poll_interval or waiter_delay
|
110
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
111
87
|
super().__init__(
|
112
88
|
serialized_fields={"cluster_identifier": cluster_identifier},
|
113
89
|
waiter_name="cluster_paused",
|
@@ -141,21 +117,10 @@ class RedshiftCreateClusterSnapshotTrigger(AwsBaseWaiterTrigger):
|
|
141
117
|
def __init__(
|
142
118
|
self,
|
143
119
|
cluster_identifier: str,
|
144
|
-
poll_interval: int | None = None,
|
145
|
-
max_attempts: int | None = None,
|
146
120
|
aws_conn_id: str | None = "aws_default",
|
147
121
|
waiter_delay: int = 15,
|
148
122
|
waiter_max_attempts: int = 999999,
|
149
123
|
):
|
150
|
-
if poll_interval is not None or max_attempts is not None:
|
151
|
-
warnings.warn(
|
152
|
-
"please use waiter_delay instead of poll_interval "
|
153
|
-
"and waiter_max_attempts instead of max_attempt.",
|
154
|
-
AirflowProviderDeprecationWarning,
|
155
|
-
stacklevel=2,
|
156
|
-
)
|
157
|
-
waiter_delay = poll_interval or waiter_delay
|
158
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
159
124
|
super().__init__(
|
160
125
|
serialized_fields={"cluster_identifier": cluster_identifier},
|
161
126
|
waiter_name="snapshot_available",
|
@@ -189,21 +154,10 @@ class RedshiftResumeClusterTrigger(AwsBaseWaiterTrigger):
|
|
189
154
|
def __init__(
|
190
155
|
self,
|
191
156
|
cluster_identifier: str,
|
192
|
-
poll_interval: int | None = None,
|
193
|
-
max_attempts: int | None = None,
|
194
157
|
aws_conn_id: str | None = "aws_default",
|
195
158
|
waiter_delay: int = 15,
|
196
159
|
waiter_max_attempts: int = 999999,
|
197
160
|
):
|
198
|
-
if poll_interval is not None or max_attempts is not None:
|
199
|
-
warnings.warn(
|
200
|
-
"please use waiter_delay instead of poll_interval "
|
201
|
-
"and waiter_max_attempts instead of max_attempt.",
|
202
|
-
AirflowProviderDeprecationWarning,
|
203
|
-
stacklevel=2,
|
204
|
-
)
|
205
|
-
waiter_delay = poll_interval or waiter_delay
|
206
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
207
161
|
super().__init__(
|
208
162
|
serialized_fields={"cluster_identifier": cluster_identifier},
|
209
163
|
waiter_name="cluster_resumed",
|
@@ -234,21 +188,10 @@ class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):
|
|
234
188
|
def __init__(
|
235
189
|
self,
|
236
190
|
cluster_identifier: str,
|
237
|
-
poll_interval: int | None = None,
|
238
|
-
max_attempts: int | None = None,
|
239
191
|
aws_conn_id: str | None = "aws_default",
|
240
192
|
waiter_delay: int = 30,
|
241
193
|
waiter_max_attempts: int = 30,
|
242
194
|
):
|
243
|
-
if poll_interval is not None or max_attempts is not None:
|
244
|
-
warnings.warn(
|
245
|
-
"please use waiter_delay instead of poll_interval "
|
246
|
-
"and waiter_max_attempts instead of max_attempt.",
|
247
|
-
AirflowProviderDeprecationWarning,
|
248
|
-
stacklevel=2,
|
249
|
-
)
|
250
|
-
waiter_delay = poll_interval or waiter_delay
|
251
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
252
195
|
super().__init__(
|
253
196
|
serialized_fields={"cluster_identifier": cluster_identifier},
|
254
197
|
waiter_name="cluster_deleted",
|
@@ -304,13 +247,11 @@ class RedshiftClusterTrigger(BaseTrigger):
|
|
304
247
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
305
248
|
"""Run async until the cluster status matches the target status."""
|
306
249
|
try:
|
307
|
-
hook =
|
250
|
+
hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
|
308
251
|
while True:
|
309
|
-
|
310
|
-
if
|
311
|
-
"status"
|
312
|
-
] == "error":
|
313
|
-
yield TriggerEvent(res)
|
252
|
+
status = await hook.cluster_status_async(self.cluster_identifier)
|
253
|
+
if status == self.target_status:
|
254
|
+
yield TriggerEvent({"status": "success", "message": "target state met"})
|
314
255
|
return
|
315
256
|
await asyncio.sleep(self.poke_interval)
|
316
257
|
except Exception as e:
|
@@ -18,17 +18,15 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import asyncio
|
21
|
-
import time
|
22
21
|
from collections import Counter
|
23
22
|
from enum import IntEnum
|
24
23
|
from functools import cached_property
|
25
24
|
from typing import Any, AsyncIterator
|
26
25
|
|
27
26
|
from botocore.exceptions import WaiterError
|
28
|
-
from deprecated import deprecated
|
29
27
|
|
30
|
-
from airflow.exceptions import AirflowException
|
31
|
-
from airflow.providers.amazon.aws.hooks.sagemaker import
|
28
|
+
from airflow.exceptions import AirflowException
|
29
|
+
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
32
30
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
33
31
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
34
32
|
|
@@ -198,92 +196,3 @@ class SageMakerPipelineTrigger(BaseTrigger):
|
|
198
196
|
await asyncio.sleep(int(self.waiter_delay))
|
199
197
|
|
200
198
|
raise AirflowException("Waiter error: max attempts reached")
|
201
|
-
|
202
|
-
|
203
|
-
@deprecated(
|
204
|
-
reason=(
|
205
|
-
"`airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger` "
|
206
|
-
"has been deprecated and will be removed in future. Please use ``SageMakerTrigger`` instead."
|
207
|
-
),
|
208
|
-
category=AirflowProviderDeprecationWarning,
|
209
|
-
)
|
210
|
-
class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
211
|
-
"""
|
212
|
-
SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer.
|
213
|
-
|
214
|
-
:param job_name: name of the job to check status
|
215
|
-
:param poke_interval: polling period in seconds to check for the status
|
216
|
-
:param aws_conn_id: AWS connection ID for sagemaker
|
217
|
-
"""
|
218
|
-
|
219
|
-
def __init__(
|
220
|
-
self,
|
221
|
-
job_name: str,
|
222
|
-
poke_interval: float,
|
223
|
-
aws_conn_id: str | None = "aws_default",
|
224
|
-
):
|
225
|
-
super().__init__()
|
226
|
-
self.job_name = job_name
|
227
|
-
self.poke_interval = poke_interval
|
228
|
-
self.aws_conn_id = aws_conn_id
|
229
|
-
|
230
|
-
def serialize(self) -> tuple[str, dict[str, Any]]:
|
231
|
-
"""Serialize SageMakerTrainingPrintLogTrigger arguments and classpath."""
|
232
|
-
return (
|
233
|
-
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
|
234
|
-
{
|
235
|
-
"poke_interval": self.poke_interval,
|
236
|
-
"aws_conn_id": self.aws_conn_id,
|
237
|
-
"job_name": self.job_name,
|
238
|
-
},
|
239
|
-
)
|
240
|
-
|
241
|
-
@cached_property
|
242
|
-
def hook(self) -> SageMakerHook:
|
243
|
-
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
244
|
-
|
245
|
-
async def run(self) -> AsyncIterator[TriggerEvent]:
|
246
|
-
"""Make async connection to sagemaker async hook and gets job status for a job submitted by the operator."""
|
247
|
-
stream_names: list[str] = [] # The list of log streams
|
248
|
-
positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position
|
249
|
-
|
250
|
-
last_description = await self.hook.describe_training_job_async(self.job_name)
|
251
|
-
instance_count = last_description["ResourceConfig"]["InstanceCount"]
|
252
|
-
status = last_description["TrainingJobStatus"]
|
253
|
-
job_already_completed = status not in self.hook.non_terminal_states
|
254
|
-
state = LogState.COMPLETE if job_already_completed else LogState.TAILING
|
255
|
-
last_describe_job_call = time.time()
|
256
|
-
try:
|
257
|
-
while True:
|
258
|
-
(
|
259
|
-
state,
|
260
|
-
last_description,
|
261
|
-
last_describe_job_call,
|
262
|
-
) = await self.hook.describe_training_job_with_log_async(
|
263
|
-
self.job_name,
|
264
|
-
positions,
|
265
|
-
stream_names,
|
266
|
-
instance_count,
|
267
|
-
state,
|
268
|
-
last_description,
|
269
|
-
last_describe_job_call,
|
270
|
-
)
|
271
|
-
status = last_description["TrainingJobStatus"]
|
272
|
-
if status in self.hook.non_terminal_states:
|
273
|
-
await asyncio.sleep(self.poke_interval)
|
274
|
-
elif status in self.hook.failed_states:
|
275
|
-
reason = last_description.get("FailureReason", "(No reason provided)")
|
276
|
-
error_message = f"SageMaker job failed because {reason}"
|
277
|
-
yield TriggerEvent({"status": "error", "message": error_message})
|
278
|
-
return
|
279
|
-
else:
|
280
|
-
billable_seconds = SageMakerHook.count_billable_seconds(
|
281
|
-
training_start_time=last_description["TrainingStartTime"],
|
282
|
-
training_end_time=last_description["TrainingEndTime"],
|
283
|
-
instance_count=instance_count,
|
284
|
-
)
|
285
|
-
self.log.info("Billable seconds: %d", billable_seconds)
|
286
|
-
yield TriggerEvent({"status": "success", "message": last_description})
|
287
|
-
return
|
288
|
-
except Exception as e:
|
289
|
-
yield TriggerEvent({"status": "error", "message": str(e)})
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from importlib.util import find_spec
|
20
|
+
|
21
|
+
|
22
|
+
def _get_asset_compat_hook_lineage_collector():
|
23
|
+
from airflow.lineage.hook import get_hook_lineage_collector
|
24
|
+
|
25
|
+
collector = get_hook_lineage_collector()
|
26
|
+
|
27
|
+
if all(
|
28
|
+
getattr(collector, asset_method_name, None)
|
29
|
+
for asset_method_name in ("add_input_asset", "add_output_asset", "collected_assets")
|
30
|
+
):
|
31
|
+
return collector
|
32
|
+
|
33
|
+
# dataset is renamed as asset in Airflow 3.0
|
34
|
+
|
35
|
+
from functools import wraps
|
36
|
+
|
37
|
+
from airflow.lineage.hook import DatasetLineageInfo, HookLineage
|
38
|
+
|
39
|
+
DatasetLineageInfo.asset = DatasetLineageInfo.dataset
|
40
|
+
|
41
|
+
def rename_dataset_kwargs_as_assets_kwargs(function):
|
42
|
+
@wraps(function)
|
43
|
+
def wrapper(*args, **kwargs):
|
44
|
+
if "asset_kwargs" in kwargs:
|
45
|
+
kwargs["dataset_kwargs"] = kwargs.pop("asset_kwargs")
|
46
|
+
|
47
|
+
if "asset_extra" in kwargs:
|
48
|
+
kwargs["dataset_extra"] = kwargs.pop("asset_extra")
|
49
|
+
|
50
|
+
return function(*args, **kwargs)
|
51
|
+
|
52
|
+
return wrapper
|
53
|
+
|
54
|
+
collector.create_asset = rename_dataset_kwargs_as_assets_kwargs(collector.create_dataset)
|
55
|
+
collector.add_input_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_input_dataset)
|
56
|
+
collector.add_output_asset = rename_dataset_kwargs_as_assets_kwargs(collector.add_output_dataset)
|
57
|
+
|
58
|
+
def collected_assets_compat(collector) -> HookLineage:
|
59
|
+
"""Get the collected hook lineage information."""
|
60
|
+
lineage = collector.collected_datasets
|
61
|
+
return HookLineage(
|
62
|
+
[
|
63
|
+
DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context)
|
64
|
+
for item in lineage.inputs
|
65
|
+
],
|
66
|
+
[
|
67
|
+
DatasetLineageInfo(dataset=item.dataset, count=item.count, context=item.context)
|
68
|
+
for item in lineage.outputs
|
69
|
+
],
|
70
|
+
)
|
71
|
+
|
72
|
+
setattr(
|
73
|
+
collector.__class__,
|
74
|
+
"collected_assets",
|
75
|
+
property(lambda collector: collected_assets_compat(collector)),
|
76
|
+
)
|
77
|
+
|
78
|
+
return collector
|
79
|
+
|
80
|
+
|
81
|
+
def get_hook_lineage_collector():
|
82
|
+
# HookLineageCollector added in 2.10
|
83
|
+
try:
|
84
|
+
if find_spec("airflow.assets"):
|
85
|
+
# Dataset has been renamed as Asset in 3.0
|
86
|
+
from airflow.lineage.hook import get_hook_lineage_collector
|
87
|
+
|
88
|
+
return get_hook_lineage_collector()
|
89
|
+
|
90
|
+
return _get_asset_compat_hook_lineage_collector()
|
91
|
+
except ImportError:
|
92
|
+
|
93
|
+
class NoOpCollector:
|
94
|
+
"""
|
95
|
+
NoOpCollector is a hook lineage collector that does nothing.
|
96
|
+
|
97
|
+
It is used when you want to disable lineage collection.
|
98
|
+
"""
|
99
|
+
|
100
|
+
def add_input_asset(self, *_, **__):
|
101
|
+
pass
|
102
|
+
|
103
|
+
def add_output_asset(self, *_, **__):
|
104
|
+
pass
|
105
|
+
|
106
|
+
return NoOpCollector()
|