apache-airflow-providers-amazon 7.4.1rc1__py3-none-any.whl → 8.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/aws/hooks/athena.py +0 -15
- airflow/providers/amazon/aws/hooks/base_aws.py +98 -65
- airflow/providers/amazon/aws/hooks/batch_client.py +60 -27
- airflow/providers/amazon/aws/hooks/batch_waiters.py +3 -1
- airflow/providers/amazon/aws/hooks/emr.py +33 -74
- airflow/providers/amazon/aws/hooks/logs.py +22 -4
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -12
- airflow/providers/amazon/aws/hooks/sagemaker.py +0 -16
- airflow/providers/amazon/aws/links/emr.py +1 -3
- airflow/providers/amazon/aws/operators/athena.py +0 -15
- airflow/providers/amazon/aws/operators/batch.py +78 -24
- airflow/providers/amazon/aws/operators/ecs.py +21 -58
- airflow/providers/amazon/aws/operators/eks.py +0 -1
- airflow/providers/amazon/aws/operators/emr.py +94 -24
- airflow/providers/amazon/aws/operators/lambda_function.py +0 -19
- airflow/providers/amazon/aws/operators/rds.py +1 -1
- airflow/providers/amazon/aws/operators/redshift_cluster.py +22 -1
- airflow/providers/amazon/aws/operators/redshift_data.py +0 -62
- airflow/providers/amazon/aws/secrets/secrets_manager.py +0 -17
- airflow/providers/amazon/aws/secrets/systems_manager.py +0 -21
- airflow/providers/amazon/aws/sensors/dynamodb.py +97 -0
- airflow/providers/amazon/aws/sensors/emr.py +1 -2
- airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +1 -1
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +0 -19
- airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -7
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +10 -10
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +0 -10
- airflow/providers/amazon/aws/transfers/mongo_to_s3.py +0 -11
- airflow/providers/amazon/aws/transfers/s3_to_sftp.py +0 -10
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +23 -9
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
- airflow/providers/amazon/aws/waiters/base_waiter.py +12 -1
- airflow/providers/amazon/aws/waiters/emr-serverless.json +18 -0
- airflow/providers/amazon/get_provider_info.py +35 -30
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/METADATA +81 -4
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/RECORD +41 -41
- airflow/providers/amazon/aws/operators/aws_lambda.py +0 -29
- airflow/providers/amazon/aws/operators/redshift_sql.py +0 -57
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/LICENSE +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/NOTICE +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/entry_points.txt +0 -0
- {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ This module contains AWS Athena hook.
|
|
24
24
|
"""
|
25
25
|
from __future__ import annotations
|
26
26
|
|
27
|
-
import warnings
|
28
27
|
from time import sleep
|
29
28
|
from typing import Any
|
30
29
|
|
@@ -224,7 +223,6 @@ class AthenaHook(AwsBaseHook):
|
|
224
223
|
def poll_query_status(
|
225
224
|
self,
|
226
225
|
query_execution_id: str,
|
227
|
-
max_tries: int | None = None,
|
228
226
|
max_polling_attempts: int | None = None,
|
229
227
|
) -> str | None:
|
230
228
|
"""
|
@@ -232,21 +230,8 @@ class AthenaHook(AwsBaseHook):
|
|
232
230
|
Returns one of the final states
|
233
231
|
|
234
232
|
:param query_execution_id: Id of submitted athena query
|
235
|
-
:param max_tries: Deprecated - Use max_polling_attempts instead
|
236
233
|
:param max_polling_attempts: Number of times to poll for query state before function exits
|
237
234
|
"""
|
238
|
-
if max_tries:
|
239
|
-
warnings.warn(
|
240
|
-
f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated "
|
241
|
-
f"and will be removed in a future release. Please use 'max_polling_attempts' instead.",
|
242
|
-
DeprecationWarning,
|
243
|
-
stacklevel=2,
|
244
|
-
)
|
245
|
-
if max_polling_attempts and max_polling_attempts != max_tries:
|
246
|
-
raise Exception("max_polling_attempts must be the same value as max_tries")
|
247
|
-
else:
|
248
|
-
max_polling_attempts = max_tries
|
249
|
-
|
250
235
|
try_number = 1
|
251
236
|
final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
|
252
237
|
while True:
|
@@ -30,7 +30,6 @@ import json
|
|
30
30
|
import logging
|
31
31
|
import os
|
32
32
|
import uuid
|
33
|
-
import warnings
|
34
33
|
from copy import deepcopy
|
35
34
|
from functools import wraps
|
36
35
|
from os import PathLike
|
@@ -58,7 +57,6 @@ from airflow.exceptions import (
|
|
58
57
|
)
|
59
58
|
from airflow.hooks.base import BaseHook
|
60
59
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
61
|
-
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
|
62
60
|
from airflow.providers_manager import ProvidersManager
|
63
61
|
from airflow.utils.helpers import exactly_one
|
64
62
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
@@ -72,12 +70,15 @@ if TYPE_CHECKING:
|
|
72
70
|
|
73
71
|
class BaseSessionFactory(LoggingMixin):
|
74
72
|
"""
|
75
|
-
Base AWS Session Factory class to handle
|
73
|
+
Base AWS Session Factory class to handle synchronous and async boto session creation.
|
76
74
|
It can handle most of the AWS supported authentication methods.
|
77
75
|
|
78
76
|
User can also derive from this class to have full control of boto3 session
|
79
77
|
creation or to support custom federation.
|
80
78
|
|
79
|
+
Note: Not all features implemented for synchronous sessions are available for async
|
80
|
+
sessions.
|
81
|
+
|
81
82
|
.. seealso::
|
82
83
|
- :ref:`howto/connection:aws:session-factory`
|
83
84
|
"""
|
@@ -127,17 +128,50 @@ class BaseSessionFactory(LoggingMixin):
|
|
127
128
|
"""Assume Role ARN from AWS Connection"""
|
128
129
|
return self.conn.role_arn
|
129
130
|
|
130
|
-
def
|
131
|
-
""
|
131
|
+
def _apply_session_kwargs(self, session):
|
132
|
+
if self.conn.session_kwargs.get("profile_name", None) is not None:
|
133
|
+
session.set_config_variable("profile", self.conn.session_kwargs["profile_name"])
|
134
|
+
|
135
|
+
if (
|
136
|
+
self.conn.session_kwargs.get("aws_access_key_id", None)
|
137
|
+
or self.conn.session_kwargs.get("aws_secret_access_key", None)
|
138
|
+
or self.conn.session_kwargs.get("aws_session_token", None)
|
139
|
+
):
|
140
|
+
session.set_credentials(
|
141
|
+
self.conn.session_kwargs["aws_access_key_id"],
|
142
|
+
self.conn.session_kwargs["aws_secret_access_key"],
|
143
|
+
self.conn.session_kwargs["aws_session_token"],
|
144
|
+
)
|
145
|
+
|
146
|
+
if self.conn.session_kwargs.get("region_name", None) is not None:
|
147
|
+
session.set_config_variable("region", self.conn.session_kwargs["region_name"])
|
148
|
+
|
149
|
+
def get_async_session(self):
|
150
|
+
from aiobotocore.session import get_session as async_get_session
|
151
|
+
|
152
|
+
return async_get_session()
|
153
|
+
|
154
|
+
def create_session(self, deferrable: bool = False) -> boto3.session.Session:
|
155
|
+
"""Create boto3 or aiobotocore Session from connection config."""
|
132
156
|
if not self.conn:
|
133
157
|
self.log.info(
|
134
158
|
"No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
|
135
159
|
"See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
|
136
160
|
self.region_name,
|
137
161
|
)
|
138
|
-
|
162
|
+
if deferrable:
|
163
|
+
session = self.get_async_session()
|
164
|
+
self._apply_session_kwargs(session)
|
165
|
+
return session
|
166
|
+
else:
|
167
|
+
return boto3.session.Session(region_name=self.region_name)
|
139
168
|
elif not self.role_arn:
|
140
|
-
|
169
|
+
if deferrable:
|
170
|
+
session = self.get_async_session()
|
171
|
+
self._apply_session_kwargs(session)
|
172
|
+
return session
|
173
|
+
else:
|
174
|
+
return self.basic_session
|
141
175
|
|
142
176
|
# Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only
|
143
177
|
# to create the initial boto3 session.
|
@@ -150,12 +184,18 @@ class BaseSessionFactory(LoggingMixin):
|
|
150
184
|
assume_session_kwargs = {}
|
151
185
|
if self.conn.region_name:
|
152
186
|
assume_session_kwargs["region_name"] = self.conn.region_name
|
153
|
-
return self._create_session_with_assume_role(
|
187
|
+
return self._create_session_with_assume_role(
|
188
|
+
session_kwargs=assume_session_kwargs, deferrable=deferrable
|
189
|
+
)
|
154
190
|
|
155
191
|
def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
|
156
192
|
return boto3.session.Session(**session_kwargs)
|
157
193
|
|
158
|
-
def _create_session_with_assume_role(
|
194
|
+
def _create_session_with_assume_role(
|
195
|
+
self, session_kwargs: dict[str, Any], deferrable: bool = False
|
196
|
+
) -> boto3.session.Session:
|
197
|
+
from aiobotocore.session import get_session as async_get_session
|
198
|
+
|
159
199
|
if self.conn.assume_role_method == "assume_role_with_web_identity":
|
160
200
|
# Deferred credentials have no initial credentials
|
161
201
|
credential_fetcher = self._get_web_identity_credential_fetcher()
|
@@ -172,10 +212,10 @@ class BaseSessionFactory(LoggingMixin):
|
|
172
212
|
method="sts-assume-role",
|
173
213
|
)
|
174
214
|
|
175
|
-
session = botocore.session.get_session()
|
215
|
+
session = async_get_session() if deferrable else botocore.session.get_session()
|
216
|
+
|
176
217
|
session._credentials = credentials
|
177
|
-
|
178
|
-
session.set_config_variable("region", region_name)
|
218
|
+
session.set_config_variable("region", self.basic_session.region_name)
|
179
219
|
|
180
220
|
return boto3.session.Session(botocore_session=session, **session_kwargs)
|
181
221
|
|
@@ -362,34 +402,6 @@ class BaseSessionFactory(LoggingMixin):
|
|
362
402
|
def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
|
363
403
|
return slugify(role_session_name, regex_pattern=r"[^\w+=,.@-]+")
|
364
404
|
|
365
|
-
def _get_region_name(self) -> str | None:
|
366
|
-
warnings.warn(
|
367
|
-
"`BaseSessionFactory._get_region_name` method deprecated and will be removed "
|
368
|
-
"in a future releases. Please use `BaseSessionFactory.region_name` property instead.",
|
369
|
-
DeprecationWarning,
|
370
|
-
stacklevel=2,
|
371
|
-
)
|
372
|
-
return self.region_name
|
373
|
-
|
374
|
-
def _read_role_arn_from_extra_config(self) -> str | None:
|
375
|
-
warnings.warn(
|
376
|
-
"`BaseSessionFactory._read_role_arn_from_extra_config` method deprecated and will be removed "
|
377
|
-
"in a future releases. Please use `BaseSessionFactory.role_arn` property instead.",
|
378
|
-
DeprecationWarning,
|
379
|
-
stacklevel=2,
|
380
|
-
)
|
381
|
-
return self.role_arn
|
382
|
-
|
383
|
-
def _read_credentials_from_connection(self) -> tuple[str | None, str | None]:
|
384
|
-
warnings.warn(
|
385
|
-
"`BaseSessionFactory._read_credentials_from_connection` method deprecated and will be removed "
|
386
|
-
"in a future releases. Please use `BaseSessionFactory.conn.aws_access_key_id` and "
|
387
|
-
"`BaseSessionFactory.aws_secret_access_key` properties instead.",
|
388
|
-
DeprecationWarning,
|
389
|
-
stacklevel=2,
|
390
|
-
)
|
391
|
-
return self.conn.aws_access_key_id, self.conn.aws_secret_access_key
|
392
|
-
|
393
405
|
|
394
406
|
class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
395
407
|
"""
|
@@ -531,13 +543,8 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
531
543
|
try:
|
532
544
|
connection = self.get_connection(self.aws_conn_id)
|
533
545
|
except AirflowNotFoundException:
|
534
|
-
|
535
|
-
|
536
|
-
"This behaviour is deprecated and will be removed in a future releases. "
|
537
|
-
"Please provide existed AWS connection ID or if required boto3 credential strategy "
|
538
|
-
"explicit set AWS Connection ID to None.",
|
539
|
-
DeprecationWarning,
|
540
|
-
stacklevel=2,
|
546
|
+
self.log.warning(
|
547
|
+
"Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
|
541
548
|
)
|
542
549
|
|
543
550
|
return AwsConnectionWrapper(
|
@@ -564,11 +571,11 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
564
571
|
"""Verify or not SSL certificates boto3 client/resource read-only property."""
|
565
572
|
return self.conn_config.verify
|
566
573
|
|
567
|
-
def get_session(self, region_name: str | None = None) -> boto3.session.Session:
|
574
|
+
def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session:
|
568
575
|
"""Get the underlying boto3.session.Session(region_name=region_name)."""
|
569
576
|
return SessionFactory(
|
570
577
|
conn=self.conn_config, region_name=region_name, config=self.config
|
571
|
-
).create_session()
|
578
|
+
).create_session(deferrable=deferrable)
|
572
579
|
|
573
580
|
def _get_config(self, config: Config | None = None) -> Config:
|
574
581
|
"""
|
@@ -591,10 +598,19 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
591
598
|
self,
|
592
599
|
region_name: str | None = None,
|
593
600
|
config: Config | None = None,
|
601
|
+
deferrable: bool = False,
|
594
602
|
) -> boto3.client:
|
595
603
|
"""Get the underlying boto3 client using boto3 session"""
|
596
604
|
client_type = self.client_type
|
597
|
-
session = self.get_session(region_name=region_name)
|
605
|
+
session = self.get_session(region_name=region_name, deferrable=deferrable)
|
606
|
+
if not isinstance(session, boto3.session.Session):
|
607
|
+
return session.create_client(
|
608
|
+
client_type,
|
609
|
+
endpoint_url=self.conn_config.endpoint_url,
|
610
|
+
config=self._get_config(config),
|
611
|
+
verify=self.verify,
|
612
|
+
)
|
613
|
+
|
598
614
|
return session.client(
|
599
615
|
client_type,
|
600
616
|
endpoint_url=self.conn_config.endpoint_url,
|
@@ -634,6 +650,14 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
634
650
|
else:
|
635
651
|
return self.get_resource_type(region_name=self.region_name)
|
636
652
|
|
653
|
+
@property
|
654
|
+
def async_conn(self):
|
655
|
+
"""Get an aiobotocore client to use for async operations."""
|
656
|
+
if not self.client_type:
|
657
|
+
raise ValueError("client_type must be specified.")
|
658
|
+
|
659
|
+
return self.get_client_type(region_name=self.region_name, deferrable=True)
|
660
|
+
|
637
661
|
@cached_property
|
638
662
|
def conn_client_meta(self) -> ClientMeta:
|
639
663
|
"""Get botocore client metadata from Hook connection (cached)."""
|
@@ -730,17 +754,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
730
754
|
|
731
755
|
return retry_decorator
|
732
756
|
|
733
|
-
def _get_credentials(self, region_name: str | None) -> tuple[boto3.session.Session, str | None]:
|
734
|
-
warnings.warn(
|
735
|
-
"`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases. "
|
736
|
-
"Please use `AwsGenericHook.get_session` method and "
|
737
|
-
"`AwsGenericHook.conn_config.endpoint_url` property instead.",
|
738
|
-
DeprecationWarning,
|
739
|
-
stacklevel=2,
|
740
|
-
)
|
741
|
-
|
742
|
-
return self.get_session(region_name=region_name), self.conn_config.endpoint_url
|
743
|
-
|
744
757
|
@staticmethod
|
745
758
|
def get_ui_field_behaviour() -> dict[str, Any]:
|
746
759
|
"""Returns custom UI field behaviour for AWS Connection."""
|
@@ -794,21 +807,39 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
794
807
|
|
795
808
|
@cached_property
|
796
809
|
def waiter_path(self) -> PathLike[str] | None:
|
797
|
-
|
810
|
+
filename = self.client_type if self.client_type else self.resource_type
|
811
|
+
path = Path(__file__).parents[1].joinpath(f"waiters/{filename}.json").resolve()
|
798
812
|
return path if path.exists() else None
|
799
813
|
|
800
|
-
def get_waiter(
|
814
|
+
def get_waiter(
|
815
|
+
self,
|
816
|
+
waiter_name: str,
|
817
|
+
parameters: dict[str, str] | None = None,
|
818
|
+
deferrable: bool = False,
|
819
|
+
client=None,
|
820
|
+
) -> Waiter:
|
801
821
|
"""
|
802
822
|
First checks if there is a custom waiter with the provided waiter_name and
|
803
823
|
uses that if it exists, otherwise it will check the service client for a
|
804
824
|
waiter that matches the name and pass that through.
|
805
825
|
|
826
|
+
If `deferrable` is True, the waiter will be an AIOWaiter, generated from the
|
827
|
+
client that is passed as a parameter. If `deferrable` is True, `client` must be
|
828
|
+
provided.
|
829
|
+
|
806
830
|
:param waiter_name: The name of the waiter. The name should exactly match the
|
807
831
|
name of the key in the waiter model file (typically this is CamelCase).
|
808
832
|
:param parameters: will scan the waiter config for the keys of that dict, and replace them with the
|
809
833
|
corresponding value. If a custom waiter has such keys to be expanded, they need to be provided
|
810
834
|
here.
|
835
|
+
:param deferrable: If True, the waiter is going to be an async custom waiter.
|
836
|
+
|
811
837
|
"""
|
838
|
+
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
|
839
|
+
|
840
|
+
if deferrable and not client:
|
841
|
+
raise ValueError("client must be provided for a deferrable waiter.")
|
842
|
+
client = client or self.conn
|
812
843
|
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
|
813
844
|
# Technically if waiter_name is in custom_waiters then self.waiter_path must
|
814
845
|
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
|
@@ -816,7 +847,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
816
847
|
config = json.loads(config_file.read())
|
817
848
|
|
818
849
|
config = self._apply_parameters_value(config, waiter_name, parameters)
|
819
|
-
return BaseBotoWaiter(client=
|
850
|
+
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
|
851
|
+
waiter_name
|
852
|
+
)
|
820
853
|
# If there is no custom waiter found for the provided name,
|
821
854
|
# then try checking the service's official waiters.
|
822
855
|
return self.conn.get_waiter(waiter_name)
|
@@ -985,7 +1018,7 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
|
|
985
1018
|
aio_session.set_config_variable("region", region_name)
|
986
1019
|
return aio_session
|
987
1020
|
|
988
|
-
def create_session(self) -> AioSession:
|
1021
|
+
def create_session(self, deferrable: bool = False) -> AioSession:
|
989
1022
|
"""Create aiobotocore Session from connection and config."""
|
990
1023
|
if not self._conn:
|
991
1024
|
self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
|
@@ -414,43 +414,76 @@ class BatchClientHook(AwsBaseHook):
|
|
414
414
|
return matching_jobs[0]
|
415
415
|
|
416
416
|
def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None:
|
417
|
+
all_info = self.get_job_all_awslogs_info(job_id)
|
418
|
+
if not all_info:
|
419
|
+
return None
|
420
|
+
if len(all_info) > 1:
|
421
|
+
self.log.warning(
|
422
|
+
f"AWS Batch job ({job_id}) has more than one log stream, " f"only returning the first one."
|
423
|
+
)
|
424
|
+
return all_info[0]
|
425
|
+
|
426
|
+
def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
|
417
427
|
"""
|
418
428
|
Parse job description to extract AWS CloudWatch information.
|
419
429
|
|
420
430
|
:param job_id: AWS Batch Job ID
|
421
431
|
"""
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
432
|
+
job_desc = self.get_job_description(job_id=job_id)
|
433
|
+
|
434
|
+
job_node_properties = job_desc.get("nodeProperties", {})
|
435
|
+
job_container_desc = job_desc.get("container", {})
|
436
|
+
|
437
|
+
if job_node_properties:
|
438
|
+
# one log config per node
|
439
|
+
log_configs = [
|
440
|
+
p.get("container", {}).get("logConfiguration", {})
|
441
|
+
for p in job_node_properties.get("nodeRangeProperties", {})
|
442
|
+
]
|
443
|
+
# one stream name per attempt
|
444
|
+
stream_names = [a.get("container", {}).get("logStreamName") for a in job_desc.get("attempts", [])]
|
445
|
+
elif job_container_desc:
|
446
|
+
log_configs = [job_container_desc.get("logConfiguration", {})]
|
447
|
+
stream_name = job_container_desc.get("logStreamName")
|
448
|
+
stream_names = [stream_name] if stream_name is not None else []
|
449
|
+
else:
|
450
|
+
raise AirflowException(
|
451
|
+
f"AWS Batch job ({job_id}) is not a supported job type. "
|
452
|
+
"Supported job types: container, array, multinode."
|
453
|
+
)
|
454
|
+
|
455
|
+
# If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
|
456
|
+
if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]):
|
433
457
|
self.log.warning(
|
434
|
-
"AWS Batch job (
|
458
|
+
f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
|
435
459
|
)
|
436
|
-
return
|
460
|
+
return []
|
437
461
|
|
438
|
-
|
439
|
-
|
440
|
-
#
|
441
|
-
#
|
442
|
-
|
443
|
-
|
444
|
-
return None
|
462
|
+
if not stream_names:
|
463
|
+
# If this method is called very early after starting the AWS Batch job,
|
464
|
+
# there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
|
465
|
+
# This can also happen in case of misconfiguration.
|
466
|
+
self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.")
|
467
|
+
return []
|
445
468
|
|
446
469
|
# Try to get user-defined log configuration options
|
447
|
-
log_options =
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
470
|
+
log_options = [c.get("options", {}) for c in log_configs]
|
471
|
+
|
472
|
+
# cross stream names with options (i.e. attempts X nodes) to generate all log infos
|
473
|
+
result = []
|
474
|
+
for stream in stream_names:
|
475
|
+
for option in log_options:
|
476
|
+
result.append(
|
477
|
+
{
|
478
|
+
"awslogs_stream_name": stream,
|
479
|
+
# If the user did not specify anything, the default settings are:
|
480
|
+
# awslogs-group = /aws/batch/job
|
481
|
+
# awslogs-region = `same as AWS Batch Job region`
|
482
|
+
"awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
|
483
|
+
"awslogs_region": option.get("awslogs-region", self.conn_region_name),
|
484
|
+
}
|
485
|
+
)
|
486
|
+
return result
|
454
487
|
|
455
488
|
@staticmethod
|
456
489
|
def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
|
@@ -138,7 +138,9 @@ class BatchWaitersHook(BatchClientHook):
|
|
138
138
|
"""
|
139
139
|
return self._waiter_model
|
140
140
|
|
141
|
-
def get_waiter(
|
141
|
+
def get_waiter(
|
142
|
+
self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None
|
143
|
+
) -> botocore.waiter.Waiter:
|
142
144
|
"""
|
143
145
|
Get an AWS Batch service waiter, using the configured ``.waiter_model``.
|
144
146
|
|
@@ -20,14 +20,12 @@ from __future__ import annotations
|
|
20
20
|
import json
|
21
21
|
import warnings
|
22
22
|
from time import sleep
|
23
|
-
from typing import Any
|
23
|
+
from typing import Any
|
24
24
|
|
25
25
|
from botocore.exceptions import ClientError
|
26
26
|
|
27
|
-
from airflow.compat.functools import cached_property
|
28
27
|
from airflow.exceptions import AirflowException, AirflowNotFoundException
|
29
28
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
30
|
-
from airflow.providers.amazon.aws.utils.waiter import get_state, waiter
|
31
29
|
from airflow.utils.helpers import prune_dict
|
32
30
|
|
33
31
|
|
@@ -254,66 +252,41 @@ class EmrServerlessHook(AwsBaseHook):
|
|
254
252
|
kwargs["client_type"] = "emr-serverless"
|
255
253
|
super().__init__(*args, **kwargs)
|
256
254
|
|
257
|
-
|
258
|
-
def conn(self):
|
259
|
-
"""Get the underlying boto3 EmrServerlessAPIService client (cached)"""
|
260
|
-
return super().conn
|
261
|
-
|
262
|
-
# This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
|
263
|
-
def waiter(
|
264
|
-
self,
|
265
|
-
get_state_callable: Callable,
|
266
|
-
get_state_args: dict,
|
267
|
-
parse_response: list,
|
268
|
-
desired_state: set,
|
269
|
-
failure_states: set,
|
270
|
-
object_type: str,
|
271
|
-
action: str,
|
272
|
-
countdown: int = 25 * 60,
|
273
|
-
check_interval_seconds: int = 60,
|
274
|
-
) -> None:
|
255
|
+
def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
|
275
256
|
"""
|
276
|
-
|
277
|
-
|
278
|
-
:
|
279
|
-
|
280
|
-
:param parse_response: Dictionary keys to extract state from response of get_state_callable
|
281
|
-
:param desired_state: Wait until the getter returns this value
|
282
|
-
:param failure_states: A set of states which indicate failure and should throw an
|
283
|
-
exception if any are reached before the desired_state
|
284
|
-
:param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
|
285
|
-
:param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
|
286
|
-
:param countdown: Total amount of time the waiter should wait for the desired state
|
287
|
-
before timing out (in seconds). Defaults to 25 * 60 seconds.
|
288
|
-
:param check_interval_seconds: Number of seconds waiter should wait before attempting
|
289
|
-
to retry get_state_callable. Defaults to 60 seconds.
|
257
|
+
List all jobs in an intermediate state and cancel them.
|
258
|
+
Then wait for those jobs to reach a terminal state.
|
259
|
+
Note: if new jobs are triggered while this operation is ongoing,
|
260
|
+
it's going to time out and return an error.
|
290
261
|
"""
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
get_state_args=get_state_args,
|
300
|
-
parse_response=parse_response,
|
301
|
-
desired_state=desired_state,
|
302
|
-
failure_states=failure_states,
|
303
|
-
object_type=object_type,
|
304
|
-
action=action,
|
305
|
-
countdown=countdown,
|
306
|
-
check_interval_seconds=check_interval_seconds,
|
307
|
-
)
|
308
|
-
|
309
|
-
def get_state(self, response, keys) -> str:
|
310
|
-
warnings.warn(
|
311
|
-
"""This method is deprecated.
|
312
|
-
Please use `airflow.providers.amazon.aws.utils.waiter.get_state`.""",
|
313
|
-
DeprecationWarning,
|
314
|
-
stacklevel=2,
|
262
|
+
paginator = self.conn.get_paginator("list_job_runs")
|
263
|
+
results_per_response = 50
|
264
|
+
iterator = paginator.paginate(
|
265
|
+
applicationId=application_id,
|
266
|
+
states=list(self.JOB_INTERMEDIATE_STATES),
|
267
|
+
PaginationConfig={
|
268
|
+
"PageSize": results_per_response,
|
269
|
+
},
|
315
270
|
)
|
316
|
-
|
271
|
+
count = 0
|
272
|
+
for r in iterator:
|
273
|
+
job_ids = [jr["id"] for jr in r["jobRuns"]]
|
274
|
+
count += len(job_ids)
|
275
|
+
if len(job_ids) > 0:
|
276
|
+
self.log.info(
|
277
|
+
"Cancelling %s pending job(s) for the application %s so that it can be stopped",
|
278
|
+
len(job_ids),
|
279
|
+
application_id,
|
280
|
+
)
|
281
|
+
for job_id in job_ids:
|
282
|
+
self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
|
283
|
+
if count > 0:
|
284
|
+
self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
|
285
|
+
self.get_waiter("no_job_running").wait(
|
286
|
+
applicationId=application_id,
|
287
|
+
states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
|
288
|
+
WaiterConfig=waiter_config,
|
289
|
+
)
|
317
290
|
|
318
291
|
|
319
292
|
class EmrContainerHook(AwsBaseHook):
|
@@ -483,7 +456,6 @@ class EmrContainerHook(AwsBaseHook):
|
|
483
456
|
def poll_query_status(
|
484
457
|
self,
|
485
458
|
job_id: str,
|
486
|
-
max_tries: int | None = None,
|
487
459
|
poll_interval: int = 30,
|
488
460
|
max_polling_attempts: int | None = None,
|
489
461
|
) -> str | None:
|
@@ -492,22 +464,9 @@ class EmrContainerHook(AwsBaseHook):
|
|
492
464
|
Returns one of the final states.
|
493
465
|
|
494
466
|
:param job_id: The ID of the job run request.
|
495
|
-
:param max_tries: Deprecated - Use max_polling_attempts instead
|
496
467
|
:param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
|
497
468
|
:param max_polling_attempts: Number of times to poll for query state before function exits
|
498
469
|
"""
|
499
|
-
if max_tries:
|
500
|
-
warnings.warn(
|
501
|
-
f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
|
502
|
-
"in a future release. Please use method `max_polling_attempts` instead.",
|
503
|
-
DeprecationWarning,
|
504
|
-
stacklevel=2,
|
505
|
-
)
|
506
|
-
if max_polling_attempts and max_polling_attempts != max_tries:
|
507
|
-
raise Exception("max_polling_attempts must be the same value as max_tries")
|
508
|
-
else:
|
509
|
-
max_polling_attempts = max_tries
|
510
|
-
|
511
470
|
try_number = 1
|
512
471
|
final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
|
513
472
|
|
@@ -25,6 +25,14 @@ from typing import Generator
|
|
25
25
|
|
26
26
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
27
27
|
|
28
|
+
# Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the
|
29
|
+
# value of the nextForwardToken is the same in subsequent calls.
|
30
|
+
# The issue with this approach is, it can take a huge amount of time (e.g. 20 seconds) to retrieve logs using
|
31
|
+
# this approach. As an intermediate solution, we decided to stop fetching logs if 3 consecutive responses
|
32
|
+
# are empty.
|
33
|
+
# See PR https://github.com/apache/airflow/pull/20814
|
34
|
+
NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3
|
35
|
+
|
28
36
|
|
29
37
|
class AwsLogsHook(AwsBaseHook):
|
30
38
|
"""
|
@@ -69,14 +77,15 @@ class AwsLogsHook(AwsBaseHook):
|
|
69
77
|
| 'message' (str): The log event data.
|
70
78
|
| 'ingestionTime' (int): The time in milliseconds the event was ingested.
|
71
79
|
"""
|
80
|
+
num_consecutive_empty_response = 0
|
72
81
|
next_token = None
|
73
82
|
while True:
|
74
83
|
if next_token is not None:
|
75
|
-
token_arg: dict[str, str]
|
84
|
+
token_arg: dict[str, str] = {"nextToken": next_token}
|
76
85
|
else:
|
77
86
|
token_arg = {}
|
78
87
|
|
79
|
-
response = self.
|
88
|
+
response = self.conn.get_log_events(
|
80
89
|
logGroupName=log_group,
|
81
90
|
logStreamName=log_stream_name,
|
82
91
|
startTime=start_time,
|
@@ -96,7 +105,16 @@ class AwsLogsHook(AwsBaseHook):
|
|
96
105
|
|
97
106
|
yield from events
|
98
107
|
|
99
|
-
if
|
100
|
-
|
108
|
+
if not event_count:
|
109
|
+
num_consecutive_empty_response += 1
|
110
|
+
if num_consecutive_empty_response >= NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD:
|
111
|
+
# Exit if there are more than NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD consecutive
|
112
|
+
# empty responses
|
113
|
+
return
|
114
|
+
elif next_token != response["nextForwardToken"]:
|
115
|
+
num_consecutive_empty_response = 0
|
101
116
|
else:
|
117
|
+
# Exit if the value of nextForwardToken is same in subsequent calls
|
102
118
|
return
|
119
|
+
|
120
|
+
next_token = response["nextForwardToken"]
|