apache-airflow-providers-amazon 8.24.0rc1__py3-none-any.whl → 8.25.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. airflow/providers/amazon/LICENSE +4 -4
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
  4. airflow/providers/amazon/aws/hooks/comprehend.py +33 -0
  5. airflow/providers/amazon/aws/hooks/glue.py +123 -0
  6. airflow/providers/amazon/aws/hooks/redshift_sql.py +8 -1
  7. airflow/providers/amazon/aws/operators/bedrock.py +6 -20
  8. airflow/providers/amazon/aws/operators/comprehend.py +148 -1
  9. airflow/providers/amazon/aws/operators/emr.py +38 -30
  10. airflow/providers/amazon/aws/operators/glue.py +408 -2
  11. airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
  12. airflow/providers/amazon/aws/sensors/comprehend.py +112 -1
  13. airflow/providers/amazon/aws/sensors/glue.py +260 -2
  14. airflow/providers/amazon/aws/sensors/s3.py +35 -5
  15. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
  16. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +1 -1
  17. airflow/providers/amazon/aws/triggers/comprehend.py +36 -0
  18. airflow/providers/amazon/aws/triggers/glue.py +76 -2
  19. airflow/providers/amazon/aws/utils/__init__.py +2 -3
  20. airflow/providers/amazon/aws/waiters/comprehend.json +55 -0
  21. airflow/providers/amazon/aws/waiters/glue.json +98 -0
  22. airflow/providers/amazon/get_provider_info.py +20 -13
  23. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/METADATA +22 -21
  24. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/RECORD +26 -26
  25. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/WHEEL +0 -0
  26. {apache_airflow_providers_amazon-8.24.0rc1.dist-info → apache_airflow_providers_amazon-8.25.0.dist-info}/entry_points.txt +0 -0
@@ -23,7 +23,10 @@ from airflow.configuration import conf
23
23
  from airflow.exceptions import AirflowException, AirflowSkipException
24
24
  from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
25
25
  from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
26
- from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger
26
+ from airflow.providers.amazon.aws.triggers.comprehend import (
27
+ ComprehendCreateDocumentClassifierCompletedTrigger,
28
+ ComprehendPiiEntitiesDetectionJobCompletedTrigger,
29
+ )
27
30
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
28
31
 
29
32
  if TYPE_CHECKING:
@@ -145,3 +148,111 @@ class ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor
145
148
  return self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[
146
149
  "PiiEntitiesDetectionJobProperties"
147
150
  ]["JobStatus"]
151
+
152
+
153
+ class ComprehendCreateDocumentClassifierCompletedSensor(AwsBaseSensor[ComprehendHook]):
154
+ """
155
+ Poll the state of the document classifier until it reaches a completed state; fails if the job fails.
156
+
157
+ .. seealso::
158
+ For more information on how to use this sensor, take a look at the guide:
159
+ :ref:`howto/sensor:ComprehendCreateDocumentClassifierCompletedSensor`
160
+
161
+ :param document_classifier_arn: The arn of the Comprehend document classifier.
162
+ :param fail_on_warnings: If set to True, the document classifier training job will throw an error when the
163
+ status is TRAINED_WITH_WARNING. (default False)
164
+
165
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
166
+ module to be installed.
167
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
168
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
169
+ :param max_retries: Number of times before returning the current state. (default: 75)
170
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
171
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
172
+ running Airflow in a distributed manner and aws_conn_id is None or
173
+ empty, then default boto3 configuration would be used (and must be
174
+ maintained on each worker node).
175
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
176
+ :param verify: Whether to verify SSL certificates. See:
177
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
178
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
179
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
180
+ """
181
+
182
+ aws_hook_class = ComprehendHook
183
+
184
+ INTERMEDIATE_STATES: tuple[str, ...] = (
185
+ "SUBMITTED",
186
+ "TRAINING",
187
+ )
188
+ FAILURE_STATES: tuple[str, ...] = (
189
+ "DELETING",
190
+ "STOP_REQUESTED",
191
+ "STOPPED",
192
+ "IN_ERROR",
193
+ )
194
+ SUCCESS_STATES: tuple[str, ...] = ("TRAINED", "TRAINED_WITH_WARNING")
195
+ FAILURE_MESSAGE = "Comprehend document classifier failed."
196
+
197
+ template_fields: Sequence[str] = aws_template_fields("document_classifier_arn")
198
+
199
+ def __init__(
200
+ self,
201
+ *,
202
+ document_classifier_arn: str,
203
+ fail_on_warnings: bool = False,
204
+ max_retries: int = 75,
205
+ poke_interval: int = 120,
206
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
207
+ aws_conn_id: str | None = "aws_default",
208
+ **kwargs: Any,
209
+ ) -> None:
210
+ super().__init__(**kwargs)
211
+ self.document_classifier_arn = document_classifier_arn
212
+ self.fail_on_warnings = fail_on_warnings
213
+ self.max_retries = max_retries
214
+ self.poke_interval = poke_interval
215
+ self.deferrable = deferrable
216
+ self.aws_conn_id = aws_conn_id
217
+
218
+ def execute(self, context: Context) -> Any:
219
+ if self.deferrable:
220
+ self.defer(
221
+ trigger=ComprehendCreateDocumentClassifierCompletedTrigger(
222
+ document_classifier_arn=self.document_classifier_arn,
223
+ waiter_delay=int(self.poke_interval),
224
+ waiter_max_attempts=self.max_retries,
225
+ aws_conn_id=self.aws_conn_id,
226
+ ),
227
+ method_name="poke",
228
+ )
229
+ else:
230
+ super().execute(context=context)
231
+
232
+ def poke(self, context: Context, **kwargs) -> bool:
233
+ status = self.hook.conn.describe_document_classifier(
234
+ DocumentClassifierArn=self.document_classifier_arn
235
+ )["DocumentClassifierProperties"]["Status"]
236
+
237
+ self.log.info(
238
+ "Poking for AWS Comprehend document classifier arn: %s status: %s",
239
+ self.document_classifier_arn,
240
+ status,
241
+ )
242
+
243
+ if status in self.FAILURE_STATES:
244
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
245
+ if self.soft_fail:
246
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
247
+ raise AirflowException(self.FAILURE_MESSAGE)
248
+
249
+ if status in self.SUCCESS_STATES:
250
+ self.hook.validate_document_classifier_training_status(
251
+ document_classifier_arn=self.document_classifier_arn, fail_on_warnings=self.fail_on_warnings
252
+ )
253
+
254
+ self.log.info("Comprehend document classifier `%s` complete.", self.document_classifier_arn)
255
+
256
+ return True
257
+
258
+ return False
@@ -18,10 +18,18 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from functools import cached_property
21
- from typing import TYPE_CHECKING, Sequence
21
+ from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
+ from airflow.configuration import conf
23
24
  from airflow.exceptions import AirflowException, AirflowSkipException
24
- from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
25
+ from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
26
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
+ from airflow.providers.amazon.aws.triggers.glue import (
28
+ GlueDataQualityRuleRecommendationRunCompleteTrigger,
29
+ GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
30
+ )
31
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
32
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
25
33
  from airflow.sensors.base import BaseSensorOperator
26
34
 
27
35
  if TYPE_CHECKING:
@@ -91,3 +99,253 @@ class GlueJobSensor(BaseSensorOperator):
91
99
  run_id=self.run_id,
92
100
  continuation_tokens=self.next_log_tokens,
93
101
  )
102
+
103
+
104
+ class GlueDataQualityRuleSetEvaluationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
105
+ """
106
+ Waits for an AWS Glue data quality ruleset evaluation run to reach any of the status below.
107
+
108
+ 'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
109
+
110
+ .. seealso::
111
+ For more information on how to use this sensor, take a look at the guide:
112
+ :ref:`howto/sensor:GlueDataQualityRuleSetEvaluationRunSensor`
113
+
114
+ :param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
115
+ :param verify_result_status: Validate all the ruleset rules evaluation run results,
116
+ If any of the rule status is Fail or Error then an exception is thrown. (default: True)
117
+ :param show_results: Displays all the ruleset rules evaluation run results. (default: True)
118
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
119
+ module to be installed.
120
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
121
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
122
+ :param max_retries: Number of times before returning the current state. (default: 60)
123
+
124
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
125
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
126
+ running Airflow in a distributed manner and aws_conn_id is None or
127
+ empty, then default boto3 configuration would be used (and must be
128
+ maintained on each worker node).
129
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
130
+ :param verify: Whether to verify SSL certificates. See:
131
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
132
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
133
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
134
+ """
135
+
136
+ SUCCESS_STATES = ("SUCCEEDED",)
137
+
138
+ FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
139
+
140
+ aws_hook_class = GlueDataQualityHook
141
+ template_fields: Sequence[str] = aws_template_fields("evaluation_run_id")
142
+
143
+ def __init__(
144
+ self,
145
+ *,
146
+ evaluation_run_id: str,
147
+ show_results: bool = True,
148
+ verify_result_status: bool = True,
149
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
150
+ poke_interval: int = 120,
151
+ max_retries: int = 60,
152
+ aws_conn_id: str | None = "aws_default",
153
+ **kwargs,
154
+ ):
155
+ super().__init__(**kwargs)
156
+ self.evaluation_run_id = evaluation_run_id
157
+ self.show_results = show_results
158
+ self.verify_result_status = verify_result_status
159
+ self.aws_conn_id = aws_conn_id
160
+ self.max_retries = max_retries
161
+ self.poke_interval = poke_interval
162
+ self.deferrable = deferrable
163
+
164
+ def execute(self, context: Context) -> Any:
165
+ if self.deferrable:
166
+ self.defer(
167
+ trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
168
+ evaluation_run_id=self.evaluation_run_id,
169
+ waiter_delay=int(self.poke_interval),
170
+ waiter_max_attempts=self.max_retries,
171
+ aws_conn_id=self.aws_conn_id,
172
+ ),
173
+ method_name="execute_complete",
174
+ )
175
+ else:
176
+ super().execute(context=context)
177
+
178
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
179
+ event = validate_execute_complete_event(event)
180
+
181
+ if event["status"] != "success":
182
+ message = f"Error: AWS Glue data quality ruleset evaluation run: {event}"
183
+ if self.soft_fail:
184
+ raise AirflowSkipException(message)
185
+ raise AirflowException(message)
186
+
187
+ self.hook.validate_evaluation_run_results(
188
+ evaluation_run_id=event["evaluation_run_id"],
189
+ show_results=self.show_results,
190
+ verify_result_status=self.verify_result_status,
191
+ )
192
+
193
+ self.log.info("AWS Glue data quality ruleset evaluation run completed.")
194
+
195
+ def poke(self, context: Context):
196
+ self.log.info(
197
+ "Poking for AWS Glue data quality ruleset evaluation run RunId: %s", self.evaluation_run_id
198
+ )
199
+
200
+ response = self.hook.conn.get_data_quality_ruleset_evaluation_run(RunId=self.evaluation_run_id)
201
+
202
+ status = response.get("Status")
203
+
204
+ if status in self.SUCCESS_STATES:
205
+ self.hook.validate_evaluation_run_results(
206
+ evaluation_run_id=self.evaluation_run_id,
207
+ show_results=self.show_results,
208
+ verify_result_status=self.verify_result_status,
209
+ )
210
+
211
+ self.log.info(
212
+ "AWS Glue data quality ruleset evaluation run completed RunId: %s Run State: %s",
213
+ self.evaluation_run_id,
214
+ response["Status"],
215
+ )
216
+
217
+ return True
218
+
219
+ elif status in self.FAILURE_STATES:
220
+ job_error_message = (
221
+ f"Error: AWS Glue data quality ruleset evaluation run RunId: {self.evaluation_run_id} Run "
222
+ f"Status: {status}"
223
+ f": {response.get('ErrorString')}"
224
+ )
225
+ self.log.info(job_error_message)
226
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
227
+ if self.soft_fail:
228
+ raise AirflowSkipException(job_error_message)
229
+ raise AirflowException(job_error_message)
230
+ else:
231
+ return False
232
+
233
+
234
+ class GlueDataQualityRuleRecommendationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
235
+ """
236
+ Waits for an AWS Glue data quality recommendation run to reach any of the status below.
237
+
238
+ 'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
239
+
240
+ .. seealso::
241
+ For more information on how to use this sensor, take a look at the guide:
242
+ :ref:`howto/sensor:GlueDataQualityRuleRecommendationRunSensor`
243
+
244
+ :param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
245
+ :param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
246
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
247
+ module to be installed.
248
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
249
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
250
+ :param max_retries: Number of times before returning the current state. (default: 60)
251
+
252
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
253
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
254
+ running Airflow in a distributed manner and aws_conn_id is None or
255
+ empty, then default boto3 configuration would be used (and must be
256
+ maintained on each worker node).
257
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
258
+ :param verify: Whether to verify SSL certificates. See:
259
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
260
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
261
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
262
+ """
263
+
264
+ SUCCESS_STATES = ("SUCCEEDED",)
265
+
266
+ FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
267
+
268
+ aws_hook_class = GlueDataQualityHook
269
+ template_fields: Sequence[str] = aws_template_fields("recommendation_run_id")
270
+
271
+ def __init__(
272
+ self,
273
+ *,
274
+ recommendation_run_id: str,
275
+ show_results: bool = True,
276
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
277
+ poke_interval: int = 120,
278
+ max_retries: int = 60,
279
+ aws_conn_id: str | None = "aws_default",
280
+ **kwargs,
281
+ ):
282
+ super().__init__(**kwargs)
283
+ self.recommendation_run_id = recommendation_run_id
284
+ self.show_results = show_results
285
+ self.deferrable = deferrable
286
+ self.poke_interval = poke_interval
287
+ self.max_retries = max_retries
288
+ self.aws_conn_id = aws_conn_id
289
+
290
+ def execute(self, context: Context) -> Any:
291
+ if self.deferrable:
292
+ self.defer(
293
+ trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
294
+ recommendation_run_id=self.recommendation_run_id,
295
+ waiter_delay=int(self.poke_interval),
296
+ waiter_max_attempts=self.max_retries,
297
+ aws_conn_id=self.aws_conn_id,
298
+ ),
299
+ method_name="execute_complete",
300
+ )
301
+ else:
302
+ super().execute(context=context)
303
+
304
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
305
+ event = validate_execute_complete_event(event)
306
+
307
+ if event["status"] != "success":
308
+ message = f"Error: AWS Glue data quality recommendation run: {event}"
309
+ if self.soft_fail:
310
+ raise AirflowSkipException(message)
311
+ raise AirflowException(message)
312
+
313
+ if self.show_results:
314
+ self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
315
+
316
+ self.log.info("AWS Glue data quality recommendation run completed.")
317
+
318
+ def poke(self, context: Context) -> bool:
319
+ self.log.info(
320
+ "Poking for AWS Glue data quality recommendation run RunId: %s", self.recommendation_run_id
321
+ )
322
+
323
+ response = self.hook.conn.get_data_quality_rule_recommendation_run(RunId=self.recommendation_run_id)
324
+
325
+ status = response.get("Status")
326
+
327
+ if status in self.SUCCESS_STATES:
328
+ if self.show_results:
329
+ self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
330
+
331
+ self.log.info(
332
+ "AWS Glue data quality recommendation run completed RunId: %s Run State: %s",
333
+ self.recommendation_run_id,
334
+ response["Status"],
335
+ )
336
+
337
+ return True
338
+
339
+ elif status in self.FAILURE_STATES:
340
+ job_error_message = (
341
+ f"Error: AWS Glue data quality recommendation run RunId: {self.recommendation_run_id} Run "
342
+ f"Status: {status}"
343
+ f": {response.get('ErrorString')}"
344
+ )
345
+ self.log.info(job_error_message)
346
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
347
+ if self.soft_fail:
348
+ raise AirflowSkipException(job_error_message)
349
+ raise AirflowException(job_error_message)
350
+ else:
351
+ return False
@@ -78,6 +78,11 @@ class S3KeySensor(BaseSensorOperator):
78
78
  CA cert bundle than the one used by botocore.
79
79
  :param deferrable: Run operator in the deferrable mode
80
80
  :param use_regex: whether to use regex to check bucket
81
+ :param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
82
+ Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
83
+ all available attributes.
84
+ Default value: "Size".
85
+ If the requested attribute is not found, the key is still included and the value is None.
81
86
  """
82
87
 
83
88
  template_fields: Sequence[str] = ("bucket_key", "bucket_name")
@@ -93,6 +98,7 @@ class S3KeySensor(BaseSensorOperator):
93
98
  verify: str | bool | None = None,
94
99
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
95
100
  use_regex: bool = False,
101
+ metadata_keys: list[str] | None = None,
96
102
  **kwargs,
97
103
  ):
98
104
  super().__init__(**kwargs)
@@ -104,14 +110,14 @@ class S3KeySensor(BaseSensorOperator):
104
110
  self.verify = verify
105
111
  self.deferrable = deferrable
106
112
  self.use_regex = use_regex
113
+ self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
107
114
 
108
115
  def _check_key(self, key):
109
116
  bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
110
117
  self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
111
118
 
112
119
  """
113
- Set variable `files` which contains a list of dict which contains only the size
114
- If needed we might want to add other attributes later
120
+ Set variable `files` which contains a list of dict which contains attributes defined by the user
115
121
  Format: [{
116
122
  'Size': int
117
123
  }]
@@ -123,8 +129,21 @@ class S3KeySensor(BaseSensorOperator):
123
129
  if not key_matches:
124
130
  return False
125
131
 
126
- # Reduce the set of metadata to size only
127
- files = [{"Size": f["Size"]} for f in key_matches]
132
+ # Reduce the set of metadata to requested attributes
133
+ files = []
134
+ for f in key_matches:
135
+ metadata = {}
136
+ if "*" in self.metadata_keys:
137
+ metadata = self.hook.head_object(f["Key"], bucket_name)
138
+ else:
139
+ for key in self.metadata_keys:
140
+ try:
141
+ metadata[key] = f[key]
142
+ except KeyError:
143
+ # supplied key might be from head_object response
144
+ self.log.info("Key %s not found in response, performing head_object", key)
145
+ metadata[key] = self.hook.head_object(f["Key"], bucket_name).get(key, None)
146
+ files.append(metadata)
128
147
  elif self.use_regex:
129
148
  keys = self.hook.get_file_metadata("", bucket_name)
130
149
  key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
@@ -134,7 +153,18 @@ class S3KeySensor(BaseSensorOperator):
134
153
  obj = self.hook.head_object(key, bucket_name)
135
154
  if obj is None:
136
155
  return False
137
- files = [{"Size": obj["ContentLength"]}]
156
+ metadata = {}
157
+ if "*" in self.metadata_keys:
158
+ metadata = self.hook.head_object(key, bucket_name)
159
+
160
+ else:
161
+ for key in self.metadata_keys:
162
+ # backwards compatibility with original implementation
163
+ if key == "Size":
164
+ metadata[key] = obj.get("ContentLength")
165
+ else:
166
+ metadata[key] = obj.get(key, None)
167
+ files = [metadata]
138
168
 
139
169
  if self.check_fn is not None:
140
170
  return self.check_fn(files)
@@ -105,7 +105,6 @@ class DynamoDBToS3Operator(AwsToAwsBaseOperator):
105
105
  "file_size",
106
106
  "dynamodb_scan_kwargs",
107
107
  "s3_key_prefix",
108
- "process_func",
109
108
  "export_time",
110
109
  "export_format",
111
110
  "check_interval",
@@ -128,7 +128,7 @@ class RedshiftToS3Operator(BaseOperator):
128
128
  self, credentials_block: str, select_query: str, s3_key: str, unload_options: str
129
129
  ) -> str:
130
130
  # Un-escape already escaped queries
131
- select_query = re.sub(r"''(.+)''", r"'\1'", select_query)
131
+ select_query = re.sub(r"''(.+?)''", r"'\1'", select_query)
132
132
  return f"""
133
133
  UNLOAD ($${select_query}$$)
134
134
  TO 's3://{self.s3_bucket}/{s3_key}'
@@ -59,3 +59,39 @@ class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger):
59
59
 
60
60
  def hook(self) -> AwsGenericHook:
61
61
  return ComprehendHook(aws_conn_id=self.aws_conn_id)
62
+
63
+
64
+ class ComprehendCreateDocumentClassifierCompletedTrigger(AwsBaseWaiterTrigger):
65
+ """
66
+ Trigger when a Comprehend document classifier is complete.
67
+
68
+ :param document_classifier_arn: The arn of the Comprehend document classifier.
69
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
70
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
71
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ *,
77
+ document_classifier_arn: str,
78
+ waiter_delay: int = 120,
79
+ waiter_max_attempts: int = 75,
80
+ aws_conn_id: str | None = "aws_default",
81
+ ) -> None:
82
+ super().__init__(
83
+ serialized_fields={"document_classifier_arn": document_classifier_arn},
84
+ waiter_name="create_document_classifier_complete",
85
+ waiter_args={"DocumentClassifierArn": document_classifier_arn},
86
+ failure_message="Comprehend create document classifier failed.",
87
+ status_message="Status of Comprehend create document classifier is",
88
+ status_queries=["DocumentClassifierProperties.Status"],
89
+ return_key="document_classifier_arn",
90
+ return_value=document_classifier_arn,
91
+ waiter_delay=waiter_delay,
92
+ waiter_max_attempts=waiter_max_attempts,
93
+ aws_conn_id=aws_conn_id,
94
+ )
95
+
96
+ def hook(self) -> AwsGenericHook:
97
+ return ComprehendHook(aws_conn_id=self.aws_conn_id)
@@ -19,10 +19,14 @@ from __future__ import annotations
19
19
 
20
20
  import asyncio
21
21
  from functools import cached_property
22
- from typing import Any, AsyncIterator
22
+ from typing import TYPE_CHECKING, Any, AsyncIterator
23
23
 
24
- from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
24
+ if TYPE_CHECKING:
25
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
26
+
27
+ from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
25
28
  from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
29
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
26
30
  from airflow.triggers.base import BaseTrigger, TriggerEvent
27
31
 
28
32
 
@@ -148,3 +152,73 @@ class GlueCatalogPartitionTrigger(BaseTrigger):
148
152
  break
149
153
  else:
150
154
  await asyncio.sleep(self.waiter_delay)
155
+
156
+
157
+ class GlueDataQualityRuleSetEvaluationRunCompleteTrigger(AwsBaseWaiterTrigger):
158
+ """
159
+ Trigger when a AWS Glue data quality evaluation run complete.
160
+
161
+ :param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
162
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
163
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
164
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
165
+ """
166
+
167
+ def __init__(
168
+ self,
169
+ evaluation_run_id: str,
170
+ waiter_delay: int = 60,
171
+ waiter_max_attempts: int = 75,
172
+ aws_conn_id: str | None = "aws_default",
173
+ ):
174
+ super().__init__(
175
+ serialized_fields={"evaluation_run_id": evaluation_run_id},
176
+ waiter_name="data_quality_ruleset_evaluation_run_complete",
177
+ waiter_args={"RunId": evaluation_run_id},
178
+ failure_message="AWS Glue data quality ruleset evaluation run failed.",
179
+ status_message="Status of AWS Glue data quality ruleset evaluation run is",
180
+ status_queries=["Status"],
181
+ return_key="evaluation_run_id",
182
+ return_value=evaluation_run_id,
183
+ waiter_delay=waiter_delay,
184
+ waiter_max_attempts=waiter_max_attempts,
185
+ aws_conn_id=aws_conn_id,
186
+ )
187
+
188
+ def hook(self) -> AwsGenericHook:
189
+ return GlueDataQualityHook(aws_conn_id=self.aws_conn_id)
190
+
191
+
192
+ class GlueDataQualityRuleRecommendationRunCompleteTrigger(AwsBaseWaiterTrigger):
193
+ """
194
+ Trigger when a AWS Glue data quality recommendation run complete.
195
+
196
+ :param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
197
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
198
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
199
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ recommendation_run_id: str,
205
+ waiter_delay: int = 60,
206
+ waiter_max_attempts: int = 75,
207
+ aws_conn_id: str | None = "aws_default",
208
+ ):
209
+ super().__init__(
210
+ serialized_fields={"recommendation_run_id": recommendation_run_id},
211
+ waiter_name="data_quality_rule_recommendation_run_complete",
212
+ waiter_args={"RunId": recommendation_run_id},
213
+ failure_message="AWS Glue data quality recommendation run failed.",
214
+ status_message="Status of AWS Glue data quality recommendation run is",
215
+ status_queries=["Status"],
216
+ return_key="recommendation_run_id",
217
+ return_value=recommendation_run_id,
218
+ waiter_delay=waiter_delay,
219
+ waiter_max_attempts=waiter_max_attempts,
220
+ aws_conn_id=aws_conn_id,
221
+ )
222
+
223
+ def hook(self) -> AwsGenericHook:
224
+ return GlueDataQualityHook(aws_conn_id=self.aws_conn_id)
@@ -20,10 +20,9 @@ import logging
20
20
  import re
21
21
  from datetime import datetime, timezone
22
22
  from enum import Enum
23
+ from importlib import metadata
23
24
  from typing import Any
24
25
 
25
- import importlib_metadata
26
-
27
26
  from airflow.exceptions import AirflowException
28
27
  from airflow.utils.helpers import prune_dict
29
28
  from airflow.version import version
@@ -78,7 +77,7 @@ def get_airflow_version() -> tuple[int, ...]:
78
77
 
79
78
  def get_botocore_version() -> tuple[int, ...]:
80
79
  """Return the version number of the installed botocore package in the form of a tuple[int,...]."""
81
- return tuple(map(int, importlib_metadata.version("botocore").split(".")[:3]))
80
+ return tuple(map(int, metadata.version("botocore").split(".")[:3]))
82
81
 
83
82
 
84
83
  def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]: