apache-airflow-providers-amazon 8.21.0rc1__py3-none-any.whl → 8.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,14 +18,16 @@
18
18
  from __future__ import annotations
19
19
 
20
20
  import abc
21
- from typing import TYPE_CHECKING, Any, Sequence
21
+ from typing import TYPE_CHECKING, Any, Sequence, TypeVar
22
22
 
23
23
  from airflow.configuration import conf
24
24
  from airflow.exceptions import AirflowException, AirflowSkipException
25
- from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
25
+ from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
26
26
  from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27
27
  from airflow.providers.amazon.aws.triggers.bedrock import (
28
28
  BedrockCustomizeModelCompletedTrigger,
29
+ BedrockIngestionJobTrigger,
30
+ BedrockKnowledgeBaseActiveTrigger,
29
31
  BedrockProvisionModelThroughputCompletedTrigger,
30
32
  )
31
33
  from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
@@ -34,7 +36,10 @@ if TYPE_CHECKING:
34
36
  from airflow.utils.context import Context
35
37
 
36
38
 
37
- class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
39
+ _GenericBedrockHook = TypeVar("_GenericBedrockHook", BedrockAgentHook, BedrockHook)
40
+
41
+
42
+ class BedrockBaseSensor(AwsBaseSensor[_GenericBedrockHook]):
38
43
  """
39
44
  General sensor behavior for Amazon Bedrock.
40
45
 
@@ -57,7 +62,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
57
62
  SUCCESS_STATES: tuple[str, ...] = ()
58
63
  FAILURE_MESSAGE = ""
59
64
 
60
- aws_hook_class = BedrockHook
65
+ aws_hook_class: type[_GenericBedrockHook]
61
66
  ui_color = "#66c3ff"
62
67
 
63
68
  def __init__(
@@ -68,7 +73,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
68
73
  super().__init__(**kwargs)
69
74
  self.deferrable = deferrable
70
75
 
71
- def poke(self, context: Context) -> bool:
76
+ def poke(self, context: Context, **kwargs) -> bool:
72
77
  state = self.get_state()
73
78
  if state in self.FAILURE_STATES:
74
79
  # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
@@ -83,7 +88,7 @@ class BedrockBaseSensor(AwsBaseSensor[BedrockHook]):
83
88
  """Implement in subclasses."""
84
89
 
85
90
 
86
- class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
91
+ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor[BedrockHook]):
87
92
  """
88
93
  Poll the state of the model customization job until it reaches a terminal state; fails if the job fails.
89
94
 
@@ -115,6 +120,8 @@ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
115
120
  SUCCESS_STATES: tuple[str, ...] = ("Completed",)
116
121
  FAILURE_MESSAGE = "Bedrock model customization job sensor failed."
117
122
 
123
+ aws_hook_class = BedrockHook
124
+
118
125
  template_fields: Sequence[str] = aws_template_fields("job_name")
119
126
 
120
127
  def __init__(
@@ -148,7 +155,7 @@ class BedrockCustomizeModelCompletedSensor(BedrockBaseSensor):
148
155
  return self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"]
149
156
 
150
157
 
151
- class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
158
+ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor[BedrockHook]):
152
159
  """
153
160
  Poll the provisioned model throughput job until it reaches a terminal state; fails if the job fails.
154
161
 
@@ -180,6 +187,8 @@ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
180
187
  SUCCESS_STATES: tuple[str, ...] = ("InService",)
181
188
  FAILURE_MESSAGE = "Bedrock provision model throughput sensor failed."
182
189
 
190
+ aws_hook_class = BedrockHook
191
+
183
192
  template_fields: Sequence[str] = aws_template_fields("model_id")
184
193
 
185
194
  def __init__(
@@ -211,3 +220,153 @@ class BedrockProvisionModelThroughputCompletedSensor(BedrockBaseSensor):
211
220
  )
212
221
  else:
213
222
  super().execute(context=context)
223
+
224
+
225
+ class BedrockKnowledgeBaseActiveSensor(BedrockBaseSensor[BedrockAgentHook]):
226
+ """
227
+ Poll the Knowledge Base status until it reaches a terminal state; fails if creation fails.
228
+
229
+ .. seealso::
230
+ For more information on how to use this sensor, take a look at the guide:
231
+ :ref:`howto/sensor:BedrockKnowledgeBaseActiveSensor`
232
+
233
+ :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated)
234
+
235
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
236
+ module to be installed.
237
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
238
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 5)
239
+ :param max_retries: Number of times before returning the current state (default: 24)
240
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
241
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
242
+ running Airflow in a distributed manner and aws_conn_id is None or
243
+ empty, then default boto3 configuration would be used (and must be
244
+ maintained on each worker node).
245
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
246
+ :param verify: Whether or not to verify SSL certificates. See:
247
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
248
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
249
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
250
+ """
251
+
252
+ INTERMEDIATE_STATES: tuple[str, ...] = ("CREATING", "UPDATING")
253
+ FAILURE_STATES: tuple[str, ...] = ("DELETING", "FAILED")
254
+ SUCCESS_STATES: tuple[str, ...] = ("ACTIVE",)
255
+ FAILURE_MESSAGE = "Bedrock Knowledge Base Active sensor failed."
256
+
257
+ aws_hook_class = BedrockAgentHook
258
+
259
+ template_fields: Sequence[str] = aws_template_fields("knowledge_base_id")
260
+
261
+ def __init__(
262
+ self,
263
+ *,
264
+ knowledge_base_id: str,
265
+ poke_interval: int = 5,
266
+ max_retries: int = 24,
267
+ **kwargs,
268
+ ) -> None:
269
+ super().__init__(**kwargs)
270
+ self.poke_interval = poke_interval
271
+ self.max_retries = max_retries
272
+ self.knowledge_base_id = knowledge_base_id
273
+
274
+ def get_state(self) -> str:
275
+ return self.hook.conn.get_knowledge_base(knowledgeBaseId=self.knowledge_base_id)["knowledgeBase"][
276
+ "status"
277
+ ]
278
+
279
+ def execute(self, context: Context) -> Any:
280
+ if self.deferrable:
281
+ self.defer(
282
+ trigger=BedrockKnowledgeBaseActiveTrigger(
283
+ knowledge_base_id=self.knowledge_base_id,
284
+ waiter_delay=int(self.poke_interval),
285
+ waiter_max_attempts=self.max_retries,
286
+ aws_conn_id=self.aws_conn_id,
287
+ ),
288
+ method_name="poke",
289
+ )
290
+ else:
291
+ super().execute(context=context)
292
+
293
+
294
+ class BedrockIngestionJobSensor(BedrockBaseSensor[BedrockAgentHook]):
295
+ """
296
+ Poll the ingestion job status until it reaches a terminal state; fails if creation fails.
297
+
298
+ .. seealso::
299
+ For more information on how to use this sensor, take a look at the guide:
300
+ :ref:`howto/sensor:BedrockIngestionJobSensor`
301
+
302
+ :param knowledge_base_id: The unique identifier of the knowledge base for which to get information. (templated)
303
+ :param data_source_id: The unique identifier of the data source in the ingestion job. (templated)
304
+ :param ingestion_job_id: The unique identifier of the ingestion job. (templated)
305
+
306
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
307
+ module to be installed.
308
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
309
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
310
+ :param max_retries: Number of times before returning the current state (default: 10)
311
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
312
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
313
+ running Airflow in a distributed manner and aws_conn_id is None or
314
+ empty, then default boto3 configuration would be used (and must be
315
+ maintained on each worker node).
316
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
317
+ :param verify: Whether or not to verify SSL certificates. See:
318
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
319
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
320
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
321
+ """
322
+
323
+ INTERMEDIATE_STATES: tuple[str, ...] = ("STARTING", "IN_PROGRESS")
324
+ FAILURE_STATES: tuple[str, ...] = ("FAILED",)
325
+ SUCCESS_STATES: tuple[str, ...] = ("COMPLETE",)
326
+ FAILURE_MESSAGE = "Bedrock ingestion job sensor failed."
327
+
328
+ aws_hook_class = BedrockAgentHook
329
+
330
+ template_fields: Sequence[str] = aws_template_fields(
331
+ "knowledge_base_id", "data_source_id", "ingestion_job_id"
332
+ )
333
+
334
+ def __init__(
335
+ self,
336
+ *,
337
+ knowledge_base_id: str,
338
+ data_source_id: str,
339
+ ingestion_job_id: str,
340
+ poke_interval: int = 60,
341
+ max_retries: int = 10,
342
+ **kwargs,
343
+ ) -> None:
344
+ super().__init__(**kwargs)
345
+ self.poke_interval = poke_interval
346
+ self.max_retries = max_retries
347
+ self.knowledge_base_id = knowledge_base_id
348
+ self.data_source_id = data_source_id
349
+ self.ingestion_job_id = ingestion_job_id
350
+
351
+ def get_state(self) -> str:
352
+ return self.hook.conn.get_ingestion_job(
353
+ knowledgeBaseId=self.knowledge_base_id,
354
+ ingestionJobId=self.ingestion_job_id,
355
+ dataSourceId=self.data_source_id,
356
+ )["ingestionJob"]["status"]
357
+
358
+ def execute(self, context: Context) -> Any:
359
+ if self.deferrable:
360
+ self.defer(
361
+ trigger=BedrockIngestionJobTrigger(
362
+ knowledge_base_id=self.knowledge_base_id,
363
+ ingestion_job_id=self.ingestion_job_id,
364
+ data_source_id=self.data_source_id,
365
+ waiter_delay=int(self.poke_interval),
366
+ waiter_max_attempts=self.max_retries,
367
+ aws_conn_id=self.aws_conn_id,
368
+ ),
369
+ method_name="poke",
370
+ )
371
+ else:
372
+ super().execute(context=context)
@@ -0,0 +1,129 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING, Any, Sequence
20
+
21
+ from airflow.configuration import conf
22
+ from airflow.exceptions import AirflowException, AirflowSkipException
23
+ from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook
24
+ from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
25
+ from airflow.providers.amazon.aws.triggers.opensearch_serverless import (
26
+ OpenSearchServerlessCollectionActiveTrigger,
27
+ )
28
+ from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
29
+ from airflow.utils.helpers import exactly_one
30
+
31
+ if TYPE_CHECKING:
32
+ from airflow.utils.context import Context
33
+
34
+
35
+ class OpenSearchServerlessCollectionActiveSensor(AwsBaseSensor[OpenSearchServerlessHook]):
36
+ """
37
+ Poll the state of the Collection until it reaches a terminal state; fails if the query fails.
38
+
39
+ .. seealso::
40
+ For more information on how to use this sensor, take a look at the guide:
41
+ :ref:`howto/sensor:OpenSearchServerlessCollectionAvailableSensor`
42
+
43
+ :param collection_id: A collection ID. You can't provide a name and an ID in the same request.
44
+ :param collection_name: A collection name. You can't provide a name and an ID in the same request.
45
+
46
+ :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore
47
+ module to be installed.
48
+ (default: False, but can be overridden in config file by setting default_deferrable to True)
49
+ :param poke_interval: Polling period in seconds to check for the status of the job. (default: 10)
50
+ :param max_retries: Number of times before returning the current state (default: 60)
51
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
52
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
53
+ running Airflow in a distributed manner and aws_conn_id is None or
54
+ empty, then default boto3 configuration would be used (and must be
55
+ maintained on each worker node).
56
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
57
+ :param verify: Whether or not to verify SSL certificates. See:
58
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
59
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
60
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
61
+ """
62
+
63
+ INTERMEDIATE_STATES = ("CREATING",)
64
+ FAILURE_STATES = (
65
+ "DELETING",
66
+ "FAILED",
67
+ )
68
+ SUCCESS_STATES = ("ACTIVE",)
69
+ FAILURE_MESSAGE = "OpenSearch Serverless Collection sensor failed"
70
+
71
+ aws_hook_class = OpenSearchServerlessHook
72
+ template_fields: Sequence[str] = aws_template_fields(
73
+ "collection_id",
74
+ "collection_name",
75
+ )
76
+ ui_color = "#66c3ff"
77
+
78
+ def __init__(
79
+ self,
80
+ *,
81
+ collection_id: str | None = None,
82
+ collection_name: str | None = None,
83
+ poke_interval: int = 10,
84
+ max_retries: int = 60,
85
+ deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
86
+ **kwargs: Any,
87
+ ) -> None:
88
+ super().__init__(**kwargs)
89
+ if not exactly_one(collection_id is None, collection_name is None):
90
+ raise AttributeError("Either collection_ids or collection_names must be provided, not both.")
91
+ self.collection_id = collection_id
92
+ self.collection_name = collection_name
93
+
94
+ self.poke_interval = poke_interval
95
+ self.max_retries = max_retries
96
+ self.deferrable = deferrable
97
+
98
+ def poke(self, context: Context, **kwargs) -> bool:
99
+ call_args = (
100
+ {"ids": [str(self.collection_id)]}
101
+ if self.collection_id
102
+ else {"names": [str(self.collection_name)]}
103
+ )
104
+ state = self.hook.conn.batch_get_collection(**call_args)["collectionDetails"][0]["status"]
105
+
106
+ if state in self.FAILURE_STATES:
107
+ # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
108
+ if self.soft_fail:
109
+ raise AirflowSkipException(self.FAILURE_MESSAGE)
110
+ raise AirflowException(self.FAILURE_MESSAGE)
111
+
112
+ if state in self.INTERMEDIATE_STATES:
113
+ return False
114
+ return True
115
+
116
+ def execute(self, context: Context) -> Any:
117
+ if self.deferrable:
118
+ self.defer(
119
+ trigger=OpenSearchServerlessCollectionActiveTrigger(
120
+ collection_id=self.collection_id,
121
+ collection_name=self.collection_name,
122
+ waiter_delay=int(self.poke_interval),
123
+ waiter_max_attempts=self.max_retries,
124
+ aws_conn_id=self.aws_conn_id,
125
+ ),
126
+ method_name="poke",
127
+ )
128
+ else:
129
+ super().execute(context=context)
@@ -18,7 +18,7 @@ from __future__ import annotations
18
18
 
19
19
  from typing import TYPE_CHECKING
20
20
 
21
- from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
21
+ from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
22
22
  from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
23
23
 
24
24
  if TYPE_CHECKING:
@@ -61,11 +61,49 @@ class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger):
61
61
  return BedrockHook(aws_conn_id=self.aws_conn_id)
62
62
 
63
63
 
64
+ class BedrockKnowledgeBaseActiveTrigger(AwsBaseWaiterTrigger):
65
+ """
66
+ Trigger when a Bedrock Knowledge Base reaches the ACTIVE state.
67
+
68
+ :param knowledge_base_id: The unique identifier of the knowledge base for which to get information.
69
+
70
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 5)
71
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 24)
72
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ *,
78
+ knowledge_base_id: str,
79
+ waiter_delay: int = 5,
80
+ waiter_max_attempts: int = 24,
81
+ aws_conn_id: str | None = None,
82
+ ) -> None:
83
+ super().__init__(
84
+ serialized_fields={"knowledge_base_id": knowledge_base_id},
85
+ waiter_name="knowledge_base_active",
86
+ waiter_args={"knowledgeBaseId": knowledge_base_id},
87
+ failure_message="Bedrock Knowledge Base creation failed.",
88
+ status_message="Status of Bedrock Knowledge Base job is",
89
+ status_queries=["status"],
90
+ return_key="knowledge_base_id",
91
+ return_value=knowledge_base_id,
92
+ waiter_delay=waiter_delay,
93
+ waiter_max_attempts=waiter_max_attempts,
94
+ aws_conn_id=aws_conn_id,
95
+ )
96
+
97
+ def hook(self) -> AwsGenericHook:
98
+ return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
99
+
100
+
64
101
  class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger):
65
102
  """
66
103
  Trigger when a provisioned throughput job is complete.
67
104
 
68
105
  :param provisioned_model_id: The ARN or name of the provisioned throughput.
106
+
69
107
  :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120)
70
108
  :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75)
71
109
  :param aws_conn_id: The Airflow connection used for AWS credentials.
@@ -95,3 +133,52 @@ class BedrockProvisionModelThroughputCompletedTrigger(AwsBaseWaiterTrigger):
95
133
 
96
134
  def hook(self) -> AwsGenericHook:
97
135
  return BedrockHook(aws_conn_id=self.aws_conn_id)
136
+
137
+
138
+ class BedrockIngestionJobTrigger(AwsBaseWaiterTrigger):
139
+ """
140
+ Trigger when a Bedrock ingestion job reaches the COMPLETE state.
141
+
142
+ :param knowledge_base_id: The unique identifier of the knowledge base for which to get information.
143
+ :param data_source_id: The unique identifier of the data source in the ingestion job.
144
+ :param ingestion_job_id: The unique identifier of the ingestion job.
145
+
146
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
147
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 10)
148
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ *,
154
+ knowledge_base_id: str,
155
+ data_source_id: str,
156
+ ingestion_job_id: str,
157
+ waiter_delay: int = 60,
158
+ waiter_max_attempts: int = 10,
159
+ aws_conn_id: str | None = None,
160
+ ) -> None:
161
+ super().__init__(
162
+ serialized_fields={
163
+ "knowledge_base_id": knowledge_base_id,
164
+ "data_source_id": data_source_id,
165
+ "ingestion_job_id": ingestion_job_id,
166
+ },
167
+ waiter_name="ingestion_job_complete",
168
+ waiter_args={
169
+ "knowledgeBaseId": knowledge_base_id,
170
+ "dataSourceId": data_source_id,
171
+ "ingestionJobId": ingestion_job_id,
172
+ },
173
+ failure_message="Bedrock ingestion job creation failed.",
174
+ status_message="Status of Bedrock ingestion job is",
175
+ status_queries=["status"],
176
+ return_key="ingestion_job_id",
177
+ return_value=ingestion_job_id,
178
+ waiter_delay=waiter_delay,
179
+ waiter_max_attempts=waiter_max_attempts,
180
+ aws_conn_id=aws_conn_id,
181
+ )
182
+
183
+ def hook(self) -> AwsGenericHook:
184
+ return BedrockAgentHook(aws_conn_id=self.aws_conn_id)
@@ -0,0 +1,68 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+ from __future__ import annotations
18
+
19
+ from typing import TYPE_CHECKING
20
+
21
+ from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook
22
+ from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
23
+ from airflow.utils.helpers import exactly_one
24
+
25
+ if TYPE_CHECKING:
26
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
27
+
28
+
29
+ class OpenSearchServerlessCollectionActiveTrigger(AwsBaseWaiterTrigger):
30
+ """
31
+ Trigger when an Amazon OpenSearch Serverless Collection reaches the ACTIVE state.
32
+
33
+ :param collection_id: A collection ID. You can't provide a name and an ID in the same request.
34
+ :param collection_name: A collection name. You can't provide a name and an ID in the same request.
35
+
36
+ :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
37
+ :param waiter_max_attempts: The maximum number of attempts to be made. (default: 20)
38
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ collection_id: str | None = None,
45
+ collection_name: str | None = None,
46
+ waiter_delay: int = 60,
47
+ waiter_max_attempts: int = 20,
48
+ aws_conn_id: str | None = None,
49
+ ) -> None:
50
+ if not exactly_one(collection_id is None, collection_name is None):
51
+ raise AttributeError("Either collection_ids or collection_names must be provided, not both.")
52
+
53
+ super().__init__(
54
+ serialized_fields={"collection_id": collection_id, "collection_name": collection_name},
55
+ waiter_name="collection_available",
56
+ waiter_args={"ids": [collection_id]} if collection_id else {"names": [collection_name]},
57
+ failure_message="OpenSearch Serverless Collection creation failed.",
58
+ status_message="Status of OpenSearch Serverless Collection is",
59
+ status_queries=["status"],
60
+ return_key="collection_id" if collection_id else "collection_name",
61
+ return_value=collection_id if collection_id else collection_name,
62
+ waiter_delay=waiter_delay,
63
+ waiter_max_attempts=waiter_max_attempts,
64
+ aws_conn_id=aws_conn_id,
65
+ )
66
+
67
+ def hook(self) -> AwsGenericHook:
68
+ return OpenSearchServerlessHook(aws_conn_id=self.aws_conn_id)
@@ -0,0 +1,73 @@
1
+ {
2
+ "version": 2,
3
+ "waiters": {
4
+ "knowledge_base_active": {
5
+ "delay": 5,
6
+ "maxAttempts": 24,
7
+ "operation": "getKnowledgeBase",
8
+ "acceptors": [
9
+ {
10
+ "matcher": "path",
11
+ "argument": "knowledgeBase.status",
12
+ "expected": "ACTIVE",
13
+ "state": "success"
14
+ },
15
+ {
16
+ "matcher": "path",
17
+ "argument": "knowledgeBase.status",
18
+ "expected": "CREATING",
19
+ "state": "retry"
20
+ },
21
+ {
22
+ "matcher": "path",
23
+ "argument": "knowledgeBase.status",
24
+ "expected": "DELETING",
25
+ "state": "failure"
26
+ },
27
+ {
28
+ "matcher": "path",
29
+ "argument": "knowledgeBase.status",
30
+ "expected": "UPDATING",
31
+ "state": "retry"
32
+ },
33
+ {
34
+ "matcher": "path",
35
+ "argument": "knowledgeBase.status",
36
+ "expected": "FAILED",
37
+ "state": "failure"
38
+ }
39
+ ]
40
+ },
41
+ "ingestion_job_complete": {
42
+ "delay": 60,
43
+ "maxAttempts": 10,
44
+ "operation": "getIngestionJob",
45
+ "acceptors": [
46
+ {
47
+ "matcher": "path",
48
+ "argument": "ingestionJob.status",
49
+ "expected": "COMPLETE",
50
+ "state": "success"
51
+ },
52
+ {
53
+ "matcher": "path",
54
+ "argument": "ingestionJob.status",
55
+ "expected": "STARTING",
56
+ "state": "retry"
57
+ },
58
+ {
59
+ "matcher": "path",
60
+ "argument": "ingestionJob.status",
61
+ "expected": "IN_PROGRESS",
62
+ "state": "retry"
63
+ },
64
+ {
65
+ "matcher": "path",
66
+ "argument": "ingestionJob.status",
67
+ "expected": "FAILED",
68
+ "state": "failure"
69
+ }
70
+ ]
71
+ }
72
+ }
73
+ }
@@ -0,0 +1,36 @@
1
+ {
2
+ "version": 2,
3
+ "waiters": {
4
+ "collection_available": {
5
+ "operation": "BatchGetCollection",
6
+ "delay": 10,
7
+ "maxAttempts": 120,
8
+ "acceptors": [
9
+ {
10
+ "matcher": "path",
11
+ "argument": "collectionDetails[0].status",
12
+ "expected": "ACTIVE",
13
+ "state": "success"
14
+ },
15
+ {
16
+ "matcher": "path",
17
+ "argument": "collectionDetails[0].status",
18
+ "expected": "DELETING",
19
+ "state": "failure"
20
+ },
21
+ {
22
+ "matcher": "path",
23
+ "argument": "collectionDetails[0].status",
24
+ "expected": "CREATING",
25
+ "state": "retry"
26
+ },
27
+ {
28
+ "matcher": "path",
29
+ "argument": "collectionDetails[0].status",
30
+ "expected": "FAILED",
31
+ "state": "failure"
32
+ }
33
+ ]
34
+ }
35
+ }
36
+ }