apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
- airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
- airflow/providers/amazon/aws/hooks/athena.py +15 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
- airflow/providers/amazon/aws/hooks/emr.py +6 -0
- airflow/providers/amazon/aws/hooks/logs.py +85 -1
- airflow/providers/amazon/aws/hooks/neptune.py +85 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
- airflow/providers/amazon/aws/hooks/s3.py +4 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
- airflow/providers/amazon/aws/links/emr.py +122 -2
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
- airflow/providers/amazon/aws/operators/athena.py +4 -1
- airflow/providers/amazon/aws/operators/batch.py +5 -6
- airflow/providers/amazon/aws/operators/ecs.py +6 -2
- airflow/providers/amazon/aws/operators/eks.py +31 -26
- airflow/providers/amazon/aws/operators/emr.py +192 -26
- airflow/providers/amazon/aws/operators/glue.py +5 -2
- airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
- airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
- airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
- airflow/providers/amazon/aws/operators/neptune.py +218 -0
- airflow/providers/amazon/aws/operators/rds.py +21 -12
- airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
- airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
- airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
- airflow/providers/amazon/aws/operators/step_function.py +4 -1
- airflow/providers/amazon/aws/sensors/batch.py +2 -2
- airflow/providers/amazon/aws/sensors/ec2.py +4 -2
- airflow/providers/amazon/aws/sensors/emr.py +13 -6
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
- airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
- airflow/providers/amazon/aws/sensors/s3.py +3 -0
- airflow/providers/amazon/aws/sensors/sqs.py +4 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
- airflow/providers/amazon/aws/triggers/neptune.py +115 -0
- airflow/providers/amazon/aws/triggers/rds.py +9 -7
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
- airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
- airflow/providers/amazon/aws/utils/__init__.py +10 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
- airflow/providers/amazon/aws/utils/mixins.py +5 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
- airflow/providers/amazon/aws/waiters/neptune.json +85 -0
- airflow/providers/amazon/get_provider_info.py +26 -2
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -32,6 +32,7 @@ from airflow.providers.amazon.aws.triggers.emr import (
|
|
32
32
|
EmrStepSensorTrigger,
|
33
33
|
EmrTerminateJobFlowTrigger,
|
34
34
|
)
|
35
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
35
36
|
from airflow.sensors.base import BaseSensorOperator
|
36
37
|
|
37
38
|
if TYPE_CHECKING:
|
@@ -335,15 +336,17 @@ class EmrContainerSensor(BaseSensorOperator):
|
|
335
336
|
method_name="execute_complete",
|
336
337
|
)
|
337
338
|
|
338
|
-
def execute_complete(self, context, event=None):
|
339
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
340
|
+
event = validate_execute_complete_event(event)
|
341
|
+
|
339
342
|
if event["status"] != "success":
|
340
343
|
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
341
344
|
message = f"Error while running job: {event}"
|
342
345
|
if self.soft_fail:
|
343
346
|
raise AirflowSkipException(message)
|
344
347
|
raise AirflowException(message)
|
345
|
-
|
346
|
-
|
348
|
+
|
349
|
+
self.log.info("Job completed.")
|
347
350
|
|
348
351
|
|
349
352
|
class EmrNotebookExecutionSensor(EmrBaseSensor):
|
@@ -526,7 +529,9 @@ class EmrJobFlowSensor(EmrBaseSensor):
|
|
526
529
|
method_name="execute_complete",
|
527
530
|
)
|
528
531
|
|
529
|
-
def execute_complete(self, context: Context, event=None) -> None:
|
532
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
533
|
+
event = validate_execute_complete_event(event)
|
534
|
+
|
530
535
|
if event["status"] != "success":
|
531
536
|
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
532
537
|
message = f"Error while running job: {event}"
|
@@ -657,7 +662,9 @@ class EmrStepSensor(EmrBaseSensor):
|
|
657
662
|
method_name="execute_complete",
|
658
663
|
)
|
659
664
|
|
660
|
-
def execute_complete(self, context, event=None):
|
665
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
666
|
+
event = validate_execute_complete_event(event)
|
667
|
+
|
661
668
|
if event["status"] != "success":
|
662
669
|
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
|
663
670
|
message = f"Error while running job: {event}"
|
@@ -665,4 +672,4 @@ class EmrStepSensor(EmrBaseSensor):
|
|
665
672
|
raise AirflowSkipException(message)
|
666
673
|
raise AirflowException(message)
|
667
674
|
|
668
|
-
self.log.info("Job completed.")
|
675
|
+
self.log.info("Job %s completed.", self.job_flow_id)
|
@@ -27,6 +27,7 @@ from airflow.configuration import conf
|
|
27
27
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
|
28
28
|
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
|
29
29
|
from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger
|
30
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
30
31
|
from airflow.sensors.base import BaseSensorOperator
|
31
32
|
|
32
33
|
if TYPE_CHECKING:
|
@@ -111,7 +112,9 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
|
|
111
112
|
return self.hook.check_for_partition(self.database_name, self.table_name, self.expression)
|
112
113
|
|
113
114
|
def execute_complete(self, context: Context, event: dict | None = None) -> None:
|
114
|
-
|
115
|
+
event = validate_execute_complete_event(event)
|
116
|
+
|
117
|
+
if event["status"] != "success":
|
115
118
|
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
116
119
|
message = f"Trigger error: event is {event}"
|
117
120
|
if self.soft_fail:
|
@@ -17,10 +17,11 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import warnings
|
21
20
|
from functools import cached_property
|
22
21
|
from typing import TYPE_CHECKING, Sequence
|
23
22
|
|
23
|
+
from deprecated import deprecated
|
24
|
+
|
24
25
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
|
25
26
|
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
|
26
27
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
@@ -80,24 +81,26 @@ class QuickSightSensor(AwsBaseSensor[QuickSightHook]):
|
|
80
81
|
return quicksight_ingestion_state == self.success_status
|
81
82
|
|
82
83
|
@cached_property
|
84
|
+
@deprecated(
|
85
|
+
reason=(
|
86
|
+
"`QuickSightSensor.quicksight_hook` property is deprecated, "
|
87
|
+
"please use `QuickSightSensor.hook` property instead."
|
88
|
+
),
|
89
|
+
category=AirflowProviderDeprecationWarning,
|
90
|
+
)
|
83
91
|
def quicksight_hook(self):
|
84
|
-
warnings.warn(
|
85
|
-
f"`{type(self).__name__}.quicksight_hook` property is deprecated, "
|
86
|
-
f"please use `{type(self).__name__}.hook` property instead.",
|
87
|
-
AirflowProviderDeprecationWarning,
|
88
|
-
stacklevel=2,
|
89
|
-
)
|
90
92
|
return self.hook
|
91
93
|
|
92
94
|
@cached_property
|
93
|
-
|
94
|
-
|
95
|
-
|
95
|
+
@deprecated(
|
96
|
+
reason=(
|
97
|
+
"`QuickSightSensor.sts_hook` property is deprecated and will be removed in the future. "
|
96
98
|
"This property used for obtain AWS Account ID, "
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
99
|
+
"please consider to use `QuickSightSensor.hook.account_id` instead"
|
100
|
+
),
|
101
|
+
category=AirflowProviderDeprecationWarning,
|
102
|
+
)
|
103
|
+
def sts_hook(self):
|
101
104
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
102
105
|
|
103
106
|
return StsHook(aws_conn_id=self.aws_conn_id)
|
@@ -26,6 +26,7 @@ from airflow.configuration import conf
|
|
26
26
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
|
27
27
|
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
|
28
28
|
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger
|
29
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
29
30
|
from airflow.sensors.base import BaseSensorOperator
|
30
31
|
|
31
32
|
if TYPE_CHECKING:
|
@@ -88,10 +89,7 @@ class RedshiftClusterSensor(BaseSensorOperator):
|
|
88
89
|
)
|
89
90
|
|
90
91
|
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
|
91
|
-
|
92
|
-
err_msg = "Trigger error: event is None"
|
93
|
-
self.log.error(err_msg)
|
94
|
-
raise AirflowException(err_msg)
|
92
|
+
event = validate_execute_complete_event(event)
|
95
93
|
|
96
94
|
status = event["status"]
|
97
95
|
if status == "error":
|
@@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
|
|
27
27
|
from deprecated import deprecated
|
28
28
|
|
29
29
|
from airflow.configuration import conf
|
30
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
30
31
|
|
31
32
|
if TYPE_CHECKING:
|
32
33
|
from airflow.utils.context import Context
|
@@ -371,6 +372,8 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
|
|
371
372
|
|
372
373
|
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
|
373
374
|
"""
|
375
|
+
event = validate_execute_complete_event(event)
|
376
|
+
|
374
377
|
if event and event["status"] == "error":
|
375
378
|
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
376
379
|
if self.soft_fail:
|
@@ -28,6 +28,7 @@ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarni
|
|
28
28
|
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
|
29
29
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
30
30
|
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
|
31
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
31
32
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
32
33
|
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response
|
33
34
|
|
@@ -155,7 +156,9 @@ class SqsSensor(AwsBaseSensor[SqsHook]):
|
|
155
156
|
super().execute(context=context)
|
156
157
|
|
157
158
|
def execute_complete(self, context: Context, event: dict | None = None) -> None:
|
158
|
-
|
159
|
+
event = validate_execute_complete_event(event)
|
160
|
+
|
161
|
+
if event["status"] != "success":
|
159
162
|
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
160
163
|
message = f"Trigger error: event is {event}"
|
161
164
|
if self.soft_fail:
|
@@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator):
|
|
81
81
|
You can specify this argument if you want to use a different
|
82
82
|
CA cert bundle than the one used by botocore.
|
83
83
|
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
|
84
|
+
:param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data
|
85
|
+
is larger than that, it will be dispatched into multiple files.
|
86
|
+
Will be ignored if ``groupby_kwargs`` argument is specified.
|
84
87
|
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
|
85
88
|
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
|
86
89
|
"""
|
@@ -110,6 +113,7 @@ class SqlToS3Operator(BaseOperator):
|
|
110
113
|
aws_conn_id: str = "aws_default",
|
111
114
|
verify: bool | str | None = None,
|
112
115
|
file_format: Literal["csv", "json", "parquet"] = "csv",
|
116
|
+
max_rows_per_file: int = 0,
|
113
117
|
pd_kwargs: dict | None = None,
|
114
118
|
groupby_kwargs: dict | None = None,
|
115
119
|
**kwargs,
|
@@ -124,12 +128,19 @@ class SqlToS3Operator(BaseOperator):
|
|
124
128
|
self.replace = replace
|
125
129
|
self.pd_kwargs = pd_kwargs or {}
|
126
130
|
self.parameters = parameters
|
131
|
+
self.max_rows_per_file = max_rows_per_file
|
127
132
|
self.groupby_kwargs = groupby_kwargs or {}
|
128
133
|
self.sql_hook_params = sql_hook_params
|
129
134
|
|
130
135
|
if "path_or_buf" in self.pd_kwargs:
|
131
136
|
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
|
132
137
|
|
138
|
+
if self.max_rows_per_file and self.groupby_kwargs:
|
139
|
+
raise AirflowException(
|
140
|
+
"SqlToS3Operator arguments max_rows_per_file and groupby_kwargs "
|
141
|
+
"can not be both specified. Please choose one."
|
142
|
+
)
|
143
|
+
|
133
144
|
try:
|
134
145
|
self.file_format = FILE_FORMAT[file_format.upper()]
|
135
146
|
except KeyError:
|
@@ -177,10 +188,8 @@ class SqlToS3Operator(BaseOperator):
|
|
177
188
|
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
178
189
|
data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters)
|
179
190
|
self.log.info("Data from SQL obtained")
|
180
|
-
|
181
191
|
self._fix_dtypes(data_df, self.file_format)
|
182
192
|
file_options = FILE_OPTIONS_MAP[self.file_format]
|
183
|
-
|
184
193
|
for group_name, df in self._partition_dataframe(df=data_df):
|
185
194
|
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
|
186
195
|
self.log.info("Writing data to temp file")
|
@@ -194,13 +203,32 @@ class SqlToS3Operator(BaseOperator):
|
|
194
203
|
|
195
204
|
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
|
196
205
|
"""Partition dataframe using pandas groupby() method."""
|
206
|
+
try:
|
207
|
+
import secrets
|
208
|
+
import string
|
209
|
+
|
210
|
+
import numpy as np
|
211
|
+
except ImportError:
|
212
|
+
pass
|
213
|
+
# if max_rows_per_file argument is specified, a temporary column with a random unusual name will be
|
214
|
+
# added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby()
|
215
|
+
random_column_name = ""
|
216
|
+
if self.max_rows_per_file and not self.groupby_kwargs:
|
217
|
+
random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20))
|
218
|
+
df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file
|
219
|
+
self.groupby_kwargs = {"by": random_column_name}
|
197
220
|
if not self.groupby_kwargs:
|
198
221
|
yield "", df
|
199
222
|
return
|
200
223
|
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
|
201
224
|
yield (
|
202
225
|
cast(str, group_label),
|
203
|
-
cast(
|
226
|
+
cast(
|
227
|
+
"pd.DataFrame",
|
228
|
+
grouped_df.get_group(group_label)
|
229
|
+
.drop(random_column_name, axis=1, errors="ignore")
|
230
|
+
.reset_index(drop=True),
|
231
|
+
),
|
204
232
|
)
|
205
233
|
|
206
234
|
def _get_hook(self) -> DbApiHook:
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
|
22
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
|
+
|
27
|
+
|
28
|
+
class NeptuneClusterAvailableTrigger(AwsBaseWaiterTrigger):
|
29
|
+
"""
|
30
|
+
Triggers when a Neptune Cluster is available.
|
31
|
+
|
32
|
+
:param db_cluster_id: Cluster ID to poll.
|
33
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
34
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
35
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
36
|
+
:param region_name: AWS region name (example: us-east-1)
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
*,
|
42
|
+
db_cluster_id: str,
|
43
|
+
waiter_delay: int = 30,
|
44
|
+
waiter_max_attempts: int = 60,
|
45
|
+
aws_conn_id: str | None = None,
|
46
|
+
region_name: str | None = None,
|
47
|
+
**kwargs,
|
48
|
+
) -> None:
|
49
|
+
super().__init__(
|
50
|
+
serialized_fields={"db_cluster_id": db_cluster_id},
|
51
|
+
waiter_name="cluster_available",
|
52
|
+
waiter_args={"DBClusterIdentifier": db_cluster_id},
|
53
|
+
failure_message="Failed to start Neptune cluster",
|
54
|
+
status_message="Status of Neptune cluster is",
|
55
|
+
status_queries=["DBClusters[0].Status"],
|
56
|
+
return_key="db_cluster_id",
|
57
|
+
return_value=db_cluster_id,
|
58
|
+
waiter_delay=waiter_delay,
|
59
|
+
waiter_max_attempts=waiter_max_attempts,
|
60
|
+
aws_conn_id=aws_conn_id,
|
61
|
+
**kwargs,
|
62
|
+
)
|
63
|
+
|
64
|
+
def hook(self) -> AwsGenericHook:
|
65
|
+
return NeptuneHook(
|
66
|
+
aws_conn_id=self.aws_conn_id,
|
67
|
+
region_name=self.region_name,
|
68
|
+
verify=self.verify,
|
69
|
+
config=self.botocore_config,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
class NeptuneClusterStoppedTrigger(AwsBaseWaiterTrigger):
|
74
|
+
"""
|
75
|
+
Triggers when a Neptune Cluster is stopped.
|
76
|
+
|
77
|
+
:param db_cluster_id: Cluster ID to poll.
|
78
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
79
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
80
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
81
|
+
:param region_name: AWS region name (example: us-east-1)
|
82
|
+
"""
|
83
|
+
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
*,
|
87
|
+
db_cluster_id: str,
|
88
|
+
waiter_delay: int = 30,
|
89
|
+
waiter_max_attempts: int = 60,
|
90
|
+
aws_conn_id: str | None = None,
|
91
|
+
region_name: str | None = None,
|
92
|
+
**kwargs,
|
93
|
+
) -> None:
|
94
|
+
super().__init__(
|
95
|
+
serialized_fields={"db_cluster_id": db_cluster_id},
|
96
|
+
waiter_name="cluster_stopped",
|
97
|
+
waiter_args={"DBClusterIdentifier": db_cluster_id},
|
98
|
+
failure_message="Failed to stop Neptune cluster",
|
99
|
+
status_message="Status of Neptune cluster is",
|
100
|
+
status_queries=["DBClusters[0].Status"],
|
101
|
+
return_key="db_cluster_id",
|
102
|
+
return_value=db_cluster_id,
|
103
|
+
waiter_delay=waiter_delay,
|
104
|
+
waiter_max_attempts=waiter_max_attempts,
|
105
|
+
aws_conn_id=aws_conn_id,
|
106
|
+
**kwargs,
|
107
|
+
)
|
108
|
+
|
109
|
+
def hook(self) -> AwsGenericHook:
|
110
|
+
return NeptuneHook(
|
111
|
+
aws_conn_id=self.aws_conn_id,
|
112
|
+
region_name=self.region_name,
|
113
|
+
verify=self.verify,
|
114
|
+
config=self.botocore_config,
|
115
|
+
)
|
@@ -16,10 +16,11 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import warnings
|
20
19
|
from functools import cached_property
|
21
20
|
from typing import TYPE_CHECKING, Any
|
22
21
|
|
22
|
+
from deprecated import deprecated
|
23
|
+
|
23
24
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
24
25
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
25
26
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
|
|
31
32
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
32
33
|
|
33
34
|
|
35
|
+
@deprecated(
|
36
|
+
reason=(
|
37
|
+
"This trigger is deprecated, please use the other RDS triggers "
|
38
|
+
"such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or RdsDbAvailableTrigger"
|
39
|
+
),
|
40
|
+
category=AirflowProviderDeprecationWarning,
|
41
|
+
)
|
34
42
|
class RdsDbInstanceTrigger(BaseTrigger):
|
35
43
|
"""
|
36
44
|
Deprecated Trigger for RDS operations. Do not use.
|
@@ -55,12 +63,6 @@ class RdsDbInstanceTrigger(BaseTrigger):
|
|
55
63
|
region_name: str | None,
|
56
64
|
response: dict[str, Any],
|
57
65
|
):
|
58
|
-
warnings.warn(
|
59
|
-
"This trigger is deprecated, please use the other RDS triggers "
|
60
|
-
"such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or RdsDbAvailableTrigger",
|
61
|
-
AirflowProviderDeprecationWarning,
|
62
|
-
stacklevel=2,
|
63
|
-
)
|
64
66
|
self.db_instance_identifier = db_instance_identifier
|
65
67
|
self.waiter_delay = waiter_delay
|
66
68
|
self.waiter_max_attempts = waiter_max_attempts
|
@@ -290,7 +290,7 @@ class RedshiftClusterTrigger(BaseTrigger):
|
|
290
290
|
self.poke_interval = poke_interval
|
291
291
|
|
292
292
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
293
|
-
"""
|
293
|
+
"""Serialize RedshiftClusterTrigger arguments and classpath."""
|
294
294
|
return (
|
295
295
|
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
|
296
296
|
{
|
@@ -302,7 +302,7 @@ class RedshiftClusterTrigger(BaseTrigger):
|
|
302
302
|
)
|
303
303
|
|
304
304
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
305
|
-
"""
|
305
|
+
"""Run async until the cluster status matches the target status."""
|
306
306
|
try:
|
307
307
|
hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id)
|
308
308
|
while True:
|
@@ -63,7 +63,7 @@ class RedshiftDataTrigger(BaseTrigger):
|
|
63
63
|
self.botocore_config = botocore_config
|
64
64
|
|
65
65
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
66
|
-
"""
|
66
|
+
"""Serialize RedshiftDataTrigger arguments and classpath."""
|
67
67
|
return (
|
68
68
|
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
|
69
69
|
{
|
@@ -18,6 +18,7 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import asyncio
|
21
|
+
import time
|
21
22
|
from collections import Counter
|
22
23
|
from enum import IntEnum
|
23
24
|
from functools import cached_property
|
@@ -26,7 +27,7 @@ from typing import Any, AsyncIterator
|
|
26
27
|
from botocore.exceptions import WaiterError
|
27
28
|
|
28
29
|
from airflow.exceptions import AirflowException
|
29
|
-
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
30
|
+
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
30
31
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
31
32
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
32
33
|
|
@@ -196,3 +197,83 @@ class SageMakerPipelineTrigger(BaseTrigger):
|
|
196
197
|
await asyncio.sleep(int(self.waiter_delay))
|
197
198
|
|
198
199
|
raise AirflowException("Waiter error: max attempts reached")
|
200
|
+
|
201
|
+
|
202
|
+
class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
203
|
+
"""
|
204
|
+
SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer.
|
205
|
+
|
206
|
+
:param job_name: name of the job to check status
|
207
|
+
:param poke_interval: polling period in seconds to check for the status
|
208
|
+
:param aws_conn_id: AWS connection ID for sagemaker
|
209
|
+
"""
|
210
|
+
|
211
|
+
def __init__(
|
212
|
+
self,
|
213
|
+
job_name: str,
|
214
|
+
poke_interval: float,
|
215
|
+
aws_conn_id: str = "aws_default",
|
216
|
+
):
|
217
|
+
super().__init__()
|
218
|
+
self.job_name = job_name
|
219
|
+
self.poke_interval = poke_interval
|
220
|
+
self.aws_conn_id = aws_conn_id
|
221
|
+
|
222
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
223
|
+
"""Serialize SageMakerTrainingPrintLogTrigger arguments and classpath."""
|
224
|
+
return (
|
225
|
+
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
|
226
|
+
{
|
227
|
+
"poke_interval": self.poke_interval,
|
228
|
+
"aws_conn_id": self.aws_conn_id,
|
229
|
+
"job_name": self.job_name,
|
230
|
+
},
|
231
|
+
)
|
232
|
+
|
233
|
+
@cached_property
|
234
|
+
def hook(self) -> SageMakerHook:
|
235
|
+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
236
|
+
|
237
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
238
|
+
"""Make async connection to sagemaker async hook and gets job status for a job submitted by the operator."""
|
239
|
+
stream_names: list[str] = [] # The list of log streams
|
240
|
+
positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position
|
241
|
+
|
242
|
+
last_description = await self.hook.describe_training_job_async(self.job_name)
|
243
|
+
instance_count = last_description["ResourceConfig"]["InstanceCount"]
|
244
|
+
status = last_description["TrainingJobStatus"]
|
245
|
+
job_already_completed = status not in self.hook.non_terminal_states
|
246
|
+
state = LogState.COMPLETE if job_already_completed else LogState.TAILING
|
247
|
+
last_describe_job_call = time.time()
|
248
|
+
while True:
|
249
|
+
try:
|
250
|
+
(
|
251
|
+
state,
|
252
|
+
last_description,
|
253
|
+
last_describe_job_call,
|
254
|
+
) = await self.hook.describe_training_job_with_log_async(
|
255
|
+
self.job_name,
|
256
|
+
positions,
|
257
|
+
stream_names,
|
258
|
+
instance_count,
|
259
|
+
state,
|
260
|
+
last_description,
|
261
|
+
last_describe_job_call,
|
262
|
+
)
|
263
|
+
status = last_description["TrainingJobStatus"]
|
264
|
+
if status in self.hook.non_terminal_states:
|
265
|
+
await asyncio.sleep(self.poke_interval)
|
266
|
+
elif status in self.hook.failed_states:
|
267
|
+
reason = last_description.get("FailureReason", "(No reason provided)")
|
268
|
+
error_message = f"SageMaker job failed because {reason}"
|
269
|
+
yield TriggerEvent({"status": "error", "message": error_message})
|
270
|
+
else:
|
271
|
+
billable_seconds = SageMakerHook.count_billable_seconds(
|
272
|
+
training_start_time=last_description["TrainingStartTime"],
|
273
|
+
training_end_time=last_description["TrainingEndTime"],
|
274
|
+
instance_count=instance_count,
|
275
|
+
)
|
276
|
+
self.log.info("Billable seconds: %d", billable_seconds)
|
277
|
+
yield TriggerEvent({"status": "success", "message": last_description})
|
278
|
+
except Exception as e:
|
279
|
+
yield TriggerEvent({"status": "error", "message": str(e)})
|
@@ -20,7 +20,9 @@ import logging
|
|
20
20
|
import re
|
21
21
|
from datetime import datetime, timezone
|
22
22
|
from enum import Enum
|
23
|
+
from typing import Any
|
23
24
|
|
25
|
+
from airflow.exceptions import AirflowException
|
24
26
|
from airflow.utils.helpers import prune_dict
|
25
27
|
from airflow.version import version
|
26
28
|
|
@@ -72,6 +74,14 @@ def get_airflow_version() -> tuple[int, ...]:
|
|
72
74
|
return tuple(int(x) for x in match.groups())
|
73
75
|
|
74
76
|
|
77
|
+
def validate_execute_complete_event(event: dict[str, Any] | None = None) -> dict[str, Any]:
|
78
|
+
if event is None:
|
79
|
+
err_msg = "Trigger error: event is None"
|
80
|
+
log.error(err_msg)
|
81
|
+
raise AirflowException(err_msg)
|
82
|
+
return event
|
83
|
+
|
84
|
+
|
75
85
|
class _StringCompareEnum(Enum):
|
76
86
|
"""
|
77
87
|
An Enum class which can be compared with regular `str` and subclasses.
|
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
26
26
|
from botocore import UNSIGNED
|
27
27
|
from botocore.config import Config
|
28
|
+
from deprecated import deprecated
|
28
29
|
|
29
30
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
30
31
|
from airflow.providers.amazon.aws.utils import trim_none_values
|
@@ -165,7 +166,7 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
165
166
|
|
166
167
|
return service_config.get("endpoint_url", global_endpoint_url)
|
167
168
|
|
168
|
-
def __post_init__(self, conn: Connection):
|
169
|
+
def __post_init__(self, conn: Connection | AwsConnectionWrapper | _ConnectionMetadata | None) -> None:
|
169
170
|
if isinstance(conn, type(self)):
|
170
171
|
# For every field with init=False we copy reference value from original wrapper
|
171
172
|
# For every field with init=True we use init values if it not equal default
|
@@ -192,6 +193,9 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
192
193
|
elif not conn:
|
193
194
|
return
|
194
195
|
|
196
|
+
if TYPE_CHECKING:
|
197
|
+
assert isinstance(conn, (Connection, _ConnectionMetadata))
|
198
|
+
|
195
199
|
# Assign attributes from AWS Connection
|
196
200
|
self.conn_id = conn.conn_id
|
197
201
|
self.conn_type = conn.conn_type or "aws"
|
@@ -462,6 +466,13 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
462
466
|
return role_arn, assume_role_method, assume_role_kwargs
|
463
467
|
|
464
468
|
|
469
|
+
@deprecated(
|
470
|
+
reason=(
|
471
|
+
"Use local credentials file is never documented and well tested. "
|
472
|
+
"Obtain credentials by this way deprecated and will be removed in a future releases."
|
473
|
+
),
|
474
|
+
category=AirflowProviderDeprecationWarning,
|
475
|
+
)
|
465
476
|
def _parse_s3_config(
|
466
477
|
config_file_name: str, config_format: str | None = "boto", profile: str | None = None
|
467
478
|
) -> tuple[str | None, str | None]:
|
@@ -474,13 +485,6 @@ def _parse_s3_config(
|
|
474
485
|
Defaults to "boto"
|
475
486
|
:param profile: profile name in AWS type config file
|
476
487
|
"""
|
477
|
-
warnings.warn(
|
478
|
-
"Use local credentials file is never documented and well tested. "
|
479
|
-
"Obtain credentials by this way deprecated and will be removed in a future releases.",
|
480
|
-
AirflowProviderDeprecationWarning,
|
481
|
-
stacklevel=4,
|
482
|
-
)
|
483
|
-
|
484
488
|
import configparser
|
485
489
|
|
486
490
|
config = configparser.ConfigParser()
|
@@ -31,6 +31,7 @@ import warnings
|
|
31
31
|
from functools import cached_property
|
32
32
|
from typing import Any, Generic, NamedTuple, TypeVar
|
33
33
|
|
34
|
+
from deprecated import deprecated
|
34
35
|
from typing_extensions import final
|
35
36
|
|
36
37
|
from airflow.compat.functools import cache
|
@@ -160,9 +161,12 @@ class AwsBaseHookMixin(Generic[AwsHookType]):
|
|
160
161
|
|
161
162
|
@property
|
162
163
|
@final
|
164
|
+
@deprecated(
|
165
|
+
reason="`region` is deprecated and will be removed in the future. Please use `region_name` instead.",
|
166
|
+
category=AirflowProviderDeprecationWarning,
|
167
|
+
)
|
163
168
|
def region(self) -> str | None:
|
164
169
|
"""Alias for ``region_name``, used for compatibility (deprecated)."""
|
165
|
-
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
|
166
170
|
return self.region_name
|
167
171
|
|
168
172
|
|