apache-airflow-providers-amazon 8.23.0__py3-none-any.whl → 8.24.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/LICENSE +4 -4
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
- airflow/providers/amazon/aws/hooks/batch_client.py +3 -0
- airflow/providers/amazon/aws/hooks/dynamodb.py +34 -1
- airflow/providers/amazon/aws/hooks/glue.py +123 -0
- airflow/providers/amazon/aws/operators/batch.py +8 -0
- airflow/providers/amazon/aws/operators/bedrock.py +6 -20
- airflow/providers/amazon/aws/operators/ecs.py +5 -5
- airflow/providers/amazon/aws/operators/emr.py +38 -30
- airflow/providers/amazon/aws/operators/glue.py +408 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
- airflow/providers/amazon/aws/sensors/glue.py +260 -2
- airflow/providers/amazon/aws/sensors/s3.py +35 -5
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
- airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +257 -0
- airflow/providers/amazon/aws/triggers/glue.py +76 -2
- airflow/providers/amazon/aws/waiters/dynamodb.json +37 -0
- airflow/providers/amazon/aws/waiters/glue.json +98 -0
- airflow/providers/amazon/get_provider_info.py +26 -13
- {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/METADATA +19 -18
- {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/RECORD +24 -23
- {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.23.0.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/entry_points.txt +0 -0
@@ -18,10 +18,18 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
from functools import cached_property
|
21
|
-
from typing import TYPE_CHECKING, Sequence
|
21
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
22
22
|
|
23
|
+
from airflow.configuration import conf
|
23
24
|
from airflow.exceptions import AirflowException, AirflowSkipException
|
24
|
-
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
|
25
|
+
from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
|
26
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
|
+
from airflow.providers.amazon.aws.triggers.glue import (
|
28
|
+
GlueDataQualityRuleRecommendationRunCompleteTrigger,
|
29
|
+
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
|
30
|
+
)
|
31
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
32
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
25
33
|
from airflow.sensors.base import BaseSensorOperator
|
26
34
|
|
27
35
|
if TYPE_CHECKING:
|
@@ -91,3 +99,253 @@ class GlueJobSensor(BaseSensorOperator):
|
|
91
99
|
run_id=self.run_id,
|
92
100
|
continuation_tokens=self.next_log_tokens,
|
93
101
|
)
|
102
|
+
|
103
|
+
|
104
|
+
class GlueDataQualityRuleSetEvaluationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
|
105
|
+
"""
|
106
|
+
Waits for an AWS Glue data quality ruleset evaluation run to reach any of the status below.
|
107
|
+
|
108
|
+
'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
|
109
|
+
|
110
|
+
.. seealso::
|
111
|
+
For more information on how to use this sensor, take a look at the guide:
|
112
|
+
:ref:`howto/sensor:GlueDataQualityRuleSetEvaluationRunSensor`
|
113
|
+
|
114
|
+
:param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
|
115
|
+
:param verify_result_status: Validate all the ruleset rules evaluation run results,
|
116
|
+
If any of the rule status is Fail or Error then an exception is thrown. (default: True)
|
117
|
+
:param show_results: Displays all the ruleset rules evaluation run results. (default: True)
|
118
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
119
|
+
module to be installed.
|
120
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
121
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
122
|
+
:param max_retries: Number of times before returning the current state. (default: 60)
|
123
|
+
|
124
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
125
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
126
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
127
|
+
empty, then default boto3 configuration would be used (and must be
|
128
|
+
maintained on each worker node).
|
129
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
130
|
+
:param verify: Whether to verify SSL certificates. See:
|
131
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
132
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
133
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
134
|
+
"""
|
135
|
+
|
136
|
+
SUCCESS_STATES = ("SUCCEEDED",)
|
137
|
+
|
138
|
+
FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
|
139
|
+
|
140
|
+
aws_hook_class = GlueDataQualityHook
|
141
|
+
template_fields: Sequence[str] = aws_template_fields("evaluation_run_id")
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
*,
|
146
|
+
evaluation_run_id: str,
|
147
|
+
show_results: bool = True,
|
148
|
+
verify_result_status: bool = True,
|
149
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
150
|
+
poke_interval: int = 120,
|
151
|
+
max_retries: int = 60,
|
152
|
+
aws_conn_id: str | None = "aws_default",
|
153
|
+
**kwargs,
|
154
|
+
):
|
155
|
+
super().__init__(**kwargs)
|
156
|
+
self.evaluation_run_id = evaluation_run_id
|
157
|
+
self.show_results = show_results
|
158
|
+
self.verify_result_status = verify_result_status
|
159
|
+
self.aws_conn_id = aws_conn_id
|
160
|
+
self.max_retries = max_retries
|
161
|
+
self.poke_interval = poke_interval
|
162
|
+
self.deferrable = deferrable
|
163
|
+
|
164
|
+
def execute(self, context: Context) -> Any:
|
165
|
+
if self.deferrable:
|
166
|
+
self.defer(
|
167
|
+
trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
|
168
|
+
evaluation_run_id=self.evaluation_run_id,
|
169
|
+
waiter_delay=int(self.poke_interval),
|
170
|
+
waiter_max_attempts=self.max_retries,
|
171
|
+
aws_conn_id=self.aws_conn_id,
|
172
|
+
),
|
173
|
+
method_name="execute_complete",
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
super().execute(context=context)
|
177
|
+
|
178
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
179
|
+
event = validate_execute_complete_event(event)
|
180
|
+
|
181
|
+
if event["status"] != "success":
|
182
|
+
message = f"Error: AWS Glue data quality ruleset evaluation run: {event}"
|
183
|
+
if self.soft_fail:
|
184
|
+
raise AirflowSkipException(message)
|
185
|
+
raise AirflowException(message)
|
186
|
+
|
187
|
+
self.hook.validate_evaluation_run_results(
|
188
|
+
evaluation_run_id=event["evaluation_run_id"],
|
189
|
+
show_results=self.show_results,
|
190
|
+
verify_result_status=self.verify_result_status,
|
191
|
+
)
|
192
|
+
|
193
|
+
self.log.info("AWS Glue data quality ruleset evaluation run completed.")
|
194
|
+
|
195
|
+
def poke(self, context: Context):
|
196
|
+
self.log.info(
|
197
|
+
"Poking for AWS Glue data quality ruleset evaluation run RunId: %s", self.evaluation_run_id
|
198
|
+
)
|
199
|
+
|
200
|
+
response = self.hook.conn.get_data_quality_ruleset_evaluation_run(RunId=self.evaluation_run_id)
|
201
|
+
|
202
|
+
status = response.get("Status")
|
203
|
+
|
204
|
+
if status in self.SUCCESS_STATES:
|
205
|
+
self.hook.validate_evaluation_run_results(
|
206
|
+
evaluation_run_id=self.evaluation_run_id,
|
207
|
+
show_results=self.show_results,
|
208
|
+
verify_result_status=self.verify_result_status,
|
209
|
+
)
|
210
|
+
|
211
|
+
self.log.info(
|
212
|
+
"AWS Glue data quality ruleset evaluation run completed RunId: %s Run State: %s",
|
213
|
+
self.evaluation_run_id,
|
214
|
+
response["Status"],
|
215
|
+
)
|
216
|
+
|
217
|
+
return True
|
218
|
+
|
219
|
+
elif status in self.FAILURE_STATES:
|
220
|
+
job_error_message = (
|
221
|
+
f"Error: AWS Glue data quality ruleset evaluation run RunId: {self.evaluation_run_id} Run "
|
222
|
+
f"Status: {status}"
|
223
|
+
f": {response.get('ErrorString')}"
|
224
|
+
)
|
225
|
+
self.log.info(job_error_message)
|
226
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
227
|
+
if self.soft_fail:
|
228
|
+
raise AirflowSkipException(job_error_message)
|
229
|
+
raise AirflowException(job_error_message)
|
230
|
+
else:
|
231
|
+
return False
|
232
|
+
|
233
|
+
|
234
|
+
class GlueDataQualityRuleRecommendationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
|
235
|
+
"""
|
236
|
+
Waits for an AWS Glue data quality recommendation run to reach any of the status below.
|
237
|
+
|
238
|
+
'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
|
239
|
+
|
240
|
+
.. seealso::
|
241
|
+
For more information on how to use this sensor, take a look at the guide:
|
242
|
+
:ref:`howto/sensor:GlueDataQualityRuleRecommendationRunSensor`
|
243
|
+
|
244
|
+
:param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
|
245
|
+
:param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
|
246
|
+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
|
247
|
+
module to be installed.
|
248
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
249
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
|
250
|
+
:param max_retries: Number of times before returning the current state. (default: 60)
|
251
|
+
|
252
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
253
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
254
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
255
|
+
empty, then default boto3 configuration would be used (and must be
|
256
|
+
maintained on each worker node).
|
257
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
258
|
+
:param verify: Whether to verify SSL certificates. See:
|
259
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
260
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
261
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
262
|
+
"""
|
263
|
+
|
264
|
+
SUCCESS_STATES = ("SUCCEEDED",)
|
265
|
+
|
266
|
+
FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
|
267
|
+
|
268
|
+
aws_hook_class = GlueDataQualityHook
|
269
|
+
template_fields: Sequence[str] = aws_template_fields("recommendation_run_id")
|
270
|
+
|
271
|
+
def __init__(
|
272
|
+
self,
|
273
|
+
*,
|
274
|
+
recommendation_run_id: str,
|
275
|
+
show_results: bool = True,
|
276
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
277
|
+
poke_interval: int = 120,
|
278
|
+
max_retries: int = 60,
|
279
|
+
aws_conn_id: str | None = "aws_default",
|
280
|
+
**kwargs,
|
281
|
+
):
|
282
|
+
super().__init__(**kwargs)
|
283
|
+
self.recommendation_run_id = recommendation_run_id
|
284
|
+
self.show_results = show_results
|
285
|
+
self.deferrable = deferrable
|
286
|
+
self.poke_interval = poke_interval
|
287
|
+
self.max_retries = max_retries
|
288
|
+
self.aws_conn_id = aws_conn_id
|
289
|
+
|
290
|
+
def execute(self, context: Context) -> Any:
|
291
|
+
if self.deferrable:
|
292
|
+
self.defer(
|
293
|
+
trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
|
294
|
+
recommendation_run_id=self.recommendation_run_id,
|
295
|
+
waiter_delay=int(self.poke_interval),
|
296
|
+
waiter_max_attempts=self.max_retries,
|
297
|
+
aws_conn_id=self.aws_conn_id,
|
298
|
+
),
|
299
|
+
method_name="execute_complete",
|
300
|
+
)
|
301
|
+
else:
|
302
|
+
super().execute(context=context)
|
303
|
+
|
304
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
305
|
+
event = validate_execute_complete_event(event)
|
306
|
+
|
307
|
+
if event["status"] != "success":
|
308
|
+
message = f"Error: AWS Glue data quality recommendation run: {event}"
|
309
|
+
if self.soft_fail:
|
310
|
+
raise AirflowSkipException(message)
|
311
|
+
raise AirflowException(message)
|
312
|
+
|
313
|
+
if self.show_results:
|
314
|
+
self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
|
315
|
+
|
316
|
+
self.log.info("AWS Glue data quality recommendation run completed.")
|
317
|
+
|
318
|
+
def poke(self, context: Context) -> bool:
|
319
|
+
self.log.info(
|
320
|
+
"Poking for AWS Glue data quality recommendation run RunId: %s", self.recommendation_run_id
|
321
|
+
)
|
322
|
+
|
323
|
+
response = self.hook.conn.get_data_quality_rule_recommendation_run(RunId=self.recommendation_run_id)
|
324
|
+
|
325
|
+
status = response.get("Status")
|
326
|
+
|
327
|
+
if status in self.SUCCESS_STATES:
|
328
|
+
if self.show_results:
|
329
|
+
self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
|
330
|
+
|
331
|
+
self.log.info(
|
332
|
+
"AWS Glue data quality recommendation run completed RunId: %s Run State: %s",
|
333
|
+
self.recommendation_run_id,
|
334
|
+
response["Status"],
|
335
|
+
)
|
336
|
+
|
337
|
+
return True
|
338
|
+
|
339
|
+
elif status in self.FAILURE_STATES:
|
340
|
+
job_error_message = (
|
341
|
+
f"Error: AWS Glue data quality recommendation run RunId: {self.recommendation_run_id} Run "
|
342
|
+
f"Status: {status}"
|
343
|
+
f": {response.get('ErrorString')}"
|
344
|
+
)
|
345
|
+
self.log.info(job_error_message)
|
346
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
347
|
+
if self.soft_fail:
|
348
|
+
raise AirflowSkipException(job_error_message)
|
349
|
+
raise AirflowException(job_error_message)
|
350
|
+
else:
|
351
|
+
return False
|
@@ -78,6 +78,11 @@ class S3KeySensor(BaseSensorOperator):
|
|
78
78
|
CA cert bundle than the one used by botocore.
|
79
79
|
:param deferrable: Run operator in the deferrable mode
|
80
80
|
:param use_regex: whether to use regex to check bucket
|
81
|
+
:param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
|
82
|
+
Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
|
83
|
+
all available attributes.
|
84
|
+
Default value: "Size".
|
85
|
+
If the requested attribute is not found, the key is still included and the value is None.
|
81
86
|
"""
|
82
87
|
|
83
88
|
template_fields: Sequence[str] = ("bucket_key", "bucket_name")
|
@@ -93,6 +98,7 @@ class S3KeySensor(BaseSensorOperator):
|
|
93
98
|
verify: str | bool | None = None,
|
94
99
|
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
95
100
|
use_regex: bool = False,
|
101
|
+
metadata_keys: list[str] | None = None,
|
96
102
|
**kwargs,
|
97
103
|
):
|
98
104
|
super().__init__(**kwargs)
|
@@ -104,14 +110,14 @@ class S3KeySensor(BaseSensorOperator):
|
|
104
110
|
self.verify = verify
|
105
111
|
self.deferrable = deferrable
|
106
112
|
self.use_regex = use_regex
|
113
|
+
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
|
107
114
|
|
108
115
|
def _check_key(self, key):
|
109
116
|
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
|
110
117
|
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
|
111
118
|
|
112
119
|
"""
|
113
|
-
Set variable `files` which contains a list of dict which contains
|
114
|
-
If needed we might want to add other attributes later
|
120
|
+
Set variable `files` which contains a list of dict which contains attributes defined by the user
|
115
121
|
Format: [{
|
116
122
|
'Size': int
|
117
123
|
}]
|
@@ -123,8 +129,21 @@ class S3KeySensor(BaseSensorOperator):
|
|
123
129
|
if not key_matches:
|
124
130
|
return False
|
125
131
|
|
126
|
-
# Reduce the set of metadata to
|
127
|
-
files = [
|
132
|
+
# Reduce the set of metadata to requested attributes
|
133
|
+
files = []
|
134
|
+
for f in key_matches:
|
135
|
+
metadata = {}
|
136
|
+
if "*" in self.metadata_keys:
|
137
|
+
metadata = self.hook.head_object(f["Key"], bucket_name)
|
138
|
+
else:
|
139
|
+
for key in self.metadata_keys:
|
140
|
+
try:
|
141
|
+
metadata[key] = f[key]
|
142
|
+
except KeyError:
|
143
|
+
# supplied key might be from head_object response
|
144
|
+
self.log.info("Key %s not found in response, performing head_object", key)
|
145
|
+
metadata[key] = self.hook.head_object(f["Key"], bucket_name).get(key, None)
|
146
|
+
files.append(metadata)
|
128
147
|
elif self.use_regex:
|
129
148
|
keys = self.hook.get_file_metadata("", bucket_name)
|
130
149
|
key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
|
@@ -134,7 +153,18 @@ class S3KeySensor(BaseSensorOperator):
|
|
134
153
|
obj = self.hook.head_object(key, bucket_name)
|
135
154
|
if obj is None:
|
136
155
|
return False
|
137
|
-
|
156
|
+
metadata = {}
|
157
|
+
if "*" in self.metadata_keys:
|
158
|
+
metadata = self.hook.head_object(key, bucket_name)
|
159
|
+
|
160
|
+
else:
|
161
|
+
for key in self.metadata_keys:
|
162
|
+
# backwards compatibility with original implementation
|
163
|
+
if key == "Size":
|
164
|
+
metadata[key] = obj.get("ContentLength")
|
165
|
+
else:
|
166
|
+
metadata[key] = obj.get(key, None)
|
167
|
+
files = [metadata]
|
138
168
|
|
139
169
|
if self.check_fn is not None:
|
140
170
|
return self.check_fn(files)
|
@@ -0,0 +1,257 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypedDict
|
21
|
+
|
22
|
+
from botocore.exceptions import ClientError, WaiterError
|
23
|
+
|
24
|
+
from airflow.exceptions import AirflowException
|
25
|
+
from airflow.models import BaseOperator
|
26
|
+
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from airflow.utils.context import Context
|
30
|
+
|
31
|
+
|
32
|
+
class AttributeDefinition(TypedDict):
|
33
|
+
"""Attribute Definition Type."""
|
34
|
+
|
35
|
+
AttributeName: str
|
36
|
+
AttributeType: Literal["S", "N", "B"]
|
37
|
+
|
38
|
+
|
39
|
+
class KeySchema(TypedDict):
|
40
|
+
"""Key Schema Type."""
|
41
|
+
|
42
|
+
AttributeName: str
|
43
|
+
KeyType: Literal["HASH", "RANGE"]
|
44
|
+
|
45
|
+
|
46
|
+
class S3ToDynamoDBOperator(BaseOperator):
|
47
|
+
"""Load Data from S3 into a DynamoDB.
|
48
|
+
|
49
|
+
Data stored in S3 can be uploaded to a new or existing DynamoDB. Supported file formats CSV, DynamoDB JSON and
|
50
|
+
Amazon ION.
|
51
|
+
|
52
|
+
|
53
|
+
:param s3_bucket: The S3 bucket that is imported
|
54
|
+
:param s3_key: Key prefix that imports single or multiple objects from S3
|
55
|
+
:param dynamodb_table_name: Name of the table that shall be created
|
56
|
+
:param dynamodb_key_schema: Primary key and sort key. Each element represents one primary key
|
57
|
+
attribute. AttributeName is the name of the attribute. KeyType is the role for the attribute. Valid values
|
58
|
+
HASH or RANGE
|
59
|
+
:param dynamodb_attributes: Name of the attributes of a table. AttributeName is the name for the attribute
|
60
|
+
AttributeType is the data type for the attribute. Valid values for AttributeType are
|
61
|
+
S - attribute is of type String
|
62
|
+
N - attribute is of type Number
|
63
|
+
B - attribute is of type Binary
|
64
|
+
:param dynamodb_tmp_table_prefix: Prefix for the temporary DynamoDB table
|
65
|
+
:param delete_on_error: If set, the new DynamoDB table will be deleted in case of import errors
|
66
|
+
:param use_existing_table: Whether to import to an existing non new DynamoDB table. If set to
|
67
|
+
true data is loaded first into a temporary DynamoDB table (using the AWS ImportTable Service),
|
68
|
+
then retrieved as chunks into memory and loaded into the target table. If set to false, a new
|
69
|
+
DynamoDB table will be created and S3 data is bulk loaded by the AWS ImportTable Service.
|
70
|
+
:param input_format: The format for the imported data. Valid values for InputFormat are CSV, DYNAMODB_JSON
|
71
|
+
or ION
|
72
|
+
:param billing_mode: Billing mode for the table. Valid values are PROVISIONED or PAY_PER_REQUEST
|
73
|
+
:param on_demand_throughput: Extra options for maximum number of read and write units
|
74
|
+
:param import_table_kwargs: Any additional optional import table parameters to pass, such as ClientToken,
|
75
|
+
InputCompressionType, or InputFormatOptions. See:
|
76
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html
|
77
|
+
:param import_table_creation_kwargs: Any additional optional import table creation parameters to pass, such as
|
78
|
+
ProvisionedThroughput, SSESpecification, or GlobalSecondaryIndexes. See:
|
79
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html
|
80
|
+
:param wait_for_completion: Whether to wait for cluster to stop
|
81
|
+
:param check_interval: Time in seconds to wait between status checks
|
82
|
+
:param max_attempts: Maximum number of attempts to check for job completion
|
83
|
+
:param aws_conn_id: The reference to the AWS connection details
|
84
|
+
"""
|
85
|
+
|
86
|
+
template_fields: Sequence[str] = (
|
87
|
+
"s3_bucket",
|
88
|
+
"s3_key",
|
89
|
+
"dynamodb_table_name",
|
90
|
+
"dynamodb_key_schema",
|
91
|
+
"dynamodb_attributes",
|
92
|
+
"dynamodb_tmp_table_prefix",
|
93
|
+
"delete_on_error",
|
94
|
+
"use_existing_table",
|
95
|
+
"input_format",
|
96
|
+
"billing_mode",
|
97
|
+
"import_table_kwargs",
|
98
|
+
"import_table_creation_kwargs",
|
99
|
+
)
|
100
|
+
ui_color = "#e2e8f0"
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
*,
|
105
|
+
s3_bucket: str,
|
106
|
+
s3_key: str,
|
107
|
+
dynamodb_table_name: str,
|
108
|
+
dynamodb_key_schema: list[KeySchema],
|
109
|
+
dynamodb_attributes: list[AttributeDefinition] | None = None,
|
110
|
+
dynamodb_tmp_table_prefix: str = "tmp",
|
111
|
+
delete_on_error: bool = False,
|
112
|
+
use_existing_table: bool = False,
|
113
|
+
input_format: Literal["CSV", "DYNAMODB_JSON", "ION"] = "DYNAMODB_JSON",
|
114
|
+
billing_mode: Literal["PROVISIONED", "PAY_PER_REQUEST"] = "PAY_PER_REQUEST",
|
115
|
+
import_table_kwargs: dict[str, Any] | None = None,
|
116
|
+
import_table_creation_kwargs: dict[str, Any] | None = None,
|
117
|
+
wait_for_completion: bool = True,
|
118
|
+
check_interval: int = 30,
|
119
|
+
max_attempts: int = 240,
|
120
|
+
aws_conn_id: str | None = "aws_default",
|
121
|
+
**kwargs,
|
122
|
+
) -> None:
|
123
|
+
super().__init__(**kwargs)
|
124
|
+
self.s3_bucket = s3_bucket
|
125
|
+
self.s3_key = s3_key
|
126
|
+
self.dynamodb_table_name = dynamodb_table_name
|
127
|
+
self.dynamodb_attributes = dynamodb_attributes
|
128
|
+
self.dynamodb_tmp_table_prefix = dynamodb_tmp_table_prefix
|
129
|
+
self.delete_on_error = delete_on_error
|
130
|
+
self.use_existing_table = use_existing_table
|
131
|
+
self.dynamodb_key_schema = dynamodb_key_schema
|
132
|
+
self.input_format = input_format
|
133
|
+
self.billing_mode = billing_mode
|
134
|
+
self.import_table_kwargs = import_table_kwargs
|
135
|
+
self.import_table_creation_kwargs = import_table_creation_kwargs
|
136
|
+
self.wait_for_completion = wait_for_completion
|
137
|
+
self.check_interval = check_interval
|
138
|
+
self.max_attempts = max_attempts
|
139
|
+
self.aws_conn_id = aws_conn_id
|
140
|
+
|
141
|
+
@property
|
142
|
+
def tmp_table_name(self):
|
143
|
+
"""Temporary table name."""
|
144
|
+
return f"{self.dynamodb_tmp_table_prefix}_{self.dynamodb_table_name}"
|
145
|
+
|
146
|
+
def _load_into_new_table(self, table_name: str, delete_on_error: bool) -> str:
|
147
|
+
"""
|
148
|
+
Import S3 key or keys into a new DynamoDB table.
|
149
|
+
|
150
|
+
:param table_name: Name of the table that shall be created
|
151
|
+
:param delete_on_error: If set, the new DynamoDB table will be deleted in case of import errors
|
152
|
+
:return: The Amazon resource number (ARN)
|
153
|
+
"""
|
154
|
+
dynamodb_hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
|
155
|
+
client = dynamodb_hook.client
|
156
|
+
|
157
|
+
import_table_config = self.import_table_kwargs or {}
|
158
|
+
import_table_creation_config = self.import_table_creation_kwargs or {}
|
159
|
+
|
160
|
+
try:
|
161
|
+
response = client.import_table(
|
162
|
+
S3BucketSource={
|
163
|
+
"S3Bucket": self.s3_bucket,
|
164
|
+
"S3KeyPrefix": self.s3_key,
|
165
|
+
},
|
166
|
+
InputFormat=self.input_format,
|
167
|
+
TableCreationParameters={
|
168
|
+
"TableName": table_name,
|
169
|
+
"AttributeDefinitions": self.dynamodb_attributes,
|
170
|
+
"KeySchema": self.dynamodb_key_schema,
|
171
|
+
"BillingMode": self.billing_mode,
|
172
|
+
**import_table_creation_config,
|
173
|
+
},
|
174
|
+
**import_table_config,
|
175
|
+
)
|
176
|
+
except ClientError as e:
|
177
|
+
self.log.error("Error: failed to load from S3 into DynamoDB table. Error: %s", str(e))
|
178
|
+
raise AirflowException(f"S3 load into DynamoDB table failed with error: {e}")
|
179
|
+
|
180
|
+
if response["ImportTableDescription"]["ImportStatus"] == "FAILED":
|
181
|
+
raise AirflowException(
|
182
|
+
"S3 into Dynamodb job creation failed. Code: "
|
183
|
+
f"{response['ImportTableDescription']['FailureCode']}. "
|
184
|
+
f"Failure: {response['ImportTableDescription']['FailureMessage']}"
|
185
|
+
)
|
186
|
+
|
187
|
+
if self.wait_for_completion:
|
188
|
+
self.log.info("Waiting for S3 into Dynamodb job to complete")
|
189
|
+
waiter = dynamodb_hook.get_waiter("import_table")
|
190
|
+
try:
|
191
|
+
waiter.wait(
|
192
|
+
ImportArn=response["ImportTableDescription"]["ImportArn"],
|
193
|
+
WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_attempts},
|
194
|
+
)
|
195
|
+
except WaiterError:
|
196
|
+
status, error_code, error_msg = dynamodb_hook.get_import_status(
|
197
|
+
response["ImportTableDescription"]["ImportArn"]
|
198
|
+
)
|
199
|
+
if delete_on_error:
|
200
|
+
client.delete_table(TableName=table_name)
|
201
|
+
raise AirflowException(
|
202
|
+
f"S3 import into Dynamodb job failed: Status: {status}. Error: {error_code}. Error message: {error_msg}"
|
203
|
+
)
|
204
|
+
return response["ImportTableDescription"]["ImportArn"]
|
205
|
+
|
206
|
+
def _load_into_existing_table(self) -> str:
|
207
|
+
"""
|
208
|
+
Import S3 key or keys in an existing DynamoDB table.
|
209
|
+
|
210
|
+
:return:The Amazon resource number (ARN)
|
211
|
+
"""
|
212
|
+
if not self.wait_for_completion:
|
213
|
+
raise ValueError("wait_for_completion must be set to True when loading into an existing table")
|
214
|
+
table_keys = [key["AttributeName"] for key in self.dynamodb_key_schema]
|
215
|
+
|
216
|
+
dynamodb_hook = DynamoDBHook(
|
217
|
+
aws_conn_id=self.aws_conn_id, table_name=self.dynamodb_table_name, table_keys=table_keys
|
218
|
+
)
|
219
|
+
client = dynamodb_hook.client
|
220
|
+
|
221
|
+
self.log.info("Loading from S3 into a tmp DynamoDB table %s", self.tmp_table_name)
|
222
|
+
self._load_into_new_table(table_name=self.tmp_table_name, delete_on_error=self.delete_on_error)
|
223
|
+
total_items = 0
|
224
|
+
try:
|
225
|
+
paginator = client.get_paginator("scan")
|
226
|
+
paginate = paginator.paginate(
|
227
|
+
TableName=self.tmp_table_name,
|
228
|
+
Select="ALL_ATTRIBUTES",
|
229
|
+
ReturnConsumedCapacity="NONE",
|
230
|
+
ConsistentRead=True,
|
231
|
+
)
|
232
|
+
self.log.info(
|
233
|
+
"Loading data from %s to %s DynamoDB table", self.tmp_table_name, self.dynamodb_table_name
|
234
|
+
)
|
235
|
+
for page in paginate:
|
236
|
+
total_items += page.get("Count", 0)
|
237
|
+
dynamodb_hook.write_batch_data(items=page["Items"])
|
238
|
+
self.log.info("Number of items loaded: %s", total_items)
|
239
|
+
finally:
|
240
|
+
self.log.info("Delete tmp DynamoDB table %s", self.tmp_table_name)
|
241
|
+
client.delete_table(TableName=self.tmp_table_name)
|
242
|
+
return dynamodb_hook.get_conn().Table(self.dynamodb_table_name).table_arn
|
243
|
+
|
244
|
+
def execute(self, context: Context) -> str:
|
245
|
+
"""
|
246
|
+
Execute S3 to DynamoDB Job from Airflow.
|
247
|
+
|
248
|
+
:param context: The current context of the task instance
|
249
|
+
:return: The Amazon resource number (ARN)
|
250
|
+
"""
|
251
|
+
if self.use_existing_table:
|
252
|
+
self.log.info("Loading from S3 into new DynamoDB table %s", self.dynamodb_table_name)
|
253
|
+
return self._load_into_existing_table()
|
254
|
+
self.log.info("Loading from S3 into existing DynamoDB table %s", self.dynamodb_table_name)
|
255
|
+
return self._load_into_new_table(
|
256
|
+
table_name=self.dynamodb_table_name, delete_on_error=self.delete_on_error
|
257
|
+
)
|