apache-airflow-providers-amazon 8.16.0__py3-none-any.whl → 8.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -0
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +34 -19
  4. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +44 -1
  5. airflow/providers/amazon/aws/auth_manager/cli/__init__.py +16 -0
  6. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +178 -0
  7. airflow/providers/amazon/aws/auth_manager/cli/definition.py +62 -0
  8. airflow/providers/amazon/aws/auth_manager/cli/schema.json +171 -0
  9. airflow/providers/amazon/aws/auth_manager/constants.py +1 -0
  10. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +77 -23
  11. airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +17 -0
  12. airflow/providers/amazon/aws/executors/ecs/utils.py +1 -1
  13. airflow/providers/amazon/aws/executors/utils/__init__.py +16 -0
  14. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +60 -0
  15. airflow/providers/amazon/aws/hooks/athena_sql.py +168 -0
  16. airflow/providers/amazon/aws/hooks/base_aws.py +14 -0
  17. airflow/providers/amazon/aws/hooks/quicksight.py +33 -18
  18. airflow/providers/amazon/aws/hooks/redshift_data.py +66 -17
  19. airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -1
  20. airflow/providers/amazon/aws/hooks/s3.py +18 -4
  21. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  22. airflow/providers/amazon/aws/operators/batch.py +33 -15
  23. airflow/providers/amazon/aws/operators/cloud_formation.py +37 -26
  24. airflow/providers/amazon/aws/operators/datasync.py +19 -18
  25. airflow/providers/amazon/aws/operators/dms.py +57 -69
  26. airflow/providers/amazon/aws/operators/ec2.py +19 -5
  27. airflow/providers/amazon/aws/operators/emr.py +30 -10
  28. airflow/providers/amazon/aws/operators/eventbridge.py +57 -80
  29. airflow/providers/amazon/aws/operators/quicksight.py +17 -24
  30. airflow/providers/amazon/aws/operators/redshift_data.py +68 -19
  31. airflow/providers/amazon/aws/operators/s3.py +1 -1
  32. airflow/providers/amazon/aws/operators/sagemaker.py +42 -12
  33. airflow/providers/amazon/aws/sensors/cloud_formation.py +30 -25
  34. airflow/providers/amazon/aws/sensors/dms.py +31 -24
  35. airflow/providers/amazon/aws/sensors/dynamodb.py +15 -15
  36. airflow/providers/amazon/aws/sensors/quicksight.py +34 -24
  37. airflow/providers/amazon/aws/sensors/redshift_cluster.py +41 -3
  38. airflow/providers/amazon/aws/sensors/s3.py +13 -8
  39. airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
  40. airflow/providers/amazon/aws/triggers/redshift_data.py +113 -0
  41. airflow/providers/amazon/aws/triggers/s3.py +9 -4
  42. airflow/providers/amazon/get_provider_info.py +55 -16
  43. {apache_airflow_providers_amazon-8.16.0.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/METADATA +15 -13
  44. {apache_airflow_providers_amazon-8.16.0.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/RECORD +46 -38
  45. {apache_airflow_providers_amazon-8.16.0.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/WHEEL +0 -0
  46. {apache_airflow_providers_amazon-8.16.0.dist-info → apache_airflow_providers_amazon-8.17.0.dist-info}/entry_points.txt +0 -0
@@ -26,6 +26,21 @@ from airflow.providers.amazon.aws.utils import trim_none_values
26
26
 
27
27
  if TYPE_CHECKING:
28
28
  from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
29
+ from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef
30
+
31
+ FINISHED_STATE = "FINISHED"
32
+ FAILED_STATE = "FAILED"
33
+ ABORTED_STATE = "ABORTED"
34
+ FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
35
+ RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
36
+
37
+
38
+ class RedshiftDataQueryFailedError(ValueError):
39
+ """Raise an error that redshift data query failed."""
40
+
41
+
42
+ class RedshiftDataQueryAbortedError(ValueError):
43
+ """Raise an error that redshift data query was aborted."""
29
44
 
30
45
 
31
46
  class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
@@ -108,27 +123,40 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
108
123
 
109
124
  return statement_id
110
125
 
111
- def wait_for_results(self, statement_id, poll_interval):
126
+ def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
112
127
  while True:
113
128
  self.log.info("Polling statement %s", statement_id)
114
- resp = self.conn.describe_statement(
115
- Id=statement_id,
116
- )
117
- status = resp["Status"]
118
- if status == "FINISHED":
119
- num_rows = resp.get("ResultRows")
120
- if num_rows is not None:
121
- self.log.info("Processed %s rows", num_rows)
122
- return status
123
- elif status in ("FAILED", "ABORTED"):
124
- raise ValueError(
125
- f"Statement {statement_id!r} terminated with status {status}. "
126
- f"Response details: {pformat(resp)}"
127
- )
128
- else:
129
- self.log.info("Query %s", status)
129
+ is_finished = self.check_query_is_finished(statement_id)
130
+ if is_finished:
131
+ return FINISHED_STATE
132
+
130
133
  time.sleep(poll_interval)
131
134
 
135
+ def check_query_is_finished(self, statement_id: str) -> bool:
136
+ """Check whether query finished, raise exception is failed."""
137
+ resp = self.conn.describe_statement(Id=statement_id)
138
+ return self.parse_statement_resposne(resp)
139
+
140
+ def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool:
141
+ """Parse the response of describe_statement."""
142
+ status = resp["Status"]
143
+ if status == FINISHED_STATE:
144
+ num_rows = resp.get("ResultRows")
145
+ if num_rows is not None:
146
+ self.log.info("Processed %s rows", num_rows)
147
+ return True
148
+ elif status in FAILURE_STATES:
149
+ exception_cls = (
150
+ RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError
151
+ )
152
+ raise exception_cls(
153
+ f"Statement {resp['Id']} terminated with status {status}. "
154
+ f"Response details: {pformat(resp)}"
155
+ )
156
+
157
+ self.log.info("Query status: %s", status)
158
+ return False
159
+
132
160
  def get_table_primary_key(
133
161
  self,
134
162
  table: str,
@@ -201,3 +229,24 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
201
229
  break
202
230
 
203
231
  return pk_columns or None
232
+
233
+ async def is_still_running(self, statement_id: str) -> bool:
234
+ """Async function to check whether the query is still running.
235
+
236
+ :param statement_id: the UUID of the statement
237
+ """
238
+ async with self.async_conn as client:
239
+ desc = await client.describe_statement(Id=statement_id)
240
+ return desc["Status"] in RUNNING_STATES
241
+
242
+ async def check_query_is_finished_async(self, statement_id: str) -> bool:
243
+ """Async function to check statement is finished.
244
+
245
+ It takes statement_id, makes async connection to redshift data to get the query status
246
+ by statement_id and returns the query status.
247
+
248
+ :param statement_id: the UUID of the statement
249
+ """
250
+ async with self.async_conn as client:
251
+ resp = await client.describe_statement(Id=statement_id)
252
+ return self.parse_statement_resposne(resp)
@@ -239,7 +239,7 @@ class RedshiftSQLHook(DbApiHook):
239
239
 
240
240
  def _get_identifier_from_hostname(self, hostname: str) -> str:
241
241
  parts = hostname.split(".")
242
- if "amazonaws.com" in hostname and len(parts) == 6:
242
+ if hostname.endswith("amazonaws.com") and len(parts) == 6:
243
243
  return f"{parts[0]}.{parts[2]}"
244
244
  else:
245
245
  self.log.debug(
@@ -462,7 +462,9 @@ class S3Hook(AwsBaseHook):
462
462
  return prefixes
463
463
 
464
464
  @provide_bucket_name_async
465
- async def get_file_metadata_async(self, client: AioBaseClient, bucket_name: str, key: str) -> list[Any]:
465
+ async def get_file_metadata_async(
466
+ self, client: AioBaseClient, bucket_name: str, key: str | None = None
467
+ ) -> list[Any]:
466
468
  """
467
469
  Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
468
470
 
@@ -470,7 +472,7 @@ class S3Hook(AwsBaseHook):
470
472
  :param bucket_name: the name of the bucket
471
473
  :param key: the path to the key
472
474
  """
473
- prefix = re.split(r"[\[*?]", key, 1)[0]
475
+ prefix = re.split(r"[\[\*\?]", key, 1)[0] if key else ""
474
476
  delimiter = ""
475
477
  paginator = client.get_paginator("list_objects_v2")
476
478
  response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
@@ -486,6 +488,7 @@ class S3Hook(AwsBaseHook):
486
488
  bucket_val: str,
487
489
  wildcard_match: bool,
488
490
  key: str,
491
+ use_regex: bool = False,
489
492
  ) -> bool:
490
493
  """
491
494
  Get a list of files that a key matching a wildcard expression or get the head object.
@@ -498,6 +501,7 @@ class S3Hook(AwsBaseHook):
498
501
  :param bucket_val: the name of the bucket
499
502
  :param key: S3 keys that will point to the file
500
503
  :param wildcard_match: the path to the key
504
+ :param use_regex: whether to use regex to check bucket
501
505
  """
502
506
  bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
503
507
  if wildcard_match:
@@ -505,6 +509,11 @@ class S3Hook(AwsBaseHook):
505
509
  key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
506
510
  if not key_matches:
507
511
  return False
512
+ elif use_regex:
513
+ keys = await self.get_file_metadata_async(client, bucket_name)
514
+ key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
515
+ if not key_matches:
516
+ return False
508
517
  else:
509
518
  obj = await self.get_head_object_async(client, key, bucket_name)
510
519
  if obj is None:
@@ -518,6 +527,7 @@ class S3Hook(AwsBaseHook):
518
527
  bucket: str,
519
528
  bucket_keys: str | list[str],
520
529
  wildcard_match: bool,
530
+ use_regex: bool = False,
521
531
  ) -> bool:
522
532
  """
523
533
  Get a list of files that a key matching a wildcard expression or get the head object.
@@ -530,14 +540,18 @@ class S3Hook(AwsBaseHook):
530
540
  :param bucket: the name of the bucket
531
541
  :param bucket_keys: S3 keys that will point to the file
532
542
  :param wildcard_match: the path to the key
543
+ :param use_regex: whether to use regex to check bucket
533
544
  """
534
545
  if isinstance(bucket_keys, list):
535
546
  return all(
536
547
  await asyncio.gather(
537
- *(self._check_key_async(client, bucket, wildcard_match, key) for key in bucket_keys)
548
+ *(
549
+ self._check_key_async(client, bucket, wildcard_match, key, use_regex)
550
+ for key in bucket_keys
551
+ )
538
552
  )
539
553
  )
540
- return await self._check_key_async(client, bucket, wildcard_match, bucket_keys)
554
+ return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex)
541
555
 
542
556
  async def check_for_prefix_async(
543
557
  self, client: AioBaseClient, prefix: str, delimiter: str, bucket_name: str | None = None
@@ -98,13 +98,13 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
98
98
 
99
99
  def set_context(self, ti: TaskInstance, *, identifier: str | None = None):
100
100
  super().set_context(ti)
101
- _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer")
101
+ _json_serialize = conf.getimport("aws", "cloudwatch_task_handler_json_serializer", fallback=None)
102
102
  self.handler = watchtower.CloudWatchLogHandler(
103
103
  log_group_name=self.log_group,
104
104
  log_stream_name=self._render_filename(ti, ti.try_number),
105
105
  use_queues=not getattr(ti, "is_trigger_log_context", False),
106
106
  boto3_client=self.hook.get_conn(),
107
- json_serialize_default=_json_serialize,
107
+ json_serialize_default=_json_serialize or json_serialize_legacy,
108
108
  )
109
109
 
110
110
  def close(self):
@@ -230,7 +230,7 @@ class BatchOperator(BaseOperator):
230
230
  region_name=self.region_name,
231
231
  )
232
232
 
233
- def execute(self, context: Context):
233
+ def execute(self, context: Context) -> str | None:
234
234
  """Submit and monitor an AWS Batch job.
235
235
 
236
236
  :raises: AirflowException
@@ -238,28 +238,46 @@ class BatchOperator(BaseOperator):
238
238
  self.submit_job(context)
239
239
 
240
240
  if self.deferrable:
241
- self.defer(
242
- timeout=self.execution_timeout,
243
- trigger=BatchJobTrigger(
244
- job_id=self.job_id,
245
- waiter_max_attempts=self.max_retries,
246
- aws_conn_id=self.aws_conn_id,
247
- region_name=self.region_name,
248
- waiter_delay=self.poll_interval,
249
- ),
250
- method_name="execute_complete",
251
- )
241
+ if not self.job_id:
242
+ raise AirflowException("AWS Batch job - job_id was not found")
243
+
244
+ job = self.hook.get_job_description(self.job_id)
245
+ job_status = job.get("status")
246
+ if job_status == self.hook.SUCCESS_STATE:
247
+ self.log.info("Job completed.")
248
+ return self.job_id
249
+ elif job_status == self.hook.FAILURE_STATE:
250
+ raise AirflowException(f"Error while running job: {self.job_id} is in {job_status} state")
251
+ elif job_status in self.hook.INTERMEDIATE_STATES:
252
+ self.defer(
253
+ timeout=self.execution_timeout,
254
+ trigger=BatchJobTrigger(
255
+ job_id=self.job_id,
256
+ waiter_max_attempts=self.max_retries,
257
+ aws_conn_id=self.aws_conn_id,
258
+ region_name=self.region_name,
259
+ waiter_delay=self.poll_interval,
260
+ ),
261
+ method_name="execute_complete",
262
+ )
263
+
264
+ raise AirflowException(f"Unexpected status: {job_status}")
252
265
 
253
266
  if self.wait_for_completion:
254
267
  self.monitor_job(context)
255
268
 
256
269
  return self.job_id
257
270
 
258
- def execute_complete(self, context, event=None):
271
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
272
+ if event is None:
273
+ err_msg = "Trigger error: event is None"
274
+ self.log.info(err_msg)
275
+ raise AirflowException(err_msg)
276
+
259
277
  if event["status"] != "success":
260
278
  raise AirflowException(f"Error while running job: {event}")
261
- else:
262
- self.log.info("Job completed.")
279
+
280
+ self.log.info("Job completed.")
263
281
  return event["job_id"]
264
282
 
265
283
  def on_kill(self):
@@ -15,66 +15,79 @@
15
15
  # KIND, either express or implied. See the License for the
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
- """This module contains CloudFormation create/delete stack operators."""
18
+ """This module contains AWS CloudFormation create/delete stack operators."""
19
19
  from __future__ import annotations
20
20
 
21
21
  from typing import TYPE_CHECKING, Sequence
22
22
 
23
- from airflow.models import BaseOperator
24
23
  from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
24
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
25
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
25
26
 
26
27
  if TYPE_CHECKING:
27
28
  from airflow.utils.context import Context
28
29
 
29
30
 
30
- class CloudFormationCreateStackOperator(BaseOperator):
31
+ class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
31
32
  """
32
- An operator that creates a CloudFormation stack.
33
+ An operator that creates a AWS CloudFormation stack.
33
34
 
34
35
  .. seealso::
35
36
  For more information on how to use this operator, take a look at the guide:
36
37
  :ref:`howto/operator:CloudFormationCreateStackOperator`
37
38
 
38
39
  :param stack_name: stack name (templated)
39
- :param cloudformation_parameters: parameters to be passed to CloudFormation.
40
- :param aws_conn_id: aws connection to uses
40
+ :param cloudformation_parameters: parameters to be passed to AWS CloudFormation.
41
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
42
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
43
+ running Airflow in a distributed manner and aws_conn_id is None or
44
+ empty, then default boto3 configuration would be used (and must be
45
+ maintained on each worker node).
46
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
47
+ :param verify: Whether or not to verify SSL certificates. See:
48
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
49
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
50
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
41
51
  """
42
52
 
43
- template_fields: Sequence[str] = ("stack_name", "cloudformation_parameters")
44
- template_ext: Sequence[str] = ()
53
+ aws_hook_class = CloudFormationHook
54
+ template_fields: Sequence[str] = aws_template_fields("stack_name", "cloudformation_parameters")
45
55
  ui_color = "#6b9659"
46
56
 
47
- def __init__(
48
- self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs
49
- ):
57
+ def __init__(self, *, stack_name: str, cloudformation_parameters: dict, **kwargs):
50
58
  super().__init__(**kwargs)
51
59
  self.stack_name = stack_name
52
60
  self.cloudformation_parameters = cloudformation_parameters
53
- self.aws_conn_id = aws_conn_id
54
61
 
55
62
  def execute(self, context: Context):
56
63
  self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters)
57
-
58
- cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
59
- cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters)
64
+ self.hook.create_stack(self.stack_name, self.cloudformation_parameters)
60
65
 
61
66
 
62
- class CloudFormationDeleteStackOperator(BaseOperator):
67
+ class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
63
68
  """
64
- An operator that deletes a CloudFormation stack.
65
-
66
- :param stack_name: stack name (templated)
67
- :param cloudformation_parameters: parameters to be passed to CloudFormation.
69
+ An operator that deletes a AWS CloudFormation stack.
68
70
 
69
71
  .. seealso::
70
72
  For more information on how to use this operator, take a look at the guide:
71
73
  :ref:`howto/operator:CloudFormationDeleteStackOperator`
72
74
 
73
- :param aws_conn_id: aws connection to uses
75
+ :param stack_name: stack name (templated)
76
+ :param cloudformation_parameters: parameters to be passed to CloudFormation.
77
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
78
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
79
+ running Airflow in a distributed manner and aws_conn_id is None or
80
+ empty, then default boto3 configuration would be used (and must be
81
+ maintained on each worker node).
82
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
83
+ :param verify: Whether or not to verify SSL certificates. See:
84
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
85
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
86
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
74
87
  """
75
88
 
76
- template_fields: Sequence[str] = ("stack_name",)
77
- template_ext: Sequence[str] = ()
89
+ aws_hook_class = CloudFormationHook
90
+ template_fields: Sequence[str] = aws_template_fields("stack_name")
78
91
  ui_color = "#1d472b"
79
92
  ui_fgcolor = "#FFF"
80
93
 
@@ -93,6 +106,4 @@ class CloudFormationDeleteStackOperator(BaseOperator):
93
106
 
94
107
  def execute(self, context: Context):
95
108
  self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)
96
-
97
- cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
98
- cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters)
109
+ self.hook.delete_stack(self.stack_name, self.cloudformation_parameters)
@@ -19,20 +19,20 @@ from __future__ import annotations
19
19
 
20
20
  import logging
21
21
  import random
22
- from functools import cached_property
23
- from typing import TYPE_CHECKING, Sequence
22
+ from typing import TYPE_CHECKING, Any, Sequence
24
23
 
25
24
  from deprecated.classic import deprecated
26
25
 
27
26
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
28
- from airflow.models import BaseOperator
29
27
  from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
28
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
29
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from airflow.utils.context import Context
33
33
 
34
34
 
35
- class DataSyncOperator(BaseOperator):
35
+ class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
36
36
  """Find, Create, Update, Execute and Delete AWS DataSync Tasks.
37
37
 
38
38
  If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
@@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator):
46
46
  environment. The default behavior is to create a new Task if there are 0, or
47
47
  execute the Task if there was 1 Task, or fail if there were many Tasks.
48
48
 
49
- :param aws_conn_id: AWS connection to use.
50
49
  :param wait_interval_seconds: Time to wait between two
51
50
  consecutive calls to check TaskExecution status.
52
51
  :param max_iterations: Maximum number of
@@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator):
91
90
  ``boto3.start_task_execution(TaskArn=task_arn, **task_execution_kwargs)``
92
91
  :param delete_task_after_execution: If True then the TaskArn which was executed
93
92
  will be deleted from AWS DataSync on successful completion.
93
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
94
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
95
+ running Airflow in a distributed manner and aws_conn_id is None or
96
+ empty, then default boto3 configuration would be used (and must be
97
+ maintained on each worker node).
98
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
99
+ :param verify: Whether or not to verify SSL certificates. See:
100
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
101
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
102
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
94
103
  :raises AirflowException: If ``task_arn`` was not specified, or if
95
104
  either ``source_location_uri`` or ``destination_location_uri`` were
96
105
  not specified.
@@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator):
100
109
  :raises AirflowException: If Task creation, update, execution or delete fails.
101
110
  """
102
111
 
103
- template_fields: Sequence[str] = (
112
+ aws_hook_class = DataSyncHook
113
+ template_fields: Sequence[str] = aws_template_fields(
104
114
  "task_arn",
105
115
  "source_location_uri",
106
116
  "destination_location_uri",
@@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator):
122
132
  def __init__(
123
133
  self,
124
134
  *,
125
- aws_conn_id: str = "aws_default",
126
135
  wait_interval_seconds: int = 30,
127
136
  max_iterations: int = 60,
128
137
  wait_for_completion: bool = True,
@@ -142,7 +151,6 @@ class DataSyncOperator(BaseOperator):
142
151
  super().__init__(**kwargs)
143
152
 
144
153
  # Assignments
145
- self.aws_conn_id = aws_conn_id
146
154
  self.wait_interval_seconds = wait_interval_seconds
147
155
  self.max_iterations = max_iterations
148
156
  self.wait_for_completion = wait_for_completion
@@ -185,16 +193,9 @@ class DataSyncOperator(BaseOperator):
185
193
  self.destination_location_arn: str | None = None
186
194
  self.task_execution_arn: str | None = None
187
195
 
188
- @cached_property
189
- def hook(self) -> DataSyncHook:
190
- """Create and return DataSyncHook.
191
-
192
- :return DataSyncHook: An DataSyncHook instance.
193
- """
194
- return DataSyncHook(
195
- aws_conn_id=self.aws_conn_id,
196
- wait_interval_seconds=self.wait_interval_seconds,
197
- )
196
+ @property
197
+ def _hook_parameters(self) -> dict[str, Any]:
198
+ return {**super()._hook_parameters, "wait_interval_seconds": self.wait_interval_seconds}
198
199
 
199
200
  @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
200
201
  def get_hook(self) -> DataSyncHook: