apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
  3. airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
  4. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
  5. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
  7. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  8. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
  9. airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
  10. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
  11. airflow/providers/amazon/aws/hooks/athena.py +15 -2
  12. airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
  13. airflow/providers/amazon/aws/hooks/emr.py +6 -0
  14. airflow/providers/amazon/aws/hooks/logs.py +85 -1
  15. airflow/providers/amazon/aws/hooks/neptune.py +85 -0
  16. airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
  17. airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
  18. airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
  19. airflow/providers/amazon/aws/hooks/s3.py +4 -6
  20. airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
  21. airflow/providers/amazon/aws/links/emr.py +122 -2
  22. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  23. airflow/providers/amazon/aws/operators/athena.py +4 -1
  24. airflow/providers/amazon/aws/operators/batch.py +5 -6
  25. airflow/providers/amazon/aws/operators/ecs.py +6 -2
  26. airflow/providers/amazon/aws/operators/eks.py +31 -26
  27. airflow/providers/amazon/aws/operators/emr.py +192 -26
  28. airflow/providers/amazon/aws/operators/glue.py +5 -2
  29. airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
  30. airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
  31. airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
  32. airflow/providers/amazon/aws/operators/neptune.py +218 -0
  33. airflow/providers/amazon/aws/operators/rds.py +21 -12
  34. airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
  35. airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
  36. airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
  37. airflow/providers/amazon/aws/operators/step_function.py +4 -1
  38. airflow/providers/amazon/aws/sensors/batch.py +2 -2
  39. airflow/providers/amazon/aws/sensors/ec2.py +4 -2
  40. airflow/providers/amazon/aws/sensors/emr.py +13 -6
  41. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
  42. airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
  43. airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
  44. airflow/providers/amazon/aws/sensors/s3.py +3 -0
  45. airflow/providers/amazon/aws/sensors/sqs.py +4 -1
  46. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  47. airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
  48. airflow/providers/amazon/aws/triggers/neptune.py +115 -0
  49. airflow/providers/amazon/aws/triggers/rds.py +9 -7
  50. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
  51. airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
  52. airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
  53. airflow/providers/amazon/aws/utils/__init__.py +10 -0
  54. airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
  55. airflow/providers/amazon/aws/utils/mixins.py +5 -1
  56. airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
  57. airflow/providers/amazon/aws/waiters/neptune.json +85 -0
  58. airflow/providers/amazon/get_provider_info.py +26 -2
  59. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
  60. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
  61. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
  62. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,218 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Any, Sequence
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
24
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
25
+ from airflow.providers.amazon.aws.triggers.neptune import (
26
+ NeptuneClusterAvailableTrigger,
27
+ NeptuneClusterStoppedTrigger,
28
+ )
29
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
+
31
+ if TYPE_CHECKING:
32
+ from airflow.utils.context import Context
33
+
34
+
35
+ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
36
+ """Starts an Amazon Neptune DB cluster.
37
+
38
+ Amazon Neptune Database is a serverless graph database designed for superior scalability
39
+ and availability. Neptune Database provides built-in security, continuous backups, and
40
+ integrations with other AWS services
41
+
42
+ .. seealso::
43
+ For more information on how to use this operator, take a look at the guide:
44
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
45
+
46
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be started.
47
+ :param wait_for_completion: Whether to wait for the cluster to start. (default: True)
48
+ :param deferrable: If True, the operator will wait asynchronously for the cluster to start.
49
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
50
+ (default: False)
51
+ :param waiter_delay: Time in seconds to wait between status checks.
52
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion.
53
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
54
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
55
+ running Airflow in a distributed manner and aws_conn_id is None or
56
+ empty, then default boto3 configuration would be used (and must be
57
+ maintained on each worker node).
58
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
59
+
60
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
61
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
62
+ :return: dictionary with Neptune cluster id
63
+ """
64
+
65
+ aws_hook_class = NeptuneHook
66
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
67
+
68
+ def __init__(
69
+ self,
70
+ db_cluster_id: str,
71
+ wait_for_completion: bool = True,
72
+ waiter_delay: int = 30,
73
+ waiter_max_attempts: int = 60,
74
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
75
+ **kwargs,
76
+ ):
77
+ super().__init__(**kwargs)
78
+ self.cluster_id = db_cluster_id
79
+ self.wait_for_completion = wait_for_completion
80
+ self.deferrable = deferrable
81
+ self.delay = waiter_delay
82
+ self.max_attempts = waiter_max_attempts
83
+
84
+ def execute(self, context: Context) -> dict[str, str]:
85
+ self.log.info("Starting Neptune cluster: %s", self.cluster_id)
86
+
87
+ # Check to make sure the cluster is not already available.
88
+ status = self.hook.get_cluster_status(self.cluster_id)
89
+ if status.lower() in NeptuneHook.AVAILABLE_STATES:
90
+ self.log.info("Neptune cluster %s is already available.", self.cluster_id)
91
+ return {"db_cluster_id": self.cluster_id}
92
+
93
+ resp = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
94
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
95
+
96
+ if self.deferrable:
97
+ self.log.info("Deferring for cluster start: %s", self.cluster_id)
98
+
99
+ self.defer(
100
+ trigger=NeptuneClusterAvailableTrigger(
101
+ aws_conn_id=self.aws_conn_id,
102
+ db_cluster_id=self.cluster_id,
103
+ waiter_delay=self.delay,
104
+ waiter_max_attempts=self.max_attempts,
105
+ ),
106
+ method_name="execute_complete",
107
+ )
108
+
109
+ elif self.wait_for_completion:
110
+ self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
111
+ self.hook.wait_for_cluster_availability(self.cluster_id, self.delay, self.max_attempts)
112
+
113
+ return {"db_cluster_id": self.cluster_id}
114
+
115
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
116
+ status = ""
117
+ cluster_id = ""
118
+
119
+ if event:
120
+ status = event.get("status", "")
121
+ cluster_id = event.get("cluster_id", "")
122
+
123
+ self.log.info("Neptune cluster %s available with status: %s", cluster_id, status)
124
+
125
+ return {"db_cluster_id": cluster_id}
126
+
127
+
128
+ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
129
+ """
130
+ Stops an Amazon Neptune DB cluster.
131
+
132
+ Amazon Neptune Database is a serverless graph database designed for superior scalability
133
+ and availability. Neptune Database provides built-in security, continuous backups, and
134
+ integrations with other AWS services
135
+
136
+ .. seealso::
137
+ For more information on how to use this operator, take a look at the guide:
138
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
139
+
140
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster to be stopped.
141
+ :param wait_for_completion: Whether to wait for cluster to stop. (default: True)
142
+ :param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
143
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
144
+ (default: False)
145
+ :param waiter_delay: Time in seconds to wait between status checks.
146
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion.
147
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
148
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
149
+ running Airflow in a distributed manner and aws_conn_id is None or
150
+ empty, then default boto3 configuration would be used (and must be
151
+ maintained on each worker node).
152
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
153
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
154
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
155
+ :return: dictionary with Neptune cluster id
156
+ """
157
+
158
+ aws_hook_class = NeptuneHook
159
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
160
+
161
+ def __init__(
162
+ self,
163
+ db_cluster_id: str,
164
+ wait_for_completion: bool = True,
165
+ waiter_delay: int = 30,
166
+ waiter_max_attempts: int = 60,
167
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
168
+ **kwargs,
169
+ ):
170
+ super().__init__(**kwargs)
171
+ self.cluster_id = db_cluster_id
172
+ self.wait_for_completion = wait_for_completion
173
+ self.deferrable = deferrable
174
+ self.delay = waiter_delay
175
+ self.max_attempts = waiter_max_attempts
176
+
177
+ def execute(self, context: Context) -> dict[str, str]:
178
+ self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
179
+
180
+ # Check to make sure the cluster is not already stopped.
181
+ status = self.hook.get_cluster_status(self.cluster_id)
182
+ if status.lower() in NeptuneHook.STOPPED_STATES:
183
+ self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
184
+ return {"db_cluster_id": self.cluster_id}
185
+
186
+ resp = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
187
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
188
+
189
+ if self.deferrable:
190
+ self.log.info("Deferring for cluster stop: %s", self.cluster_id)
191
+
192
+ self.defer(
193
+ trigger=NeptuneClusterStoppedTrigger(
194
+ aws_conn_id=self.aws_conn_id,
195
+ db_cluster_id=self.cluster_id,
196
+ waiter_delay=self.delay,
197
+ waiter_max_attempts=self.max_attempts,
198
+ ),
199
+ method_name="execute_complete",
200
+ )
201
+
202
+ elif self.wait_for_completion:
203
+ self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
204
+ self.hook.wait_for_cluster_stopped(self.cluster_id, self.delay, self.max_attempts)
205
+
206
+ return {"db_cluster_id": self.cluster_id}
207
+
208
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
209
+ status = ""
210
+ cluster_id = ""
211
+
212
+ if event:
213
+ status = event.get("status", "")
214
+ cluster_id = event.get("cluster_id", "")
215
+
216
+ self.log.info("Neptune cluster %s stopped with status: %s", cluster_id, status)
217
+
218
+ return {"db_cluster_id": cluster_id}
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.rds import (
32
32
  RdsDbDeletedTrigger,
33
33
  RdsDbStoppedTrigger,
34
34
  )
35
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
35
36
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
36
37
  from airflow.providers.amazon.aws.utils.tags import format_tags
37
38
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -637,11 +638,13 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
637
638
  )
638
639
  return json.dumps(create_db_instance, default=str)
639
640
 
640
- def execute_complete(self, context, event=None) -> str:
641
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
642
+ event = validate_execute_complete_event(event)
643
+
641
644
  if event["status"] != "success":
642
645
  raise AirflowException(f"DB instance creation failed: {event}")
643
- else:
644
- return json.dumps(event["response"], default=str)
646
+
647
+ return json.dumps(event["response"], default=str)
645
648
 
646
649
 
647
650
  class RdsDeleteDbInstanceOperator(RdsBaseOperator):
@@ -720,11 +723,13 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
720
723
  )
721
724
  return json.dumps(delete_db_instance, default=str)
722
725
 
723
- def execute_complete(self, context, event=None) -> str:
726
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
727
+ event = validate_execute_complete_event(event)
728
+
724
729
  if event["status"] != "success":
725
730
  raise AirflowException(f"DB instance deletion failed: {event}")
726
- else:
727
- return json.dumps(event["response"], default=str)
731
+
732
+ return json.dumps(event["response"], default=str)
728
733
 
729
734
 
730
735
  class RdsStartDbOperator(RdsBaseOperator):
@@ -786,10 +791,12 @@ class RdsStartDbOperator(RdsBaseOperator):
786
791
  return json.dumps(start_db_response, default=str)
787
792
 
788
793
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
789
- if event is None or event["status"] != "success":
794
+ event = validate_execute_complete_event(event)
795
+
796
+ if event["status"] != "success":
790
797
  raise AirflowException(f"Failed to start DB: {event}")
791
- else:
792
- return json.dumps(event["response"], default=str)
798
+
799
+ return json.dumps(event["response"], default=str)
793
800
 
794
801
  def _start_db(self):
795
802
  self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
@@ -883,10 +890,12 @@ class RdsStopDbOperator(RdsBaseOperator):
883
890
  return json.dumps(stop_db_response, default=str)
884
891
 
885
892
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
886
- if event is None or event["status"] != "success":
893
+ event = validate_execute_complete_event(event)
894
+
895
+ if event["status"] != "success":
887
896
  raise AirflowException(f"Failed to start DB: {event}")
888
- else:
889
- return json.dumps(event["response"], default=str)
897
+
898
+ return json.dumps(event["response"], default=str)
890
899
 
891
900
  def _stop_db(self):
892
901
  self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
@@ -31,6 +31,7 @@ from airflow.providers.amazon.aws.triggers.redshift_cluster import (
31
31
  RedshiftPauseClusterTrigger,
32
32
  RedshiftResumeClusterTrigger,
33
33
  )
34
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
34
35
 
35
36
  if TYPE_CHECKING:
36
37
  from airflow.utils.context import Context
@@ -314,10 +315,11 @@ class RedshiftCreateClusterOperator(BaseOperator):
314
315
  self.log.info("Created Redshift cluster %s", self.cluster_identifier)
315
316
  self.log.info(cluster)
316
317
 
317
- def execute_complete(self, context, event=None):
318
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
319
+ event = validate_execute_complete_event(event)
320
+
318
321
  if event["status"] != "success":
319
322
  raise AirflowException(f"Error creating cluster: {event}")
320
- return
321
323
 
322
324
 
323
325
  class RedshiftCreateClusterSnapshotOperator(BaseOperator):
@@ -409,12 +411,13 @@ class RedshiftCreateClusterSnapshotOperator(BaseOperator):
409
411
  },
410
412
  )
411
413
 
412
- def execute_complete(self, context, event=None):
414
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
415
+ event = validate_execute_complete_event(event)
416
+
413
417
  if event["status"] != "success":
414
418
  raise AirflowException(f"Error creating snapshot: {event}")
415
- else:
416
- self.log.info("Cluster snapshot created.")
417
- return
419
+
420
+ self.log.info("Cluster snapshot created.")
418
421
 
419
422
 
420
423
  class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
@@ -569,10 +572,7 @@ class RedshiftResumeClusterOperator(BaseOperator):
569
572
  )
570
573
 
571
574
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
572
- if event is None:
573
- err_msg = "Trigger error: event is None"
574
- self.log.info(err_msg)
575
- raise AirflowException(err_msg)
575
+ event = validate_execute_complete_event(event)
576
576
 
577
577
  if event["status"] != "success":
578
578
  raise AirflowException(f"Error resuming cluster: {event}")
@@ -659,10 +659,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
659
659
  )
660
660
 
661
661
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
662
- if event is None:
663
- err_msg = "Trigger error: event is None"
664
- self.log.info(err_msg)
665
- raise AirflowException(err_msg)
662
+ event = validate_execute_complete_event(event)
666
663
 
667
664
  if event["status"] != "success":
668
665
  raise AirflowException(f"Error pausing cluster: {event}")
@@ -767,10 +764,7 @@ class RedshiftDeleteClusterOperator(BaseOperator):
767
764
  )
768
765
 
769
766
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
770
- if event is None:
771
- err_msg = "Trigger error: event is None"
772
- self.log.info(err_msg)
773
- raise AirflowException(err_msg)
767
+ event = validate_execute_complete_event(event)
774
768
 
775
769
  if event["status"] != "success":
776
770
  raise AirflowException(f"Error deleting cluster: {event}")
@@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException
24
24
  from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
25
25
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
26
26
  from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger
27
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
27
28
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
28
29
 
29
30
  if TYPE_CHECKING:
@@ -170,10 +171,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
170
171
  def execute_complete(
171
172
  self, context: Context, event: dict[str, Any] | None = None
172
173
  ) -> GetStatementResultResponseTypeDef | str:
173
- if event is None:
174
- err_msg = "Trigger error: event is None"
175
- self.log.info(err_msg)
176
- raise AirflowException(err_msg)
174
+ event = validate_execute_complete_event(event)
177
175
 
178
176
  if event["status"] == "error":
179
177
  msg = f"context: {context}, error message: {event['message']}"
@@ -29,12 +29,17 @@ from airflow.configuration import conf
29
29
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
30
30
  from airflow.models import BaseOperator
31
31
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
32
- from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
32
+ from airflow.providers.amazon.aws.hooks.sagemaker import (
33
+ LogState,
34
+ SageMakerHook,
35
+ secondary_training_status_message,
36
+ )
33
37
  from airflow.providers.amazon.aws.triggers.sagemaker import (
34
38
  SageMakerPipelineTrigger,
39
+ SageMakerTrainingPrintLogTrigger,
35
40
  SageMakerTrigger,
36
41
  )
37
- from airflow.providers.amazon.aws.utils import trim_none_values
42
+ from airflow.providers.amazon.aws.utils import trim_none_values, validate_execute_complete_event
38
43
  from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
39
44
  from airflow.providers.amazon.aws.utils.tags import format_tags
40
45
  from airflow.utils.helpers import prune_dict
@@ -310,11 +315,13 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
310
315
  self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
311
316
  return {"Processing": self.serialized_job}
312
317
 
313
- def execute_complete(self, context, event=None):
318
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
319
+ event = validate_execute_complete_event(event)
320
+
314
321
  if event["status"] != "success":
315
322
  raise AirflowException(f"Error while running job: {event}")
316
- else:
317
- self.log.info(event["message"])
323
+
324
+ self.log.info(event["message"])
318
325
  self.serialized_job = serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))
319
326
  self.log.info("%s completed successfully.", self.task_id)
320
327
  return {"Processing": self.serialized_job}
@@ -561,7 +568,9 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
561
568
  "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
562
569
  }
563
570
 
564
- def execute_complete(self, context, event=None):
571
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
572
+ event = validate_execute_complete_event(event)
573
+
565
574
  if event["status"] != "success":
566
575
  raise AirflowException(f"Error while running job: {event}")
567
576
  endpoint_info = self.config.get("Endpoint", self.config)
@@ -744,10 +753,7 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
744
753
  return self.serialize_result()
745
754
 
746
755
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
747
- if event is None:
748
- err_msg = "Trigger error: event is None"
749
- self.log.error(err_msg)
750
- raise AirflowException(err_msg)
756
+ event = validate_execute_complete_event(event)
751
757
 
752
758
  self.log.info(event["message"])
753
759
  return self.serialize_result()
@@ -899,9 +905,11 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
899
905
  aws_conn_id=self.aws_conn_id,
900
906
  ),
901
907
  method_name="execute_complete",
902
- timeout=datetime.timedelta(seconds=self.max_ingestion_time)
903
- if self.max_ingestion_time is not None
904
- else None,
908
+ timeout=(
909
+ datetime.timedelta(seconds=self.max_ingestion_time)
910
+ if self.max_ingestion_time is not None
911
+ else None
912
+ ),
905
913
  )
906
914
  description = {} # never executed but makes static checkers happy
907
915
  elif self.wait_for_completion:
@@ -917,7 +925,9 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
917
925
 
918
926
  return {"Tuning": serialize(description)}
919
927
 
920
- def execute_complete(self, context, event=None):
928
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
929
+ event = validate_execute_complete_event(event)
930
+
921
931
  if event["status"] != "success":
922
932
  raise AirflowException(f"Error while running job: {event}")
923
933
  return {
@@ -1085,28 +1095,77 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1085
1095
  raise AirflowException(f"Sagemaker Training Job creation failed: {response}")
1086
1096
 
1087
1097
  if self.deferrable and self.wait_for_completion:
1088
- self.defer(
1089
- timeout=self.execution_timeout,
1090
- trigger=SageMakerTrigger(
1098
+ description = self.hook.describe_training_job(self.config["TrainingJobName"])
1099
+ status = description["TrainingJobStatus"]
1100
+
1101
+ if self.print_log:
1102
+ instance_count = description["ResourceConfig"]["InstanceCount"]
1103
+ last_describe_job_call = time.monotonic()
1104
+ job_already_completed = status not in self.hook.non_terminal_states
1105
+ _, description, last_describe_job_call = self.hook.describe_training_job_with_log(
1106
+ self.config["TrainingJobName"],
1107
+ {},
1108
+ [],
1109
+ instance_count,
1110
+ LogState.COMPLETE if job_already_completed else LogState.TAILING,
1111
+ description,
1112
+ last_describe_job_call,
1113
+ )
1114
+ self.log.info(secondary_training_status_message(description, None))
1115
+
1116
+ if status in self.hook.failed_states:
1117
+ reason = description.get("FailureReason", "(No reason provided)")
1118
+ raise AirflowException(f"SageMaker job failed because {reason}")
1119
+ elif status == "Completed":
1120
+ log_message = f"{self.task_id} completed successfully."
1121
+ if self.print_log:
1122
+ billable_seconds = SageMakerHook.count_billable_seconds(
1123
+ training_start_time=description["TrainingStartTime"],
1124
+ training_end_time=description["TrainingEndTime"],
1125
+ instance_count=instance_count,
1126
+ )
1127
+ log_message = f"Billable seconds: {billable_seconds}\n{log_message}"
1128
+ self.log.info(log_message)
1129
+ return {"Training": serialize(description)}
1130
+
1131
+ timeout = self.execution_timeout
1132
+ if self.max_ingestion_time:
1133
+ timeout = datetime.timedelta(seconds=self.max_ingestion_time)
1134
+
1135
+ trigger: SageMakerTrainingPrintLogTrigger | SageMakerTrigger
1136
+ if self.print_log:
1137
+ trigger = SageMakerTrainingPrintLogTrigger(
1138
+ job_name=self.config["TrainingJobName"],
1139
+ poke_interval=self.check_interval,
1140
+ aws_conn_id=self.aws_conn_id,
1141
+ )
1142
+ else:
1143
+ trigger = SageMakerTrigger(
1091
1144
  job_name=self.config["TrainingJobName"],
1092
1145
  job_type="Training",
1093
1146
  poke_interval=self.check_interval,
1094
1147
  max_attempts=self.max_attempts,
1095
1148
  aws_conn_id=self.aws_conn_id,
1096
- ),
1149
+ )
1150
+
1151
+ self.defer(
1152
+ timeout=timeout,
1153
+ trigger=trigger,
1097
1154
  method_name="execute_complete",
1098
1155
  )
1099
1156
 
1100
- self.serialized_training_data = serialize(
1101
- self.hook.describe_training_job(self.config["TrainingJobName"])
1102
- )
1103
- return {"Training": self.serialized_training_data}
1157
+ return self.serialize_result()
1158
+
1159
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
1160
+ event = validate_execute_complete_event(event)
1104
1161
 
1105
- def execute_complete(self, context, event=None):
1106
1162
  if event["status"] != "success":
1107
1163
  raise AirflowException(f"Error while running job: {event}")
1108
- else:
1109
- self.log.info(event["message"])
1164
+
1165
+ self.log.info(event["message"])
1166
+ return self.serialize_result()
1167
+
1168
+ def serialize_result(self) -> dict[str, dict]:
1110
1169
  self.serialized_training_data = serialize(
1111
1170
  self.hook.describe_training_job(self.config["TrainingJobName"])
1112
1171
  )
@@ -1237,7 +1296,9 @@ class SageMakerStartPipelineOperator(SageMakerBaseOperator):
1237
1296
  return arn
1238
1297
 
1239
1298
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
1240
- if event is None or event["status"] != "success":
1299
+ event = validate_execute_complete_event(event)
1300
+
1301
+ if event["status"] != "success":
1241
1302
  raise AirflowException(f"Failure during pipeline execution: {event}")
1242
1303
  return event["value"]
1243
1304
 
@@ -1330,12 +1391,14 @@ class SageMakerStopPipelineOperator(SageMakerBaseOperator):
1330
1391
  return status
1331
1392
 
1332
1393
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
1333
- if event is None or event["status"] != "success":
1394
+ event = validate_execute_complete_event(event)
1395
+
1396
+ if event["status"] != "success":
1334
1397
  raise AirflowException(f"Failure during pipeline execution: {event}")
1335
- else:
1336
- # theoretically we should do a `describe` call to know this,
1337
- # but if we reach this point, this is the only possible status
1338
- return "Stopped"
1398
+
1399
+ # theoretically we should do a `describe` call to know this,
1400
+ # but if we reach this point, this is the only possible status
1401
+ return "Stopped"
1339
1402
 
1340
1403
 
1341
1404
  class SageMakerRegisterModelVersionOperator(SageMakerBaseOperator):
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.links.step_function import (
29
29
  )
30
30
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
31
31
  from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
32
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
32
33
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
33
34
 
34
35
  if TYPE_CHECKING:
@@ -129,7 +130,9 @@ class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
129
130
  return execution_arn
130
131
 
131
132
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
132
- if event is None or event["status"] != "success":
133
+ event = validate_execute_complete_event(event)
134
+
135
+ if event["status"] != "success":
133
136
  raise AirflowException(f"Trigger error: event is {event}")
134
137
 
135
138
  self.log.info("State Machine execution completed successfully")
@@ -60,8 +60,8 @@ class BatchSensor(BaseSensorOperator):
60
60
  aws_conn_id: str = "aws_default",
61
61
  region_name: str | None = None,
62
62
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
63
- poke_interval: float = 5,
64
- max_retries: int = 5,
63
+ poke_interval: float = 30,
64
+ max_retries: int = 4200,
65
65
  **kwargs,
66
66
  ):
67
67
  super().__init__(**kwargs)
@@ -24,6 +24,7 @@ from airflow.configuration import conf
24
24
  from airflow.exceptions import AirflowException, AirflowSkipException
25
25
  from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
26
26
  from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
27
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
27
28
  from airflow.sensors.base import BaseSensorOperator
28
29
 
29
30
  if TYPE_CHECKING:
@@ -92,11 +93,12 @@ class EC2InstanceStateSensor(BaseSensorOperator):
92
93
  self.log.info("instance state: %s", instance_state)
93
94
  return instance_state == self.target_state
94
95
 
95
- def execute_complete(self, context, event=None):
96
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
97
+ event = validate_execute_complete_event(event)
98
+
96
99
  if event["status"] != "success":
97
100
  # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
98
101
  message = f"Error: {event}"
99
102
  if self.soft_fail:
100
103
  raise AirflowSkipException(message)
101
104
  raise AirflowException(message)
102
- return