apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
  3. airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
  4. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
  5. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  6. airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
  7. airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
  8. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
  9. airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
  10. airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
  11. airflow/providers/amazon/aws/hooks/athena.py +15 -2
  12. airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
  13. airflow/providers/amazon/aws/hooks/emr.py +6 -0
  14. airflow/providers/amazon/aws/hooks/logs.py +85 -1
  15. airflow/providers/amazon/aws/hooks/neptune.py +85 -0
  16. airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
  17. airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
  18. airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
  19. airflow/providers/amazon/aws/hooks/s3.py +4 -6
  20. airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
  21. airflow/providers/amazon/aws/links/emr.py +122 -2
  22. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
  23. airflow/providers/amazon/aws/operators/athena.py +4 -1
  24. airflow/providers/amazon/aws/operators/batch.py +5 -6
  25. airflow/providers/amazon/aws/operators/ecs.py +6 -2
  26. airflow/providers/amazon/aws/operators/eks.py +31 -26
  27. airflow/providers/amazon/aws/operators/emr.py +192 -26
  28. airflow/providers/amazon/aws/operators/glue.py +5 -2
  29. airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
  30. airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
  31. airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
  32. airflow/providers/amazon/aws/operators/neptune.py +218 -0
  33. airflow/providers/amazon/aws/operators/rds.py +21 -12
  34. airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
  35. airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
  36. airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
  37. airflow/providers/amazon/aws/operators/step_function.py +4 -1
  38. airflow/providers/amazon/aws/sensors/batch.py +2 -2
  39. airflow/providers/amazon/aws/sensors/ec2.py +4 -2
  40. airflow/providers/amazon/aws/sensors/emr.py +13 -6
  41. airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
  42. airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
  43. airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
  44. airflow/providers/amazon/aws/sensors/s3.py +3 -0
  45. airflow/providers/amazon/aws/sensors/sqs.py +4 -1
  46. airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
  47. airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
  48. airflow/providers/amazon/aws/triggers/neptune.py +115 -0
  49. airflow/providers/amazon/aws/triggers/rds.py +9 -7
  50. airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
  51. airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
  52. airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
  53. airflow/providers/amazon/aws/utils/__init__.py +10 -0
  54. airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
  55. airflow/providers/amazon/aws/utils/mixins.py +5 -1
  56. airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
  57. airflow/providers/amazon/aws/waiters/neptune.json +85 -0
  58. airflow/providers/amazon/get_provider_info.py +26 -2
  59. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
  60. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
  61. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
  62. {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -27,7 +27,7 @@ import packaging.version
27
27
 
28
28
  __all__ = ["__version__"]
29
29
 
30
- __version__ = "8.17.0"
30
+ __version__ = "8.18.0"
31
31
 
32
32
  try:
33
33
  from airflow import __version__ as airflow_version
@@ -55,7 +55,7 @@ def init_avp(args):
55
55
  if not is_new_policy_store:
56
56
  print(
57
57
  f"Since an existing policy store with description '{args.policy_store_description}' has been found in Amazon Verified Permissions, "
58
- "the CLI nade no changes to this policy store for security reasons. "
58
+ "the CLI made no changes to this policy store for security reasons. "
59
59
  "Any modification to this policy store must be done manually.",
60
60
  )
61
61
  else:
@@ -80,7 +80,7 @@ def update_schema(args):
80
80
 
81
81
 
82
82
  def _get_client():
83
- """Returns Amazon Verified Permissions client."""
83
+ """Return Amazon Verified Permissions client."""
84
84
  region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
85
85
  return boto3.client("verifiedpermissions", region_name=region_name)
86
86
 
@@ -115,7 +115,7 @@ def _create_policy_store(client: BaseClient, args) -> tuple[str | None, bool]:
115
115
  print(f"No policy store with description '{args.policy_store_description}' found, creating one.")
116
116
  if args.dry_run:
117
117
  print(
118
- "Dry run, not creating the policy store with description '{args.policy_store_description}'."
118
+ f"Dry run, not creating the policy store with description '{args.policy_store_description}'."
119
119
  )
120
120
  return None, True
121
121
 
@@ -35,6 +35,14 @@ ARG_DRY_RUN = Arg(
35
35
  action="store_true",
36
36
  )
37
37
 
38
+ # AWS IAM Identity Center
39
+ ARG_INSTANCE_NAME = Arg(("--instance-name",), help="Instance name in Identity Center", default="Airflow")
40
+
41
+ ARG_APPLICATION_NAME = Arg(
42
+ ("--application-name",), help="Application name in Identity Center", default="Airflow"
43
+ )
44
+
45
+
38
46
  # Amazon Verified Permissions
39
47
  ARG_POLICY_STORE_DESCRIPTION = Arg(
40
48
  ("--policy-store-description",), help="Policy store description", default="Airflow"
@@ -47,6 +55,12 @@ ARG_POLICY_STORE_ID = Arg(("--policy-store-id",), help="Policy store ID")
47
55
  ################
48
56
 
49
57
  AWS_AUTH_MANAGER_COMMANDS = (
58
+ ActionCommand(
59
+ name="init-identity-center",
60
+ help="Initialize AWS IAM identity Center resources to be used by AWS manager",
61
+ func=lazy_load_command("airflow.providers.amazon.aws.auth_manager.cli.idc_commands.init_idc"),
62
+ args=(ARG_INSTANCE_NAME, ARG_APPLICATION_NAME, ARG_DRY_RUN, ARG_VERBOSE),
63
+ ),
50
64
  ActionCommand(
51
65
  name="init-avp",
52
66
  help="Initialize Amazon Verified resources to be used by AWS manager",
@@ -0,0 +1,148 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ """User sub-commands."""
18
+ from __future__ import annotations
19
+
20
+ import logging
21
+ from typing import TYPE_CHECKING
22
+
23
+ import boto3
24
+ from botocore.exceptions import ClientError
25
+
26
+ from airflow.configuration import conf
27
+ from airflow.exceptions import AirflowOptionalProviderFeatureException
28
+ from airflow.providers.amazon.aws.auth_manager.constants import CONF_REGION_NAME_KEY, CONF_SECTION_NAME
29
+ from airflow.utils import cli as cli_utils
30
+
31
+ try:
32
+ from airflow.utils.providers_configuration_loader import providers_configuration_loaded
33
+ except ImportError:
34
+ raise AirflowOptionalProviderFeatureException(
35
+ "Failed to import avp_commands. This feature is only available in Airflow "
36
+ "version >= 2.8.0 where Auth Managers are introduced."
37
+ )
38
+
39
+ if TYPE_CHECKING:
40
+ from botocore.client import BaseClient
41
+
42
+ log = logging.getLogger(__name__)
43
+
44
+
45
+ @cli_utils.action_cli
46
+ @providers_configuration_loaded
47
+ def init_idc(args):
48
+ """Initialize AWS IAM Identity Center resources."""
49
+ client = _get_client()
50
+
51
+ # Create the instance if needed
52
+ instance_arn = _create_instance(client, args)
53
+
54
+ # Create the application if needed
55
+ _create_application(client, instance_arn, args)
56
+
57
+ if not args.dry_run:
58
+ print("AWS IAM Identity Center resources created successfully.")
59
+
60
+
61
+ def _get_client():
62
+ """Return AWS IAM Identity Center client."""
63
+ region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
64
+ return boto3.client("sso-admin", region_name=region_name)
65
+
66
+
67
+ def _create_instance(client: BaseClient, args) -> str | None:
68
+ """Create if needed AWS IAM Identity Center instance."""
69
+ instances = client.list_instances()
70
+
71
+ if args.verbose:
72
+ log.debug("Instances found: %s", instances)
73
+
74
+ if len(instances["Instances"]) > 0:
75
+ print(
76
+ f"There is already an instance configured in AWS IAM Identity Center: '{instances['Instances'][0]['InstanceArn']}'. "
77
+ "No need to create a new one."
78
+ )
79
+ return instances["Instances"][0]["InstanceArn"]
80
+ else:
81
+ print("No instance configured in AWS IAM Identity Center, creating one.")
82
+ if args.dry_run:
83
+ print("Dry run, not creating the instance.")
84
+ return None
85
+
86
+ response = client.create_instance(Name=args.instance_name)
87
+ if args.verbose:
88
+ log.debug("Response from create_instance: %s", response)
89
+
90
+ print(f"Instance created: '{response['InstanceArn']}'")
91
+
92
+ return response["InstanceArn"]
93
+
94
+
95
+ def _create_application(client: BaseClient, instance_arn: str | None, args) -> str | None:
96
+ """Create if needed AWS IAM identity Center application."""
97
+ paginator = client.get_paginator("list_applications")
98
+ pages = paginator.paginate(InstanceArn=instance_arn or "")
99
+ applications = [application for page in pages for application in page["Applications"]]
100
+ existing_applications = [
101
+ application for application in applications if application["Name"] == args.application_name
102
+ ]
103
+
104
+ if args.verbose:
105
+ log.debug("Applications found: %s", applications)
106
+ log.debug("Existing applications found: %s", existing_applications)
107
+
108
+ if len(existing_applications) > 0:
109
+ print(
110
+ f"There is already an application named '{args.application_name}' in AWS IAM Identity Center: '{existing_applications[0]['ApplicationArn']}'. "
111
+ "Using this application."
112
+ )
113
+ return existing_applications[0]["ApplicationArn"]
114
+ else:
115
+ print(f"No application named {args.application_name} found, creating one.")
116
+ if args.dry_run:
117
+ print("Dry run, not creating the application.")
118
+ return None
119
+
120
+ try:
121
+ response = client.create_application(
122
+ ApplicationProviderArn="arn:aws:sso::aws:applicationProvider/custom-saml",
123
+ Description="Application automatically created through the Airflow CLI. This application is used to access Airflow environment.",
124
+ InstanceArn=instance_arn,
125
+ Name=args.application_name,
126
+ PortalOptions={
127
+ "SignInOptions": {
128
+ "Origin": "IDENTITY_CENTER",
129
+ },
130
+ "Visibility": "ENABLED",
131
+ },
132
+ Status="ENABLED",
133
+ )
134
+ if args.verbose:
135
+ log.debug("Response from create_application: %s", response)
136
+ except ClientError as e:
137
+ # This is needed because as of today, the create_application in AWS Identity Center does not support SAML application
138
+ # Remove this part when it is supported
139
+ if "is not supported for this action" in e.response["Error"]["Message"]:
140
+ print(
141
+ "Creation of SAML applications is only supported in AWS console today. "
142
+ "Please create the application through the console."
143
+ )
144
+ raise
145
+
146
+ print(f"Application created: '{response['ApplicationArn']}'")
147
+
148
+ return response["ApplicationArn"]
@@ -71,7 +71,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView):
71
71
  @expose("/login_callback", methods=("GET", "POST"))
72
72
  def login_callback(self):
73
73
  """
74
- Callback where the user is redirected to after successful login.
74
+ Redirect the user to this callback after successful login.
75
75
 
76
76
  CSRF protection needs to be disabled otherwise the callback won't work.
77
77
  """
@@ -37,7 +37,7 @@ COPY <<"EOF" /install_dags_entrypoint.sh
37
37
  #!/bin/bash
38
38
 
39
39
  echo "Downloading DAGs from S3 bucket"
40
- aws s3 sync "$S3_URL" "$CONTAINER_DAG_PATH"
40
+ aws s3 sync "$S3_URI" "$CONTAINER_DAG_PATH"
41
41
 
42
42
  /home/airflow/entrypoint.sh "$@"
43
43
  EOF
@@ -98,8 +98,8 @@ ENV CONTAINER_DAG_PATH=$container_dag_path
98
98
 
99
99
 
100
100
  # Use these arguments to load DAGs onto the container from S3
101
- ARG s3_url
102
- ENV S3_URL=$s3_url
101
+ ARG s3_uri
102
+ ENV S3_URI=$s3_uri
103
103
  # If using S3 bucket as source of DAGs, uncommenting the next ENTRYPOINT command will overwrite this one.
104
104
  ENTRYPOINT ["/usr/bin/dumb-init", "--", "/home/airflow/entrypoint.sh"]
105
105
 
@@ -61,7 +61,7 @@ class BotoTaskSchema(Schema):
61
61
 
62
62
  @post_load
63
63
  def make_task(self, data, **kwargs):
64
- """Overwrites marshmallow load() to return an instance of EcsExecutorTask instead of a dictionary."""
64
+ """Overwrite marshmallow load() to return an EcsExecutorTask instance instead of a dictionary."""
65
65
  # Imported here to avoid circular import.
66
66
  from airflow.providers.amazon.aws.executors.ecs.utils import EcsExecutorTask
67
67
 
@@ -42,9 +42,13 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
42
42
  EcsQueuedTask,
43
43
  EcsTaskCollection,
44
44
  )
45
- from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry
45
+ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
46
+ calculate_next_attempt_delay,
47
+ exponential_backoff_retry,
48
+ )
46
49
  from airflow.providers.amazon.aws.hooks.ecs import EcsHook
47
50
  from airflow.utils import timezone
51
+ from airflow.utils.helpers import merge_dicts
48
52
  from airflow.utils.state import State
49
53
 
50
54
  if TYPE_CHECKING:
@@ -108,7 +112,7 @@ class AwsEcsExecutor(BaseExecutor):
108
112
  self.run_task_kwargs = self._load_run_kwargs()
109
113
 
110
114
  def start(self):
111
- """This is called by the scheduler when the Executor is being run for the first time."""
115
+ """Call this when the Executor is run for the first time by the scheduler."""
112
116
  check_health = conf.getboolean(
113
117
  CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
114
118
  )
@@ -213,7 +217,7 @@ class AwsEcsExecutor(BaseExecutor):
213
217
  self.log.exception("Failed to sync %s", self.__class__.__name__)
214
218
 
215
219
  def sync_running_tasks(self):
216
- """Checks and update state on all running tasks."""
220
+ """Check and update state on all running tasks."""
217
221
  all_task_arns = self.active_workers.get_all_arns()
218
222
  if not all_task_arns:
219
223
  self.log.debug("No active Airflow tasks, skipping sync.")
@@ -300,7 +304,14 @@ class AwsEcsExecutor(BaseExecutor):
300
304
  )
301
305
  self.active_workers.increment_failure_count(task_key)
302
306
  self.pending_tasks.appendleft(
303
- EcsQueuedTask(task_key, task_cmd, queue, exec_info, failure_count + 1)
307
+ EcsQueuedTask(
308
+ task_key,
309
+ task_cmd,
310
+ queue,
311
+ exec_info,
312
+ failure_count + 1,
313
+ timezone.utcnow() + calculate_next_attempt_delay(failure_count),
314
+ )
304
315
  )
305
316
  else:
306
317
  self.log.error(
@@ -313,7 +324,7 @@ class AwsEcsExecutor(BaseExecutor):
313
324
 
314
325
  def attempt_task_runs(self):
315
326
  """
316
- Takes tasks from the pending_tasks queue, and attempts to find an instance to run it on.
327
+ Take tasks from the pending_tasks queue, and attempts to find an instance to run it on.
317
328
 
318
329
  If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
319
330
  there are no EC2 instances available, no task is placed and this function will be
@@ -331,6 +342,8 @@ class AwsEcsExecutor(BaseExecutor):
331
342
  exec_config = ecs_task.executor_config
332
343
  attempt_number = ecs_task.attempt_number
333
344
  _failure_reasons = []
345
+ if timezone.utcnow() < ecs_task.next_attempt_time:
346
+ continue
334
347
  try:
335
348
  run_task_response = self._run_task(task_key, cmd, queue, exec_config)
336
349
  except NoCredentialsError:
@@ -361,6 +374,9 @@ class AwsEcsExecutor(BaseExecutor):
361
374
  # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
362
375
  if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
363
376
  ecs_task.attempt_number += 1
377
+ ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
378
+ attempt_number
379
+ )
364
380
  self.pending_tasks.appendleft(ecs_task)
365
381
  else:
366
382
  self.log.error(
@@ -393,8 +409,8 @@ class AwsEcsExecutor(BaseExecutor):
393
409
  The command and executor config will be placed in the container-override
394
410
  section of the JSON request before calling Boto3's "run_task" function.
395
411
  """
396
- run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
397
- boto_run_task = self.ecs.run_task(**run_task_api)
412
+ run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config)
413
+ boto_run_task = self.ecs.run_task(**run_task_kwargs)
398
414
  run_task_response = BotoRunTaskSchema().load(boto_run_task)
399
415
  return run_task_response
400
416
 
@@ -402,30 +418,32 @@ class AwsEcsExecutor(BaseExecutor):
402
418
  self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
403
419
  ) -> dict:
404
420
  """
405
- Overrides the Airflow command to update the container overrides so kwargs are specific to this task.
421
+ Update the Airflow command by modifying container overrides for task-specific kwargs.
406
422
 
407
423
  One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
408
424
  """
409
- run_task_api = deepcopy(self.run_task_kwargs)
410
- container_override = self.get_container(run_task_api["overrides"]["containerOverrides"])
425
+ run_task_kwargs = deepcopy(self.run_task_kwargs)
426
+ run_task_kwargs = merge_dicts(run_task_kwargs, exec_config)
427
+ container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"])
411
428
  container_override["command"] = cmd
412
- container_override.update(exec_config)
413
429
 
414
430
  # Inject the env variable to configure logging for containerized execution environment
415
431
  if "environment" not in container_override:
416
432
  container_override["environment"] = []
417
433
  container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"})
418
434
 
419
- return run_task_api
435
+ return run_task_kwargs
420
436
 
421
437
  def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
422
438
  """Save the task to be executed in the next sync by inserting the commands into a queue."""
423
439
  if executor_config and ("name" in executor_config or "command" in executor_config):
424
440
  raise ValueError('Executor Config should never override "name" or "command"')
425
- self.pending_tasks.append(EcsQueuedTask(key, command, queue, executor_config or {}, 1))
441
+ self.pending_tasks.append(
442
+ EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
443
+ )
426
444
 
427
445
  def end(self, heartbeat_interval=10):
428
- """Waits for all currently running tasks to end, and doesn't launch any tasks."""
446
+ """Wait for all currently running tasks to end, and don't launch any tasks."""
429
447
  try:
430
448
  while True:
431
449
  self.sync()
@@ -465,8 +483,13 @@ class AwsEcsExecutor(BaseExecutor):
465
483
  return ecs_executor_run_task_kwargs
466
484
 
467
485
  def get_container(self, container_list):
468
- """Searches task list for core Airflow container."""
486
+ """Search task list for core Airflow container."""
469
487
  for container in container_list:
470
- if container["name"] == self.container_name:
471
- return container
488
+ try:
489
+ if container["name"] == self.container_name:
490
+ return container
491
+ except KeyError:
492
+ raise EcsExecutorException(
493
+ 'container "name" must be provided in "containerOverrides" configuration'
494
+ )
472
495
  raise KeyError(f"No such container found by container name: {self.container_name}")
@@ -23,6 +23,7 @@ Data classes and utility functions used by the ECS executor.
23
23
 
24
24
  from __future__ import annotations
25
25
 
26
+ import datetime
26
27
  from collections import defaultdict
27
28
  from dataclasses import dataclass
28
29
  from typing import TYPE_CHECKING, Any, Callable, Dict, List
@@ -58,6 +59,7 @@ class EcsQueuedTask:
58
59
  queue: str
59
60
  executor_config: ExecutorConfigType
60
61
  attempt_number: int
62
+ next_attempt_time: datetime.datetime
61
63
 
62
64
 
63
65
  @dataclass
@@ -125,9 +127,9 @@ class EcsExecutorTask:
125
127
 
126
128
  def get_task_state(self) -> str:
127
129
  """
128
- This is the primary logic that handles state in an ECS task.
130
+ Determine the state of an ECS task based on its status and other relevant attributes.
129
131
 
130
- It will determine if a status is:
132
+ It can return one of the following statuses:
131
133
  QUEUED - Task is being provisioned.
132
134
  RUNNING - Task is launched on ECS.
133
135
  REMOVED - Task provisioning has failed for some reason. See `stopped_reason`.
@@ -171,7 +173,7 @@ class EcsTaskCollection:
171
173
  exec_config: ExecutorConfigType,
172
174
  attempt_number: int,
173
175
  ):
174
- """Adds a task to the collection."""
176
+ """Add a task to the collection."""
175
177
  arn = task.task_arn
176
178
  self.tasks[arn] = task
177
179
  self.key_to_arn[airflow_task_key] = arn
@@ -180,7 +182,7 @@ class EcsTaskCollection:
180
182
  self.key_to_failure_counts[airflow_task_key] = attempt_number
181
183
 
182
184
  def update_task(self, task: EcsExecutorTask):
183
- """Updates the state of the given task based on task ARN."""
185
+ """Update the state of the given task based on task ARN."""
184
186
  self.tasks[task.task_arn] = task
185
187
 
186
188
  def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
@@ -193,7 +195,7 @@ class EcsTaskCollection:
193
195
  return self.tasks[arn]
194
196
 
195
197
  def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
196
- """Deletes task from collection based off of Airflow Task Instance Key."""
198
+ """Delete task from collection based off of Airflow Task Instance Key."""
197
199
  arn = self.key_to_arn[task_key]
198
200
  task = self.tasks[arn]
199
201
  del self.key_to_arn[task_key]
@@ -225,11 +227,11 @@ class EcsTaskCollection:
225
227
  return self.key_to_task_info[task_key]
226
228
 
227
229
  def __getitem__(self, value):
228
- """Gets a task by AWS ARN."""
230
+ """Get a task by AWS ARN."""
229
231
  return self.task_by_arn(value)
230
232
 
231
233
  def __len__(self):
232
- """Determines the number of tasks in collection."""
234
+ """Determine the number of tasks in collection."""
233
235
  return len(self.tasks)
234
236
 
235
237
 
@@ -25,6 +25,21 @@ from airflow.utils import timezone
25
25
  log = logging.getLogger(__name__)
26
26
 
27
27
 
28
+ def calculate_next_attempt_delay(
29
+ attempt_number: int,
30
+ max_delay: int = 60 * 2,
31
+ exponent_base: int = 4,
32
+ ) -> timedelta:
33
+ """
34
+ Calculate the exponential backoff (in seconds) until the next attempt.
35
+
36
+ :param attempt_number: Number of attempts since last success.
37
+ :param max_delay: Maximum delay in seconds between retries. Default 120.
38
+ :param exponent_base: Exponent base to calculate delay. Default 4.
39
+ """
40
+ return timedelta(seconds=min((exponent_base**attempt_number), max_delay))
41
+
42
+
28
43
  def exponential_backoff_retry(
29
44
  last_attempt_time: datetime,
30
45
  attempts_since_last_successful: int,
@@ -34,7 +49,7 @@ def exponential_backoff_retry(
34
49
  exponent_base: int = 4,
35
50
  ) -> None:
36
51
  """
37
- Retries a callable function with exponential backoff between attempts if it raises an exception.
52
+ Retry a callable function with exponential backoff between attempts if it raises an exception.
38
53
 
39
54
  :param last_attempt_time: Timestamp of last attempt call.
40
55
  :param attempts_since_last_successful: Number of attempts since last success.
@@ -47,8 +62,10 @@ def exponential_backoff_retry(
47
62
  log.error("Max attempts reached. Exiting.")
48
63
  return
49
64
 
50
- delay = min((exponent_base**attempts_since_last_successful), max_delay)
51
- next_retry_time = last_attempt_time + timedelta(seconds=delay)
65
+ next_retry_time = last_attempt_time + calculate_next_attempt_delay(
66
+ attempt_number=attempts_since_last_successful, max_delay=max_delay, exponent_base=exponent_base
67
+ )
68
+
52
69
  current_time = timezone.utcnow()
53
70
 
54
71
  if current_time >= next_retry_time:
@@ -56,5 +73,7 @@ def exponential_backoff_retry(
56
73
  callable_function()
57
74
  except Exception:
58
75
  log.exception("Error calling %r", callable_function.__name__)
59
- next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay)
76
+ next_delay = calculate_next_attempt_delay(
77
+ attempts_since_last_successful + 1, max_delay, exponent_base
78
+ )
60
79
  log.info("Waiting for %s seconds before retrying.", next_delay)
@@ -25,7 +25,7 @@ This module contains AWS Athena hook.
25
25
  from __future__ import annotations
26
26
 
27
27
  import warnings
28
- from typing import TYPE_CHECKING, Any
28
+ from typing import TYPE_CHECKING, Any, Collection
29
29
 
30
30
  from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
31
31
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -34,6 +34,19 @@ from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
34
34
  if TYPE_CHECKING:
35
35
  from botocore.paginate import PageIterator
36
36
 
37
+ MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
38
+
39
+
40
+ def query_params_to_string(params: dict[str, str | Collection[str]]) -> str:
41
+ result = ""
42
+ for key, value in params.items():
43
+ if key == "QueryString":
44
+ value = (
45
+ MULTI_LINE_QUERY_LOG_PREFIX + str(value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
46
+ )
47
+ result += f"\t{key}: {value}\n"
48
+ return result.rstrip()
49
+
37
50
 
38
51
  class AthenaHook(AwsBaseHook):
39
52
  """Interact with Amazon Athena.
@@ -115,7 +128,7 @@ class AthenaHook(AwsBaseHook):
115
128
  if client_request_token:
116
129
  params["ClientRequestToken"] = client_request_token
117
130
  if self.log_query:
118
- self.log.info("Running Query with params: %s", params)
131
+ self.log.info("Running Query with params:\n%s", query_params_to_string(params))
119
132
  response = self.get_conn().start_query_execution(**params)
120
133
  query_execution_id = response["QueryExecutionId"]
121
134
  self.log.info("Query execution id: %s", query_execution_id)
@@ -29,7 +29,6 @@ import inspect
29
29
  import json
30
30
  import logging
31
31
  import os
32
- import warnings
33
32
  from copy import deepcopy
34
33
  from functools import cached_property, wraps
35
34
  from pathlib import Path
@@ -44,6 +43,7 @@ import tenacity
44
43
  from botocore.config import Config
45
44
  from botocore.waiter import Waiter, WaiterModel
46
45
  from dateutil.tz import tzlocal
46
+ from deprecated import deprecated
47
47
  from slugify import slugify
48
48
 
49
49
  from airflow.configuration import conf
@@ -314,7 +314,7 @@ class BaseSessionFactory(LoggingMixin):
314
314
  idp_request_retry_kwargs = saml_config["idp_request_retry_kwargs"]
315
315
  self.log.info("idp_request_retry_kwargs= %s", idp_request_retry_kwargs)
316
316
  from requests.adapters import HTTPAdapter
317
- from requests.packages.urllib3.util.retry import Retry
317
+ from urllib3.util.retry import Retry
318
318
 
319
319
  retry_strategy = Retry(**idp_request_retry_kwargs)
320
320
  adapter = HTTPAdapter(max_retries=retry_strategy)
@@ -1020,6 +1020,13 @@ except ImportError:
1020
1020
  pass
1021
1021
 
1022
1022
 
1023
+ @deprecated(
1024
+ reason=(
1025
+ "`airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory` "
1026
+ "has been deprecated and will be removed in future"
1027
+ ),
1028
+ category=AirflowProviderDeprecationWarning,
1029
+ )
1023
1030
  class BaseAsyncSessionFactory(BaseSessionFactory):
1024
1031
  """
1025
1032
  Base AWS Session Factory class to handle aiobotocore session creation.
@@ -1029,12 +1036,6 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
1029
1036
  """
1030
1037
 
1031
1038
  def __init__(self, *args, **kwargs):
1032
- warnings.warn(
1033
- "airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and "
1034
- "will be removed in future",
1035
- AirflowProviderDeprecationWarning,
1036
- stacklevel=2,
1037
- )
1038
1039
  super().__init__(*args, **kwargs)
1039
1040
 
1040
1041
  async def get_role_credentials(self) -> dict:
@@ -1113,6 +1114,13 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
1113
1114
  return self._get_session_with_assume_role()
1114
1115
 
1115
1116
 
1117
+ @deprecated(
1118
+ reason=(
1119
+ "`airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook` "
1120
+ "has been deprecated and will be removed in future"
1121
+ ),
1122
+ category=AirflowProviderDeprecationWarning,
1123
+ )
1116
1124
  class AwsBaseAsyncHook(AwsBaseHook):
1117
1125
  """Interacts with AWS using aiobotocore asynchronously.
1118
1126
 
@@ -1129,12 +1137,6 @@ class AwsBaseAsyncHook(AwsBaseHook):
1129
1137
  """
1130
1138
 
1131
1139
  def __init__(self, *args, **kwargs):
1132
- warnings.warn(
1133
- "airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and "
1134
- "will be removed in future",
1135
- AirflowProviderDeprecationWarning,
1136
- stacklevel=2,
1137
- )
1138
1140
  super().__init__(*args, **kwargs)
1139
1141
 
1140
1142
  def get_async_session(self) -> AioSession:
@@ -383,6 +383,7 @@ class EmrContainerHook(AwsBaseHook):
383
383
  configuration_overrides: dict | None = None,
384
384
  client_request_token: str | None = None,
385
385
  tags: dict | None = None,
386
+ retry_max_attempts: int | None = None,
386
387
  ) -> str:
387
388
  """
388
389
  Submit a job to the EMR Containers API and return the job ID.
@@ -402,6 +403,7 @@ class EmrContainerHook(AwsBaseHook):
402
403
  :param client_request_token: The client idempotency token of the job run request.
403
404
  Use this if you want to specify a unique ID to prevent two jobs from getting started.
404
405
  :param tags: The tags assigned to job runs.
406
+ :param retry_max_attempts: The maximum number of attempts on the job's driver.
405
407
  :return: The ID of the job run request.
406
408
  """
407
409
  params = {
@@ -415,6 +417,10 @@ class EmrContainerHook(AwsBaseHook):
415
417
  }
416
418
  if client_request_token:
417
419
  params["clientToken"] = client_request_token
420
+ if retry_max_attempts:
421
+ params["retryPolicyConfiguration"] = {
422
+ "maxAttempts": retry_max_attempts,
423
+ }
418
424
 
419
425
  response = self.conn.start_job_run(**params)
420
426