apache-airflow-providers-amazon 9.8.0rc1__py3-none-any.whl → 9.9.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/executors/aws_lambda/__init__.py +21 -0
  3. airflow/providers/amazon/aws/executors/aws_lambda/docker/Dockerfile +107 -0
  4. airflow/providers/amazon/aws/executors/aws_lambda/docker/__init__.py +16 -0
  5. airflow/providers/amazon/aws/executors/aws_lambda/docker/app.py +129 -0
  6. airflow/providers/amazon/aws/executors/aws_lambda/lambda_executor.py +479 -0
  7. airflow/providers/amazon/aws/executors/aws_lambda/utils.py +70 -0
  8. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +1 -1
  9. airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +4 -8
  10. airflow/providers/amazon/aws/hooks/base_aws.py +20 -4
  11. airflow/providers/amazon/aws/hooks/eks.py +14 -5
  12. airflow/providers/amazon/aws/hooks/s3.py +101 -34
  13. airflow/providers/amazon/aws/hooks/sns.py +10 -1
  14. airflow/providers/amazon/aws/log/cloudwatch_task_handler.py +12 -5
  15. airflow/providers/amazon/aws/operators/batch.py +1 -2
  16. airflow/providers/amazon/aws/operators/cloud_formation.py +0 -2
  17. airflow/providers/amazon/aws/operators/comprehend.py +0 -2
  18. airflow/providers/amazon/aws/operators/dms.py +0 -2
  19. airflow/providers/amazon/aws/operators/ecs.py +1 -1
  20. airflow/providers/amazon/aws/operators/eks.py +13 -0
  21. airflow/providers/amazon/aws/operators/emr.py +4 -4
  22. airflow/providers/amazon/aws/operators/glue.py +0 -6
  23. airflow/providers/amazon/aws/operators/rds.py +0 -4
  24. airflow/providers/amazon/aws/operators/redshift_cluster.py +90 -63
  25. airflow/providers/amazon/aws/operators/sns.py +15 -1
  26. airflow/providers/amazon/aws/sensors/redshift_cluster.py +13 -10
  27. airflow/providers/amazon/get_provider_info.py +68 -0
  28. {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/METADATA +15 -19
  29. {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/RECORD +31 -25
  30. {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/WHEEL +0 -0
  31. {apache_airflow_providers_amazon-9.8.0rc1.dist-info → apache_airflow_providers_amazon-9.9.0rc1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,479 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import time
21
+ from collections import deque
22
+ from collections.abc import Sequence
23
+ from typing import TYPE_CHECKING
24
+
25
+ from boto3.session import NoCredentialsError
26
+ from botocore.utils import ClientError
27
+
28
+ from airflow.configuration import conf
29
+ from airflow.exceptions import AirflowException
30
+ from airflow.executors.base_executor import BaseExecutor
31
+ from airflow.models.taskinstancekey import TaskInstanceKey
32
+ from airflow.providers.amazon.aws.executors.aws_lambda.utils import (
33
+ CONFIG_GROUP_NAME,
34
+ INVALID_CREDENTIALS_EXCEPTIONS,
35
+ AllLambdaConfigKeys,
36
+ CommandType,
37
+ LambdaQueuedTask,
38
+ )
39
+ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
40
+ calculate_next_attempt_delay,
41
+ exponential_backoff_retry,
42
+ )
43
+ from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
44
+ from airflow.providers.amazon.aws.hooks.sqs import SqsHook
45
+ from airflow.stats import Stats
46
+ from airflow.utils import timezone
47
+
48
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
49
+
50
+ if TYPE_CHECKING:
51
+ from sqlalchemy.orm import Session
52
+
53
+ from airflow.executors import workloads
54
+ from airflow.models.taskinstance import TaskInstance
55
+
56
+
57
+ class AwsLambdaExecutor(BaseExecutor):
58
+ """
59
+ An Airflow Executor that submits tasks to AWS Lambda asynchronously.
60
+
61
+ When execute_async() is called, the executor invokes a specified AWS Lambda function (asynchronously)
62
+ with a payload that includes the task command and a unique task key.
63
+
64
+ The Lambda function writes its result directly to an SQS queue, which is then polled by this executor
65
+ to update task state in Airflow.
66
+ """
67
+
68
+ if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
69
+ # In the v3 path, we store workloads, not commands as strings.
70
+ # TODO: TaskSDK: move this type change into BaseExecutor
71
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
72
+
73
+ def __init__(self, *args, **kwargs):
74
+ super().__init__(*args, **kwargs)
75
+ self.pending_tasks: deque = deque()
76
+ self.running_tasks: dict[str, TaskInstanceKey] = {}
77
+ self.lambda_function_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME)
78
+ self.sqs_queue_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL)
79
+ self.dlq_url = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
80
+ self.qualifier = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUALIFIER, fallback=None)
81
+ # Maximum number of retries to invoke Lambda.
82
+ self.max_invoke_attempts = conf.get(
83
+ CONFIG_GROUP_NAME,
84
+ AllLambdaConfigKeys.MAX_INVOKE_ATTEMPTS,
85
+ )
86
+
87
+ self.attempts_since_last_successful_connection = 0
88
+ self.IS_BOTO_CONNECTION_HEALTHY = False
89
+ self.load_connections(check_connection=False)
90
+
91
+ def start(self):
92
+ """Call this when the Executor is run for the first time by the scheduler."""
93
+ check_health = conf.getboolean(CONFIG_GROUP_NAME, AllLambdaConfigKeys.CHECK_HEALTH_ON_STARTUP)
94
+
95
+ if not check_health:
96
+ return
97
+
98
+ self.log.info("Starting Lambda Executor and determining health...")
99
+ try:
100
+ self.check_health()
101
+ except AirflowException:
102
+ self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
103
+ raise
104
+
105
+ def check_health(self):
106
+ """
107
+ Check the health of the Lambda and SQS connections.
108
+
109
+ For lambda: Use get_function to test if the lambda connection works and the function can be
110
+ described.
111
+ For SQS: Use get_queue_attributes is used as a close analog to describe to test if the SQS
112
+ connection is working.
113
+ """
114
+ self.IS_BOTO_CONNECTION_HEALTHY = False
115
+
116
+ def _check_queue(queue_url):
117
+ sqs_get_queue_attrs_response = self.sqs_client.get_queue_attributes(
118
+ QueueUrl=queue_url, AttributeNames=["ApproximateNumberOfMessages"]
119
+ )
120
+ approx_num_msgs = sqs_get_queue_attrs_response.get("Attributes").get(
121
+ "ApproximateNumberOfMessages"
122
+ )
123
+ self.log.info(
124
+ "SQS connection is healthy and queue %s is present with %s messages.",
125
+ queue_url,
126
+ approx_num_msgs,
127
+ )
128
+
129
+ self.log.info("Checking Lambda and SQS connections")
130
+ try:
131
+ # Check Lambda health
132
+ lambda_get_response = self.lambda_client.get_function(FunctionName=self.lambda_function_name)
133
+ if self.lambda_function_name not in lambda_get_response["Configuration"]["FunctionName"]:
134
+ raise AirflowException("Lambda function %s not found.", self.lambda_function_name)
135
+ self.log.info(
136
+ "Lambda connection is healthy and function %s is present.", self.lambda_function_name
137
+ )
138
+
139
+ # Check SQS results queue
140
+ _check_queue(self.sqs_queue_url)
141
+ # Check SQS dead letter queue
142
+ _check_queue(self.dlq_url)
143
+
144
+ # If we reach this point, both connections are healthy and all resources are present
145
+ self.IS_BOTO_CONNECTION_HEALTHY = True
146
+ except Exception:
147
+ self.log.exception("Lambda Executor health check failed")
148
+ raise AirflowException(
149
+ "The Lambda executor will not be able to run Airflow tasks until the issue is addressed."
150
+ )
151
+
152
+ def load_connections(self, check_connection: bool = True):
153
+ """
154
+ Retrieve the AWS connection via Hooks to leverage the Airflow connection system.
155
+
156
+ :param check_connection: If True, check the health of the connection after loading it.
157
+ """
158
+ self.log.info("Loading Connections")
159
+ aws_conn_id = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.AWS_CONN_ID)
160
+ region_name = conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.REGION_NAME, fallback=None)
161
+ self.sqs_client = SqsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
162
+ self.lambda_client = LambdaHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
163
+
164
+ self.attempts_since_last_successful_connection += 1
165
+ self.last_connection_reload = timezone.utcnow()
166
+
167
+ if check_connection:
168
+ self.check_health()
169
+ self.attempts_since_last_successful_connection = 0
170
+
171
+ def sync(self):
172
+ """
173
+ Sync the executor with the current state of tasks.
174
+
175
+ Check in on currently running tasks and attempt to run any new tasks that have been queued.
176
+ """
177
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
178
+ exponential_backoff_retry(
179
+ self.last_connection_reload,
180
+ self.attempts_since_last_successful_connection,
181
+ self.load_connections,
182
+ )
183
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
184
+ return
185
+ try:
186
+ self.sync_running_tasks()
187
+ self.attempt_task_runs()
188
+ except (ClientError, NoCredentialsError) as error:
189
+ error_code = error.response["Error"]["Code"]
190
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
191
+ self.IS_BOTO_CONNECTION_HEALTHY = False
192
+ self.log.warning(
193
+ "AWS credentials are either missing or expired: %s.\nRetrying connection", error
194
+ )
195
+ except Exception:
196
+ self.log.exception("An error occurred while syncing tasks")
197
+
198
+ def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
199
+ from airflow.executors import workloads
200
+
201
+ if not isinstance(workload, workloads.ExecuteTask):
202
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
203
+ ti = workload.ti
204
+ self.queued_tasks[ti.key] = workload
205
+
206
+ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
207
+ from airflow.executors.workloads import ExecuteTask
208
+
209
+ for w in workloads:
210
+ if not isinstance(w, ExecuteTask):
211
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
212
+
213
+ command = [w]
214
+ key = w.ti.key
215
+ queue = w.ti.queue
216
+ executor_config = w.ti.executor_config or {}
217
+
218
+ del self.queued_tasks[key]
219
+ self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
220
+ self.running.add(key)
221
+
222
+ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
223
+ """
224
+ Save the task to be executed in the next sync by inserting the commands into a queue.
225
+
226
+ :param key: A unique task key (typically a tuple identifying the task instance).
227
+ :param command: The shell command string to execute.
228
+ :param executor_config: (Unused) to keep the same signature as the base.
229
+ :param queue: (Unused) to keep the same signature as the base.
230
+ """
231
+ if len(command) == 1:
232
+ from airflow.executors.workloads import ExecuteTask
233
+
234
+ if isinstance(command[0], ExecuteTask):
235
+ workload = command[0]
236
+ ser_input = workload.model_dump_json()
237
+ command = [
238
+ "python",
239
+ "-m",
240
+ "airflow.sdk.execution_time.execute_workload",
241
+ "--json-string",
242
+ ser_input,
243
+ ]
244
+ else:
245
+ raise RuntimeError(
246
+ f"LambdaExecutor doesn't know how to handle workload of type: {type(command[0])}"
247
+ )
248
+
249
+ self.pending_tasks.append(
250
+ LambdaQueuedTask(
251
+ key, command, queue if queue else "", executor_config or {}, 1, timezone.utcnow()
252
+ )
253
+ )
254
+
255
+ def attempt_task_runs(self):
256
+ """
257
+ Attempt to run tasks that are queued in the pending_tasks.
258
+
259
+ Each task is submitted to AWS Lambda with a payload containing the task key and command.
260
+ The task key is used to track the task's state in Airflow.
261
+ """
262
+ queue_len = len(self.pending_tasks)
263
+ for _ in range(queue_len):
264
+ task_to_run = self.pending_tasks.popleft()
265
+ task_key = task_to_run.key
266
+ cmd = task_to_run.command
267
+ attempt_number = task_to_run.attempt_number
268
+ failure_reasons = []
269
+ ser_task_key = json.dumps(task_key._asdict())
270
+ payload = {
271
+ "task_key": ser_task_key,
272
+ "command": cmd,
273
+ }
274
+ if timezone.utcnow() < task_to_run.next_attempt_time:
275
+ self.pending_tasks.append(task_to_run)
276
+ continue
277
+
278
+ self.log.info("Submitting task %s to Lambda function %s", task_key, self.lambda_function_name)
279
+
280
+ try:
281
+ invoke_kwargs = {
282
+ "FunctionName": self.lambda_function_name,
283
+ "InvocationType": "Event",
284
+ "Payload": json.dumps(payload),
285
+ }
286
+ if self.qualifier:
287
+ invoke_kwargs["Qualifier"] = self.qualifier
288
+ response = self.lambda_client.invoke(**invoke_kwargs)
289
+ except NoCredentialsError:
290
+ self.pending_tasks.append(task_to_run)
291
+ raise
292
+ except ClientError as e:
293
+ error_code = e.response["Error"]["Code"]
294
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
295
+ self.pending_tasks.append(task_to_run)
296
+ raise
297
+ failure_reasons.append(str(e))
298
+ except Exception as e:
299
+ # Failed to even get a response back from the Boto3 API or something else went
300
+ # wrong. For any possible failure we want to add the exception reasons to the
301
+ # failure list so that it is logged to the user and most importantly the task is
302
+ # added back to the pending list to be retried later.
303
+ failure_reasons.append(str(e))
304
+
305
+ if failure_reasons:
306
+ # Make sure the number of attempts does not exceed max invoke attempts
307
+ if int(attempt_number) < int(self.max_invoke_attempts):
308
+ task_to_run.attempt_number += 1
309
+ task_to_run.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
310
+ attempt_number
311
+ )
312
+ self.pending_tasks.append(task_to_run)
313
+ else:
314
+ reasons_str = ", ".join(failure_reasons)
315
+ self.log.error(
316
+ "Lambda invoke %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
317
+ task_key,
318
+ attempt_number,
319
+ reasons_str,
320
+ )
321
+ self.log_task_event(
322
+ event="lambda invoke failure",
323
+ ti_key=task_key,
324
+ extra=(
325
+ f"Task could not be queued after {attempt_number} attempts. "
326
+ f"Marking as failed. Reasons: {reasons_str}"
327
+ ),
328
+ )
329
+ self.fail(task_key)
330
+ else:
331
+ status_code = response.get("StatusCode")
332
+ self.log.info("Invoked Lambda for task %s with status %s", task_key, status_code)
333
+ self.running_tasks[ser_task_key] = task_key
334
+ # Add the serialized task key as the info, this will be assigned on the ti as the external_executor_id
335
+ self.running_state(task_key, ser_task_key)
336
+
337
+ def sync_running_tasks(self):
338
+ """
339
+ Poll the SQS queue for messages indicating task completion.
340
+
341
+ Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
342
+ Based on the return code, update the task state accordingly.
343
+ """
344
+ if not len(self.running_tasks):
345
+ self.log.debug("No running tasks to process.")
346
+ return
347
+
348
+ self.process_queue(self.sqs_queue_url)
349
+ if self.dlq_url and self.running_tasks:
350
+ self.process_queue(self.dlq_url)
351
+
352
+ def process_queue(self, queue_url: str):
353
+ """
354
+ Poll the SQS queue for messages indicating task completion.
355
+
356
+ Each message is expected to contain a JSON payload with 'task_key' and 'return_code'.
357
+
358
+ Based on the return code, update the task state accordingly.
359
+ """
360
+ response = self.sqs_client.receive_message(
361
+ QueueUrl=queue_url,
362
+ MaxNumberOfMessages=10,
363
+ )
364
+
365
+ messages = response.get("Messages", [])
366
+ # Pagination? Maybe we don't need it. Since we always delete messages after looking at them.
367
+ # But then that may delete messages that could have been adopted. Let's leave it for now and see how it goes.
368
+ if messages and queue_url == self.dlq_url:
369
+ self.log.warning("%d messages received from the dead letter queue", len(messages))
370
+
371
+ for message in messages:
372
+ receipt_handle = message["ReceiptHandle"]
373
+ body = json.loads(message["Body"])
374
+ return_code = body.get("return_code")
375
+ ser_task_key = body.get("task_key")
376
+ # Fetch the real task key from the running_tasks dict, using the serialized task key.
377
+ try:
378
+ task_key = self.running_tasks[ser_task_key]
379
+ except KeyError:
380
+ self.log.warning(
381
+ "Received task %s from the queue which is not found in running tasks. Removing message.",
382
+ ser_task_key,
383
+ )
384
+ task_key = None
385
+
386
+ if task_key:
387
+ if return_code == 0:
388
+ self.success(task_key)
389
+ self.log.info(
390
+ "Successful Lambda invocation for task %s received from SQS queue.", task_key
391
+ )
392
+ else:
393
+ # In this case the Lambda likely started but failed at run time since we got a non-zero
394
+ # return code. We could consider retrying these tasks within the executor, because this _likely_
395
+ # means the Airflow task did not run to completion, however we can't be sure (maybe the
396
+ # lambda runtime code has a bug and is returning a non-zero when it actually passed?). So
397
+ # perhaps not retrying is the safest option.
398
+ self.fail(task_key)
399
+ self.log.error(
400
+ "Lambda invocation for task: %s has failed to run with return code %s",
401
+ task_key,
402
+ return_code,
403
+ )
404
+ # Remove the task from the tracking mapping.
405
+ self.running_tasks.pop(ser_task_key)
406
+
407
+ # Delete the message from the queue.
408
+ self.sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
409
+
410
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
411
+ """
412
+ Adopt task instances which have an external_executor_id (the serialized task key).
413
+
414
+ Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
415
+
416
+ :param tis: The task instances to adopt.
417
+ """
418
+ with Stats.timer("lambda_executor.adopt_task_instances.duration"):
419
+ adopted_tis: list[TaskInstance] = []
420
+
421
+ if serialized_task_keys := [
422
+ (ti, ti.external_executor_id) for ti in tis if ti.external_executor_id
423
+ ]:
424
+ for ti, ser_task_key in serialized_task_keys:
425
+ try:
426
+ task_key = TaskInstanceKey.from_dict(json.loads(ser_task_key))
427
+ except Exception:
428
+ # If that task fails to deserialize, we should just skip it.
429
+ self.log.exception(
430
+ "Task failed to be adopted because the key could not be deserialized"
431
+ )
432
+ continue
433
+ self.running_tasks[ser_task_key] = task_key
434
+ adopted_tis.append(ti)
435
+
436
+ if adopted_tis:
437
+ tasks = [f"{task} in state {task.state}" for task in adopted_tis]
438
+ task_instance_str = "\n\t".join(tasks)
439
+ self.log.info(
440
+ "Adopted the following %d tasks from a dead executor:\n\t%s",
441
+ len(adopted_tis),
442
+ task_instance_str,
443
+ )
444
+
445
+ not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
446
+ return not_adopted_tis
447
+
448
+ def end(self, heartbeat_interval=10):
449
+ """
450
+ End execution. Poll until all outstanding tasks are marked as completed.
451
+
452
+ This is a blocking call and async Lambda tasks can not be cancelled, so this will wait until
453
+ all tasks are either completed or the timeout is reached.
454
+
455
+ :param heartbeat_interval: The interval in seconds to wait between checks for task completion.
456
+ """
457
+ self.log.info("Received signal to end, waiting for outstanding tasks to finish.")
458
+ time_to_wait = int(conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.END_WAIT_TIMEOUT))
459
+ start_time = timezone.utcnow()
460
+ while True:
461
+ if time_to_wait:
462
+ current_time = timezone.utcnow()
463
+ elapsed_time = (current_time - start_time).total_seconds()
464
+ if elapsed_time > time_to_wait:
465
+ self.log.warning(
466
+ "Timed out waiting for tasks to finish. Some tasks may not be handled gracefully"
467
+ " as the executor is force ending due to timeout."
468
+ )
469
+ break
470
+ self.sync()
471
+ if not self.running_tasks:
472
+ self.log.info("All tasks completed; executor ending.")
473
+ break
474
+ self.log.info("Waiting for %d task(s) to complete.", len(self.running_tasks))
475
+ time.sleep(heartbeat_interval)
476
+
477
+ def terminate(self):
478
+ """Get called when the daemon receives a SIGTERM."""
479
+ self.log.warning("Terminating Lambda executor. In-flight tasks cannot be stopped.")
@@ -0,0 +1,70 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import datetime
20
+ from collections.abc import Sequence
21
+ from dataclasses import dataclass
22
+ from typing import TYPE_CHECKING, Any
23
+
24
+ from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
25
+
26
+ if TYPE_CHECKING:
27
+ from airflow.models.taskinstancekey import TaskInstanceKey
28
+
29
+
30
+ CONFIG_GROUP_NAME = "aws_lambda_executor"
31
+ INVALID_CREDENTIALS_EXCEPTIONS = [
32
+ "ExpiredTokenException",
33
+ "InvalidClientTokenId",
34
+ "UnrecognizedClientException",
35
+ ]
36
+
37
+
38
+ @dataclass
39
+ class LambdaQueuedTask:
40
+ """Represents a Lambda task that is queued. The task will be run in the next heartbeat."""
41
+
42
+ key: TaskInstanceKey
43
+ command: CommandType
44
+ queue: str
45
+ executor_config: ExecutorConfigType
46
+ attempt_number: int
47
+ next_attempt_time: datetime.datetime
48
+
49
+
50
+ class InvokeLambdaKwargsConfigKeys(BaseConfigKeys):
51
+ """Config keys loaded which are valid lambda invoke args."""
52
+
53
+ FUNCTION_NAME = "function_name"
54
+ QUALIFIER = "function_qualifier"
55
+
56
+
57
+ class AllLambdaConfigKeys(InvokeLambdaKwargsConfigKeys):
58
+ """All config keys which are related to the Lambda Executor."""
59
+
60
+ AWS_CONN_ID = "conn_id"
61
+ CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
62
+ MAX_INVOKE_ATTEMPTS = "max_run_task_attempts"
63
+ REGION_NAME = "region_name"
64
+ QUEUE_URL = "queue_url"
65
+ DLQ_URL = "dead_letter_queue_url"
66
+ END_WAIT_TIMEOUT = "end_wait_timeout"
67
+
68
+
69
+ CommandType = Sequence[str]
70
+ ExecutorConfigType = dict[str, Any]
@@ -131,7 +131,7 @@ class AwsEcsExecutor(BaseExecutor):
131
131
  ti = workload.ti
132
132
  self.queued_tasks[ti.key] = workload
133
133
 
134
- def _process_workloads(self, workloads: list[workloads.All]) -> None:
134
+ def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
135
135
  from airflow.executors.workloads import ExecuteTask
136
136
 
137
137
  # Airflow V3 version
@@ -101,14 +101,10 @@ def build_task_kwargs() -> dict:
101
101
  # The executor will overwrite the 'command' property during execution. Must always be the first container!
102
102
  task_kwargs["overrides"]["containerOverrides"][0]["command"] = [] # type: ignore
103
103
 
104
- if any(
105
- [
106
- subnets := task_kwargs.pop(AllEcsConfigKeys.SUBNETS, None),
107
- security_groups := task_kwargs.pop(AllEcsConfigKeys.SECURITY_GROUPS, None),
108
- # Surrounding parens are for the walrus operator to function correctly along with the None check
109
- (assign_public_ip := task_kwargs.pop(AllEcsConfigKeys.ASSIGN_PUBLIC_IP, None)) is not None,
110
- ]
111
- ):
104
+ subnets = task_kwargs.pop(AllEcsConfigKeys.SUBNETS, None)
105
+ security_groups = task_kwargs.pop(AllEcsConfigKeys.SECURITY_GROUPS, None)
106
+ assign_public_ip = task_kwargs.pop(AllEcsConfigKeys.ASSIGN_PUBLIC_IP, None)
107
+ if subnets or security_groups or assign_public_ip != "False":
112
108
  network_config = prune_dict(
113
109
  {
114
110
  "awsvpcConfiguration": {
@@ -58,12 +58,16 @@ from airflow.hooks.base import BaseHook
58
58
  from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
59
59
  from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
60
60
  from airflow.providers.amazon.aws.utils.suppress import return_on_error
61
+ from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS
61
62
  from airflow.providers_manager import ProvidersManager
62
63
  from airflow.utils.helpers import exactly_one
63
64
  from airflow.utils.log.logging_mixin import LoggingMixin
64
65
 
65
66
  BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
66
67
 
68
+ if AIRFLOW_V_3_0_PLUS:
69
+ from airflow.sdk.exceptions import AirflowRuntimeError
70
+
67
71
  if TYPE_CHECKING:
68
72
  from aiobotocore.session import AioSession
69
73
  from botocore.client import ClientMeta
@@ -603,12 +607,24 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
603
607
  """Get the Airflow Connection object and wrap it in helper (cached)."""
604
608
  connection = None
605
609
  if self.aws_conn_id:
610
+ possible_exceptions: tuple[type[Exception], ...]
611
+
612
+ if AIRFLOW_V_3_0_PLUS:
613
+ possible_exceptions = (AirflowNotFoundException, AirflowRuntimeError)
614
+ else:
615
+ possible_exceptions = (AirflowNotFoundException,)
616
+
606
617
  try:
607
618
  connection = self.get_connection(self.aws_conn_id)
608
- except AirflowNotFoundException:
609
- self.log.warning(
610
- "Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
611
- )
619
+ except possible_exceptions as e:
620
+ if isinstance(
621
+ e, AirflowNotFoundException
622
+ ) or f"Connection with ID {self.aws_conn_id} not found" in str(e):
623
+ self.log.warning(
624
+ "Unable to find AWS Connection ID '%s', switching to empty.", self.aws_conn_id
625
+ )
626
+ else:
627
+ raise
612
628
 
613
629
  return AwsConnectionWrapper(
614
630
  conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
@@ -79,16 +79,25 @@ class NodegroupStates(Enum):
79
79
 
80
80
 
81
81
  COMMAND = """
82
+ export PYTHON_OPERATORS_VIRTUAL_ENV_MODE=1
82
83
  output=$({python_executable} -m airflow.providers.amazon.aws.utils.eks_get_token \
83
84
  --cluster-name {eks_cluster_name} {args} 2>&1)
84
85
 
85
- if [ $? -ne 0 ]; then
86
- echo "Error running the script"
87
- exit 1
86
+ status=$?
87
+
88
+ if [ "$status" -ne 0 ]; then
89
+ printf '%s' "$output" >&2
90
+ exit "$status"
88
91
  fi
89
92
 
90
- expiration_timestamp=$(echo "$output" | grep -oP 'expirationTimestamp: \\K[^,]+')
91
- token=$(echo "$output" | grep -oP 'token: \\K[^,]+')
93
+ # Use pure bash below to parse so that it's posix compliant
94
+
95
+ last_line=${{output##*$'\\n'}} # strip everything up to the last newline
96
+
97
+ timestamp=${{last_line#expirationTimestamp: }} # drop the label
98
+ timestamp=${{timestamp%%,*}} # keep up to the first comma
99
+
100
+ token=${{last_line##*, token: }} # text after ", token: "
92
101
 
93
102
  json_string=$(printf '{{"kind": "ExecCredential","apiVersion": \
94
103
  "client.authentication.k8s.io/v1alpha1","spec": {{}},"status": \