apache-airflow-providers-amazon 8.17.0rc2__py3-none-any.whl → 8.18.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +3 -3
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +14 -0
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +148 -0
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/Dockerfile +3 -3
- airflow/providers/amazon/aws/executors/ecs/boto_schema.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +40 -17
- airflow/providers/amazon/aws/executors/ecs/utils.py +9 -7
- airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +23 -4
- airflow/providers/amazon/aws/hooks/athena.py +15 -2
- airflow/providers/amazon/aws/hooks/base_aws.py +16 -14
- airflow/providers/amazon/aws/hooks/emr.py +6 -0
- airflow/providers/amazon/aws/hooks/logs.py +85 -1
- airflow/providers/amazon/aws/hooks/neptune.py +85 -0
- airflow/providers/amazon/aws/hooks/quicksight.py +9 -8
- airflow/providers/amazon/aws/hooks/redshift_cluster.py +8 -7
- airflow/providers/amazon/aws/hooks/redshift_sql.py +3 -3
- airflow/providers/amazon/aws/hooks/s3.py +4 -6
- airflow/providers/amazon/aws/hooks/sagemaker.py +136 -9
- airflow/providers/amazon/aws/links/emr.py +122 -2
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +2 -2
- airflow/providers/amazon/aws/operators/athena.py +4 -1
- airflow/providers/amazon/aws/operators/batch.py +5 -6
- airflow/providers/amazon/aws/operators/ecs.py +6 -2
- airflow/providers/amazon/aws/operators/eks.py +31 -26
- airflow/providers/amazon/aws/operators/emr.py +192 -26
- airflow/providers/amazon/aws/operators/glue.py +5 -2
- airflow/providers/amazon/aws/operators/glue_crawler.py +5 -2
- airflow/providers/amazon/aws/operators/glue_databrew.py +5 -2
- airflow/providers/amazon/aws/operators/lambda_function.py +3 -0
- airflow/providers/amazon/aws/operators/neptune.py +218 -0
- airflow/providers/amazon/aws/operators/rds.py +21 -12
- airflow/providers/amazon/aws/operators/redshift_cluster.py +12 -18
- airflow/providers/amazon/aws/operators/redshift_data.py +2 -4
- airflow/providers/amazon/aws/operators/sagemaker.py +94 -31
- airflow/providers/amazon/aws/operators/step_function.py +4 -1
- airflow/providers/amazon/aws/sensors/batch.py +2 -2
- airflow/providers/amazon/aws/sensors/ec2.py +4 -2
- airflow/providers/amazon/aws/sensors/emr.py +13 -6
- airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +4 -1
- airflow/providers/amazon/aws/sensors/quicksight.py +17 -14
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +2 -4
- airflow/providers/amazon/aws/sensors/s3.py +3 -0
- airflow/providers/amazon/aws/sensors/sqs.py +4 -1
- airflow/providers/amazon/aws/transfers/s3_to_redshift.py +1 -0
- airflow/providers/amazon/aws/transfers/sql_to_s3.py +31 -3
- airflow/providers/amazon/aws/triggers/neptune.py +115 -0
- airflow/providers/amazon/aws/triggers/rds.py +9 -7
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +2 -2
- airflow/providers/amazon/aws/triggers/redshift_data.py +1 -1
- airflow/providers/amazon/aws/triggers/sagemaker.py +82 -1
- airflow/providers/amazon/aws/utils/__init__.py +10 -0
- airflow/providers/amazon/aws/utils/connection_wrapper.py +12 -8
- airflow/providers/amazon/aws/utils/mixins.py +5 -1
- airflow/providers/amazon/aws/utils/task_log_fetcher.py +2 -2
- airflow/providers/amazon/aws/waiters/neptune.json +85 -0
- airflow/providers/amazon/get_provider_info.py +26 -2
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/METADATA +6 -6
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/RECORD +62 -57
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.17.0rc2.dist-info → apache_airflow_providers_amazon-8.18.0rc2.dist-info}/entry_points.txt +0 -0
@@ -55,7 +55,7 @@ def init_avp(args):
|
|
55
55
|
if not is_new_policy_store:
|
56
56
|
print(
|
57
57
|
f"Since an existing policy store with description '{args.policy_store_description}' has been found in Amazon Verified Permissions, "
|
58
|
-
"the CLI
|
58
|
+
"the CLI made no changes to this policy store for security reasons. "
|
59
59
|
"Any modification to this policy store must be done manually.",
|
60
60
|
)
|
61
61
|
else:
|
@@ -80,7 +80,7 @@ def update_schema(args):
|
|
80
80
|
|
81
81
|
|
82
82
|
def _get_client():
|
83
|
-
"""
|
83
|
+
"""Return Amazon Verified Permissions client."""
|
84
84
|
region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
|
85
85
|
return boto3.client("verifiedpermissions", region_name=region_name)
|
86
86
|
|
@@ -115,7 +115,7 @@ def _create_policy_store(client: BaseClient, args) -> tuple[str | None, bool]:
|
|
115
115
|
print(f"No policy store with description '{args.policy_store_description}' found, creating one.")
|
116
116
|
if args.dry_run:
|
117
117
|
print(
|
118
|
-
"Dry run, not creating the policy store with description '{args.policy_store_description}'."
|
118
|
+
f"Dry run, not creating the policy store with description '{args.policy_store_description}'."
|
119
119
|
)
|
120
120
|
return None, True
|
121
121
|
|
@@ -35,6 +35,14 @@ ARG_DRY_RUN = Arg(
|
|
35
35
|
action="store_true",
|
36
36
|
)
|
37
37
|
|
38
|
+
# AWS IAM Identity Center
|
39
|
+
ARG_INSTANCE_NAME = Arg(("--instance-name",), help="Instance name in Identity Center", default="Airflow")
|
40
|
+
|
41
|
+
ARG_APPLICATION_NAME = Arg(
|
42
|
+
("--application-name",), help="Application name in Identity Center", default="Airflow"
|
43
|
+
)
|
44
|
+
|
45
|
+
|
38
46
|
# Amazon Verified Permissions
|
39
47
|
ARG_POLICY_STORE_DESCRIPTION = Arg(
|
40
48
|
("--policy-store-description",), help="Policy store description", default="Airflow"
|
@@ -47,6 +55,12 @@ ARG_POLICY_STORE_ID = Arg(("--policy-store-id",), help="Policy store ID")
|
|
47
55
|
################
|
48
56
|
|
49
57
|
AWS_AUTH_MANAGER_COMMANDS = (
|
58
|
+
ActionCommand(
|
59
|
+
name="init-identity-center",
|
60
|
+
help="Initialize AWS IAM identity Center resources to be used by AWS manager",
|
61
|
+
func=lazy_load_command("airflow.providers.amazon.aws.auth_manager.cli.idc_commands.init_idc"),
|
62
|
+
args=(ARG_INSTANCE_NAME, ARG_APPLICATION_NAME, ARG_DRY_RUN, ARG_VERBOSE),
|
63
|
+
),
|
50
64
|
ActionCommand(
|
51
65
|
name="init-avp",
|
52
66
|
help="Initialize Amazon Verified resources to be used by AWS manager",
|
@@ -0,0 +1,148 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
"""User sub-commands."""
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import logging
|
21
|
+
from typing import TYPE_CHECKING
|
22
|
+
|
23
|
+
import boto3
|
24
|
+
from botocore.exceptions import ClientError
|
25
|
+
|
26
|
+
from airflow.configuration import conf
|
27
|
+
from airflow.exceptions import AirflowOptionalProviderFeatureException
|
28
|
+
from airflow.providers.amazon.aws.auth_manager.constants import CONF_REGION_NAME_KEY, CONF_SECTION_NAME
|
29
|
+
from airflow.utils import cli as cli_utils
|
30
|
+
|
31
|
+
try:
|
32
|
+
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
|
33
|
+
except ImportError:
|
34
|
+
raise AirflowOptionalProviderFeatureException(
|
35
|
+
"Failed to import avp_commands. This feature is only available in Airflow "
|
36
|
+
"version >= 2.8.0 where Auth Managers are introduced."
|
37
|
+
)
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
from botocore.client import BaseClient
|
41
|
+
|
42
|
+
log = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
@cli_utils.action_cli
|
46
|
+
@providers_configuration_loaded
|
47
|
+
def init_idc(args):
|
48
|
+
"""Initialize AWS IAM Identity Center resources."""
|
49
|
+
client = _get_client()
|
50
|
+
|
51
|
+
# Create the instance if needed
|
52
|
+
instance_arn = _create_instance(client, args)
|
53
|
+
|
54
|
+
# Create the application if needed
|
55
|
+
_create_application(client, instance_arn, args)
|
56
|
+
|
57
|
+
if not args.dry_run:
|
58
|
+
print("AWS IAM Identity Center resources created successfully.")
|
59
|
+
|
60
|
+
|
61
|
+
def _get_client():
|
62
|
+
"""Return AWS IAM Identity Center client."""
|
63
|
+
region_name = conf.get(CONF_SECTION_NAME, CONF_REGION_NAME_KEY)
|
64
|
+
return boto3.client("sso-admin", region_name=region_name)
|
65
|
+
|
66
|
+
|
67
|
+
def _create_instance(client: BaseClient, args) -> str | None:
|
68
|
+
"""Create if needed AWS IAM Identity Center instance."""
|
69
|
+
instances = client.list_instances()
|
70
|
+
|
71
|
+
if args.verbose:
|
72
|
+
log.debug("Instances found: %s", instances)
|
73
|
+
|
74
|
+
if len(instances["Instances"]) > 0:
|
75
|
+
print(
|
76
|
+
f"There is already an instance configured in AWS IAM Identity Center: '{instances['Instances'][0]['InstanceArn']}'. "
|
77
|
+
"No need to create a new one."
|
78
|
+
)
|
79
|
+
return instances["Instances"][0]["InstanceArn"]
|
80
|
+
else:
|
81
|
+
print("No instance configured in AWS IAM Identity Center, creating one.")
|
82
|
+
if args.dry_run:
|
83
|
+
print("Dry run, not creating the instance.")
|
84
|
+
return None
|
85
|
+
|
86
|
+
response = client.create_instance(Name=args.instance_name)
|
87
|
+
if args.verbose:
|
88
|
+
log.debug("Response from create_instance: %s", response)
|
89
|
+
|
90
|
+
print(f"Instance created: '{response['InstanceArn']}'")
|
91
|
+
|
92
|
+
return response["InstanceArn"]
|
93
|
+
|
94
|
+
|
95
|
+
def _create_application(client: BaseClient, instance_arn: str | None, args) -> str | None:
|
96
|
+
"""Create if needed AWS IAM identity Center application."""
|
97
|
+
paginator = client.get_paginator("list_applications")
|
98
|
+
pages = paginator.paginate(InstanceArn=instance_arn or "")
|
99
|
+
applications = [application for page in pages for application in page["Applications"]]
|
100
|
+
existing_applications = [
|
101
|
+
application for application in applications if application["Name"] == args.application_name
|
102
|
+
]
|
103
|
+
|
104
|
+
if args.verbose:
|
105
|
+
log.debug("Applications found: %s", applications)
|
106
|
+
log.debug("Existing applications found: %s", existing_applications)
|
107
|
+
|
108
|
+
if len(existing_applications) > 0:
|
109
|
+
print(
|
110
|
+
f"There is already an application named '{args.application_name}' in AWS IAM Identity Center: '{existing_applications[0]['ApplicationArn']}'. "
|
111
|
+
"Using this application."
|
112
|
+
)
|
113
|
+
return existing_applications[0]["ApplicationArn"]
|
114
|
+
else:
|
115
|
+
print(f"No application named {args.application_name} found, creating one.")
|
116
|
+
if args.dry_run:
|
117
|
+
print("Dry run, not creating the application.")
|
118
|
+
return None
|
119
|
+
|
120
|
+
try:
|
121
|
+
response = client.create_application(
|
122
|
+
ApplicationProviderArn="arn:aws:sso::aws:applicationProvider/custom-saml",
|
123
|
+
Description="Application automatically created through the Airflow CLI. This application is used to access Airflow environment.",
|
124
|
+
InstanceArn=instance_arn,
|
125
|
+
Name=args.application_name,
|
126
|
+
PortalOptions={
|
127
|
+
"SignInOptions": {
|
128
|
+
"Origin": "IDENTITY_CENTER",
|
129
|
+
},
|
130
|
+
"Visibility": "ENABLED",
|
131
|
+
},
|
132
|
+
Status="ENABLED",
|
133
|
+
)
|
134
|
+
if args.verbose:
|
135
|
+
log.debug("Response from create_application: %s", response)
|
136
|
+
except ClientError as e:
|
137
|
+
# This is needed because as of today, the create_application in AWS Identity Center does not support SAML application
|
138
|
+
# Remove this part when it is supported
|
139
|
+
if "is not supported for this action" in e.response["Error"]["Message"]:
|
140
|
+
print(
|
141
|
+
"Creation of SAML applications is only supported in AWS console today. "
|
142
|
+
"Please create the application through the console."
|
143
|
+
)
|
144
|
+
raise
|
145
|
+
|
146
|
+
print(f"Application created: '{response['ApplicationArn']}'")
|
147
|
+
|
148
|
+
return response["ApplicationArn"]
|
@@ -71,7 +71,7 @@ class AwsAuthManagerAuthenticationViews(AirflowBaseView):
|
|
71
71
|
@expose("/login_callback", methods=("GET", "POST"))
|
72
72
|
def login_callback(self):
|
73
73
|
"""
|
74
|
-
|
74
|
+
Redirect the user to this callback after successful login.
|
75
75
|
|
76
76
|
CSRF protection needs to be disabled otherwise the callback won't work.
|
77
77
|
"""
|
@@ -37,7 +37,7 @@ COPY <<"EOF" /install_dags_entrypoint.sh
|
|
37
37
|
#!/bin/bash
|
38
38
|
|
39
39
|
echo "Downloading DAGs from S3 bucket"
|
40
|
-
aws s3 sync "$
|
40
|
+
aws s3 sync "$S3_URI" "$CONTAINER_DAG_PATH"
|
41
41
|
|
42
42
|
/home/airflow/entrypoint.sh "$@"
|
43
43
|
EOF
|
@@ -98,8 +98,8 @@ ENV CONTAINER_DAG_PATH=$container_dag_path
|
|
98
98
|
|
99
99
|
|
100
100
|
# Use these arguments to load DAGs onto the container from S3
|
101
|
-
ARG
|
102
|
-
ENV
|
101
|
+
ARG s3_uri
|
102
|
+
ENV S3_URI=$s3_uri
|
103
103
|
# If using S3 bucket as source of DAGs, uncommenting the next ENTRYPOINT command will overwrite this one.
|
104
104
|
ENTRYPOINT ["/usr/bin/dumb-init", "--", "/home/airflow/entrypoint.sh"]
|
105
105
|
|
@@ -61,7 +61,7 @@ class BotoTaskSchema(Schema):
|
|
61
61
|
|
62
62
|
@post_load
|
63
63
|
def make_task(self, data, **kwargs):
|
64
|
-
"""
|
64
|
+
"""Overwrite marshmallow load() to return an EcsExecutorTask instance instead of a dictionary."""
|
65
65
|
# Imported here to avoid circular import.
|
66
66
|
from airflow.providers.amazon.aws.executors.ecs.utils import EcsExecutorTask
|
67
67
|
|
@@ -42,9 +42,13 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
|
|
42
42
|
EcsQueuedTask,
|
43
43
|
EcsTaskCollection,
|
44
44
|
)
|
45
|
-
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import
|
45
|
+
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
|
46
|
+
calculate_next_attempt_delay,
|
47
|
+
exponential_backoff_retry,
|
48
|
+
)
|
46
49
|
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
47
50
|
from airflow.utils import timezone
|
51
|
+
from airflow.utils.helpers import merge_dicts
|
48
52
|
from airflow.utils.state import State
|
49
53
|
|
50
54
|
if TYPE_CHECKING:
|
@@ -108,7 +112,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
108
112
|
self.run_task_kwargs = self._load_run_kwargs()
|
109
113
|
|
110
114
|
def start(self):
|
111
|
-
"""
|
115
|
+
"""Call this when the Executor is run for the first time by the scheduler."""
|
112
116
|
check_health = conf.getboolean(
|
113
117
|
CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
|
114
118
|
)
|
@@ -213,7 +217,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
213
217
|
self.log.exception("Failed to sync %s", self.__class__.__name__)
|
214
218
|
|
215
219
|
def sync_running_tasks(self):
|
216
|
-
"""
|
220
|
+
"""Check and update state on all running tasks."""
|
217
221
|
all_task_arns = self.active_workers.get_all_arns()
|
218
222
|
if not all_task_arns:
|
219
223
|
self.log.debug("No active Airflow tasks, skipping sync.")
|
@@ -300,7 +304,14 @@ class AwsEcsExecutor(BaseExecutor):
|
|
300
304
|
)
|
301
305
|
self.active_workers.increment_failure_count(task_key)
|
302
306
|
self.pending_tasks.appendleft(
|
303
|
-
EcsQueuedTask(
|
307
|
+
EcsQueuedTask(
|
308
|
+
task_key,
|
309
|
+
task_cmd,
|
310
|
+
queue,
|
311
|
+
exec_info,
|
312
|
+
failure_count + 1,
|
313
|
+
timezone.utcnow() + calculate_next_attempt_delay(failure_count),
|
314
|
+
)
|
304
315
|
)
|
305
316
|
else:
|
306
317
|
self.log.error(
|
@@ -313,7 +324,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
313
324
|
|
314
325
|
def attempt_task_runs(self):
|
315
326
|
"""
|
316
|
-
|
327
|
+
Take tasks from the pending_tasks queue, and attempts to find an instance to run it on.
|
317
328
|
|
318
329
|
If the launch type is EC2, this will attempt to place tasks on empty EC2 instances. If
|
319
330
|
there are no EC2 instances available, no task is placed and this function will be
|
@@ -331,6 +342,8 @@ class AwsEcsExecutor(BaseExecutor):
|
|
331
342
|
exec_config = ecs_task.executor_config
|
332
343
|
attempt_number = ecs_task.attempt_number
|
333
344
|
_failure_reasons = []
|
345
|
+
if timezone.utcnow() < ecs_task.next_attempt_time:
|
346
|
+
continue
|
334
347
|
try:
|
335
348
|
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
|
336
349
|
except NoCredentialsError:
|
@@ -361,6 +374,9 @@ class AwsEcsExecutor(BaseExecutor):
|
|
361
374
|
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
|
362
375
|
if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
|
363
376
|
ecs_task.attempt_number += 1
|
377
|
+
ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
|
378
|
+
attempt_number
|
379
|
+
)
|
364
380
|
self.pending_tasks.appendleft(ecs_task)
|
365
381
|
else:
|
366
382
|
self.log.error(
|
@@ -393,8 +409,8 @@ class AwsEcsExecutor(BaseExecutor):
|
|
393
409
|
The command and executor config will be placed in the container-override
|
394
410
|
section of the JSON request before calling Boto3's "run_task" function.
|
395
411
|
"""
|
396
|
-
|
397
|
-
boto_run_task = self.ecs.run_task(**
|
412
|
+
run_task_kwargs = self._run_task_kwargs(task_id, cmd, queue, exec_config)
|
413
|
+
boto_run_task = self.ecs.run_task(**run_task_kwargs)
|
398
414
|
run_task_response = BotoRunTaskSchema().load(boto_run_task)
|
399
415
|
return run_task_response
|
400
416
|
|
@@ -402,30 +418,32 @@ class AwsEcsExecutor(BaseExecutor):
|
|
402
418
|
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
|
403
419
|
) -> dict:
|
404
420
|
"""
|
405
|
-
|
421
|
+
Update the Airflow command by modifying container overrides for task-specific kwargs.
|
406
422
|
|
407
423
|
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
|
408
424
|
"""
|
409
|
-
|
410
|
-
|
425
|
+
run_task_kwargs = deepcopy(self.run_task_kwargs)
|
426
|
+
run_task_kwargs = merge_dicts(run_task_kwargs, exec_config)
|
427
|
+
container_override = self.get_container(run_task_kwargs["overrides"]["containerOverrides"])
|
411
428
|
container_override["command"] = cmd
|
412
|
-
container_override.update(exec_config)
|
413
429
|
|
414
430
|
# Inject the env variable to configure logging for containerized execution environment
|
415
431
|
if "environment" not in container_override:
|
416
432
|
container_override["environment"] = []
|
417
433
|
container_override["environment"].append({"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"})
|
418
434
|
|
419
|
-
return
|
435
|
+
return run_task_kwargs
|
420
436
|
|
421
437
|
def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
|
422
438
|
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
|
423
439
|
if executor_config and ("name" in executor_config or "command" in executor_config):
|
424
440
|
raise ValueError('Executor Config should never override "name" or "command"')
|
425
|
-
self.pending_tasks.append(
|
441
|
+
self.pending_tasks.append(
|
442
|
+
EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
|
443
|
+
)
|
426
444
|
|
427
445
|
def end(self, heartbeat_interval=10):
|
428
|
-
"""
|
446
|
+
"""Wait for all currently running tasks to end, and don't launch any tasks."""
|
429
447
|
try:
|
430
448
|
while True:
|
431
449
|
self.sync()
|
@@ -465,8 +483,13 @@ class AwsEcsExecutor(BaseExecutor):
|
|
465
483
|
return ecs_executor_run_task_kwargs
|
466
484
|
|
467
485
|
def get_container(self, container_list):
|
468
|
-
"""
|
486
|
+
"""Search task list for core Airflow container."""
|
469
487
|
for container in container_list:
|
470
|
-
|
471
|
-
|
488
|
+
try:
|
489
|
+
if container["name"] == self.container_name:
|
490
|
+
return container
|
491
|
+
except KeyError:
|
492
|
+
raise EcsExecutorException(
|
493
|
+
'container "name" must be provided in "containerOverrides" configuration'
|
494
|
+
)
|
472
495
|
raise KeyError(f"No such container found by container name: {self.container_name}")
|
@@ -23,6 +23,7 @@ Data classes and utility functions used by the ECS executor.
|
|
23
23
|
|
24
24
|
from __future__ import annotations
|
25
25
|
|
26
|
+
import datetime
|
26
27
|
from collections import defaultdict
|
27
28
|
from dataclasses import dataclass
|
28
29
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
@@ -58,6 +59,7 @@ class EcsQueuedTask:
|
|
58
59
|
queue: str
|
59
60
|
executor_config: ExecutorConfigType
|
60
61
|
attempt_number: int
|
62
|
+
next_attempt_time: datetime.datetime
|
61
63
|
|
62
64
|
|
63
65
|
@dataclass
|
@@ -125,9 +127,9 @@ class EcsExecutorTask:
|
|
125
127
|
|
126
128
|
def get_task_state(self) -> str:
|
127
129
|
"""
|
128
|
-
|
130
|
+
Determine the state of an ECS task based on its status and other relevant attributes.
|
129
131
|
|
130
|
-
It
|
132
|
+
It can return one of the following statuses:
|
131
133
|
QUEUED - Task is being provisioned.
|
132
134
|
RUNNING - Task is launched on ECS.
|
133
135
|
REMOVED - Task provisioning has failed for some reason. See `stopped_reason`.
|
@@ -171,7 +173,7 @@ class EcsTaskCollection:
|
|
171
173
|
exec_config: ExecutorConfigType,
|
172
174
|
attempt_number: int,
|
173
175
|
):
|
174
|
-
"""
|
176
|
+
"""Add a task to the collection."""
|
175
177
|
arn = task.task_arn
|
176
178
|
self.tasks[arn] = task
|
177
179
|
self.key_to_arn[airflow_task_key] = arn
|
@@ -180,7 +182,7 @@ class EcsTaskCollection:
|
|
180
182
|
self.key_to_failure_counts[airflow_task_key] = attempt_number
|
181
183
|
|
182
184
|
def update_task(self, task: EcsExecutorTask):
|
183
|
-
"""
|
185
|
+
"""Update the state of the given task based on task ARN."""
|
184
186
|
self.tasks[task.task_arn] = task
|
185
187
|
|
186
188
|
def task_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
|
@@ -193,7 +195,7 @@ class EcsTaskCollection:
|
|
193
195
|
return self.tasks[arn]
|
194
196
|
|
195
197
|
def pop_by_key(self, task_key: TaskInstanceKey) -> EcsExecutorTask:
|
196
|
-
"""
|
198
|
+
"""Delete task from collection based off of Airflow Task Instance Key."""
|
197
199
|
arn = self.key_to_arn[task_key]
|
198
200
|
task = self.tasks[arn]
|
199
201
|
del self.key_to_arn[task_key]
|
@@ -225,11 +227,11 @@ class EcsTaskCollection:
|
|
225
227
|
return self.key_to_task_info[task_key]
|
226
228
|
|
227
229
|
def __getitem__(self, value):
|
228
|
-
"""
|
230
|
+
"""Get a task by AWS ARN."""
|
229
231
|
return self.task_by_arn(value)
|
230
232
|
|
231
233
|
def __len__(self):
|
232
|
-
"""
|
234
|
+
"""Determine the number of tasks in collection."""
|
233
235
|
return len(self.tasks)
|
234
236
|
|
235
237
|
|
@@ -25,6 +25,21 @@ from airflow.utils import timezone
|
|
25
25
|
log = logging.getLogger(__name__)
|
26
26
|
|
27
27
|
|
28
|
+
def calculate_next_attempt_delay(
|
29
|
+
attempt_number: int,
|
30
|
+
max_delay: int = 60 * 2,
|
31
|
+
exponent_base: int = 4,
|
32
|
+
) -> timedelta:
|
33
|
+
"""
|
34
|
+
Calculate the exponential backoff (in seconds) until the next attempt.
|
35
|
+
|
36
|
+
:param attempt_number: Number of attempts since last success.
|
37
|
+
:param max_delay: Maximum delay in seconds between retries. Default 120.
|
38
|
+
:param exponent_base: Exponent base to calculate delay. Default 4.
|
39
|
+
"""
|
40
|
+
return timedelta(seconds=min((exponent_base**attempt_number), max_delay))
|
41
|
+
|
42
|
+
|
28
43
|
def exponential_backoff_retry(
|
29
44
|
last_attempt_time: datetime,
|
30
45
|
attempts_since_last_successful: int,
|
@@ -34,7 +49,7 @@ def exponential_backoff_retry(
|
|
34
49
|
exponent_base: int = 4,
|
35
50
|
) -> None:
|
36
51
|
"""
|
37
|
-
|
52
|
+
Retry a callable function with exponential backoff between attempts if it raises an exception.
|
38
53
|
|
39
54
|
:param last_attempt_time: Timestamp of last attempt call.
|
40
55
|
:param attempts_since_last_successful: Number of attempts since last success.
|
@@ -47,8 +62,10 @@ def exponential_backoff_retry(
|
|
47
62
|
log.error("Max attempts reached. Exiting.")
|
48
63
|
return
|
49
64
|
|
50
|
-
|
51
|
-
|
65
|
+
next_retry_time = last_attempt_time + calculate_next_attempt_delay(
|
66
|
+
attempt_number=attempts_since_last_successful, max_delay=max_delay, exponent_base=exponent_base
|
67
|
+
)
|
68
|
+
|
52
69
|
current_time = timezone.utcnow()
|
53
70
|
|
54
71
|
if current_time >= next_retry_time:
|
@@ -56,5 +73,7 @@ def exponential_backoff_retry(
|
|
56
73
|
callable_function()
|
57
74
|
except Exception:
|
58
75
|
log.exception("Error calling %r", callable_function.__name__)
|
59
|
-
next_delay =
|
76
|
+
next_delay = calculate_next_attempt_delay(
|
77
|
+
attempts_since_last_successful + 1, max_delay, exponent_base
|
78
|
+
)
|
60
79
|
log.info("Waiting for %s seconds before retrying.", next_delay)
|
@@ -25,7 +25,7 @@ This module contains AWS Athena hook.
|
|
25
25
|
from __future__ import annotations
|
26
26
|
|
27
27
|
import warnings
|
28
|
-
from typing import TYPE_CHECKING, Any
|
28
|
+
from typing import TYPE_CHECKING, Any, Collection
|
29
29
|
|
30
30
|
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
31
31
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
@@ -34,6 +34,19 @@ from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
|
34
34
|
if TYPE_CHECKING:
|
35
35
|
from botocore.paginate import PageIterator
|
36
36
|
|
37
|
+
MULTI_LINE_QUERY_LOG_PREFIX = "\n\t\t"
|
38
|
+
|
39
|
+
|
40
|
+
def query_params_to_string(params: dict[str, str | Collection[str]]) -> str:
|
41
|
+
result = ""
|
42
|
+
for key, value in params.items():
|
43
|
+
if key == "QueryString":
|
44
|
+
value = (
|
45
|
+
MULTI_LINE_QUERY_LOG_PREFIX + str(value).replace("\n", MULTI_LINE_QUERY_LOG_PREFIX).rstrip()
|
46
|
+
)
|
47
|
+
result += f"\t{key}: {value}\n"
|
48
|
+
return result.rstrip()
|
49
|
+
|
37
50
|
|
38
51
|
class AthenaHook(AwsBaseHook):
|
39
52
|
"""Interact with Amazon Athena.
|
@@ -115,7 +128,7 @@ class AthenaHook(AwsBaseHook):
|
|
115
128
|
if client_request_token:
|
116
129
|
params["ClientRequestToken"] = client_request_token
|
117
130
|
if self.log_query:
|
118
|
-
self.log.info("Running Query with params
|
131
|
+
self.log.info("Running Query with params:\n%s", query_params_to_string(params))
|
119
132
|
response = self.get_conn().start_query_execution(**params)
|
120
133
|
query_execution_id = response["QueryExecutionId"]
|
121
134
|
self.log.info("Query execution id: %s", query_execution_id)
|
@@ -29,7 +29,6 @@ import inspect
|
|
29
29
|
import json
|
30
30
|
import logging
|
31
31
|
import os
|
32
|
-
import warnings
|
33
32
|
from copy import deepcopy
|
34
33
|
from functools import cached_property, wraps
|
35
34
|
from pathlib import Path
|
@@ -44,6 +43,7 @@ import tenacity
|
|
44
43
|
from botocore.config import Config
|
45
44
|
from botocore.waiter import Waiter, WaiterModel
|
46
45
|
from dateutil.tz import tzlocal
|
46
|
+
from deprecated import deprecated
|
47
47
|
from slugify import slugify
|
48
48
|
|
49
49
|
from airflow.configuration import conf
|
@@ -314,7 +314,7 @@ class BaseSessionFactory(LoggingMixin):
|
|
314
314
|
idp_request_retry_kwargs = saml_config["idp_request_retry_kwargs"]
|
315
315
|
self.log.info("idp_request_retry_kwargs= %s", idp_request_retry_kwargs)
|
316
316
|
from requests.adapters import HTTPAdapter
|
317
|
-
from
|
317
|
+
from urllib3.util.retry import Retry
|
318
318
|
|
319
319
|
retry_strategy = Retry(**idp_request_retry_kwargs)
|
320
320
|
adapter = HTTPAdapter(max_retries=retry_strategy)
|
@@ -1020,6 +1020,13 @@ except ImportError:
|
|
1020
1020
|
pass
|
1021
1021
|
|
1022
1022
|
|
1023
|
+
@deprecated(
|
1024
|
+
reason=(
|
1025
|
+
"`airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory` "
|
1026
|
+
"has been deprecated and will be removed in future"
|
1027
|
+
),
|
1028
|
+
category=AirflowProviderDeprecationWarning,
|
1029
|
+
)
|
1023
1030
|
class BaseAsyncSessionFactory(BaseSessionFactory):
|
1024
1031
|
"""
|
1025
1032
|
Base AWS Session Factory class to handle aiobotocore session creation.
|
@@ -1029,12 +1036,6 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
|
|
1029
1036
|
"""
|
1030
1037
|
|
1031
1038
|
def __init__(self, *args, **kwargs):
|
1032
|
-
warnings.warn(
|
1033
|
-
"airflow.providers.amazon.aws.hook.base_aws.BaseAsyncSessionFactory has been deprecated and "
|
1034
|
-
"will be removed in future",
|
1035
|
-
AirflowProviderDeprecationWarning,
|
1036
|
-
stacklevel=2,
|
1037
|
-
)
|
1038
1039
|
super().__init__(*args, **kwargs)
|
1039
1040
|
|
1040
1041
|
async def get_role_credentials(self) -> dict:
|
@@ -1113,6 +1114,13 @@ class BaseAsyncSessionFactory(BaseSessionFactory):
|
|
1113
1114
|
return self._get_session_with_assume_role()
|
1114
1115
|
|
1115
1116
|
|
1117
|
+
@deprecated(
|
1118
|
+
reason=(
|
1119
|
+
"`airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook` "
|
1120
|
+
"has been deprecated and will be removed in future"
|
1121
|
+
),
|
1122
|
+
category=AirflowProviderDeprecationWarning,
|
1123
|
+
)
|
1116
1124
|
class AwsBaseAsyncHook(AwsBaseHook):
|
1117
1125
|
"""Interacts with AWS using aiobotocore asynchronously.
|
1118
1126
|
|
@@ -1129,12 +1137,6 @@ class AwsBaseAsyncHook(AwsBaseHook):
|
|
1129
1137
|
"""
|
1130
1138
|
|
1131
1139
|
def __init__(self, *args, **kwargs):
|
1132
|
-
warnings.warn(
|
1133
|
-
"airflow.providers.amazon.aws.hook.base_aws.AwsBaseAsyncHook has been deprecated and "
|
1134
|
-
"will be removed in future",
|
1135
|
-
AirflowProviderDeprecationWarning,
|
1136
|
-
stacklevel=2,
|
1137
|
-
)
|
1138
1140
|
super().__init__(*args, **kwargs)
|
1139
1141
|
|
1140
1142
|
def get_async_session(self) -> AioSession:
|
@@ -383,6 +383,7 @@ class EmrContainerHook(AwsBaseHook):
|
|
383
383
|
configuration_overrides: dict | None = None,
|
384
384
|
client_request_token: str | None = None,
|
385
385
|
tags: dict | None = None,
|
386
|
+
retry_max_attempts: int | None = None,
|
386
387
|
) -> str:
|
387
388
|
"""
|
388
389
|
Submit a job to the EMR Containers API and return the job ID.
|
@@ -402,6 +403,7 @@ class EmrContainerHook(AwsBaseHook):
|
|
402
403
|
:param client_request_token: The client idempotency token of the job run request.
|
403
404
|
Use this if you want to specify a unique ID to prevent two jobs from getting started.
|
404
405
|
:param tags: The tags assigned to job runs.
|
406
|
+
:param retry_max_attempts: The maximum number of attempts on the job's driver.
|
405
407
|
:return: The ID of the job run request.
|
406
408
|
"""
|
407
409
|
params = {
|
@@ -415,6 +417,10 @@ class EmrContainerHook(AwsBaseHook):
|
|
415
417
|
}
|
416
418
|
if client_request_token:
|
417
419
|
params["clientToken"] = client_request_token
|
420
|
+
if retry_max_attempts:
|
421
|
+
params["retryPolicyConfiguration"] = {
|
422
|
+
"maxAttempts": retry_max_attempts,
|
423
|
+
}
|
418
424
|
|
419
425
|
response = self.conn.start_job_run(**params)
|
420
426
|
|