apache-airflow-providers-amazon 7.4.1rc1__py3-none-any.whl → 8.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. airflow/providers/amazon/aws/hooks/athena.py +0 -15
  2. airflow/providers/amazon/aws/hooks/base_aws.py +98 -65
  3. airflow/providers/amazon/aws/hooks/batch_client.py +60 -27
  4. airflow/providers/amazon/aws/hooks/batch_waiters.py +3 -1
  5. airflow/providers/amazon/aws/hooks/emr.py +33 -74
  6. airflow/providers/amazon/aws/hooks/logs.py +22 -4
  7. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -12
  8. airflow/providers/amazon/aws/hooks/sagemaker.py +0 -16
  9. airflow/providers/amazon/aws/links/emr.py +1 -3
  10. airflow/providers/amazon/aws/operators/athena.py +0 -15
  11. airflow/providers/amazon/aws/operators/batch.py +78 -24
  12. airflow/providers/amazon/aws/operators/ecs.py +21 -58
  13. airflow/providers/amazon/aws/operators/eks.py +0 -1
  14. airflow/providers/amazon/aws/operators/emr.py +94 -24
  15. airflow/providers/amazon/aws/operators/lambda_function.py +0 -19
  16. airflow/providers/amazon/aws/operators/rds.py +1 -1
  17. airflow/providers/amazon/aws/operators/redshift_cluster.py +22 -1
  18. airflow/providers/amazon/aws/operators/redshift_data.py +0 -62
  19. airflow/providers/amazon/aws/secrets/secrets_manager.py +0 -17
  20. airflow/providers/amazon/aws/secrets/systems_manager.py +0 -21
  21. airflow/providers/amazon/aws/sensors/dynamodb.py +97 -0
  22. airflow/providers/amazon/aws/sensors/emr.py +1 -2
  23. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +1 -1
  24. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +0 -19
  25. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -7
  26. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +10 -10
  27. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +0 -10
  28. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +0 -11
  29. airflow/providers/amazon/aws/transfers/s3_to_sftp.py +0 -10
  30. airflow/providers/amazon/aws/transfers/sql_to_s3.py +23 -9
  31. airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
  32. airflow/providers/amazon/aws/waiters/base_waiter.py +12 -1
  33. airflow/providers/amazon/aws/waiters/emr-serverless.json +18 -0
  34. airflow/providers/amazon/get_provider_info.py +35 -30
  35. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/METADATA +81 -4
  36. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/RECORD +41 -41
  37. airflow/providers/amazon/aws/operators/aws_lambda.py +0 -29
  38. airflow/providers/amazon/aws/operators/redshift_sql.py +0 -57
  39. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/LICENSE +0 -0
  40. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/NOTICE +0 -0
  41. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/WHEEL +0 -0
  42. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/entry_points.txt +0 -0
  43. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,10 @@ from typing import TYPE_CHECKING, Any, Sequence
22
22
  from airflow.exceptions import AirflowException
23
23
  from airflow.models import BaseOperator
24
24
  from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
25
- from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
25
+ from airflow.providers.amazon.aws.triggers.redshift_cluster import (
26
+ RedshiftClusterTrigger,
27
+ RedshiftCreateClusterTrigger,
28
+ )
26
29
 
27
30
  if TYPE_CHECKING:
28
31
  from airflow.utils.context import Context
@@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
88
91
  :param wait_for_completion: Whether wait for the cluster to be in ``available`` state
89
92
  :param max_attempt: The maximum number of attempts to be made. Default: 5
90
93
  :param poll_interval: The amount of time in seconds to wait between attempts. Default: 60
94
+ :param deferrable: If True, the operator will run in deferrable mode
91
95
  """
92
96
 
93
97
  template_fields: Sequence[str] = (
@@ -140,6 +144,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
140
144
  wait_for_completion: bool = False,
141
145
  max_attempt: int = 5,
142
146
  poll_interval: int = 60,
147
+ deferrable: bool = False,
143
148
  **kwargs,
144
149
  ):
145
150
  super().__init__(**kwargs)
@@ -180,6 +185,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
180
185
  self.wait_for_completion = wait_for_completion
181
186
  self.max_attempt = max_attempt
182
187
  self.poll_interval = poll_interval
188
+ self.deferrable = deferrable
183
189
  self.kwargs = kwargs
184
190
 
185
191
  def execute(self, context: Context):
@@ -252,6 +258,16 @@ class RedshiftCreateClusterOperator(BaseOperator):
252
258
  self.master_user_password,
253
259
  params,
254
260
  )
261
+ if self.deferrable:
262
+ self.defer(
263
+ trigger=RedshiftCreateClusterTrigger(
264
+ cluster_identifier=self.cluster_identifier,
265
+ poll_interval=self.poll_interval,
266
+ max_attempt=self.max_attempt,
267
+ aws_conn_id=self.aws_conn_id,
268
+ ),
269
+ method_name="execute_complete",
270
+ )
255
271
  if self.wait_for_completion:
256
272
  redshift_hook.get_conn().get_waiter("cluster_available").wait(
257
273
  ClusterIdentifier=self.cluster_identifier,
@@ -264,6 +280,11 @@ class RedshiftCreateClusterOperator(BaseOperator):
264
280
  self.log.info("Created Redshift cluster %s", self.cluster_identifier)
265
281
  self.log.info(cluster)
266
282
 
283
+ def execute_complete(self, context, event=None):
284
+ if event["status"] != "success":
285
+ raise AirflowException(f"Error creating cluster: {event}")
286
+ return
287
+
267
288
 
268
289
  class RedshiftCreateClusterSnapshotOperator(BaseOperator):
269
290
  """
@@ -17,7 +17,6 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- import warnings
21
20
  from typing import TYPE_CHECKING
22
21
 
23
22
  from airflow.compat.functools import cached_property
@@ -83,7 +82,6 @@ class RedshiftDataOperator(BaseOperator):
83
82
  return_sql_result: bool = False,
84
83
  aws_conn_id: str = "aws_default",
85
84
  region: str | None = None,
86
- await_result: bool | None = None,
87
85
  **kwargs,
88
86
  ) -> None:
89
87
  super().__init__(**kwargs)
@@ -95,16 +93,7 @@ class RedshiftDataOperator(BaseOperator):
95
93
  self.secret_arn = secret_arn
96
94
  self.statement_name = statement_name
97
95
  self.with_event = with_event
98
- self.await_result = await_result
99
96
  self.wait_for_completion = wait_for_completion
100
- if await_result:
101
- warnings.warn(
102
- f"Parameter `{self.__class__.__name__}.await_result` is deprecated and will be removed "
103
- "in a future release. Please use method `wait_for_completion` instead.",
104
- DeprecationWarning,
105
- stacklevel=2,
106
- )
107
- self.wait_for_completion = await_result
108
97
  if poll_interval > 0:
109
98
  self.poll_interval = poll_interval
110
99
  else:
@@ -122,57 +111,6 @@ class RedshiftDataOperator(BaseOperator):
122
111
  """Create and return an RedshiftDataHook."""
123
112
  return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
124
113
 
125
- def execute_query(self) -> str:
126
- warnings.warn(
127
- "This method is deprecated and has been moved to the hook "
128
- "`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
129
- DeprecationWarning,
130
- stacklevel=2,
131
- )
132
- self.statement_id = self.hook.execute_query(
133
- database=self.database,
134
- sql=self.sql,
135
- cluster_identifier=self.cluster_identifier,
136
- db_user=self.db_user,
137
- parameters=self.parameters,
138
- secret_arn=self.secret_arn,
139
- statement_name=self.statement_name,
140
- with_event=self.with_event,
141
- wait_for_completion=self.wait_for_completion,
142
- poll_interval=self.poll_interval,
143
- )
144
- return self.statement_id
145
-
146
- def execute_batch_query(self) -> str:
147
- warnings.warn(
148
- "This method is deprecated and has been moved to the hook "
149
- "`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
150
- DeprecationWarning,
151
- stacklevel=2,
152
- )
153
- self.statement_id = self.hook.execute_query(
154
- database=self.database,
155
- sql=self.sql,
156
- cluster_identifier=self.cluster_identifier,
157
- db_user=self.db_user,
158
- parameters=self.parameters,
159
- secret_arn=self.secret_arn,
160
- statement_name=self.statement_name,
161
- with_event=self.with_event,
162
- wait_for_completion=self.wait_for_completion,
163
- poll_interval=self.poll_interval,
164
- )
165
- return self.statement_id
166
-
167
- def wait_for_results(self, statement_id: str):
168
- warnings.warn(
169
- "This method is deprecated and has been moved to the hook "
170
- "`airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook`.",
171
- DeprecationWarning,
172
- stacklevel=2,
173
- )
174
- return self.hook.wait_for_results(statement_id=statement_id, poll_interval=self.poll_interval)
175
-
176
114
  def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
177
115
  """Execute a statement against Amazon Redshift"""
178
116
  self.log.info("Executing statement: %s", self.sql)
@@ -267,23 +267,6 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
267
267
  else:
268
268
  return secret
269
269
 
270
- def get_conn_uri(self, conn_id: str) -> str | None:
271
- """
272
- Return URI representation of Connection conn_id.
273
-
274
- As of Airflow version 2.3.0 this method is deprecated.
275
-
276
- :param conn_id: the connection id
277
- :return: deserialized Connection
278
- """
279
- warnings.warn(
280
- f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
281
- "in a future release. Please use method `get_conn_value` instead.",
282
- DeprecationWarning,
283
- stacklevel=2,
284
- )
285
- return self.get_conn_value(conn_id)
286
-
287
270
  def get_variable(self, key: str) -> str | None:
288
271
  """
289
272
  Get Airflow Variable
@@ -19,7 +19,6 @@
19
19
  from __future__ import annotations
20
20
 
21
21
  import re
22
- import warnings
23
22
 
24
23
  from airflow.compat.functools import cached_property
25
24
  from airflow.providers.amazon.aws.utils import trim_none_values
@@ -143,26 +142,6 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
143
142
 
144
143
  return self._get_secret(self.connections_prefix, conn_id, self.connections_lookup_pattern)
145
144
 
146
- def get_conn_uri(self, conn_id: str) -> str | None:
147
- """
148
- Return URI representation of Connection conn_id.
149
-
150
- As of Airflow version 2.3.0 this method is deprecated.
151
-
152
- :param conn_id: the connection id
153
- """
154
- warnings.warn(
155
- f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
156
- "in a future release. Please use method `get_conn_value` instead.",
157
- DeprecationWarning,
158
- stacklevel=2,
159
- )
160
- value = self.get_conn_value(conn_id)
161
- if value is None:
162
- return None
163
-
164
- return self.deserialize_connection(conn_id, value).get_uri()
165
-
166
145
  def get_variable(self, key: str) -> str | None:
167
146
  """
168
147
  Get Airflow Variable
@@ -0,0 +1,97 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING, Any
20
+
21
+ from airflow.compat.functools import cached_property
22
+ from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
23
+ from airflow.sensors.base import BaseSensorOperator
24
+
25
+ if TYPE_CHECKING:
26
+ from airflow.utils.context import Context
27
+
28
+
29
+ class DynamoDBValueSensor(BaseSensorOperator):
30
+ """
31
+ Waits for an attribute value to be present for an item in a DynamoDB table.
32
+
33
+ :param partition_key_name: DynamoDB partition key name
34
+ :param partition_key_value: DynamoDB partition key value
35
+ :param attribute_name: DynamoDB attribute name
36
+ :param attribute_value: DynamoDB attribute value
37
+ :param sort_key_name: (optional) DynamoDB sort key name
38
+ :param sort_key_value: (optional) DynamoDB sort key value
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ table_name: str,
44
+ partition_key_name: str,
45
+ partition_key_value: str,
46
+ attribute_name: str,
47
+ attribute_value: str,
48
+ sort_key_name: str | None = None,
49
+ sort_key_value: str | None = None,
50
+ aws_conn_id: str | None = DynamoDBHook.default_conn_name,
51
+ region_name: str | None = None,
52
+ **kwargs: Any,
53
+ ):
54
+ super().__init__(**kwargs)
55
+ self.table_name = table_name
56
+ self.partition_key_name = partition_key_name
57
+ self.partition_key_value = partition_key_value
58
+ self.attribute_name = attribute_name
59
+ self.attribute_value = attribute_value
60
+ self.sort_key_name = sort_key_name
61
+ self.sort_key_value = sort_key_value
62
+ self.aws_conn_id = aws_conn_id
63
+ self.region_name = region_name
64
+
65
+ def poke(self, context: Context) -> bool:
66
+ """Test DynamoDB item for matching attribute value"""
67
+ key = {self.partition_key_name: self.partition_key_value}
68
+ msg = (
69
+ f"Checking table {self.table_name} for "
70
+ + f"item Partition Key: {self.partition_key_name}={self.partition_key_value}"
71
+ )
72
+
73
+ if self.sort_key_name and self.sort_key_value:
74
+ key = {self.partition_key_name: self.partition_key_value, self.sort_key_name: self.sort_key_value}
75
+ msg += f"\nSort Key: {self.sort_key_name}={self.sort_key_value}"
76
+
77
+ msg += f"\nattribute: {self.attribute_name}={self.attribute_value}"
78
+
79
+ self.log.info(msg)
80
+ table = self.hook.conn.Table(self.table_name)
81
+ self.log.info("Table: %s", table)
82
+ self.log.info("Key: %s", key)
83
+ response = table.get_item(Key=key)
84
+ try:
85
+ self.log.info("Response: %s", response)
86
+ self.log.info("Want: %s = %s", self.attribute_name, self.attribute_value)
87
+ self.log.info(
88
+ "Got: {response['Item'][self.attribute_name]} = %s", response["Item"][self.attribute_name]
89
+ )
90
+ return response["Item"][self.attribute_name] == self.attribute_value
91
+ except KeyError:
92
+ return False
93
+
94
+ @cached_property
95
+ def hook(self) -> DynamoDBHook:
96
+ """Create and return a DynamoDBHook"""
97
+ return DynamoDBHook(self.aws_conn_id, region_name=self.region_name)
@@ -25,7 +25,7 @@ from airflow.exceptions import AirflowException
25
25
  from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
26
26
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
27
27
  from airflow.providers.amazon.aws.links.emr import EmrLogsLink
28
- from airflow.sensors.base import BaseSensorOperator, poke_mode_only
28
+ from airflow.sensors.base import BaseSensorOperator
29
29
 
30
30
  if TYPE_CHECKING:
31
31
  from airflow.utils.context import Context
@@ -451,7 +451,6 @@ class EmrJobFlowSensor(EmrBaseSensor):
451
451
  return None
452
452
 
453
453
 
454
- @poke_mode_only
455
454
  class EmrStepSensor(EmrBaseSensor):
456
455
  """
457
456
  Asks for the state of the step until it reaches any of the target states.
@@ -87,7 +87,7 @@ class DynamoDBToS3Operator(AwsToAwsBaseOperator):
87
87
  :param dynamodb_scan_kwargs: kwargs pass to <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.scan>
88
88
  :param s3_key_prefix: Prefix of s3 object key
89
89
  :param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
90
- """ # noqa: E501
90
+ """
91
91
 
92
92
  template_fields: Sequence[str] = (
93
93
  *AwsToAwsBaseOperator.template_fields,
@@ -19,7 +19,6 @@
19
19
  from __future__ import annotations
20
20
 
21
21
  import os
22
- import warnings
23
22
  from typing import TYPE_CHECKING, Sequence
24
23
 
25
24
  from airflow.models import BaseOperator
@@ -45,11 +44,6 @@ class GCSToS3Operator(BaseOperator):
45
44
  For e.g to lists the CSV files from in a directory in GCS you would use
46
45
  delimiter='.csv'.
47
46
  :param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
48
- :param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
49
- This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
50
- :param delegate_to: Google account to impersonate using domain-wide delegation of authority,
51
- if any. For this to work, the service account making the request must have
52
- domain-wide delegation enabled.
53
47
  :param dest_aws_conn_id: The destination S3 connection
54
48
  :param dest_s3_key: The base S3 key to be used to store the files. (templated)
55
49
  :param dest_verify: Whether or not to verify SSL certificates for S3 connection.
@@ -100,8 +94,6 @@ class GCSToS3Operator(BaseOperator):
100
94
  prefix: str | None = None,
101
95
  delimiter: str | None = None,
102
96
  gcp_conn_id: str = "google_cloud_default",
103
- google_cloud_storage_conn_id: str | None = None,
104
- delegate_to: str | None = None,
105
97
  dest_aws_conn_id: str = "aws_default",
106
98
  dest_s3_key: str,
107
99
  dest_verify: str | bool | None = None,
@@ -114,20 +106,10 @@ class GCSToS3Operator(BaseOperator):
114
106
  ) -> None:
115
107
  super().__init__(**kwargs)
116
108
 
117
- if google_cloud_storage_conn_id:
118
- warnings.warn(
119
- "The google_cloud_storage_conn_id parameter has been deprecated. You should pass "
120
- "the gcp_conn_id parameter.",
121
- DeprecationWarning,
122
- stacklevel=3,
123
- )
124
- gcp_conn_id = google_cloud_storage_conn_id
125
-
126
109
  self.bucket = bucket
127
110
  self.prefix = prefix
128
111
  self.delimiter = delimiter
129
112
  self.gcp_conn_id = gcp_conn_id
130
- self.delegate_to = delegate_to
131
113
  self.dest_aws_conn_id = dest_aws_conn_id
132
114
  self.dest_s3_key = dest_s3_key
133
115
  self.dest_verify = dest_verify
@@ -141,7 +123,6 @@ class GCSToS3Operator(BaseOperator):
141
123
  # list all files in an Google Cloud Storage bucket
142
124
  hook = GCSHook(
143
125
  gcp_conn_id=self.gcp_conn_id,
144
- delegate_to=self.delegate_to,
145
126
  impersonation_chain=self.google_impersonation_chain,
146
127
  )
147
128
 
@@ -47,10 +47,7 @@ class GlacierToGCSOperator(BaseOperator):
47
47
  :param object_name: the name of the object to check in the Google cloud
48
48
  storage bucket.
49
49
  :param gzip: option to compress local file or file data for upload
50
- :param chunk_size: size of chunk in bytes the that will downloaded from Glacier vault
51
- :param delegate_to: The account to impersonate using domain-wide delegation of authority,
52
- if any. For this to work, the service account making the request must have
53
- domain-wide delegation enabled.
50
+ :param chunk_size: size of chunk in bytes the that will be downloaded from Glacier vault
54
51
  :param google_impersonation_chain: Optional Google service account to impersonate using
55
52
  short-term credentials, or chained list of accounts required to get the access_token
56
53
  of the last account in the list, which will be impersonated in the request.
@@ -73,7 +70,6 @@ class GlacierToGCSOperator(BaseOperator):
73
70
  object_name: str,
74
71
  gzip: bool,
75
72
  chunk_size: int = 1024,
76
- delegate_to: str | None = None,
77
73
  google_impersonation_chain: str | Sequence[str] | None = None,
78
74
  **kwargs,
79
75
  ) -> None:
@@ -85,14 +81,12 @@ class GlacierToGCSOperator(BaseOperator):
85
81
  self.object_name = object_name
86
82
  self.gzip = gzip
87
83
  self.chunk_size = chunk_size
88
- self.delegate_to = delegate_to
89
84
  self.impersonation_chain = google_impersonation_chain
90
85
 
91
86
  def execute(self, context: Context) -> str:
92
87
  glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
93
88
  gcs_hook = GCSHook(
94
89
  gcp_conn_id=self.gcp_conn_id,
95
- delegate_to=self.delegate_to,
96
90
  impersonation_chain=self.impersonation_chain,
97
91
  )
98
92
  job_id = glacier_hook.retrieve_inventory(vault_name=self.vault_name)
@@ -22,12 +22,14 @@ import json
22
22
  import sys
23
23
  from typing import TYPE_CHECKING, Sequence
24
24
 
25
- from airflow.models import BaseOperator, TaskInstance
25
+ from airflow.models import BaseOperator
26
26
  from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
27
27
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
28
28
  from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook
29
29
 
30
30
  if TYPE_CHECKING:
31
+ from airflow.models import TaskInstance
32
+ from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
31
33
  from airflow.utils.context import Context
32
34
 
33
35
 
@@ -71,13 +73,10 @@ class GoogleApiToS3Operator(BaseOperator):
71
73
 
72
74
  .. note:: This means the response will be a list of responses.
73
75
 
74
- :param google_api_num_retries: Define the number of retries for the google api requests being made
76
+ :param google_api_num_retries: Define the number of retries for the Google API requests being made
75
77
  if it fails.
76
78
  :param s3_overwrite: Specifies whether the s3 file will be overwritten if exists.
77
79
  :param gcp_conn_id: The connection ID to use when fetching connection info.
78
- :param delegate_to: Google account to impersonate using domain-wide delegation of authority,
79
- if any. For this to work, the service account making the request must have
80
- domain-wide delegation enabled.
81
80
  :param aws_conn_id: The connection id specifying the authentication information for the S3 Bucket.
82
81
  :param google_impersonation_chain: Optional Google service account to impersonate using
83
82
  short-term credentials, or chained list of accounts required to get the access_token
@@ -113,7 +112,6 @@ class GoogleApiToS3Operator(BaseOperator):
113
112
  google_api_num_retries: int = 0,
114
113
  s3_overwrite: bool = False,
115
114
  gcp_conn_id: str = "google_cloud_default",
116
- delegate_to: str | None = None,
117
115
  aws_conn_id: str = "aws_default",
118
116
  google_impersonation_chain: str | Sequence[str] | None = None,
119
117
  **kwargs,
@@ -131,7 +129,6 @@ class GoogleApiToS3Operator(BaseOperator):
131
129
  self.google_api_num_retries = google_api_num_retries
132
130
  self.s3_overwrite = s3_overwrite
133
131
  self.gcp_conn_id = gcp_conn_id
134
- self.delegate_to = delegate_to
135
132
  self.aws_conn_id = aws_conn_id
136
133
  self.google_impersonation_chain = google_impersonation_chain
137
134
 
@@ -156,7 +153,6 @@ class GoogleApiToS3Operator(BaseOperator):
156
153
  def _retrieve_data_from_google_api(self) -> dict:
157
154
  google_discovery_api_hook = GoogleDiscoveryApiHook(
158
155
  gcp_conn_id=self.gcp_conn_id,
159
- delegate_to=self.delegate_to,
160
156
  api_service_name=self.google_api_service_name,
161
157
  api_version=self.google_api_service_version,
162
158
  impersonation_chain=self.google_impersonation_chain,
@@ -177,7 +173,9 @@ class GoogleApiToS3Operator(BaseOperator):
177
173
  replace=self.s3_overwrite,
178
174
  )
179
175
 
180
- def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None:
176
+ def _update_google_api_endpoint_params_via_xcom(
177
+ self, task_instance: TaskInstance | TaskInstancePydantic
178
+ ) -> None:
181
179
 
182
180
  if self.google_api_endpoint_params_via_xcom:
183
181
  google_api_endpoint_params = task_instance.xcom_pull(
@@ -186,7 +184,9 @@ class GoogleApiToS3Operator(BaseOperator):
186
184
  )
187
185
  self.google_api_endpoint_params.update(google_api_endpoint_params)
188
186
 
189
- def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None:
187
+ def _expose_google_api_response_via_xcom(
188
+ self, task_instance: TaskInstance | TaskInstancePydantic, data: dict
189
+ ) -> None:
190
190
  if sys.getsizeof(data) < MAX_XCOM_SIZE:
191
191
  task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
192
192
  else:
@@ -18,7 +18,6 @@
18
18
  """This module allows you to transfer mail attachments from a mail server into s3 bucket."""
19
19
  from __future__ import annotations
20
20
 
21
- import warnings
22
21
  from typing import TYPE_CHECKING, Sequence
23
22
 
24
23
  from airflow.models import BaseOperator
@@ -28,10 +27,6 @@ from airflow.providers.imap.hooks.imap import ImapHook
28
27
  if TYPE_CHECKING:
29
28
  from airflow.utils.context import Context
30
29
 
31
- _DEPRECATION_MSG = (
32
- "The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
33
- )
34
-
35
30
 
36
31
  class ImapAttachmentToS3Operator(BaseOperator):
37
32
  """
@@ -66,15 +61,10 @@ class ImapAttachmentToS3Operator(BaseOperator):
66
61
  imap_mail_filter: str = "All",
67
62
  s3_overwrite: bool = False,
68
63
  imap_conn_id: str = "imap_default",
69
- s3_conn_id: str | None = None,
70
64
  aws_conn_id: str = "aws_default",
71
65
  **kwargs,
72
66
  ) -> None:
73
67
  super().__init__(**kwargs)
74
- if s3_conn_id:
75
- warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
76
- aws_conn_id = s3_conn_id
77
-
78
68
  self.imap_attachment_name = imap_attachment_name
79
69
  self.s3_bucket = s3_bucket
80
70
  self.s3_key = s3_key
@@ -18,7 +18,6 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import json
21
- import warnings
22
21
  from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast
23
22
 
24
23
  from bson import json_util
@@ -31,11 +30,6 @@ if TYPE_CHECKING:
31
30
  from airflow.utils.context import Context
32
31
 
33
32
 
34
- _DEPRECATION_MSG = (
35
- "The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
36
- )
37
-
38
-
39
33
  class MongoToS3Operator(BaseOperator):
40
34
  """Operator meant to move data from mongo via pymongo to s3 via boto.
41
35
 
@@ -66,7 +60,6 @@ class MongoToS3Operator(BaseOperator):
66
60
  def __init__(
67
61
  self,
68
62
  *,
69
- s3_conn_id: str | None = None,
70
63
  mongo_conn_id: str = "mongo_default",
71
64
  aws_conn_id: str = "aws_default",
72
65
  mongo_collection: str,
@@ -81,10 +74,6 @@ class MongoToS3Operator(BaseOperator):
81
74
  **kwargs,
82
75
  ) -> None:
83
76
  super().__init__(**kwargs)
84
- if s3_conn_id:
85
- warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
86
- aws_conn_id = s3_conn_id
87
-
88
77
  self.mongo_conn_id = mongo_conn_id
89
78
  self.aws_conn_id = aws_conn_id
90
79
  self.mongo_db = mongo_db
@@ -17,7 +17,6 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- import warnings
21
20
  from tempfile import NamedTemporaryFile
22
21
  from typing import TYPE_CHECKING, Sequence
23
22
  from urllib.parse import urlsplit
@@ -29,10 +28,6 @@ from airflow.providers.ssh.hooks.ssh import SSHHook
29
28
  if TYPE_CHECKING:
30
29
  from airflow.utils.context import Context
31
30
 
32
- _DEPRECATION_MSG = (
33
- "The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
34
- )
35
-
36
31
 
37
32
  class S3ToSFTPOperator(BaseOperator):
38
33
  """
@@ -62,15 +57,10 @@ class S3ToSFTPOperator(BaseOperator):
62
57
  s3_key: str,
63
58
  sftp_path: str,
64
59
  sftp_conn_id: str = "ssh_default",
65
- s3_conn_id: str | None = None,
66
60
  aws_conn_id: str = "aws_default",
67
61
  **kwargs,
68
62
  ) -> None:
69
63
  super().__init__(**kwargs)
70
- if s3_conn_id:
71
- warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
72
- aws_conn_id = s3_conn_id
73
-
74
64
  self.sftp_conn_id = sftp_conn_id
75
65
  self.sftp_path = sftp_path
76
66
  self.s3_bucket = s3_bucket
@@ -80,6 +80,7 @@ class SqlToS3Operator(BaseOperator):
80
80
  CA cert bundle than the one used by botocore.
81
81
  :param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
82
82
  :param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
83
+ :param groupby_kwargs: argument to include in DataFrame ``groupby()``.
83
84
  """
84
85
 
85
86
  template_fields: Sequence[str] = (
@@ -107,6 +108,7 @@ class SqlToS3Operator(BaseOperator):
107
108
  verify: bool | str | None = None,
108
109
  file_format: Literal["csv", "json", "parquet"] = "csv",
109
110
  pd_kwargs: dict | None = None,
111
+ groupby_kwargs: dict | None = None,
110
112
  **kwargs,
111
113
  ) -> None:
112
114
  super().__init__(**kwargs)
@@ -119,6 +121,7 @@ class SqlToS3Operator(BaseOperator):
119
121
  self.replace = replace
120
122
  self.pd_kwargs = pd_kwargs or {}
121
123
  self.parameters = parameters
124
+ self.groupby_kwargs = groupby_kwargs or {}
122
125
 
123
126
  if "path_or_buf" in self.pd_kwargs:
124
127
  raise AirflowException("The argument path_or_buf is not allowed, please remove it")
@@ -170,15 +173,26 @@ class SqlToS3Operator(BaseOperator):
170
173
  self._fix_dtypes(data_df, self.file_format)
171
174
  file_options = FILE_OPTIONS_MAP[self.file_format]
172
175
 
173
- with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
174
-
175
- self.log.info("Writing data to temp file")
176
- getattr(data_df, file_options.function)(tmp_file.name, **self.pd_kwargs)
177
-
178
- self.log.info("Uploading data to S3")
179
- s3_conn.load_file(
180
- filename=tmp_file.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=self.replace
181
- )
176
+ for group_name, df in self._partition_dataframe(df=data_df):
177
+ with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
178
+
179
+ self.log.info("Writing data to temp file")
180
+ getattr(df, file_options.function)(tmp_file.name, **self.pd_kwargs)
181
+
182
+ self.log.info("Uploading data to S3")
183
+ object_key = f"{self.s3_key}_{group_name}" if group_name else self.s3_key
184
+ s3_conn.load_file(
185
+ filename=tmp_file.name, key=object_key, bucket_name=self.s3_bucket, replace=self.replace
186
+ )
187
+
188
+ def _partition_dataframe(self, df: DataFrame) -> Iterable[tuple[str, DataFrame]]:
189
+ """Partition dataframe using pandas groupby() method"""
190
+ if not self.groupby_kwargs:
191
+ yield "", df
192
+ else:
193
+ grouped_df = df.groupby(**self.groupby_kwargs)
194
+ for group_label in grouped_df.groups.keys():
195
+ yield group_label, grouped_df.get_group(group_label).reset_index(drop=True)
182
196
 
183
197
  def _get_hook(self) -> DbApiHook:
184
198
  self.log.debug("Get connection for %s", self.sql_conn_id)