apache-airflow-providers-amazon 8.3.1rc1__py3-none-any.whl → 8.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +4 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +29 -12
- airflow/providers/amazon/aws/hooks/emr.py +17 -9
- airflow/providers/amazon/aws/hooks/eventbridge.py +27 -0
- airflow/providers/amazon/aws/hooks/redshift_data.py +10 -0
- airflow/providers/amazon/aws/hooks/sagemaker.py +24 -14
- airflow/providers/amazon/aws/notifications/chime.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +140 -7
- airflow/providers/amazon/aws/operators/emr.py +202 -22
- airflow/providers/amazon/aws/operators/eventbridge.py +87 -0
- airflow/providers/amazon/aws/operators/rds.py +120 -48
- airflow/providers/amazon/aws/operators/redshift_data.py +7 -0
- airflow/providers/amazon/aws/operators/sagemaker.py +75 -7
- airflow/providers/amazon/aws/operators/step_function.py +34 -2
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
- airflow/providers/amazon/aws/triggers/batch.py +1 -1
- airflow/providers/amazon/aws/triggers/ecs.py +7 -5
- airflow/providers/amazon/aws/triggers/eks.py +174 -3
- airflow/providers/amazon/aws/triggers/emr.py +215 -1
- airflow/providers/amazon/aws/triggers/rds.py +161 -5
- airflow/providers/amazon/aws/triggers/sagemaker.py +84 -1
- airflow/providers/amazon/aws/triggers/step_function.py +59 -0
- airflow/providers/amazon/aws/utils/__init__.py +16 -1
- airflow/providers/amazon/aws/utils/rds.py +2 -2
- airflow/providers/amazon/aws/waiters/sagemaker.json +46 -0
- airflow/providers/amazon/aws/waiters/stepfunctions.json +36 -0
- airflow/providers/amazon/get_provider_info.py +21 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/METADATA +13 -13
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/RECORD +34 -30
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/top_level.txt +0 -0
@@ -178,7 +178,7 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
178
178
|
where_statement = " AND ".join([f"{self.table}.{k} = {copy_destination}.{k}" for k in keys])
|
179
179
|
|
180
180
|
sql = [
|
181
|
-
f"CREATE TABLE {copy_destination} (LIKE {destination});",
|
181
|
+
f"CREATE TABLE {copy_destination} (LIKE {destination} INCLUDING DEFAULTS);",
|
182
182
|
copy_statement,
|
183
183
|
"BEGIN;",
|
184
184
|
f"DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};",
|
@@ -209,7 +209,7 @@ class BatchJobTrigger(AwsBaseWaiterTrigger):
|
|
209
209
|
def __init__(
|
210
210
|
self,
|
211
211
|
job_id: str | None,
|
212
|
-
region_name: str | None,
|
212
|
+
region_name: str | None = None,
|
213
213
|
aws_conn_id: str | None = "aws_default",
|
214
214
|
waiter_delay: int = 5,
|
215
215
|
waiter_max_attempts: int = 720,
|
@@ -22,6 +22,7 @@ from typing import Any, AsyncIterator
|
|
22
22
|
|
23
23
|
from botocore.exceptions import ClientError, WaiterError
|
24
24
|
|
25
|
+
from airflow import AirflowException
|
25
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
27
|
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
27
28
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
@@ -48,7 +49,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
|
|
48
49
|
waiter_delay: int,
|
49
50
|
waiter_max_attempts: int,
|
50
51
|
aws_conn_id: str | None,
|
51
|
-
region_name: str | None,
|
52
|
+
region_name: str | None = None,
|
52
53
|
):
|
53
54
|
super().__init__(
|
54
55
|
serialized_fields={"cluster_arn": cluster_arn},
|
@@ -87,7 +88,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
|
|
87
88
|
waiter_delay: int,
|
88
89
|
waiter_max_attempts: int,
|
89
90
|
aws_conn_id: str | None,
|
90
|
-
region_name: str | None,
|
91
|
+
region_name: str | None = None,
|
91
92
|
):
|
92
93
|
super().__init__(
|
93
94
|
serialized_fields={"cluster_arn": cluster_arn},
|
@@ -170,7 +171,9 @@ class TaskDoneTrigger(BaseTrigger):
|
|
170
171
|
await waiter.wait(
|
171
172
|
cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
|
172
173
|
)
|
173
|
-
|
174
|
+
# we reach this point only if the waiter met a success criteria
|
175
|
+
yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
|
176
|
+
return
|
174
177
|
except WaiterError as error:
|
175
178
|
if "terminal failure" in str(error):
|
176
179
|
raise
|
@@ -179,8 +182,7 @@ class TaskDoneTrigger(BaseTrigger):
|
|
179
182
|
finally:
|
180
183
|
if self.log_group and self.log_stream:
|
181
184
|
logs_token = await self._forward_logs(logs_client, logs_token)
|
182
|
-
|
183
|
-
yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
|
185
|
+
raise AirflowException("Waiter error: max attempts reached")
|
184
186
|
|
185
187
|
async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None:
|
186
188
|
"""
|
@@ -17,11 +17,178 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import warnings
|
20
|
+
from typing import Any
|
20
21
|
|
21
22
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
22
23
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
23
24
|
from airflow.providers.amazon.aws.hooks.eks import EksHook
|
24
25
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
26
|
+
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
27
|
+
from airflow.triggers.base import TriggerEvent
|
28
|
+
|
29
|
+
|
30
|
+
class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
|
31
|
+
"""
|
32
|
+
Trigger for EksCreateClusterOperator.
|
33
|
+
|
34
|
+
The trigger will asynchronously wait for the cluster to be created.
|
35
|
+
|
36
|
+
:param cluster_name: The name of the EKS cluster
|
37
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
38
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
39
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
40
|
+
:param region_name: Which AWS region the connection should use.
|
41
|
+
If this is None or empty then the default boto3 behaviour is used.
|
42
|
+
"""
|
43
|
+
|
44
|
+
def __init__(
|
45
|
+
self,
|
46
|
+
cluster_name: str,
|
47
|
+
waiter_delay: int,
|
48
|
+
waiter_max_attempts: int,
|
49
|
+
aws_conn_id: str,
|
50
|
+
region_name: str | None = None,
|
51
|
+
):
|
52
|
+
super().__init__(
|
53
|
+
serialized_fields={"cluster_name": cluster_name, "region_name": region_name},
|
54
|
+
waiter_name="cluster_active",
|
55
|
+
waiter_args={"name": cluster_name},
|
56
|
+
failure_message="Error checking Eks cluster",
|
57
|
+
status_message="Eks cluster status is",
|
58
|
+
status_queries=["cluster.status"],
|
59
|
+
return_value=None,
|
60
|
+
waiter_delay=waiter_delay,
|
61
|
+
waiter_max_attempts=waiter_max_attempts,
|
62
|
+
aws_conn_id=aws_conn_id,
|
63
|
+
region_name=region_name,
|
64
|
+
)
|
65
|
+
|
66
|
+
def hook(self) -> AwsGenericHook:
|
67
|
+
return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
68
|
+
|
69
|
+
|
70
|
+
class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
|
71
|
+
"""
|
72
|
+
Trigger for EksDeleteClusterOperator.
|
73
|
+
|
74
|
+
The trigger will asynchronously wait for the cluster to be deleted. If there are
|
75
|
+
any nodegroups or fargate profiles associated with the cluster, they will be deleted
|
76
|
+
before the cluster is deleted.
|
77
|
+
|
78
|
+
:param cluster_name: The name of the EKS cluster
|
79
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
80
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
81
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
82
|
+
:param region_name: Which AWS region the connection should use.
|
83
|
+
If this is None or empty then the default boto3 behaviour is used.
|
84
|
+
:param force_delete_compute: If True, any nodegroups or fargate profiles associated
|
85
|
+
with the cluster will be deleted before the cluster is deleted.
|
86
|
+
"""
|
87
|
+
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
cluster_name,
|
91
|
+
waiter_delay: int,
|
92
|
+
waiter_max_attempts: int,
|
93
|
+
aws_conn_id: str,
|
94
|
+
region_name: str | None,
|
95
|
+
force_delete_compute: bool,
|
96
|
+
):
|
97
|
+
self.cluster_name = cluster_name
|
98
|
+
self.waiter_delay = waiter_delay
|
99
|
+
self.waiter_max_attempts = waiter_max_attempts
|
100
|
+
self.aws_conn_id = aws_conn_id
|
101
|
+
self.region_name = region_name
|
102
|
+
self.force_delete_compute = force_delete_compute
|
103
|
+
|
104
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
105
|
+
return (
|
106
|
+
self.__class__.__module__ + "." + self.__class__.__qualname__,
|
107
|
+
{
|
108
|
+
"cluster_name": self.cluster_name,
|
109
|
+
"waiter_delay": str(self.waiter_delay),
|
110
|
+
"waiter_max_attempts": str(self.waiter_max_attempts),
|
111
|
+
"aws_conn_id": self.aws_conn_id,
|
112
|
+
"region_name": self.region_name,
|
113
|
+
"force_delete_compute": self.force_delete_compute,
|
114
|
+
},
|
115
|
+
)
|
116
|
+
|
117
|
+
def hook(self) -> AwsGenericHook:
|
118
|
+
return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
119
|
+
|
120
|
+
async def run(self):
|
121
|
+
async with self.hook.async_conn as client:
|
122
|
+
waiter = client.get_waiter("cluster_deleted")
|
123
|
+
if self.force_delete_compute:
|
124
|
+
await self.delete_any_nodegroups(client=client)
|
125
|
+
await self.delete_any_fargate_profiles(client=client)
|
126
|
+
await client.delete_cluster(name=self.cluster_name)
|
127
|
+
await async_wait(
|
128
|
+
waiter=waiter,
|
129
|
+
waiter_delay=int(self.waiter_delay),
|
130
|
+
waiter_max_attempts=int(self.waiter_max_attempts),
|
131
|
+
args={"name": self.cluster_name},
|
132
|
+
failure_message="Error deleting cluster",
|
133
|
+
status_message="Status of cluster is",
|
134
|
+
status_args=["cluster.status"],
|
135
|
+
)
|
136
|
+
|
137
|
+
yield TriggerEvent({"status": "deleted"})
|
138
|
+
|
139
|
+
async def delete_any_nodegroups(self, client) -> None:
|
140
|
+
"""
|
141
|
+
Deletes all EKS Nodegroups for a provided Amazon EKS Cluster.
|
142
|
+
|
143
|
+
All the EKS Nodegroups are deleted simultaneously. We wait for
|
144
|
+
all Nodegroups to be deleted before returning.
|
145
|
+
"""
|
146
|
+
nodegroups = await client.list_nodegroups(clusterName=self.cluster_name)
|
147
|
+
if nodegroups.get("nodegroups", None):
|
148
|
+
self.log.info("Deleting nodegroups")
|
149
|
+
# ignoring attr-defined here because aws_base hook defines get_waiter for all hooks
|
150
|
+
waiter = self.hook.get_waiter( # type: ignore[attr-defined]
|
151
|
+
"all_nodegroups_deleted", deferrable=True, client=client
|
152
|
+
)
|
153
|
+
for group in nodegroups["nodegroups"]:
|
154
|
+
await client.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=group)
|
155
|
+
await async_wait(
|
156
|
+
waiter=waiter,
|
157
|
+
waiter_delay=int(self.waiter_delay),
|
158
|
+
waiter_max_attempts=int(self.waiter_max_attempts),
|
159
|
+
args={"clusterName": self.cluster_name},
|
160
|
+
failure_message=f"Error deleting nodegroup for cluster {self.cluster_name}",
|
161
|
+
status_message="Deleting nodegroups associated with the cluster",
|
162
|
+
status_args=["nodegroups"],
|
163
|
+
)
|
164
|
+
self.log.info("All nodegroups deleted")
|
165
|
+
else:
|
166
|
+
self.log.info("No nodegroups associated with cluster %s", self.cluster_name)
|
167
|
+
|
168
|
+
async def delete_any_fargate_profiles(self, client) -> None:
|
169
|
+
"""
|
170
|
+
Deletes all EKS Fargate profiles for a provided Amazon EKS Cluster.
|
171
|
+
|
172
|
+
EKS Fargate profiles must be deleted one at a time, so we must wait
|
173
|
+
for one to be deleted before sending the next delete command.
|
174
|
+
"""
|
175
|
+
fargate_profiles = await client.list_fargate_profiles(clusterName=self.cluster_name)
|
176
|
+
if fargate_profiles.get("fargateProfileNames"):
|
177
|
+
self.log.info("Waiting for Fargate profiles to delete. This will take some time.")
|
178
|
+
for profile in fargate_profiles["fargateProfileNames"]:
|
179
|
+
await client.delete_fargate_profile(clusterName=self.cluster_name, fargateProfileName=profile)
|
180
|
+
await async_wait(
|
181
|
+
waiter=client.get_waiter("fargate_profile_deleted"),
|
182
|
+
waiter_delay=int(self.waiter_delay),
|
183
|
+
waiter_max_attempts=int(self.waiter_max_attempts),
|
184
|
+
args={"clusterName": self.cluster_name, "fargateProfileName": profile},
|
185
|
+
failure_message=f"Error deleting fargate profile for cluster {self.cluster_name}",
|
186
|
+
status_message="Status of fargate profile is",
|
187
|
+
status_args=["fargateProfile.status"],
|
188
|
+
)
|
189
|
+
self.log.info("All Fargate profiles deleted")
|
190
|
+
else:
|
191
|
+
self.log.info(f"No Fargate profiles associated with cluster {self.cluster_name}")
|
25
192
|
|
26
193
|
|
27
194
|
class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger):
|
@@ -142,10 +309,14 @@ class EksCreateNodegroupTrigger(AwsBaseWaiterTrigger):
|
|
142
309
|
waiter_delay: int,
|
143
310
|
waiter_max_attempts: int,
|
144
311
|
aws_conn_id: str,
|
145
|
-
region_name: str | None,
|
312
|
+
region_name: str | None = None,
|
146
313
|
):
|
147
314
|
super().__init__(
|
148
|
-
serialized_fields={
|
315
|
+
serialized_fields={
|
316
|
+
"cluster_name": cluster_name,
|
317
|
+
"nodegroup_name": nodegroup_name,
|
318
|
+
"region_name": region_name,
|
319
|
+
},
|
149
320
|
waiter_name="nodegroup_active",
|
150
321
|
waiter_args={"clusterName": cluster_name, "nodegroupName": nodegroup_name},
|
151
322
|
failure_message="Error creating nodegroup",
|
@@ -186,7 +357,7 @@ class EksDeleteNodegroupTrigger(AwsBaseWaiterTrigger):
|
|
186
357
|
waiter_delay: int,
|
187
358
|
waiter_max_attempts: int,
|
188
359
|
aws_conn_id: str,
|
189
|
-
region_name: str | None,
|
360
|
+
region_name: str | None = None,
|
190
361
|
):
|
191
362
|
super().__init__(
|
192
363
|
serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name},
|
@@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError
|
|
24
24
|
|
25
25
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
26
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
27
|
-
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
|
27
|
+
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
|
28
28
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
29
29
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
30
30
|
|
@@ -283,3 +283,217 @@ class EmrStepSensorTrigger(AwsBaseWaiterTrigger):
|
|
283
283
|
|
284
284
|
def hook(self) -> AwsGenericHook:
|
285
285
|
return EmrHook(self.aws_conn_id)
|
286
|
+
|
287
|
+
|
288
|
+
class EmrServerlessCreateApplicationTrigger(AwsBaseWaiterTrigger):
|
289
|
+
"""
|
290
|
+
Poll an Emr Serverless application and wait for it to be created.
|
291
|
+
|
292
|
+
:param application_id: The ID of the application being polled.
|
293
|
+
:waiter_delay: polling period in seconds to check for the status
|
294
|
+
:param waiter_max_attempts: The maximum number of attempts to be made
|
295
|
+
:param aws_conn_id: Reference to AWS connection id
|
296
|
+
"""
|
297
|
+
|
298
|
+
def __init__(
|
299
|
+
self,
|
300
|
+
application_id: str,
|
301
|
+
waiter_delay: int = 30,
|
302
|
+
waiter_max_attempts: int = 60,
|
303
|
+
aws_conn_id: str = "aws_default",
|
304
|
+
) -> None:
|
305
|
+
super().__init__(
|
306
|
+
serialized_fields={"application_id": application_id},
|
307
|
+
waiter_name="serverless_app_created",
|
308
|
+
waiter_args={"applicationId": application_id},
|
309
|
+
failure_message="Application creation failed",
|
310
|
+
status_message="Application status is",
|
311
|
+
status_queries=["application.state", "application.stateDetails"],
|
312
|
+
return_key="application_id",
|
313
|
+
return_value=application_id,
|
314
|
+
waiter_delay=waiter_delay,
|
315
|
+
waiter_max_attempts=waiter_max_attempts,
|
316
|
+
aws_conn_id=aws_conn_id,
|
317
|
+
)
|
318
|
+
|
319
|
+
def hook(self) -> AwsGenericHook:
|
320
|
+
return EmrServerlessHook(self.aws_conn_id)
|
321
|
+
|
322
|
+
|
323
|
+
class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
|
324
|
+
"""
|
325
|
+
Poll an Emr Serverless application and wait for it to be started.
|
326
|
+
|
327
|
+
:param application_id: The ID of the application being polled.
|
328
|
+
:waiter_delay: polling period in seconds to check for the status
|
329
|
+
:param waiter_max_attempts: The maximum number of attempts to be made
|
330
|
+
:param aws_conn_id: Reference to AWS connection id
|
331
|
+
"""
|
332
|
+
|
333
|
+
def __init__(
|
334
|
+
self,
|
335
|
+
application_id: str,
|
336
|
+
waiter_delay: int = 30,
|
337
|
+
waiter_max_attempts: int = 60,
|
338
|
+
aws_conn_id: str = "aws_default",
|
339
|
+
) -> None:
|
340
|
+
super().__init__(
|
341
|
+
serialized_fields={"application_id": application_id},
|
342
|
+
waiter_name="serverless_app_started",
|
343
|
+
waiter_args={"applicationId": application_id},
|
344
|
+
failure_message="Application failed to start",
|
345
|
+
status_message="Application status is",
|
346
|
+
status_queries=["application.state", "application.stateDetails"],
|
347
|
+
return_key="application_id",
|
348
|
+
return_value=application_id,
|
349
|
+
waiter_delay=waiter_delay,
|
350
|
+
waiter_max_attempts=waiter_max_attempts,
|
351
|
+
aws_conn_id=aws_conn_id,
|
352
|
+
)
|
353
|
+
|
354
|
+
def hook(self) -> AwsGenericHook:
|
355
|
+
return EmrServerlessHook(self.aws_conn_id)
|
356
|
+
|
357
|
+
|
358
|
+
class EmrServerlessStopApplicationTrigger(AwsBaseWaiterTrigger):
|
359
|
+
"""
|
360
|
+
Poll an Emr Serverless application and wait for it to be stopped.
|
361
|
+
|
362
|
+
:param application_id: The ID of the application being polled.
|
363
|
+
:waiter_delay: polling period in seconds to check for the status
|
364
|
+
:param waiter_max_attempts: The maximum number of attempts to be made
|
365
|
+
:param aws_conn_id: Reference to AWS connection id.
|
366
|
+
"""
|
367
|
+
|
368
|
+
def __init__(
|
369
|
+
self,
|
370
|
+
application_id: str,
|
371
|
+
waiter_delay: int = 30,
|
372
|
+
waiter_max_attempts: int = 60,
|
373
|
+
aws_conn_id: str = "aws_default",
|
374
|
+
) -> None:
|
375
|
+
super().__init__(
|
376
|
+
serialized_fields={"application_id": application_id},
|
377
|
+
waiter_name="serverless_app_stopped",
|
378
|
+
waiter_args={"applicationId": application_id},
|
379
|
+
failure_message="Application failed to start",
|
380
|
+
status_message="Application status is",
|
381
|
+
status_queries=["application.state", "application.stateDetails"],
|
382
|
+
return_key="application_id",
|
383
|
+
return_value=application_id,
|
384
|
+
waiter_delay=waiter_delay,
|
385
|
+
waiter_max_attempts=waiter_max_attempts,
|
386
|
+
aws_conn_id=aws_conn_id,
|
387
|
+
)
|
388
|
+
|
389
|
+
def hook(self) -> AwsGenericHook:
|
390
|
+
return EmrServerlessHook(self.aws_conn_id)
|
391
|
+
|
392
|
+
|
393
|
+
class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
|
394
|
+
"""
|
395
|
+
Poll an Emr Serverless job run and wait for it to be completed.
|
396
|
+
|
397
|
+
:param application_id: The ID of the application the job in being run on.
|
398
|
+
:param job_id: The ID of the job run.
|
399
|
+
:waiter_delay: polling period in seconds to check for the status
|
400
|
+
:param waiter_max_attempts: The maximum number of attempts to be made
|
401
|
+
:param aws_conn_id: Reference to AWS connection id
|
402
|
+
"""
|
403
|
+
|
404
|
+
def __init__(
|
405
|
+
self,
|
406
|
+
application_id: str,
|
407
|
+
job_id: str | None,
|
408
|
+
waiter_delay: int = 30,
|
409
|
+
waiter_max_attempts: int = 60,
|
410
|
+
aws_conn_id: str = "aws_default",
|
411
|
+
) -> None:
|
412
|
+
super().__init__(
|
413
|
+
serialized_fields={"application_id": application_id, "job_id": job_id},
|
414
|
+
waiter_name="serverless_job_completed",
|
415
|
+
waiter_args={"applicationId": application_id, "jobRunId": job_id},
|
416
|
+
failure_message="Serverless Job failed",
|
417
|
+
status_message="Serverless Job status is",
|
418
|
+
status_queries=["jobRun.state", "jobRun.stateDetails"],
|
419
|
+
return_key="job_id",
|
420
|
+
return_value=job_id,
|
421
|
+
waiter_delay=waiter_delay,
|
422
|
+
waiter_max_attempts=waiter_max_attempts,
|
423
|
+
aws_conn_id=aws_conn_id,
|
424
|
+
)
|
425
|
+
|
426
|
+
def hook(self) -> AwsGenericHook:
|
427
|
+
return EmrServerlessHook(self.aws_conn_id)
|
428
|
+
|
429
|
+
|
430
|
+
class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):
|
431
|
+
"""
|
432
|
+
Poll an Emr Serverless application and wait for it to be deleted.
|
433
|
+
|
434
|
+
:param application_id: The ID of the application being polled.
|
435
|
+
:waiter_delay: polling period in seconds to check for the status
|
436
|
+
:param waiter_max_attempts: The maximum number of attempts to be made
|
437
|
+
:param aws_conn_id: Reference to AWS connection id
|
438
|
+
"""
|
439
|
+
|
440
|
+
def __init__(
|
441
|
+
self,
|
442
|
+
application_id: str,
|
443
|
+
waiter_delay: int = 30,
|
444
|
+
waiter_max_attempts: int = 60,
|
445
|
+
aws_conn_id: str = "aws_default",
|
446
|
+
) -> None:
|
447
|
+
super().__init__(
|
448
|
+
serialized_fields={"application_id": application_id},
|
449
|
+
waiter_name="serverless_app_terminated",
|
450
|
+
waiter_args={"applicationId": application_id},
|
451
|
+
failure_message="Application failed to start",
|
452
|
+
status_message="Application status is",
|
453
|
+
status_queries=["application.state", "application.stateDetails"],
|
454
|
+
return_key="application_id",
|
455
|
+
return_value=application_id,
|
456
|
+
waiter_delay=waiter_delay,
|
457
|
+
waiter_max_attempts=waiter_max_attempts,
|
458
|
+
aws_conn_id=aws_conn_id,
|
459
|
+
)
|
460
|
+
|
461
|
+
def hook(self) -> AwsGenericHook:
|
462
|
+
return EmrServerlessHook(self.aws_conn_id)
|
463
|
+
|
464
|
+
|
465
|
+
class EmrServerlessCancelJobsTrigger(AwsBaseWaiterTrigger):
|
466
|
+
"""
|
467
|
+
Trigger for canceling a list of jobs in an EMR Serverless application.
|
468
|
+
|
469
|
+
:param application_id: EMR Serverless application ID
|
470
|
+
:param aws_conn_id: Reference to AWS connection id
|
471
|
+
:param waiter_delay: Delay in seconds between each attempt to check the status
|
472
|
+
:param waiter_max_attempts: Maximum number of attempts to check the status
|
473
|
+
"""
|
474
|
+
|
475
|
+
def __init__(
|
476
|
+
self,
|
477
|
+
application_id: str,
|
478
|
+
aws_conn_id: str,
|
479
|
+
waiter_delay: int,
|
480
|
+
waiter_max_attempts: int,
|
481
|
+
) -> None:
|
482
|
+
self.hook_instance = EmrServerlessHook(aws_conn_id)
|
483
|
+
states = list(self.hook_instance.JOB_INTERMEDIATE_STATES.union({"CANCELLING"}))
|
484
|
+
super().__init__(
|
485
|
+
serialized_fields={"application_id": application_id},
|
486
|
+
waiter_name="no_job_running",
|
487
|
+
waiter_args={"applicationId": application_id, "states": states},
|
488
|
+
failure_message="Error while waiting for jobs to cancel",
|
489
|
+
status_message="Currently running jobs",
|
490
|
+
status_queries=["jobRuns[*].applicationId", "jobRuns[*].state"],
|
491
|
+
return_key="application_id",
|
492
|
+
return_value=application_id,
|
493
|
+
waiter_delay=waiter_delay,
|
494
|
+
waiter_max_attempts=waiter_max_attempts,
|
495
|
+
aws_conn_id=aws_conn_id,
|
496
|
+
)
|
497
|
+
|
498
|
+
def hook(self) -> AwsGenericHook:
|
499
|
+
return self.hook_instance
|