apache-airflow-providers-amazon 9.8.0rc1__py3-none-any.whl → 9.9.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/executors/aws_lambda/__init__.py +21 -0
- airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile +107 -0
- airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py +16 -0
- airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py +129 -0
- airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +479 -0
- airflow/providers/amazon/aws/executors/aws_lambda/utils.py +70 -0
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +4 -8
- airflow/providers/amazon/aws/hooks/base_aws.py +20 -4
- airflow/providers/amazon/aws/hooks/eks.py +14 -5
- airflow/providers/amazon/aws/hooks/s3.py +101 -34
- airflow/providers/amazon/aws/hooks/sns.py +10 -1
- airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +12 -5
- airflow/providers/amazon/aws/operators/batch.py +1 -2
- airflow/providers/amazon/aws/operators/cloud_formation.py +0 -2
- airflow/providers/amazon/aws/operators/comprehend.py +0 -2
- airflow/providers/amazon/aws/operators/dms.py +0 -2
- airflow/providers/amazon/aws/operators/ecs.py +1 -1
- airflow/providers/amazon/aws/operators/eks.py +13 -0
- airflow/providers/amazon/aws/operators/emr.py +4 -4
- airflow/providers/amazon/aws/operators/glue.py +0 -6
- airflow/providers/amazon/aws/operators/rds.py +0 -4
- airflow/providers/amazon/aws/operators/redshift_cluster.py +90 -63
- airflow/providers/amazon/aws/operators/sns.py +15 -1
- airflow/providers/amazon/aws/sensors/redshift_cluster.py +13 -10
- airflow/providers/amazon/get_provider_info.py +68 -0
- {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/METADATA +15 -19
- {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/RECORD +31 -25
- {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,479 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import json
|
20
|
+
import time
|
21
|
+
from collections import deque
|
22
|
+
from collections.abc import Sequence
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
25
|
+
from boto3.session import NoCredentialsError
|
26
|
+
from botocore.utils import ClientError
|
27
|
+
|
28
|
+
from airflow.configuration import conf
|
29
|
+
from airflow.exceptions import AirflowException
|
30
|
+
from airflow.executors.base_executor import BaseExecutor
|
31
|
+
from airflow.models.taskinstancekey import TaskInstanceKey
|
32
|
+
from airflow.providers.amazon.aws.executors.aws_lambda.utils import (
|
33
|
+
CONFIG_GROUP_NAME,
|
34
|
+
INVALID_CREDENTIALS_EXCEPTIONS,
|
35
|
+
AllLambdaConfigKeys,
|
36
|
+
CommandType,
|
37
|
+
LambdaQueuedTask,
|
38
|
+
)
|
39
|
+
from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
|
40
|
+
calculate_next_attempt_delay,
|
41
|
+
exponential_backoff_retry,
|
42
|
+
)
|
43
|
+
from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
|
44
|
+
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
|
45
|
+
from airflow.stats import Stats
|
46
|
+
from airflow.utils import timezone
|
47
|
+
|
48
|
+
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
|
49
|
+
|
50
|
+
if TYPE_CHECKING:
|
51
|
+
from sqlalchemy.orm import Session
|
52
|
+
|
53
|
+
from airflow.executors import workloads
|
54
|
+
from airflow.models.taskinstance import TaskInstance
|
55
|
+
|
56
|
+
|
57
|
+
class AwsLambdaExecutor(BaseExecutor):
|
58
|
+
"""
|
59
|
+
An Airflow Executor that submits tasks to AWS Lambda asynchronously.
|
60
|
+
|
61
|
+
When execute_async() is called, the executor invokes a specified AWS Lambda function (asynchronously)
|
62
|
+
with a payload that includes the task command and a unique task key.
|
63
|
+
|
64
|
+
The Lambda function writes its result directly to an SQS queue, which is then polled by this executor
|
65
|
+
to update task state in Airflow.
|
66
|
+
"""
|
67
|
+
|
68
|
+
if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
|
69
|
+
# In the v3 path, we store workloads, not commands as strings.
|
70
|
+
# TODO: TaskSDK: move this type change into BaseExecutor
|
71
|
+
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
|
72
|
+
|
73
|
+
def __init__(self, *args, **kwargs):
|
74
|
+
super().__init__(*args, **kwargs)
|
75
|
+
self.pending_tasks: deque = deque()
|
76
|
+
self.running_tasks: dict[str, TaskInstanceKey] = {}
|
77
|
+
self.lambda_function_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME)
|
78
|
+
self.sqs_queue_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL)
|
79
|
+
self.dlq_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
|
80
|
+
self.qualifier = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER, fallback=None)
|
81
|
+
# Maximum number of retries to invoke Lambda.
|
82
|
+
self.max_invoke_attempts = conf.get(
|
83
|
+
CONFIG_GROUP_NAME,
|
84
|
+
AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS,
|
85
|
+
)
|
86
|
+
|
87
|
+
self.attempts_since_last_successful_connection = 0
|
88
|
+
self.IS_BOTO_CONNECTION_HEALTHY = False
|
89
|
+
self.load_connections(check_connection=False)
|
90
|
+
|
91
|
+
def start(self):
|
92
|
+
"""Call this when the Executor is run for the first time by the scheduler."""
|
93
|
+
check_health = conf.getboolean(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP)
|
94
|
+
|
95
|
+
if not check_health:
|
96
|
+
return
|
97
|
+
|
98
|
+
self.log.info("Starting Lambda Executor and determining health...")
|
99
|
+
try:
|
100
|
+
self.check_health()
|
101
|
+
except AirflowException:
|
102
|
+
self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
|
103
|
+
raise
|
104
|
+
|
105
|
+
def check_health(self):
|
106
|
+
"""
|
107
|
+
Check the health of the Lambda and SQS connections.
|
108
|
+
|
109
|
+
For lambda: Use get_function to test if the lambda connection works and the function can be
|
110
|
+
described.
|
111
|
+
For SQS: Use get_queue_attributes is used as a close analog to describe to test if the SQS
|
112
|
+
connection is working.
|
113
|
+
"""
|
114
|
+
self.IS_BOTO_CONNECTION_HEALTHY = False
|
115
|
+
|
116
|
+
def _check_queue(queue_url):
|
117
|
+
sqs_get_queue_attrs_response = self.sqs_client.get_queue_attributes(
|
118
|
+
QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages"]
|
119
|
+
)
|
120
|
+
approx_num_msgs = sqs_get_queue_attrs_response.get("Attributes").get(
|
121
|
+
"ApproximateNumberOfMessages"
|
122
|
+
)
|
123
|
+
self.log.info(
|
124
|
+
"SQS connection is healthy and queue %s is present with %s messages.",
|
125
|
+
queue_url,
|
126
|
+
approx_num_msgs,
|
127
|
+
)
|
128
|
+
|
129
|
+
self.log.info("Checking Lambda and SQS connections")
|
130
|
+
try:
|
131
|
+
# Check Lambda health
|
132
|
+
lambda_get_response = self.lambda_client.get_function(FunctionName=self.lambda_function_name)
|
133
|
+
if self.lambda_function_name not in lambda_get_response["Configuration"]["FunctionName"]:
|
134
|
+
raise AirflowException("Lambda function %s not found.", self.lambda_function_name)
|
135
|
+
self.log.info(
|
136
|
+
"Lambda connection is healthy and function %s is present.", self.lambda_function_name
|
137
|
+
)
|
138
|
+
|
139
|
+
# Check SQS results queue
|
140
|
+
_check_queue(self.sqs_queue_url)
|
141
|
+
# Check SQS dead letter queue
|
142
|
+
_check_queue(self.dlq_url)
|
143
|
+
|
144
|
+
# If we reach this point, both connections are healthy and all resources are present
|
145
|
+
self.IS_BOTO_CONNECTION_HEALTHY = True
|
146
|
+
except Exception:
|
147
|
+
self.log.exception("Lambda Executor health check failed")
|
148
|
+
raise AirflowException(
|
149
|
+
"The Lambda executor will not be able to run Airflow tasks until the issue is addressed."
|
150
|
+
)
|
151
|
+
|
152
|
+
def load_connections(self, check_connection: bool = True):
|
153
|
+
"""
|
154
|
+
Retrieve the AWS connection via Hooks to leverage the Airflow connection system.
|
155
|
+
|
156
|
+
:param check_connection: If True, check the health of the connection after loading it.
|
157
|
+
"""
|
158
|
+
self.log.info("Loading Connections")
|
159
|
+
aws_conn_id = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.AWS_CONN_ID)
|
160
|
+
region_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME, fallback=None)
|
161
|
+
self.sqs_client = SqsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
|
162
|
+
self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
|
163
|
+
|
164
|
+
self.attempts_since_last_successful_connection += 1
|
165
|
+
self.last_connection_reload = timezone.utcnow()
|
166
|
+
|
167
|
+
if check_connection:
|
168
|
+
self.check_health()
|
169
|
+
self.attempts_since_last_successful_connection = 0
|
170
|
+
|
171
|
+
def sync(self):
|
172
|
+
"""
|
173
|
+
Sync the executor with the current state of tasks.
|
174
|
+
|
175
|
+
Check in on currently running tasks and attempt to run any new tasks that have been queued.
|
176
|
+
"""
|
177
|
+
if not self.IS_BOTO_CONNECTION_HEALTHY:
|
178
|
+
exponential_backoff_retry(
|
179
|
+
self.last_connection_reload,
|
180
|
+
self.attempts_since_last_successful_connection,
|
181
|
+
self.load_connections,
|
182
|
+
)
|
183
|
+
if not self.IS_BOTO_CONNECTION_HEALTHY:
|
184
|
+
return
|
185
|
+
try:
|
186
|
+
self.sync_running_tasks()
|
187
|
+
self.attempt_task_runs()
|
188
|
+
except (ClientError, NoCredentialsError) as error:
|
189
|
+
error_code = error.response["Error"]["Code"]
|
190
|
+
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
191
|
+
self.IS_BOTO_CONNECTION_HEALTHY = False
|
192
|
+
self.log.warning(
|
193
|
+
"AWS credentials are either missing or expired: %s.\nRetrying connection", error
|
194
|
+
)
|
195
|
+
except Exception:
|
196
|
+
self.log.exception("An error occurred while syncing tasks")
|
197
|
+
|
198
|
+
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
|
199
|
+
from airflow.executors import workloads
|
200
|
+
|
201
|
+
if not isinstance(workload, workloads.ExecuteTask):
|
202
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
|
203
|
+
ti = workload.ti
|
204
|
+
self.queued_tasks[ti.key] = workload
|
205
|
+
|
206
|
+
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
|
207
|
+
from airflow.executors.workloads import ExecuteTask
|
208
|
+
|
209
|
+
for w in workloads:
|
210
|
+
if not isinstance(w, ExecuteTask):
|
211
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
|
212
|
+
|
213
|
+
command = [w]
|
214
|
+
key = w.ti.key
|
215
|
+
queue = w.ti.queue
|
216
|
+
executor_config = w.ti.executor_config or {}
|
217
|
+
|
218
|
+
del self.queued_tasks[key]
|
219
|
+
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
|
220
|
+
self.running.add(key)
|
221
|
+
|
222
|
+
def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
|
223
|
+
"""
|
224
|
+
Save the task to be executed in the next sync by inserting the commands into a queue.
|
225
|
+
|
226
|
+
:param key: A unique task key (typically a tuple identifying the task instance).
|
227
|
+
:param command: The shell command string to execute.
|
228
|
+
:param executor_config: (Unused) to keep the same signature as the base.
|
229
|
+
:param queue: (Unused) to keep the same signature as the base.
|
230
|
+
"""
|
231
|
+
if len(command) == 1:
|
232
|
+
from airflow.executors.workloads import ExecuteTask
|
233
|
+
|
234
|
+
if isinstance(command[0], ExecuteTask):
|
235
|
+
workload = command[0]
|
236
|
+
ser_input = workload.model_dump_json()
|
237
|
+
command = [
|
238
|
+
"python",
|
239
|
+
"-m",
|
240
|
+
"airflow.sdk.execution_time.execute_workload",
|
241
|
+
"--json-string",
|
242
|
+
ser_input,
|
243
|
+
]
|
244
|
+
else:
|
245
|
+
raise RuntimeError(
|
246
|
+
f"LambdaExecutor doesn't know how to handle workload of type: {type(command[0])}"
|
247
|
+
)
|
248
|
+
|
249
|
+
self.pending_tasks.append(
|
250
|
+
LambdaQueuedTask(
|
251
|
+
key, command, queue if queue else "", executor_config or {}, 1, timezone.utcnow()
|
252
|
+
)
|
253
|
+
)
|
254
|
+
|
255
|
+
def attempt_task_runs(self):
|
256
|
+
"""
|
257
|
+
Attempt to run tasks that are queued in the pending_tasks.
|
258
|
+
|
259
|
+
Each task is submitted to AWS Lambda with a payload containing the task key and command.
|
260
|
+
The task key is used to track the task's state in Airflow.
|
261
|
+
"""
|
262
|
+
queue_len = len(self.pending_tasks)
|
263
|
+
for _ in range(queue_len):
|
264
|
+
task_to_run = self.pending_tasks.popleft()
|
265
|
+
task_key = task_to_run.key
|
266
|
+
cmd = task_to_run.command
|
267
|
+
attempt_number = task_to_run.attempt_number
|
268
|
+
failure_reasons = []
|
269
|
+
ser_task_key = json.dumps(task_key._asdict())
|
270
|
+
payload = {
|
271
|
+
"task_key": ser_task_key,
|
272
|
+
"command": cmd,
|
273
|
+
}
|
274
|
+
if timezone.utcnow() < task_to_run.next_attempt_time:
|
275
|
+
self.pending_tasks.append(task_to_run)
|
276
|
+
continue
|
277
|
+
|
278
|
+
self.log.info("Submitting task %s to Lambda function %s", task_key, self.lambda_function_name)
|
279
|
+
|
280
|
+
try:
|
281
|
+
invoke_kwargs = {
|
282
|
+
"FunctionName": self.lambda_function_name,
|
283
|
+
"InvocationType": "Event",
|
284
|
+
"Payload": json.dumps(payload),
|
285
|
+
}
|
286
|
+
if self.qualifier:
|
287
|
+
invoke_kwargs["Qualifier"] = self.qualifier
|
288
|
+
response = self.lambda_client.invoke(**invoke_kwargs)
|
289
|
+
except NoCredentialsError:
|
290
|
+
self.pending_tasks.append(task_to_run)
|
291
|
+
raise
|
292
|
+
except ClientError as e:
|
293
|
+
error_code = e.response["Error"]["Code"]
|
294
|
+
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
|
295
|
+
self.pending_tasks.append(task_to_run)
|
296
|
+
raise
|
297
|
+
failure_reasons.append(str(e))
|
298
|
+
except Exception as e:
|
299
|
+
# Failed to even get a response back from the Boto3 API or something else went
|
300
|
+
# wrong. For any possible failure we want to add the exception reasons to the
|
301
|
+
# failure list so that it is logged to the user and most importantly the task is
|
302
|
+
# added back to the pending list to be retried later.
|
303
|
+
failure_reasons.append(str(e))
|
304
|
+
|
305
|
+
if failure_reasons:
|
306
|
+
# Make sure the number of attempts does not exceed max invoke attempts
|
307
|
+
if int(attempt_number) < int(self.max_invoke_attempts):
|
308
|
+
task_to_run.attempt_number += 1
|
309
|
+
task_to_run.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
|
310
|
+
attempt_number
|
311
|
+
)
|
312
|
+
self.pending_tasks.append(task_to_run)
|
313
|
+
else:
|
314
|
+
reasons_str = ", ".join(failure_reasons)
|
315
|
+
self.log.error(
|
316
|
+
"Lambda invoke %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
|
317
|
+
task_key,
|
318
|
+
attempt_number,
|
319
|
+
reasons_str,
|
320
|
+
)
|
321
|
+
self.log_task_event(
|
322
|
+
event="lambda invoke failure",
|
323
|
+
ti_key=task_key,
|
324
|
+
extra=(
|
325
|
+
f"Task could not be queued after {attempt_number} attempts. "
|
326
|
+
f"Marking as failed. Reasons: {reasons_str}"
|
327
|
+
),
|
328
|
+
)
|
329
|
+
self.fail(task_key)
|
330
|
+
else:
|
331
|
+
status_code = response.get("StatusCode")
|
332
|
+
self.log.info("Invoked Lambda for task %s with status %s", task_key, status_code)
|
333
|
+
self.running_tasks[ser_task_key] = task_key
|
334
|
+
# Add the serialized task key as the info, this will be assigned on the ti as the external_executor_id
|
335
|
+
self.running_state(task_key, ser_task_key)
|
336
|
+
|
337
|
+
def sync_running_tasks(self):
|
338
|
+
"""
|
339
|
+
Poll the SQS queue for messages indicating task completion.
|
340
|
+
|
341
|
+
Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
|
342
|
+
Based on the return code, update the task state accordingly.
|
343
|
+
"""
|
344
|
+
if not len(self.running_tasks):
|
345
|
+
self.log.debug("No running tasks to process.")
|
346
|
+
return
|
347
|
+
|
348
|
+
self.process_queue(self.sqs_queue_url)
|
349
|
+
if self.dlq_url and self.running_tasks:
|
350
|
+
self.process_queue(self.dlq_url)
|
351
|
+
|
352
|
+
def process_queue(self, queue_url: str):
|
353
|
+
"""
|
354
|
+
Poll the SQS queue for messages indicating task completion.
|
355
|
+
|
356
|
+
Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
|
357
|
+
|
358
|
+
Based on the return code, update the task state accordingly.
|
359
|
+
"""
|
360
|
+
response = self.sqs_client.receive_message(
|
361
|
+
QueueUrl=queue_url,
|
362
|
+
MaxNumberOfMessages=10,
|
363
|
+
)
|
364
|
+
|
365
|
+
messages = response.get("Messages", [])
|
366
|
+
# Pagination? Maybe we don't need it. Since we always delete messages after looking at them.
|
367
|
+
# But then that may delete messages that could have been adopted. Let's leave it for now and see how it goes.
|
368
|
+
if messages and queue_url == self.dlq_url:
|
369
|
+
self.log.warning("%d messages received from the dead letter queue", len(messages))
|
370
|
+
|
371
|
+
for message in messages:
|
372
|
+
receipt_handle = message["ReceiptHandle"]
|
373
|
+
body = json.loads(message["Body"])
|
374
|
+
return_code = body.get("return_code")
|
375
|
+
ser_task_key = body.get("task_key")
|
376
|
+
# Fetch the real task key from the running_tasks dict, using the serialized task key.
|
377
|
+
try:
|
378
|
+
task_key = self.running_tasks[ser_task_key]
|
379
|
+
except KeyError:
|
380
|
+
self.log.warning(
|
381
|
+
"Received task %s from the queue which is not found in running tasks. Removing message.",
|
382
|
+
ser_task_key,
|
383
|
+
)
|
384
|
+
task_key = None
|
385
|
+
|
386
|
+
if task_key:
|
387
|
+
if return_code == 0:
|
388
|
+
self.success(task_key)
|
389
|
+
self.log.info(
|
390
|
+
"Successful Lambda invocation for task %s received from SQS queue.", task_key
|
391
|
+
)
|
392
|
+
else:
|
393
|
+
# In this case the Lambda likely started but failed at run time since we got a non-zero
|
394
|
+
# return code. We could consider retrying these tasks within the executor, because this _likely_
|
395
|
+
# means the Airflow task did not run to completion, however we can't be sure (maybe the
|
396
|
+
# lambda runtime code has a bug and is returning a non-zero when it actually passed?). So
|
397
|
+
# perhaps not retrying is the safest option.
|
398
|
+
self.fail(task_key)
|
399
|
+
self.log.error(
|
400
|
+
"Lambda invocation for task: %s has failed to run with return code %s",
|
401
|
+
task_key,
|
402
|
+
return_code,
|
403
|
+
)
|
404
|
+
# Remove the task from the tracking mapping.
|
405
|
+
self.running_tasks.pop(ser_task_key)
|
406
|
+
|
407
|
+
# Delete the message from the queue.
|
408
|
+
self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
|
409
|
+
|
410
|
+
def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
|
411
|
+
"""
|
412
|
+
Adopt task instances which have an external_executor_id (the serialized task key).
|
413
|
+
|
414
|
+
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
|
415
|
+
|
416
|
+
:param tis: The task instances to adopt.
|
417
|
+
"""
|
418
|
+
with Stats.timer("lambda_executor.adopt_task_instances.duration"):
|
419
|
+
adopted_tis: list[TaskInstance] = []
|
420
|
+
|
421
|
+
if serialized_task_keys := [
|
422
|
+
(ti, ti.external_executor_id) for ti in tis if ti.external_executor_id
|
423
|
+
]:
|
424
|
+
for ti, ser_task_key in serialized_task_keys:
|
425
|
+
try:
|
426
|
+
task_key = TaskInstanceKey.from_dict(json.loads(ser_task_key))
|
427
|
+
except Exception:
|
428
|
+
# If that task fails to deserialize, we should just skip it.
|
429
|
+
self.log.exception(
|
430
|
+
"Task failed to be adopted because the key could not be deserialized"
|
431
|
+
)
|
432
|
+
continue
|
433
|
+
self.running_tasks[ser_task_key] = task_key
|
434
|
+
adopted_tis.append(ti)
|
435
|
+
|
436
|
+
if adopted_tis:
|
437
|
+
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
|
438
|
+
task_instance_str = "\n\t".join(tasks)
|
439
|
+
self.log.info(
|
440
|
+
"Adopted the following %d tasks from a dead executor:\n\t%s",
|
441
|
+
len(adopted_tis),
|
442
|
+
task_instance_str,
|
443
|
+
)
|
444
|
+
|
445
|
+
not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
|
446
|
+
return not_adopted_tis
|
447
|
+
|
448
|
+
def end(self, heartbeat_interval=10):
|
449
|
+
"""
|
450
|
+
End execution. Poll until all outstanding tasks are marked as completed.
|
451
|
+
|
452
|
+
This is a blocking call and async Lambda tasks can not be cancelled, so this will wait until
|
453
|
+
all tasks are either completed or the timeout is reached.
|
454
|
+
|
455
|
+
:param heartbeat_interval: The interval in seconds to wait between checks for task completion.
|
456
|
+
"""
|
457
|
+
self.log.info("Received signal to end, waiting for outstanding tasks to finish.")
|
458
|
+
time_to_wait = int(conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT))
|
459
|
+
start_time = timezone.utcnow()
|
460
|
+
while True:
|
461
|
+
if time_to_wait:
|
462
|
+
current_time = timezone.utcnow()
|
463
|
+
elapsed_time = (current_time - start_time).total_seconds()
|
464
|
+
if elapsed_time > time_to_wait:
|
465
|
+
self.log.warning(
|
466
|
+
"Timed out waiting for tasks to finish. Some tasks may not be handled gracefully"
|
467
|
+
" as the executor is force ending due to timeout."
|
468
|
+
)
|
469
|
+
break
|
470
|
+
self.sync()
|
471
|
+
if not self.running_tasks:
|
472
|
+
self.log.info("All tasks completed; executor ending.")
|
473
|
+
break
|
474
|
+
self.log.info("Waiting for %d task(s) to complete.", len(self.running_tasks))
|
475
|
+
time.sleep(heartbeat_interval)
|
476
|
+
|
477
|
+
def terminate(self):
|
478
|
+
"""Get called when the daemon receives a SIGTERM."""
|
479
|
+
self.log.warning("Terminating Lambda executor. In-flight tasks cannot be stopped.")
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import datetime
|
20
|
+
from collections.abc import Sequence
|
21
|
+
from dataclasses import dataclass
|
22
|
+
from typing import TYPE_CHECKING, Any
|
23
|
+
|
24
|
+
from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from airflow.models.taskinstancekey import TaskInstanceKey
|
28
|
+
|
29
|
+
|
30
|
+
CONFIG_GROUP_NAME = "aws_lambda_executor"
|
31
|
+
INVALID_CREDENTIALS_EXCEPTIONS = [
|
32
|
+
"ExpiredTokenException",
|
33
|
+
"InvalidClientTokenId",
|
34
|
+
"UnrecognizedClientException",
|
35
|
+
]
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class LambdaQueuedTask:
|
40
|
+
"""Represents a Lambda task that is queued. The task will be run in the next heartbeat."""
|
41
|
+
|
42
|
+
key: TaskInstanceKey
|
43
|
+
command: CommandType
|
44
|
+
queue: str
|
45
|
+
executor_config: ExecutorConfigType
|
46
|
+
attempt_number: int
|
47
|
+
next_attempt_time: datetime.datetime
|
48
|
+
|
49
|
+
|
50
|
+
class InvokeLambdaKwargsConfigKeys(BaseConfigKeys):
|
51
|
+
"""Config keys loaded which are valid lambda invoke args."""
|
52
|
+
|
53
|
+
FUNCTION_NAME = "function_name"
|
54
|
+
QUALIFIER = "function_qualifier"
|
55
|
+
|
56
|
+
|
57
|
+
class AllLambdaConfigKeys(InvokeLambdaKwargsConfigKeys):
|
58
|
+
"""All config keys which are related to the Lambda Executor."""
|
59
|
+
|
60
|
+
AWS_CONN_ID = "conn_id"
|
61
|
+
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
|
62
|
+
MAX_INVOKE_ATTEMPTS = "max_run_task_attempts"
|
63
|
+
REGION_NAME = "region_name"
|
64
|
+
QUEUE_URL = "queue_url"
|
65
|
+
DLQ_URL = "dead_letter_queue_url"
|
66
|
+
END_WAIT_TIMEOUT = "end_wait_timeout"
|
67
|
+
|
68
|
+
|
69
|
+
CommandType = Sequence[str]
|
70
|
+
ExecutorConfigType = dict[str, Any]
|
@@ -131,7 +131,7 @@ class AwsEcsExecutor(BaseExecutor):
|
|
131
131
|
ti = workload.ti
|
132
132
|
self.queued_tasks[ti.key] = workload
|
133
133
|
|
134
|
-
def _process_workloads(self, workloads:
|
134
|
+
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
|
135
135
|
from airflow.executors.workloads import ExecuteTask
|
136
136
|
|
137
137
|
# Airflow V3 version
|
@@ -101,14 +101,10 @@ def build_task_kwargs() -> dict:
|
|
101
101
|
# The executor will overwrite the 'command' property during execution. Must always be the first container!
|
102
102
|
task_kwargs["overrides"]["containerOverrides"][0]["command"] = [] # type: ignore
|
103
103
|
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
# Surrounding parens are for the walrus operator to function correctly along with the None check
|
109
|
-
(assign_public_ip := task_kwargs.pop(AllEcsConfigKeys.ASSIGN_PUBLIC_IP, None)) is not None,
|
110
|
-
]
|
111
|
-
):
|
104
|
+
subnets = task_kwargs.pop(AllEcsConfigKeys.SUBNETS, None)
|
105
|
+
security_groups = task_kwargs.pop(AllEcsConfigKeys.SECURITY_GROUPS, None)
|
106
|
+
assign_public_ip = task_kwargs.pop(AllEcsConfigKeys.ASSIGN_PUBLIC_IP, None)
|
107
|
+
if subnets or security_groups or assign_public_ip != "False":
|
112
108
|
network_config = prune_dict(
|
113
109
|
{
|
114
110
|
"awsvpcConfiguration": {
|
@@ -58,12 +58,16 @@ from airflow.hooks.base import BaseHook
|
|
58
58
|
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
|
59
59
|
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
|
60
60
|
from airflow.providers.amazon.aws.utils.suppress import return_on_error
|
61
|
+
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
|
61
62
|
from airflow.providers_manager import ProvidersManager
|
62
63
|
from airflow.utils.helpers import exactly_one
|
63
64
|
from airflow.utils.log.logging_mixin import LoggingMixin
|
64
65
|
|
65
66
|
BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
|
66
67
|
|
68
|
+
if AIRFLOW_V_3_0_PLUS:
|
69
|
+
from airflow.sdk.exceptions import AirflowRuntimeError
|
70
|
+
|
67
71
|
if TYPE_CHECKING:
|
68
72
|
from aiobotocore.session import AioSession
|
69
73
|
from botocore.client import ClientMeta
|
@@ -603,12 +607,24 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
|
|
603
607
|
"""Get the Airflow Connection object and wrap it in helper (cached)."""
|
604
608
|
connection = None
|
605
609
|
if self.aws_conn_id:
|
610
|
+
possible_exceptions: tuple[type[Exception], ...]
|
611
|
+
|
612
|
+
if AIRFLOW_V_3_0_PLUS:
|
613
|
+
possible_exceptions = (AirflowNotFoundException, AirflowRuntimeError)
|
614
|
+
else:
|
615
|
+
possible_exceptions = (AirflowNotFoundException,)
|
616
|
+
|
606
617
|
try:
|
607
618
|
connection = self.get_connection(self.aws_conn_id)
|
608
|
-
except
|
609
|
-
|
610
|
-
|
611
|
-
)
|
619
|
+
except possible_exceptions as e:
|
620
|
+
if isinstance(
|
621
|
+
e, AirflowNotFoundException
|
622
|
+
) or f"Connection with ID {self.aws_conn_id} not found" in str(e):
|
623
|
+
self.log.warning(
|
624
|
+
"Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
|
625
|
+
)
|
626
|
+
else:
|
627
|
+
raise
|
612
628
|
|
613
629
|
return AwsConnectionWrapper(
|
614
630
|
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
|
@@ -79,16 +79,25 @@ class NodegroupStates(Enum):
|
|
79
79
|
|
80
80
|
|
81
81
|
COMMAND = """
|
82
|
+
export PYTHON_OPERATORS_VIRTUAL_ENV_MODE=1
|
82
83
|
output=$({python_executable} -m airflow.providers.amazon.aws.utils.eks_get_token \
|
83
84
|
--cluster-name {eks_cluster_name} {args} 2>&1)
|
84
85
|
|
85
|
-
|
86
|
-
|
87
|
-
|
86
|
+
status=$?
|
87
|
+
|
88
|
+
if [ "$status" -ne 0 ]; then
|
89
|
+
printf '%s' "$output" >&2
|
90
|
+
exit "$status"
|
88
91
|
fi
|
89
92
|
|
90
|
-
|
91
|
-
|
93
|
+
# Use pure bash below to parse so that it's posix compliant
|
94
|
+
|
95
|
+
last_line=${{output##*$'\\n'}} # strip everything up to the last newline
|
96
|
+
|
97
|
+
timestamp=${{last_line#expirationTimestamp: }} # drop the label
|
98
|
+
timestamp=${{timestamp%%,*}} # keep up to the first comma
|
99
|
+
|
100
|
+
token=${{last_line##*, token: }} # text after ", token: "
|
92
101
|
|
93
102
|
json_string=$(printf '{{"kind": "ExecCredential","apiVersion": \
|
94
103
|
"client.authentication.k8s.io/v1alpha1","spec": {{}},"status": \
|