apache-airflow-providers-amazon 8.23.0rc1__py3-none-any.whl → 8.24.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. airflow/providers/amazon/LICENSE +4 -4
  2. airflow/providers/amazon/__init__.py +1 -1
  3. airflow/providers/amazon/aws/hooks/base_aws.py +8 -3
  4. airflow/providers/amazon/aws/hooks/batch_client.py +3 -0
  5. airflow/providers/amazon/aws/hooks/dynamodb.py +34 -1
  6. airflow/providers/amazon/aws/hooks/glue.py +123 -0
  7. airflow/providers/amazon/aws/operators/batch.py +8 -0
  8. airflow/providers/amazon/aws/operators/bedrock.py +6 -20
  9. airflow/providers/amazon/aws/operators/ecs.py +5 -5
  10. airflow/providers/amazon/aws/operators/emr.py +38 -30
  11. airflow/providers/amazon/aws/operators/glue.py +408 -2
  12. airflow/providers/amazon/aws/operators/sagemaker.py +85 -12
  13. airflow/providers/amazon/aws/sensors/glue.py +260 -2
  14. airflow/providers/amazon/aws/sensors/s3.py +35 -5
  15. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +0 -1
  16. airflow/providers/amazon/aws/transfers/s3_to_dynamodb.py +257 -0
  17. airflow/providers/amazon/aws/triggers/glue.py +76 -2
  18. airflow/providers/amazon/aws/waiters/dynamodb.json +37 -0
  19. airflow/providers/amazon/aws/waiters/glue.json +98 -0
  20. airflow/providers/amazon/get_provider_info.py +26 -13
  21. {apache_airflow_providers_amazon-8.23.0rc1.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/METADATA +22 -21
  22. {apache_airflow_providers_amazon-8.23.0rc1.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/RECORD +24 -23
  23. {apache_airflow_providers_amazon-8.23.0rc1.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/WHEEL +0 -0
  24. {apache_airflow_providers_amazon-8.23.0rc1.dist-info → apache_airflow_providers_amazon-8.24.0.dist-info}/entry_points.txt +0 -0
@@ -18,10 +18,18 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  from functools import cached_property
21
- from typing import TYPE_CHECKING, Sequence
21
+ from typing import TYPE_CHECKING, Any, Sequence
22
22
 
23
+ from airflow.configuration import conf
23
24
  from airflow.exceptions import AirflowException, AirflowSkipException
24
- from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
25
+ from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook
26
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
+ from airflow.providers.amazon.aws.triggers.glue import (
28
+ GlueDataQualityRuleRecommendationRunCompleteTrigger,
29
+ GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
30
+ )
31
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
32
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
25
33
  from airflow.sensors.base import BaseSensorOperator
26
34
 
27
35
  if TYPE_CHECKING:
@@ -91,3 +99,253 @@ class GlueJobSensor(BaseSensorOperator):
91
99
  run_id=self.run_id,
92
100
  continuation_tokens=self.next_log_tokens,
93
101
  )
102
+
103
+
104
+ class GlueDataQualityRuleSetEvaluationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
105
+ """
106
+ Waits for an AWS Glue data quality ruleset evaluation run to reach any of the status below.
107
+
108
+ 'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
109
+
110
+ .. seealso::
111
+ For more information on how to use this sensor, take a look at the guide:
112
+ :ref:`howto/sensor:GlueDataQualityRuleSetEvaluationRunSensor`
113
+
114
+ :param evaluation_run_id: The AWS Glue data quality ruleset evaluation run identifier.
115
+ :param verify_result_status: Validate all the ruleset rules evaluation run results,
116
+ If any of the rule status is Fail or Error then an exception is thrown. (default: True)
117
+ :param show_results: Displays all the ruleset rules evaluation run results. (default: True)
118
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
119
+ module to be installed.
120
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
121
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
122
+ :param max_retries: Number of times before returning the current state. (default: 60)
123
+
124
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
125
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
126
+ running Airflow in a distributed manner and aws_conn_id is None or
127
+ empty, then default boto3 configuration would be used (and must be
128
+ maintained on each worker node).
129
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
130
+ :param verify: Whether to verify SSL certificates. See:
131
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
132
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
133
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
134
+ """
135
+
136
+ SUCCESS_STATES = ("SUCCEEDED",)
137
+
138
+ FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
139
+
140
+ aws_hook_class = GlueDataQualityHook
141
+ template_fields: Sequence[str] = aws_template_fields("evaluation_run_id")
142
+
143
+ def __init__(
144
+ self,
145
+ *,
146
+ evaluation_run_id: str,
147
+ show_results: bool = True,
148
+ verify_result_status: bool = True,
149
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
150
+ poke_interval: int = 120,
151
+ max_retries: int = 60,
152
+ aws_conn_id: str | None = "aws_default",
153
+ **kwargs,
154
+ ):
155
+ super().__init__(**kwargs)
156
+ self.evaluation_run_id = evaluation_run_id
157
+ self.show_results = show_results
158
+ self.verify_result_status = verify_result_status
159
+ self.aws_conn_id = aws_conn_id
160
+ self.max_retries = max_retries
161
+ self.poke_interval = poke_interval
162
+ self.deferrable = deferrable
163
+
164
+ def execute(self, context: Context) -> Any:
165
+ if self.deferrable:
166
+ self.defer(
167
+ trigger=GlueDataQualityRuleSetEvaluationRunCompleteTrigger(
168
+ evaluation_run_id=self.evaluation_run_id,
169
+ waiter_delay=int(self.poke_interval),
170
+ waiter_max_attempts=self.max_retries,
171
+ aws_conn_id=self.aws_conn_id,
172
+ ),
173
+ method_name="execute_complete",
174
+ )
175
+ else:
176
+ super().execute(context=context)
177
+
178
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
179
+ event = validate_execute_complete_event(event)
180
+
181
+ if event["status"] != "success":
182
+ message = f"Error: AWS Glue data quality ruleset evaluation run: {event}"
183
+ if self.soft_fail:
184
+ raise AirflowSkipException(message)
185
+ raise AirflowException(message)
186
+
187
+ self.hook.validate_evaluation_run_results(
188
+ evaluation_run_id=event["evaluation_run_id"],
189
+ show_results=self.show_results,
190
+ verify_result_status=self.verify_result_status,
191
+ )
192
+
193
+ self.log.info("AWS Glue data quality ruleset evaluation run completed.")
194
+
195
+ def poke(self, context: Context):
196
+ self.log.info(
197
+ "Poking for AWS Glue data quality ruleset evaluation run RunId: %s", self.evaluation_run_id
198
+ )
199
+
200
+ response = self.hook.conn.get_data_quality_ruleset_evaluation_run(RunId=self.evaluation_run_id)
201
+
202
+ status = response.get("Status")
203
+
204
+ if status in self.SUCCESS_STATES:
205
+ self.hook.validate_evaluation_run_results(
206
+ evaluation_run_id=self.evaluation_run_id,
207
+ show_results=self.show_results,
208
+ verify_result_status=self.verify_result_status,
209
+ )
210
+
211
+ self.log.info(
212
+ "AWS Glue data quality ruleset evaluation run completed RunId: %s Run State: %s",
213
+ self.evaluation_run_id,
214
+ response["Status"],
215
+ )
216
+
217
+ return True
218
+
219
+ elif status in self.FAILURE_STATES:
220
+ job_error_message = (
221
+ f"Error: AWS Glue data quality ruleset evaluation run RunId: {self.evaluation_run_id} Run "
222
+ f"Status: {status}"
223
+ f": {response.get('ErrorString')}"
224
+ )
225
+ self.log.info(job_error_message)
226
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
227
+ if self.soft_fail:
228
+ raise AirflowSkipException(job_error_message)
229
+ raise AirflowException(job_error_message)
230
+ else:
231
+ return False
232
+
233
+
234
+ class GlueDataQualityRuleRecommendationRunSensor(AwsBaseSensor[GlueDataQualityHook]):
235
+ """
236
+ Waits for an AWS Glue data quality recommendation run to reach any of the status below.
237
+
238
+ 'FAILED', 'STOPPED', 'STOPPING', 'TIMEOUT', 'SUCCEEDED'
239
+
240
+ .. seealso::
241
+ For more information on how to use this sensor, take a look at the guide:
242
+ :ref:`howto/sensor:GlueDataQualityRuleRecommendationRunSensor`
243
+
244
+ :param recommendation_run_id: The AWS Glue data quality rule recommendation run identifier.
245
+ :param show_results: Displays the recommended ruleset (a set of rules), when recommendation run completes. (default: True)
246
+ :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
247
+ module to be installed.
248
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
249
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120)
250
+ :param max_retries: Number of times before returning the current state. (default: 60)
251
+
252
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
253
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
254
+ running Airflow in a distributed manner and aws_conn_id is None or
255
+ empty, then default boto3 configuration would be used (and must be
256
+ maintained on each worker node).
257
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
258
+ :param verify: Whether to verify SSL certificates. See:
259
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
260
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
261
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
262
+ """
263
+
264
+ SUCCESS_STATES = ("SUCCEEDED",)
265
+
266
+ FAILURE_STATES = ("FAILED", "STOPPED", "STOPPING", "TIMEOUT")
267
+
268
+ aws_hook_class = GlueDataQualityHook
269
+ template_fields: Sequence[str] = aws_template_fields("recommendation_run_id")
270
+
271
+ def __init__(
272
+ self,
273
+ *,
274
+ recommendation_run_id: str,
275
+ show_results: bool = True,
276
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
277
+ poke_interval: int = 120,
278
+ max_retries: int = 60,
279
+ aws_conn_id: str | None = "aws_default",
280
+ **kwargs,
281
+ ):
282
+ super().__init__(**kwargs)
283
+ self.recommendation_run_id = recommendation_run_id
284
+ self.show_results = show_results
285
+ self.deferrable = deferrable
286
+ self.poke_interval = poke_interval
287
+ self.max_retries = max_retries
288
+ self.aws_conn_id = aws_conn_id
289
+
290
+ def execute(self, context: Context) -> Any:
291
+ if self.deferrable:
292
+ self.defer(
293
+ trigger=GlueDataQualityRuleRecommendationRunCompleteTrigger(
294
+ recommendation_run_id=self.recommendation_run_id,
295
+ waiter_delay=int(self.poke_interval),
296
+ waiter_max_attempts=self.max_retries,
297
+ aws_conn_id=self.aws_conn_id,
298
+ ),
299
+ method_name="execute_complete",
300
+ )
301
+ else:
302
+ super().execute(context=context)
303
+
304
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
305
+ event = validate_execute_complete_event(event)
306
+
307
+ if event["status"] != "success":
308
+ message = f"Error: AWS Glue data quality recommendation run: {event}"
309
+ if self.soft_fail:
310
+ raise AirflowSkipException(message)
311
+ raise AirflowException(message)
312
+
313
+ if self.show_results:
314
+ self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
315
+
316
+ self.log.info("AWS Glue data quality recommendation run completed.")
317
+
318
+ def poke(self, context: Context) -> bool:
319
+ self.log.info(
320
+ "Poking for AWS Glue data quality recommendation run RunId: %s", self.recommendation_run_id
321
+ )
322
+
323
+ response = self.hook.conn.get_data_quality_rule_recommendation_run(RunId=self.recommendation_run_id)
324
+
325
+ status = response.get("Status")
326
+
327
+ if status in self.SUCCESS_STATES:
328
+ if self.show_results:
329
+ self.hook.log_recommendation_results(run_id=self.recommendation_run_id)
330
+
331
+ self.log.info(
332
+ "AWS Glue data quality recommendation run completed RunId: %s Run State: %s",
333
+ self.recommendation_run_id,
334
+ response["Status"],
335
+ )
336
+
337
+ return True
338
+
339
+ elif status in self.FAILURE_STATES:
340
+ job_error_message = (
341
+ f"Error: AWS Glue data quality recommendation run RunId: {self.recommendation_run_id} Run "
342
+ f"Status: {status}"
343
+ f": {response.get('ErrorString')}"
344
+ )
345
+ self.log.info(job_error_message)
346
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
347
+ if self.soft_fail:
348
+ raise AirflowSkipException(job_error_message)
349
+ raise AirflowException(job_error_message)
350
+ else:
351
+ return False
@@ -78,6 +78,11 @@ class S3KeySensor(BaseSensorOperator):
78
78
  CA cert bundle than the one used by botocore.
79
79
  :param deferrable: Run operator in the deferrable mode
80
80
  :param use_regex: whether to use regex to check bucket
81
+ :param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
82
+ Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
83
+ all available attributes.
84
+ Default value: "Size".
85
+ If the requested attribute is not found, the key is still included and the value is None.
81
86
  """
82
87
 
83
88
  template_fields: Sequence[str] = ("bucket_key", "bucket_name")
@@ -93,6 +98,7 @@ class S3KeySensor(BaseSensorOperator):
93
98
  verify: str | bool | None = None,
94
99
  deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
95
100
  use_regex: bool = False,
101
+ metadata_keys: list[str] | None = None,
96
102
  **kwargs,
97
103
  ):
98
104
  super().__init__(**kwargs)
@@ -104,14 +110,14 @@ class S3KeySensor(BaseSensorOperator):
104
110
  self.verify = verify
105
111
  self.deferrable = deferrable
106
112
  self.use_regex = use_regex
113
+ self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
107
114
 
108
115
  def _check_key(self, key):
109
116
  bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
110
117
  self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
111
118
 
112
119
  """
113
- Set variable `files` which contains a list of dict which contains only the size
114
- If needed we might want to add other attributes later
120
+ Set variable `files` which contains a list of dict which contains attributes defined by the user
115
121
  Format: [{
116
122
  'Size': int
117
123
  }]
@@ -123,8 +129,21 @@ class S3KeySensor(BaseSensorOperator):
123
129
  if not key_matches:
124
130
  return False
125
131
 
126
- # Reduce the set of metadata to size only
127
- files = [{"Size": f["Size"]} for f in key_matches]
132
+ # Reduce the set of metadata to requested attributes
133
+ files = []
134
+ for f in key_matches:
135
+ metadata = {}
136
+ if "*" in self.metadata_keys:
137
+ metadata = self.hook.head_object(f["Key"], bucket_name)
138
+ else:
139
+ for key in self.metadata_keys:
140
+ try:
141
+ metadata[key] = f[key]
142
+ except KeyError:
143
+ # supplied key might be from head_object response
144
+ self.log.info("Key %s not found in response, performing head_object", key)
145
+ metadata[key] = self.hook.head_object(f["Key"], bucket_name).get(key, None)
146
+ files.append(metadata)
128
147
  elif self.use_regex:
129
148
  keys = self.hook.get_file_metadata("", bucket_name)
130
149
  key_matches = [k for k in keys if re.match(pattern=key, string=k["Key"])]
@@ -134,7 +153,18 @@ class S3KeySensor(BaseSensorOperator):
134
153
  obj = self.hook.head_object(key, bucket_name)
135
154
  if obj is None:
136
155
  return False
137
- files = [{"Size": obj["ContentLength"]}]
156
+ metadata = {}
157
+ if "*" in self.metadata_keys:
158
+ metadata = self.hook.head_object(key, bucket_name)
159
+
160
+ else:
161
+ for key in self.metadata_keys:
162
+ # backwards compatibility with original implementation
163
+ if key == "Size":
164
+ metadata[key] = obj.get("ContentLength")
165
+ else:
166
+ metadata[key] = obj.get(key, None)
167
+ files = [metadata]
138
168
 
139
169
  if self.check_fn is not None:
140
170
  return self.check_fn(files)
@@ -105,7 +105,6 @@ class DynamoDBToS3Operator(AwsToAwsBaseOperator):
105
105
  "file_size",
106
106
  "dynamodb_scan_kwargs",
107
107
  "s3_key_prefix",
108
- "process_func",
109
108
  "export_time",
110
109
  "export_format",
111
110
  "check_interval",
@@ -0,0 +1,257 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+ from __future__ import annotations
19
+
20
+ from typing import TYPE_CHECKING, Any, Literal, Sequence, TypedDict
21
+
22
+ from botocore.exceptions import ClientError, WaiterError
23
+
24
+ from airflow.exceptions import AirflowException
25
+ from airflow.models import BaseOperator
26
+ from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
27
+
28
+ if TYPE_CHECKING:
29
+ from airflow.utils.context import Context
30
+
31
+
32
+ class AttributeDefinition(TypedDict):
33
+ """Attribute Definition Type."""
34
+
35
+ AttributeName: str
36
+ AttributeType: Literal["S", "N", "B"]
37
+
38
+
39
+ class KeySchema(TypedDict):
40
+ """Key Schema Type."""
41
+
42
+ AttributeName: str
43
+ KeyType: Literal["HASH", "RANGE"]
44
+
45
+
46
+ class S3ToDynamoDBOperator(BaseOperator):
47
+ """Load Data from S3 into a DynamoDB.
48
+
49
+ Data stored in S3 can be uploaded to a new or existing DynamoDB. Supported file formats CSV, DynamoDB JSON and
50
+ Amazon ION.
51
+
52
+
53
+ :param s3_bucket: The S3 bucket that is imported
54
+ :param s3_key: Key prefix that imports single or multiple objects from S3
55
+ :param dynamodb_table_name: Name of the table that shall be created
56
+ :param dynamodb_key_schema: Primary key and sort key. Each element represents one primary key
57
+ attribute. AttributeName is the name of the attribute. KeyType is the role for the attribute. Valid values
58
+ HASH or RANGE
59
+ :param dynamodb_attributes: Name of the attributes of a table. AttributeName is the name for the attribute
60
+ AttributeType is the data type for the attribute. Valid values for AttributeType are
61
+ S - attribute is of type String
62
+ N - attribute is of type Number
63
+ B - attribute is of type Binary
64
+ :param dynamodb_tmp_table_prefix: Prefix for the temporary DynamoDB table
65
+ :param delete_on_error: If set, the new DynamoDB table will be deleted in case of import errors
66
+ :param use_existing_table: Whether to import to an existing non new DynamoDB table. If set to
67
+ true data is loaded first into a temporary DynamoDB table (using the AWS ImportTable Service),
68
+ then retrieved as chunks into memory and loaded into the target table. If set to false, a new
69
+ DynamoDB table will be created and S3 data is bulk loaded by the AWS ImportTable Service.
70
+ :param input_format: The format for the imported data. Valid values for InputFormat are CSV, DYNAMODB_JSON
71
+ or ION
72
+ :param billing_mode: Billing mode for the table. Valid values are PROVISIONED or PAY_PER_REQUEST
73
+ :param on_demand_throughput: Extra options for maximum number of read and write units
74
+ :param import_table_kwargs: Any additional optional import table parameters to pass, such as ClientToken,
75
+ InputCompressionType, or InputFormatOptions. See:
76
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html
77
+ :param import_table_creation_kwargs: Any additional optional import table creation parameters to pass, such as
78
+ ProvisionedThroughput, SSESpecification, or GlobalSecondaryIndexes. See:
79
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/import_table.html
80
+ :param wait_for_completion: Whether to wait for cluster to stop
81
+ :param check_interval: Time in seconds to wait between status checks
82
+ :param max_attempts: Maximum number of attempts to check for job completion
83
+ :param aws_conn_id: The reference to the AWS connection details
84
+ """
85
+
86
+ template_fields: Sequence[str] = (
87
+ "s3_bucket",
88
+ "s3_key",
89
+ "dynamodb_table_name",
90
+ "dynamodb_key_schema",
91
+ "dynamodb_attributes",
92
+ "dynamodb_tmp_table_prefix",
93
+ "delete_on_error",
94
+ "use_existing_table",
95
+ "input_format",
96
+ "billing_mode",
97
+ "import_table_kwargs",
98
+ "import_table_creation_kwargs",
99
+ )
100
+ ui_color = "#e2e8f0"
101
+
102
+ def __init__(
103
+ self,
104
+ *,
105
+ s3_bucket: str,
106
+ s3_key: str,
107
+ dynamodb_table_name: str,
108
+ dynamodb_key_schema: list[KeySchema],
109
+ dynamodb_attributes: list[AttributeDefinition] | None = None,
110
+ dynamodb_tmp_table_prefix: str = "tmp",
111
+ delete_on_error: bool = False,
112
+ use_existing_table: bool = False,
113
+ input_format: Literal["CSV", "DYNAMODB_JSON", "ION"] = "DYNAMODB_JSON",
114
+ billing_mode: Literal["PROVISIONED", "PAY_PER_REQUEST"] = "PAY_PER_REQUEST",
115
+ import_table_kwargs: dict[str, Any] | None = None,
116
+ import_table_creation_kwargs: dict[str, Any] | None = None,
117
+ wait_for_completion: bool = True,
118
+ check_interval: int = 30,
119
+ max_attempts: int = 240,
120
+ aws_conn_id: str | None = "aws_default",
121
+ **kwargs,
122
+ ) -> None:
123
+ super().__init__(**kwargs)
124
+ self.s3_bucket = s3_bucket
125
+ self.s3_key = s3_key
126
+ self.dynamodb_table_name = dynamodb_table_name
127
+ self.dynamodb_attributes = dynamodb_attributes
128
+ self.dynamodb_tmp_table_prefix = dynamodb_tmp_table_prefix
129
+ self.delete_on_error = delete_on_error
130
+ self.use_existing_table = use_existing_table
131
+ self.dynamodb_key_schema = dynamodb_key_schema
132
+ self.input_format = input_format
133
+ self.billing_mode = billing_mode
134
+ self.import_table_kwargs = import_table_kwargs
135
+ self.import_table_creation_kwargs = import_table_creation_kwargs
136
+ self.wait_for_completion = wait_for_completion
137
+ self.check_interval = check_interval
138
+ self.max_attempts = max_attempts
139
+ self.aws_conn_id = aws_conn_id
140
+
141
+ @property
142
+ def tmp_table_name(self):
143
+ """Temporary table name."""
144
+ return f"{self.dynamodb_tmp_table_prefix}_{self.dynamodb_table_name}"
145
+
146
+ def _load_into_new_table(self, table_name: str, delete_on_error: bool) -> str:
147
+ """
148
+ Import S3 key or keys into a new DynamoDB table.
149
+
150
+ :param table_name: Name of the table that shall be created
151
+ :param delete_on_error: If set, the new DynamoDB table will be deleted in case of import errors
152
+ :return: The Amazon resource number (ARN)
153
+ """
154
+ dynamodb_hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
155
+ client = dynamodb_hook.client
156
+
157
+ import_table_config = self.import_table_kwargs or {}
158
+ import_table_creation_config = self.import_table_creation_kwargs or {}
159
+
160
+ try:
161
+ response = client.import_table(
162
+ S3BucketSource={
163
+ "S3Bucket": self.s3_bucket,
164
+ "S3KeyPrefix": self.s3_key,
165
+ },
166
+ InputFormat=self.input_format,
167
+ TableCreationParameters={
168
+ "TableName": table_name,
169
+ "AttributeDefinitions": self.dynamodb_attributes,
170
+ "KeySchema": self.dynamodb_key_schema,
171
+ "BillingMode": self.billing_mode,
172
+ **import_table_creation_config,
173
+ },
174
+ **import_table_config,
175
+ )
176
+ except ClientError as e:
177
+ self.log.error("Error: failed to load from S3 into DynamoDB table. Error: %s", str(e))
178
+ raise AirflowException(f"S3 load into DynamoDB table failed with error: {e}")
179
+
180
+ if response["ImportTableDescription"]["ImportStatus"] == "FAILED":
181
+ raise AirflowException(
182
+ "S3 into Dynamodb job creation failed. Code: "
183
+ f"{response['ImportTableDescription']['FailureCode']}. "
184
+ f"Failure: {response['ImportTableDescription']['FailureMessage']}"
185
+ )
186
+
187
+ if self.wait_for_completion:
188
+ self.log.info("Waiting for S3 into Dynamodb job to complete")
189
+ waiter = dynamodb_hook.get_waiter("import_table")
190
+ try:
191
+ waiter.wait(
192
+ ImportArn=response["ImportTableDescription"]["ImportArn"],
193
+ WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_attempts},
194
+ )
195
+ except WaiterError:
196
+ status, error_code, error_msg = dynamodb_hook.get_import_status(
197
+ response["ImportTableDescription"]["ImportArn"]
198
+ )
199
+ if delete_on_error:
200
+ client.delete_table(TableName=table_name)
201
+ raise AirflowException(
202
+ f"S3 import into Dynamodb job failed: Status: {status}. Error: {error_code}. Error message: {error_msg}"
203
+ )
204
+ return response["ImportTableDescription"]["ImportArn"]
205
+
206
+ def _load_into_existing_table(self) -> str:
207
+ """
208
+ Import S3 key or keys in an existing DynamoDB table.
209
+
210
+ :return:The Amazon resource number (ARN)
211
+ """
212
+ if not self.wait_for_completion:
213
+ raise ValueError("wait_for_completion must be set to True when loading into an existing table")
214
+ table_keys = [key["AttributeName"] for key in self.dynamodb_key_schema]
215
+
216
+ dynamodb_hook = DynamoDBHook(
217
+ aws_conn_id=self.aws_conn_id, table_name=self.dynamodb_table_name, table_keys=table_keys
218
+ )
219
+ client = dynamodb_hook.client
220
+
221
+ self.log.info("Loading from S3 into a tmp DynamoDB table %s", self.tmp_table_name)
222
+ self._load_into_new_table(table_name=self.tmp_table_name, delete_on_error=self.delete_on_error)
223
+ total_items = 0
224
+ try:
225
+ paginator = client.get_paginator("scan")
226
+ paginate = paginator.paginate(
227
+ TableName=self.tmp_table_name,
228
+ Select="ALL_ATTRIBUTES",
229
+ ReturnConsumedCapacity="NONE",
230
+ ConsistentRead=True,
231
+ )
232
+ self.log.info(
233
+ "Loading data from %s to %s DynamoDB table", self.tmp_table_name, self.dynamodb_table_name
234
+ )
235
+ for page in paginate:
236
+ total_items += page.get("Count", 0)
237
+ dynamodb_hook.write_batch_data(items=page["Items"])
238
+ self.log.info("Number of items loaded: %s", total_items)
239
+ finally:
240
+ self.log.info("Delete tmp DynamoDB table %s", self.tmp_table_name)
241
+ client.delete_table(TableName=self.tmp_table_name)
242
+ return dynamodb_hook.get_conn().Table(self.dynamodb_table_name).table_arn
243
+
244
+ def execute(self, context: Context) -> str:
245
+ """
246
+ Execute S3 to DynamoDB Job from Airflow.
247
+
248
+ :param context: The current context of the task instance
249
+ :return: The Amazon resource number (ARN)
250
+ """
251
+ if self.use_existing_table:
252
+ self.log.info("Loading from S3 into new DynamoDB table %s", self.dynamodb_table_name)
253
+ return self._load_into_existing_table()
254
+ self.log.info("Loading from S3 into existing DynamoDB table %s", self.dynamodb_table_name)
255
+ return self._load_into_new_table(
256
+ table_name=self.dynamodb_table_name, delete_on_error=self.delete_on_error
257
+ )