apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +1 -1
  3. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  4. airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
  5. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
  7. airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
  8. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
  9. airflow/providers/amazon/aws/hooks/athena.py +15 -2
  10. airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
  11. airflow/providers/amazon/aws/hooks/logs.py +85 -1
  12. airflow/providers/amazon/aws/hooks/neptune.py +85 -0
  13. airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
  14. airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
  15. airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
  16. airflow/providers/amazon/aws/hooks/s3.py +4 -6
  17. airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
  18. airflow/providers/amazon/aws/operators/eks.py +8 -6
  19. airflow/providers/amazon/aws/operators/neptune.py +218 -0
  20. airflow/providers/amazon/aws/operators/sagemaker.py +74 -15
  21. airflow/providers/amazon/aws/sensors/batch.py +2 -2
  22. airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
  23. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  24. airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
  25. airflow/providers/amazon/aws/triggers/neptune.py +115 -0
  26. airflow/providers/amazon/aws/triggers/rds.py +9 -7
  27. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
  28. airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
  29. airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
  30. airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
  31. airflow/providers/amazon/aws/utils/mixins.py +5 -1
  32. airflow/providers/amazon/aws/waiters/neptune.json +85 -0
  33. airflow/providers/amazon/get_provider_info.py +22 -2
  34. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/METADATA +6 -6
  35. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/RECORD +37 -33
  36. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/WHEEL +0 -0
  37. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc1.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ import packaging.version
27
27
 
28
28
  __all__ = ["__version__"]
29
29
 
30
- __version__ = "8.17.0"
30
+ __version__ = "8.18.0"
31
31
 
32
32
  try:
33
33
  from airflow import __version__ as airflow_version
@@ -80,7 +80,7 @@ def update_schema(args):
80
80
 
81
81
 
82
82
  def _get_client():
83
- """Returns Amazon Verified Permissions client."""
83
+ """Return Amazon Verified Permissions client."""
84
84
  region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
85
85
  return boto3.client("verifiedpermissions", region_name=region_name)
86
86
 
@@ -71,7 +71,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView):
71
71
  @expose("/login_callback", methods=("GET", "POST"))
72
72
  def login_callback(self):
73
73
  """
74
- Callback where the user is redirected to after successful login.
74
+ Redirect the user to this callback after successful login.
75
75
 
76
76
  CSRF protection needs to be disabled otherwise the callback won't work.
77
77
  """
@@ -37,7 +37,7 @@ COPY <<"EOF" /install_dags_entrypoint.sh
37
37
  #!/bin/bash
38
38
 
39
39
  echo "Downloading DAGs from S3 bucket"
40
- aws s3 sync "$S3_URL" "$CONTAINER_DAG_PATH"
40
+ aws s3 sync "$S3_URI" "$CONTAINER_DAG_PATH"
41
41
 
42
42
  /home/airflow/entrypoint.sh "$@"
43
43
  EOF
@@ -98,8 +98,8 @@ ENV CONTAINER_DAG_PATH=$container_dag_path
98
98
 
99
99
 
100
100
  # Use these arguments to load DAGs onto the container from S3
101
- ARG s3_url
102
- ENV S3_URL=$s3_url
101
+ ARG s3_uri
102
+ ENV S3_URI=$s3_uri
103
103
  # If using S3 bucket as source of DAGs, uncommenting the next ENTRYPOINT command will overwrite this one.
104
104
  ENTRYPOINT ["/usr/bin/dumb-init", "--", "/home/airflow/entrypoint.sh"]
105
105
 
@@ -61,7 +61,7 @@ class BotoTaskSchema(Schema):
61
61
 
62
62
  @post_load
63
63
  def make_task(self, data, **kwargs):
64
- """Overwrites marshmallow load() to return an instance of EcsExecutorTask instead of a dictionary."""
64
+ """Overwrite marshmallow load() to return an EcsExecutorTask instance instead of a dictionary."""
65
65
  # Imported here to avoid circular import.
66
66
  from airflow.providers.amazon.aws.executors.ecs.utils import EcsExecutorTask
67
67
 
@@ -42,9 +42,13 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
42
42
  EcsQueuedTask,
43
43
  EcsTaskCollection,
44
44
  )
45
- from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
45
+ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
46
+ calculate_next_attempt_delay,
47
+ exponential_backoff_retry,
48
+ )
46
49
  from airflow.providers.amazon.aws.hooks.ecs import EcsHook
47
50
  from airflow.utils import timezone
51
+ from airflow.utils.helpers import merge_dicts
48
52
  from airflow.utils.state import State
49
53
 
50
54
  if TYPE_CHECKING:
@@ -108,7 +112,7 @@ class AwsEcsExecutor(BaseExecutor):
108
112
  self.run_task_kwargs = self._load_run_kwargs()
109
113
 
110
114
  def start(self):
111
- """This is called by the scheduler when the Executor is being run for the first time."""
115
+ """Call this when the Executor is run for the first time by the scheduler."""
112
116
  check_health = conf.getboolean(
113
117
  CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
114
118
  )
@@ -213,7 +217,7 @@ class AwsEcsExecutor(BaseExecutor):
213
217
  self.log.exception("Failed to sync %s", self.__class__.__name__)
214
218
 
215
219
  def sync_running_tasks(self):
216
- """Checks and update state on all running tasks."""
220
+ """Check and update state on all running tasks."""
217
221
  all_task_arns = self.active_workers.get_all_arns()
218
222
  if not all_task_arns:
219
223
  self.log.debug("No active Airflow tasks, skipping sync.")
@@ -300,7 +304,14 @@ class AwsEcsExecutor(BaseExecutor):
300
304
  )
301
305
  self.active_workers.increment_failure_count(task_key)
302
306
  self.pending_tasks.appendleft(
303
- EcsQueuedTask(task_key, task_cmd, queue, exec_info, failure_count + 1)
307
+ EcsQueuedTask(
308
+ task_key,
309
+ task_cmd,
310
+ queue,
311
+ exec_info,
312
+ failure_count + 1,
313
+ timezone.utcnow() + calculate_next_attempt_delay(failure_count),
314
+ )
304
315
  )
305
316
  else:
306
317
  self.log.error(
@@ -313,7 +324,7 @@ class AwsEcsExecutor(BaseExecutor):
313
324
 
314
325
  def attempt_task_runs(self):
315
326
  """
316
- Takes tasks from the pending_tasks queue, and attempts to find an instance to run it on.
327
+ Take tasks from the pending_tasks queue, and attempts to find an instance to run it on.
317
328
 
318
329
  If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
319
330
  there are no EC2 instances available, no task is placed and this function will be
@@ -331,6 +342,8 @@ class AwsEcsExecutor(BaseExecutor):
331
342
  exec_config = ecs_task.executor_config
332
343
  attempt_number = ecs_task.attempt_number
333
344
  _failure_reasons = []
345
+ if timezone.utcnow() < ecs_task.next_attempt_time:
346
+ continue
334
347
  try:
335
348
  run_task_response = self._run_task(task_key, cmd, queue, exec_config)
336
349
  except NoCredentialsError:
@@ -361,6 +374,9 @@ class AwsEcsExecutor(BaseExecutor):
361
374
  # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
362
375
  if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
363
376
  ecs_task.attempt_number += 1
377
+ ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
378
+ attempt_number
379
+ )
364
380
  self.pending_tasks.appendleft(ecs_task)
365
381
  else:
366
382
  self.log.error(
@@ -393,8 +409,8 @@ class AwsEcsExecutor(BaseExecutor):
393
409
  The command and executor config will be placed in the container-override
394
410
  section of the JSON request before calling Boto3's "run_task" function.
395
411
  """
396
- run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
397
- boto_run_task = self.ecs.run_task(**run_task_api)
412
+ run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config)
413
+ boto_run_task = self.ecs.run_task(**run_task_kwargs)
398
414
  run_task_response = BotoRunTaskSchema().load(boto_run_task)
399
415
  return run_task_response
400
416
 
@@ -402,30 +418,32 @@ class AwsEcsExecutor(BaseExecutor):
402
418
  self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
403
419
  ) -> dict:
404
420
  """
405
- Overrides the Airflow command to update the container overrides so kwargs are specific to this task.
421
+ Update the Airflow command by modifying container overrides for task-specific kwargs.
406
422
 
407
423
  One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
408
424
  """
409
- run_task_api = deepcopy(self.run_task_kwargs)
410
- container_override = self.get_container(run_task_api["overrides"]["containerOverrides"])
425
+ run_task_kwargs = deepcopy(self.run_task_kwargs)
426
+ run_task_kwargs = merge_dicts(run_task_kwargs, exec_config)
427
+ container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"])
411
428
  container_override["command"] = cmd
412
- container_override.update(exec_config)
413
429
 
414
430
  # Inject the env variable to configure logging for containerized execution environment
415
431
  if "environment" not in container_override:
416
432
  container_override["environment"] = []
417
433
  container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"})
418
434
 
419
- return run_task_api
435
+ return run_task_kwargs
420
436
 
421
437
  def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
422
438
  """Save the task to be executed in the next sync by inserting the commands into a queue."""
423
439
  if executor_config and ("name" in executor_config or "command" in executor_config):
424
440
  raise ValueError('Executor Config should never override "name" or "command"')
425
- self.pending_tasks.append(EcsQueuedTask(key, command, queue, executor_config or {}, 1))
441
+ self.pending_tasks.append(
442
+ EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
443
+ )
426
444
 
427
445
  def end(self, heartbeat_interval=10):
428
- """Waits for all currently running tasks to end, and doesn't launch any tasks."""
446
+ """Wait for all currently running tasks to end, and don't launch any tasks."""
429
447
  try:
430
448
  while True:
431
449
  self.sync()
@@ -465,8 +483,13 @@ class AwsEcsExecutor(BaseExecutor):
465
483
  return ecs_executor_run_task_kwargs
466
484
 
467
485
  def get_container(self, container_list):
468
- """Searches task list for core Airflow container."""
486
+ """Search task list for core Airflow container."""
469
487
  for container in container_list:
470
- if container["name"] == self.container_name:
471
- return container
488
+ try:
489
+ if container["name"] == self.container_name:
490
+ return container
491
+ except KeyError:
492
+ raise EcsExecutorException(
493
+ 'container "name" must be provided in "containerOverrides" configuration'
494
+ )
472
495
  raise KeyError(f"No such container found by container name: {self.container_name}")
@@ -23,6 +23,7 @@ Data classes and utility functions used by the ECS executor.
23
23
 
24
24
  from __future__ import annotations
25
25
 
26
+ import datetime
26
27
  from collections import defaultdict
27
28
  from dataclasses import dataclass
28
29
  from typing import TYPE_CHECKING, Any, Callable, Dict, List
@@ -58,6 +59,7 @@ class EcsQueuedTask:
58
59
  queue: str
59
60
  executor_config: ExecutorConfigType
60
61
  attempt_number: int
62
+ next_attempt_time: datetime.datetime
61
63
 
62
64
 
63
65
  @dataclass
@@ -125,9 +127,9 @@ class EcsExecutorTask:
125
127
 
126
128
  def get_task_state(self) -> str:
127
129
  """
128
- This is the primary logic that handles state in an ECS task.
130
+ Determine the state of an ECS task based on its status and other relevant attributes.
129
131
 
130
- It will determine if a status is:
132
+ It can return one of the following statuses:
131
133
  QUEUED - Task is being provisioned.
132
134
  RUNNING - Task is launched on ECS.
133
135
  REMOVED - Task provisioning has failed for some reason. See `stopped_reason`.
@@ -171,7 +173,7 @@ class EcsTaskCollection:
171
173
  exec_config: ExecutorConfigType,
172
174
  attempt_number: int,
173
175
  ):
174
- """Adds a task to the collection."""
176
+ """Add a task to the collection."""
175
177
  arn = task.task_arn
176
178
  self.tasks[arn] = task
177
179
  self.key_to_arn[airflow_task_key] = arn
@@ -180,7 +182,7 @@ class EcsTaskCollection:
180
182
  self.key_to_failure_counts[airflow_task_key] = attempt_number
181
183
 
182
184
  def update_task(self, task: EcsExecutorTask):
183
- """Updates the state of the given task based on task ARN."""
185
+ """Update the state of the given task based on task ARN."""
184
186
  self.tasks[task.task_arn] = task
185
187
 
186
188
  def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
@@ -193,7 +195,7 @@ class EcsTaskCollection:
193
195
  return self.tasks[arn]
194
196
 
195
197
  def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
196
- """Deletes task from collection based off of Airflow Task Instance Key."""
198
+ """Delete task from collection based off of Airflow Task Instance Key."""
197
199
  arn = self.key_to_arn[task_key]
198
200
  task = self.tasks[arn]
199
201
  del self.key_to_arn[task_key]
@@ -225,11 +227,11 @@ class EcsTaskCollection:
225
227
  return self.key_to_task_info[task_key]
226
228
 
227
229
  def __getitem__(self, value):
228
- """Gets a task by AWS ARN."""
230
+ """Get a task by AWS ARN."""
229
231
  return self.task_by_arn(value)
230
232
 
231
233
  def __len__(self):
232
- """Determines the number of tasks in collection."""
234
+ """Determine the number of tasks in collection."""
233
235
  return len(self.tasks)
234
236
 
235
237
 
@@ -25,6 +25,21 @@ from airflow.utils import timezone
25
25
  log = logging.getLogger(__name__)
26
26
 
27
27
 
28
+ def calculate_next_attempt_delay(
29
+ attempt_number: int,
30
+ max_delay: int = 60 * 2,
31
+ exponent_base: int = 4,
32
+ ) -> timedelta:
33
+ """
34
+ Calculate the exponential backoff (in seconds) until the next attempt.
35
+
36
+ :param attempt_number: Number of attempts since last success.
37
+ :param max_delay: Maximum delay in seconds between retries. Default 120.
38
+ :param exponent_base: Exponent base to calculate delay. Default 4.
39
+ """
40
+ return timedelta(seconds=min((exponent_base**attempt_number), max_delay))
41
+
42
+
28
43
  def exponential_backoff_retry(
29
44
  last_attempt_time: datetime,
30
45
  attempts_since_last_successful: int,
@@ -34,7 +49,7 @@ def exponential_backoff_retry(
34
49
  exponent_base: int = 4,
35
50
  ) -> None:
36
51
  """
37
- Retries a callable function with exponential backoff between attempts if it raises an exception.
52
+ Retry a callable function with exponential backoff between attempts if it raises an exception.
38
53
 
39
54
  :param last_attempt_time: Timestamp of last attempt call.
40
55
  :param attempts_since_last_successful: Number of attempts since last success.
@@ -47,8 +62,10 @@ def exponential_backoff_retry(
47
62
  log.error("Max attempts reached. Exiting.")
48
63
  return
49
64
 
50
- delay = min((exponent_base**attempts_since_last_successful), max_delay)
51
- next_retry_time = last_attempt_time + timedelta(seconds=delay)
65
+ next_retry_time = last_attempt_time + calculate_next_attempt_delay(
66
+ attempt_number=attempts_since_last_successful, max_delay=max_delay, exponent_base=exponent_base
67
+ )
68
+
52
69
  current_time = timezone.utcnow()
53
70
 
54
71
  if current_time >= next_retry_time:
@@ -56,5 +73,7 @@ def exponential_backoff_retry(
56
73
  callable_function()
57
74
  except Exception:
58
75
  log.exception("Error calling %r", callable_function.__name__)
59
- next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
76
+ next_delay = calculate_next_attempt_delay(
77
+ attempts_since_last_successful + 1, max_delay, exponent_base
78
+ )
60
79
  log.info("Waiting for %s seconds before retrying.", next_delay)
@@ -25,7 +25,7 @@ This module contains AWS Athena hook.
25
25
  from __future__ import annotations
26
26
 
27
27
  import warnings
28
- from typing import TYPE_CHECKING, Any
28
+ from typing import TYPE_CHECKING, Any, Collection
29
29
 
30
30
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
31
31
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -34,6 +34,19 @@ from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
34
34
  if TYPE_CHECKING:
35
35
  from botocore.paginate import PageIterator
36
36
 
37
+ MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
38
+
39
+
40
+ def query_params_to_string(params: dict[str, str | Collection[str]]) -> str:
41
+ result = ""
42
+ for key, value in params.items():
43
+ if key == "QueryString":
44
+ value = (
45
+ MULTI_LINE_QUERY_LOG_PREFIX + str(value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
46
+ )
47
+ result += f"\t{key}: {value}\n"
48
+ return result.rstrip()
49
+
37
50
 
38
51
  class AthenaHook(AwsBaseHook):
39
52
  """Interact with Amazon Athena.
@@ -115,7 +128,7 @@ class AthenaHook(AwsBaseHook):
115
128
  if client_request_token:
116
129
  params["ClientRequestToken"] = client_request_token
117
130
  if self.log_query:
118
- self.log.info("Running Query with params: %s", params)
131
+ self.log.info("Running Query with params:\n%s", query_params_to_string(params))
119
132
  response = self.get_conn().start_query_execution(**params)
120
133
  query_execution_id = response["QueryExecutionId"]
121
134
  self.log.info("Query execution id: %s", query_execution_id)
@@ -29,7 +29,6 @@ import inspect
29
29
  import json
30
30
  import logging
31
31
  import os
32
- import warnings
33
32
  from copy import deepcopy
34
33
  from functools import cached_property, wraps
35
34
  from pathlib import Path
@@ -44,6 +43,7 @@ import tenacity
44
43
  from botocore.config import Config
45
44
  from botocore.waiter import Waiter, WaiterModel
46
45
  from dateutil.tz import tzlocal
46
+ from deprecated import deprecated
47
47
  from slugify import slugify
48
48
 
49
49
  from airflow.configuration import conf
@@ -314,7 +314,7 @@ class BaseSessionFactory(LoggingMixin):
314
314
  idp_request_retry_kwargs = saml_config["idp_request_retry_kwargs"]
315
315
  self.log.info("idp_request_retry_kwargs= %s", idp_request_retry_kwargs)
316
316
  from requests.adapters import HTTPAdapter
317
- from requests.packages.urllib3.util.retry import Retry
317
+ from urllib3.util.retry import Retry
318
318
 
319
319
  retry_strategy = Retry(**idp_request_retry_kwargs)
320
320
  adapter = HTTPAdapter(max_retries=retry_strategy)
@@ -1020,6 +1020,13 @@ except ImportError:
1020
1020
  pass
1021
1021
 
1022
1022
 
1023
+ @deprecated(
1024
+ reason=(
1025
+ "airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory "
1026
+ "has been deprecated and will be removed in future"
1027
+ ),
1028
+ category=AirflowProviderDeprecationWarning,
1029
+ )
1023
1030
  class BaseAsyncSessionFactory(BaseSessionFactory):
1024
1031
  """
1025
1032
  Base AWS Session Factory class to handle aiobotocore session creation.
@@ -1029,12 +1036,6 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
1029
1036
  """
1030
1037
 
1031
1038
  def __init__(self, *args, **kwargs):
1032
- warnings.warn(
1033
- "airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and "
1034
- "will be removed in future",
1035
- AirflowProviderDeprecationWarning,
1036
- stacklevel=2,
1037
- )
1038
1039
  super().__init__(*args, **kwargs)
1039
1040
 
1040
1041
  async def get_role_credentials(self) -> dict:
@@ -1113,6 +1114,13 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
1113
1114
  return self._get_session_with_assume_role()
1114
1115
 
1115
1116
 
1117
+ @deprecated(
1118
+ reason=(
1119
+ "airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook "
1120
+ "has been deprecated and will be removed in future"
1121
+ ),
1122
+ category=AirflowProviderDeprecationWarning,
1123
+ )
1116
1124
  class AwsBaseAsyncHook(AwsBaseHook):
1117
1125
  """Interacts with AWS using aiobotocore asynchronously.
1118
1126
 
@@ -1129,12 +1137,6 @@ class AwsBaseAsyncHook(AwsBaseHook):
1129
1137
  """
1130
1138
 
1131
1139
  def __init__(self, *args, **kwargs):
1132
- warnings.warn(
1133
- "airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and "
1134
- "will be removed in future",
1135
- AirflowProviderDeprecationWarning,
1136
- stacklevel=2,
1137
- )
1138
1140
  super().__init__(*args, **kwargs)
1139
1141
 
1140
1142
  def get_async_session(self) -> AioSession:
@@ -17,8 +17,11 @@
17
17
  # under the License.
18
18
  from __future__ import annotations
19
19
 
20
+ import asyncio
20
21
  import warnings
21
- from typing import Generator
22
+ from typing import Any, AsyncGenerator, Generator
23
+
24
+ from botocore.exceptions import ClientError
22
25
 
23
26
  from airflow.exceptions import AirflowProviderDeprecationWarning
24
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -151,3 +154,84 @@ class AwsLogsHook(AwsBaseHook):
151
154
  num_consecutive_empty_response = 0
152
155
 
153
156
  continuation_token.value = response["nextForwardToken"]
157
+
158
+ async def describe_log_streams_async(
159
+ self, log_group: str, stream_prefix: str, order_by: str, count: int
160
+ ) -> dict[str, Any] | None:
161
+ """Async function to get the list of log streams for the specified log group.
162
+
163
+ You can list all the log streams or filter the results by prefix. You can also control
164
+ how the results are ordered.
165
+
166
+ :param log_group: The name of the log group.
167
+ :param stream_prefix: The prefix to match.
168
+ :param order_by: If the value is LogStreamName , the results are ordered by log stream name.
169
+ If the value is LastEventTime , the results are ordered by the event time. The default value is LogStreamName.
170
+ :param count: The maximum number of items returned
171
+ """
172
+ async with self.async_conn as client:
173
+ try:
174
+ response: dict[str, Any] = await client.describe_log_streams(
175
+ logGroupName=log_group,
176
+ logStreamNamePrefix=stream_prefix,
177
+ orderBy=order_by,
178
+ limit=count,
179
+ )
180
+ return response
181
+ except ClientError as error:
182
+ # On the very first training job run on an account, there's no log group until
183
+ # the container starts logging, so ignore any errors thrown about that
184
+ if error.response["Error"]["Code"] == "ResourceNotFoundException":
185
+ return None
186
+ raise error
187
+
188
+ async def get_log_events_async(
189
+ self,
190
+ log_group: str,
191
+ log_stream_name: str,
192
+ start_time: int = 0,
193
+ skip: int = 0,
194
+ start_from_head: bool = True,
195
+ ) -> AsyncGenerator[Any, dict[str, Any]]:
196
+ """Yield all the available items in a single log stream.
197
+
198
+ :param log_group: The name of the log group.
199
+ :param log_stream_name: The name of the specific stream.
200
+ :param start_time: The time stamp value to start reading the logs from (default: 0).
201
+ :param skip: The number of log entries to skip at the start (default: 0).
202
+ This is for when there are multiple entries at the same timestamp.
203
+ :param start_from_head: whether to start from the beginning (True) of the log or
204
+ at the end of the log (False).
205
+ """
206
+ next_token = None
207
+ while True:
208
+ if next_token is not None:
209
+ token_arg: dict[str, str] = {"nextToken": next_token}
210
+ else:
211
+ token_arg = {}
212
+
213
+ async with self.async_conn as client:
214
+ response = await client.get_log_events(
215
+ logGroupName=log_group,
216
+ logStreamName=log_stream_name,
217
+ startTime=start_time,
218
+ startFromHead=start_from_head,
219
+ **token_arg,
220
+ )
221
+
222
+ events = response["events"]
223
+ event_count = len(events)
224
+
225
+ if event_count > skip:
226
+ events = events[skip:]
227
+ skip = 0
228
+ else:
229
+ skip -= event_count
230
+ events = []
231
+
232
+ for event in events:
233
+ await asyncio.sleep(1)
234
+ yield event
235
+
236
+ if next_token != response["nextForwardToken"]:
237
+ next_token = response["nextForwardToken"]
@@ -0,0 +1,85 @@
1
+ #
2
+ # Licensed to the Apache Software Foundation (ASF) under one
3
+ # or more contributor license agreements. See the NOTICE file
4
+ # distributed with this work for additional information
5
+ # regarding copyright ownership. The ASF licenses this file
6
+ # to you under the Apache License, Version 2.0 (the
7
+ # "License"); you may not use this file except in compliance
8
+ # with the License. You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing,
13
+ # software distributed under the License is distributed on an
14
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ # KIND, either express or implied. See the License for the
16
+ # specific language governing permissions and limitations
17
+ # under the License.
18
+
19
+ from __future__ import annotations
20
+
21
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
22
+
23
+
24
+ class NeptuneHook(AwsBaseHook):
25
+ """
26
+ Interact with Amazon Neptune.
27
+
28
+ Additional arguments (such as ``aws_conn_id``) may be specified and
29
+ are passed down to the underlying AwsBaseHook.
30
+
31
+ .. seealso::
32
+ - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
33
+ """
34
+
35
+ AVAILABLE_STATES = ["available"]
36
+ STOPPED_STATES = ["stopped"]
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ kwargs["client_type"] = "neptune"
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def wait_for_cluster_availability(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
43
+ """
44
+ Wait for Neptune cluster to start.
45
+
46
+ :param cluster_id: The ID of the cluster to wait for.
47
+ :param delay: Time in seconds to delay between polls.
48
+ :param max_attempts: Maximum number of attempts to poll for completion.
49
+ :return: The status of the cluster.
50
+ """
51
+ self.get_waiter("cluster_available").wait(
52
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
53
+ )
54
+
55
+ status = self.get_cluster_status(cluster_id)
56
+ self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
57
+
58
+ return status
59
+
60
+ def wait_for_cluster_stopped(self, cluster_id: str, delay: int = 30, max_attempts: int = 60) -> str:
61
+ """
62
+ Wait for Neptune cluster to stop.
63
+
64
+ :param cluster_id: The ID of the cluster to wait for.
65
+ :param delay: Time in seconds to delay between polls.
66
+ :param max_attempts: Maximum number of attempts to poll for completion.
67
+ :return: The status of the cluster.
68
+ """
69
+ self.get_waiter("cluster_stopped").wait(
70
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay, "MaxAttempts": max_attempts}
71
+ )
72
+
73
+ status = self.get_cluster_status(cluster_id)
74
+ self.log.info("Finished waiting for cluster %s. Status is now %s", cluster_id, status)
75
+
76
+ return status
77
+
78
+ def get_cluster_status(self, cluster_id: str) -> str:
79
+ """
80
+ Get the status of a Neptune cluster.
81
+
82
+ :param cluster_id: The ID of the cluster to get the status of.
83
+ :return: The status of the cluster.
84
+ """
85
+ return self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
@@ -18,10 +18,10 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import time
21
- import warnings
22
21
  from functools import cached_property
23
22
 
24
23
  from botocore.exceptions import ClientError
24
+ from deprecated import deprecated
25
25
 
26
26
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
27
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -172,14 +172,15 @@ class QuickSightHook(AwsBaseHook):
172
172
  return status
173
173
 
174
174
  @cached_property
175
- def sts_hook(self):
176
- warnings.warn(
177
- f"`{type(self).__name__}.sts_hook` property is deprecated and will be removed in the future. "
175
+ @deprecated(
176
+ reason=(
177
+ "`QuickSightHook.sts_hook` property is deprecated and will be removed in the future. "
178
178
  "This property used for obtain AWS Account ID, "
179
- f"please consider to use `{type(self).__name__}.account_id` instead",
180
- AirflowProviderDeprecationWarning,
181
- stacklevel=2,
182
- )
179
+ "please consider to use `QuickSightHook.account_id` instead"
180
+ ),
181
+ category=AirflowProviderDeprecationWarning,
182
+ )
183
+ def sts_hook(self):
183
184
  from airflow.providers.amazon.aws.hooks.sts import StsHook
184
185
 
185
186
  return StsHook(aws_conn_id=self.aws_conn_id)