apache-airflow-providers-amazon 9.5.0rc3__py3-none-any.whl → 9.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
29
29
 
30
30
  __all__ = ["__version__"]
31
31
 
32
- __version__ = "9.5.0"
32
+ __version__ = "9.6.0"
33
33
 
34
34
  if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
35
35
  "2.9.0"
@@ -87,7 +87,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
87
87
 
88
88
  @cached_property
89
89
  def apiserver_endpoint(self) -> str:
90
- return conf.get("api", "base_url")
90
+ return conf.get("api", "base_url", fallback="/")
91
91
 
92
92
  def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
93
93
  return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
@@ -80,7 +80,7 @@ def login_callback(request: Request):
80
80
  username=saml_auth.get_nameid(),
81
81
  email=attributes["email"][0] if "email" in attributes else None,
82
82
  )
83
- url = conf.get("api", "base_url")
83
+ url = conf.get("api", "base_url", fallback="/")
84
84
  token = get_auth_manager().generate_jwt(user)
85
85
  response = RedirectResponse(url=url, status_code=303)
86
86
 
@@ -49,12 +49,16 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
49
49
  exponential_backoff_retry,
50
50
  )
51
51
  from airflow.providers.amazon.aws.hooks.ecs import EcsHook
52
+ from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
52
53
  from airflow.stats import Stats
53
54
  from airflow.utils import timezone
54
55
  from airflow.utils.helpers import merge_dicts
55
56
  from airflow.utils.state import State
56
57
 
57
58
  if TYPE_CHECKING:
59
+ from sqlalchemy.orm import Session
60
+
61
+ from airflow.executors import workloads
58
62
  from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
59
63
  from airflow.providers.amazon.aws.executors.ecs.utils import (
60
64
  CommandType,
@@ -100,6 +104,11 @@ class AwsEcsExecutor(BaseExecutor):
100
104
  # AWS limits the maximum number of ARNs in the describe_tasks function.
101
105
  DESCRIBE_TASKS_BATCH_SIZE = 99
102
106
 
107
+ if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
108
+ # In the v3 path, we store workloads, not commands as strings.
109
+ # TODO: TaskSDK: move this type change into BaseExecutor
110
+ queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
111
+
103
112
  def __init__(self, *args, **kwargs):
104
113
  super().__init__(*args, **kwargs)
105
114
  self.active_workers: EcsTaskCollection = EcsTaskCollection()
@@ -114,6 +123,31 @@ class AwsEcsExecutor(BaseExecutor):
114
123
 
115
124
  self.run_task_kwargs = self._load_run_kwargs()
116
125
 
126
+ def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
127
+ from airflow.executors import workloads
128
+
129
+ if not isinstance(workload, workloads.ExecuteTask):
130
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
131
+ ti = workload.ti
132
+ self.queued_tasks[ti.key] = workload
133
+
134
+ def _process_workloads(self, workloads: list[workloads.All]) -> None:
135
+ from airflow.executors.workloads import ExecuteTask
136
+
137
+ # Airflow V3 version
138
+ for w in workloads:
139
+ if not isinstance(w, ExecuteTask):
140
+ raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
141
+
142
+ command = [w]
143
+ key = w.ti.key
144
+ queue = w.ti.queue
145
+ executor_config = w.ti.executor_config or {}
146
+
147
+ del self.queued_tasks[key]
148
+ self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
149
+ self.running.add(key)
150
+
117
151
  def start(self):
118
152
  """Call this when the Executor is run for the first time by the scheduler."""
119
153
  check_health = conf.getboolean(
@@ -462,6 +496,24 @@ class AwsEcsExecutor(BaseExecutor):
462
496
  """Save the task to be executed in the next sync by inserting the commands into a queue."""
463
497
  if executor_config and ("name" in executor_config or "command" in executor_config):
464
498
  raise ValueError('Executor Config should never override "name" or "command"')
499
+ if len(command) == 1:
500
+ from airflow.executors.workloads import ExecuteTask
501
+
502
+ if isinstance(command[0], ExecuteTask):
503
+ workload = command[0]
504
+ ser_input = workload.model_dump_json()
505
+ command = [
506
+ "python",
507
+ "-m",
508
+ "airflow.sdk.execution_time.execute_workload",
509
+ "--json-string",
510
+ ser_input,
511
+ ]
512
+ else:
513
+ raise ValueError(
514
+ f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}"
515
+ )
516
+
465
517
  self.pending_tasks.append(
466
518
  EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
467
519
  )
@@ -19,12 +19,13 @@ from __future__ import annotations
19
19
 
20
20
  import asyncio
21
21
  import time
22
+ import warnings
22
23
  from functools import cached_property
23
24
  from typing import Any
24
25
 
25
26
  from botocore.exceptions import ClientError
26
27
 
27
- from airflow.exceptions import AirflowException
28
+ from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
28
29
  from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
29
30
  from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
30
31
 
@@ -145,7 +146,7 @@ class GlueJobHook(AwsBaseHook):
145
146
 
146
147
  return config
147
148
 
148
- def list_jobs(self) -> list:
149
+ def describe_jobs(self) -> list:
149
150
  """
150
151
  Get list of Jobs.
151
152
 
@@ -154,6 +155,20 @@ class GlueJobHook(AwsBaseHook):
154
155
  """
155
156
  return self.conn.get_jobs()
156
157
 
158
+ def list_jobs(self) -> list:
159
+ """
160
+ Get list of Jobs.
161
+
162
+ .. deprecated::
163
+ - Use :meth:`describe_jobs` instead.
164
+ """
165
+ warnings.warn(
166
+ "The method `list_jobs` is deprecated. Use the method `describe_jobs` instead.",
167
+ AirflowProviderDeprecationWarning,
168
+ stacklevel=2,
169
+ )
170
+ return self.describe_jobs()
171
+
157
172
  def get_iam_execution_role(self) -> dict:
158
173
  try:
159
174
  iam_client = self.get_session(region_name=self.region_name).client(
@@ -26,7 +26,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
26
26
 
27
27
  class MwaaHook(AwsBaseHook):
28
28
  """
29
- Interact with AWS Manager Workflows for Apache Airflow.
29
+ Interact with AWS Managed Workflows for Apache Airflow.
30
30
 
31
31
  Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
32
32
 
@@ -790,10 +790,6 @@ class S3Hook(AwsBaseHook):
790
790
  "FAILURE: Inactivity Period passed, not enough objects found in %s",
791
791
  path,
792
792
  )
793
- return {
794
- "status": "error",
795
- "message": f"FAILURE: Inactivity Period passed, not enough objects found in {path}",
796
- }
797
793
  return {
798
794
  "status": "pending",
799
795
  "previous_objects": previous_objects,
@@ -33,6 +33,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import (
33
33
  )
34
34
  from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
35
35
  from airflow.providers.amazon.aws.triggers.bedrock import (
36
+ BedrockBatchInferenceCompletedTrigger,
36
37
  BedrockCustomizeModelCompletedTrigger,
37
38
  BedrockIngestionJobTrigger,
38
39
  BedrockKnowledgeBaseActiveTrigger,
@@ -869,3 +870,121 @@ class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
869
870
 
870
871
  self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, result["retrievalResults"])
871
872
  return result
873
+
874
+
875
+ class BedrockBatchInferenceOperator(AwsBaseOperator[BedrockHook]):
876
+ """
877
+ Create a batch inference job to invoke a model on multiple prompts.
878
+
879
+ .. seealso::
880
+ For more information on how to use this operator, take a look at the guide:
881
+ :ref:`howto/operator:BedrockBatchInferenceOperator`
882
+
883
+ :param job_name: A name to give the batch inference job. (templated)
884
+ :param role_arn: The ARN of the IAM role with permissions to create the knowledge base. (templated)
885
+ :param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated)
886
+ :param input_uri: The S3 location of the input data. (templated)
887
+ :param output_uri: The S3 location of the output data. (templated)
888
+ :param invoke_kwargs: Additional keyword arguments to pass to the API call. (templated)
889
+
890
+ :param wait_for_completion: Whether to wait for cluster to stop. (default: True)
891
+ NOTE: The way batch inference jobs work, your jobs are added to a queue and done "eventually"
892
+ so using deferrable mode is much more practical than using wait_for_completion.
893
+ :param waiter_delay: Time in seconds to wait between status checks. (default: 60)
894
+ :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 10)
895
+ :param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
896
+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
897
+ (default: False)
898
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
899
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
900
+ running Airflow in a distributed manner and aws_conn_id is None or
901
+ empty, then default boto3 configuration would be used (and must be
902
+ maintained on each worker node).
903
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
904
+ :param verify: Whether or not to verify SSL certificates. See:
905
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
906
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
907
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
908
+ """
909
+
910
+ aws_hook_class = BedrockHook
911
+ template_fields: Sequence[str] = aws_template_fields(
912
+ "job_name",
913
+ "role_arn",
914
+ "model_id",
915
+ "input_uri",
916
+ "output_uri",
917
+ "invoke_kwargs",
918
+ )
919
+
920
+ def __init__(
921
+ self,
922
+ job_name: str,
923
+ role_arn: str,
924
+ model_id: str,
925
+ input_uri: str,
926
+ output_uri: str,
927
+ invoke_kwargs: dict[str, Any] | None = None,
928
+ wait_for_completion: bool = True,
929
+ waiter_delay: int = 60,
930
+ waiter_max_attempts: int = 10,
931
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
932
+ **kwargs,
933
+ ):
934
+ super().__init__(**kwargs)
935
+ self.job_name = job_name
936
+ self.role_arn = role_arn
937
+ self.model_id = model_id
938
+ self.input_uri = input_uri
939
+ self.output_uri = output_uri
940
+ self.invoke_kwargs = invoke_kwargs or {}
941
+
942
+ self.wait_for_completion = wait_for_completion
943
+ self.waiter_delay = waiter_delay
944
+ self.waiter_max_attempts = waiter_max_attempts
945
+ self.deferrable = deferrable
946
+
947
+ self.activity = "Bedrock batch inference job"
948
+
949
+ def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
950
+ validated_event = validate_execute_complete_event(event)
951
+
952
+ if validated_event["status"] != "success":
953
+ raise AirflowException(f"Error while running {self.activity}: {validated_event}")
954
+
955
+ self.log.info("%s '%s' complete.", self.activity, validated_event["job_arn"])
956
+
957
+ return validated_event["job_arn"]
958
+
959
+ def execute(self, context: Context) -> str:
960
+ response = self.hook.conn.create_model_invocation_job(
961
+ jobName=self.job_name,
962
+ roleArn=self.role_arn,
963
+ modelId=self.model_id,
964
+ inputDataConfig={"s3InputDataConfig": {"s3Uri": self.input_uri}},
965
+ outputDataConfig={"s3OutputDataConfig": {"s3Uri": self.output_uri}},
966
+ **self.invoke_kwargs,
967
+ )
968
+ job_arn = response["jobArn"]
969
+ self.log.info("%s '%s' started with ARN: %s", self.activity, self.job_name, job_arn)
970
+
971
+ task_description = f"for {self.activity} '{self.job_name}' to complete."
972
+ if self.deferrable:
973
+ self.log.info("Deferring %s", task_description)
974
+ self.defer(
975
+ trigger=BedrockBatchInferenceCompletedTrigger(
976
+ job_arn=job_arn,
977
+ waiter_delay=self.waiter_delay,
978
+ waiter_max_attempts=self.waiter_max_attempts,
979
+ aws_conn_id=self.aws_conn_id,
980
+ ),
981
+ method_name="execute_complete",
982
+ )
983
+ elif self.wait_for_completion:
984
+ self.log.info("Waiting %s", task_description)
985
+ self.hook.get_waiter(waiter_name="batch_inference_complete").wait(
986
+ jobIdentifier=job_arn,
987
+ WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
988
+ )
989
+
990
+ return job_arn
@@ -20,19 +20,19 @@ from __future__ import annotations
20
20
  import json
21
21
  from collections.abc import Sequence
22
22
  from datetime import timedelta
23
- from functools import cached_property
24
23
  from typing import TYPE_CHECKING, Any
25
24
 
26
25
  from airflow.configuration import conf
27
26
  from airflow.exceptions import AirflowException
28
- from airflow.models import BaseOperator
29
27
  from airflow.providers.amazon.aws.hooks.rds import RdsHook
28
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
30
29
  from airflow.providers.amazon.aws.triggers.rds import (
31
30
  RdsDbAvailableTrigger,
32
31
  RdsDbDeletedTrigger,
33
32
  RdsDbStoppedTrigger,
34
33
  )
35
34
  from airflow.providers.amazon.aws.utils import validate_execute_complete_event
35
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
36
36
  from airflow.providers.amazon.aws.utils.rds import RdsDbType
37
37
  from airflow.providers.amazon.aws.utils.tags import format_tags
38
38
  from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
@@ -44,9 +44,10 @@ if TYPE_CHECKING:
44
44
  from airflow.utils.context import Context
45
45
 
46
46
 
47
- class RdsBaseOperator(BaseOperator):
47
+ class RdsBaseOperator(AwsBaseOperator[RdsHook]):
48
48
  """Base operator that implements common functions for all operators."""
49
49
 
50
+ aws_hook_class = RdsHook
50
51
  ui_color = "#eeaa88"
51
52
  ui_fgcolor = "#ffffff"
52
53
 
@@ -63,10 +64,6 @@ class RdsBaseOperator(BaseOperator):
63
64
 
64
65
  self._await_interval = 60 # seconds
65
66
 
66
- @cached_property
67
- def hook(self) -> RdsHook:
68
- return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
69
-
70
67
  def execute(self, context: Context) -> str:
71
68
  """Different implementations for snapshots, tasks and events."""
72
69
  raise NotImplementedError
@@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
92
89
  :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
93
90
  `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
94
91
  :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
92
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
93
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
94
+ running Airflow in a distributed manner and aws_conn_id is None or
95
+ empty, then default boto3 configuration would be used (and must be
96
+ maintained on each worker node).
97
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
98
+ :param verify: Whether or not to verify SSL certificates. See:
99
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
100
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
101
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
95
102
  """
96
103
 
97
- template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
104
+ template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags")
98
105
 
99
106
  def __init__(
100
107
  self,
@@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
167
174
  Only when db_type='instance'
168
175
  :param source_region: The ID of the region that contains the snapshot to be copied
169
176
  :param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
177
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
178
+ :param verify: Whether or not to verify SSL certificates. See:
179
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
180
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
181
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
170
182
  """
171
183
 
172
- template_fields = (
184
+ template_fields = aws_template_fields(
173
185
  "source_db_snapshot_identifier",
174
186
  "target_db_snapshot_identifier",
175
187
  "tags",
@@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
260
272
 
261
273
  :param db_type: Type of the DB - either "instance" or "cluster"
262
274
  :param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
275
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
276
+ :param verify: Whether or not to verify SSL certificates. See:
277
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
278
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
279
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
263
280
  """
264
281
 
265
- template_fields = ("db_snapshot_identifier",)
282
+ template_fields = aws_template_fields(
283
+ "db_snapshot_identifier",
284
+ )
266
285
 
267
286
  def __init__(
268
287
  self,
@@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
319
338
  :param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
320
339
  :param waiter_interval: The number of seconds to wait before checking the export status. (default: 30)
321
340
  :param waiter_max_attempts: The number of attempts to make before failing. (default: 40)
341
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
342
+ :param verify: Whether or not to verify SSL certificates. See:
343
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
344
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
345
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
322
346
  """
323
347
 
324
- template_fields = (
348
+ template_fields = aws_template_fields(
325
349
  "export_task_identifier",
326
350
  "source_arn",
327
351
  "s3_bucket_name",
@@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
394
418
  :param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
395
419
  :param check_interval: The amount of time in seconds to wait between attempts
396
420
  :param max_attempts: The maximum number of attempts to be made
421
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
422
+ :param verify: Whether or not to verify SSL certificates. See:
423
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
424
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
425
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
397
426
  """
398
427
 
399
- template_fields = ("export_task_identifier",)
428
+ template_fields = aws_template_fields(
429
+ "export_task_identifier",
430
+ )
400
431
 
401
432
  def __init__(
402
433
  self,
@@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
450
481
  :param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
451
482
  `USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
452
483
  :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
484
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
485
+ :param verify: Whether or not to verify SSL certificates. See:
486
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
487
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
488
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
453
489
  """
454
490
 
455
- template_fields = (
491
+ template_fields = aws_template_fields(
456
492
  "subscription_name",
457
493
  "sns_topic_arn",
458
494
  "source_type",
@@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
513
549
  :ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
514
550
 
515
551
  :param subscription_name: The name of the RDS event notification subscription you want to delete
552
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
553
+ :param verify: Whether or not to verify SSL certificates. See:
554
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
555
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
556
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
516
557
  """
517
558
 
518
- template_fields = ("subscription_name",)
559
+ template_fields = aws_template_fields(
560
+ "subscription_name",
561
+ )
519
562
 
520
563
  def __init__(
521
564
  self,
@@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
560
603
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
561
604
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
562
605
  (default: False)
606
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
607
+ :param verify: Whether or not to verify SSL certificates. See:
608
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
609
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
610
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
563
611
  """
564
612
 
565
- template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs")
613
+ template_fields = aws_template_fields(
614
+ "db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
615
+ )
566
616
 
567
617
  def __init__(
568
618
  self,
@@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
652
702
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
653
703
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
654
704
  (default: False)
705
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
706
+ :param verify: Whether or not to verify SSL certificates. See:
707
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
708
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
709
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
655
710
  """
656
711
 
657
- template_fields = ("db_instance_identifier", "rds_kwargs")
712
+ template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs")
658
713
 
659
714
  def __init__(
660
715
  self,
@@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
735
790
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
736
791
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
737
792
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
793
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
794
+ :param verify: Whether or not to verify SSL certificates. See:
795
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
796
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
797
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
738
798
  """
739
799
 
740
- template_fields = ("db_identifier", "db_type")
800
+ template_fields = aws_template_fields("db_identifier", "db_type")
741
801
 
742
802
  def __init__(
743
803
  self,
@@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
832
892
  :param waiter_max_attempts: The maximum number of attempts to check DB instance state
833
893
  :param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
834
894
  This implies waiting for completion. This mode requires aiobotocore module to be installed.
895
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
896
+ :param verify: Whether or not to verify SSL certificates. See:
897
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
898
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
899
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
835
900
  """
836
901
 
837
- template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
902
+ template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type")
838
903
 
839
904
  def __init__(
840
905
  self,
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
26
26
  from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
27
27
  from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
28
28
  from airflow.providers.amazon.aws.triggers.bedrock import (
29
+ BedrockBatchInferenceCompletedTrigger,
30
+ BedrockBatchInferenceScheduledTrigger,
29
31
  BedrockCustomizeModelCompletedTrigger,
30
32
  BedrockIngestionJobTrigger,
31
33
  BedrockKnowledgeBaseActiveTrigger,
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
34
36
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
35
37
 
36
38
  if TYPE_CHECKING:
39
+ from airflow.providers.amazon.aws.triggers.bedrock import BedrockBaseBatchInferenceTrigger
37
40
  from airflow.utils.context import Context
38
41
 
39
42
 
@@ -368,3 +371,110 @@ class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
368
371
  )
369
372
  else:
370
373
  super().execute(context=context)
374
+
375
+
376
+ class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
377
+ """
378
+ Poll the batch inference job status until it reaches a terminal state; fails if creation fails.
379
+
380
+ .. seealso::
381
+ For more information on how to use this sensor, take a look at the guide:
382
+ :ref:`howto/sensor:BedrockBatchInferenceSensor`
383
+
384
+ :param job_arn: The Amazon Resource Name (ARN) of the batch inference job. (templated)
385
+ :param success_state: A BedrockBatchInferenceSensor.TargetState; defaults to 'SCHEDULED' (templated)
386
+
387
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
388
+ module to be installed.
389
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
390
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
391
+ :param max_retries: Number of times before returning the current state (default: 24)
392
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
393
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
394
+ running Airflow in a distributed manner and aws_conn_id is None or
395
+ empty, then default boto3 configuration would be used (and must be
396
+ maintained on each worker node).
397
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
398
+ :param verify: Whether or not to verify SSL certificates. See:
399
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
400
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
401
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
402
+ """
403
+
404
+ class SuccessState:
405
+ """
406
+ Target state for the BedrockBatchInferenceSensor.
407
+
408
+ Bedrock adds batch inference jobs to a queue, and they may take some time to complete.
409
+ If you want to wait for the job to complete, use TargetState.COMPLETED, but if you only want
410
+ to wait until the service confirms that the job is in the queue, use TargetState.SCHEDULED.
411
+
412
+ The normal successful progression of states is:
413
+ Submitted > Validating > Scheduled > InProgress > PartiallyCompleted > Completed
414
+ """
415
+
416
+ SCHEDULED = "scheduled"
417
+ COMPLETED = "completed"
418
+
419
+ INTERMEDIATE_STATES: tuple[str, ...] # Defined in __init__ based on target state
420
+ FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped", "PartiallyCompleted", "Expired")
421
+ SUCCESS_STATES: tuple[str, ...] # Defined in __init__ based on target state
422
+ FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
423
+ INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of TargetState."
424
+
425
+ aws_hook_class = BedrockHook
426
+
427
+ template_fields: Sequence[str] = aws_template_fields("job_arn", "success_state")
428
+
429
+ def __init__(
430
+ self,
431
+ *,
432
+ job_arn: str,
433
+ success_state: SuccessState | str = SuccessState.SCHEDULED,
434
+ poke_interval: int = 120,
435
+ max_retries: int = 75,
436
+ **kwargs,
437
+ ) -> None:
438
+ super().__init__(**kwargs)
439
+ self.poke_interval = poke_interval
440
+ self.max_retries = max_retries
441
+ self.job_arn = job_arn
442
+ self.success_state = success_state
443
+
444
+ base_success_states: tuple[str, ...] = ("Completed",)
445
+ base_intermediate_states: tuple[str, ...] = ("Submitted", "InProgress", "Stopping", "Validating")
446
+ scheduled_state = ("Scheduled",)
447
+ self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
448
+
449
+ if self.success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
450
+ intermediate_states = base_intermediate_states + scheduled_state
451
+ success_states = base_success_states
452
+ self.trigger_class = BedrockBatchInferenceCompletedTrigger
453
+ elif self.success_state == BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
454
+ intermediate_states = base_intermediate_states
455
+ success_states = base_success_states + scheduled_state
456
+ self.trigger_class = BedrockBatchInferenceScheduledTrigger
457
+ else:
458
+ raise ValueError(
459
+ "Success states for BedrockBatchInferenceSensor must be set using a BedrockBatchInferenceSensor.SuccessState"
460
+ )
461
+
462
+ BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states or base_intermediate_states
463
+ BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or base_success_states
464
+
465
+ def get_state(self) -> str:
466
+ return self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
467
+
468
+ def execute(self, context: Context) -> Any:
469
+ if self.deferrable:
470
+ self.defer(
471
+ trigger=self.trigger_class(
472
+ job_arn=self.job_arn,
473
+ waiter_delay=int(self.poke_interval),
474
+ waiter_max_attempts=self.max_retries,
475
+ aws_conn_id=self.aws_conn_id,
476
+ ),
477
+ method_name="poke",
478
+ )
479
+ else:
480
+ super().execute(context=context)