apache-airflow-providers-amazon 8.24.0rc1__py3-none-any.whl → 8.25.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +4 -4
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
- airflow/providers/amazon/aws/hooks/comprehend.py +33 -0
- airflow/providers/amazon/aws/hooks/glue.py +123 -0
- airflow/providers/amazon/aws/hooks/redshift_sql.py +8 -1
- airflow/providers/amazon/aws/operators/bedrock.py +6 -20
- airflow/providers/amazon/aws/operators/comprehend.py +148 -1
- airflow/providers/amazon/aws/operators/emr.py +38 -30
- airflow/providers/amazon/aws/operators/glue.py +408 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
- airflow/providers/amazon/aws/sensors/comprehend.py +112 -1
- airflow/providers/amazon/aws/sensors/glue.py +260 -2
- airflow/providers/amazon/aws/sensors/s3.py +35 -5
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
- airflow/providers/amazon/aws/triggers/comprehend.py +36 -0
- airflow/providers/amazon/aws/triggers/glue.py +76 -2
- airflow/providers/amazon/aws/utils/__init__.py +2 -3
- airflow/providers/amazon/aws/waiters/comprehend.json +55 -0
- airflow/providers/amazon/aws/waiters/glue.json +98 -0
- airflow/providers/amazon/get_provider_info.py +20 -13
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/METADATA +22 -21
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/RECORD +26 -26
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/entry_points.txt +0 -0
@@ -23,7 +23,10 @@ from airflow.configuration import conf
|
|
23
23
|
from airflow.exceptions import AirflowException, AirflowSkipException
|
24
24
|
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
|
25
25
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
26
|
-
from airflow.providers.amazon.aws.triggers.comprehend import
|
26
|
+
from airflow.providers.amazon.aws.triggers.comprehend import (
|
27
|
+
ComprehendCreateDocumentClassifierCompletedTrigger,
|
28
|
+
ComprehendPiiEntitiesDetectionJobCompletedTrigger,
|
29
|
+
)
|
27
30
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
28
31
|
|
29
32
|
if TYPE_CHECKING:
|
@@ -145,3 +148,111 @@ class ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor
|
|
145
148
|
return self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
|
146
149
|
"PiiEntitiesDetectionJobProperties"
|
147
150
|
]["JobStatus"]
|
151
|
+
|
152
|
+
|
153
|
+
class ComprehendCreateDocumentClassifierCompletedSensor(AwsBaseSensor[ComprehendHook]):
|
154
|
+
"""
|
155
|
+
Poll the state of the document classifier until it reaches a completed state; fails if the job fails.
|
156
|
+
|
157
|
+
.. seealso::
|
158
|
+
For more information on how to use this sensor, take a look at the guide:
|
159
|
+
:ref:`howto/sensor:ComprehendCreateDocumentClassifierCompletedSensor`
|
160
|
+
|
161
|
+
:param document_classifier_arn: The arn of the Comprehend document classifier.
|
162
|
+
:param fail_on_warnings: If set to True, the document classifier training job will throw an error when the
|
163
|
+
status is TRAINED_WITH_WARNING. (default False)
|
164
|
+
|
165
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
166
|
+
module to be installed.
|
167
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
168
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
169
|
+
:param max_retries: Number of times before returning the current state. (default: 75)
|
170
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
171
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
172
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
173
|
+
empty, then default boto3 configuration would be used (and must be
|
174
|
+
maintained on each worker node).
|
175
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
176
|
+
:param verify: Whether to verify SSL certificates. See:
|
177
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
178
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
179
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
180
|
+
"""
|
181
|
+
|
182
|
+
aws_hook_class = ComprehendHook
|
183
|
+
|
184
|
+
INTERMEDIATE_STATES: tuple[str, ...] = (
|
185
|
+
"SUBMITTED",
|
186
|
+
"TRAINING",
|
187
|
+
)
|
188
|
+
FAILURE_STATES: tuple[str, ...] = (
|
189
|
+
"DELETING",
|
190
|
+
"STOP_REQUESTED",
|
191
|
+
"STOPPED",
|
192
|
+
"IN_ERROR",
|
193
|
+
)
|
194
|
+
SUCCESS_STATES: tuple[str, ...] = ("TRAINED", "TRAINED_WITH_WARNING")
|
195
|
+
FAILURE_MESSAGE = "Comprehend document classifier failed."
|
196
|
+
|
197
|
+
template_fields: Sequence[str] = aws_template_fields("document_classifier_arn")
|
198
|
+
|
199
|
+
def __init__(
|
200
|
+
self,
|
201
|
+
*,
|
202
|
+
document_classifier_arn: str,
|
203
|
+
fail_on_warnings: bool = False,
|
204
|
+
max_retries: int = 75,
|
205
|
+
poke_interval: int = 120,
|
206
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
207
|
+
aws_conn_id: str | None = "aws_default",
|
208
|
+
**kwargs: Any,
|
209
|
+
) -> None:
|
210
|
+
super().__init__(**kwargs)
|
211
|
+
self.document_classifier_arn = document_classifier_arn
|
212
|
+
self.fail_on_warnings = fail_on_warnings
|
213
|
+
self.max_retries = max_retries
|
214
|
+
self.poke_interval = poke_interval
|
215
|
+
self.deferrable = deferrable
|
216
|
+
self.aws_conn_id = aws_conn_id
|
217
|
+
|
218
|
+
def execute(self, context: Context) -> Any:
|
219
|
+
if self.deferrable:
|
220
|
+
self.defer(
|
221
|
+
trigger=ComprehendCreateDocumentClassifierCompletedTrigger(
|
222
|
+
document_classifier_arn=self.document_classifier_arn,
|
223
|
+
waiter_delay=int(self.poke_interval),
|
224
|
+
waiter_max_attempts=self.max_retries,
|
225
|
+
aws_conn_id=self.aws_conn_id,
|
226
|
+
),
|
227
|
+
method_name="poke",
|
228
|
+
)
|
229
|
+
else:
|
230
|
+
super().execute(context=context)
|
231
|
+
|
232
|
+
def poke(self, context: Context, **kwargs) -> bool:
|
233
|
+
status = self.hook.conn.describe_document_classifier(
|
234
|
+
DocumentClassifierArn=self.document_classifier_arn
|
235
|
+
)["DocumentClassifierProperties"]["Status"]
|
236
|
+
|
237
|
+
self.log.info(
|
238
|
+
"Poking for AWS Comprehend document classifier arn: %s status: %s",
|
239
|
+
self.document_classifier_arn,
|
240
|
+
status,
|
241
|
+
)
|
242
|
+
|
243
|
+
if status in self.FAILURE_STATES:
|
244
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
245
|
+
if self.soft_fail:
|
246
|
+
raise AirflowSkipException(self.FAILURE_MESSAGE)
|
247
|
+
raise AirflowException(self.FAILURE_MESSAGE)
|
248
|
+
|
249
|
+
if status in self.SUCCESS_STATES:
|
250
|
+
self.hook.validate_document_classifier_training_status(
|
251
|
+
document_classifier_arn=self.document_classifier_arn, fail_on_warnings=self.fail_on_warnings
|
252
|
+
)
|
253
|
+
|
254
|
+
self.log.info("Comprehend document classifier `%s` complete.", self.document_classifier_arn)
|
255
|
+
|
256
|
+
return True
|
257
|
+
|
258
|
+
return False
|
@@ -18,10 +18,18 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from functools import cached_property
|
21
|
-
from typing import TYPE_CHECKING, Sequence
|
21
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
22
22
|
|
23
|
+
from airflow.configuration import conf
|
23
24
|
from airflow.exceptions import AirflowException, AirflowSkipException
|
24
|
-
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
|
25
|
+
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
|
26
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
|
+
from airflow.providers.amazon.aws.triggers.glue import (
|
28
|
+
GlueDataQualityRuleRecommendationRunCompleteTrigger,
|
29
|
+
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
|
30
|
+
)
|
31
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
32
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
25
33
|
from airflow.sensors.base import BaseSensorOperator
|
26
34
|
|
27
35
|
if TYPE_CHECKING:
|
@@ -91,3 +99,253 @@ class GlueJobSensor(BaseSensorOperator):
|
|
91
99
|
run_id=self.run_id,
|
92
100
|
continuation_tokens=self.next_log_tokens,
|
93
101
|
)
|
102
|
+
|
103
|
+
|
104
|
+
class GlueDataQualityRuleSetEvaluationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
|
105
|
+
"""
|
106
|
+
Waits for an AWS Glue data quality ruleset evaluation run to reach any of the status below.
|
107
|
+
|
108
|
+
'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
|
109
|
+
|
110
|
+
.. seealso::
|
111
|
+
For more information on how to use this sensor, take a look at the guide:
|
112
|
+
:ref:`howto/sensor:GlueDataQualityRuleSetEvaluationRunSensor`
|
113
|
+
|
114
|
+
:param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
|
115
|
+
:param verify_result_status: Validate all the ruleset rules evaluation run results,
|
116
|
+
If any of the rule status is Fail or Error then an exception is thrown. (default: True)
|
117
|
+
:param show_results: Displays all the ruleset rules evaluation run results. (default: True)
|
118
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
119
|
+
module to be installed.
|
120
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
121
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
122
|
+
:param max_retries: Number of times before returning the current state. (default: 60)
|
123
|
+
|
124
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
125
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
126
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
127
|
+
empty, then default boto3 configuration would be used (and must be
|
128
|
+
maintained on each worker node).
|
129
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
130
|
+
:param verify: Whether to verify SSL certificates. See:
|
131
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
132
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
133
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
134
|
+
"""
|
135
|
+
|
136
|
+
SUCCESS_STATES = ("SUCCEEDED",)
|
137
|
+
|
138
|
+
FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
|
139
|
+
|
140
|
+
aws_hook_class = GlueDataQualityHook
|
141
|
+
template_fields: Sequence[str] = aws_template_fields("evaluation_run_id")
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
*,
|
146
|
+
evaluation_run_id: str,
|
147
|
+
show_results: bool = True,
|
148
|
+
verify_result_status: bool = True,
|
149
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
150
|
+
poke_interval: int = 120,
|
151
|
+
max_retries: int = 60,
|
152
|
+
aws_conn_id: str | None = "aws_default",
|
153
|
+
**kwargs,
|
154
|
+
):
|
155
|
+
super().__init__(**kwargs)
|
156
|
+
self.evaluation_run_id = evaluation_run_id
|
157
|
+
self.show_results = show_results
|
158
|
+
self.verify_result_status = verify_result_status
|
159
|
+
self.aws_conn_id = aws_conn_id
|
160
|
+
self.max_retries = max_retries
|
161
|
+
self.poke_interval = poke_interval
|
162
|
+
self.deferrable = deferrable
|
163
|
+
|
164
|
+
def execute(self, context: Context) -> Any:
|
165
|
+
if self.deferrable:
|
166
|
+
self.defer(
|
167
|
+
trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
|
168
|
+
evaluation_run_id=self.evaluation_run_id,
|
169
|
+
waiter_delay=int(self.poke_interval),
|
170
|
+
waiter_max_attempts=self.max_retries,
|
171
|
+
aws_conn_id=self.aws_conn_id,
|
172
|
+
),
|
173
|
+
method_name="execute_complete",
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
super().execute(context=context)
|
177
|
+
|
178
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
179
|
+
event = validate_execute_complete_event(event)
|
180
|
+
|
181
|
+
if event["status"] != "success":
|
182
|
+
message = f"Error: AWS Glue data quality ruleset evaluation run: {event}"
|
183
|
+
if self.soft_fail:
|
184
|
+
raise AirflowSkipException(message)
|
185
|
+
raise AirflowException(message)
|
186
|
+
|
187
|
+
self.hook.validate_evaluation_run_results(
|
188
|
+
evaluation_run_id=event["evaluation_run_id"],
|
189
|
+
show_results=self.show_results,
|
190
|
+
verify_result_status=self.verify_result_status,
|
191
|
+
)
|
192
|
+
|
193
|
+
self.log.info("AWS Glue data quality ruleset evaluation run completed.")
|
194
|
+
|
195
|
+
def poke(self, context: Context):
|
196
|
+
self.log.info(
|
197
|
+
"Poking for AWS Glue data quality ruleset evaluation run RunId: %s", self.evaluation_run_id
|
198
|
+
)
|
199
|
+
|
200
|
+
response = self.hook.conn.get_data_quality_ruleset_evaluation_run(RunId=self.evaluation_run_id)
|
201
|
+
|
202
|
+
status = response.get("Status")
|
203
|
+
|
204
|
+
if status in self.SUCCESS_STATES:
|
205
|
+
self.hook.validate_evaluation_run_results(
|
206
|
+
evaluation_run_id=self.evaluation_run_id,
|
207
|
+
show_results=self.show_results,
|
208
|
+
verify_result_status=self.verify_result_status,
|
209
|
+
)
|
210
|
+
|
211
|
+
self.log.info(
|
212
|
+
"AWS Glue data quality ruleset evaluation run completed RunId: %s Run State: %s",
|
213
|
+
self.evaluation_run_id,
|
214
|
+
response["Status"],
|
215
|
+
)
|
216
|
+
|
217
|
+
return True
|
218
|
+
|
219
|
+
elif status in self.FAILURE_STATES:
|
220
|
+
job_error_message = (
|
221
|
+
f"Error: AWS Glue data quality ruleset evaluation run RunId: {self.evaluation_run_id} Run "
|
222
|
+
f"Status: {status}"
|
223
|
+
f": {response.get('ErrorString')}"
|
224
|
+
)
|
225
|
+
self.log.info(job_error_message)
|
226
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
227
|
+
if self.soft_fail:
|
228
|
+
raise AirflowSkipException(job_error_message)
|
229
|
+
raise AirflowException(job_error_message)
|
230
|
+
else:
|
231
|
+
return False
|
232
|
+
|
233
|
+
|
234
|
+
class GlueDataQualityRuleRecommendationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
|
235
|
+
"""
|
236
|
+
Waits for an AWS Glue data quality recommendation run to reach any of the status below.
|
237
|
+
|
238
|
+
'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
|
239
|
+
|
240
|
+
.. seealso::
|
241
|
+
For more information on how to use this sensor, take a look at the guide:
|
242
|
+
:ref:`howto/sensor:GlueDataQualityRuleRecommendationRunSensor`
|
243
|
+
|
244
|
+
:param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
|
245
|
+
:param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
|
246
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
247
|
+
module to be installed.
|
248
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
249
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
250
|
+
:param max_retries: Number of times before returning the current state. (default: 60)
|
251
|
+
|
252
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
253
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
254
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
255
|
+
empty, then default boto3 configuration would be used (and must be
|
256
|
+
maintained on each worker node).
|
257
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
258
|
+
:param verify: Whether to verify SSL certificates. See:
|
259
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
260
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
261
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
262
|
+
"""
|
263
|
+
|
264
|
+
SUCCESS_STATES = ("SUCCEEDED",)
|
265
|
+
|
266
|
+
FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
|
267
|
+
|
268
|
+
aws_hook_class = GlueDataQualityHook
|
269
|
+
template_fields: Sequence[str] = aws_template_fields("recommendation_run_id")
|
270
|
+
|
271
|
+
def __init__(
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
recommendation_run_id: str,
|
275
|
+
show_results: bool = True,
|
276
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
277
|
+
poke_interval: int = 120,
|
278
|
+
max_retries: int = 60,
|
279
|
+
aws_conn_id: str | None = "aws_default",
|
280
|
+
**kwargs,
|
281
|
+
):
|
282
|
+
super().__init__(**kwargs)
|
283
|
+
self.recommendation_run_id = recommendation_run_id
|
284
|
+
self.show_results = show_results
|
285
|
+
self.deferrable = deferrable
|
286
|
+
self.poke_interval = poke_interval
|
287
|
+
self.max_retries = max_retries
|
288
|
+
self.aws_conn_id = aws_conn_id
|
289
|
+
|
290
|
+
def execute(self, context: Context) -> Any:
|
291
|
+
if self.deferrable:
|
292
|
+
self.defer(
|
293
|
+
trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
|
294
|
+
recommendation_run_id=self.recommendation_run_id,
|
295
|
+
waiter_delay=int(self.poke_interval),
|
296
|
+
waiter_max_attempts=self.max_retries,
|
297
|
+
aws_conn_id=self.aws_conn_id,
|
298
|
+
),
|
299
|
+
method_name="execute_complete",
|
300
|
+
)
|
301
|
+
else:
|
302
|
+
super().execute(context=context)
|
303
|
+
|
304
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
305
|
+
event = validate_execute_complete_event(event)
|
306
|
+
|
307
|
+
if event["status"] != "success":
|
308
|
+
message = f"Error: AWS Glue data quality recommendation run: {event}"
|
309
|
+
if self.soft_fail:
|
310
|
+
raise AirflowSkipException(message)
|
311
|
+
raise AirflowException(message)
|
312
|
+
|
313
|
+
if self.show_results:
|
314
|
+
self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
|
315
|
+
|
316
|
+
self.log.info("AWS Glue data quality recommendation run completed.")
|
317
|
+
|
318
|
+
def poke(self, context: Context) -> bool:
|
319
|
+
self.log.info(
|
320
|
+
"Poking for AWS Glue data quality recommendation run RunId: %s", self.recommendation_run_id
|
321
|
+
)
|
322
|
+
|
323
|
+
response = self.hook.conn.get_data_quality_rule_recommendation_run(RunId=self.recommendation_run_id)
|
324
|
+
|
325
|
+
status = response.get("Status")
|
326
|
+
|
327
|
+
if status in self.SUCCESS_STATES:
|
328
|
+
if self.show_results:
|
329
|
+
self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
|
330
|
+
|
331
|
+
self.log.info(
|
332
|
+
"AWS Glue data quality recommendation run completed RunId: %s Run State: %s",
|
333
|
+
self.recommendation_run_id,
|
334
|
+
response["Status"],
|
335
|
+
)
|
336
|
+
|
337
|
+
return True
|
338
|
+
|
339
|
+
elif status in self.FAILURE_STATES:
|
340
|
+
job_error_message = (
|
341
|
+
f"Error: AWS Glue data quality recommendation run RunId: {self.recommendation_run_id} Run "
|
342
|
+
f"Status: {status}"
|
343
|
+
f": {response.get('ErrorString')}"
|
344
|
+
)
|
345
|
+
self.log.info(job_error_message)
|
346
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
347
|
+
if self.soft_fail:
|
348
|
+
raise AirflowSkipException(job_error_message)
|
349
|
+
raise AirflowException(job_error_message)
|
350
|
+
else:
|
351
|
+
return False
|
@@ -78,6 +78,11 @@ class S3KeySensor(BaseSensorOperator):
|
|
78
78
|
CA cert bundle than the one used by botocore.
|
79
79
|
:param deferrable: Run operator in the deferrable mode
|
80
80
|
:param use_regex: whether to use regex to check bucket
|
81
|
+
:param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
|
82
|
+
Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
|
83
|
+
all available attributes.
|
84
|
+
Default value: "Size".
|
85
|
+
If the requested attribute is not found, the key is still included and the value is None.
|
81
86
|
"""
|
82
87
|
|
83
88
|
template_fields: Sequence[str] = ("bucket_key", "bucket_name")
|
@@ -93,6 +98,7 @@ class S3KeySensor(BaseSensorOperator):
|
|
93
98
|
verify: str | bool | None = None,
|
94
99
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
95
100
|
use_regex: bool = False,
|
101
|
+
metadata_keys: list[str] | None = None,
|
96
102
|
**kwargs,
|
97
103
|
):
|
98
104
|
super().__init__(**kwargs)
|
@@ -104,14 +110,14 @@ class S3KeySensor(BaseSensorOperator):
|
|
104
110
|
self.verify = verify
|
105
111
|
self.deferrable = deferrable
|
106
112
|
self.use_regex = use_regex
|
113
|
+
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
|
107
114
|
|
108
115
|
def _check_key(self, key):
|
109
116
|
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
|
110
117
|
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
|
111
118
|
|
112
119
|
"""
|
113
|
-
Set variable `files` which contains a list of dict which contains
|
114
|
-
If needed we might want to add other attributes later
|
120
|
+
Set variable `files` which contains a list of dict which contains attributes defined by the user
|
115
121
|
Format: [{
|
116
122
|
'Size': int
|
117
123
|
}]
|
@@ -123,8 +129,21 @@ class S3KeySensor(BaseSensorOperator):
|
|
123
129
|
if not key_matches:
|
124
130
|
return False
|
125
131
|
|
126
|
-
# Reduce the set of metadata to
|
127
|
-
files = [
|
132
|
+
# Reduce the set of metadata to requested attributes
|
133
|
+
files = []
|
134
|
+
for f in key_matches:
|
135
|
+
metadata = {}
|
136
|
+
if "*" in self.metadata_keys:
|
137
|
+
metadata = self.hook.head_object(f["Key"], bucket_name)
|
138
|
+
else:
|
139
|
+
for key in self.metadata_keys:
|
140
|
+
try:
|
141
|
+
metadata[key] = f[key]
|
142
|
+
except KeyError:
|
143
|
+
# supplied key might be from head_object response
|
144
|
+
self.log.info("Key %s not found in response, performing head_object", key)
|
145
|
+
metadata[key] = self.hook.head_object(f["Key"], bucket_name).get(key, None)
|
146
|
+
files.append(metadata)
|
128
147
|
elif self.use_regex:
|
129
148
|
keys = self.hook.get_file_metadata("", bucket_name)
|
130
149
|
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
|
@@ -134,7 +153,18 @@ class S3KeySensor(BaseSensorOperator):
|
|
134
153
|
obj = self.hook.head_object(key, bucket_name)
|
135
154
|
if obj is None:
|
136
155
|
return False
|
137
|
-
|
156
|
+
metadata = {}
|
157
|
+
if "*" in self.metadata_keys:
|
158
|
+
metadata = self.hook.head_object(key, bucket_name)
|
159
|
+
|
160
|
+
else:
|
161
|
+
for key in self.metadata_keys:
|
162
|
+
# backwards compatibility with original implementation
|
163
|
+
if key == "Size":
|
164
|
+
metadata[key] = obj.get("ContentLength")
|
165
|
+
else:
|
166
|
+
metadata[key] = obj.get(key, None)
|
167
|
+
files = [metadata]
|
138
168
|
|
139
169
|
if self.check_fn is not None:
|
140
170
|
return self.check_fn(files)
|
@@ -128,7 +128,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|
128
128
|
self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
|
129
129
|
) -> str:
|
130
130
|
# Un-escape already escaped queries
|
131
|
-
select_query = re.sub(r"''(
|
131
|
+
select_query = re.sub(r"''(.+?)''", r"'\1'", select_query)
|
132
132
|
return f"""
|
133
133
|
UNLOAD ($${select_query}$$)
|
134
134
|
TO 's3://{self.s3_bucket}/{s3_key}'
|
@@ -59,3 +59,39 @@ class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
|
|
59
59
|
|
60
60
|
def hook(self) -> AwsGenericHook:
|
61
61
|
return ComprehendHook(aws_conn_id=self.aws_conn_id)
|
62
|
+
|
63
|
+
|
64
|
+
class ComprehendCreateDocumentClassifierCompletedTrigger(AwsBaseWaiterTrigger):
|
65
|
+
"""
|
66
|
+
Trigger when a Comprehend document classifier is complete.
|
67
|
+
|
68
|
+
:param document_classifier_arn: The arn of the Comprehend document classifier.
|
69
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
|
70
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
71
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
*,
|
77
|
+
document_classifier_arn: str,
|
78
|
+
waiter_delay: int = 120,
|
79
|
+
waiter_max_attempts: int = 75,
|
80
|
+
aws_conn_id: str | None = "aws_default",
|
81
|
+
) -> None:
|
82
|
+
super().__init__(
|
83
|
+
serialized_fields={"document_classifier_arn": document_classifier_arn},
|
84
|
+
waiter_name="create_document_classifier_complete",
|
85
|
+
waiter_args={"DocumentClassifierArn": document_classifier_arn},
|
86
|
+
failure_message="Comprehend create document classifier failed.",
|
87
|
+
status_message="Status of Comprehend create document classifier is",
|
88
|
+
status_queries=["DocumentClassifierProperties.Status"],
|
89
|
+
return_key="document_classifier_arn",
|
90
|
+
return_value=document_classifier_arn,
|
91
|
+
waiter_delay=waiter_delay,
|
92
|
+
waiter_max_attempts=waiter_max_attempts,
|
93
|
+
aws_conn_id=aws_conn_id,
|
94
|
+
)
|
95
|
+
|
96
|
+
def hook(self) -> AwsGenericHook:
|
97
|
+
return ComprehendHook(aws_conn_id=self.aws_conn_id)
|
@@ -19,10 +19,14 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import asyncio
|
21
21
|
from functools import cached_property
|
22
|
-
from typing import Any, AsyncIterator
|
22
|
+
from typing import TYPE_CHECKING, Any, AsyncIterator
|
23
23
|
|
24
|
-
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
|
+
|
27
|
+
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
|
25
28
|
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
|
29
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
26
30
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
27
31
|
|
28
32
|
|
@@ -148,3 +152,73 @@ class GlueCatalogPartitionTrigger(BaseTrigger):
|
|
148
152
|
break
|
149
153
|
else:
|
150
154
|
await asyncio.sleep(self.waiter_delay)
|
155
|
+
|
156
|
+
|
157
|
+
class GlueDataQualityRuleSetEvaluationRunCompleteTrigger(AwsBaseWaiterTrigger):
|
158
|
+
"""
|
159
|
+
Trigger when a AWS Glue data quality evaluation run complete.
|
160
|
+
|
161
|
+
:param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
|
162
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
163
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
164
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
165
|
+
"""
|
166
|
+
|
167
|
+
def __init__(
|
168
|
+
self,
|
169
|
+
evaluation_run_id: str,
|
170
|
+
waiter_delay: int = 60,
|
171
|
+
waiter_max_attempts: int = 75,
|
172
|
+
aws_conn_id: str | None = "aws_default",
|
173
|
+
):
|
174
|
+
super().__init__(
|
175
|
+
serialized_fields={"evaluation_run_id": evaluation_run_id},
|
176
|
+
waiter_name="data_quality_ruleset_evaluation_run_complete",
|
177
|
+
waiter_args={"RunId": evaluation_run_id},
|
178
|
+
failure_message="AWS Glue data quality ruleset evaluation run failed.",
|
179
|
+
status_message="Status of AWS Glue data quality ruleset evaluation run is",
|
180
|
+
status_queries=["Status"],
|
181
|
+
return_key="evaluation_run_id",
|
182
|
+
return_value=evaluation_run_id,
|
183
|
+
waiter_delay=waiter_delay,
|
184
|
+
waiter_max_attempts=waiter_max_attempts,
|
185
|
+
aws_conn_id=aws_conn_id,
|
186
|
+
)
|
187
|
+
|
188
|
+
def hook(self) -> AwsGenericHook:
|
189
|
+
return GlueDataQualityHook(aws_conn_id=self.aws_conn_id)
|
190
|
+
|
191
|
+
|
192
|
+
class GlueDataQualityRuleRecommendationRunCompleteTrigger(AwsBaseWaiterTrigger):
|
193
|
+
"""
|
194
|
+
Trigger when a AWS Glue data quality recommendation run complete.
|
195
|
+
|
196
|
+
:param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
|
197
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
198
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
199
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
200
|
+
"""
|
201
|
+
|
202
|
+
def __init__(
|
203
|
+
self,
|
204
|
+
recommendation_run_id: str,
|
205
|
+
waiter_delay: int = 60,
|
206
|
+
waiter_max_attempts: int = 75,
|
207
|
+
aws_conn_id: str | None = "aws_default",
|
208
|
+
):
|
209
|
+
super().__init__(
|
210
|
+
serialized_fields={"recommendation_run_id": recommendation_run_id},
|
211
|
+
waiter_name="data_quality_rule_recommendation_run_complete",
|
212
|
+
waiter_args={"RunId": recommendation_run_id},
|
213
|
+
failure_message="AWS Glue data quality recommendation run failed.",
|
214
|
+
status_message="Status of AWS Glue data quality recommendation run is",
|
215
|
+
status_queries=["Status"],
|
216
|
+
return_key="recommendation_run_id",
|
217
|
+
return_value=recommendation_run_id,
|
218
|
+
waiter_delay=waiter_delay,
|
219
|
+
waiter_max_attempts=waiter_max_attempts,
|
220
|
+
aws_conn_id=aws_conn_id,
|
221
|
+
)
|
222
|
+
|
223
|
+
def hook(self) -> AwsGenericHook:
|
224
|
+
return GlueDataQualityHook(aws_conn_id=self.aws_conn_id)
|
@@ -20,10 +20,9 @@ import logging
|
|
20
20
|
import re
|
21
21
|
from datetime import datetime, timezone
|
22
22
|
from enum import Enum
|
23
|
+
from importlib import metadata
|
23
24
|
from typing import Any
|
24
25
|
|
25
|
-
import importlib_metadata
|
26
|
-
|
27
26
|
from airflow.exceptions import AirflowException
|
28
27
|
from airflow.utils.helpers import prune_dict
|
29
28
|
from airflow.version import version
|
@@ -78,7 +77,7 @@ def get_airflow_version() -> tuple[int, ...]:
|
|
78
77
|
|
79
78
|
def get_botocore_version() -> tuple[int, ...]:
|
80
79
|
"""Return the version number of the installed botocore package in the form of a tuple[int,...]."""
|
81
|
-
return tuple(map(int,
|
80
|
+
return tuple(map(int, metadata.version("botocore").split(".")[:3]))
|
82
81
|
|
83
82
|
|
84
83
|
def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]:
|