apache-airflow-providers-amazon 8.3.1rc1__py3-none-any.whl → 8.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. airflow/providers/amazon/__init__.py +4 -2
  2. airflow/providers/amazon/aws/hooks/base_aws.py +29 -12
  3. airflow/providers/amazon/aws/hooks/emr.py +17 -9
  4. airflow/providers/amazon/aws/hooks/eventbridge.py +27 -0
  5. airflow/providers/amazon/aws/hooks/redshift_data.py +10 -0
  6. airflow/providers/amazon/aws/hooks/sagemaker.py +24 -14
  7. airflow/providers/amazon/aws/notifications/chime.py +1 -1
  8. airflow/providers/amazon/aws/operators/eks.py +140 -7
  9. airflow/providers/amazon/aws/operators/emr.py +202 -22
  10. airflow/providers/amazon/aws/operators/eventbridge.py +87 -0
  11. airflow/providers/amazon/aws/operators/rds.py +120 -48
  12. airflow/providers/amazon/aws/operators/redshift_data.py +7 -0
  13. airflow/providers/amazon/aws/operators/sagemaker.py +75 -7
  14. airflow/providers/amazon/aws/operators/step_function.py +34 -2
  15. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
  16. airflow/providers/amazon/aws/triggers/batch.py +1 -1
  17. airflow/providers/amazon/aws/triggers/ecs.py +7 -5
  18. airflow/providers/amazon/aws/triggers/eks.py +174 -3
  19. airflow/providers/amazon/aws/triggers/emr.py +215 -1
  20. airflow/providers/amazon/aws/triggers/rds.py +161 -5
  21. airflow/providers/amazon/aws/triggers/sagemaker.py +84 -1
  22. airflow/providers/amazon/aws/triggers/step_function.py +59 -0
  23. airflow/providers/amazon/aws/utils/__init__.py +16 -1
  24. airflow/providers/amazon/aws/utils/rds.py +2 -2
  25. airflow/providers/amazon/aws/waiters/sagemaker.json +46 -0
  26. airflow/providers/amazon/aws/waiters/stepfunctions.json +36 -0
  27. airflow/providers/amazon/get_provider_info.py +21 -1
  28. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/METADATA +13 -13
  29. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/RECORD +34 -30
  30. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/WHEEL +1 -1
  31. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/LICENSE +0 -0
  32. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/NOTICE +0 -0
  33. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/entry_points.txt +0 -0
  34. {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/top_level.txt +0 -0
@@ -20,15 +20,20 @@ from __future__ import annotations
20
20
  import json
21
21
  import warnings
22
22
  from datetime import timedelta
23
- from typing import TYPE_CHECKING, Sequence
23
+ from typing import TYPE_CHECKING, Any, Sequence
24
24
 
25
+ from astroid.decorators import cachedproperty
25
26
  from mypy_boto3_rds.type_defs import TagTypeDef
26
27
 
27
28
  from airflow.configuration import conf
28
29
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
29
30
  from airflow.models import BaseOperator
30
31
  from airflow.providers.amazon.aws.hooks.rds import RdsHook
31
- from airflow.providers.amazon.aws.triggers.rds import RdsDbInstanceTrigger
32
+ from airflow.providers.amazon.aws.triggers.rds import (
33
+ RdsDbAvailableTrigger,
34
+ RdsDbDeletedTrigger,
35
+ RdsDbStoppedTrigger,
36
+ )
32
37
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
33
38
  from airflow.providers.amazon.aws.utils.tags import format_tags
34
39
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -62,13 +67,17 @@ class RdsBaseOperator(BaseOperator):
62
67
  AirflowProviderDeprecationWarning,
63
68
  stacklevel=3, # 2 is in the operator's init, 3 is in the user code creating the operator
64
69
  )
65
- hook_params = hook_params or {}
66
- self.region_name = region_name or hook_params.pop("region_name", None)
67
- self.hook = RdsHook(aws_conn_id=aws_conn_id, region_name=self.region_name, **(hook_params))
70
+ self.hook_params = hook_params or {}
71
+ self.aws_conn_id = aws_conn_id
72
+ self.region_name = region_name or self.hook_params.pop("region_name", None)
68
73
  super().__init__(*args, **kwargs)
69
74
 
70
75
  self._await_interval = 60 # seconds
71
76
 
77
+ @cachedproperty
78
+ def hook(self) -> RdsHook:
79
+ return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, **self.hook_params)
80
+
72
81
  def execute(self, context: Context) -> str:
73
82
  """Different implementations for snapshots, tasks and events."""
74
83
  raise NotImplementedError
@@ -106,10 +115,9 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
106
115
  db_snapshot_identifier: str,
107
116
  tags: Sequence[TagTypeDef] | dict | None = None,
108
117
  wait_for_completion: bool = True,
109
- aws_conn_id: str = "aws_default",
110
118
  **kwargs,
111
119
  ):
112
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
120
+ super().__init__(**kwargs)
113
121
  self.db_type = RdsDbType(db_type)
114
122
  self.db_identifier = db_identifier
115
123
  self.db_snapshot_identifier = db_snapshot_identifier
@@ -194,10 +202,9 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
194
202
  target_custom_availability_zone: str = "",
195
203
  source_region: str = "",
196
204
  wait_for_completion: bool = True,
197
- aws_conn_id: str = "aws_default",
198
205
  **kwargs,
199
206
  ):
200
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
207
+ super().__init__(**kwargs)
201
208
 
202
209
  self.db_type = RdsDbType(db_type)
203
210
  self.source_db_snapshot_identifier = source_db_snapshot_identifier
@@ -274,10 +281,9 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
274
281
  db_type: str,
275
282
  db_snapshot_identifier: str,
276
283
  wait_for_completion: bool = True,
277
- aws_conn_id: str = "aws_default",
278
284
  **kwargs,
279
285
  ):
280
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
286
+ super().__init__(**kwargs)
281
287
 
282
288
  self.db_type = RdsDbType(db_type)
283
289
  self.db_snapshot_identifier = db_snapshot_identifier
@@ -345,10 +351,9 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
345
351
  s3_prefix: str = "",
346
352
  export_only: list[str] | None = None,
347
353
  wait_for_completion: bool = True,
348
- aws_conn_id: str = "aws_default",
349
354
  **kwargs,
350
355
  ):
351
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
356
+ super().__init__(**kwargs)
352
357
 
353
358
  self.export_task_identifier = export_task_identifier
354
359
  self.source_arn = source_arn
@@ -397,10 +402,9 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
397
402
  export_task_identifier: str,
398
403
  wait_for_completion: bool = True,
399
404
  check_interval: int = 30,
400
- aws_conn_id: str = "aws_default",
401
405
  **kwargs,
402
406
  ):
403
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
407
+ super().__init__(**kwargs)
404
408
 
405
409
  self.export_task_identifier = export_task_identifier
406
410
  self.wait_for_completion = wait_for_completion
@@ -461,10 +465,9 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
461
465
  enabled: bool = True,
462
466
  tags: Sequence[TagTypeDef] | dict | None = None,
463
467
  wait_for_completion: bool = True,
464
- aws_conn_id: str = "aws_default",
465
468
  **kwargs,
466
469
  ):
467
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
470
+ super().__init__(**kwargs)
468
471
 
469
472
  self.subscription_name = subscription_name
470
473
  self.sns_topic_arn = sns_topic_arn
@@ -511,10 +514,9 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
511
514
  self,
512
515
  *,
513
516
  subscription_name: str,
514
- aws_conn_id: str = "aws_default",
515
517
  **kwargs,
516
518
  ):
517
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
519
+ super().__init__(**kwargs)
518
520
 
519
521
  self.subscription_name = subscription_name
520
522
 
@@ -545,7 +547,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
545
547
  :param engine: The name of the database engine to be used for this instance
546
548
  :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
547
549
  https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
548
- :param aws_conn_id: The Airflow connection used for AWS credentials.
549
550
  :param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True)
550
551
  :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
551
552
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
@@ -563,14 +564,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
563
564
  db_instance_class: str,
564
565
  engine: str,
565
566
  rds_kwargs: dict | None = None,
566
- aws_conn_id: str = "aws_default",
567
567
  wait_for_completion: bool = True,
568
568
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
569
569
  waiter_delay: int = 30,
570
570
  waiter_max_attempts: int = 60,
571
571
  **kwargs,
572
572
  ):
573
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
573
+ super().__init__(**kwargs)
574
574
 
575
575
  self.db_instance_identifier = db_instance_identifier
576
576
  self.db_instance_class = db_instance_class
@@ -580,7 +580,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
580
580
  self.deferrable = deferrable
581
581
  self.waiter_delay = waiter_delay
582
582
  self.waiter_max_attempts = waiter_max_attempts
583
- self.aws_conn_id = aws_conn_id
584
583
 
585
584
  def execute(self, context: Context) -> str:
586
585
  self.log.info("Creating new DB instance %s", self.db_instance_identifier)
@@ -593,15 +592,15 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
593
592
  )
594
593
  if self.deferrable:
595
594
  self.defer(
596
- trigger=RdsDbInstanceTrigger(
597
- db_instance_identifier=self.db_instance_identifier,
595
+ trigger=RdsDbAvailableTrigger(
596
+ db_identifier=self.db_instance_identifier,
598
597
  waiter_delay=self.waiter_delay,
599
598
  waiter_max_attempts=self.waiter_max_attempts,
600
599
  aws_conn_id=self.aws_conn_id,
601
600
  region_name=self.region_name,
602
- waiter_name="db_instance_available",
603
601
  # ignoring type because create_db_instance is a dict
604
602
  response=create_db_instance, # type: ignore[arg-type]
603
+ db_type=RdsDbType.INSTANCE,
605
604
  ),
606
605
  method_name="execute_complete",
607
606
  timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
@@ -638,7 +637,6 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
638
637
  :param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
639
638
  :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
640
639
  https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
641
- :param aws_conn_id: The Airflow connection used for AWS credentials.
642
640
  :param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True)
643
641
  :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
644
642
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
@@ -654,21 +652,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
654
652
  *,
655
653
  db_instance_identifier: str,
656
654
  rds_kwargs: dict | None = None,
657
- aws_conn_id: str = "aws_default",
658
655
  wait_for_completion: bool = True,
659
656
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
660
657
  waiter_delay: int = 30,
661
658
  waiter_max_attempts: int = 60,
662
659
  **kwargs,
663
660
  ):
664
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
661
+ super().__init__(**kwargs)
665
662
  self.db_instance_identifier = db_instance_identifier
666
663
  self.rds_kwargs = rds_kwargs or {}
667
664
  self.wait_for_completion = False if deferrable else wait_for_completion
668
665
  self.deferrable = deferrable
669
666
  self.waiter_delay = waiter_delay
670
667
  self.waiter_max_attempts = waiter_max_attempts
671
- self.aws_conn_id = aws_conn_id
672
668
 
673
669
  def execute(self, context: Context) -> str:
674
670
  self.log.info("Deleting DB instance %s", self.db_instance_identifier)
@@ -679,15 +675,15 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
679
675
  )
680
676
  if self.deferrable:
681
677
  self.defer(
682
- trigger=RdsDbInstanceTrigger(
683
- db_instance_identifier=self.db_instance_identifier,
678
+ trigger=RdsDbDeletedTrigger(
679
+ db_identifier=self.db_instance_identifier,
684
680
  waiter_delay=self.waiter_delay,
685
681
  waiter_max_attempts=self.waiter_max_attempts,
686
682
  aws_conn_id=self.aws_conn_id,
687
683
  region_name=self.region_name,
688
- waiter_name="db_instance_deleted",
689
684
  # ignoring type because delete_db_instance is a dict
690
685
  response=delete_db_instance, # type: ignore[arg-type]
686
+ db_type=RdsDbType.INSTANCE,
691
687
  ),
692
688
  method_name="execute_complete",
693
689
  timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
@@ -723,8 +719,11 @@ class RdsStartDbOperator(RdsBaseOperator):
723
719
 
724
720
  :param db_identifier: The AWS identifier of the DB to start
725
721
  :param db_type: Type of the DB - either "instance" or "cluster" (default: "instance")
726
- :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
727
722
  :param wait_for_completion: If True, waits for DB to start. (default: True)
723
+ :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
724
+ :param waiter_max_attempts: The maximum number of attempts to check DB instance state
725
+ :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
726
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
728
727
  """
729
728
 
730
729
  template_fields = ("db_identifier", "db_type")
@@ -734,26 +733,52 @@ class RdsStartDbOperator(RdsBaseOperator):
734
733
  *,
735
734
  db_identifier: str,
736
735
  db_type: RdsDbType | str = RdsDbType.INSTANCE,
737
- aws_conn_id: str = "aws_default",
738
736
  wait_for_completion: bool = True,
737
+ waiter_delay: int = 30,
738
+ waiter_max_attempts: int = 40,
739
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
739
740
  **kwargs,
740
741
  ):
741
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
742
+ super().__init__(**kwargs)
742
743
  self.db_identifier = db_identifier
743
744
  self.db_type = db_type
744
745
  self.wait_for_completion = wait_for_completion
746
+ self.waiter_delay = waiter_delay
747
+ self.waiter_max_attempts = waiter_max_attempts
748
+ self.deferrable = deferrable
745
749
 
746
750
  def execute(self, context: Context) -> str:
747
751
  self.db_type = RdsDbType(self.db_type)
748
- start_db_response = self._start_db()
749
- if self.wait_for_completion:
752
+ start_db_response: dict[str, Any] = self._start_db()
753
+ if self.deferrable:
754
+ self.defer(
755
+ trigger=RdsDbAvailableTrigger(
756
+ db_identifier=self.db_identifier,
757
+ waiter_delay=self.waiter_delay,
758
+ waiter_max_attempts=self.waiter_max_attempts,
759
+ aws_conn_id=self.aws_conn_id,
760
+ region_name=self.region_name,
761
+ response=start_db_response,
762
+ db_type=RdsDbType.INSTANCE,
763
+ ),
764
+ method_name="execute_complete",
765
+ )
766
+ elif self.wait_for_completion:
750
767
  self._wait_until_db_available()
751
768
  return json.dumps(start_db_response, default=str)
752
769
 
770
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
771
+ if event is None or event["status"] != "success":
772
+ raise AirflowException(f"Failed to start DB: {event}")
773
+ else:
774
+ return json.dumps(event["response"], default=str)
775
+
753
776
  def _start_db(self):
754
777
  self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
755
778
  if self.db_type == RdsDbType.INSTANCE:
756
- response = self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
779
+ response = self.hook.conn.start_db_instance(
780
+ DBInstanceIdentifier=self.db_identifier,
781
+ )
757
782
  else:
758
783
  response = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
759
784
  return response
@@ -761,9 +786,19 @@ class RdsStartDbOperator(RdsBaseOperator):
761
786
  def _wait_until_db_available(self):
762
787
  self.log.info("Waiting for DB %s to reach 'available' state", self.db_type.value)
763
788
  if self.db_type == RdsDbType.INSTANCE:
764
- self.hook.wait_for_db_instance_state(self.db_identifier, target_state="available")
789
+ self.hook.wait_for_db_instance_state(
790
+ self.db_identifier,
791
+ target_state="available",
792
+ check_interval=self.waiter_delay,
793
+ max_attempts=self.waiter_max_attempts,
794
+ )
765
795
  else:
766
- self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="available")
796
+ self.hook.wait_for_db_cluster_state(
797
+ self.db_identifier,
798
+ target_state="available",
799
+ check_interval=self.waiter_delay,
800
+ max_attempts=self.waiter_max_attempts,
801
+ )
767
802
 
768
803
 
769
804
  class RdsStopDbOperator(RdsBaseOperator):
@@ -779,8 +814,11 @@ class RdsStopDbOperator(RdsBaseOperator):
779
814
  :param db_snapshot_identifier: The instance identifier of the DB Snapshot to create before
780
815
  stopping the DB instance. The default value (None) skips snapshot creation. This
781
816
  parameter is ignored when ``db_type`` is "cluster"
782
- :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
783
817
  :param wait_for_completion: If True, waits for DB to stop. (default: True)
818
+ :param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
819
+ :param waiter_max_attempts: The maximum number of attempts to check DB instance state
820
+ :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
821
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
784
822
  """
785
823
 
786
824
  template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
@@ -791,23 +829,47 @@ class RdsStopDbOperator(RdsBaseOperator):
791
829
  db_identifier: str,
792
830
  db_type: RdsDbType | str = RdsDbType.INSTANCE,
793
831
  db_snapshot_identifier: str | None = None,
794
- aws_conn_id: str = "aws_default",
795
832
  wait_for_completion: bool = True,
833
+ waiter_delay: int = 30,
834
+ waiter_max_attempts: int = 40,
835
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
796
836
  **kwargs,
797
837
  ):
798
- super().__init__(aws_conn_id=aws_conn_id, **kwargs)
838
+ super().__init__(**kwargs)
799
839
  self.db_identifier = db_identifier
800
840
  self.db_type = db_type
801
841
  self.db_snapshot_identifier = db_snapshot_identifier
802
842
  self.wait_for_completion = wait_for_completion
843
+ self.waiter_delay = waiter_delay
844
+ self.waiter_max_attempts = waiter_max_attempts
845
+ self.deferrable = deferrable
803
846
 
804
847
  def execute(self, context: Context) -> str:
805
848
  self.db_type = RdsDbType(self.db_type)
806
- stop_db_response = self._stop_db()
807
- if self.wait_for_completion:
849
+ stop_db_response: dict[str, Any] = self._stop_db()
850
+ if self.deferrable:
851
+ self.defer(
852
+ trigger=RdsDbStoppedTrigger(
853
+ db_identifier=self.db_identifier,
854
+ waiter_delay=self.waiter_delay,
855
+ waiter_max_attempts=self.waiter_max_attempts,
856
+ aws_conn_id=self.aws_conn_id,
857
+ region_name=self.region_name,
858
+ response=stop_db_response,
859
+ db_type=RdsDbType.INSTANCE,
860
+ ),
861
+ method_name="execute_complete",
862
+ )
863
+ elif self.wait_for_completion:
808
864
  self._wait_until_db_stopped()
809
865
  return json.dumps(stop_db_response, default=str)
810
866
 
867
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
868
+ if event is None or event["status"] != "success":
869
+ raise AirflowException(f"Failed to start DB: {event}")
870
+ else:
871
+ return json.dumps(event["response"], default=str)
872
+
811
873
  def _stop_db(self):
812
874
  self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
813
875
  if self.db_type == RdsDbType.INSTANCE:
@@ -829,9 +891,19 @@ class RdsStopDbOperator(RdsBaseOperator):
829
891
  def _wait_until_db_stopped(self):
830
892
  self.log.info("Waiting for DB %s to reach 'stopped' state", self.db_type.value)
831
893
  if self.db_type == RdsDbType.INSTANCE:
832
- self.hook.wait_for_db_instance_state(self.db_identifier, target_state="stopped")
894
+ self.hook.wait_for_db_instance_state(
895
+ self.db_identifier,
896
+ target_state="stopped",
897
+ check_interval=self.waiter_delay,
898
+ max_attempts=self.waiter_max_attempts,
899
+ )
833
900
  else:
834
- self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="stopped")
901
+ self.hook.wait_for_db_cluster_state(
902
+ self.db_identifier,
903
+ target_state="stopped",
904
+ check_interval=self.waiter_delay,
905
+ max_attempts=self.waiter_max_attempts,
906
+ )
835
907
 
836
908
 
837
909
  __all__ = [
@@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
51
51
  if False (default) will return statement ID
52
52
  :param aws_conn_id: aws connection to use
53
53
  :param region: aws region to use
54
+ :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
55
+ `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
56
+ https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
54
57
  """
55
58
 
56
59
  template_fields = (
@@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
62
65
  "statement_name",
63
66
  "aws_conn_id",
64
67
  "region",
68
+ "workgroup_name",
65
69
  )
66
70
  template_ext = (".sql",)
67
71
  template_fields_renderers = {"sql": "sql"}
@@ -82,12 +86,14 @@ class RedshiftDataOperator(BaseOperator):
82
86
  return_sql_result: bool = False,
83
87
  aws_conn_id: str = "aws_default",
84
88
  region: str | None = None,
89
+ workgroup_name: str | None = None,
85
90
  **kwargs,
86
91
  ) -> None:
87
92
  super().__init__(**kwargs)
88
93
  self.database = database
89
94
  self.sql = sql
90
95
  self.cluster_identifier = cluster_identifier
96
+ self.workgroup_name = workgroup_name
91
97
  self.db_user = db_user
92
98
  self.parameters = parameters
93
99
  self.secret_arn = secret_arn
@@ -119,6 +125,7 @@ class RedshiftDataOperator(BaseOperator):
119
125
  database=self.database,
120
126
  sql=self.sql,
121
127
  cluster_identifier=self.cluster_identifier,
128
+ workgroup_name=self.workgroup_name,
122
129
  db_user=self.db_user,
123
130
  parameters=self.parameters,
124
131
  secret_arn=self.secret_arn,
@@ -30,7 +30,10 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni
30
30
  from airflow.models import BaseOperator
31
31
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
32
32
  from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
33
- from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger
33
+ from airflow.providers.amazon.aws.triggers.sagemaker import (
34
+ SageMakerPipelineTrigger,
35
+ SageMakerTrigger,
36
+ )
34
37
  from airflow.providers.amazon.aws.utils import trim_none_values
35
38
  from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
36
39
  from airflow.providers.amazon.aws.utils.tags import format_tags
@@ -998,8 +1001,10 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
998
1001
  All parameters supplied need to already be present in the pipeline definition.
999
1002
  :param wait_for_completion: If true, this operator will only complete once the pipeline is complete.
1000
1003
  :param check_interval: How long to wait between checks for pipeline status when waiting for completion.
1004
+ :param waiter_max_attempts: How many times to check the status before failing.
1001
1005
  :param verbose: Whether to print steps details when waiting for completion.
1002
1006
  Defaults to true, consider turning off for pipelines that have thousands of steps.
1007
+ :param deferrable: Run operator in the deferrable mode.
1003
1008
 
1004
1009
  :return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
1005
1010
  """
@@ -1015,7 +1020,9 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1015
1020
  pipeline_params: dict | None = None,
1016
1021
  wait_for_completion: bool = False,
1017
1022
  check_interval: int = CHECK_INTERVAL_SECOND,
1023
+ waiter_max_attempts: int = 9999,
1018
1024
  verbose: bool = True,
1025
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1019
1026
  **kwargs,
1020
1027
  ):
1021
1028
  super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
@@ -1024,22 +1031,46 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1024
1031
  self.pipeline_params = pipeline_params
1025
1032
  self.wait_for_completion = wait_for_completion
1026
1033
  self.check_interval = check_interval
1034
+ self.waiter_max_attempts = waiter_max_attempts
1027
1035
  self.verbose = verbose
1036
+ self.deferrable = deferrable
1028
1037
 
1029
1038
  def execute(self, context: Context) -> str:
1030
1039
  arn = self.hook.start_pipeline(
1031
1040
  pipeline_name=self.pipeline_name,
1032
1041
  display_name=self.display_name,
1033
1042
  pipeline_params=self.pipeline_params,
1034
- wait_for_completion=self.wait_for_completion,
1035
- check_interval=self.check_interval,
1036
- verbose=self.verbose,
1037
1043
  )
1038
1044
  self.log.info(
1039
1045
  "Starting a new execution for pipeline %s, running with ARN %s", self.pipeline_name, arn
1040
1046
  )
1047
+ if self.deferrable:
1048
+ self.defer(
1049
+ trigger=SageMakerPipelineTrigger(
1050
+ waiter_type=SageMakerPipelineTrigger.Type.COMPLETE,
1051
+ pipeline_execution_arn=arn,
1052
+ waiter_delay=self.check_interval,
1053
+ waiter_max_attempts=self.waiter_max_attempts,
1054
+ aws_conn_id=self.aws_conn_id,
1055
+ ),
1056
+ method_name="execute_complete",
1057
+ )
1058
+ elif self.wait_for_completion:
1059
+ self.hook.check_status(
1060
+ arn,
1061
+ "PipelineExecutionStatus",
1062
+ lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
1063
+ self.check_interval,
1064
+ non_terminal_states=self.hook.pipeline_non_terminal_states,
1065
+ max_ingestion_time=self.waiter_max_attempts * self.check_interval,
1066
+ )
1041
1067
  return arn
1042
1068
 
1069
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
1070
+ if event is None or event["status"] != "success":
1071
+ raise AirflowException(f"Failure during pipeline execution: {event}")
1072
+ return event["value"]
1073
+
1043
1074
 
1044
1075
  class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1045
1076
  """
@@ -1057,6 +1088,7 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1057
1088
  :param verbose: Whether to print steps details when waiting for completion.
1058
1089
  Defaults to true, consider turning off for pipelines that have thousands of steps.
1059
1090
  :param fail_if_not_running: raises an exception if the pipeline stopped or succeeded before this was run
1091
+ :param deferrable: Run operator in the deferrable mode.
1060
1092
 
1061
1093
  :return str: Returns the status of the pipeline execution after the operation has been done.
1062
1094
  """
@@ -1073,23 +1105,24 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1073
1105
  pipeline_exec_arn: str,
1074
1106
  wait_for_completion: bool = False,
1075
1107
  check_interval: int = CHECK_INTERVAL_SECOND,
1108
+ waiter_max_attempts: int = 9999,
1076
1109
  verbose: bool = True,
1077
1110
  fail_if_not_running: bool = False,
1111
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
1078
1112
  **kwargs,
1079
1113
  ):
1080
1114
  super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
1081
1115
  self.pipeline_exec_arn = pipeline_exec_arn
1082
1116
  self.wait_for_completion = wait_for_completion
1083
1117
  self.check_interval = check_interval
1118
+ self.waiter_max_attempts = waiter_max_attempts
1084
1119
  self.verbose = verbose
1085
1120
  self.fail_if_not_running = fail_if_not_running
1121
+ self.deferrable = deferrable
1086
1122
 
1087
1123
  def execute(self, context: Context) -> str:
1088
1124
  status = self.hook.stop_pipeline(
1089
1125
  pipeline_exec_arn=self.pipeline_exec_arn,
1090
- wait_for_completion=self.wait_for_completion,
1091
- check_interval=self.check_interval,
1092
- verbose=self.verbose,
1093
1126
  fail_if_not_running=self.fail_if_not_running,
1094
1127
  )
1095
1128
  self.log.info(
@@ -1097,8 +1130,43 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1097
1130
  self.pipeline_exec_arn,
1098
1131
  status,
1099
1132
  )
1133
+
1134
+ if status not in self.hook.pipeline_non_terminal_states:
1135
+ # pipeline already stopped
1136
+ return status
1137
+
1138
+ # else, eventually wait for completion
1139
+ if self.deferrable:
1140
+ self.defer(
1141
+ trigger=SageMakerPipelineTrigger(
1142
+ waiter_type=SageMakerPipelineTrigger.Type.STOPPED,
1143
+ pipeline_execution_arn=self.pipeline_exec_arn,
1144
+ waiter_delay=self.check_interval,
1145
+ waiter_max_attempts=self.waiter_max_attempts,
1146
+ aws_conn_id=self.aws_conn_id,
1147
+ ),
1148
+ method_name="execute_complete",
1149
+ )
1150
+ elif self.wait_for_completion:
1151
+ status = self.hook.check_status(
1152
+ self.pipeline_exec_arn,
1153
+ "PipelineExecutionStatus",
1154
+ lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
1155
+ self.check_interval,
1156
+ non_terminal_states=self.hook.pipeline_non_terminal_states,
1157
+ max_ingestion_time=self.waiter_max_attempts * self.check_interval,
1158
+ )["PipelineExecutionStatus"]
1159
+
1100
1160
  return status
1101
1161
 
1162
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
1163
+ if event is None or event["status"] != "success":
1164
+ raise AirflowException(f"Failure during pipeline execution: {event}")
1165
+ else:
1166
+ # theoretically we should do a `describe` call to know this,
1167
+ # but if we reach this point, this is the only possible status
1168
+ return "Stopped"
1169
+
1102
1170
 
1103
1171
  class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
1104
1172
  """
@@ -17,11 +17,14 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import json
20
- from typing import TYPE_CHECKING, Sequence
20
+ from datetime import timedelta
21
+ from typing import TYPE_CHECKING, Any, Sequence
21
22
 
23
+ from airflow.configuration import conf
22
24
  from airflow.exceptions import AirflowException
23
25
  from airflow.models import BaseOperator
24
26
  from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
27
+ from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
25
28
 
26
29
  if TYPE_CHECKING:
27
30
  from airflow.utils.context import Context
@@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator):
42
45
  :param state_machine_input: JSON data input to pass to the State Machine
43
46
  :param aws_conn_id: aws connection to uses
44
47
  :param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
48
+ :param waiter_max_attempts: Maximum number of attempts to poll the execution.
49
+ :param waiter_delay: Number of seconds between polling the state of the execution.
50
+ :param deferrable: If True, the operator will wait asynchronously for the job to complete.
51
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
52
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
45
53
  """
46
54
 
47
55
  template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
@@ -56,6 +64,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
56
64
  state_machine_input: dict | str | None = None,
57
65
  aws_conn_id: str = "aws_default",
58
66
  region_name: str | None = None,
67
+ waiter_max_attempts: int = 30,
68
+ waiter_delay: int = 60,
69
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
59
70
  **kwargs,
60
71
  ):
61
72
  super().__init__(**kwargs)
@@ -64,6 +75,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
64
75
  self.input = state_machine_input
65
76
  self.aws_conn_id = aws_conn_id
66
77
  self.region_name = region_name
78
+ self.waiter_delay = waiter_delay
79
+ self.waiter_max_attempts = waiter_max_attempts
80
+ self.deferrable = deferrable
67
81
 
68
82
  def execute(self, context: Context):
69
83
  hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
@@ -74,9 +88,27 @@ class StepFunctionStartExecutionOperator(BaseOperator):
74
88
  raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
75
89
 
76
90
  self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
77
-
91
+ if self.deferrable:
92
+ self.defer(
93
+ trigger=StepFunctionsExecutionCompleteTrigger(
94
+ execution_arn=execution_arn,
95
+ waiter_delay=self.waiter_delay,
96
+ waiter_max_attempts=self.waiter_max_attempts,
97
+ aws_conn_id=self.aws_conn_id,
98
+ region_name=self.region_name,
99
+ ),
100
+ method_name="execute_complete",
101
+ timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
102
+ )
78
103
  return execution_arn
79
104
 
105
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
106
+ if event is None or event["status"] != "success":
107
+ raise AirflowException(f"Trigger error: event is {event}")
108
+
109
+ self.log.info("State Machine execution completed successfully")
110
+ return event["execution_arn"]
111
+
80
112
 
81
113
  class StepFunctionGetExecutionOutputOperator(BaseOperator):
82
114
  """