apache-airflow-providers-amazon 8.22.0__py3-none-any.whl → 8.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +47 -3
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -0
- airflow/providers/amazon/aws/hooks/bedrock.py +20 -0
- airflow/providers/amazon/aws/hooks/comprehend.py +37 -0
- airflow/providers/amazon/aws/hooks/neptune.py +36 -1
- airflow/providers/amazon/aws/operators/athena.py +1 -1
- airflow/providers/amazon/aws/operators/batch.py +1 -3
- airflow/providers/amazon/aws/operators/bedrock.py +218 -2
- airflow/providers/amazon/aws/operators/comprehend.py +192 -0
- airflow/providers/amazon/aws/operators/emr.py +21 -11
- airflow/providers/amazon/aws/operators/neptune.py +128 -21
- airflow/providers/amazon/aws/operators/sagemaker.py +10 -14
- airflow/providers/amazon/aws/sensors/comprehend.py +147 -0
- airflow/providers/amazon/aws/sensors/emr.py +8 -0
- airflow/providers/amazon/aws/triggers/comprehend.py +61 -0
- airflow/providers/amazon/aws/triggers/neptune.py +45 -0
- airflow/providers/amazon/aws/triggers/sagemaker.py +1 -1
- airflow/providers/amazon/aws/utils/__init__.py +7 -0
- airflow/providers/amazon/aws/waiters/comprehend.json +49 -0
- airflow/providers/amazon/get_provider_info.py +25 -1
- {apache_airflow_providers_amazon-8.22.0.dist-info → apache_airflow_providers_amazon-8.23.0.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.22.0.dist-info → apache_airflow_providers_amazon-8.23.0.dist-info}/RECORD +25 -20
- {apache_airflow_providers_amazon-8.22.0.dist-info → apache_airflow_providers_amazon-8.23.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.22.0.dist-info → apache_airflow_providers_amazon-8.23.0.dist-info}/entry_points.txt +0 -0
@@ -617,6 +617,14 @@ class EmrContainerOperator(BaseOperator):
|
|
617
617
|
job_id=self.job_id,
|
618
618
|
aws_conn_id=self.aws_conn_id,
|
619
619
|
waiter_delay=self.poll_interval,
|
620
|
+
waiter_max_attempts=self.max_polling_attempts,
|
621
|
+
)
|
622
|
+
if self.max_polling_attempts
|
623
|
+
else EmrContainerTrigger(
|
624
|
+
virtual_cluster_id=self.virtual_cluster_id,
|
625
|
+
job_id=self.job_id,
|
626
|
+
aws_conn_id=self.aws_conn_id,
|
627
|
+
waiter_delay=self.poll_interval,
|
620
628
|
),
|
621
629
|
method_name="execute_complete",
|
622
630
|
)
|
@@ -734,10 +742,20 @@ class EmrCreateJobFlowOperator(BaseOperator):
|
|
734
742
|
waiter_max_attempts: int | None = None,
|
735
743
|
waiter_delay: int | None = None,
|
736
744
|
waiter_countdown: int | None = None,
|
737
|
-
waiter_check_interval_seconds: int =
|
745
|
+
waiter_check_interval_seconds: int | None = None,
|
738
746
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
739
747
|
**kwargs: Any,
|
740
748
|
):
|
749
|
+
if waiter_check_interval_seconds:
|
750
|
+
warnings.warn(
|
751
|
+
"The parameter `waiter_check_interval_seconds` has been deprecated to "
|
752
|
+
"standardize naming conventions. Please `use waiter_delay instead`. In the "
|
753
|
+
"future this will default to None and defer to the waiter's default value.",
|
754
|
+
AirflowProviderDeprecationWarning,
|
755
|
+
stacklevel=2,
|
756
|
+
)
|
757
|
+
else:
|
758
|
+
waiter_check_interval_seconds = 60
|
741
759
|
if waiter_countdown:
|
742
760
|
warnings.warn(
|
743
761
|
"The parameter waiter_countdown has been deprecated to standardize "
|
@@ -749,15 +767,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
|
|
749
767
|
# waiter_countdown defaults to never timing out, which is not supported
|
750
768
|
# by boto waiters, so we will set it here to "a very long time" for now.
|
751
769
|
waiter_max_attempts = (waiter_countdown or 999) // waiter_check_interval_seconds
|
752
|
-
|
753
|
-
warnings.warn(
|
754
|
-
"The parameter waiter_check_interval_seconds has been deprecated to "
|
755
|
-
"standardize naming conventions. Please use waiter_delay instead. In the "
|
756
|
-
"future this will default to None and defer to the waiter's default value.",
|
757
|
-
AirflowProviderDeprecationWarning,
|
758
|
-
stacklevel=2,
|
759
|
-
)
|
760
|
-
waiter_delay = waiter_check_interval_seconds
|
770
|
+
|
761
771
|
super().__init__(**kwargs)
|
762
772
|
self.aws_conn_id = aws_conn_id
|
763
773
|
self.emr_conn_id = emr_conn_id
|
@@ -765,7 +775,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
|
|
765
775
|
self.region_name = region_name
|
766
776
|
self.wait_for_completion = wait_for_completion
|
767
777
|
self.waiter_max_attempts = waiter_max_attempts or 60
|
768
|
-
self.waiter_delay = waiter_delay or
|
778
|
+
self.waiter_delay = waiter_delay or waiter_check_interval_seconds or 60
|
769
779
|
self.deferrable = deferrable
|
770
780
|
|
771
781
|
@cached_property
|
@@ -19,11 +19,15 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import TYPE_CHECKING, Any, Sequence
|
21
21
|
|
22
|
+
from botocore.exceptions import ClientError
|
23
|
+
|
22
24
|
from airflow.configuration import conf
|
25
|
+
from airflow.exceptions import AirflowException
|
23
26
|
from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
|
24
27
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
25
28
|
from airflow.providers.amazon.aws.triggers.neptune import (
|
26
29
|
NeptuneClusterAvailableTrigger,
|
30
|
+
NeptuneClusterInstancesAvailableTrigger,
|
27
31
|
NeptuneClusterStoppedTrigger,
|
28
32
|
)
|
29
33
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
@@ -32,6 +36,50 @@ if TYPE_CHECKING:
|
|
32
36
|
from airflow.utils.context import Context
|
33
37
|
|
34
38
|
|
39
|
+
def handle_waitable_exception(
|
40
|
+
operator: NeptuneStartDbClusterOperator | NeptuneStopDbClusterOperator, err: str
|
41
|
+
):
|
42
|
+
"""
|
43
|
+
Handle client exceptions for invalid cluster or invalid instance status that are temporary.
|
44
|
+
|
45
|
+
After status change, it's possible to retry. Waiter will handle terminal status.
|
46
|
+
"""
|
47
|
+
code = err
|
48
|
+
|
49
|
+
if code in ("InvalidDBInstanceStateFault", "InvalidDBInstanceState"):
|
50
|
+
if operator.deferrable:
|
51
|
+
operator.log.info("Deferring until instances become available: %s", operator.cluster_id)
|
52
|
+
operator.defer(
|
53
|
+
trigger=NeptuneClusterInstancesAvailableTrigger(
|
54
|
+
aws_conn_id=operator.aws_conn_id,
|
55
|
+
db_cluster_id=operator.cluster_id,
|
56
|
+
region_name=operator.region_name,
|
57
|
+
botocore_config=operator.botocore_config,
|
58
|
+
verify=operator.verify,
|
59
|
+
),
|
60
|
+
method_name="execute",
|
61
|
+
)
|
62
|
+
else:
|
63
|
+
operator.log.info("Need to wait for instances to become available: %s", operator.cluster_id)
|
64
|
+
operator.hook.wait_for_cluster_instance_availability(cluster_id=operator.cluster_id)
|
65
|
+
if code in ["InvalidClusterState", "InvalidDBClusterStateFault"]:
|
66
|
+
if operator.deferrable:
|
67
|
+
operator.log.info("Deferring until cluster becomes available: %s", operator.cluster_id)
|
68
|
+
operator.defer(
|
69
|
+
trigger=NeptuneClusterAvailableTrigger(
|
70
|
+
aws_conn_id=operator.aws_conn_id,
|
71
|
+
db_cluster_id=operator.cluster_id,
|
72
|
+
region_name=operator.region_name,
|
73
|
+
botocore_config=operator.botocore_config,
|
74
|
+
verify=operator.verify,
|
75
|
+
),
|
76
|
+
method_name="execute",
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
operator.log.info("Need to wait for cluster to become available: %s", operator.cluster_id)
|
80
|
+
operator.hook.wait_for_cluster_availability(operator.cluster_id)
|
81
|
+
|
82
|
+
|
35
83
|
class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
36
84
|
"""Starts an Amazon Neptune DB cluster.
|
37
85
|
|
@@ -78,10 +126,10 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
78
126
|
self.cluster_id = db_cluster_id
|
79
127
|
self.wait_for_completion = wait_for_completion
|
80
128
|
self.deferrable = deferrable
|
81
|
-
self.
|
82
|
-
self.
|
129
|
+
self.waiter_delay = waiter_delay
|
130
|
+
self.waiter_max_attempts = waiter_max_attempts
|
83
131
|
|
84
|
-
def execute(self, context: Context) -> dict[str, str]:
|
132
|
+
def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
|
85
133
|
self.log.info("Starting Neptune cluster: %s", self.cluster_id)
|
86
134
|
|
87
135
|
# Check to make sure the cluster is not already available.
|
@@ -89,9 +137,32 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
89
137
|
if status.lower() in NeptuneHook.AVAILABLE_STATES:
|
90
138
|
self.log.info("Neptune cluster %s is already available.", self.cluster_id)
|
91
139
|
return {"db_cluster_id": self.cluster_id}
|
92
|
-
|
93
|
-
|
94
|
-
|
140
|
+
elif status.lower() in NeptuneHook.ERROR_STATES:
|
141
|
+
# some states will not allow you to start the cluster
|
142
|
+
self.log.error(
|
143
|
+
"Neptune cluster %s is in error state %s and cannot be started", self.cluster_id, status
|
144
|
+
)
|
145
|
+
raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")
|
146
|
+
|
147
|
+
"""
|
148
|
+
A cluster and its instances must be in a valid state to send the start request.
|
149
|
+
This loop covers the case where the cluster is not available and also the case where
|
150
|
+
the cluster is available, but one or more of the instances are in an invalid state.
|
151
|
+
If either are in an invalid state, wait for the availability and retry.
|
152
|
+
Let the waiters handle retries and detecting the error states.
|
153
|
+
"""
|
154
|
+
try:
|
155
|
+
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
|
156
|
+
except ClientError as ex:
|
157
|
+
code = ex.response["Error"]["Code"]
|
158
|
+
self.log.warning("Received client error when attempting to start the cluster: %s", code)
|
159
|
+
|
160
|
+
if code in ["InvalidDBInstanceState", "InvalidClusterState", "InvalidDBClusterStateFault"]:
|
161
|
+
handle_waitable_exception(operator=self, err=code)
|
162
|
+
|
163
|
+
else:
|
164
|
+
# re raise for any other type of client error
|
165
|
+
raise
|
95
166
|
|
96
167
|
if self.deferrable:
|
97
168
|
self.log.info("Deferring for cluster start: %s", self.cluster_id)
|
@@ -100,15 +171,17 @@ class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
100
171
|
trigger=NeptuneClusterAvailableTrigger(
|
101
172
|
aws_conn_id=self.aws_conn_id,
|
102
173
|
db_cluster_id=self.cluster_id,
|
103
|
-
waiter_delay=self.
|
104
|
-
waiter_max_attempts=self.
|
174
|
+
waiter_delay=self.waiter_delay,
|
175
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
105
176
|
),
|
106
177
|
method_name="execute_complete",
|
107
178
|
)
|
108
179
|
|
109
180
|
elif self.wait_for_completion:
|
110
181
|
self.log.info("Waiting for Neptune cluster %s to start.", self.cluster_id)
|
111
|
-
self.hook.wait_for_cluster_availability(
|
182
|
+
self.hook.wait_for_cluster_availability(
|
183
|
+
self.cluster_id, self.waiter_delay, self.waiter_max_attempts
|
184
|
+
)
|
112
185
|
|
113
186
|
return {"db_cluster_id": self.cluster_id}
|
114
187
|
|
@@ -171,20 +244,53 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
171
244
|
self.cluster_id = db_cluster_id
|
172
245
|
self.wait_for_completion = wait_for_completion
|
173
246
|
self.deferrable = deferrable
|
174
|
-
self.
|
175
|
-
self.
|
247
|
+
self.waiter_delay = waiter_delay
|
248
|
+
self.waiter_max_attempts = waiter_max_attempts
|
176
249
|
|
177
|
-
def execute(self, context: Context) -> dict[str, str]:
|
250
|
+
def execute(self, context: Context, event: dict[str, Any] | None = None, **kwargs) -> dict[str, str]:
|
178
251
|
self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
|
179
252
|
|
180
|
-
# Check to make sure the cluster is not already stopped
|
253
|
+
# Check to make sure the cluster is not already stopped or that its not in a bad state
|
181
254
|
status = self.hook.get_cluster_status(self.cluster_id)
|
255
|
+
self.log.info("Current status: %s", status)
|
256
|
+
|
182
257
|
if status.lower() in NeptuneHook.STOPPED_STATES:
|
183
258
|
self.log.info("Neptune cluster %s is already stopped.", self.cluster_id)
|
184
259
|
return {"db_cluster_id": self.cluster_id}
|
185
|
-
|
186
|
-
|
187
|
-
|
260
|
+
elif status.lower() in NeptuneHook.ERROR_STATES:
|
261
|
+
# some states will not allow you to stop the cluster
|
262
|
+
self.log.error(
|
263
|
+
"Neptune cluster %s is in error state %s and cannot be stopped", self.cluster_id, status
|
264
|
+
)
|
265
|
+
raise AirflowException(f"Neptune cluster {self.cluster_id} is in error state {status}")
|
266
|
+
|
267
|
+
"""
|
268
|
+
A cluster and its instances must be in a valid state to send the stop request.
|
269
|
+
This loop covers the case where the cluster is not available and also the case where
|
270
|
+
the cluster is available, but one or more of the instances are in an invalid state.
|
271
|
+
If either are in an invalid state, wait for the availability and retry.
|
272
|
+
Let the waiters handle retries and detecting the error states.
|
273
|
+
"""
|
274
|
+
|
275
|
+
try:
|
276
|
+
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
|
277
|
+
|
278
|
+
# cluster must be in available state to stop it
|
279
|
+
except ClientError as ex:
|
280
|
+
code = ex.response["Error"]["Code"]
|
281
|
+
self.log.warning("Received client error when attempting to stop the cluster: %s", code)
|
282
|
+
|
283
|
+
# these can be handled by a waiter
|
284
|
+
if code in [
|
285
|
+
"InvalidDBInstanceState",
|
286
|
+
"InvalidDBInstanceStateFault",
|
287
|
+
"InvalidClusterState",
|
288
|
+
"InvalidDBClusterStateFault",
|
289
|
+
]:
|
290
|
+
handle_waitable_exception(self, code)
|
291
|
+
else:
|
292
|
+
# re raise for any other type of client error
|
293
|
+
raise
|
188
294
|
|
189
295
|
if self.deferrable:
|
190
296
|
self.log.info("Deferring for cluster stop: %s", self.cluster_id)
|
@@ -193,22 +299,23 @@ class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
|
|
193
299
|
trigger=NeptuneClusterStoppedTrigger(
|
194
300
|
aws_conn_id=self.aws_conn_id,
|
195
301
|
db_cluster_id=self.cluster_id,
|
196
|
-
waiter_delay=self.
|
197
|
-
waiter_max_attempts=self.
|
302
|
+
waiter_delay=self.waiter_delay,
|
303
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
198
304
|
),
|
199
305
|
method_name="execute_complete",
|
200
306
|
)
|
201
307
|
|
202
308
|
elif self.wait_for_completion:
|
203
|
-
self.log.info("Waiting for Neptune cluster %s to
|
204
|
-
|
309
|
+
self.log.info("Waiting for Neptune cluster %s to stop.", self.cluster_id)
|
310
|
+
|
311
|
+
self.hook.wait_for_cluster_stopped(self.cluster_id, self.waiter_delay, self.waiter_max_attempts)
|
205
312
|
|
206
313
|
return {"db_cluster_id": self.cluster_id}
|
207
314
|
|
208
315
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, str]:
|
209
316
|
status = ""
|
210
317
|
cluster_id = ""
|
211
|
-
|
318
|
+
self.log.info(event)
|
212
319
|
if event:
|
213
320
|
status = event.get("status", "")
|
214
321
|
cluster_id = event.get("cluster_id", "")
|
@@ -750,20 +750,18 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
|
|
750
750
|
method_name="execute_complete",
|
751
751
|
)
|
752
752
|
|
753
|
-
return self.serialize_result()
|
753
|
+
return self.serialize_result(transform_config["TransformJobName"])
|
754
754
|
|
755
755
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
756
756
|
event = validate_execute_complete_event(event)
|
757
757
|
|
758
758
|
self.log.info(event["message"])
|
759
|
-
return self.serialize_result()
|
759
|
+
return self.serialize_result(event["job_name"])
|
760
760
|
|
761
|
-
def serialize_result(self) -> dict[str, dict]:
|
762
|
-
|
763
|
-
self.serialized_model = serialize(self.hook.describe_model(
|
764
|
-
self.serialized_transform = serialize(
|
765
|
-
self.hook.describe_transform_job(transform_config["TransformJobName"])
|
766
|
-
)
|
761
|
+
def serialize_result(self, job_name: str) -> dict[str, dict]:
|
762
|
+
job_description = self.hook.describe_transform_job(job_name)
|
763
|
+
self.serialized_model = serialize(self.hook.describe_model(job_description["ModelName"]))
|
764
|
+
self.serialized_transform = serialize(job_description)
|
767
765
|
return {"Model": self.serialized_model, "Transform": self.serialized_transform}
|
768
766
|
|
769
767
|
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
|
@@ -1154,7 +1152,7 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1154
1152
|
method_name="execute_complete",
|
1155
1153
|
)
|
1156
1154
|
|
1157
|
-
return self.serialize_result()
|
1155
|
+
return self.serialize_result(self.config["TrainingJobName"])
|
1158
1156
|
|
1159
1157
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]:
|
1160
1158
|
event = validate_execute_complete_event(event)
|
@@ -1163,12 +1161,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
|
|
1163
1161
|
raise AirflowException(f"Error while running job: {event}")
|
1164
1162
|
|
1165
1163
|
self.log.info(event["message"])
|
1166
|
-
return self.serialize_result()
|
1164
|
+
return self.serialize_result(event["job_name"])
|
1167
1165
|
|
1168
|
-
def serialize_result(self) -> dict[str, dict]:
|
1169
|
-
self.serialized_training_data = serialize(
|
1170
|
-
self.hook.describe_training_job(self.config["TrainingJobName"])
|
1171
|
-
)
|
1166
|
+
def serialize_result(self, job_name: str) -> dict[str, dict]:
|
1167
|
+
self.serialized_training_data = serialize(self.hook.describe_training_job(job_name))
|
1172
1168
|
return {"Training": self.serialized_training_data}
|
1173
1169
|
|
1174
1170
|
def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage:
|
@@ -0,0 +1,147 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import abc
|
20
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
21
|
+
|
22
|
+
from airflow.configuration import conf
|
23
|
+
from airflow.exceptions import AirflowException, AirflowSkipException
|
24
|
+
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
|
25
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
26
|
+
from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger
|
27
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
28
|
+
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from airflow.utils.context import Context
|
31
|
+
|
32
|
+
|
33
|
+
class ComprehendBaseSensor(AwsBaseSensor[ComprehendHook]):
|
34
|
+
"""
|
35
|
+
General sensor behavior for Amazon Comprehend.
|
36
|
+
|
37
|
+
Subclasses must implement following methods:
|
38
|
+
- ``get_state()``
|
39
|
+
|
40
|
+
Subclasses must set the following fields:
|
41
|
+
- ``INTERMEDIATE_STATES``
|
42
|
+
- ``FAILURE_STATES``
|
43
|
+
- ``SUCCESS_STATES``
|
44
|
+
- ``FAILURE_MESSAGE``
|
45
|
+
|
46
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
47
|
+
module to be installed.
|
48
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
49
|
+
"""
|
50
|
+
|
51
|
+
aws_hook_class = ComprehendHook
|
52
|
+
|
53
|
+
INTERMEDIATE_STATES: tuple[str, ...] = ()
|
54
|
+
FAILURE_STATES: tuple[str, ...] = ()
|
55
|
+
SUCCESS_STATES: tuple[str, ...] = ()
|
56
|
+
FAILURE_MESSAGE = ""
|
57
|
+
|
58
|
+
ui_color = "#66c3ff"
|
59
|
+
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
63
|
+
**kwargs: Any,
|
64
|
+
):
|
65
|
+
super().__init__(**kwargs)
|
66
|
+
self.deferrable = deferrable
|
67
|
+
|
68
|
+
def poke(self, context: Context, **kwargs) -> bool:
|
69
|
+
state = self.get_state()
|
70
|
+
if state in self.FAILURE_STATES:
|
71
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
72
|
+
if self.soft_fail:
|
73
|
+
raise AirflowSkipException(self.FAILURE_MESSAGE)
|
74
|
+
raise AirflowException(self.FAILURE_MESSAGE)
|
75
|
+
|
76
|
+
return state not in self.INTERMEDIATE_STATES
|
77
|
+
|
78
|
+
@abc.abstractmethod
|
79
|
+
def get_state(self) -> str:
|
80
|
+
"""Implement in subclasses."""
|
81
|
+
|
82
|
+
|
83
|
+
class ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor):
|
84
|
+
"""
|
85
|
+
Poll the state of the pii entities detection job until it reaches a completed state; fails if the job fails.
|
86
|
+
|
87
|
+
.. seealso::
|
88
|
+
For more information on how to use this sensor, take a look at the guide:
|
89
|
+
:ref:`howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor`
|
90
|
+
|
91
|
+
:param job_id: The id of the Comprehend pii entities detection job.
|
92
|
+
|
93
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
94
|
+
module to be installed.
|
95
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
96
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
97
|
+
:param max_retries: Number of times before returning the current state. (default: 75)
|
98
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
99
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
100
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
101
|
+
empty, then default boto3 configuration would be used (and must be
|
102
|
+
maintained on each worker node).
|
103
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
104
|
+
:param verify: Whether to verify SSL certificates. See:
|
105
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
106
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
107
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
108
|
+
"""
|
109
|
+
|
110
|
+
INTERMEDIATE_STATES: tuple[str, ...] = ("IN_PROGRESS",)
|
111
|
+
FAILURE_STATES: tuple[str, ...] = ("FAILED", "STOP_REQUESTED", "STOPPED")
|
112
|
+
SUCCESS_STATES: tuple[str, ...] = ("COMPLETED",)
|
113
|
+
FAILURE_MESSAGE = "Comprehend start pii entities detection job sensor failed."
|
114
|
+
|
115
|
+
template_fields: Sequence[str] = aws_template_fields("job_id")
|
116
|
+
|
117
|
+
def __init__(
|
118
|
+
self,
|
119
|
+
*,
|
120
|
+
job_id: str,
|
121
|
+
max_retries: int = 75,
|
122
|
+
poke_interval: int = 120,
|
123
|
+
**kwargs: Any,
|
124
|
+
) -> None:
|
125
|
+
super().__init__(**kwargs)
|
126
|
+
self.job_id = job_id
|
127
|
+
self.max_retries = max_retries
|
128
|
+
self.poke_interval = poke_interval
|
129
|
+
|
130
|
+
def execute(self, context: Context) -> Any:
|
131
|
+
if self.deferrable:
|
132
|
+
self.defer(
|
133
|
+
trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger(
|
134
|
+
job_id=self.job_id,
|
135
|
+
waiter_delay=int(self.poke_interval),
|
136
|
+
waiter_max_attempts=self.max_retries,
|
137
|
+
aws_conn_id=self.aws_conn_id,
|
138
|
+
),
|
139
|
+
method_name="poke",
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
super().execute(context=context)
|
143
|
+
|
144
|
+
def get_state(self) -> str:
|
145
|
+
return self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
|
146
|
+
"PiiEntitiesDetectionJobProperties"
|
147
|
+
]["JobStatus"]
|
@@ -354,6 +354,14 @@ class EmrContainerSensor(BaseSensorOperator):
|
|
354
354
|
job_id=self.job_id,
|
355
355
|
aws_conn_id=self.aws_conn_id,
|
356
356
|
waiter_delay=self.poll_interval,
|
357
|
+
waiter_max_attempts=self.max_retries,
|
358
|
+
)
|
359
|
+
if self.max_retries
|
360
|
+
else EmrContainerTrigger(
|
361
|
+
virtual_cluster_id=self.virtual_cluster_id,
|
362
|
+
job_id=self.job_id,
|
363
|
+
aws_conn_id=self.aws_conn_id,
|
364
|
+
waiter_delay=self.poll_interval,
|
357
365
|
),
|
358
366
|
method_name="execute_complete",
|
359
367
|
)
|
@@ -0,0 +1,61 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
23
|
+
|
24
|
+
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
|
25
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
26
|
+
|
27
|
+
|
28
|
+
class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
|
29
|
+
"""
|
30
|
+
Trigger when a Comprehend pii entities detection job is complete.
|
31
|
+
|
32
|
+
:param job_id: The id of the Comprehend pii entities detection job.
|
33
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
|
34
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
35
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
*,
|
41
|
+
job_id: str,
|
42
|
+
waiter_delay: int = 120,
|
43
|
+
waiter_max_attempts: int = 75,
|
44
|
+
aws_conn_id: str | None = "aws_default",
|
45
|
+
) -> None:
|
46
|
+
super().__init__(
|
47
|
+
serialized_fields={"job_id": job_id},
|
48
|
+
waiter_name="pii_entities_detection_job_complete",
|
49
|
+
waiter_args={"JobId": job_id},
|
50
|
+
failure_message="Comprehend start pii entities detection job failed.",
|
51
|
+
status_message="Status of Comprehend start pii entities detection job is",
|
52
|
+
status_queries=["PiiEntitiesDetectionJobProperties.JobStatus"],
|
53
|
+
return_key="job_id",
|
54
|
+
return_value=job_id,
|
55
|
+
waiter_delay=waiter_delay,
|
56
|
+
waiter_max_attempts=waiter_max_attempts,
|
57
|
+
aws_conn_id=aws_conn_id,
|
58
|
+
)
|
59
|
+
|
60
|
+
def hook(self) -> AwsGenericHook:
|
61
|
+
return ComprehendHook(aws_conn_id=self.aws_conn_id)
|
@@ -113,3 +113,48 @@ class NeptuneClusterStoppedTrigger(AwsBaseWaiterTrigger):
|
|
113
113
|
verify=self.verify,
|
114
114
|
config=self.botocore_config,
|
115
115
|
)
|
116
|
+
|
117
|
+
|
118
|
+
class NeptuneClusterInstancesAvailableTrigger(AwsBaseWaiterTrigger):
|
119
|
+
"""
|
120
|
+
Triggers when a Neptune Cluster Instance is available.
|
121
|
+
|
122
|
+
:param db_cluster_id: Cluster ID to wait on instances from
|
123
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
124
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
125
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
126
|
+
:param region_name: AWS region name (example: us-east-1)
|
127
|
+
"""
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
*,
|
132
|
+
db_cluster_id: str,
|
133
|
+
waiter_delay: int = 30,
|
134
|
+
waiter_max_attempts: int = 60,
|
135
|
+
aws_conn_id: str | None = None,
|
136
|
+
region_name: str | None = None,
|
137
|
+
**kwargs,
|
138
|
+
) -> None:
|
139
|
+
super().__init__(
|
140
|
+
serialized_fields={"db_cluster_id": db_cluster_id},
|
141
|
+
waiter_name="db_instance_available",
|
142
|
+
waiter_args={"Filters": [{"Name": "db-cluster-id", "Values": [db_cluster_id]}]},
|
143
|
+
failure_message="Failed to start Neptune instances",
|
144
|
+
status_message="Status of Neptune instances are",
|
145
|
+
status_queries=["DBInstances[].Status"],
|
146
|
+
return_key="db_cluster_id",
|
147
|
+
return_value=db_cluster_id,
|
148
|
+
waiter_delay=waiter_delay,
|
149
|
+
waiter_max_attempts=waiter_max_attempts,
|
150
|
+
aws_conn_id=aws_conn_id,
|
151
|
+
**kwargs,
|
152
|
+
)
|
153
|
+
|
154
|
+
def hook(self) -> AwsGenericHook:
|
155
|
+
return NeptuneHook(
|
156
|
+
aws_conn_id=self.aws_conn_id,
|
157
|
+
region_name=self.region_name,
|
158
|
+
verify=self.verify,
|
159
|
+
config=self.botocore_config,
|
160
|
+
)
|
@@ -121,7 +121,7 @@ class SageMakerTrigger(BaseTrigger):
|
|
121
121
|
status_message=f"{self.job_type} job not done yet",
|
122
122
|
status_args=[self._get_response_status_key(self.job_type)],
|
123
123
|
)
|
124
|
-
yield TriggerEvent({"status": "success", "message": "Job completed."})
|
124
|
+
yield TriggerEvent({"status": "success", "message": "Job completed.", "job_name": self.job_name})
|
125
125
|
|
126
126
|
|
127
127
|
class SageMakerPipelineTrigger(BaseTrigger):
|
@@ -22,6 +22,8 @@ from datetime import datetime, timezone
|
|
22
22
|
from enum import Enum
|
23
23
|
from typing import Any
|
24
24
|
|
25
|
+
import importlib_metadata
|
26
|
+
|
25
27
|
from airflow.exceptions import AirflowException
|
26
28
|
from airflow.utils.helpers import prune_dict
|
27
29
|
from airflow.version import version
|
@@ -74,6 +76,11 @@ def get_airflow_version() -> tuple[int, ...]:
|
|
74
76
|
return tuple(int(x) for x in match.groups())
|
75
77
|
|
76
78
|
|
79
|
+
def get_botocore_version() -> tuple[int, ...]:
|
80
|
+
"""Return the version number of the installed botocore package in the form of a tuple[int,...]."""
|
81
|
+
return tuple(map(int, importlib_metadata.version("botocore").split(".")[:3]))
|
82
|
+
|
83
|
+
|
77
84
|
def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]:
|
78
85
|
if event is None:
|
79
86
|
err_msg = "Trigger error: event is None"
|