apache-airflow-providers-amazon 8.19.0rc1__py3-none-any.whl → 8.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. airflow/providers/amazon/__init__.py +1 -1
  2. airflow/providers/amazon/aws/auth_manager/avp/entities.py +4 -2
  3. airflow/providers/amazon/aws/auth_manager/avp/facade.py +22 -7
  4. airflow/providers/amazon/aws/auth_manager/{cli → avp}/schema.json +34 -2
  5. airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +91 -170
  6. airflow/providers/amazon/aws/auth_manager/cli/avp_commands.py +7 -32
  7. airflow/providers/amazon/aws/auth_manager/cli/definition.py +1 -1
  8. airflow/providers/amazon/aws/auth_manager/cli/idc_commands.py +1 -0
  9. airflow/providers/amazon/aws/auth_manager/views/auth.py +1 -1
  10. airflow/providers/amazon/aws/executors/batch/__init__.py +16 -0
  11. airflow/providers/amazon/aws/executors/batch/batch_executor.py +420 -0
  12. airflow/providers/amazon/aws/executors/batch/batch_executor_config.py +87 -0
  13. airflow/providers/amazon/aws/executors/batch/boto_schema.py +67 -0
  14. airflow/providers/amazon/aws/executors/batch/utils.py +160 -0
  15. airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +61 -18
  16. airflow/providers/amazon/aws/executors/ecs/utils.py +8 -13
  17. airflow/providers/amazon/aws/executors/utils/base_config_keys.py +25 -0
  18. airflow/providers/amazon/aws/hooks/athena.py +1 -0
  19. airflow/providers/amazon/aws/hooks/base_aws.py +1 -0
  20. airflow/providers/amazon/aws/hooks/batch_client.py +4 -3
  21. airflow/providers/amazon/aws/hooks/batch_waiters.py +1 -0
  22. airflow/providers/amazon/aws/hooks/bedrock.py +59 -0
  23. airflow/providers/amazon/aws/hooks/chime.py +1 -0
  24. airflow/providers/amazon/aws/hooks/cloud_formation.py +1 -0
  25. airflow/providers/amazon/aws/hooks/datasync.py +1 -0
  26. airflow/providers/amazon/aws/hooks/dynamodb.py +1 -0
  27. airflow/providers/amazon/aws/hooks/eks.py +1 -0
  28. airflow/providers/amazon/aws/hooks/glue.py +13 -5
  29. airflow/providers/amazon/aws/hooks/glue_catalog.py +1 -0
  30. airflow/providers/amazon/aws/hooks/kinesis.py +1 -0
  31. airflow/providers/amazon/aws/hooks/lambda_function.py +1 -0
  32. airflow/providers/amazon/aws/hooks/rds.py +1 -0
  33. airflow/providers/amazon/aws/hooks/s3.py +24 -30
  34. airflow/providers/amazon/aws/hooks/ses.py +1 -0
  35. airflow/providers/amazon/aws/hooks/sns.py +1 -0
  36. airflow/providers/amazon/aws/hooks/sqs.py +1 -0
  37. airflow/providers/amazon/aws/operators/athena.py +2 -2
  38. airflow/providers/amazon/aws/operators/base_aws.py +4 -1
  39. airflow/providers/amazon/aws/operators/batch.py +4 -2
  40. airflow/providers/amazon/aws/operators/bedrock.py +252 -0
  41. airflow/providers/amazon/aws/operators/cloud_formation.py +1 -0
  42. airflow/providers/amazon/aws/operators/datasync.py +1 -0
  43. airflow/providers/amazon/aws/operators/ecs.py +9 -10
  44. airflow/providers/amazon/aws/operators/eks.py +1 -0
  45. airflow/providers/amazon/aws/operators/emr.py +57 -7
  46. airflow/providers/amazon/aws/operators/s3.py +1 -0
  47. airflow/providers/amazon/aws/operators/sns.py +1 -0
  48. airflow/providers/amazon/aws/operators/sqs.py +1 -0
  49. airflow/providers/amazon/aws/secrets/secrets_manager.py +1 -0
  50. airflow/providers/amazon/aws/secrets/systems_manager.py +1 -0
  51. airflow/providers/amazon/aws/sensors/base_aws.py +4 -1
  52. airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
  53. airflow/providers/amazon/aws/sensors/cloud_formation.py +1 -0
  54. airflow/providers/amazon/aws/sensors/eks.py +3 -4
  55. airflow/providers/amazon/aws/sensors/sqs.py +2 -1
  56. airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py +4 -2
  57. airflow/providers/amazon/aws/transfers/base.py +1 -0
  58. airflow/providers/amazon/aws/transfers/exasol_to_s3.py +1 -0
  59. airflow/providers/amazon/aws/transfers/gcs_to_s3.py +1 -0
  60. airflow/providers/amazon/aws/transfers/google_api_to_s3.py +1 -0
  61. airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py +1 -0
  62. airflow/providers/amazon/aws/transfers/http_to_s3.py +1 -0
  63. airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py +1 -0
  64. airflow/providers/amazon/aws/transfers/redshift_to_s3.py +21 -19
  65. airflow/providers/amazon/aws/triggers/bedrock.py +61 -0
  66. airflow/providers/amazon/aws/triggers/eks.py +1 -1
  67. airflow/providers/amazon/aws/triggers/redshift_cluster.py +1 -0
  68. airflow/providers/amazon/aws/triggers/s3.py +4 -2
  69. airflow/providers/amazon/aws/triggers/sagemaker.py +6 -4
  70. airflow/providers/amazon/aws/utils/emailer.py +1 -0
  71. airflow/providers/amazon/aws/waiters/bedrock.json +42 -0
  72. airflow/providers/amazon/get_provider_info.py +86 -1
  73. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/METADATA +10 -9
  74. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/RECORD +77 -66
  75. /airflow/providers/amazon/aws/executors/{ecs/Dockerfile → Dockerfile} +0 -0
  76. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/WHEEL +0 -0
  77. {apache_airflow_providers_amazon-8.19.0rc1.dist-info → apache_airflow_providers_amazon-8.20.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,160 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ import datetime
20
+ from collections import defaultdict
21
+ from dataclasses import dataclass
22
+ from typing import TYPE_CHECKING, Any, Dict, List
23
+
24
+ from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
25
+ from airflow.utils.state import State
26
+
27
+ if TYPE_CHECKING:
28
+ from airflow.models.taskinstance import TaskInstanceKey
29
+
30
+ CommandType = List[str]
31
+ ExecutorConfigType = Dict[str, Any]
32
+
33
+ CONFIG_GROUP_NAME = "aws_batch_executor"
34
+
35
+ CONFIG_DEFAULTS = {
36
+ "conn_id": "aws_default",
37
+ "max_submit_job_attempts": "3",
38
+ "check_health_on_startup": "True",
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class BatchQueuedJob:
44
+ """Represents a Batch job that is queued. The job will be run in the next heartbeat."""
45
+
46
+ key: TaskInstanceKey
47
+ command: CommandType
48
+ queue: str
49
+ executor_config: ExecutorConfigType
50
+ attempt_number: int
51
+ next_attempt_time: datetime.datetime
52
+
53
+
54
+ @dataclass
55
+ class BatchJobInfo:
56
+ """Contains information about a currently running Batch job."""
57
+
58
+ cmd: CommandType
59
+ queue: str
60
+ config: ExecutorConfigType
61
+
62
+
63
+ class BatchJob:
64
+ """Data Transfer Object for an AWS Batch Job."""
65
+
66
+ STATE_MAPPINGS = {
67
+ "SUBMITTED": State.QUEUED,
68
+ "PENDING": State.QUEUED,
69
+ "RUNNABLE": State.QUEUED,
70
+ "STARTING": State.QUEUED,
71
+ "RUNNING": State.RUNNING,
72
+ "SUCCEEDED": State.SUCCESS,
73
+ "FAILED": State.FAILED,
74
+ }
75
+
76
+ def __init__(self, job_id: str, status: str, status_reason: str | None = None):
77
+ self.job_id = job_id
78
+ self.status = status
79
+ self.status_reason = status_reason
80
+
81
+ def get_job_state(self) -> str:
82
+ """Return the state of the job."""
83
+ return self.STATE_MAPPINGS.get(self.status, State.QUEUED)
84
+
85
+ def __repr__(self):
86
+ """Return a visual representation of the Job status."""
87
+ return f"({self.job_id} -> {self.status}, {self.get_job_state()})"
88
+
89
+
90
+ class BatchJobCollection:
91
+ """A collection to manage running Batch Jobs."""
92
+
93
+ def __init__(self):
94
+ self.key_to_id: dict[TaskInstanceKey, str] = {}
95
+ self.id_to_key: dict[str, TaskInstanceKey] = {}
96
+ self.id_to_failure_counts: dict[str, int] = defaultdict(int)
97
+ self.id_to_job_info: dict[str, BatchJobInfo] = {}
98
+
99
+ def add_job(
100
+ self,
101
+ job_id: str,
102
+ airflow_task_key: TaskInstanceKey,
103
+ airflow_cmd: CommandType,
104
+ queue: str,
105
+ exec_config: ExecutorConfigType,
106
+ attempt_number: int,
107
+ ):
108
+ """Add a job to the collection."""
109
+ self.key_to_id[airflow_task_key] = job_id
110
+ self.id_to_key[job_id] = airflow_task_key
111
+ self.id_to_failure_counts[job_id] = attempt_number
112
+ self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)
113
+
114
+ def pop_by_id(self, job_id: str) -> TaskInstanceKey:
115
+ """Delete job from collection based off of Batch Job ID."""
116
+ task_key = self.id_to_key[job_id]
117
+ del self.key_to_id[task_key]
118
+ del self.id_to_key[job_id]
119
+ del self.id_to_failure_counts[job_id]
120
+ return task_key
121
+
122
+ def failure_count_by_id(self, job_id: str) -> int:
123
+ """Get the number of times a job has failed given a Batch Job Id."""
124
+ return self.id_to_failure_counts[job_id]
125
+
126
+ def increment_failure_count(self, job_id: str):
127
+ """Increment the failure counter given a Batch Job Id."""
128
+ self.id_to_failure_counts[job_id] += 1
129
+
130
+ def get_all_jobs(self) -> list[str]:
131
+ """Get all AWS ARNs in collection."""
132
+ return list(self.id_to_key.keys())
133
+
134
+ def __len__(self):
135
+ """Return the number of jobs in collection."""
136
+ return len(self.key_to_id)
137
+
138
+
139
+ class BatchSubmitJobKwargsConfigKeys(BaseConfigKeys):
140
+ """Keys loaded into the config which are valid Batch submit_job kwargs."""
141
+
142
+ JOB_NAME = "job_name"
143
+ JOB_QUEUE = "job_queue"
144
+ JOB_DEFINITION = "job_definition"
145
+ EKS_PROPERTIES_OVERRIDE = "eks_properties_override"
146
+ NODE_OVERRIDE = "node_override"
147
+
148
+
149
+ class AllBatchConfigKeys(BatchSubmitJobKwargsConfigKeys):
150
+ """All keys loaded into the config which are related to the Batch Executor."""
151
+
152
+ MAX_SUBMIT_JOB_ATTEMPTS = "max_submit_job_attempts"
153
+ AWS_CONN_ID = "conn_id"
154
+ SUBMIT_JOB_KWARGS = "submit_job_kwargs"
155
+ REGION_NAME = "region_name"
156
+ CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
157
+
158
+
159
+ class BatchExecutorException(Exception):
160
+ """Thrown when something unexpected has occurred within the AWS Batch ecosystem."""
@@ -26,7 +26,7 @@ from __future__ import annotations
26
26
  import time
27
27
  from collections import defaultdict, deque
28
28
  from copy import deepcopy
29
- from typing import TYPE_CHECKING
29
+ from typing import TYPE_CHECKING, Sequence
30
30
 
31
31
  from botocore.exceptions import ClientError, NoCredentialsError
32
32
 
@@ -47,12 +47,13 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
47
47
  exponential_backoff_retry,
48
48
  )
49
49
  from airflow.providers.amazon.aws.hooks.ecs import EcsHook
50
+ from airflow.stats import Stats
50
51
  from airflow.utils import timezone
51
52
  from airflow.utils.helpers import merge_dicts
52
53
  from airflow.utils.state import State
53
54
 
54
55
  if TYPE_CHECKING:
55
- from airflow.models.taskinstance import TaskInstanceKey
56
+ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
56
57
  from airflow.providers.amazon.aws.executors.ecs.utils import (
57
58
  CommandType,
58
59
  ExecutorConfigType,
@@ -182,7 +183,7 @@ class AwsEcsExecutor(BaseExecutor):
182
183
  AllEcsConfigKeys.AWS_CONN_ID,
183
184
  fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
184
185
  )
185
- region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME)
186
+ region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None)
186
187
  self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn
187
188
  self.attempts_since_last_successful_connection += 1
188
189
  self.last_connection_reload = timezone.utcnow()
@@ -208,7 +209,7 @@ class AwsEcsExecutor(BaseExecutor):
208
209
  if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
209
210
  self.IS_BOTO_CONNECTION_HEALTHY = False
210
211
  self.log.warning(
211
- f"AWS credentials are either missing or expired: {error}.\nRetrying connection"
212
+ "AWS credentials are either missing or expired: %s.\nRetrying connection", error
212
213
  )
213
214
 
214
215
  except Exception:
@@ -240,21 +241,19 @@ class AwsEcsExecutor(BaseExecutor):
240
241
  # Get state of current task.
241
242
  task_state = task.get_task_state()
242
243
  task_key = self.active_workers.arn_to_key[task.task_arn]
244
+
243
245
  # Mark finished tasks as either a success/failure.
244
- if task_state == State.FAILED:
245
- self.fail(task_key)
246
+ if task_state == State.FAILED or task_state == State.REMOVED:
246
247
  self.__log_container_failures(task_arn=task.task_arn)
247
- elif task_state == State.SUCCESS:
248
- self.success(task_key)
249
- elif task_state == State.REMOVED:
250
248
  self.__handle_failed_task(task.task_arn, task.stopped_reason)
251
- if task_state in (State.FAILED, State.SUCCESS):
249
+ elif task_state == State.SUCCESS:
252
250
  self.log.debug(
253
251
  "Airflow task %s marked as %s after running on ECS Task (arn) %s",
254
252
  task_key,
255
253
  task_state,
256
254
  task.task_arn,
257
255
  )
256
+ self.success(task_key)
258
257
  self.active_workers.pop_by_key(task_key)
259
258
 
260
259
  def __describe_tasks(self, task_arns):
@@ -287,7 +286,14 @@ class AwsEcsExecutor(BaseExecutor):
287
286
  )
288
287
 
289
288
  def __handle_failed_task(self, task_arn: str, reason: str):
290
- """If an API failure occurs, the task is rescheduled."""
289
+ """
290
+ If an API failure occurs, the task is rescheduled.
291
+
292
+ This function will determine whether the task has been attempted the appropriate number
293
+ of times, and determine whether the task should be marked failed or not. The task will
294
+ be removed active_workers, and marked as FAILED, or set into pending_tasks depending on
295
+ how many times it has been retried.
296
+ """
291
297
  task_key = self.active_workers.arn_to_key[task_arn]
292
298
  task_info = self.active_workers.info_by_key(task_key)
293
299
  task_cmd = task_info.cmd
@@ -303,8 +309,7 @@ class AwsEcsExecutor(BaseExecutor):
303
309
  self.__class__.MAX_RUN_TASK_ATTEMPTS,
304
310
  task_arn,
305
311
  )
306
- self.active_workers.increment_failure_count(task_key)
307
- self.pending_tasks.appendleft(
312
+ self.pending_tasks.append(
308
313
  EcsQueuedTask(
309
314
  task_key,
310
315
  task_cmd,
@@ -320,8 +325,8 @@ class AwsEcsExecutor(BaseExecutor):
320
325
  task_key,
321
326
  failure_count,
322
327
  )
323
- self.active_workers.pop_by_key(task_key)
324
328
  self.fail(task_key)
329
+ self.active_workers.pop_by_key(task_key)
325
330
 
326
331
  def attempt_task_runs(self):
327
332
  """
@@ -344,16 +349,17 @@ class AwsEcsExecutor(BaseExecutor):
344
349
  attempt_number = ecs_task.attempt_number
345
350
  _failure_reasons = []
346
351
  if timezone.utcnow() < ecs_task.next_attempt_time:
352
+ self.pending_tasks.append(ecs_task)
347
353
  continue
348
354
  try:
349
355
  run_task_response = self._run_task(task_key, cmd, queue, exec_config)
350
356
  except NoCredentialsError:
351
- self.pending_tasks.appendleft(ecs_task)
357
+ self.pending_tasks.append(ecs_task)
352
358
  raise
353
359
  except ClientError as e:
354
360
  error_code = e.response["Error"]["Code"]
355
361
  if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
356
- self.pending_tasks.appendleft(ecs_task)
362
+ self.pending_tasks.append(ecs_task)
357
363
  raise
358
364
  _failure_reasons.append(str(e))
359
365
  except Exception as e:
@@ -373,12 +379,12 @@ class AwsEcsExecutor(BaseExecutor):
373
379
  for reason in _failure_reasons:
374
380
  failure_reasons[reason] += 1
375
381
  # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
376
- if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
382
+ if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
377
383
  ecs_task.attempt_number += 1
378
384
  ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay(
379
385
  attempt_number
380
386
  )
381
- self.pending_tasks.appendleft(ecs_task)
387
+ self.pending_tasks.append(ecs_task)
382
388
  else:
383
389
  self.log.error(
384
390
  "ECS task %s has failed a maximum of %s times. Marking as failed",
@@ -394,6 +400,7 @@ class AwsEcsExecutor(BaseExecutor):
394
400
  else:
395
401
  task = run_task_response["tasks"][0]
396
402
  self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
403
+ self.queued(task_key, task.task_arn)
397
404
  if failure_reasons:
398
405
  self.log.error(
399
406
  "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
@@ -494,3 +501,39 @@ class AwsEcsExecutor(BaseExecutor):
494
501
  'container "name" must be provided in "containerOverrides" configuration'
495
502
  )
496
503
  raise KeyError(f"No such container found by container name: {self.container_name}")
504
+
505
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
506
+ """
507
+ Adopt task instances which have an external_executor_id (the ECS task ARN).
508
+
509
+ Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
510
+ """
511
+ with Stats.timer("ecs_executor.adopt_task_instances.duration"):
512
+ adopted_tis: list[TaskInstance] = []
513
+
514
+ if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
515
+ task_descriptions = self.__describe_tasks(task_arns).get("tasks", [])
516
+
517
+ for task in task_descriptions:
518
+ ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
519
+ self.active_workers.add_task(
520
+ task,
521
+ ti.key,
522
+ ti.queue,
523
+ ti.command_as_list(),
524
+ ti.executor_config,
525
+ ti.prev_attempted_tries,
526
+ )
527
+ adopted_tis.append(ti)
528
+
529
+ if adopted_tis:
530
+ tasks = [f"{task} in state {task.state}" for task in adopted_tis]
531
+ task_instance_str = "\n\t".join(tasks)
532
+ self.log.info(
533
+ "Adopted the following %d tasks from a dead executor:\n\t%s",
534
+ len(adopted_tis),
535
+ task_instance_str,
536
+ )
537
+
538
+ not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
539
+ return not_adopted_tis
@@ -30,6 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List
30
30
 
31
31
  from inflection import camelize
32
32
 
33
+ from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
33
34
  from airflow.utils.state import State
34
35
 
35
36
  if TYPE_CHECKING:
@@ -71,36 +72,28 @@ class EcsTaskInfo:
71
72
  config: ExecutorConfigType
72
73
 
73
74
 
74
- class BaseConfigKeys:
75
- """Base Implementation of the Config Keys class. Implements iteration for child classes to inherit."""
76
-
77
- def __iter__(self):
78
- """Return an iterator of values of non dunder attributes of Config Keys."""
79
- return iter({value for (key, value) in self.__class__.__dict__.items() if not key.startswith("__")})
80
-
81
-
82
75
  class RunTaskKwargsConfigKeys(BaseConfigKeys):
83
76
  """Keys loaded into the config which are valid ECS run_task kwargs."""
84
77
 
85
78
  ASSIGN_PUBLIC_IP = "assign_public_ip"
86
79
  CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
87
80
  CLUSTER = "cluster"
81
+ CONTAINER_NAME = "container_name"
88
82
  LAUNCH_TYPE = "launch_type"
89
83
  PLATFORM_VERSION = "platform_version"
90
84
  SECURITY_GROUPS = "security_groups"
91
85
  SUBNETS = "subnets"
92
86
  TASK_DEFINITION = "task_definition"
93
- CONTAINER_NAME = "container_name"
94
87
 
95
88
 
96
89
  class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
97
90
  """All keys loaded into the config which are related to the ECS Executor."""
98
91
 
99
- MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
100
92
  AWS_CONN_ID = "conn_id"
101
- RUN_TASK_KWARGS = "run_task_kwargs"
102
- REGION_NAME = "region_name"
103
93
  CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
94
+ MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
95
+ REGION_NAME = "region_name"
96
+ RUN_TASK_KWARGS = "run_task_kwargs"
104
97
 
105
98
 
106
99
  class EcsExecutorException(Exception):
@@ -108,7 +101,7 @@ class EcsExecutorException(Exception):
108
101
 
109
102
 
110
103
  class EcsExecutorTask:
111
- """Data Transfer Object for an ECS Fargate Task."""
104
+ """Data Transfer Object for an ECS Task."""
112
105
 
113
106
  def __init__(
114
107
  self,
@@ -118,6 +111,7 @@ class EcsExecutorTask:
118
111
  containers: list[dict[str, Any]],
119
112
  started_at: Any | None = None,
120
113
  stopped_reason: str | None = None,
114
+ external_executor_id: str | None = None,
121
115
  ):
122
116
  self.task_arn = task_arn
123
117
  self.last_status = last_status
@@ -125,6 +119,7 @@ class EcsExecutorTask:
125
119
  self.containers = containers
126
120
  self.started_at = started_at
127
121
  self.stopped_reason = stopped_reason
122
+ self.external_executor_id = external_executor_id
128
123
 
129
124
  def get_task_state(self) -> str:
130
125
  """
@@ -0,0 +1,25 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+
20
+ class BaseConfigKeys:
21
+ """Base Implementation of the Config Keys class. Implements iteration for child classes to inherit."""
22
+
23
+ def __iter__(self):
24
+ """Return an iterator of values of non dunder attributes of Config Keys."""
25
+ return iter({value for (key, value) in self.__class__.__dict__.items() if not key.startswith("__")})
@@ -22,6 +22,7 @@ This module contains AWS Athena hook.
22
22
 
23
23
  PageIterator
24
24
  """
25
+
25
26
  from __future__ import annotations
26
27
 
27
28
  import warnings
@@ -22,6 +22,7 @@ This module contains Base AWS Hook.
22
22
  For more information on how to use this hook, take a look at the guide:
23
23
  :ref:`howto/connection:aws`
24
24
  """
25
+
25
26
  from __future__ import annotations
26
27
 
27
28
  import datetime
@@ -24,6 +24,7 @@ A client for AWS Batch services.
24
24
  - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
25
25
  - https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
26
26
  """
27
+
27
28
  from __future__ import annotations
28
29
 
29
30
  import itertools
@@ -438,7 +439,7 @@ class BatchClientHook(AwsBaseHook):
438
439
  return None
439
440
  if len(all_info) > 1:
440
441
  self.log.warning(
441
- f"AWS Batch job ({job_id}) has more than one log stream, only returning the first one."
442
+ "AWS Batch job (%s) has more than one log stream, only returning the first one.", job_id
442
443
  )
443
444
  return all_info[0]
444
445
 
@@ -474,7 +475,7 @@ class BatchClientHook(AwsBaseHook):
474
475
  # If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
475
476
  if any(c.get("logDriver", "awslogs") != "awslogs" for c in log_configs):
476
477
  self.log.warning(
477
- f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
478
+ "AWS Batch job (%s) uses non-aws log drivers. AWS CloudWatch logging disabled.", job_id
478
479
  )
479
480
  return []
480
481
 
@@ -482,7 +483,7 @@ class BatchClientHook(AwsBaseHook):
482
483
  # If this method is called very early after starting the AWS Batch job,
483
484
  # there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
484
485
  # This can also happen in case of misconfiguration.
485
- self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.")
486
+ self.log.warning("AWS Batch job (%s) doesn't have any AWS CloudWatch Stream.", job_id)
486
487
  return []
487
488
 
488
489
  # Try to get user-defined log configuration options
@@ -23,6 +23,7 @@ AWS Batch service waiters.
23
23
  - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters
24
24
  - https://github.com/boto/botocore/blob/develop/botocore/waiter.py
25
25
  """
26
+
26
27
  from __future__ import annotations
27
28
 
28
29
  import json
@@ -0,0 +1,59 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
20
+
21
+
22
+ class BedrockHook(AwsBaseHook):
23
+ """
24
+ Interact with Amazon Bedrock.
25
+
26
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock") <Bedrock.Client>`.
27
+
28
+ Additional arguments (such as ``aws_conn_id``) may be specified and
29
+ are passed down to the underlying AwsBaseHook.
30
+
31
+ .. seealso::
32
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
33
+ """
34
+
35
+ client_type = "bedrock"
36
+
37
+ def __init__(self, *args, **kwargs) -> None:
38
+ kwargs["client_type"] = self.client_type
39
+ super().__init__(*args, **kwargs)
40
+
41
+
42
+ class BedrockRuntimeHook(AwsBaseHook):
43
+ """
44
+ Interact with the Amazon Bedrock Runtime.
45
+
46
+ Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") <BedrockRuntime.Client>`.
47
+
48
+ Additional arguments (such as ``aws_conn_id``) may be specified and
49
+ are passed down to the underlying AwsBaseHook.
50
+
51
+ .. seealso::
52
+ - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
53
+ """
54
+
55
+ client_type = "bedrock-runtime"
56
+
57
+ def __init__(self, *args, **kwargs) -> None:
58
+ kwargs["client_type"] = self.client_type
59
+ super().__init__(*args, **kwargs)
@@ -17,6 +17,7 @@
17
17
  # under the License.
18
18
 
19
19
  """This module contains a web hook for Chime."""
20
+
20
21
  from __future__ import annotations
21
22
 
22
23
  import json
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains AWS CloudFormation Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import TYPE_CHECKING
@@ -15,6 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
  """Interact with AWS DataSync, using the AWS ``boto3`` library."""
18
+
18
19
  from __future__ import annotations
19
20
 
20
21
  import time
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains the Amazon DynamoDB Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import Iterable
@@ -15,6 +15,7 @@
15
15
  # specific language governing permissions and limitations
16
16
  # under the License.
17
17
  """Interact with Amazon EKS, using the boto3 library."""
18
+
18
19
  from __future__ import annotations
19
20
 
20
21
  import base64
@@ -19,12 +19,13 @@ from __future__ import annotations
19
19
 
20
20
  import asyncio
21
21
  import time
22
+ from functools import cached_property
22
23
 
23
- import boto3
24
24
  from botocore.exceptions import ClientError
25
25
 
26
26
  from airflow.exceptions import AirflowException
27
27
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
28
+ from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
28
29
 
29
30
  DEFAULT_LOG_SUFFIX = "output"
30
31
  ERROR_LOG_SUFFIX = "error"
@@ -213,6 +214,13 @@ class GlueJobHook(AwsBaseHook):
213
214
  job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
214
215
  return job_run["JobRun"]["JobRunState"]
215
216
 
217
+ @cached_property
218
+ def logs_hook(self):
219
+ """Returns an AwsLogsHook instantiated with the parameters of the GlueJobHook."""
220
+ return AwsLogsHook(
221
+ aws_conn_id=self.aws_conn_id, region_name=self.region_name, verify=self.verify, config=self.config
222
+ )
223
+
216
224
  def print_job_logs(
217
225
  self,
218
226
  job_name: str,
@@ -225,7 +233,7 @@ class GlueJobHook(AwsBaseHook):
225
233
  :param continuation_tokens: the tokens where to resume from when reading logs.
226
234
  The object gets updated with the new tokens by this method.
227
235
  """
228
- log_client = boto3.client("logs")
236
+ log_client = self.logs_hook.get_conn()
229
237
  paginator = log_client.get_paginator("filter_log_events")
230
238
 
231
239
  def display_logs_from(log_group: str, continuation_token: str | None) -> str | None:
@@ -245,8 +253,9 @@ class GlueJobHook(AwsBaseHook):
245
253
  if e.response["Error"]["Code"] == "ResourceNotFoundException":
246
254
  # we land here when the log groups/streams don't exist yet
247
255
  self.log.warning(
248
- "No new Glue driver logs so far.\nIf this persists, check the CloudWatch dashboard "
249
- f"at: https://{self.conn_region_name}.console.aws.amazon.com/cloudwatch/home"
256
+ "No new Glue driver logs so far.\n"
257
+ "If this persists, check the CloudWatch dashboard at: %r.",
258
+ f"https://{self.conn_region_name}.console.aws.amazon.com/cloudwatch/home",
250
259
  )
251
260
  else:
252
261
  raise
@@ -263,7 +272,6 @@ class GlueJobHook(AwsBaseHook):
263
272
  log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
264
273
  log_group_default = f"{log_group_prefix}/{DEFAULT_LOG_SUFFIX}"
265
274
  log_group_error = f"{log_group_prefix}/{ERROR_LOG_SUFFIX}"
266
-
267
275
  # one would think that the error log group would contain only errors, but it actually contains
268
276
  # a lot of interesting logs too, so it's valuable to have both
269
277
  continuation_tokens.output_stream_continuation = display_logs_from(
@@ -16,6 +16,7 @@
16
16
  # specific language governing permissions and limitations
17
17
  # under the License.
18
18
  """This module contains AWS Glue Catalog Hook."""
19
+
19
20
  from __future__ import annotations
20
21
 
21
22
  from typing import Any