amazon-bedrock-haystack 3.5.2__py3-none-any.whl → 3.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: amazon-bedrock-haystack
3
- Version: 3.5.2
3
+ Version: 3.6.1
4
4
  Summary: An integration of Amazon Bedrock as an AmazonBedrockGenerator component.
5
5
  Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme
6
6
  Project-URL: Issues, https://github.com/deepset-ai/haystack-core-integrations/issues
@@ -21,7 +21,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
21
21
  Requires-Python: >=3.9
22
22
  Requires-Dist: aioboto3>=14.0.0
23
23
  Requires-Dist: boto3>=1.28.57
24
- Requires-Dist: haystack-ai
24
+ Requires-Dist: haystack-ai>=2.13.1
25
25
  Description-Content-Type: text/markdown
26
26
 
27
27
  # amazon-bedrock-haystack
@@ -1,18 +1,18 @@
1
1
  haystack_integrations/common/amazon_bedrock/__init__.py,sha256=6GZ8Y3Lw0rLOsOAqi6Tu5mZC977UzQvgDxKpOWr8IQw,110
2
2
  haystack_integrations/common/amazon_bedrock/errors.py,sha256=ReheDbY7L3EJkWcUoih6lWHjbPHg2TlUs9SnXIKK7Gg,744
3
- haystack_integrations/common/amazon_bedrock/utils.py,sha256=jVXNjzPvYrvWtEJi8FjaW0aBqoTyFp6btp5mXo7F5Go,2687
3
+ haystack_integrations/common/amazon_bedrock/utils.py,sha256=dHUWzHYT0A8_eLDpVkwDhmDpprYbFlWsGg0FOS0uF0I,2720
4
4
  haystack_integrations/components/embedders/amazon_bedrock/__init__.py,sha256=CFqYmAVq2aavlMkZHYScKHOTwwETdRzRZITMqGhJ9Kw,298
5
- haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py,sha256=8Kj3_JiND2IQGdAc8n2GnI244QzbOjNmEHuH_W2T610,13115
6
- haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py,sha256=KRreCYplvxT6Yvg5TMSb4aDz5pqW1zMeV216cR8C6FY,9005
5
+ haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py,sha256=ZdtM6HpHQwbKfjmfOK6gkIQPPbI0n8_pWRrR6lyXmr8,13321
6
+ haystack_integrations/components/embedders/amazon_bedrock/text_embedder.py,sha256=gpvu6IMoycUXrn4r1OH5yEIheiDxHf2T5fdJJO4DfW0,9202
7
7
  haystack_integrations/components/generators/amazon_bedrock/__init__.py,sha256=lv4NouIVm78YavUssWQrHHP_81u-7j21qW8v1kZMJPQ,284
8
8
  haystack_integrations/components/generators/amazon_bedrock/adapters.py,sha256=cnlfmie4HfEX4nipSXSDk_3koy7HYZ-ezimGN6BozQ0,19543
9
- haystack_integrations/components/generators/amazon_bedrock/generator.py,sha256=2L9-ZKJkMI5JxN0xZ8Gn_5MfdjhT4UlFjlrTd9Rdhsg,14514
9
+ haystack_integrations/components/generators/amazon_bedrock/generator.py,sha256=NgywyiKYazEbsLAcGcOPUT4blWhYYOJ9WjO-HWDvu7I,14576
10
10
  haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py,sha256=6GZ8Y3Lw0rLOsOAqi6Tu5mZC977UzQvgDxKpOWr8IQw,110
11
- haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=f60FRraGvPeMBiosclNN3hLv9uQihQcQFETC-gwGieM,20784
12
- haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=Dch2YkrDFu-Dvi9i1bE97taAA1Qf3wMQb4wh2T3Z-lw,11789
13
- haystack_integrations/components/rankers/amazon_bedrock/__init__.py,sha256=DWsCu-dav2wzr13U2H1jlFjdzdjkV0fnJw7DVDRY8RQ,63
14
- haystack_integrations/components/rankers/amazon_bedrock/ranker.py,sha256=kq9h_ApGXBDs_GW66QQpo8JwW9grLpqpFT9uXiwT8_M,11643
15
- amazon_bedrock_haystack-3.5.2.dist-info/METADATA,sha256=gumDQJXyNSXhdWQFhVuVY1OSmhOGJiXP9re3y2uTVE0,2217
16
- amazon_bedrock_haystack-3.5.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- amazon_bedrock_haystack-3.5.2.dist-info/licenses/LICENSE.txt,sha256=B05uMshqTA74s-0ltyHKI6yoPfJ3zYgQbvcXfDVGFf8,10280
18
- amazon_bedrock_haystack-3.5.2.dist-info/RECORD,,
11
+ haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=M9I0sB8LFrXCgoyr5ik2ZPPHyB0b4nJFKX7GARKsk8Y,23384
12
+ haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=kBSaU_ZqzL-7a7nplezjb4XRBy51pt-4VULoX5lq21A,21148
13
+ haystack_integrations/components/rankers/amazon_bedrock/__init__.py,sha256=Zrc3BSVkEaXYpliEi6hKG9bqW4J7DNk93p50SuoyT1Q,107
14
+ haystack_integrations/components/rankers/amazon_bedrock/ranker.py,sha256=x4QEVkbFM-jMFHx-xmk571wtrohnPLtkIWMhCyg4_II,12278
15
+ amazon_bedrock_haystack-3.6.1.dist-info/METADATA,sha256=oRIb-2Nv642N0wmapQgykocVVjxzwM7rG1_TNOoF6Vs,2225
16
+ amazon_bedrock_haystack-3.6.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ amazon_bedrock_haystack-3.6.1.dist-info/licenses/LICENSE.txt,sha256=B05uMshqTA74s-0ltyHKI6yoPfJ3zYgQbvcXfDVGFf8,10280
18
+ amazon_bedrock_haystack-3.6.1.dist-info/RECORD,,
@@ -1,6 +1,6 @@
1
- from typing import Optional
1
+ from typing import Optional, Union
2
2
 
3
- import aioboto3 # type: ignore
3
+ import aioboto3
4
4
  import boto3
5
5
  from botocore.exceptions import BotoCoreError
6
6
 
@@ -23,7 +23,7 @@ def get_aws_session(
23
23
  aws_profile_name: Optional[str] = None,
24
24
  async_mode: bool = False,
25
25
  **kwargs,
26
- ):
26
+ ) -> Union[boto3.Session, aioboto3.Session]:
27
27
  """
28
28
  Creates an AWS Session with the given parameters.
29
29
  Checks if the provided AWS credentials are valid and can be used to connect to AWS.
@@ -75,7 +75,7 @@ class AmazonBedrockDocumentEmbedder:
75
75
  embedding_separator: str = "\n",
76
76
  boto3_config: Optional[Dict[str, Any]] = None,
77
77
  **kwargs,
78
- ):
78
+ ) -> None:
79
79
  """
80
80
  Initializes the AmazonBedrockDocumentEmbedder with the provided parameters. The parameters are passed to the
81
81
  Amazon Bedrock client.
@@ -234,7 +234,7 @@ class AmazonBedrockDocumentEmbedder:
234
234
  return documents
235
235
 
236
236
  @component.output_types(documents=List[Document])
237
- def run(self, documents: List[Document]):
237
+ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
238
238
  """Embed the provided `Document`s using the specified model.
239
239
 
240
240
  :param documents: The `Document`s to embed.
@@ -253,6 +253,9 @@ class AmazonBedrockDocumentEmbedder:
253
253
  documents_with_embeddings = self._embed_cohere(documents=documents)
254
254
  elif "titan" in self.model:
255
255
  documents_with_embeddings = self._embed_titan(documents=documents)
256
+ else:
257
+ msg = f"Model {self.model} is not supported. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}."
258
+ raise ValueError(msg)
256
259
 
257
260
  return {"documents": documents_with_embeddings}
258
261
 
@@ -64,7 +64,7 @@ class AmazonBedrockTextEmbedder:
64
64
  aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
65
65
  boto3_config: Optional[Dict[str, Any]] = None,
66
66
  **kwargs,
67
- ):
67
+ ) -> None:
68
68
  """
69
69
  Initializes the AmazonBedrockTextEmbedder with the provided parameters. The parameters are passed to the
70
70
  Amazon Bedrock client.
@@ -126,7 +126,7 @@ class AmazonBedrockTextEmbedder:
126
126
  raise AmazonBedrockConfigurationError(msg) from exception
127
127
 
128
128
  @component.output_types(embedding=List[float])
129
- def run(self, text: str):
129
+ def run(self, text: str) -> Dict[str, List[float]]:
130
130
  """Embeds the input text using the Amazon Bedrock model.
131
131
 
132
132
  :param text: The input text to embed.
@@ -173,6 +173,9 @@ class AmazonBedrockTextEmbedder:
173
173
  embedding = response_body["embeddings"][0]
174
174
  elif "titan" in self.model:
175
175
  embedding = response_body["embedding"]
176
+ else:
177
+ msg = f"Unsupported model {self.model}. Supported models are: {', '.join(SUPPORTED_EMBEDDING_MODELS)}"
178
+ raise ValueError(msg)
176
179
 
177
180
  return {"embedding": embedding}
178
181
 
@@ -1,20 +1,21 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
2
 
3
+ import aioboto3
3
4
  from botocore.config import Config
4
5
  from botocore.eventstream import EventStream
5
6
  from botocore.exceptions import ClientError
6
7
  from haystack import component, default_from_dict, default_to_dict, logging
7
8
  from haystack.dataclasses import ChatMessage, StreamingCallbackT, select_streaming_callback
8
- from haystack.tools import Tool, _check_duplicate_tool_names
9
+ from haystack.tools import (
10
+ Tool,
11
+ Toolset,
12
+ _check_duplicate_tool_names,
13
+ deserialize_tools_or_toolset_inplace,
14
+ serialize_tools_or_toolset,
15
+ )
9
16
  from haystack.utils.auth import Secret, deserialize_secrets_inplace
10
17
  from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
11
18
 
12
- # Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
13
- try:
14
- from haystack.tools import deserialize_tools_or_toolset_inplace
15
- except ImportError:
16
- from haystack.tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
17
-
18
19
  from haystack_integrations.common.amazon_bedrock.errors import (
19
20
  AmazonBedrockConfigurationError,
20
21
  AmazonBedrockInferenceError,
@@ -140,8 +141,8 @@ class AmazonBedrockChatGenerator:
140
141
  stop_words: Optional[List[str]] = None,
141
142
  streaming_callback: Optional[StreamingCallbackT] = None,
142
143
  boto3_config: Optional[Dict[str, Any]] = None,
143
- tools: Optional[List[Tool]] = None,
144
- ):
144
+ tools: Optional[Union[List[Tool], Toolset]] = None,
145
+ ) -> None:
145
146
  """
146
147
  Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
147
148
  Amazon Bedrock client.
@@ -172,7 +173,7 @@ class AmazonBedrockChatGenerator:
172
173
  [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk) object and switches
173
174
  the streaming mode on.
174
175
  :param boto3_config: The configuration for the boto3 client.
175
- :param tools: A list of Tool objects that the model can use. Each tool should have a unique name.
176
+ :param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
176
177
 
177
178
  :raises ValueError: If the model name is empty or None.
178
179
  :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
@@ -190,7 +191,7 @@ class AmazonBedrockChatGenerator:
190
191
  self.stop_words = stop_words or []
191
192
  self.streaming_callback = streaming_callback
192
193
  self.boto3_config = boto3_config
193
- _check_duplicate_tool_names(tools)
194
+ _check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
194
195
  self.tools = tools
195
196
 
196
197
  def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
@@ -228,7 +229,18 @@ class AmazonBedrockChatGenerator:
228
229
  self.stop_words = stop_words or []
229
230
  self.async_session = None
230
231
 
231
- def _get_async_session(self):
232
+ def _get_async_session(self) -> aioboto3.Session:
233
+ """
234
+ Initializes and returns an asynchronous AWS session for accessing Amazon Bedrock.
235
+
236
+ If the session is already created, it is reused. Otherwise, a new session is created using the provided AWS
237
+ credentials and configuration.
238
+
239
+ :returns:
240
+ An async-compatible boto3 session configured for use with Amazon Bedrock.
241
+ :raises AmazonBedrockConfigurationError:
242
+ If unable to establish an async session due to misconfiguration.
243
+ """
232
244
  if self.async_session:
233
245
  return self.async_session
234
246
 
@@ -260,7 +272,6 @@ class AmazonBedrockChatGenerator:
260
272
  Dictionary with serialized data.
261
273
  """
262
274
  callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
263
- serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
264
275
  return default_to_dict(
265
276
  self,
266
277
  aws_access_key_id=self.aws_access_key_id.to_dict() if self.aws_access_key_id else None,
@@ -273,7 +284,7 @@ class AmazonBedrockChatGenerator:
273
284
  generation_kwargs=self.generation_kwargs,
274
285
  streaming_callback=callback_name,
275
286
  boto3_config=self.boto3_config,
276
- tools=serialized_tools,
287
+ tools=serialize_tools_or_toolset(self.tools),
277
288
  )
278
289
 
279
290
  @classmethod
@@ -301,19 +312,27 @@ class AmazonBedrockChatGenerator:
301
312
  messages: List[ChatMessage],
302
313
  streaming_callback: Optional[StreamingCallbackT] = None,
303
314
  generation_kwargs: Optional[Dict[str, Any]] = None,
304
- tools: Optional[List[Tool]] = None,
315
+ tools: Optional[Union[List[Tool], Toolset]] = None,
305
316
  requires_async: bool = False,
306
317
  ) -> Tuple[Dict[str, Any], Optional[StreamingCallbackT]]:
307
318
  """
308
- Prepares the request parameters for both sync and async run methods.
309
-
310
- :param messages: List of ChatMessage objects representing the conversation history.
311
- :param streaming_callback: Optional callback function for handling streaming responses.
312
- :param generation_kwargs: Optional dictionary of generation parameters.
313
- :param tools: Optional list of Tool objects that the model can use.
314
- :param requires_async: Boolean indicating whether the request is for async execution.
315
- This affects how the streaming callback is selected.
316
- :return: Tuple of (request parameters dict, callback function)
319
+ Prepares and formats parameters required to call the Amazon Bedrock Converse API.
320
+
321
+ This includes merging default and runtime generation parameters, formatting messages and tools, and
322
+ selecting the appropriate streaming callback.
323
+
324
+ :param messages: List of `ChatMessage` objects representing the conversation history.
325
+ :param streaming_callback: Optional streaming callback provided at runtime.
326
+ :param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
327
+ - `maxTokens`: Maximum number of tokens to generate.
328
+ - `stopSequences`: List of stop sequences to stop generation.
329
+ - `temperature`: Sampling temperature.
330
+ - `topP`: Nucleus sampling parameter.
331
+ :param tools: Optional list of Tool objects or a Toolset that the model can use.
332
+ :param requires_async: Boolean flag to indicate if an async-compatible streaming callback function is needed.
333
+
334
+ :returns:
335
+ A tuple of (API-ready parameter dictionary, streaming callback function).
317
336
  """
318
337
  generation_kwargs = generation_kwargs or {}
319
338
 
@@ -331,9 +350,12 @@ class AmazonBedrockChatGenerator:
331
350
 
332
351
  # Handle tools - either toolConfig or Haystack Tool objects but not both
333
352
  tools = tools or self.tools
334
- _check_duplicate_tool_names(tools)
353
+ _check_duplicate_tool_names(list(tools or []))
335
354
  tool_config = merged_kwargs.pop("toolConfig", None)
336
355
  if tools:
356
+ # Convert Toolset to list if needed
357
+ if isinstance(tools, Toolset):
358
+ tools = list(tools)
337
359
  # Format Haystack tools to Bedrock format
338
360
  tool_config = _format_tools(tools)
339
361
 
@@ -369,8 +391,27 @@ class AmazonBedrockChatGenerator:
369
391
  messages: List[ChatMessage],
370
392
  streaming_callback: Optional[StreamingCallbackT] = None,
371
393
  generation_kwargs: Optional[Dict[str, Any]] = None,
372
- tools: Optional[List[Tool]] = None,
373
- ):
394
+ tools: Optional[Union[List[Tool], Toolset]] = None,
395
+ ) -> Dict[str, List[ChatMessage]]:
396
+ """
397
+ Executes a synchronous inference call to the Amazon Bedrock model using the Converse API.
398
+
399
+ Supports both standard and streaming responses depending on whether a streaming callback is provided.
400
+
401
+ :param messages: A list of `ChatMessage` objects forming the chat history.
402
+ :param streaming_callback: Optional callback for handling streaming outputs.
403
+ :param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
404
+ - `maxTokens`: Maximum number of tokens to generate.
405
+ - `stopSequences`: List of stop sequences to stop generation.
406
+ - `temperature`: Sampling temperature.
407
+ - `topP`: Nucleus sampling parameter.
408
+ :param tools: Optional list of Tools that the model may call during execution.
409
+
410
+ :returns:
411
+ A dictionary containing the model-generated replies under the `"replies"` key.
412
+ :raises AmazonBedrockInferenceError:
413
+ If the Bedrock inference API call fails.
414
+ """
374
415
  params, callback = self._prepare_request_params(
375
416
  messages=messages,
376
417
  streaming_callback=streaming_callback,
@@ -402,16 +443,26 @@ class AmazonBedrockChatGenerator:
402
443
  messages: List[ChatMessage],
403
444
  streaming_callback: Optional[StreamingCallbackT] = None,
404
445
  generation_kwargs: Optional[Dict[str, Any]] = None,
405
- tools: Optional[List[Tool]] = None,
406
- ):
446
+ tools: Optional[Union[List[Tool], Toolset]] = None,
447
+ ) -> Dict[str, List[ChatMessage]]:
407
448
  """
408
- Async version of the run method. Completes chats using LLMs hosted on Amazon Bedrock.
449
+ Executes an asynchronous inference call to the Amazon Bedrock model using the Converse API.
450
+
451
+ Designed for use cases where non-blocking or concurrent execution is desired.
409
452
 
410
- :param messages: List of ChatMessage objects representing the conversation history.
411
- :param streaming_callback: Optional callback function for handling streaming responses.
412
- :param generation_kwargs: Optional dictionary of generation parameters.
413
- :param tools: Optional list of Tool objects that the model can use.
414
- :return: Dictionary containing the model's replies as a list of ChatMessage objects.
453
+ :param messages: A list of `ChatMessage` objects forming the chat history.
454
+ :param streaming_callback: Optional async-compatible callback for handling streaming outputs.
455
+ :param generation_kwargs: Optional dictionary of generation parameters. Some common parameters are:
456
+ - `maxTokens`: Maximum number of tokens to generate.
457
+ - `stopSequences`: List of stop sequences to stop generation.
458
+ - `temperature`: Sampling temperature.
459
+ - `topP`: Nucleus sampling parameter.
460
+ :param tools: Optional list of Tool objects or a Toolset that the model can use.
461
+
462
+ :returns:
463
+ A dictionary containing the model-generated replies under the `"replies"` key.
464
+ :raises AmazonBedrockInferenceError:
465
+ If the Bedrock inference API call fails.
415
466
  """
416
467
  params, callback = self._prepare_request_params(
417
468
  messages=messages,
@@ -1,17 +1,30 @@
1
1
  import json
2
- from typing import Any, Callable, Dict, List, Optional, Tuple
2
+ from datetime import datetime, timezone
3
+ from typing import Any, Dict, List, Optional, Tuple
3
4
 
4
5
  from botocore.eventstream import EventStream
5
- from haystack.dataclasses import ChatMessage, ChatRole, StreamingCallbackT, StreamingChunk, ToolCall
6
+ from haystack import logging
7
+ from haystack.dataclasses import (
8
+ AsyncStreamingCallbackT,
9
+ ChatMessage,
10
+ ChatRole,
11
+ StreamingChunk,
12
+ SyncStreamingCallbackT,
13
+ ToolCall,
14
+ )
6
15
  from haystack.tools import Tool
7
16
 
17
+ logger = logging.getLogger(__name__)
8
18
 
19
+
20
+ # Haystack to Bedrock util methods
9
21
  def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]:
10
22
  """
11
23
  Format Haystack Tool(s) to Amazon Bedrock toolConfig format.
12
24
 
13
25
  :param tools: List of Tool objects to format
14
- :return: Dictionary in Bedrock toolConfig format or None if no tools
26
+ :returns:
27
+ Dictionary in Bedrock toolConfig format or None if no tools are provided
15
28
  """
16
29
  if not tools:
17
30
  return None
@@ -25,69 +38,160 @@ def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]
25
38
  return {"tools": tool_specs} if tool_specs else None
26
39
 
27
40
 
41
+ def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]:
42
+ """
43
+ Format a Haystack ChatMessage containing tool calls into Bedrock format.
44
+
45
+ :param tool_call_message: ChatMessage object containing tool calls to be formatted.
46
+ :returns:
47
+ Dictionary representing the tool call message in Bedrock's expected format
48
+ """
49
+ content = []
50
+ # Tool call message can contain text
51
+ if tool_call_message.text:
52
+ content.append({"text": tool_call_message.text})
53
+
54
+ for tool_call in tool_call_message.tool_calls:
55
+ content.append(
56
+ {"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}}
57
+ )
58
+ return {"role": tool_call_message.role.value, "content": content}
59
+
60
+
61
+ def _format_tool_result_message(tool_call_result_message: ChatMessage) -> Dict[str, Any]:
62
+ """
63
+ Format a Haystack ChatMessage containing tool call results into Bedrock format.
64
+
65
+ :param tool_call_result_message: ChatMessage object containing tool call results to be formatted.
66
+ :returns: Dictionary representing the tool result message in Bedrock's expected format
67
+ """
68
+ # Assuming tool call result messages will only contain tool results
69
+ tool_results = []
70
+ for result in tool_call_result_message.tool_call_results:
71
+ try:
72
+ json_result = json.loads(result.result)
73
+ content = [{"json": json_result}]
74
+ except json.JSONDecodeError:
75
+ content = [{"text": result.result}]
76
+
77
+ tool_results.append(
78
+ {
79
+ "toolResult": {
80
+ "toolUseId": result.origin.id,
81
+ "content": content,
82
+ **({"status": "error"} if result.error else {}),
83
+ }
84
+ }
85
+ )
86
+ # role must be user
87
+ return {"role": "user", "content": tool_results}
88
+
89
+
90
+ def _repair_tool_result_messages(bedrock_formatted_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
91
+ """
92
+ Repair and reorganize tool result messages to maintain proper ordering and grouping.
93
+
94
+ Ensures tool result messages are properly grouped in the same way as their corresponding tool call messages
95
+ and maintains the original message ordering.
96
+
97
+ :param bedrock_formatted_messages: List of Bedrock-formatted messages that may need repair.
98
+ :returns: List of properly organized Bedrock-formatted messages with correctly grouped tool results.
99
+ """
100
+ tool_call_messages = []
101
+ tool_result_messages = []
102
+ for idx, msg in enumerate(bedrock_formatted_messages):
103
+ content = msg.get("content", [])
104
+ if content:
105
+ if any("toolUse" in c for c in content):
106
+ tool_call_messages.append((idx, msg))
107
+ elif any("toolResult" in c for c in content):
108
+ tool_result_messages.append((idx, msg))
109
+
110
+ # Determine the tool call IDs for each tool call message
111
+ group_to_tool_call_ids: Dict[int, Any] = {idx: [] for idx, _ in tool_call_messages}
112
+ for idx, tool_call in tool_call_messages:
113
+ tool_use_contents = [c for c in tool_call["content"] if "toolUse" in c]
114
+ for content in tool_use_contents:
115
+ group_to_tool_call_ids[idx].append(content["toolUse"]["toolUseId"])
116
+
117
+ # Regroups the tool_result_prompts based on the tool_call_prompts
118
+ # Makes sure to:
119
+ # - Within the new group the tool call IDs of the tool result messages are in the same order as the tool call
120
+ # messages
121
+ # - The tool result messages are in the same order as the original message list
122
+ repaired_tool_result_prompts = []
123
+ for tool_call_ids in group_to_tool_call_ids.values():
124
+ regrouped_tool_result = []
125
+ original_idx = None
126
+ for tool_call_id in tool_call_ids:
127
+ for idx, tool_result in tool_result_messages:
128
+ tool_result_contents = [c for c in tool_result["content"] if "toolResult" in c]
129
+ for content in tool_result_contents:
130
+ if content["toolResult"]["toolUseId"] == tool_call_id:
131
+ regrouped_tool_result.append(content)
132
+ # Keep track of the original index of the last tool result message
133
+ original_idx = idx
134
+ if regrouped_tool_result and original_idx is not None:
135
+ repaired_tool_result_prompts.append((original_idx, {"role": "user", "content": regrouped_tool_result}))
136
+
137
+ # Remove the tool result messages from bedrock_formatted_messages
138
+ bedrock_formatted_messages_minus_tool_results: List[Tuple[int, Any]] = []
139
+ for idx, msg in enumerate(bedrock_formatted_messages):
140
+ # Assumes the content of tool result messages only contains 'toolResult': {...} objects (e.g. no 'text')
141
+ if msg.get("content") and "toolResult" not in msg["content"][0]:
142
+ bedrock_formatted_messages_minus_tool_results.append((idx, msg))
143
+
144
+ # Add the repaired tool result messages and sort to maintain the correct order
145
+ repaired_bedrock_formatted_messages = bedrock_formatted_messages_minus_tool_results + repaired_tool_result_prompts
146
+ repaired_bedrock_formatted_messages.sort(key=lambda x: x[0])
147
+
148
+ # Drop the index and return only the messages
149
+ return [msg for _, msg in repaired_bedrock_formatted_messages]
150
+
151
+
28
152
  def _format_messages(messages: List[ChatMessage]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
29
153
  """
30
- Format a list of ChatMessages to the format expected by Bedrock API.
31
- Separates system messages and handles tool results and tool calls.
154
+ Format a list of Haystack ChatMessages to the format expected by Bedrock API.
32
155
 
33
- :param messages: List of ChatMessages to format
34
- :return: Tuple of (system_prompts, non_system_messages) in Bedrock format
156
+ Processes and separates system messages from other message types and handles special formatting for tool calls
157
+ and tool results.
158
+
159
+ :param messages: List of ChatMessage objects to format for Bedrock API.
160
+ :returns: Tuple containing (system_prompts, non_system_messages) in Bedrock format,
161
+ where system_prompts is a list of system message dictionaries and
162
+ non_system_messages is a list of properly formatted message dictionaries.
35
163
  """
164
+ # Separate system messages, tool calls, and tool results
36
165
  system_prompts = []
37
- non_system_messages = []
38
-
166
+ bedrock_formatted_messages = []
39
167
  for msg in messages:
40
168
  if msg.is_from(ChatRole.SYSTEM):
169
+ # Assuming system messages can only contain text
170
+ # Don't need to track idx since system_messages are handled separately
41
171
  system_prompts.append({"text": msg.text})
42
- continue
43
-
44
- # Handle tool results - must role these as user messages
45
- if msg.tool_call_results:
46
- tool_results = []
47
- for result in msg.tool_call_results:
48
- try:
49
- json_result = json.loads(result.result)
50
- content = [{"json": json_result}]
51
- except json.JSONDecodeError:
52
- content = [{"text": result.result}]
53
-
54
- tool_results.append(
55
- {
56
- "toolResult": {
57
- "toolUseId": result.origin.id,
58
- "content": content,
59
- **({"status": "error"} if result.error else {}),
60
- }
61
- }
62
- )
63
- non_system_messages.append({"role": "user", "content": tool_results})
64
- continue
65
-
66
- content = []
67
- # Handle text content
68
- if msg.text:
69
- content.append({"text": msg.text})
70
-
71
- # Handle tool calls
72
- if msg.tool_calls:
73
- for tool_call in msg.tool_calls:
74
- content.append(
75
- {"toolUse": {"toolUseId": tool_call.id, "name": tool_call.tool_name, "input": tool_call.arguments}}
76
- )
77
-
78
- if content: # Only add message if it has content
79
- non_system_messages.append({"role": msg.role.value, "content": content})
172
+ elif msg.tool_calls:
173
+ bedrock_formatted_messages.append(_format_tool_call_message(msg))
174
+ elif msg.tool_call_results:
175
+ bedrock_formatted_messages.append(_format_tool_result_message(msg))
176
+ else:
177
+ # regular user or assistant messages with only text content
178
+ bedrock_formatted_messages.append({"role": msg.role.value, "content": [{"text": msg.text}]})
80
179
 
81
- return system_prompts, non_system_messages
180
+ repaired_bedrock_formatted_messages = _repair_tool_result_messages(bedrock_formatted_messages)
181
+ return system_prompts, repaired_bedrock_formatted_messages
82
182
 
83
183
 
184
+ # Bedrock to Haystack util method
84
185
  def _parse_completion_response(response_body: Dict[str, Any], model: str) -> List[ChatMessage]:
85
186
  """
86
- Parse a Bedrock response to a list of ChatMessage objects.
187
+ Parse a Bedrock API response into Haystack ChatMessage objects.
87
188
 
88
- :param response_body: Raw response from Bedrock API
89
- :param model: The model ID used for generation
90
- :return: List of ChatMessage objects
189
+ Extracts text content, tool calls, and metadata from the Bedrock response and converts them into the appropriate
190
+ Haystack format.
191
+
192
+ :param response_body: Raw JSON response from Bedrock API.
193
+ :param model: The model ID used for generation, included in message metadata.
194
+ :returns: List of ChatMessage objects containing the assistant's response(s) with appropriate metadata.
91
195
  """
92
196
  replies = []
93
197
  if "output" in response_body and "message" in response_body["output"]:
@@ -130,9 +234,209 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
130
234
  return replies
131
235
 
132
236
 
237
+ # Bedrock streaming to Haystack util methods
238
+ def _convert_event_to_streaming_chunk(event: Dict[str, Any], model: str) -> StreamingChunk:
239
+ """
240
+ Convert a Bedrock streaming event to a Haystack StreamingChunk.
241
+
242
+ Handles different event types (contentBlockStart, contentBlockDelta, messageStop, metadata) and extracts relevant
243
+ information to create StreamingChunk objects in the same format used by Haystack's OpenAIChatGenerator.
244
+
245
+ :param event: Dictionary containing a Bedrock streaming event.
246
+ :param model: The model ID used for generation, included in chunk metadata.
247
+ :returns: StreamingChunk object containing the content and metadata extracted from the event.
248
+ """
249
+ # Initialize an empty StreamingChunk to return if no relevant event is found
250
+ # (e.g. for messageStart and contentBlockStop)
251
+ streaming_chunk = StreamingChunk(
252
+ content="", meta={"model": model, "received_at": datetime.now(timezone.utc).isoformat()}
253
+ )
254
+
255
+ if "contentBlockStart" in event:
256
+ # contentBlockStart always has the key "contentBlockIndex"
257
+ block_start = event["contentBlockStart"]
258
+ block_idx = block_start["contentBlockIndex"]
259
+ if "start" in block_start and "toolUse" in block_start["start"]:
260
+ tool_start = block_start["start"]["toolUse"]
261
+ streaming_chunk = StreamingChunk(
262
+ content="",
263
+ meta={
264
+ "model": model,
265
+ # This is always 0 b/c it represents the choice index
266
+ "index": 0,
267
+ # We follow the same format used in the OpenAIChatGenerator
268
+ "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
269
+ {
270
+ "index": block_idx, # int
271
+ "id": tool_start["toolUseId"], # Optional[str]
272
+ "function": { # Optional[ChoiceDeltaToolCallFunction]
273
+ # Will accumulate deltas as string
274
+ "arguments": "", # Optional[str]
275
+ "name": tool_start["name"], # Optional[str]
276
+ },
277
+ "type": "function", # Optional[Literal["function"]]
278
+ }
279
+ ],
280
+ "finish_reason": None,
281
+ "received_at": datetime.now(timezone.utc).isoformat(),
282
+ },
283
+ )
284
+
285
+ elif "contentBlockDelta" in event:
286
+ # contentBlockDelta always has the key "contentBlockIndex" and "delta"
287
+ block_idx = event["contentBlockDelta"]["contentBlockIndex"]
288
+ delta = event["contentBlockDelta"]["delta"]
289
+ # This is for accumulating text deltas
290
+ if "text" in delta:
291
+ streaming_chunk = StreamingChunk(
292
+ content=delta["text"],
293
+ meta={
294
+ "model": model,
295
+ # This is always 0 b/c it represents the choice index
296
+ "index": 0,
297
+ "tool_calls": None,
298
+ "finish_reason": None,
299
+ "received_at": datetime.now(timezone.utc).isoformat(),
300
+ },
301
+ )
302
+ # This only occurs when accumulating the arguments for a toolUse
303
+ # The content_block for this tool should already exist at this point
304
+ elif "toolUse" in delta:
305
+ streaming_chunk = StreamingChunk(
306
+ content="",
307
+ meta={
308
+ "model": model,
309
+ # This is always 0 b/c it represents the choice index
310
+ "index": 0,
311
+ "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
312
+ {
313
+ "index": block_idx, # int
314
+ "id": None, # Optional[str]
315
+ "function": { # Optional[ChoiceDeltaToolCallFunction]
316
+ # Will accumulate deltas as string
317
+ "arguments": delta["toolUse"].get("input", ""), # Optional[str]
318
+ "name": None, # Optional[str]
319
+ },
320
+ "type": "function", # Optional[Literal["function"]]
321
+ }
322
+ ],
323
+ "finish_reason": None,
324
+ "received_at": datetime.now(timezone.utc).isoformat(),
325
+ },
326
+ )
327
+
328
+ elif "messageStop" in event:
329
+ finish_reason = event["messageStop"].get("stopReason")
330
+ streaming_chunk = StreamingChunk(
331
+ content="",
332
+ meta={
333
+ "model": model,
334
+ # This is always 0 b/c it represents the choice index
335
+ "index": 0,
336
+ "tool_calls": None,
337
+ "finish_reason": finish_reason,
338
+ "received_at": datetime.now(timezone.utc).isoformat(),
339
+ },
340
+ )
341
+
342
+ elif "metadata" in event and "usage" in event["metadata"]:
343
+ metadata = event["metadata"]
344
+ streaming_chunk = StreamingChunk(
345
+ content="",
346
+ meta={
347
+ "model": model,
348
+ # This is always 0 b/c it represents the choice index
349
+ "index": 0,
350
+ "tool_calls": None,
351
+ "finish_reason": None,
352
+ "received_at": datetime.now(timezone.utc).isoformat(),
353
+ "usage": {
354
+ "prompt_tokens": metadata["usage"].get("inputTokens", 0),
355
+ "completion_tokens": metadata["usage"].get("outputTokens", 0),
356
+ "total_tokens": metadata["usage"].get("totalTokens", 0),
357
+ },
358
+ },
359
+ )
360
+
361
+ return streaming_chunk
362
+
363
+
364
+ def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage:
365
+ """
366
+ Converts a list of streaming chunks into a ChatMessage object.
367
+
368
+ The function processes streaming chunks to build a ChatMessage object, including extracting and constructing
369
+ tool calls, managing metadata such as model type, finish reason, and usage information.
370
+ The tool call processing handles accumulating data across the chunks and attempts to parse JSON-formatted
371
+ arguments for tool calls.
372
+
373
+ :param chunks: A list of StreamingChunk objects representing parts of the assistant's response.
374
+
375
+ :returns:
376
+ A ChatMessage object constructed from the streaming chunks, containing the aggregated text, processed tool
377
+ calls, and metadata.
378
+ """
379
+ text = "".join([chunk.content for chunk in chunks])
380
+
381
+ # Process tool calls if present in any chunk
382
+ tool_calls = []
383
+ tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
384
+ for chunk_payload in chunks:
385
+ tool_calls_meta = chunk_payload.meta.get("tool_calls")
386
+ if tool_calls_meta is not None:
387
+ for delta in tool_calls_meta:
388
+ # We use the index of the tool call to track it across chunks since the ID is not always provided
389
+ if delta["index"] not in tool_call_data:
390
+ tool_call_data[delta["index"]] = {"id": "", "name": "", "arguments": ""}
391
+
392
+ # Save the ID if present
393
+ if delta.get("id"):
394
+ tool_call_data[delta["index"]]["id"] = delta["id"]
395
+
396
+ if delta.get("function"):
397
+ if delta["function"].get("name"):
398
+ tool_call_data[delta["index"]]["name"] += delta["function"]["name"]
399
+ if delta["function"].get("arguments"):
400
+ tool_call_data[delta["index"]]["arguments"] += delta["function"]["arguments"]
401
+
402
+ # Convert accumulated tool call data into ToolCall objects
403
+ for call_data in tool_call_data.values():
404
+ try:
405
+ arguments = json.loads(call_data["arguments"])
406
+ tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
407
+ except json.JSONDecodeError:
408
+ logger.warning(
409
+ "Amazon Bedrock returned a malformed JSON string for tool call arguments. This tool call will be "
410
+ "skipped. Tool call ID: {tool_id}, Tool name: {tool_name}, Arguments: {tool_arguments}",
411
+ tool_id=call_data["id"],
412
+ tool_name=call_data["name"],
413
+ tool_arguments=call_data["arguments"],
414
+ )
415
+
416
+ # finish_reason can appear in different places so we look for the last one
417
+ finish_reasons = [
418
+ chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
419
+ ]
420
+ finish_reason = finish_reasons[-1] if finish_reasons else None
421
+
422
+ # usage is usually last but we look for it as well
423
+ usages = [chunk.meta.get("usage") for chunk in chunks if chunk.meta.get("usage") is not None]
424
+ usage = usages[-1] if usages else None
425
+
426
+ meta = {
427
+ "model": chunks[-1].meta["model"],
428
+ "index": 0,
429
+ "finish_reason": finish_reason,
430
+ "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
431
+ "usage": usage,
432
+ }
433
+
434
+ return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
435
+
436
+
133
437
  def _parse_streaming_response(
134
438
  response_stream: EventStream,
135
- streaming_callback: Callable[[StreamingChunk], None],
439
+ streaming_callback: SyncStreamingCallbackT,
136
440
  model: str,
137
441
  ) -> List[ChatMessage]:
138
442
  """
@@ -143,78 +447,18 @@ def _parse_streaming_response(
143
447
  :param model: The model ID used for generation
144
448
  :return: List of ChatMessage objects
145
449
  """
146
- replies = []
147
- current_content = ""
148
- current_tool_call: Optional[Dict[str, Any]] = None
149
- base_meta = {"model": model, "index": 0}
150
-
450
+ chunks: List[StreamingChunk] = []
151
451
  for event in response_stream:
152
- if "contentBlockStart" in event:
153
- # Reset accumulators for new message
154
- current_content = ""
155
- current_tool_call = None
156
- block_start = event["contentBlockStart"]
157
- if "start" in block_start and "toolUse" in block_start["start"]:
158
- tool_start = block_start["start"]["toolUse"]
159
- current_tool_call = {
160
- "id": tool_start["toolUseId"],
161
- "name": tool_start["name"],
162
- "arguments": "", # Will accumulate deltas as string
163
- }
164
-
165
- elif "contentBlockDelta" in event:
166
- delta = event["contentBlockDelta"]["delta"]
167
- if "text" in delta:
168
- delta_text = delta["text"]
169
- current_content += delta_text
170
- streaming_chunk = StreamingChunk(content=delta_text, meta={})
171
- streaming_callback(streaming_chunk)
172
- elif "toolUse" in delta and current_tool_call:
173
- # Accumulate tool use input deltas
174
- current_tool_call["arguments"] += delta["toolUse"].get("input", "")
175
-
176
- elif "contentBlockStop" in event:
177
- if current_tool_call:
178
- # Parse accumulated input if it's a JSON string
179
- try:
180
- input_json = json.loads(current_tool_call["arguments"])
181
- current_tool_call["arguments"] = input_json
182
- except json.JSONDecodeError:
183
- # Keep as string if not valid JSON
184
- pass
185
-
186
- tool_call = ToolCall(
187
- id=current_tool_call["id"],
188
- tool_name=current_tool_call["name"],
189
- arguments=current_tool_call["arguments"],
190
- )
191
- replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy()))
192
- elif current_content:
193
- replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))
194
-
195
- elif "messageStop" in event:
196
- # Update finish reason for all replies
197
- for reply in replies:
198
- reply.meta["finish_reason"] = event["messageStop"].get("stopReason")
199
-
200
- elif "metadata" in event:
201
- metadata = event["metadata"]
202
- # Update usage stats for all replies
203
- for reply in replies:
204
- if "usage" in metadata:
205
- usage = metadata["usage"]
206
- reply.meta["usage"] = {
207
- "prompt_tokens": usage.get("inputTokens", 0),
208
- "completion_tokens": usage.get("outputTokens", 0),
209
- "total_tokens": usage.get("totalTokens", 0),
210
- }
211
-
452
+ streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model)
453
+ streaming_callback(streaming_chunk)
454
+ chunks.append(streaming_chunk)
455
+ replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
212
456
  return replies
213
457
 
214
458
 
215
459
  async def _parse_streaming_response_async(
216
460
  response_stream: EventStream,
217
- streaming_callback: StreamingCallbackT,
461
+ streaming_callback: AsyncStreamingCallbackT,
218
462
  model: str,
219
463
  ) -> List[ChatMessage]:
220
464
  """
@@ -225,70 +469,10 @@ async def _parse_streaming_response_async(
225
469
  :param model: The model ID used for generation
226
470
  :return: List of ChatMessage objects
227
471
  """
228
- replies = []
229
- current_content = ""
230
- current_tool_call: Optional[Dict[str, Any]] = None
231
- base_meta = {"model": model, "index": 0}
232
-
472
+ chunks: List[StreamingChunk] = []
233
473
  async for event in response_stream:
234
- if "contentBlockStart" in event:
235
- # Reset accumulators for new message
236
- current_content = ""
237
- current_tool_call = None
238
- block_start = event["contentBlockStart"]
239
- if "start" in block_start and "toolUse" in block_start["start"]:
240
- tool_start = block_start["start"]["toolUse"]
241
- current_tool_call = {
242
- "id": tool_start["toolUseId"],
243
- "name": tool_start["name"],
244
- "arguments": "", # Will accumulate deltas as string
245
- }
246
-
247
- elif "contentBlockDelta" in event:
248
- delta = event["contentBlockDelta"]["delta"]
249
- if "text" in delta:
250
- delta_text = delta["text"]
251
- current_content += delta_text
252
- streaming_chunk = StreamingChunk(content=delta_text, meta={})
253
- await streaming_callback(streaming_chunk)
254
- elif "toolUse" in delta and current_tool_call:
255
- # Accumulate tool use input deltas
256
- current_tool_call["arguments"] += delta["toolUse"].get("input", "")
257
-
258
- elif "contentBlockStop" in event:
259
- if current_tool_call:
260
- # Parse accumulated input if it's a JSON string
261
- try:
262
- input_json = json.loads(current_tool_call["arguments"])
263
- current_tool_call["arguments"] = input_json
264
- except json.JSONDecodeError:
265
- # Keep as string if not valid JSON
266
- pass
267
-
268
- tool_call = ToolCall(
269
- id=current_tool_call["id"],
270
- tool_name=current_tool_call["name"],
271
- arguments=current_tool_call["arguments"],
272
- )
273
- replies.append(ChatMessage.from_assistant("", tool_calls=[tool_call], meta=base_meta.copy()))
274
- elif current_content:
275
- replies.append(ChatMessage.from_assistant(current_content, meta=base_meta.copy()))
276
-
277
- elif "messageStop" in event:
278
- # Update finish reason for all replies
279
- for reply in replies:
280
- reply.meta["finish_reason"] = event["messageStop"].get("stopReason")
281
-
282
- elif "metadata" in event:
283
- metadata = event["metadata"]
284
- # Update usage stats for all replies
285
- for reply in replies:
286
- if "usage" in metadata:
287
- usage = metadata["usage"]
288
- reply.meta["usage"] = {
289
- "prompt_tokens": usage.get("inputTokens", 0),
290
- "completion_tokens": usage.get("outputTokens", 0),
291
- "total_tokens": usage.get("totalTokens", 0),
292
- }
293
-
474
+ streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model)
475
+ await streaming_callback(streaming_chunk)
476
+ chunks.append(streaming_chunk)
477
+ replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
294
478
  return replies
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import re
3
3
  import warnings
4
- from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args
4
+ from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, Union, get_args
5
5
 
6
6
  from botocore.config import Config
7
7
  from botocore.exceptions import ClientError
@@ -108,7 +108,7 @@ class AmazonBedrockGenerator:
108
108
  boto3_config: Optional[Dict[str, Any]] = None,
109
109
  model_family: Optional[MODEL_FAMILIES] = None,
110
110
  **kwargs,
111
- ):
111
+ ) -> None:
112
112
  """
113
113
  Create a new `AmazonBedrockGenerator` instance.
114
114
 
@@ -189,7 +189,7 @@ class AmazonBedrockGenerator:
189
189
  prompt: str,
190
190
  streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
191
191
  generation_kwargs: Optional[Dict[str, Any]] = None,
192
- ):
192
+ ) -> Dict[str, Union[List[str], Dict[str, Any]]]:
193
193
  """
194
194
  Generates a list of string response to the given prompt.
195
195
 
@@ -1,3 +1,3 @@
1
- from .ranker import BedrockRanker
1
+ from .ranker import AmazonBedrockRanker, BedrockRanker
2
2
 
3
- __all__ = ["BedrockRanker"]
3
+ __all__ = ["AmazonBedrockRanker", "BedrockRanker"]
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  from typing import Any, Dict, List, Optional
2
3
 
3
4
  from botocore.exceptions import ClientError
@@ -16,7 +17,7 @@ MAX_NUM_DOCS_FOR_BEDROCK_RANKER = 1000
16
17
 
17
18
 
18
19
  @component
19
- class BedrockRanker:
20
+ class AmazonBedrockRanker:
20
21
  """
21
22
  Ranks Documents based on their similarity to the query using Amazon Bedrock's Cohere Rerank model.
22
23
 
@@ -30,9 +31,13 @@ class BedrockRanker:
30
31
  ```python
31
32
  from haystack import Document
32
33
  from haystack.utils import Secret
33
- from haystack_integrations.components.rankers.amazon_bedrock import BedrockRanker
34
+ from haystack_integrations.components.rankers.amazon_bedrock import AmazonBedrockRanker
34
35
 
35
- ranker = BedrockRanker(model="cohere.rerank-v3-5:0", top_k=2, aws_region_name=Secret.from_token("eu-central-1"))
36
+ ranker = AmazonBedrockRanker(
37
+ model="cohere.rerank-v3-5:0",
38
+ top_k=2,
39
+ aws_region_name=Secret.from_token("eu-central-1")
40
+ )
36
41
 
37
42
  docs = [Document(content="Paris"), Document(content="Berlin")]
38
43
  query = "What is the capital of germany?"
@@ -40,7 +45,7 @@ class BedrockRanker:
40
45
  docs = output["documents"]
41
46
  ```
42
47
 
43
- BedrockRanker uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM.
48
+ AmazonBedrockRanker uses AWS for authentication. You can use the AWS CLI to authenticate through your IAM.
44
49
  For more information on setting up an IAM identity-based policy, see [Amazon Bedrock documentation]
45
50
  (https://docs.aws.amazon.com/bedrock/latest/userguide/security_iam_id-based-policy-examples.html).
46
51
 
@@ -66,12 +71,12 @@ class BedrockRanker:
66
71
  max_chunks_per_doc: Optional[int] = None,
67
72
  meta_fields_to_embed: Optional[List[str]] = None,
68
73
  meta_data_separator: str = "\n",
69
- ):
74
+ ) -> None:
70
75
  if not model:
71
76
  msg = "'model' cannot be None or empty string"
72
77
  raise ValueError(msg)
73
78
  """
74
- Creates an instance of the 'BedrockRanker'.
79
+ Creates an instance of the 'AmazonBedrockRanker'.
75
80
 
76
81
  :param model: Amazon Bedrock model name for Cohere Rerank. Default is "cohere.rerank-v3-5:0".
77
82
  :param top_k: The maximum number of documents to return.
@@ -140,7 +145,7 @@ class BedrockRanker:
140
145
  )
141
146
 
142
147
  @classmethod
143
- def from_dict(cls, data: Dict[str, Any]) -> "BedrockRanker":
148
+ def from_dict(cls, data: Dict[str, Any]) -> "AmazonBedrockRanker":
144
149
  """
145
150
  Deserializes the component from a dictionary.
146
151
 
@@ -173,7 +178,7 @@ class BedrockRanker:
173
178
  return concatenated_input_list
174
179
 
175
180
  @component.output_types(documents=List[Document])
176
- def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
181
+ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict[str, List[Document]]:
177
182
  """
178
183
  Use the Amazon Bedrock Reranker to re-rank the list of documents based on the query.
179
184
 
@@ -260,3 +265,20 @@ class BedrockRanker:
260
265
  except Exception as e:
261
266
  msg = f"Error during Amazon Bedrock API call: {e!s}"
262
267
  raise AmazonBedrockInferenceError(msg) from e
268
+
269
+
270
+ class BedrockRanker(AmazonBedrockRanker):
271
+ """
272
+ Deprecated alias for AmazonBedrockRanker.
273
+ This class will be removed in a future version.
274
+ Please use AmazonBedrockRanker instead.
275
+ """
276
+
277
+ def __init__(self, *args, **kwargs):
278
+ warnings.warn(
279
+ "BedrockRanker is deprecated and will be removed in a future version. "
280
+ "Please use AmazonBedrockRanker instead.",
281
+ DeprecationWarning,
282
+ stacklevel=2,
283
+ )
284
+ super().__init__(*args, **kwargs)