apache-airflow-providers-amazon 7.4.1rc1__py3-none-any.whl → 8.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. airflow/providers/amazon/aws/hooks/athena.py +0 -15
  2. airflow/providers/amazon/aws/hooks/base_aws.py +98 -65
  3. airflow/providers/amazon/aws/hooks/batch_client.py +60 -27
  4. airflow/providers/amazon/aws/hooks/batch_waiters.py +3 -1
  5. airflow/providers/amazon/aws/hooks/emr.py +33 -74
  6. airflow/providers/amazon/aws/hooks/logs.py +22 -4
  7. airflow/providers/amazon/aws/hooks/redshift_cluster.py +1 -12
  8. airflow/providers/amazon/aws/hooks/sagemaker.py +0 -16
  9. airflow/providers/amazon/aws/links/emr.py +1 -3
  10. airflow/providers/amazon/aws/operators/athena.py +0 -15
  11. airflow/providers/amazon/aws/operators/batch.py +78 -24
  12. airflow/providers/amazon/aws/operators/ecs.py +21 -58
  13. airflow/providers/amazon/aws/operators/eks.py +0 -1
  14. airflow/providers/amazon/aws/operators/emr.py +94 -24
  15. airflow/providers/amazon/aws/operators/lambda_function.py +0 -19
  16. airflow/providers/amazon/aws/operators/rds.py +1 -1
  17. airflow/providers/amazon/aws/operators/redshift_cluster.py +22 -1
  18. airflow/providers/amazon/aws/operators/redshift_data.py +0 -62
  19. airflow/providers/amazon/aws/secrets/secrets_manager.py +0 -17
  20. airflow/providers/amazon/aws/secrets/systems_manager.py +0 -21
  21. airflow/providers/amazon/aws/sensors/dynamodb.py +97 -0
  22. airflow/providers/amazon/aws/sensors/emr.py +1 -2
  23. airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +1 -1
  24. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +0 -19
  25. airflow/providers/amazon/aws/transfers/glacier_to_gcs.py +1 -7
  26. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +10 -10
  27. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +0 -10
  28. airflow/providers/amazon/aws/transfers/mongo_to_s3.py +0 -11
  29. airflow/providers/amazon/aws/transfers/s3_to_sftp.py +0 -10
  30. airflow/providers/amazon/aws/transfers/sql_to_s3.py +23 -9
  31. airflow/providers/amazon/aws/triggers/redshift_cluster.py +54 -2
  32. airflow/providers/amazon/aws/waiters/base_waiter.py +12 -1
  33. airflow/providers/amazon/aws/waiters/emr-serverless.json +18 -0
  34. airflow/providers/amazon/get_provider_info.py +35 -30
  35. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/METADATA +81 -4
  36. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/RECORD +41 -41
  37. airflow/providers/amazon/aws/operators/aws_lambda.py +0 -29
  38. airflow/providers/amazon/aws/operators/redshift_sql.py +0 -57
  39. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/LICENSE +0 -0
  40. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/NOTICE +0 -0
  41. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/WHEEL +0 -0
  42. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/entry_points.txt +0 -0
  43. {apache_airflow_providers_amazon-7.4.1rc1.dist-info → apache_airflow_providers_amazon-8.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ This module contains AWS Athena hook.
24
24
  """
25
25
  from __future__ import annotations
26
26
 
27
- import warnings
28
27
  from time import sleep
29
28
  from typing import Any
30
29
 
@@ -224,7 +223,6 @@ class AthenaHook(AwsBaseHook):
224
223
  def poll_query_status(
225
224
  self,
226
225
  query_execution_id: str,
227
- max_tries: int | None = None,
228
226
  max_polling_attempts: int | None = None,
229
227
  ) -> str | None:
230
228
  """
@@ -232,21 +230,8 @@ class AthenaHook(AwsBaseHook):
232
230
  Returns one of the final states
233
231
 
234
232
  :param query_execution_id: Id of submitted athena query
235
- :param max_tries: Deprecated - Use max_polling_attempts instead
236
233
  :param max_polling_attempts: Number of times to poll for query state before function exits
237
234
  """
238
- if max_tries:
239
- warnings.warn(
240
- f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated "
241
- f"and will be removed in a future release. Please use 'max_polling_attempts' instead.",
242
- DeprecationWarning,
243
- stacklevel=2,
244
- )
245
- if max_polling_attempts and max_polling_attempts != max_tries:
246
- raise Exception("max_polling_attempts must be the same value as max_tries")
247
- else:
248
- max_polling_attempts = max_tries
249
-
250
235
  try_number = 1
251
236
  final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
252
237
  while True:
@@ -30,7 +30,6 @@ import json
30
30
  import logging
31
31
  import os
32
32
  import uuid
33
- import warnings
34
33
  from copy import deepcopy
35
34
  from functools import wraps
36
35
  from os import PathLike
@@ -58,7 +57,6 @@ from airflow.exceptions import (
58
57
  )
59
58
  from airflow.hooks.base import BaseHook
60
59
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
61
- from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
62
60
  from airflow.providers_manager import ProvidersManager
63
61
  from airflow.utils.helpers import exactly_one
64
62
  from airflow.utils.log.logging_mixin import LoggingMixin
@@ -72,12 +70,15 @@ if TYPE_CHECKING:
72
70
 
73
71
  class BaseSessionFactory(LoggingMixin):
74
72
  """
75
- Base AWS Session Factory class to handle boto3 session creation.
73
+ Base AWS Session Factory class to handle synchronous and async boto session creation.
76
74
  It can handle most of the AWS supported authentication methods.
77
75
 
78
76
  User can also derive from this class to have full control of boto3 session
79
77
  creation or to support custom federation.
80
78
 
79
+ Note: Not all features implemented for synchronous sessions are available for async
80
+ sessions.
81
+
81
82
  .. seealso::
82
83
  - :ref:`howto/connection:aws:session-factory`
83
84
  """
@@ -127,17 +128,50 @@ class BaseSessionFactory(LoggingMixin):
127
128
  """Assume Role ARN from AWS Connection"""
128
129
  return self.conn.role_arn
129
130
 
130
- def create_session(self) -> boto3.session.Session:
131
- """Create boto3 Session from connection config."""
131
+ def _apply_session_kwargs(self, session):
132
+ if self.conn.session_kwargs.get("profile_name", None) is not None:
133
+ session.set_config_variable("profile", self.conn.session_kwargs["profile_name"])
134
+
135
+ if (
136
+ self.conn.session_kwargs.get("aws_access_key_id", None)
137
+ or self.conn.session_kwargs.get("aws_secret_access_key", None)
138
+ or self.conn.session_kwargs.get("aws_session_token", None)
139
+ ):
140
+ session.set_credentials(
141
+ self.conn.session_kwargs["aws_access_key_id"],
142
+ self.conn.session_kwargs["aws_secret_access_key"],
143
+ self.conn.session_kwargs["aws_session_token"],
144
+ )
145
+
146
+ if self.conn.session_kwargs.get("region_name", None) is not None:
147
+ session.set_config_variable("region", self.conn.session_kwargs["region_name"])
148
+
149
+ def get_async_session(self):
150
+ from aiobotocore.session import get_session as async_get_session
151
+
152
+ return async_get_session()
153
+
154
+ def create_session(self, deferrable: bool = False) -> boto3.session.Session:
155
+ """Create boto3 or aiobotocore Session from connection config."""
132
156
  if not self.conn:
133
157
  self.log.info(
134
158
  "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
135
159
  "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
136
160
  self.region_name,
137
161
  )
138
- return boto3.session.Session(region_name=self.region_name)
162
+ if deferrable:
163
+ session = self.get_async_session()
164
+ self._apply_session_kwargs(session)
165
+ return session
166
+ else:
167
+ return boto3.session.Session(region_name=self.region_name)
139
168
  elif not self.role_arn:
140
- return self.basic_session
169
+ if deferrable:
170
+ session = self.get_async_session()
171
+ self._apply_session_kwargs(session)
172
+ return session
173
+ else:
174
+ return self.basic_session
141
175
 
142
176
  # Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only
143
177
  # to create the initial boto3 session.
@@ -150,12 +184,18 @@ class BaseSessionFactory(LoggingMixin):
150
184
  assume_session_kwargs = {}
151
185
  if self.conn.region_name:
152
186
  assume_session_kwargs["region_name"] = self.conn.region_name
153
- return self._create_session_with_assume_role(session_kwargs=assume_session_kwargs)
187
+ return self._create_session_with_assume_role(
188
+ session_kwargs=assume_session_kwargs, deferrable=deferrable
189
+ )
154
190
 
155
191
  def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
156
192
  return boto3.session.Session(**session_kwargs)
157
193
 
158
- def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
194
+ def _create_session_with_assume_role(
195
+ self, session_kwargs: dict[str, Any], deferrable: bool = False
196
+ ) -> boto3.session.Session:
197
+ from aiobotocore.session import get_session as async_get_session
198
+
159
199
  if self.conn.assume_role_method == "assume_role_with_web_identity":
160
200
  # Deferred credentials have no initial credentials
161
201
  credential_fetcher = self._get_web_identity_credential_fetcher()
@@ -172,10 +212,10 @@ class BaseSessionFactory(LoggingMixin):
172
212
  method="sts-assume-role",
173
213
  )
174
214
 
175
- session = botocore.session.get_session()
215
+ session = async_get_session() if deferrable else botocore.session.get_session()
216
+
176
217
  session._credentials = credentials
177
- region_name = self.basic_session.region_name
178
- session.set_config_variable("region", region_name)
218
+ session.set_config_variable("region", self.basic_session.region_name)
179
219
 
180
220
  return boto3.session.Session(botocore_session=session, **session_kwargs)
181
221
 
@@ -362,34 +402,6 @@ class BaseSessionFactory(LoggingMixin):
362
402
  def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
363
403
  return slugify(role_session_name, regex_pattern=r"[^\w+=,.@-]+")
364
404
 
365
- def _get_region_name(self) -> str | None:
366
- warnings.warn(
367
- "`BaseSessionFactory._get_region_name` method deprecated and will be removed "
368
- "in a future releases. Please use `BaseSessionFactory.region_name` property instead.",
369
- DeprecationWarning,
370
- stacklevel=2,
371
- )
372
- return self.region_name
373
-
374
- def _read_role_arn_from_extra_config(self) -> str | None:
375
- warnings.warn(
376
- "`BaseSessionFactory._read_role_arn_from_extra_config` method deprecated and will be removed "
377
- "in a future releases. Please use `BaseSessionFactory.role_arn` property instead.",
378
- DeprecationWarning,
379
- stacklevel=2,
380
- )
381
- return self.role_arn
382
-
383
- def _read_credentials_from_connection(self) -> tuple[str | None, str | None]:
384
- warnings.warn(
385
- "`BaseSessionFactory._read_credentials_from_connection` method deprecated and will be removed "
386
- "in a future releases. Please use `BaseSessionFactory.conn.aws_access_key_id` and "
387
- "`BaseSessionFactory.aws_secret_access_key` properties instead.",
388
- DeprecationWarning,
389
- stacklevel=2,
390
- )
391
- return self.conn.aws_access_key_id, self.conn.aws_secret_access_key
392
-
393
405
 
394
406
  class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
395
407
  """
@@ -531,13 +543,8 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
531
543
  try:
532
544
  connection = self.get_connection(self.aws_conn_id)
533
545
  except AirflowNotFoundException:
534
- warnings.warn(
535
- f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
536
- "This behaviour is deprecated and will be removed in a future releases. "
537
- "Please provide existed AWS connection ID or if required boto3 credential strategy "
538
- "explicit set AWS Connection ID to None.",
539
- DeprecationWarning,
540
- stacklevel=2,
546
+ self.log.warning(
547
+ "Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
541
548
  )
542
549
 
543
550
  return AwsConnectionWrapper(
@@ -564,11 +571,11 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
564
571
  """Verify or not SSL certificates boto3 client/resource read-only property."""
565
572
  return self.conn_config.verify
566
573
 
567
- def get_session(self, region_name: str | None = None) -> boto3.session.Session:
574
+ def get_session(self, region_name: str | None = None, deferrable: bool = False) -> boto3.session.Session:
568
575
  """Get the underlying boto3.session.Session(region_name=region_name)."""
569
576
  return SessionFactory(
570
577
  conn=self.conn_config, region_name=region_name, config=self.config
571
- ).create_session()
578
+ ).create_session(deferrable=deferrable)
572
579
 
573
580
  def _get_config(self, config: Config | None = None) -> Config:
574
581
  """
@@ -591,10 +598,19 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
591
598
  self,
592
599
  region_name: str | None = None,
593
600
  config: Config | None = None,
601
+ deferrable: bool = False,
594
602
  ) -> boto3.client:
595
603
  """Get the underlying boto3 client using boto3 session"""
596
604
  client_type = self.client_type
597
- session = self.get_session(region_name=region_name)
605
+ session = self.get_session(region_name=region_name, deferrable=deferrable)
606
+ if not isinstance(session, boto3.session.Session):
607
+ return session.create_client(
608
+ client_type,
609
+ endpoint_url=self.conn_config.endpoint_url,
610
+ config=self._get_config(config),
611
+ verify=self.verify,
612
+ )
613
+
598
614
  return session.client(
599
615
  client_type,
600
616
  endpoint_url=self.conn_config.endpoint_url,
@@ -634,6 +650,14 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
634
650
  else:
635
651
  return self.get_resource_type(region_name=self.region_name)
636
652
 
653
+ @property
654
+ def async_conn(self):
655
+ """Get an aiobotocore client to use for async operations."""
656
+ if not self.client_type:
657
+ raise ValueError("client_type must be specified.")
658
+
659
+ return self.get_client_type(region_name=self.region_name, deferrable=True)
660
+
637
661
  @cached_property
638
662
  def conn_client_meta(self) -> ClientMeta:
639
663
  """Get botocore client metadata from Hook connection (cached)."""
@@ -730,17 +754,6 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
730
754
 
731
755
  return retry_decorator
732
756
 
733
- def _get_credentials(self, region_name: str | None) -> tuple[boto3.session.Session, str | None]:
734
- warnings.warn(
735
- "`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases. "
736
- "Please use `AwsGenericHook.get_session` method and "
737
- "`AwsGenericHook.conn_config.endpoint_url` property instead.",
738
- DeprecationWarning,
739
- stacklevel=2,
740
- )
741
-
742
- return self.get_session(region_name=region_name), self.conn_config.endpoint_url
743
-
744
757
  @staticmethod
745
758
  def get_ui_field_behaviour() -> dict[str, Any]:
746
759
  """Returns custom UI field behaviour for AWS Connection."""
@@ -794,21 +807,39 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
794
807
 
795
808
  @cached_property
796
809
  def waiter_path(self) -> PathLike[str] | None:
797
- path = Path(__file__).parents[1].joinpath(f"waiters/{self.client_type}.json").resolve()
810
+ filename = self.client_type if self.client_type else self.resource_type
811
+ path = Path(__file__).parents[1].joinpath(f"waiters/{filename}.json").resolve()
798
812
  return path if path.exists() else None
799
813
 
800
- def get_waiter(self, waiter_name: str, parameters: dict[str, str] | None = None) -> Waiter:
814
+ def get_waiter(
815
+ self,
816
+ waiter_name: str,
817
+ parameters: dict[str, str] | None = None,
818
+ deferrable: bool = False,
819
+ client=None,
820
+ ) -> Waiter:
801
821
  """
802
822
  First checks if there is a custom waiter with the provided waiter_name and
803
823
  uses that if it exists, otherwise it will check the service client for a
804
824
  waiter that matches the name and pass that through.
805
825
 
826
+ If `deferrable` is True, the waiter will be an AIOWaiter, generated from the
827
+ client that is passed as a parameter. If `deferrable` is True, `client` must be
828
+ provided.
829
+
806
830
  :param waiter_name: The name of the waiter. The name should exactly match the
807
831
  name of the key in the waiter model file (typically this is CamelCase).
808
832
  :param parameters: will scan the waiter config for the keys of that dict, and replace them with the
809
833
  corresponding value. If a custom waiter has such keys to be expanded, they need to be provided
810
834
  here.
835
+ :param deferrable: If True, the waiter is going to be an async custom waiter.
836
+
811
837
  """
838
+ from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
839
+
840
+ if deferrable and not client:
841
+ raise ValueError("client must be provided for a deferrable waiter.")
842
+ client = client or self.conn
812
843
  if self.waiter_path and (waiter_name in self._list_custom_waiters()):
813
844
  # Technically if waiter_name is in custom_waiters then self.waiter_path must
814
845
  # exist but MyPy doesn't like the fact that self.waiter_path could be None.
@@ -816,7 +847,9 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
816
847
  config = json.loads(config_file.read())
817
848
 
818
849
  config = self._apply_parameters_value(config, waiter_name, parameters)
819
- return BaseBotoWaiter(client=self.conn, model_config=config).waiter(waiter_name)
850
+ return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
851
+ waiter_name
852
+ )
820
853
  # If there is no custom waiter found for the provided name,
821
854
  # then try checking the service's official waiters.
822
855
  return self.conn.get_waiter(waiter_name)
@@ -985,7 +1018,7 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
985
1018
  aio_session.set_config_variable("region", region_name)
986
1019
  return aio_session
987
1020
 
988
- def create_session(self) -> AioSession:
1021
+ def create_session(self, deferrable: bool = False) -> AioSession:
989
1022
  """Create aiobotocore Session from connection and config."""
990
1023
  if not self._conn:
991
1024
  self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
@@ -414,43 +414,76 @@ class BatchClientHook(AwsBaseHook):
414
414
  return matching_jobs[0]
415
415
 
416
416
  def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None:
417
+ all_info = self.get_job_all_awslogs_info(job_id)
418
+ if not all_info:
419
+ return None
420
+ if len(all_info) > 1:
421
+ self.log.warning(
422
+ f"AWS Batch job ({job_id}) has more than one log stream, " f"only returning the first one."
423
+ )
424
+ return all_info[0]
425
+
426
+ def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
417
427
  """
418
428
  Parse job description to extract AWS CloudWatch information.
419
429
 
420
430
  :param job_id: AWS Batch Job ID
421
431
  """
422
- job_container_desc = self.get_job_description(job_id=job_id).get("container", {})
423
- log_configuration = job_container_desc.get("logConfiguration", {})
424
-
425
- # In case if user select other "logDriver" rather than "awslogs"
426
- # than CloudWatch logging should be disabled.
427
- # If user not specify anything than expected that "awslogs" will use
428
- # with default settings:
429
- # awslogs-group = /aws/batch/job
430
- # awslogs-region = `same as AWS Batch Job region`
431
- log_driver = log_configuration.get("logDriver", "awslogs")
432
- if log_driver != "awslogs":
432
+ job_desc = self.get_job_description(job_id=job_id)
433
+
434
+ job_node_properties = job_desc.get("nodeProperties", {})
435
+ job_container_desc = job_desc.get("container", {})
436
+
437
+ if job_node_properties:
438
+ # one log config per node
439
+ log_configs = [
440
+ p.get("container", {}).get("logConfiguration", {})
441
+ for p in job_node_properties.get("nodeRangeProperties", {})
442
+ ]
443
+ # one stream name per attempt
444
+ stream_names = [a.get("container", {}).get("logStreamName") for a in job_desc.get("attempts", [])]
445
+ elif job_container_desc:
446
+ log_configs = [job_container_desc.get("logConfiguration", {})]
447
+ stream_name = job_container_desc.get("logStreamName")
448
+ stream_names = [stream_name] if stream_name is not None else []
449
+ else:
450
+ raise AirflowException(
451
+ f"AWS Batch job ({job_id}) is not a supported job type. "
452
+ "Supported job types: container, array, multinode."
453
+ )
454
+
455
+ # If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
456
+ if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]):
433
457
  self.log.warning(
434
- "AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver
458
+ f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
435
459
  )
436
- return None
460
+ return []
437
461
 
438
- awslogs_stream_name = job_container_desc.get("logStreamName")
439
- if not awslogs_stream_name:
440
- # In case of call this method on very early stage of running AWS Batch
441
- # there is possibility than AWS CloudWatch Stream Name not exists yet.
442
- # AWS CloudWatch Stream Name also not created in case of misconfiguration.
443
- self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id)
444
- return None
462
+ if not stream_names:
463
+ # If this method is called very early after starting the AWS Batch job,
464
+ # there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
465
+ # This can also happen in case of misconfiguration.
466
+ self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.")
467
+ return []
445
468
 
446
469
  # Try to get user-defined log configuration options
447
- log_options = log_configuration.get("options", {})
448
-
449
- return {
450
- "awslogs_stream_name": awslogs_stream_name,
451
- "awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"),
452
- "awslogs_region": log_options.get("awslogs-region", self.conn_region_name),
453
- }
470
+ log_options = [c.get("options", {}) for c in log_configs]
471
+
472
+ # cross stream names with options (i.e. attempts X nodes) to generate all log infos
473
+ result = []
474
+ for stream in stream_names:
475
+ for option in log_options:
476
+ result.append(
477
+ {
478
+ "awslogs_stream_name": stream,
479
+ # If the user did not specify anything, the default settings are:
480
+ # awslogs-group = /aws/batch/job
481
+ # awslogs-region = `same as AWS Batch Job region`
482
+ "awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
483
+ "awslogs_region": option.get("awslogs-region", self.conn_region_name),
484
+ }
485
+ )
486
+ return result
454
487
 
455
488
  @staticmethod
456
489
  def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
@@ -138,7 +138,9 @@ class BatchWaitersHook(BatchClientHook):
138
138
  """
139
139
  return self._waiter_model
140
140
 
141
- def get_waiter(self, waiter_name: str, _: dict[str, str] | None = None) -> botocore.waiter.Waiter:
141
+ def get_waiter(
142
+ self, waiter_name: str, _: dict[str, str] | None = None, deferrable: bool = False, client=None
143
+ ) -> botocore.waiter.Waiter:
142
144
  """
143
145
  Get an AWS Batch service waiter, using the configured ``.waiter_model``.
144
146
 
@@ -20,14 +20,12 @@ from __future__ import annotations
20
20
  import json
21
21
  import warnings
22
22
  from time import sleep
23
- from typing import Any, Callable
23
+ from typing import Any
24
24
 
25
25
  from botocore.exceptions import ClientError
26
26
 
27
- from airflow.compat.functools import cached_property
28
27
  from airflow.exceptions import AirflowException, AirflowNotFoundException
29
28
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
30
- from airflow.providers.amazon.aws.utils.waiter import get_state, waiter
31
29
  from airflow.utils.helpers import prune_dict
32
30
 
33
31
 
@@ -254,66 +252,41 @@ class EmrServerlessHook(AwsBaseHook):
254
252
  kwargs["client_type"] = "emr-serverless"
255
253
  super().__init__(*args, **kwargs)
256
254
 
257
- @cached_property
258
- def conn(self):
259
- """Get the underlying boto3 EmrServerlessAPIService client (cached)"""
260
- return super().conn
261
-
262
- # This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
263
- def waiter(
264
- self,
265
- get_state_callable: Callable,
266
- get_state_args: dict,
267
- parse_response: list,
268
- desired_state: set,
269
- failure_states: set,
270
- object_type: str,
271
- action: str,
272
- countdown: int = 25 * 60,
273
- check_interval_seconds: int = 60,
274
- ) -> None:
255
+ def cancel_running_jobs(self, application_id: str, waiter_config: dict = {}):
275
256
  """
276
- Will run the sensor until it turns True.
277
-
278
- :param get_state_callable: A callable to run until it returns True
279
- :param get_state_args: Arguments to pass to get_state_callable
280
- :param parse_response: Dictionary keys to extract state from response of get_state_callable
281
- :param desired_state: Wait until the getter returns this value
282
- :param failure_states: A set of states which indicate failure and should throw an
283
- exception if any are reached before the desired_state
284
- :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
285
- :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
286
- :param countdown: Total amount of time the waiter should wait for the desired state
287
- before timing out (in seconds). Defaults to 25 * 60 seconds.
288
- :param check_interval_seconds: Number of seconds waiter should wait before attempting
289
- to retry get_state_callable. Defaults to 60 seconds.
257
+ List all jobs in an intermediate state and cancel them.
258
+ Then wait for those jobs to reach a terminal state.
259
+ Note: if new jobs are triggered while this operation is ongoing,
260
+ it's going to time out and return an error.
290
261
  """
291
- warnings.warn(
292
- """This method is deprecated.
293
- Please use `airflow.providers.amazon.aws.utils.waiter.waiter`.""",
294
- DeprecationWarning,
295
- stacklevel=2,
296
- )
297
- waiter(
298
- get_state_callable=get_state_callable,
299
- get_state_args=get_state_args,
300
- parse_response=parse_response,
301
- desired_state=desired_state,
302
- failure_states=failure_states,
303
- object_type=object_type,
304
- action=action,
305
- countdown=countdown,
306
- check_interval_seconds=check_interval_seconds,
307
- )
308
-
309
- def get_state(self, response, keys) -> str:
310
- warnings.warn(
311
- """This method is deprecated.
312
- Please use `airflow.providers.amazon.aws.utils.waiter.get_state`.""",
313
- DeprecationWarning,
314
- stacklevel=2,
262
+ paginator = self.conn.get_paginator("list_job_runs")
263
+ results_per_response = 50
264
+ iterator = paginator.paginate(
265
+ applicationId=application_id,
266
+ states=list(self.JOB_INTERMEDIATE_STATES),
267
+ PaginationConfig={
268
+ "PageSize": results_per_response,
269
+ },
315
270
  )
316
- return get_state(response=response, keys=keys)
271
+ count = 0
272
+ for r in iterator:
273
+ job_ids = [jr["id"] for jr in r["jobRuns"]]
274
+ count += len(job_ids)
275
+ if len(job_ids) > 0:
276
+ self.log.info(
277
+ "Cancelling %s pending job(s) for the application %s so that it can be stopped",
278
+ len(job_ids),
279
+ application_id,
280
+ )
281
+ for job_id in job_ids:
282
+ self.conn.cancel_job_run(applicationId=application_id, jobRunId=job_id)
283
+ if count > 0:
284
+ self.log.info("now waiting for the %s cancelled job(s) to terminate", count)
285
+ self.get_waiter("no_job_running").wait(
286
+ applicationId=application_id,
287
+ states=list(self.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})),
288
+ WaiterConfig=waiter_config,
289
+ )
317
290
 
318
291
 
319
292
  class EmrContainerHook(AwsBaseHook):
@@ -483,7 +456,6 @@ class EmrContainerHook(AwsBaseHook):
483
456
  def poll_query_status(
484
457
  self,
485
458
  job_id: str,
486
- max_tries: int | None = None,
487
459
  poll_interval: int = 30,
488
460
  max_polling_attempts: int | None = None,
489
461
  ) -> str | None:
@@ -492,22 +464,9 @@ class EmrContainerHook(AwsBaseHook):
492
464
  Returns one of the final states.
493
465
 
494
466
  :param job_id: The ID of the job run request.
495
- :param max_tries: Deprecated - Use max_polling_attempts instead
496
467
  :param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
497
468
  :param max_polling_attempts: Number of times to poll for query state before function exits
498
469
  """
499
- if max_tries:
500
- warnings.warn(
501
- f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
502
- "in a future release. Please use method `max_polling_attempts` instead.",
503
- DeprecationWarning,
504
- stacklevel=2,
505
- )
506
- if max_polling_attempts and max_polling_attempts != max_tries:
507
- raise Exception("max_polling_attempts must be the same value as max_tries")
508
- else:
509
- max_polling_attempts = max_tries
510
-
511
470
  try_number = 1
512
471
  final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
513
472
 
@@ -25,6 +25,14 @@ from typing import Generator
25
25
 
26
26
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
27
27
 
28
+ # Guidance received from the AWS team regarding the correct way to check for the end of a stream is that the
29
+ # value of the nextForwardToken is the same in subsequent calls.
30
+ # The issue with this approach is, it can take a huge amount of time (e.g. 20 seconds) to retrieve logs using
31
+ # this approach. As an intermediate solution, we decided to stop fetching logs if 3 consecutive responses
32
+ # are empty.
33
+ # See PR https://github.com/apache/airflow/pull/20814
34
+ NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD = 3
35
+
28
36
 
29
37
  class AwsLogsHook(AwsBaseHook):
30
38
  """
@@ -69,14 +77,15 @@ class AwsLogsHook(AwsBaseHook):
69
77
  | 'message' (str): The log event data.
70
78
  | 'ingestionTime' (int): The time in milliseconds the event was ingested.
71
79
  """
80
+ num_consecutive_empty_response = 0
72
81
  next_token = None
73
82
  while True:
74
83
  if next_token is not None:
75
- token_arg: dict[str, str] | None = {"nextToken": next_token}
84
+ token_arg: dict[str, str] = {"nextToken": next_token}
76
85
  else:
77
86
  token_arg = {}
78
87
 
79
- response = self.get_conn().get_log_events(
88
+ response = self.conn.get_log_events(
80
89
  logGroupName=log_group,
81
90
  logStreamName=log_stream_name,
82
91
  startTime=start_time,
@@ -96,7 +105,16 @@ class AwsLogsHook(AwsBaseHook):
96
105
 
97
106
  yield from events
98
107
 
99
- if next_token != response["nextForwardToken"]:
100
- next_token = response["nextForwardToken"]
108
+ if not event_count:
109
+ num_consecutive_empty_response += 1
110
+ if num_consecutive_empty_response >= NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD:
111
+ # Exit if there are more than NUM_CONSECUTIVE_EMPTY_RESPONSE_EXIT_THRESHOLD consecutive
112
+ # empty responses
113
+ return
114
+ elif next_token != response["nextForwardToken"]:
115
+ num_consecutive_empty_response = 0
101
116
  else:
117
+ # Exit if the value of nextForwardToken is same in subsequent calls
102
118
  return
119
+
120
+ next_token = response["nextForwardToken"]