apache-airflow-providers-amazon 8.25.0rc1__py3-none-any.whl → 8.26.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
  3. airflow/providers/amazon/aws/executors/batch/batch_executor.py +19 -16
  4. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +22 -15
  5. airflow/providers/amazon/aws/hooks/athena.py +18 -9
  6. airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
  7. airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
  8. airflow/providers/amazon/aws/hooks/chime.py +2 -1
  9. airflow/providers/amazon/aws/hooks/datasync.py +6 -3
  10. airflow/providers/amazon/aws/hooks/ecr.py +2 -1
  11. airflow/providers/amazon/aws/hooks/ecs.py +12 -6
  12. airflow/providers/amazon/aws/hooks/glacier.py +8 -4
  13. airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
  14. airflow/providers/amazon/aws/hooks/logs.py +4 -2
  15. airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
  16. airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
  17. airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
  18. airflow/providers/amazon/aws/hooks/s3.py +70 -53
  19. airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
  20. airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
  21. airflow/providers/amazon/aws/hooks/sts.py +2 -1
  22. airflow/providers/amazon/aws/operators/athena.py +21 -8
  23. airflow/providers/amazon/aws/operators/batch.py +12 -6
  24. airflow/providers/amazon/aws/operators/datasync.py +2 -1
  25. airflow/providers/amazon/aws/operators/ecs.py +1 -0
  26. airflow/providers/amazon/aws/operators/emr.py +6 -86
  27. airflow/providers/amazon/aws/operators/glue.py +4 -2
  28. airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
  29. airflow/providers/amazon/aws/operators/neptune.py +2 -1
  30. airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
  31. airflow/providers/amazon/aws/operators/s3.py +11 -1
  32. airflow/providers/amazon/aws/operators/sagemaker.py +8 -10
  33. airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
  34. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
  35. airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
  36. airflow/providers/amazon/aws/sensors/s3.py +11 -5
  37. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
  38. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
  39. airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
  40. airflow/providers/amazon/aws/triggers/ecs.py +3 -1
  41. airflow/providers/amazon/aws/triggers/glue.py +15 -3
  42. airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
  43. airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
  44. airflow/providers/amazon/aws/utils/mixins.py +2 -1
  45. airflow/providers/amazon/aws/utils/redshift.py +2 -1
  46. airflow/providers/amazon/get_provider_info.py +2 -1
  47. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/METADATA +9 -9
  48. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/RECORD +50 -50
  49. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/WHEEL +0 -0
  50. {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBas
27
27
 
28
28
 
29
29
  class RedshiftHook(AwsBaseHook):
30
- """Interact with Amazon Redshift.
30
+ """
31
+ Interact with Amazon Redshift.
31
32
 
32
33
  This is a thin wrapper around
33
34
  :external+boto3:py:class:`boto3.client("redshift") <Redshift.Client>`.
@@ -53,7 +54,8 @@ class RedshiftHook(AwsBaseHook):
53
54
  master_user_password: str,
54
55
  params: dict[str, Any],
55
56
  ) -> dict[str, Any]:
56
- """Create a new cluster with the specified parameters.
57
+ """
58
+ Create a new cluster with the specified parameters.
57
59
 
58
60
  .. seealso::
59
61
  - :external+boto3:py:meth:`Redshift.Client.create_cluster`
@@ -80,7 +82,8 @@ class RedshiftHook(AwsBaseHook):
80
82
 
81
83
  # TODO: Wrap create_cluster_snapshot
82
84
  def cluster_status(self, cluster_identifier: str) -> str:
83
- """Get status of a cluster.
85
+ """
86
+ Get status of a cluster.
84
87
 
85
88
  .. seealso::
86
89
  - :external+boto3:py:meth:`Redshift.Client.describe_clusters`
@@ -101,7 +104,8 @@ class RedshiftHook(AwsBaseHook):
101
104
  skip_final_cluster_snapshot: bool = True,
102
105
  final_cluster_snapshot_identifier: str | None = None,
103
106
  ):
104
- """Delete a cluster and optionally create a snapshot.
107
+ """
108
+ Delete a cluster and optionally create a snapshot.
105
109
 
106
110
  .. seealso::
107
111
  - :external+boto3:py:meth:`Redshift.Client.delete_cluster`
@@ -120,7 +124,8 @@ class RedshiftHook(AwsBaseHook):
120
124
  return response["Cluster"] if response["Cluster"] else None
121
125
 
122
126
  def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str] | None:
123
- """List snapshots for a cluster.
127
+ """
128
+ List snapshots for a cluster.
124
129
 
125
130
  .. seealso::
126
131
  - :external+boto3:py:meth:`Redshift.Client.describe_cluster_snapshots`
@@ -136,7 +141,8 @@ class RedshiftHook(AwsBaseHook):
136
141
  return snapshots
137
142
 
138
143
  def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str:
139
- """Restore a cluster from its snapshot.
144
+ """
145
+ Restore a cluster from its snapshot.
140
146
 
141
147
  .. seealso::
142
148
  - :external+boto3:py:meth:`Redshift.Client.restore_from_cluster_snapshot`
@@ -156,7 +162,8 @@ class RedshiftHook(AwsBaseHook):
156
162
  retention_period: int = -1,
157
163
  tags: list[Any] | None = None,
158
164
  ) -> str:
159
- """Create a snapshot of a cluster.
165
+ """
166
+ Create a snapshot of a cluster.
160
167
 
161
168
  .. seealso::
162
169
  - :external+boto3:py:meth:`Redshift.Client.create_cluster_snapshot`
@@ -178,7 +185,8 @@ class RedshiftHook(AwsBaseHook):
178
185
  return response["Snapshot"] if response["Snapshot"] else None
179
186
 
180
187
  def get_cluster_snapshot_status(self, snapshot_identifier: str):
181
- """Get Redshift cluster snapshot status.
188
+ """
189
+ Get Redshift cluster snapshot status.
182
190
 
183
191
  If cluster snapshot not found, *None* is returned.
184
192
 
@@ -210,7 +218,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
210
218
  super().__init__(*args, **kwargs)
211
219
 
212
220
  async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]:
213
- """Get the cluster status.
221
+ """
222
+ Get the cluster status.
214
223
 
215
224
  :param cluster_identifier: unique identifier of a cluster
216
225
  :param delete_operation: whether the method has been called as part of delete cluster operation
@@ -228,7 +237,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
228
237
  return {"status": "error", "message": str(error)}
229
238
 
230
239
  async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]:
231
- """Pause the cluster.
240
+ """
241
+ Pause the cluster.
232
242
 
233
243
  :param cluster_identifier: unique identifier of a cluster
234
244
  :param poll_interval: polling period in seconds to check for the status
@@ -255,7 +265,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
255
265
  cluster_identifier: str,
256
266
  polling_period_seconds: float = 5.0,
257
267
  ) -> dict[str, Any]:
258
- """Resume the cluster.
268
+ """
269
+ Resume the cluster.
259
270
 
260
271
  :param cluster_identifier: unique identifier of a cluster
261
272
  :param polling_period_seconds: polling period in seconds to check for the status
@@ -284,7 +295,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
284
295
  flag: asyncio.Event,
285
296
  delete_operation: bool = False,
286
297
  ) -> dict[str, Any]:
287
- """Check for expected Redshift cluster state.
298
+ """
299
+ Check for expected Redshift cluster state.
288
300
 
289
301
  :param cluster_identifier: unique identifier of a cluster
290
302
  :param expected_state: expected_state example("available", "pausing", "paused"")
@@ -231,7 +231,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
231
231
  return pk_columns or None
232
232
 
233
233
  async def is_still_running(self, statement_id: str) -> bool:
234
- """Async function to check whether the query is still running.
234
+ """
235
+ Async function to check whether the query is still running.
235
236
 
236
237
  :param statement_id: the UUID of the statement
237
238
  """
@@ -240,7 +241,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
240
241
  return desc["Status"] in RUNNING_STATES
241
242
 
242
243
  async def check_query_is_finished_async(self, statement_id: str) -> bool:
243
- """Async function to check statement is finished.
244
+ """
245
+ Async function to check statement is finished.
244
246
 
245
247
  It takes statement_id, makes async connection to redshift data to get the query status
246
248
  by statement_id and returns the query status.
@@ -39,7 +39,8 @@ if TYPE_CHECKING:
39
39
 
40
40
 
41
41
  class RedshiftSQLHook(DbApiHook):
42
- """Execute statements against Amazon Redshift.
42
+ """
43
+ Execute statements against Amazon Redshift.
43
44
 
44
45
  This hook requires the redshift_conn_id connection.
45
46
 
@@ -103,7 +104,8 @@ class RedshiftSQLHook(DbApiHook):
103
104
  return conn_params
104
105
 
105
106
  def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
106
- """Retrieve a temporary password to connect to Redshift.
107
+ """
108
+ Retrieve a temporary password to connect to Redshift.
107
109
 
108
110
  Port is required. If none is provided, default is used for each service.
109
111
  """
@@ -177,7 +179,8 @@ class RedshiftSQLHook(DbApiHook):
177
179
  return create_engine(self.get_uri(), **engine_kwargs)
178
180
 
179
181
  def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None:
180
- """Get the table's primary key.
182
+ """
183
+ Get the table's primary key.
181
184
 
182
185
  :param table: Name of the target table
183
186
  :param schema: Name of the target schema, public by default
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import asyncio
23
23
  import fnmatch
24
24
  import gzip as gz
25
+ import inspect
25
26
  import logging
26
27
  import os
27
28
  import re
@@ -36,7 +37,7 @@ from inspect import signature
36
37
  from io import BytesIO
37
38
  from pathlib import Path
38
39
  from tempfile import NamedTemporaryFile, gettempdir
39
- from typing import TYPE_CHECKING, Any, Callable
40
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
40
41
  from urllib.parse import urlsplit
41
42
  from uuid import uuid4
42
43
 
@@ -61,42 +62,24 @@ from airflow.utils.helpers import chunks
61
62
  logger = logging.getLogger(__name__)
62
63
 
63
64
 
65
+ # Explicit value that would remove ACLs from a copy
66
+ # No conflicts with Canned ACLs:
67
+ # https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
68
+ NO_ACL = "no-acl"
69
+
70
+
64
71
  def provide_bucket_name(func: Callable) -> Callable:
65
72
  """Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
66
73
  if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
67
74
  logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
68
- function_signature = signature(func)
69
-
70
- @wraps(func)
71
- def wrapper(*args, **kwargs) -> Callable:
72
- bound_args = function_signature.bind(*args, **kwargs)
73
-
74
- if "bucket_name" not in bound_args.arguments:
75
- self = args[0]
76
-
77
- if "bucket_name" in self.service_config:
78
- bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
79
- elif self.conn_config and self.conn_config.schema:
80
- warnings.warn(
81
- "s3 conn_type, and the associated schema field, is deprecated. "
82
- "Please use aws conn_type instead, and specify `bucket_name` "
83
- "in `service_config.s3` within `extras`.",
84
- AirflowProviderDeprecationWarning,
85
- stacklevel=2,
86
- )
87
- bound_args.arguments["bucket_name"] = self.conn_config.schema
88
-
89
- return func(*bound_args.args, **bound_args.kwargs)
90
75
 
91
- return wrapper
92
-
93
-
94
- def provide_bucket_name_async(func: Callable) -> Callable:
95
- """Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
96
76
  function_signature = signature(func)
77
+ if "bucket_name" not in function_signature.parameters:
78
+ raise RuntimeError(
79
+ "Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
80
+ )
97
81
 
98
- @wraps(func)
99
- async def wrapper(*args: Any, **kwargs: Any) -> Any:
82
+ async def maybe_add_bucket_name(*args, **kwargs):
100
83
  bound_args = function_signature.bind(*args, **kwargs)
101
84
 
102
85
  if "bucket_name" not in bound_args.arguments:
@@ -105,8 +88,46 @@ def provide_bucket_name_async(func: Callable) -> Callable:
105
88
  connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
106
89
  if connection.schema:
107
90
  bound_args.arguments["bucket_name"] = connection.schema
91
+ return bound_args
92
+
93
+ if inspect.iscoroutinefunction(func):
94
+
95
+ @wraps(func)
96
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
97
+ bound_args = await maybe_add_bucket_name(*args, **kwargs)
98
+ print(f"invoking async function {func=}")
99
+ return await func(*bound_args.args, **bound_args.kwargs)
100
+
101
+ elif inspect.isasyncgenfunction(func):
102
+
103
+ @wraps(func)
104
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
105
+ bound_args = await maybe_add_bucket_name(*args, **kwargs)
106
+ async for thing in func(*bound_args.args, **bound_args.kwargs):
107
+ yield thing
108
+
109
+ else:
110
+
111
+ @wraps(func)
112
+ def wrapper(*args, **kwargs) -> Callable:
113
+ bound_args = function_signature.bind(*args, **kwargs)
114
+
115
+ if "bucket_name" not in bound_args.arguments:
116
+ self = args[0]
117
+
118
+ if "bucket_name" in self.service_config:
119
+ bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
120
+ elif self.conn_config and self.conn_config.schema:
121
+ warnings.warn(
122
+ "s3 conn_type, and the associated schema field, is deprecated. "
123
+ "Please use aws conn_type instead, and specify `bucket_name` "
124
+ "in `service_config.s3` within `extras`.",
125
+ AirflowProviderDeprecationWarning,
126
+ stacklevel=2,
127
+ )
128
+ bound_args.arguments["bucket_name"] = self.conn_config.schema
108
129
 
109
- return await func(*bound_args.args, **bound_args.kwargs)
130
+ return func(*bound_args.args, **bound_args.kwargs)
110
131
 
111
132
  return wrapper
112
133
 
@@ -400,8 +421,8 @@ class S3Hook(AwsBaseHook):
400
421
 
401
422
  return prefixes
402
423
 
403
- @provide_bucket_name_async
404
424
  @unify_bucket_name_and_key
425
+ @provide_bucket_name
405
426
  async def get_head_object_async(
406
427
  self, client: AioBaseClient, key: str, bucket_name: str | None = None
407
428
  ) -> dict[str, Any] | None:
@@ -462,10 +483,10 @@ class S3Hook(AwsBaseHook):
462
483
 
463
484
  return prefixes
464
485
 
465
- @provide_bucket_name_async
486
+ @provide_bucket_name
466
487
  async def get_file_metadata_async(
467
488
  self, client: AioBaseClient, bucket_name: str, key: str | None = None
468
- ) -> list[Any]:
489
+ ) -> AsyncIterator[Any]:
469
490
  """
470
491
  Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
471
492
 
@@ -477,11 +498,10 @@ class S3Hook(AwsBaseHook):
477
498
  delimiter = ""
478
499
  paginator = client.get_paginator("list_objects_v2")
479
500
  response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
480
- files = []
481
501
  async for page in response:
482
502
  if "Contents" in page:
483
- files += page["Contents"]
484
- return files
503
+ for row in page["Contents"]:
504
+ yield row
485
505
 
486
506
  async def _check_key_async(
487
507
  self,
@@ -506,21 +526,16 @@ class S3Hook(AwsBaseHook):
506
526
  """
507
527
  bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
508
528
  if wildcard_match:
509
- keys = await self.get_file_metadata_async(client, bucket_name, key)
510
- key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
511
- if not key_matches:
512
- return False
513
- elif use_regex:
514
- keys = await self.get_file_metadata_async(client, bucket_name)
515
- key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
516
- if not key_matches:
517
- return False
518
- else:
519
- obj = await self.get_head_object_async(client, key, bucket_name)
520
- if obj is None:
521
- return False
522
-
523
- return True
529
+ async for k in self.get_file_metadata_async(client, bucket_name, key):
530
+ if fnmatch.fnmatch(k["Key"], key):
531
+ return True
532
+ return False
533
+ if use_regex:
534
+ async for k in self.get_file_metadata_async(client, bucket_name):
535
+ if re.match(pattern=key, string=k["Key"]):
536
+ return True
537
+ return False
538
+ return bool(await self.get_head_object_async(client, key, bucket_name))
524
539
 
525
540
  async def check_key_async(
526
541
  self,
@@ -1276,6 +1291,8 @@ class S3Hook(AwsBaseHook):
1276
1291
  object to be copied which is private by default.
1277
1292
  """
1278
1293
  acl_policy = acl_policy or "private"
1294
+ if acl_policy != NO_ACL:
1295
+ kwargs["ACL"] = acl_policy
1279
1296
 
1280
1297
  dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
1281
1298
  dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
@@ -1287,7 +1304,7 @@ class S3Hook(AwsBaseHook):
1287
1304
 
1288
1305
  copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
1289
1306
  response = self.get_conn().copy_object(
1290
- Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy, **kwargs
1307
+ Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
1291
1308
  )
1292
1309
  return response
1293
1310