apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +1 -1
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
- airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
- airflow/providers/amazon/aws/hooks/athena.py +15 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
- airflow/providers/amazon/aws/hooks/logs.py +85 -1
- airflow/providers/amazon/aws/hooks/neptune.py +85 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
- airflow/providers/amazon/aws/hooks/s3.py +4 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
- airflow/providers/amazon/aws/operators/eks.py +8 -6
- airflow/providers/amazon/aws/operators/neptune.py +218 -0
- airflow/providers/amazon/aws/operators/sagemaker.py +74 -15
- airflow/providers/amazon/aws/sensors/batch.py +2 -2
- airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
- airflow/providers/amazon/aws/triggers/neptune.py +115 -0
- airflow/providers/amazon/aws/triggers/rds.py +9 -7
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
- airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
- airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
- airflow/providers/amazon/aws/utils/mixins.py +5 -1
- airflow/providers/amazon/aws/waiters/neptune.json +85 -0
- airflow/providers/amazon/get_provider_info.py +22 -2
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/RECORD +37 -33
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -17,10 +17,11 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import warnings
|
21
20
|
from functools import cached_property
|
22
21
|
from typing import TYPE_CHECKING, Sequence
|
23
22
|
|
23
|
+
from deprecated import deprecated
|
24
|
+
|
24
25
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
|
25
26
|
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
|
26
27
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
@@ -80,24 +81,26 @@ class QuickSightSensor(AwsBaseSensor[QuickSightHook]):
|
|
80
81
|
return quicksight_ingestion_state == self.success_status
|
81
82
|
|
82
83
|
@cached_property
|
84
|
+
@deprecated(
|
85
|
+
reason=(
|
86
|
+
"`QuickSightSensor.quicksight_hook` property is deprecated, "
|
87
|
+
"please use `QuickSightSensor.hook` property instead."
|
88
|
+
),
|
89
|
+
category=AirflowProviderDeprecationWarning,
|
90
|
+
)
|
83
91
|
def quicksight_hook(self):
|
84
|
-
warnings.warn(
|
85
|
-
f"`{type(self).__name__}.quicksight_hook` property is deprecated, "
|
86
|
-
f"please use `{type(self).__name__}.hook` property instead.",
|
87
|
-
AirflowProviderDeprecationWarning,
|
88
|
-
stacklevel=2,
|
89
|
-
)
|
90
92
|
return self.hook
|
91
93
|
|
92
94
|
@cached_property
|
93
|
-
|
94
|
-
|
95
|
-
|
95
|
+
@deprecated(
|
96
|
+
reason=(
|
97
|
+
"`QuickSightSensor.sts_hook` property is deprecated and will be removed in the future. "
|
96
98
|
"This property used for obtain AWS Account ID, "
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
99
|
+
"please consider to use `QuickSightSensor.hook.account_id` instead"
|
100
|
+
),
|
101
|
+
category=AirflowProviderDeprecationWarning,
|
102
|
+
)
|
103
|
+
def sts_hook(self):
|
101
104
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
102
105
|
|
103
106
|
return StsHook(aws_conn_id=self.aws_conn_id)
|
@@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator):
|
|
81
81
|
You can specify this argument if you want to use a different
|
82
82
|
CA cert bundle than the one used by botocore.
|
83
83
|
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
|
84
|
+
:param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data
|
85
|
+
is larger than that, it will be dispatched into multiple files.
|
86
|
+
Will be ignored if ``groupby_kwargs`` argument is specified.
|
84
87
|
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
|
85
88
|
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
|
86
89
|
"""
|
@@ -110,6 +113,7 @@ class SqlToS3Operator(BaseOperator):
|
|
110
113
|
aws_conn_id: str = "aws_default",
|
111
114
|
verify: bool | str | None = None,
|
112
115
|
file_format: Literal["csv", "json", "parquet"] = "csv",
|
116
|
+
max_rows_per_file: int = 0,
|
113
117
|
pd_kwargs: dict | None = None,
|
114
118
|
groupby_kwargs: dict | None = None,
|
115
119
|
**kwargs,
|
@@ -124,12 +128,19 @@ class SqlToS3Operator(BaseOperator):
|
|
124
128
|
self.replace = replace
|
125
129
|
self.pd_kwargs = pd_kwargs or {}
|
126
130
|
self.parameters = parameters
|
131
|
+
self.max_rows_per_file = max_rows_per_file
|
127
132
|
self.groupby_kwargs = groupby_kwargs or {}
|
128
133
|
self.sql_hook_params = sql_hook_params
|
129
134
|
|
130
135
|
if "path_or_buf" in self.pd_kwargs:
|
131
136
|
raise AirflowException("The argument path_or_buf is not allowed, please remove it")
|
132
137
|
|
138
|
+
if self.max_rows_per_file and self.groupby_kwargs:
|
139
|
+
raise AirflowException(
|
140
|
+
"SqlToS3Operator arguments max_rows_per_file and groupby_kwargs "
|
141
|
+
"can not be both specified. Please choose one."
|
142
|
+
)
|
143
|
+
|
133
144
|
try:
|
134
145
|
self.file_format = FILE_FORMAT[file_format.upper()]
|
135
146
|
except KeyError:
|
@@ -177,10 +188,8 @@ class SqlToS3Operator(BaseOperator):
|
|
177
188
|
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
|
178
189
|
data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters)
|
179
190
|
self.log.info("Data from SQL obtained")
|
180
|
-
|
181
191
|
self._fix_dtypes(data_df, self.file_format)
|
182
192
|
file_options = FILE_OPTIONS_MAP[self.file_format]
|
183
|
-
|
184
193
|
for group_name, df in self._partition_dataframe(df=data_df):
|
185
194
|
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
|
186
195
|
self.log.info("Writing data to temp file")
|
@@ -194,13 +203,32 @@ class SqlToS3Operator(BaseOperator):
|
|
194
203
|
|
195
204
|
def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
|
196
205
|
"""Partition dataframe using pandas groupby() method."""
|
206
|
+
try:
|
207
|
+
import secrets
|
208
|
+
import string
|
209
|
+
|
210
|
+
import numpy as np
|
211
|
+
except ImportError:
|
212
|
+
pass
|
213
|
+
# if max_rows_per_file argument is specified, a temporary column with a random unusual name will be
|
214
|
+
# added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby()
|
215
|
+
random_column_name = ""
|
216
|
+
if self.max_rows_per_file and not self.groupby_kwargs:
|
217
|
+
random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20))
|
218
|
+
df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file
|
219
|
+
self.groupby_kwargs = {"by": random_column_name}
|
197
220
|
if not self.groupby_kwargs:
|
198
221
|
yield "", df
|
199
222
|
return
|
200
223
|
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
|
201
224
|
yield (
|
202
225
|
cast(str, group_label),
|
203
|
-
cast(
|
226
|
+
cast(
|
227
|
+
"pd.DataFrame",
|
228
|
+
grouped_df.get_group(group_label)
|
229
|
+
.drop(random_column_name, axis=1, errors="ignore")
|
230
|
+
.reset_index(drop=True),
|
231
|
+
),
|
204
232
|
)
|
205
233
|
|
206
234
|
def _get_hook(self) -> DbApiHook:
|
@@ -0,0 +1,115 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
|
22
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
|
+
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
26
|
+
|
27
|
+
|
28
|
+
class NeptuneClusterAvailableTrigger(AwsBaseWaiterTrigger):
|
29
|
+
"""
|
30
|
+
Triggers when a Neptune Cluster is available.
|
31
|
+
|
32
|
+
:param db_cluster_id: Cluster ID to poll.
|
33
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
34
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
35
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
36
|
+
:param region_name: AWS region name (example: us-east-1)
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
*,
|
42
|
+
db_cluster_id: str,
|
43
|
+
waiter_delay: int = 30,
|
44
|
+
waiter_max_attempts: int = 60,
|
45
|
+
aws_conn_id: str | None = None,
|
46
|
+
region_name: str | None = None,
|
47
|
+
**kwargs,
|
48
|
+
) -> None:
|
49
|
+
super().__init__(
|
50
|
+
serialized_fields={"db_cluster_id": db_cluster_id},
|
51
|
+
waiter_name="cluster_available",
|
52
|
+
waiter_args={"DBClusterIdentifier": db_cluster_id},
|
53
|
+
failure_message="Failed to start Neptune cluster",
|
54
|
+
status_message="Status of Neptune cluster is",
|
55
|
+
status_queries=["DBClusters[0].Status"],
|
56
|
+
return_key="db_cluster_id",
|
57
|
+
return_value=db_cluster_id,
|
58
|
+
waiter_delay=waiter_delay,
|
59
|
+
waiter_max_attempts=waiter_max_attempts,
|
60
|
+
aws_conn_id=aws_conn_id,
|
61
|
+
**kwargs,
|
62
|
+
)
|
63
|
+
|
64
|
+
def hook(self) -> AwsGenericHook:
|
65
|
+
return NeptuneHook(
|
66
|
+
aws_conn_id=self.aws_conn_id,
|
67
|
+
region_name=self.region_name,
|
68
|
+
verify=self.verify,
|
69
|
+
config=self.botocore_config,
|
70
|
+
)
|
71
|
+
|
72
|
+
|
73
|
+
class NeptuneClusterStoppedTrigger(AwsBaseWaiterTrigger):
|
74
|
+
"""
|
75
|
+
Triggers when a Neptune Cluster is stopped.
|
76
|
+
|
77
|
+
:param db_cluster_id: Cluster ID to poll.
|
78
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts.
|
79
|
+
:param waiter_max_attempts: The maximum number of attempts to be made.
|
80
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
81
|
+
:param region_name: AWS region name (example: us-east-1)
|
82
|
+
"""
|
83
|
+
|
84
|
+
def __init__(
|
85
|
+
self,
|
86
|
+
*,
|
87
|
+
db_cluster_id: str,
|
88
|
+
waiter_delay: int = 30,
|
89
|
+
waiter_max_attempts: int = 60,
|
90
|
+
aws_conn_id: str | None = None,
|
91
|
+
region_name: str | None = None,
|
92
|
+
**kwargs,
|
93
|
+
) -> None:
|
94
|
+
super().__init__(
|
95
|
+
serialized_fields={"db_cluster_id": db_cluster_id},
|
96
|
+
waiter_name="cluster_stopped",
|
97
|
+
waiter_args={"DBClusterIdentifier": db_cluster_id},
|
98
|
+
failure_message="Failed to stop Neptune cluster",
|
99
|
+
status_message="Status of Neptune cluster is",
|
100
|
+
status_queries=["DBClusters[0].Status"],
|
101
|
+
return_key="db_cluster_id",
|
102
|
+
return_value=db_cluster_id,
|
103
|
+
waiter_delay=waiter_delay,
|
104
|
+
waiter_max_attempts=waiter_max_attempts,
|
105
|
+
aws_conn_id=aws_conn_id,
|
106
|
+
**kwargs,
|
107
|
+
)
|
108
|
+
|
109
|
+
def hook(self) -> AwsGenericHook:
|
110
|
+
return NeptuneHook(
|
111
|
+
aws_conn_id=self.aws_conn_id,
|
112
|
+
region_name=self.region_name,
|
113
|
+
verify=self.verify,
|
114
|
+
config=self.botocore_config,
|
115
|
+
)
|
@@ -16,10 +16,11 @@
|
|
16
16
|
# under the License.
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
import warnings
|
20
19
|
from functools import cached_property
|
21
20
|
from typing import TYPE_CHECKING, Any
|
22
21
|
|
22
|
+
from deprecated import deprecated
|
23
|
+
|
23
24
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
24
25
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
25
26
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
|
|
31
32
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
32
33
|
|
33
34
|
|
35
|
+
@deprecated(
|
36
|
+
reason=(
|
37
|
+
"This trigger is deprecated, please use the other RDS triggers "
|
38
|
+
"such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or RdsDbAvailableTrigger"
|
39
|
+
),
|
40
|
+
category=AirflowProviderDeprecationWarning,
|
41
|
+
)
|
34
42
|
class RdsDbInstanceTrigger(BaseTrigger):
|
35
43
|
"""
|
36
44
|
Deprecated Trigger for RDS operations. Do not use.
|
@@ -55,12 +63,6 @@ class RdsDbInstanceTrigger(BaseTrigger):
|
|
55
63
|
region_name: str | None,
|
56
64
|
response: dict[str, Any],
|
57
65
|
):
|
58
|
-
warnings.warn(
|
59
|
-
"This trigger is deprecated, please use the other RDS triggers "
|
60
|
-
"such as RdsDbDeletedTrigger, RdsDbStoppedTrigger or RdsDbAvailableTrigger",
|
61
|
-
AirflowProviderDeprecationWarning,
|
62
|
-
stacklevel=2,
|
63
|
-
)
|
64
66
|
self.db_instance_identifier = db_instance_identifier
|
65
67
|
self.waiter_delay = waiter_delay
|
66
68
|
self.waiter_max_attempts = waiter_max_attempts
|
@@ -290,7 +290,7 @@ class RedshiftClusterTrigger(BaseTrigger):
|
|
290
290
|
self.poke_interval = poke_interval
|
291
291
|
|
292
292
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
293
|
-
"""
|
293
|
+
"""Serialize RedshiftClusterTrigger arguments and classpath."""
|
294
294
|
return (
|
295
295
|
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
|
296
296
|
{
|
@@ -302,7 +302,7 @@ class RedshiftClusterTrigger(BaseTrigger):
|
|
302
302
|
)
|
303
303
|
|
304
304
|
async def run(self) -> AsyncIterator[TriggerEvent]:
|
305
|
-
"""
|
305
|
+
"""Run async until the cluster status matches the target status."""
|
306
306
|
try:
|
307
307
|
hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id)
|
308
308
|
while True:
|
@@ -63,7 +63,7 @@ class RedshiftDataTrigger(BaseTrigger):
|
|
63
63
|
self.botocore_config = botocore_config
|
64
64
|
|
65
65
|
def serialize(self) -> tuple[str, dict[str, Any]]:
|
66
|
-
"""
|
66
|
+
"""Serialize RedshiftDataTrigger arguments and classpath."""
|
67
67
|
return (
|
68
68
|
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
|
69
69
|
{
|
@@ -18,6 +18,7 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import asyncio
|
21
|
+
import time
|
21
22
|
from collections import Counter
|
22
23
|
from enum import IntEnum
|
23
24
|
from functools import cached_property
|
@@ -26,7 +27,7 @@ from typing import Any, AsyncIterator
|
|
26
27
|
from botocore.exceptions import WaiterError
|
27
28
|
|
28
29
|
from airflow.exceptions import AirflowException
|
29
|
-
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
|
30
|
+
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
|
30
31
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
|
31
32
|
from airflow.triggers.base import BaseTrigger, TriggerEvent
|
32
33
|
|
@@ -196,3 +197,83 @@ class SageMakerPipelineTrigger(BaseTrigger):
|
|
196
197
|
await asyncio.sleep(int(self.waiter_delay))
|
197
198
|
|
198
199
|
raise AirflowException("Waiter error: max attempts reached")
|
200
|
+
|
201
|
+
|
202
|
+
class SageMakerTrainingPrintLogTrigger(BaseTrigger):
|
203
|
+
"""
|
204
|
+
SageMakerTrainingPrintLogTrigger is fired as deferred class with params to run the task in triggerer.
|
205
|
+
|
206
|
+
:param job_name: name of the job to check status
|
207
|
+
:param poke_interval: polling period in seconds to check for the status
|
208
|
+
:param aws_conn_id: AWS connection ID for sagemaker
|
209
|
+
"""
|
210
|
+
|
211
|
+
def __init__(
|
212
|
+
self,
|
213
|
+
job_name: str,
|
214
|
+
poke_interval: float,
|
215
|
+
aws_conn_id: str = "aws_default",
|
216
|
+
):
|
217
|
+
super().__init__()
|
218
|
+
self.job_name = job_name
|
219
|
+
self.poke_interval = poke_interval
|
220
|
+
self.aws_conn_id = aws_conn_id
|
221
|
+
|
222
|
+
def serialize(self) -> tuple[str, dict[str, Any]]:
|
223
|
+
"""Serialize SageMakerTrainingPrintLogTrigger arguments and classpath."""
|
224
|
+
return (
|
225
|
+
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrainingPrintLogTrigger",
|
226
|
+
{
|
227
|
+
"poke_interval": self.poke_interval,
|
228
|
+
"aws_conn_id": self.aws_conn_id,
|
229
|
+
"job_name": self.job_name,
|
230
|
+
},
|
231
|
+
)
|
232
|
+
|
233
|
+
@cached_property
|
234
|
+
def hook(self) -> SageMakerHook:
|
235
|
+
return SageMakerHook(aws_conn_id=self.aws_conn_id)
|
236
|
+
|
237
|
+
async def run(self) -> AsyncIterator[TriggerEvent]:
|
238
|
+
"""Make async connection to sagemaker async hook and gets job status for a job submitted by the operator."""
|
239
|
+
stream_names: list[str] = [] # The list of log streams
|
240
|
+
positions: dict[str, Any] = {} # The current position in each stream, map of stream name -> position
|
241
|
+
|
242
|
+
last_description = await self.hook.describe_training_job_async(self.job_name)
|
243
|
+
instance_count = last_description["ResourceConfig"]["InstanceCount"]
|
244
|
+
status = last_description["TrainingJobStatus"]
|
245
|
+
job_already_completed = status not in self.hook.non_terminal_states
|
246
|
+
state = LogState.COMPLETE if job_already_completed else LogState.TAILING
|
247
|
+
last_describe_job_call = time.time()
|
248
|
+
while True:
|
249
|
+
try:
|
250
|
+
(
|
251
|
+
state,
|
252
|
+
last_description,
|
253
|
+
last_describe_job_call,
|
254
|
+
) = await self.hook.describe_training_job_with_log_async(
|
255
|
+
self.job_name,
|
256
|
+
positions,
|
257
|
+
stream_names,
|
258
|
+
instance_count,
|
259
|
+
state,
|
260
|
+
last_description,
|
261
|
+
last_describe_job_call,
|
262
|
+
)
|
263
|
+
status = last_description["TrainingJobStatus"]
|
264
|
+
if status in self.hook.non_terminal_states:
|
265
|
+
await asyncio.sleep(self.poke_interval)
|
266
|
+
elif status in self.hook.failed_states:
|
267
|
+
reason = last_description.get("FailureReason", "(No reason provided)")
|
268
|
+
error_message = f"SageMaker job failed because {reason}"
|
269
|
+
yield TriggerEvent({"status": "error", "message": error_message})
|
270
|
+
else:
|
271
|
+
billable_seconds = SageMakerHook.count_billable_seconds(
|
272
|
+
training_start_time=last_description["TrainingStartTime"],
|
273
|
+
training_end_time=last_description["TrainingEndTime"],
|
274
|
+
instance_count=instance_count,
|
275
|
+
)
|
276
|
+
self.log.info("Billable seconds: %d", billable_seconds)
|
277
|
+
yield TriggerEvent({"status": "success", "message": last_description})
|
278
|
+
except Exception as e:
|
279
|
+
yield TriggerEvent({"status": "error", "message": str(e)})
|
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any
|
|
25
25
|
|
26
26
|
from botocore import UNSIGNED
|
27
27
|
from botocore.config import Config
|
28
|
+
from deprecated import deprecated
|
28
29
|
|
29
30
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
30
31
|
from airflow.providers.amazon.aws.utils import trim_none_values
|
@@ -165,7 +166,7 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
165
166
|
|
166
167
|
return service_config.get("endpoint_url", global_endpoint_url)
|
167
168
|
|
168
|
-
def __post_init__(self, conn: Connection):
|
169
|
+
def __post_init__(self, conn: Connection | AwsConnectionWrapper | _ConnectionMetadata | None) -> None:
|
169
170
|
if isinstance(conn, type(self)):
|
170
171
|
# For every field with init=False we copy reference value from original wrapper
|
171
172
|
# For every field with init=True we use init values if it not equal default
|
@@ -192,6 +193,9 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
192
193
|
elif not conn:
|
193
194
|
return
|
194
195
|
|
196
|
+
if TYPE_CHECKING:
|
197
|
+
assert isinstance(conn, (Connection, _ConnectionMetadata))
|
198
|
+
|
195
199
|
# Assign attributes from AWS Connection
|
196
200
|
self.conn_id = conn.conn_id
|
197
201
|
self.conn_type = conn.conn_type or "aws"
|
@@ -462,6 +466,13 @@ class AwsConnectionWrapper(LoggingMixin):
|
|
462
466
|
return role_arn, assume_role_method, assume_role_kwargs
|
463
467
|
|
464
468
|
|
469
|
+
@deprecated(
|
470
|
+
reason=(
|
471
|
+
"Use local credentials file is never documented and well tested. "
|
472
|
+
"Obtain credentials by this way deprecated and will be removed in a future releases."
|
473
|
+
),
|
474
|
+
category=AirflowProviderDeprecationWarning,
|
475
|
+
)
|
465
476
|
def _parse_s3_config(
|
466
477
|
config_file_name: str, config_format: str | None = "boto", profile: str | None = None
|
467
478
|
) -> tuple[str | None, str | None]:
|
@@ -474,13 +485,6 @@ def _parse_s3_config(
|
|
474
485
|
Defaults to "boto"
|
475
486
|
:param profile: profile name in AWS type config file
|
476
487
|
"""
|
477
|
-
warnings.warn(
|
478
|
-
"Use local credentials file is never documented and well tested. "
|
479
|
-
"Obtain credentials by this way deprecated and will be removed in a future releases.",
|
480
|
-
AirflowProviderDeprecationWarning,
|
481
|
-
stacklevel=4,
|
482
|
-
)
|
483
|
-
|
484
488
|
import configparser
|
485
489
|
|
486
490
|
config = configparser.ConfigParser()
|
@@ -31,6 +31,7 @@ import warnings
|
|
31
31
|
from functools import cached_property
|
32
32
|
from typing import Any, Generic, NamedTuple, TypeVar
|
33
33
|
|
34
|
+
from deprecated import deprecated
|
34
35
|
from typing_extensions import final
|
35
36
|
|
36
37
|
from airflow.compat.functools import cache
|
@@ -160,9 +161,12 @@ class AwsBaseHookMixin(Generic[AwsHookType]):
|
|
160
161
|
|
161
162
|
@property
|
162
163
|
@final
|
164
|
+
@deprecated(
|
165
|
+
reason="`region` is deprecated and will be removed in the future. Please use `region_name` instead.",
|
166
|
+
category=AirflowProviderDeprecationWarning,
|
167
|
+
)
|
163
168
|
def region(self) -> str | None:
|
164
169
|
"""Alias for ``region_name``, used for compatibility (deprecated)."""
|
165
|
-
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
|
166
170
|
return self.region_name
|
167
171
|
|
168
172
|
|
@@ -0,0 +1,85 @@
|
|
1
|
+
{
|
2
|
+
"version": 2,
|
3
|
+
"waiters": {
|
4
|
+
"cluster_available": {
|
5
|
+
"operation": "DescribeDBClusters",
|
6
|
+
"delay": 30,
|
7
|
+
"maxAttempts": 60,
|
8
|
+
"acceptors": [
|
9
|
+
{
|
10
|
+
"matcher": "path",
|
11
|
+
"argument": "DBClusters[0].Status",
|
12
|
+
"expected": "available",
|
13
|
+
"state": "success"
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"matcher": "path",
|
17
|
+
"argument": "DBClusters[0].Status",
|
18
|
+
"expected": "deleting",
|
19
|
+
"state": "failure"
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"matcher": "path",
|
23
|
+
"argument": "DBClusters[0].Status",
|
24
|
+
"expected": "inaccessible-encryption-credentials",
|
25
|
+
"state": "failure"
|
26
|
+
},
|
27
|
+
{
|
28
|
+
"matcher": "path",
|
29
|
+
"argument": "DBClusters[0].Status",
|
30
|
+
"expected": "inaccessible-encryption-credentials-recoverable",
|
31
|
+
"state": "failure"
|
32
|
+
},
|
33
|
+
{
|
34
|
+
"matcher": "path",
|
35
|
+
"argument": "DBClusters[0].Status",
|
36
|
+
"expected": "migration-failed",
|
37
|
+
"state": "failure"
|
38
|
+
},
|
39
|
+
{
|
40
|
+
"matcher": "path",
|
41
|
+
"argument": "DBClusters[0].Status",
|
42
|
+
"expected": "stopped",
|
43
|
+
"state": "retry"
|
44
|
+
}
|
45
|
+
]
|
46
|
+
},
|
47
|
+
"cluster_stopped": {
|
48
|
+
"operation": "DescribeDBClusters",
|
49
|
+
"delay": 30,
|
50
|
+
"maxAttempts": 60,
|
51
|
+
"acceptors": [
|
52
|
+
{
|
53
|
+
"matcher": "path",
|
54
|
+
"argument": "DBClusters[0].Status",
|
55
|
+
"expected": "stopped",
|
56
|
+
"state": "success"
|
57
|
+
},
|
58
|
+
{
|
59
|
+
"matcher": "path",
|
60
|
+
"argument": "DBClusters[0].Status",
|
61
|
+
"expected": "deleting",
|
62
|
+
"state": "failure"
|
63
|
+
},
|
64
|
+
{
|
65
|
+
"matcher": "path",
|
66
|
+
"argument": "DBClusters[0].Status",
|
67
|
+
"expected": "inaccessible-encryption-credentials",
|
68
|
+
"state": "failure"
|
69
|
+
},
|
70
|
+
{
|
71
|
+
"matcher": "path",
|
72
|
+
"argument": "DBClusters[0].Status",
|
73
|
+
"expected": "inaccessible-encryption-credentials-recoverable",
|
74
|
+
"state": "failure"
|
75
|
+
},
|
76
|
+
{
|
77
|
+
"matcher": "path",
|
78
|
+
"argument": "DBClusters[0].Status",
|
79
|
+
"expected": "migration-failed",
|
80
|
+
"state": "failure"
|
81
|
+
}
|
82
|
+
]
|
83
|
+
}
|
84
|
+
}
|
85
|
+
}
|
@@ -28,8 +28,9 @@ def get_provider_info():
|
|
28
28
|
"name": "Amazon",
|
29
29
|
"description": "Amazon integration (including `Amazon Web Services (AWS) <https://aws.amazon.com/>`__).\n",
|
30
30
|
"state": "ready",
|
31
|
-
"source-date-epoch":
|
31
|
+
"source-date-epoch": 1707636119,
|
32
32
|
"versions": [
|
33
|
+
"8.18.0",
|
33
34
|
"8.17.0",
|
34
35
|
"8.16.0",
|
35
36
|
"8.15.0",
|
@@ -107,7 +108,7 @@ def get_provider_info():
|
|
107
108
|
"devel-dependencies": [
|
108
109
|
"aiobotocore>=2.7.0",
|
109
110
|
"aws_xray_sdk>=2.12.0",
|
110
|
-
"moto[cloudformation,glue]>=
|
111
|
+
"moto[cloudformation,glue]>=5.0.0",
|
111
112
|
"mypy-boto3-appflow>=1.33.0",
|
112
113
|
"mypy-boto3-rds>=1.33.0",
|
113
114
|
"mypy-boto3-redshift-data>=1.33.0",
|
@@ -383,6 +384,13 @@ def get_provider_info():
|
|
383
384
|
"logo": "/integration-logos/aws/Amazon-Verified-Permissions.png",
|
384
385
|
"tags": ["aws"],
|
385
386
|
},
|
387
|
+
{
|
388
|
+
"integration-name": "Amazon Neptune",
|
389
|
+
"external-doc-url": "https://aws.amazon.com/neptune/",
|
390
|
+
"logo": "/integration-logos/aws/Amazon-Neptune_64.png",
|
391
|
+
"how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/neptune.rst"],
|
392
|
+
"tags": ["aws"],
|
393
|
+
},
|
386
394
|
],
|
387
395
|
"operators": [
|
388
396
|
{
|
@@ -491,6 +499,10 @@ def get_provider_info():
|
|
491
499
|
"integration-name": "AWS Glue DataBrew",
|
492
500
|
"python-modules": ["airflow.providers.amazon.aws.operators.glue_databrew"],
|
493
501
|
},
|
502
|
+
{
|
503
|
+
"integration-name": "Amazon Neptune",
|
504
|
+
"python-modules": ["airflow.providers.amazon.aws.operators.neptune"],
|
505
|
+
},
|
494
506
|
],
|
495
507
|
"sensors": [
|
496
508
|
{
|
@@ -730,6 +742,10 @@ def get_provider_info():
|
|
730
742
|
"integration-name": "Amazon Verified Permissions",
|
731
743
|
"python-modules": ["airflow.providers.amazon.aws.hooks.verified_permissions"],
|
732
744
|
},
|
745
|
+
{
|
746
|
+
"integration-name": "Amazon Neptune",
|
747
|
+
"python-modules": ["airflow.providers.amazon.aws.hooks.neptune"],
|
748
|
+
},
|
733
749
|
],
|
734
750
|
"triggers": [
|
735
751
|
{
|
@@ -802,6 +818,10 @@ def get_provider_info():
|
|
802
818
|
"integration-name": "AWS Glue DataBrew",
|
803
819
|
"python-modules": ["airflow.providers.amazon.aws.triggers.glue_databrew"],
|
804
820
|
},
|
821
|
+
{
|
822
|
+
"integration-name": "Amazon Neptune",
|
823
|
+
"python-modules": ["airflow.providers.amazon.aws.triggers.neptune"],
|
824
|
+
},
|
805
825
|
],
|
806
826
|
"transfers": [
|
807
827
|
{
|