apache-airflow-providers-amazon 8.19.0rc1__py3-none-any.whl → 8.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +4 -2
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +22 -7
  4. airflow/providers/amazon/aws/auth_manager/{cli → avp}/schema.json +34 -2
  5. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +91 -170
  6. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +7 -32
  7. airflow/providers/amazon/aws/auth_manager/cli/definition.py +1 -1
  8. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +1 -0
  9. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  10. airflow/providers/amazon/aws/executors/batch/__init__.py +16 -0
  11. airflow/providers/amazon/aws/executors/batch/batch_executor.py +420 -0
  12. airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +87 -0
  13. airflow/providers/amazon/aws/executors/batch/boto_schema.py +67 -0
  14. airflow/providers/amazon/aws/executors/batch/utils.py +160 -0
  15. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +61 -18
  16. airflow/providers/amazon/aws/executors/ecs/utils.py +8 -13
  17. airflow/providers/amazon/aws/executors/utils/base_config_keys.py +25 -0
  18. airflow/providers/amazon/aws/hooks/athena.py +1 -0
  19. airflow/providers/amazon/aws/hooks/base_aws.py +1 -0
  20. airflow/providers/amazon/aws/hooks/batch_client.py +4 -3
  21. airflow/providers/amazon/aws/hooks/batch_waiters.py +1 -0
  22. airflow/providers/amazon/aws/hooks/bedrock.py +59 -0
  23. airflow/providers/amazon/aws/hooks/chime.py +1 -0
  24. airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -0
  25. airflow/providers/amazon/aws/hooks/datasync.py +1 -0
  26. airflow/providers/amazon/aws/hooks/dynamodb.py +1 -0
  27. airflow/providers/amazon/aws/hooks/eks.py +1 -0
  28. airflow/providers/amazon/aws/hooks/glue.py +13 -5
  29. airflow/providers/amazon/aws/hooks/glue_catalog.py +1 -0
  30. airflow/providers/amazon/aws/hooks/kinesis.py +1 -0
  31. airflow/providers/amazon/aws/hooks/lambda_function.py +1 -0
  32. airflow/providers/amazon/aws/hooks/rds.py +1 -0
  33. airflow/providers/amazon/aws/hooks/s3.py +24 -30
  34. airflow/providers/amazon/aws/hooks/ses.py +1 -0
  35. airflow/providers/amazon/aws/hooks/sns.py +1 -0
  36. airflow/providers/amazon/aws/hooks/sqs.py +1 -0
  37. airflow/providers/amazon/aws/operators/athena.py +2 -2
  38. airflow/providers/amazon/aws/operators/base_aws.py +4 -1
  39. airflow/providers/amazon/aws/operators/batch.py +4 -2
  40. airflow/providers/amazon/aws/operators/bedrock.py +252 -0
  41. airflow/providers/amazon/aws/operators/cloud_formation.py +1 -0
  42. airflow/providers/amazon/aws/operators/datasync.py +1 -0
  43. airflow/providers/amazon/aws/operators/ecs.py +9 -10
  44. airflow/providers/amazon/aws/operators/eks.py +1 -0
  45. airflow/providers/amazon/aws/operators/emr.py +57 -7
  46. airflow/providers/amazon/aws/operators/s3.py +1 -0
  47. airflow/providers/amazon/aws/operators/sns.py +1 -0
  48. airflow/providers/amazon/aws/operators/sqs.py +1 -0
  49. airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -0
  50. airflow/providers/amazon/aws/secrets/systems_manager.py +1 -0
  51. airflow/providers/amazon/aws/sensors/base_aws.py +4 -1
  52. airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
  53. airflow/providers/amazon/aws/sensors/cloud_formation.py +1 -0
  54. airflow/providers/amazon/aws/sensors/eks.py +3 -4
  55. airflow/providers/amazon/aws/sensors/sqs.py +2 -1
  56. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -2
  57. airflow/providers/amazon/aws/transfers/base.py +1 -0
  58. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -0
  59. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -0
  60. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -0
  61. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -0
  62. airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -0
  63. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -0
  64. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +21 -19
  65. airflow/providers/amazon/aws/triggers/bedrock.py +61 -0
  66. airflow/providers/amazon/aws/triggers/eks.py +1 -1
  67. airflow/providers/amazon/aws/triggers/redshift_cluster.py +1 -0
  68. airflow/providers/amazon/aws/triggers/s3.py +4 -2
  69. airflow/providers/amazon/aws/triggers/sagemaker.py +6 -4
  70. airflow/providers/amazon/aws/utils/emailer.py +1 -0
  71. airflow/providers/amazon/aws/waiters/bedrock.json +42 -0
  72. airflow/providers/amazon/get_provider_info.py +86 -1
  73. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/METADATA +10 -9
  74. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/RECORD +77 -66
  75. /airflow/providers/amazon/aws/executors/{ecs/Dockerfile → Dockerfile} +0 -0
  76. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/WHEEL +0 -0
  77. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,420 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job."""
19
+
20
+ from __future__ import annotations
21
+
22
+ import time
23
+ from collections import defaultdict, deque
24
+ from copy import deepcopy
25
+ from typing import TYPE_CHECKING, Any, Dict, List
26
+
27
+ from botocore.exceptions import ClientError, NoCredentialsError
28
+
29
+ from airflow.configuration import conf
30
+ from airflow.exceptions import AirflowException
31
+ from airflow.executors.base_executor import BaseExecutor
32
+ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import (
33
+ calculate_next_attempt_delay,
34
+ exponential_backoff_retry,
35
+ )
36
+ from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
37
+ from airflow.utils import timezone
38
+ from airflow.utils.helpers import merge_dicts
39
+
40
+ if TYPE_CHECKING:
41
+ from airflow.models.taskinstance import TaskInstanceKey
42
+ from airflow.providers.amazon.aws.executors.batch.boto_schema import (
43
+ BatchDescribeJobsResponseSchema,
44
+ BatchSubmitJobResponseSchema,
45
+ )
46
+ from airflow.providers.amazon.aws.executors.batch.utils import (
47
+ CONFIG_DEFAULTS,
48
+ CONFIG_GROUP_NAME,
49
+ AllBatchConfigKeys,
50
+ BatchJob,
51
+ BatchJobCollection,
52
+ BatchQueuedJob,
53
+ )
54
+ from airflow.utils.state import State
55
+
56
+ CommandType = List[str]
57
+ ExecutorConfigType = Dict[str, Any]
58
+
59
+ INVALID_CREDENTIALS_EXCEPTIONS = [
60
+ "ExpiredTokenException",
61
+ "InvalidClientTokenId",
62
+ "UnrecognizedClientException",
63
+ ]
64
+
65
+
66
+ class AwsBatchExecutor(BaseExecutor):
67
+ """
68
+ The Airflow Scheduler creates a shell command, and passes it to the executor.
69
+
70
+ This Batch Executor simply runs said airflow command in a resource provisioned and managed
71
+ by AWS Batch. It then periodically checks in with the launched jobs (via job-ids) to
72
+ determine the status.
73
+ The `submit_job_kwargs` is a dictionary that should match the kwargs for the
74
+ SubmitJob definition per AWS' documentation (see below).
75
+ For maximum flexibility, individual tasks can specify `executor_config` as a dictionary, with keys that
76
+ match the request syntax for the SubmitJob definition per AWS' documentation (see link below). The
77
+ `executor_config` will update the `submit_job_kwargs` dictionary when calling the task. This allows
78
+ individual jobs to specify CPU, memory, GPU, env variables, etc.
79
+ Prerequisite: proper configuration of Boto3 library
80
+ .. seealso:: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html for
81
+ authentication and access-key management. You can store an environmental variable, setup aws config from
82
+ console, or use IAM roles.
83
+ .. seealso:: https://docs.aws.amazon.com/batch/latest/APIReference/API_SubmitJob.html for an
84
+ Airflow TaskInstance's executor_config.
85
+ """
86
+
87
+ # Maximum number of retries to submit a Batch Job.
88
+ MAX_SUBMIT_JOB_ATTEMPTS = conf.get(
89
+ CONFIG_GROUP_NAME,
90
+ AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS,
91
+ fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.MAX_SUBMIT_JOB_ATTEMPTS],
92
+ )
93
+
94
+ # AWS only allows a maximum number of JOBs in the describe_jobs function
95
+ DESCRIBE_JOBS_BATCH_SIZE = 99
96
+
97
+ def __init__(self, *args, **kwargs):
98
+ super().__init__(*args, **kwargs)
99
+ self.active_workers = BatchJobCollection()
100
+ self.pending_jobs: deque = deque()
101
+ self.attempts_since_last_successful_connection = 0
102
+ self.load_batch_connection(check_connection=False)
103
+ self.IS_BOTO_CONNECTION_HEALTHY = False
104
+ self.submit_job_kwargs = self._load_submit_kwargs()
105
+
106
+ def check_health(self):
107
+ """Make a test API call to check the health of the Batch Executor."""
108
+ success_status = "succeeded."
109
+ status = success_status
110
+
111
+ try:
112
+ invalid_job_id = "a" * 32
113
+ self.batch.describe_jobs(jobs=[invalid_job_id])
114
+ # If an empty response was received, then that is considered to be the success case.
115
+ except ClientError as ex:
116
+ error_code = ex.response["Error"]["Code"]
117
+ error_message = ex.response["Error"]["Message"]
118
+ status = f"failed because: {error_code}: {error_message}. "
119
+ except Exception as e:
120
+ # Any non-ClientError exceptions. This can include Botocore exceptions for example
121
+ status = f"failed because: {e}. "
122
+ finally:
123
+ msg_prefix = "Batch Executor health check has %s"
124
+ if status == success_status:
125
+ self.IS_BOTO_CONNECTION_HEALTHY = True
126
+ self.log.info(msg_prefix, status)
127
+ else:
128
+ msg_error_suffix = (
129
+ "The Batch executor will not be able to run Airflow tasks until the issue is addressed."
130
+ )
131
+ raise AirflowException(msg_prefix % status + msg_error_suffix)
132
+
133
+ def start(self):
134
+ """Call this when the Executor is run for the first time by the scheduler."""
135
+ check_health = conf.getboolean(
136
+ CONFIG_GROUP_NAME, AllBatchConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False
137
+ )
138
+
139
+ if not check_health:
140
+ return
141
+
142
+ self.log.info("Starting Batch Executor and determining health...")
143
+ try:
144
+ self.check_health()
145
+ except AirflowException:
146
+ self.log.error("Stopping the Airflow Scheduler from starting until the issue is resolved.")
147
+ raise
148
+
149
+ def load_batch_connection(self, check_connection: bool = True):
150
+ self.log.info("Loading Connection information")
151
+ aws_conn_id = conf.get(
152
+ CONFIG_GROUP_NAME,
153
+ AllBatchConfigKeys.AWS_CONN_ID,
154
+ fallback=CONFIG_DEFAULTS[AllBatchConfigKeys.AWS_CONN_ID],
155
+ )
156
+ region_name = conf.get(CONFIG_GROUP_NAME, AllBatchConfigKeys.REGION_NAME, fallback=None)
157
+ self.batch = BatchClientHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
158
+ self.attempts_since_last_successful_connection += 1
159
+ self.last_connection_reload = timezone.utcnow()
160
+
161
+ if check_connection:
162
+ self.check_health()
163
+ self.attempts_since_last_successful_connection = 0
164
+
165
+ def sync(self):
166
+ """Sync will get called periodically by the heartbeat method in the scheduler."""
167
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
168
+ exponential_backoff_retry(
169
+ self.last_connection_reload,
170
+ self.attempts_since_last_successful_connection,
171
+ self.load_batch_connection,
172
+ )
173
+ if not self.IS_BOTO_CONNECTION_HEALTHY:
174
+ return
175
+ try:
176
+ self.sync_running_jobs()
177
+ self.attempt_submit_jobs()
178
+ except (ClientError, NoCredentialsError) as error:
179
+ error_code = error.response["Error"]["Code"]
180
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
181
+ self.IS_BOTO_CONNECTION_HEALTHY = False
182
+ self.log.warning(
183
+ "AWS credentials are either missing or expired: %s.\nRetrying connection", error
184
+ )
185
+ except Exception:
186
+ # We catch any and all exceptions because otherwise they would bubble
187
+ # up and kill the scheduler process
188
+ self.log.exception("Failed to sync %s", self.__class__.__name__)
189
+
190
+ def sync_running_jobs(self):
191
+ all_job_ids = self.active_workers.get_all_jobs()
192
+ if not all_job_ids:
193
+ self.log.debug("No active Airflow tasks, skipping sync")
194
+ return
195
+ describe_job_response = self._describe_jobs(all_job_ids)
196
+
197
+ self.log.debug("Active Workers: %s", describe_job_response)
198
+
199
+ for job in describe_job_response:
200
+ if job.get_job_state() == State.FAILED:
201
+ self._handle_failed_job(job)
202
+ elif job.get_job_state() == State.SUCCESS:
203
+ task_key = self.active_workers.pop_by_id(job.job_id)
204
+ self.success(task_key)
205
+
206
+ def _handle_failed_job(self, job):
207
+ """
208
+ Handle a failed AWS Batch job.
209
+
210
+ If an API failure occurs when running a Batch job, the job is rescheduled.
211
+ """
212
+ # A failed job here refers to a job that has been marked Failed by AWS Batch, which is not
213
+ # necessarily the same as being marked Failed by Airflow. AWS Batch will mark a job Failed
214
+ # if the job fails before the Airflow process on the container has started. These failures
215
+ # can be caused by a Batch API failure, container misconfiguration etc.
216
+ # If the container is able to start up and run the Airflow process, any failures after that
217
+ # (i.e. DAG failures) will not be marked as Failed by AWS Batch, because Batch on assumes
218
+ # responsibility for ensuring the process started. Failures in the DAG will be caught by
219
+ # Airflow, which will be handled separately.
220
+ job_info = self.active_workers.id_to_job_info[job.job_id]
221
+ task_key = self.active_workers.id_to_key[job.job_id]
222
+ task_cmd = job_info.cmd
223
+ queue = job_info.queue
224
+ exec_info = job_info.config
225
+ failure_count = self.active_workers.failure_count_by_id(job_id=job.job_id)
226
+ if int(failure_count) < int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
227
+ self.log.warning(
228
+ "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.",
229
+ task_key,
230
+ job.status_reason,
231
+ failure_count,
232
+ self.__class__.MAX_SUBMIT_JOB_ATTEMPTS,
233
+ job.job_id,
234
+ )
235
+ self.active_workers.increment_failure_count(job_id=job.job_id)
236
+ self.active_workers.pop_by_id(job.job_id)
237
+ self.pending_jobs.append(
238
+ BatchQueuedJob(
239
+ task_key,
240
+ task_cmd,
241
+ queue,
242
+ exec_info,
243
+ failure_count + 1,
244
+ timezone.utcnow() + calculate_next_attempt_delay(failure_count),
245
+ )
246
+ )
247
+ else:
248
+ self.log.error(
249
+ "Airflow task %s has failed a maximum of %s times. Marking as failed",
250
+ task_key,
251
+ failure_count,
252
+ )
253
+ self.active_workers.pop_by_id(job.job_id)
254
+ self.fail(task_key)
255
+
256
+ def attempt_submit_jobs(self):
257
+ """
258
+ Attempt to submit all jobs submitted to the Executor.
259
+
260
+ For each iteration of the sync() method, every pending job is submitted to Batch.
261
+ If a job fails validation, it will be put at the back of the queue to be reattempted
262
+ in the next iteration of the sync() method, unless it has exceeded the maximum number of
263
+ attempts. If a job exceeds the maximum number of attempts, it is removed from the queue.
264
+ """
265
+ failure_reasons = defaultdict(int)
266
+ for _ in range(len(self.pending_jobs)):
267
+ batch_job = self.pending_jobs.popleft()
268
+ key = batch_job.key
269
+ cmd = batch_job.command
270
+ queue = batch_job.queue
271
+ exec_config = batch_job.executor_config
272
+ attempt_number = batch_job.attempt_number
273
+ _failure_reason = []
274
+ if timezone.utcnow() < batch_job.next_attempt_time:
275
+ self.pending_jobs.append(batch_job)
276
+ continue
277
+ try:
278
+ submit_job_response = self._submit_job(key, cmd, queue, exec_config or {})
279
+ except NoCredentialsError:
280
+ self.pending_jobs.append(batch_job)
281
+ raise
282
+ except ClientError as e:
283
+ error_code = e.response["Error"]["Code"]
284
+ if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
285
+ self.pending_jobs.append(batch_job)
286
+ raise
287
+ _failure_reason.append(str(e))
288
+ except Exception as e:
289
+ _failure_reason.append(str(e))
290
+
291
+ if _failure_reason:
292
+ for reason in _failure_reason:
293
+ failure_reasons[reason] += 1
294
+
295
+ if attempt_number >= int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
296
+ self.log.error(
297
+ "This job has been unsuccessfully attempted too many times (%s). Dropping the task.",
298
+ attempt_number,
299
+ )
300
+ self.fail(key=key)
301
+ else:
302
+ batch_job.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
303
+ attempt_number
304
+ )
305
+ batch_job.attempt_number += 1
306
+ self.pending_jobs.append(batch_job)
307
+ else:
308
+ # Success case
309
+ self.active_workers.add_job(
310
+ job_id=submit_job_response["job_id"],
311
+ airflow_task_key=key,
312
+ airflow_cmd=cmd,
313
+ queue=queue,
314
+ exec_config=exec_config,
315
+ attempt_number=attempt_number,
316
+ )
317
+ if failure_reasons:
318
+ self.log.error(
319
+ "Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
320
+ dict(failure_reasons),
321
+ )
322
+
323
+ def _describe_jobs(self, job_ids) -> list[BatchJob]:
324
+ all_jobs = []
325
+ for i in range(0, len(job_ids), self.__class__.DESCRIBE_JOBS_BATCH_SIZE):
326
+ batched_job_ids = job_ids[i : i + self.__class__.DESCRIBE_JOBS_BATCH_SIZE]
327
+ if not batched_job_ids:
328
+ continue
329
+ boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids)
330
+
331
+ describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
332
+ all_jobs.extend(describe_tasks_response["jobs"])
333
+ return all_jobs
334
+
335
+ def execute_async(self, key: TaskInstanceKey, command: CommandType, queue=None, executor_config=None):
336
+ """Save the task to be executed in the next sync using Boto3's RunTask API."""
337
+ if executor_config and "command" in executor_config:
338
+ raise ValueError('Executor Config should never override "command"')
339
+
340
+ self.pending_jobs.append(
341
+ BatchQueuedJob(
342
+ key=key,
343
+ command=command,
344
+ queue=queue,
345
+ executor_config=executor_config or {},
346
+ attempt_number=1,
347
+ next_attempt_time=timezone.utcnow(),
348
+ )
349
+ )
350
+
351
+ def _submit_job(
352
+ self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
353
+ ) -> str:
354
+ """
355
+ Override the submit_job_kwargs, and calls the boto3 API submit_job endpoint.
356
+
357
+ The command and executor config will be placed in the container-override section of the JSON request,
358
+ before calling Boto3's "submit_job" function.
359
+ """
360
+ submit_job_api = self._submit_job_kwargs(key, cmd, queue, exec_config)
361
+
362
+ boto_submit_job = self.batch.submit_job(**submit_job_api)
363
+ submit_job_response = BatchSubmitJobResponseSchema().load(boto_submit_job)
364
+ return submit_job_response
365
+
366
+ def _submit_job_kwargs(
367
+ self, key: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
368
+ ) -> dict:
369
+ """
370
+ Override the Airflow command to update the container overrides so kwargs are specific to this task.
371
+
372
+ One last chance to modify Boto3's "submit_job" kwarg params before it gets passed into the Boto3
373
+ client. For the latest kwarg parameters:
374
+ .. seealso:: https://docs.aws.amazon.com/batch/latest/APIReference/API_SubmitJob.html
375
+ """
376
+ submit_job_api = deepcopy(self.submit_job_kwargs)
377
+ submit_job_api = merge_dicts(submit_job_api, exec_config)
378
+ submit_job_api["containerOverrides"]["command"] = cmd
379
+ if "environment" not in submit_job_api["containerOverrides"]:
380
+ submit_job_api["containerOverrides"]["environment"] = []
381
+ submit_job_api["containerOverrides"]["environment"].append(
382
+ {"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}
383
+ )
384
+ return submit_job_api
385
+
386
+ def end(self, heartbeat_interval=10):
387
+ """Wait for all currently running tasks to end and prevent any new jobs from running."""
388
+ try:
389
+ while True:
390
+ self.sync()
391
+ if not self.active_workers:
392
+ break
393
+ time.sleep(heartbeat_interval)
394
+ except Exception:
395
+ # This should never happen because sync() should never raise an exception.
396
+ self.log.exception("Failed to end %s", self.__class__.__name__)
397
+
398
+ def terminate(self):
399
+ """Kill all Batch Jobs by calling Boto3's TerminateJob API."""
400
+ try:
401
+ for job_id in self.active_workers.get_all_jobs():
402
+ self.batch.terminate_job(jobId=job_id, reason="Airflow Executor received a SIGTERM")
403
+ self.end()
404
+ except Exception:
405
+ # We catch any and all exceptions because otherwise they would bubble
406
+ # up and kill the scheduler process.
407
+ self.log.exception("Failed to terminate %s", self.__class__.__name__)
408
+
409
+ @staticmethod
410
+ def _load_submit_kwargs() -> dict:
411
+ from airflow.providers.amazon.aws.executors.batch.batch_executor_config import build_submit_kwargs
412
+
413
+ submit_kwargs = build_submit_kwargs()
414
+
415
+ if "containerOverrides" not in submit_kwargs or "command" not in submit_kwargs["containerOverrides"]:
416
+ raise KeyError(
417
+ 'SubmitJob API needs kwargs["containerOverrides"]["command"] field,'
418
+ " and value should be NULL or empty."
419
+ )
420
+ return submit_kwargs
@@ -0,0 +1,87 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ """
19
+ AWS Batch Executor configuration.
20
+
21
+ This is the configuration for calling the Batch ``submit_job`` function. The AWS Batch Executor calls
22
+ Boto3's ``submit_job(**kwargs)`` function with the kwargs templated by this dictionary. See the URL
23
+ below for documentation on the parameters accepted by the Boto3 submit_job function.
24
+
25
+ .. seealso::
26
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch/client/submit_job.html
27
+
28
+ """
29
+
30
+ from __future__ import annotations
31
+
32
+ import json
33
+ from json import JSONDecodeError
34
+ from typing import TYPE_CHECKING
35
+
36
+ from airflow.configuration import conf
37
+ from airflow.providers.amazon.aws.executors.batch.utils import (
38
+ CONFIG_GROUP_NAME,
39
+ AllBatchConfigKeys,
40
+ BatchSubmitJobKwargsConfigKeys,
41
+ )
42
+ from airflow.providers.amazon.aws.executors.ecs.utils import camelize_dict_keys
43
+ from airflow.utils.helpers import prune_dict
44
+
45
+
46
+ def _fetch_templated_kwargs() -> dict[str, str]:
47
+ submit_job_kwargs_value = conf.get(
48
+ CONFIG_GROUP_NAME, AllBatchConfigKeys.SUBMIT_JOB_KWARGS, fallback=dict()
49
+ )
50
+ return json.loads(str(submit_job_kwargs_value))
51
+
52
+
53
+ def _fetch_config_values() -> dict[str, str]:
54
+ return prune_dict(
55
+ {key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in BatchSubmitJobKwargsConfigKeys()}
56
+ )
57
+
58
+
59
+ def build_submit_kwargs() -> dict:
60
+ job_kwargs = _fetch_config_values()
61
+ job_kwargs.update(_fetch_templated_kwargs())
62
+
63
+ if "containerOverrides" not in job_kwargs:
64
+ job_kwargs["containerOverrides"] = {} # type: ignore
65
+ job_kwargs["containerOverrides"]["command"] = [] # type: ignore
66
+
67
+ if "nodeOverrides" in job_kwargs:
68
+ raise KeyError("Multi-node jobs are not currently supported.")
69
+ if "eksPropertiesOverride" in job_kwargs:
70
+ raise KeyError("Eks jobs are not currently supported.")
71
+
72
+ if TYPE_CHECKING:
73
+ assert isinstance(job_kwargs, dict)
74
+ # some checks with some helpful errors
75
+ if "containerOverrides" not in job_kwargs or "command" not in job_kwargs["containerOverrides"]:
76
+ raise KeyError(
77
+ 'SubmitJob API needs kwargs["containerOverrides"]["command"] field,'
78
+ " and value should be NULL or empty."
79
+ )
80
+ job_kwargs = camelize_dict_keys(job_kwargs)
81
+
82
+ try:
83
+ json.loads(json.dumps(job_kwargs))
84
+ except JSONDecodeError:
85
+ raise ValueError("AWS Batch Executor config values must be JSON serializable.")
86
+
87
+ return job_kwargs
@@ -0,0 +1,67 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from marshmallow import EXCLUDE, Schema, fields, post_load
20
+
21
+ from airflow.providers.amazon.aws.executors.batch.utils import BatchJob
22
+
23
+
24
+ class BatchSubmitJobResponseSchema(Schema):
25
+ """API Response for SubmitJob."""
26
+
27
+ # The unique identifier for the job.
28
+ job_id = fields.String(data_key="jobId", required=True)
29
+
30
+ class Meta:
31
+ """Options object for a Schema. See Schema.Meta for more details and valid values."""
32
+
33
+ unknown = EXCLUDE
34
+
35
+
36
+ class BatchJobDetailSchema(Schema):
37
+ """API Response for Describe Jobs."""
38
+
39
+ # The unique identifier for the job.
40
+ job_id = fields.String(data_key="jobId", required=True)
41
+ # The current status for the job:
42
+ # 'SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED'
43
+ status = fields.String(required=True)
44
+ # A short, human-readable string to provide additional details about the current status of the job.
45
+ status_reason = fields.String(data_key="statusReason")
46
+
47
+ @post_load
48
+ def make_job(self, data, **kwargs):
49
+ """Overwrite marshmallow load() to return an instance of BatchJob instead of a dictionary."""
50
+ return BatchJob(**data)
51
+
52
+ class Meta:
53
+ """Options object for a Schema. See Schema.Meta for more details and valid values."""
54
+
55
+ unknown = EXCLUDE
56
+
57
+
58
+ class BatchDescribeJobsResponseSchema(Schema):
59
+ """API Response for Describe Jobs."""
60
+
61
+ # The list of jobs
62
+ jobs = fields.List(fields.Nested(BatchJobDetailSchema), required=True)
63
+
64
+ class Meta:
65
+ """Options object for a Schema. See Schema.Meta for more details and valid values."""
66
+
67
+ unknown = EXCLUDE