apache-airflow-providers-amazon 8.3.1rc1__py3-none-any.whl → 8.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +4 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +29 -12
- airflow/providers/amazon/aws/hooks/emr.py +17 -9
- airflow/providers/amazon/aws/hooks/eventbridge.py +27 -0
- airflow/providers/amazon/aws/hooks/redshift_data.py +10 -0
- airflow/providers/amazon/aws/hooks/sagemaker.py +24 -14
- airflow/providers/amazon/aws/notifications/chime.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +140 -7
- airflow/providers/amazon/aws/operators/emr.py +202 -22
- airflow/providers/amazon/aws/operators/eventbridge.py +87 -0
- airflow/providers/amazon/aws/operators/rds.py +120 -48
- airflow/providers/amazon/aws/operators/redshift_data.py +7 -0
- airflow/providers/amazon/aws/operators/sagemaker.py +75 -7
- airflow/providers/amazon/aws/operators/step_function.py +34 -2
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -1
- airflow/providers/amazon/aws/triggers/batch.py +1 -1
- airflow/providers/amazon/aws/triggers/ecs.py +7 -5
- airflow/providers/amazon/aws/triggers/eks.py +174 -3
- airflow/providers/amazon/aws/triggers/emr.py +215 -1
- airflow/providers/amazon/aws/triggers/rds.py +161 -5
- airflow/providers/amazon/aws/triggers/sagemaker.py +84 -1
- airflow/providers/amazon/aws/triggers/step_function.py +59 -0
- airflow/providers/amazon/aws/utils/__init__.py +16 -1
- airflow/providers/amazon/aws/utils/rds.py +2 -2
- airflow/providers/amazon/aws/waiters/sagemaker.json +46 -0
- airflow/providers/amazon/aws/waiters/stepfunctions.json +36 -0
- airflow/providers/amazon/get_provider_info.py +21 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/METADATA +13 -13
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/RECORD +34 -30
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-8.3.1rc1.dist-info → apache_airflow_providers_amazon-8.4.0.dist-info}/top_level.txt +0 -0
@@ -20,15 +20,20 @@ from __future__ import annotations
|
|
20
20
|
import json
|
21
21
|
import warnings
|
22
22
|
from datetime import timedelta
|
23
|
-
from typing import TYPE_CHECKING, Sequence
|
23
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
24
24
|
|
25
|
+
from astroid.decorators import cachedproperty
|
25
26
|
from mypy_boto3_rds.type_defs import TagTypeDef
|
26
27
|
|
27
28
|
from airflow.configuration import conf
|
28
29
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
29
30
|
from airflow.models import BaseOperator
|
30
31
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
31
|
-
from airflow.providers.amazon.aws.triggers.rds import
|
32
|
+
from airflow.providers.amazon.aws.triggers.rds import (
|
33
|
+
RdsDbAvailableTrigger,
|
34
|
+
RdsDbDeletedTrigger,
|
35
|
+
RdsDbStoppedTrigger,
|
36
|
+
)
|
32
37
|
from airflow.providers.amazon.aws.utils.rds import RdsDbType
|
33
38
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
34
39
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -62,13 +67,17 @@ class RdsBaseOperator(BaseOperator):
|
|
62
67
|
AirflowProviderDeprecationWarning,
|
63
68
|
stacklevel=3, # 2 is in the operator's init, 3 is in the user code creating the operator
|
64
69
|
)
|
65
|
-
hook_params = hook_params or {}
|
66
|
-
self.
|
67
|
-
self.
|
70
|
+
self.hook_params = hook_params or {}
|
71
|
+
self.aws_conn_id = aws_conn_id
|
72
|
+
self.region_name = region_name or self.hook_params.pop("region_name", None)
|
68
73
|
super().__init__(*args, **kwargs)
|
69
74
|
|
70
75
|
self._await_interval = 60 # seconds
|
71
76
|
|
77
|
+
@cachedproperty
|
78
|
+
def hook(self) -> RdsHook:
|
79
|
+
return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, **self.hook_params)
|
80
|
+
|
72
81
|
def execute(self, context: Context) -> str:
|
73
82
|
"""Different implementations for snapshots, tasks and events."""
|
74
83
|
raise NotImplementedError
|
@@ -106,10 +115,9 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
|
|
106
115
|
db_snapshot_identifier: str,
|
107
116
|
tags: Sequence[TagTypeDef] | dict | None = None,
|
108
117
|
wait_for_completion: bool = True,
|
109
|
-
aws_conn_id: str = "aws_default",
|
110
118
|
**kwargs,
|
111
119
|
):
|
112
|
-
super().__init__(
|
120
|
+
super().__init__(**kwargs)
|
113
121
|
self.db_type = RdsDbType(db_type)
|
114
122
|
self.db_identifier = db_identifier
|
115
123
|
self.db_snapshot_identifier = db_snapshot_identifier
|
@@ -194,10 +202,9 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
|
|
194
202
|
target_custom_availability_zone: str = "",
|
195
203
|
source_region: str = "",
|
196
204
|
wait_for_completion: bool = True,
|
197
|
-
aws_conn_id: str = "aws_default",
|
198
205
|
**kwargs,
|
199
206
|
):
|
200
|
-
super().__init__(
|
207
|
+
super().__init__(**kwargs)
|
201
208
|
|
202
209
|
self.db_type = RdsDbType(db_type)
|
203
210
|
self.source_db_snapshot_identifier = source_db_snapshot_identifier
|
@@ -274,10 +281,9 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
|
|
274
281
|
db_type: str,
|
275
282
|
db_snapshot_identifier: str,
|
276
283
|
wait_for_completion: bool = True,
|
277
|
-
aws_conn_id: str = "aws_default",
|
278
284
|
**kwargs,
|
279
285
|
):
|
280
|
-
super().__init__(
|
286
|
+
super().__init__(**kwargs)
|
281
287
|
|
282
288
|
self.db_type = RdsDbType(db_type)
|
283
289
|
self.db_snapshot_identifier = db_snapshot_identifier
|
@@ -345,10 +351,9 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
|
|
345
351
|
s3_prefix: str = "",
|
346
352
|
export_only: list[str] | None = None,
|
347
353
|
wait_for_completion: bool = True,
|
348
|
-
aws_conn_id: str = "aws_default",
|
349
354
|
**kwargs,
|
350
355
|
):
|
351
|
-
super().__init__(
|
356
|
+
super().__init__(**kwargs)
|
352
357
|
|
353
358
|
self.export_task_identifier = export_task_identifier
|
354
359
|
self.source_arn = source_arn
|
@@ -397,10 +402,9 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
|
|
397
402
|
export_task_identifier: str,
|
398
403
|
wait_for_completion: bool = True,
|
399
404
|
check_interval: int = 30,
|
400
|
-
aws_conn_id: str = "aws_default",
|
401
405
|
**kwargs,
|
402
406
|
):
|
403
|
-
super().__init__(
|
407
|
+
super().__init__(**kwargs)
|
404
408
|
|
405
409
|
self.export_task_identifier = export_task_identifier
|
406
410
|
self.wait_for_completion = wait_for_completion
|
@@ -461,10 +465,9 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
|
|
461
465
|
enabled: bool = True,
|
462
466
|
tags: Sequence[TagTypeDef] | dict | None = None,
|
463
467
|
wait_for_completion: bool = True,
|
464
|
-
aws_conn_id: str = "aws_default",
|
465
468
|
**kwargs,
|
466
469
|
):
|
467
|
-
super().__init__(
|
470
|
+
super().__init__(**kwargs)
|
468
471
|
|
469
472
|
self.subscription_name = subscription_name
|
470
473
|
self.sns_topic_arn = sns_topic_arn
|
@@ -511,10 +514,9 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
|
|
511
514
|
self,
|
512
515
|
*,
|
513
516
|
subscription_name: str,
|
514
|
-
aws_conn_id: str = "aws_default",
|
515
517
|
**kwargs,
|
516
518
|
):
|
517
|
-
super().__init__(
|
519
|
+
super().__init__(**kwargs)
|
518
520
|
|
519
521
|
self.subscription_name = subscription_name
|
520
522
|
|
@@ -545,7 +547,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
545
547
|
:param engine: The name of the database engine to be used for this instance
|
546
548
|
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
|
547
549
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
|
548
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
549
550
|
:param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True)
|
550
551
|
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
|
551
552
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
@@ -563,14 +564,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
563
564
|
db_instance_class: str,
|
564
565
|
engine: str,
|
565
566
|
rds_kwargs: dict | None = None,
|
566
|
-
aws_conn_id: str = "aws_default",
|
567
567
|
wait_for_completion: bool = True,
|
568
568
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
569
569
|
waiter_delay: int = 30,
|
570
570
|
waiter_max_attempts: int = 60,
|
571
571
|
**kwargs,
|
572
572
|
):
|
573
|
-
super().__init__(
|
573
|
+
super().__init__(**kwargs)
|
574
574
|
|
575
575
|
self.db_instance_identifier = db_instance_identifier
|
576
576
|
self.db_instance_class = db_instance_class
|
@@ -580,7 +580,6 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
580
580
|
self.deferrable = deferrable
|
581
581
|
self.waiter_delay = waiter_delay
|
582
582
|
self.waiter_max_attempts = waiter_max_attempts
|
583
|
-
self.aws_conn_id = aws_conn_id
|
584
583
|
|
585
584
|
def execute(self, context: Context) -> str:
|
586
585
|
self.log.info("Creating new DB instance %s", self.db_instance_identifier)
|
@@ -593,15 +592,15 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
593
592
|
)
|
594
593
|
if self.deferrable:
|
595
594
|
self.defer(
|
596
|
-
trigger=
|
597
|
-
|
595
|
+
trigger=RdsDbAvailableTrigger(
|
596
|
+
db_identifier=self.db_instance_identifier,
|
598
597
|
waiter_delay=self.waiter_delay,
|
599
598
|
waiter_max_attempts=self.waiter_max_attempts,
|
600
599
|
aws_conn_id=self.aws_conn_id,
|
601
600
|
region_name=self.region_name,
|
602
|
-
waiter_name="db_instance_available",
|
603
601
|
# ignoring type because create_db_instance is a dict
|
604
602
|
response=create_db_instance, # type: ignore[arg-type]
|
603
|
+
db_type=RdsDbType.INSTANCE,
|
605
604
|
),
|
606
605
|
method_name="execute_complete",
|
607
606
|
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
|
@@ -638,7 +637,6 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
638
637
|
:param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
|
639
638
|
:param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
|
640
639
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
|
641
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
642
640
|
:param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True)
|
643
641
|
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
|
644
642
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
@@ -654,21 +652,19 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
654
652
|
*,
|
655
653
|
db_instance_identifier: str,
|
656
654
|
rds_kwargs: dict | None = None,
|
657
|
-
aws_conn_id: str = "aws_default",
|
658
655
|
wait_for_completion: bool = True,
|
659
656
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
660
657
|
waiter_delay: int = 30,
|
661
658
|
waiter_max_attempts: int = 60,
|
662
659
|
**kwargs,
|
663
660
|
):
|
664
|
-
super().__init__(
|
661
|
+
super().__init__(**kwargs)
|
665
662
|
self.db_instance_identifier = db_instance_identifier
|
666
663
|
self.rds_kwargs = rds_kwargs or {}
|
667
664
|
self.wait_for_completion = False if deferrable else wait_for_completion
|
668
665
|
self.deferrable = deferrable
|
669
666
|
self.waiter_delay = waiter_delay
|
670
667
|
self.waiter_max_attempts = waiter_max_attempts
|
671
|
-
self.aws_conn_id = aws_conn_id
|
672
668
|
|
673
669
|
def execute(self, context: Context) -> str:
|
674
670
|
self.log.info("Deleting DB instance %s", self.db_instance_identifier)
|
@@ -679,15 +675,15 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
679
675
|
)
|
680
676
|
if self.deferrable:
|
681
677
|
self.defer(
|
682
|
-
trigger=
|
683
|
-
|
678
|
+
trigger=RdsDbDeletedTrigger(
|
679
|
+
db_identifier=self.db_instance_identifier,
|
684
680
|
waiter_delay=self.waiter_delay,
|
685
681
|
waiter_max_attempts=self.waiter_max_attempts,
|
686
682
|
aws_conn_id=self.aws_conn_id,
|
687
683
|
region_name=self.region_name,
|
688
|
-
waiter_name="db_instance_deleted",
|
689
684
|
# ignoring type because delete_db_instance is a dict
|
690
685
|
response=delete_db_instance, # type: ignore[arg-type]
|
686
|
+
db_type=RdsDbType.INSTANCE,
|
691
687
|
),
|
692
688
|
method_name="execute_complete",
|
693
689
|
timeout=timedelta(seconds=self.waiter_delay * self.waiter_max_attempts),
|
@@ -723,8 +719,11 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
723
719
|
|
724
720
|
:param db_identifier: The AWS identifier of the DB to start
|
725
721
|
:param db_type: Type of the DB - either "instance" or "cluster" (default: "instance")
|
726
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
|
727
722
|
:param wait_for_completion: If True, waits for DB to start. (default: True)
|
723
|
+
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
|
724
|
+
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
725
|
+
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
726
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
728
727
|
"""
|
729
728
|
|
730
729
|
template_fields = ("db_identifier", "db_type")
|
@@ -734,26 +733,52 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
734
733
|
*,
|
735
734
|
db_identifier: str,
|
736
735
|
db_type: RdsDbType | str = RdsDbType.INSTANCE,
|
737
|
-
aws_conn_id: str = "aws_default",
|
738
736
|
wait_for_completion: bool = True,
|
737
|
+
waiter_delay: int = 30,
|
738
|
+
waiter_max_attempts: int = 40,
|
739
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
739
740
|
**kwargs,
|
740
741
|
):
|
741
|
-
super().__init__(
|
742
|
+
super().__init__(**kwargs)
|
742
743
|
self.db_identifier = db_identifier
|
743
744
|
self.db_type = db_type
|
744
745
|
self.wait_for_completion = wait_for_completion
|
746
|
+
self.waiter_delay = waiter_delay
|
747
|
+
self.waiter_max_attempts = waiter_max_attempts
|
748
|
+
self.deferrable = deferrable
|
745
749
|
|
746
750
|
def execute(self, context: Context) -> str:
|
747
751
|
self.db_type = RdsDbType(self.db_type)
|
748
|
-
start_db_response = self._start_db()
|
749
|
-
if self.
|
752
|
+
start_db_response: dict[str, Any] = self._start_db()
|
753
|
+
if self.deferrable:
|
754
|
+
self.defer(
|
755
|
+
trigger=RdsDbAvailableTrigger(
|
756
|
+
db_identifier=self.db_identifier,
|
757
|
+
waiter_delay=self.waiter_delay,
|
758
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
759
|
+
aws_conn_id=self.aws_conn_id,
|
760
|
+
region_name=self.region_name,
|
761
|
+
response=start_db_response,
|
762
|
+
db_type=RdsDbType.INSTANCE,
|
763
|
+
),
|
764
|
+
method_name="execute_complete",
|
765
|
+
)
|
766
|
+
elif self.wait_for_completion:
|
750
767
|
self._wait_until_db_available()
|
751
768
|
return json.dumps(start_db_response, default=str)
|
752
769
|
|
770
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
771
|
+
if event is None or event["status"] != "success":
|
772
|
+
raise AirflowException(f"Failed to start DB: {event}")
|
773
|
+
else:
|
774
|
+
return json.dumps(event["response"], default=str)
|
775
|
+
|
753
776
|
def _start_db(self):
|
754
777
|
self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
|
755
778
|
if self.db_type == RdsDbType.INSTANCE:
|
756
|
-
response = self.hook.conn.start_db_instance(
|
779
|
+
response = self.hook.conn.start_db_instance(
|
780
|
+
DBInstanceIdentifier=self.db_identifier,
|
781
|
+
)
|
757
782
|
else:
|
758
783
|
response = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
|
759
784
|
return response
|
@@ -761,9 +786,19 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
761
786
|
def _wait_until_db_available(self):
|
762
787
|
self.log.info("Waiting for DB %s to reach 'available' state", self.db_type.value)
|
763
788
|
if self.db_type == RdsDbType.INSTANCE:
|
764
|
-
self.hook.wait_for_db_instance_state(
|
789
|
+
self.hook.wait_for_db_instance_state(
|
790
|
+
self.db_identifier,
|
791
|
+
target_state="available",
|
792
|
+
check_interval=self.waiter_delay,
|
793
|
+
max_attempts=self.waiter_max_attempts,
|
794
|
+
)
|
765
795
|
else:
|
766
|
-
self.hook.wait_for_db_cluster_state(
|
796
|
+
self.hook.wait_for_db_cluster_state(
|
797
|
+
self.db_identifier,
|
798
|
+
target_state="available",
|
799
|
+
check_interval=self.waiter_delay,
|
800
|
+
max_attempts=self.waiter_max_attempts,
|
801
|
+
)
|
767
802
|
|
768
803
|
|
769
804
|
class RdsStopDbOperator(RdsBaseOperator):
|
@@ -779,8 +814,11 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
779
814
|
:param db_snapshot_identifier: The instance identifier of the DB Snapshot to create before
|
780
815
|
stopping the DB instance. The default value (None) skips snapshot creation. This
|
781
816
|
parameter is ignored when ``db_type`` is "cluster"
|
782
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
|
783
817
|
:param wait_for_completion: If True, waits for DB to stop. (default: True)
|
818
|
+
:param waiter_delay: Time (in seconds) to wait between two consecutive calls to check DB instance state
|
819
|
+
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
820
|
+
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
821
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
784
822
|
"""
|
785
823
|
|
786
824
|
template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
|
@@ -791,23 +829,47 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
791
829
|
db_identifier: str,
|
792
830
|
db_type: RdsDbType | str = RdsDbType.INSTANCE,
|
793
831
|
db_snapshot_identifier: str | None = None,
|
794
|
-
aws_conn_id: str = "aws_default",
|
795
832
|
wait_for_completion: bool = True,
|
833
|
+
waiter_delay: int = 30,
|
834
|
+
waiter_max_attempts: int = 40,
|
835
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
796
836
|
**kwargs,
|
797
837
|
):
|
798
|
-
super().__init__(
|
838
|
+
super().__init__(**kwargs)
|
799
839
|
self.db_identifier = db_identifier
|
800
840
|
self.db_type = db_type
|
801
841
|
self.db_snapshot_identifier = db_snapshot_identifier
|
802
842
|
self.wait_for_completion = wait_for_completion
|
843
|
+
self.waiter_delay = waiter_delay
|
844
|
+
self.waiter_max_attempts = waiter_max_attempts
|
845
|
+
self.deferrable = deferrable
|
803
846
|
|
804
847
|
def execute(self, context: Context) -> str:
|
805
848
|
self.db_type = RdsDbType(self.db_type)
|
806
|
-
stop_db_response = self._stop_db()
|
807
|
-
if self.
|
849
|
+
stop_db_response: dict[str, Any] = self._stop_db()
|
850
|
+
if self.deferrable:
|
851
|
+
self.defer(
|
852
|
+
trigger=RdsDbStoppedTrigger(
|
853
|
+
db_identifier=self.db_identifier,
|
854
|
+
waiter_delay=self.waiter_delay,
|
855
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
856
|
+
aws_conn_id=self.aws_conn_id,
|
857
|
+
region_name=self.region_name,
|
858
|
+
response=stop_db_response,
|
859
|
+
db_type=RdsDbType.INSTANCE,
|
860
|
+
),
|
861
|
+
method_name="execute_complete",
|
862
|
+
)
|
863
|
+
elif self.wait_for_completion:
|
808
864
|
self._wait_until_db_stopped()
|
809
865
|
return json.dumps(stop_db_response, default=str)
|
810
866
|
|
867
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
868
|
+
if event is None or event["status"] != "success":
|
869
|
+
raise AirflowException(f"Failed to start DB: {event}")
|
870
|
+
else:
|
871
|
+
return json.dumps(event["response"], default=str)
|
872
|
+
|
811
873
|
def _stop_db(self):
|
812
874
|
self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
|
813
875
|
if self.db_type == RdsDbType.INSTANCE:
|
@@ -829,9 +891,19 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
829
891
|
def _wait_until_db_stopped(self):
|
830
892
|
self.log.info("Waiting for DB %s to reach 'stopped' state", self.db_type.value)
|
831
893
|
if self.db_type == RdsDbType.INSTANCE:
|
832
|
-
self.hook.wait_for_db_instance_state(
|
894
|
+
self.hook.wait_for_db_instance_state(
|
895
|
+
self.db_identifier,
|
896
|
+
target_state="stopped",
|
897
|
+
check_interval=self.waiter_delay,
|
898
|
+
max_attempts=self.waiter_max_attempts,
|
899
|
+
)
|
833
900
|
else:
|
834
|
-
self.hook.wait_for_db_cluster_state(
|
901
|
+
self.hook.wait_for_db_cluster_state(
|
902
|
+
self.db_identifier,
|
903
|
+
target_state="stopped",
|
904
|
+
check_interval=self.waiter_delay,
|
905
|
+
max_attempts=self.waiter_max_attempts,
|
906
|
+
)
|
835
907
|
|
836
908
|
|
837
909
|
__all__ = [
|
@@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
|
|
51
51
|
if False (default) will return statement ID
|
52
52
|
:param aws_conn_id: aws connection to use
|
53
53
|
:param region: aws region to use
|
54
|
+
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
|
55
|
+
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
|
56
|
+
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
|
54
57
|
"""
|
55
58
|
|
56
59
|
template_fields = (
|
@@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
|
|
62
65
|
"statement_name",
|
63
66
|
"aws_conn_id",
|
64
67
|
"region",
|
68
|
+
"workgroup_name",
|
65
69
|
)
|
66
70
|
template_ext = (".sql",)
|
67
71
|
template_fields_renderers = {"sql": "sql"}
|
@@ -82,12 +86,14 @@ class RedshiftDataOperator(BaseOperator):
|
|
82
86
|
return_sql_result: bool = False,
|
83
87
|
aws_conn_id: str = "aws_default",
|
84
88
|
region: str | None = None,
|
89
|
+
workgroup_name: str | None = None,
|
85
90
|
**kwargs,
|
86
91
|
) -> None:
|
87
92
|
super().__init__(**kwargs)
|
88
93
|
self.database = database
|
89
94
|
self.sql = sql
|
90
95
|
self.cluster_identifier = cluster_identifier
|
96
|
+
self.workgroup_name = workgroup_name
|
91
97
|
self.db_user = db_user
|
92
98
|
self.parameters = parameters
|
93
99
|
self.secret_arn = secret_arn
|
@@ -119,6 +125,7 @@ class RedshiftDataOperator(BaseOperator):
|
|
119
125
|
database=self.database,
|
120
126
|
sql=self.sql,
|
121
127
|
cluster_identifier=self.cluster_identifier,
|
128
|
+
workgroup_name=self.workgroup_name,
|
122
129
|
db_user=self.db_user,
|
123
130
|
parameters=self.parameters,
|
124
131
|
secret_arn=self.secret_arn,
|
@@ -30,7 +30,10 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni
|
|
30
30
|
from airflow.models import BaseOperator
|
31
31
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
32
32
|
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
33
|
-
from airflow.providers.amazon.aws.triggers.sagemaker import
|
33
|
+
from airflow.providers.amazon.aws.triggers.sagemaker import (
|
34
|
+
SageMakerPipelineTrigger,
|
35
|
+
SageMakerTrigger,
|
36
|
+
)
|
34
37
|
from airflow.providers.amazon.aws.utils import trim_none_values
|
35
38
|
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
|
36
39
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
@@ -998,8 +1001,10 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
998
1001
|
All parameters supplied need to already be present in the pipeline definition.
|
999
1002
|
:param wait_for_completion: If true, this operator will only complete once the pipeline is complete.
|
1000
1003
|
:param check_interval: How long to wait between checks for pipeline status when waiting for completion.
|
1004
|
+
:param waiter_max_attempts: How many times to check the status before failing.
|
1001
1005
|
:param verbose: Whether to print steps details when waiting for completion.
|
1002
1006
|
Defaults to true, consider turning off for pipelines that have thousands of steps.
|
1007
|
+
:param deferrable: Run operator in the deferrable mode.
|
1003
1008
|
|
1004
1009
|
:return str: Returns The ARN of the pipeline execution created in Amazon SageMaker.
|
1005
1010
|
"""
|
@@ -1015,7 +1020,9 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1015
1020
|
pipeline_params: dict | None = None,
|
1016
1021
|
wait_for_completion: bool = False,
|
1017
1022
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
1023
|
+
waiter_max_attempts: int = 9999,
|
1018
1024
|
verbose: bool = True,
|
1025
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1019
1026
|
**kwargs,
|
1020
1027
|
):
|
1021
1028
|
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
|
@@ -1024,22 +1031,46 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1024
1031
|
self.pipeline_params = pipeline_params
|
1025
1032
|
self.wait_for_completion = wait_for_completion
|
1026
1033
|
self.check_interval = check_interval
|
1034
|
+
self.waiter_max_attempts = waiter_max_attempts
|
1027
1035
|
self.verbose = verbose
|
1036
|
+
self.deferrable = deferrable
|
1028
1037
|
|
1029
1038
|
def execute(self, context: Context) -> str:
|
1030
1039
|
arn = self.hook.start_pipeline(
|
1031
1040
|
pipeline_name=self.pipeline_name,
|
1032
1041
|
display_name=self.display_name,
|
1033
1042
|
pipeline_params=self.pipeline_params,
|
1034
|
-
wait_for_completion=self.wait_for_completion,
|
1035
|
-
check_interval=self.check_interval,
|
1036
|
-
verbose=self.verbose,
|
1037
1043
|
)
|
1038
1044
|
self.log.info(
|
1039
1045
|
"Starting a new execution for pipeline %s, running with ARN %s", self.pipeline_name, arn
|
1040
1046
|
)
|
1047
|
+
if self.deferrable:
|
1048
|
+
self.defer(
|
1049
|
+
trigger=SageMakerPipelineTrigger(
|
1050
|
+
waiter_type=SageMakerPipelineTrigger.Type.COMPLETE,
|
1051
|
+
pipeline_execution_arn=arn,
|
1052
|
+
waiter_delay=self.check_interval,
|
1053
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
1054
|
+
aws_conn_id=self.aws_conn_id,
|
1055
|
+
),
|
1056
|
+
method_name="execute_complete",
|
1057
|
+
)
|
1058
|
+
elif self.wait_for_completion:
|
1059
|
+
self.hook.check_status(
|
1060
|
+
arn,
|
1061
|
+
"PipelineExecutionStatus",
|
1062
|
+
lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
|
1063
|
+
self.check_interval,
|
1064
|
+
non_terminal_states=self.hook.pipeline_non_terminal_states,
|
1065
|
+
max_ingestion_time=self.waiter_max_attempts * self.check_interval,
|
1066
|
+
)
|
1041
1067
|
return arn
|
1042
1068
|
|
1069
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
1070
|
+
if event is None or event["status"] != "success":
|
1071
|
+
raise AirflowException(f"Failure during pipeline execution: {event}")
|
1072
|
+
return event["value"]
|
1073
|
+
|
1043
1074
|
|
1044
1075
|
class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
1045
1076
|
"""
|
@@ -1057,6 +1088,7 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1057
1088
|
:param verbose: Whether to print steps details when waiting for completion.
|
1058
1089
|
Defaults to true, consider turning off for pipelines that have thousands of steps.
|
1059
1090
|
:param fail_if_not_running: raises an exception if the pipeline stopped or succeeded before this was run
|
1091
|
+
:param deferrable: Run operator in the deferrable mode.
|
1060
1092
|
|
1061
1093
|
:return str: Returns the status of the pipeline execution after the operation has been done.
|
1062
1094
|
"""
|
@@ -1073,23 +1105,24 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1073
1105
|
pipeline_exec_arn: str,
|
1074
1106
|
wait_for_completion: bool = False,
|
1075
1107
|
check_interval: int = CHECK_INTERVAL_SECOND,
|
1108
|
+
waiter_max_attempts: int = 9999,
|
1076
1109
|
verbose: bool = True,
|
1077
1110
|
fail_if_not_running: bool = False,
|
1111
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
1078
1112
|
**kwargs,
|
1079
1113
|
):
|
1080
1114
|
super().__init__(config={}, aws_conn_id=aws_conn_id, **kwargs)
|
1081
1115
|
self.pipeline_exec_arn = pipeline_exec_arn
|
1082
1116
|
self.wait_for_completion = wait_for_completion
|
1083
1117
|
self.check_interval = check_interval
|
1118
|
+
self.waiter_max_attempts = waiter_max_attempts
|
1084
1119
|
self.verbose = verbose
|
1085
1120
|
self.fail_if_not_running = fail_if_not_running
|
1121
|
+
self.deferrable = deferrable
|
1086
1122
|
|
1087
1123
|
def execute(self, context: Context) -> str:
|
1088
1124
|
status = self.hook.stop_pipeline(
|
1089
1125
|
pipeline_exec_arn=self.pipeline_exec_arn,
|
1090
|
-
wait_for_completion=self.wait_for_completion,
|
1091
|
-
check_interval=self.check_interval,
|
1092
|
-
verbose=self.verbose,
|
1093
1126
|
fail_if_not_running=self.fail_if_not_running,
|
1094
1127
|
)
|
1095
1128
|
self.log.info(
|
@@ -1097,8 +1130,43 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1097
1130
|
self.pipeline_exec_arn,
|
1098
1131
|
status,
|
1099
1132
|
)
|
1133
|
+
|
1134
|
+
if status not in self.hook.pipeline_non_terminal_states:
|
1135
|
+
# pipeline already stopped
|
1136
|
+
return status
|
1137
|
+
|
1138
|
+
# else, eventually wait for completion
|
1139
|
+
if self.deferrable:
|
1140
|
+
self.defer(
|
1141
|
+
trigger=SageMakerPipelineTrigger(
|
1142
|
+
waiter_type=SageMakerPipelineTrigger.Type.STOPPED,
|
1143
|
+
pipeline_execution_arn=self.pipeline_exec_arn,
|
1144
|
+
waiter_delay=self.check_interval,
|
1145
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
1146
|
+
aws_conn_id=self.aws_conn_id,
|
1147
|
+
),
|
1148
|
+
method_name="execute_complete",
|
1149
|
+
)
|
1150
|
+
elif self.wait_for_completion:
|
1151
|
+
status = self.hook.check_status(
|
1152
|
+
self.pipeline_exec_arn,
|
1153
|
+
"PipelineExecutionStatus",
|
1154
|
+
lambda p: self.hook.describe_pipeline_exec(p, self.verbose),
|
1155
|
+
self.check_interval,
|
1156
|
+
non_terminal_states=self.hook.pipeline_non_terminal_states,
|
1157
|
+
max_ingestion_time=self.waiter_max_attempts * self.check_interval,
|
1158
|
+
)["PipelineExecutionStatus"]
|
1159
|
+
|
1100
1160
|
return status
|
1101
1161
|
|
1162
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
1163
|
+
if event is None or event["status"] != "success":
|
1164
|
+
raise AirflowException(f"Failure during pipeline execution: {event}")
|
1165
|
+
else:
|
1166
|
+
# theoretically we should do a `describe` call to know this,
|
1167
|
+
# but if we reach this point, this is the only possible status
|
1168
|
+
return "Stopped"
|
1169
|
+
|
1102
1170
|
|
1103
1171
|
class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
|
1104
1172
|
"""
|
@@ -17,11 +17,14 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import json
|
20
|
-
from
|
20
|
+
from datetime import timedelta
|
21
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
21
22
|
|
23
|
+
from airflow.configuration import conf
|
22
24
|
from airflow.exceptions import AirflowException
|
23
25
|
from airflow.models import BaseOperator
|
24
26
|
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
|
27
|
+
from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
|
25
28
|
|
26
29
|
if TYPE_CHECKING:
|
27
30
|
from airflow.utils.context import Context
|
@@ -42,6 +45,11 @@ class StepFunctionStartExecutionOperator(BaseOperator):
|
|
42
45
|
:param state_machine_input: JSON data input to pass to the State Machine
|
43
46
|
:param aws_conn_id: aws connection to uses
|
44
47
|
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
|
48
|
+
:param waiter_max_attempts: Maximum number of attempts to poll the execution.
|
49
|
+
:param waiter_delay: Number of seconds between polling the state of the execution.
|
50
|
+
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
|
51
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
52
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
45
53
|
"""
|
46
54
|
|
47
55
|
template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
|
@@ -56,6 +64,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
|
|
56
64
|
state_machine_input: dict | str | None = None,
|
57
65
|
aws_conn_id: str = "aws_default",
|
58
66
|
region_name: str | None = None,
|
67
|
+
waiter_max_attempts: int = 30,
|
68
|
+
waiter_delay: int = 60,
|
69
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
59
70
|
**kwargs,
|
60
71
|
):
|
61
72
|
super().__init__(**kwargs)
|
@@ -64,6 +75,9 @@ class StepFunctionStartExecutionOperator(BaseOperator):
|
|
64
75
|
self.input = state_machine_input
|
65
76
|
self.aws_conn_id = aws_conn_id
|
66
77
|
self.region_name = region_name
|
78
|
+
self.waiter_delay = waiter_delay
|
79
|
+
self.waiter_max_attempts = waiter_max_attempts
|
80
|
+
self.deferrable = deferrable
|
67
81
|
|
68
82
|
def execute(self, context: Context):
|
69
83
|
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
@@ -74,9 +88,27 @@ class StepFunctionStartExecutionOperator(BaseOperator):
|
|
74
88
|
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
|
75
89
|
|
76
90
|
self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
|
77
|
-
|
91
|
+
if self.deferrable:
|
92
|
+
self.defer(
|
93
|
+
trigger=StepFunctionsExecutionCompleteTrigger(
|
94
|
+
execution_arn=execution_arn,
|
95
|
+
waiter_delay=self.waiter_delay,
|
96
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
97
|
+
aws_conn_id=self.aws_conn_id,
|
98
|
+
region_name=self.region_name,
|
99
|
+
),
|
100
|
+
method_name="execute_complete",
|
101
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
|
102
|
+
)
|
78
103
|
return execution_arn
|
79
104
|
|
105
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
106
|
+
if event is None or event["status"] != "success":
|
107
|
+
raise AirflowException(f"Trigger error: event is {event}")
|
108
|
+
|
109
|
+
self.log.info("State Machine execution completed successfully")
|
110
|
+
return event["execution_arn"]
|
111
|
+
|
80
112
|
|
81
113
|
class StepFunctionGetExecutionOutputOperator(BaseOperator):
|
82
114
|
"""
|