apache-airflow-providers-amazon 8.29.0__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/{datasets → assets}/s3.py +10 -6
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +5 -11
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +2 -5
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +0 -6
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +3 -17
- airflow/providers/amazon/aws/hooks/base_aws.py +4 -162
- airflow/providers/amazon/aws/hooks/logs.py +1 -20
- airflow/providers/amazon/aws/hooks/quicksight.py +1 -17
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +6 -120
- airflow/providers/amazon/aws/hooks/redshift_data.py +52 -14
- airflow/providers/amazon/aws/hooks/s3.py +24 -27
- airflow/providers/amazon/aws/hooks/sagemaker.py +4 -48
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -6
- airflow/providers/amazon/aws/operators/appflow.py +1 -10
- airflow/providers/amazon/aws/operators/batch.py +1 -29
- airflow/providers/amazon/aws/operators/datasync.py +1 -8
- airflow/providers/amazon/aws/operators/ecs.py +1 -25
- airflow/providers/amazon/aws/operators/eks.py +7 -46
- airflow/providers/amazon/aws/operators/emr.py +16 -232
- airflow/providers/amazon/aws/operators/glue_databrew.py +1 -10
- airflow/providers/amazon/aws/operators/rds.py +3 -17
- airflow/providers/amazon/aws/operators/redshift_data.py +18 -3
- airflow/providers/amazon/aws/operators/s3.py +12 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +10 -32
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -40
- airflow/providers/amazon/aws/sensors/batch.py +1 -8
- airflow/providers/amazon/aws/sensors/dms.py +1 -8
- airflow/providers/amazon/aws/sensors/dynamodb.py +22 -8
- airflow/providers/amazon/aws/sensors/emr.py +0 -7
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +1 -8
- airflow/providers/amazon/aws/sensors/glue_crawler.py +1 -8
- airflow/providers/amazon/aws/sensors/quicksight.py +1 -29
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -8
- airflow/providers/amazon/aws/sensors/s3.py +1 -8
- airflow/providers/amazon/aws/sensors/sagemaker.py +2 -9
- airflow/providers/amazon/aws/sensors/sqs.py +1 -8
- airflow/providers/amazon/aws/sensors/step_function.py +1 -8
- airflow/providers/amazon/aws/transfers/base.py +1 -14
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +5 -33
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +15 -10
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +3 -6
- airflow/providers/amazon/aws/triggers/batch.py +1 -168
- airflow/providers/amazon/aws/triggers/eks.py +1 -20
- airflow/providers/amazon/aws/triggers/emr.py +0 -32
- airflow/providers/amazon/aws/triggers/glue_crawler.py +0 -11
- airflow/providers/amazon/aws/triggers/glue_databrew.py +0 -21
- airflow/providers/amazon/aws/triggers/rds.py +0 -79
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +5 -64
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -93
- airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py +106 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +4 -164
- airflow/providers/amazon/aws/utils/mixins.py +1 -23
- airflow/providers/amazon/aws/utils/openlineage.py +3 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +1 -1
- airflow/providers/amazon/get_provider_info.py +13 -4
- {apache_airflow_providers_amazon-8.29.0.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/METADATA +8 -9
- {apache_airflow_providers_amazon-8.29.0.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/RECORD +64 -64
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +0 -149
- /airflow/providers/amazon/aws/{datasets → assets}/__init__.py +0 -0
- {apache_airflow_providers_amazon-8.29.0.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.29.0.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/entry_points.txt +0 -0
@@ -18,8 +18,12 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import time
|
21
|
+
from dataclasses import dataclass
|
21
22
|
from pprint import pformat
|
22
23
|
from typing import TYPE_CHECKING, Any, Iterable
|
24
|
+
from uuid import UUID
|
25
|
+
|
26
|
+
from pendulum import duration
|
23
27
|
|
24
28
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
25
29
|
from airflow.providers.amazon.aws.utils import trim_none_values
|
@@ -35,6 +39,14 @@ FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
|
|
35
39
|
RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
|
36
40
|
|
37
41
|
|
42
|
+
@dataclass
|
43
|
+
class QueryExecutionOutput:
|
44
|
+
"""Describes the output of a query execution."""
|
45
|
+
|
46
|
+
statement_id: str
|
47
|
+
session_id: str | None
|
48
|
+
|
49
|
+
|
38
50
|
class RedshiftDataQueryFailedError(ValueError):
|
39
51
|
"""Raise an error that redshift data query failed."""
|
40
52
|
|
@@ -65,8 +77,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
65
77
|
|
66
78
|
def execute_query(
|
67
79
|
self,
|
68
|
-
database: str,
|
69
80
|
sql: str | list[str],
|
81
|
+
database: str | None = None,
|
70
82
|
cluster_identifier: str | None = None,
|
71
83
|
db_user: str | None = None,
|
72
84
|
parameters: Iterable | None = None,
|
@@ -76,23 +88,28 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
76
88
|
wait_for_completion: bool = True,
|
77
89
|
poll_interval: int = 10,
|
78
90
|
workgroup_name: str | None = None,
|
79
|
-
|
91
|
+
session_id: str | None = None,
|
92
|
+
session_keep_alive_seconds: int | None = None,
|
93
|
+
) -> QueryExecutionOutput:
|
80
94
|
"""
|
81
95
|
Execute a statement against Amazon Redshift.
|
82
96
|
|
83
|
-
:param database: the name of the database
|
84
97
|
:param sql: the SQL statement or list of SQL statement to run
|
98
|
+
:param database: the name of the database
|
85
99
|
:param cluster_identifier: unique identifier of a cluster
|
86
100
|
:param db_user: the database username
|
87
101
|
:param parameters: the parameters for the SQL statement
|
88
102
|
:param secret_arn: the name or ARN of the secret that enables db access
|
89
103
|
:param statement_name: the name of the SQL statement
|
90
|
-
:param with_event:
|
91
|
-
:param wait_for_completion:
|
104
|
+
:param with_event: whether to send an event to EventBridge
|
105
|
+
:param wait_for_completion: whether to wait for a result
|
92
106
|
:param poll_interval: how often in seconds to check the query status
|
93
107
|
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
|
94
108
|
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
|
95
109
|
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
|
110
|
+
:param session_id: the session identifier of the query
|
111
|
+
:param session_keep_alive_seconds: duration in seconds to keep the session alive after the query
|
112
|
+
finishes. The maximum time a session can keep alive is 24 hours
|
96
113
|
|
97
114
|
:returns statement_id: str, the UUID of the statement
|
98
115
|
"""
|
@@ -105,7 +122,28 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
105
122
|
"SecretArn": secret_arn,
|
106
123
|
"StatementName": statement_name,
|
107
124
|
"WorkgroupName": workgroup_name,
|
125
|
+
"SessionId": session_id,
|
126
|
+
"SessionKeepAliveSeconds": session_keep_alive_seconds,
|
108
127
|
}
|
128
|
+
|
129
|
+
if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1:
|
130
|
+
raise ValueError(
|
131
|
+
"Exactly one of cluster_identifier, workgroup_name, or session_id must be provided"
|
132
|
+
)
|
133
|
+
|
134
|
+
if session_id is not None:
|
135
|
+
msg = "session_id must be a valid UUID4"
|
136
|
+
try:
|
137
|
+
if UUID(session_id).version != 4:
|
138
|
+
raise ValueError(msg)
|
139
|
+
except ValueError:
|
140
|
+
raise ValueError(msg)
|
141
|
+
|
142
|
+
if session_keep_alive_seconds is not None and (
|
143
|
+
session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24
|
144
|
+
):
|
145
|
+
raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.")
|
146
|
+
|
109
147
|
if isinstance(sql, list):
|
110
148
|
kwargs["Sqls"] = sql
|
111
149
|
resp = self.conn.batch_execute_statement(**trim_none_values(kwargs))
|
@@ -115,13 +153,10 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
115
153
|
|
116
154
|
statement_id = resp["Id"]
|
117
155
|
|
118
|
-
if bool(cluster_identifier) is bool(workgroup_name):
|
119
|
-
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")
|
120
|
-
|
121
156
|
if wait_for_completion:
|
122
157
|
self.wait_for_results(statement_id, poll_interval=poll_interval)
|
123
158
|
|
124
|
-
return statement_id
|
159
|
+
return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId"))
|
125
160
|
|
126
161
|
def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
|
127
162
|
while True:
|
@@ -135,9 +170,9 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
135
170
|
def check_query_is_finished(self, statement_id: str) -> bool:
|
136
171
|
"""Check whether query finished, raise exception is failed."""
|
137
172
|
resp = self.conn.describe_statement(Id=statement_id)
|
138
|
-
return self.
|
173
|
+
return self.parse_statement_response(resp)
|
139
174
|
|
140
|
-
def
|
175
|
+
def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool:
|
141
176
|
"""Parse the response of describe_statement."""
|
142
177
|
status = resp["Status"]
|
143
178
|
if status == FINISHED_STATE:
|
@@ -179,8 +214,10 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
179
214
|
:param table: Name of the target table
|
180
215
|
:param database: the name of the database
|
181
216
|
:param schema: Name of the target schema, public by default
|
182
|
-
:param sql: the SQL statement or list of SQL statement to run
|
183
217
|
:param cluster_identifier: unique identifier of a cluster
|
218
|
+
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
|
219
|
+
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
|
220
|
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
|
184
221
|
:param db_user: the database username
|
185
222
|
:param secret_arn: the name or ARN of the secret that enables db access
|
186
223
|
:param statement_name: the name of the SQL statement
|
@@ -212,7 +249,8 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
212
249
|
with_event=with_event,
|
213
250
|
wait_for_completion=wait_for_completion,
|
214
251
|
poll_interval=poll_interval,
|
215
|
-
)
|
252
|
+
).statement_id
|
253
|
+
|
216
254
|
pk_columns = []
|
217
255
|
token = ""
|
218
256
|
while True:
|
@@ -251,4 +289,4 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
251
289
|
"""
|
252
290
|
async with self.async_conn as client:
|
253
291
|
resp = await client.describe_statement(Id=statement_id)
|
254
|
-
return self.
|
292
|
+
return self.parse_statement_response(resp)
|
@@ -28,7 +28,6 @@ import os
|
|
28
28
|
import re
|
29
29
|
import shutil
|
30
30
|
import time
|
31
|
-
import warnings
|
32
31
|
from contextlib import suppress
|
33
32
|
from copy import deepcopy
|
34
33
|
from datetime import datetime
|
@@ -41,8 +40,6 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable
|
|
41
40
|
from urllib.parse import urlsplit
|
42
41
|
from uuid import uuid4
|
43
42
|
|
44
|
-
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
|
45
|
-
|
46
43
|
if TYPE_CHECKING:
|
47
44
|
from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject
|
48
45
|
|
@@ -51,16 +48,25 @@ if TYPE_CHECKING:
|
|
51
48
|
with suppress(ImportError):
|
52
49
|
from aiobotocore.client import AioBaseClient
|
53
50
|
|
51
|
+
from importlib.util import find_spec
|
52
|
+
|
54
53
|
from asgiref.sync import sync_to_async
|
55
54
|
from boto3.s3.transfer import S3Transfer, TransferConfig
|
56
55
|
from botocore.exceptions import ClientError
|
57
56
|
|
58
|
-
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
57
|
+
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
59
58
|
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
|
60
59
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
61
60
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
62
61
|
from airflow.utils.helpers import chunks
|
63
62
|
|
63
|
+
if find_spec("airflow.assets"):
|
64
|
+
from airflow.lineage.hook import get_hook_lineage_collector
|
65
|
+
else:
|
66
|
+
# TODO: import from common.compat directly after common.compat providers with
|
67
|
+
# asset_compat_lineage_collector released
|
68
|
+
from airflow.providers.amazon.aws.utils.asset_compat_lineage_collector import get_hook_lineage_collector
|
69
|
+
|
64
70
|
logger = logging.getLogger(__name__)
|
65
71
|
|
66
72
|
|
@@ -119,15 +125,6 @@ def provide_bucket_name(func: Callable) -> Callable:
|
|
119
125
|
|
120
126
|
if "bucket_name" in self.service_config:
|
121
127
|
bound_args.arguments["bucket_name"] = self.service_config["bucket_name"]
|
122
|
-
elif self.conn_config and self.conn_config.schema:
|
123
|
-
warnings.warn(
|
124
|
-
"s3 conn_type, and the associated schema field, is deprecated. "
|
125
|
-
"Please use aws conn_type instead, and specify `bucket_name` "
|
126
|
-
"in `service_config.s3` within `extras`.",
|
127
|
-
AirflowProviderDeprecationWarning,
|
128
|
-
stacklevel=2,
|
129
|
-
)
|
130
|
-
bound_args.arguments["bucket_name"] = self.conn_config.schema
|
131
128
|
|
132
129
|
return func(*bound_args.args, **bound_args.kwargs)
|
133
130
|
|
@@ -1113,11 +1110,11 @@ class S3Hook(AwsBaseHook):
|
|
1113
1110
|
|
1114
1111
|
client = self.get_conn()
|
1115
1112
|
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
|
1116
|
-
get_hook_lineage_collector().
|
1117
|
-
context=self, scheme="file",
|
1113
|
+
get_hook_lineage_collector().add_input_asset(
|
1114
|
+
context=self, scheme="file", asset_kwargs={"path": filename}
|
1118
1115
|
)
|
1119
|
-
get_hook_lineage_collector().
|
1120
|
-
context=self, scheme="s3",
|
1116
|
+
get_hook_lineage_collector().add_output_asset(
|
1117
|
+
context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key}
|
1121
1118
|
)
|
1122
1119
|
|
1123
1120
|
@unify_bucket_name_and_key
|
@@ -1260,8 +1257,8 @@ class S3Hook(AwsBaseHook):
|
|
1260
1257
|
Config=self.transfer_config,
|
1261
1258
|
)
|
1262
1259
|
# No input because file_obj can be anything - handle in calling function if possible
|
1263
|
-
get_hook_lineage_collector().
|
1264
|
-
context=self, scheme="s3",
|
1260
|
+
get_hook_lineage_collector().add_output_asset(
|
1261
|
+
context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key}
|
1265
1262
|
)
|
1266
1263
|
|
1267
1264
|
def copy_object(
|
@@ -1318,11 +1315,11 @@ class S3Hook(AwsBaseHook):
|
|
1318
1315
|
response = self.get_conn().copy_object(
|
1319
1316
|
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
|
1320
1317
|
)
|
1321
|
-
get_hook_lineage_collector().
|
1322
|
-
context=self, scheme="s3",
|
1318
|
+
get_hook_lineage_collector().add_input_asset(
|
1319
|
+
context=self, scheme="s3", asset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key}
|
1323
1320
|
)
|
1324
|
-
get_hook_lineage_collector().
|
1325
|
-
context=self, scheme="s3",
|
1321
|
+
get_hook_lineage_collector().add_output_asset(
|
1322
|
+
context=self, scheme="s3", asset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key}
|
1326
1323
|
)
|
1327
1324
|
return response
|
1328
1325
|
|
@@ -1443,10 +1440,10 @@ class S3Hook(AwsBaseHook):
|
|
1443
1440
|
|
1444
1441
|
file_path.parent.mkdir(exist_ok=True, parents=True)
|
1445
1442
|
|
1446
|
-
get_hook_lineage_collector().
|
1443
|
+
get_hook_lineage_collector().add_output_asset(
|
1447
1444
|
context=self,
|
1448
1445
|
scheme="file",
|
1449
|
-
|
1446
|
+
asset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()},
|
1450
1447
|
)
|
1451
1448
|
file = open(file_path, "wb")
|
1452
1449
|
else:
|
@@ -1458,8 +1455,8 @@ class S3Hook(AwsBaseHook):
|
|
1458
1455
|
ExtraArgs=self.extra_args,
|
1459
1456
|
Config=self.transfer_config,
|
1460
1457
|
)
|
1461
|
-
get_hook_lineage_collector().
|
1462
|
-
context=self, scheme="s3",
|
1458
|
+
get_hook_lineage_collector().add_input_asset(
|
1459
|
+
context=self, scheme="s3", asset_kwargs={"bucket": bucket_name, "key": key}
|
1463
1460
|
)
|
1464
1461
|
return file.name
|
1465
1462
|
|
@@ -22,7 +22,6 @@ import re
|
|
22
22
|
import tarfile
|
23
23
|
import tempfile
|
24
24
|
import time
|
25
|
-
import warnings
|
26
25
|
from collections import Counter, namedtuple
|
27
26
|
from datetime import datetime
|
28
27
|
from functools import partial
|
@@ -31,7 +30,7 @@ from typing import Any, AsyncGenerator, Callable, Generator, cast
|
|
31
30
|
from asgiref.sync import sync_to_async
|
32
31
|
from botocore.exceptions import ClientError
|
33
32
|
|
34
|
-
from airflow.exceptions import AirflowException
|
33
|
+
from airflow.exceptions import AirflowException
|
35
34
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
36
35
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
37
36
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
@@ -155,6 +154,7 @@ class SageMakerHook(AwsBaseHook):
|
|
155
154
|
endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
|
156
155
|
pipeline_non_terminal_states = {"Executing", "Stopping"}
|
157
156
|
failed_states = {"Failed"}
|
157
|
+
training_failed_states = {*failed_states, "Stopped"}
|
158
158
|
|
159
159
|
def __init__(self, *args, **kwargs):
|
160
160
|
super().__init__(client_type="sagemaker", *args, **kwargs)
|
@@ -309,7 +309,7 @@ class SageMakerHook(AwsBaseHook):
|
|
309
309
|
self.check_training_status_with_log(
|
310
310
|
config["TrainingJobName"],
|
311
311
|
self.non_terminal_states,
|
312
|
-
self.
|
312
|
+
self.training_failed_states,
|
313
313
|
wait_for_completion,
|
314
314
|
check_interval,
|
315
315
|
max_ingestion_time,
|
@@ -1097,9 +1097,6 @@ class SageMakerHook(AwsBaseHook):
|
|
1097
1097
|
pipeline_name: str,
|
1098
1098
|
display_name: str = "airflow-triggered-execution",
|
1099
1099
|
pipeline_params: dict | None = None,
|
1100
|
-
wait_for_completion: bool = False,
|
1101
|
-
check_interval: int | None = None,
|
1102
|
-
verbose: bool = True,
|
1103
1100
|
) -> str:
|
1104
1101
|
"""
|
1105
1102
|
Start a new execution for a SageMaker pipeline.
|
@@ -1114,16 +1111,6 @@ class SageMakerHook(AwsBaseHook):
|
|
1114
1111
|
|
1115
1112
|
:return: the ARN of the pipeline execution launched.
|
1116
1113
|
"""
|
1117
|
-
if wait_for_completion or check_interval is not None:
|
1118
|
-
warnings.warn(
|
1119
|
-
"parameter `wait_for_completion` and `check_interval` are deprecated, "
|
1120
|
-
"remove them and call check_status yourself if you want to wait for completion",
|
1121
|
-
AirflowProviderDeprecationWarning,
|
1122
|
-
stacklevel=2,
|
1123
|
-
)
|
1124
|
-
if check_interval is None:
|
1125
|
-
check_interval = 30
|
1126
|
-
|
1127
1114
|
formatted_params = format_tags(pipeline_params, key_label="Name")
|
1128
1115
|
|
1129
1116
|
try:
|
@@ -1136,23 +1123,11 @@ class SageMakerHook(AwsBaseHook):
|
|
1136
1123
|
self.log.error("Failed to start pipeline execution, error: %s", ce)
|
1137
1124
|
raise
|
1138
1125
|
|
1139
|
-
|
1140
|
-
if wait_for_completion:
|
1141
|
-
self.check_status(
|
1142
|
-
arn,
|
1143
|
-
"PipelineExecutionStatus",
|
1144
|
-
lambda p: self.describe_pipeline_exec(p, verbose),
|
1145
|
-
check_interval,
|
1146
|
-
non_terminal_states=self.pipeline_non_terminal_states,
|
1147
|
-
)
|
1148
|
-
return arn
|
1126
|
+
return res["PipelineExecutionArn"]
|
1149
1127
|
|
1150
1128
|
def stop_pipeline(
|
1151
1129
|
self,
|
1152
1130
|
pipeline_exec_arn: str,
|
1153
|
-
wait_for_completion: bool = False,
|
1154
|
-
check_interval: int | None = None,
|
1155
|
-
verbose: bool = True,
|
1156
1131
|
fail_if_not_running: bool = False,
|
1157
1132
|
) -> str:
|
1158
1133
|
"""
|
@@ -1171,16 +1146,6 @@ class SageMakerHook(AwsBaseHook):
|
|
1171
1146
|
:return: Status of the pipeline execution after the operation.
|
1172
1147
|
One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
|
1173
1148
|
"""
|
1174
|
-
if wait_for_completion or check_interval is not None:
|
1175
|
-
warnings.warn(
|
1176
|
-
"parameter `wait_for_completion` and `check_interval` are deprecated, "
|
1177
|
-
"remove them and call check_status yourself if you want to wait for completion",
|
1178
|
-
AirflowProviderDeprecationWarning,
|
1179
|
-
stacklevel=2,
|
1180
|
-
)
|
1181
|
-
if check_interval is None:
|
1182
|
-
check_interval = 10
|
1183
|
-
|
1184
1149
|
for retries in reversed(range(5)):
|
1185
1150
|
try:
|
1186
1151
|
self.conn.stop_pipeline_execution(PipelineExecutionArn=pipeline_exec_arn)
|
@@ -1212,15 +1177,6 @@ class SageMakerHook(AwsBaseHook):
|
|
1212
1177
|
|
1213
1178
|
res = self.describe_pipeline_exec(pipeline_exec_arn)
|
1214
1179
|
|
1215
|
-
if wait_for_completion and res["PipelineExecutionStatus"] in self.pipeline_non_terminal_states:
|
1216
|
-
res = self.check_status(
|
1217
|
-
pipeline_exec_arn,
|
1218
|
-
"PipelineExecutionStatus",
|
1219
|
-
lambda p: self.describe_pipeline_exec(p, verbose),
|
1220
|
-
check_interval,
|
1221
|
-
non_terminal_states=self.pipeline_non_terminal_states,
|
1222
|
-
)
|
1223
|
-
|
1224
1180
|
return res["PipelineExecutionStatus"]
|
1225
1181
|
|
1226
1182
|
def create_model_package_group(self, package_group_name: str, package_group_desc: str = "") -> bool:
|
@@ -62,12 +62,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|
62
62
|
)
|
63
63
|
|
64
64
|
def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None:
|
65
|
-
|
66
|
-
# after Airflow 2.8 can always pass `identifier`
|
67
|
-
if getattr(super(), "supports_task_context_logging", False):
|
68
|
-
super().set_context(ti, identifier=identifier)
|
69
|
-
else:
|
70
|
-
super().set_context(ti)
|
65
|
+
super().set_context(ti, identifier=identifier)
|
71
66
|
# Local location and remote location is needed to open and
|
72
67
|
# upload local log file to S3 remote storage.
|
73
68
|
if TYPE_CHECKING:
|
@@ -17,11 +17,10 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import time
|
20
|
-
import warnings
|
21
20
|
from datetime import datetime, timedelta
|
22
21
|
from typing import TYPE_CHECKING, cast
|
23
22
|
|
24
|
-
from airflow.exceptions import AirflowException
|
23
|
+
from airflow.exceptions import AirflowException
|
25
24
|
from airflow.operators.python import ShortCircuitOperator
|
26
25
|
from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
|
27
26
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
@@ -140,7 +139,6 @@ class AppflowRunOperator(AppflowBaseOperator):
|
|
140
139
|
For more information on how to use this operator, take a look at the guide:
|
141
140
|
:ref:`howto/operator:AppflowRunOperator`
|
142
141
|
|
143
|
-
:param source: Obsolete, unnecessary for this operator
|
144
142
|
:param flow_name: The flow name
|
145
143
|
:param poll_interval: how often in seconds to check the query status
|
146
144
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
@@ -155,17 +153,10 @@ class AppflowRunOperator(AppflowBaseOperator):
|
|
155
153
|
def __init__(
|
156
154
|
self,
|
157
155
|
flow_name: str,
|
158
|
-
source: str | None = None,
|
159
156
|
poll_interval: int = 20,
|
160
157
|
wait_for_completion: bool = True,
|
161
158
|
**kwargs,
|
162
159
|
) -> None:
|
163
|
-
if source is not None:
|
164
|
-
warnings.warn(
|
165
|
-
"The `source` parameter is unused when simply running a flow, please remove it.",
|
166
|
-
AirflowProviderDeprecationWarning,
|
167
|
-
stacklevel=2,
|
168
|
-
)
|
169
160
|
super().__init__(
|
170
161
|
flow_name=flow_name,
|
171
162
|
flow_update=False,
|
@@ -26,13 +26,12 @@ AWS Batch services.
|
|
26
26
|
|
27
27
|
from __future__ import annotations
|
28
28
|
|
29
|
-
import warnings
|
30
29
|
from datetime import timedelta
|
31
30
|
from functools import cached_property
|
32
31
|
from typing import TYPE_CHECKING, Any, Sequence
|
33
32
|
|
34
33
|
from airflow.configuration import conf
|
35
|
-
from airflow.exceptions import AirflowException
|
34
|
+
from airflow.exceptions import AirflowException
|
36
35
|
from airflow.models import BaseOperator
|
37
36
|
from airflow.models.mappedoperator import MappedOperator
|
38
37
|
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
|
@@ -48,7 +47,6 @@ from airflow.providers.amazon.aws.triggers.batch import (
|
|
48
47
|
)
|
49
48
|
from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
|
50
49
|
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
51
|
-
from airflow.utils.types import NOTSET
|
52
50
|
|
53
51
|
if TYPE_CHECKING:
|
54
52
|
from airflow.utils.context import Context
|
@@ -65,7 +63,6 @@ class BatchOperator(BaseOperator):
|
|
65
63
|
:param job_name: the name for the job that will run on AWS Batch (templated)
|
66
64
|
:param job_definition: the job definition name on AWS Batch
|
67
65
|
:param job_queue: the queue name on AWS Batch
|
68
|
-
:param overrides: DEPRECATED, use container_overrides instead with the same value.
|
69
66
|
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
|
70
67
|
:param ecs_properties_override: the `ecsPropertiesOverride` parameter for boto3 (templated)
|
71
68
|
:param eks_properties_override: the `eksPropertiesOverride` parameter for boto3 (templated)
|
@@ -165,7 +162,6 @@ class BatchOperator(BaseOperator):
|
|
165
162
|
job_name: str,
|
166
163
|
job_definition: str,
|
167
164
|
job_queue: str,
|
168
|
-
overrides: dict | None = None, # deprecated
|
169
165
|
container_overrides: dict | None = None,
|
170
166
|
array_properties: dict | None = None,
|
171
167
|
ecs_properties_override: dict | None = None,
|
@@ -196,21 +192,6 @@ class BatchOperator(BaseOperator):
|
|
196
192
|
self.job_queue = job_queue
|
197
193
|
|
198
194
|
self.container_overrides = container_overrides
|
199
|
-
# handle `overrides` deprecation in favor of `container_overrides`
|
200
|
-
if overrides:
|
201
|
-
if container_overrides:
|
202
|
-
# disallow setting both old and new params
|
203
|
-
raise AirflowException(
|
204
|
-
"'container_overrides' replaces the 'overrides' parameter. "
|
205
|
-
"You cannot specify both. Please remove assignation to the deprecated 'overrides'."
|
206
|
-
)
|
207
|
-
self.container_overrides = overrides
|
208
|
-
warnings.warn(
|
209
|
-
"Parameter `overrides` is deprecated, Please use `container_overrides` instead.",
|
210
|
-
AirflowProviderDeprecationWarning,
|
211
|
-
stacklevel=2,
|
212
|
-
)
|
213
|
-
|
214
195
|
self.ecs_properties_override = ecs_properties_override
|
215
196
|
self.eks_properties_override = eks_properties_override
|
216
197
|
self.node_overrides = node_overrides
|
@@ -501,17 +482,8 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator):
|
|
501
482
|
aws_conn_id: str | None = None,
|
502
483
|
region_name: str | None = None,
|
503
484
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
504
|
-
status_retries=NOTSET,
|
505
485
|
**kwargs,
|
506
486
|
):
|
507
|
-
if status_retries is not NOTSET:
|
508
|
-
warnings.warn(
|
509
|
-
"The `status_retries` parameter is unused and should be removed. "
|
510
|
-
"It'll be deleted in a future version.",
|
511
|
-
AirflowProviderDeprecationWarning,
|
512
|
-
stacklevel=2,
|
513
|
-
)
|
514
|
-
|
515
487
|
super().__init__(**kwargs)
|
516
488
|
|
517
489
|
self.compute_environment_name = compute_environment_name
|
@@ -22,9 +22,7 @@ import logging
|
|
22
22
|
import random
|
23
23
|
from typing import TYPE_CHECKING, Any, Sequence
|
24
24
|
|
25
|
-
from
|
26
|
-
|
27
|
-
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowTaskTimeout
|
25
|
+
from airflow.exceptions import AirflowException, AirflowTaskTimeout
|
28
26
|
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
|
29
27
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
30
28
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
@@ -199,11 +197,6 @@ class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
|
|
199
197
|
def _hook_parameters(self) -> dict[str, Any]:
|
200
198
|
return {**super()._hook_parameters, "wait_interval_seconds": self.wait_interval_seconds}
|
201
199
|
|
202
|
-
@deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning)
|
203
|
-
def get_hook(self) -> DataSyncHook:
|
204
|
-
"""Create and return DataSyncHook."""
|
205
|
-
return self.hook
|
206
|
-
|
207
200
|
def execute(self, context: Context):
|
208
201
|
# If task_arn was not specified then try to
|
209
202
|
# find 0, 1 or many candidate DataSync Tasks to run
|
@@ -18,13 +18,12 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import re
|
21
|
-
import warnings
|
22
21
|
from datetime import timedelta
|
23
22
|
from functools import cached_property
|
24
23
|
from typing import TYPE_CHECKING, Any, Sequence
|
25
24
|
|
26
25
|
from airflow.configuration import conf
|
27
|
-
from airflow.exceptions import AirflowException
|
26
|
+
from airflow.exceptions import AirflowException
|
28
27
|
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
|
29
28
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
30
29
|
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, should_retry_eni
|
@@ -40,7 +39,6 @@ from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
|
|
40
39
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
41
40
|
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
|
42
41
|
from airflow.utils.helpers import prune_dict
|
43
|
-
from airflow.utils.types import NOTSET
|
44
42
|
|
45
43
|
if TYPE_CHECKING:
|
46
44
|
import boto3
|
@@ -258,19 +256,8 @@ class EcsDeregisterTaskDefinitionOperator(EcsBaseOperator):
|
|
258
256
|
self,
|
259
257
|
*,
|
260
258
|
task_definition: str,
|
261
|
-
wait_for_completion=NOTSET,
|
262
|
-
waiter_delay=NOTSET,
|
263
|
-
waiter_max_attempts=NOTSET,
|
264
259
|
**kwargs,
|
265
260
|
):
|
266
|
-
if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]):
|
267
|
-
warnings.warn(
|
268
|
-
"'wait_for_completion' and waiter related params have no effect and are deprecated, "
|
269
|
-
"please remove them.",
|
270
|
-
AirflowProviderDeprecationWarning,
|
271
|
-
stacklevel=2,
|
272
|
-
)
|
273
|
-
|
274
261
|
super().__init__(**kwargs)
|
275
262
|
self.task_definition = task_definition
|
276
263
|
|
@@ -311,19 +298,8 @@ class EcsRegisterTaskDefinitionOperator(EcsBaseOperator):
|
|
311
298
|
family: str,
|
312
299
|
container_definitions: list[dict],
|
313
300
|
register_task_kwargs: dict | None = None,
|
314
|
-
wait_for_completion=NOTSET,
|
315
|
-
waiter_delay=NOTSET,
|
316
|
-
waiter_max_attempts=NOTSET,
|
317
301
|
**kwargs,
|
318
302
|
):
|
319
|
-
if any(arg is not NOTSET for arg in [wait_for_completion, waiter_delay, waiter_max_attempts]):
|
320
|
-
warnings.warn(
|
321
|
-
"'wait_for_completion' and waiter related params have no effect and are deprecated, "
|
322
|
-
"please remove them.",
|
323
|
-
AirflowProviderDeprecationWarning,
|
324
|
-
stacklevel=2,
|
325
|
-
)
|
326
|
-
|
327
303
|
super().__init__(**kwargs)
|
328
304
|
self.family = family
|
329
305
|
self.container_definitions = container_definitions
|