apache-airflow-providers-amazon 8.25.0rc1__py3-none-any.whl → 8.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +10 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +19 -16
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +22 -15
- airflow/providers/amazon/aws/hooks/athena.py +18 -9
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -10
- airflow/providers/amazon/aws/hooks/chime.py +2 -1
- airflow/providers/amazon/aws/hooks/datasync.py +6 -3
- airflow/providers/amazon/aws/hooks/ecr.py +2 -1
- airflow/providers/amazon/aws/hooks/ecs.py +12 -6
- airflow/providers/amazon/aws/hooks/glacier.py +8 -4
- airflow/providers/amazon/aws/hooks/kinesis.py +2 -1
- airflow/providers/amazon/aws/hooks/logs.py +4 -2
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +24 -12
- airflow/providers/amazon/aws/hooks/redshift_data.py +4 -2
- airflow/providers/amazon/aws/hooks/redshift_sql.py +6 -3
- airflow/providers/amazon/aws/hooks/s3.py +70 -53
- airflow/providers/amazon/aws/hooks/sagemaker.py +82 -41
- airflow/providers/amazon/aws/hooks/secrets_manager.py +6 -3
- airflow/providers/amazon/aws/hooks/sts.py +2 -1
- airflow/providers/amazon/aws/operators/athena.py +21 -8
- airflow/providers/amazon/aws/operators/batch.py +12 -6
- airflow/providers/amazon/aws/operators/datasync.py +2 -1
- airflow/providers/amazon/aws/operators/ecs.py +1 -0
- airflow/providers/amazon/aws/operators/emr.py +6 -86
- airflow/providers/amazon/aws/operators/glue.py +4 -2
- airflow/providers/amazon/aws/operators/glue_crawler.py +22 -19
- airflow/providers/amazon/aws/operators/neptune.py +2 -1
- airflow/providers/amazon/aws/operators/redshift_cluster.py +2 -1
- airflow/providers/amazon/aws/operators/s3.py +11 -1
- airflow/providers/amazon/aws/operators/sagemaker.py +8 -10
- airflow/providers/amazon/aws/sensors/base_aws.py +2 -1
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +25 -17
- airflow/providers/amazon/aws/sensors/glue_crawler.py +16 -12
- airflow/providers/amazon/aws/sensors/s3.py +11 -5
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +6 -3
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +2 -1
- airflow/providers/amazon/aws/transfers/s3_to_sql.py +2 -1
- airflow/providers/amazon/aws/triggers/ecs.py +3 -1
- airflow/providers/amazon/aws/triggers/glue.py +15 -3
- airflow/providers/amazon/aws/triggers/glue_crawler.py +8 -1
- airflow/providers/amazon/aws/utils/connection_wrapper.py +10 -5
- airflow/providers/amazon/aws/utils/mixins.py +2 -1
- airflow/providers/amazon/aws/utils/redshift.py +2 -1
- airflow/providers/amazon/get_provider_info.py +2 -1
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/METADATA +9 -9
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/RECORD +50 -50
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.25.0rc1.dist-info → apache_airflow_providers_amazon-8.26.0.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,8 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBas
|
|
27
27
|
|
28
28
|
|
29
29
|
class RedshiftHook(AwsBaseHook):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Interact with Amazon Redshift.
|
31
32
|
|
32
33
|
This is a thin wrapper around
|
33
34
|
:external+boto3:py:class:`boto3.client("redshift") <Redshift.Client>`.
|
@@ -53,7 +54,8 @@ class RedshiftHook(AwsBaseHook):
|
|
53
54
|
master_user_password: str,
|
54
55
|
params: dict[str, Any],
|
55
56
|
) -> dict[str, Any]:
|
56
|
-
"""
|
57
|
+
"""
|
58
|
+
Create a new cluster with the specified parameters.
|
57
59
|
|
58
60
|
.. seealso::
|
59
61
|
- :external+boto3:py:meth:`Redshift.Client.create_cluster`
|
@@ -80,7 +82,8 @@ class RedshiftHook(AwsBaseHook):
|
|
80
82
|
|
81
83
|
# TODO: Wrap create_cluster_snapshot
|
82
84
|
def cluster_status(self, cluster_identifier: str) -> str:
|
83
|
-
"""
|
85
|
+
"""
|
86
|
+
Get status of a cluster.
|
84
87
|
|
85
88
|
.. seealso::
|
86
89
|
- :external+boto3:py:meth:`Redshift.Client.describe_clusters`
|
@@ -101,7 +104,8 @@ class RedshiftHook(AwsBaseHook):
|
|
101
104
|
skip_final_cluster_snapshot: bool = True,
|
102
105
|
final_cluster_snapshot_identifier: str | None = None,
|
103
106
|
):
|
104
|
-
"""
|
107
|
+
"""
|
108
|
+
Delete a cluster and optionally create a snapshot.
|
105
109
|
|
106
110
|
.. seealso::
|
107
111
|
- :external+boto3:py:meth:`Redshift.Client.delete_cluster`
|
@@ -120,7 +124,8 @@ class RedshiftHook(AwsBaseHook):
|
|
120
124
|
return response["Cluster"] if response["Cluster"] else None
|
121
125
|
|
122
126
|
def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str] | None:
|
123
|
-
"""
|
127
|
+
"""
|
128
|
+
List snapshots for a cluster.
|
124
129
|
|
125
130
|
.. seealso::
|
126
131
|
- :external+boto3:py:meth:`Redshift.Client.describe_cluster_snapshots`
|
@@ -136,7 +141,8 @@ class RedshiftHook(AwsBaseHook):
|
|
136
141
|
return snapshots
|
137
142
|
|
138
143
|
def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str:
|
139
|
-
"""
|
144
|
+
"""
|
145
|
+
Restore a cluster from its snapshot.
|
140
146
|
|
141
147
|
.. seealso::
|
142
148
|
- :external+boto3:py:meth:`Redshift.Client.restore_from_cluster_snapshot`
|
@@ -156,7 +162,8 @@ class RedshiftHook(AwsBaseHook):
|
|
156
162
|
retention_period: int = -1,
|
157
163
|
tags: list[Any] | None = None,
|
158
164
|
) -> str:
|
159
|
-
"""
|
165
|
+
"""
|
166
|
+
Create a snapshot of a cluster.
|
160
167
|
|
161
168
|
.. seealso::
|
162
169
|
- :external+boto3:py:meth:`Redshift.Client.create_cluster_snapshot`
|
@@ -178,7 +185,8 @@ class RedshiftHook(AwsBaseHook):
|
|
178
185
|
return response["Snapshot"] if response["Snapshot"] else None
|
179
186
|
|
180
187
|
def get_cluster_snapshot_status(self, snapshot_identifier: str):
|
181
|
-
"""
|
188
|
+
"""
|
189
|
+
Get Redshift cluster snapshot status.
|
182
190
|
|
183
191
|
If cluster snapshot not found, *None* is returned.
|
184
192
|
|
@@ -210,7 +218,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
|
|
210
218
|
super().__init__(*args, **kwargs)
|
211
219
|
|
212
220
|
async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]:
|
213
|
-
"""
|
221
|
+
"""
|
222
|
+
Get the cluster status.
|
214
223
|
|
215
224
|
:param cluster_identifier: unique identifier of a cluster
|
216
225
|
:param delete_operation: whether the method has been called as part of delete cluster operation
|
@@ -228,7 +237,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
|
|
228
237
|
return {"status": "error", "message": str(error)}
|
229
238
|
|
230
239
|
async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]:
|
231
|
-
"""
|
240
|
+
"""
|
241
|
+
Pause the cluster.
|
232
242
|
|
233
243
|
:param cluster_identifier: unique identifier of a cluster
|
234
244
|
:param poll_interval: polling period in seconds to check for the status
|
@@ -255,7 +265,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
|
|
255
265
|
cluster_identifier: str,
|
256
266
|
polling_period_seconds: float = 5.0,
|
257
267
|
) -> dict[str, Any]:
|
258
|
-
"""
|
268
|
+
"""
|
269
|
+
Resume the cluster.
|
259
270
|
|
260
271
|
:param cluster_identifier: unique identifier of a cluster
|
261
272
|
:param polling_period_seconds: polling period in seconds to check for the status
|
@@ -284,7 +295,8 @@ class RedshiftAsyncHook(AwsBaseAsyncHook):
|
|
284
295
|
flag: asyncio.Event,
|
285
296
|
delete_operation: bool = False,
|
286
297
|
) -> dict[str, Any]:
|
287
|
-
"""
|
298
|
+
"""
|
299
|
+
Check for expected Redshift cluster state.
|
288
300
|
|
289
301
|
:param cluster_identifier: unique identifier of a cluster
|
290
302
|
:param expected_state: expected_state example("available", "pausing", "paused"")
|
@@ -231,7 +231,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
231
231
|
return pk_columns or None
|
232
232
|
|
233
233
|
async def is_still_running(self, statement_id: str) -> bool:
|
234
|
-
"""
|
234
|
+
"""
|
235
|
+
Async function to check whether the query is still running.
|
235
236
|
|
236
237
|
:param statement_id: the UUID of the statement
|
237
238
|
"""
|
@@ -240,7 +241,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
240
241
|
return desc["Status"] in RUNNING_STATES
|
241
242
|
|
242
243
|
async def check_query_is_finished_async(self, statement_id: str) -> bool:
|
243
|
-
"""
|
244
|
+
"""
|
245
|
+
Async function to check statement is finished.
|
244
246
|
|
245
247
|
It takes statement_id, makes async connection to redshift data to get the query status
|
246
248
|
by statement_id and returns the query status.
|
@@ -39,7 +39,8 @@ if TYPE_CHECKING:
|
|
39
39
|
|
40
40
|
|
41
41
|
class RedshiftSQLHook(DbApiHook):
|
42
|
-
"""
|
42
|
+
"""
|
43
|
+
Execute statements against Amazon Redshift.
|
43
44
|
|
44
45
|
This hook requires the redshift_conn_id connection.
|
45
46
|
|
@@ -103,7 +104,8 @@ class RedshiftSQLHook(DbApiHook):
|
|
103
104
|
return conn_params
|
104
105
|
|
105
106
|
def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
|
106
|
-
"""
|
107
|
+
"""
|
108
|
+
Retrieve a temporary password to connect to Redshift.
|
107
109
|
|
108
110
|
Port is required. If none is provided, default is used for each service.
|
109
111
|
"""
|
@@ -177,7 +179,8 @@ class RedshiftSQLHook(DbApiHook):
|
|
177
179
|
return create_engine(self.get_uri(), **engine_kwargs)
|
178
180
|
|
179
181
|
def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None:
|
180
|
-
"""
|
182
|
+
"""
|
183
|
+
Get the table's primary key.
|
181
184
|
|
182
185
|
:param table: Name of the target table
|
183
186
|
:param schema: Name of the target schema, public by default
|
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
22
22
|
import asyncio
|
23
23
|
import fnmatch
|
24
24
|
import gzip as gz
|
25
|
+
import inspect
|
25
26
|
import logging
|
26
27
|
import os
|
27
28
|
import re
|
@@ -36,7 +37,7 @@ from inspect import signature
|
|
36
37
|
from io import BytesIO
|
37
38
|
from pathlib import Path
|
38
39
|
from tempfile import NamedTemporaryFile, gettempdir
|
39
|
-
from typing import TYPE_CHECKING, Any, Callable
|
40
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
|
40
41
|
from urllib.parse import urlsplit
|
41
42
|
from uuid import uuid4
|
42
43
|
|
@@ -61,42 +62,24 @@ from airflow.utils.helpers import chunks
|
|
61
62
|
logger = logging.getLogger(__name__)
|
62
63
|
|
63
64
|
|
65
|
+
# Explicit value that would remove ACLs from a copy
|
66
|
+
# No conflicts with Canned ACLs:
|
67
|
+
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
|
68
|
+
NO_ACL = "no-acl"
|
69
|
+
|
70
|
+
|
64
71
|
def provide_bucket_name(func: Callable) -> Callable:
|
65
72
|
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
|
66
73
|
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
|
67
74
|
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
|
68
|
-
function_signature = signature(func)
|
69
|
-
|
70
|
-
@wraps(func)
|
71
|
-
def wrapper(*args, **kwargs) -> Callable:
|
72
|
-
bound_args = function_signature.bind(*args, **kwargs)
|
73
|
-
|
74
|
-
if "bucket_name" not in bound_args.arguments:
|
75
|
-
self = args[0]
|
76
|
-
|
77
|
-
if "bucket_name" in self.service_config:
|
78
|
-
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
|
79
|
-
elif self.conn_config and self.conn_config.schema:
|
80
|
-
warnings.warn(
|
81
|
-
"s3 conn_type, and the associated schema field, is deprecated. "
|
82
|
-
"Please use aws conn_type instead, and specify `bucket_name` "
|
83
|
-
"in `service_config.s3` within `extras`.",
|
84
|
-
AirflowProviderDeprecationWarning,
|
85
|
-
stacklevel=2,
|
86
|
-
)
|
87
|
-
bound_args.arguments["bucket_name"] = self.conn_config.schema
|
88
|
-
|
89
|
-
return func(*bound_args.args, **bound_args.kwargs)
|
90
75
|
|
91
|
-
return wrapper
|
92
|
-
|
93
|
-
|
94
|
-
def provide_bucket_name_async(func: Callable) -> Callable:
|
95
|
-
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
|
96
76
|
function_signature = signature(func)
|
77
|
+
if "bucket_name" not in function_signature.parameters:
|
78
|
+
raise RuntimeError(
|
79
|
+
"Decorator provide_bucket_name should only wrap a function with param 'bucket_name'."
|
80
|
+
)
|
97
81
|
|
98
|
-
|
99
|
-
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
82
|
+
async def maybe_add_bucket_name(*args, **kwargs):
|
100
83
|
bound_args = function_signature.bind(*args, **kwargs)
|
101
84
|
|
102
85
|
if "bucket_name" not in bound_args.arguments:
|
@@ -105,8 +88,46 @@ def provide_bucket_name_async(func: Callable) -> Callable:
|
|
105
88
|
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
|
106
89
|
if connection.schema:
|
107
90
|
bound_args.arguments["bucket_name"] = connection.schema
|
91
|
+
return bound_args
|
92
|
+
|
93
|
+
if inspect.iscoroutinefunction(func):
|
94
|
+
|
95
|
+
@wraps(func)
|
96
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
97
|
+
bound_args = await maybe_add_bucket_name(*args, **kwargs)
|
98
|
+
print(f"invoking async function {func=}")
|
99
|
+
return await func(*bound_args.args, **bound_args.kwargs)
|
100
|
+
|
101
|
+
elif inspect.isasyncgenfunction(func):
|
102
|
+
|
103
|
+
@wraps(func)
|
104
|
+
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
105
|
+
bound_args = await maybe_add_bucket_name(*args, **kwargs)
|
106
|
+
async for thing in func(*bound_args.args, **bound_args.kwargs):
|
107
|
+
yield thing
|
108
|
+
|
109
|
+
else:
|
110
|
+
|
111
|
+
@wraps(func)
|
112
|
+
def wrapper(*args, **kwargs) -> Callable:
|
113
|
+
bound_args = function_signature.bind(*args, **kwargs)
|
114
|
+
|
115
|
+
if "bucket_name" not in bound_args.arguments:
|
116
|
+
self = args[0]
|
117
|
+
|
118
|
+
if "bucket_name" in self.service_config:
|
119
|
+
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
|
120
|
+
elif self.conn_config and self.conn_config.schema:
|
121
|
+
warnings.warn(
|
122
|
+
"s3 conn_type, and the associated schema field, is deprecated. "
|
123
|
+
"Please use aws conn_type instead, and specify `bucket_name` "
|
124
|
+
"in `service_config.s3` within `extras`.",
|
125
|
+
AirflowProviderDeprecationWarning,
|
126
|
+
stacklevel=2,
|
127
|
+
)
|
128
|
+
bound_args.arguments["bucket_name"] = self.conn_config.schema
|
108
129
|
|
109
|
-
|
130
|
+
return func(*bound_args.args, **bound_args.kwargs)
|
110
131
|
|
111
132
|
return wrapper
|
112
133
|
|
@@ -400,8 +421,8 @@ class S3Hook(AwsBaseHook):
|
|
400
421
|
|
401
422
|
return prefixes
|
402
423
|
|
403
|
-
@provide_bucket_name_async
|
404
424
|
@unify_bucket_name_and_key
|
425
|
+
@provide_bucket_name
|
405
426
|
async def get_head_object_async(
|
406
427
|
self, client: AioBaseClient, key: str, bucket_name: str | None = None
|
407
428
|
) -> dict[str, Any] | None:
|
@@ -462,10 +483,10 @@ class S3Hook(AwsBaseHook):
|
|
462
483
|
|
463
484
|
return prefixes
|
464
485
|
|
465
|
-
@
|
486
|
+
@provide_bucket_name
|
466
487
|
async def get_file_metadata_async(
|
467
488
|
self, client: AioBaseClient, bucket_name: str, key: str | None = None
|
468
|
-
) ->
|
489
|
+
) -> AsyncIterator[Any]:
|
469
490
|
"""
|
470
491
|
Get a list of files that a key matching a wildcard expression exists in a bucket asynchronously.
|
471
492
|
|
@@ -477,11 +498,10 @@ class S3Hook(AwsBaseHook):
|
|
477
498
|
delimiter = ""
|
478
499
|
paginator = client.get_paginator("list_objects_v2")
|
479
500
|
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter)
|
480
|
-
files = []
|
481
501
|
async for page in response:
|
482
502
|
if "Contents" in page:
|
483
|
-
|
484
|
-
|
503
|
+
for row in page["Contents"]:
|
504
|
+
yield row
|
485
505
|
|
486
506
|
async def _check_key_async(
|
487
507
|
self,
|
@@ -506,21 +526,16 @@ class S3Hook(AwsBaseHook):
|
|
506
526
|
"""
|
507
527
|
bucket_name, key = self.get_s3_bucket_key(bucket_val, key, "bucket_name", "bucket_key")
|
508
528
|
if wildcard_match:
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
obj = await self.get_head_object_async(client, key, bucket_name)
|
520
|
-
if obj is None:
|
521
|
-
return False
|
522
|
-
|
523
|
-
return True
|
529
|
+
async for k in self.get_file_metadata_async(client, bucket_name, key):
|
530
|
+
if fnmatch.fnmatch(k["Key"], key):
|
531
|
+
return True
|
532
|
+
return False
|
533
|
+
if use_regex:
|
534
|
+
async for k in self.get_file_metadata_async(client, bucket_name):
|
535
|
+
if re.match(pattern=key, string=k["Key"]):
|
536
|
+
return True
|
537
|
+
return False
|
538
|
+
return bool(await self.get_head_object_async(client, key, bucket_name))
|
524
539
|
|
525
540
|
async def check_key_async(
|
526
541
|
self,
|
@@ -1276,6 +1291,8 @@ class S3Hook(AwsBaseHook):
|
|
1276
1291
|
object to be copied which is private by default.
|
1277
1292
|
"""
|
1278
1293
|
acl_policy = acl_policy or "private"
|
1294
|
+
if acl_policy != NO_ACL:
|
1295
|
+
kwargs["ACL"] = acl_policy
|
1279
1296
|
|
1280
1297
|
dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
|
1281
1298
|
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
|
@@ -1287,7 +1304,7 @@ class S3Hook(AwsBaseHook):
|
|
1287
1304
|
|
1288
1305
|
copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
|
1289
1306
|
response = self.get_conn().copy_object(
|
1290
|
-
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source,
|
1307
|
+
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
|
1291
1308
|
)
|
1292
1309
|
return response
|
1293
1310
|
|