apache-airflow-providers-amazon 9.4.0rc1__py3-none-any.whl → 9.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +3 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +80 -110
- airflow/providers/amazon/aws/auth_manager/router/login.py +11 -4
- airflow/providers/amazon/aws/auth_manager/user.py +7 -4
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/appflow.py +5 -15
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +34 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/batch_waiters.py +11 -3
- airflow/providers/amazon/aws/hooks/dms.py +3 -1
- airflow/providers/amazon/aws/hooks/ec2.py +1 -1
- airflow/providers/amazon/aws/hooks/eks.py +3 -6
- airflow/providers/amazon/aws/hooks/glue.py +6 -2
- airflow/providers/amazon/aws/hooks/logs.py +2 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +79 -15
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +10 -10
- airflow/providers/amazon/aws/hooks/redshift_data.py +3 -4
- airflow/providers/amazon/aws/hooks/s3.py +3 -1
- airflow/providers/amazon/aws/hooks/sagemaker.py +2 -2
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +188 -0
- airflow/providers/amazon/aws/links/athena.py +1 -2
- airflow/providers/amazon/aws/links/base_aws.py +8 -1
- airflow/providers/amazon/aws/links/sagemaker_unified_studio.py +27 -0
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +174 -54
- airflow/providers/amazon/aws/log/s3_task_handler.py +136 -84
- airflow/providers/amazon/aws/notifications/chime.py +1 -2
- airflow/providers/amazon/aws/notifications/sns.py +1 -1
- airflow/providers/amazon/aws/notifications/sqs.py +1 -1
- airflow/providers/amazon/aws/operators/ec2.py +91 -83
- airflow/providers/amazon/aws/operators/eks.py +3 -3
- airflow/providers/amazon/aws/operators/mwaa.py +73 -2
- airflow/providers/amazon/aws/operators/redshift_cluster.py +10 -3
- airflow/providers/amazon/aws/operators/s3.py +147 -157
- airflow/providers/amazon/aws/operators/sagemaker.py +4 -7
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +155 -0
- airflow/providers/amazon/aws/sensors/ec2.py +5 -12
- airflow/providers/amazon/aws/sensors/emr.py +1 -1
- airflow/providers/amazon/aws/sensors/glacier.py +1 -1
- airflow/providers/amazon/aws/sensors/mwaa.py +161 -0
- airflow/providers/amazon/aws/sensors/rds.py +10 -5
- airflow/providers/amazon/aws/sensors/s3.py +32 -43
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +73 -0
- airflow/providers/amazon/aws/sensors/step_function.py +2 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +2 -2
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +19 -4
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +19 -3
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/README.md +4 -4
- airflow/providers/amazon/aws/triggers/base.py +11 -2
- airflow/providers/amazon/aws/triggers/ecs.py +6 -2
- airflow/providers/amazon/aws/triggers/eks.py +2 -2
- airflow/providers/amazon/aws/triggers/glue.py +1 -1
- airflow/providers/amazon/aws/triggers/mwaa.py +128 -0
- airflow/providers/amazon/aws/triggers/s3.py +31 -6
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -2
- airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py +66 -0
- airflow/providers/amazon/aws/triggers/sqs.py +11 -3
- airflow/providers/amazon/aws/{auth_manager/security_manager/__init__.py → utils/sagemaker_unified_studio.py} +12 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +4 -3
- airflow/providers/amazon/aws/waiters/mwaa.json +36 -0
- airflow/providers/amazon/get_provider_info.py +46 -5
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/METADATA +40 -33
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/RECORD +68 -61
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/WHEEL +1 -1
- airflow/providers/amazon/aws/auth_manager/security_manager/aws_security_manager_override.py +0 -40
- {apache_airflow_providers_amazon-9.4.0rc1.dist-info → apache_airflow_providers_amazon-9.5.0.dist-info}/entry_points.txt +0 -0
@@ -18,21 +18,22 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from collections.abc import Sequence
|
21
|
-
from typing import TYPE_CHECKING
|
21
|
+
from typing import TYPE_CHECKING, Any
|
22
22
|
|
23
23
|
from airflow.exceptions import AirflowException
|
24
|
-
from airflow.models import BaseOperator
|
25
24
|
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
|
26
25
|
from airflow.providers.amazon.aws.links.ec2 import (
|
27
26
|
EC2InstanceDashboardLink,
|
28
27
|
EC2InstanceLink,
|
29
28
|
)
|
29
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
30
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
30
31
|
|
31
32
|
if TYPE_CHECKING:
|
32
33
|
from airflow.utils.context import Context
|
33
34
|
|
34
35
|
|
35
|
-
class EC2StartInstanceOperator(
|
36
|
+
class EC2StartInstanceOperator(AwsBaseOperator[EC2Hook]):
|
36
37
|
"""
|
37
38
|
Start AWS EC2 instance using boto3.
|
38
39
|
|
@@ -41,18 +42,21 @@ class EC2StartInstanceOperator(BaseOperator):
|
|
41
42
|
:ref:`howto/operator:EC2StartInstanceOperator`
|
42
43
|
|
43
44
|
:param instance_id: id of the AWS EC2 instance
|
44
|
-
|
45
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
45
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
46
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
46
47
|
running Airflow in a distributed manner and aws_conn_id is None or
|
47
48
|
empty, then default boto3 configuration would be used (and must be
|
48
49
|
maintained on each worker node).
|
49
|
-
:param region_name:
|
50
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
51
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
52
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
50
53
|
:param check_interval: time in seconds that the job should wait in
|
51
54
|
between each instance state checks until operation is completed
|
52
55
|
"""
|
53
56
|
|
57
|
+
aws_hook_class = EC2Hook
|
54
58
|
operator_extra_links = (EC2InstanceLink(),)
|
55
|
-
template_fields: Sequence[str] = ("instance_id", "region_name")
|
59
|
+
template_fields: Sequence[str] = aws_template_fields("instance_id", "region_name")
|
56
60
|
ui_color = "#eeaa11"
|
57
61
|
ui_fgcolor = "#ffffff"
|
58
62
|
|
@@ -60,37 +64,32 @@ class EC2StartInstanceOperator(BaseOperator):
|
|
60
64
|
self,
|
61
65
|
*,
|
62
66
|
instance_id: str,
|
63
|
-
aws_conn_id: str | None = "aws_default",
|
64
|
-
region_name: str | None = None,
|
65
67
|
check_interval: float = 15,
|
66
68
|
**kwargs,
|
67
69
|
):
|
68
70
|
super().__init__(**kwargs)
|
69
71
|
self.instance_id = instance_id
|
70
|
-
self.aws_conn_id = aws_conn_id
|
71
|
-
self.region_name = region_name
|
72
72
|
self.check_interval = check_interval
|
73
73
|
|
74
74
|
def execute(self, context: Context):
|
75
|
-
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
76
75
|
self.log.info("Starting EC2 instance %s", self.instance_id)
|
77
|
-
instance =
|
76
|
+
instance = self.hook.get_instance(instance_id=self.instance_id)
|
78
77
|
instance.start()
|
79
78
|
EC2InstanceLink.persist(
|
80
79
|
context=context,
|
81
80
|
operator=self,
|
82
|
-
aws_partition=
|
81
|
+
aws_partition=self.hook.conn_partition,
|
83
82
|
instance_id=self.instance_id,
|
84
|
-
region_name=
|
83
|
+
region_name=self.hook.conn_region_name,
|
85
84
|
)
|
86
|
-
|
85
|
+
self.hook.wait_for_state(
|
87
86
|
instance_id=self.instance_id,
|
88
87
|
target_state="running",
|
89
88
|
check_interval=self.check_interval,
|
90
89
|
)
|
91
90
|
|
92
91
|
|
93
|
-
class EC2StopInstanceOperator(
|
92
|
+
class EC2StopInstanceOperator(AwsBaseOperator[EC2Hook]):
|
94
93
|
"""
|
95
94
|
Stop AWS EC2 instance using boto3.
|
96
95
|
|
@@ -100,17 +99,20 @@ class EC2StopInstanceOperator(BaseOperator):
|
|
100
99
|
|
101
100
|
:param instance_id: id of the AWS EC2 instance
|
102
101
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
103
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
102
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
104
103
|
running Airflow in a distributed manner and aws_conn_id is None or
|
105
104
|
empty, then default boto3 configuration would be used (and must be
|
106
105
|
maintained on each worker node).
|
107
|
-
:param region_name:
|
106
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
107
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
108
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
108
109
|
:param check_interval: time in seconds that the job should wait in
|
109
110
|
between each instance state checks until operation is completed
|
110
111
|
"""
|
111
112
|
|
113
|
+
aws_hook_class = EC2Hook
|
112
114
|
operator_extra_links = (EC2InstanceLink(),)
|
113
|
-
template_fields: Sequence[str] = ("instance_id", "region_name")
|
115
|
+
template_fields: Sequence[str] = aws_template_fields("instance_id", "region_name")
|
114
116
|
ui_color = "#eeaa11"
|
115
117
|
ui_fgcolor = "#ffffff"
|
116
118
|
|
@@ -118,38 +120,33 @@ class EC2StopInstanceOperator(BaseOperator):
|
|
118
120
|
self,
|
119
121
|
*,
|
120
122
|
instance_id: str,
|
121
|
-
aws_conn_id: str | None = "aws_default",
|
122
|
-
region_name: str | None = None,
|
123
123
|
check_interval: float = 15,
|
124
124
|
**kwargs,
|
125
125
|
):
|
126
126
|
super().__init__(**kwargs)
|
127
127
|
self.instance_id = instance_id
|
128
|
-
self.aws_conn_id = aws_conn_id
|
129
|
-
self.region_name = region_name
|
130
128
|
self.check_interval = check_interval
|
131
129
|
|
132
130
|
def execute(self, context: Context):
|
133
|
-
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
134
131
|
self.log.info("Stopping EC2 instance %s", self.instance_id)
|
135
|
-
instance =
|
132
|
+
instance = self.hook.get_instance(instance_id=self.instance_id)
|
136
133
|
EC2InstanceLink.persist(
|
137
134
|
context=context,
|
138
135
|
operator=self,
|
139
|
-
aws_partition=
|
136
|
+
aws_partition=self.hook.conn_partition,
|
140
137
|
instance_id=self.instance_id,
|
141
|
-
region_name=
|
138
|
+
region_name=self.hook.conn_region_name,
|
142
139
|
)
|
143
140
|
instance.stop()
|
144
141
|
|
145
|
-
|
142
|
+
self.hook.wait_for_state(
|
146
143
|
instance_id=self.instance_id,
|
147
144
|
target_state="stopped",
|
148
145
|
check_interval=self.check_interval,
|
149
146
|
)
|
150
147
|
|
151
148
|
|
152
|
-
class EC2CreateInstanceOperator(
|
149
|
+
class EC2CreateInstanceOperator(AwsBaseOperator[EC2Hook]):
|
153
150
|
"""
|
154
151
|
Create and start a specified number of EC2 Instances using boto3.
|
155
152
|
|
@@ -161,11 +158,13 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
161
158
|
:param max_count: Maximum number of instances to launch. Defaults to 1.
|
162
159
|
:param min_count: Minimum number of instances to launch. Defaults to 1.
|
163
160
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
164
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
161
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
165
162
|
running Airflow in a distributed manner and aws_conn_id is None or
|
166
163
|
empty, then default boto3 configuration would be used (and must be
|
167
164
|
maintained on each worker node).
|
168
|
-
:param region_name: AWS
|
165
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
166
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
167
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
169
168
|
:param poll_interval: Number of seconds to wait before attempting to
|
170
169
|
check state of instance. Only used if wait_for_completion is True. Default is 20.
|
171
170
|
:param max_attempts: Maximum number of attempts when checking state of instance.
|
@@ -175,8 +174,10 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
175
174
|
in the `running` state before returning.
|
176
175
|
"""
|
177
176
|
|
177
|
+
aws_hook_class = EC2Hook
|
178
|
+
|
178
179
|
operator_extra_links = (EC2InstanceDashboardLink(),)
|
179
|
-
template_fields: Sequence[str] = (
|
180
|
+
template_fields: Sequence[str] = aws_template_fields(
|
180
181
|
"image_id",
|
181
182
|
"max_count",
|
182
183
|
"min_count",
|
@@ -191,8 +192,6 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
191
192
|
image_id: str,
|
192
193
|
max_count: int = 1,
|
193
194
|
min_count: int = 1,
|
194
|
-
aws_conn_id: str | None = "aws_default",
|
195
|
-
region_name: str | None = None,
|
196
195
|
poll_interval: int = 20,
|
197
196
|
max_attempts: int = 20,
|
198
197
|
config: dict | None = None,
|
@@ -203,16 +202,17 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
203
202
|
self.image_id = image_id
|
204
203
|
self.max_count = max_count
|
205
204
|
self.min_count = min_count
|
206
|
-
self.aws_conn_id = aws_conn_id
|
207
|
-
self.region_name = region_name
|
208
205
|
self.poll_interval = poll_interval
|
209
206
|
self.max_attempts = max_attempts
|
210
207
|
self.config = config or {}
|
211
208
|
self.wait_for_completion = wait_for_completion
|
212
209
|
|
210
|
+
@property
|
211
|
+
def _hook_parameters(self) -> dict[str, Any]:
|
212
|
+
return {**super()._hook_parameters, "api_type": "client_type"}
|
213
|
+
|
213
214
|
def execute(self, context: Context):
|
214
|
-
|
215
|
-
instances = ec2_hook.conn.run_instances(
|
215
|
+
instances = self.hook.conn.run_instances(
|
216
216
|
ImageId=self.image_id,
|
217
217
|
MinCount=self.min_count,
|
218
218
|
MaxCount=self.max_count,
|
@@ -225,15 +225,15 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
225
225
|
EC2InstanceDashboardLink.persist(
|
226
226
|
context=context,
|
227
227
|
operator=self,
|
228
|
-
region_name=
|
229
|
-
aws_partition=
|
228
|
+
region_name=self.hook.conn_region_name,
|
229
|
+
aws_partition=self.hook.conn_partition,
|
230
230
|
instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(instance_ids),
|
231
231
|
)
|
232
232
|
for instance_id in instance_ids:
|
233
233
|
self.log.info("Created EC2 instance %s", instance_id)
|
234
234
|
|
235
235
|
if self.wait_for_completion:
|
236
|
-
|
236
|
+
self.hook.get_waiter("instance_running").wait(
|
237
237
|
InstanceIds=[instance_id],
|
238
238
|
WaiterConfig={
|
239
239
|
"Delay": self.poll_interval,
|
@@ -249,16 +249,16 @@ class EC2CreateInstanceOperator(BaseOperator):
|
|
249
249
|
|
250
250
|
if instance_ids:
|
251
251
|
self.log.info("on_kill: Terminating instance/s %s", ", ".join(instance_ids))
|
252
|
-
ec2_hook = EC2Hook(
|
252
|
+
""" ec2_hook = EC2Hook(
|
253
253
|
aws_conn_id=self.aws_conn_id,
|
254
254
|
region_name=self.region_name,
|
255
255
|
api_type="client_type",
|
256
|
-
)
|
257
|
-
|
256
|
+
) """
|
257
|
+
self.hook.terminate_instances(instance_ids=instance_ids)
|
258
258
|
super().on_kill()
|
259
259
|
|
260
260
|
|
261
|
-
class EC2TerminateInstanceOperator(
|
261
|
+
class EC2TerminateInstanceOperator(AwsBaseOperator[EC2Hook]):
|
262
262
|
"""
|
263
263
|
Terminate EC2 Instances using boto3.
|
264
264
|
|
@@ -268,11 +268,13 @@ class EC2TerminateInstanceOperator(BaseOperator):
|
|
268
268
|
|
269
269
|
:param instance_id: ID of the instance to be terminated.
|
270
270
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
271
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
271
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
272
272
|
running Airflow in a distributed manner and aws_conn_id is None or
|
273
273
|
empty, then default boto3 configuration would be used (and must be
|
274
274
|
maintained on each worker node).
|
275
|
-
:param region_name: AWS
|
275
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
276
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
277
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
276
278
|
:param poll_interval: Number of seconds to wait before attempting to
|
277
279
|
check state of instance. Only used if wait_for_completion is True. Default is 20.
|
278
280
|
:param max_attempts: Maximum number of attempts when checking state of instance.
|
@@ -281,13 +283,14 @@ class EC2TerminateInstanceOperator(BaseOperator):
|
|
281
283
|
in the `terminated` state before returning.
|
282
284
|
"""
|
283
285
|
|
284
|
-
|
286
|
+
aws_hook_class = EC2Hook
|
287
|
+
template_fields: Sequence[str] = aws_template_fields(
|
288
|
+
"instance_ids", "region_name", "aws_conn_id", "wait_for_completion"
|
289
|
+
)
|
285
290
|
|
286
291
|
def __init__(
|
287
292
|
self,
|
288
293
|
instance_ids: str | list[str],
|
289
|
-
aws_conn_id: str | None = "aws_default",
|
290
|
-
region_name: str | None = None,
|
291
294
|
poll_interval: int = 20,
|
292
295
|
max_attempts: int = 20,
|
293
296
|
wait_for_completion: bool = False,
|
@@ -295,22 +298,23 @@ class EC2TerminateInstanceOperator(BaseOperator):
|
|
295
298
|
):
|
296
299
|
super().__init__(**kwargs)
|
297
300
|
self.instance_ids = instance_ids
|
298
|
-
self.aws_conn_id = aws_conn_id
|
299
|
-
self.region_name = region_name
|
300
301
|
self.poll_interval = poll_interval
|
301
302
|
self.max_attempts = max_attempts
|
302
303
|
self.wait_for_completion = wait_for_completion
|
303
304
|
|
305
|
+
@property
|
306
|
+
def _hook_parameters(self) -> dict[str, Any]:
|
307
|
+
return {**super()._hook_parameters, "api_type": "client_type"}
|
308
|
+
|
304
309
|
def execute(self, context: Context):
|
305
310
|
if isinstance(self.instance_ids, str):
|
306
311
|
self.instance_ids = [self.instance_ids]
|
307
|
-
|
308
|
-
ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids)
|
312
|
+
self.hook.conn.terminate_instances(InstanceIds=self.instance_ids)
|
309
313
|
|
310
314
|
for instance_id in self.instance_ids:
|
311
315
|
self.log.info("Terminating EC2 instance %s", instance_id)
|
312
316
|
if self.wait_for_completion:
|
313
|
-
|
317
|
+
self.hook.get_waiter("instance_terminated").wait(
|
314
318
|
InstanceIds=[instance_id],
|
315
319
|
WaiterConfig={
|
316
320
|
"Delay": self.poll_interval,
|
@@ -319,7 +323,7 @@ class EC2TerminateInstanceOperator(BaseOperator):
|
|
319
323
|
)
|
320
324
|
|
321
325
|
|
322
|
-
class EC2RebootInstanceOperator(
|
326
|
+
class EC2RebootInstanceOperator(AwsBaseOperator[EC2Hook]):
|
323
327
|
"""
|
324
328
|
Reboot Amazon EC2 instances.
|
325
329
|
|
@@ -329,11 +333,13 @@ class EC2RebootInstanceOperator(BaseOperator):
|
|
329
333
|
|
330
334
|
:param instance_ids: ID of the instance(s) to be rebooted.
|
331
335
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
332
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
336
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
333
337
|
running Airflow in a distributed manner and aws_conn_id is None or
|
334
338
|
empty, then default boto3 configuration would be used (and must be
|
335
339
|
maintained on each worker node).
|
336
|
-
:param region_name: AWS
|
340
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
341
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
342
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
337
343
|
:param poll_interval: Number of seconds to wait before attempting to
|
338
344
|
check state of instance. Only used if wait_for_completion is True. Default is 20.
|
339
345
|
:param max_attempts: Maximum number of attempts when checking state of instance.
|
@@ -342,8 +348,9 @@ class EC2RebootInstanceOperator(BaseOperator):
|
|
342
348
|
in the `running` state before returning.
|
343
349
|
"""
|
344
350
|
|
351
|
+
aws_hook_class = EC2Hook
|
345
352
|
operator_extra_links = (EC2InstanceDashboardLink(),)
|
346
|
-
template_fields: Sequence[str] = ("instance_ids", "region_name")
|
353
|
+
template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name")
|
347
354
|
ui_color = "#eeaa11"
|
348
355
|
ui_fgcolor = "#ffffff"
|
349
356
|
|
@@ -351,8 +358,6 @@ class EC2RebootInstanceOperator(BaseOperator):
|
|
351
358
|
self,
|
352
359
|
*,
|
353
360
|
instance_ids: str | list[str],
|
354
|
-
aws_conn_id: str | None = "aws_default",
|
355
|
-
region_name: str | None = None,
|
356
361
|
poll_interval: int = 20,
|
357
362
|
max_attempts: int = 20,
|
358
363
|
wait_for_completion: bool = False,
|
@@ -360,29 +365,30 @@ class EC2RebootInstanceOperator(BaseOperator):
|
|
360
365
|
):
|
361
366
|
super().__init__(**kwargs)
|
362
367
|
self.instance_ids = instance_ids
|
363
|
-
self.aws_conn_id = aws_conn_id
|
364
|
-
self.region_name = region_name
|
365
368
|
self.poll_interval = poll_interval
|
366
369
|
self.max_attempts = max_attempts
|
367
370
|
self.wait_for_completion = wait_for_completion
|
368
371
|
|
372
|
+
@property
|
373
|
+
def _hook_parameters(self) -> dict[str, Any]:
|
374
|
+
return {**super()._hook_parameters, "api_type": "client_type"}
|
375
|
+
|
369
376
|
def execute(self, context: Context):
|
370
377
|
if isinstance(self.instance_ids, str):
|
371
378
|
self.instance_ids = [self.instance_ids]
|
372
|
-
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
|
373
379
|
self.log.info("Rebooting EC2 instances %s", ", ".join(self.instance_ids))
|
374
|
-
|
380
|
+
self.hook.conn.reboot_instances(InstanceIds=self.instance_ids)
|
375
381
|
|
376
382
|
# Console link is for EC2 dashboard list, not individual instances
|
377
383
|
EC2InstanceDashboardLink.persist(
|
378
384
|
context=context,
|
379
385
|
operator=self,
|
380
|
-
region_name=
|
381
|
-
aws_partition=
|
386
|
+
region_name=self.hook.conn_region_name,
|
387
|
+
aws_partition=self.hook.conn_partition,
|
382
388
|
instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids),
|
383
389
|
)
|
384
390
|
if self.wait_for_completion:
|
385
|
-
|
391
|
+
self.hook.get_waiter("instance_running").wait(
|
386
392
|
InstanceIds=self.instance_ids,
|
387
393
|
WaiterConfig={
|
388
394
|
"Delay": self.poll_interval,
|
@@ -391,7 +397,7 @@ class EC2RebootInstanceOperator(BaseOperator):
|
|
391
397
|
)
|
392
398
|
|
393
399
|
|
394
|
-
class EC2HibernateInstanceOperator(
|
400
|
+
class EC2HibernateInstanceOperator(AwsBaseOperator[EC2Hook]):
|
395
401
|
"""
|
396
402
|
Hibernate Amazon EC2 instances.
|
397
403
|
|
@@ -401,11 +407,13 @@ class EC2HibernateInstanceOperator(BaseOperator):
|
|
401
407
|
|
402
408
|
:param instance_ids: ID of the instance(s) to be hibernated.
|
403
409
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
404
|
-
If this is None or empty then the default boto3 behaviour is used. If
|
410
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
405
411
|
running Airflow in a distributed manner and aws_conn_id is None or
|
406
412
|
empty, then default boto3 configuration would be used (and must be
|
407
413
|
maintained on each worker node).
|
408
|
-
:param region_name: AWS
|
414
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
415
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
416
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
409
417
|
:param poll_interval: Number of seconds to wait before attempting to
|
410
418
|
check state of instance. Only used if wait_for_completion is True. Default is 20.
|
411
419
|
:param max_attempts: Maximum number of attempts when checking state of instance.
|
@@ -414,8 +422,9 @@ class EC2HibernateInstanceOperator(BaseOperator):
|
|
414
422
|
in the `stopped` state before returning.
|
415
423
|
"""
|
416
424
|
|
425
|
+
aws_hook_class = EC2Hook
|
417
426
|
operator_extra_links = (EC2InstanceDashboardLink(),)
|
418
|
-
template_fields: Sequence[str] = ("instance_ids", "region_name")
|
427
|
+
template_fields: Sequence[str] = aws_template_fields("instance_ids", "region_name")
|
419
428
|
ui_color = "#eeaa11"
|
420
429
|
ui_fgcolor = "#ffffff"
|
421
430
|
|
@@ -423,8 +432,6 @@ class EC2HibernateInstanceOperator(BaseOperator):
|
|
423
432
|
self,
|
424
433
|
*,
|
425
434
|
instance_ids: str | list[str],
|
426
|
-
aws_conn_id: str | None = "aws_default",
|
427
|
-
region_name: str | None = None,
|
428
435
|
poll_interval: int = 20,
|
429
436
|
max_attempts: int = 20,
|
430
437
|
wait_for_completion: bool = False,
|
@@ -432,25 +439,26 @@ class EC2HibernateInstanceOperator(BaseOperator):
|
|
432
439
|
):
|
433
440
|
super().__init__(**kwargs)
|
434
441
|
self.instance_ids = instance_ids
|
435
|
-
self.aws_conn_id = aws_conn_id
|
436
|
-
self.region_name = region_name
|
437
442
|
self.poll_interval = poll_interval
|
438
443
|
self.max_attempts = max_attempts
|
439
444
|
self.wait_for_completion = wait_for_completion
|
440
445
|
|
446
|
+
@property
|
447
|
+
def _hook_parameters(self) -> dict[str, Any]:
|
448
|
+
return {**super()._hook_parameters, "api_type": "client_type"}
|
449
|
+
|
441
450
|
def execute(self, context: Context):
|
442
451
|
if isinstance(self.instance_ids, str):
|
443
452
|
self.instance_ids = [self.instance_ids]
|
444
|
-
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
|
445
453
|
self.log.info("Hibernating EC2 instances %s", ", ".join(self.instance_ids))
|
446
|
-
instances =
|
454
|
+
instances = self.hook.get_instances(instance_ids=self.instance_ids)
|
447
455
|
|
448
456
|
# Console link is for EC2 dashboard list, not individual instances
|
449
457
|
EC2InstanceDashboardLink.persist(
|
450
458
|
context=context,
|
451
459
|
operator=self,
|
452
|
-
region_name=
|
453
|
-
aws_partition=
|
460
|
+
region_name=self.hook.conn_region_name,
|
461
|
+
aws_partition=self.hook.conn_partition,
|
454
462
|
instance_ids=EC2InstanceDashboardLink.format_instance_id_filter(self.instance_ids),
|
455
463
|
)
|
456
464
|
|
@@ -459,10 +467,10 @@ class EC2HibernateInstanceOperator(BaseOperator):
|
|
459
467
|
if not hibernation_options or not hibernation_options["Configured"]:
|
460
468
|
raise AirflowException(f"Instance {instance['InstanceId']} is not configured for hibernation")
|
461
469
|
|
462
|
-
|
470
|
+
self.hook.conn.stop_instances(InstanceIds=self.instance_ids, Hibernate=True)
|
463
471
|
|
464
472
|
if self.wait_for_completion:
|
465
|
-
|
473
|
+
self.hook.get_waiter("instance_stopped").wait(
|
466
474
|
InstanceIds=self.instance_ids,
|
467
475
|
WaiterConfig={
|
468
476
|
"Delay": self.poll_interval,
|
@@ -338,7 +338,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
338
338
|
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
|
339
339
|
fargate_selectors=self.fargate_selectors,
|
340
340
|
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
|
341
|
-
subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
|
341
|
+
subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
|
342
342
|
)
|
343
343
|
|
344
344
|
def deferrable_create_cluster_next(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
@@ -377,7 +377,7 @@ class EksCreateClusterOperator(BaseOperator):
|
|
377
377
|
fargate_pod_execution_role_arn=self.fargate_pod_execution_role_arn,
|
378
378
|
fargate_selectors=self.fargate_selectors,
|
379
379
|
create_fargate_profile_kwargs=self.create_fargate_profile_kwargs,
|
380
|
-
subnets=cast(list[str], self.resources_vpc_config.get("subnetIds")),
|
380
|
+
subnets=cast("list[str]", self.resources_vpc_config.get("subnetIds")),
|
381
381
|
)
|
382
382
|
if self.compute == "fargate":
|
383
383
|
self.defer(
|
@@ -503,7 +503,7 @@ class EksCreateNodegroupOperator(BaseOperator):
|
|
503
503
|
nodegroup_subnets_list: list[str] = []
|
504
504
|
if self.nodegroup_subnets != "":
|
505
505
|
try:
|
506
|
-
nodegroup_subnets_list = cast(list, literal_eval(self.nodegroup_subnets))
|
506
|
+
nodegroup_subnets_list = cast("list", literal_eval(self.nodegroup_subnets))
|
507
507
|
except ValueError:
|
508
508
|
self.log.warning(
|
509
509
|
"The nodegroup_subnets should be List or string representing "
|
@@ -19,10 +19,14 @@
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
21
|
from collections.abc import Sequence
|
22
|
-
from typing import TYPE_CHECKING
|
22
|
+
from typing import TYPE_CHECKING, Any
|
23
23
|
|
24
|
+
from airflow.configuration import conf
|
25
|
+
from airflow.exceptions import AirflowException
|
24
26
|
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
|
25
27
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
28
|
+
from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
|
29
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
26
30
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
27
31
|
|
28
32
|
if TYPE_CHECKING:
|
@@ -48,6 +52,23 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
48
52
|
:param conf: Additional configuration parameters. The value of this field can be set only when creating
|
49
53
|
the object. (templated)
|
50
54
|
:param note: Contains manually entered notes by the user about the DagRun. (templated)
|
55
|
+
|
56
|
+
:param wait_for_completion: Whether to wait for DAG run to stop. (default: False)
|
57
|
+
:param waiter_delay: Time in seconds to wait between status checks. (default: 120)
|
58
|
+
:param waiter_max_attempts: Maximum number of attempts to check for DAG run completion. (default: 720)
|
59
|
+
:param deferrable: If True, the operator will wait asynchronously for the DAG run to stop.
|
60
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
61
|
+
(default: False)
|
62
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
63
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
64
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
65
|
+
empty, then default boto3 configuration would be used (and must be
|
66
|
+
maintained on each worker node).
|
67
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
68
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
69
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
70
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
71
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
51
72
|
"""
|
52
73
|
|
53
74
|
aws_hook_class = MwaaHook
|
@@ -74,6 +95,10 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
74
95
|
data_interval_end: str | None = None,
|
75
96
|
conf: dict | None = None,
|
76
97
|
note: str | None = None,
|
98
|
+
wait_for_completion: bool = False,
|
99
|
+
waiter_delay: int = 60,
|
100
|
+
waiter_max_attempts: int = 720,
|
101
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
77
102
|
**kwargs,
|
78
103
|
):
|
79
104
|
super().__init__(**kwargs)
|
@@ -85,6 +110,21 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
85
110
|
self.data_interval_end = data_interval_end
|
86
111
|
self.conf = conf if conf else {}
|
87
112
|
self.note = note
|
113
|
+
self.wait_for_completion = wait_for_completion
|
114
|
+
self.waiter_delay = waiter_delay
|
115
|
+
self.waiter_max_attempts = waiter_max_attempts
|
116
|
+
self.deferrable = deferrable
|
117
|
+
|
118
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict:
|
119
|
+
validated_event = validate_execute_complete_event(event)
|
120
|
+
if validated_event["status"] != "success":
|
121
|
+
raise AirflowException(f"DAG run failed: {validated_event}")
|
122
|
+
|
123
|
+
dag_run_id = validated_event["dag_run_id"]
|
124
|
+
self.log.info("DAG run %s of DAG %s completed", dag_run_id, self.trigger_dag_id)
|
125
|
+
return self.hook.invoke_rest_api(
|
126
|
+
env_name=self.env_name, path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}", method="GET"
|
127
|
+
)
|
88
128
|
|
89
129
|
def execute(self, context: Context) -> dict:
|
90
130
|
"""
|
@@ -94,7 +134,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
94
134
|
:return: dict with information about the Dag run
|
95
135
|
For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api`
|
96
136
|
"""
|
97
|
-
|
137
|
+
response = self.hook.invoke_rest_api(
|
98
138
|
env_name=self.env_name,
|
99
139
|
path=f"/dags/{self.trigger_dag_id}/dagRuns",
|
100
140
|
method="POST",
|
@@ -107,3 +147,34 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
107
147
|
"note": self.note,
|
108
148
|
},
|
109
149
|
)
|
150
|
+
|
151
|
+
dag_run_id = response["RestApiResponse"]["dag_run_id"]
|
152
|
+
self.log.info("DAG run %s of DAG %s created", dag_run_id, self.trigger_dag_id)
|
153
|
+
|
154
|
+
task_description = f"DAG run {dag_run_id} of DAG {self.trigger_dag_id} to complete"
|
155
|
+
if self.deferrable:
|
156
|
+
self.log.info("Deferring for %s", task_description)
|
157
|
+
self.defer(
|
158
|
+
trigger=MwaaDagRunCompletedTrigger(
|
159
|
+
external_env_name=self.env_name,
|
160
|
+
external_dag_id=self.trigger_dag_id,
|
161
|
+
external_dag_run_id=dag_run_id,
|
162
|
+
waiter_delay=self.waiter_delay,
|
163
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
164
|
+
aws_conn_id=self.aws_conn_id,
|
165
|
+
),
|
166
|
+
method_name="execute_complete",
|
167
|
+
)
|
168
|
+
elif self.wait_for_completion:
|
169
|
+
self.log.info("Waiting for %s", task_description)
|
170
|
+
api_kwargs = {
|
171
|
+
"Name": self.env_name,
|
172
|
+
"Path": f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}",
|
173
|
+
"Method": "GET",
|
174
|
+
}
|
175
|
+
self.hook.get_waiter("mwaa_dag_run_complete").wait(
|
176
|
+
**api_kwargs,
|
177
|
+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
178
|
+
)
|
179
|
+
|
180
|
+
return response
|
@@ -755,11 +755,18 @@ class RedshiftDeleteClusterOperator(BaseOperator):
|
|
755
755
|
final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
|
756
756
|
)
|
757
757
|
break
|
758
|
-
except self.redshift_hook.
|
758
|
+
except self.redshift_hook.conn.exceptions.InvalidClusterStateFault:
|
759
759
|
self._attempts -= 1
|
760
760
|
|
761
761
|
if self._attempts:
|
762
|
-
|
762
|
+
current_state = self.redshift_hook.conn.describe_clusters(
|
763
|
+
ClusterIdentifier=self.cluster_identifier
|
764
|
+
)["Clusters"][0]["ClusterStatus"]
|
765
|
+
self.log.error(
|
766
|
+
"Cluster in %s state, unable to delete. %d attempts remaining.",
|
767
|
+
current_state,
|
768
|
+
self._attempts,
|
769
|
+
)
|
763
770
|
time.sleep(self._attempt_interval)
|
764
771
|
else:
|
765
772
|
raise
|
@@ -785,7 +792,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
|
|
785
792
|
)
|
786
793
|
|
787
794
|
elif self.wait_for_completion:
|
788
|
-
waiter = self.redshift_hook.
|
795
|
+
waiter = self.redshift_hook.conn.get_waiter("cluster_deleted")
|
789
796
|
waiter.wait(
|
790
797
|
ClusterIdentifier=self.cluster_identifier,
|
791
798
|
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
|