amazon-bedrock-haystack 3.5.2__py3-none-any.whl → 3.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {amazon_bedrock_haystack-3.5.2.dist-info → amazon_bedrock_haystack-3.6.1.dist-info}/METADATA +2 -2
- {amazon_bedrock_haystack-3.5.2.dist-info → amazon_bedrock_haystack-3.6.1.dist-info}/RECORD +12 -12
- haystack_integrations/common/amazon_bedrock/utils.py +3 -3
- haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +5 -2
- haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py +5 -2
- haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +87 -36
- haystack_integrations/components/generators/amazon_bedrock/chat/utils.py +368 -184
- haystack_integrations/components/generators/amazon_bedrock/generator.py +3 -3
- haystack_integrations/components/rankers/amazon_bedrock/__init__.py +2 -2
- haystack_integrations/components/rankers/amazon_bedrock/ranker.py +30 -8
- {amazon_bedrock_haystack-3.5.2.dist-info → amazon_bedrock_haystack-3.6.1.dist-info}/WHEEL +0 -0
- {amazon_bedrock_haystack-3.5.2.dist-info → amazon_bedrock_haystack-3.6.1.dist-info}/licenses/LICENSE.txt +0 -0
{amazon_bedrock_haystack-3.5.2.dist-info → amazon_bedrock_haystack-3.6.1.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: amazon-bedrock-haystack
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.6.1
|
|
4
4
|
Summary: An integration of Amazon Bedrock as an AmazonBedrockGenerator component.
|
|
5
5
|
Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme
|
|
6
6
|
Project-URL: Issues, https://github.com/deepset-ai/haystack-core-integrations/issues
|
|
@@ -21,7 +21,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
|
21
21
|
Requires-Python: >=3.9
|
|
22
22
|
Requires-Dist: aioboto3>=14.0.0
|
|
23
23
|
Requires-Dist: boto3>=1.28.57
|
|
24
|
-
Requires-Dist: haystack-ai
|
|
24
|
+
Requires-Dist: haystack-ai>=2.13.1
|
|
25
25
|
Description-Content-Type: text/markdown
|
|
26
26
|
|
|
27
27
|
# amazon-bedrock-haystack
|
|
@@ -1,18 +1,18 @@
|
|
|
1
1
|
haystack_integrations/common/amazon_bedrock/__init__.py,sha256=6GZ8Y3Lw0rLOsOAqi6Tu5mZC977UzQvgDxKpOWr8IQw,110
|
|
2
2
|
haystack_integrations/common/amazon_bedrock/errors.py,sha256=ReheDbY7L3EJkWcUoih6lWHjbPHg2TlUs9SnXIKK7Gg,744
|
|
3
|
-
haystack_integrations/common/amazon_bedrock/utils.py,sha256=
|
|
3
|
+
haystack_integrations/common/amazon_bedrock/utils.py,sha256=dHUWzHYT0A8_eLDpVkwDhmDpprYbFlWsGg0FOS0uF0I,2720
|
|
4
4
|
haystack_integrations/components/embedders/amazon_bedrock/__init__.py,sha256=CFqYmAVq2aavlMkZHYScKHOTwwETdRzRZITMqGhJ9Kw,298
|
|
5
|
-
haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py,sha256=
|
|
6
|
-
haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py,sha256=
|
|
5
|
+
haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py,sha256=ZdtM6HpHQwbKfjmfOK6gkIQPPbI0n8_pWRrR6lyXmr8,13321
|
|
6
|
+
haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py,sha256=gpvu6IMoycUXrn4r1OH5yEIheiDxHf2T5fdJJO4DfW0,9202
|
|
7
7
|
haystack_integrations/components/generators/amazon_bedrock/__init__.py,sha256=lv4NouIVm78YavUssWQrHHP_81u-7j21qW8v1kZMJPQ,284
|
|
8
8
|
haystack_integrations/components/generators/amazon_bedrock/adapters.py,sha256=cnlfmie4HfEX4nipSXSDk_3koy7HYZ-ezimGN6BozQ0,19543
|
|
9
|
-
haystack_integrations/components/generators/amazon_bedrock/generator.py,sha256=
|
|
9
|
+
haystack_integrations/components/generators/amazon_bedrock/generator.py,sha256=NgywyiKYazEbsLAcGcOPUT4blWhYYOJ9WjO-HWDvu7I,14576
|
|
10
10
|
haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py,sha256=6GZ8Y3Lw0rLOsOAqi6Tu5mZC977UzQvgDxKpOWr8IQw,110
|
|
11
|
-
haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=
|
|
12
|
-
haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=
|
|
13
|
-
haystack_integrations/components/rankers/amazon_bedrock/__init__.py,sha256=
|
|
14
|
-
haystack_integrations/components/rankers/amazon_bedrock/ranker.py,sha256=
|
|
15
|
-
amazon_bedrock_haystack-3.
|
|
16
|
-
amazon_bedrock_haystack-3.
|
|
17
|
-
amazon_bedrock_haystack-3.
|
|
18
|
-
amazon_bedrock_haystack-3.
|
|
11
|
+
haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=M9I0sB8LFrXCgoyr5ik2ZPPHyB0b4nJFKX7GARKsk8Y,23384
|
|
12
|
+
haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=kBSaU_ZqzL-7a7nplezjb4XRBy51pt-4VULoX5lq21A,21148
|
|
13
|
+
haystack_integrations/components/rankers/amazon_bedrock/__init__.py,sha256=Zrc3BSVkEaXYpliEi6hKG9bqW4J7DNk93p50SuoyT1Q,107
|
|
14
|
+
haystack_integrations/components/rankers/amazon_bedrock/ranker.py,sha256=x4QEVkbFM-jMFHx-xmk571wtrohnPLtkIWMhCyg4_II,12278
|
|
15
|
+
amazon_bedrock_haystack-3.6.1.dist-info/METADATA,sha256=oRIb-2Nv642N0wmapQgykocVVjxzwM7rG1_TNOoF6Vs,2225
|
|
16
|
+
amazon_bedrock_haystack-3.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
17
|
+
amazon_bedrock_haystack-3.6.1.dist-info/licenses/LICENSE.txt,sha256=B05uMshqTA74s-0ltyHKI6yoPfJ3zYgQbvcXfDVGFf8,10280
|
|
18
|
+
amazon_bedrock_haystack-3.6.1.dist-info/RECORD,,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
|
-
import aioboto3
|
|
3
|
+
import aioboto3
|
|
4
4
|
import boto3
|
|
5
5
|
from botocore.exceptions import BotoCoreError
|
|
6
6
|
|
|
@@ -23,7 +23,7 @@ def get_aws_session(
|
|
|
23
23
|
aws_profile_name: Optional[str] = None,
|
|
24
24
|
async_mode: bool = False,
|
|
25
25
|
**kwargs,
|
|
26
|
-
):
|
|
26
|
+
) -> Union[boto3.Session, aioboto3.Session]:
|
|
27
27
|
"""
|
|
28
28
|
Creates an AWS Session with the given parameters.
|
|
29
29
|
Checks if the provided AWS credentials are valid and can be used to connect to AWS.
|
|
@@ -75,7 +75,7 @@ class AmazonBedrockDocumentEmbedder:
|
|
|
75
75
|
embedding_separator: str = "\n",
|
|
76
76
|
boto3_config: Optional[Dict[str, Any]] = None,
|
|
77
77
|
**kwargs,
|
|
78
|
-
):
|
|
78
|
+
) -> None:
|
|
79
79
|
"""
|
|
80
80
|
Initializes the AmazonBedrockDocumentEmbedder with the provided parameters. The parameters are passed to the
|
|
81
81
|
Amazon Bedrock client.
|
|
@@ -234,7 +234,7 @@ class AmazonBedrockDocumentEmbedder:
|
|
|
234
234
|
return documents
|
|
235
235
|
|
|
236
236
|
@component.output_types(documents=List[Document])
|
|
237
|
-
def run(self, documents: List[Document]):
|
|
237
|
+
def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
|
|
238
238
|
"""Embed the provided `Document`s using the specified model.
|
|
239
239
|
|
|
240
240
|
:param documents: The `Document`s to embed.
|
|
@@ -253,6 +253,9 @@ class AmazonBedrockDocumentEmbedder:
|
|
|
253
253
|
documents_with_embeddings = self._embed_cohere(documents=documents)
|
|
254
254
|
elif "titan" in self.model:
|
|
255
255
|
documents_with_embeddings = self._embed_titan(documents=documents)
|
|
256
|
+
else:
|
|
257
|
+
msg = f"Model {self.model} is not supported. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}."
|
|
258
|
+
raise ValueError(msg)
|
|
256
259
|
|
|
257
260
|
return {"documents": documents_with_embeddings}
|
|
258
261
|
|
|
@@ -64,7 +64,7 @@ class AmazonBedrockTextEmbedder:
|
|
|
64
64
|
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
|
|
65
65
|
boto3_config: Optional[Dict[str, Any]] = None,
|
|
66
66
|
**kwargs,
|
|
67
|
-
):
|
|
67
|
+
) -> None:
|
|
68
68
|
"""
|
|
69
69
|
Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the
|
|
70
70
|
Amazon Bedrock client.
|
|
@@ -126,7 +126,7 @@ class AmazonBedrockTextEmbedder:
|
|
|
126
126
|
raise AmazonBedrockConfigurationError(msg) from exception
|
|
127
127
|
|
|
128
128
|
@component.output_types(embedding=List[float])
|
|
129
|
-
def run(self, text: str):
|
|
129
|
+
def run(self, text: str) -> Dict[str, List[float]]:
|
|
130
130
|
"""Embeds the input text using the Amazon Bedrock model.
|
|
131
131
|
|
|
132
132
|
:param text: The input text to embed.
|
|
@@ -173,6 +173,9 @@ class AmazonBedrockTextEmbedder:
|
|
|
173
173
|
embedding = response_body["embeddings"][0]
|
|
174
174
|
elif "titan" in self.model:
|
|
175
175
|
embedding = response_body["embedding"]
|
|
176
|
+
else:
|
|
177
|
+
msg = f"Unsupported model {self.model}. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}"
|
|
178
|
+
raise ValueError(msg)
|
|
176
179
|
|
|
177
180
|
return {"embedding": embedding}
|
|
178
181
|
|
|
@@ -1,20 +1,21 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional, Tuple
|
|
1
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
+
import aioboto3
|
|
3
4
|
from botocore.config import Config
|
|
4
5
|
from botocore.eventstream import EventStream
|
|
5
6
|
from botocore.exceptions import ClientError
|
|
6
7
|
from haystack import component, default_from_dict, default_to_dict, logging
|
|
7
8
|
from haystack.dataclasses import ChatMessage, StreamingCallbackT, select_streaming_callback
|
|
8
|
-
from haystack.tools import
|
|
9
|
+
from haystack.tools import (
|
|
10
|
+
Tool,
|
|
11
|
+
Toolset,
|
|
12
|
+
_check_duplicate_tool_names,
|
|
13
|
+
deserialize_tools_or_toolset_inplace,
|
|
14
|
+
serialize_tools_or_toolset,
|
|
15
|
+
)
|
|
9
16
|
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
|
10
17
|
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
|
|
11
18
|
|
|
12
|
-
# Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
|
|
13
|
-
try:
|
|
14
|
-
from haystack.tools import deserialize_tools_or_toolset_inplace
|
|
15
|
-
except ImportError:
|
|
16
|
-
from haystack.tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
|
|
17
|
-
|
|
18
19
|
from haystack_integrations.common.amazon_bedrock.errors import (
|
|
19
20
|
AmazonBedrockConfigurationError,
|
|
20
21
|
AmazonBedrockInferenceError,
|
|
@@ -140,8 +141,8 @@ class AmazonBedrockChatGenerator:
|
|
|
140
141
|
stop_words: Optional[List[str]] = None,
|
|
141
142
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
142
143
|
boto3_config: Optional[Dict[str, Any]] = None,
|
|
143
|
-
tools: Optional[List[Tool]] = None,
|
|
144
|
-
):
|
|
144
|
+
tools: Optional[Union[List[Tool], Toolset]] = None,
|
|
145
|
+
) -> None:
|
|
145
146
|
"""
|
|
146
147
|
Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
|
|
147
148
|
Amazon Bedrock client.
|
|
@@ -172,7 +173,7 @@ class AmazonBedrockChatGenerator:
|
|
|
172
173
|
[StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches
|
|
173
174
|
the streaming mode on.
|
|
174
175
|
:param boto3_config: The configuration for the boto3 client.
|
|
175
|
-
:param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
|
|
176
|
+
:param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
|
|
176
177
|
|
|
177
178
|
:raises ValueError: If the model name is empty or None.
|
|
178
179
|
:raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
|
|
@@ -190,7 +191,7 @@ class AmazonBedrockChatGenerator:
|
|
|
190
191
|
self.stop_words = stop_words or []
|
|
191
192
|
self.streaming_callback = streaming_callback
|
|
192
193
|
self.boto3_config = boto3_config
|
|
193
|
-
_check_duplicate_tool_names(tools)
|
|
194
|
+
_check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
|
|
194
195
|
self.tools = tools
|
|
195
196
|
|
|
196
197
|
def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
|
|
@@ -228,7 +229,18 @@ class AmazonBedrockChatGenerator:
|
|
|
228
229
|
self.stop_words = stop_words or []
|
|
229
230
|
self.async_session = None
|
|
230
231
|
|
|
231
|
-
def _get_async_session(self):
|
|
232
|
+
def _get_async_session(self) -> aioboto3.Session:
|
|
233
|
+
"""
|
|
234
|
+
Initializes and returns an asynchronous AWS session for accessing Amazon Bedrock.
|
|
235
|
+
|
|
236
|
+
If the session is already created, it is reused. Otherwise, a new session is created using the provided AWS
|
|
237
|
+
credentials and configuration.
|
|
238
|
+
|
|
239
|
+
:returns:
|
|
240
|
+
An async-compatible boto3 session configured for use with Amazon Bedrock.
|
|
241
|
+
:raises AmazonBedrockConfigurationError:
|
|
242
|
+
If unable to establish an async session due to misconfiguration.
|
|
243
|
+
"""
|
|
232
244
|
if self.async_session:
|
|
233
245
|
return self.async_session
|
|
234
246
|
|
|
@@ -260,7 +272,6 @@ class AmazonBedrockChatGenerator:
|
|
|
260
272
|
Dictionary with serialized data.
|
|
261
273
|
"""
|
|
262
274
|
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
|
|
263
|
-
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
|
|
264
275
|
return default_to_dict(
|
|
265
276
|
self,
|
|
266
277
|
aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
|
|
@@ -273,7 +284,7 @@ class AmazonBedrockChatGenerator:
|
|
|
273
284
|
generation_kwargs=self.generation_kwargs,
|
|
274
285
|
streaming_callback=callback_name,
|
|
275
286
|
boto3_config=self.boto3_config,
|
|
276
|
-
tools=
|
|
287
|
+
tools=serialize_tools_or_toolset(self.tools),
|
|
277
288
|
)
|
|
278
289
|
|
|
279
290
|
@classmethod
|
|
@@ -301,19 +312,27 @@ class AmazonBedrockChatGenerator:
|
|
|
301
312
|
messages: List[ChatMessage],
|
|
302
313
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
303
314
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
|
304
|
-
tools: Optional[List[Tool]] = None,
|
|
315
|
+
tools: Optional[Union[List[Tool], Toolset]] = None,
|
|
305
316
|
requires_async: bool = False,
|
|
306
317
|
) -> Tuple[Dict[str, Any], Optional[StreamingCallbackT]]:
|
|
307
318
|
"""
|
|
308
|
-
Prepares
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
:param
|
|
314
|
-
:param
|
|
315
|
-
|
|
316
|
-
|
|
319
|
+
Prepares and formats parameters required to call the Amazon Bedrock Converse API.
|
|
320
|
+
|
|
321
|
+
This includes merging default and runtime generation parameters, formatting messages and tools, and
|
|
322
|
+
selecting the appropriate streaming callback.
|
|
323
|
+
|
|
324
|
+
:param messages: List of `ChatMessage` objects representing the conversation history.
|
|
325
|
+
:param streaming_callback: Optional streaming callback provided at runtime.
|
|
326
|
+
:param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
|
|
327
|
+
- `maxTokens`: Maximum number of tokens to generate.
|
|
328
|
+
- `stopSequences`: List of stop sequences to stop generation.
|
|
329
|
+
- `temperature`: Sampling temperature.
|
|
330
|
+
- `topP`: Nucleus sampling parameter.
|
|
331
|
+
:param tools: Optional list of Tool objects or a Toolset that the model can use.
|
|
332
|
+
:param requires_async: Boolean flag to indicate if an async-compatible streaming callback function is needed.
|
|
333
|
+
|
|
334
|
+
:returns:
|
|
335
|
+
A tuple of (API-ready parameter dictionary, streaming callback function).
|
|
317
336
|
"""
|
|
318
337
|
generation_kwargs = generation_kwargs or {}
|
|
319
338
|
|
|
@@ -331,9 +350,12 @@ class AmazonBedrockChatGenerator:
|
|
|
331
350
|
|
|
332
351
|
# Handle tools - either toolConfig or Haystack Tool objects but not both
|
|
333
352
|
tools = tools or self.tools
|
|
334
|
-
_check_duplicate_tool_names(tools)
|
|
353
|
+
_check_duplicate_tool_names(list(tools or []))
|
|
335
354
|
tool_config = merged_kwargs.pop("toolConfig", None)
|
|
336
355
|
if tools:
|
|
356
|
+
# Convert Toolset to list if needed
|
|
357
|
+
if isinstance(tools, Toolset):
|
|
358
|
+
tools = list(tools)
|
|
337
359
|
# Format Haystack tools to Bedrock format
|
|
338
360
|
tool_config = _format_tools(tools)
|
|
339
361
|
|
|
@@ -369,8 +391,27 @@ class AmazonBedrockChatGenerator:
|
|
|
369
391
|
messages: List[ChatMessage],
|
|
370
392
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
371
393
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
|
372
|
-
tools: Optional[List[Tool]] = None,
|
|
373
|
-
):
|
|
394
|
+
tools: Optional[Union[List[Tool], Toolset]] = None,
|
|
395
|
+
) -> Dict[str, List[ChatMessage]]:
|
|
396
|
+
"""
|
|
397
|
+
Executes a synchronous inference call to the Amazon Bedrock model using the Converse API.
|
|
398
|
+
|
|
399
|
+
Supports both standard and streaming responses depending on whether a streaming callback is provided.
|
|
400
|
+
|
|
401
|
+
:param messages: A list of `ChatMessage` objects forming the chat history.
|
|
402
|
+
:param streaming_callback: Optional callback for handling streaming outputs.
|
|
403
|
+
:param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
|
|
404
|
+
- `maxTokens`: Maximum number of tokens to generate.
|
|
405
|
+
- `stopSequences`: List of stop sequences to stop generation.
|
|
406
|
+
- `temperature`: Sampling temperature.
|
|
407
|
+
- `topP`: Nucleus sampling parameter.
|
|
408
|
+
:param tools: Optional list of Tools that the model may call during execution.
|
|
409
|
+
|
|
410
|
+
:returns:
|
|
411
|
+
A dictionary containing the model-generated replies under the `"replies"` key.
|
|
412
|
+
:raises AmazonBedrockInferenceError:
|
|
413
|
+
If the Bedrock inference API call fails.
|
|
414
|
+
"""
|
|
374
415
|
params, callback = self._prepare_request_params(
|
|
375
416
|
messages=messages,
|
|
376
417
|
streaming_callback=streaming_callback,
|
|
@@ -402,16 +443,26 @@ class AmazonBedrockChatGenerator:
|
|
|
402
443
|
messages: List[ChatMessage],
|
|
403
444
|
streaming_callback: Optional[StreamingCallbackT] = None,
|
|
404
445
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
|
405
|
-
tools: Optional[List[Tool]] = None,
|
|
406
|
-
):
|
|
446
|
+
tools: Optional[Union[List[Tool], Toolset]] = None,
|
|
447
|
+
) -> Dict[str, List[ChatMessage]]:
|
|
407
448
|
"""
|
|
408
|
-
|
|
449
|
+
Executes an asynchronous inference call to the Amazon Bedrock model using the Converse API.
|
|
450
|
+
|
|
451
|
+
Designed for use cases where non-blocking or concurrent execution is desired.
|
|
409
452
|
|
|
410
|
-
:param messages:
|
|
411
|
-
:param streaming_callback: Optional callback
|
|
412
|
-
:param generation_kwargs: Optional dictionary of generation parameters.
|
|
413
|
-
|
|
414
|
-
|
|
453
|
+
:param messages: A list of `ChatMessage` objects forming the chat history.
|
|
454
|
+
:param streaming_callback: Optional async-compatible callback for handling streaming outputs.
|
|
455
|
+
:param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
|
|
456
|
+
- `maxTokens`: Maximum number of tokens to generate.
|
|
457
|
+
- `stopSequences`: List of stop sequences to stop generation.
|
|
458
|
+
- `temperature`: Sampling temperature.
|
|
459
|
+
- `topP`: Nucleus sampling parameter.
|
|
460
|
+
:param tools: Optional list of Tool objects or a Toolset that the model can use.
|
|
461
|
+
|
|
462
|
+
:returns:
|
|
463
|
+
A dictionary containing the model-generated replies under the `"replies"` key.
|
|
464
|
+
:raises AmazonBedrockInferenceError:
|
|
465
|
+
If the Bedrock inference API call fails.
|
|
415
466
|
"""
|
|
416
467
|
params, callback = self._prepare_request_params(
|
|
417
468
|
messages=messages,
|
|
@@ -1,17 +1,30 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from
|
|
2
|
+
from datetime import datetime, timezone
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
3
4
|
|
|
4
5
|
from botocore.eventstream import EventStream
|
|
5
|
-
from haystack
|
|
6
|
+
from haystack import logging
|
|
7
|
+
from haystack.dataclasses import (
|
|
8
|
+
AsyncStreamingCallbackT,
|
|
9
|
+
ChatMessage,
|
|
10
|
+
ChatRole,
|
|
11
|
+
StreamingChunk,
|
|
12
|
+
SyncStreamingCallbackT,
|
|
13
|
+
ToolCall,
|
|
14
|
+
)
|
|
6
15
|
from haystack.tools import Tool
|
|
7
16
|
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
8
18
|
|
|
19
|
+
|
|
20
|
+
# Haystack to Bedrock util methods
|
|
9
21
|
def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]:
|
|
10
22
|
"""
|
|
11
23
|
Format Haystack Tool(s) to Amazon Bedrock toolConfig format.
|
|
12
24
|
|
|
13
25
|
:param tools: List of Tool objects to format
|
|
14
|
-
:
|
|
26
|
+
:returns:
|
|
27
|
+
Dictionary in Bedrock toolConfig format or None if no tools are provided
|
|
15
28
|
"""
|
|
16
29
|
if not tools:
|
|
17
30
|
return None
|
|
@@ -25,69 +38,160 @@ def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]
|
|
|
25
38
|
return {"tools": tool_specs} if tool_specs else None
|
|
26
39
|
|
|
27
40
|
|
|
41
|
+
def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]:
|
|
42
|
+
"""
|
|
43
|
+
Format a Haystack ChatMessage containing tool calls into Bedrock format.
|
|
44
|
+
|
|
45
|
+
:param tool_call_message: ChatMessage object containing tool calls to be formatted.
|
|
46
|
+
:returns:
|
|
47
|
+
Dictionary representing the tool call message in Bedrock's expected format
|
|
48
|
+
"""
|
|
49
|
+
content = []
|
|
50
|
+
# Tool call message can contain text
|
|
51
|
+
if tool_call_message.text:
|
|
52
|
+
content.append({"text": tool_call_message.text})
|
|
53
|
+
|
|
54
|
+
for tool_call in tool_call_message.tool_calls:
|
|
55
|
+
content.append(
|
|
56
|
+
{"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}}
|
|
57
|
+
)
|
|
58
|
+
return {"role": tool_call_message.role.value, "content": content}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _format_tool_result_message(tool_call_result_message: ChatMessage) -> Dict[str, Any]:
|
|
62
|
+
"""
|
|
63
|
+
Format a Haystack ChatMessage containing tool call results into Bedrock format.
|
|
64
|
+
|
|
65
|
+
:param tool_call_result_message: ChatMessage object containing tool call results to be formatted.
|
|
66
|
+
:returns: Dictionary representing the tool result message in Bedrock's expected format
|
|
67
|
+
"""
|
|
68
|
+
# Assuming tool call result messages will only contain tool results
|
|
69
|
+
tool_results = []
|
|
70
|
+
for result in tool_call_result_message.tool_call_results:
|
|
71
|
+
try:
|
|
72
|
+
json_result = json.loads(result.result)
|
|
73
|
+
content = [{"json": json_result}]
|
|
74
|
+
except json.JSONDecodeError:
|
|
75
|
+
content = [{"text": result.result}]
|
|
76
|
+
|
|
77
|
+
tool_results.append(
|
|
78
|
+
{
|
|
79
|
+
"toolResult": {
|
|
80
|
+
"toolUseId": result.origin.id,
|
|
81
|
+
"content": content,
|
|
82
|
+
**({"status": "error"} if result.error else {}),
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
)
|
|
86
|
+
# role must be user
|
|
87
|
+
return {"role": "user", "content": tool_results}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _repair_tool_result_messages(bedrock_formatted_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
91
|
+
"""
|
|
92
|
+
Repair and reorganize tool result messages to maintain proper ordering and grouping.
|
|
93
|
+
|
|
94
|
+
Ensures tool result messages are properly grouped in the same way as their corresponding tool call messages
|
|
95
|
+
and maintains the original message ordering.
|
|
96
|
+
|
|
97
|
+
:param bedrock_formatted_messages: List of Bedrock-formatted messages that may need repair.
|
|
98
|
+
:returns: List of properly organized Bedrock-formatted messages with correctly grouped tool results.
|
|
99
|
+
"""
|
|
100
|
+
tool_call_messages = []
|
|
101
|
+
tool_result_messages = []
|
|
102
|
+
for idx, msg in enumerate(bedrock_formatted_messages):
|
|
103
|
+
content = msg.get("content", [])
|
|
104
|
+
if content:
|
|
105
|
+
if any("toolUse" in c for c in content):
|
|
106
|
+
tool_call_messages.append((idx, msg))
|
|
107
|
+
elif any("toolResult" in c for c in content):
|
|
108
|
+
tool_result_messages.append((idx, msg))
|
|
109
|
+
|
|
110
|
+
# Determine the tool call IDs for each tool call message
|
|
111
|
+
group_to_tool_call_ids: Dict[int, Any] = {idx: [] for idx, _ in tool_call_messages}
|
|
112
|
+
for idx, tool_call in tool_call_messages:
|
|
113
|
+
tool_use_contents = [c for c in tool_call["content"] if "toolUse" in c]
|
|
114
|
+
for content in tool_use_contents:
|
|
115
|
+
group_to_tool_call_ids[idx].append(content["toolUse"]["toolUseId"])
|
|
116
|
+
|
|
117
|
+
# Regroups the tool_result_prompts based on the tool_call_prompts
|
|
118
|
+
# Makes sure to:
|
|
119
|
+
# - Within the new group the tool call IDs of the tool result messages are in the same order as the tool call
|
|
120
|
+
# messages
|
|
121
|
+
# - The tool result messages are in the same order as the original message list
|
|
122
|
+
repaired_tool_result_prompts = []
|
|
123
|
+
for tool_call_ids in group_to_tool_call_ids.values():
|
|
124
|
+
regrouped_tool_result = []
|
|
125
|
+
original_idx = None
|
|
126
|
+
for tool_call_id in tool_call_ids:
|
|
127
|
+
for idx, tool_result in tool_result_messages:
|
|
128
|
+
tool_result_contents = [c for c in tool_result["content"] if "toolResult" in c]
|
|
129
|
+
for content in tool_result_contents:
|
|
130
|
+
if content["toolResult"]["toolUseId"] == tool_call_id:
|
|
131
|
+
regrouped_tool_result.append(content)
|
|
132
|
+
# Keep track of the original index of the last tool result message
|
|
133
|
+
original_idx = idx
|
|
134
|
+
if regrouped_tool_result and original_idx is not None:
|
|
135
|
+
repaired_tool_result_prompts.append((original_idx, {"role": "user", "content": regrouped_tool_result}))
|
|
136
|
+
|
|
137
|
+
# Remove the tool result messages from bedrock_formatted_messages
|
|
138
|
+
bedrock_formatted_messages_minus_tool_results: List[Tuple[int, Any]] = []
|
|
139
|
+
for idx, msg in enumerate(bedrock_formatted_messages):
|
|
140
|
+
# Assumes the content of tool result messages only contains 'toolResult': {...} objects (e.g. no 'text')
|
|
141
|
+
if msg.get("content") and "toolResult" not in msg["content"][0]:
|
|
142
|
+
bedrock_formatted_messages_minus_tool_results.append((idx, msg))
|
|
143
|
+
|
|
144
|
+
# Add the repaired tool result messages and sort to maintain the correct order
|
|
145
|
+
repaired_bedrock_formatted_messages = bedrock_formatted_messages_minus_tool_results + repaired_tool_result_prompts
|
|
146
|
+
repaired_bedrock_formatted_messages.sort(key=lambda x: x[0])
|
|
147
|
+
|
|
148
|
+
# Drop the index and return only the messages
|
|
149
|
+
return [msg for _, msg in repaired_bedrock_formatted_messages]
|
|
150
|
+
|
|
151
|
+
|
|
28
152
|
def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
29
153
|
"""
|
|
30
|
-
Format a list of ChatMessages to the format expected by Bedrock API.
|
|
31
|
-
Separates system messages and handles tool results and tool calls.
|
|
154
|
+
Format a list of Haystack ChatMessages to the format expected by Bedrock API.
|
|
32
155
|
|
|
33
|
-
|
|
34
|
-
|
|
156
|
+
Processes and separates system messages from other message types and handles special formatting for tool calls
|
|
157
|
+
and tool results.
|
|
158
|
+
|
|
159
|
+
:param messages: List of ChatMessage objects to format for Bedrock API.
|
|
160
|
+
:returns: Tuple containing (system_prompts, non_system_messages) in Bedrock format,
|
|
161
|
+
where system_prompts is a list of system message dictionaries and
|
|
162
|
+
non_system_messages is a list of properly formatted message dictionaries.
|
|
35
163
|
"""
|
|
164
|
+
# Separate system messages, tool calls, and tool results
|
|
36
165
|
system_prompts = []
|
|
37
|
-
|
|
38
|
-
|
|
166
|
+
bedrock_formatted_messages = []
|
|
39
167
|
for msg in messages:
|
|
40
168
|
if msg.is_from(ChatRole.SYSTEM):
|
|
169
|
+
# Assuming system messages can only contain text
|
|
170
|
+
# Don't need to track idx since system_messages are handled separately
|
|
41
171
|
system_prompts.append({"text": msg.text})
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
json_result = json.loads(result.result)
|
|
50
|
-
content = [{"json": json_result}]
|
|
51
|
-
except json.JSONDecodeError:
|
|
52
|
-
content = [{"text": result.result}]
|
|
53
|
-
|
|
54
|
-
tool_results.append(
|
|
55
|
-
{
|
|
56
|
-
"toolResult": {
|
|
57
|
-
"toolUseId": result.origin.id,
|
|
58
|
-
"content": content,
|
|
59
|
-
**({"status": "error"} if result.error else {}),
|
|
60
|
-
}
|
|
61
|
-
}
|
|
62
|
-
)
|
|
63
|
-
non_system_messages.append({"role": "user", "content": tool_results})
|
|
64
|
-
continue
|
|
65
|
-
|
|
66
|
-
content = []
|
|
67
|
-
# Handle text content
|
|
68
|
-
if msg.text:
|
|
69
|
-
content.append({"text": msg.text})
|
|
70
|
-
|
|
71
|
-
# Handle tool calls
|
|
72
|
-
if msg.tool_calls:
|
|
73
|
-
for tool_call in msg.tool_calls:
|
|
74
|
-
content.append(
|
|
75
|
-
{"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}}
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
if content: # Only add message if it has content
|
|
79
|
-
non_system_messages.append({"role": msg.role.value, "content": content})
|
|
172
|
+
elif msg.tool_calls:
|
|
173
|
+
bedrock_formatted_messages.append(_format_tool_call_message(msg))
|
|
174
|
+
elif msg.tool_call_results:
|
|
175
|
+
bedrock_formatted_messages.append(_format_tool_result_message(msg))
|
|
176
|
+
else:
|
|
177
|
+
# regular user or assistant messages with only text content
|
|
178
|
+
bedrock_formatted_messages.append({"role": msg.role.value, "content": [{"text": msg.text}]})
|
|
80
179
|
|
|
81
|
-
|
|
180
|
+
repaired_bedrock_formatted_messages = _repair_tool_result_messages(bedrock_formatted_messages)
|
|
181
|
+
return system_prompts, repaired_bedrock_formatted_messages
|
|
82
182
|
|
|
83
183
|
|
|
184
|
+
# Bedrock to Haystack util method
|
|
84
185
|
def _parse_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]:
|
|
85
186
|
"""
|
|
86
|
-
Parse a Bedrock response
|
|
187
|
+
Parse a Bedrock API response into Haystack ChatMessage objects.
|
|
87
188
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
189
|
+
Extracts text content, tool calls, and metadata from the Bedrock response and converts them into the appropriate
|
|
190
|
+
Haystack format.
|
|
191
|
+
|
|
192
|
+
:param response_body: Raw JSON response from Bedrock API.
|
|
193
|
+
:param model: The model ID used for generation, included in message metadata.
|
|
194
|
+
:returns: List of ChatMessage objects containing the assistant's response(s) with appropriate metadata.
|
|
91
195
|
"""
|
|
92
196
|
replies = []
|
|
93
197
|
if "output" in response_body and "message" in response_body["output"]:
|
|
@@ -130,9 +234,209 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
|
|
|
130
234
|
return replies
|
|
131
235
|
|
|
132
236
|
|
|
237
|
+
# Bedrock streaming to Haystack util methods
|
|
238
|
+
def _convert_event_to_streaming_chunk(event: Dict[str, Any], model: str) -> StreamingChunk:
|
|
239
|
+
"""
|
|
240
|
+
Convert a Bedrock streaming event to a Haystack StreamingChunk.
|
|
241
|
+
|
|
242
|
+
Handles different event types (contentBlockStart, contentBlockDelta, messageStop, metadata) and extracts relevant
|
|
243
|
+
information to create StreamingChunk objects in the same format used by Haystack's OpenAIChatGenerator.
|
|
244
|
+
|
|
245
|
+
:param event: Dictionary containing a Bedrock streaming event.
|
|
246
|
+
:param model: The model ID used for generation, included in chunk metadata.
|
|
247
|
+
:returns: StreamingChunk object containing the content and metadata extracted from the event.
|
|
248
|
+
"""
|
|
249
|
+
# Initialize an empty StreamingChunk to return if no relevant event is found
|
|
250
|
+
# (e.g. for messageStart and contentBlockStop)
|
|
251
|
+
streaming_chunk = StreamingChunk(
|
|
252
|
+
content="", meta={"model": model, "received_at": datetime.now(timezone.utc).isoformat()}
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if "contentBlockStart" in event:
|
|
256
|
+
# contentBlockStart always has the key "contentBlockIndex"
|
|
257
|
+
block_start = event["contentBlockStart"]
|
|
258
|
+
block_idx = block_start["contentBlockIndex"]
|
|
259
|
+
if "start" in block_start and "toolUse" in block_start["start"]:
|
|
260
|
+
tool_start = block_start["start"]["toolUse"]
|
|
261
|
+
streaming_chunk = StreamingChunk(
|
|
262
|
+
content="",
|
|
263
|
+
meta={
|
|
264
|
+
"model": model,
|
|
265
|
+
# This is always 0 b/c it represents the choice index
|
|
266
|
+
"index": 0,
|
|
267
|
+
# We follow the same format used in the OpenAIChatGenerator
|
|
268
|
+
"tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
|
|
269
|
+
{
|
|
270
|
+
"index": block_idx, # int
|
|
271
|
+
"id": tool_start["toolUseId"], # Optional[str]
|
|
272
|
+
"function": { # Optional[ChoiceDeltaToolCallFunction]
|
|
273
|
+
# Will accumulate deltas as string
|
|
274
|
+
"arguments": "", # Optional[str]
|
|
275
|
+
"name": tool_start["name"], # Optional[str]
|
|
276
|
+
},
|
|
277
|
+
"type": "function", # Optional[Literal["function"]]
|
|
278
|
+
}
|
|
279
|
+
],
|
|
280
|
+
"finish_reason": None,
|
|
281
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
282
|
+
},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
elif "contentBlockDelta" in event:
|
|
286
|
+
# contentBlockDelta always has the key "contentBlockIndex" and "delta"
|
|
287
|
+
block_idx = event["contentBlockDelta"]["contentBlockIndex"]
|
|
288
|
+
delta = event["contentBlockDelta"]["delta"]
|
|
289
|
+
# This is for accumulating text deltas
|
|
290
|
+
if "text" in delta:
|
|
291
|
+
streaming_chunk = StreamingChunk(
|
|
292
|
+
content=delta["text"],
|
|
293
|
+
meta={
|
|
294
|
+
"model": model,
|
|
295
|
+
# This is always 0 b/c it represents the choice index
|
|
296
|
+
"index": 0,
|
|
297
|
+
"tool_calls": None,
|
|
298
|
+
"finish_reason": None,
|
|
299
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
300
|
+
},
|
|
301
|
+
)
|
|
302
|
+
# This only occurs when accumulating the arguments for a toolUse
|
|
303
|
+
# The content_block for this tool should already exist at this point
|
|
304
|
+
elif "toolUse" in delta:
|
|
305
|
+
streaming_chunk = StreamingChunk(
|
|
306
|
+
content="",
|
|
307
|
+
meta={
|
|
308
|
+
"model": model,
|
|
309
|
+
# This is always 0 b/c it represents the choice index
|
|
310
|
+
"index": 0,
|
|
311
|
+
"tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
|
|
312
|
+
{
|
|
313
|
+
"index": block_idx, # int
|
|
314
|
+
"id": None, # Optional[str]
|
|
315
|
+
"function": { # Optional[ChoiceDeltaToolCallFunction]
|
|
316
|
+
# Will accumulate deltas as string
|
|
317
|
+
"arguments": delta["toolUse"].get("input", ""), # Optional[str]
|
|
318
|
+
"name": None, # Optional[str]
|
|
319
|
+
},
|
|
320
|
+
"type": "function", # Optional[Literal["function"]]
|
|
321
|
+
}
|
|
322
|
+
],
|
|
323
|
+
"finish_reason": None,
|
|
324
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
325
|
+
},
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
elif "messageStop" in event:
|
|
329
|
+
finish_reason = event["messageStop"].get("stopReason")
|
|
330
|
+
streaming_chunk = StreamingChunk(
|
|
331
|
+
content="",
|
|
332
|
+
meta={
|
|
333
|
+
"model": model,
|
|
334
|
+
# This is always 0 b/c it represents the choice index
|
|
335
|
+
"index": 0,
|
|
336
|
+
"tool_calls": None,
|
|
337
|
+
"finish_reason": finish_reason,
|
|
338
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
339
|
+
},
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
elif "metadata" in event and "usage" in event["metadata"]:
|
|
343
|
+
metadata = event["metadata"]
|
|
344
|
+
streaming_chunk = StreamingChunk(
|
|
345
|
+
content="",
|
|
346
|
+
meta={
|
|
347
|
+
"model": model,
|
|
348
|
+
# This is always 0 b/c it represents the choice index
|
|
349
|
+
"index": 0,
|
|
350
|
+
"tool_calls": None,
|
|
351
|
+
"finish_reason": None,
|
|
352
|
+
"received_at": datetime.now(timezone.utc).isoformat(),
|
|
353
|
+
"usage": {
|
|
354
|
+
"prompt_tokens": metadata["usage"].get("inputTokens", 0),
|
|
355
|
+
"completion_tokens": metadata["usage"].get("outputTokens", 0),
|
|
356
|
+
"total_tokens": metadata["usage"].get("totalTokens", 0),
|
|
357
|
+
},
|
|
358
|
+
},
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return streaming_chunk
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage:
|
|
365
|
+
"""
|
|
366
|
+
Converts a list of streaming chunks into a ChatMessage object.
|
|
367
|
+
|
|
368
|
+
The function processes streaming chunks to build a ChatMessage object, including extracting and constructing
|
|
369
|
+
tool calls, managing metadata such as model type, finish reason, and usage information.
|
|
370
|
+
The tool call processing handles accumulating data across the chunks and attempts to parse JSON-formatted
|
|
371
|
+
arguments for tool calls.
|
|
372
|
+
|
|
373
|
+
:param chunks: A list of StreamingChunk objects representing parts of the assistant's response.
|
|
374
|
+
|
|
375
|
+
:returns:
|
|
376
|
+
A ChatMessage object constructed from the streaming chunks, containing the aggregated text, processed tool
|
|
377
|
+
calls, and metadata.
|
|
378
|
+
"""
|
|
379
|
+
text = "".join([chunk.content for chunk in chunks])
|
|
380
|
+
|
|
381
|
+
# Process tool calls if present in any chunk
|
|
382
|
+
tool_calls = []
|
|
383
|
+
tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
|
|
384
|
+
for chunk_payload in chunks:
|
|
385
|
+
tool_calls_meta = chunk_payload.meta.get("tool_calls")
|
|
386
|
+
if tool_calls_meta is not None:
|
|
387
|
+
for delta in tool_calls_meta:
|
|
388
|
+
# We use the index of the tool call to track it across chunks since the ID is not always provided
|
|
389
|
+
if delta["index"] not in tool_call_data:
|
|
390
|
+
tool_call_data[delta["index"]] = {"id": "", "name": "", "arguments": ""}
|
|
391
|
+
|
|
392
|
+
# Save the ID if present
|
|
393
|
+
if delta.get("id"):
|
|
394
|
+
tool_call_data[delta["index"]]["id"] = delta["id"]
|
|
395
|
+
|
|
396
|
+
if delta.get("function"):
|
|
397
|
+
if delta["function"].get("name"):
|
|
398
|
+
tool_call_data[delta["index"]]["name"] += delta["function"]["name"]
|
|
399
|
+
if delta["function"].get("arguments"):
|
|
400
|
+
tool_call_data[delta["index"]]["arguments"] += delta["function"]["arguments"]
|
|
401
|
+
|
|
402
|
+
# Convert accumulated tool call data into ToolCall objects
|
|
403
|
+
for call_data in tool_call_data.values():
|
|
404
|
+
try:
|
|
405
|
+
arguments = json.loads(call_data["arguments"])
|
|
406
|
+
tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
|
|
407
|
+
except json.JSONDecodeError:
|
|
408
|
+
logger.warning(
|
|
409
|
+
"Amazon Bedrock returned a malformed JSON string for tool call arguments. This tool call will be "
|
|
410
|
+
"skipped. Tool call ID: {tool_id}, Tool name: {tool_name}, Arguments: {tool_arguments}",
|
|
411
|
+
tool_id=call_data["id"],
|
|
412
|
+
tool_name=call_data["name"],
|
|
413
|
+
tool_arguments=call_data["arguments"],
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
# finish_reason can appear in different places so we look for the last one
|
|
417
|
+
finish_reasons = [
|
|
418
|
+
chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
|
|
419
|
+
]
|
|
420
|
+
finish_reason = finish_reasons[-1] if finish_reasons else None
|
|
421
|
+
|
|
422
|
+
# usage is usually last but we look for it as well
|
|
423
|
+
usages = [chunk.meta.get("usage") for chunk in chunks if chunk.meta.get("usage") is not None]
|
|
424
|
+
usage = usages[-1] if usages else None
|
|
425
|
+
|
|
426
|
+
meta = {
|
|
427
|
+
"model": chunks[-1].meta["model"],
|
|
428
|
+
"index": 0,
|
|
429
|
+
"finish_reason": finish_reason,
|
|
430
|
+
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
|
|
431
|
+
"usage": usage,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
|
|
435
|
+
|
|
436
|
+
|
|
133
437
|
def _parse_streaming_response(
|
|
134
438
|
response_stream: EventStream,
|
|
135
|
-
streaming_callback:
|
|
439
|
+
streaming_callback: SyncStreamingCallbackT,
|
|
136
440
|
model: str,
|
|
137
441
|
) -> List[ChatMessage]:
|
|
138
442
|
"""
|
|
@@ -143,78 +447,18 @@ def _parse_streaming_response(
|
|
|
143
447
|
:param model: The model ID used for generation
|
|
144
448
|
:return: List of ChatMessage objects
|
|
145
449
|
"""
|
|
146
|
-
|
|
147
|
-
current_content = ""
|
|
148
|
-
current_tool_call: Optional[Dict[str, Any]] = None
|
|
149
|
-
base_meta = {"model": model, "index": 0}
|
|
150
|
-
|
|
450
|
+
chunks: List[StreamingChunk] = []
|
|
151
451
|
for event in response_stream:
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
block_start = event["contentBlockStart"]
|
|
157
|
-
if "start" in block_start and "toolUse" in block_start["start"]:
|
|
158
|
-
tool_start = block_start["start"]["toolUse"]
|
|
159
|
-
current_tool_call = {
|
|
160
|
-
"id": tool_start["toolUseId"],
|
|
161
|
-
"name": tool_start["name"],
|
|
162
|
-
"arguments": "", # Will accumulate deltas as string
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
elif "contentBlockDelta" in event:
|
|
166
|
-
delta = event["contentBlockDelta"]["delta"]
|
|
167
|
-
if "text" in delta:
|
|
168
|
-
delta_text = delta["text"]
|
|
169
|
-
current_content += delta_text
|
|
170
|
-
streaming_chunk = StreamingChunk(content=delta_text, meta={})
|
|
171
|
-
streaming_callback(streaming_chunk)
|
|
172
|
-
elif "toolUse" in delta and current_tool_call:
|
|
173
|
-
# Accumulate tool use input deltas
|
|
174
|
-
current_tool_call["arguments"] += delta["toolUse"].get("input", "")
|
|
175
|
-
|
|
176
|
-
elif "contentBlockStop" in event:
|
|
177
|
-
if current_tool_call:
|
|
178
|
-
# Parse accumulated input if it's a JSON string
|
|
179
|
-
try:
|
|
180
|
-
input_json = json.loads(current_tool_call["arguments"])
|
|
181
|
-
current_tool_call["arguments"] = input_json
|
|
182
|
-
except json.JSONDecodeError:
|
|
183
|
-
# Keep as string if not valid JSON
|
|
184
|
-
pass
|
|
185
|
-
|
|
186
|
-
tool_call = ToolCall(
|
|
187
|
-
id=current_tool_call["id"],
|
|
188
|
-
tool_name=current_tool_call["name"],
|
|
189
|
-
arguments=current_tool_call["arguments"],
|
|
190
|
-
)
|
|
191
|
-
replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy()))
|
|
192
|
-
elif current_content:
|
|
193
|
-
replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))
|
|
194
|
-
|
|
195
|
-
elif "messageStop" in event:
|
|
196
|
-
# Update finish reason for all replies
|
|
197
|
-
for reply in replies:
|
|
198
|
-
reply.meta["finish_reason"] = event["messageStop"].get("stopReason")
|
|
199
|
-
|
|
200
|
-
elif "metadata" in event:
|
|
201
|
-
metadata = event["metadata"]
|
|
202
|
-
# Update usage stats for all replies
|
|
203
|
-
for reply in replies:
|
|
204
|
-
if "usage" in metadata:
|
|
205
|
-
usage = metadata["usage"]
|
|
206
|
-
reply.meta["usage"] = {
|
|
207
|
-
"prompt_tokens": usage.get("inputTokens", 0),
|
|
208
|
-
"completion_tokens": usage.get("outputTokens", 0),
|
|
209
|
-
"total_tokens": usage.get("totalTokens", 0),
|
|
210
|
-
}
|
|
211
|
-
|
|
452
|
+
streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model)
|
|
453
|
+
streaming_callback(streaming_chunk)
|
|
454
|
+
chunks.append(streaming_chunk)
|
|
455
|
+
replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
|
|
212
456
|
return replies
|
|
213
457
|
|
|
214
458
|
|
|
215
459
|
async def _parse_streaming_response_async(
|
|
216
460
|
response_stream: EventStream,
|
|
217
|
-
streaming_callback:
|
|
461
|
+
streaming_callback: AsyncStreamingCallbackT,
|
|
218
462
|
model: str,
|
|
219
463
|
) -> List[ChatMessage]:
|
|
220
464
|
"""
|
|
@@ -225,70 +469,10 @@ async def _parse_streaming_response_async(
|
|
|
225
469
|
:param model: The model ID used for generation
|
|
226
470
|
:return: List of ChatMessage objects
|
|
227
471
|
"""
|
|
228
|
-
|
|
229
|
-
current_content = ""
|
|
230
|
-
current_tool_call: Optional[Dict[str, Any]] = None
|
|
231
|
-
base_meta = {"model": model, "index": 0}
|
|
232
|
-
|
|
472
|
+
chunks: List[StreamingChunk] = []
|
|
233
473
|
async for event in response_stream:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
block_start = event["contentBlockStart"]
|
|
239
|
-
if "start" in block_start and "toolUse" in block_start["start"]:
|
|
240
|
-
tool_start = block_start["start"]["toolUse"]
|
|
241
|
-
current_tool_call = {
|
|
242
|
-
"id": tool_start["toolUseId"],
|
|
243
|
-
"name": tool_start["name"],
|
|
244
|
-
"arguments": "", # Will accumulate deltas as string
|
|
245
|
-
}
|
|
246
|
-
|
|
247
|
-
elif "contentBlockDelta" in event:
|
|
248
|
-
delta = event["contentBlockDelta"]["delta"]
|
|
249
|
-
if "text" in delta:
|
|
250
|
-
delta_text = delta["text"]
|
|
251
|
-
current_content += delta_text
|
|
252
|
-
streaming_chunk = StreamingChunk(content=delta_text, meta={})
|
|
253
|
-
await streaming_callback(streaming_chunk)
|
|
254
|
-
elif "toolUse" in delta and current_tool_call:
|
|
255
|
-
# Accumulate tool use input deltas
|
|
256
|
-
current_tool_call["arguments"] += delta["toolUse"].get("input", "")
|
|
257
|
-
|
|
258
|
-
elif "contentBlockStop" in event:
|
|
259
|
-
if current_tool_call:
|
|
260
|
-
# Parse accumulated input if it's a JSON string
|
|
261
|
-
try:
|
|
262
|
-
input_json = json.loads(current_tool_call["arguments"])
|
|
263
|
-
current_tool_call["arguments"] = input_json
|
|
264
|
-
except json.JSONDecodeError:
|
|
265
|
-
# Keep as string if not valid JSON
|
|
266
|
-
pass
|
|
267
|
-
|
|
268
|
-
tool_call = ToolCall(
|
|
269
|
-
id=current_tool_call["id"],
|
|
270
|
-
tool_name=current_tool_call["name"],
|
|
271
|
-
arguments=current_tool_call["arguments"],
|
|
272
|
-
)
|
|
273
|
-
replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy()))
|
|
274
|
-
elif current_content:
|
|
275
|
-
replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))
|
|
276
|
-
|
|
277
|
-
elif "messageStop" in event:
|
|
278
|
-
# Update finish reason for all replies
|
|
279
|
-
for reply in replies:
|
|
280
|
-
reply.meta["finish_reason"] = event["messageStop"].get("stopReason")
|
|
281
|
-
|
|
282
|
-
elif "metadata" in event:
|
|
283
|
-
metadata = event["metadata"]
|
|
284
|
-
# Update usage stats for all replies
|
|
285
|
-
for reply in replies:
|
|
286
|
-
if "usage" in metadata:
|
|
287
|
-
usage = metadata["usage"]
|
|
288
|
-
reply.meta["usage"] = {
|
|
289
|
-
"prompt_tokens": usage.get("inputTokens", 0),
|
|
290
|
-
"completion_tokens": usage.get("outputTokens", 0),
|
|
291
|
-
"total_tokens": usage.get("totalTokens", 0),
|
|
292
|
-
}
|
|
293
|
-
|
|
474
|
+
streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model)
|
|
475
|
+
await streaming_callback(streaming_chunk)
|
|
476
|
+
chunks.append(streaming_chunk)
|
|
477
|
+
replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
|
|
294
478
|
return replies
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
3
|
import warnings
|
|
4
|
-
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args
|
|
4
|
+
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, Union, get_args
|
|
5
5
|
|
|
6
6
|
from botocore.config import Config
|
|
7
7
|
from botocore.exceptions import ClientError
|
|
@@ -108,7 +108,7 @@ class AmazonBedrockGenerator:
|
|
|
108
108
|
boto3_config: Optional[Dict[str, Any]] = None,
|
|
109
109
|
model_family: Optional[MODEL_FAMILIES] = None,
|
|
110
110
|
**kwargs,
|
|
111
|
-
):
|
|
111
|
+
) -> None:
|
|
112
112
|
"""
|
|
113
113
|
Create a new `AmazonBedrockGenerator` instance.
|
|
114
114
|
|
|
@@ -189,7 +189,7 @@ class AmazonBedrockGenerator:
|
|
|
189
189
|
prompt: str,
|
|
190
190
|
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
|
|
191
191
|
generation_kwargs: Optional[Dict[str, Any]] = None,
|
|
192
|
-
):
|
|
192
|
+
) -> Dict[str, Union[List[str], Dict[str, Any]]]:
|
|
193
193
|
"""
|
|
194
194
|
Generates a list of string response to the given prompt.
|
|
195
195
|
|
|
@@ -1,3 +1,3 @@
|
|
|
1
|
-
from .ranker import BedrockRanker
|
|
1
|
+
from .ranker import AmazonBedrockRanker, BedrockRanker
|
|
2
2
|
|
|
3
|
-
__all__ = ["BedrockRanker"]
|
|
3
|
+
__all__ = ["AmazonBedrockRanker", "BedrockRanker"]
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import warnings
|
|
1
2
|
from typing import Any, Dict, List, Optional
|
|
2
3
|
|
|
3
4
|
from botocore.exceptions import ClientError
|
|
@@ -16,7 +17,7 @@ MAX_NUM_DOCS_FOR_BEDROCK_RANKER = 1000
|
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
@component
|
|
19
|
-
class
|
|
20
|
+
class AmazonBedrockRanker:
|
|
20
21
|
"""
|
|
21
22
|
Ranks Documents based on their similarity to the query using Amazon Bedrock's Cohere Rerank model.
|
|
22
23
|
|
|
@@ -30,9 +31,13 @@ class BedrockRanker:
|
|
|
30
31
|
```python
|
|
31
32
|
from haystack import Document
|
|
32
33
|
from haystack.utils import Secret
|
|
33
|
-
from haystack_integrations.components.rankers.amazon_bedrock import
|
|
34
|
+
from haystack_integrations.components.rankers.amazon_bedrock import AmazonBedrockRanker
|
|
34
35
|
|
|
35
|
-
ranker =
|
|
36
|
+
ranker = AmazonBedrockRanker(
|
|
37
|
+
model="cohere.rerank-v3-5:0",
|
|
38
|
+
top_k=2,
|
|
39
|
+
aws_region_name=Secret.from_token("eu-central-1")
|
|
40
|
+
)
|
|
36
41
|
|
|
37
42
|
docs = [Document(content="Paris"), Document(content="Berlin")]
|
|
38
43
|
query = "What is the capital of germany?"
|
|
@@ -40,7 +45,7 @@ class BedrockRanker:
|
|
|
40
45
|
docs = output["documents"]
|
|
41
46
|
```
|
|
42
47
|
|
|
43
|
-
|
|
48
|
+
AmazonBedrockRanker uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM.
|
|
44
49
|
For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation]
|
|
45
50
|
(https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html).
|
|
46
51
|
|
|
@@ -66,12 +71,12 @@ class BedrockRanker:
|
|
|
66
71
|
max_chunks_per_doc: Optional[int] = None,
|
|
67
72
|
meta_fields_to_embed: Optional[List[str]] = None,
|
|
68
73
|
meta_data_separator: str = "\n",
|
|
69
|
-
):
|
|
74
|
+
) -> None:
|
|
70
75
|
if not model:
|
|
71
76
|
msg = "'model' cannot be None or empty string"
|
|
72
77
|
raise ValueError(msg)
|
|
73
78
|
"""
|
|
74
|
-
Creates an instance of the '
|
|
79
|
+
Creates an instance of the 'AmazonBedrockRanker'.
|
|
75
80
|
|
|
76
81
|
:param model: Amazon Bedrock model name for Cohere Rerank. Default is "cohere.rerank-v3-5:0".
|
|
77
82
|
:param top_k: The maximum number of documents to return.
|
|
@@ -140,7 +145,7 @@ class BedrockRanker:
|
|
|
140
145
|
)
|
|
141
146
|
|
|
142
147
|
@classmethod
|
|
143
|
-
def from_dict(cls, data: Dict[str, Any]) -> "
|
|
148
|
+
def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockRanker":
|
|
144
149
|
"""
|
|
145
150
|
Deserializes the component from a dictionary.
|
|
146
151
|
|
|
@@ -173,7 +178,7 @@ class BedrockRanker:
|
|
|
173
178
|
return concatenated_input_list
|
|
174
179
|
|
|
175
180
|
@component.output_types(documents=List[Document])
|
|
176
|
-
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
|
181
|
+
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict[str, List[Document]]:
|
|
177
182
|
"""
|
|
178
183
|
Use the Amazon Bedrock Reranker to re-rank the list of documents based on the query.
|
|
179
184
|
|
|
@@ -260,3 +265,20 @@ class BedrockRanker:
|
|
|
260
265
|
except Exception as e:
|
|
261
266
|
msg = f"Error during Amazon Bedrock API call: {e!s}"
|
|
262
267
|
raise AmazonBedrockInferenceError(msg) from e
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class BedrockRanker(AmazonBedrockRanker):
|
|
271
|
+
"""
|
|
272
|
+
Deprecated alias for AmazonBedrockRanker.
|
|
273
|
+
This class will be removed in a future version.
|
|
274
|
+
Please use AmazonBedrockRanker instead.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
def __init__(self, *args, **kwargs):
|
|
278
|
+
warnings.warn(
|
|
279
|
+
"BedrockRanker is deprecated and will be removed in a future version. "
|
|
280
|
+
"Please use AmazonBedrockRanker instead.",
|
|
281
|
+
DeprecationWarning,
|
|
282
|
+
stacklevel=2,
|
|
283
|
+
)
|
|
284
|
+
super().__init__(*args, **kwargs)
|
|
File without changes
|
|
File without changes
|