apache-airflow-providers-amazon 8.21.0__py3-none-any.whl → 8.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airflow/providers/amazon/__init__.py +3 -6
- airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +7 -2
- airflow/providers/amazon/aws/hooks/bedrock.py +20 -0
- airflow/providers/amazon/aws/hooks/opensearch_serverless.py +39 -0
- airflow/providers/amazon/aws/hooks/s3.py +2 -2
- airflow/providers/amazon/aws/operators/bedrock.py +314 -1
- airflow/providers/amazon/aws/operators/s3.py +27 -7
- airflow/providers/amazon/aws/sensors/bedrock.py +166 -7
- airflow/providers/amazon/aws/sensors/opensearch_serverless.py +129 -0
- airflow/providers/amazon/aws/triggers/bedrock.py +88 -1
- airflow/providers/amazon/aws/triggers/opensearch_serverless.py +68 -0
- airflow/providers/amazon/aws/waiters/bedrock-agent.json +73 -0
- airflow/providers/amazon/aws/waiters/opensearchserverless.json +36 -0
- airflow/providers/amazon/get_provider_info.py +22 -2
- {apache_airflow_providers_amazon-8.21.0.dist-info → apache_airflow_providers_amazon-8.22.0.dist-info}/METADATA +8 -8
- {apache_airflow_providers_amazon-8.21.0.dist-info → apache_airflow_providers_amazon-8.22.0.dist-info}/RECORD +18 -13
- {apache_airflow_providers_amazon-8.21.0.dist-info → apache_airflow_providers_amazon-8.22.0.dist-info}/WHEEL +0 -0
- {apache_airflow_providers_amazon-8.21.0.dist-info → apache_airflow_providers_amazon-8.22.0.dist-info}/entry_points.txt +0 -0
@@ -18,14 +18,16 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import abc
|
21
|
-
from typing import TYPE_CHECKING, Any, Sequence
|
21
|
+
from typing import TYPE_CHECKING, Any, Sequence, TypeVar
|
22
22
|
|
23
23
|
from airflow.configuration import conf
|
24
24
|
from airflow.exceptions import AirflowException, AirflowSkipException
|
25
|
-
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
|
25
|
+
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
|
26
26
|
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
27
27
|
from airflow.providers.amazon.aws.triggers.bedrock import (
|
28
28
|
BedrockCustomizeModelCompletedTrigger,
|
29
|
+
BedrockIngestionJobTrigger,
|
30
|
+
BedrockKnowledgeBaseActiveTrigger,
|
29
31
|
BedrockProvisionModelThroughputCompletedTrigger,
|
30
32
|
)
|
31
33
|
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
@@ -34,7 +36,10 @@ if TYPE_CHECKING:
|
|
34
36
|
from airflow.utils.context import Context
|
35
37
|
|
36
38
|
|
37
|
-
|
39
|
+
_GenericBedrockHook = TypeVar("_GenericBedrockHook", BedrockAgentHook, BedrockHook)
|
40
|
+
|
41
|
+
|
42
|
+
class BedrockBaseSensor(AwsBaseSensor[_GenericBedrockHook]):
|
38
43
|
"""
|
39
44
|
General sensor behavior for Amazon Bedrock.
|
40
45
|
|
@@ -57,7 +62,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
|
|
57
62
|
SUCCESS_STATES: tuple[str, ...] = ()
|
58
63
|
FAILURE_MESSAGE = ""
|
59
64
|
|
60
|
-
aws_hook_class
|
65
|
+
aws_hook_class: type[_GenericBedrockHook]
|
61
66
|
ui_color = "#66c3ff"
|
62
67
|
|
63
68
|
def __init__(
|
@@ -68,7 +73,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
|
|
68
73
|
super().__init__(**kwargs)
|
69
74
|
self.deferrable = deferrable
|
70
75
|
|
71
|
-
def poke(self, context: Context) -> bool:
|
76
|
+
def poke(self, context: Context, **kwargs) -> bool:
|
72
77
|
state = self.get_state()
|
73
78
|
if state in self.FAILURE_STATES:
|
74
79
|
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
@@ -83,7 +88,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
|
|
83
88
|
"""Implement in subclasses."""
|
84
89
|
|
85
90
|
|
86
|
-
class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
|
91
|
+
class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor[BedrockHook]):
|
87
92
|
"""
|
88
93
|
Poll the state of the model customization job until it reaches a terminal state; fails if the job fails.
|
89
94
|
|
@@ -115,6 +120,8 @@ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
|
|
115
120
|
SUCCESS_STATES: tuple[str, ...] = ("Completed",)
|
116
121
|
FAILURE_MESSAGE = "Bedrock model customization job sensor failed."
|
117
122
|
|
123
|
+
aws_hook_class = BedrockHook
|
124
|
+
|
118
125
|
template_fields: Sequence[str] = aws_template_fields("job_name")
|
119
126
|
|
120
127
|
def __init__(
|
@@ -148,7 +155,7 @@ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
|
|
148
155
|
return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
|
149
156
|
|
150
157
|
|
151
|
-
class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
|
158
|
+
class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor[BedrockHook]):
|
152
159
|
"""
|
153
160
|
Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails.
|
154
161
|
|
@@ -180,6 +187,8 @@ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
|
|
180
187
|
SUCCESS_STATES: tuple[str, ...] = ("InService",)
|
181
188
|
FAILURE_MESSAGE = "Bedrock provision model throughput sensor failed."
|
182
189
|
|
190
|
+
aws_hook_class = BedrockHook
|
191
|
+
|
183
192
|
template_fields: Sequence[str] = aws_template_fields("model_id")
|
184
193
|
|
185
194
|
def __init__(
|
@@ -211,3 +220,153 @@ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
|
|
211
220
|
)
|
212
221
|
else:
|
213
222
|
super().execute(context=context)
|
223
|
+
|
224
|
+
|
225
|
+
class BedrockKnowledgeBaseActiveSensor(BedrockBaseSensor[BedrockAgentHook]):
|
226
|
+
"""
|
227
|
+
Poll the Knowledge Base status until it reaches a terminal state; fails if creation fails.
|
228
|
+
|
229
|
+
.. seealso::
|
230
|
+
For more information on how to use this sensor, take a look at the guide:
|
231
|
+
:ref:`howto/sensor:BedrockKnowledgeBaseActiveSensor`
|
232
|
+
|
233
|
+
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated)
|
234
|
+
|
235
|
+
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
|
236
|
+
module to be installed.
|
237
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
238
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
|
239
|
+
:param max_retries: Number of times before returning the current state (default: 24)
|
240
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
241
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
242
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
243
|
+
empty, then default boto3 configuration would be used (and must be
|
244
|
+
maintained on each worker node).
|
245
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
246
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
247
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
248
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
249
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
250
|
+
"""
|
251
|
+
|
252
|
+
INTERMEDIATE_STATES: tuple[str, ...] = ("CREATING", "UPDATING")
|
253
|
+
FAILURE_STATES: tuple[str, ...] = ("DELETING", "FAILED")
|
254
|
+
SUCCESS_STATES: tuple[str, ...] = ("ACTIVE",)
|
255
|
+
FAILURE_MESSAGE = "Bedrock Knowledge Base Active sensor failed."
|
256
|
+
|
257
|
+
aws_hook_class = BedrockAgentHook
|
258
|
+
|
259
|
+
template_fields: Sequence[str] = aws_template_fields("knowledge_base_id")
|
260
|
+
|
261
|
+
def __init__(
|
262
|
+
self,
|
263
|
+
*,
|
264
|
+
knowledge_base_id: str,
|
265
|
+
poke_interval: int = 5,
|
266
|
+
max_retries: int = 24,
|
267
|
+
**kwargs,
|
268
|
+
) -> None:
|
269
|
+
super().__init__(**kwargs)
|
270
|
+
self.poke_interval = poke_interval
|
271
|
+
self.max_retries = max_retries
|
272
|
+
self.knowledge_base_id = knowledge_base_id
|
273
|
+
|
274
|
+
def get_state(self) -> str:
|
275
|
+
return self.hook.conn.get_knowledge_base(knowledgeBaseId=self.knowledge_base_id)["knowledgeBase"][
|
276
|
+
"status"
|
277
|
+
]
|
278
|
+
|
279
|
+
def execute(self, context: Context) -> Any:
|
280
|
+
if self.deferrable:
|
281
|
+
self.defer(
|
282
|
+
trigger=BedrockKnowledgeBaseActiveTrigger(
|
283
|
+
knowledge_base_id=self.knowledge_base_id,
|
284
|
+
waiter_delay=int(self.poke_interval),
|
285
|
+
waiter_max_attempts=self.max_retries,
|
286
|
+
aws_conn_id=self.aws_conn_id,
|
287
|
+
),
|
288
|
+
method_name="poke",
|
289
|
+
)
|
290
|
+
else:
|
291
|
+
super().execute(context=context)
|
292
|
+
|
293
|
+
|
294
|
+
class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
|
295
|
+
"""
|
296
|
+
Poll the ingestion job status until it reaches a terminal state; fails if creation fails.
|
297
|
+
|
298
|
+
.. seealso::
|
299
|
+
For more information on how to use this sensor, take a look at the guide:
|
300
|
+
:ref:`howto/sensor:BedrockIngestionJobSensor`
|
301
|
+
|
302
|
+
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated)
|
303
|
+
:param data_source_id: The unique identifier of the data source in the ingestion job. (templated)
|
304
|
+
:param ingestion_job_id: The unique identifier of the ingestion job. (templated)
|
305
|
+
|
306
|
+
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
|
307
|
+
module to be installed.
|
308
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
309
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
|
310
|
+
:param max_retries: Number of times before returning the current state (default: 10)
|
311
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
312
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
313
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
314
|
+
empty, then default boto3 configuration would be used (and must be
|
315
|
+
maintained on each worker node).
|
316
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
317
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
318
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
319
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
320
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
321
|
+
"""
|
322
|
+
|
323
|
+
INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "IN_PROGRESS")
|
324
|
+
FAILURE_STATES: tuple[str, ...] = ("FAILED",)
|
325
|
+
SUCCESS_STATES: tuple[str, ...] = ("COMPLETE",)
|
326
|
+
FAILURE_MESSAGE = "Bedrock ingestion job sensor failed."
|
327
|
+
|
328
|
+
aws_hook_class = BedrockAgentHook
|
329
|
+
|
330
|
+
template_fields: Sequence[str] = aws_template_fields(
|
331
|
+
"knowledge_base_id", "data_source_id", "ingestion_job_id"
|
332
|
+
)
|
333
|
+
|
334
|
+
def __init__(
|
335
|
+
self,
|
336
|
+
*,
|
337
|
+
knowledge_base_id: str,
|
338
|
+
data_source_id: str,
|
339
|
+
ingestion_job_id: str,
|
340
|
+
poke_interval: int = 60,
|
341
|
+
max_retries: int = 10,
|
342
|
+
**kwargs,
|
343
|
+
) -> None:
|
344
|
+
super().__init__(**kwargs)
|
345
|
+
self.poke_interval = poke_interval
|
346
|
+
self.max_retries = max_retries
|
347
|
+
self.knowledge_base_id = knowledge_base_id
|
348
|
+
self.data_source_id = data_source_id
|
349
|
+
self.ingestion_job_id = ingestion_job_id
|
350
|
+
|
351
|
+
def get_state(self) -> str:
|
352
|
+
return self.hook.conn.get_ingestion_job(
|
353
|
+
knowledgeBaseId=self.knowledge_base_id,
|
354
|
+
ingestionJobId=self.ingestion_job_id,
|
355
|
+
dataSourceId=self.data_source_id,
|
356
|
+
)["ingestionJob"]["status"]
|
357
|
+
|
358
|
+
def execute(self, context: Context) -> Any:
|
359
|
+
if self.deferrable:
|
360
|
+
self.defer(
|
361
|
+
trigger=BedrockIngestionJobTrigger(
|
362
|
+
knowledge_base_id=self.knowledge_base_id,
|
363
|
+
ingestion_job_id=self.ingestion_job_id,
|
364
|
+
data_source_id=self.data_source_id,
|
365
|
+
waiter_delay=int(self.poke_interval),
|
366
|
+
waiter_max_attempts=self.max_retries,
|
367
|
+
aws_conn_id=self.aws_conn_id,
|
368
|
+
),
|
369
|
+
method_name="poke",
|
370
|
+
)
|
371
|
+
else:
|
372
|
+
super().execute(context=context)
|
@@ -0,0 +1,129 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
20
|
+
|
21
|
+
from airflow.configuration import conf
|
22
|
+
from airflow.exceptions import AirflowException, AirflowSkipException
|
23
|
+
from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook
|
24
|
+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
|
25
|
+
from airflow.providers.amazon.aws.triggers.opensearch_serverless import (
|
26
|
+
OpenSearchServerlessCollectionActiveTrigger,
|
27
|
+
)
|
28
|
+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
|
29
|
+
from airflow.utils.helpers import exactly_one
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from airflow.utils.context import Context
|
33
|
+
|
34
|
+
|
35
|
+
class OpenSearchServerlessCollectionActiveSensor(AwsBaseSensor[OpenSearchServerlessHook]):
|
36
|
+
"""
|
37
|
+
Poll the state of the Collection until it reaches a terminal state; fails if the query fails.
|
38
|
+
|
39
|
+
.. seealso::
|
40
|
+
For more information on how to use this sensor, take a look at the guide:
|
41
|
+
:ref:`howto/sensor:OpenSearchServerlessCollectionAvailableSensor`
|
42
|
+
|
43
|
+
:param collection_id: A collection ID. You can't provide a name and an ID in the same request.
|
44
|
+
:param collection_name: A collection name. You can't provide a name and an ID in the same request.
|
45
|
+
|
46
|
+
:param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
|
47
|
+
module to be installed.
|
48
|
+
(default: False, but can be overridden in config file by setting default_deferrable to True)
|
49
|
+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 10)
|
50
|
+
:param max_retries: Number of times before returning the current state (default: 60)
|
51
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
52
|
+
If this is ``None`` or empty then the default boto3 behaviour is used. If
|
53
|
+
running Airflow in a distributed manner and aws_conn_id is None or
|
54
|
+
empty, then default boto3 configuration would be used (and must be
|
55
|
+
maintained on each worker node).
|
56
|
+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
|
57
|
+
:param verify: Whether or not to verify SSL certificates. See:
|
58
|
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
|
59
|
+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
|
60
|
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
|
61
|
+
"""
|
62
|
+
|
63
|
+
INTERMEDIATE_STATES = ("CREATING",)
|
64
|
+
FAILURE_STATES = (
|
65
|
+
"DELETING",
|
66
|
+
"FAILED",
|
67
|
+
)
|
68
|
+
SUCCESS_STATES = ("ACTIVE",)
|
69
|
+
FAILURE_MESSAGE = "OpenSearch Serverless Collection sensor failed"
|
70
|
+
|
71
|
+
aws_hook_class = OpenSearchServerlessHook
|
72
|
+
template_fields: Sequence[str] = aws_template_fields(
|
73
|
+
"collection_id",
|
74
|
+
"collection_name",
|
75
|
+
)
|
76
|
+
ui_color = "#66c3ff"
|
77
|
+
|
78
|
+
def __init__(
|
79
|
+
self,
|
80
|
+
*,
|
81
|
+
collection_id: str | None = None,
|
82
|
+
collection_name: str | None = None,
|
83
|
+
poke_interval: int = 10,
|
84
|
+
max_retries: int = 60,
|
85
|
+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
|
86
|
+
**kwargs: Any,
|
87
|
+
) -> None:
|
88
|
+
super().__init__(**kwargs)
|
89
|
+
if not exactly_one(collection_id is None, collection_name is None):
|
90
|
+
raise AttributeError("Either collection_ids or collection_names must be provided, not both.")
|
91
|
+
self.collection_id = collection_id
|
92
|
+
self.collection_name = collection_name
|
93
|
+
|
94
|
+
self.poke_interval = poke_interval
|
95
|
+
self.max_retries = max_retries
|
96
|
+
self.deferrable = deferrable
|
97
|
+
|
98
|
+
def poke(self, context: Context, **kwargs) -> bool:
|
99
|
+
call_args = (
|
100
|
+
{"ids": [str(self.collection_id)]}
|
101
|
+
if self.collection_id
|
102
|
+
else {"names": [str(self.collection_name)]}
|
103
|
+
)
|
104
|
+
state = self.hook.conn.batch_get_collection(**call_args)["collectionDetails"][0]["status"]
|
105
|
+
|
106
|
+
if state in self.FAILURE_STATES:
|
107
|
+
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
|
108
|
+
if self.soft_fail:
|
109
|
+
raise AirflowSkipException(self.FAILURE_MESSAGE)
|
110
|
+
raise AirflowException(self.FAILURE_MESSAGE)
|
111
|
+
|
112
|
+
if state in self.INTERMEDIATE_STATES:
|
113
|
+
return False
|
114
|
+
return True
|
115
|
+
|
116
|
+
def execute(self, context: Context) -> Any:
|
117
|
+
if self.deferrable:
|
118
|
+
self.defer(
|
119
|
+
trigger=OpenSearchServerlessCollectionActiveTrigger(
|
120
|
+
collection_id=self.collection_id,
|
121
|
+
collection_name=self.collection_name,
|
122
|
+
waiter_delay=int(self.poke_interval),
|
123
|
+
waiter_max_attempts=self.max_retries,
|
124
|
+
aws_conn_id=self.aws_conn_id,
|
125
|
+
),
|
126
|
+
method_name="poke",
|
127
|
+
)
|
128
|
+
else:
|
129
|
+
super().execute(context=context)
|
@@ -18,7 +18,7 @@ from __future__ import annotations
|
|
18
18
|
|
19
19
|
from typing import TYPE_CHECKING
|
20
20
|
|
21
|
-
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
|
21
|
+
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
|
22
22
|
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
23
|
|
24
24
|
if TYPE_CHECKING:
|
@@ -61,11 +61,49 @@ class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger):
|
|
61
61
|
return BedrockHook(aws_conn_id=self.aws_conn_id)
|
62
62
|
|
63
63
|
|
64
|
+
class BedrockKnowledgeBaseActiveTrigger(AwsBaseWaiterTrigger):
|
65
|
+
"""
|
66
|
+
Trigger when a Bedrock Knowledge Base reaches the ACTIVE state.
|
67
|
+
|
68
|
+
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information.
|
69
|
+
|
70
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 5)
|
71
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 24)
|
72
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
*,
|
78
|
+
knowledge_base_id: str,
|
79
|
+
waiter_delay: int = 5,
|
80
|
+
waiter_max_attempts: int = 24,
|
81
|
+
aws_conn_id: str | None = None,
|
82
|
+
) -> None:
|
83
|
+
super().__init__(
|
84
|
+
serialized_fields={"knowledge_base_id": knowledge_base_id},
|
85
|
+
waiter_name="knowledge_base_active",
|
86
|
+
waiter_args={"knowledgeBaseId": knowledge_base_id},
|
87
|
+
failure_message="Bedrock Knowledge Base creation failed.",
|
88
|
+
status_message="Status of Bedrock Knowledge Base job is",
|
89
|
+
status_queries=["status"],
|
90
|
+
return_key="knowledge_base_id",
|
91
|
+
return_value=knowledge_base_id,
|
92
|
+
waiter_delay=waiter_delay,
|
93
|
+
waiter_max_attempts=waiter_max_attempts,
|
94
|
+
aws_conn_id=aws_conn_id,
|
95
|
+
)
|
96
|
+
|
97
|
+
def hook(self) -> AwsGenericHook:
|
98
|
+
return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
|
99
|
+
|
100
|
+
|
64
101
|
class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger):
|
65
102
|
"""
|
66
103
|
Trigger when a provisioned throughput job is complete.
|
67
104
|
|
68
105
|
:param provisioned_model_id: The ARN or name of the provisioned throughput.
|
106
|
+
|
69
107
|
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
|
70
108
|
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
|
71
109
|
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
@@ -95,3 +133,52 @@ class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger):
|
|
95
133
|
|
96
134
|
def hook(self) -> AwsGenericHook:
|
97
135
|
return BedrockHook(aws_conn_id=self.aws_conn_id)
|
136
|
+
|
137
|
+
|
138
|
+
class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger):
|
139
|
+
"""
|
140
|
+
Trigger when a Bedrock ingestion job reaches the COMPLETE state.
|
141
|
+
|
142
|
+
:param knowledge_base_id: The unique identifier of the knowledge base for which to get information.
|
143
|
+
:param data_source_id: The unique identifier of the data source in the ingestion job.
|
144
|
+
:param ingestion_job_id: The unique identifier of the ingestion job.
|
145
|
+
|
146
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
147
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 10)
|
148
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
149
|
+
"""
|
150
|
+
|
151
|
+
def __init__(
|
152
|
+
self,
|
153
|
+
*,
|
154
|
+
knowledge_base_id: str,
|
155
|
+
data_source_id: str,
|
156
|
+
ingestion_job_id: str,
|
157
|
+
waiter_delay: int = 60,
|
158
|
+
waiter_max_attempts: int = 10,
|
159
|
+
aws_conn_id: str | None = None,
|
160
|
+
) -> None:
|
161
|
+
super().__init__(
|
162
|
+
serialized_fields={
|
163
|
+
"knowledge_base_id": knowledge_base_id,
|
164
|
+
"data_source_id": data_source_id,
|
165
|
+
"ingestion_job_id": ingestion_job_id,
|
166
|
+
},
|
167
|
+
waiter_name="ingestion_job_complete",
|
168
|
+
waiter_args={
|
169
|
+
"knowledgeBaseId": knowledge_base_id,
|
170
|
+
"dataSourceId": data_source_id,
|
171
|
+
"ingestionJobId": ingestion_job_id,
|
172
|
+
},
|
173
|
+
failure_message="Bedrock ingestion job creation failed.",
|
174
|
+
status_message="Status of Bedrock ingestion job is",
|
175
|
+
status_queries=["status"],
|
176
|
+
return_key="ingestion_job_id",
|
177
|
+
return_value=ingestion_job_id,
|
178
|
+
waiter_delay=waiter_delay,
|
179
|
+
waiter_max_attempts=waiter_max_attempts,
|
180
|
+
aws_conn_id=aws_conn_id,
|
181
|
+
)
|
182
|
+
|
183
|
+
def hook(self) -> AwsGenericHook:
|
184
|
+
return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Licensed to the Apache Software Foundation (ASF) under one
|
2
|
+
# or more contributor license agreements. See the NOTICE file
|
3
|
+
# distributed with this work for additional information
|
4
|
+
# regarding copyright ownership. The ASF licenses this file
|
5
|
+
# to you under the Apache License, Version 2.0 (the
|
6
|
+
# "License"); you may not use this file except in compliance
|
7
|
+
# with the License. You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing,
|
12
|
+
# software distributed under the License is distributed on an
|
13
|
+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
14
|
+
# KIND, either express or implied. See the License for the
|
15
|
+
# specific language governing permissions and limitations
|
16
|
+
# under the License.
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook
|
22
|
+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
|
23
|
+
from airflow.utils.helpers import exactly_one
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
|
27
|
+
|
28
|
+
|
29
|
+
class OpenSearchServerlessCollectionActiveTrigger(AwsBaseWaiterTrigger):
|
30
|
+
"""
|
31
|
+
Trigger when an Amazon OpenSearch Serverless Collection reaches the ACTIVE state.
|
32
|
+
|
33
|
+
:param collection_id: A collection ID. You can't provide a name and an ID in the same request.
|
34
|
+
:param collection_name: A collection name. You can't provide a name and an ID in the same request.
|
35
|
+
|
36
|
+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
|
37
|
+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 20)
|
38
|
+
:param aws_conn_id: The Airflow connection used for AWS credentials.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
*,
|
44
|
+
collection_id: str | None = None,
|
45
|
+
collection_name: str | None = None,
|
46
|
+
waiter_delay: int = 60,
|
47
|
+
waiter_max_attempts: int = 20,
|
48
|
+
aws_conn_id: str | None = None,
|
49
|
+
) -> None:
|
50
|
+
if not exactly_one(collection_id is None, collection_name is None):
|
51
|
+
raise AttributeError("Either collection_ids or collection_names must be provided, not both.")
|
52
|
+
|
53
|
+
super().__init__(
|
54
|
+
serialized_fields={"collection_id": collection_id, "collection_name": collection_name},
|
55
|
+
waiter_name="collection_available",
|
56
|
+
waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]},
|
57
|
+
failure_message="OpenSearch Serverless Collection creation failed.",
|
58
|
+
status_message="Status of OpenSearch Serverless Collection is",
|
59
|
+
status_queries=["status"],
|
60
|
+
return_key="collection_id" if collection_id else "collection_name",
|
61
|
+
return_value=collection_id if collection_id else collection_name,
|
62
|
+
waiter_delay=waiter_delay,
|
63
|
+
waiter_max_attempts=waiter_max_attempts,
|
64
|
+
aws_conn_id=aws_conn_id,
|
65
|
+
)
|
66
|
+
|
67
|
+
def hook(self) -> AwsGenericHook:
|
68
|
+
return OpenSearchServerlessHook(aws_conn_id=self.aws_conn_id)
|
@@ -0,0 +1,73 @@
|
|
1
|
+
{
|
2
|
+
"version": 2,
|
3
|
+
"waiters": {
|
4
|
+
"knowledge_base_active": {
|
5
|
+
"delay": 5,
|
6
|
+
"maxAttempts": 24,
|
7
|
+
"operation": "getKnowledgeBase",
|
8
|
+
"acceptors": [
|
9
|
+
{
|
10
|
+
"matcher": "path",
|
11
|
+
"argument": "knowledgeBase.status",
|
12
|
+
"expected": "ACTIVE",
|
13
|
+
"state": "success"
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"matcher": "path",
|
17
|
+
"argument": "knowledgeBase.status",
|
18
|
+
"expected": "CREATING",
|
19
|
+
"state": "retry"
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"matcher": "path",
|
23
|
+
"argument": "knowledgeBase.status",
|
24
|
+
"expected": "DELETING",
|
25
|
+
"state": "failure"
|
26
|
+
},
|
27
|
+
{
|
28
|
+
"matcher": "path",
|
29
|
+
"argument": "knowledgeBase.status",
|
30
|
+
"expected": "UPDATING",
|
31
|
+
"state": "retry"
|
32
|
+
},
|
33
|
+
{
|
34
|
+
"matcher": "path",
|
35
|
+
"argument": "knowledgeBase.status",
|
36
|
+
"expected": "FAILED",
|
37
|
+
"state": "failure"
|
38
|
+
}
|
39
|
+
]
|
40
|
+
},
|
41
|
+
"ingestion_job_complete": {
|
42
|
+
"delay": 60,
|
43
|
+
"maxAttempts": 10,
|
44
|
+
"operation": "getIngestionJob",
|
45
|
+
"acceptors": [
|
46
|
+
{
|
47
|
+
"matcher": "path",
|
48
|
+
"argument": "ingestionJob.status",
|
49
|
+
"expected": "COMPLETE",
|
50
|
+
"state": "success"
|
51
|
+
},
|
52
|
+
{
|
53
|
+
"matcher": "path",
|
54
|
+
"argument": "ingestionJob.status",
|
55
|
+
"expected": "STARTING",
|
56
|
+
"state": "retry"
|
57
|
+
},
|
58
|
+
{
|
59
|
+
"matcher": "path",
|
60
|
+
"argument": "ingestionJob.status",
|
61
|
+
"expected": "IN_PROGRESS",
|
62
|
+
"state": "retry"
|
63
|
+
},
|
64
|
+
{
|
65
|
+
"matcher": "path",
|
66
|
+
"argument": "ingestionJob.status",
|
67
|
+
"expected": "FAILED",
|
68
|
+
"state": "failure"
|
69
|
+
}
|
70
|
+
]
|
71
|
+
}
|
72
|
+
}
|
73
|
+
}
|
@@ -0,0 +1,36 @@
|
|
1
|
+
{
|
2
|
+
"version": 2,
|
3
|
+
"waiters": {
|
4
|
+
"collection_available": {
|
5
|
+
"operation": "BatchGetCollection",
|
6
|
+
"delay": 10,
|
7
|
+
"maxAttempts": 120,
|
8
|
+
"acceptors": [
|
9
|
+
{
|
10
|
+
"matcher": "path",
|
11
|
+
"argument": "collectionDetails[0].status",
|
12
|
+
"expected": "ACTIVE",
|
13
|
+
"state": "success"
|
14
|
+
},
|
15
|
+
{
|
16
|
+
"matcher": "path",
|
17
|
+
"argument": "collectionDetails[0].status",
|
18
|
+
"expected": "DELETING",
|
19
|
+
"state": "failure"
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"matcher": "path",
|
23
|
+
"argument": "collectionDetails[0].status",
|
24
|
+
"expected": "CREATING",
|
25
|
+
"state": "retry"
|
26
|
+
},
|
27
|
+
{
|
28
|
+
"matcher": "path",
|
29
|
+
"argument": "collectionDetails[0].status",
|
30
|
+
"expected": "FAILED",
|
31
|
+
"state": "failure"
|
32
|
+
}
|
33
|
+
]
|
34
|
+
}
|
35
|
+
}
|
36
|
+
}
|