apache-airflow-providers-amazon 8.22.0rc1__py3-none-any.whl → 8.23.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/executors/batch/batch_executor.py +47 -3
  3. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -0
  4. airflow/providers/amazon/aws/hooks/bedrock.py +20 -0
  5. airflow/providers/amazon/aws/hooks/comprehend.py +37 -0
  6. airflow/providers/amazon/aws/hooks/neptune.py +36 -1
  7. airflow/providers/amazon/aws/operators/athena.py +1 -1
  8. airflow/providers/amazon/aws/operators/batch.py +1 -3
  9. airflow/providers/amazon/aws/operators/bedrock.py +218 -2
  10. airflow/providers/amazon/aws/operators/comprehend.py +192 -0
  11. airflow/providers/amazon/aws/operators/emr.py +21 -11
  12. airflow/providers/amazon/aws/operators/neptune.py +128 -21
  13. airflow/providers/amazon/aws/operators/sagemaker.py +10 -14
  14. airflow/providers/amazon/aws/sensors/comprehend.py +147 -0
  15. airflow/providers/amazon/aws/sensors/emr.py +8 -0
  16. airflow/providers/amazon/aws/triggers/comprehend.py +61 -0
  17. airflow/providers/amazon/aws/triggers/neptune.py +45 -0
  18. airflow/providers/amazon/aws/triggers/sagemaker.py +1 -1
  19. airflow/providers/amazon/aws/utils/__init__.py +7 -0
  20. airflow/providers/amazon/aws/waiters/comprehend.json +49 -0
  21. airflow/providers/amazon/get_provider_info.py +25 -1
  22. {apache_airflow_providers_amazon-8.22.0rc1.dist-info → apache_airflow_providers_amazon-8.23.0rc1.dist-info}/METADATA +6 -6
  23. {apache_airflow_providers_amazon-8.22.0rc1.dist-info → apache_airflow_providers_amazon-8.23.0rc1.dist-info}/RECORD +25 -20
  24. {apache_airflow_providers_amazon-8.22.0rc1.dist-info → apache_airflow_providers_amazon-8.23.0rc1.dist-info}/WHEEL +0 -0
  25. {apache_airflow_providers_amazon-8.22.0rc1.dist-info → apache_airflow_providers_amazon-8.23.0rc1.dist-info}/entry_points.txt +0 -0
@@ -617,6 +617,14 @@ class EmrContainerOperator(BaseOperator):
617
617
  job_id=self.job_id,
618
618
  aws_conn_id=self.aws_conn_id,
619
619
  waiter_delay=self.poll_interval,
620
+ waiter_max_attempts=self.max_polling_attempts,
621
+ )
622
+ if self.max_polling_attempts
623
+ else EmrContainerTrigger(
624
+ virtual_cluster_id=self.virtual_cluster_id,
625
+ job_id=self.job_id,
626
+ aws_conn_id=self.aws_conn_id,
627
+ waiter_delay=self.poll_interval,
620
628
  ),
621
629
  method_name="execute_complete",
622
630
  )
@@ -734,10 +742,20 @@ class EmrCreateJobFlowOperator(BaseOperator):
734
742
  waiter_max_attempts: int | None = None,
735
743
  waiter_delay: int | None = None,
736
744
  waiter_countdown: int | None = None,
737
- waiter_check_interval_seconds: int = 60,
745
+ waiter_check_interval_seconds: int | None = None,
738
746
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
739
747
  **kwargs: Any,
740
748
  ):
749
+ if waiter_check_interval_seconds:
750
+ warnings.warn(
751
+ "The parameter `waiter_check_interval_seconds` has been deprecated to "
752
+ "standardize naming conventions. Please `use waiter_delay instead`. In the "
753
+ "future this will default to None and defer to the waiter's default value.",
754
+ AirflowProviderDeprecationWarning,
755
+ stacklevel=2,
756
+ )
757
+ else:
758
+ waiter_check_interval_seconds = 60
741
759
  if waiter_countdown:
742
760
  warnings.warn(
743
761
  "The parameter waiter_countdown has been deprecated to standardize "
@@ -749,15 +767,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
749
767
  # waiter_countdown defaults to never timing out, which is not supported
750
768
  # by boto waiters, so we will set it here to "a very long time" for now.
751
769
  waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
752
- if waiter_check_interval_seconds:
753
- warnings.warn(
754
- "The parameter waiter_check_interval_seconds has been deprecated to "
755
- "standardize naming conventions. Please use waiter_delay instead. In the "
756
- "future this will default to None and defer to the waiter's default value.",
757
- AirflowProviderDeprecationWarning,
758
- stacklevel=2,
759
- )
760
- waiter_delay = waiter_check_interval_seconds
770
+
761
771
  super().__init__(**kwargs)
762
772
  self.aws_conn_id = aws_conn_id
763
773
  self.emr_conn_id = emr_conn_id
@@ -765,7 +775,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
765
775
  self.region_name = region_name
766
776
  self.wait_for_completion = wait_for_completion
767
777
  self.waiter_max_attempts = waiter_max_attempts or 60
768
- self.waiter_delay = waiter_delay or 30
778
+ self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
769
779
  self.deferrable = deferrable
770
780
 
771
781
  @cached_property
@@ -19,11 +19,15 @@ from __future__ import annotations
19
19
 
20
20
  from typing import TYPE_CHECKING, Any, Sequence
21
21
 
22
+ from botocore.exceptions import ClientError
23
+
22
24
  from airflow.configuration import conf
25
+ from airflow.exceptions import AirflowException
23
26
  from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
24
27
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
25
28
  from airflow.providers.amazon.aws.triggers.neptune import (
26
29
  NeptuneClusterAvailableTrigger,
30
+ NeptuneClusterInstancesAvailableTrigger,
27
31
  NeptuneClusterStoppedTrigger,
28
32
  )
29
33
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
@@ -32,6 +36,50 @@ if TYPE_CHECKING:
32
36
  from airflow.utils.context import Context
33
37
 
34
38
 
39
+ def handle_waitable_exception(
40
+ operator: NeptuneStartDbClusterOperator | NeptuneStopDbClusterOperator, err: str
41
+ ):
42
+ """
43
+ Handle client exceptions for invalid cluster or invalid instance status that are temporary.
44
+
45
+ After status change, it's possible to retry. Waiter will handle terminal status.
46
+ """
47
+ code = err
48
+
49
+ if code in ("InvalidDBInstanceStateFault", "InvalidDBInstanceState"):
50
+ if operator.deferrable:
51
+ operator.log.info("Deferring until instances become available: %s", operator.cluster_id)
52
+ operator.defer(
53
+ trigger=NeptuneClusterInstancesAvailableTrigger(
54
+ aws_conn_id=operator.aws_conn_id,
55
+ db_cluster_id=operator.cluster_id,
56
+ region_name=operator.region_name,
57
+ botocore_config=operator.botocore_config,
58
+ verify=operator.verify,
59
+ ),
60
+ method_name="execute",
61
+ )
62
+ else:
63
+ operator.log.info("Need to wait for instances to become available: %s", operator.cluster_id)
64
+ operator.hook.wait_for_cluster_instance_availability(cluster_id=operator.cluster_id)
65
+ if code in ["InvalidClusterState", "InvalidDBClusterStateFault"]:
66
+ if operator.deferrable:
67
+ operator.log.info("Deferring until cluster becomes available: %s", operator.cluster_id)
68
+ operator.defer(
69
+ trigger=NeptuneClusterAvailableTrigger(
70
+ aws_conn_id=operator.aws_conn_id,
71
+ db_cluster_id=operator.cluster_id,
72
+ region_name=operator.region_name,
73
+ botocore_config=operator.botocore_config,
74
+ verify=operator.verify,
75
+ ),
76
+ method_name="execute",
77
+ )
78
+ else:
79
+ operator.log.info("Need to wait for cluster to become available: %s", operator.cluster_id)
80
+ operator.hook.wait_for_cluster_availability(operator.cluster_id)
81
+
82
+
35
83
  class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
36
84
  """Starts an Amazon Neptune DB cluster.
37
85
 
@@ -78,10 +126,10 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
78
126
  self.cluster_id = db_cluster_id
79
127
  self.wait_for_completion = wait_for_completion
80
128
  self.deferrable = deferrable
81
- self.delay = waiter_delay
82
- self.max_attempts = waiter_max_attempts
129
+ self.waiter_delay = waiter_delay
130
+ self.waiter_max_attempts = waiter_max_attempts
83
131
 
84
- def execute(self, context: Context) -> dict[str, str]:
132
+ def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
85
133
  self.log.info("Starting Neptune cluster: %s", self.cluster_id)
86
134
 
87
135
  # Check to make sure the cluster is not already available.
@@ -89,9 +137,32 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
89
137
  if status.lower() in NeptuneHook.AVAILABLE_STATES:
90
138
  self.log.info("Neptune cluster %s is already available.", self.cluster_id)
91
139
  return {"db_cluster_id": self.cluster_id}
92
-
93
- resp = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
94
- status = resp.get("DBClusters", {}).get("Status", "Unknown")
140
+ elif status.lower() in NeptuneHook.ERROR_STATES:
141
+ # some states will not allow you to start the cluster
142
+ self.log.error(
143
+ "Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status
144
+ )
145
+ raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")
146
+
147
+ """
148
+ A cluster and its instances must be in a valid state to send the start request.
149
+ This loop covers the case where the cluster is not available and also the case where
150
+ the cluster is available, but one or more of the instances are in an invalid state.
151
+ If either are in an invalid state, wait for the availability and retry.
152
+ Let the waiters handle retries and detecting the error states.
153
+ """
154
+ try:
155
+ self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
156
+ except ClientError as ex:
157
+ code = ex.response["Error"]["Code"]
158
+ self.log.warning("Received client error when attempting to start the cluster: %s", code)
159
+
160
+ if code in ["InvalidDBInstanceState", "InvalidClusterState", "InvalidDBClusterStateFault"]:
161
+ handle_waitable_exception(operator=self, err=code)
162
+
163
+ else:
164
+ # re raise for any other type of client error
165
+ raise
95
166
 
96
167
  if self.deferrable:
97
168
  self.log.info("Deferring for cluster start: %s", self.cluster_id)
@@ -100,15 +171,17 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
100
171
  trigger=NeptuneClusterAvailableTrigger(
101
172
  aws_conn_id=self.aws_conn_id,
102
173
  db_cluster_id=self.cluster_id,
103
- waiter_delay=self.delay,
104
- waiter_max_attempts=self.max_attempts,
174
+ waiter_delay=self.waiter_delay,
175
+ waiter_max_attempts=self.waiter_max_attempts,
105
176
  ),
106
177
  method_name="execute_complete",
107
178
  )
108
179
 
109
180
  elif self.wait_for_completion:
110
181
  self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
111
- self.hook.wait_for_cluster_availability(self.cluster_id, self.delay, self.max_attempts)
182
+ self.hook.wait_for_cluster_availability(
183
+ self.cluster_id, self.waiter_delay, self.waiter_max_attempts
184
+ )
112
185
 
113
186
  return {"db_cluster_id": self.cluster_id}
114
187
 
@@ -171,20 +244,53 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
171
244
  self.cluster_id = db_cluster_id
172
245
  self.wait_for_completion = wait_for_completion
173
246
  self.deferrable = deferrable
174
- self.delay = waiter_delay
175
- self.max_attempts = waiter_max_attempts
247
+ self.waiter_delay = waiter_delay
248
+ self.waiter_max_attempts = waiter_max_attempts
176
249
 
177
- def execute(self, context: Context) -> dict[str, str]:
250
+ def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
178
251
  self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
179
252
 
180
- # Check to make sure the cluster is not already stopped.
253
+ # Check to make sure the cluster is not already stopped or that its not in a bad state
181
254
  status = self.hook.get_cluster_status(self.cluster_id)
255
+ self.log.info("Current status: %s", status)
256
+
182
257
  if status.lower() in NeptuneHook.STOPPED_STATES:
183
258
  self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
184
259
  return {"db_cluster_id": self.cluster_id}
185
-
186
- resp = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
187
- status = resp.get("DBClusters", {}).get("Status", "Unknown")
260
+ elif status.lower() in NeptuneHook.ERROR_STATES:
261
+ # some states will not allow you to stop the cluster
262
+ self.log.error(
263
+ "Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status
264
+ )
265
+ raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")
266
+
267
+ """
268
+ A cluster and its instances must be in a valid state to send the stop request.
269
+ This loop covers the case where the cluster is not available and also the case where
270
+ the cluster is available, but one or more of the instances are in an invalid state.
271
+ If either are in an invalid state, wait for the availability and retry.
272
+ Let the waiters handle retries and detecting the error states.
273
+ """
274
+
275
+ try:
276
+ self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
277
+
278
+ # cluster must be in available state to stop it
279
+ except ClientError as ex:
280
+ code = ex.response["Error"]["Code"]
281
+ self.log.warning("Received client error when attempting to stop the cluster: %s", code)
282
+
283
+ # these can be handled by a waiter
284
+ if code in [
285
+ "InvalidDBInstanceState",
286
+ "InvalidDBInstanceStateFault",
287
+ "InvalidClusterState",
288
+ "InvalidDBClusterStateFault",
289
+ ]:
290
+ handle_waitable_exception(self, code)
291
+ else:
292
+ # re raise for any other type of client error
293
+ raise
188
294
 
189
295
  if self.deferrable:
190
296
  self.log.info("Deferring for cluster stop: %s", self.cluster_id)
@@ -193,22 +299,23 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
193
299
  trigger=NeptuneClusterStoppedTrigger(
194
300
  aws_conn_id=self.aws_conn_id,
195
301
  db_cluster_id=self.cluster_id,
196
- waiter_delay=self.delay,
197
- waiter_max_attempts=self.max_attempts,
302
+ waiter_delay=self.waiter_delay,
303
+ waiter_max_attempts=self.waiter_max_attempts,
198
304
  ),
199
305
  method_name="execute_complete",
200
306
  )
201
307
 
202
308
  elif self.wait_for_completion:
203
- self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
204
- self.hook.wait_for_cluster_stopped(self.cluster_id, self.delay, self.max_attempts)
309
+ self.log.info("Waiting for Neptune cluster %s to stop.", self.cluster_id)
310
+
311
+ self.hook.wait_for_cluster_stopped(self.cluster_id, self.waiter_delay, self.waiter_max_attempts)
205
312
 
206
313
  return {"db_cluster_id": self.cluster_id}
207
314
 
208
315
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
209
316
  status = ""
210
317
  cluster_id = ""
211
-
318
+ self.log.info(event)
212
319
  if event:
213
320
  status = event.get("status", "")
214
321
  cluster_id = event.get("cluster_id", "")
@@ -750,20 +750,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
750
750
  method_name="execute_complete",
751
751
  )
752
752
 
753
- return self.serialize_result()
753
+ return self.serialize_result(transform_config["TransformJobName"])
754
754
 
755
755
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
756
756
  event = validate_execute_complete_event(event)
757
757
 
758
758
  self.log.info(event["message"])
759
- return self.serialize_result()
759
+ return self.serialize_result(event["job_name"])
760
760
 
761
- def serialize_result(self) -> dict[str, dict]:
762
- transform_config = self.config.get("Transform", self.config)
763
- self.serialized_model = serialize(self.hook.describe_model(transform_config["ModelName"]))
764
- self.serialized_transform = serialize(
765
- self.hook.describe_transform_job(transform_config["TransformJobName"])
766
- )
761
+ def serialize_result(self, job_name: str) -> dict[str, dict]:
762
+ job_description = self.hook.describe_transform_job(job_name)
763
+ self.serialized_model = serialize(self.hook.describe_model(job_description["ModelName"]))
764
+ self.serialized_transform = serialize(job_description)
767
765
  return {"Model": self.serialized_model, "Transform": self.serialized_transform}
768
766
 
769
767
  def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
@@ -1154,7 +1152,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1154
1152
  method_name="execute_complete",
1155
1153
  )
1156
1154
 
1157
- return self.serialize_result()
1155
+ return self.serialize_result(self.config["TrainingJobName"])
1158
1156
 
1159
1157
  def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
1160
1158
  event = validate_execute_complete_event(event)
@@ -1163,12 +1161,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
1163
1161
  raise AirflowException(f"Error while running job: {event}")
1164
1162
 
1165
1163
  self.log.info(event["message"])
1166
- return self.serialize_result()
1164
+ return self.serialize_result(event["job_name"])
1167
1165
 
1168
- def serialize_result(self) -> dict[str, dict]:
1169
- self.serialized_training_data = serialize(
1170
- self.hook.describe_training_job(self.config["TrainingJobName"])
1171
- )
1166
+ def serialize_result(self, job_name: str) -> dict[str, dict]:
1167
+ self.serialized_training_data = serialize(self.hook.describe_training_job(job_name))
1172
1168
  return {"Training": self.serialized_training_data}
1173
1169
 
1174
1170
  def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
@@ -0,0 +1,147 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import abc
20
+ from typing import TYPE_CHECKING, Any, Sequence
21
+
22
+ from airflow.configuration import conf
23
+ from airflow.exceptions import AirflowException, AirflowSkipException
24
+ from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
25
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
26
+ from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger
27
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
28
+
29
+ if TYPE_CHECKING:
30
+ from airflow.utils.context import Context
31
+
32
+
33
+ class ComprehendBaseSensor(AwsBaseSensor[ComprehendHook]):
34
+ """
35
+ General sensor behavior for Amazon Comprehend.
36
+
37
+ Subclasses must implement following methods:
38
+ - ``get_state()``
39
+
40
+ Subclasses must set the following fields:
41
+ - ``INTERMEDIATE_STATES``
42
+ - ``FAILURE_STATES``
43
+ - ``SUCCESS_STATES``
44
+ - ``FAILURE_MESSAGE``
45
+
46
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
47
+ module to be installed.
48
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
49
+ """
50
+
51
+ aws_hook_class = ComprehendHook
52
+
53
+ INTERMEDIATE_STATES: tuple[str, ...] = ()
54
+ FAILURE_STATES: tuple[str, ...] = ()
55
+ SUCCESS_STATES: tuple[str, ...] = ()
56
+ FAILURE_MESSAGE = ""
57
+
58
+ ui_color = "#66c3ff"
59
+
60
+ def __init__(
61
+ self,
62
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
63
+ **kwargs: Any,
64
+ ):
65
+ super().__init__(**kwargs)
66
+ self.deferrable = deferrable
67
+
68
+ def poke(self, context: Context, **kwargs) -> bool:
69
+ state = self.get_state()
70
+ if state in self.FAILURE_STATES:
71
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
72
+ if self.soft_fail:
73
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
74
+ raise AirflowException(self.FAILURE_MESSAGE)
75
+
76
+ return state not in self.INTERMEDIATE_STATES
77
+
78
+ @abc.abstractmethod
79
+ def get_state(self) -> str:
80
+ """Implement in subclasses."""
81
+
82
+
83
+ class ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor):
84
+ """
85
+ Poll the state of the pii entities detection job until it reaches a completed state; fails if the job fails.
86
+
87
+ .. seealso::
88
+ For more information on how to use this sensor, take a look at the guide:
89
+ :ref:`howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor`
90
+
91
+ :param job_id: The id of the Comprehend pii entities detection job.
92
+
93
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
94
+ module to be installed.
95
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
96
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
97
+ :param max_retries: Number of times before returning the current state. (default: 75)
98
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
99
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
100
+ running Airflow in a distributed manner and aws_conn_id is None or
101
+ empty, then default boto3 configuration would be used (and must be
102
+ maintained on each worker node).
103
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
104
+ :param verify: Whether to verify SSL certificates. See:
105
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
106
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
107
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
108
+ """
109
+
110
+ INTERMEDIATE_STATES: tuple[str, ...] = ("IN_PROGRESS",)
111
+ FAILURE_STATES: tuple[str, ...] = ("FAILED", "STOP_REQUESTED", "STOPPED")
112
+ SUCCESS_STATES: tuple[str, ...] = ("COMPLETED",)
113
+ FAILURE_MESSAGE = "Comprehend start pii entities detection job sensor failed."
114
+
115
+ template_fields: Sequence[str] = aws_template_fields("job_id")
116
+
117
+ def __init__(
118
+ self,
119
+ *,
120
+ job_id: str,
121
+ max_retries: int = 75,
122
+ poke_interval: int = 120,
123
+ **kwargs: Any,
124
+ ) -> None:
125
+ super().__init__(**kwargs)
126
+ self.job_id = job_id
127
+ self.max_retries = max_retries
128
+ self.poke_interval = poke_interval
129
+
130
+ def execute(self, context: Context) -> Any:
131
+ if self.deferrable:
132
+ self.defer(
133
+ trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
134
+ job_id=self.job_id,
135
+ waiter_delay=int(self.poke_interval),
136
+ waiter_max_attempts=self.max_retries,
137
+ aws_conn_id=self.aws_conn_id,
138
+ ),
139
+ method_name="poke",
140
+ )
141
+ else:
142
+ super().execute(context=context)
143
+
144
+ def get_state(self) -> str:
145
+ return self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
146
+ "PiiEntitiesDetectionJobProperties"
147
+ ]["JobStatus"]
@@ -354,6 +354,14 @@ class EmrContainerSensor(BaseSensorOperator):
354
354
  job_id=self.job_id,
355
355
  aws_conn_id=self.aws_conn_id,
356
356
  waiter_delay=self.poll_interval,
357
+ waiter_max_attempts=self.max_retries,
358
+ )
359
+ if self.max_retries
360
+ else EmrContainerTrigger(
361
+ virtual_cluster_id=self.virtual_cluster_id,
362
+ job_id=self.job_id,
363
+ aws_conn_id=self.aws_conn_id,
364
+ waiter_delay=self.poll_interval,
357
365
  ),
358
366
  method_name="execute_complete",
359
367
  )
@@ -0,0 +1,61 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ if TYPE_CHECKING:
22
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
23
+
24
+ from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
25
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
26
+
27
+
28
+ class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
29
+ """
30
+ Trigger when a Comprehend pii entities detection job is complete.
31
+
32
+ :param job_id: The id of the Comprehend pii entities detection job.
33
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
34
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
35
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ job_id: str,
42
+ waiter_delay: int = 120,
43
+ waiter_max_attempts: int = 75,
44
+ aws_conn_id: str | None = "aws_default",
45
+ ) -> None:
46
+ super().__init__(
47
+ serialized_fields={"job_id": job_id},
48
+ waiter_name="pii_entities_detection_job_complete",
49
+ waiter_args={"JobId": job_id},
50
+ failure_message="Comprehend start pii entities detection job failed.",
51
+ status_message="Status of Comprehend start pii entities detection job is",
52
+ status_queries=["PiiEntitiesDetectionJobProperties.JobStatus"],
53
+ return_key="job_id",
54
+ return_value=job_id,
55
+ waiter_delay=waiter_delay,
56
+ waiter_max_attempts=waiter_max_attempts,
57
+ aws_conn_id=aws_conn_id,
58
+ )
59
+
60
+ def hook(self) -> AwsGenericHook:
61
+ return ComprehendHook(aws_conn_id=self.aws_conn_id)
@@ -113,3 +113,48 @@ class NeptuneClusterStoppedTrigger(AwsBaseWaiterTrigger):
113
113
  verify=self.verify,
114
114
  config=self.botocore_config,
115
115
  )
116
+
117
+
118
+ class NeptuneClusterInstancesAvailableTrigger(AwsBaseWaiterTrigger):
119
+ """
120
+ Triggers when a Neptune Cluster Instance is available.
121
+
122
+ :param db_cluster_id: Cluster ID to wait on instances from
123
+ :param waiter_delay: The amount of time in seconds to wait between attempts.
124
+ :param waiter_max_attempts: The maximum number of attempts to be made.
125
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
126
+ :param region_name: AWS region name (example: us-east-1)
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ *,
132
+ db_cluster_id: str,
133
+ waiter_delay: int = 30,
134
+ waiter_max_attempts: int = 60,
135
+ aws_conn_id: str | None = None,
136
+ region_name: str | None = None,
137
+ **kwargs,
138
+ ) -> None:
139
+ super().__init__(
140
+ serialized_fields={"db_cluster_id": db_cluster_id},
141
+ waiter_name="db_instance_available",
142
+ waiter_args={"Filters": [{"Name": "db-cluster-id", "Values": [db_cluster_id]}]},
143
+ failure_message="Failed to start Neptune instances",
144
+ status_message="Status of Neptune instances are",
145
+ status_queries=["DBInstances[].Status"],
146
+ return_key="db_cluster_id",
147
+ return_value=db_cluster_id,
148
+ waiter_delay=waiter_delay,
149
+ waiter_max_attempts=waiter_max_attempts,
150
+ aws_conn_id=aws_conn_id,
151
+ **kwargs,
152
+ )
153
+
154
+ def hook(self) -> AwsGenericHook:
155
+ return NeptuneHook(
156
+ aws_conn_id=self.aws_conn_id,
157
+ region_name=self.region_name,
158
+ verify=self.verify,
159
+ config=self.botocore_config,
160
+ )
@@ -121,7 +121,7 @@ class SageMakerTrigger(BaseTrigger):
121
121
  status_message=f"{self.job_type} job not done yet",
122
122
  status_args=[self._get_response_status_key(self.job_type)],
123
123
  )
124
- yield TriggerEvent({"status": "success", "message": "Job completed."})
124
+ yield TriggerEvent({"status": "success", "message": "Job completed.", "job_name": self.job_name})
125
125
 
126
126
 
127
127
  class SageMakerPipelineTrigger(BaseTrigger):
@@ -22,6 +22,8 @@ from datetime import datetime, timezone
22
22
  from enum import Enum
23
23
  from typing import Any
24
24
 
25
+ import importlib_metadata
26
+
25
27
  from airflow.exceptions import AirflowException
26
28
  from airflow.utils.helpers import prune_dict
27
29
  from airflow.version import version
@@ -74,6 +76,11 @@ def get_airflow_version() -> tuple[int, ...]:
74
76
  return tuple(int(x) for x in match.groups())
75
77
 
76
78
 
79
+ def get_botocore_version() -> tuple[int, ...]:
80
+ """Return the version number of the installed botocore package in the form of a tuple[int,...]."""
81
+ return tuple(map(int, importlib_metadata.version("botocore").split(".")[:3]))
82
+
83
+
77
84
  def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]:
78
85
  if event is None:
79
86
  err_msg = "Trigger error: event is None"