apache-airflow-providers-amazon 8.16.0rc1__py3-none-any.whl → 8.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -0
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +34 -19
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +44 -1
- airflow/providers/amazon/aws/auth_manager/cli/__init__.py +16 -0
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +178 -0
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +62 -0
- airflow/providers/amazon/aws/auth_manager/cli/schema.json +171 -0
- airflow/providers/amazon/aws/auth_manager/constants.py +1 -0
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +77 -23
- airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +17 -0
- airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
- airflow/providers/amazon/aws/executors/utils/__init__.py +16 -0
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +60 -0
- airflow/providers/amazon/aws/hooks/athena_sql.py +168 -0
- airflow/providers/amazon/aws/hooks/base_aws.py +14 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +33 -18
- airflow/providers/amazon/aws/hooks/redshift_data.py +66 -17
- airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -1
- airflow/providers/amazon/aws/hooks/s3.py +18 -4
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
- airflow/providers/amazon/aws/operators/batch.py +33 -15
- airflow/providers/amazon/aws/operators/cloud_formation.py +37 -26
- airflow/providers/amazon/aws/operators/datasync.py +19 -18
- airflow/providers/amazon/aws/operators/dms.py +57 -69
- airflow/providers/amazon/aws/operators/ec2.py +19 -5
- airflow/providers/amazon/aws/operators/emr.py +30 -10
- airflow/providers/amazon/aws/operators/eventbridge.py +57 -80
- airflow/providers/amazon/aws/operators/quicksight.py +17 -24
- airflow/providers/amazon/aws/operators/redshift_data.py +68 -19
- airflow/providers/amazon/aws/operators/s3.py +1 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +42 -12
- airflow/providers/amazon/aws/sensors/cloud_formation.py +30 -25
- airflow/providers/amazon/aws/sensors/dms.py +31 -24
- airflow/providers/amazon/aws/sensors/dynamodb.py +15 -15
- airflow/providers/amazon/aws/sensors/quicksight.py +34 -24
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +41 -3
- airflow/providers/amazon/aws/sensors/s3.py +13 -8
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +113 -0
- airflow/providers/amazon/aws/triggers/s3.py +9 -4
- airflow/providers/amazon/get_provider_info.py +55 -16
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/METADATA +17 -15
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/RECORD +46 -38
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.16.0rc1.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/entry_points.txt +0 -0
@@ -26,6 +26,21 @@ from airflow.providers.amazon.aws.utils import trim_none_values
|
|
26
26
|
|
27
27
|
if TYPE_CHECKING:
|
28
28
|
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
|
29
|
+
from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef
|
30
|
+
|
31
|
+
FINISHED_STATE = "FINISHED"
|
32
|
+
FAILED_STATE = "FAILED"
|
33
|
+
ABORTED_STATE = "ABORTED"
|
34
|
+
FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
|
35
|
+
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
|
36
|
+
|
37
|
+
|
38
|
+
class RedshiftDataQueryFailedError(ValueError):
|
39
|
+
"""Raise an error that redshift data query failed."""
|
40
|
+
|
41
|
+
|
42
|
+
class RedshiftDataQueryAbortedError(ValueError):
|
43
|
+
"""Raise an error that redshift data query was aborted."""
|
29
44
|
|
30
45
|
|
31
46
|
class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
@@ -108,27 +123,40 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
108
123
|
|
109
124
|
return statement_id
|
110
125
|
|
111
|
-
def wait_for_results(self, statement_id, poll_interval):
|
126
|
+
def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
|
112
127
|
while True:
|
113
128
|
self.log.info("Polling statement %s", statement_id)
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
if status == "FINISHED":
|
119
|
-
num_rows = resp.get("ResultRows")
|
120
|
-
if num_rows is not None:
|
121
|
-
self.log.info("Processed %s rows", num_rows)
|
122
|
-
return status
|
123
|
-
elif status in ("FAILED", "ABORTED"):
|
124
|
-
raise ValueError(
|
125
|
-
f"Statement {statement_id!r} terminated with status {status}. "
|
126
|
-
f"Response details: {pformat(resp)}"
|
127
|
-
)
|
128
|
-
else:
|
129
|
-
self.log.info("Query %s", status)
|
129
|
+
is_finished = self.check_query_is_finished(statement_id)
|
130
|
+
if is_finished:
|
131
|
+
return FINISHED_STATE
|
132
|
+
|
130
133
|
time.sleep(poll_interval)
|
131
134
|
|
135
|
+
def check_query_is_finished(self, statement_id: str) -> bool:
|
136
|
+
"""Check whether query finished, raise exception is failed."""
|
137
|
+
resp = self.conn.describe_statement(Id=statement_id)
|
138
|
+
return self.parse_statement_resposne(resp)
|
139
|
+
|
140
|
+
def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool:
|
141
|
+
"""Parse the response of describe_statement."""
|
142
|
+
status = resp["Status"]
|
143
|
+
if status == FINISHED_STATE:
|
144
|
+
num_rows = resp.get("ResultRows")
|
145
|
+
if num_rows is not None:
|
146
|
+
self.log.info("Processed %s rows", num_rows)
|
147
|
+
return True
|
148
|
+
elif status in FAILURE_STATES:
|
149
|
+
exception_cls = (
|
150
|
+
RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
|
151
|
+
)
|
152
|
+
raise exception_cls(
|
153
|
+
f"Statement {resp['Id']} terminated with status {status}. "
|
154
|
+
f"Response details: {pformat(resp)}"
|
155
|
+
)
|
156
|
+
|
157
|
+
self.log.info("Query status: %s", status)
|
158
|
+
return False
|
159
|
+
|
132
160
|
def get_table_primary_key(
|
133
161
|
self,
|
134
162
|
table: str,
|
@@ -201,3 +229,24 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
201
229
|
break
|
202
230
|
|
203
231
|
return pk_columns or None
|
232
|
+
|
233
|
+
async def is_still_running(self, statement_id: str) -> bool:
|
234
|
+
"""Async function to check whether the query is still running.
|
235
|
+
|
236
|
+
:param statement_id: the UUID of the statement
|
237
|
+
"""
|
238
|
+
async with self.async_conn as client:
|
239
|
+
desc = await client.describe_statement(Id=statement_id)
|
240
|
+
return desc["Status"] in RUNNING_STATES
|
241
|
+
|
242
|
+
async def check_query_is_finished_async(self, statement_id: str) -> bool:
|
243
|
+
"""Async function to check statement is finished.
|
244
|
+
|
245
|
+
It takes statement_id, makes async connection to redshift data to get the query status
|
246
|
+
by statement_id and returns the query status.
|
247
|
+
|
248
|
+
:param statement_id: the UUID of the statement
|
249
|
+
"""
|
250
|
+
async with self.async_conn as client:
|
251
|
+
resp = await client.describe_statement(Id=statement_id)
|
252
|
+
return self.parse_statement_resposne(resp)
|
@@ -239,7 +239,7 @@ class RedshiftSQLHook(DbApiHook):
|
|
239
239
|
|
240
240
|
def _get_identifier_from_hostname(self, hostname: str) -> str:
|
241
241
|
parts = hostname.split(".")
|
242
|
-
if "amazonaws.com"
|
242
|
+
if hostname.endswith("amazonaws.com") and len(parts) == 6:
|
243
243
|
return f"{parts[0]}.{parts[2]}"
|
244
244
|
else:
|
245
245
|
self.log.debug(
|
@@ -462,7 +462,9 @@ class S3Hook(AwsBaseHook):
|
|
462
462
|
return prefixes
|
463
463
|
|
464
464
|
@provide_bucket_name_async
|
465
|
-
async def get_file_metadata_async(
|
465
|
+
async def get_file_metadata_async(
|
466
|
+
self, client: AioBaseClient, bucket_name: str, key: str | None = None
|
467
|
+
) -> list[Any]:
|
466
468
|
"""
|
467
469
|
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
|
468
470
|
|
@@ -470,7 +472,7 @@ class S3Hook(AwsBaseHook):
|
|
470
472
|
:param bucket_name: the name of the bucket
|
471
473
|
:param key: the path to the key
|
472
474
|
"""
|
473
|
-
prefix = re.split(r"[\[
|
475
|
+
prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else ""
|
474
476
|
delimiter = ""
|
475
477
|
paginator = client.get_paginator("list_objects_v2")
|
476
478
|
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
|
@@ -486,6 +488,7 @@ class S3Hook(AwsBaseHook):
|
|
486
488
|
bucket_val: str,
|
487
489
|
wildcard_match: bool,
|
488
490
|
key: str,
|
491
|
+
use_regex: bool = False,
|
489
492
|
) -> bool:
|
490
493
|
"""
|
491
494
|
Get a list of files that a key matching a wildcard expression or get the head object.
|
@@ -498,6 +501,7 @@ class S3Hook(AwsBaseHook):
|
|
498
501
|
:param bucket_val: the name of the bucket
|
499
502
|
:param key: S3 keys that will point to the file
|
500
503
|
:param wildcard_match: the path to the key
|
504
|
+
:param use_regex: whether to use regex to check bucket
|
501
505
|
"""
|
502
506
|
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
|
503
507
|
if wildcard_match:
|
@@ -505,6 +509,11 @@ class S3Hook(AwsBaseHook):
|
|
505
509
|
key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
|
506
510
|
if not key_matches:
|
507
511
|
return False
|
512
|
+
elif use_regex:
|
513
|
+
keys = await self.get_file_metadata_async(client, bucket_name)
|
514
|
+
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
|
515
|
+
if not key_matches:
|
516
|
+
return False
|
508
517
|
else:
|
509
518
|
obj = await self.get_head_object_async(client, key, bucket_name)
|
510
519
|
if obj is None:
|
@@ -518,6 +527,7 @@ class S3Hook(AwsBaseHook):
|
|
518
527
|
bucket: str,
|
519
528
|
bucket_keys: str | list[str],
|
520
529
|
wildcard_match: bool,
|
530
|
+
use_regex: bool = False,
|
521
531
|
) -> bool:
|
522
532
|
"""
|
523
533
|
Get a list of files that a key matching a wildcard expression or get the head object.
|
@@ -530,14 +540,18 @@ class S3Hook(AwsBaseHook):
|
|
530
540
|
:param bucket: the name of the bucket
|
531
541
|
:param bucket_keys: S3 keys that will point to the file
|
532
542
|
:param wildcard_match: the path to the key
|
543
|
+
:param use_regex: whether to use regex to check bucket
|
533
544
|
"""
|
534
545
|
if isinstance(bucket_keys, list):
|
535
546
|
return all(
|
536
547
|
await asyncio.gather(
|
537
|
-
*(
|
548
|
+
*(
|
549
|
+
self._check_key_async(client, bucket, wildcard_match, key, use_regex)
|
550
|
+
for key in bucket_keys
|
551
|
+
)
|
538
552
|
)
|
539
553
|
)
|
540
|
-
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys)
|
554
|
+
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex)
|
541
555
|
|
542
556
|
async def check_for_prefix_async(
|
543
557
|
self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
|
@@ -98,13 +98,13 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
98
98
|
|
99
99
|
def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
|
100
100
|
super().set_context(ti)
|
101
|
-
_json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer")
|
101
|
+
_json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
|
102
102
|
self.handler = watchtower.CloudWatchLogHandler(
|
103
103
|
log_group_name=self.log_group,
|
104
104
|
log_stream_name=self._render_filename(ti, ti.try_number),
|
105
105
|
use_queues=not getattr(ti, "is_trigger_log_context", False),
|
106
106
|
boto3_client=self.hook.get_conn(),
|
107
|
-
json_serialize_default=_json_serialize,
|
107
|
+
json_serialize_default=_json_serialize or json_serialize_legacy,
|
108
108
|
)
|
109
109
|
|
110
110
|
def close(self):
|
@@ -230,7 +230,7 @@ class BatchOperator(BaseOperator):
|
|
230
230
|
region_name=self.region_name,
|
231
231
|
)
|
232
232
|
|
233
|
-
def execute(self, context: Context):
|
233
|
+
def execute(self, context: Context) -> str | None:
|
234
234
|
"""Submit and monitor an AWS Batch job.
|
235
235
|
|
236
236
|
:raises: AirflowException
|
@@ -238,28 +238,46 @@ class BatchOperator(BaseOperator):
|
|
238
238
|
self.submit_job(context)
|
239
239
|
|
240
240
|
if self.deferrable:
|
241
|
-
self.
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
241
|
+
if not self.job_id:
|
242
|
+
raise AirflowException("AWS Batch job - job_id was not found")
|
243
|
+
|
244
|
+
job = self.hook.get_job_description(self.job_id)
|
245
|
+
job_status = job.get("status")
|
246
|
+
if job_status == self.hook.SUCCESS_STATE:
|
247
|
+
self.log.info("Job completed.")
|
248
|
+
return self.job_id
|
249
|
+
elif job_status == self.hook.FAILURE_STATE:
|
250
|
+
raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state")
|
251
|
+
elif job_status in self.hook.INTERMEDIATE_STATES:
|
252
|
+
self.defer(
|
253
|
+
timeout=self.execution_timeout,
|
254
|
+
trigger=BatchJobTrigger(
|
255
|
+
job_id=self.job_id,
|
256
|
+
waiter_max_attempts=self.max_retries,
|
257
|
+
aws_conn_id=self.aws_conn_id,
|
258
|
+
region_name=self.region_name,
|
259
|
+
waiter_delay=self.poll_interval,
|
260
|
+
),
|
261
|
+
method_name="execute_complete",
|
262
|
+
)
|
263
|
+
|
264
|
+
raise AirflowException(f"Unexpected status: {job_status}")
|
252
265
|
|
253
266
|
if self.wait_for_completion:
|
254
267
|
self.monitor_job(context)
|
255
268
|
|
256
269
|
return self.job_id
|
257
270
|
|
258
|
-
def execute_complete(self, context, event=None):
|
271
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
272
|
+
if event is None:
|
273
|
+
err_msg = "Trigger error: event is None"
|
274
|
+
self.log.info(err_msg)
|
275
|
+
raise AirflowException(err_msg)
|
276
|
+
|
259
277
|
if event["status"] != "success":
|
260
278
|
raise AirflowException(f"Error while running job: {event}")
|
261
|
-
|
262
|
-
|
279
|
+
|
280
|
+
self.log.info("Job completed.")
|
263
281
|
return event["job_id"]
|
264
282
|
|
265
283
|
def on_kill(self):
|
@@ -15,66 +15,79 @@
|
|
15
15
|
# KIND, either express or implied. See the License for the
|
16
16
|
# specific language governing permissions and limitations
|
17
17
|
# under the License.
|
18
|
-
"""This module contains CloudFormation create/delete stack operators."""
|
18
|
+
"""This module contains AWS CloudFormation create/delete stack operators."""
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
21
|
from typing import TYPE_CHECKING, Sequence
|
22
22
|
|
23
|
-
from airflow.models import BaseOperator
|
24
23
|
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
|
24
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
25
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
25
26
|
|
26
27
|
if TYPE_CHECKING:
|
27
28
|
from airflow.utils.context import Context
|
28
29
|
|
29
30
|
|
30
|
-
class CloudFormationCreateStackOperator(
|
31
|
+
class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
|
31
32
|
"""
|
32
|
-
An operator that creates a CloudFormation stack.
|
33
|
+
An operator that creates a AWS CloudFormation stack.
|
33
34
|
|
34
35
|
.. seealso::
|
35
36
|
For more information on how to use this operator, take a look at the guide:
|
36
37
|
:ref:`howto/operator:CloudFormationCreateStackOperator`
|
37
38
|
|
38
39
|
:param stack_name: stack name (templated)
|
39
|
-
:param cloudformation_parameters: parameters to be passed to CloudFormation.
|
40
|
-
:param aws_conn_id:
|
40
|
+
:param cloudformation_parameters: parameters to be passed to AWS CloudFormation.
|
41
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
42
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
43
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
44
|
+
empty, then default boto3 configuration would be used (and must be
|
45
|
+
maintained on each worker node).
|
46
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
47
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
48
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
49
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
50
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
41
51
|
"""
|
42
52
|
|
43
|
-
|
44
|
-
|
53
|
+
aws_hook_class = CloudFormationHook
|
54
|
+
template_fields: Sequence[str] = aws_template_fields("stack_name", "cloudformation_parameters")
|
45
55
|
ui_color = "#6b9659"
|
46
56
|
|
47
|
-
def __init__(
|
48
|
-
self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs
|
49
|
-
):
|
57
|
+
def __init__(self, *, stack_name: str, cloudformation_parameters: dict, **kwargs):
|
50
58
|
super().__init__(**kwargs)
|
51
59
|
self.stack_name = stack_name
|
52
60
|
self.cloudformation_parameters = cloudformation_parameters
|
53
|
-
self.aws_conn_id = aws_conn_id
|
54
61
|
|
55
62
|
def execute(self, context: Context):
|
56
63
|
self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters)
|
57
|
-
|
58
|
-
cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
|
59
|
-
cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters)
|
64
|
+
self.hook.create_stack(self.stack_name, self.cloudformation_parameters)
|
60
65
|
|
61
66
|
|
62
|
-
class CloudFormationDeleteStackOperator(
|
67
|
+
class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
|
63
68
|
"""
|
64
|
-
An operator that deletes a CloudFormation stack.
|
65
|
-
|
66
|
-
:param stack_name: stack name (templated)
|
67
|
-
:param cloudformation_parameters: parameters to be passed to CloudFormation.
|
69
|
+
An operator that deletes a AWS CloudFormation stack.
|
68
70
|
|
69
71
|
.. seealso::
|
70
72
|
For more information on how to use this operator, take a look at the guide:
|
71
73
|
:ref:`howto/operator:CloudFormationDeleteStackOperator`
|
72
74
|
|
73
|
-
:param
|
75
|
+
:param stack_name: stack name (templated)
|
76
|
+
:param cloudformation_parameters: parameters to be passed to CloudFormation.
|
77
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
78
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
79
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
80
|
+
empty, then default boto3 configuration would be used (and must be
|
81
|
+
maintained on each worker node).
|
82
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
83
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
84
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
85
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
86
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
74
87
|
"""
|
75
88
|
|
76
|
-
|
77
|
-
|
89
|
+
aws_hook_class = CloudFormationHook
|
90
|
+
template_fields: Sequence[str] = aws_template_fields("stack_name")
|
78
91
|
ui_color = "#1d472b"
|
79
92
|
ui_fgcolor = "#FFF"
|
80
93
|
|
@@ -93,6 +106,4 @@ class CloudFormationDeleteStackOperator(BaseOperator):
|
|
93
106
|
|
94
107
|
def execute(self, context: Context):
|
95
108
|
self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)
|
96
|
-
|
97
|
-
cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
|
98
|
-
cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters)
|
109
|
+
self.hook.delete_stack(self.stack_name, self.cloudformation_parameters)
|
@@ -19,20 +19,20 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import logging
|
21
21
|
import random
|
22
|
-
from
|
23
|
-
from typing import TYPE_CHECKING, Sequence
|
22
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
24
23
|
|
25
24
|
from deprecated.classic import deprecated
|
26
25
|
|
27
26
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
|
28
|
-
from airflow.models import BaseOperator
|
29
27
|
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
|
28
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
29
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
30
30
|
|
31
31
|
if TYPE_CHECKING:
|
32
32
|
from airflow.utils.context import Context
|
33
33
|
|
34
34
|
|
35
|
-
class DataSyncOperator(
|
35
|
+
class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
|
36
36
|
"""Find, Create, Update, Execute and Delete AWS DataSync Tasks.
|
37
37
|
|
38
38
|
If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
|
@@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator):
|
|
46
46
|
environment. The default behavior is to create a new Task if there are 0, or
|
47
47
|
execute the Task if there was 1 Task, or fail if there were many Tasks.
|
48
48
|
|
49
|
-
:param aws_conn_id: AWS connection to use.
|
50
49
|
:param wait_interval_seconds: Time to wait between two
|
51
50
|
consecutive calls to check TaskExecution status.
|
52
51
|
:param max_iterations: Maximum number of
|
@@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator):
|
|
91
90
|
``boto3.start_task_execution(TaskArn=task_arn, **task_execution_kwargs)``
|
92
91
|
:param delete_task_after_execution: If True then the TaskArn which was executed
|
93
92
|
will be deleted from AWS DataSync on successful completion.
|
93
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
94
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
95
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
96
|
+
empty, then default boto3 configuration would be used (and must be
|
97
|
+
maintained on each worker node).
|
98
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
99
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
100
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
101
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
102
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
94
103
|
:raises AirflowException: If ``task_arn`` was not specified, or if
|
95
104
|
either ``source_location_uri`` or ``destination_location_uri`` were
|
96
105
|
not specified.
|
@@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator):
|
|
100
109
|
:raises AirflowException: If Task creation, update, execution or delete fails.
|
101
110
|
"""
|
102
111
|
|
103
|
-
|
112
|
+
aws_hook_class = DataSyncHook
|
113
|
+
template_fields: Sequence[str] = aws_template_fields(
|
104
114
|
"task_arn",
|
105
115
|
"source_location_uri",
|
106
116
|
"destination_location_uri",
|
@@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator):
|
|
122
132
|
def __init__(
|
123
133
|
self,
|
124
134
|
*,
|
125
|
-
aws_conn_id: str = "aws_default",
|
126
135
|
wait_interval_seconds: int = 30,
|
127
136
|
max_iterations: int = 60,
|
128
137
|
wait_for_completion: bool = True,
|
@@ -142,7 +151,6 @@ class DataSyncOperator(BaseOperator):
|
|
142
151
|
super().__init__(**kwargs)
|
143
152
|
|
144
153
|
# Assignments
|
145
|
-
self.aws_conn_id = aws_conn_id
|
146
154
|
self.wait_interval_seconds = wait_interval_seconds
|
147
155
|
self.max_iterations = max_iterations
|
148
156
|
self.wait_for_completion = wait_for_completion
|
@@ -185,16 +193,9 @@ class DataSyncOperator(BaseOperator):
|
|
185
193
|
self.destination_location_arn: str | None = None
|
186
194
|
self.task_execution_arn: str | None = None
|
187
195
|
|
188
|
-
@
|
189
|
-
def
|
190
|
-
""
|
191
|
-
|
192
|
-
:return DataSyncHook: An DataSyncHook instance.
|
193
|
-
"""
|
194
|
-
return DataSyncHook(
|
195
|
-
aws_conn_id=self.aws_conn_id,
|
196
|
-
wait_interval_seconds=self.wait_interval_seconds,
|
197
|
-
)
|
196
|
+
@property
|
197
|
+
def _hook_parameters(self) -> dict[str, Any]:
|
198
|
+
return {**super()._hook_parameters, "wait_interval_seconds": self.wait_interval_seconds}
|
198
199
|
|
199
200
|
@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
|
200
201
|
def get_hook(self) -> DataSyncHook:
|