apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
  3. airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
  4. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
  5. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
  7. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  8. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
  9. airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
  10. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
  11. airflow/providers/amazon/aws/hooks/athena.py +15 -2
  12. airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
  13. airflow/providers/amazon/aws/hooks/emr.py +6 -0
  14. airflow/providers/amazon/aws/hooks/logs.py +85 -1
  15. airflow/providers/amazon/aws/hooks/neptune.py +85 -0
  16. airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
  17. airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
  18. airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
  19. airflow/providers/amazon/aws/hooks/s3.py +4 -6
  20. airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
  21. airflow/providers/amazon/aws/links/emr.py +122 -2
  22. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  23. airflow/providers/amazon/aws/operators/athena.py +4 -1
  24. airflow/providers/amazon/aws/operators/batch.py +5 -6
  25. airflow/providers/amazon/aws/operators/ecs.py +6 -2
  26. airflow/providers/amazon/aws/operators/eks.py +31 -26
  27. airflow/providers/amazon/aws/operators/emr.py +192 -26
  28. airflow/providers/amazon/aws/operators/glue.py +5 -2
  29. airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
  30. airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
  31. airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
  32. airflow/providers/amazon/aws/operators/neptune.py +218 -0
  33. airflow/providers/amazon/aws/operators/rds.py +21 -12
  34. airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
  35. airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
  36. airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
  37. airflow/providers/amazon/aws/operators/step_function.py +4 -1
  38. airflow/providers/amazon/aws/sensors/batch.py +2 -2
  39. airflow/providers/amazon/aws/sensors/ec2.py +4 -2
  40. airflow/providers/amazon/aws/sensors/emr.py +13 -6
  41. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
  42. airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
  43. airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
  44. airflow/providers/amazon/aws/sensors/s3.py +3 -0
  45. airflow/providers/amazon/aws/sensors/sqs.py +4 -1
  46. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  47. airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
  48. airflow/providers/amazon/aws/triggers/neptune.py +115 -0
  49. airflow/providers/amazon/aws/triggers/rds.py +9 -7
  50. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
  51. airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
  52. airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
  53. airflow/providers/amazon/aws/utils/__init__.py +10 -0
  54. airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
  55. airflow/providers/amazon/aws/utils/mixins.py +5 -1
  56. airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
  57. airflow/providers/amazon/aws/waiters/neptune.json +85 -0
  58. airflow/providers/amazon/get_provider_info.py +26 -2
  59. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
  60. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
  61. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
  62. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -17,8 +17,11 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
+ import asyncio
20
21
  import warnings
21
- from typing import Generator
22
+ from typing import Any, AsyncGenerator, Generator
23
+
24
+ from botocore.exceptions import ClientError
22
25
 
23
26
  from airflow.exceptions import AirflowProviderDeprecationWarning
24
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -151,3 +154,84 @@ class AwsLogsHook(AwsBaseHook):
151
154
  num_consecutive_empty_response = 0
152
155
 
153
156
  continuation_token.value = response["nextForwardToken"]
157
+
158
+ async def describe_log_streams_async(
159
+ self, log_group: str, stream_prefix: str, order_by: str, count: int
160
+ ) -> dict[str, Any] | None:
161
+ """Async function to get the list of log streams for the specified log group.
162
+
163
+ You can list all the log streams or filter the results by prefix. You can also control
164
+ how the results are ordered.
165
+
166
+ :param log_group: The name of the log group.
167
+ :param stream_prefix: The prefix to match.
168
+ :param order_by: If the value is LogStreamName , the results are ordered by log stream name.
169
+ If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
170
+ :param count: The maximum number of items returned
171
+ """
172
+ async with self.async_conn as client:
173
+ try:
174
+ response: dict[str, Any] = await client.describe_log_streams(
175
+ logGroupName=log_group,
176
+ logStreamNamePrefix=stream_prefix,
177
+ orderBy=order_by,
178
+ limit=count,
179
+ )
180
+ return response
181
+ except ClientError as error:
182
+ # On the very first training job run on an account, there's no log group until
183
+ # the container starts logging, so ignore any errors thrown about that
184
+ if error.response["Error"]["Code"] == "ResourceNotFoundException":
185
+ return None
186
+ raise error
187
+
188
+ async def get_log_events_async(
189
+ self,
190
+ log_group: str,
191
+ log_stream_name: str,
192
+ start_time: int = 0,
193
+ skip: int = 0,
194
+ start_from_head: bool = True,
195
+ ) -> AsyncGenerator[Any, dict[str, Any]]:
196
+ """Yield all the available items in a single log stream.
197
+
198
+ :param log_group: The name of the log group.
199
+ :param log_stream_name: The name of the specific stream.
200
+ :param start_time: The time stamp value to start reading the logs from (default: 0).
201
+ :param skip: The number of log entries to skip at the start (default: 0).
202
+ This is for when there are multiple entries at the same timestamp.
203
+ :param start_from_head: whether to start from the beginning (True) of the log or
204
+ at the end of the log (False).
205
+ """
206
+ next_token = None
207
+ while True:
208
+ if next_token is not None:
209
+ token_arg: dict[str, str] = {"nextToken": next_token}
210
+ else:
211
+ token_arg = {}
212
+
213
+ async with self.async_conn as client:
214
+ response = await client.get_log_events(
215
+ logGroupName=log_group,
216
+ logStreamName=log_stream_name,
217
+ startTime=start_time,
218
+ startFromHead=start_from_head,
219
+ **token_arg,
220
+ )
221
+
222
+ events = response["events"]
223
+ event_count = len(events)
224
+
225
+ if event_count > skip:
226
+ events = events[skip:]
227
+ skip = 0
228
+ else:
229
+ skip -= event_count
230
+ events = []
231
+
232
+ for event in events:
233
+ await asyncio.sleep(1)
234
+ yield event
235
+
236
+ if next_token != response["nextForwardToken"]:
237
+ next_token = response["nextForwardToken"]
@@ -0,0 +1,85 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+
19
+ from __future__ import annotations
20
+
21
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
22
+
23
+
24
+ class NeptuneHook(AwsBaseHook):
25
+ """
26
+ Interact with Amazon Neptune.
27
+
28
+ Additional arguments (such as ``aws_conn_id``) may be specified and
29
+ are passed down to the underlying AwsBaseHook.
30
+
31
+ .. seealso::
32
+ - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
33
+ """
34
+
35
+ AVAILABLE_STATES = ["available"]
36
+ STOPPED_STATES = ["stopped"]
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ kwargs["client_type"] = "neptune"
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def wait_for_cluster_availability(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
43
+ """
44
+ Wait for Neptune cluster to start.
45
+
46
+ :param cluster_id: The ID of the cluster to wait for.
47
+ :param delay: Time in seconds to delay between polls.
48
+ :param max_attempts: Maximum number of attempts to poll for completion.
49
+ :return: The status of the cluster.
50
+ """
51
+ self.get_waiter("cluster_available").wait(
52
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
53
+ )
54
+
55
+ status = self.get_cluster_status(cluster_id)
56
+ self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
57
+
58
+ return status
59
+
60
+ def wait_for_cluster_stopped(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
61
+ """
62
+ Wait for Neptune cluster to stop.
63
+
64
+ :param cluster_id: The ID of the cluster to wait for.
65
+ :param delay: Time in seconds to delay between polls.
66
+ :param max_attempts: Maximum number of attempts to poll for completion.
67
+ :return: The status of the cluster.
68
+ """
69
+ self.get_waiter("cluster_stopped").wait(
70
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
71
+ )
72
+
73
+ status = self.get_cluster_status(cluster_id)
74
+ self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
75
+
76
+ return status
77
+
78
+ def get_cluster_status(self, cluster_id: str) -> str:
79
+ """
80
+ Get the status of a Neptune cluster.
81
+
82
+ :param cluster_id: The ID of the cluster to get the status of.
83
+ :return: The status of the cluster.
84
+ """
85
+ return self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
@@ -18,10 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import time
21
- import warnings
22
21
  from functools import cached_property
23
22
 
24
23
  from botocore.exceptions import ClientError
24
+ from deprecated import deprecated
25
25
 
26
26
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
27
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -172,14 +172,15 @@ class QuickSightHook(AwsBaseHook):
172
172
  return status
173
173
 
174
174
  @cached_property
175
- def sts_hook(self):
176
- warnings.warn(
177
- f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. "
175
+ @deprecated(
176
+ reason=(
177
+ "`QuickSightHook.sts_hook` property is deprecated and will be removed in the future. "
178
178
  "This property used for obtain AWS Account ID, "
179
- f"please consider to use `{type(self).__name__}.account_id` instead",
180
- AirflowProviderDeprecationWarning,
181
- stacklevel=2,
182
- )
179
+ "please consider to use `QuickSightHook.account_id` instead"
180
+ ),
181
+ category=AirflowProviderDeprecationWarning,
182
+ )
183
+ def sts_hook(self):
183
184
  from airflow.providers.amazon.aws.hooks.sts import StsHook
184
185
 
185
186
  return StsHook(aws_conn_id=self.aws_conn_id)
@@ -17,10 +17,10 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  import asyncio
20
- import warnings
21
20
  from typing import Any, Sequence
22
21
 
23
22
  import botocore.exceptions
23
+ from deprecated import deprecated
24
24
 
25
25
  from airflow.exceptions import AirflowProviderDeprecationWarning
26
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook
@@ -195,16 +195,17 @@ class RedshiftHook(AwsBaseHook):
195
195
  return None
196
196
 
197
197
 
198
+ @deprecated(
199
+ reason=(
200
+ "`airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook` "
201
+ "has been deprecated and will be removed in future"
202
+ ),
203
+ category=AirflowProviderDeprecationWarning,
204
+ )
198
205
  class RedshiftAsyncHook(AwsBaseAsyncHook):
199
206
  """Interact with AWS Redshift using aiobotocore library."""
200
207
 
201
208
  def __init__(self, *args, **kwargs):
202
- warnings.warn(
203
- "airflow.providers.amazon.aws.hook.base_aws.RedshiftAsyncHook has been deprecated and "
204
- "will be removed in future",
205
- AirflowProviderDeprecationWarning,
206
- stacklevel=2,
207
- )
208
209
  kwargs["client_type"] = "redshift"
209
210
  super().__init__(*args, **kwargs)
210
211
 
@@ -200,7 +200,7 @@ class RedshiftSQLHook(DbApiHook):
200
200
  return redshift_connector.connect(**conn_kwargs)
201
201
 
202
202
  def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
203
- """Returns Redshift specific information for OpenLineage."""
203
+ """Return Redshift specific information for OpenLineage."""
204
204
  from airflow.providers.openlineage.sqlparser import DatabaseInfo
205
205
 
206
206
  authority = self._get_openlineage_redshift_authority_part(connection)
@@ -252,9 +252,9 @@ class RedshiftSQLHook(DbApiHook):
252
252
  return hostname
253
253
 
254
254
  def get_openlineage_database_dialect(self, connection: Connection) -> str:
255
- """Returns redshift dialect."""
255
+ """Return redshift dialect."""
256
256
  return "redshift"
257
257
 
258
258
  def get_openlineage_default_schema(self) -> str | None:
259
- """Returns current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
259
+ """Return current schema. This is usually changed with ``SEARCH_PATH`` parameter."""
260
260
  return self.get_first("SELECT CURRENT_SCHEMA();")[0]
@@ -1369,6 +1369,10 @@ class S3Hook(AwsBaseHook):
1369
1369
  """
1370
1370
  Download a file from the S3 location to the local file system.
1371
1371
 
1372
+ Note:
1373
+ This function shadows the 'download_file' method of S3 API, but it is not the same.
1374
+ If you want to use the original method from S3 API, please use 'S3Hook.get_conn().download_file()'
1375
+
1372
1376
  .. seealso::
1373
1377
  - :external+boto3:py:meth:`S3.Object.download_fileobj`
1374
1378
 
@@ -1386,12 +1390,6 @@ class S3Hook(AwsBaseHook):
1386
1390
  Default: True.
1387
1391
  :return: the file name.
1388
1392
  """
1389
- self.log.info(
1390
- "This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
1391
- "want to use the original method from S3 API, please call "
1392
- "'S3Hook.get_conn().download_file()'"
1393
- )
1394
-
1395
1393
  self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
1396
1394
 
1397
1395
  try:
@@ -26,8 +26,9 @@ import warnings
26
26
  from collections import Counter, namedtuple
27
27
  from datetime import datetime
28
28
  from functools import partial
29
- from typing import Any, Callable, Generator, cast
29
+ from typing import Any, AsyncGenerator, Callable, Generator, cast
30
30
 
31
+ from asgiref.sync import sync_to_async
31
32
  from botocore.exceptions import ClientError
32
33
 
33
34
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
@@ -310,10 +311,12 @@ class SageMakerHook(AwsBaseHook):
310
311
  max_ingestion_time,
311
312
  )
312
313
 
313
- billable_time = (
314
- describe_response["TrainingEndTime"] - describe_response["TrainingStartTime"]
315
- ) * describe_response["ResourceConfig"]["InstanceCount"]
316
- self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
314
+ billable_seconds = SageMakerHook.count_billable_seconds(
315
+ training_start_time=describe_response["TrainingStartTime"],
316
+ training_end_time=describe_response["TrainingEndTime"],
317
+ instance_count=describe_response["ResourceConfig"]["InstanceCount"],
318
+ )
319
+ self.log.info("Billable seconds: %d", billable_seconds)
317
320
 
318
321
  return response
319
322
 
@@ -811,10 +814,12 @@ class SageMakerHook(AwsBaseHook):
811
814
  if status in failed_states:
812
815
  reason = last_description.get("FailureReason", "(No reason provided)")
813
816
  raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}")
814
- billable_time = (
815
- last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
816
- ) * instance_count
817
- self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
817
+ billable_seconds = SageMakerHook.count_billable_seconds(
818
+ training_start_time=last_description["TrainingStartTime"],
819
+ training_end_time=last_description["TrainingEndTime"],
820
+ instance_count=instance_count,
821
+ )
822
+ self.log.info("Billable seconds: %d", billable_seconds)
818
823
 
819
824
  def list_training_jobs(
820
825
  self, name_contains: str | None = None, max_results: int | None = None, **kwargs
@@ -1300,3 +1305,125 @@ class SageMakerHook(AwsBaseHook):
1300
1305
  if "BestCandidate" in res:
1301
1306
  return res["BestCandidate"]
1302
1307
  return None
1308
+
1309
+ @staticmethod
1310
+ def count_billable_seconds(
1311
+ training_start_time: datetime, training_end_time: datetime, instance_count: int
1312
+ ) -> int:
1313
+ billable_time = (training_end_time - training_start_time) * instance_count
1314
+ return int(billable_time.total_seconds()) + 1
1315
+
1316
+ async def describe_training_job_async(self, job_name: str) -> dict[str, Any]:
1317
+ """
1318
+ Return the training job info associated with the name.
1319
+
1320
+ :param job_name: the name of the training job
1321
+ """
1322
+ async with self.async_conn as client:
1323
+ response: dict[str, Any] = await client.describe_training_job(TrainingJobName=job_name)
1324
+ return response
1325
+
1326
+ async def describe_training_job_with_log_async(
1327
+ self,
1328
+ job_name: str,
1329
+ positions: dict[str, Any],
1330
+ stream_names: list[str],
1331
+ instance_count: int,
1332
+ state: int,
1333
+ last_description: dict[str, Any],
1334
+ last_describe_job_call: float,
1335
+ ) -> tuple[int, dict[str, Any], float]:
1336
+ """
1337
+ Return the training job info associated with job_name and print CloudWatch logs.
1338
+
1339
+ :param job_name: name of the job to check status
1340
+ :param positions: A list of pairs of (timestamp, skip) which represents the last record
1341
+ read from each stream.
1342
+ :param stream_names: A list of the log stream names. The position of the stream in this list is
1343
+ the stream number.
1344
+ :param instance_count: Count of the instance created for the job initially
1345
+ :param state: log state
1346
+ :param last_description: Latest description of the training job
1347
+ :param last_describe_job_call: previous job called time
1348
+ """
1349
+ log_group = "/aws/sagemaker/TrainingJobs"
1350
+
1351
+ if len(stream_names) < instance_count:
1352
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
1353
+ streams = await logs_hook.describe_log_streams_async(
1354
+ log_group=log_group,
1355
+ stream_prefix=job_name + "/",
1356
+ order_by="LogStreamName",
1357
+ count=instance_count,
1358
+ )
1359
+
1360
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]] if streams else []
1361
+ positions.update([(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions])
1362
+
1363
+ if len(stream_names) > 0:
1364
+ async for idx, event in self.get_multi_stream(log_group, stream_names, positions):
1365
+ self.log.info(event["message"])
1366
+ ts, count = positions[stream_names[idx]]
1367
+ if event["timestamp"] == ts:
1368
+ positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
1369
+ else:
1370
+ positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
1371
+
1372
+ if state == LogState.COMPLETE:
1373
+ return state, last_description, last_describe_job_call
1374
+
1375
+ if state == LogState.JOB_COMPLETE:
1376
+ state = LogState.COMPLETE
1377
+ elif time.time() - last_describe_job_call >= 30:
1378
+ description = await self.describe_training_job_async(job_name)
1379
+ last_describe_job_call = time.time()
1380
+
1381
+ if await sync_to_async(secondary_training_status_changed)(description, last_description):
1382
+ self.log.info(
1383
+ await sync_to_async(secondary_training_status_message)(description, last_description)
1384
+ )
1385
+ last_description = description
1386
+
1387
+ status = description["TrainingJobStatus"]
1388
+
1389
+ if status not in self.non_terminal_states:
1390
+ state = LogState.JOB_COMPLETE
1391
+ return state, last_description, last_describe_job_call
1392
+
1393
+ async def get_multi_stream(
1394
+ self, log_group: str, streams: list[str], positions: dict[str, Any]
1395
+ ) -> AsyncGenerator[Any, tuple[int, Any | None]]:
1396
+ """Iterate over the available events coming and interleaving the events from each stream so they're yielded in timestamp order.
1397
+
1398
+ :param log_group: The name of the log group.
1399
+ :param streams: A list of the log stream names. The position of the stream in this list is
1400
+ the stream number.
1401
+ :param positions: A list of pairs of (timestamp, skip) which represents the last record
1402
+ read from each stream.
1403
+ """
1404
+ positions = positions or {s: Position(timestamp=0, skip=0) for s in streams}
1405
+ events: list[Any | None] = []
1406
+
1407
+ logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
1408
+ event_iters = [
1409
+ logs_hook.get_log_events_async(log_group, s, positions[s].timestamp, positions[s].skip)
1410
+ for s in streams
1411
+ ]
1412
+ for event_stream in event_iters:
1413
+ if not event_stream:
1414
+ events.append(None)
1415
+ continue
1416
+
1417
+ try:
1418
+ events.append(await event_stream.__anext__())
1419
+ except StopAsyncIteration:
1420
+ events.append(None)
1421
+
1422
+ while any(events):
1423
+ i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
1424
+ yield i, events[i]
1425
+
1426
+ try:
1427
+ events[i] = await event_iters[i].__anext__()
1428
+ except StopAsyncIteration:
1429
+ events[i] = None
@@ -17,8 +17,10 @@
17
17
  from __future__ import annotations
18
18
 
19
19
  from typing import TYPE_CHECKING, Any
20
+ from urllib.parse import ParseResult, quote_plus, urlparse
20
21
 
21
22
  from airflow.exceptions import AirflowException
23
+ from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
22
24
  from airflow.providers.amazon.aws.hooks.s3 import S3Hook
23
25
  from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
24
26
  from airflow.utils.helpers import exactly_one
@@ -28,7 +30,7 @@ if TYPE_CHECKING:
28
30
 
29
31
 
30
32
  class EmrClusterLink(BaseAwsLink):
31
- """Helper class for constructing AWS EMR Cluster Link."""
33
+ """Helper class for constructing Amazon EMR Cluster Link."""
32
34
 
33
35
  name = "EMR Cluster"
34
36
  key = "emr_cluster"
@@ -36,7 +38,7 @@ class EmrClusterLink(BaseAwsLink):
36
38
 
37
39
 
38
40
  class EmrLogsLink(BaseAwsLink):
39
- """Helper class for constructing AWS EMR Logs Link."""
41
+ """Helper class for constructing Amazon EMR Logs Link."""
40
42
 
41
43
  name = "EMR Cluster Logs"
42
44
  key = "emr_logs"
@@ -48,6 +50,49 @@ class EmrLogsLink(BaseAwsLink):
48
50
  return super().format_link(**kwargs)
49
51
 
50
52
 
53
+ def get_serverless_log_uri(*, s3_log_uri: str, application_id: str, job_run_id: str) -> str:
54
+ """
55
+ Retrieve the S3 URI to EMR Serverless Job logs.
56
+
57
+ Any EMR Serverless job may have a different S3 logging location (or none), which is an S3 URI.
58
+ The logging location is then {s3_uri}/applications/{application_id}/jobs/{job_run_id}.
59
+ """
60
+ return f"{s3_log_uri}/applications/{application_id}/jobs/{job_run_id}"
61
+
62
+
63
+ def get_serverless_dashboard_url(
64
+ *,
65
+ aws_conn_id: str | None = None,
66
+ emr_serverless_client: boto3.client = None,
67
+ application_id: str,
68
+ job_run_id: str,
69
+ ) -> ParseResult | None:
70
+ """
71
+ Retrieve the URL to EMR Serverless dashboard.
72
+
73
+ The URL is a one-use, ephemeral link that expires in 1 hour and is accessible without authentication.
74
+
75
+ Either an AWS connection ID or existing EMR Serverless client must be passed.
76
+ If the connection ID is passed, a client is generated using that connection.
77
+ """
78
+ if not exactly_one(aws_conn_id, emr_serverless_client):
79
+ raise AirflowException("Requires either an AWS connection ID or an EMR Serverless Client.")
80
+
81
+ if aws_conn_id:
82
+ # If get_dashboard_for_job_run fails for whatever reason, fail after 1 attempt
83
+ # so that the rest of the links load in a reasonable time frame.
84
+ hook = EmrServerlessHook(aws_conn_id=aws_conn_id, config={"retries": {"total_max_attempts": 1}})
85
+ emr_serverless_client = hook.conn
86
+
87
+ response = emr_serverless_client.get_dashboard_for_job_run(
88
+ applicationId=application_id, jobRunId=job_run_id
89
+ )
90
+ if "url" not in response:
91
+ return None
92
+ log_uri = urlparse(response["url"])
93
+ return log_uri
94
+
95
+
51
96
  def get_log_uri(
52
97
  *, cluster: dict[str, Any] | None = None, emr_client: boto3.client = None, job_flow_id: str | None = None
53
98
  ) -> str | None:
@@ -66,3 +111,78 @@ def get_log_uri(
66
111
  return None
67
112
  log_uri = S3Hook.parse_s3_url(cluster_info["LogUri"])
68
113
  return "/".join(log_uri)
114
+
115
+
116
+ class EmrServerlessLogsLink(BaseAwsLink):
117
+ """Helper class for constructing Amazon EMR Serverless link to Spark stdout logs."""
118
+
119
+ name = "Spark Driver stdout"
120
+ key = "emr_serverless_logs"
121
+
122
+ def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str:
123
+ if not application_id or not job_run_id:
124
+ return ""
125
+ url = get_serverless_dashboard_url(
126
+ aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id
127
+ )
128
+ if url:
129
+ return url._replace(path="/logs/SPARK_DRIVER/stdout.gz").geturl()
130
+ else:
131
+ return ""
132
+
133
+
134
+ class EmrServerlessDashboardLink(BaseAwsLink):
135
+ """Helper class for constructing Amazon EMR Serverless Dashboard Link."""
136
+
137
+ name = "EMR Serverless Dashboard"
138
+ key = "emr_serverless_dashboard"
139
+
140
+ def format_link(self, application_id: str | None = None, job_run_id: str | None = None, **kwargs) -> str:
141
+ if not application_id or not job_run_id:
142
+ return ""
143
+ url = get_serverless_dashboard_url(
144
+ aws_conn_id=kwargs.get("conn_id"), application_id=application_id, job_run_id=job_run_id
145
+ )
146
+ if url:
147
+ return url.geturl()
148
+ else:
149
+ return ""
150
+
151
+
152
+ class EmrServerlessS3LogsLink(BaseAwsLink):
153
+ """Helper class for constructing link to S3 console for Amazon EMR Serverless Logs."""
154
+
155
+ name = "S3 Logs"
156
+ key = "emr_serverless_s3_logs"
157
+ format_str = BASE_AWS_CONSOLE_LINK + (
158
+ "/s3/buckets/{bucket_name}?region={region_name}"
159
+ "&prefix={prefix}/applications/{application_id}/jobs/{job_run_id}/"
160
+ )
161
+
162
+ def format_link(self, **kwargs) -> str:
163
+ bucket, prefix = S3Hook.parse_s3_url(kwargs["log_uri"])
164
+ kwargs["bucket_name"] = bucket
165
+ kwargs["prefix"] = prefix.rstrip("/")
166
+ return super().format_link(**kwargs)
167
+
168
+
169
+ class EmrServerlessCloudWatchLogsLink(BaseAwsLink):
170
+ """
171
+ Helper class for constructing link to CloudWatch console for Amazon EMR Serverless Logs.
172
+
173
+ This is a deep link that filters on a specific job run.
174
+ """
175
+
176
+ name = "CloudWatch Logs"
177
+ key = "emr_serverless_cloudwatch_logs"
178
+ format_str = (
179
+ BASE_AWS_CONSOLE_LINK
180
+ + "/cloudwatch/home?region={region_name}#logsV2:log-groups/log-group/{awslogs_group}{stream_prefix}"
181
+ )
182
+
183
+ def format_link(self, **kwargs) -> str:
184
+ kwargs["awslogs_group"] = quote_plus(kwargs["awslogs_group"])
185
+ kwargs["stream_prefix"] = quote_plus("?logStreamNameFilter=").replace("%", "$") + quote_plus(
186
+ kwargs["stream_prefix"]
187
+ )
188
+ return super().format_link(**kwargs)
@@ -17,7 +17,7 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
- from datetime import date, datetime, timedelta
20
+ from datetime import date, datetime, timedelta, timezone
21
21
  from functools import cached_property
22
22
  from typing import TYPE_CHECKING, Any
23
23
 
@@ -163,7 +163,7 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
163
163
  return "\n".join(self._event_to_str(event) for event in events)
164
164
 
165
165
  def _event_to_str(self, event: dict) -> str:
166
- event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
166
+ event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc)
167
167
  formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
168
168
  message = event["message"]
169
169
  return f"[{formatted_event_dt}] {message}"
@@ -26,6 +26,7 @@ from airflow.providers.amazon.aws.hooks.athena import AthenaHook
26
26
  from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
27
27
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
28
28
  from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
29
+ from airflow.providers.amazon.aws.utils import validate_execute_complete_event
29
30
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
30
31
 
31
32
  if TYPE_CHECKING:
@@ -179,7 +180,9 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
179
180
 
180
181
  return self.query_execution_id
181
182
 
182
- def execute_complete(self, context, event=None):
183
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
184
+ event = validate_execute_complete_event(event)
185
+
183
186
  if event["status"] != "success":
184
187
  raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}")
185
188
  return event["value"]