apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
- airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
- airflow/providers/amazon/aws/hooks/athena.py +15 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
- airflow/providers/amazon/aws/hooks/emr.py +6 -0
- airflow/providers/amazon/aws/hooks/logs.py +85 -1
- airflow/providers/amazon/aws/hooks/neptune.py +85 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
- airflow/providers/amazon/aws/hooks/s3.py +4 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
- airflow/providers/amazon/aws/links/emr.py +122 -2
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
- airflow/providers/amazon/aws/operators/athena.py +4 -1
- airflow/providers/amazon/aws/operators/batch.py +5 -6
- airflow/providers/amazon/aws/operators/ecs.py +6 -2
- airflow/providers/amazon/aws/operators/eks.py +31 -26
- airflow/providers/amazon/aws/operators/emr.py +192 -26
- airflow/providers/amazon/aws/operators/glue.py +5 -2
- airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
- airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
- airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
- airflow/providers/amazon/aws/operators/neptune.py +218 -0
- airflow/providers/amazon/aws/operators/rds.py +21 -12
- airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
- airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
- airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
- airflow/providers/amazon/aws/operators/step_function.py +4 -1
- airflow/providers/amazon/aws/sensors/batch.py +2 -2
- airflow/providers/amazon/aws/sensors/ec2.py +4 -2
- airflow/providers/amazon/aws/sensors/emr.py +13 -6
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
- airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
- airflow/providers/amazon/aws/sensors/s3.py +3 -0
- airflow/providers/amazon/aws/sensors/sqs.py +4 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
- airflow/providers/amazon/aws/triggers/neptune.py +115 -0
- airflow/providers/amazon/aws/triggers/rds.py +9 -7
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
- airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
- airflow/providers/amazon/aws/utils/__init__.py +10 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
- airflow/providers/amazon/aws/utils/mixins.py +5 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
- airflow/providers/amazon/aws/waiters/neptune.json +85 -0
- airflow/providers/amazon/get_provider_info.py +26 -2
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -17,8 +17,11 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
import asyncio
|
20
21
|
import warnings
|
21
|
-
from typing import Generator
|
22
|
+
from typing import Any, AsyncGenerator, Generator
|
23
|
+
|
24
|
+
from botocore.exceptions import ClientError
|
22
25
|
|
23
26
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
24
27
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -151,3 +154,84 @@ class AwsLogsHook(AwsBaseHook):
|
|
151
154
|
num_consecutive_empty_response = 0
|
152
155
|
|
153
156
|
continuation_token.value = response["nextForwardToken"]
|
157
|
+
|
158
|
+
async def describe_log_streams_async(
|
159
|
+
self, log_group: str, stream_prefix: str, order_by: str, count: int
|
160
|
+
) -> dict[str, Any] | None:
|
161
|
+
"""Async function to get the list of log streams for the specified log group.
|
162
|
+
|
163
|
+
You can list all the log streams or filter the results by prefix. You can also control
|
164
|
+
how the results are ordered.
|
165
|
+
|
166
|
+
:param log_group: The name of the log group.
|
167
|
+
:param stream_prefix: The prefix to match.
|
168
|
+
:param order_by: If the value is LogStreamName , the results are ordered by log stream name.
|
169
|
+
If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
|
170
|
+
:param count: The maximum number of items returned
|
171
|
+
"""
|
172
|
+
async with self.async_conn as client:
|
173
|
+
try:
|
174
|
+
response: dict[str, Any] = await client.describe_log_streams(
|
175
|
+
logGroupName=log_group,
|
176
|
+
logStreamNamePrefix=stream_prefix,
|
177
|
+
orderBy=order_by,
|
178
|
+
limit=count,
|
179
|
+
)
|
180
|
+
return response
|
181
|
+
except ClientError as error:
|
182
|
+
# On the very first training job run on an account, there's no log group until
|
183
|
+
# the container starts logging, so ignore any errors thrown about that
|
184
|
+
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
185
|
+
return None
|
186
|
+
raise error
|
187
|
+
|
188
|
+
async def get_log_events_async(
|
189
|
+
self,
|
190
|
+
log_group: str,
|
191
|
+
log_stream_name: str,
|
192
|
+
start_time: int = 0,
|
193
|
+
skip: int = 0,
|
194
|
+
start_from_head: bool = True,
|
195
|
+
) -> AsyncGenerator[Any, dict[str, Any]]:
|
196
|
+
"""Yield all the available items in a single log stream.
|
197
|
+
|
198
|
+
:param log_group: The name of the log group.
|
199
|
+
:param log_stream_name: The name of the specific stream.
|
200
|
+
:param start_time: The time stamp value to start reading the logs from (default: 0).
|
201
|
+
:param skip: The number of log entries to skip at the start (default: 0).
|
202
|
+
This is for when there are multiple entries at the same timestamp.
|
203
|
+
:param start_from_head: whether to start from the beginning (True) of the log or
|
204
|
+
at the end of the log (False).
|
205
|
+
"""
|
206
|
+
next_token = None
|
207
|
+
while True:
|
208
|
+
if next_token is not None:
|
209
|
+
token_arg: dict[str, str] = {"nextToken": next_token}
|
210
|
+
else:
|
211
|
+
token_arg = {}
|
212
|
+
|
213
|
+
async with self.async_conn as client:
|
214
|
+
response = await client.get_log_events(
|
215
|
+
logGroupName=log_group,
|
216
|
+
logStreamName=log_stream_name,
|
217
|
+
startTime=start_time,
|
218
|
+
startFromHead=start_from_head,
|
219
|
+
**token_arg,
|
220
|
+
)
|
221
|
+
|
222
|
+
events = response["events"]
|
223
|
+
event_count = len(events)
|
224
|
+
|
225
|
+
if event_count > skip:
|
226
|
+
events = events[skip:]
|
227
|
+
skip = 0
|
228
|
+
else:
|
229
|
+
skip -= event_count
|
230
|
+
events = []
|
231
|
+
|
232
|
+
for event in events:
|
233
|
+
await asyncio.sleep(1)
|
234
|
+
yield event
|
235
|
+
|
236
|
+
if next_token != response["nextForwardToken"]:
|
237
|
+
next_token = response["nextForwardToken"]
|
@@ -0,0 +1,85 @@
|
|
1
|
+
#
|
2
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
3
|
+
# or more contributor license agreements. See the NOTICE file
|
4
|
+
# distributed with this work for additional information
|
5
|
+
# regarding copyright ownership. The ASF licenses this file
|
6
|
+
# to you under the Apache License, Version 2.0 (the
|
7
|
+
# "License"); you may not use this file except in compliance
|
8
|
+
# with the License. You may obtain a copy of the License at
|
9
|
+
#
|
10
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11
|
+
#
|
12
|
+
# Unless required by applicable law or agreed to in writing,
|
13
|
+
# software distributed under the License is distributed on an
|
14
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
15
|
+
# KIND, either express or implied. See the License for the
|
16
|
+
# specific language governing permissions and limitations
|
17
|
+
# under the License.
|
18
|
+
|
19
|
+
from __future__ import annotations
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
22
|
+
|
23
|
+
|
24
|
+
class NeptuneHook(AwsBaseHook):
|
25
|
+
"""
|
26
|
+
Interact with Amazon Neptune.
|
27
|
+
|
28
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
29
|
+
are passed down to the underlying AwsBaseHook.
|
30
|
+
|
31
|
+
.. seealso::
|
32
|
+
- :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
33
|
+
"""
|
34
|
+
|
35
|
+
AVAILABLE_STATES = ["available"]
|
36
|
+
STOPPED_STATES = ["stopped"]
|
37
|
+
|
38
|
+
def __init__(self, *args, **kwargs):
|
39
|
+
kwargs["client_type"] = "neptune"
|
40
|
+
super().__init__(*args, **kwargs)
|
41
|
+
|
42
|
+
def wait_for_cluster_availability(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
|
43
|
+
"""
|
44
|
+
Wait for Neptune cluster to start.
|
45
|
+
|
46
|
+
:param cluster_id: The ID of the cluster to wait for.
|
47
|
+
:param delay: Time in seconds to delay between polls.
|
48
|
+
:param max_attempts: Maximum number of attempts to poll for completion.
|
49
|
+
:return: The status of the cluster.
|
50
|
+
"""
|
51
|
+
self.get_waiter("cluster_available").wait(
|
52
|
+
DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
|
53
|
+
)
|
54
|
+
|
55
|
+
status = self.get_cluster_status(cluster_id)
|
56
|
+
self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
|
57
|
+
|
58
|
+
return status
|
59
|
+
|
60
|
+
def wait_for_cluster_stopped(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
|
61
|
+
"""
|
62
|
+
Wait for Neptune cluster to stop.
|
63
|
+
|
64
|
+
:param cluster_id: The ID of the cluster to wait for.
|
65
|
+
:param delay: Time in seconds to delay between polls.
|
66
|
+
:param max_attempts: Maximum number of attempts to poll for completion.
|
67
|
+
:return: The status of the cluster.
|
68
|
+
"""
|
69
|
+
self.get_waiter("cluster_stopped").wait(
|
70
|
+
DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
|
71
|
+
)
|
72
|
+
|
73
|
+
status = self.get_cluster_status(cluster_id)
|
74
|
+
self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
|
75
|
+
|
76
|
+
return status
|
77
|
+
|
78
|
+
def get_cluster_status(self, cluster_id: str) -> str:
|
79
|
+
"""
|
80
|
+
Get the status of a Neptune cluster.
|
81
|
+
|
82
|
+
:param cluster_id: The ID of the cluster to get the status of.
|
83
|
+
:return: The status of the cluster.
|
84
|
+
"""
|
85
|
+
return self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
|
@@ -18,10 +18,10 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import time
|
21
|
-
import warnings
|
22
21
|
from functools import cached_property
|
23
22
|
|
24
23
|
from botocore.exceptions import ClientError
|
24
|
+
from deprecated import deprecated
|
25
25
|
|
26
26
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
27
27
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -172,14 +172,15 @@ class QuickSightHook(AwsBaseHook):
|
|
172
172
|
return status
|
173
173
|
|
174
174
|
@cached_property
|
175
|
-
|
176
|
-
|
177
|
-
|
175
|
+
@deprecated(
|
176
|
+
reason=(
|
177
|
+
"`QuickSightHook.sts_hook` property is deprecated and will be removed in the future. "
|
178
178
|
"This property used for obtain AWS Account ID, "
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
179
|
+
"please consider to use `QuickSightHook.account_id` instead"
|
180
|
+
),
|
181
|
+
category=AirflowProviderDeprecationWarning,
|
182
|
+
)
|
183
|
+
def sts_hook(self):
|
183
184
|
from airflow.providers.amazon.aws.hooks.sts import StsHook
|
184
185
|
|
185
186
|
return StsHook(aws_conn_id=self.aws_conn_id)
|
@@ -17,10 +17,10 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
import asyncio
|
20
|
-
import warnings
|
21
20
|
from typing import Any, Sequence
|
22
21
|
|
23
22
|
import botocore.exceptions
|
23
|
+
from deprecated import deprecated
|
24
24
|
|
25
25
|
from airflow.exceptions import AirflowProviderDeprecationWarning
|
26
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
|
@@ -195,16 +195,17 @@ class RedshiftHook(AwsBaseHook):
|
|
195
195
|
return None
|
196
196
|
|
197
197
|
|
198
|
+
@deprecated(
|
199
|
+
reason=(
|
200
|
+
"`airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook` "
|
201
|
+
"has been deprecated and will be removed in future"
|
202
|
+
),
|
203
|
+
category=AirflowProviderDeprecationWarning,
|
204
|
+
)
|
198
205
|
class RedshiftAsyncHook(AwsBaseAsyncHook):
|
199
206
|
"""Interact with AWS Redshift using aiobotocore library."""
|
200
207
|
|
201
208
|
def __init__(self, *args, **kwargs):
|
202
|
-
warnings.warn(
|
203
|
-
"airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook has been deprecated and "
|
204
|
-
"will be removed in future",
|
205
|
-
AirflowProviderDeprecationWarning,
|
206
|
-
stacklevel=2,
|
207
|
-
)
|
208
209
|
kwargs["client_type"] = "redshift"
|
209
210
|
super().__init__(*args, **kwargs)
|
210
211
|
|
@@ -200,7 +200,7 @@ class RedshiftSQLHook(DbApiHook):
|
|
200
200
|
return redshift_connector.connect(**conn_kwargs)
|
201
201
|
|
202
202
|
def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
|
203
|
-
"""
|
203
|
+
"""Return Redshift specific information for OpenLineage."""
|
204
204
|
from airflow.providers.openlineage.sqlparser import DatabaseInfo
|
205
205
|
|
206
206
|
authority = self._get_openlineage_redshift_authority_part(connection)
|
@@ -252,9 +252,9 @@ class RedshiftSQLHook(DbApiHook):
|
|
252
252
|
return hostname
|
253
253
|
|
254
254
|
def get_openlineage_database_dialect(self, connection: Connection) -> str:
|
255
|
-
"""
|
255
|
+
"""Return redshift dialect."""
|
256
256
|
return "redshift"
|
257
257
|
|
258
258
|
def get_openlineage_default_schema(self) -> str | None:
|
259
|
-
"""
|
259
|
+
"""Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
|
260
260
|
return self.get_first("SELECT CURRENT_SCHEMA();")[0]
|
@@ -1369,6 +1369,10 @@ class S3Hook(AwsBaseHook):
|
|
1369
1369
|
"""
|
1370
1370
|
Download a file from the S3 location to the local file system.
|
1371
1371
|
|
1372
|
+
Note:
|
1373
|
+
This function shadows the 'download_file' method of S3 API, but it is not the same.
|
1374
|
+
If you want to use the original method from S3 API, please use 'S3Hook.get_conn().download_file()'
|
1375
|
+
|
1372
1376
|
.. seealso::
|
1373
1377
|
- :external+boto3:py:meth:`S3.Object.download_fileobj`
|
1374
1378
|
|
@@ -1386,12 +1390,6 @@ class S3Hook(AwsBaseHook):
|
|
1386
1390
|
Default: True.
|
1387
1391
|
:return: the file name.
|
1388
1392
|
"""
|
1389
|
-
self.log.info(
|
1390
|
-
"This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
|
1391
|
-
"want to use the original method from S3 API, please call "
|
1392
|
-
"'S3Hook.get_conn().download_file()'"
|
1393
|
-
)
|
1394
|
-
|
1395
1393
|
self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
|
1396
1394
|
|
1397
1395
|
try:
|
@@ -26,8 +26,9 @@ import warnings
|
|
26
26
|
from collections import Counter, namedtuple
|
27
27
|
from datetime import datetime
|
28
28
|
from functools import partial
|
29
|
-
from typing import Any, Callable, Generator, cast
|
29
|
+
from typing import Any, AsyncGenerator, Callable, Generator, cast
|
30
30
|
|
31
|
+
from asgiref.sync import sync_to_async
|
31
32
|
from botocore.exceptions import ClientError
|
32
33
|
|
33
34
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
@@ -310,10 +311,12 @@ class SageMakerHook(AwsBaseHook):
|
|
310
311
|
max_ingestion_time,
|
311
312
|
)
|
312
313
|
|
313
|
-
|
314
|
-
describe_response["
|
315
|
-
|
316
|
-
|
314
|
+
billable_seconds = SageMakerHook.count_billable_seconds(
|
315
|
+
training_start_time=describe_response["TrainingStartTime"],
|
316
|
+
training_end_time=describe_response["TrainingEndTime"],
|
317
|
+
instance_count=describe_response["ResourceConfig"]["InstanceCount"],
|
318
|
+
)
|
319
|
+
self.log.info("Billable seconds: %d", billable_seconds)
|
317
320
|
|
318
321
|
return response
|
319
322
|
|
@@ -811,10 +814,12 @@ class SageMakerHook(AwsBaseHook):
|
|
811
814
|
if status in failed_states:
|
812
815
|
reason = last_description.get("FailureReason", "(No reason provided)")
|
813
816
|
raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}")
|
814
|
-
|
815
|
-
last_description["
|
816
|
-
|
817
|
-
|
817
|
+
billable_seconds = SageMakerHook.count_billable_seconds(
|
818
|
+
training_start_time=last_description["TrainingStartTime"],
|
819
|
+
training_end_time=last_description["TrainingEndTime"],
|
820
|
+
instance_count=instance_count,
|
821
|
+
)
|
822
|
+
self.log.info("Billable seconds: %d", billable_seconds)
|
818
823
|
|
819
824
|
def list_training_jobs(
|
820
825
|
self, name_contains: str | None = None, max_results: int | None = None, **kwargs
|
@@ -1300,3 +1305,125 @@ class SageMakerHook(AwsBaseHook):
|
|
1300
1305
|
if "BestCandidate" in res:
|
1301
1306
|
return res["BestCandidate"]
|
1302
1307
|
return None
|
1308
|
+
|
1309
|
+
@staticmethod
|
1310
|
+
def count_billable_seconds(
|
1311
|
+
training_start_time: datetime, training_end_time: datetime, instance_count: int
|
1312
|
+
) -> int:
|
1313
|
+
billable_time = (training_end_time - training_start_time) * instance_count
|
1314
|
+
return int(billable_time.total_seconds()) + 1
|
1315
|
+
|
1316
|
+
async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
|
1317
|
+
"""
|
1318
|
+
Return the training job info associated with the name.
|
1319
|
+
|
1320
|
+
:param job_name: the name of the training job
|
1321
|
+
"""
|
1322
|
+
async with self.async_conn as client:
|
1323
|
+
response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
|
1324
|
+
return response
|
1325
|
+
|
1326
|
+
async def describe_training_job_with_log_async(
|
1327
|
+
self,
|
1328
|
+
job_name: str,
|
1329
|
+
positions: dict[str, Any],
|
1330
|
+
stream_names: list[str],
|
1331
|
+
instance_count: int,
|
1332
|
+
state: int,
|
1333
|
+
last_description: dict[str, Any],
|
1334
|
+
last_describe_job_call: float,
|
1335
|
+
) -> tuple[int, dict[str, Any], float]:
|
1336
|
+
"""
|
1337
|
+
Return the training job info associated with job_name and print CloudWatch logs.
|
1338
|
+
|
1339
|
+
:param job_name: name of the job to check status
|
1340
|
+
:param positions: A list of pairs of (timestamp, skip) which represents the last record
|
1341
|
+
read from each stream.
|
1342
|
+
:param stream_names: A list of the log stream names. The position of the stream in this list is
|
1343
|
+
the stream number.
|
1344
|
+
:param instance_count: Count of the instance created for the job initially
|
1345
|
+
:param state: log state
|
1346
|
+
:param last_description: Latest description of the training job
|
1347
|
+
:param last_describe_job_call: previous job called time
|
1348
|
+
"""
|
1349
|
+
log_group = "/aws/sagemaker/TrainingJobs"
|
1350
|
+
|
1351
|
+
if len(stream_names) < instance_count:
|
1352
|
+
logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
1353
|
+
streams = await logs_hook.describe_log_streams_async(
|
1354
|
+
log_group=log_group,
|
1355
|
+
stream_prefix=job_name + "/",
|
1356
|
+
order_by="LogStreamName",
|
1357
|
+
count=instance_count,
|
1358
|
+
)
|
1359
|
+
|
1360
|
+
stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
|
1361
|
+
positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])
|
1362
|
+
|
1363
|
+
if len(stream_names) > 0:
|
1364
|
+
async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
|
1365
|
+
self.log.info(event["message"])
|
1366
|
+
ts, count = positions[stream_names[idx]]
|
1367
|
+
if event["timestamp"] == ts:
|
1368
|
+
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
|
1369
|
+
else:
|
1370
|
+
positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
|
1371
|
+
|
1372
|
+
if state == LogState.COMPLETE:
|
1373
|
+
return state, last_description, last_describe_job_call
|
1374
|
+
|
1375
|
+
if state == LogState.JOB_COMPLETE:
|
1376
|
+
state = LogState.COMPLETE
|
1377
|
+
elif time.time() - last_describe_job_call >= 30:
|
1378
|
+
description = await self.describe_training_job_async(job_name)
|
1379
|
+
last_describe_job_call = time.time()
|
1380
|
+
|
1381
|
+
if await sync_to_async(secondary_training_status_changed)(description, last_description):
|
1382
|
+
self.log.info(
|
1383
|
+
await sync_to_async(secondary_training_status_message)(description, last_description)
|
1384
|
+
)
|
1385
|
+
last_description = description
|
1386
|
+
|
1387
|
+
status = description["TrainingJobStatus"]
|
1388
|
+
|
1389
|
+
if status not in self.non_terminal_states:
|
1390
|
+
state = LogState.JOB_COMPLETE
|
1391
|
+
return state, last_description, last_describe_job_call
|
1392
|
+
|
1393
|
+
async def get_multi_stream(
|
1394
|
+
self, log_group: str, streams: list[str], positions: dict[str, Any]
|
1395
|
+
) -> AsyncGenerator[Any, tuple[int, Any | None]]:
|
1396
|
+
"""Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
|
1397
|
+
|
1398
|
+
:param log_group: The name of the log group.
|
1399
|
+
:param streams: A list of the log stream names. The position of the stream in this list is
|
1400
|
+
the stream number.
|
1401
|
+
:param positions: A list of pairs of (timestamp, skip) which represents the last record
|
1402
|
+
read from each stream.
|
1403
|
+
"""
|
1404
|
+
positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
|
1405
|
+
events: list[Any | None] = []
|
1406
|
+
|
1407
|
+
logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
1408
|
+
event_iters = [
|
1409
|
+
logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
|
1410
|
+
for s in streams
|
1411
|
+
]
|
1412
|
+
for event_stream in event_iters:
|
1413
|
+
if not event_stream:
|
1414
|
+
events.append(None)
|
1415
|
+
continue
|
1416
|
+
|
1417
|
+
try:
|
1418
|
+
events.append(await event_stream.__anext__())
|
1419
|
+
except StopAsyncIteration:
|
1420
|
+
events.append(None)
|
1421
|
+
|
1422
|
+
while any(events):
|
1423
|
+
i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
|
1424
|
+
yield i, events[i]
|
1425
|
+
|
1426
|
+
try:
|
1427
|
+
events[i] = await event_iters[i].__anext__()
|
1428
|
+
except StopAsyncIteration:
|
1429
|
+
events[i] = None
|
@@ -17,8 +17,10 @@
|
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
19
|
from typing import TYPE_CHECKING, Any
|
20
|
+
from urllib.parse import ParseResult, quote_plus, urlparse
|
20
21
|
|
21
22
|
from airflow.exceptions import AirflowException
|
23
|
+
from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
|
22
24
|
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
|
23
25
|
from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
|
24
26
|
from airflow.utils.helpers import exactly_one
|
@@ -28,7 +30,7 @@ if TYPE_CHECKING:
|
|
28
30
|
|
29
31
|
|
30
32
|
class EmrClusterLink(BaseAwsLink):
|
31
|
-
"""Helper class for constructing
|
33
|
+
"""Helper class for constructing Amazon EMR Cluster Link."""
|
32
34
|
|
33
35
|
name = "EMR Cluster"
|
34
36
|
key = "emr_cluster"
|
@@ -36,7 +38,7 @@ class EmrClusterLink(BaseAwsLink):
|
|
36
38
|
|
37
39
|
|
38
40
|
class EmrLogsLink(BaseAwsLink):
|
39
|
-
"""Helper class for constructing
|
41
|
+
"""Helper class for constructing Amazon EMR Logs Link."""
|
40
42
|
|
41
43
|
name = "EMR Cluster Logs"
|
42
44
|
key = "emr_logs"
|
@@ -48,6 +50,49 @@ class EmrLogsLink(BaseAwsLink):
|
|
48
50
|
return super().format_link(**kwargs)
|
49
51
|
|
50
52
|
|
53
|
+
def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: str) -> str:
|
54
|
+
"""
|
55
|
+
Retrieve the S3 URI to EMR Serverless Job logs.
|
56
|
+
|
57
|
+
Any EMR Serverless job may have a different S3 logging location (or none), which is an S3 URI.
|
58
|
+
The logging location is then {s3_uri}/applications/{application_id}/jobs/{job_run_id}.
|
59
|
+
"""
|
60
|
+
return f"{s3_log_uri}/applications/{application_id}/jobs/{job_run_id}"
|
61
|
+
|
62
|
+
|
63
|
+
def get_serverless_dashboard_url(
|
64
|
+
*,
|
65
|
+
aws_conn_id: str | None = None,
|
66
|
+
emr_serverless_client: boto3.client = None,
|
67
|
+
application_id: str,
|
68
|
+
job_run_id: str,
|
69
|
+
) -> ParseResult | None:
|
70
|
+
"""
|
71
|
+
Retrieve the URL to EMR Serverless dashboard.
|
72
|
+
|
73
|
+
The URL is a one-use, ephemeral link that expires in 1 hour and is accessible without authentication.
|
74
|
+
|
75
|
+
Either an AWS connection ID or existing EMR Serverless client must be passed.
|
76
|
+
If the connection ID is passed, a client is generated using that connection.
|
77
|
+
"""
|
78
|
+
if not exactly_one(aws_conn_id, emr_serverless_client):
|
79
|
+
raise AirflowException("Requires either an AWS connection ID or an EMR Serverless Client.")
|
80
|
+
|
81
|
+
if aws_conn_id:
|
82
|
+
# If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt
|
83
|
+
# so that the rest of the links load in a reasonable time frame.
|
84
|
+
hook = EmrServerlessHook(aws_conn_id=aws_conn_id, config={"retries": {"total_max_attempts": 1}})
|
85
|
+
emr_serverless_client = hook.conn
|
86
|
+
|
87
|
+
response = emr_serverless_client.get_dashboard_for_job_run(
|
88
|
+
applicationId=application_id, jobRunId=job_run_id
|
89
|
+
)
|
90
|
+
if "url" not in response:
|
91
|
+
return None
|
92
|
+
log_uri = urlparse(response["url"])
|
93
|
+
return log_uri
|
94
|
+
|
95
|
+
|
51
96
|
def get_log_uri(
|
52
97
|
*, cluster: dict[str, Any] | None = None, emr_client: boto3.client = None, job_flow_id: str | None = None
|
53
98
|
) -> str | None:
|
@@ -66,3 +111,78 @@ def get_log_uri(
|
|
66
111
|
return None
|
67
112
|
log_uri = S3Hook.parse_s3_url(cluster_info["LogUri"])
|
68
113
|
return "/".join(log_uri)
|
114
|
+
|
115
|
+
|
116
|
+
class EmrServerlessLogsLink(BaseAwsLink):
|
117
|
+
"""Helper class for constructing Amazon EMR Serverless link to Spark stdout logs."""
|
118
|
+
|
119
|
+
name = "Spark Driver stdout"
|
120
|
+
key = "emr_serverless_logs"
|
121
|
+
|
122
|
+
def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str:
|
123
|
+
if not application_id or not job_run_id:
|
124
|
+
return ""
|
125
|
+
url = get_serverless_dashboard_url(
|
126
|
+
aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id
|
127
|
+
)
|
128
|
+
if url:
|
129
|
+
return url._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl()
|
130
|
+
else:
|
131
|
+
return ""
|
132
|
+
|
133
|
+
|
134
|
+
class EmrServerlessDashboardLink(BaseAwsLink):
|
135
|
+
"""Helper class for constructing Amazon EMR Serverless Dashboard Link."""
|
136
|
+
|
137
|
+
name = "EMR Serverless Dashboard"
|
138
|
+
key = "emr_serverless_dashboard"
|
139
|
+
|
140
|
+
def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str:
|
141
|
+
if not application_id or not job_run_id:
|
142
|
+
return ""
|
143
|
+
url = get_serverless_dashboard_url(
|
144
|
+
aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id
|
145
|
+
)
|
146
|
+
if url:
|
147
|
+
return url.geturl()
|
148
|
+
else:
|
149
|
+
return ""
|
150
|
+
|
151
|
+
|
152
|
+
class EmrServerlessS3LogsLink(BaseAwsLink):
|
153
|
+
"""Helper class for constructing link to S3 console for Amazon EMR Serverless Logs."""
|
154
|
+
|
155
|
+
name = "S3 Logs"
|
156
|
+
key = "emr_serverless_s3_logs"
|
157
|
+
format_str = BASE_AWS_CONSOLE_LINK + (
|
158
|
+
"/s3/buckets/{bucket_name}?region={region_name}"
|
159
|
+
"&prefix={prefix}/applications/{application_id}/jobs/{job_run_id}/"
|
160
|
+
)
|
161
|
+
|
162
|
+
def format_link(self, **kwargs) -> str:
|
163
|
+
bucket, prefix = S3Hook.parse_s3_url(kwargs["log_uri"])
|
164
|
+
kwargs["bucket_name"] = bucket
|
165
|
+
kwargs["prefix"] = prefix.rstrip("/")
|
166
|
+
return super().format_link(**kwargs)
|
167
|
+
|
168
|
+
|
169
|
+
class EmrServerlessCloudWatchLogsLink(BaseAwsLink):
|
170
|
+
"""
|
171
|
+
Helper class for constructing link to CloudWatch console for Amazon EMR Serverless Logs.
|
172
|
+
|
173
|
+
This is a deep link that filters on a specific job run.
|
174
|
+
"""
|
175
|
+
|
176
|
+
name = "CloudWatch Logs"
|
177
|
+
key = "emr_serverless_cloudwatch_logs"
|
178
|
+
format_str = (
|
179
|
+
BASE_AWS_CONSOLE_LINK
|
180
|
+
+ "/cloudwatch/home?region={region_name}#logsV2:log-groups/log-group/{awslogs_group}{stream_prefix}"
|
181
|
+
)
|
182
|
+
|
183
|
+
def format_link(self, **kwargs) -> str:
|
184
|
+
kwargs["awslogs_group"] = quote_plus(kwargs["awslogs_group"])
|
185
|
+
kwargs["stream_prefix"] = quote_plus("?logStreamNameFilter=").replace("%", "$") + quote_plus(
|
186
|
+
kwargs["stream_prefix"]
|
187
|
+
)
|
188
|
+
return super().format_link(**kwargs)
|
@@ -17,7 +17,7 @@
|
|
17
17
|
# under the License.
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
from datetime import date, datetime, timedelta
|
20
|
+
from datetime import date, datetime, timedelta, timezone
|
21
21
|
from functools import cached_property
|
22
22
|
from typing import TYPE_CHECKING, Any
|
23
23
|
|
@@ -163,7 +163,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
|
|
163
163
|
return "\n".join(self._event_to_str(event) for event in events)
|
164
164
|
|
165
165
|
def _event_to_str(self, event: dict) -> str:
|
166
|
-
event_dt = datetime.
|
166
|
+
event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
|
167
167
|
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
168
168
|
message = event["message"]
|
169
169
|
return f"[{formatted_event_dt}] {message}"
|
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.athena import AthenaHook
|
|
26
26
|
from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
|
27
27
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
28
28
|
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
|
29
|
+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
29
30
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
30
31
|
|
31
32
|
if TYPE_CHECKING:
|
@@ -179,7 +180,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
|
|
179
180
|
|
180
181
|
return self.query_execution_id
|
181
182
|
|
182
|
-
def execute_complete(self, context, event=None):
|
183
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
184
|
+
event = validate_execute_complete_event(event)
|
185
|
+
|
183
186
|
if event["status"] != "success":
|
184
187
|
raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
|
185
188
|
return event["value"]
|