apache-airflow-providers-amazon 8.29.0rc1__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/{datasets → assets}/s3.py +10 -6
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +5 -11
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +2 -5
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +0 -6
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/hooks/athena.py +3 -17
- airflow/providers/amazon/aws/hooks/base_aws.py +4 -162
- airflow/providers/amazon/aws/hooks/logs.py +1 -20
- airflow/providers/amazon/aws/hooks/quicksight.py +1 -17
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +6 -120
- airflow/providers/amazon/aws/hooks/redshift_data.py +52 -14
- airflow/providers/amazon/aws/hooks/s3.py +24 -27
- airflow/providers/amazon/aws/hooks/sagemaker.py +4 -48
- airflow/providers/amazon/aws/log/s3_task_handler.py +1 -6
- airflow/providers/amazon/aws/operators/appflow.py +1 -10
- airflow/providers/amazon/aws/operators/batch.py +1 -29
- airflow/providers/amazon/aws/operators/datasync.py +1 -8
- airflow/providers/amazon/aws/operators/ecs.py +1 -25
- airflow/providers/amazon/aws/operators/eks.py +7 -46
- airflow/providers/amazon/aws/operators/emr.py +16 -232
- airflow/providers/amazon/aws/operators/glue_databrew.py +1 -10
- airflow/providers/amazon/aws/operators/rds.py +3 -17
- airflow/providers/amazon/aws/operators/redshift_data.py +18 -3
- airflow/providers/amazon/aws/operators/s3.py +12 -2
- airflow/providers/amazon/aws/operators/sagemaker.py +10 -32
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -40
- airflow/providers/amazon/aws/sensors/batch.py +1 -8
- airflow/providers/amazon/aws/sensors/dms.py +1 -8
- airflow/providers/amazon/aws/sensors/dynamodb.py +22 -8
- airflow/providers/amazon/aws/sensors/emr.py +0 -7
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +1 -8
- airflow/providers/amazon/aws/sensors/glue_crawler.py +1 -8
- airflow/providers/amazon/aws/sensors/quicksight.py +1 -29
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +1 -8
- airflow/providers/amazon/aws/sensors/s3.py +1 -8
- airflow/providers/amazon/aws/sensors/sagemaker.py +2 -9
- airflow/providers/amazon/aws/sensors/sqs.py +1 -8
- airflow/providers/amazon/aws/sensors/step_function.py +1 -8
- airflow/providers/amazon/aws/transfers/base.py +1 -14
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +5 -33
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +15 -10
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +6 -6
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +3 -6
- airflow/providers/amazon/aws/triggers/batch.py +1 -168
- airflow/providers/amazon/aws/triggers/eks.py +1 -20
- airflow/providers/amazon/aws/triggers/emr.py +0 -32
- airflow/providers/amazon/aws/triggers/glue_crawler.py +0 -11
- airflow/providers/amazon/aws/triggers/glue_databrew.py +0 -21
- airflow/providers/amazon/aws/triggers/rds.py +0 -79
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +5 -64
- airflow/providers/amazon/aws/triggers/sagemaker.py +2 -93
- airflow/providers/amazon/aws/utils/asset_compat_lineage_collector.py +106 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +4 -164
- airflow/providers/amazon/aws/utils/mixins.py +1 -23
- airflow/providers/amazon/aws/utils/openlineage.py +3 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +1 -1
- airflow/providers/amazon/get_provider_info.py +13 -4
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/METADATA +12 -13
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/RECORD +64 -64
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +0 -149
- /airflow/providers/amazon/aws/{datasets → assets}/__init__.py +0 -0
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.29.0rc1.dist-info → apache_airflow_providers_amazon-9.0.0.dist-info}/entry_points.txt +0 -0
@@ -19,18 +19,12 @@
|
|
19
19
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
|
-
import warnings
|
23
22
|
from typing import Sequence
|
24
23
|
|
25
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
26
24
|
from airflow.models import BaseOperator
|
27
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
28
26
|
from airflow.utils.types import NOTSET, ArgNotSet
|
29
27
|
|
30
|
-
_DEPRECATION_MSG = (
|
31
|
-
"The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead."
|
32
|
-
)
|
33
|
-
|
34
28
|
|
35
29
|
class AwsToAwsBaseOperator(BaseOperator):
|
36
30
|
"""
|
@@ -43,8 +37,6 @@ class AwsToAwsBaseOperator(BaseOperator):
|
|
43
37
|
would be used (and must be maintained on each worker node).
|
44
38
|
:param dest_aws_conn_id: The Airflow connection used for AWS credentials
|
45
39
|
to access S3. If this is not set then the source_aws_conn_id connection is used.
|
46
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id).
|
47
|
-
|
48
40
|
"""
|
49
41
|
|
50
42
|
template_fields: Sequence[str] = (
|
@@ -57,17 +49,12 @@ class AwsToAwsBaseOperator(BaseOperator):
|
|
57
49
|
*,
|
58
50
|
source_aws_conn_id: str | None = AwsBaseHook.default_conn_name,
|
59
51
|
dest_aws_conn_id: str | None | ArgNotSet = NOTSET,
|
60
|
-
aws_conn_id: str | None | ArgNotSet = NOTSET,
|
61
52
|
**kwargs,
|
62
53
|
) -> None:
|
63
54
|
super().__init__(**kwargs)
|
64
55
|
self.source_aws_conn_id = source_aws_conn_id
|
65
56
|
self.dest_aws_conn_id = dest_aws_conn_id
|
66
|
-
|
67
|
-
warnings.warn(_DEPRECATION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
|
68
|
-
self.source_aws_conn_id = aws_conn_id
|
69
|
-
else:
|
70
|
-
self.source_aws_conn_id = source_aws_conn_id
|
57
|
+
self.source_aws_conn_id = source_aws_conn_id
|
71
58
|
if isinstance(dest_aws_conn_id, ArgNotSet):
|
72
59
|
self.dest_aws_conn_id = self.source_aws_conn_id
|
73
60
|
else:
|
@@ -20,12 +20,11 @@
|
|
20
20
|
from __future__ import annotations
|
21
21
|
|
22
22
|
import os
|
23
|
-
import warnings
|
24
23
|
from typing import TYPE_CHECKING, Sequence
|
25
24
|
|
26
25
|
from packaging.version import Version
|
27
26
|
|
28
|
-
from airflow.exceptions import AirflowException
|
27
|
+
from airflow.exceptions import AirflowException
|
29
28
|
from airflow.models import BaseOperator
|
30
29
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
31
30
|
from airflow.providers.google.cloud.hooks.gcs import GCSHook
|
@@ -43,12 +42,8 @@ class GCSToS3Operator(BaseOperator):
|
|
43
42
|
:ref:`howto/operator:GCSToS3Operator`
|
44
43
|
|
45
44
|
:param gcs_bucket: The Google Cloud Storage bucket to find the objects. (templated)
|
46
|
-
:param bucket: (Deprecated) Use ``gcs_bucket`` instead.
|
47
45
|
:param prefix: Prefix string which filters objects whose name begin with
|
48
46
|
this prefix. (templated)
|
49
|
-
:param delimiter: (Deprecated) The delimiter by which you want to filter the objects. (templated)
|
50
|
-
For e.g to lists the CSV files from in a directory in GCS you would use
|
51
|
-
delimiter='.csv'.
|
52
47
|
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
|
53
48
|
:param dest_aws_conn_id: The destination S3 connection
|
54
49
|
:param dest_s3_key: The base S3 key to be used to store the files. (templated)
|
@@ -91,7 +86,6 @@ class GCSToS3Operator(BaseOperator):
|
|
91
86
|
template_fields: Sequence[str] = (
|
92
87
|
"gcs_bucket",
|
93
88
|
"prefix",
|
94
|
-
"delimiter",
|
95
89
|
"dest_s3_key",
|
96
90
|
"google_impersonation_chain",
|
97
91
|
"gcp_user_project",
|
@@ -101,10 +95,8 @@ class GCSToS3Operator(BaseOperator):
|
|
101
95
|
def __init__(
|
102
96
|
self,
|
103
97
|
*,
|
104
|
-
gcs_bucket: str
|
105
|
-
bucket: str | None = None,
|
98
|
+
gcs_bucket: str,
|
106
99
|
prefix: str | None = None,
|
107
|
-
delimiter: str | None = None,
|
108
100
|
gcp_conn_id: str = "google_cloud_default",
|
109
101
|
dest_aws_conn_id: str | None = "aws_default",
|
110
102
|
dest_s3_key: str,
|
@@ -119,17 +111,7 @@ class GCSToS3Operator(BaseOperator):
|
|
119
111
|
**kwargs,
|
120
112
|
) -> None:
|
121
113
|
super().__init__(**kwargs)
|
122
|
-
|
123
|
-
warnings.warn(
|
124
|
-
"The ``bucket`` parameter is deprecated and will be removed in a future version. "
|
125
|
-
"Please use ``gcs_bucket`` instead.",
|
126
|
-
AirflowProviderDeprecationWarning,
|
127
|
-
stacklevel=2,
|
128
|
-
)
|
129
|
-
self.gcs_bucket = gcs_bucket or bucket
|
130
|
-
if not (bucket or gcs_bucket):
|
131
|
-
raise ValueError("You must pass either ``bucket`` or ``gcs_bucket``.")
|
132
|
-
|
114
|
+
self.gcs_bucket = gcs_bucket
|
133
115
|
self.prefix = prefix
|
134
116
|
self.gcp_conn_id = gcp_conn_id
|
135
117
|
self.dest_aws_conn_id = dest_aws_conn_id
|
@@ -149,18 +131,10 @@ class GCSToS3Operator(BaseOperator):
|
|
149
131
|
self.__is_match_glob_supported = False
|
150
132
|
except ImportError: # __version__ was added in 10.1.0, so this means it's < 10.3.0
|
151
133
|
self.__is_match_glob_supported = False
|
152
|
-
if self.__is_match_glob_supported:
|
153
|
-
if delimiter:
|
154
|
-
warnings.warn(
|
155
|
-
"Usage of 'delimiter' is deprecated, please use 'match_glob' instead",
|
156
|
-
AirflowProviderDeprecationWarning,
|
157
|
-
stacklevel=2,
|
158
|
-
)
|
159
|
-
elif match_glob:
|
134
|
+
if not self.__is_match_glob_supported and match_glob:
|
160
135
|
raise AirflowException(
|
161
136
|
"The 'match_glob' parameter requires 'apache-airflow-providers-google>=10.3.0'."
|
162
137
|
)
|
163
|
-
self.delimiter = delimiter
|
164
138
|
self.match_glob = match_glob
|
165
139
|
self.gcp_user_project = gcp_user_project
|
166
140
|
|
@@ -172,16 +146,14 @@ class GCSToS3Operator(BaseOperator):
|
|
172
146
|
)
|
173
147
|
|
174
148
|
self.log.info(
|
175
|
-
"Getting list of the files. Bucket: %s;
|
149
|
+
"Getting list of the files. Bucket: %s; Prefix: %s",
|
176
150
|
self.gcs_bucket,
|
177
|
-
self.delimiter,
|
178
151
|
self.prefix,
|
179
152
|
)
|
180
153
|
|
181
154
|
list_kwargs = {
|
182
155
|
"bucket_name": self.gcs_bucket,
|
183
156
|
"prefix": self.prefix,
|
184
|
-
"delimiter": self.delimiter,
|
185
157
|
"user_project": self.gcp_user_project,
|
186
158
|
}
|
187
159
|
if self.__is_match_glob_supported:
|
@@ -45,7 +45,8 @@ class RedshiftToS3Operator(BaseOperator):
|
|
45
45
|
:param s3_key: reference to a specific S3 key. If ``table_as_file_name`` is set
|
46
46
|
to False, this param must include the desired file name
|
47
47
|
:param schema: reference to a specific schema in redshift database,
|
48
|
-
used when ``table`` param provided and ``select_query`` param not provided
|
48
|
+
used when ``table`` param provided and ``select_query`` param not provided.
|
49
|
+
Do not provide when unloading a temporary table
|
49
50
|
:param table: reference to a specific table in redshift database,
|
50
51
|
used when ``schema`` param provided and ``select_query`` param not provided
|
51
52
|
:param select_query: custom select query to fetch data from redshift database,
|
@@ -55,8 +56,8 @@ class RedshiftToS3Operator(BaseOperator):
|
|
55
56
|
If the AWS connection contains 'aws_iam_role' in ``extras``
|
56
57
|
the operator will use AWS STS credentials with a token
|
57
58
|
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
|
58
|
-
:param verify: Whether
|
59
|
-
By default SSL certificates are verified.
|
59
|
+
:param verify: Whether to verify SSL certificates for S3 connection.
|
60
|
+
By default, SSL certificates are verified.
|
60
61
|
You can provide the following values:
|
61
62
|
|
62
63
|
- ``False``: do not validate SSL certificates. SSL will still be used
|
@@ -67,7 +68,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|
67
68
|
CA cert bundle than the one used by botocore.
|
68
69
|
:param unload_options: reference to a list of UNLOAD options
|
69
70
|
:param autocommit: If set to True it will automatically commit the UNLOAD statement.
|
70
|
-
Otherwise it will be committed right before the redshift connection gets closed.
|
71
|
+
Otherwise, it will be committed right before the redshift connection gets closed.
|
71
72
|
:param include_header: If set to True the s3 file contains the header columns.
|
72
73
|
:param parameters: (optional) the parameters to render the SQL query with.
|
73
74
|
:param table_as_file_name: If set to True, the s3 file will be named as the table.
|
@@ -141,9 +142,15 @@ class RedshiftToS3Operator(BaseOperator):
|
|
141
142
|
|
142
143
|
@property
|
143
144
|
def default_select_query(self) -> str | None:
|
144
|
-
if
|
145
|
-
return
|
146
|
-
|
145
|
+
if not self.table:
|
146
|
+
return None
|
147
|
+
|
148
|
+
if self.schema:
|
149
|
+
table = f"{self.schema}.{self.table}"
|
150
|
+
else:
|
151
|
+
# Relevant when unloading a temporary table
|
152
|
+
table = self.table
|
153
|
+
return f"SELECT * FROM {table}"
|
147
154
|
|
148
155
|
def execute(self, context: Context) -> None:
|
149
156
|
if self.table and self.table_as_file_name:
|
@@ -152,9 +159,7 @@ class RedshiftToS3Operator(BaseOperator):
|
|
152
159
|
self.select_query = self.select_query or self.default_select_query
|
153
160
|
|
154
161
|
if self.select_query is None:
|
155
|
-
raise ValueError(
|
156
|
-
"Please provide both `schema` and `table` params or `select_query` to fetch the data."
|
157
|
-
)
|
162
|
+
raise ValueError("Please specify either a table or `select_query` to fetch the data.")
|
158
163
|
|
159
164
|
if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
|
160
165
|
self.unload_options = [*self.unload_options, "HEADER"]
|
@@ -28,7 +28,6 @@ from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
|
|
28
28
|
if TYPE_CHECKING:
|
29
29
|
from airflow.utils.context import Context
|
30
30
|
|
31
|
-
|
32
31
|
AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"]
|
33
32
|
|
34
33
|
|
@@ -40,17 +39,18 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
40
39
|
For more information on how to use this operator, take a look at the guide:
|
41
40
|
:ref:`howto/operator:S3ToRedshiftOperator`
|
42
41
|
|
43
|
-
:param schema: reference to a specific schema in redshift database
|
44
42
|
:param table: reference to a specific table in redshift database
|
45
43
|
:param s3_bucket: reference to a specific S3 bucket
|
46
44
|
:param s3_key: key prefix that selects single or multiple objects from S3
|
45
|
+
:param schema: reference to a specific schema in redshift database.
|
46
|
+
Do not provide when copying into a temporary table
|
47
47
|
:param redshift_conn_id: reference to a specific redshift database OR a redshift data-api connection
|
48
48
|
:param aws_conn_id: reference to a specific S3 connection
|
49
49
|
If the AWS connection contains 'aws_iam_role' in ``extras``
|
50
50
|
the operator will use AWS STS credentials with a token
|
51
51
|
https://docs.aws.amazon.com/redshift/latest/dg/copy-parameters-authorization.html#copy-credentials
|
52
|
-
:param verify: Whether
|
53
|
-
By default SSL certificates are verified.
|
52
|
+
:param verify: Whether to verify SSL certificates for S3 connection.
|
53
|
+
By default, SSL certificates are verified.
|
54
54
|
You can provide the following values:
|
55
55
|
|
56
56
|
- ``False``: do not validate SSL certificates. SSL will still be used
|
@@ -87,10 +87,10 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
87
87
|
def __init__(
|
88
88
|
self,
|
89
89
|
*,
|
90
|
-
schema: str,
|
91
90
|
table: str,
|
92
91
|
s3_bucket: str,
|
93
92
|
s3_key: str,
|
93
|
+
schema: str | None = None,
|
94
94
|
redshift_conn_id: str = "redshift_default",
|
95
95
|
aws_conn_id: str | None = "aws_default",
|
96
96
|
verify: bool | str | None = None,
|
@@ -160,7 +160,7 @@ class S3ToRedshiftOperator(BaseOperator):
|
|
160
160
|
credentials_block = build_credentials_block(credentials)
|
161
161
|
|
162
162
|
copy_options = "\n\t\t\t".join(self.copy_options)
|
163
|
-
destination = f"{self.schema}.{self.table}"
|
163
|
+
destination = f"{self.schema}.{self.table}" if self.schema else self.table
|
164
164
|
copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination
|
165
165
|
|
166
166
|
copy_statement = self._build_copy_query(
|
@@ -223,12 +223,9 @@ class SqlToS3Operator(BaseOperator):
|
|
223
223
|
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
|
224
224
|
yield (
|
225
225
|
cast(str, group_label),
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
.drop(random_column_name, axis=1, errors="ignore")
|
230
|
-
.reset_index(drop=True),
|
231
|
-
),
|
226
|
+
grouped_df.get_group(group_label)
|
227
|
+
.drop(random_column_name, axis=1, errors="ignore")
|
228
|
+
.reset_index(drop=True),
|
232
229
|
)
|
233
230
|
|
234
231
|
def _get_hook(self) -> DbApiHook:
|
@@ -16,182 +16,15 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import
|
20
|
-
import itertools
|
21
|
-
from functools import cached_property
|
22
|
-
from typing import TYPE_CHECKING, Any
|
19
|
+
from typing import TYPE_CHECKING
|
23
20
|
|
24
|
-
from botocore.exceptions import WaiterError
|
25
|
-
from deprecated import deprecated
|
26
|
-
|
27
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
28
21
|
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
|
29
22
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
30
|
-
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
31
23
|
|
32
24
|
if TYPE_CHECKING:
|
33
25
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
34
26
|
|
35
27
|
|
36
|
-
@deprecated(reason="use BatchJobTrigger instead", category=AirflowProviderDeprecationWarning)
|
37
|
-
class BatchOperatorTrigger(BaseTrigger):
|
38
|
-
"""
|
39
|
-
Asynchronously poll the boto3 API and wait for the Batch job to be in the `SUCCEEDED` state.
|
40
|
-
|
41
|
-
:param job_id: A unique identifier for the cluster.
|
42
|
-
:param max_retries: The maximum number of attempts to be made.
|
43
|
-
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
44
|
-
:param region_name: region name to use in AWS Hook
|
45
|
-
:param poll_interval: The amount of time in seconds to wait between attempts.
|
46
|
-
"""
|
47
|
-
|
48
|
-
def __init__(
|
49
|
-
self,
|
50
|
-
job_id: str | None = None,
|
51
|
-
max_retries: int = 10,
|
52
|
-
aws_conn_id: str | None = "aws_default",
|
53
|
-
region_name: str | None = None,
|
54
|
-
poll_interval: int = 30,
|
55
|
-
):
|
56
|
-
super().__init__()
|
57
|
-
self.job_id = job_id
|
58
|
-
self.max_retries = max_retries
|
59
|
-
self.aws_conn_id = aws_conn_id
|
60
|
-
self.region_name = region_name
|
61
|
-
self.poll_interval = poll_interval
|
62
|
-
|
63
|
-
def serialize(self) -> tuple[str, dict[str, Any]]:
|
64
|
-
"""Serialize BatchOperatorTrigger arguments and classpath."""
|
65
|
-
return (
|
66
|
-
"airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger",
|
67
|
-
{
|
68
|
-
"job_id": self.job_id,
|
69
|
-
"max_retries": self.max_retries,
|
70
|
-
"aws_conn_id": self.aws_conn_id,
|
71
|
-
"region_name": self.region_name,
|
72
|
-
"poll_interval": self.poll_interval,
|
73
|
-
},
|
74
|
-
)
|
75
|
-
|
76
|
-
@cached_property
|
77
|
-
def hook(self) -> BatchClientHook:
|
78
|
-
return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
79
|
-
|
80
|
-
async def run(self):
|
81
|
-
async with self.hook.async_conn as client:
|
82
|
-
waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client)
|
83
|
-
for attempt in range(1, 1 + self.max_retries):
|
84
|
-
try:
|
85
|
-
await waiter.wait(
|
86
|
-
jobs=[self.job_id],
|
87
|
-
WaiterConfig={
|
88
|
-
"Delay": self.poll_interval,
|
89
|
-
"MaxAttempts": 1,
|
90
|
-
},
|
91
|
-
)
|
92
|
-
except WaiterError as error:
|
93
|
-
if "terminal failure" in str(error):
|
94
|
-
yield TriggerEvent(
|
95
|
-
{"status": "failure", "message": f"Delete Cluster Failed: {error}"}
|
96
|
-
)
|
97
|
-
break
|
98
|
-
self.log.info(
|
99
|
-
"Job status is %s. Retrying attempt %s/%s",
|
100
|
-
error.last_response["jobs"][0]["status"],
|
101
|
-
attempt,
|
102
|
-
self.max_retries,
|
103
|
-
)
|
104
|
-
await asyncio.sleep(int(self.poll_interval))
|
105
|
-
else:
|
106
|
-
yield TriggerEvent({"status": "success", "job_id": self.job_id})
|
107
|
-
break
|
108
|
-
else:
|
109
|
-
yield TriggerEvent({"status": "failure", "message": "Job Failed - max attempts reached."})
|
110
|
-
|
111
|
-
|
112
|
-
@deprecated(reason="use BatchJobTrigger instead", category=AirflowProviderDeprecationWarning)
|
113
|
-
class BatchSensorTrigger(BaseTrigger):
|
114
|
-
"""
|
115
|
-
Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state.
|
116
|
-
|
117
|
-
BatchSensorTrigger is fired as deferred class with params to poll the job state in Triggerer.
|
118
|
-
|
119
|
-
:param job_id: the job ID, to poll for job completion or not
|
120
|
-
:param region_name: AWS region name to use
|
121
|
-
Override the region_name in connection (if provided)
|
122
|
-
:param aws_conn_id: connection id of AWS credentials / region name. If None,
|
123
|
-
credential boto3 strategy will be used
|
124
|
-
:param poke_interval: polling period in seconds to check for the status of the job
|
125
|
-
"""
|
126
|
-
|
127
|
-
def __init__(
|
128
|
-
self,
|
129
|
-
job_id: str,
|
130
|
-
region_name: str | None,
|
131
|
-
aws_conn_id: str | None = "aws_default",
|
132
|
-
poke_interval: float = 5,
|
133
|
-
):
|
134
|
-
super().__init__()
|
135
|
-
self.job_id = job_id
|
136
|
-
self.aws_conn_id = aws_conn_id
|
137
|
-
self.region_name = region_name
|
138
|
-
self.poke_interval = poke_interval
|
139
|
-
|
140
|
-
def serialize(self) -> tuple[str, dict[str, Any]]:
|
141
|
-
"""Serialize BatchSensorTrigger arguments and classpath."""
|
142
|
-
return (
|
143
|
-
"airflow.providers.amazon.aws.triggers.batch.BatchSensorTrigger",
|
144
|
-
{
|
145
|
-
"job_id": self.job_id,
|
146
|
-
"aws_conn_id": self.aws_conn_id,
|
147
|
-
"region_name": self.region_name,
|
148
|
-
"poke_interval": self.poke_interval,
|
149
|
-
},
|
150
|
-
)
|
151
|
-
|
152
|
-
@cached_property
|
153
|
-
def hook(self) -> BatchClientHook:
|
154
|
-
return BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
155
|
-
|
156
|
-
async def run(self):
|
157
|
-
"""
|
158
|
-
Make async connection using aiobotocore library to AWS Batch, periodically poll for the job status.
|
159
|
-
|
160
|
-
The status that indicates job completion are: 'SUCCEEDED'|'FAILED'.
|
161
|
-
"""
|
162
|
-
async with self.hook.async_conn as client:
|
163
|
-
waiter = self.hook.get_waiter("batch_job_complete", deferrable=True, client=client)
|
164
|
-
for attempt in itertools.count(1):
|
165
|
-
try:
|
166
|
-
await waiter.wait(
|
167
|
-
jobs=[self.job_id],
|
168
|
-
WaiterConfig={
|
169
|
-
"Delay": int(self.poke_interval),
|
170
|
-
"MaxAttempts": 1,
|
171
|
-
},
|
172
|
-
)
|
173
|
-
except WaiterError as error:
|
174
|
-
if "error" in str(error):
|
175
|
-
yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"})
|
176
|
-
break
|
177
|
-
self.log.info(
|
178
|
-
"Job response is %s. Retrying attempt %s",
|
179
|
-
error.last_response["Error"]["Message"],
|
180
|
-
attempt,
|
181
|
-
)
|
182
|
-
await asyncio.sleep(int(self.poke_interval))
|
183
|
-
else:
|
184
|
-
break
|
185
|
-
|
186
|
-
yield TriggerEvent(
|
187
|
-
{
|
188
|
-
"status": "success",
|
189
|
-
"job_id": self.job_id,
|
190
|
-
"message": f"Job {self.job_id} Succeeded",
|
191
|
-
}
|
192
|
-
)
|
193
|
-
|
194
|
-
|
195
28
|
class BatchJobTrigger(AwsBaseWaiterTrigger):
|
196
29
|
"""
|
197
30
|
Checks for the status of a submitted job_id to AWS Batch until it reaches a failure or a success state.
|
@@ -16,12 +16,11 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import warnings
|
20
19
|
from typing import TYPE_CHECKING, Any
|
21
20
|
|
22
21
|
from botocore.exceptions import ClientError
|
23
22
|
|
24
|
-
from airflow.exceptions import AirflowException
|
23
|
+
from airflow.exceptions import AirflowException
|
25
24
|
from airflow.providers.amazon.aws.hooks.eks import EksHook
|
26
25
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
27
26
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
@@ -235,17 +234,8 @@ class EksCreateFargateProfileTrigger(AwsBaseWaiterTrigger):
|
|
235
234
|
waiter_delay: int,
|
236
235
|
waiter_max_attempts: int,
|
237
236
|
aws_conn_id: str | None,
|
238
|
-
region: str | None = None,
|
239
237
|
region_name: str | None = None,
|
240
238
|
):
|
241
|
-
if region is not None:
|
242
|
-
warnings.warn(
|
243
|
-
"please use region_name param instead of region",
|
244
|
-
AirflowProviderDeprecationWarning,
|
245
|
-
stacklevel=2,
|
246
|
-
)
|
247
|
-
region_name = region
|
248
|
-
|
249
239
|
super().__init__(
|
250
240
|
serialized_fields={"cluster_name": cluster_name, "fargate_profile_name": fargate_profile_name},
|
251
241
|
waiter_name="fargate_profile_active",
|
@@ -282,17 +272,8 @@ class EksDeleteFargateProfileTrigger(AwsBaseWaiterTrigger):
|
|
282
272
|
waiter_delay: int,
|
283
273
|
waiter_max_attempts: int,
|
284
274
|
aws_conn_id: str | None,
|
285
|
-
region: str | None = None,
|
286
275
|
region_name: str | None = None,
|
287
276
|
):
|
288
|
-
if region is not None:
|
289
|
-
warnings.warn(
|
290
|
-
"please use region_name param instead of region",
|
291
|
-
AirflowProviderDeprecationWarning,
|
292
|
-
stacklevel=2,
|
293
|
-
)
|
294
|
-
region_name = region
|
295
|
-
|
296
277
|
super().__init__(
|
297
278
|
serialized_fields={"cluster_name": cluster_name, "fargate_profile_name": fargate_profile_name},
|
298
279
|
waiter_name="fargate_profile_deleted",
|
@@ -17,10 +17,8 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import sys
|
20
|
-
import warnings
|
21
20
|
from typing import TYPE_CHECKING
|
22
21
|
|
23
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
24
22
|
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
|
25
23
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
26
24
|
|
@@ -81,21 +79,10 @@ class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
|
|
81
79
|
def __init__(
|
82
80
|
self,
|
83
81
|
job_flow_id: str,
|
84
|
-
poll_interval: int | None = None, # deprecated
|
85
|
-
max_attempts: int | None = None, # deprecated
|
86
82
|
aws_conn_id: str | None = None,
|
87
83
|
waiter_delay: int = 30,
|
88
84
|
waiter_max_attempts: int = 60,
|
89
85
|
):
|
90
|
-
if poll_interval is not None or max_attempts is not None:
|
91
|
-
warnings.warn(
|
92
|
-
"please use waiter_delay instead of poll_interval "
|
93
|
-
"and waiter_max_attempts instead of max_attempts",
|
94
|
-
AirflowProviderDeprecationWarning,
|
95
|
-
stacklevel=2,
|
96
|
-
)
|
97
|
-
waiter_delay = poll_interval or waiter_delay
|
98
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
99
86
|
super().__init__(
|
100
87
|
serialized_fields={"job_flow_id": job_flow_id},
|
101
88
|
waiter_name="job_flow_waiting",
|
@@ -131,21 +118,10 @@ class EmrTerminateJobFlowTrigger(AwsBaseWaiterTrigger):
|
|
131
118
|
def __init__(
|
132
119
|
self,
|
133
120
|
job_flow_id: str,
|
134
|
-
poll_interval: int | None = None, # deprecated
|
135
|
-
max_attempts: int | None = None, # deprecated
|
136
121
|
aws_conn_id: str | None = None,
|
137
122
|
waiter_delay: int = 30,
|
138
123
|
waiter_max_attempts: int = 60,
|
139
124
|
):
|
140
|
-
if poll_interval is not None or max_attempts is not None:
|
141
|
-
warnings.warn(
|
142
|
-
"please use waiter_delay instead of poll_interval "
|
143
|
-
"and waiter_max_attempts instead of max_attempts",
|
144
|
-
AirflowProviderDeprecationWarning,
|
145
|
-
stacklevel=2,
|
146
|
-
)
|
147
|
-
waiter_delay = poll_interval or waiter_delay
|
148
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
149
125
|
super().__init__(
|
150
126
|
serialized_fields={"job_flow_id": job_flow_id},
|
151
127
|
waiter_name="job_flow_terminated",
|
@@ -183,17 +159,9 @@ class EmrContainerTrigger(AwsBaseWaiterTrigger):
|
|
183
159
|
virtual_cluster_id: str,
|
184
160
|
job_id: str,
|
185
161
|
aws_conn_id: str | None = "aws_default",
|
186
|
-
poll_interval: int | None = None, # deprecated
|
187
162
|
waiter_delay: int = 30,
|
188
163
|
waiter_max_attempts: int = sys.maxsize,
|
189
164
|
):
|
190
|
-
if poll_interval is not None:
|
191
|
-
warnings.warn(
|
192
|
-
"please use waiter_delay instead of poll_interval.",
|
193
|
-
AirflowProviderDeprecationWarning,
|
194
|
-
stacklevel=2,
|
195
|
-
)
|
196
|
-
waiter_delay = poll_interval or waiter_delay
|
197
165
|
super().__init__(
|
198
166
|
serialized_fields={"virtual_cluster_id": virtual_cluster_id, "job_id": job_id},
|
199
167
|
waiter_name="container_job_complete",
|
@@ -16,10 +16,8 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import warnings
|
20
19
|
from typing import TYPE_CHECKING
|
21
20
|
|
22
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
23
21
|
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
|
24
22
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
25
23
|
|
@@ -32,26 +30,17 @@ class GlueCrawlerCompleteTrigger(AwsBaseWaiterTrigger):
|
|
32
30
|
Watches for a glue crawl, triggers when it finishes.
|
33
31
|
|
34
32
|
:param crawler_name: name of the crawler to watch
|
35
|
-
:param poll_interval: The amount of time in seconds to wait between attempts.
|
36
33
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
37
34
|
"""
|
38
35
|
|
39
36
|
def __init__(
|
40
37
|
self,
|
41
38
|
crawler_name: str,
|
42
|
-
poll_interval: int | None = None,
|
43
39
|
aws_conn_id: str | None = "aws_default",
|
44
40
|
waiter_delay: int = 5,
|
45
41
|
waiter_max_attempts: int = 1500,
|
46
42
|
**kwargs,
|
47
43
|
):
|
48
|
-
if poll_interval is not None:
|
49
|
-
warnings.warn(
|
50
|
-
"please use waiter_delay instead of poll_interval.",
|
51
|
-
AirflowProviderDeprecationWarning,
|
52
|
-
stacklevel=2,
|
53
|
-
)
|
54
|
-
waiter_delay = poll_interval or waiter_delay
|
55
44
|
super().__init__(
|
56
45
|
serialized_fields={"crawler_name": crawler_name},
|
57
46
|
waiter_name="crawler_ready",
|
@@ -17,9 +17,6 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import warnings
|
21
|
-
|
22
|
-
from airflow.exceptions import AirflowProviderDeprecationWarning
|
23
20
|
from airflow.providers.amazon.aws.hooks.glue_databrew import GlueDataBrewHook
|
24
21
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
25
22
|
|
@@ -30,9 +27,7 @@ class GlueDataBrewJobCompleteTrigger(AwsBaseWaiterTrigger):
|
|
30
27
|
|
31
28
|
:param job_name: Glue DataBrew job name
|
32
29
|
:param run_id: the ID of the specific run to watch for that job
|
33
|
-
:param delay: Number of seconds to wait between two checks.(Deprecated).
|
34
30
|
:param waiter_delay: Number of seconds to wait between two checks. Default is 30 seconds.
|
35
|
-
:param max_attempts: Maximum number of attempts to wait for the job to complete.(Deprecated).
|
36
31
|
:param waiter_max_attempts: Maximum number of attempts to wait for the job to complete. Default is 60 attempts.
|
37
32
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
38
33
|
"""
|
@@ -41,27 +36,11 @@ class GlueDataBrewJobCompleteTrigger(AwsBaseWaiterTrigger):
|
|
41
36
|
self,
|
42
37
|
job_name: str,
|
43
38
|
run_id: str,
|
44
|
-
delay: int | None = None,
|
45
|
-
max_attempts: int | None = None,
|
46
39
|
waiter_delay: int = 30,
|
47
40
|
waiter_max_attempts: int = 60,
|
48
41
|
aws_conn_id: str | None = "aws_default",
|
49
42
|
**kwargs,
|
50
43
|
):
|
51
|
-
if delay is not None:
|
52
|
-
warnings.warn(
|
53
|
-
"please use `waiter_delay` instead of delay.",
|
54
|
-
AirflowProviderDeprecationWarning,
|
55
|
-
stacklevel=2,
|
56
|
-
)
|
57
|
-
waiter_delay = delay or waiter_delay
|
58
|
-
if max_attempts is not None:
|
59
|
-
warnings.warn(
|
60
|
-
"please use `waiter_max_attempts` instead of max_attempts.",
|
61
|
-
AirflowProviderDeprecationWarning,
|
62
|
-
stacklevel=2,
|
63
|
-
)
|
64
|
-
waiter_max_attempts = max_attempts or waiter_max_attempts
|
65
44
|
super().__init__(
|
66
45
|
serialized_fields={"job_name": job_name, "run_id": run_id},
|
67
46
|
waiter_name="job_complete",
|