apache-airflow-providers-amazon 8.19.0rc1__py3-none-any.whl → 8.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/avp/entities.py +4 -2
- airflow/providers/amazon/aws/auth_manager/avp/facade.py +22 -7
- airflow/providers/amazon/aws/auth_manager/{cli → avp}/schema.json +34 -2
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +91 -170
- airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +7 -32
- airflow/providers/amazon/aws/auth_manager/cli/definition.py +1 -1
- airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +1 -0
- airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
- airflow/providers/amazon/aws/executors/batch/__init__.py +16 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor.py +420 -0
- airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +87 -0
- airflow/providers/amazon/aws/executors/batch/boto_schema.py +67 -0
- airflow/providers/amazon/aws/executors/batch/utils.py +160 -0
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +61 -18
- airflow/providers/amazon/aws/executors/ecs/utils.py +8 -13
- airflow/providers/amazon/aws/executors/utils/base_config_keys.py +25 -0
- airflow/providers/amazon/aws/hooks/athena.py +1 -0
- airflow/providers/amazon/aws/hooks/base_aws.py +1 -0
- airflow/providers/amazon/aws/hooks/batch_client.py +4 -3
- airflow/providers/amazon/aws/hooks/batch_waiters.py +1 -0
- airflow/providers/amazon/aws/hooks/bedrock.py +59 -0
- airflow/providers/amazon/aws/hooks/chime.py +1 -0
- airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/hooks/datasync.py +1 -0
- airflow/providers/amazon/aws/hooks/dynamodb.py +1 -0
- airflow/providers/amazon/aws/hooks/eks.py +1 -0
- airflow/providers/amazon/aws/hooks/glue.py +13 -5
- airflow/providers/amazon/aws/hooks/glue_catalog.py +1 -0
- airflow/providers/amazon/aws/hooks/kinesis.py +1 -0
- airflow/providers/amazon/aws/hooks/lambda_function.py +1 -0
- airflow/providers/amazon/aws/hooks/rds.py +1 -0
- airflow/providers/amazon/aws/hooks/s3.py +24 -30
- airflow/providers/amazon/aws/hooks/ses.py +1 -0
- airflow/providers/amazon/aws/hooks/sns.py +1 -0
- airflow/providers/amazon/aws/hooks/sqs.py +1 -0
- airflow/providers/amazon/aws/operators/athena.py +2 -2
- airflow/providers/amazon/aws/operators/base_aws.py +4 -1
- airflow/providers/amazon/aws/operators/batch.py +4 -2
- airflow/providers/amazon/aws/operators/bedrock.py +252 -0
- airflow/providers/amazon/aws/operators/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/operators/datasync.py +1 -0
- airflow/providers/amazon/aws/operators/ecs.py +9 -10
- airflow/providers/amazon/aws/operators/eks.py +1 -0
- airflow/providers/amazon/aws/operators/emr.py +57 -7
- airflow/providers/amazon/aws/operators/s3.py +1 -0
- airflow/providers/amazon/aws/operators/sns.py +1 -0
- airflow/providers/amazon/aws/operators/sqs.py +1 -0
- airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -0
- airflow/providers/amazon/aws/secrets/systems_manager.py +1 -0
- airflow/providers/amazon/aws/sensors/base_aws.py +4 -1
- airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
- airflow/providers/amazon/aws/sensors/cloud_formation.py +1 -0
- airflow/providers/amazon/aws/sensors/eks.py +3 -4
- airflow/providers/amazon/aws/sensors/sqs.py +2 -1
- airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -2
- airflow/providers/amazon/aws/transfers/base.py +1 -0
- airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -0
- airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -0
- airflow/providers/amazon/aws/transfers/redshift_to_s3.py +21 -19
- airflow/providers/amazon/aws/triggers/bedrock.py +61 -0
- airflow/providers/amazon/aws/triggers/eks.py +1 -1
- airflow/providers/amazon/aws/triggers/redshift_cluster.py +1 -0
- airflow/providers/amazon/aws/triggers/s3.py +4 -2
- airflow/providers/amazon/aws/triggers/sagemaker.py +6 -4
- airflow/providers/amazon/aws/utils/emailer.py +1 -0
- airflow/providers/amazon/aws/waiters/bedrock.json +42 -0
- airflow/providers/amazon/get_provider_info.py +86 -1
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/METADATA +10 -9
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/RECORD +77 -66
- /airflow/providers/amazon/aws/executors/{ecs/Dockerfile → Dockerfile} +0 -0
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,160 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import datetime
|
20
|
+
from collections import defaultdict
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from typing import TYPE_CHECKING, Any, Dict, List
|
23
|
+
|
24
|
+
from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
|
25
|
+
from airflow.utils.state import State
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from airflow.models.taskinstance import TaskInstanceKey
|
29
|
+
|
30
|
+
CommandType = List[str]
|
31
|
+
ExecutorConfigType = Dict[str, Any]
|
32
|
+
|
33
|
+
CONFIG_GROUP_NAME = "aws_batch_executor"
|
34
|
+
|
35
|
+
CONFIG_DEFAULTS = {
|
36
|
+
"conn_id": "aws_default",
|
37
|
+
"max_submit_job_attempts": "3",
|
38
|
+
"check_health_on_startup": "True",
|
39
|
+
}
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass
|
43
|
+
class BatchQueuedJob:
|
44
|
+
"""Represents a Batch job that is queued. The job will be run in the next heartbeat."""
|
45
|
+
|
46
|
+
key: TaskInstanceKey
|
47
|
+
command: CommandType
|
48
|
+
queue: str
|
49
|
+
executor_config: ExecutorConfigType
|
50
|
+
attempt_number: int
|
51
|
+
next_attempt_time: datetime.datetime
|
52
|
+
|
53
|
+
|
54
|
+
@dataclass
|
55
|
+
class BatchJobInfo:
|
56
|
+
"""Contains information about a currently running Batch job."""
|
57
|
+
|
58
|
+
cmd: CommandType
|
59
|
+
queue: str
|
60
|
+
config: ExecutorConfigType
|
61
|
+
|
62
|
+
|
63
|
+
class BatchJob:
|
64
|
+
"""Data Transfer Object for an AWS Batch Job."""
|
65
|
+
|
66
|
+
STATE_MAPPINGS = {
|
67
|
+
"SUBMITTED": State.QUEUED,
|
68
|
+
"PENDING": State.QUEUED,
|
69
|
+
"RUNNABLE": State.QUEUED,
|
70
|
+
"STARTING": State.QUEUED,
|
71
|
+
"RUNNING": State.RUNNING,
|
72
|
+
"SUCCEEDED": State.SUCCESS,
|
73
|
+
"FAILED": State.FAILED,
|
74
|
+
}
|
75
|
+
|
76
|
+
def __init__(self, job_id: str, status: str, status_reason: str | None = None):
|
77
|
+
self.job_id = job_id
|
78
|
+
self.status = status
|
79
|
+
self.status_reason = status_reason
|
80
|
+
|
81
|
+
def get_job_state(self) -> str:
|
82
|
+
"""Return the state of the job."""
|
83
|
+
return self.STATE_MAPPINGS.get(self.status, State.QUEUED)
|
84
|
+
|
85
|
+
def __repr__(self):
|
86
|
+
"""Return a visual representation of the Job status."""
|
87
|
+
return f"({self.job_id} -> {self.status}, {self.get_job_state()})"
|
88
|
+
|
89
|
+
|
90
|
+
class BatchJobCollection:
|
91
|
+
"""A collection to manage running Batch Jobs."""
|
92
|
+
|
93
|
+
def __init__(self):
|
94
|
+
self.key_to_id: dict[TaskInstanceKey, str] = {}
|
95
|
+
self.id_to_key: dict[str, TaskInstanceKey] = {}
|
96
|
+
self.id_to_failure_counts: dict[str, int] = defaultdict(int)
|
97
|
+
self.id_to_job_info: dict[str, BatchJobInfo] = {}
|
98
|
+
|
99
|
+
def add_job(
|
100
|
+
self,
|
101
|
+
job_id: str,
|
102
|
+
airflow_task_key: TaskInstanceKey,
|
103
|
+
airflow_cmd: CommandType,
|
104
|
+
queue: str,
|
105
|
+
exec_config: ExecutorConfigType,
|
106
|
+
attempt_number: int,
|
107
|
+
):
|
108
|
+
"""Add a job to the collection."""
|
109
|
+
self.key_to_id[airflow_task_key] = job_id
|
110
|
+
self.id_to_key[job_id] = airflow_task_key
|
111
|
+
self.id_to_failure_counts[job_id] = attempt_number
|
112
|
+
self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)
|
113
|
+
|
114
|
+
def pop_by_id(self, job_id: str) -> TaskInstanceKey:
|
115
|
+
"""Delete job from collection based off of Batch Job ID."""
|
116
|
+
task_key = self.id_to_key[job_id]
|
117
|
+
del self.key_to_id[task_key]
|
118
|
+
del self.id_to_key[job_id]
|
119
|
+
del self.id_to_failure_counts[job_id]
|
120
|
+
return task_key
|
121
|
+
|
122
|
+
def failure_count_by_id(self, job_id: str) -> int:
|
123
|
+
"""Get the number of times a job has failed given a Batch Job Id."""
|
124
|
+
return self.id_to_failure_counts[job_id]
|
125
|
+
|
126
|
+
def increment_failure_count(self, job_id: str):
|
127
|
+
"""Increment the failure counter given a Batch Job Id."""
|
128
|
+
self.id_to_failure_counts[job_id] += 1
|
129
|
+
|
130
|
+
def get_all_jobs(self) -> list[str]:
|
131
|
+
"""Get all AWS ARNs in collection."""
|
132
|
+
return list(self.id_to_key.keys())
|
133
|
+
|
134
|
+
def __len__(self):
|
135
|
+
"""Return the number of jobs in collection."""
|
136
|
+
return len(self.key_to_id)
|
137
|
+
|
138
|
+
|
139
|
+
class BatchSubmitJobKwargsConfigKeys(BaseConfigKeys):
|
140
|
+
"""Keys loaded into the config which are valid Batch submit_job kwargs."""
|
141
|
+
|
142
|
+
JOB_NAME = "job_name"
|
143
|
+
JOB_QUEUE = "job_queue"
|
144
|
+
JOB_DEFINITION = "job_definition"
|
145
|
+
EKS_PROPERTIES_OVERRIDE = "eks_properties_override"
|
146
|
+
NODE_OVERRIDE = "node_override"
|
147
|
+
|
148
|
+
|
149
|
+
class AllBatchConfigKeys(BatchSubmitJobKwargsConfigKeys):
|
150
|
+
"""All keys loaded into the config which are related to the Batch Executor."""
|
151
|
+
|
152
|
+
MAX_SUBMIT_JOB_ATTEMPTS = "max_submit_job_attempts"
|
153
|
+
AWS_CONN_ID = "conn_id"
|
154
|
+
SUBMIT_JOB_KWARGS = "submit_job_kwargs"
|
155
|
+
REGION_NAME = "region_name"
|
156
|
+
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
|
157
|
+
|
158
|
+
|
159
|
+
class BatchExecutorException(Exception):
|
160
|
+
"""Thrown when something unexpected has occurred within the AWS Batch ecosystem."""
|
@@ -26,7 +26,7 @@ from __future__ import annotations
|
|
26
26
|
import time
|
27
27
|
from collections import defaultdict, deque
|
28
28
|
from copy import deepcopy
|
29
|
-
from typing import TYPE_CHECKING
|
29
|
+
from typing import TYPE_CHECKING, Sequence
|
30
30
|
|
31
31
|
from botocore.exceptions import ClientError, NoCredentialsError
|
32
32
|
|
@@ -47,12 +47,13 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
|
|
47
47
|
exponential_backoff_retry,
|
48
48
|
)
|
49
49
|
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
50
|
+
from airflow.stats import Stats
|
50
51
|
from airflow.utils import timezone
|
51
52
|
from airflow.utils.helpers import merge_dicts
|
52
53
|
from airflow.utils.state import State
|
53
54
|
|
54
55
|
if TYPE_CHECKING:
|
55
|
-
from airflow.models.taskinstance import TaskInstanceKey
|
56
|
+
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
|
56
57
|
from airflow.providers.amazon.aws.executors.ecs.utils import (
|
57
58
|
CommandType,
|
58
59
|
ExecutorConfigType,
|
@@ -182,7 +183,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
182
183
|
AllEcsConfigKeys.AWS_CONN_ID,
|
183
184
|
fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
|
184
185
|
)
|
185
|
-
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
|
186
|
+
region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None)
|
186
187
|
self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
|
187
188
|
self.attempts_since_last_successful_connection += 1
|
188
189
|
self.last_connection_reload = timezone.utcnow()
|
@@ -208,7 +209,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
208
209
|
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
209
210
|
self.IS_BOTO_CONNECTION_HEALTHY = False
|
210
211
|
self.log.warning(
|
211
|
-
|
212
|
+
"AWS credentials are either missing or expired: %s.\nRetrying connection", error
|
212
213
|
)
|
213
214
|
|
214
215
|
except Exception:
|
@@ -240,21 +241,19 @@ class AwsEcsExecutor(BaseExecutor):
|
|
240
241
|
# Get state of current task.
|
241
242
|
task_state = task.get_task_state()
|
242
243
|
task_key = self.active_workers.arn_to_key[task.task_arn]
|
244
|
+
|
243
245
|
# Mark finished tasks as either a success/failure.
|
244
|
-
if task_state == State.FAILED:
|
245
|
-
self.fail(task_key)
|
246
|
+
if task_state == State.FAILED or task_state == State.REMOVED:
|
246
247
|
self.__log_container_failures(task_arn=task.task_arn)
|
247
|
-
elif task_state == State.SUCCESS:
|
248
|
-
self.success(task_key)
|
249
|
-
elif task_state == State.REMOVED:
|
250
248
|
self.__handle_failed_task(task.task_arn, task.stopped_reason)
|
251
|
-
|
249
|
+
elif task_state == State.SUCCESS:
|
252
250
|
self.log.debug(
|
253
251
|
"Airflow task %s marked as %s after running on ECS Task (arn) %s",
|
254
252
|
task_key,
|
255
253
|
task_state,
|
256
254
|
task.task_arn,
|
257
255
|
)
|
256
|
+
self.success(task_key)
|
258
257
|
self.active_workers.pop_by_key(task_key)
|
259
258
|
|
260
259
|
def __describe_tasks(self, task_arns):
|
@@ -287,7 +286,14 @@ class AwsEcsExecutor(BaseExecutor):
|
|
287
286
|
)
|
288
287
|
|
289
288
|
def __handle_failed_task(self, task_arn: str, reason: str):
|
290
|
-
"""
|
289
|
+
"""
|
290
|
+
If an API failure occurs, the task is rescheduled.
|
291
|
+
|
292
|
+
This function will determine whether the task has been attempted the appropriate number
|
293
|
+
of times, and determine whether the task should be marked failed or not. The task will
|
294
|
+
be removed active_workers, and marked as FAILED, or set into pending_tasks depending on
|
295
|
+
how many times it has been retried.
|
296
|
+
"""
|
291
297
|
task_key = self.active_workers.arn_to_key[task_arn]
|
292
298
|
task_info = self.active_workers.info_by_key(task_key)
|
293
299
|
task_cmd = task_info.cmd
|
@@ -303,8 +309,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
303
309
|
self.__class__.MAX_RUN_TASK_ATTEMPTS,
|
304
310
|
task_arn,
|
305
311
|
)
|
306
|
-
self.
|
307
|
-
self.pending_tasks.appendleft(
|
312
|
+
self.pending_tasks.append(
|
308
313
|
EcsQueuedTask(
|
309
314
|
task_key,
|
310
315
|
task_cmd,
|
@@ -320,8 +325,8 @@ class AwsEcsExecutor(BaseExecutor):
|
|
320
325
|
task_key,
|
321
326
|
failure_count,
|
322
327
|
)
|
323
|
-
self.active_workers.pop_by_key(task_key)
|
324
328
|
self.fail(task_key)
|
329
|
+
self.active_workers.pop_by_key(task_key)
|
325
330
|
|
326
331
|
def attempt_task_runs(self):
|
327
332
|
"""
|
@@ -344,16 +349,17 @@ class AwsEcsExecutor(BaseExecutor):
|
|
344
349
|
attempt_number = ecs_task.attempt_number
|
345
350
|
_failure_reasons = []
|
346
351
|
if timezone.utcnow() < ecs_task.next_attempt_time:
|
352
|
+
self.pending_tasks.append(ecs_task)
|
347
353
|
continue
|
348
354
|
try:
|
349
355
|
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
|
350
356
|
except NoCredentialsError:
|
351
|
-
self.pending_tasks.
|
357
|
+
self.pending_tasks.append(ecs_task)
|
352
358
|
raise
|
353
359
|
except ClientError as e:
|
354
360
|
error_code = e.response["Error"]["Code"]
|
355
361
|
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
356
|
-
self.pending_tasks.
|
362
|
+
self.pending_tasks.append(ecs_task)
|
357
363
|
raise
|
358
364
|
_failure_reasons.append(str(e))
|
359
365
|
except Exception as e:
|
@@ -373,12 +379,12 @@ class AwsEcsExecutor(BaseExecutor):
|
|
373
379
|
for reason in _failure_reasons:
|
374
380
|
failure_reasons[reason] += 1
|
375
381
|
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
|
376
|
-
if int(attempt_number)
|
382
|
+
if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
|
377
383
|
ecs_task.attempt_number += 1
|
378
384
|
ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
|
379
385
|
attempt_number
|
380
386
|
)
|
381
|
-
self.pending_tasks.
|
387
|
+
self.pending_tasks.append(ecs_task)
|
382
388
|
else:
|
383
389
|
self.log.error(
|
384
390
|
"ECS task %s has failed a maximum of %s times. Marking as failed",
|
@@ -394,6 +400,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
394
400
|
else:
|
395
401
|
task = run_task_response["tasks"][0]
|
396
402
|
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
|
403
|
+
self.queued(task_key, task.task_arn)
|
397
404
|
if failure_reasons:
|
398
405
|
self.log.error(
|
399
406
|
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
|
@@ -494,3 +501,39 @@ class AwsEcsExecutor(BaseExecutor):
|
|
494
501
|
'container "name" must be provided in "containerOverrides" configuration'
|
495
502
|
)
|
496
503
|
raise KeyError(f"No such container found by container name: {self.container_name}")
|
504
|
+
|
505
|
+
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
|
506
|
+
"""
|
507
|
+
Adopt task instances which have an external_executor_id (the ECS task ARN).
|
508
|
+
|
509
|
+
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
|
510
|
+
"""
|
511
|
+
with Stats.timer("ecs_executor.adopt_task_instances.duration"):
|
512
|
+
adopted_tis: list[TaskInstance] = []
|
513
|
+
|
514
|
+
if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
|
515
|
+
task_descriptions = self.__describe_tasks(task_arns).get("tasks", [])
|
516
|
+
|
517
|
+
for task in task_descriptions:
|
518
|
+
ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
|
519
|
+
self.active_workers.add_task(
|
520
|
+
task,
|
521
|
+
ti.key,
|
522
|
+
ti.queue,
|
523
|
+
ti.command_as_list(),
|
524
|
+
ti.executor_config,
|
525
|
+
ti.prev_attempted_tries,
|
526
|
+
)
|
527
|
+
adopted_tis.append(ti)
|
528
|
+
|
529
|
+
if adopted_tis:
|
530
|
+
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
|
531
|
+
task_instance_str = "\n\t".join(tasks)
|
532
|
+
self.log.info(
|
533
|
+
"Adopted the following %d tasks from a dead executor:\n\t%s",
|
534
|
+
len(adopted_tis),
|
535
|
+
task_instance_str,
|
536
|
+
)
|
537
|
+
|
538
|
+
not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
|
539
|
+
return not_adopted_tis
|
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List
|
|
30
30
|
|
31
31
|
from inflection import camelize
|
32
32
|
|
33
|
+
from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
|
33
34
|
from airflow.utils.state import State
|
34
35
|
|
35
36
|
if TYPE_CHECKING:
|
@@ -71,36 +72,28 @@ class EcsTaskInfo:
|
|
71
72
|
config: ExecutorConfigType
|
72
73
|
|
73
74
|
|
74
|
-
class BaseConfigKeys:
|
75
|
-
"""Base Implementation of the Config Keys class. Implements iteration for child classes to inherit."""
|
76
|
-
|
77
|
-
def __iter__(self):
|
78
|
-
"""Return an iterator of values of non dunder attributes of Config Keys."""
|
79
|
-
return iter({value for (key, value) in self.__class__.__dict__.items() if not key.startswith("__")})
|
80
|
-
|
81
|
-
|
82
75
|
class RunTaskKwargsConfigKeys(BaseConfigKeys):
|
83
76
|
"""Keys loaded into the config which are valid ECS run_task kwargs."""
|
84
77
|
|
85
78
|
ASSIGN_PUBLIC_IP = "assign_public_ip"
|
86
79
|
CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
|
87
80
|
CLUSTER = "cluster"
|
81
|
+
CONTAINER_NAME = "container_name"
|
88
82
|
LAUNCH_TYPE = "launch_type"
|
89
83
|
PLATFORM_VERSION = "platform_version"
|
90
84
|
SECURITY_GROUPS = "security_groups"
|
91
85
|
SUBNETS = "subnets"
|
92
86
|
TASK_DEFINITION = "task_definition"
|
93
|
-
CONTAINER_NAME = "container_name"
|
94
87
|
|
95
88
|
|
96
89
|
class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
|
97
90
|
"""All keys loaded into the config which are related to the ECS Executor."""
|
98
91
|
|
99
|
-
MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
|
100
92
|
AWS_CONN_ID = "conn_id"
|
101
|
-
RUN_TASK_KWARGS = "run_task_kwargs"
|
102
|
-
REGION_NAME = "region_name"
|
103
93
|
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
|
94
|
+
MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
|
95
|
+
REGION_NAME = "region_name"
|
96
|
+
RUN_TASK_KWARGS = "run_task_kwargs"
|
104
97
|
|
105
98
|
|
106
99
|
class EcsExecutorException(Exception):
|
@@ -108,7 +101,7 @@ class EcsExecutorException(Exception):
|
|
108
101
|
|
109
102
|
|
110
103
|
class EcsExecutorTask:
|
111
|
-
"""Data Transfer Object for an ECS
|
104
|
+
"""Data Transfer Object for an ECS Task."""
|
112
105
|
|
113
106
|
def __init__(
|
114
107
|
self,
|
@@ -118,6 +111,7 @@ class EcsExecutorTask:
|
|
118
111
|
containers: list[dict[str, Any]],
|
119
112
|
started_at: Any | None = None,
|
120
113
|
stopped_reason: str | None = None,
|
114
|
+
external_executor_id: str | None = None,
|
121
115
|
):
|
122
116
|
self.task_arn = task_arn
|
123
117
|
self.last_status = last_status
|
@@ -125,6 +119,7 @@ class EcsExecutorTask:
|
|
125
119
|
self.containers = containers
|
126
120
|
self.started_at = started_at
|
127
121
|
self.stopped_reason = stopped_reason
|
122
|
+
self.external_executor_id = external_executor_id
|
128
123
|
|
129
124
|
def get_task_state(self) -> str:
|
130
125
|
"""
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
|
20
|
+
class BaseConfigKeys:
|
21
|
+
"""Base Implementation of the Config Keys class. Implements iteration for child classes to inherit."""
|
22
|
+
|
23
|
+
def __iter__(self):
|
24
|
+
"""Return an iterator of values of non dunder attributes of Config Keys."""
|
25
|
+
return iter({value for (key, value) in self.__class__.__dict__.items() if not key.startswith("__")})
|
@@ -24,6 +24,7 @@ A client for AWS Batch services.
|
|
24
24
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
|
25
25
|
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
|
26
26
|
"""
|
27
|
+
|
27
28
|
from __future__ import annotations
|
28
29
|
|
29
30
|
import itertools
|
@@ -438,7 +439,7 @@ class BatchClientHook(AwsBaseHook):
|
|
438
439
|
return None
|
439
440
|
if len(all_info) > 1:
|
440
441
|
self.log.warning(
|
441
|
-
|
442
|
+
"AWS Batch job (%s) has more than one log stream, only returning the first one.", job_id
|
442
443
|
)
|
443
444
|
return all_info[0]
|
444
445
|
|
@@ -474,7 +475,7 @@ class BatchClientHook(AwsBaseHook):
|
|
474
475
|
# If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
|
475
476
|
if any(c.get("logDriver", "awslogs") != "awslogs" for c in log_configs):
|
476
477
|
self.log.warning(
|
477
|
-
|
478
|
+
"AWS Batch job (%s) uses non-aws log drivers. AWS CloudWatch logging disabled.", job_id
|
478
479
|
)
|
479
480
|
return []
|
480
481
|
|
@@ -482,7 +483,7 @@ class BatchClientHook(AwsBaseHook):
|
|
482
483
|
# If this method is called very early after starting the AWS Batch job,
|
483
484
|
# there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
|
484
485
|
# This can also happen in case of misconfiguration.
|
485
|
-
self.log.warning(
|
486
|
+
self.log.warning("AWS Batch job (%s) doesn't have any AWS CloudWatch Stream.", job_id)
|
486
487
|
return []
|
487
488
|
|
488
489
|
# Try to get user-defined log configuration options
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
20
|
+
|
21
|
+
|
22
|
+
class BedrockHook(AwsBaseHook):
|
23
|
+
"""
|
24
|
+
Interact with Amazon Bedrock.
|
25
|
+
|
26
|
+
Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock") <Bedrock.Client>`.
|
27
|
+
|
28
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
29
|
+
are passed down to the underlying AwsBaseHook.
|
30
|
+
|
31
|
+
.. seealso::
|
32
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
33
|
+
"""
|
34
|
+
|
35
|
+
client_type = "bedrock"
|
36
|
+
|
37
|
+
def __init__(self, *args, **kwargs) -> None:
|
38
|
+
kwargs["client_type"] = self.client_type
|
39
|
+
super().__init__(*args, **kwargs)
|
40
|
+
|
41
|
+
|
42
|
+
class BedrockRuntimeHook(AwsBaseHook):
|
43
|
+
"""
|
44
|
+
Interact with the Amazon Bedrock Runtime.
|
45
|
+
|
46
|
+
Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") <BedrockRuntime.Client>`.
|
47
|
+
|
48
|
+
Additional arguments (such as ``aws_conn_id``) may be specified and
|
49
|
+
are passed down to the underlying AwsBaseHook.
|
50
|
+
|
51
|
+
.. seealso::
|
52
|
+
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
|
53
|
+
"""
|
54
|
+
|
55
|
+
client_type = "bedrock-runtime"
|
56
|
+
|
57
|
+
def __init__(self, *args, **kwargs) -> None:
|
58
|
+
kwargs["client_type"] = self.client_type
|
59
|
+
super().__init__(*args, **kwargs)
|
@@ -19,12 +19,13 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import asyncio
|
21
21
|
import time
|
22
|
+
from functools import cached_property
|
22
23
|
|
23
|
-
import boto3
|
24
24
|
from botocore.exceptions import ClientError
|
25
25
|
|
26
26
|
from airflow.exceptions import AirflowException
|
27
27
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
28
|
+
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
28
29
|
|
29
30
|
DEFAULT_LOG_SUFFIX = "output"
|
30
31
|
ERROR_LOG_SUFFIX = "error"
|
@@ -213,6 +214,13 @@ class GlueJobHook(AwsBaseHook):
|
|
213
214
|
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
|
214
215
|
return job_run["JobRun"]["JobRunState"]
|
215
216
|
|
217
|
+
@cached_property
|
218
|
+
def logs_hook(self):
|
219
|
+
"""Returns an AwsLogsHook instantiated with the parameters of the GlueJobHook."""
|
220
|
+
return AwsLogsHook(
|
221
|
+
aws_conn_id=self.aws_conn_id, region_name=self.region_name, verify=self.verify, config=self.config
|
222
|
+
)
|
223
|
+
|
216
224
|
def print_job_logs(
|
217
225
|
self,
|
218
226
|
job_name: str,
|
@@ -225,7 +233,7 @@ class GlueJobHook(AwsBaseHook):
|
|
225
233
|
:param continuation_tokens: the tokens where to resume from when reading logs.
|
226
234
|
The object gets updated with the new tokens by this method.
|
227
235
|
"""
|
228
|
-
log_client =
|
236
|
+
log_client = self.logs_hook.get_conn()
|
229
237
|
paginator = log_client.get_paginator("filter_log_events")
|
230
238
|
|
231
239
|
def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
|
@@ -245,8 +253,9 @@ class GlueJobHook(AwsBaseHook):
|
|
245
253
|
if e.response["Error"]["Code"] == "ResourceNotFoundException":
|
246
254
|
# we land here when the log groups/streams don't exist yet
|
247
255
|
self.log.warning(
|
248
|
-
"No new Glue driver logs so far.\
|
249
|
-
|
256
|
+
"No new Glue driver logs so far.\n"
|
257
|
+
"If this persists, check the CloudWatch dashboard at: %r.",
|
258
|
+
f"https://{self.conn_region_name}.console.aws.amazon.com/cloudwatch/home",
|
250
259
|
)
|
251
260
|
else:
|
252
261
|
raise
|
@@ -263,7 +272,6 @@ class GlueJobHook(AwsBaseHook):
|
|
263
272
|
log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
|
264
273
|
log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
|
265
274
|
log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
|
266
|
-
|
267
275
|
# one would think that the error log group would contain only errors, but it actually contains
|
268
276
|
# a lot of interesting logs too, so it's valuable to have both
|
269
277
|
continuation_tokens.output_stream_continuation = display_logs_from(
|