apache-airflow-providers-amazon 7.4.1rc1__py3-none-any.whl → 8.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/aws/hooks/athena.py +0 -15
- airflow/providers/amazon/aws/hooks/base_aws.py +98 -65
- airflow/providers/amazon/aws/hooks/batch_client.py +60 -27
- airflow/providers/amazon/aws/hooks/batch_waiters.py +3 -1
- airflow/providers/amazon/aws/hooks/emr.py +33 -74
- airflow/providers/amazon/aws/hooks/logs.py +22 -4
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -12
- airflow/providers/amazon/aws/hooks/sagemaker.py +0 -16
- airflow/providers/amazon/aws/links/emr.py +1 -3
- airflow/providers/amazon/aws/operators/athena.py +0 -15
- airflow/providers/amazon/aws/operators/batch.py +78 -24
- airflow/providers/amazon/aws/operators/ecs.py +21 -58
- airflow/providers/amazon/aws/operators/eks.py +0 -1
- airflow/providers/amazon/aws/operators/emr.py +94 -24
- airflow/providers/amazon/aws/operators/lambda_function.py +0 -19
- airflow/providers/amazon/aws/operators/rds.py +1 -1
- airflow/providers/amazon/aws/operators/redshift_cluster.py +22 -1
- airflow/providers/amazon/aws/operators/redshift_data.py +0 -62
- airflow/providers/amazon/aws/secrets/secrets_manager.py +0 -17
- airflow/providers/amazon/aws/secrets/systems_manager.py +0 -21
- airflow/providers/amazon/aws/sensors/dynamodb.py +97 -0
- airflow/providers/amazon/aws/sensors/emr.py +1 -2
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +0 -19
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -7
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +10 -10
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +0 -10
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +0 -11
- airflow/providers/amazon/aws/transfers/s3_to_sftp.py +0 -10
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +23 -9
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
- airflow/providers/amazon/aws/waiters/base_waiter.py +12 -1
- airflow/providers/amazon/aws/waiters/emr-serverless.json +18 -0
- airflow/providers/amazon/get_provider_info.py +35 -30
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/METADATA +81 -4
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/RECORD +41 -41
- airflow/providers/amazon/aws/operators/aws_lambda.py +0 -29
- airflow/providers/amazon/aws/operators/redshift_sql.py +0 -57
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,10 @@ from typing import TYPE_CHECKING, Any, Sequence
|
|
22
22
|
from airflow.exceptions import AirflowException
|
23
23
|
from airflow.models import BaseOperator
|
24
24
|
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
|
25
|
-
from airflow.providers.amazon.aws.triggers.redshift_cluster import
|
25
|
+
from airflow.providers.amazon.aws.triggers.redshift_cluster import (
|
26
|
+
RedshiftClusterTrigger,
|
27
|
+
RedshiftCreateClusterTrigger,
|
28
|
+
)
|
26
29
|
|
27
30
|
if TYPE_CHECKING:
|
28
31
|
from airflow.utils.context import Context
|
@@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
88
91
|
:param wait_for_completion: Whether wait for the cluster to be in ``available`` state
|
89
92
|
:param max_attempt: The maximum number of attempts to be made. Default: 5
|
90
93
|
:param poll_interval: The amount of time in seconds to wait between attempts. Default: 60
|
94
|
+
:param deferrable: If True, the operator will run in deferrable mode
|
91
95
|
"""
|
92
96
|
|
93
97
|
template_fields: Sequence[str] = (
|
@@ -140,6 +144,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
140
144
|
wait_for_completion: bool = False,
|
141
145
|
max_attempt: int = 5,
|
142
146
|
poll_interval: int = 60,
|
147
|
+
deferrable: bool = False,
|
143
148
|
**kwargs,
|
144
149
|
):
|
145
150
|
super().__init__(**kwargs)
|
@@ -180,6 +185,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
180
185
|
self.wait_for_completion = wait_for_completion
|
181
186
|
self.max_attempt = max_attempt
|
182
187
|
self.poll_interval = poll_interval
|
188
|
+
self.deferrable = deferrable
|
183
189
|
self.kwargs = kwargs
|
184
190
|
|
185
191
|
def execute(self, context: Context):
|
@@ -252,6 +258,16 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
252
258
|
self.master_user_password,
|
253
259
|
params,
|
254
260
|
)
|
261
|
+
if self.deferrable:
|
262
|
+
self.defer(
|
263
|
+
trigger=RedshiftCreateClusterTrigger(
|
264
|
+
cluster_identifier=self.cluster_identifier,
|
265
|
+
poll_interval=self.poll_interval,
|
266
|
+
max_attempt=self.max_attempt,
|
267
|
+
aws_conn_id=self.aws_conn_id,
|
268
|
+
),
|
269
|
+
method_name="execute_complete",
|
270
|
+
)
|
255
271
|
if self.wait_for_completion:
|
256
272
|
redshift_hook.get_conn().get_waiter("cluster_available").wait(
|
257
273
|
ClusterIdentifier=self.cluster_identifier,
|
@@ -264,6 +280,11 @@ class RedshiftCreateClusterOperator(BaseOperator):
|
|
264
280
|
self.log.info("Created Redshift cluster %s", self.cluster_identifier)
|
265
281
|
self.log.info(cluster)
|
266
282
|
|
283
|
+
def execute_complete(self, context, event=None):
|
284
|
+
if event["status"] != "success":
|
285
|
+
raise AirflowException(f"Error creating cluster: {event}")
|
286
|
+
return
|
287
|
+
|
267
288
|
|
268
289
|
class RedshiftCreateClusterSnapshotOperator(BaseOperator):
|
269
290
|
"""
|
@@ -17,7 +17,6 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import warnings
|
21
20
|
from typing import TYPE_CHECKING
|
22
21
|
|
23
22
|
from airflow.compat.functools import cached_property
|
@@ -83,7 +82,6 @@ class RedshiftDataOperator(BaseOperator):
|
|
83
82
|
return_sql_result: bool = False,
|
84
83
|
aws_conn_id: str = "aws_default",
|
85
84
|
region: str | None = None,
|
86
|
-
await_result: bool | None = None,
|
87
85
|
**kwargs,
|
88
86
|
) -> None:
|
89
87
|
super().__init__(**kwargs)
|
@@ -95,16 +93,7 @@ class RedshiftDataOperator(BaseOperator):
|
|
95
93
|
self.secret_arn = secret_arn
|
96
94
|
self.statement_name = statement_name
|
97
95
|
self.with_event = with_event
|
98
|
-
self.await_result = await_result
|
99
96
|
self.wait_for_completion = wait_for_completion
|
100
|
-
if await_result:
|
101
|
-
warnings.warn(
|
102
|
-
f"Parameter `{self.__class__.__name__}.await_result` is deprecated and will be removed "
|
103
|
-
"in a future release. Please use method `wait_for_completion` instead.",
|
104
|
-
DeprecationWarning,
|
105
|
-
stacklevel=2,
|
106
|
-
)
|
107
|
-
self.wait_for_completion = await_result
|
108
97
|
if poll_interval > 0:
|
109
98
|
self.poll_interval = poll_interval
|
110
99
|
else:
|
@@ -122,57 +111,6 @@ class RedshiftDataOperator(BaseOperator):
|
|
122
111
|
"""Create and return an RedshiftDataHook."""
|
123
112
|
return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
|
124
113
|
|
125
|
-
def execute_query(self) -> str:
|
126
|
-
warnings.warn(
|
127
|
-
"This method is deprecated and has been moved to the hook "
|
128
|
-
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
|
129
|
-
DeprecationWarning,
|
130
|
-
stacklevel=2,
|
131
|
-
)
|
132
|
-
self.statement_id = self.hook.execute_query(
|
133
|
-
database=self.database,
|
134
|
-
sql=self.sql,
|
135
|
-
cluster_identifier=self.cluster_identifier,
|
136
|
-
db_user=self.db_user,
|
137
|
-
parameters=self.parameters,
|
138
|
-
secret_arn=self.secret_arn,
|
139
|
-
statement_name=self.statement_name,
|
140
|
-
with_event=self.with_event,
|
141
|
-
wait_for_completion=self.wait_for_completion,
|
142
|
-
poll_interval=self.poll_interval,
|
143
|
-
)
|
144
|
-
return self.statement_id
|
145
|
-
|
146
|
-
def execute_batch_query(self) -> str:
|
147
|
-
warnings.warn(
|
148
|
-
"This method is deprecated and has been moved to the hook "
|
149
|
-
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
|
150
|
-
DeprecationWarning,
|
151
|
-
stacklevel=2,
|
152
|
-
)
|
153
|
-
self.statement_id = self.hook.execute_query(
|
154
|
-
database=self.database,
|
155
|
-
sql=self.sql,
|
156
|
-
cluster_identifier=self.cluster_identifier,
|
157
|
-
db_user=self.db_user,
|
158
|
-
parameters=self.parameters,
|
159
|
-
secret_arn=self.secret_arn,
|
160
|
-
statement_name=self.statement_name,
|
161
|
-
with_event=self.with_event,
|
162
|
-
wait_for_completion=self.wait_for_completion,
|
163
|
-
poll_interval=self.poll_interval,
|
164
|
-
)
|
165
|
-
return self.statement_id
|
166
|
-
|
167
|
-
def wait_for_results(self, statement_id: str):
|
168
|
-
warnings.warn(
|
169
|
-
"This method is deprecated and has been moved to the hook "
|
170
|
-
"`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
|
171
|
-
DeprecationWarning,
|
172
|
-
stacklevel=2,
|
173
|
-
)
|
174
|
-
return self.hook.wait_for_results(statement_id=statement_id, poll_interval=self.poll_interval)
|
175
|
-
|
176
114
|
def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
|
177
115
|
"""Execute a statement against Amazon Redshift"""
|
178
116
|
self.log.info("Executing statement: %s", self.sql)
|
@@ -267,23 +267,6 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
|
|
267
267
|
else:
|
268
268
|
return secret
|
269
269
|
|
270
|
-
def get_conn_uri(self, conn_id: str) -> str | None:
|
271
|
-
"""
|
272
|
-
Return URI representation of Connection conn_id.
|
273
|
-
|
274
|
-
As of Airflow version 2.3.0 this method is deprecated.
|
275
|
-
|
276
|
-
:param conn_id: the connection id
|
277
|
-
:return: deserialized Connection
|
278
|
-
"""
|
279
|
-
warnings.warn(
|
280
|
-
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
|
281
|
-
"in a future release. Please use method `get_conn_value` instead.",
|
282
|
-
DeprecationWarning,
|
283
|
-
stacklevel=2,
|
284
|
-
)
|
285
|
-
return self.get_conn_value(conn_id)
|
286
|
-
|
287
270
|
def get_variable(self, key: str) -> str | None:
|
288
271
|
"""
|
289
272
|
Get Airflow Variable
|
@@ -19,7 +19,6 @@
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
21
|
import re
|
22
|
-
import warnings
|
23
22
|
|
24
23
|
from airflow.compat.functools import cached_property
|
25
24
|
from airflow.providers.amazon.aws.utils import trim_none_values
|
@@ -143,26 +142,6 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
|
|
143
142
|
|
144
143
|
return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern)
|
145
144
|
|
146
|
-
def get_conn_uri(self, conn_id: str) -> str | None:
|
147
|
-
"""
|
148
|
-
Return URI representation of Connection conn_id.
|
149
|
-
|
150
|
-
As of Airflow version 2.3.0 this method is deprecated.
|
151
|
-
|
152
|
-
:param conn_id: the connection id
|
153
|
-
"""
|
154
|
-
warnings.warn(
|
155
|
-
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
|
156
|
-
"in a future release. Please use method `get_conn_value` instead.",
|
157
|
-
DeprecationWarning,
|
158
|
-
stacklevel=2,
|
159
|
-
)
|
160
|
-
value = self.get_conn_value(conn_id)
|
161
|
-
if value is None:
|
162
|
-
return None
|
163
|
-
|
164
|
-
return self.deserialize_connection(conn_id, value).get_uri()
|
165
|
-
|
166
145
|
def get_variable(self, key: str) -> str | None:
|
167
146
|
"""
|
168
147
|
Get Airflow Variable
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING, Any
|
20
|
+
|
21
|
+
from airflow.compat.functools import cached_property
|
22
|
+
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
|
23
|
+
from airflow.sensors.base import BaseSensorOperator
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from airflow.utils.context import Context
|
27
|
+
|
28
|
+
|
29
|
+
class DynamoDBValueSensor(BaseSensorOperator):
|
30
|
+
"""
|
31
|
+
Waits for an attribute value to be present for an item in a DynamoDB table.
|
32
|
+
|
33
|
+
:param partition_key_name: DynamoDB partition key name
|
34
|
+
:param partition_key_value: DynamoDB partition key value
|
35
|
+
:param attribute_name: DynamoDB attribute name
|
36
|
+
:param attribute_value: DynamoDB attribute value
|
37
|
+
:param sort_key_name: (optional) DynamoDB sort key name
|
38
|
+
:param sort_key_value: (optional) DynamoDB sort key value
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
table_name: str,
|
44
|
+
partition_key_name: str,
|
45
|
+
partition_key_value: str,
|
46
|
+
attribute_name: str,
|
47
|
+
attribute_value: str,
|
48
|
+
sort_key_name: str | None = None,
|
49
|
+
sort_key_value: str | None = None,
|
50
|
+
aws_conn_id: str | None = DynamoDBHook.default_conn_name,
|
51
|
+
region_name: str | None = None,
|
52
|
+
**kwargs: Any,
|
53
|
+
):
|
54
|
+
super().__init__(**kwargs)
|
55
|
+
self.table_name = table_name
|
56
|
+
self.partition_key_name = partition_key_name
|
57
|
+
self.partition_key_value = partition_key_value
|
58
|
+
self.attribute_name = attribute_name
|
59
|
+
self.attribute_value = attribute_value
|
60
|
+
self.sort_key_name = sort_key_name
|
61
|
+
self.sort_key_value = sort_key_value
|
62
|
+
self.aws_conn_id = aws_conn_id
|
63
|
+
self.region_name = region_name
|
64
|
+
|
65
|
+
def poke(self, context: Context) -> bool:
|
66
|
+
"""Test DynamoDB item for matching attribute value"""
|
67
|
+
key = {self.partition_key_name: self.partition_key_value}
|
68
|
+
msg = (
|
69
|
+
f"Checking table {self.table_name} for "
|
70
|
+
+ f"item Partition Key: {self.partition_key_name}={self.partition_key_value}"
|
71
|
+
)
|
72
|
+
|
73
|
+
if self.sort_key_name and self.sort_key_value:
|
74
|
+
key = {self.partition_key_name: self.partition_key_value, self.sort_key_name: self.sort_key_value}
|
75
|
+
msg += f"\nSort Key: {self.sort_key_name}={self.sort_key_value}"
|
76
|
+
|
77
|
+
msg += f"\nattribute: {self.attribute_name}={self.attribute_value}"
|
78
|
+
|
79
|
+
self.log.info(msg)
|
80
|
+
table = self.hook.conn.Table(self.table_name)
|
81
|
+
self.log.info("Table: %s", table)
|
82
|
+
self.log.info("Key: %s", key)
|
83
|
+
response = table.get_item(Key=key)
|
84
|
+
try:
|
85
|
+
self.log.info("Response: %s", response)
|
86
|
+
self.log.info("Want: %s = %s", self.attribute_name, self.attribute_value)
|
87
|
+
self.log.info(
|
88
|
+
"Got: {response['Item'][self.attribute_name]} = %s", response["Item"][self.attribute_name]
|
89
|
+
)
|
90
|
+
return response["Item"][self.attribute_name] == self.attribute_value
|
91
|
+
except KeyError:
|
92
|
+
return False
|
93
|
+
|
94
|
+
@cached_property
|
95
|
+
def hook(self) -> DynamoDBHook:
|
96
|
+
"""Create and return a DynamoDBHook"""
|
97
|
+
return DynamoDBHook(self.aws_conn_id, region_name=self.region_name)
|
@@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException
|
|
25
25
|
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
|
26
26
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
27
27
|
from airflow.providers.amazon.aws.links.emr import EmrLogsLink
|
28
|
-
from airflow.sensors.base import BaseSensorOperator
|
28
|
+
from airflow.sensors.base import BaseSensorOperator
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
31
|
from airflow.utils.context import Context
|
@@ -451,7 +451,6 @@ class EmrJobFlowSensor(EmrBaseSensor):
|
|
451
451
|
return None
|
452
452
|
|
453
453
|
|
454
|
-
@poke_mode_only
|
455
454
|
class EmrStepSensor(EmrBaseSensor):
|
456
455
|
"""
|
457
456
|
Asks for the state of the step until it reaches any of the target states.
|
@@ -87,7 +87,7 @@ class DynamoDBToS3Operator(AwsToAwsBaseOperator):
|
|
87
87
|
:param dynamodb_scan_kwargs: kwargs pass to <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.scan>
|
88
88
|
:param s3_key_prefix: Prefix of s3 object key
|
89
89
|
:param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
|
90
|
-
"""
|
90
|
+
"""
|
91
91
|
|
92
92
|
template_fields: Sequence[str] = (
|
93
93
|
*AwsToAwsBaseOperator.template_fields,
|
@@ -19,7 +19,6 @@
|
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
21
|
import os
|
22
|
-
import warnings
|
23
22
|
from typing import TYPE_CHECKING, Sequence
|
24
23
|
|
25
24
|
from airflow.models import BaseOperator
|
@@ -45,11 +44,6 @@ class GCSToS3Operator(BaseOperator):
|
|
45
44
|
For e.g to lists the CSV files from in a directory in GCS you would use
|
46
45
|
delimiter='.csv'.
|
47
46
|
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
|
48
|
-
:param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
|
49
|
-
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
|
50
|
-
:param delegate_to: Google account to impersonate using domain-wide delegation of authority,
|
51
|
-
if any. For this to work, the service account making the request must have
|
52
|
-
domain-wide delegation enabled.
|
53
47
|
:param dest_aws_conn_id: The destination S3 connection
|
54
48
|
:param dest_s3_key: The base S3 key to be used to store the files. (templated)
|
55
49
|
:param dest_verify: Whether or not to verify SSL certificates for S3 connection.
|
@@ -100,8 +94,6 @@ class GCSToS3Operator(BaseOperator):
|
|
100
94
|
prefix: str | None = None,
|
101
95
|
delimiter: str | None = None,
|
102
96
|
gcp_conn_id: str = "google_cloud_default",
|
103
|
-
google_cloud_storage_conn_id: str | None = None,
|
104
|
-
delegate_to: str | None = None,
|
105
97
|
dest_aws_conn_id: str = "aws_default",
|
106
98
|
dest_s3_key: str,
|
107
99
|
dest_verify: str | bool | None = None,
|
@@ -114,20 +106,10 @@ class GCSToS3Operator(BaseOperator):
|
|
114
106
|
) -> None:
|
115
107
|
super().__init__(**kwargs)
|
116
108
|
|
117
|
-
if google_cloud_storage_conn_id:
|
118
|
-
warnings.warn(
|
119
|
-
"The google_cloud_storage_conn_id parameter has been deprecated. You should pass "
|
120
|
-
"the gcp_conn_id parameter.",
|
121
|
-
DeprecationWarning,
|
122
|
-
stacklevel=3,
|
123
|
-
)
|
124
|
-
gcp_conn_id = google_cloud_storage_conn_id
|
125
|
-
|
126
109
|
self.bucket = bucket
|
127
110
|
self.prefix = prefix
|
128
111
|
self.delimiter = delimiter
|
129
112
|
self.gcp_conn_id = gcp_conn_id
|
130
|
-
self.delegate_to = delegate_to
|
131
113
|
self.dest_aws_conn_id = dest_aws_conn_id
|
132
114
|
self.dest_s3_key = dest_s3_key
|
133
115
|
self.dest_verify = dest_verify
|
@@ -141,7 +123,6 @@ class GCSToS3Operator(BaseOperator):
|
|
141
123
|
# list all files in an Google Cloud Storage bucket
|
142
124
|
hook = GCSHook(
|
143
125
|
gcp_conn_id=self.gcp_conn_id,
|
144
|
-
delegate_to=self.delegate_to,
|
145
126
|
impersonation_chain=self.google_impersonation_chain,
|
146
127
|
)
|
147
128
|
|
@@ -47,10 +47,7 @@ class GlacierToGCSOperator(BaseOperator):
|
|
47
47
|
:param object_name: the name of the object to check in the Google cloud
|
48
48
|
storage bucket.
|
49
49
|
:param gzip: option to compress local file or file data for upload
|
50
|
-
:param chunk_size: size of chunk in bytes the that will downloaded from Glacier vault
|
51
|
-
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
|
52
|
-
if any. For this to work, the service account making the request must have
|
53
|
-
domain-wide delegation enabled.
|
50
|
+
:param chunk_size: size of chunk in bytes the that will be downloaded from Glacier vault
|
54
51
|
:param google_impersonation_chain: Optional Google service account to impersonate using
|
55
52
|
short-term credentials, or chained list of accounts required to get the access_token
|
56
53
|
of the last account in the list, which will be impersonated in the request.
|
@@ -73,7 +70,6 @@ class GlacierToGCSOperator(BaseOperator):
|
|
73
70
|
object_name: str,
|
74
71
|
gzip: bool,
|
75
72
|
chunk_size: int = 1024,
|
76
|
-
delegate_to: str | None = None,
|
77
73
|
google_impersonation_chain: str | Sequence[str] | None = None,
|
78
74
|
**kwargs,
|
79
75
|
) -> None:
|
@@ -85,14 +81,12 @@ class GlacierToGCSOperator(BaseOperator):
|
|
85
81
|
self.object_name = object_name
|
86
82
|
self.gzip = gzip
|
87
83
|
self.chunk_size = chunk_size
|
88
|
-
self.delegate_to = delegate_to
|
89
84
|
self.impersonation_chain = google_impersonation_chain
|
90
85
|
|
91
86
|
def execute(self, context: Context) -> str:
|
92
87
|
glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
|
93
88
|
gcs_hook = GCSHook(
|
94
89
|
gcp_conn_id=self.gcp_conn_id,
|
95
|
-
delegate_to=self.delegate_to,
|
96
90
|
impersonation_chain=self.impersonation_chain,
|
97
91
|
)
|
98
92
|
job_id = glacier_hook.retrieve_inventory(vault_name=self.vault_name)
|
@@ -22,12 +22,14 @@ import json
|
|
22
22
|
import sys
|
23
23
|
from typing import TYPE_CHECKING, Sequence
|
24
24
|
|
25
|
-
from airflow.models import BaseOperator
|
25
|
+
from airflow.models import BaseOperator
|
26
26
|
from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
|
27
27
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
28
28
|
from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook
|
29
29
|
|
30
30
|
if TYPE_CHECKING:
|
31
|
+
from airflow.models import TaskInstance
|
32
|
+
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
|
31
33
|
from airflow.utils.context import Context
|
32
34
|
|
33
35
|
|
@@ -71,13 +73,10 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
71
73
|
|
72
74
|
.. note:: This means the response will be a list of responses.
|
73
75
|
|
74
|
-
:param google_api_num_retries: Define the number of retries for the
|
76
|
+
:param google_api_num_retries: Define the number of retries for the Google API requests being made
|
75
77
|
if it fails.
|
76
78
|
:param s3_overwrite: Specifies whether the s3 file will be overwritten if exists.
|
77
79
|
:param gcp_conn_id: The connection ID to use when fetching connection info.
|
78
|
-
:param delegate_to: Google account to impersonate using domain-wide delegation of authority,
|
79
|
-
if any. For this to work, the service account making the request must have
|
80
|
-
domain-wide delegation enabled.
|
81
80
|
:param aws_conn_id: The connection id specifying the authentication information for the S3 Bucket.
|
82
81
|
:param google_impersonation_chain: Optional Google service account to impersonate using
|
83
82
|
short-term credentials, or chained list of accounts required to get the access_token
|
@@ -113,7 +112,6 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
113
112
|
google_api_num_retries: int = 0,
|
114
113
|
s3_overwrite: bool = False,
|
115
114
|
gcp_conn_id: str = "google_cloud_default",
|
116
|
-
delegate_to: str | None = None,
|
117
115
|
aws_conn_id: str = "aws_default",
|
118
116
|
google_impersonation_chain: str | Sequence[str] | None = None,
|
119
117
|
**kwargs,
|
@@ -131,7 +129,6 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
131
129
|
self.google_api_num_retries = google_api_num_retries
|
132
130
|
self.s3_overwrite = s3_overwrite
|
133
131
|
self.gcp_conn_id = gcp_conn_id
|
134
|
-
self.delegate_to = delegate_to
|
135
132
|
self.aws_conn_id = aws_conn_id
|
136
133
|
self.google_impersonation_chain = google_impersonation_chain
|
137
134
|
|
@@ -156,7 +153,6 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
156
153
|
def _retrieve_data_from_google_api(self) -> dict:
|
157
154
|
google_discovery_api_hook = GoogleDiscoveryApiHook(
|
158
155
|
gcp_conn_id=self.gcp_conn_id,
|
159
|
-
delegate_to=self.delegate_to,
|
160
156
|
api_service_name=self.google_api_service_name,
|
161
157
|
api_version=self.google_api_service_version,
|
162
158
|
impersonation_chain=self.google_impersonation_chain,
|
@@ -177,7 +173,9 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
177
173
|
replace=self.s3_overwrite,
|
178
174
|
)
|
179
175
|
|
180
|
-
def _update_google_api_endpoint_params_via_xcom(
|
176
|
+
def _update_google_api_endpoint_params_via_xcom(
|
177
|
+
self, task_instance: TaskInstance | TaskInstancePydantic
|
178
|
+
) -> None:
|
181
179
|
|
182
180
|
if self.google_api_endpoint_params_via_xcom:
|
183
181
|
google_api_endpoint_params = task_instance.xcom_pull(
|
@@ -186,7 +184,9 @@ class GoogleApiToS3Operator(BaseOperator):
|
|
186
184
|
)
|
187
185
|
self.google_api_endpoint_params.update(google_api_endpoint_params)
|
188
186
|
|
189
|
-
def _expose_google_api_response_via_xcom(
|
187
|
+
def _expose_google_api_response_via_xcom(
|
188
|
+
self, task_instance: TaskInstance | TaskInstancePydantic, data: dict
|
189
|
+
) -> None:
|
190
190
|
if sys.getsizeof(data) < MAX_XCOM_SIZE:
|
191
191
|
task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
|
192
192
|
else:
|
@@ -18,7 +18,6 @@
|
|
18
18
|
"""This module allows you to transfer mail attachments from a mail server into s3 bucket."""
|
19
19
|
from __future__ import annotations
|
20
20
|
|
21
|
-
import warnings
|
22
21
|
from typing import TYPE_CHECKING, Sequence
|
23
22
|
|
24
23
|
from airflow.models import BaseOperator
|
@@ -28,10 +27,6 @@ from airflow.providers.imap.hooks.imap import ImapHook
|
|
28
27
|
if TYPE_CHECKING:
|
29
28
|
from airflow.utils.context import Context
|
30
29
|
|
31
|
-
_DEPRECATION_MSG = (
|
32
|
-
"The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
|
33
|
-
)
|
34
|
-
|
35
30
|
|
36
31
|
class ImapAttachmentToS3Operator(BaseOperator):
|
37
32
|
"""
|
@@ -66,15 +61,10 @@ class ImapAttachmentToS3Operator(BaseOperator):
|
|
66
61
|
imap_mail_filter: str = "All",
|
67
62
|
s3_overwrite: bool = False,
|
68
63
|
imap_conn_id: str = "imap_default",
|
69
|
-
s3_conn_id: str | None = None,
|
70
64
|
aws_conn_id: str = "aws_default",
|
71
65
|
**kwargs,
|
72
66
|
) -> None:
|
73
67
|
super().__init__(**kwargs)
|
74
|
-
if s3_conn_id:
|
75
|
-
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
|
76
|
-
aws_conn_id = s3_conn_id
|
77
|
-
|
78
68
|
self.imap_attachment_name = imap_attachment_name
|
79
69
|
self.s3_bucket = s3_bucket
|
80
70
|
self.s3_key = s3_key
|
@@ -18,7 +18,6 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import json
|
21
|
-
import warnings
|
22
21
|
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast
|
23
22
|
|
24
23
|
from bson import json_util
|
@@ -31,11 +30,6 @@ if TYPE_CHECKING:
|
|
31
30
|
from airflow.utils.context import Context
|
32
31
|
|
33
32
|
|
34
|
-
_DEPRECATION_MSG = (
|
35
|
-
"The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
|
36
|
-
)
|
37
|
-
|
38
|
-
|
39
33
|
class MongoToS3Operator(BaseOperator):
|
40
34
|
"""Operator meant to move data from mongo via pymongo to s3 via boto.
|
41
35
|
|
@@ -66,7 +60,6 @@ class MongoToS3Operator(BaseOperator):
|
|
66
60
|
def __init__(
|
67
61
|
self,
|
68
62
|
*,
|
69
|
-
s3_conn_id: str | None = None,
|
70
63
|
mongo_conn_id: str = "mongo_default",
|
71
64
|
aws_conn_id: str = "aws_default",
|
72
65
|
mongo_collection: str,
|
@@ -81,10 +74,6 @@ class MongoToS3Operator(BaseOperator):
|
|
81
74
|
**kwargs,
|
82
75
|
) -> None:
|
83
76
|
super().__init__(**kwargs)
|
84
|
-
if s3_conn_id:
|
85
|
-
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
|
86
|
-
aws_conn_id = s3_conn_id
|
87
|
-
|
88
77
|
self.mongo_conn_id = mongo_conn_id
|
89
78
|
self.aws_conn_id = aws_conn_id
|
90
79
|
self.mongo_db = mongo_db
|
@@ -17,7 +17,6 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import warnings
|
21
20
|
from tempfile import NamedTemporaryFile
|
22
21
|
from typing import TYPE_CHECKING, Sequence
|
23
22
|
from urllib.parse import urlsplit
|
@@ -29,10 +28,6 @@ from airflow.providers.ssh.hooks.ssh import SSHHook
|
|
29
28
|
if TYPE_CHECKING:
|
30
29
|
from airflow.utils.context import Context
|
31
30
|
|
32
|
-
_DEPRECATION_MSG = (
|
33
|
-
"The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
|
34
|
-
)
|
35
|
-
|
36
31
|
|
37
32
|
class S3ToSFTPOperator(BaseOperator):
|
38
33
|
"""
|
@@ -62,15 +57,10 @@ class S3ToSFTPOperator(BaseOperator):
|
|
62
57
|
s3_key: str,
|
63
58
|
sftp_path: str,
|
64
59
|
sftp_conn_id: str = "ssh_default",
|
65
|
-
s3_conn_id: str | None = None,
|
66
60
|
aws_conn_id: str = "aws_default",
|
67
61
|
**kwargs,
|
68
62
|
) -> None:
|
69
63
|
super().__init__(**kwargs)
|
70
|
-
if s3_conn_id:
|
71
|
-
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
|
72
|
-
aws_conn_id = s3_conn_id
|
73
|
-
|
74
64
|
self.sftp_conn_id = sftp_conn_id
|
75
65
|
self.sftp_path = sftp_path
|
76
66
|
self.s3_bucket = s3_bucket
|
@@ -80,6 +80,7 @@ class SqlToS3Operator(BaseOperator):
|
|
80
80
|
CA cert bundle than the one used by botocore.
|
81
81
|
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
|
82
82
|
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
|
83
|
+
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
|
83
84
|
"""
|
84
85
|
|
85
86
|
template_fields: Sequence[str] = (
|
@@ -107,6 +108,7 @@ class SqlToS3Operator(BaseOperator):
|
|
107
108
|
verify: bool | str | None = None,
|
108
109
|
file_format: Literal["csv", "json", "parquet"] = "csv",
|
109
110
|
pd_kwargs: dict | None = None,
|
111
|
+
groupby_kwargs: dict | None = None,
|
110
112
|
**kwargs,
|
111
113
|
) -> None:
|
112
114
|
super().__init__(**kwargs)
|
@@ -119,6 +121,7 @@ class SqlToS3Operator(BaseOperator):
|
|
119
121
|
self.replace = replace
|
120
122
|
self.pd_kwargs = pd_kwargs or {}
|
121
123
|
self.parameters = parameters
|
124
|
+
self.groupby_kwargs = groupby_kwargs or {}
|
122
125
|
|
123
126
|
if "path_or_buf" in self.pd_kwargs:
|
124
127
|
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
|
@@ -170,15 +173,26 @@ class SqlToS3Operator(BaseOperator):
|
|
170
173
|
self._fix_dtypes(data_df, self.file_format)
|
171
174
|
file_options = FILE_OPTIONS_MAP[self.file_format]
|
172
175
|
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
176
|
+
for group_name, df in self._partition_dataframe(df=data_df):
|
177
|
+
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
|
178
|
+
|
179
|
+
self.log.info("Writing data to temp file")
|
180
|
+
getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs)
|
181
|
+
|
182
|
+
self.log.info("Uploading data to S3")
|
183
|
+
object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
|
184
|
+
s3_conn.load_file(
|
185
|
+
filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
|
186
|
+
)
|
187
|
+
|
188
|
+
def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
|
189
|
+
"""Partition dataframe using pandas groupby() method"""
|
190
|
+
if not self.groupby_kwargs:
|
191
|
+
yield "", df
|
192
|
+
else:
|
193
|
+
grouped_df = df.groupby(**self.groupby_kwargs)
|
194
|
+
for group_label in grouped_df.groups.keys():
|
195
|
+
yield group_label, grouped_df.get_group(group_label).reset_index(drop=True)
|
182
196
|
|
183
197
|
def _get_hook(self) -> DbApiHook:
|
184
198
|
self.log.debug("Get connection for %s", self.sql_conn_id)
|