apache-airflow-providers-amazon 9.14.0__py3-none-any.whl → 9.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +3 -3
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +106 -5
- airflow/providers/amazon/aws/auth_manager/routes/login.py +7 -1
- airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py +5 -1
- airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +6 -2
- airflow/providers/amazon/aws/hooks/athena_sql.py +2 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +2 -2
- airflow/providers/amazon/aws/hooks/batch_client.py +4 -6
- airflow/providers/amazon/aws/hooks/batch_waiters.py +0 -1
- airflow/providers/amazon/aws/hooks/chime.py +1 -1
- airflow/providers/amazon/aws/hooks/datasync.py +3 -3
- airflow/providers/amazon/aws/hooks/firehose.py +56 -0
- airflow/providers/amazon/aws/hooks/glue.py +7 -1
- airflow/providers/amazon/aws/hooks/kinesis.py +31 -13
- airflow/providers/amazon/aws/hooks/mwaa.py +38 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +20 -6
- airflow/providers/amazon/aws/hooks/s3.py +41 -11
- airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/hooks/ses.py +76 -10
- airflow/providers/amazon/aws/hooks/sns.py +74 -18
- airflow/providers/amazon/aws/hooks/sqs.py +64 -11
- airflow/providers/amazon/aws/hooks/ssm.py +34 -6
- airflow/providers/amazon/aws/hooks/step_function.py +1 -1
- airflow/providers/amazon/aws/links/base_aws.py +1 -1
- airflow/providers/amazon/aws/notifications/ses.py +139 -0
- airflow/providers/amazon/aws/notifications/sns.py +16 -1
- airflow/providers/amazon/aws/notifications/sqs.py +17 -1
- airflow/providers/amazon/aws/operators/base_aws.py +2 -2
- airflow/providers/amazon/aws/operators/bedrock.py +2 -0
- airflow/providers/amazon/aws/operators/cloud_formation.py +2 -2
- airflow/providers/amazon/aws/operators/datasync.py +2 -1
- airflow/providers/amazon/aws/operators/emr.py +44 -33
- airflow/providers/amazon/aws/operators/mwaa.py +12 -3
- airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/operators/ssm.py +122 -17
- airflow/providers/amazon/aws/secrets/secrets_manager.py +3 -4
- airflow/providers/amazon/aws/sensors/base_aws.py +2 -2
- airflow/providers/amazon/aws/sensors/mwaa.py +14 -1
- airflow/providers/amazon/aws/sensors/s3.py +27 -13
- airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py +1 -1
- airflow/providers/amazon/aws/sensors/ssm.py +33 -17
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +3 -3
- airflow/providers/amazon/aws/transfers/base.py +5 -5
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +4 -4
- airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/ftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +48 -5
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -1
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +2 -5
- airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/local_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +6 -6
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_ftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
- airflow/providers/amazon/aws/transfers/s3_to_sftp.py +1 -1
- airflow/providers/amazon/aws/transfers/s3_to_sql.py +1 -1
- airflow/providers/amazon/aws/transfers/salesforce_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sftp_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +4 -5
- airflow/providers/amazon/aws/triggers/bedrock.py +1 -1
- airflow/providers/amazon/aws/triggers/s3.py +29 -2
- airflow/providers/amazon/aws/triggers/ssm.py +17 -1
- airflow/providers/amazon/aws/utils/connection_wrapper.py +2 -5
- airflow/providers/amazon/aws/utils/mixins.py +1 -1
- airflow/providers/amazon/aws/utils/waiter.py +2 -2
- airflow/providers/amazon/aws/waiters/emr.json +6 -6
- airflow/providers/amazon/get_provider_info.py +19 -1
- airflow/providers/amazon/version_compat.py +19 -16
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/METADATA +25 -19
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/RECORD +79 -76
- apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses/NOTICE +5 -0
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.14.0.dist-info → apache_airflow_providers_amazon-9.18.0rc2.dist-info}/entry_points.txt +0 -0
- {airflow/providers/amazon → apache_airflow_providers_amazon-9.18.0rc2.dist-info/licenses}/LICENSE +0 -0
|
@@ -21,6 +21,7 @@ from collections.abc import Sequence
|
|
|
21
21
|
from functools import cached_property
|
|
22
22
|
|
|
23
23
|
from airflow.providers.amazon.aws.hooks.sns import SnsHook
|
|
24
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
|
|
24
25
|
from airflow.providers.common.compat.notifier import BaseNotifier
|
|
25
26
|
|
|
26
27
|
|
|
@@ -60,8 +61,13 @@ class SnsNotifier(BaseNotifier):
|
|
|
60
61
|
subject: str | None = None,
|
|
61
62
|
message_attributes: dict | None = None,
|
|
62
63
|
region_name: str | None = None,
|
|
64
|
+
**kwargs,
|
|
63
65
|
):
|
|
64
|
-
|
|
66
|
+
if AIRFLOW_V_3_1_PLUS:
|
|
67
|
+
# Support for passing context was added in 3.1.0
|
|
68
|
+
super().__init__(**kwargs)
|
|
69
|
+
else:
|
|
70
|
+
super().__init__()
|
|
65
71
|
self.aws_conn_id = aws_conn_id
|
|
66
72
|
self.region_name = region_name
|
|
67
73
|
self.target_arn = target_arn
|
|
@@ -83,5 +89,14 @@ class SnsNotifier(BaseNotifier):
|
|
|
83
89
|
message_attributes=self.message_attributes,
|
|
84
90
|
)
|
|
85
91
|
|
|
92
|
+
async def async_notify(self, context):
|
|
93
|
+
"""Publish the notification message to Amazon SNS (async)."""
|
|
94
|
+
await self.hook.apublish_to_target(
|
|
95
|
+
target_arn=self.target_arn,
|
|
96
|
+
message=self.message,
|
|
97
|
+
subject=self.subject,
|
|
98
|
+
message_attributes=self.message_attributes,
|
|
99
|
+
)
|
|
100
|
+
|
|
86
101
|
|
|
87
102
|
send_sns_notification = SnsNotifier
|
|
@@ -21,6 +21,7 @@ from collections.abc import Sequence
|
|
|
21
21
|
from functools import cached_property
|
|
22
22
|
|
|
23
23
|
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
|
|
24
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS
|
|
24
25
|
from airflow.providers.common.compat.notifier import BaseNotifier
|
|
25
26
|
|
|
26
27
|
|
|
@@ -64,8 +65,13 @@ class SqsNotifier(BaseNotifier):
|
|
|
64
65
|
message_group_id: str | None = None,
|
|
65
66
|
delay_seconds: int = 0,
|
|
66
67
|
region_name: str | None = None,
|
|
68
|
+
**kwargs,
|
|
67
69
|
):
|
|
68
|
-
|
|
70
|
+
if AIRFLOW_V_3_1_PLUS:
|
|
71
|
+
# Support for passing context was added in 3.1.0
|
|
72
|
+
super().__init__(**kwargs)
|
|
73
|
+
else:
|
|
74
|
+
super().__init__()
|
|
69
75
|
self.aws_conn_id = aws_conn_id
|
|
70
76
|
self.region_name = region_name
|
|
71
77
|
self.queue_url = queue_url
|
|
@@ -89,5 +95,15 @@ class SqsNotifier(BaseNotifier):
|
|
|
89
95
|
message_group_id=self.message_group_id,
|
|
90
96
|
)
|
|
91
97
|
|
|
98
|
+
async def async_notify(self, context):
|
|
99
|
+
"""Publish the notification message to Amazon SQS queue (async)."""
|
|
100
|
+
await self.hook.asend_message(
|
|
101
|
+
queue_url=self.queue_url,
|
|
102
|
+
message_body=self.message_body,
|
|
103
|
+
delay_seconds=self.delay_seconds,
|
|
104
|
+
message_attributes=self.message_attributes,
|
|
105
|
+
message_group_id=self.message_group_id,
|
|
106
|
+
)
|
|
107
|
+
|
|
92
108
|
|
|
93
109
|
send_sqs_notification = SqsNotifier
|
|
@@ -25,8 +25,8 @@ from airflow.providers.amazon.aws.utils.mixins import (
|
|
|
25
25
|
AwsHookType,
|
|
26
26
|
aws_template_fields,
|
|
27
27
|
)
|
|
28
|
-
from airflow.providers.amazon.version_compat import
|
|
29
|
-
from airflow.
|
|
28
|
+
from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
|
|
29
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]):
|
|
@@ -482,6 +482,8 @@ class BedrockCreateKnowledgeBaseOperator(AwsBaseOperator[BedrockAgentHook]):
|
|
|
482
482
|
# It may also be that permissions haven't even propagated yet to check for the index
|
|
483
483
|
or "server returned 401" in error_message
|
|
484
484
|
or "user does not have permissions" in error_message
|
|
485
|
+
or "status code: 403" in error_message
|
|
486
|
+
or "bad authorization" in error_message
|
|
485
487
|
)
|
|
486
488
|
if all(
|
|
487
489
|
[
|
|
@@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
|
|
32
32
|
|
|
33
33
|
class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
|
|
34
34
|
"""
|
|
35
|
-
An operator that creates
|
|
35
|
+
An operator that creates an AWS CloudFormation stack.
|
|
36
36
|
|
|
37
37
|
.. seealso::
|
|
38
38
|
For more information on how to use this operator, take a look at the guide:
|
|
@@ -68,7 +68,7 @@ class CloudFormationCreateStackOperator(AwsBaseOperator[CloudFormationHook]):
|
|
|
68
68
|
|
|
69
69
|
class CloudFormationDeleteStackOperator(AwsBaseOperator[CloudFormationHook]):
|
|
70
70
|
"""
|
|
71
|
-
An operator that deletes
|
|
71
|
+
An operator that deletes an AWS CloudFormation stack.
|
|
72
72
|
|
|
73
73
|
.. seealso::
|
|
74
74
|
For more information on how to use this operator, take a look at the guide:
|
|
@@ -23,11 +23,12 @@ import random
|
|
|
23
23
|
from collections.abc import Sequence
|
|
24
24
|
from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
|
26
|
-
from airflow.exceptions import AirflowException
|
|
26
|
+
from airflow.exceptions import AirflowException
|
|
27
27
|
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
|
|
28
28
|
from airflow.providers.amazon.aws.links.datasync import DataSyncTaskExecutionLink, DataSyncTaskLink
|
|
29
29
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
|
30
30
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
|
31
|
+
from airflow.providers.common.compat.sdk import AirflowTaskTimeout
|
|
31
32
|
|
|
32
33
|
if TYPE_CHECKING:
|
|
33
34
|
from airflow.utils.context import Context
|
|
@@ -57,8 +57,8 @@ from airflow.providers.amazon.aws.utils.waiter import (
|
|
|
57
57
|
waiter,
|
|
58
58
|
)
|
|
59
59
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
|
60
|
+
from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
|
|
60
61
|
from airflow.utils.helpers import exactly_one, prune_dict
|
|
61
|
-
from airflow.utils.types import NOTSET, ArgNotSet
|
|
62
62
|
|
|
63
63
|
if TYPE_CHECKING:
|
|
64
64
|
from airflow.utils.context import Context
|
|
@@ -654,11 +654,10 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
|
|
|
654
654
|
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
|
655
655
|
:param verify: Whether or not to verify SSL certificates. See:
|
|
656
656
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
|
657
|
-
:param wait_for_completion:
|
|
658
|
-
Whether to finish task immediately after creation (False) or wait for jobflow
|
|
657
|
+
:param wait_for_completion: Whether to finish task immediately after creation (False) or wait for jobflow
|
|
659
658
|
completion (True)
|
|
660
659
|
(default: None)
|
|
661
|
-
:param wait_policy: Whether to finish the task immediately after creation (None) or:
|
|
660
|
+
:param wait_policy: Deprecated. Use `wait_for_completion` instead. Whether to finish the task immediately after creation (None) or:
|
|
662
661
|
- wait for the jobflow completion (WaitPolicy.WAIT_FOR_COMPLETION)
|
|
663
662
|
- wait for the jobflow completion and cluster to terminate (WaitPolicy.WAIT_FOR_STEPS_COMPLETION)
|
|
664
663
|
(default: None)
|
|
@@ -698,19 +697,29 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
|
|
|
698
697
|
super().__init__(**kwargs)
|
|
699
698
|
self.emr_conn_id = emr_conn_id
|
|
700
699
|
self.job_flow_overrides = job_flow_overrides or {}
|
|
701
|
-
self.
|
|
700
|
+
self.wait_for_completion = wait_for_completion
|
|
702
701
|
self.waiter_max_attempts = waiter_max_attempts or 60
|
|
703
702
|
self.waiter_delay = waiter_delay or 60
|
|
704
703
|
self.deferrable = deferrable
|
|
705
704
|
|
|
706
|
-
if
|
|
705
|
+
if wait_policy is not None:
|
|
707
706
|
warnings.warn(
|
|
708
|
-
"`
|
|
707
|
+
"`wait_policy` parameter is deprecated and will be removed in a future release; "
|
|
708
|
+
"please use `wait_for_completion` (bool) instead.",
|
|
709
709
|
AirflowProviderDeprecationWarning,
|
|
710
710
|
stacklevel=2,
|
|
711
711
|
)
|
|
712
|
-
|
|
713
|
-
|
|
712
|
+
|
|
713
|
+
if wait_for_completion is not None:
|
|
714
|
+
raise ValueError(
|
|
715
|
+
"Cannot specify both `wait_for_completion` and deprecated `wait_policy`. "
|
|
716
|
+
"Please use `wait_for_completion` (bool)."
|
|
717
|
+
)
|
|
718
|
+
|
|
719
|
+
self.wait_for_completion = wait_policy in (
|
|
720
|
+
WaitPolicy.WAIT_FOR_COMPLETION,
|
|
721
|
+
WaitPolicy.WAIT_FOR_STEPS_COMPLETION,
|
|
722
|
+
)
|
|
714
723
|
|
|
715
724
|
@property
|
|
716
725
|
def _hook_parameters(self):
|
|
@@ -748,30 +757,32 @@ class EmrCreateJobFlowOperator(AwsBaseOperator[EmrHook]):
|
|
|
748
757
|
job_flow_id=self._job_flow_id,
|
|
749
758
|
log_uri=get_log_uri(emr_client=self.hook.conn, job_flow_id=self._job_flow_id),
|
|
750
759
|
)
|
|
751
|
-
if self.
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
760
|
+
if self.wait_for_completion:
|
|
761
|
+
waiter_name = WAITER_POLICY_NAME_MAPPING[WaitPolicy.WAIT_FOR_COMPLETION]
|
|
762
|
+
|
|
763
|
+
if self.deferrable:
|
|
764
|
+
self.defer(
|
|
765
|
+
trigger=EmrCreateJobFlowTrigger(
|
|
766
|
+
job_flow_id=self._job_flow_id,
|
|
767
|
+
aws_conn_id=self.aws_conn_id,
|
|
768
|
+
waiter_delay=self.waiter_delay,
|
|
769
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
|
770
|
+
),
|
|
771
|
+
method_name="execute_complete",
|
|
772
|
+
# timeout is set to ensure that if a trigger dies, the timeout does not restart
|
|
773
|
+
# 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent)
|
|
774
|
+
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60),
|
|
775
|
+
)
|
|
776
|
+
else:
|
|
777
|
+
self.hook.get_waiter(waiter_name).wait(
|
|
778
|
+
ClusterId=self._job_flow_id,
|
|
779
|
+
WaiterConfig=prune_dict(
|
|
780
|
+
{
|
|
781
|
+
"Delay": self.waiter_delay,
|
|
782
|
+
"MaxAttempts": self.waiter_max_attempts,
|
|
783
|
+
}
|
|
784
|
+
),
|
|
785
|
+
)
|
|
775
786
|
return self._job_flow_id
|
|
776
787
|
|
|
777
788
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
from __future__ import annotations
|
|
20
20
|
|
|
21
21
|
from collections.abc import Sequence
|
|
22
|
-
from typing import TYPE_CHECKING, Any
|
|
22
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
23
23
|
|
|
24
24
|
from airflow.configuration import conf
|
|
25
25
|
from airflow.exceptions import AirflowException
|
|
@@ -46,12 +46,14 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
46
46
|
:param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated)
|
|
47
47
|
:param logical_date: The logical date (previously called execution date). This is the time or interval
|
|
48
48
|
covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a
|
|
49
|
-
unique key. (templated)
|
|
49
|
+
unique key. This field is required if your environment is running with Airflow 3. (templated)
|
|
50
50
|
:param data_interval_start: The beginning of the interval the DAG run covers
|
|
51
51
|
:param data_interval_end: The end of the interval the DAG run covers
|
|
52
52
|
:param conf: Additional configuration parameters. The value of this field can be set only when creating
|
|
53
53
|
the object. (templated)
|
|
54
54
|
:param note: Contains manually entered notes by the user about the DagRun. (templated)
|
|
55
|
+
:param airflow_version: The Airflow major version the MWAA environment runs.
|
|
56
|
+
This parameter is only used if the local web token method is used to call Airflow API. (templated)
|
|
55
57
|
|
|
56
58
|
:param wait_for_completion: Whether to wait for DAG run to stop. (default: False)
|
|
57
59
|
:param waiter_delay: Time in seconds to wait between status checks. (default: 120)
|
|
@@ -81,6 +83,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
81
83
|
"data_interval_end",
|
|
82
84
|
"conf",
|
|
83
85
|
"note",
|
|
86
|
+
"airflow_version",
|
|
84
87
|
)
|
|
85
88
|
template_fields_renderers = {"conf": "json"}
|
|
86
89
|
|
|
@@ -95,6 +98,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
95
98
|
data_interval_end: str | None = None,
|
|
96
99
|
conf: dict | None = None,
|
|
97
100
|
note: str | None = None,
|
|
101
|
+
airflow_version: Literal[2, 3] | None = None,
|
|
98
102
|
wait_for_completion: bool = False,
|
|
99
103
|
waiter_delay: int = 60,
|
|
100
104
|
waiter_max_attempts: int = 20,
|
|
@@ -110,6 +114,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
110
114
|
self.data_interval_end = data_interval_end
|
|
111
115
|
self.conf = conf if conf else {}
|
|
112
116
|
self.note = note
|
|
117
|
+
self.airflow_version = airflow_version
|
|
113
118
|
self.wait_for_completion = wait_for_completion
|
|
114
119
|
self.waiter_delay = waiter_delay
|
|
115
120
|
self.waiter_max_attempts = waiter_max_attempts
|
|
@@ -123,7 +128,10 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
123
128
|
dag_run_id = validated_event["dag_run_id"]
|
|
124
129
|
self.log.info("DAG run %s of DAG %s completed", dag_run_id, self.trigger_dag_id)
|
|
125
130
|
return self.hook.invoke_rest_api(
|
|
126
|
-
env_name=self.env_name,
|
|
131
|
+
env_name=self.env_name,
|
|
132
|
+
path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}",
|
|
133
|
+
method="GET",
|
|
134
|
+
airflow_version=self.airflow_version,
|
|
127
135
|
)
|
|
128
136
|
|
|
129
137
|
def execute(self, context: Context) -> dict:
|
|
@@ -146,6 +154,7 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
|
|
146
154
|
"conf": self.conf,
|
|
147
155
|
"note": self.note,
|
|
148
156
|
},
|
|
157
|
+
airflow_version=self.airflow_version,
|
|
149
158
|
)
|
|
150
159
|
|
|
151
160
|
dag_run_id = response["RestApiResponse"]["dag_run_id"]
|
|
@@ -33,7 +33,7 @@ from airflow.providers.amazon.aws.links.sagemaker_unified_studio import (
|
|
|
33
33
|
from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import (
|
|
34
34
|
SageMakerNotebookJobTrigger,
|
|
35
35
|
)
|
|
36
|
-
from airflow.providers.
|
|
36
|
+
from airflow.providers.common.compat.sdk import BaseOperator
|
|
37
37
|
|
|
38
38
|
if TYPE_CHECKING:
|
|
39
39
|
from airflow.utils.context import Context
|
|
@@ -20,7 +20,6 @@ from collections.abc import Sequence
|
|
|
20
20
|
from typing import TYPE_CHECKING, Any
|
|
21
21
|
|
|
22
22
|
from airflow.configuration import conf
|
|
23
|
-
from airflow.exceptions import AirflowException
|
|
24
23
|
from airflow.providers.amazon.aws.hooks.ssm import SsmHook
|
|
25
24
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
|
26
25
|
from airflow.providers.amazon.aws.triggers.ssm import SsmRunCommandTrigger
|
|
@@ -36,27 +35,35 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]):
|
|
|
36
35
|
Executes the SSM Run Command to perform actions on managed instances.
|
|
37
36
|
|
|
38
37
|
.. seealso::
|
|
39
|
-
For more information on how to use this operator, take a look at the
|
|
38
|
+
For more information on how to use this operator, take a look at the
|
|
39
|
+
guide:
|
|
40
40
|
:ref:`howto/operator:SsmRunCommandOperator`
|
|
41
41
|
|
|
42
|
-
:param document_name: The name of the Amazon Web Services Systems Manager
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
:param
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
(default:
|
|
42
|
+
:param document_name: The name of the Amazon Web Services Systems Manager
|
|
43
|
+
document (SSM document) to run.
|
|
44
|
+
:param run_command_kwargs: Optional parameters to pass to the send_command
|
|
45
|
+
API.
|
|
46
|
+
|
|
47
|
+
:param wait_for_completion: Whether to wait for cluster to stop.
|
|
48
|
+
(default: True)
|
|
49
|
+
:param waiter_delay: Time in seconds to wait between status checks.
|
|
50
|
+
(default: 120)
|
|
51
|
+
:param waiter_max_attempts: Maximum number of attempts to check for job
|
|
52
|
+
completion. (default: 75)
|
|
53
|
+
:param deferrable: If True, the operator will wait asynchronously for the
|
|
54
|
+
cluster to stop. This implies waiting for completion. This mode
|
|
55
|
+
requires aiobotocore module to be installed. (default: False)
|
|
51
56
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
|
52
|
-
If this is ``None`` or empty then the default boto3 behaviour is used.
|
|
53
|
-
running Airflow in a distributed manner and aws_conn_id is None or
|
|
57
|
+
If this is ``None`` or empty then the default boto3 behaviour is used.
|
|
58
|
+
If running Airflow in a distributed manner and aws_conn_id is None or
|
|
54
59
|
empty, then default boto3 configuration would be used (and must be
|
|
55
60
|
maintained on each worker node).
|
|
56
|
-
:param region_name: AWS region_name. If not specified then the default
|
|
61
|
+
:param region_name: AWS region_name. If not specified then the default
|
|
62
|
+
boto3 behaviour is used.
|
|
57
63
|
:param verify: Whether or not to verify SSL certificates. See:
|
|
58
64
|
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
|
59
|
-
:param botocore_config: Configuration dictionary (key-values) for botocore
|
|
65
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore
|
|
66
|
+
client. See:
|
|
60
67
|
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
|
61
68
|
"""
|
|
62
69
|
|
|
@@ -90,7 +97,7 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]):
|
|
|
90
97
|
event = validate_execute_complete_event(event)
|
|
91
98
|
|
|
92
99
|
if event["status"] != "success":
|
|
93
|
-
raise
|
|
100
|
+
raise RuntimeError(f"Error while running run command: {event}")
|
|
94
101
|
|
|
95
102
|
self.log.info("SSM run command `%s` completed.", event["command_id"])
|
|
96
103
|
return event["command_id"]
|
|
@@ -112,6 +119,9 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]):
|
|
|
112
119
|
waiter_delay=self.waiter_delay,
|
|
113
120
|
waiter_max_attempts=self.waiter_max_attempts,
|
|
114
121
|
aws_conn_id=self.aws_conn_id,
|
|
122
|
+
region_name=self.region_name,
|
|
123
|
+
verify=self.verify,
|
|
124
|
+
botocore_config=self.botocore_config,
|
|
115
125
|
),
|
|
116
126
|
method_name="execute_complete",
|
|
117
127
|
)
|
|
@@ -125,7 +135,102 @@ class SsmRunCommandOperator(AwsBaseOperator[SsmHook]):
|
|
|
125
135
|
waiter.wait(
|
|
126
136
|
CommandId=command_id,
|
|
127
137
|
InstanceId=instance_id,
|
|
128
|
-
WaiterConfig={
|
|
138
|
+
WaiterConfig={
|
|
139
|
+
"Delay": self.waiter_delay,
|
|
140
|
+
"MaxAttempts": self.waiter_max_attempts,
|
|
141
|
+
},
|
|
129
142
|
)
|
|
130
143
|
|
|
131
144
|
return command_id
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class SsmGetCommandInvocationOperator(AwsBaseOperator[SsmHook]):
|
|
148
|
+
"""
|
|
149
|
+
Retrieves the output and execution details of an SSM command invocation.
|
|
150
|
+
|
|
151
|
+
This operator allows you to fetch the standard output, standard error,
|
|
152
|
+
execution status, and other details from SSM commands. It can be used to
|
|
153
|
+
retrieve output from commands executed by SsmRunCommandOperator in previous
|
|
154
|
+
tasks, or from commands executed outside of Airflow entirely.
|
|
155
|
+
|
|
156
|
+
The operator returns structured data including stdout, stderr, execution
|
|
157
|
+
times, and status information for each instance that executed the command.
|
|
158
|
+
|
|
159
|
+
.. seealso::
|
|
160
|
+
For more information on how to use this operator, take a look at the
|
|
161
|
+
guide:
|
|
162
|
+
:ref:`howto/operator:SsmGetCommandInvocationOperator`
|
|
163
|
+
|
|
164
|
+
:param command_id: The ID of the SSM command to retrieve output for.
|
|
165
|
+
:param instance_id: The ID of the specific instance to retrieve output
|
|
166
|
+
for. If not provided, retrieves output from all instances that
|
|
167
|
+
executed the command.
|
|
168
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
|
169
|
+
If this is ``None`` or empty then the default boto3 behaviour is used.
|
|
170
|
+
If running Airflow in a distributed manner and aws_conn_id is None or
|
|
171
|
+
empty, then default boto3 configuration would be used (and must be
|
|
172
|
+
maintained on each worker node).
|
|
173
|
+
:param region_name: AWS region_name. If not specified then the default
|
|
174
|
+
boto3 behaviour is used.
|
|
175
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
|
176
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
|
177
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore
|
|
178
|
+
client. See:
|
|
179
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
aws_hook_class = SsmHook
|
|
183
|
+
template_fields: Sequence[str] = aws_template_fields(
|
|
184
|
+
"command_id",
|
|
185
|
+
"instance_id",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def __init__(
|
|
189
|
+
self,
|
|
190
|
+
*,
|
|
191
|
+
command_id: str,
|
|
192
|
+
instance_id: str | None = None,
|
|
193
|
+
**kwargs,
|
|
194
|
+
):
|
|
195
|
+
super().__init__(**kwargs)
|
|
196
|
+
self.command_id = command_id
|
|
197
|
+
self.instance_id = instance_id
|
|
198
|
+
|
|
199
|
+
def execute(self, context: Context) -> dict[str, Any]:
|
|
200
|
+
"""Execute the operator to retrieve command invocation output."""
|
|
201
|
+
if self.instance_id:
|
|
202
|
+
self.log.info(
|
|
203
|
+
"Retrieving output for command %s on instance %s",
|
|
204
|
+
self.command_id,
|
|
205
|
+
self.instance_id,
|
|
206
|
+
)
|
|
207
|
+
invocations = [{"InstanceId": self.instance_id}]
|
|
208
|
+
else:
|
|
209
|
+
self.log.info("Retrieving output for command %s from all instances", self.command_id)
|
|
210
|
+
response = self.hook.list_command_invocations(self.command_id)
|
|
211
|
+
invocations = response.get("CommandInvocations", [])
|
|
212
|
+
|
|
213
|
+
output_data: dict[str, Any] = {"command_id": self.command_id, "invocations": []}
|
|
214
|
+
|
|
215
|
+
for invocation in invocations:
|
|
216
|
+
instance_id = invocation["InstanceId"]
|
|
217
|
+
try:
|
|
218
|
+
invocation_details = self.hook.get_command_invocation(self.command_id, instance_id)
|
|
219
|
+
output_data["invocations"].append(
|
|
220
|
+
{
|
|
221
|
+
"instance_id": instance_id,
|
|
222
|
+
"status": invocation_details.get("Status", ""),
|
|
223
|
+
"response_code": invocation_details.get("ResponseCode", ""),
|
|
224
|
+
"standard_output": invocation_details.get("StandardOutputContent", ""),
|
|
225
|
+
"standard_error": invocation_details.get("StandardErrorContent", ""),
|
|
226
|
+
"execution_start_time": invocation_details.get("ExecutionStartDateTime", ""),
|
|
227
|
+
"execution_end_time": invocation_details.get("ExecutionEndDateTime", ""),
|
|
228
|
+
"document_name": invocation_details.get("DocumentName", ""),
|
|
229
|
+
"comment": invocation_details.get("Comment", ""),
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
self.log.warning("Failed to get output for instance %s: %s", instance_id, e)
|
|
234
|
+
output_data["invocations"].append({"instance_id": instance_id, "error": str(e)})
|
|
235
|
+
|
|
236
|
+
return output_data
|
|
@@ -187,10 +187,9 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
|
|
|
187
187
|
}
|
|
188
188
|
|
|
189
189
|
for conn_field, extra_words in self.extra_conn_words.items():
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
possible_words_for_conn_fields[conn_field].extend(extra_words)
|
|
190
|
+
# Support `user` for backwards compatibility.
|
|
191
|
+
conn_field_backcompat = "login" if conn_field == "user" else conn_field
|
|
192
|
+
possible_words_for_conn_fields[conn_field_backcompat].extend(extra_words)
|
|
194
193
|
|
|
195
194
|
conn_d: dict[str, Any] = {}
|
|
196
195
|
for conn_field, possible_words in possible_words_for_conn_fields.items():
|
|
@@ -25,8 +25,8 @@ from airflow.providers.amazon.aws.utils.mixins import (
|
|
|
25
25
|
AwsHookType,
|
|
26
26
|
aws_template_fields,
|
|
27
27
|
)
|
|
28
|
-
from airflow.providers.amazon.version_compat import
|
|
29
|
-
from airflow.
|
|
28
|
+
from airflow.providers.amazon.version_compat import NOTSET, ArgNotSet
|
|
29
|
+
from airflow.providers.common.compat.sdk import BaseSensorOperator
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
from __future__ import annotations
|
|
19
19
|
|
|
20
20
|
from collections.abc import Collection, Sequence
|
|
21
|
-
from typing import TYPE_CHECKING, Any
|
|
21
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
22
22
|
|
|
23
23
|
from airflow.configuration import conf
|
|
24
24
|
from airflow.exceptions import AirflowException
|
|
@@ -51,6 +51,8 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
|
51
51
|
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
|
|
52
52
|
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
|
|
53
53
|
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
|
|
54
|
+
:param airflow_version: The Airflow major version the MWAA environment runs.
|
|
55
|
+
This parameter is only used if the local web token method is used to call Airflow API. (templated)
|
|
54
56
|
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
|
55
57
|
module to be installed.
|
|
56
58
|
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
|
@@ -75,6 +77,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
|
75
77
|
"external_dag_run_id",
|
|
76
78
|
"success_states",
|
|
77
79
|
"failure_states",
|
|
80
|
+
"airflow_version",
|
|
78
81
|
"deferrable",
|
|
79
82
|
"max_retries",
|
|
80
83
|
"poke_interval",
|
|
@@ -88,6 +91,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
|
88
91
|
external_dag_run_id: str,
|
|
89
92
|
success_states: Collection[str] | None = None,
|
|
90
93
|
failure_states: Collection[str] | None = None,
|
|
94
|
+
airflow_version: Literal[2, 3] | None = None,
|
|
91
95
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
92
96
|
poke_interval: int = 60,
|
|
93
97
|
max_retries: int = 720,
|
|
@@ -104,6 +108,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
|
104
108
|
self.external_env_name = external_env_name
|
|
105
109
|
self.external_dag_id = external_dag_id
|
|
106
110
|
self.external_dag_run_id = external_dag_run_id
|
|
111
|
+
self.airflow_version = airflow_version
|
|
107
112
|
self.deferrable = deferrable
|
|
108
113
|
self.poke_interval = poke_interval
|
|
109
114
|
self.max_retries = max_retries
|
|
@@ -119,6 +124,7 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
|
|
|
119
124
|
env_name=self.external_env_name,
|
|
120
125
|
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}",
|
|
121
126
|
method="GET",
|
|
127
|
+
airflow_version=self.airflow_version,
|
|
122
128
|
)
|
|
123
129
|
|
|
124
130
|
# If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has
|
|
@@ -179,6 +185,8 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
179
185
|
``{airflow.utils.state.TaskInstanceState.SUCCESS}`` (templated)
|
|
180
186
|
:param failure_states: Collection of task instance states that would make this task marked as failed and raise an
|
|
181
187
|
AirflowException, default is ``{airflow.utils.state.TaskInstanceState.FAILED}`` (templated)
|
|
188
|
+
:param airflow_version: The Airflow major version the MWAA environment runs.
|
|
189
|
+
This parameter is only used if the local web token method is used to call Airflow API. (templated)
|
|
182
190
|
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
|
183
191
|
module to be installed.
|
|
184
192
|
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
|
@@ -204,6 +212,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
204
212
|
"external_task_id",
|
|
205
213
|
"success_states",
|
|
206
214
|
"failure_states",
|
|
215
|
+
"airflow_version",
|
|
207
216
|
"deferrable",
|
|
208
217
|
"max_retries",
|
|
209
218
|
"poke_interval",
|
|
@@ -218,6 +227,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
218
227
|
external_task_id: str,
|
|
219
228
|
success_states: Collection[str] | None = None,
|
|
220
229
|
failure_states: Collection[str] | None = None,
|
|
230
|
+
airflow_version: Literal[2, 3] | None = None,
|
|
221
231
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
|
222
232
|
poke_interval: int = 60,
|
|
223
233
|
max_retries: int = 720,
|
|
@@ -235,6 +245,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
235
245
|
self.external_dag_id = external_dag_id
|
|
236
246
|
self.external_dag_run_id = external_dag_run_id
|
|
237
247
|
self.external_task_id = external_task_id
|
|
248
|
+
self.airflow_version = airflow_version
|
|
238
249
|
self.deferrable = deferrable
|
|
239
250
|
self.poke_interval = poke_interval
|
|
240
251
|
self.max_retries = max_retries
|
|
@@ -252,6 +263,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
252
263
|
env_name=self.external_env_name,
|
|
253
264
|
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}/taskInstances/{self.external_task_id}",
|
|
254
265
|
method="GET",
|
|
266
|
+
airflow_version=self.airflow_version,
|
|
255
267
|
)
|
|
256
268
|
# If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has
|
|
257
269
|
# happened in the API and KeyError would be raised
|
|
@@ -278,6 +290,7 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
|
|
|
278
290
|
env_name=self.external_env_name,
|
|
279
291
|
path=f"/dags/{self.external_dag_id}/dagRuns",
|
|
280
292
|
method="GET",
|
|
293
|
+
airflow_version=self.airflow_version,
|
|
281
294
|
)
|
|
282
295
|
self.external_dag_run_id = response["RestApiResponse"]["dag_runs"][-1]["dag_run_id"]
|
|
283
296
|
|