apache-airflow-providers-amazon 8.18.0rc1__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +2 -2
  2. airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
  3. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
  4. airflow/providers/amazon/aws/hooks/base_aws.py +2 -2
  5. airflow/providers/amazon/aws/hooks/emr.py +6 -0
  6. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -1
  7. airflow/providers/amazon/aws/links/emr.py +122 -2
  8. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  9. airflow/providers/amazon/aws/operators/athena.py +4 -1
  10. airflow/providers/amazon/aws/operators/batch.py +5 -6
  11. airflow/providers/amazon/aws/operators/ecs.py +6 -2
  12. airflow/providers/amazon/aws/operators/eks.py +23 -20
  13. airflow/providers/amazon/aws/operators/emr.py +192 -26
  14. airflow/providers/amazon/aws/operators/glue.py +5 -2
  15. airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
  16. airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
  17. airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
  18. airflow/providers/amazon/aws/operators/rds.py +21 -12
  19. airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
  20. airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
  21. airflow/providers/amazon/aws/operators/sagemaker.py +24 -20
  22. airflow/providers/amazon/aws/operators/step_function.py +4 -1
  23. airflow/providers/amazon/aws/sensors/ec2.py +4 -2
  24. airflow/providers/amazon/aws/sensors/emr.py +13 -6
  25. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
  26. airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
  27. airflow/providers/amazon/aws/sensors/s3.py +3 -0
  28. airflow/providers/amazon/aws/sensors/sqs.py +4 -1
  29. airflow/providers/amazon/aws/utils/__init__.py +10 -0
  30. airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
  31. airflow/providers/amazon/get_provider_info.py +4 -0
  32. {apache_airflow_providers_amazon-8.18.0rc1.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +2 -2
  33. {apache_airflow_providers_amazon-8.18.0rc1.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +35 -34
  34. {apache_airflow_providers_amazon-8.18.0rc1.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
  35. {apache_airflow_providers_amazon-8.18.0rc1.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -27,8 +27,17 @@ from uuid import uuid4
27
27
  from airflow.configuration import conf
28
28
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
29
29
  from airflow.models import BaseOperator
30
+ from airflow.models.mappedoperator import MappedOperator
30
31
  from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
31
- from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
32
+ from airflow.providers.amazon.aws.links.emr import (
33
+ EmrClusterLink,
34
+ EmrLogsLink,
35
+ EmrServerlessCloudWatchLogsLink,
36
+ EmrServerlessDashboardLink,
37
+ EmrServerlessLogsLink,
38
+ EmrServerlessS3LogsLink,
39
+ get_log_uri,
40
+ )
32
41
  from airflow.providers.amazon.aws.triggers.emr import (
33
42
  EmrAddStepsTrigger,
34
43
  EmrContainerTrigger,
@@ -41,6 +50,7 @@ from airflow.providers.amazon.aws.triggers.emr import (
41
50
  EmrServerlessStopApplicationTrigger,
42
51
  EmrTerminateJobFlowTrigger,
43
52
  )
53
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
44
54
  from airflow.providers.amazon.aws.utils.waiter import waiter
45
55
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
46
56
  from airflow.utils.helpers import exactly_one, prune_dict
@@ -180,11 +190,13 @@ class EmrAddStepsOperator(BaseOperator):
180
190
 
181
191
  return step_ids
182
192
 
183
- def execute_complete(self, context, event=None):
193
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
194
+ event = validate_execute_complete_event(event)
195
+
184
196
  if event["status"] != "success":
185
197
  raise AirflowException(f"Error while running steps: {event}")
186
- else:
187
- self.log.info("Steps completed successfully")
198
+
199
+ self.log.info("Steps completed successfully")
188
200
  return event["value"]
189
201
 
190
202
 
@@ -494,6 +506,8 @@ class EmrContainerOperator(BaseOperator):
494
506
  :param max_tries: Deprecated - use max_polling_attempts instead.
495
507
  :param max_polling_attempts: Maximum number of times to wait for the job run to finish.
496
508
  Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
509
+ :param job_retry_max_attempts: Maximum number of times to retry when the EMR job fails.
510
+ Defaults to None, which disable the retry.
497
511
  :param tags: The tags assigned to job runs.
498
512
  Defaults to None
499
513
  :param deferrable: Run operator in the deferrable mode.
@@ -525,6 +539,7 @@ class EmrContainerOperator(BaseOperator):
525
539
  max_tries: int | None = None,
526
540
  tags: dict | None = None,
527
541
  max_polling_attempts: int | None = None,
542
+ job_retry_max_attempts: int | None = None,
528
543
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
529
544
  **kwargs: Any,
530
545
  ) -> None:
@@ -540,6 +555,7 @@ class EmrContainerOperator(BaseOperator):
540
555
  self.wait_for_completion = wait_for_completion
541
556
  self.poll_interval = poll_interval
542
557
  self.max_polling_attempts = max_polling_attempts
558
+ self.job_retry_max_attempts = job_retry_max_attempts
543
559
  self.tags = tags
544
560
  self.job_id: str | None = None
545
561
  self.deferrable = deferrable
@@ -574,6 +590,7 @@ class EmrContainerOperator(BaseOperator):
574
590
  self.configuration_overrides,
575
591
  self.client_request_token,
576
592
  self.tags,
593
+ self.job_retry_max_attempts,
577
594
  )
578
595
  if self.deferrable:
579
596
  query_status = self.hook.check_query_status(job_id=self.job_id)
@@ -619,7 +636,9 @@ class EmrContainerOperator(BaseOperator):
619
636
  f"query_execution_id is {self.job_id}. Error: {error_message}"
620
637
  )
621
638
 
622
- def execute_complete(self, context, event=None):
639
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
640
+ event = validate_execute_complete_event(event)
641
+
623
642
  if event["status"] != "success":
624
643
  raise AirflowException(f"Error while running job: {event}")
625
644
 
@@ -806,11 +825,13 @@ class EmrCreateJobFlowOperator(BaseOperator):
806
825
  )
807
826
  return self._job_flow_id
808
827
 
809
- def execute_complete(self, context, event=None):
828
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
829
+ event = validate_execute_complete_event(event)
830
+
810
831
  if event["status"] != "success":
811
832
  raise AirflowException(f"Error creating jobFlow: {event}")
812
- else:
813
- self.log.info("JobFlow created successfully")
833
+
834
+ self.log.info("JobFlow created successfully")
814
835
  return event["job_flow_id"]
815
836
 
816
837
  def on_kill(self) -> None:
@@ -969,12 +990,13 @@ class EmrTerminateJobFlowOperator(BaseOperator):
969
990
  timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
970
991
  )
971
992
 
972
- def execute_complete(self, context, event=None):
993
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
994
+ event = validate_execute_complete_event(event)
995
+
973
996
  if event["status"] != "success":
974
997
  raise AirflowException(f"Error terminating JobFlow: {event}")
975
- else:
976
- self.log.info("Jobflow terminated successfully.")
977
- return
998
+
999
+ self.log.info("Jobflow terminated successfully.")
978
1000
 
979
1001
 
980
1002
  class EmrServerlessCreateApplicationOperator(BaseOperator):
@@ -1135,7 +1157,9 @@ class EmrServerlessCreateApplicationOperator(BaseOperator):
1135
1157
  )
1136
1158
 
1137
1159
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
1138
- if event is None or event["status"] != "success":
1160
+ event = validate_execute_complete_event(event)
1161
+
1162
+ if event["status"] != "success":
1139
1163
  raise AirflowException(f"Trigger error: Application failed to start, event is {event}")
1140
1164
 
1141
1165
  self.log.info("Application %s started", event["application_id"])
@@ -1172,6 +1196,9 @@ class EmrServerlessStartJobOperator(BaseOperator):
1172
1196
  :param deferrable: If True, the operator will wait asynchronously for the crawl to complete.
1173
1197
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
1174
1198
  (default: False, but can be overridden in config file by setting default_deferrable to True)
1199
+ :param enable_application_ui_links: If True, the operator will generate one-time links to EMR Serverless
1200
+ application UIs. The generated links will allow any user with access to the DAG to see the Spark or
1201
+ Tez UI or Spark stdout logs. Defaults to False.
1175
1202
  """
1176
1203
 
1177
1204
  template_fields: Sequence[str] = (
@@ -1181,6 +1208,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
1181
1208
  "job_driver",
1182
1209
  "configuration_overrides",
1183
1210
  "name",
1211
+ "aws_conn_id",
1184
1212
  )
1185
1213
 
1186
1214
  template_fields_renderers = {
@@ -1188,12 +1216,48 @@ class EmrServerlessStartJobOperator(BaseOperator):
1188
1216
  "configuration_overrides": "json",
1189
1217
  }
1190
1218
 
1219
+ @property
1220
+ def operator_extra_links(self):
1221
+ """
1222
+ Dynamically add extra links depending on the job type and if they're enabled.
1223
+
1224
+ If S3 or CloudWatch monitoring configurations exist, add links directly to the relevant consoles.
1225
+ Only add dashboard links if they're explicitly enabled. These are one-time links that any user
1226
+ can access, but expire on first click or one hour, whichever comes first.
1227
+ """
1228
+ op_extra_links = []
1229
+
1230
+ if isinstance(self, MappedOperator):
1231
+ enable_application_ui_links = self.partial_kwargs.get(
1232
+ "enable_application_ui_links"
1233
+ ) or self.expand_input.value.get("enable_application_ui_links")
1234
+ job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get("job_driver")
1235
+ configuration_overrides = self.partial_kwargs.get(
1236
+ "configuration_overrides"
1237
+ ) or self.expand_input.value.get("configuration_overrides")
1238
+
1239
+ else:
1240
+ enable_application_ui_links = self.enable_application_ui_links
1241
+ configuration_overrides = self.configuration_overrides
1242
+ job_driver = self.job_driver
1243
+
1244
+ if enable_application_ui_links:
1245
+ op_extra_links.extend([EmrServerlessDashboardLink()])
1246
+ if "sparkSubmit" in job_driver:
1247
+ op_extra_links.extend([EmrServerlessLogsLink()])
1248
+ if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides):
1249
+ op_extra_links.extend([EmrServerlessS3LogsLink()])
1250
+ if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides):
1251
+ op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
1252
+
1253
+ return tuple(op_extra_links)
1254
+
1191
1255
  def __init__(
1192
1256
  self,
1193
1257
  application_id: str,
1194
1258
  execution_role_arn: str,
1195
1259
  job_driver: dict,
1196
- configuration_overrides: dict | None,
1260
+ configuration_overrides: dict | None = None,
1197
1261
  client_request_token: str = "",
1198
1262
  config: dict | None = None,
1199
1263
  wait_for_completion: bool = True,
@@ -1204,6 +1268,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
1204
1268
  waiter_max_attempts: int | ArgNotSet = NOTSET,
1205
1269
  waiter_delay: int | ArgNotSet = NOTSET,
1206
1270
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1271
+ enable_application_ui_links: bool = False,
1207
1272
  **kwargs,
1208
1273
  ):
1209
1274
  if waiter_check_interval_seconds is NOTSET:
@@ -1243,6 +1308,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
1243
1308
  self.waiter_delay = int(waiter_delay) # type: ignore[arg-type]
1244
1309
  self.job_id: str | None = None
1245
1310
  self.deferrable = deferrable
1311
+ self.enable_application_ui_links = enable_application_ui_links
1246
1312
  super().__init__(**kwargs)
1247
1313
 
1248
1314
  self.client_request_token = client_request_token or str(uuid4())
@@ -1300,6 +1366,9 @@ class EmrServerlessStartJobOperator(BaseOperator):
1300
1366
 
1301
1367
  self.job_id = response["jobRunId"]
1302
1368
  self.log.info("EMR serverless job started: %s", self.job_id)
1369
+
1370
+ self.persist_links(context)
1371
+
1303
1372
  if self.deferrable:
1304
1373
  self.defer(
1305
1374
  trigger=EmrServerlessStartJobTrigger(
@@ -1312,6 +1381,7 @@ class EmrServerlessStartJobOperator(BaseOperator):
1312
1381
  method_name="execute_complete",
1313
1382
  timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
1314
1383
  )
1384
+
1315
1385
  if self.wait_for_completion:
1316
1386
  waiter = self.hook.get_waiter("serverless_job_completed")
1317
1387
  wait(
@@ -1327,10 +1397,9 @@ class EmrServerlessStartJobOperator(BaseOperator):
1327
1397
  return self.job_id
1328
1398
 
1329
1399
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
1330
- if event is None:
1331
- self.log.error("Trigger error: event is None")
1332
- raise AirflowException("Trigger error: event is None")
1333
- elif event["status"] == "success":
1400
+ event = validate_execute_complete_event(event)
1401
+
1402
+ if event["status"] == "success":
1334
1403
  self.log.info("Serverless job completed")
1335
1404
  return event["job_id"]
1336
1405
 
@@ -1369,6 +1438,105 @@ class EmrServerlessStartJobOperator(BaseOperator):
1369
1438
  check_interval_seconds=self.waiter_delay,
1370
1439
  )
1371
1440
 
1441
+ def is_monitoring_in_job_override(self, config_key: str, job_override: dict | None) -> bool:
1442
+ """
1443
+ Check if monitoring is enabled for the job.
1444
+
1445
+ Note: This is not compatible with application defaults:
1446
+ https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/default-configs.html
1447
+
1448
+ This is used to determine what extra links should be shown.
1449
+ """
1450
+ monitoring_config = (job_override or {}).get("monitoringConfiguration")
1451
+ if monitoring_config is None or config_key not in monitoring_config:
1452
+ return False
1453
+
1454
+ # CloudWatch can have an "enabled" flag set to False
1455
+ if config_key == "cloudWatchLoggingConfiguration":
1456
+ return monitoring_config.get(config_key).get("enabled") is True
1457
+
1458
+ return config_key in monitoring_config
1459
+
1460
+ def persist_links(self, context: Context):
1461
+ """Populate the relevant extra links for the EMR Serverless jobs."""
1462
+ # Persist the EMR Serverless Dashboard link (Spark/Tez UI)
1463
+ if self.enable_application_ui_links:
1464
+ EmrServerlessDashboardLink.persist(
1465
+ context=context,
1466
+ operator=self,
1467
+ region_name=self.hook.conn_region_name,
1468
+ aws_partition=self.hook.conn_partition,
1469
+ conn_id=self.hook.aws_conn_id,
1470
+ application_id=self.application_id,
1471
+ job_run_id=self.job_id,
1472
+ )
1473
+
1474
+ # If this is a Spark job, persist the EMR Serverless logs link (Driver stdout)
1475
+ if self.enable_application_ui_links and "sparkSubmit" in self.job_driver:
1476
+ EmrServerlessLogsLink.persist(
1477
+ context=context,
1478
+ operator=self,
1479
+ region_name=self.hook.conn_region_name,
1480
+ aws_partition=self.hook.conn_partition,
1481
+ conn_id=self.hook.aws_conn_id,
1482
+ application_id=self.application_id,
1483
+ job_run_id=self.job_id,
1484
+ )
1485
+
1486
+ # Add S3 and/or CloudWatch links if either is enabled
1487
+ if self.is_monitoring_in_job_override("s3MonitoringConfiguration", self.configuration_overrides):
1488
+ log_uri = (
1489
+ (self.configuration_overrides or {})
1490
+ .get("monitoringConfiguration", {})
1491
+ .get("s3MonitoringConfiguration", {})
1492
+ .get("logUri")
1493
+ )
1494
+ EmrServerlessS3LogsLink.persist(
1495
+ context=context,
1496
+ operator=self,
1497
+ region_name=self.hook.conn_region_name,
1498
+ aws_partition=self.hook.conn_partition,
1499
+ log_uri=log_uri,
1500
+ application_id=self.application_id,
1501
+ job_run_id=self.job_id,
1502
+ )
1503
+ emrs_s3_url = EmrServerlessS3LogsLink().format_link(
1504
+ aws_domain=EmrServerlessCloudWatchLogsLink.get_aws_domain(self.hook.conn_partition),
1505
+ region_name=self.hook.conn_region_name,
1506
+ aws_partition=self.hook.conn_partition,
1507
+ log_uri=log_uri,
1508
+ application_id=self.application_id,
1509
+ job_run_id=self.job_id,
1510
+ )
1511
+ self.log.info("S3 logs available at: %s", emrs_s3_url)
1512
+
1513
+ if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", self.configuration_overrides):
1514
+ cloudwatch_config = (
1515
+ (self.configuration_overrides or {})
1516
+ .get("monitoringConfiguration", {})
1517
+ .get("cloudWatchLoggingConfiguration", {})
1518
+ )
1519
+ log_group_name = cloudwatch_config.get("logGroupName", "/aws/emr-serverless")
1520
+ log_stream_prefix = cloudwatch_config.get("logStreamNamePrefix", "")
1521
+ log_stream_prefix = f"{log_stream_prefix}/applications/{self.application_id}/jobs/{self.job_id}"
1522
+
1523
+ EmrServerlessCloudWatchLogsLink.persist(
1524
+ context=context,
1525
+ operator=self,
1526
+ region_name=self.hook.conn_region_name,
1527
+ aws_partition=self.hook.conn_partition,
1528
+ awslogs_group=log_group_name,
1529
+ stream_prefix=log_stream_prefix,
1530
+ )
1531
+ emrs_cloudwatch_url = EmrServerlessCloudWatchLogsLink().format_link(
1532
+ aws_domain=EmrServerlessCloudWatchLogsLink.get_aws_domain(self.hook.conn_partition),
1533
+ region_name=self.hook.conn_region_name,
1534
+ aws_partition=self.hook.conn_partition,
1535
+ awslogs_group=log_group_name,
1536
+ stream_prefix=log_stream_prefix,
1537
+ )
1538
+ self.log.info("CloudWatch logs available at: %s", emrs_cloudwatch_url)
1539
+
1372
1540
 
1373
1541
  class EmrServerlessStopApplicationOperator(BaseOperator):
1374
1542
  """
@@ -1527,10 +1695,9 @@ class EmrServerlessStopApplicationOperator(BaseOperator):
1527
1695
  )
1528
1696
 
1529
1697
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
1530
- if event is None:
1531
- self.log.error("Trigger error: event is None")
1532
- raise AirflowException("Trigger error: event is None")
1533
- elif event["status"] == "success":
1698
+ event = validate_execute_complete_event(event)
1699
+
1700
+ if event["status"] == "success":
1534
1701
  self.log.info("EMR serverless application %s stopped successfully", self.application_id)
1535
1702
 
1536
1703
 
@@ -1656,8 +1823,7 @@ class EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
1656
1823
  self.log.info("EMR serverless application deleted")
1657
1824
 
1658
1825
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
1659
- if event is None:
1660
- self.log.error("Trigger error: event is None")
1661
- raise AirflowException("Trigger error: event is None")
1662
- elif event["status"] == "success":
1826
+ event = validate_execute_complete_event(event)
1827
+
1828
+ if event["status"] == "success":
1663
1829
  self.log.info("EMR serverless application %s deleted successfully", self.application_id)
@@ -20,7 +20,7 @@ from __future__ import annotations
20
20
  import os
21
21
  import urllib.parse
22
22
  from functools import cached_property
23
- from typing import TYPE_CHECKING, Sequence
23
+ from typing import TYPE_CHECKING, Any, Sequence
24
24
 
25
25
  from airflow.configuration import conf
26
26
  from airflow.exceptions import AirflowException
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
29
29
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
30
30
  from airflow.providers.amazon.aws.links.glue import GlueJobRunDetailsLink
31
31
  from airflow.providers.amazon.aws.triggers.glue import GlueJobCompleteTrigger
32
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
32
33
 
33
34
  if TYPE_CHECKING:
34
35
  from airflow.utils.context import Context
@@ -215,7 +216,9 @@ class GlueJobOperator(BaseOperator):
215
216
  self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
216
217
  return self._job_run_id
217
218
 
218
- def execute_complete(self, context, event=None):
219
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
220
+ event = validate_execute_complete_event(event)
221
+
219
222
  if event["status"] != "success":
220
223
  raise AirflowException(f"Error in glue job: {event}")
221
224
  return event["value"]
@@ -18,11 +18,12 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from functools import cached_property
21
- from typing import TYPE_CHECKING, Sequence
21
+ from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
23
  from airflow.configuration import conf
24
24
  from airflow.exceptions import AirflowException
25
25
  from airflow.providers.amazon.aws.triggers.glue_crawler import GlueCrawlerCompleteTrigger
26
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
26
27
 
27
28
  if TYPE_CHECKING:
28
29
  from airflow.utils.context import Context
@@ -107,7 +108,9 @@ class GlueCrawlerOperator(BaseOperator):
107
108
 
108
109
  return crawler_name
109
110
 
110
- def execute_complete(self, context, event=None):
111
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
112
+ event = validate_execute_complete_event(event)
113
+
111
114
  if event["status"] != "success":
112
115
  raise AirflowException(f"Error in glue crawl: {event}")
113
116
  return self.config["Name"]
@@ -18,12 +18,13 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from functools import cached_property
21
- from typing import TYPE_CHECKING, Sequence
21
+ from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
23
  from airflow.configuration import conf
24
24
  from airflow.models import BaseOperator
25
25
  from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
26
26
  from airflow.providers.amazon.aws.triggers.glue_databrew import GlueDataBrewJobCompleteTrigger
27
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from airflow.utils.context import Context
@@ -101,7 +102,9 @@ class GlueDataBrewStartJobOperator(BaseOperator):
101
102
 
102
103
  return {"run_id": run_id}
103
104
 
104
- def execute_complete(self, context: Context, event=None) -> dict[str, str]:
105
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
106
+ event = validate_execute_complete_event(event)
107
+
105
108
  run_id = event.get("run_id", "")
106
109
  status = event.get("status", "")
107
110
 
@@ -26,6 +26,7 @@ from airflow.exceptions import AirflowException
26
26
  from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
27
27
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
28
28
  from airflow.providers.amazon.aws.triggers.lambda_function import LambdaCreateFunctionCompleteTrigger
29
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
29
30
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
31
 
31
32
  if TYPE_CHECKING:
@@ -143,6 +144,8 @@ class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
143
144
  return response.get("FunctionArn")
144
145
 
145
146
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
147
+ event = validate_execute_complete_event(event)
148
+
146
149
  if not event or event["status"] != "success":
147
150
  raise AirflowException(f"Trigger error: event is {event}")
148
151
 
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.rds import (
32
32
  RdsDbDeletedTrigger,
33
33
  RdsDbStoppedTrigger,
34
34
  )
35
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
35
36
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
36
37
  from airflow.providers.amazon.aws.utils.tags import format_tags
37
38
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -637,11 +638,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
637
638
  )
638
639
  return json.dumps(create_db_instance, default=str)
639
640
 
640
- def execute_complete(self, context, event=None) -> str:
641
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
642
+ event = validate_execute_complete_event(event)
643
+
641
644
  if event["status"] != "success":
642
645
  raise AirflowException(f"DB instance creation failed: {event}")
643
- else:
644
- return json.dumps(event["response"], default=str)
646
+
647
+ return json.dumps(event["response"], default=str)
645
648
 
646
649
 
647
650
  class RdsDeleteDbInstanceOperator(RdsBaseOperator):
@@ -720,11 +723,13 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
720
723
  )
721
724
  return json.dumps(delete_db_instance, default=str)
722
725
 
723
- def execute_complete(self, context, event=None) -> str:
726
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
727
+ event = validate_execute_complete_event(event)
728
+
724
729
  if event["status"] != "success":
725
730
  raise AirflowException(f"DB instance deletion failed: {event}")
726
- else:
727
- return json.dumps(event["response"], default=str)
731
+
732
+ return json.dumps(event["response"], default=str)
728
733
 
729
734
 
730
735
  class RdsStartDbOperator(RdsBaseOperator):
@@ -786,10 +791,12 @@ class RdsStartDbOperator(RdsBaseOperator):
786
791
  return json.dumps(start_db_response, default=str)
787
792
 
788
793
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
789
- if event is None or event["status"] != "success":
794
+ event = validate_execute_complete_event(event)
795
+
796
+ if event["status"] != "success":
790
797
  raise AirflowException(f"Failed to start DB: {event}")
791
- else:
792
- return json.dumps(event["response"], default=str)
798
+
799
+ return json.dumps(event["response"], default=str)
793
800
 
794
801
  def _start_db(self):
795
802
  self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
@@ -883,10 +890,12 @@ class RdsStopDbOperator(RdsBaseOperator):
883
890
  return json.dumps(stop_db_response, default=str)
884
891
 
885
892
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
886
- if event is None or event["status"] != "success":
893
+ event = validate_execute_complete_event(event)
894
+
895
+ if event["status"] != "success":
887
896
  raise AirflowException(f"Failed to start DB: {event}")
888
- else:
889
- return json.dumps(event["response"], default=str)
897
+
898
+ return json.dumps(event["response"], default=str)
890
899
 
891
900
  def _stop_db(self):
892
901
  self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
@@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import (
31
31
  RedshiftPauseClusterTrigger,
32
32
  RedshiftResumeClusterTrigger,
33
33
  )
34
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from airflow.utils.context import Context
@@ -314,10 +315,11 @@ class RedshiftCreateClusterOperator(BaseOperator):
314
315
  self.log.info("Created Redshift cluster %s", self.cluster_identifier)
315
316
  self.log.info(cluster)
316
317
 
317
- def execute_complete(self, context, event=None):
318
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
319
+ event = validate_execute_complete_event(event)
320
+
318
321
  if event["status"] != "success":
319
322
  raise AirflowException(f"Error creating cluster: {event}")
320
- return
321
323
 
322
324
 
323
325
  class RedshiftCreateClusterSnapshotOperator(BaseOperator):
@@ -409,12 +411,13 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator):
409
411
  },
410
412
  )
411
413
 
412
- def execute_complete(self, context, event=None):
414
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
415
+ event = validate_execute_complete_event(event)
416
+
413
417
  if event["status"] != "success":
414
418
  raise AirflowException(f"Error creating snapshot: {event}")
415
- else:
416
- self.log.info("Cluster snapshot created.")
417
- return
419
+
420
+ self.log.info("Cluster snapshot created.")
418
421
 
419
422
 
420
423
  class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
@@ -569,10 +572,7 @@ class RedshiftResumeClusterOperator(BaseOperator):
569
572
  )
570
573
 
571
574
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
572
- if event is None:
573
- err_msg = "Trigger error: event is None"
574
- self.log.info(err_msg)
575
- raise AirflowException(err_msg)
575
+ event = validate_execute_complete_event(event)
576
576
 
577
577
  if event["status"] != "success":
578
578
  raise AirflowException(f"Error resuming cluster: {event}")
@@ -659,10 +659,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
659
659
  )
660
660
 
661
661
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
662
- if event is None:
663
- err_msg = "Trigger error: event is None"
664
- self.log.info(err_msg)
665
- raise AirflowException(err_msg)
662
+ event = validate_execute_complete_event(event)
666
663
 
667
664
  if event["status"] != "success":
668
665
  raise AirflowException(f"Error pausing cluster: {event}")
@@ -767,10 +764,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
767
764
  )
768
765
 
769
766
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
770
- if event is None:
771
- err_msg = "Trigger error: event is None"
772
- self.log.info(err_msg)
773
- raise AirflowException(err_msg)
767
+ event = validate_execute_complete_event(event)
774
768
 
775
769
  if event["status"] != "success":
776
770
  raise AirflowException(f"Error deleting cluster: {event}")
@@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException
24
24
  from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
25
25
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
26
26
  from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
27
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
27
28
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
28
29
 
29
30
  if TYPE_CHECKING:
@@ -170,10 +171,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
170
171
  def execute_complete(
171
172
  self, context: Context, event: dict[str, Any] | None = None
172
173
  ) -> GetStatementResultResponseTypeDef | str:
173
- if event is None:
174
- err_msg = "Trigger error: event is None"
175
- self.log.info(err_msg)
176
- raise AirflowException(err_msg)
174
+ event = validate_execute_complete_event(event)
177
175
 
178
176
  if event["status"] == "error":
179
177
  msg = f"context: {context}, error message: {event['message']}"