amazon-bedrock-haystack 3.11.0__py3-none-any.whl → 4.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: amazon-bedrock-haystack
3
- Version: 3.11.0
3
+ Version: 4.1.0
4
4
  Summary: An integration of Amazon Bedrock as an AmazonBedrockGenerator component.
5
5
  Project-URL: Documentation, https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme
6
6
  Project-URL: Issues, https://github.com/deepset-ai/haystack-core-integrations/issues
@@ -21,7 +21,7 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
21
21
  Requires-Python: >=3.9
22
22
  Requires-Dist: aioboto3>=14.0.0
23
23
  Requires-Dist: boto3>=1.28.57
24
- Requires-Dist: haystack-ai>=2.16.0
24
+ Requires-Dist: haystack-ai>=2.17.1
25
25
  Description-Content-Type: text/markdown
26
26
 
27
27
  # amazon-bedrock-haystack
@@ -31,7 +31,7 @@ Description-Content-Type: text/markdown
31
31
 
32
32
  - [Integration page](https://haystack.deepset.ai/integrations/amazon-bedrock)
33
33
  - [Changelog](https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/amazon_bedrock/CHANGELOG.md)
34
- -----
34
+ ---
35
35
 
36
36
  ## Contributing
37
37
 
@@ -12,12 +12,12 @@ haystack_integrations/components/generators/amazon_bedrock/__init__.py,sha256=lv
12
12
  haystack_integrations/components/generators/amazon_bedrock/adapters.py,sha256=yBC-3YwV6qAwSXMtdZiLSYh2lUpPQIDy7Efl7w-Cu-k,19640
13
13
  haystack_integrations/components/generators/amazon_bedrock/generator.py,sha256=Brzw0XvtPJhz2kR2I3liAqWHRmDR6p5HzJerEAPhoJU,14743
14
14
  haystack_integrations/components/generators/amazon_bedrock/chat/__init__.py,sha256=6GZ8Y3Lw0rLOsOAqi6Tu5mZC977UzQvgDxKpOWr8IQw,110
15
- haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=_0dpBoZGY9kgK9zQOTskcjElcTifwhyBAixXDliK-vY,24918
16
- haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=g2SZV8LdLobaCZpwWCreBJn1BtS1V3-wQkpisStJrcY,29015
15
+ haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py,sha256=qArwfXcforWnPzLXrAW-1hkPFpMy3NSdDyJ5GOta25w,26068
16
+ haystack_integrations/components/generators/amazon_bedrock/chat/utils.py,sha256=1M_k8CG2WH23Yz-sB7a1kiIqVh2QB8Pqi0zbWXyMUL8,27255
17
17
  haystack_integrations/components/rankers/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  haystack_integrations/components/rankers/amazon_bedrock/__init__.py,sha256=Zrc3BSVkEaXYpliEi6hKG9bqW4J7DNk93p50SuoyT1Q,107
19
19
  haystack_integrations/components/rankers/amazon_bedrock/ranker.py,sha256=enAjf2QyDwfpidKkFCdLz954cx-Tjh9emrOS3vINJDg,12344
20
- amazon_bedrock_haystack-3.11.0.dist-info/METADATA,sha256=5nA_v2Ze5xk1p-RQxbshQ0XGa3LYFljVGvNi2VvKU7o,2225
21
- amazon_bedrock_haystack-3.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
- amazon_bedrock_haystack-3.11.0.dist-info/licenses/LICENSE.txt,sha256=B05uMshqTA74s-0ltyHKI6yoPfJ3zYgQbvcXfDVGFf8,10280
23
- amazon_bedrock_haystack-3.11.0.dist-info/RECORD,,
20
+ amazon_bedrock_haystack-4.1.0.dist-info/METADATA,sha256=P6e8VfoRQ0hZrA6gkaTwjCCxZez2NqKjHNAJuFdwK6c,2222
21
+ amazon_bedrock_haystack-4.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
22
+ amazon_bedrock_haystack-4.1.0.dist-info/licenses/LICENSE.txt,sha256=B05uMshqTA74s-0ltyHKI6yoPfJ3zYgQbvcXfDVGFf8,10280
23
+ amazon_bedrock_haystack-4.1.0.dist-info/RECORD,,
@@ -27,6 +27,7 @@ from haystack_integrations.components.generators.amazon_bedrock.chat.utils impor
27
27
  _parse_completion_response,
28
28
  _parse_streaming_response,
29
29
  _parse_streaming_response_async,
30
+ _validate_guardrail_config,
30
31
  )
31
32
 
32
33
  logger = logging.getLogger(__name__)
@@ -154,10 +155,11 @@ class AmazonBedrockChatGenerator:
154
155
  aws_region_name: Optional[Secret] = Secret.from_env_var(["AWS_DEFAULT_REGION"], strict=False), # noqa: B008
155
156
  aws_profile_name: Optional[Secret] = Secret.from_env_var(["AWS_PROFILE"], strict=False), # noqa: B008
156
157
  generation_kwargs: Optional[Dict[str, Any]] = None,
157
- stop_words: Optional[List[str]] = None,
158
158
  streaming_callback: Optional[StreamingCallbackT] = None,
159
159
  boto3_config: Optional[Dict[str, Any]] = None,
160
160
  tools: Optional[Union[List[Tool], Toolset]] = None,
161
+ *,
162
+ guardrail_config: Optional[Dict[str, str]] = None,
161
163
  ) -> None:
162
164
  """
163
165
  Initializes the `AmazonBedrockChatGenerator` with the provided parameters. The parameters are passed to the
@@ -179,10 +181,6 @@ class AmazonBedrockChatGenerator:
179
181
  :param generation_kwargs: Keyword arguments sent to the model. These parameters are specific to a model.
180
182
  You can find the model specific arguments in the AWS Bedrock API
181
183
  [documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html).
182
- :param stop_words: A list of stop words that stop the model from generating more text
183
- when encountered. You can provide them using this parameter or using the model's `generation_kwargs`
184
- under a model's specific key for stop words.
185
- For example, you can provide stop words for Anthropic Claude in the `stop_sequences` key.
186
184
  :param streaming_callback: A callback function called when a new token is received from the stream.
187
185
  By default, the model is not set up for streaming. To enable streaming, set this parameter to a callback
188
186
  function that handles the streaming chunks. The callback function receives a
@@ -190,6 +188,19 @@ class AmazonBedrockChatGenerator:
190
188
  the streaming mode on.
191
189
  :param boto3_config: The configuration for the boto3 client.
192
190
  :param tools: A list of Tool objects or a Toolset that the model can use. Each tool should have a unique name.
191
+ :param guardrail_config: Optional configuration for a guardrail that has been created in Amazon Bedrock.
192
+ This must be provided as a dictionary matching either
193
+ [GuardrailConfiguration](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_GuardrailConfiguration.html).
194
+ or, in streaming mode (when `streaming_callback` is set),
195
+ [GuardrailStreamConfiguration](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_GuardrailStreamConfiguration.html).
196
+ If `trace` is set to `enabled`, the guardrail trace will be included under the `trace` key in the `meta`
197
+ attribute of the resulting `ChatMessage`.
198
+ Note: Enabling guardrails in streaming mode may introduce additional latency.
199
+ To manage this, you can adjust the `streamProcessingMode` parameter.
200
+ See the
201
+ [Guardrails Streaming documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-streaming.html)
202
+ for more information.
203
+
193
204
 
194
205
  :raises ValueError: If the model name is empty or None.
195
206
  :raises AmazonBedrockConfigurationError: If the AWS environment is not configured correctly or the model is
@@ -204,12 +215,15 @@ class AmazonBedrockChatGenerator:
204
215
  self.aws_session_token = aws_session_token
205
216
  self.aws_region_name = aws_region_name
206
217
  self.aws_profile_name = aws_profile_name
207
- self.stop_words = stop_words or []
208
218
  self.streaming_callback = streaming_callback
209
219
  self.boto3_config = boto3_config
220
+
210
221
  _check_duplicate_tool_names(list(tools or [])) # handles Toolset as well
211
222
  self.tools = tools
212
223
 
224
+ _validate_guardrail_config(guardrail_config=guardrail_config, streaming=streaming_callback is not None)
225
+ self.guardrail_config = guardrail_config
226
+
213
227
  def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
214
228
  return secret.resolve_value() if secret else None
215
229
 
@@ -237,7 +251,6 @@ class AmazonBedrockChatGenerator:
237
251
  raise AmazonBedrockConfigurationError(msg) from exception
238
252
 
239
253
  self.generation_kwargs = generation_kwargs or {}
240
- self.stop_words = stop_words or []
241
254
  self.async_session: Optional[aioboto3.Session] = None
242
255
 
243
256
  def _get_async_session(self) -> aioboto3.Session:
@@ -291,11 +304,11 @@ class AmazonBedrockChatGenerator:
291
304
  aws_region_name=self.aws_region_name.to_dict() if self.aws_region_name else None,
292
305
  aws_profile_name=self.aws_profile_name.to_dict() if self.aws_profile_name else None,
293
306
  model=self.model,
294
- stop_words=self.stop_words,
295
307
  generation_kwargs=self.generation_kwargs,
296
308
  streaming_callback=callback_name,
297
309
  boto3_config=self.boto3_config,
298
310
  tools=serialize_tools_or_toolset(self.tools),
311
+ guardrail_config=self.guardrail_config,
299
312
  )
300
313
 
301
314
  @classmethod
@@ -308,6 +321,12 @@ class AmazonBedrockChatGenerator:
308
321
  Instance of `AmazonBedrockChatGenerator`.
309
322
  """
310
323
  init_params = data.get("init_parameters", {})
324
+
325
+ stop_words = init_params.pop("stop_words", None)
326
+ msg = "stop_words parameter will be ignored. Use the `stopSequences` key in `generation_kwargs` instead."
327
+ if stop_words:
328
+ logger.warning(msg)
329
+
311
330
  serialized_callback_handler = init_params.get("streaming_callback")
312
331
  if serialized_callback_handler:
313
332
  data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
@@ -387,6 +406,8 @@ class AmazonBedrockChatGenerator:
387
406
  params["toolConfig"] = tool_config
388
407
  if additional_fields:
389
408
  params["additionalModelRequestFields"] = additional_fields
409
+ if self.guardrail_config:
410
+ params["guardrailConfig"] = self.guardrail_config
390
411
 
391
412
  # overloads that exhaust finite Literals(bool) not treated as exhaustive
392
413
  # see https://github.com/python/mypy/issues/14764
@@ -5,16 +5,20 @@ from typing import Any, Dict, List, Optional, Tuple
5
5
 
6
6
  from botocore.eventstream import EventStream
7
7
  from haystack import logging
8
+ from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
8
9
  from haystack.dataclasses import (
9
10
  AsyncStreamingCallbackT,
10
11
  ChatMessage,
11
12
  ChatRole,
12
13
  ComponentInfo,
14
+ FinishReason,
13
15
  ImageContent,
16
+ ReasoningContent,
14
17
  StreamingChunk,
15
18
  SyncStreamingCallbackT,
16
19
  TextContent,
17
20
  ToolCall,
21
+ ToolCallDelta,
18
22
  )
19
23
  from haystack.tools import Tool
20
24
 
@@ -24,6 +28,16 @@ logger = logging.getLogger(__name__)
24
28
  # see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageBlock.html for supported formats
25
29
  IMAGE_SUPPORTED_FORMATS = ["png", "jpeg", "gif", "webp"]
26
30
 
31
+ # see https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_MessageStopEvent.html
32
+ FINISH_REASON_MAPPING: Dict[str, FinishReason] = {
33
+ "end_turn": "stop",
34
+ "stop_sequence": "stop",
35
+ "max_tokens": "length",
36
+ "guardrail_intervened": "content_filter",
37
+ "content_filtered": "content_filter",
38
+ "tool_use": "tool_calls",
39
+ }
40
+
27
41
 
28
42
  # Haystack to Bedrock util methods
29
43
  def _format_tools(tools: Optional[List[Tool]] = None) -> Optional[Dict[str, Any]]:
@@ -57,8 +71,8 @@ def _format_tool_call_message(tool_call_message: ChatMessage) -> Dict[str, Any]:
57
71
  content: List[Dict[str, Any]] = []
58
72
 
59
73
  # tool call messages can contain reasoning content
60
- if reasoning_contents := tool_call_message.meta.get("reasoning_contents"):
61
- content.extend(_format_reasoning_contents(reasoning_contents=reasoning_contents))
74
+ if reasoning_content := tool_call_message.reasoning:
75
+ content.extend(_format_reasoning_content(reasoning_content=reasoning_content))
62
76
 
63
77
  # Tool call message can contain text
64
78
  if tool_call_message.text:
@@ -162,16 +176,16 @@ def _repair_tool_result_messages(bedrock_formatted_messages: List[Dict[str, Any]
162
176
  return [msg for _, msg in repaired_bedrock_formatted_messages]
163
177
 
164
178
 
165
- def _format_reasoning_contents(reasoning_contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
179
+ def _format_reasoning_content(reasoning_content: ReasoningContent) -> List[Dict[str, Any]]:
166
180
  """
167
- Format reasoning contents to match Bedrock's expected structure.
181
+ Format ReasoningContent to match Bedrock's expected structure.
168
182
 
169
- :param reasoning_contents: List of reasoning content dictionaries from Haystack ChatMessage metadata.
183
+ :param reasoning_content: ReasoningContent object containing reasoning contents to format.
170
184
  :returns: List of formatted reasoning content dictionaries for Bedrock.
171
185
  """
172
186
  formatted_contents = []
173
- for reasoning_content in reasoning_contents:
174
- formatted_content = {"reasoningContent": reasoning_content["reasoning_content"]}
187
+ for content in reasoning_content.extra.get("reasoning_contents", []):
188
+ formatted_content = {"reasoningContent": content["reasoning_content"]}
175
189
  if reasoning_text := formatted_content["reasoningContent"].pop("reasoning_text", None):
176
190
  formatted_content["reasoningContent"]["reasoningText"] = reasoning_text
177
191
  if redacted_content := formatted_content["reasoningContent"].pop("redacted_content", None):
@@ -192,8 +206,8 @@ def _format_text_image_message(message: ChatMessage) -> Dict[str, Any]:
192
206
 
193
207
  bedrock_content_blocks: List[Dict[str, Any]] = []
194
208
  # Add reasoning content if available as the first content block
195
- if message.meta.get("reasoning_contents"):
196
- bedrock_content_blocks.extend(_format_reasoning_contents(reasoning_contents=message.meta["reasoning_contents"]))
209
+ if message.reasoning:
210
+ bedrock_content_blocks.extend(_format_reasoning_content(reasoning_content=message.reasoning))
197
211
 
198
212
  for part in content_parts:
199
213
  if isinstance(part, TextContent):
@@ -259,6 +273,7 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
259
273
  :param model: The model ID used for generation, included in message metadata.
260
274
  :returns: List of ChatMessage objects containing the assistant's response(s) with appropriate metadata.
261
275
  """
276
+
262
277
  replies = []
263
278
  if "output" in response_body and "message" in response_body["output"]:
264
279
  message = response_body["output"]["message"]
@@ -266,10 +281,10 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
266
281
  content_blocks = message["content"]
267
282
 
268
283
  # Common meta information
269
- base_meta = {
284
+ meta = {
270
285
  "model": model,
271
286
  "index": 0,
272
- "finish_reason": response_body.get("stopReason"),
287
+ "finish_reason": FINISH_REASON_MAPPING.get(response_body.get("stopReason", "")),
273
288
  "usage": {
274
289
  # OpenAI's format for usage for cross ChatGenerator compatibility
275
290
  "prompt_tokens": response_body.get("usage", {}).get("inputTokens", 0),
@@ -277,6 +292,9 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
277
292
  "total_tokens": response_body.get("usage", {}).get("totalTokens", 0),
278
293
  },
279
294
  }
295
+ # guardrail trace
296
+ if "trace" in response_body:
297
+ meta["trace"] = response_body["trace"]
280
298
 
281
299
  # Process all content blocks and combine them into a single message
282
300
  text_content = []
@@ -303,11 +321,26 @@ def _parse_completion_response(response_body: Dict[str, Any], model: str) -> Lis
303
321
  reasoning_content["redacted_content"] = reasoning_content.pop("redactedContent")
304
322
  reasoning_contents.append({"reasoning_content": reasoning_content})
305
323
 
306
- # If reasoning contents were found, add them to the base meta
307
- base_meta.update({"reasoning_contents": reasoning_contents})
324
+ reasoning_text = ""
325
+ for content in reasoning_contents:
326
+ if "reasoning_text" in content["reasoning_content"]:
327
+ reasoning_text += content["reasoning_content"]["reasoning_text"]["text"]
328
+ elif "redacted_content" in content["reasoning_content"]:
329
+ reasoning_text += "[REDACTED]"
308
330
 
309
331
  # Create a single ChatMessage with combined text and tool calls
310
- replies.append(ChatMessage.from_assistant(" ".join(text_content), tool_calls=tool_calls, meta=base_meta))
332
+ replies.append(
333
+ ChatMessage.from_assistant(
334
+ " ".join(text_content),
335
+ tool_calls=tool_calls,
336
+ meta=meta,
337
+ reasoning=ReasoningContent(
338
+ reasoning_text=reasoning_text, extra={"reasoning_contents": reasoning_contents}
339
+ )
340
+ if reasoning_contents
341
+ else None,
342
+ )
343
+ )
311
344
 
312
345
  return replies
313
346
 
@@ -326,11 +359,11 @@ def _convert_event_to_streaming_chunk(
326
359
  :param component_info: ComponentInfo object
327
360
  :returns: StreamingChunk object containing the content and metadata extracted from the event.
328
361
  """
362
+
329
363
  # Initialize an empty StreamingChunk to return if no relevant event is found
330
364
  # (e.g. for messageStart and contentBlockStop)
331
- streaming_chunk = StreamingChunk(
332
- content="", meta={"model": model, "received_at": datetime.now(timezone.utc).isoformat()}
333
- )
365
+ base_meta = {"model": model, "received_at": datetime.now(timezone.utc).isoformat()}
366
+ streaming_chunk = StreamingChunk(content="", meta=base_meta)
334
367
 
335
368
  if "contentBlockStart" in event:
336
369
  # contentBlockStart always has the key "contentBlockIndex"
@@ -340,26 +373,15 @@ def _convert_event_to_streaming_chunk(
340
373
  tool_start = block_start["start"]["toolUse"]
341
374
  streaming_chunk = StreamingChunk(
342
375
  content="",
343
- meta={
344
- "model": model,
345
- # This is always 0 b/c it represents the choice index
346
- "index": 0,
347
- # We follow the same format used in the OpenAIChatGenerator
348
- "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
349
- {
350
- "index": block_idx, # int
351
- "id": tool_start["toolUseId"], # Optional[str]
352
- "function": { # Optional[ChoiceDeltaToolCallFunction]
353
- # Will accumulate deltas as string
354
- "arguments": "", # Optional[str]
355
- "name": tool_start["name"], # Optional[str]
356
- },
357
- "type": "function", # Optional[Literal["function"]]
358
- }
359
- ],
360
- "finish_reason": None,
361
- "received_at": datetime.now(timezone.utc).isoformat(),
362
- },
376
+ index=block_idx,
377
+ tool_calls=[
378
+ ToolCallDelta(
379
+ index=block_idx,
380
+ id=tool_start["toolUseId"],
381
+ tool_name=tool_start["name"],
382
+ )
383
+ ],
384
+ meta=base_meta,
363
385
  )
364
386
 
365
387
  elif "contentBlockDelta" in event:
@@ -370,39 +392,22 @@ def _convert_event_to_streaming_chunk(
370
392
  if "text" in delta:
371
393
  streaming_chunk = StreamingChunk(
372
394
  content=delta["text"],
373
- meta={
374
- "model": model,
375
- # This is always 0 b/c it represents the choice index
376
- "index": 0,
377
- "tool_calls": None,
378
- "finish_reason": None,
379
- "received_at": datetime.now(timezone.utc).isoformat(),
380
- },
395
+ index=block_idx,
396
+ meta=base_meta,
381
397
  )
382
398
  # This only occurs when accumulating the arguments for a toolUse
383
399
  # The content_block for this tool should already exist at this point
384
400
  elif "toolUse" in delta:
385
401
  streaming_chunk = StreamingChunk(
386
402
  content="",
387
- meta={
388
- "model": model,
389
- # This is always 0 b/c it represents the choice index
390
- "index": 0,
391
- "tool_calls": [ # Optional[List[ChoiceDeltaToolCall]]
392
- {
393
- "index": block_idx, # int
394
- "id": None, # Optional[str]
395
- "function": { # Optional[ChoiceDeltaToolCallFunction]
396
- # Will accumulate deltas as string
397
- "arguments": delta["toolUse"].get("input", ""), # Optional[str]
398
- "name": None, # Optional[str]
399
- },
400
- "type": "function", # Optional[Literal["function"]]
401
- }
402
- ],
403
- "finish_reason": None,
404
- "received_at": datetime.now(timezone.utc).isoformat(),
405
- },
403
+ index=block_idx,
404
+ tool_calls=[
405
+ ToolCallDelta(
406
+ index=block_idx,
407
+ arguments=delta["toolUse"].get("input", ""),
408
+ )
409
+ ],
410
+ meta=base_meta,
406
411
  )
407
412
  # This is for accumulating reasoning content deltas
408
413
  elif "reasoningContent" in delta:
@@ -411,55 +416,45 @@ def _convert_event_to_streaming_chunk(
411
416
  reasoning_content["redacted_content"] = reasoning_content.pop("redactedContent")
412
417
  streaming_chunk = StreamingChunk(
413
418
  content="",
419
+ index=block_idx,
414
420
  meta={
415
- "model": model,
416
- "index": 0,
417
- "tool_calls": None,
418
- "finish_reason": None,
419
- "received_at": datetime.now(timezone.utc).isoformat(),
421
+ **base_meta,
420
422
  "reasoning_contents": [{"index": block_idx, "reasoning_content": reasoning_content}],
421
423
  },
422
424
  )
423
425
 
424
426
  elif "messageStop" in event:
425
- finish_reason = event["messageStop"].get("stopReason")
427
+ finish_reason = FINISH_REASON_MAPPING.get(event["messageStop"].get("stopReason"))
426
428
  streaming_chunk = StreamingChunk(
427
429
  content="",
428
- meta={
429
- "model": model,
430
- # This is always 0 b/c it represents the choice index
431
- "index": 0,
432
- "tool_calls": None,
433
- "finish_reason": finish_reason,
434
- "received_at": datetime.now(timezone.utc).isoformat(),
435
- },
430
+ finish_reason=finish_reason,
431
+ meta=base_meta,
436
432
  )
437
433
 
438
- elif "metadata" in event and "usage" in event["metadata"]:
439
- metadata = event["metadata"]
440
- streaming_chunk = StreamingChunk(
441
- content="",
442
- meta={
443
- "model": model,
444
- # This is always 0 b/c it represents the choice index
445
- "index": 0,
446
- "tool_calls": None,
447
- "finish_reason": None,
448
- "received_at": datetime.now(timezone.utc).isoformat(),
449
- "usage": {
450
- "prompt_tokens": metadata["usage"].get("inputTokens", 0),
451
- "completion_tokens": metadata["usage"].get("outputTokens", 0),
452
- "total_tokens": metadata["usage"].get("totalTokens", 0),
453
- },
454
- },
455
- )
434
+ elif "metadata" in event:
435
+ event_meta = event["metadata"]
436
+ chunk_meta: Dict[str, Any] = {**base_meta}
437
+
438
+ if "usage" in event_meta:
439
+ usage = event_meta["usage"]
440
+ chunk_meta["usage"] = {
441
+ "prompt_tokens": usage.get("inputTokens", 0),
442
+ "completion_tokens": usage.get("outputTokens", 0),
443
+ "total_tokens": usage.get("totalTokens", 0),
444
+ }
445
+ if "trace" in event_meta:
446
+ chunk_meta["trace"] = event_meta["trace"]
447
+
448
+ # Only create chunk if we added usage or trace data
449
+ if len(chunk_meta) > len(base_meta):
450
+ streaming_chunk = StreamingChunk(content="", meta=chunk_meta)
456
451
 
457
452
  streaming_chunk.component_info = component_info
458
453
 
459
454
  return streaming_chunk
460
455
 
461
456
 
462
- def _process_reasoning_contents(chunks: List[StreamingChunk]) -> List[Dict[str, Any]]:
457
+ def _process_reasoning_contents(chunks: List[StreamingChunk]) -> Optional[ReasoningContent]:
463
458
  """
464
459
  Process reasoning contents from a list of StreamingChunk objects into the Bedrock expected format.
465
460
 
@@ -491,6 +486,8 @@ def _process_reasoning_contents(chunks: List[StreamingChunk]) -> List[Dict[str,
491
486
  )
492
487
  if redacted_content:
493
488
  formatted_reasoning_contents.append({"reasoning_content": {"redacted_content": redacted_content}})
489
+
490
+ # Reset accumulators for new group
494
491
  reasoning_text = ""
495
492
  reasoning_signature = None
496
493
  redacted_content = None
@@ -516,85 +513,22 @@ def _process_reasoning_contents(chunks: List[StreamingChunk]) -> List[Dict[str,
516
513
  if redacted_content:
517
514
  formatted_reasoning_contents.append({"reasoning_content": {"redacted_content": redacted_content}})
518
515
 
519
- return formatted_reasoning_contents
520
-
521
-
522
- def _convert_streaming_chunks_to_chat_message(chunks: List[StreamingChunk]) -> ChatMessage:
523
- """
524
- Converts a list of streaming chunks into a ChatMessage object.
525
-
526
- The function processes streaming chunks to build a ChatMessage object, including extracting and constructing
527
- tool calls, managing metadata such as model type, finish reason, and usage information.
528
- The tool call processing handles accumulating data across the chunks and attempts to parse JSON-formatted
529
- arguments for tool calls.
530
-
531
- :param chunks: A list of StreamingChunk objects representing parts of the assistant's response.
532
-
533
- :returns:
534
- A ChatMessage object constructed from the streaming chunks, containing the aggregated text, processed tool
535
- calls, and metadata.
536
- """
537
- # Join all text content from the chunks
538
- text = "".join([chunk.content for chunk in chunks])
539
-
540
- # If reasoning content is present in any chunk, accumulate it
541
- reasoning_contents = _process_reasoning_contents(chunks=chunks)
542
-
543
- # Process tool calls if present in any chunk
544
- tool_calls = []
545
- tool_call_data: Dict[int, Dict[str, str]] = {} # Track tool calls by index
546
- for chunk_payload in chunks:
547
- tool_calls_meta = chunk_payload.meta.get("tool_calls")
548
- if tool_calls_meta is not None:
549
- for delta in tool_calls_meta:
550
- # We use the index of the tool call to track it across chunks since the ID is not always provided
551
- if delta["index"] not in tool_call_data:
552
- tool_call_data[delta["index"]] = {"id": "", "name": "", "arguments": ""}
553
-
554
- # Save the ID if present
555
- if delta.get("id"):
556
- tool_call_data[delta["index"]]["id"] = delta["id"]
557
-
558
- if delta.get("function"):
559
- if delta["function"].get("name"):
560
- tool_call_data[delta["index"]]["name"] += delta["function"]["name"]
561
- if delta["function"].get("arguments"):
562
- tool_call_data[delta["index"]]["arguments"] += delta["function"]["arguments"]
563
-
564
- # Convert accumulated tool call data into ToolCall objects
565
- for call_data in tool_call_data.values():
566
- try:
567
- arguments = json.loads(call_data.get("arguments", "{}")) if call_data.get("arguments") else {}
568
- tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
569
- except json.JSONDecodeError:
570
- logger.warning(
571
- "Amazon Bedrock returned a malformed JSON string for tool call arguments. This tool call will be "
572
- "skipped. Tool call ID: {tool_id}, Tool name: {tool_name}, Arguments: {tool_arguments}",
573
- tool_id=call_data["id"],
574
- tool_name=call_data["name"],
575
- tool_arguments=call_data["arguments"],
576
- )
577
-
578
- # finish_reason can appear in different places so we look for the last one
579
- finish_reasons = [
580
- chunk.meta.get("finish_reason") for chunk in chunks if chunk.meta.get("finish_reason") is not None
581
- ]
582
- finish_reason = finish_reasons[-1] if finish_reasons else None
583
-
584
- # usage is usually last but we look for it as well
585
- usages = [chunk.meta.get("usage") for chunk in chunks if chunk.meta.get("usage") is not None]
586
- usage = usages[-1] if usages else None
587
-
588
- meta = {
589
- "model": chunks[-1].meta["model"],
590
- "index": 0,
591
- "finish_reason": finish_reason,
592
- "completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
593
- "usage": usage,
594
- "reasoning_contents": reasoning_contents,
595
- }
596
-
597
- return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
516
+ # Combine all reasoning texts into a single string for the main reasoning_text field
517
+ final_reasoning_text = ""
518
+ for content in formatted_reasoning_contents:
519
+ if "reasoning_text" in content["reasoning_content"]:
520
+ # mypy somehow thinks that content["reasoning_content"]["reasoning_text"]["text"] can be of type None
521
+ final_reasoning_text += content["reasoning_content"]["reasoning_text"]["text"] # type: ignore[operator]
522
+ elif "redacted_content" in content["reasoning_content"]:
523
+ final_reasoning_text += "[REDACTED]"
524
+
525
+ return (
526
+ ReasoningContent(
527
+ reasoning_text=final_reasoning_text, extra={"reasoning_contents": formatted_reasoning_contents}
528
+ )
529
+ if formatted_reasoning_contents
530
+ else None
531
+ )
598
532
 
599
533
 
600
534
  def _parse_streaming_response(
@@ -612,13 +546,34 @@ def _parse_streaming_response(
612
546
  :param component_info: ComponentInfo object
613
547
  :return: List of ChatMessage objects
614
548
  """
549
+ content_block_idxs = set()
615
550
  chunks: List[StreamingChunk] = []
616
551
  for event in response_stream:
617
552
  streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model, component_info=component_info)
553
+ content_block_idx = streaming_chunk.index
554
+ if content_block_idx is not None and content_block_idx not in content_block_idxs:
555
+ streaming_chunk.start = True
556
+ content_block_idxs.add(content_block_idx)
618
557
  streaming_callback(streaming_chunk)
619
558
  chunks.append(streaming_chunk)
620
- replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
621
- return replies
559
+
560
+ reply = _convert_streaming_chunks_to_chat_message(chunks=chunks)
561
+
562
+ # both the reasoning content and the trace are ignored in _convert_streaming_chunks_to_chat_message
563
+ # so we need to process them separately
564
+ reasoning_content = _process_reasoning_contents(chunks=chunks)
565
+ if chunks[-1].meta and "trace" in chunks[-1].meta:
566
+ reply.meta["trace"] = chunks[-1].meta["trace"]
567
+
568
+ reply = ChatMessage.from_assistant(
569
+ text=reply.text,
570
+ meta=reply.meta,
571
+ name=reply.name,
572
+ tool_calls=reply.tool_calls,
573
+ reasoning=reasoning_content,
574
+ )
575
+
576
+ return [reply]
622
577
 
623
578
 
624
579
  async def _parse_streaming_response_async(
@@ -636,10 +591,44 @@ async def _parse_streaming_response_async(
636
591
  :param component_info: ComponentInfo object
637
592
  :return: List of ChatMessage objects
638
593
  """
594
+ content_block_idxs = set()
639
595
  chunks: List[StreamingChunk] = []
640
596
  async for event in response_stream:
641
597
  streaming_chunk = _convert_event_to_streaming_chunk(event=event, model=model, component_info=component_info)
598
+ content_block_idx = streaming_chunk.index
599
+ if content_block_idx is not None and content_block_idx not in content_block_idxs:
600
+ streaming_chunk.start = True
601
+ content_block_idxs.add(content_block_idx)
642
602
  await streaming_callback(streaming_chunk)
643
603
  chunks.append(streaming_chunk)
644
- replies = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
645
- return replies
604
+ reply = _convert_streaming_chunks_to_chat_message(chunks=chunks)
605
+ reasoning_content = _process_reasoning_contents(chunks=chunks)
606
+ reply = ChatMessage.from_assistant(
607
+ text=reply.text,
608
+ meta=reply.meta,
609
+ name=reply.name,
610
+ tool_calls=reply.tool_calls,
611
+ reasoning=reasoning_content,
612
+ )
613
+ return [reply]
614
+
615
+
616
+ def _validate_guardrail_config(guardrail_config: Optional[Dict[str, str]] = None, streaming: bool = False) -> None:
617
+ """
618
+ Validate the guardrail configuration.
619
+
620
+ :param guardrail_config: The guardrail configuration.
621
+ :param streaming: Whether the streaming is enabled.
622
+
623
+ :raises ValueError: If the guardrail configuration is invalid.
624
+ """
625
+ if guardrail_config is None:
626
+ return
627
+
628
+ required_fields = {"guardrailIdentifier", "guardrailVersion"}
629
+ if not required_fields.issubset(guardrail_config):
630
+ msg = "`guardrailIdentifier` and `guardrailVersion` fields are required in guardrail configuration."
631
+ raise ValueError(msg)
632
+ if not streaming and "streamProcessingMode" in guardrail_config:
633
+ msg = "`streamProcessingMode` field is only supported for streaming (when `streaming_callback` is not None)."
634
+ raise ValueError(msg)