apache-airflow-providers-amazon 8.3.1rc1__py3-none-any.whl → 8.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +4 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +29 -12
- airflow/providers/amazon/aws/hooks/emr.py +17 -9
- airflow/providers/amazon/aws/hooks/eventbridge.py +27 -0
- airflow/providers/amazon/aws/hooks/redshift_data.py +10 -0
- airflow/providers/amazon/aws/hooks/sagemaker.py +24 -14
- airflow/providers/amazon/aws/notifications/chime.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +140 -7
- airflow/providers/amazon/aws/operators/emr.py +202 -22
- airflow/providers/amazon/aws/operators/eventbridge.py +87 -0
- airflow/providers/amazon/aws/operators/rds.py +120 -48
- airflow/providers/amazon/aws/operators/redshift_data.py +7 -0
- airflow/providers/amazon/aws/operators/sagemaker.py +75 -7
- airflow/providers/amazon/aws/operators/step_function.py +34 -2
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
- airflow/providers/amazon/aws/triggers/batch.py +1 -1
- airflow/providers/amazon/aws/triggers/ecs.py +7 -5
- airflow/providers/amazon/aws/triggers/eks.py +174 -3
- airflow/providers/amazon/aws/triggers/emr.py +215 -1
- airflow/providers/amazon/aws/triggers/rds.py +161 -5
- airflow/providers/amazon/aws/triggers/sagemaker.py +84 -1
- airflow/providers/amazon/aws/triggers/step_function.py +59 -0
- airflow/providers/amazon/aws/utils/__init__.py +16 -1
- airflow/providers/amazon/aws/utils/rds.py +2 -2
- airflow/providers/amazon/aws/waiters/sagemaker.json +46 -0
- airflow/providers/amazon/aws/waiters/stepfunctions.json +36 -0
- airflow/providers/amazon/get_provider_info.py +21 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/METADATA +13 -13
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/RECORD +34 -30
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/top_level.txt +0 -0
@@ -28,14 +28,16 @@ import packaging.version
|
|
28
28
|
|
29
29
|
__all__ = ["__version__"]
|
30
30
|
|
31
|
-
__version__ = "8.
|
31
|
+
__version__ = "8.4.0"
|
32
32
|
|
33
33
|
try:
|
34
34
|
from airflow import __version__ as airflow_version
|
35
35
|
except ImportError:
|
36
36
|
from airflow.version import version as airflow_version
|
37
37
|
|
38
|
-
if packaging.version.parse(airflow_version) < packaging.version.parse(
|
38
|
+
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
39
|
+
"2.4.0"
|
40
|
+
):
|
39
41
|
raise RuntimeError(
|
40
42
|
f"The package `apache-airflow-providers-amazon:{__version__}` requires Apache Airflow 2.4.0+" # NOQA: E501
|
41
43
|
)
|
@@ -197,22 +197,35 @@ class BaseSessionFactory(LoggingMixin):
|
|
197
197
|
def _create_session_with_assume_role(
|
198
198
|
self, session_kwargs: dict[str, Any], deferrable: bool = False
|
199
199
|
) -> boto3.session.Session:
|
200
|
-
|
201
200
|
if self.conn.assume_role_method == "assume_role_with_web_identity":
|
202
201
|
# Deferred credentials have no initial credentials
|
203
202
|
credential_fetcher = self._get_web_identity_credential_fetcher()
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
203
|
+
|
204
|
+
params = {
|
205
|
+
"method": "assume-role-with-web-identity",
|
206
|
+
"refresh_using": credential_fetcher.fetch_credentials,
|
207
|
+
"time_fetcher": lambda: datetime.datetime.now(tz=tzlocal()),
|
208
|
+
}
|
209
|
+
|
210
|
+
if deferrable:
|
211
|
+
from aiobotocore.credentials import AioDeferredRefreshableCredentials
|
212
|
+
|
213
|
+
credentials = AioDeferredRefreshableCredentials(**params)
|
214
|
+
else:
|
215
|
+
credentials = botocore.credentials.DeferredRefreshableCredentials(**params)
|
209
216
|
else:
|
210
217
|
# Refreshable credentials do have initial credentials
|
211
|
-
|
212
|
-
metadata
|
213
|
-
refresh_using
|
214
|
-
method
|
215
|
-
|
218
|
+
params = {
|
219
|
+
"metadata": self._refresh_credentials(),
|
220
|
+
"refresh_using": self._refresh_credentials,
|
221
|
+
"method": "sts-assume-role",
|
222
|
+
}
|
223
|
+
if deferrable:
|
224
|
+
from aiobotocore.credentials import AioRefreshableCredentials
|
225
|
+
|
226
|
+
credentials = AioRefreshableCredentials.create_from_metadata(**params)
|
227
|
+
else:
|
228
|
+
credentials = botocore.credentials.RefreshableCredentials.create_from_metadata(**params)
|
216
229
|
|
217
230
|
if deferrable:
|
218
231
|
from aiobotocore.session import get_session as async_get_session
|
@@ -796,7 +809,11 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
796
809
|
"""
|
797
810
|
try:
|
798
811
|
session = self.get_session()
|
799
|
-
|
812
|
+
test_endpoint_url = self.conn_config.extra_config.get("test_endpoint_url")
|
813
|
+
conn_info = session.client(
|
814
|
+
"sts",
|
815
|
+
endpoint_url=test_endpoint_url,
|
816
|
+
).get_caller_identity()
|
800
817
|
metadata = conn_info.pop("ResponseMetadata", {})
|
801
818
|
if metadata.get("HTTPStatusCode") != 200:
|
802
819
|
try:
|
@@ -256,9 +256,14 @@ class EmrServerlessHook(AwsBaseHook):
|
|
256
256
|
kwargs["client_type"] = "emr-serverless"
|
257
257
|
super().__init__(*args, **kwargs)
|
258
258
|
|
259
|
-
def cancel_running_jobs(
|
259
|
+
def cancel_running_jobs(
|
260
|
+
self, application_id: str, waiter_config: dict | None = None, wait_for_completion: bool = True
|
261
|
+
) -> int:
|
260
262
|
"""
|
261
|
-
|
263
|
+
Cancel jobs in an intermediate state, and return the number of cancelled jobs.
|
264
|
+
|
265
|
+
If wait_for_completion is True, then the method will wait until all jobs are
|
266
|
+
cancelled before returning.
|
262
267
|
|
263
268
|
Note: if new jobs are triggered while this operation is ongoing,
|
264
269
|
it's going to time out and return an error.
|
@@ -284,13 +289,16 @@ class EmrServerlessHook(AwsBaseHook):
|
|
284
289
|
)
|
285
290
|
for job_id in job_ids:
|
286
291
|
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
|
287
|
-
if
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
292
|
+
if wait_for_completion:
|
293
|
+
if count > 0:
|
294
|
+
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
|
295
|
+
self.get_waiter("no_job_running").wait(
|
296
|
+
applicationId=application_id,
|
297
|
+
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
|
298
|
+
WaiterConfig=waiter_config or {},
|
299
|
+
)
|
300
|
+
|
301
|
+
return count
|
294
302
|
|
295
303
|
|
296
304
|
class EmrContainerHook(AwsBaseHook):
|
@@ -0,0 +1,27 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
20
|
+
|
21
|
+
|
22
|
+
class EventBridgeHook(AwsBaseHook):
|
23
|
+
"""Amazon EventBridge Hook."""
|
24
|
+
|
25
|
+
def __init__(self, *args, **kwargs):
|
26
|
+
"""Creating object."""
|
27
|
+
super().__init__(client_type="events", *args, **kwargs)
|
@@ -60,6 +60,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
60
60
|
with_event: bool = False,
|
61
61
|
wait_for_completion: bool = True,
|
62
62
|
poll_interval: int = 10,
|
63
|
+
workgroup_name: str | None = None,
|
63
64
|
) -> str:
|
64
65
|
"""
|
65
66
|
Execute a statement against Amazon Redshift.
|
@@ -74,6 +75,9 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
74
75
|
:param with_event: indicates whether to send an event to EventBridge
|
75
76
|
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
|
76
77
|
:param poll_interval: how often in seconds to check the query status
|
78
|
+
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
|
79
|
+
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
|
80
|
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
|
77
81
|
|
78
82
|
:returns statement_id: str, the UUID of the statement
|
79
83
|
"""
|
@@ -85,6 +89,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
85
89
|
"WithEvent": with_event,
|
86
90
|
"SecretArn": secret_arn,
|
87
91
|
"StatementName": statement_name,
|
92
|
+
"WorkgroupName": workgroup_name,
|
88
93
|
}
|
89
94
|
if isinstance(sql, list):
|
90
95
|
kwargs["Sqls"] = sql
|
@@ -95,6 +100,9 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
95
100
|
|
96
101
|
statement_id = resp["Id"]
|
97
102
|
|
103
|
+
if bool(cluster_identifier) is bool(workgroup_name):
|
104
|
+
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")
|
105
|
+
|
98
106
|
if wait_for_completion:
|
99
107
|
self.wait_for_results(statement_id, poll_interval=poll_interval)
|
100
108
|
|
@@ -127,6 +135,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
127
135
|
database: str,
|
128
136
|
schema: str | None = "public",
|
129
137
|
cluster_identifier: str | None = None,
|
138
|
+
workgroup_name: str | None = None,
|
130
139
|
db_user: str | None = None,
|
131
140
|
secret_arn: str | None = None,
|
132
141
|
statement_name: str | None = None,
|
@@ -168,6 +177,7 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
|
|
168
177
|
sql=sql,
|
169
178
|
database=database,
|
170
179
|
cluster_identifier=cluster_identifier,
|
180
|
+
workgroup_name=workgroup_name,
|
171
181
|
db_user=db_user,
|
172
182
|
secret_arn=secret_arn,
|
173
183
|
statement_name=statement_name,
|
@@ -23,6 +23,7 @@ import re
|
|
23
23
|
import tarfile
|
24
24
|
import tempfile
|
25
25
|
import time
|
26
|
+
import warnings
|
26
27
|
from collections import Counter
|
27
28
|
from datetime import datetime
|
28
29
|
from functools import partial
|
@@ -30,7 +31,7 @@ from typing import Any, Callable, Generator, cast
|
|
30
31
|
|
31
32
|
from botocore.exceptions import ClientError
|
32
33
|
|
33
|
-
from airflow.exceptions import AirflowException
|
34
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
34
35
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
35
36
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
36
37
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
@@ -1061,7 +1062,7 @@ class SageMakerHook(AwsBaseHook):
|
|
1061
1062
|
display_name: str = "airflow-triggered-execution",
|
1062
1063
|
pipeline_params: dict | None = None,
|
1063
1064
|
wait_for_completion: bool = False,
|
1064
|
-
check_interval: int =
|
1065
|
+
check_interval: int | None = None,
|
1065
1066
|
verbose: bool = True,
|
1066
1067
|
) -> str:
|
1067
1068
|
"""Start a new execution for a SageMaker pipeline.
|
@@ -1073,14 +1074,19 @@ class SageMakerHook(AwsBaseHook):
|
|
1073
1074
|
:param display_name: The name this pipeline execution will have in the UI. Doesn't need to be unique.
|
1074
1075
|
:param pipeline_params: Optional parameters for the pipeline.
|
1075
1076
|
All parameters supplied need to already be present in the pipeline definition.
|
1076
|
-
:param wait_for_completion: Will only return once the pipeline is complete if true.
|
1077
|
-
:param check_interval: How long to wait between checks for pipeline status when waiting for
|
1078
|
-
completion.
|
1079
|
-
:param verbose: Whether to print steps details when waiting for completion.
|
1080
|
-
Defaults to true, consider turning off for pipelines that have thousands of steps.
|
1081
1077
|
|
1082
1078
|
:return: the ARN of the pipeline execution launched.
|
1083
1079
|
"""
|
1080
|
+
if wait_for_completion or check_interval is not None:
|
1081
|
+
warnings.warn(
|
1082
|
+
"parameter `wait_for_completion` and `check_interval` are deprecated, "
|
1083
|
+
"remove them and call check_status yourself if you want to wait for completion",
|
1084
|
+
AirflowProviderDeprecationWarning,
|
1085
|
+
stacklevel=2,
|
1086
|
+
)
|
1087
|
+
if check_interval is None:
|
1088
|
+
check_interval = 30
|
1089
|
+
|
1084
1090
|
formatted_params = format_tags(pipeline_params, key_label="Name")
|
1085
1091
|
|
1086
1092
|
try:
|
@@ -1108,7 +1114,7 @@ class SageMakerHook(AwsBaseHook):
|
|
1108
1114
|
self,
|
1109
1115
|
pipeline_exec_arn: str,
|
1110
1116
|
wait_for_completion: bool = False,
|
1111
|
-
check_interval: int =
|
1117
|
+
check_interval: int | None = None,
|
1112
1118
|
verbose: bool = True,
|
1113
1119
|
fail_if_not_running: bool = False,
|
1114
1120
|
) -> str:
|
@@ -1119,12 +1125,6 @@ class SageMakerHook(AwsBaseHook):
|
|
1119
1125
|
|
1120
1126
|
:param pipeline_exec_arn: Amazon Resource Name (ARN) of the pipeline execution.
|
1121
1127
|
It's the ARN of the pipeline itself followed by "/execution/" and an id.
|
1122
|
-
:param wait_for_completion: Whether to wait for the pipeline to reach a final state.
|
1123
|
-
(i.e. either 'Stopped' or 'Failed')
|
1124
|
-
:param check_interval: How long to wait between checks for pipeline status when waiting for
|
1125
|
-
completion.
|
1126
|
-
:param verbose: Whether to print steps details when waiting for completion.
|
1127
|
-
Defaults to true, consider turning off for pipelines that have thousands of steps.
|
1128
1128
|
:param fail_if_not_running: This method will raise an exception if the pipeline we're trying to stop
|
1129
1129
|
is not in an "Executing" state when the call is sent (which would mean that the pipeline is
|
1130
1130
|
already either stopping or stopped).
|
@@ -1133,6 +1133,16 @@ class SageMakerHook(AwsBaseHook):
|
|
1133
1133
|
:return: Status of the pipeline execution after the operation.
|
1134
1134
|
One of 'Executing'|'Stopping'|'Stopped'|'Failed'|'Succeeded'.
|
1135
1135
|
"""
|
1136
|
+
if wait_for_completion or check_interval is not None:
|
1137
|
+
warnings.warn(
|
1138
|
+
"parameter `wait_for_completion` and `check_interval` are deprecated, "
|
1139
|
+
"remove them and call check_status yourself if you want to wait for completion",
|
1140
|
+
AirflowProviderDeprecationWarning,
|
1141
|
+
stacklevel=2,
|
1142
|
+
)
|
1143
|
+
if check_interval is None:
|
1144
|
+
check_interval = 10
|
1145
|
+
|
1136
1146
|
retries = 2 # i.e. 3 calls max, 1 initial + 2 retries
|
1137
1147
|
while True:
|
1138
1148
|
try:
|
@@ -53,7 +53,7 @@ class ChimeNotifier(BaseNotifier):
|
|
53
53
|
"""To reduce overhead cache the hook for the notifier."""
|
54
54
|
return ChimeWebhookHook(chime_conn_id=self.chime_conn_id)
|
55
55
|
|
56
|
-
def notify(self, context: Context) -> None:
|
56
|
+
def notify(self, context: Context) -> None: # type: ignore[override]
|
57
57
|
"""Send a message to a Chime Chat Room."""
|
58
58
|
self.hook.send_message(message=self.message)
|
59
59
|
|
@@ -21,7 +21,7 @@ import logging
|
|
21
21
|
import warnings
|
22
22
|
from ast import literal_eval
|
23
23
|
from datetime import timedelta
|
24
|
-
from typing import TYPE_CHECKING, List, Sequence, cast
|
24
|
+
from typing import TYPE_CHECKING, Any, List, Sequence, cast
|
25
25
|
|
26
26
|
from botocore.exceptions import ClientError, WaiterError
|
27
27
|
|
@@ -30,8 +30,10 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni
|
|
30
30
|
from airflow.models import BaseOperator
|
31
31
|
from airflow.providers.amazon.aws.hooks.eks import EksHook
|
32
32
|
from airflow.providers.amazon.aws.triggers.eks import (
|
33
|
+
EksCreateClusterTrigger,
|
33
34
|
EksCreateFargateProfileTrigger,
|
34
35
|
EksCreateNodegroupTrigger,
|
36
|
+
EksDeleteClusterTrigger,
|
35
37
|
EksDeleteFargateProfileTrigger,
|
36
38
|
EksDeleteNodegroupTrigger,
|
37
39
|
)
|
@@ -187,6 +189,9 @@ class EksCreateClusterOperator(BaseOperator):
|
|
187
189
|
(templated)
|
188
190
|
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state
|
189
191
|
:param waiter_max_attempts: The maximum number of attempts to check cluster state
|
192
|
+
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
|
193
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
194
|
+
(default: False)
|
190
195
|
|
191
196
|
"""
|
192
197
|
|
@@ -225,6 +230,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
225
230
|
wait_for_completion: bool = False,
|
226
231
|
aws_conn_id: str = DEFAULT_CONN_ID,
|
227
232
|
region: str | None = None,
|
233
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
228
234
|
waiter_delay: int = 30,
|
229
235
|
waiter_max_attempts: int = 40,
|
230
236
|
**kwargs,
|
@@ -237,7 +243,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
237
243
|
self.nodegroup_role_arn = nodegroup_role_arn
|
238
244
|
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
|
239
245
|
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
|
240
|
-
self.wait_for_completion = wait_for_completion
|
246
|
+
self.wait_for_completion = False if deferrable else wait_for_completion
|
241
247
|
self.waiter_delay = waiter_delay
|
242
248
|
self.waiter_max_attempts = waiter_max_attempts
|
243
249
|
self.aws_conn_id = aws_conn_id
|
@@ -246,6 +252,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
246
252
|
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
|
247
253
|
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
|
248
254
|
self.fargate_profile_name = fargate_profile_name
|
255
|
+
self.deferrable = deferrable
|
249
256
|
super().__init__(
|
250
257
|
**kwargs,
|
251
258
|
)
|
@@ -274,12 +281,25 @@ class EksCreateClusterOperator(BaseOperator):
|
|
274
281
|
|
275
282
|
# Short circuit early if we don't need to wait to attach compute
|
276
283
|
# and the caller hasn't requested to wait for the cluster either.
|
277
|
-
if not self.compute
|
284
|
+
if not any([self.compute, self.wait_for_completion, self.deferrable]):
|
278
285
|
return None
|
279
286
|
|
280
|
-
self.log.info("Waiting for EKS Cluster to provision.
|
287
|
+
self.log.info("Waiting for EKS Cluster to provision. This will take some time.")
|
281
288
|
client = self.eks_hook.conn
|
282
289
|
|
290
|
+
if self.deferrable:
|
291
|
+
self.defer(
|
292
|
+
trigger=EksCreateClusterTrigger(
|
293
|
+
cluster_name=self.cluster_name,
|
294
|
+
aws_conn_id=self.aws_conn_id,
|
295
|
+
region_name=self.region,
|
296
|
+
waiter_delay=self.waiter_delay,
|
297
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
298
|
+
),
|
299
|
+
method_name="deferrable_create_cluster_next",
|
300
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
|
301
|
+
)
|
302
|
+
|
283
303
|
try:
|
284
304
|
client.get_waiter("cluster_active").wait(
|
285
305
|
name=self.cluster_name,
|
@@ -311,6 +331,89 @@ class EksCreateClusterOperator(BaseOperator):
|
|
311
331
|
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
|
312
332
|
)
|
313
333
|
|
334
|
+
def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
335
|
+
if event is None:
|
336
|
+
self.log.error("Trigger error: event is None")
|
337
|
+
raise AirflowException("Trigger error: event is None")
|
338
|
+
elif event["status"] == "failed":
|
339
|
+
self.log.error("Cluster failed to start and will be torn down.")
|
340
|
+
self.eks_hook.delete_cluster(name=self.cluster_name)
|
341
|
+
self.defer(
|
342
|
+
trigger=EksDeleteClusterTrigger(
|
343
|
+
cluster_name=self.cluster_name,
|
344
|
+
waiter_delay=self.waiter_delay,
|
345
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
346
|
+
aws_conn_id=self.aws_conn_id,
|
347
|
+
region_name=self.region,
|
348
|
+
force_delete_compute=False,
|
349
|
+
),
|
350
|
+
method_name="execute_failed",
|
351
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
|
352
|
+
)
|
353
|
+
elif event["status"] == "success":
|
354
|
+
self.log.info("Cluster is ready to provision compute.")
|
355
|
+
_create_compute(
|
356
|
+
compute=self.compute,
|
357
|
+
cluster_name=self.cluster_name,
|
358
|
+
aws_conn_id=self.aws_conn_id,
|
359
|
+
region=self.region,
|
360
|
+
wait_for_completion=self.wait_for_completion,
|
361
|
+
waiter_delay=self.waiter_delay,
|
362
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
363
|
+
nodegroup_name=self.nodegroup_name,
|
364
|
+
nodegroup_role_arn=self.nodegroup_role_arn,
|
365
|
+
create_nodegroup_kwargs=self.create_nodegroup_kwargs,
|
366
|
+
fargate_profile_name=self.fargate_profile_name,
|
367
|
+
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
|
368
|
+
fargate_selectors=self.fargate_selectors,
|
369
|
+
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
|
370
|
+
subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
|
371
|
+
)
|
372
|
+
if self.compute == "fargate":
|
373
|
+
self.defer(
|
374
|
+
trigger=EksCreateFargateProfileTrigger(
|
375
|
+
cluster_name=self.cluster_name,
|
376
|
+
fargate_profile_name=self.fargate_profile_name,
|
377
|
+
waiter_delay=self.waiter_delay,
|
378
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
379
|
+
aws_conn_id=self.aws_conn_id,
|
380
|
+
region=self.region,
|
381
|
+
),
|
382
|
+
method_name="execute_complete",
|
383
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
|
384
|
+
)
|
385
|
+
else:
|
386
|
+
self.defer(
|
387
|
+
trigger=EksCreateNodegroupTrigger(
|
388
|
+
nodegroup_name=self.nodegroup_name,
|
389
|
+
cluster_name=self.cluster_name,
|
390
|
+
aws_conn_id=self.aws_conn_id,
|
391
|
+
region_name=self.region,
|
392
|
+
waiter_delay=self.waiter_delay,
|
393
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
394
|
+
),
|
395
|
+
method_name="execute_complete",
|
396
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
|
397
|
+
)
|
398
|
+
|
399
|
+
def execute_failed(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
400
|
+
if event is None:
|
401
|
+
self.log.info("Trigger error: event is None")
|
402
|
+
raise AirflowException("Trigger error: event is None")
|
403
|
+
elif event["status"] == "delteted":
|
404
|
+
self.log.info("Cluster deleted")
|
405
|
+
raise event["exception"]
|
406
|
+
|
407
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
408
|
+
resource = "fargate profile" if self.compute == "fargate" else self.compute
|
409
|
+
if event is None:
|
410
|
+
self.log.info("Trigger error: event is None")
|
411
|
+
raise AirflowException("Trigger error: event is None")
|
412
|
+
elif event["status"] != "success":
|
413
|
+
raise AirflowException(f"Error creating {resource}: {event}")
|
414
|
+
|
415
|
+
self.log.info("%s created successfully", resource)
|
416
|
+
|
314
417
|
|
315
418
|
class EksCreateNodegroupOperator(BaseOperator):
|
316
419
|
"""
|
@@ -564,6 +667,11 @@ class EksDeleteClusterOperator(BaseOperator):
|
|
564
667
|
maintained on each worker node).
|
565
668
|
:param region: Which AWS region the connection should use. (templated)
|
566
669
|
If this is None or empty then the default boto3 behaviour is used.
|
670
|
+
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check cluster state
|
671
|
+
:param waiter_max_attempts: The maximum number of attempts to check cluster state
|
672
|
+
:param deferrable: If True, the operator will wait asynchronously for the cluster to be deleted.
|
673
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
674
|
+
(default: False)
|
567
675
|
|
568
676
|
"""
|
569
677
|
|
@@ -582,13 +690,19 @@ class EksDeleteClusterOperator(BaseOperator):
|
|
582
690
|
wait_for_completion: bool = False,
|
583
691
|
aws_conn_id: str = DEFAULT_CONN_ID,
|
584
692
|
region: str | None = None,
|
693
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
694
|
+
waiter_delay: int = 30,
|
695
|
+
waiter_max_attempts: int = 40,
|
585
696
|
**kwargs,
|
586
697
|
) -> None:
|
587
698
|
self.cluster_name = cluster_name
|
588
699
|
self.force_delete_compute = force_delete_compute
|
589
|
-
self.wait_for_completion = wait_for_completion
|
700
|
+
self.wait_for_completion = False if deferrable else wait_for_completion
|
590
701
|
self.aws_conn_id = aws_conn_id
|
591
702
|
self.region = region
|
703
|
+
self.deferrable = deferrable
|
704
|
+
self.waiter_delay = waiter_delay
|
705
|
+
self.waiter_max_attempts = waiter_max_attempts
|
592
706
|
super().__init__(**kwargs)
|
593
707
|
|
594
708
|
def execute(self, context: Context):
|
@@ -596,8 +710,20 @@ class EksDeleteClusterOperator(BaseOperator):
|
|
596
710
|
aws_conn_id=self.aws_conn_id,
|
597
711
|
region_name=self.region,
|
598
712
|
)
|
599
|
-
|
600
|
-
|
713
|
+
if self.deferrable:
|
714
|
+
self.defer(
|
715
|
+
trigger=EksDeleteClusterTrigger(
|
716
|
+
cluster_name=self.cluster_name,
|
717
|
+
waiter_delay=self.waiter_delay,
|
718
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
719
|
+
aws_conn_id=self.aws_conn_id,
|
720
|
+
region_name=self.region,
|
721
|
+
force_delete_compute=self.force_delete_compute,
|
722
|
+
),
|
723
|
+
method_name="execute_complete",
|
724
|
+
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
|
725
|
+
)
|
726
|
+
elif self.force_delete_compute:
|
601
727
|
self.delete_any_nodegroups(eks_hook)
|
602
728
|
self.delete_any_fargate_profiles(eks_hook)
|
603
729
|
|
@@ -645,6 +771,13 @@ class EksDeleteClusterOperator(BaseOperator):
|
|
645
771
|
)
|
646
772
|
self.log.info(SUCCESS_MSG.format(compute=FARGATE_FULL_NAME))
|
647
773
|
|
774
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
775
|
+
if event is None:
|
776
|
+
self.log.error("Trigger error. Event is None")
|
777
|
+
raise AirflowException("Trigger error. Event is None")
|
778
|
+
elif event["status"] == "success":
|
779
|
+
self.log.info("Cluster deleted successfully.")
|
780
|
+
|
648
781
|
|
649
782
|
class EksDeleteNodegroupOperator(BaseOperator):
|
650
783
|
"""
|