apache-airflow-providers-amazon 9.0.0__py3-none-any.whl → 9.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/assets/s3.py +7 -7
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +12 -1
- airflow/providers/amazon/aws/hooks/athena.py +25 -15
- airflow/providers/amazon/aws/hooks/eks.py +2 -2
- airflow/providers/amazon/aws/hooks/glue.py +5 -1
- airflow/providers/amazon/aws/hooks/redshift_sql.py +1 -1
- airflow/providers/amazon/aws/hooks/s3.py +79 -31
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -0
- airflow/providers/amazon/aws/operators/appflow.py +1 -1
- airflow/providers/amazon/aws/operators/athena.py +3 -1
- airflow/providers/amazon/aws/operators/comprehend.py +3 -3
- airflow/providers/amazon/aws/operators/dms.py +3 -3
- airflow/providers/amazon/aws/operators/ecs.py +11 -3
- airflow/providers/amazon/aws/operators/eks.py +4 -2
- airflow/providers/amazon/aws/operators/glue.py +10 -1
- airflow/providers/amazon/aws/operators/kinesis_analytics.py +3 -3
- airflow/providers/amazon/aws/operators/redshift_data.py +43 -20
- airflow/providers/amazon/aws/operators/sagemaker.py +2 -2
- airflow/providers/amazon/aws/sensors/sagemaker.py +32 -0
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +106 -7
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
- airflow/providers/amazon/aws/triggers/athena.py +1 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -6
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/get_provider_info.py +5 -5
- {apache_airflow_providers_amazon-9.0.0.dist-info → apache_airflow_providers_amazon-9.1.0.dist-info}/METADATA +21 -23
- {apache_airflow_providers_amazon-9.0.0.dist-info → apache_airflow_providers_amazon-9.1.0.dist-info}/RECORD +31 -32
- {apache_airflow_providers_amazon-9.0.0.dist-info → apache_airflow_providers_amazon-9.1.0.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py +0 -106
- {apache_airflow_providers_amazon-9.0.0.dist-info → apache_airflow_providers_amazon-9.1.0.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.1.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.8.0"
|
@@ -19,16 +19,14 @@ from __future__ import annotations
|
|
19
19
|
from typing import TYPE_CHECKING
|
20
20
|
|
21
21
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
22
|
-
|
23
|
-
try:
|
24
|
-
from airflow.assets import Asset
|
25
|
-
except ModuleNotFoundError:
|
26
|
-
from airflow.datasets import Dataset as Asset # type: ignore[no-redef]
|
22
|
+
from airflow.providers.common.compat.assets import Asset
|
27
23
|
|
28
24
|
if TYPE_CHECKING:
|
29
25
|
from urllib.parse import SplitResult
|
30
26
|
|
31
|
-
from airflow.providers.common.compat.openlineage.facet import
|
27
|
+
from airflow.providers.common.compat.openlineage.facet import (
|
28
|
+
Dataset as OpenLineageDataset,
|
29
|
+
)
|
32
30
|
|
33
31
|
|
34
32
|
def create_asset(*, bucket: str, key: str, extra=None) -> Asset:
|
@@ -43,7 +41,9 @@ def sanitize_uri(uri: SplitResult) -> SplitResult:
|
|
43
41
|
|
44
42
|
def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset:
|
45
43
|
"""Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
|
46
|
-
from airflow.providers.common.compat.openlineage.facet import
|
44
|
+
from airflow.providers.common.compat.openlineage.facet import (
|
45
|
+
Dataset as OpenLineageDataset,
|
46
|
+
)
|
47
47
|
|
48
48
|
bucket, key = S3Hook.parse_s3_url(asset.uri)
|
49
49
|
return OpenLineageDataset(namespace=f"s3://{bucket}", name=key if key else "/")
|
@@ -17,6 +17,7 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import argparse
|
20
|
+
import warnings
|
20
21
|
from collections import defaultdict
|
21
22
|
from functools import cached_property
|
22
23
|
from typing import TYPE_CHECKING, Container, Sequence, cast
|
@@ -24,7 +25,7 @@ from typing import TYPE_CHECKING, Container, Sequence, cast
|
|
24
25
|
from flask import session, url_for
|
25
26
|
|
26
27
|
from airflow.cli.cli_config import CLICommand, DefaultHelpParser, GroupCommand
|
27
|
-
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
28
|
+
from airflow.exceptions import AirflowOptionalProviderFeatureException, AirflowProviderDeprecationWarning
|
28
29
|
from airflow.providers.amazon.aws.auth_manager.avp.entities import AvpEntities
|
29
30
|
from airflow.providers.amazon.aws.auth_manager.avp.facade import (
|
30
31
|
AwsAuthManagerAmazonVerifiedPermissionsFacade,
|
@@ -166,6 +167,16 @@ class AwsAuthManager(BaseAuthManager):
|
|
166
167
|
method=method, entity_type=AvpEntities.ASSET, user=user or self.get_user(), entity_id=asset_uri
|
167
168
|
)
|
168
169
|
|
170
|
+
def is_authorized_dataset(
|
171
|
+
self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None
|
172
|
+
) -> bool:
|
173
|
+
warnings.warn(
|
174
|
+
"is_authorized_dataset will be renamed as is_authorized_asset in Airflow 3 and will be removed when the minimum Airflow version is set to 3.0 for the amazon provider",
|
175
|
+
AirflowProviderDeprecationWarning,
|
176
|
+
stacklevel=2,
|
177
|
+
)
|
178
|
+
return self.is_authorized_asset(method=method, user=user)
|
179
|
+
|
169
180
|
def is_authorized_pool(
|
170
181
|
self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None
|
171
182
|
) -> bool:
|
@@ -155,14 +155,15 @@ class AthenaHook(AwsBaseHook):
|
|
155
155
|
state = None
|
156
156
|
try:
|
157
157
|
state = response["QueryExecution"]["Status"]["State"]
|
158
|
-
except Exception:
|
159
|
-
self.log.exception(
|
160
|
-
"Exception while getting query state. Query execution id: %s", query_execution_id
|
161
|
-
)
|
162
|
-
finally:
|
158
|
+
except Exception as e:
|
163
159
|
# The error is being absorbed here and is being handled by the caller.
|
164
160
|
# The error is being absorbed to implement retries.
|
165
|
-
|
161
|
+
self.log.exception(
|
162
|
+
"Exception while getting query state. Query execution id: %s, Exception: %s",
|
163
|
+
query_execution_id,
|
164
|
+
e,
|
165
|
+
)
|
166
|
+
return state
|
166
167
|
|
167
168
|
def get_state_change_reason(self, query_execution_id: str, use_cache: bool = False) -> str | None:
|
168
169
|
"""
|
@@ -177,15 +178,15 @@ class AthenaHook(AwsBaseHook):
|
|
177
178
|
reason = None
|
178
179
|
try:
|
179
180
|
reason = response["QueryExecution"]["Status"]["StateChangeReason"]
|
180
|
-
except Exception:
|
181
|
+
except Exception as e:
|
182
|
+
# The error is being absorbed here and is being handled by the caller.
|
183
|
+
# The error is being absorbed to implement retries.
|
181
184
|
self.log.exception(
|
182
|
-
"Exception while getting query state change reason. Query execution id: %s",
|
185
|
+
"Exception while getting query state change reason. Query execution id: %s, Exception: %s",
|
183
186
|
query_execution_id,
|
187
|
+
e,
|
184
188
|
)
|
185
|
-
|
186
|
-
# The error is being absorbed here and is being handled by the caller.
|
187
|
-
# The error is being absorbed to implement retries.
|
188
|
-
return reason
|
189
|
+
return reason
|
189
190
|
|
190
191
|
def get_query_results(
|
191
192
|
self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000
|
@@ -287,9 +288,18 @@ class AthenaHook(AwsBaseHook):
|
|
287
288
|
)
|
288
289
|
except AirflowException as error:
|
289
290
|
# this function does not raise errors to keep previous behavior.
|
290
|
-
self.log.warning(
|
291
|
-
|
292
|
-
|
291
|
+
self.log.warning(
|
292
|
+
"AirflowException while polling query status. Query execution id: %s, Exception: %s",
|
293
|
+
query_execution_id,
|
294
|
+
error,
|
295
|
+
)
|
296
|
+
except Exception as e:
|
297
|
+
self.log.warning(
|
298
|
+
"Unexpected exception while polling query status. Query execution id: %s, Exception: %s",
|
299
|
+
query_execution_id,
|
300
|
+
e,
|
301
|
+
)
|
302
|
+
return self.check_query_status(query_execution_id)
|
293
303
|
|
294
304
|
def get_output_location(self, query_execution_id: str) -> str:
|
295
305
|
"""
|
@@ -85,8 +85,8 @@ COMMAND = """
|
|
85
85
|
exit 1
|
86
86
|
fi
|
87
87
|
|
88
|
-
expiration_timestamp=$(echo "$output" | grep -oP 'expirationTimestamp
|
89
|
-
token=$(echo "$output" | grep -oP 'token
|
88
|
+
expiration_timestamp=$(echo "$output" | grep -oP 'expirationTimestamp: \\K[^,]+')
|
89
|
+
token=$(echo "$output" | grep -oP 'token: \\K[^,]+')
|
90
90
|
|
91
91
|
json_string=$(printf '{{"kind": "ExecCredential","apiVersion": \
|
92
92
|
"client.authentication.k8s.io/v1alpha1","spec": {{}},"status": \
|
@@ -282,13 +282,16 @@ class GlueJobHook(AwsBaseHook):
|
|
282
282
|
log_group_error, continuation_tokens.error_stream_continuation
|
283
283
|
)
|
284
284
|
|
285
|
-
def job_completion(
|
285
|
+
def job_completion(
|
286
|
+
self, job_name: str, run_id: str, verbose: bool = False, sleep_before_return: int = 0
|
287
|
+
) -> dict[str, str]:
|
286
288
|
"""
|
287
289
|
Wait until Glue job with job_name finishes; return final state if finished or raises AirflowException.
|
288
290
|
|
289
291
|
:param job_name: unique job name per AWS account
|
290
292
|
:param run_id: The job-run ID of the predecessor job run
|
291
293
|
:param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False)
|
294
|
+
:param sleep_before_return: time in seconds to wait before returning final status.
|
292
295
|
:return: Dict of JobRunState and JobRunId
|
293
296
|
"""
|
294
297
|
next_log_tokens = self.LogContinuationTokens()
|
@@ -296,6 +299,7 @@ class GlueJobHook(AwsBaseHook):
|
|
296
299
|
job_run_state = self.get_job_state(job_name, run_id)
|
297
300
|
ret = self._handle_state(job_run_state, job_name, run_id, verbose, next_log_tokens)
|
298
301
|
if ret:
|
302
|
+
time.sleep(sleep_before_return)
|
299
303
|
return ret
|
300
304
|
else:
|
301
305
|
time.sleep(self.job_poll_interval)
|
@@ -163,7 +163,7 @@ class RedshiftSQLHook(DbApiHook):
|
|
163
163
|
# Compatibility: The 'create' factory method was added in SQLAlchemy 1.4
|
164
164
|
# to replace calling the default URL constructor directly.
|
165
165
|
create_url = getattr(URL, "create", URL)
|
166
|
-
return str(create_url(drivername="
|
166
|
+
return str(create_url(drivername="postgresql", **conn_params))
|
167
167
|
|
168
168
|
def get_sqlalchemy_engine(self, engine_kwargs=None):
|
169
169
|
"""Overridden to pass Redshift-specific arguments."""
|
@@ -41,14 +41,16 @@ from urllib.parse import urlsplit
|
|
41
41
|
from uuid import uuid4
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
|
-
from mypy_boto3_s3.service_resource import
|
44
|
+
from mypy_boto3_s3.service_resource import (
|
45
|
+
Bucket as S3Bucket,
|
46
|
+
Object as S3ResourceObject,
|
47
|
+
)
|
45
48
|
|
46
49
|
from airflow.utils.types import ArgNotSet
|
47
50
|
|
48
51
|
with suppress(ImportError):
|
49
52
|
from aiobotocore.client import AioBaseClient
|
50
53
|
|
51
|
-
from importlib.util import find_spec
|
52
54
|
|
53
55
|
from asgiref.sync import sync_to_async
|
54
56
|
from boto3.s3.transfer import S3Transfer, TransferConfig
|
@@ -58,15 +60,9 @@ from airflow.exceptions import AirflowException, AirflowNotFoundException
|
|
58
60
|
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
|
59
61
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
60
62
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
63
|
+
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
|
61
64
|
from airflow.utils.helpers import chunks
|
62
65
|
|
63
|
-
if find_spec("airflow.assets"):
|
64
|
-
from airflow.lineage.hook import get_hook_lineage_collector
|
65
|
-
else:
|
66
|
-
# TODO: import from common.compat directly after common.compat providers with
|
67
|
-
# asset_compat_lineage_collector released
|
68
|
-
from airflow.providers.amazon.aws.utils.asset_compat_lineage_collector import get_hook_lineage_collector
|
69
|
-
|
70
66
|
logger = logging.getLogger(__name__)
|
71
67
|
|
72
68
|
|
@@ -90,7 +86,7 @@ def provide_bucket_name(func: Callable) -> Callable:
|
|
90
86
|
async def maybe_add_bucket_name(*args, **kwargs):
|
91
87
|
bound_args = function_signature.bind(*args, **kwargs)
|
92
88
|
|
93
|
-
if
|
89
|
+
if not bound_args.arguments.get("bucket_name"):
|
94
90
|
self = args[0]
|
95
91
|
if self.aws_conn_id:
|
96
92
|
connection = await sync_to_async(self.get_connection)(self.aws_conn_id)
|
@@ -120,7 +116,7 @@ def provide_bucket_name(func: Callable) -> Callable:
|
|
120
116
|
def wrapper(*args, **kwargs) -> Callable:
|
121
117
|
bound_args = function_signature.bind(*args, **kwargs)
|
122
118
|
|
123
|
-
if
|
119
|
+
if not bound_args.arguments.get("bucket_name"):
|
124
120
|
self = args[0]
|
125
121
|
|
126
122
|
if "bucket_name" in self.service_config:
|
@@ -148,9 +144,10 @@ def unify_bucket_name_and_key(func: Callable) -> Callable:
|
|
148
144
|
|
149
145
|
if "bucket_name" not in bound_args.arguments:
|
150
146
|
with suppress(S3HookUriParseFailure):
|
151
|
-
|
152
|
-
bound_args.arguments[
|
153
|
-
|
147
|
+
(
|
148
|
+
bound_args.arguments["bucket_name"],
|
149
|
+
bound_args.arguments[key_name],
|
150
|
+
) = S3Hook.parse_s3_url(bound_args.arguments[key_name])
|
154
151
|
|
155
152
|
return func(*bound_args.args, **bound_args.kwargs)
|
156
153
|
|
@@ -318,7 +315,8 @@ class S3Hook(AwsBaseHook):
|
|
318
315
|
self.log.info('Bucket "%s" does not exist', bucket_name)
|
319
316
|
elif return_code == 403:
|
320
317
|
self.log.error(
|
321
|
-
'Access to bucket "%s" is forbidden or there was an error with the request',
|
318
|
+
'Access to bucket "%s" is forbidden or there was an error with the request',
|
319
|
+
bucket_name,
|
322
320
|
)
|
323
321
|
self.log.error(e)
|
324
322
|
return False
|
@@ -359,7 +357,8 @@ class S3Hook(AwsBaseHook):
|
|
359
357
|
self.get_conn().create_bucket(Bucket=bucket_name)
|
360
358
|
else:
|
361
359
|
self.get_conn().create_bucket(
|
362
|
-
Bucket=bucket_name,
|
360
|
+
Bucket=bucket_name,
|
361
|
+
CreateBucketConfiguration={"LocationConstraint": region_name},
|
363
362
|
)
|
364
363
|
|
365
364
|
@provide_bucket_name
|
@@ -410,7 +409,10 @@ class S3Hook(AwsBaseHook):
|
|
410
409
|
|
411
410
|
paginator = self.get_conn().get_paginator("list_objects_v2")
|
412
411
|
response = paginator.paginate(
|
413
|
-
Bucket=bucket_name,
|
412
|
+
Bucket=bucket_name,
|
413
|
+
Prefix=prefix,
|
414
|
+
Delimiter=delimiter,
|
415
|
+
PaginationConfig=config,
|
414
416
|
)
|
415
417
|
|
416
418
|
prefixes: list[str] = []
|
@@ -471,7 +473,10 @@ class S3Hook(AwsBaseHook):
|
|
471
473
|
|
472
474
|
paginator = client.get_paginator("list_objects_v2")
|
473
475
|
response = paginator.paginate(
|
474
|
-
Bucket=bucket_name,
|
476
|
+
Bucket=bucket_name,
|
477
|
+
Prefix=prefix,
|
478
|
+
Delimiter=delimiter,
|
479
|
+
PaginationConfig=config,
|
475
480
|
)
|
476
481
|
|
477
482
|
prefixes = []
|
@@ -569,7 +574,11 @@ class S3Hook(AwsBaseHook):
|
|
569
574
|
return await self._check_key_async(client, bucket, wildcard_match, bucket_keys, use_regex)
|
570
575
|
|
571
576
|
async def check_for_prefix_async(
|
572
|
-
self,
|
577
|
+
self,
|
578
|
+
client: AioBaseClient,
|
579
|
+
prefix: str,
|
580
|
+
delimiter: str,
|
581
|
+
bucket_name: str | None = None,
|
573
582
|
) -> bool:
|
574
583
|
"""
|
575
584
|
Check that a prefix exists in a bucket.
|
@@ -587,7 +596,11 @@ class S3Hook(AwsBaseHook):
|
|
587
596
|
return prefix in plist
|
588
597
|
|
589
598
|
async def _check_for_prefix_async(
|
590
|
-
self,
|
599
|
+
self,
|
600
|
+
client: AioBaseClient,
|
601
|
+
prefix: str,
|
602
|
+
delimiter: str,
|
603
|
+
bucket_name: str | None = None,
|
591
604
|
) -> bool:
|
592
605
|
return await self.check_for_prefix_async(
|
593
606
|
client, prefix=prefix, delimiter=delimiter, bucket_name=bucket_name
|
@@ -643,7 +656,10 @@ class S3Hook(AwsBaseHook):
|
|
643
656
|
|
644
657
|
paginator = client.get_paginator("list_objects_v2")
|
645
658
|
response = paginator.paginate(
|
646
|
-
Bucket=bucket_name,
|
659
|
+
Bucket=bucket_name,
|
660
|
+
Prefix=prefix,
|
661
|
+
Delimiter=delimiter,
|
662
|
+
PaginationConfig=config,
|
647
663
|
)
|
648
664
|
|
649
665
|
keys = []
|
@@ -655,7 +671,10 @@ class S3Hook(AwsBaseHook):
|
|
655
671
|
return keys
|
656
672
|
|
657
673
|
def _list_key_object_filter(
|
658
|
-
self,
|
674
|
+
self,
|
675
|
+
keys: list,
|
676
|
+
from_datetime: datetime | None = None,
|
677
|
+
to_datetime: datetime | None = None,
|
659
678
|
) -> list:
|
660
679
|
def _is_in_period(input_date: datetime) -> bool:
|
661
680
|
if from_datetime is not None and input_date <= from_datetime:
|
@@ -766,7 +785,10 @@ class S3Hook(AwsBaseHook):
|
|
766
785
|
"message": success_message,
|
767
786
|
}
|
768
787
|
|
769
|
-
self.log.error(
|
788
|
+
self.log.error(
|
789
|
+
"FAILURE: Inactivity Period passed, not enough objects found in %s",
|
790
|
+
path,
|
791
|
+
)
|
770
792
|
return {
|
771
793
|
"status": "error",
|
772
794
|
"message": f"FAILURE: Inactivity Period passed, not enough objects found in {path}",
|
@@ -1109,7 +1131,13 @@ class S3Hook(AwsBaseHook):
|
|
1109
1131
|
extra_args["ACL"] = acl_policy
|
1110
1132
|
|
1111
1133
|
client = self.get_conn()
|
1112
|
-
client.upload_file(
|
1134
|
+
client.upload_file(
|
1135
|
+
filename,
|
1136
|
+
bucket_name,
|
1137
|
+
key,
|
1138
|
+
ExtraArgs=extra_args,
|
1139
|
+
Config=self.transfer_config,
|
1140
|
+
)
|
1113
1141
|
get_hook_lineage_collector().add_input_asset(
|
1114
1142
|
context=self, scheme="file", asset_kwargs={"path": filename}
|
1115
1143
|
)
|
@@ -1308,18 +1336,32 @@ class S3Hook(AwsBaseHook):
|
|
1308
1336
|
)
|
1309
1337
|
|
1310
1338
|
source_bucket_name, source_bucket_key = self.get_s3_bucket_key(
|
1311
|
-
source_bucket_name,
|
1339
|
+
source_bucket_name,
|
1340
|
+
source_bucket_key,
|
1341
|
+
"source_bucket_name",
|
1342
|
+
"source_bucket_key",
|
1312
1343
|
)
|
1313
1344
|
|
1314
|
-
copy_source = {
|
1345
|
+
copy_source = {
|
1346
|
+
"Bucket": source_bucket_name,
|
1347
|
+
"Key": source_bucket_key,
|
1348
|
+
"VersionId": source_version_id,
|
1349
|
+
}
|
1315
1350
|
response = self.get_conn().copy_object(
|
1316
|
-
Bucket=dest_bucket_name,
|
1351
|
+
Bucket=dest_bucket_name,
|
1352
|
+
Key=dest_bucket_key,
|
1353
|
+
CopySource=copy_source,
|
1354
|
+
**kwargs,
|
1317
1355
|
)
|
1318
1356
|
get_hook_lineage_collector().add_input_asset(
|
1319
|
-
context=self,
|
1357
|
+
context=self,
|
1358
|
+
scheme="s3",
|
1359
|
+
asset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key},
|
1320
1360
|
)
|
1321
1361
|
get_hook_lineage_collector().add_output_asset(
|
1322
|
-
context=self,
|
1362
|
+
context=self,
|
1363
|
+
scheme="s3",
|
1364
|
+
asset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key},
|
1323
1365
|
)
|
1324
1366
|
return response
|
1325
1367
|
|
@@ -1435,7 +1477,10 @@ class S3Hook(AwsBaseHook):
|
|
1435
1477
|
file_path = Path(local_dir, subdir, filename_in_s3)
|
1436
1478
|
|
1437
1479
|
if file_path.is_file():
|
1438
|
-
self.log.error(
|
1480
|
+
self.log.error(
|
1481
|
+
"file '%s' already exists. Failing the task and not overwriting it",
|
1482
|
+
file_path,
|
1483
|
+
)
|
1439
1484
|
raise FileExistsError
|
1440
1485
|
|
1441
1486
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
@@ -1484,7 +1529,10 @@ class S3Hook(AwsBaseHook):
|
|
1484
1529
|
s3_client = self.get_conn()
|
1485
1530
|
try:
|
1486
1531
|
return s3_client.generate_presigned_url(
|
1487
|
-
ClientMethod=client_method,
|
1532
|
+
ClientMethod=client_method,
|
1533
|
+
Params=params,
|
1534
|
+
ExpiresIn=expires_in,
|
1535
|
+
HttpMethod=http_method,
|
1488
1536
|
)
|
1489
1537
|
|
1490
1538
|
except ClientError as e:
|
@@ -153,7 +153,9 @@ class SageMakerHook(AwsBaseHook):
|
|
153
153
|
non_terminal_states = {"InProgress", "Stopping"}
|
154
154
|
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
|
155
155
|
pipeline_non_terminal_states = {"Executing", "Stopping"}
|
156
|
+
processing_job_non_terminal_states = {"InProgress", "Stopping"}
|
156
157
|
failed_states = {"Failed"}
|
158
|
+
processing_job_failed_states = {*failed_states, "Stopped"}
|
157
159
|
training_failed_states = {*failed_states, "Stopped"}
|
158
160
|
|
159
161
|
def __init__(self, *args, **kwargs):
|
@@ -21,11 +21,11 @@ from datetime import datetime, timedelta
|
|
21
21
|
from typing import TYPE_CHECKING, cast
|
22
22
|
|
23
23
|
from airflow.exceptions import AirflowException
|
24
|
-
from airflow.operators.python import ShortCircuitOperator
|
25
24
|
from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
|
26
25
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
27
26
|
from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms
|
28
27
|
from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, aws_template_fields
|
28
|
+
from airflow.providers.common.compat.standard.operators import ShortCircuitOperator
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from mypy_boto3_appflow.type_defs import (
|
@@ -311,7 +311,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
|
|
311
311
|
}
|
312
312
|
fields = [
|
313
313
|
SchemaDatasetFacetFields(
|
314
|
-
name=column["Name"],
|
314
|
+
name=column["Name"],
|
315
|
+
type=column["Type"],
|
316
|
+
description=column.get("Comment"),
|
315
317
|
)
|
316
318
|
for column in table_metadata["TableMetadata"]["Columns"]
|
317
319
|
]
|
@@ -17,7 +17,7 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
from functools import cached_property
|
20
|
-
from typing import TYPE_CHECKING, Any, Sequence
|
20
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Sequence
|
21
21
|
|
22
22
|
from airflow.configuration import conf
|
23
23
|
from airflow.exceptions import AirflowException
|
@@ -55,7 +55,7 @@ class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]):
|
|
55
55
|
"input_data_config", "output_data_config", "data_access_role_arn", "language_code"
|
56
56
|
)
|
57
57
|
|
58
|
-
template_fields_renderers: dict = {"input_data_config": "json", "output_data_config": "json"}
|
58
|
+
template_fields_renderers: ClassVar[dict] = {"input_data_config": "json", "output_data_config": "json"}
|
59
59
|
|
60
60
|
def __init__(
|
61
61
|
self,
|
@@ -248,7 +248,7 @@ class ComprehendCreateDocumentClassifierOperator(AwsBaseOperator[ComprehendHook]
|
|
248
248
|
"document_classifier_kwargs",
|
249
249
|
)
|
250
250
|
|
251
|
-
template_fields_renderers: dict = {
|
251
|
+
template_fields_renderers: ClassVar[dict] = {
|
252
252
|
"input_data_config": "json",
|
253
253
|
"output_data_config": "json",
|
254
254
|
"document_classifier_kwargs": "json",
|
@@ -17,7 +17,7 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from typing import TYPE_CHECKING, Sequence
|
20
|
+
from typing import TYPE_CHECKING, ClassVar, Sequence
|
21
21
|
|
22
22
|
from airflow.providers.amazon.aws.hooks.dms import DmsHook
|
23
23
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
@@ -64,7 +64,7 @@ class DmsCreateTaskOperator(AwsBaseOperator[DmsHook]):
|
|
64
64
|
"migration_type",
|
65
65
|
"create_task_kwargs",
|
66
66
|
)
|
67
|
-
template_fields_renderers = {
|
67
|
+
template_fields_renderers: ClassVar[dict] = {
|
68
68
|
"table_mappings": "json",
|
69
69
|
"create_task_kwargs": "json",
|
70
70
|
}
|
@@ -173,7 +173,7 @@ class DmsDescribeTasksOperator(AwsBaseOperator[DmsHook]):
|
|
173
173
|
|
174
174
|
aws_hook_class = DmsHook
|
175
175
|
template_fields: Sequence[str] = aws_template_fields("describe_tasks_kwargs")
|
176
|
-
template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"}
|
176
|
+
template_fields_renderers: ClassVar[dict[str, str]] = {"describe_tasks_kwargs": "json"}
|
177
177
|
|
178
178
|
def __init__(self, *, describe_tasks_kwargs: dict | None = None, **kwargs):
|
179
179
|
super().__init__(**kwargs)
|
@@ -368,7 +368,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
368
368
|
If None, this is the same as the `region` parameter. If that is also None,
|
369
369
|
this is the default AWS region based on your connection settings.
|
370
370
|
:param awslogs_stream_prefix: the stream prefix that is used for the CloudWatch logs.
|
371
|
-
This
|
371
|
+
This should match the prefix specified in the log configuration of the task definition.
|
372
372
|
Only required if you want logs to be shown in the Airflow UI after your job has
|
373
373
|
finished.
|
374
374
|
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
|
@@ -481,6 +481,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
481
481
|
self.awslogs_region = self.region_name
|
482
482
|
|
483
483
|
self.arn: str | None = None
|
484
|
+
self.container_name: str | None = None
|
484
485
|
self._started_by: str | None = None
|
485
486
|
|
486
487
|
self.retry_args = quota_retry
|
@@ -597,10 +598,10 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
597
598
|
|
598
599
|
if self.capacity_provider_strategy:
|
599
600
|
run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy
|
600
|
-
if self.volume_configurations is not None:
|
601
|
-
run_opts["volumeConfigurations"] = self.volume_configurations
|
602
601
|
elif self.launch_type:
|
603
602
|
run_opts["launchType"] = self.launch_type
|
603
|
+
if self.volume_configurations is not None:
|
604
|
+
run_opts["volumeConfigurations"] = self.volume_configurations
|
604
605
|
if self.platform_version is not None:
|
605
606
|
run_opts["platformVersion"] = self.platform_version
|
606
607
|
if self.group is not None:
|
@@ -624,6 +625,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
624
625
|
self.log.info("ECS Task started: %s", response)
|
625
626
|
|
626
627
|
self.arn = response["tasks"][0]["taskArn"]
|
628
|
+
self.container_name = response["tasks"][0]["containers"][0]["name"]
|
627
629
|
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
|
628
630
|
|
629
631
|
def _try_reattach_task(self, started_by: str):
|
@@ -659,6 +661,12 @@ class EcsRunTaskOperator(EcsBaseOperator):
|
|
659
661
|
return self.awslogs_group and self.awslogs_stream_prefix
|
660
662
|
|
661
663
|
def _get_logs_stream_name(self) -> str:
|
664
|
+
if (
|
665
|
+
self.awslogs_stream_prefix
|
666
|
+
and self.container_name
|
667
|
+
and not self.awslogs_stream_prefix.endswith(f"/{self.container_name}")
|
668
|
+
):
|
669
|
+
return f"{self.awslogs_stream_prefix}/{self.container_name}/{self._get_ecs_task_id(self.arn)}"
|
662
670
|
return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
|
663
671
|
|
664
672
|
def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
|
@@ -45,8 +45,10 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction
|
|
45
45
|
try:
|
46
46
|
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
|
47
47
|
except ImportError:
|
48
|
-
# preserve backward compatibility for older versions of cncf.kubernetes provider
|
49
|
-
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import
|
48
|
+
# preserve backward compatibility for older versions of cncf.kubernetes provider, remove this when minimum cncf.kubernetes provider is 10.0
|
49
|
+
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( # type: ignore[no-redef]
|
50
|
+
KubernetesPodOperator,
|
51
|
+
)
|
50
52
|
|
51
53
|
if TYPE_CHECKING:
|
52
54
|
from airflow.utils.context import Context
|
@@ -74,6 +74,11 @@ class GlueJobOperator(BaseOperator):
|
|
74
74
|
:param update_config: If True, Operator will update job configuration. (default: False)
|
75
75
|
:param replace_script_file: If True, the script file will be replaced in S3. (default: False)
|
76
76
|
:param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed.
|
77
|
+
:param sleep_before_return: time in seconds to wait before returning final status. This is meaningful in case
|
78
|
+
of limiting concurrency, Glue needs 5-10 seconds to clean up resources.
|
79
|
+
Thus if status is returned immediately it might end up in case of more than 1 concurrent run.
|
80
|
+
It is recommended to set this parameter to 10 when you are using concurrency=1.
|
81
|
+
For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA
|
77
82
|
"""
|
78
83
|
|
79
84
|
template_fields: Sequence[str] = (
|
@@ -118,6 +123,7 @@ class GlueJobOperator(BaseOperator):
|
|
118
123
|
update_config: bool = False,
|
119
124
|
job_poll_interval: int | float = 6,
|
120
125
|
stop_job_run_on_kill: bool = False,
|
126
|
+
sleep_before_return: int = 0,
|
121
127
|
**kwargs,
|
122
128
|
):
|
123
129
|
super().__init__(**kwargs)
|
@@ -145,6 +151,7 @@ class GlueJobOperator(BaseOperator):
|
|
145
151
|
self.job_poll_interval = job_poll_interval
|
146
152
|
self.stop_job_run_on_kill = stop_job_run_on_kill
|
147
153
|
self._job_run_id: str | None = None
|
154
|
+
self.sleep_before_return: int = sleep_before_return
|
148
155
|
|
149
156
|
@cached_property
|
150
157
|
def glue_job_hook(self) -> GlueJobHook:
|
@@ -220,7 +227,9 @@ class GlueJobOperator(BaseOperator):
|
|
220
227
|
method_name="execute_complete",
|
221
228
|
)
|
222
229
|
elif self.wait_for_completion:
|
223
|
-
glue_job_run = self.glue_job_hook.job_completion(
|
230
|
+
glue_job_run = self.glue_job_hook.job_completion(
|
231
|
+
self.job_name, self._job_run_id, self.verbose, self.sleep_before_return
|
232
|
+
)
|
224
233
|
self.log.info(
|
225
234
|
"AWS Glue Job: %s status: %s. Run Id: %s",
|
226
235
|
self.job_name,
|
@@ -16,7 +16,7 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from typing import TYPE_CHECKING, Any, Sequence
|
19
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Sequence
|
20
20
|
|
21
21
|
from botocore.exceptions import ClientError
|
22
22
|
|
@@ -70,7 +70,7 @@ class KinesisAnalyticsV2CreateApplicationOperator(AwsBaseOperator[KinesisAnalyti
|
|
70
70
|
"create_application_kwargs",
|
71
71
|
"application_description",
|
72
72
|
)
|
73
|
-
template_fields_renderers: dict = {
|
73
|
+
template_fields_renderers: ClassVar[dict] = {
|
74
74
|
"create_application_kwargs": "json",
|
75
75
|
}
|
76
76
|
|
@@ -149,7 +149,7 @@ class KinesisAnalyticsV2StartApplicationOperator(AwsBaseOperator[KinesisAnalytic
|
|
149
149
|
"application_name",
|
150
150
|
"run_configuration",
|
151
151
|
)
|
152
|
-
template_fields_renderers: dict = {
|
152
|
+
template_fields_renderers: ClassVar[dict] = {
|
153
153
|
"run_configuration": "json",
|
154
154
|
}
|
155
155
|
|