apache-airflow-providers-amazon 9.5.0rc3__py3-none-any.whl → 9.6.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +1 -1
- airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +1 -1
- airflow/providers/amazon/aws/auth_manager/router/login.py +1 -1
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +52 -0
- airflow/providers/amazon/aws/hooks/glue.py +17 -2
- airflow/providers/amazon/aws/hooks/mwaa.py +1 -1
- airflow/providers/amazon/aws/hooks/s3.py +0 -4
- airflow/providers/amazon/aws/operators/bedrock.py +119 -0
- airflow/providers/amazon/aws/operators/rds.py +83 -18
- airflow/providers/amazon/aws/sensors/bedrock.py +110 -0
- airflow/providers/amazon/aws/sensors/rds.py +23 -20
- airflow/providers/amazon/aws/triggers/bedrock.py +98 -0
- airflow/providers/amazon/aws/utils/waiter_with_logging.py +9 -1
- airflow/providers/amazon/aws/waiters/bedrock.json +134 -0
- airflow/providers/amazon/get_provider_info.py +0 -124
- {apache_airflow_providers_amazon-9.5.0rc3.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/METADATA +18 -18
- {apache_airflow_providers_amazon-9.5.0rc3.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/RECORD +19 -19
- {apache_airflow_providers_amazon-9.5.0rc3.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/WHEEL +1 -1
- {apache_airflow_providers_amazon-9.5.0rc3.dist-info → apache_airflow_providers_amazon-9.6.0rc1.dist-info}/entry_points.txt +0 -0
@@ -29,7 +29,7 @@ from airflow import __version__ as airflow_version
|
|
29
29
|
|
30
30
|
__all__ = ["__version__"]
|
31
31
|
|
32
|
-
__version__ = "9.
|
32
|
+
__version__ = "9.6.0"
|
33
33
|
|
34
34
|
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
|
35
35
|
"2.9.0"
|
@@ -87,7 +87,7 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
|
|
87
87
|
|
88
88
|
@cached_property
|
89
89
|
def apiserver_endpoint(self) -> str:
|
90
|
-
return conf.get("api", "base_url")
|
90
|
+
return conf.get("api", "base_url", fallback="/")
|
91
91
|
|
92
92
|
def deserialize_user(self, token: dict[str, Any]) -> AwsAuthManagerUser:
|
93
93
|
return AwsAuthManagerUser(user_id=token.pop("sub"), **token)
|
@@ -80,7 +80,7 @@ def login_callback(request: Request):
|
|
80
80
|
username=saml_auth.get_nameid(),
|
81
81
|
email=attributes["email"][0] if "email" in attributes else None,
|
82
82
|
)
|
83
|
-
url = conf.get("api", "base_url")
|
83
|
+
url = conf.get("api", "base_url", fallback="/")
|
84
84
|
token = get_auth_manager().generate_jwt(user)
|
85
85
|
response = RedirectResponse(url=url, status_code=303)
|
86
86
|
|
@@ -49,12 +49,16 @@ from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry impo
|
|
49
49
|
exponential_backoff_retry,
|
50
50
|
)
|
51
51
|
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
|
52
|
+
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
|
52
53
|
from airflow.stats import Stats
|
53
54
|
from airflow.utils import timezone
|
54
55
|
from airflow.utils.helpers import merge_dicts
|
55
56
|
from airflow.utils.state import State
|
56
57
|
|
57
58
|
if TYPE_CHECKING:
|
59
|
+
from sqlalchemy.orm import Session
|
60
|
+
|
61
|
+
from airflow.executors import workloads
|
58
62
|
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
|
59
63
|
from airflow.providers.amazon.aws.executors.ecs.utils import (
|
60
64
|
CommandType,
|
@@ -100,6 +104,11 @@ class AwsEcsExecutor(BaseExecutor):
|
|
100
104
|
# AWS limits the maximum number of ARNs in the describe_tasks function.
|
101
105
|
DESCRIBE_TASKS_BATCH_SIZE = 99
|
102
106
|
|
107
|
+
if TYPE_CHECKING and AIRFLOW_V_3_0_PLUS:
|
108
|
+
# In the v3 path, we store workloads, not commands as strings.
|
109
|
+
# TODO: TaskSDK: move this type change into BaseExecutor
|
110
|
+
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
|
111
|
+
|
103
112
|
def __init__(self, *args, **kwargs):
|
104
113
|
super().__init__(*args, **kwargs)
|
105
114
|
self.active_workers: EcsTaskCollection = EcsTaskCollection()
|
@@ -114,6 +123,31 @@ class AwsEcsExecutor(BaseExecutor):
|
|
114
123
|
|
115
124
|
self.run_task_kwargs = self._load_run_kwargs()
|
116
125
|
|
126
|
+
def queue_workload(self, workload: workloads.All, session: Session | None) -> None:
|
127
|
+
from airflow.executors import workloads
|
128
|
+
|
129
|
+
if not isinstance(workload, workloads.ExecuteTask):
|
130
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(workload)}")
|
131
|
+
ti = workload.ti
|
132
|
+
self.queued_tasks[ti.key] = workload
|
133
|
+
|
134
|
+
def _process_workloads(self, workloads: list[workloads.All]) -> None:
|
135
|
+
from airflow.executors.workloads import ExecuteTask
|
136
|
+
|
137
|
+
# Airflow V3 version
|
138
|
+
for w in workloads:
|
139
|
+
if not isinstance(w, ExecuteTask):
|
140
|
+
raise RuntimeError(f"{type(self)} cannot handle workloads of type {type(w)}")
|
141
|
+
|
142
|
+
command = [w]
|
143
|
+
key = w.ti.key
|
144
|
+
queue = w.ti.queue
|
145
|
+
executor_config = w.ti.executor_config or {}
|
146
|
+
|
147
|
+
del self.queued_tasks[key]
|
148
|
+
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) # type: ignore[arg-type]
|
149
|
+
self.running.add(key)
|
150
|
+
|
117
151
|
def start(self):
|
118
152
|
"""Call this when the Executor is run for the first time by the scheduler."""
|
119
153
|
check_health = conf.getboolean(
|
@@ -462,6 +496,24 @@ class AwsEcsExecutor(BaseExecutor):
|
|
462
496
|
"""Save the task to be executed in the next sync by inserting the commands into a queue."""
|
463
497
|
if executor_config and ("name" in executor_config or "command" in executor_config):
|
464
498
|
raise ValueError('Executor Config should never override "name" or "command"')
|
499
|
+
if len(command) == 1:
|
500
|
+
from airflow.executors.workloads import ExecuteTask
|
501
|
+
|
502
|
+
if isinstance(command[0], ExecuteTask):
|
503
|
+
workload = command[0]
|
504
|
+
ser_input = workload.model_dump_json()
|
505
|
+
command = [
|
506
|
+
"python",
|
507
|
+
"-m",
|
508
|
+
"airflow.sdk.execution_time.execute_workload",
|
509
|
+
"--json-string",
|
510
|
+
ser_input,
|
511
|
+
]
|
512
|
+
else:
|
513
|
+
raise ValueError(
|
514
|
+
f"EcsExecutor doesn't know how to handle workload of type: {type(command[0])}"
|
515
|
+
)
|
516
|
+
|
465
517
|
self.pending_tasks.append(
|
466
518
|
EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow())
|
467
519
|
)
|
@@ -19,12 +19,13 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
import asyncio
|
21
21
|
import time
|
22
|
+
import warnings
|
22
23
|
from functools import cached_property
|
23
24
|
from typing import Any
|
24
25
|
|
25
26
|
from botocore.exceptions import ClientError
|
26
27
|
|
27
|
-
from airflow.exceptions import AirflowException
|
28
|
+
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
|
28
29
|
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
29
30
|
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
|
30
31
|
|
@@ -145,7 +146,7 @@ class GlueJobHook(AwsBaseHook):
|
|
145
146
|
|
146
147
|
return config
|
147
148
|
|
148
|
-
def
|
149
|
+
def describe_jobs(self) -> list:
|
149
150
|
"""
|
150
151
|
Get list of Jobs.
|
151
152
|
|
@@ -154,6 +155,20 @@ class GlueJobHook(AwsBaseHook):
|
|
154
155
|
"""
|
155
156
|
return self.conn.get_jobs()
|
156
157
|
|
158
|
+
def list_jobs(self) -> list:
|
159
|
+
"""
|
160
|
+
Get list of Jobs.
|
161
|
+
|
162
|
+
.. deprecated::
|
163
|
+
- Use :meth:`describe_jobs` instead.
|
164
|
+
"""
|
165
|
+
warnings.warn(
|
166
|
+
"The method `list_jobs` is deprecated. Use the method `describe_jobs` instead.",
|
167
|
+
AirflowProviderDeprecationWarning,
|
168
|
+
stacklevel=2,
|
169
|
+
)
|
170
|
+
return self.describe_jobs()
|
171
|
+
|
157
172
|
def get_iam_execution_role(self) -> dict:
|
158
173
|
try:
|
159
174
|
iam_client = self.get_session(region_name=self.region_name).client(
|
@@ -26,7 +26,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
|
26
26
|
|
27
27
|
class MwaaHook(AwsBaseHook):
|
28
28
|
"""
|
29
|
-
Interact with AWS
|
29
|
+
Interact with AWS Managed Workflows for Apache Airflow.
|
30
30
|
|
31
31
|
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
|
32
32
|
|
@@ -790,10 +790,6 @@ class S3Hook(AwsBaseHook):
|
|
790
790
|
"FAILURE: Inactivity Period passed, not enough objects found in %s",
|
791
791
|
path,
|
792
792
|
)
|
793
|
-
return {
|
794
|
-
"status": "error",
|
795
|
-
"message": f"FAILURE: Inactivity Period passed, not enough objects found in {path}",
|
796
|
-
}
|
797
793
|
return {
|
798
794
|
"status": "pending",
|
799
795
|
"previous_objects": previous_objects,
|
@@ -33,6 +33,7 @@ from airflow.providers.amazon.aws.hooks.bedrock import (
|
|
33
33
|
)
|
34
34
|
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
35
35
|
from airflow.providers.amazon.aws.triggers.bedrock import (
|
36
|
+
BedrockBatchInferenceCompletedTrigger,
|
36
37
|
BedrockCustomizeModelCompletedTrigger,
|
37
38
|
BedrockIngestionJobTrigger,
|
38
39
|
BedrockKnowledgeBaseActiveTrigger,
|
@@ -869,3 +870,121 @@ class BedrockRetrieveOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
|
|
869
870
|
|
870
871
|
self.log.info("\nQuery: %s\nRetrieved: %s", self.retrieval_query, result["retrievalResults"])
|
871
872
|
return result
|
873
|
+
|
874
|
+
|
875
|
+
class BedrockBatchInferenceOperator(AwsBaseOperator[BedrockHook]):
|
876
|
+
"""
|
877
|
+
Create a batch inference job to invoke a model on multiple prompts.
|
878
|
+
|
879
|
+
.. seealso::
|
880
|
+
For more information on how to use this operator, take a look at the guide:
|
881
|
+
:ref:`howto/operator:BedrockBatchInferenceOperator`
|
882
|
+
|
883
|
+
:param job_name: A name to give the batch inference job. (templated)
|
884
|
+
:param role_arn: The ARN of the IAM role with permissions to create the knowledge base. (templated)
|
885
|
+
:param model_id: Name or ARN of the model to associate with this provisioned throughput. (templated)
|
886
|
+
:param input_uri: The S3 location of the input data. (templated)
|
887
|
+
:param output_uri: The S3 location of the output data. (templated)
|
888
|
+
:param invoke_kwargs: Additional keyword arguments to pass to the API call. (templated)
|
889
|
+
|
890
|
+
:param wait_for_completion: Whether to wait for cluster to stop. (default: True)
|
891
|
+
NOTE: The way batch inference jobs work, your jobs are added to a queue and done "eventually"
|
892
|
+
so using deferrable mode is much more practical than using wait_for_completion.
|
893
|
+
:param waiter_delay: Time in seconds to wait between status checks. (default: 60)
|
894
|
+
:param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 10)
|
895
|
+
:param deferrable: If True, the operator will wait asynchronously for the cluster to stop.
|
896
|
+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
897
|
+
(default: False)
|
898
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
899
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
900
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
901
|
+
empty, then default boto3 configuration would be used (and must be
|
902
|
+
maintained on each worker node).
|
903
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
904
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
905
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
906
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
907
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
908
|
+
"""
|
909
|
+
|
910
|
+
aws_hook_class = BedrockHook
|
911
|
+
template_fields: Sequence[str] = aws_template_fields(
|
912
|
+
"job_name",
|
913
|
+
"role_arn",
|
914
|
+
"model_id",
|
915
|
+
"input_uri",
|
916
|
+
"output_uri",
|
917
|
+
"invoke_kwargs",
|
918
|
+
)
|
919
|
+
|
920
|
+
def __init__(
|
921
|
+
self,
|
922
|
+
job_name: str,
|
923
|
+
role_arn: str,
|
924
|
+
model_id: str,
|
925
|
+
input_uri: str,
|
926
|
+
output_uri: str,
|
927
|
+
invoke_kwargs: dict[str, Any] | None = None,
|
928
|
+
wait_for_completion: bool = True,
|
929
|
+
waiter_delay: int = 60,
|
930
|
+
waiter_max_attempts: int = 10,
|
931
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
932
|
+
**kwargs,
|
933
|
+
):
|
934
|
+
super().__init__(**kwargs)
|
935
|
+
self.job_name = job_name
|
936
|
+
self.role_arn = role_arn
|
937
|
+
self.model_id = model_id
|
938
|
+
self.input_uri = input_uri
|
939
|
+
self.output_uri = output_uri
|
940
|
+
self.invoke_kwargs = invoke_kwargs or {}
|
941
|
+
|
942
|
+
self.wait_for_completion = wait_for_completion
|
943
|
+
self.waiter_delay = waiter_delay
|
944
|
+
self.waiter_max_attempts = waiter_max_attempts
|
945
|
+
self.deferrable = deferrable
|
946
|
+
|
947
|
+
self.activity = "Bedrock batch inference job"
|
948
|
+
|
949
|
+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str:
|
950
|
+
validated_event = validate_execute_complete_event(event)
|
951
|
+
|
952
|
+
if validated_event["status"] != "success":
|
953
|
+
raise AirflowException(f"Error while running {self.activity}: {validated_event}")
|
954
|
+
|
955
|
+
self.log.info("%s '%s' complete.", self.activity, validated_event["job_arn"])
|
956
|
+
|
957
|
+
return validated_event["job_arn"]
|
958
|
+
|
959
|
+
def execute(self, context: Context) -> str:
|
960
|
+
response = self.hook.conn.create_model_invocation_job(
|
961
|
+
jobName=self.job_name,
|
962
|
+
roleArn=self.role_arn,
|
963
|
+
modelId=self.model_id,
|
964
|
+
inputDataConfig={"s3InputDataConfig": {"s3Uri": self.input_uri}},
|
965
|
+
outputDataConfig={"s3OutputDataConfig": {"s3Uri": self.output_uri}},
|
966
|
+
**self.invoke_kwargs,
|
967
|
+
)
|
968
|
+
job_arn = response["jobArn"]
|
969
|
+
self.log.info("%s '%s' started with ARN: %s", self.activity, self.job_name, job_arn)
|
970
|
+
|
971
|
+
task_description = f"for {self.activity} '{self.job_name}' to complete."
|
972
|
+
if self.deferrable:
|
973
|
+
self.log.info("Deferring %s", task_description)
|
974
|
+
self.defer(
|
975
|
+
trigger=BedrockBatchInferenceCompletedTrigger(
|
976
|
+
job_arn=job_arn,
|
977
|
+
waiter_delay=self.waiter_delay,
|
978
|
+
waiter_max_attempts=self.waiter_max_attempts,
|
979
|
+
aws_conn_id=self.aws_conn_id,
|
980
|
+
),
|
981
|
+
method_name="execute_complete",
|
982
|
+
)
|
983
|
+
elif self.wait_for_completion:
|
984
|
+
self.log.info("Waiting %s", task_description)
|
985
|
+
self.hook.get_waiter(waiter_name="batch_inference_complete").wait(
|
986
|
+
jobIdentifier=job_arn,
|
987
|
+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
|
988
|
+
)
|
989
|
+
|
990
|
+
return job_arn
|
@@ -20,19 +20,19 @@ from __future__ import annotations
|
|
20
20
|
import json
|
21
21
|
from collections.abc import Sequence
|
22
22
|
from datetime import timedelta
|
23
|
-
from functools import cached_property
|
24
23
|
from typing import TYPE_CHECKING, Any
|
25
24
|
|
26
25
|
from airflow.configuration import conf
|
27
26
|
from airflow.exceptions import AirflowException
|
28
|
-
from airflow.models import BaseOperator
|
29
27
|
from airflow.providers.amazon.aws.hooks.rds import RdsHook
|
28
|
+
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
|
30
29
|
from airflow.providers.amazon.aws.triggers.rds import (
|
31
30
|
RdsDbAvailableTrigger,
|
32
31
|
RdsDbDeletedTrigger,
|
33
32
|
RdsDbStoppedTrigger,
|
34
33
|
)
|
35
34
|
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
|
35
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
36
36
|
from airflow.providers.amazon.aws.utils.rds import RdsDbType
|
37
37
|
from airflow.providers.amazon.aws.utils.tags import format_tags
|
38
38
|
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
|
@@ -44,9 +44,10 @@ if TYPE_CHECKING:
|
|
44
44
|
from airflow.utils.context import Context
|
45
45
|
|
46
46
|
|
47
|
-
class RdsBaseOperator(
|
47
|
+
class RdsBaseOperator(AwsBaseOperator[RdsHook]):
|
48
48
|
"""Base operator that implements common functions for all operators."""
|
49
49
|
|
50
|
+
aws_hook_class = RdsHook
|
50
51
|
ui_color = "#eeaa88"
|
51
52
|
ui_fgcolor = "#ffffff"
|
52
53
|
|
@@ -63,10 +64,6 @@ class RdsBaseOperator(BaseOperator):
|
|
63
64
|
|
64
65
|
self._await_interval = 60 # seconds
|
65
66
|
|
66
|
-
@cached_property
|
67
|
-
def hook(self) -> RdsHook:
|
68
|
-
return RdsHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
|
69
|
-
|
70
67
|
def execute(self, context: Context) -> str:
|
71
68
|
"""Different implementations for snapshots, tasks and events."""
|
72
69
|
raise NotImplementedError
|
@@ -92,9 +89,19 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
|
|
92
89
|
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
|
93
90
|
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
|
94
91
|
:param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
|
92
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
93
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
94
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
95
|
+
empty, then default boto3 configuration would be used (and must be
|
96
|
+
maintained on each worker node).
|
97
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
98
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
99
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
100
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
101
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
95
102
|
"""
|
96
103
|
|
97
|
-
template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
|
104
|
+
template_fields = aws_template_fields("db_snapshot_identifier", "db_identifier", "tags")
|
98
105
|
|
99
106
|
def __init__(
|
100
107
|
self,
|
@@ -167,9 +174,14 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
|
|
167
174
|
Only when db_type='instance'
|
168
175
|
:param source_region: The ID of the region that contains the snapshot to be copied
|
169
176
|
:param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
|
177
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
178
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
179
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
180
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
181
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
170
182
|
"""
|
171
183
|
|
172
|
-
template_fields = (
|
184
|
+
template_fields = aws_template_fields(
|
173
185
|
"source_db_snapshot_identifier",
|
174
186
|
"target_db_snapshot_identifier",
|
175
187
|
"tags",
|
@@ -260,9 +272,16 @@ class RdsDeleteDbSnapshotOperator(RdsBaseOperator):
|
|
260
272
|
|
261
273
|
:param db_type: Type of the DB - either "instance" or "cluster"
|
262
274
|
:param db_snapshot_identifier: The identifier for the DB instance or DB cluster snapshot
|
275
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
276
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
277
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
278
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
279
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
263
280
|
"""
|
264
281
|
|
265
|
-
template_fields = (
|
282
|
+
template_fields = aws_template_fields(
|
283
|
+
"db_snapshot_identifier",
|
284
|
+
)
|
266
285
|
|
267
286
|
def __init__(
|
268
287
|
self,
|
@@ -319,9 +338,14 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
|
|
319
338
|
:param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
|
320
339
|
:param waiter_interval: The number of seconds to wait before checking the export status. (default: 30)
|
321
340
|
:param waiter_max_attempts: The number of attempts to make before failing. (default: 40)
|
341
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
342
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
343
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
344
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
345
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
322
346
|
"""
|
323
347
|
|
324
|
-
template_fields = (
|
348
|
+
template_fields = aws_template_fields(
|
325
349
|
"export_task_identifier",
|
326
350
|
"source_arn",
|
327
351
|
"s3_bucket_name",
|
@@ -394,9 +418,16 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
|
|
394
418
|
:param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
|
395
419
|
:param check_interval: The amount of time in seconds to wait between attempts
|
396
420
|
:param max_attempts: The maximum number of attempts to be made
|
421
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
422
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
423
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
424
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
425
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
397
426
|
"""
|
398
427
|
|
399
|
-
template_fields = (
|
428
|
+
template_fields = aws_template_fields(
|
429
|
+
"export_task_identifier",
|
430
|
+
)
|
400
431
|
|
401
432
|
def __init__(
|
402
433
|
self,
|
@@ -450,9 +481,14 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
|
|
450
481
|
:param tags: A dictionary of tags or a list of tags in format `[{"Key": "...", "Value": "..."},]`
|
451
482
|
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
|
452
483
|
:param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
|
484
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
485
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
486
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
487
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
488
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
453
489
|
"""
|
454
490
|
|
455
|
-
template_fields = (
|
491
|
+
template_fields = aws_template_fields(
|
456
492
|
"subscription_name",
|
457
493
|
"sns_topic_arn",
|
458
494
|
"source_type",
|
@@ -513,9 +549,16 @@ class RdsDeleteEventSubscriptionOperator(RdsBaseOperator):
|
|
513
549
|
:ref:`howto/operator:RdsDeleteEventSubscriptionOperator`
|
514
550
|
|
515
551
|
:param subscription_name: The name of the RDS event notification subscription you want to delete
|
552
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
553
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
554
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
555
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
556
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
516
557
|
"""
|
517
558
|
|
518
|
-
template_fields = (
|
559
|
+
template_fields = aws_template_fields(
|
560
|
+
"subscription_name",
|
561
|
+
)
|
519
562
|
|
520
563
|
def __init__(
|
521
564
|
self,
|
@@ -560,9 +603,16 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator):
|
|
560
603
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
561
604
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
562
605
|
(default: False)
|
606
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
607
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
608
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
609
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
610
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
563
611
|
"""
|
564
612
|
|
565
|
-
template_fields = (
|
613
|
+
template_fields = aws_template_fields(
|
614
|
+
"db_instance_identifier", "db_instance_class", "engine", "rds_kwargs"
|
615
|
+
)
|
566
616
|
|
567
617
|
def __init__(
|
568
618
|
self,
|
@@ -652,9 +702,14 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator):
|
|
652
702
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
653
703
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
654
704
|
(default: False)
|
705
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
706
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
707
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
708
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
709
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
655
710
|
"""
|
656
711
|
|
657
|
-
template_fields = ("db_instance_identifier", "rds_kwargs")
|
712
|
+
template_fields = aws_template_fields("db_instance_identifier", "rds_kwargs")
|
658
713
|
|
659
714
|
def __init__(
|
660
715
|
self,
|
@@ -735,9 +790,14 @@ class RdsStartDbOperator(RdsBaseOperator):
|
|
735
790
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
736
791
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
737
792
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
793
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
794
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
795
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
796
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
797
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
738
798
|
"""
|
739
799
|
|
740
|
-
template_fields = ("db_identifier", "db_type")
|
800
|
+
template_fields = aws_template_fields("db_identifier", "db_type")
|
741
801
|
|
742
802
|
def __init__(
|
743
803
|
self,
|
@@ -832,9 +892,14 @@ class RdsStopDbOperator(RdsBaseOperator):
|
|
832
892
|
:param waiter_max_attempts: The maximum number of attempts to check DB instance state
|
833
893
|
:param deferrable: If True, the operator will wait asynchronously for the DB instance to be created.
|
834
894
|
This implies waiting for completion. This mode requires aiobotocore module to be installed.
|
895
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
896
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
897
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
898
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
899
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
835
900
|
"""
|
836
901
|
|
837
|
-
template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
|
902
|
+
template_fields = aws_template_fields("db_identifier", "db_snapshot_identifier", "db_type")
|
838
903
|
|
839
904
|
def __init__(
|
840
905
|
self,
|
@@ -26,6 +26,8 @@ from airflow.exceptions import AirflowException
|
|
26
26
|
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
|
27
27
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
28
28
|
from airflow.providers.amazon.aws.triggers.bedrock import (
|
29
|
+
BedrockBatchInferenceCompletedTrigger,
|
30
|
+
BedrockBatchInferenceScheduledTrigger,
|
29
31
|
BedrockCustomizeModelCompletedTrigger,
|
30
32
|
BedrockIngestionJobTrigger,
|
31
33
|
BedrockKnowledgeBaseActiveTrigger,
|
@@ -34,6 +36,7 @@ from airflow.providers.amazon.aws.triggers.bedrock import (
|
|
34
36
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
35
37
|
|
36
38
|
if TYPE_CHECKING:
|
39
|
+
from airflow.providers.amazon.aws.triggers.bedrock import BedrockBaseBatchInferenceTrigger
|
37
40
|
from airflow.utils.context import Context
|
38
41
|
|
39
42
|
|
@@ -368,3 +371,110 @@ class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
|
|
368
371
|
)
|
369
372
|
else:
|
370
373
|
super().execute(context=context)
|
374
|
+
|
375
|
+
|
376
|
+
class BedrockBatchInferenceSensor(BedrockBaseSensor[BedrockHook]):
|
377
|
+
"""
|
378
|
+
Poll the batch inference job status until it reaches a terminal state; fails if creation fails.
|
379
|
+
|
380
|
+
.. seealso::
|
381
|
+
For more information on how to use this sensor, take a look at the guide:
|
382
|
+
:ref:`howto/sensor:BedrockBatchInferenceSensor`
|
383
|
+
|
384
|
+
:param job_arn: The Amazon Resource Name (ARN) of the batch inference job. (templated)
|
385
|
+
:param success_state: A BedrockBatchInferenceSensor.TargetState; defaults to 'SCHEDULED' (templated)
|
386
|
+
|
387
|
+
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
|
388
|
+
module to be installed.
|
389
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
390
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
|
391
|
+
:param max_retries: Number of times before returning the current state (default: 24)
|
392
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
393
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
394
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
395
|
+
empty, then default boto3 configuration would be used (and must be
|
396
|
+
maintained on each worker node).
|
397
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
398
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
399
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
400
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
401
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
402
|
+
"""
|
403
|
+
|
404
|
+
class SuccessState:
|
405
|
+
"""
|
406
|
+
Target state for the BedrockBatchInferenceSensor.
|
407
|
+
|
408
|
+
Bedrock adds batch inference jobs to a queue, and they may take some time to complete.
|
409
|
+
If you want to wait for the job to complete, use TargetState.COMPLETED, but if you only want
|
410
|
+
to wait until the service confirms that the job is in the queue, use TargetState.SCHEDULED.
|
411
|
+
|
412
|
+
The normal successful progression of states is:
|
413
|
+
Submitted > Validating > Scheduled > InProgress > PartiallyCompleted > Completed
|
414
|
+
"""
|
415
|
+
|
416
|
+
SCHEDULED = "scheduled"
|
417
|
+
COMPLETED = "completed"
|
418
|
+
|
419
|
+
INTERMEDIATE_STATES: tuple[str, ...] # Defined in __init__ based on target state
|
420
|
+
FAILURE_STATES: tuple[str, ...] = ("Failed", "Stopped", "PartiallyCompleted", "Expired")
|
421
|
+
SUCCESS_STATES: tuple[str, ...] # Defined in __init__ based on target state
|
422
|
+
FAILURE_MESSAGE = "Bedrock batch inference job sensor failed."
|
423
|
+
INVALID_SUCCESS_STATE_MESSAGE = "success_state must be an instance of TargetState."
|
424
|
+
|
425
|
+
aws_hook_class = BedrockHook
|
426
|
+
|
427
|
+
template_fields: Sequence[str] = aws_template_fields("job_arn", "success_state")
|
428
|
+
|
429
|
+
def __init__(
|
430
|
+
self,
|
431
|
+
*,
|
432
|
+
job_arn: str,
|
433
|
+
success_state: SuccessState | str = SuccessState.SCHEDULED,
|
434
|
+
poke_interval: int = 120,
|
435
|
+
max_retries: int = 75,
|
436
|
+
**kwargs,
|
437
|
+
) -> None:
|
438
|
+
super().__init__(**kwargs)
|
439
|
+
self.poke_interval = poke_interval
|
440
|
+
self.max_retries = max_retries
|
441
|
+
self.job_arn = job_arn
|
442
|
+
self.success_state = success_state
|
443
|
+
|
444
|
+
base_success_states: tuple[str, ...] = ("Completed",)
|
445
|
+
base_intermediate_states: tuple[str, ...] = ("Submitted", "InProgress", "Stopping", "Validating")
|
446
|
+
scheduled_state = ("Scheduled",)
|
447
|
+
self.trigger_class: type[BedrockBaseBatchInferenceTrigger]
|
448
|
+
|
449
|
+
if self.success_state == BedrockBatchInferenceSensor.SuccessState.COMPLETED:
|
450
|
+
intermediate_states = base_intermediate_states + scheduled_state
|
451
|
+
success_states = base_success_states
|
452
|
+
self.trigger_class = BedrockBatchInferenceCompletedTrigger
|
453
|
+
elif self.success_state == BedrockBatchInferenceSensor.SuccessState.SCHEDULED:
|
454
|
+
intermediate_states = base_intermediate_states
|
455
|
+
success_states = base_success_states + scheduled_state
|
456
|
+
self.trigger_class = BedrockBatchInferenceScheduledTrigger
|
457
|
+
else:
|
458
|
+
raise ValueError(
|
459
|
+
"Success states for BedrockBatchInferenceSensor must be set using a BedrockBatchInferenceSensor.SuccessState"
|
460
|
+
)
|
461
|
+
|
462
|
+
BedrockBatchInferenceSensor.INTERMEDIATE_STATES = intermediate_states or base_intermediate_states
|
463
|
+
BedrockBatchInferenceSensor.SUCCESS_STATES = success_states or base_success_states
|
464
|
+
|
465
|
+
def get_state(self) -> str:
|
466
|
+
return self.hook.conn.get_model_invocation_job(jobIdentifier=self.job_arn)["status"]
|
467
|
+
|
468
|
+
def execute(self, context: Context) -> Any:
|
469
|
+
if self.deferrable:
|
470
|
+
self.defer(
|
471
|
+
trigger=self.trigger_class(
|
472
|
+
job_arn=self.job_arn,
|
473
|
+
waiter_delay=int(self.poke_interval),
|
474
|
+
waiter_max_attempts=self.max_retries,
|
475
|
+
aws_conn_id=self.aws_conn_id,
|
476
|
+
),
|
477
|
+
method_name="poke",
|
478
|
+
)
|
479
|
+
else:
|
480
|
+
super().execute(context=context)
|