apache-airflow-providers-amazon 8.3.1__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. airflow/providers/amazon/__init__.py +4 -2
  2. airflow/providers/amazon/aws/hooks/base_aws.py +29 -12
  3. airflow/providers/amazon/aws/hooks/emr.py +17 -9
  4. airflow/providers/amazon/aws/hooks/eventbridge.py +27 -0
  5. airflow/providers/amazon/aws/hooks/redshift_data.py +10 -0
  6. airflow/providers/amazon/aws/hooks/sagemaker.py +24 -14
  7. airflow/providers/amazon/aws/notifications/chime.py +1 -1
  8. airflow/providers/amazon/aws/operators/eks.py +140 -7
  9. airflow/providers/amazon/aws/operators/emr.py +202 -22
  10. airflow/providers/amazon/aws/operators/eventbridge.py +87 -0
  11. airflow/providers/amazon/aws/operators/rds.py +120 -48
  12. airflow/providers/amazon/aws/operators/redshift_data.py +7 -0
  13. airflow/providers/amazon/aws/operators/sagemaker.py +75 -7
  14. airflow/providers/amazon/aws/operators/step_function.py +34 -2
  15. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
  16. airflow/providers/amazon/aws/triggers/batch.py +1 -1
  17. airflow/providers/amazon/aws/triggers/ecs.py +7 -5
  18. airflow/providers/amazon/aws/triggers/eks.py +174 -3
  19. airflow/providers/amazon/aws/triggers/emr.py +215 -1
  20. airflow/providers/amazon/aws/triggers/rds.py +161 -5
  21. airflow/providers/amazon/aws/triggers/sagemaker.py +84 -1
  22. airflow/providers/amazon/aws/triggers/step_function.py +59 -0
  23. airflow/providers/amazon/aws/utils/__init__.py +16 -1
  24. airflow/providers/amazon/aws/utils/rds.py +2 -2
  25. airflow/providers/amazon/aws/waiters/sagemaker.json +46 -0
  26. airflow/providers/amazon/aws/waiters/stepfunctions.json +36 -0
  27. airflow/providers/amazon/get_provider_info.py +21 -1
  28. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/METADATA +11 -11
  29. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/RECORD +34 -30
  30. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/WHEEL +1 -1
  31. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/LICENSE +0 -0
  32. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/NOTICE +0 -0
  33. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/entry_points.txt +0 -0
  34. {apache_airflow_providers_amazon-8.3.1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/top_level.txt +0 -0
@@ -178,7 +178,7 @@ class S3ToRedshiftOperator(BaseOperator):
178
178
  where_statement = " AND ".join([f"{self.table}.{k} = {copy_destination}.{k}" for k in keys])
179
179
 
180
180
  sql = [
181
- f"CREATE TABLE {copy_destination} (LIKE {destination});",
181
+ f"CREATE TABLE {copy_destination} (LIKE {destination} INCLUDING DEFAULTS);",
182
182
  copy_statement,
183
183
  "BEGIN;",
184
184
  f"DELETE FROM {destination} USING {copy_destination} WHERE {where_statement};",
@@ -209,7 +209,7 @@ class BatchJobTrigger(AwsBaseWaiterTrigger):
209
209
  def __init__(
210
210
  self,
211
211
  job_id: str | None,
212
- region_name: str | None,
212
+ region_name: str | None = None,
213
213
  aws_conn_id: str | None = "aws_default",
214
214
  waiter_delay: int = 5,
215
215
  waiter_max_attempts: int = 720,
@@ -22,6 +22,7 @@ from typing import Any, AsyncIterator
22
22
 
23
23
  from botocore.exceptions import ClientError, WaiterError
24
24
 
25
+ from airflow import AirflowException
25
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
26
27
  from airflow.providers.amazon.aws.hooks.ecs import EcsHook
27
28
  from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
@@ -48,7 +49,7 @@ class ClusterActiveTrigger(AwsBaseWaiterTrigger):
48
49
  waiter_delay: int,
49
50
  waiter_max_attempts: int,
50
51
  aws_conn_id: str | None,
51
- region_name: str | None,
52
+ region_name: str | None = None,
52
53
  ):
53
54
  super().__init__(
54
55
  serialized_fields={"cluster_arn": cluster_arn},
@@ -87,7 +88,7 @@ class ClusterInactiveTrigger(AwsBaseWaiterTrigger):
87
88
  waiter_delay: int,
88
89
  waiter_max_attempts: int,
89
90
  aws_conn_id: str | None,
90
- region_name: str | None,
91
+ region_name: str | None = None,
91
92
  ):
92
93
  super().__init__(
93
94
  serialized_fields={"cluster_arn": cluster_arn},
@@ -170,7 +171,9 @@ class TaskDoneTrigger(BaseTrigger):
170
171
  await waiter.wait(
171
172
  cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1}
172
173
  )
173
- break # we reach this point only if the waiter met a success criteria
174
+ # we reach this point only if the waiter met a success criteria
175
+ yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
176
+ return
174
177
  except WaiterError as error:
175
178
  if "terminal failure" in str(error):
176
179
  raise
@@ -179,8 +182,7 @@ class TaskDoneTrigger(BaseTrigger):
179
182
  finally:
180
183
  if self.log_group and self.log_stream:
181
184
  logs_token = await self._forward_logs(logs_client, logs_token)
182
-
183
- yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
185
+ raise AirflowException("Waiter error: max attempts reached")
184
186
 
185
187
  async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None:
186
188
  """
@@ -17,11 +17,178 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import warnings
20
+ from typing import Any
20
21
 
21
22
  from airflow.exceptions import AirflowProviderDeprecationWarning
22
23
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
23
24
  from airflow.providers.amazon.aws.hooks.eks import EksHook
24
25
  from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
26
+ from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
27
+ from airflow.triggers.base import TriggerEvent
28
+
29
+
30
+ class EksCreateClusterTrigger(AwsBaseWaiterTrigger):
31
+ """
32
+ Trigger for EksCreateClusterOperator.
33
+
34
+ The trigger will asynchronously wait for the cluster to be created.
35
+
36
+ :param cluster_name: The name of the EKS cluster
37
+ :param waiter_delay: The amount of time in seconds to wait between attempts.
38
+ :param waiter_max_attempts: The maximum number of attempts to be made.
39
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
40
+ :param region_name: Which AWS region the connection should use.
41
+ If this is None or empty then the default boto3 behaviour is used.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ cluster_name: str,
47
+ waiter_delay: int,
48
+ waiter_max_attempts: int,
49
+ aws_conn_id: str,
50
+ region_name: str | None = None,
51
+ ):
52
+ super().__init__(
53
+ serialized_fields={"cluster_name": cluster_name, "region_name": region_name},
54
+ waiter_name="cluster_active",
55
+ waiter_args={"name": cluster_name},
56
+ failure_message="Error checking Eks cluster",
57
+ status_message="Eks cluster status is",
58
+ status_queries=["cluster.status"],
59
+ return_value=None,
60
+ waiter_delay=waiter_delay,
61
+ waiter_max_attempts=waiter_max_attempts,
62
+ aws_conn_id=aws_conn_id,
63
+ region_name=region_name,
64
+ )
65
+
66
+ def hook(self) -> AwsGenericHook:
67
+ return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
68
+
69
+
70
+ class EksDeleteClusterTrigger(AwsBaseWaiterTrigger):
71
+ """
72
+ Trigger for EksDeleteClusterOperator.
73
+
74
+ The trigger will asynchronously wait for the cluster to be deleted. If there are
75
+ any nodegroups or fargate profiles associated with the cluster, they will be deleted
76
+ before the cluster is deleted.
77
+
78
+ :param cluster_name: The name of the EKS cluster
79
+ :param waiter_delay: The amount of time in seconds to wait between attempts.
80
+ :param waiter_max_attempts: The maximum number of attempts to be made.
81
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
82
+ :param region_name: Which AWS region the connection should use.
83
+ If this is None or empty then the default boto3 behaviour is used.
84
+ :param force_delete_compute: If True, any nodegroups or fargate profiles associated
85
+ with the cluster will be deleted before the cluster is deleted.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ cluster_name,
91
+ waiter_delay: int,
92
+ waiter_max_attempts: int,
93
+ aws_conn_id: str,
94
+ region_name: str | None,
95
+ force_delete_compute: bool,
96
+ ):
97
+ self.cluster_name = cluster_name
98
+ self.waiter_delay = waiter_delay
99
+ self.waiter_max_attempts = waiter_max_attempts
100
+ self.aws_conn_id = aws_conn_id
101
+ self.region_name = region_name
102
+ self.force_delete_compute = force_delete_compute
103
+
104
+ def serialize(self) -> tuple[str, dict[str, Any]]:
105
+ return (
106
+ self.__class__.__module__ + "." + self.__class__.__qualname__,
107
+ {
108
+ "cluster_name": self.cluster_name,
109
+ "waiter_delay": str(self.waiter_delay),
110
+ "waiter_max_attempts": str(self.waiter_max_attempts),
111
+ "aws_conn_id": self.aws_conn_id,
112
+ "region_name": self.region_name,
113
+ "force_delete_compute": self.force_delete_compute,
114
+ },
115
+ )
116
+
117
+ def hook(self) -> AwsGenericHook:
118
+ return EksHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
119
+
120
+ async def run(self):
121
+ async with self.hook.async_conn as client:
122
+ waiter = client.get_waiter("cluster_deleted")
123
+ if self.force_delete_compute:
124
+ await self.delete_any_nodegroups(client=client)
125
+ await self.delete_any_fargate_profiles(client=client)
126
+ await client.delete_cluster(name=self.cluster_name)
127
+ await async_wait(
128
+ waiter=waiter,
129
+ waiter_delay=int(self.waiter_delay),
130
+ waiter_max_attempts=int(self.waiter_max_attempts),
131
+ args={"name": self.cluster_name},
132
+ failure_message="Error deleting cluster",
133
+ status_message="Status of cluster is",
134
+ status_args=["cluster.status"],
135
+ )
136
+
137
+ yield TriggerEvent({"status": "deleted"})
138
+
139
+ async def delete_any_nodegroups(self, client) -> None:
140
+ """
141
+ Deletes all EKS Nodegroups for a provided Amazon EKS Cluster.
142
+
143
+ All the EKS Nodegroups are deleted simultaneously. We wait for
144
+ all Nodegroups to be deleted before returning.
145
+ """
146
+ nodegroups = await client.list_nodegroups(clusterName=self.cluster_name)
147
+ if nodegroups.get("nodegroups", None):
148
+ self.log.info("Deleting nodegroups")
149
+ # ignoring attr-defined here because aws_base hook defines get_waiter for all hooks
150
+ waiter = self.hook.get_waiter( # type: ignore[attr-defined]
151
+ "all_nodegroups_deleted", deferrable=True, client=client
152
+ )
153
+ for group in nodegroups["nodegroups"]:
154
+ await client.delete_nodegroup(clusterName=self.cluster_name, nodegroupName=group)
155
+ await async_wait(
156
+ waiter=waiter,
157
+ waiter_delay=int(self.waiter_delay),
158
+ waiter_max_attempts=int(self.waiter_max_attempts),
159
+ args={"clusterName": self.cluster_name},
160
+ failure_message=f"Error deleting nodegroup for cluster {self.cluster_name}",
161
+ status_message="Deleting nodegroups associated with the cluster",
162
+ status_args=["nodegroups"],
163
+ )
164
+ self.log.info("All nodegroups deleted")
165
+ else:
166
+ self.log.info("No nodegroups associated with cluster %s", self.cluster_name)
167
+
168
+ async def delete_any_fargate_profiles(self, client) -> None:
169
+ """
170
+ Deletes all EKS Fargate profiles for a provided Amazon EKS Cluster.
171
+
172
+ EKS Fargate profiles must be deleted one at a time, so we must wait
173
+ for one to be deleted before sending the next delete command.
174
+ """
175
+ fargate_profiles = await client.list_fargate_profiles(clusterName=self.cluster_name)
176
+ if fargate_profiles.get("fargateProfileNames"):
177
+ self.log.info("Waiting for Fargate profiles to delete. This will take some time.")
178
+ for profile in fargate_profiles["fargateProfileNames"]:
179
+ await client.delete_fargate_profile(clusterName=self.cluster_name, fargateProfileName=profile)
180
+ await async_wait(
181
+ waiter=client.get_waiter("fargate_profile_deleted"),
182
+ waiter_delay=int(self.waiter_delay),
183
+ waiter_max_attempts=int(self.waiter_max_attempts),
184
+ args={"clusterName": self.cluster_name, "fargateProfileName": profile},
185
+ failure_message=f"Error deleting fargate profile for cluster {self.cluster_name}",
186
+ status_message="Status of fargate profile is",
187
+ status_args=["fargateProfile.status"],
188
+ )
189
+ self.log.info("All Fargate profiles deleted")
190
+ else:
191
+ self.log.info(f"No Fargate profiles associated with cluster {self.cluster_name}")
25
192
 
26
193
 
27
194
  class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger):
@@ -142,10 +309,14 @@ class EksCreateNodegroupTrigger(AwsBaseWaiterTrigger):
142
309
  waiter_delay: int,
143
310
  waiter_max_attempts: int,
144
311
  aws_conn_id: str,
145
- region_name: str | None,
312
+ region_name: str | None = None,
146
313
  ):
147
314
  super().__init__(
148
- serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name},
315
+ serialized_fields={
316
+ "cluster_name": cluster_name,
317
+ "nodegroup_name": nodegroup_name,
318
+ "region_name": region_name,
319
+ },
149
320
  waiter_name="nodegroup_active",
150
321
  waiter_args={"clusterName": cluster_name, "nodegroupName": nodegroup_name},
151
322
  failure_message="Error creating nodegroup",
@@ -186,7 +357,7 @@ class EksDeleteNodegroupTrigger(AwsBaseWaiterTrigger):
186
357
  waiter_delay: int,
187
358
  waiter_max_attempts: int,
188
359
  aws_conn_id: str,
189
- region_name: str | None,
360
+ region_name: str | None = None,
190
361
  ):
191
362
  super().__init__(
192
363
  serialized_fields={"cluster_name": cluster_name, "nodegroup_name": nodegroup_name},
@@ -24,7 +24,7 @@ from botocore.exceptions import WaiterError
24
24
 
25
25
  from airflow.exceptions import AirflowProviderDeprecationWarning
26
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
27
- from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
27
+ from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
28
28
  from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
29
29
  from airflow.triggers.base import BaseTrigger, TriggerEvent
30
30
 
@@ -283,3 +283,217 @@ class EmrStepSensorTrigger(AwsBaseWaiterTrigger):
283
283
 
284
284
  def hook(self) -> AwsGenericHook:
285
285
  return EmrHook(self.aws_conn_id)
286
+
287
+
288
+ class EmrServerlessCreateApplicationTrigger(AwsBaseWaiterTrigger):
289
+ """
290
+ Poll an Emr Serverless application and wait for it to be created.
291
+
292
+ :param application_id: The ID of the application being polled.
293
+ :waiter_delay: polling period in seconds to check for the status
294
+ :param waiter_max_attempts: The maximum number of attempts to be made
295
+ :param aws_conn_id: Reference to AWS connection id
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ application_id: str,
301
+ waiter_delay: int = 30,
302
+ waiter_max_attempts: int = 60,
303
+ aws_conn_id: str = "aws_default",
304
+ ) -> None:
305
+ super().__init__(
306
+ serialized_fields={"application_id": application_id},
307
+ waiter_name="serverless_app_created",
308
+ waiter_args={"applicationId": application_id},
309
+ failure_message="Application creation failed",
310
+ status_message="Application status is",
311
+ status_queries=["application.state", "application.stateDetails"],
312
+ return_key="application_id",
313
+ return_value=application_id,
314
+ waiter_delay=waiter_delay,
315
+ waiter_max_attempts=waiter_max_attempts,
316
+ aws_conn_id=aws_conn_id,
317
+ )
318
+
319
+ def hook(self) -> AwsGenericHook:
320
+ return EmrServerlessHook(self.aws_conn_id)
321
+
322
+
323
+ class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger):
324
+ """
325
+ Poll an Emr Serverless application and wait for it to be started.
326
+
327
+ :param application_id: The ID of the application being polled.
328
+ :waiter_delay: polling period in seconds to check for the status
329
+ :param waiter_max_attempts: The maximum number of attempts to be made
330
+ :param aws_conn_id: Reference to AWS connection id
331
+ """
332
+
333
+ def __init__(
334
+ self,
335
+ application_id: str,
336
+ waiter_delay: int = 30,
337
+ waiter_max_attempts: int = 60,
338
+ aws_conn_id: str = "aws_default",
339
+ ) -> None:
340
+ super().__init__(
341
+ serialized_fields={"application_id": application_id},
342
+ waiter_name="serverless_app_started",
343
+ waiter_args={"applicationId": application_id},
344
+ failure_message="Application failed to start",
345
+ status_message="Application status is",
346
+ status_queries=["application.state", "application.stateDetails"],
347
+ return_key="application_id",
348
+ return_value=application_id,
349
+ waiter_delay=waiter_delay,
350
+ waiter_max_attempts=waiter_max_attempts,
351
+ aws_conn_id=aws_conn_id,
352
+ )
353
+
354
+ def hook(self) -> AwsGenericHook:
355
+ return EmrServerlessHook(self.aws_conn_id)
356
+
357
+
358
+ class EmrServerlessStopApplicationTrigger(AwsBaseWaiterTrigger):
359
+ """
360
+ Poll an Emr Serverless application and wait for it to be stopped.
361
+
362
+ :param application_id: The ID of the application being polled.
363
+ :waiter_delay: polling period in seconds to check for the status
364
+ :param waiter_max_attempts: The maximum number of attempts to be made
365
+ :param aws_conn_id: Reference to AWS connection id.
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ application_id: str,
371
+ waiter_delay: int = 30,
372
+ waiter_max_attempts: int = 60,
373
+ aws_conn_id: str = "aws_default",
374
+ ) -> None:
375
+ super().__init__(
376
+ serialized_fields={"application_id": application_id},
377
+ waiter_name="serverless_app_stopped",
378
+ waiter_args={"applicationId": application_id},
379
+ failure_message="Application failed to start",
380
+ status_message="Application status is",
381
+ status_queries=["application.state", "application.stateDetails"],
382
+ return_key="application_id",
383
+ return_value=application_id,
384
+ waiter_delay=waiter_delay,
385
+ waiter_max_attempts=waiter_max_attempts,
386
+ aws_conn_id=aws_conn_id,
387
+ )
388
+
389
+ def hook(self) -> AwsGenericHook:
390
+ return EmrServerlessHook(self.aws_conn_id)
391
+
392
+
393
+ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
394
+ """
395
+ Poll an Emr Serverless job run and wait for it to be completed.
396
+
397
+ :param application_id: The ID of the application the job in being run on.
398
+ :param job_id: The ID of the job run.
399
+ :waiter_delay: polling period in seconds to check for the status
400
+ :param waiter_max_attempts: The maximum number of attempts to be made
401
+ :param aws_conn_id: Reference to AWS connection id
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ application_id: str,
407
+ job_id: str | None,
408
+ waiter_delay: int = 30,
409
+ waiter_max_attempts: int = 60,
410
+ aws_conn_id: str = "aws_default",
411
+ ) -> None:
412
+ super().__init__(
413
+ serialized_fields={"application_id": application_id, "job_id": job_id},
414
+ waiter_name="serverless_job_completed",
415
+ waiter_args={"applicationId": application_id, "jobRunId": job_id},
416
+ failure_message="Serverless Job failed",
417
+ status_message="Serverless Job status is",
418
+ status_queries=["jobRun.state", "jobRun.stateDetails"],
419
+ return_key="job_id",
420
+ return_value=job_id,
421
+ waiter_delay=waiter_delay,
422
+ waiter_max_attempts=waiter_max_attempts,
423
+ aws_conn_id=aws_conn_id,
424
+ )
425
+
426
+ def hook(self) -> AwsGenericHook:
427
+ return EmrServerlessHook(self.aws_conn_id)
428
+
429
+
430
+ class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger):
431
+ """
432
+ Poll an Emr Serverless application and wait for it to be deleted.
433
+
434
+ :param application_id: The ID of the application being polled.
435
+ :waiter_delay: polling period in seconds to check for the status
436
+ :param waiter_max_attempts: The maximum number of attempts to be made
437
+ :param aws_conn_id: Reference to AWS connection id
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ application_id: str,
443
+ waiter_delay: int = 30,
444
+ waiter_max_attempts: int = 60,
445
+ aws_conn_id: str = "aws_default",
446
+ ) -> None:
447
+ super().__init__(
448
+ serialized_fields={"application_id": application_id},
449
+ waiter_name="serverless_app_terminated",
450
+ waiter_args={"applicationId": application_id},
451
+ failure_message="Application failed to start",
452
+ status_message="Application status is",
453
+ status_queries=["application.state", "application.stateDetails"],
454
+ return_key="application_id",
455
+ return_value=application_id,
456
+ waiter_delay=waiter_delay,
457
+ waiter_max_attempts=waiter_max_attempts,
458
+ aws_conn_id=aws_conn_id,
459
+ )
460
+
461
+ def hook(self) -> AwsGenericHook:
462
+ return EmrServerlessHook(self.aws_conn_id)
463
+
464
+
465
+ class EmrServerlessCancelJobsTrigger(AwsBaseWaiterTrigger):
466
+ """
467
+ Trigger for canceling a list of jobs in an EMR Serverless application.
468
+
469
+ :param application_id: EMR Serverless application ID
470
+ :param aws_conn_id: Reference to AWS connection id
471
+ :param waiter_delay: Delay in seconds between each attempt to check the status
472
+ :param waiter_max_attempts: Maximum number of attempts to check the status
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ application_id: str,
478
+ aws_conn_id: str,
479
+ waiter_delay: int,
480
+ waiter_max_attempts: int,
481
+ ) -> None:
482
+ self.hook_instance = EmrServerlessHook(aws_conn_id)
483
+ states = list(self.hook_instance.JOB_INTERMEDIATE_STATES.union({"CANCELLING"}))
484
+ super().__init__(
485
+ serialized_fields={"application_id": application_id},
486
+ waiter_name="no_job_running",
487
+ waiter_args={"applicationId": application_id, "states": states},
488
+ failure_message="Error while waiting for jobs to cancel",
489
+ status_message="Currently running jobs",
490
+ status_queries=["jobRuns[*].applicationId", "jobRuns[*].state"],
491
+ return_key="application_id",
492
+ return_value=application_id,
493
+ waiter_delay=waiter_delay,
494
+ waiter_max_attempts=waiter_max_attempts,
495
+ aws_conn_id=aws_conn_id,
496
+ )
497
+
498
+ def hook(self) -> AwsGenericHook:
499
+ return self.hook_instance