apache-airflow-providers-amazon 9.2.0rc2__py3-none-any.whl → 9.4.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +0 -52
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +1 -4
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +90 -106
- airflow/providers/amazon/aws/auth_manager/router/login.py +124 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +2 -2
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/utils.py +2 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +6 -1
- airflow/providers/amazon/aws/hooks/batch_client.py +1 -2
- airflow/providers/amazon/aws/hooks/ecr.py +7 -1
- airflow/providers/amazon/aws/hooks/ecs.py +1 -2
- airflow/providers/amazon/aws/hooks/eks.py +10 -3
- airflow/providers/amazon/aws/hooks/emr.py +20 -0
- airflow/providers/amazon/aws/hooks/mwaa.py +85 -0
- airflow/providers/amazon/aws/hooks/sqs.py +4 -0
- airflow/providers/amazon/aws/hooks/ssm.py +10 -1
- airflow/providers/amazon/aws/links/comprehend.py +41 -0
- airflow/providers/amazon/aws/links/datasync.py +37 -0
- airflow/providers/amazon/aws/links/ec2.py +46 -0
- airflow/providers/amazon/aws/links/sagemaker.py +27 -0
- airflow/providers/amazon/aws/operators/athena.py +7 -5
- airflow/providers/amazon/aws/operators/batch.py +16 -8
- airflow/providers/amazon/aws/operators/bedrock.py +20 -18
- airflow/providers/amazon/aws/operators/comprehend.py +52 -11
- airflow/providers/amazon/aws/operators/datasync.py +40 -2
- airflow/providers/amazon/aws/operators/dms.py +0 -4
- airflow/providers/amazon/aws/operators/ec2.py +50 -0
- airflow/providers/amazon/aws/operators/ecs.py +11 -7
- airflow/providers/amazon/aws/operators/eks.py +17 -17
- airflow/providers/amazon/aws/operators/emr.py +27 -27
- airflow/providers/amazon/aws/operators/glue.py +16 -14
- airflow/providers/amazon/aws/operators/glue_crawler.py +3 -3
- airflow/providers/amazon/aws/operators/glue_databrew.py +5 -5
- airflow/providers/amazon/aws/operators/kinesis_analytics.py +9 -9
- airflow/providers/amazon/aws/operators/lambda_function.py +4 -4
- airflow/providers/amazon/aws/operators/mwaa.py +109 -0
- airflow/providers/amazon/aws/operators/rds.py +16 -16
- airflow/providers/amazon/aws/operators/redshift_cluster.py +15 -15
- airflow/providers/amazon/aws/operators/redshift_data.py +4 -4
- airflow/providers/amazon/aws/operators/sagemaker.py +52 -29
- airflow/providers/amazon/aws/operators/sqs.py +6 -0
- airflow/providers/amazon/aws/operators/step_function.py +4 -4
- airflow/providers/amazon/aws/sensors/ec2.py +3 -3
- airflow/providers/amazon/aws/sensors/emr.py +9 -9
- airflow/providers/amazon/aws/sensors/glue.py +7 -7
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +3 -3
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +3 -3
- airflow/providers/amazon/aws/sensors/sqs.py +6 -5
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +8 -3
- airflow/providers/amazon/aws/triggers/README.md +1 -1
- airflow/providers/amazon/aws/triggers/opensearch_serverless.py +2 -1
- airflow/providers/amazon/aws/triggers/sqs.py +2 -1
- airflow/providers/amazon/aws/utils/sqs.py +6 -4
- airflow/providers/amazon/aws/waiters/dms.json +12 -0
- airflow/providers/amazon/get_provider_info.py +106 -87
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/METADATA +16 -34
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/RECORD +61 -55
- airflow/providers/amazon/aws/auth_manager/views/auth.py +0 -151
- /airflow/providers/amazon/aws/auth_manager/{views → router}/__init__.py +0 -0
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.2.0rc2.dist-info → apache_airflow_providers_amazon-9.4.0rc1.dist-info}/entry_points.txt +0 -0
@@ -87,6 +87,7 @@ class GlueJobOperator(BaseOperator):
|
|
87
87
|
"script_location",
|
88
88
|
"script_args",
|
89
89
|
"create_job_kwargs",
|
90
|
+
"run_job_kwargs",
|
90
91
|
"s3_bucket",
|
91
92
|
"iam_role_name",
|
92
93
|
"iam_role_arn",
|
@@ -95,6 +96,7 @@ class GlueJobOperator(BaseOperator):
|
|
95
96
|
template_fields_renderers = {
|
96
97
|
"script_args": "json",
|
97
98
|
"create_job_kwargs": "json",
|
99
|
+
"run_job_kwargs": "json",
|
98
100
|
}
|
99
101
|
ui_color = "#ededed"
|
100
102
|
|
@@ -242,11 +244,11 @@ class GlueJobOperator(BaseOperator):
|
|
242
244
|
return self._job_run_id
|
243
245
|
|
244
246
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
245
|
-
|
247
|
+
validated_event = validate_execute_complete_event(event)
|
246
248
|
|
247
|
-
if
|
248
|
-
raise AirflowException(f"Error in glue job: {
|
249
|
-
return
|
249
|
+
if validated_event["status"] != "success":
|
250
|
+
raise AirflowException(f"Error in glue job: {validated_event}")
|
251
|
+
return validated_event["value"]
|
250
252
|
|
251
253
|
def on_kill(self):
|
252
254
|
"""Cancel the running AWS Glue Job."""
|
@@ -500,18 +502,18 @@ class GlueDataQualityRuleSetEvaluationRunOperator(AwsBaseOperator[GlueDataQualit
|
|
500
502
|
return evaluation_run_id
|
501
503
|
|
502
504
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
503
|
-
|
505
|
+
validated_event = validate_execute_complete_event(event)
|
504
506
|
|
505
|
-
if
|
506
|
-
raise AirflowException(f"Error: AWS Glue data quality ruleset evaluation run: {
|
507
|
+
if validated_event["status"] != "success":
|
508
|
+
raise AirflowException(f"Error: AWS Glue data quality ruleset evaluation run: {validated_event}")
|
507
509
|
|
508
510
|
self.hook.validate_evaluation_run_results(
|
509
|
-
evaluation_run_id=
|
511
|
+
evaluation_run_id=validated_event["evaluation_run_id"],
|
510
512
|
show_results=self.show_results,
|
511
513
|
verify_result_status=self.verify_result_status,
|
512
514
|
)
|
513
515
|
|
514
|
-
return
|
516
|
+
return validated_event["evaluation_run_id"]
|
515
517
|
|
516
518
|
|
517
519
|
class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQualityHook]):
|
@@ -648,12 +650,12 @@ class GlueDataQualityRuleRecommendationRunOperator(AwsBaseOperator[GlueDataQuali
|
|
648
650
|
return recommendation_run_id
|
649
651
|
|
650
652
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
651
|
-
|
653
|
+
validated_event = validate_execute_complete_event(event)
|
652
654
|
|
653
|
-
if
|
654
|
-
raise AirflowException(f"Error: AWS Glue data quality rule recommendation run: {
|
655
|
+
if validated_event["status"] != "success":
|
656
|
+
raise AirflowException(f"Error: AWS Glue data quality rule recommendation run: {validated_event}")
|
655
657
|
|
656
658
|
if self.show_results:
|
657
|
-
self.hook.log_recommendation_results(run_id=
|
659
|
+
self.hook.log_recommendation_results(run_id=validated_event["recommendation_run_id"])
|
658
660
|
|
659
|
-
return
|
661
|
+
return validated_event["recommendation_run_id"]
|
@@ -117,8 +117,8 @@ class GlueCrawlerOperator(AwsBaseOperator[GlueCrawlerHook]):
|
|
117
117
|
return crawler_name
|
118
118
|
|
119
119
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
120
|
-
|
120
|
+
validated_event = validate_execute_complete_event(event)
|
121
121
|
|
122
|
-
if
|
123
|
-
raise AirflowException(f"Error in glue crawl: {
|
122
|
+
if validated_event["status"] != "success":
|
123
|
+
raise AirflowException(f"Error in glue crawl: {validated_event}")
|
124
124
|
return self.config["Name"]
|
@@ -129,13 +129,13 @@ class GlueDataBrewStartJobOperator(AwsBaseOperator[GlueDataBrewHook]):
|
|
129
129
|
return {"run_id": run_id}
|
130
130
|
|
131
131
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
|
132
|
-
|
132
|
+
validated_event = validate_execute_complete_event(event)
|
133
133
|
|
134
|
-
if
|
135
|
-
raise AirflowException("Error while running AWS Glue DataBrew job: %s",
|
134
|
+
if validated_event["status"] != "success":
|
135
|
+
raise AirflowException("Error while running AWS Glue DataBrew job: %s", validated_event)
|
136
136
|
|
137
|
-
run_id =
|
138
|
-
status =
|
137
|
+
run_id = validated_event.get("run_id", "")
|
138
|
+
status = validated_event.get("status", "")
|
139
139
|
|
140
140
|
self.log.info("AWS Glue DataBrew runID: %s completed with status: %s", run_id, status)
|
141
141
|
|
@@ -215,20 +215,20 @@ class KinesisAnalyticsV2StartApplicationOperator(AwsBaseOperator[KinesisAnalytic
|
|
215
215
|
return {"ApplicationARN": describe_response["ApplicationDetail"]["ApplicationARN"]}
|
216
216
|
|
217
217
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]:
|
218
|
-
|
218
|
+
validated_event = validate_execute_complete_event(event)
|
219
219
|
|
220
|
-
if
|
220
|
+
if validated_event["status"] != "success":
|
221
221
|
raise AirflowException(
|
222
|
-
"Error while starting AWS Managed Service for Apache Flink application: %s",
|
222
|
+
"Error while starting AWS Managed Service for Apache Flink application: %s", validated_event
|
223
223
|
)
|
224
224
|
|
225
225
|
response = self.hook.conn.describe_application(
|
226
|
-
ApplicationName=
|
226
|
+
ApplicationName=validated_event["application_name"],
|
227
227
|
)
|
228
228
|
|
229
229
|
self.log.info(
|
230
230
|
"AWS Managed Service for Apache Flink application %s started successfully.",
|
231
|
-
|
231
|
+
validated_event["application_name"],
|
232
232
|
)
|
233
233
|
|
234
234
|
return {"ApplicationARN": response["ApplicationDetail"]["ApplicationARN"]}
|
@@ -332,18 +332,18 @@ class KinesisAnalyticsV2StopApplicationOperator(AwsBaseOperator[KinesisAnalytics
|
|
332
332
|
return {"ApplicationARN": describe_response["ApplicationDetail"]["ApplicationARN"]}
|
333
333
|
|
334
334
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, Any]:
|
335
|
-
|
335
|
+
validated_event = validate_execute_complete_event(event)
|
336
336
|
|
337
|
-
if
|
337
|
+
if validated_event["status"] != "success":
|
338
338
|
raise AirflowException("Error while stopping AWS Managed Service for Apache Flink application")
|
339
339
|
|
340
340
|
response = self.hook.conn.describe_application(
|
341
|
-
ApplicationName=
|
341
|
+
ApplicationName=validated_event["application_name"],
|
342
342
|
)
|
343
343
|
|
344
344
|
self.log.info(
|
345
345
|
"AWS Managed Service for Apache Flink application %s stopped successfully.",
|
346
|
-
|
346
|
+
validated_event["application_name"],
|
347
347
|
)
|
348
348
|
|
349
349
|
return {"ApplicationARN": response["ApplicationDetail"]["ApplicationARN"]}
|
@@ -145,13 +145,13 @@ class LambdaCreateFunctionOperator(AwsBaseOperator[LambdaHook]):
|
|
145
145
|
return response.get("FunctionArn")
|
146
146
|
|
147
147
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
148
|
-
|
148
|
+
validated_event = validate_execute_complete_event(event)
|
149
149
|
|
150
|
-
if not
|
151
|
-
raise AirflowException(f"Trigger error: event is {
|
150
|
+
if not validated_event or validated_event["status"] != "success":
|
151
|
+
raise AirflowException(f"Trigger error: event is {validated_event}")
|
152
152
|
|
153
153
|
self.log.info("Lambda function created successfully")
|
154
|
-
return
|
154
|
+
return validated_event["function_arn"]
|
155
155
|
|
156
156
|
|
157
157
|
class LambdaInvokeFunctionOperator(AwsBaseOperator[LambdaHook]):
|
@@ -0,0 +1,109 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
"""This module contains AWS MWAA operators."""
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
from collections.abc import Sequence
|
22
|
+
from typing import TYPE_CHECKING
|
23
|
+
|
24
|
+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
|
25
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
26
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from airflow.utils.context import Context
|
30
|
+
|
31
|
+
|
32
|
+
class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
|
33
|
+
"""
|
34
|
+
Trigger a Dag Run for a Dag in an Amazon MWAA environment.
|
35
|
+
|
36
|
+
.. seealso::
|
37
|
+
For more information on how to use this operator, take a look at the guide:
|
38
|
+
:ref:`howto/operator:MwaaTriggerDagRunOperator`
|
39
|
+
|
40
|
+
:param env_name: The MWAA environment name (templated)
|
41
|
+
:param trigger_dag_id: The ID of the DAG to be triggered (templated)
|
42
|
+
:param trigger_run_id: The Run ID. This together with trigger_dag_id are a unique key. (templated)
|
43
|
+
:param logical_date: The logical date (previously called execution date). This is the time or interval
|
44
|
+
covered by this DAG run, according to the DAG definition. This together with trigger_dag_id are a
|
45
|
+
unique key. (templated)
|
46
|
+
:param data_interval_start: The beginning of the interval the DAG run covers
|
47
|
+
:param data_interval_end: The end of the interval the DAG run covers
|
48
|
+
:param conf: Additional configuration parameters. The value of this field can be set only when creating
|
49
|
+
the object. (templated)
|
50
|
+
:param note: Contains manually entered notes by the user about the DagRun. (templated)
|
51
|
+
"""
|
52
|
+
|
53
|
+
aws_hook_class = MwaaHook
|
54
|
+
template_fields: Sequence[str] = aws_template_fields(
|
55
|
+
"env_name",
|
56
|
+
"trigger_dag_id",
|
57
|
+
"trigger_run_id",
|
58
|
+
"logical_date",
|
59
|
+
"data_interval_start",
|
60
|
+
"data_interval_end",
|
61
|
+
"conf",
|
62
|
+
"note",
|
63
|
+
)
|
64
|
+
template_fields_renderers = {"conf": "json"}
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
*,
|
69
|
+
env_name: str,
|
70
|
+
trigger_dag_id: str,
|
71
|
+
trigger_run_id: str | None = None,
|
72
|
+
logical_date: str | None = None,
|
73
|
+
data_interval_start: str | None = None,
|
74
|
+
data_interval_end: str | None = None,
|
75
|
+
conf: dict | None = None,
|
76
|
+
note: str | None = None,
|
77
|
+
**kwargs,
|
78
|
+
):
|
79
|
+
super().__init__(**kwargs)
|
80
|
+
self.env_name = env_name
|
81
|
+
self.trigger_dag_id = trigger_dag_id
|
82
|
+
self.trigger_run_id = trigger_run_id
|
83
|
+
self.logical_date = logical_date
|
84
|
+
self.data_interval_start = data_interval_start
|
85
|
+
self.data_interval_end = data_interval_end
|
86
|
+
self.conf = conf if conf else {}
|
87
|
+
self.note = note
|
88
|
+
|
89
|
+
def execute(self, context: Context) -> dict:
|
90
|
+
"""
|
91
|
+
Trigger a Dag Run for the Dag in the Amazon MWAA environment.
|
92
|
+
|
93
|
+
:param context: the Context object
|
94
|
+
:return: dict with information about the Dag run
|
95
|
+
For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api`
|
96
|
+
"""
|
97
|
+
return self.hook.invoke_rest_api(
|
98
|
+
env_name=self.env_name,
|
99
|
+
path=f"/dags/{self.trigger_dag_id}/dagRuns",
|
100
|
+
method="POST",
|
101
|
+
body={
|
102
|
+
"dag_run_id": self.trigger_run_id,
|
103
|
+
"logical_date": self.logical_date,
|
104
|
+
"data_interval_start": self.data_interval_start,
|
105
|
+
"data_interval_end": self.data_interval_end,
|
106
|
+
"conf": self.conf,
|
107
|
+
"note": self.note,
|
108
|
+
},
|
109
|
+
)
|
@@ -627,12 +627,12 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
627
627
|
return json.dumps(create_db_instance, default=str)
|
628
628
|
|
629
629
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
630
|
-
|
630
|
+
validated_event = validate_execute_complete_event(event)
|
631
631
|
|
632
|
-
if
|
633
|
-
raise AirflowException(f"DB instance creation failed: {
|
632
|
+
if validated_event["status"] != "success":
|
633
|
+
raise AirflowException(f"DB instance creation failed: {validated_event}")
|
634
634
|
|
635
|
-
return json.dumps(
|
635
|
+
return json.dumps(validated_event["response"], default=str)
|
636
636
|
|
637
637
|
|
638
638
|
class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
@@ -712,12 +712,12 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
712
712
|
return json.dumps(delete_db_instance, default=str)
|
713
713
|
|
714
714
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
715
|
-
|
715
|
+
validated_event = validate_execute_complete_event(event)
|
716
716
|
|
717
|
-
if
|
718
|
-
raise AirflowException(f"DB instance deletion failed: {
|
717
|
+
if validated_event["status"] != "success":
|
718
|
+
raise AirflowException(f"DB instance deletion failed: {validated_event}")
|
719
719
|
|
720
|
-
return json.dumps(
|
720
|
+
return json.dumps(validated_event["response"], default=str)
|
721
721
|
|
722
722
|
|
723
723
|
class RdsStartDbOperator(RdsBaseOperator):
|
@@ -779,12 +779,12 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
779
779
|
return json.dumps(start_db_response, default=str)
|
780
780
|
|
781
781
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
782
|
-
|
782
|
+
validated_event = validate_execute_complete_event(event)
|
783
783
|
|
784
|
-
if
|
785
|
-
raise AirflowException(f"Failed to start DB: {
|
784
|
+
if validated_event["status"] != "success":
|
785
|
+
raise AirflowException(f"Failed to start DB: {validated_event}")
|
786
786
|
|
787
|
-
return json.dumps(
|
787
|
+
return json.dumps(validated_event["response"], default=str)
|
788
788
|
|
789
789
|
def _start_db(self):
|
790
790
|
self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
|
@@ -891,12 +891,12 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
891
891
|
return json.dumps(stop_db_response, default=str)
|
892
892
|
|
893
893
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
894
|
-
|
894
|
+
validated_event = validate_execute_complete_event(event)
|
895
895
|
|
896
|
-
if
|
897
|
-
raise AirflowException(f"Failed to start DB: {
|
896
|
+
if validated_event["status"] != "success":
|
897
|
+
raise AirflowException(f"Failed to start DB: {validated_event}")
|
898
898
|
|
899
|
-
return json.dumps(
|
899
|
+
return json.dumps(validated_event["response"], default=str)
|
900
900
|
|
901
901
|
def _stop_db(self):
|
902
902
|
self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
|
@@ -321,10 +321,10 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
321
321
|
self.log.info(cluster)
|
322
322
|
|
323
323
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
324
|
-
|
324
|
+
validated_event = validate_execute_complete_event(event)
|
325
325
|
|
326
|
-
if
|
327
|
-
raise AirflowException(f"Error creating cluster: {
|
326
|
+
if validated_event["status"] != "success":
|
327
|
+
raise AirflowException(f"Error creating cluster: {validated_event}")
|
328
328
|
|
329
329
|
|
330
330
|
class RedshiftCreateClusterSnapshotOperator(BaseOperator):
|
@@ -417,10 +417,10 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator):
|
|
417
417
|
)
|
418
418
|
|
419
419
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
420
|
-
|
420
|
+
validated_event = validate_execute_complete_event(event)
|
421
421
|
|
422
|
-
if
|
423
|
-
raise AirflowException(f"Error creating snapshot: {
|
422
|
+
if validated_event["status"] != "success":
|
423
|
+
raise AirflowException(f"Error creating snapshot: {validated_event}")
|
424
424
|
|
425
425
|
self.log.info("Cluster snapshot created.")
|
426
426
|
|
@@ -577,10 +577,10 @@ class RedshiftResumeClusterOperator(BaseOperator):
|
|
577
577
|
)
|
578
578
|
|
579
579
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
580
|
-
|
580
|
+
validated_event = validate_execute_complete_event(event)
|
581
581
|
|
582
|
-
if
|
583
|
-
raise AirflowException(f"Error resuming cluster: {
|
582
|
+
if validated_event["status"] != "success":
|
583
|
+
raise AirflowException(f"Error resuming cluster: {validated_event}")
|
584
584
|
self.log.info("Resumed cluster successfully")
|
585
585
|
|
586
586
|
|
@@ -683,10 +683,10 @@ class RedshiftPauseClusterOperator(BaseOperator):
|
|
683
683
|
)
|
684
684
|
|
685
685
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
686
|
-
|
686
|
+
validated_event = validate_execute_complete_event(event)
|
687
687
|
|
688
|
-
if
|
689
|
-
raise AirflowException(f"Error pausing cluster: {
|
688
|
+
if validated_event["status"] != "success":
|
689
|
+
raise AirflowException(f"Error pausing cluster: {validated_event}")
|
690
690
|
self.log.info("Paused cluster successfully")
|
691
691
|
|
692
692
|
|
@@ -792,9 +792,9 @@ class RedshiftDeleteClusterOperator(BaseOperator):
|
|
792
792
|
)
|
793
793
|
|
794
794
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
795
|
-
|
795
|
+
validated_event = validate_execute_complete_event(event)
|
796
796
|
|
797
|
-
if
|
798
|
-
raise AirflowException(f"Error deleting cluster: {
|
797
|
+
if validated_event["status"] != "success":
|
798
|
+
raise AirflowException(f"Error deleting cluster: {validated_event}")
|
799
799
|
|
800
800
|
self.log.info("Cluster deleted successfully")
|
@@ -185,13 +185,13 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
|
|
185
185
|
def execute_complete(
|
186
186
|
self, context: Context, event: dict[str, Any] | None = None
|
187
187
|
) -> list[GetStatementResultResponseTypeDef] | list[str]:
|
188
|
-
|
188
|
+
validated_event = validate_execute_complete_event(event)
|
189
189
|
|
190
|
-
if
|
191
|
-
msg = f"context: {context}, error message: {
|
190
|
+
if validated_event["status"] == "error":
|
191
|
+
msg = f"context: {context}, error message: {validated_event['message']}"
|
192
192
|
raise AirflowException(msg)
|
193
193
|
|
194
|
-
statement_id =
|
194
|
+
statement_id = validated_event["statement_id"]
|
195
195
|
if not statement_id:
|
196
196
|
raise AirflowException("statement_id should not be empty.")
|
197
197
|
|
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
import datetime
|
20
20
|
import json
|
21
21
|
import time
|
22
|
+
import urllib
|
22
23
|
from collections.abc import Sequence
|
23
24
|
from functools import cached_property
|
24
25
|
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
@@ -34,6 +35,7 @@ from airflow.providers.amazon.aws.hooks.sagemaker import (
|
|
34
35
|
SageMakerHook,
|
35
36
|
secondary_training_status_message,
|
36
37
|
)
|
38
|
+
from airflow.providers.amazon.aws.links.sagemaker import SageMakerTransformJobLink
|
37
39
|
from airflow.providers.amazon.aws.triggers.sagemaker import (
|
38
40
|
SageMakerPipelineTrigger,
|
39
41
|
SageMakerTrigger,
|
@@ -165,7 +167,11 @@ class SageMakerBaseOperator(BaseOperator):
|
|
165
167
|
if fail_if_exists:
|
166
168
|
raise AirflowException(f"A SageMaker {resource_type} with name {name} already exists.")
|
167
169
|
else:
|
168
|
-
|
170
|
+
max_name_len = 63
|
171
|
+
timestamp = str(
|
172
|
+
time.time_ns() // 1000000000
|
173
|
+
) # only keep the relevant datetime (first 10 digits)
|
174
|
+
name = f"{proposed_name[:max_name_len - len(timestamp) - 1]}-{timestamp}" # we subtract one to make provision for the dash between the truncated name and timestamp
|
169
175
|
self.log.info("Changed %s name to '%s' to avoid collision.", resource_type, name)
|
170
176
|
return name
|
171
177
|
|
@@ -346,13 +352,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
|
|
346
352
|
return {"Processing": self.serialized_job}
|
347
353
|
|
348
354
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
349
|
-
|
355
|
+
validated_event = validate_execute_complete_event(event)
|
350
356
|
|
351
|
-
if
|
352
|
-
raise AirflowException(f"Error while running job: {
|
357
|
+
if validated_event["status"] != "success":
|
358
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
353
359
|
|
354
|
-
self.log.info(
|
355
|
-
self.serialized_job = serialize(self.hook.describe_processing_job(
|
360
|
+
self.log.info(validated_event["message"])
|
361
|
+
self.serialized_job = serialize(self.hook.describe_processing_job(validated_event["job_name"]))
|
356
362
|
self.log.info("%s completed successfully.", self.task_id)
|
357
363
|
return {"Processing": self.serialized_job}
|
358
364
|
|
@@ -599,12 +605,12 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
|
|
599
605
|
}
|
600
606
|
|
601
607
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
602
|
-
|
608
|
+
validated_event = validate_execute_complete_event(event)
|
603
609
|
|
604
|
-
if
|
605
|
-
raise AirflowException(f"Error while running job: {
|
610
|
+
if validated_event["status"] != "success":
|
611
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
606
612
|
|
607
|
-
response = self.hook.describe_endpoint(
|
613
|
+
response = self.hook.describe_endpoint(validated_event["job_name"])
|
608
614
|
return {
|
609
615
|
"EndpointConfig": serialize(self.hook.describe_endpoint_config(response["EndpointConfigName"])),
|
610
616
|
"Endpoint": serialize(self.hook.describe_endpoint(response["EndpointName"])),
|
@@ -655,6 +661,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
655
661
|
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
|
656
662
|
"""
|
657
663
|
|
664
|
+
operator_extra_links = (SageMakerTransformJobLink(),)
|
665
|
+
|
658
666
|
def __init__(
|
659
667
|
self,
|
660
668
|
*,
|
@@ -761,6 +769,21 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
761
769
|
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
762
770
|
raise AirflowException(f"Sagemaker transform Job creation failed: {response}")
|
763
771
|
|
772
|
+
transform_job_url = SageMakerTransformJobLink.format_str.format(
|
773
|
+
aws_domain=SageMakerTransformJobLink.get_aws_domain(self.hook.conn_partition),
|
774
|
+
region_name=self.hook.conn_region_name,
|
775
|
+
job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""),
|
776
|
+
)
|
777
|
+
SageMakerTransformJobLink.persist(
|
778
|
+
context=context,
|
779
|
+
operator=self,
|
780
|
+
region_name=self.hook.conn_region_name,
|
781
|
+
aws_partition=self.hook.conn_partition,
|
782
|
+
job_name=urllib.parse.quote(transform_config["TransformJobName"], safe=""),
|
783
|
+
)
|
784
|
+
|
785
|
+
self.log.info("You can monitor this SageMaker Transform job at %s", transform_job_url)
|
786
|
+
|
764
787
|
if self.deferrable and self.wait_for_completion:
|
765
788
|
response = self.hook.describe_transform_job(transform_config["TransformJobName"])
|
766
789
|
status = response["TransformJobStatus"]
|
@@ -804,10 +827,10 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
804
827
|
return self._check_if_resource_exists(model_name, "model", describe_func)
|
805
828
|
|
806
829
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
807
|
-
|
830
|
+
validated_event = validate_execute_complete_event(event)
|
808
831
|
|
809
|
-
self.log.info(
|
810
|
-
return self.serialize_result(
|
832
|
+
self.log.info(validated_event["message"])
|
833
|
+
return self.serialize_result(validated_event["job_name"])
|
811
834
|
|
812
835
|
def serialize_result(self, job_name: str) -> dict[str, dict]:
|
813
836
|
job_description = self.hook.describe_transform_job(job_name)
|
@@ -976,11 +999,11 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
|
|
976
999
|
return {"Tuning": serialize(description)}
|
977
1000
|
|
978
1001
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
979
|
-
|
1002
|
+
validated_event = validate_execute_complete_event(event)
|
980
1003
|
|
981
|
-
if
|
982
|
-
raise AirflowException(f"Error while running job: {
|
983
|
-
return {"Tuning": serialize(self.hook.describe_tuning_job(
|
1004
|
+
if validated_event["status"] != "success":
|
1005
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
1006
|
+
return {"Tuning": serialize(self.hook.describe_tuning_job(validated_event["job_name"]))}
|
984
1007
|
|
985
1008
|
|
986
1009
|
class SageMakerModelOperator(SageMakerBaseOperator):
|
@@ -1188,13 +1211,13 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1188
1211
|
return self.serialize_result(self.config["TrainingJobName"])
|
1189
1212
|
|
1190
1213
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
1191
|
-
|
1214
|
+
validated_event = validate_execute_complete_event(event)
|
1192
1215
|
|
1193
|
-
if
|
1194
|
-
raise AirflowException(f"Error while running job: {
|
1216
|
+
if validated_event["status"] != "success":
|
1217
|
+
raise AirflowException(f"Error while running job: {validated_event}")
|
1195
1218
|
|
1196
|
-
self.log.info(
|
1197
|
-
return self.serialize_result(
|
1219
|
+
self.log.info(validated_event["message"])
|
1220
|
+
return self.serialize_result(validated_event["job_name"])
|
1198
1221
|
|
1199
1222
|
def serialize_result(self, job_name: str) -> dict[str, dict]:
|
1200
1223
|
self.serialized_training_data = serialize(self.hook.describe_training_job(job_name))
|
@@ -1330,11 +1353,11 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
|
|
1330
1353
|
return arn
|
1331
1354
|
|
1332
1355
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
1333
|
-
|
1356
|
+
validated_event = validate_execute_complete_event(event)
|
1334
1357
|
|
1335
|
-
if
|
1336
|
-
raise AirflowException(f"Failure during pipeline execution: {
|
1337
|
-
return
|
1358
|
+
if validated_event["status"] != "success":
|
1359
|
+
raise AirflowException(f"Failure during pipeline execution: {validated_event}")
|
1360
|
+
return validated_event["value"]
|
1338
1361
|
|
1339
1362
|
|
1340
1363
|
class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
@@ -1425,10 +1448,10 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
|
|
1425
1448
|
return status
|
1426
1449
|
|
1427
1450
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
1428
|
-
|
1451
|
+
validated_event = validate_execute_complete_event(event)
|
1429
1452
|
|
1430
|
-
if
|
1431
|
-
raise AirflowException(f"Failure during pipeline execution: {
|
1453
|
+
if validated_event["status"] != "success":
|
1454
|
+
raise AirflowException(f"Failure during pipeline execution: {validated_event}")
|
1432
1455
|
|
1433
1456
|
# theoretically we should do a `describe` call to know this,
|
1434
1457
|
# but if we reach this point, this is the only possible status
|
@@ -44,6 +44,8 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
|
|
44
44
|
:param delay_seconds: message delay (templated) (default: 1 second)
|
45
45
|
:param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None)
|
46
46
|
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
|
47
|
+
:param message_deduplication_id: This applies only to FIFO (first-in-first-out) queues.
|
48
|
+
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
|
47
49
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
48
50
|
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
49
51
|
running Airflow in a distributed manner and aws_conn_id is None or
|
@@ -63,6 +65,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
|
|
63
65
|
"delay_seconds",
|
64
66
|
"message_attributes",
|
65
67
|
"message_group_id",
|
68
|
+
"message_deduplication_id",
|
66
69
|
)
|
67
70
|
template_fields_renderers = {"message_attributes": "json"}
|
68
71
|
ui_color = "#6ad3fa"
|
@@ -75,6 +78,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
|
|
75
78
|
message_attributes: dict | None = None,
|
76
79
|
delay_seconds: int = 0,
|
77
80
|
message_group_id: str | None = None,
|
81
|
+
message_deduplication_id: str | None = None,
|
78
82
|
**kwargs,
|
79
83
|
):
|
80
84
|
super().__init__(**kwargs)
|
@@ -83,6 +87,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
|
|
83
87
|
self.delay_seconds = delay_seconds
|
84
88
|
self.message_attributes = message_attributes or {}
|
85
89
|
self.message_group_id = message_group_id
|
90
|
+
self.message_deduplication_id = message_deduplication_id
|
86
91
|
|
87
92
|
def execute(self, context: Context) -> dict:
|
88
93
|
"""
|
@@ -98,6 +103,7 @@ class SqsPublishOperator(AwsBaseOperator[SqsHook]):
|
|
98
103
|
delay_seconds=self.delay_seconds,
|
99
104
|
message_attributes=self.message_attributes,
|
100
105
|
message_group_id=self.message_group_id,
|
106
|
+
message_deduplication_id=self.message_deduplication_id,
|
101
107
|
)
|
102
108
|
|
103
109
|
self.log.info("send_message result: %s", result)
|