agentrun-sdk 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agentrun-sdk might be problematic. Click here for more details.
- agentrun_operation_sdk/cli/__init__.py +1 -0
- agentrun_operation_sdk/cli/cli.py +19 -0
- agentrun_operation_sdk/cli/common.py +21 -0
- agentrun_operation_sdk/cli/runtime/__init__.py +1 -0
- agentrun_operation_sdk/cli/runtime/commands.py +203 -0
- agentrun_operation_sdk/client/client.py +75 -0
- agentrun_operation_sdk/operations/runtime/__init__.py +8 -0
- agentrun_operation_sdk/operations/runtime/configure.py +101 -0
- agentrun_operation_sdk/operations/runtime/launch.py +82 -0
- agentrun_operation_sdk/operations/runtime/models.py +31 -0
- agentrun_operation_sdk/services/runtime.py +152 -0
- agentrun_operation_sdk/utils/logging_config.py +72 -0
- agentrun_operation_sdk/utils/runtime/config.py +94 -0
- agentrun_operation_sdk/utils/runtime/container.py +280 -0
- agentrun_operation_sdk/utils/runtime/entrypoint.py +203 -0
- agentrun_operation_sdk/utils/runtime/schema.py +56 -0
- agentrun_sdk/__init__.py +7 -0
- agentrun_sdk/agent/__init__.py +25 -0
- agentrun_sdk/agent/agent.py +696 -0
- agentrun_sdk/agent/agent_result.py +46 -0
- agentrun_sdk/agent/conversation_manager/__init__.py +26 -0
- agentrun_sdk/agent/conversation_manager/conversation_manager.py +88 -0
- agentrun_sdk/agent/conversation_manager/null_conversation_manager.py +46 -0
- agentrun_sdk/agent/conversation_manager/sliding_window_conversation_manager.py +179 -0
- agentrun_sdk/agent/conversation_manager/summarizing_conversation_manager.py +252 -0
- agentrun_sdk/agent/state.py +97 -0
- agentrun_sdk/event_loop/__init__.py +9 -0
- agentrun_sdk/event_loop/event_loop.py +499 -0
- agentrun_sdk/event_loop/streaming.py +319 -0
- agentrun_sdk/experimental/__init__.py +4 -0
- agentrun_sdk/experimental/hooks/__init__.py +15 -0
- agentrun_sdk/experimental/hooks/events.py +123 -0
- agentrun_sdk/handlers/__init__.py +10 -0
- agentrun_sdk/handlers/callback_handler.py +70 -0
- agentrun_sdk/hooks/__init__.py +49 -0
- agentrun_sdk/hooks/events.py +80 -0
- agentrun_sdk/hooks/registry.py +247 -0
- agentrun_sdk/models/__init__.py +10 -0
- agentrun_sdk/models/anthropic.py +432 -0
- agentrun_sdk/models/bedrock.py +649 -0
- agentrun_sdk/models/litellm.py +225 -0
- agentrun_sdk/models/llamaapi.py +438 -0
- agentrun_sdk/models/mistral.py +539 -0
- agentrun_sdk/models/model.py +95 -0
- agentrun_sdk/models/ollama.py +357 -0
- agentrun_sdk/models/openai.py +436 -0
- agentrun_sdk/models/sagemaker.py +598 -0
- agentrun_sdk/models/writer.py +449 -0
- agentrun_sdk/multiagent/__init__.py +22 -0
- agentrun_sdk/multiagent/a2a/__init__.py +15 -0
- agentrun_sdk/multiagent/a2a/executor.py +148 -0
- agentrun_sdk/multiagent/a2a/server.py +252 -0
- agentrun_sdk/multiagent/base.py +92 -0
- agentrun_sdk/multiagent/graph.py +555 -0
- agentrun_sdk/multiagent/swarm.py +656 -0
- agentrun_sdk/py.typed +1 -0
- agentrun_sdk/session/__init__.py +18 -0
- agentrun_sdk/session/file_session_manager.py +216 -0
- agentrun_sdk/session/repository_session_manager.py +152 -0
- agentrun_sdk/session/s3_session_manager.py +272 -0
- agentrun_sdk/session/session_manager.py +73 -0
- agentrun_sdk/session/session_repository.py +51 -0
- agentrun_sdk/telemetry/__init__.py +21 -0
- agentrun_sdk/telemetry/config.py +194 -0
- agentrun_sdk/telemetry/metrics.py +476 -0
- agentrun_sdk/telemetry/metrics_constants.py +15 -0
- agentrun_sdk/telemetry/tracer.py +563 -0
- agentrun_sdk/tools/__init__.py +17 -0
- agentrun_sdk/tools/decorator.py +569 -0
- agentrun_sdk/tools/executor.py +137 -0
- agentrun_sdk/tools/loader.py +152 -0
- agentrun_sdk/tools/mcp/__init__.py +13 -0
- agentrun_sdk/tools/mcp/mcp_agent_tool.py +99 -0
- agentrun_sdk/tools/mcp/mcp_client.py +423 -0
- agentrun_sdk/tools/mcp/mcp_instrumentation.py +322 -0
- agentrun_sdk/tools/mcp/mcp_types.py +63 -0
- agentrun_sdk/tools/registry.py +607 -0
- agentrun_sdk/tools/structured_output.py +421 -0
- agentrun_sdk/tools/tools.py +217 -0
- agentrun_sdk/tools/watcher.py +136 -0
- agentrun_sdk/types/__init__.py +5 -0
- agentrun_sdk/types/collections.py +23 -0
- agentrun_sdk/types/content.py +188 -0
- agentrun_sdk/types/event_loop.py +48 -0
- agentrun_sdk/types/exceptions.py +81 -0
- agentrun_sdk/types/guardrails.py +254 -0
- agentrun_sdk/types/media.py +89 -0
- agentrun_sdk/types/session.py +152 -0
- agentrun_sdk/types/streaming.py +201 -0
- agentrun_sdk/types/tools.py +258 -0
- agentrun_sdk/types/traces.py +5 -0
- agentrun_sdk-0.1.2.dist-info/METADATA +51 -0
- agentrun_sdk-0.1.2.dist-info/RECORD +115 -0
- agentrun_sdk-0.1.2.dist-info/WHEEL +5 -0
- agentrun_sdk-0.1.2.dist-info/entry_points.txt +2 -0
- agentrun_sdk-0.1.2.dist-info/top_level.txt +3 -0
- agentrun_wrapper/__init__.py +11 -0
- agentrun_wrapper/_utils/__init__.py +6 -0
- agentrun_wrapper/_utils/endpoints.py +16 -0
- agentrun_wrapper/identity/__init__.py +5 -0
- agentrun_wrapper/identity/auth.py +211 -0
- agentrun_wrapper/memory/__init__.py +6 -0
- agentrun_wrapper/memory/client.py +1697 -0
- agentrun_wrapper/memory/constants.py +103 -0
- agentrun_wrapper/memory/controlplane.py +626 -0
- agentrun_wrapper/py.typed +1 -0
- agentrun_wrapper/runtime/__init__.py +13 -0
- agentrun_wrapper/runtime/app.py +473 -0
- agentrun_wrapper/runtime/context.py +34 -0
- agentrun_wrapper/runtime/models.py +25 -0
- agentrun_wrapper/services/__init__.py +1 -0
- agentrun_wrapper/services/identity.py +192 -0
- agentrun_wrapper/tools/__init__.py +6 -0
- agentrun_wrapper/tools/browser_client.py +325 -0
- agentrun_wrapper/tools/code_interpreter_client.py +186 -0
|
@@ -0,0 +1,649 @@
|
|
|
1
|
+
"""AWS Bedrock model provider.
|
|
2
|
+
|
|
3
|
+
- Docs: https://aws.amazon.com/bedrock/
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, AsyncGenerator, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
|
11
|
+
|
|
12
|
+
import boto3
|
|
13
|
+
from botocore.config import Config as BotocoreConfig
|
|
14
|
+
from botocore.exceptions import ClientError
|
|
15
|
+
from pydantic import BaseModel
|
|
16
|
+
from typing_extensions import TypedDict, Unpack, override
|
|
17
|
+
|
|
18
|
+
from ..event_loop import streaming
|
|
19
|
+
from ..tools import convert_pydantic_to_tool_spec
|
|
20
|
+
from ..types.content import ContentBlock, Message, Messages
|
|
21
|
+
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
|
|
22
|
+
from ..types.streaming import StreamEvent
|
|
23
|
+
from ..types.tools import ToolResult, ToolSpec
|
|
24
|
+
from .model import Model
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
DEFAULT_BEDROCK_MODEL_ID = "us.anthropic.claude-sonnet-4-20250514-v1:0"
|
|
29
|
+
DEFAULT_BEDROCK_REGION = "us-west-2"
|
|
30
|
+
|
|
31
|
+
BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
|
|
32
|
+
"Input is too long for requested model",
|
|
33
|
+
"input length and `max_tokens` exceed context limit",
|
|
34
|
+
"too many total text bytes",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
T = TypeVar("T", bound=BaseModel)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BedrockModel(Model):
|
|
41
|
+
"""AWS Bedrock model provider implementation.
|
|
42
|
+
|
|
43
|
+
The implementation handles Bedrock-specific features such as:
|
|
44
|
+
|
|
45
|
+
- Tool configuration for function calling
|
|
46
|
+
- Guardrails integration
|
|
47
|
+
- Caching points for system prompts and tools
|
|
48
|
+
- Streaming responses
|
|
49
|
+
- Context window overflow detection
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
class BedrockConfig(TypedDict, total=False):
|
|
53
|
+
"""Configuration options for Bedrock models.
|
|
54
|
+
|
|
55
|
+
Attributes:
|
|
56
|
+
additional_args: Any additional arguments to include in the request
|
|
57
|
+
additional_request_fields: Additional fields to include in the Bedrock request
|
|
58
|
+
additional_response_field_paths: Additional response field paths to extract
|
|
59
|
+
cache_prompt: Cache point type for the system prompt
|
|
60
|
+
cache_tools: Cache point type for tools
|
|
61
|
+
guardrail_id: ID of the guardrail to apply
|
|
62
|
+
guardrail_trace: Guardrail trace mode. Defaults to enabled.
|
|
63
|
+
guardrail_version: Version of the guardrail to apply
|
|
64
|
+
guardrail_stream_processing_mode: The guardrail processing mode
|
|
65
|
+
guardrail_redact_input: Flag to redact input if a guardrail is triggered. Defaults to True.
|
|
66
|
+
guardrail_redact_input_message: If a Bedrock Input guardrail triggers, replace the input with this message.
|
|
67
|
+
guardrail_redact_output: Flag to redact output if guardrail is triggered. Defaults to False.
|
|
68
|
+
guardrail_redact_output_message: If a Bedrock Output guardrail triggers, replace output with this message.
|
|
69
|
+
max_tokens: Maximum number of tokens to generate in the response
|
|
70
|
+
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-sonnet-4-20250514-v1:0")
|
|
71
|
+
stop_sequences: List of sequences that will stop generation when encountered
|
|
72
|
+
streaming: Flag to enable/disable streaming. Defaults to True.
|
|
73
|
+
temperature: Controls randomness in generation (higher = more random)
|
|
74
|
+
top_p: Controls diversity via nucleus sampling (alternative to temperature)
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
additional_args: Optional[dict[str, Any]]
|
|
78
|
+
additional_request_fields: Optional[dict[str, Any]]
|
|
79
|
+
additional_response_field_paths: Optional[list[str]]
|
|
80
|
+
cache_prompt: Optional[str]
|
|
81
|
+
cache_tools: Optional[str]
|
|
82
|
+
guardrail_id: Optional[str]
|
|
83
|
+
guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]]
|
|
84
|
+
guardrail_stream_processing_mode: Optional[Literal["sync", "async"]]
|
|
85
|
+
guardrail_version: Optional[str]
|
|
86
|
+
guardrail_redact_input: Optional[bool]
|
|
87
|
+
guardrail_redact_input_message: Optional[str]
|
|
88
|
+
guardrail_redact_output: Optional[bool]
|
|
89
|
+
guardrail_redact_output_message: Optional[str]
|
|
90
|
+
max_tokens: Optional[int]
|
|
91
|
+
model_id: str
|
|
92
|
+
stop_sequences: Optional[list[str]]
|
|
93
|
+
streaming: Optional[bool]
|
|
94
|
+
temperature: Optional[float]
|
|
95
|
+
top_p: Optional[float]
|
|
96
|
+
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
boto_session: Optional[boto3.Session] = None,
|
|
101
|
+
boto_client_config: Optional[BotocoreConfig] = None,
|
|
102
|
+
region_name: Optional[str] = None,
|
|
103
|
+
**model_config: Unpack[BedrockConfig],
|
|
104
|
+
):
|
|
105
|
+
"""Initialize provider instance.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
boto_session: Boto Session to use when calling the Bedrock Model.
|
|
109
|
+
boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client.
|
|
110
|
+
region_name: AWS region to use for the Bedrock service.
|
|
111
|
+
Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set.
|
|
112
|
+
**model_config: Configuration options for the Bedrock model.
|
|
113
|
+
"""
|
|
114
|
+
if region_name and boto_session:
|
|
115
|
+
raise ValueError("Cannot specify both `region_name` and `boto_session`.")
|
|
116
|
+
|
|
117
|
+
self.config = BedrockModel.BedrockConfig(model_id=DEFAULT_BEDROCK_MODEL_ID)
|
|
118
|
+
self.update_config(**model_config)
|
|
119
|
+
|
|
120
|
+
logger.debug("config=<%s> | initializing", self.config)
|
|
121
|
+
|
|
122
|
+
session = boto_session or boto3.Session()
|
|
123
|
+
|
|
124
|
+
# Add strands-agents to the request user agent
|
|
125
|
+
if boto_client_config:
|
|
126
|
+
existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
|
|
127
|
+
|
|
128
|
+
# Append 'strands-agents' to existing user_agent_extra or set it if not present
|
|
129
|
+
if existing_user_agent:
|
|
130
|
+
new_user_agent = f"{existing_user_agent} strands-agents"
|
|
131
|
+
else:
|
|
132
|
+
new_user_agent = "strands-agents"
|
|
133
|
+
|
|
134
|
+
client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
|
|
135
|
+
else:
|
|
136
|
+
client_config = BotocoreConfig(user_agent_extra="strands-agents")
|
|
137
|
+
|
|
138
|
+
resolved_region = region_name or session.region_name or os.environ.get("AWS_REGION") or DEFAULT_BEDROCK_REGION
|
|
139
|
+
|
|
140
|
+
self.client = session.client(
|
|
141
|
+
service_name="bedrock-runtime",
|
|
142
|
+
config=client_config,
|
|
143
|
+
region_name=resolved_region,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
logger.debug("region=<%s> | bedrock client created", self.client.meta.region_name)
|
|
147
|
+
|
|
148
|
+
@override
|
|
149
|
+
def update_config(self, **model_config: Unpack[BedrockConfig]) -> None: # type: ignore
|
|
150
|
+
"""Update the Bedrock Model configuration with the provided arguments.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
**model_config: Configuration overrides.
|
|
154
|
+
"""
|
|
155
|
+
self.config.update(model_config)
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
def get_config(self) -> BedrockConfig:
|
|
159
|
+
"""Get the current Bedrock Model configuration.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
The Bedrock model configuration.
|
|
163
|
+
"""
|
|
164
|
+
return self.config
|
|
165
|
+
|
|
166
|
+
def format_request(
|
|
167
|
+
self,
|
|
168
|
+
messages: Messages,
|
|
169
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
170
|
+
system_prompt: Optional[str] = None,
|
|
171
|
+
) -> dict[str, Any]:
|
|
172
|
+
"""Format a Bedrock converse stream request.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
messages: List of message objects to be processed by the model.
|
|
176
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
177
|
+
system_prompt: System prompt to provide context to the model.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
A Bedrock converse stream request.
|
|
181
|
+
"""
|
|
182
|
+
return {
|
|
183
|
+
"modelId": self.config["model_id"],
|
|
184
|
+
"messages": self._format_bedrock_messages(messages),
|
|
185
|
+
"system": [
|
|
186
|
+
*([{"text": system_prompt}] if system_prompt else []),
|
|
187
|
+
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
|
|
188
|
+
],
|
|
189
|
+
**(
|
|
190
|
+
{
|
|
191
|
+
"toolConfig": {
|
|
192
|
+
"tools": [
|
|
193
|
+
*[{"toolSpec": tool_spec} for tool_spec in tool_specs],
|
|
194
|
+
*(
|
|
195
|
+
[{"cachePoint": {"type": self.config["cache_tools"]}}]
|
|
196
|
+
if self.config.get("cache_tools")
|
|
197
|
+
else []
|
|
198
|
+
),
|
|
199
|
+
],
|
|
200
|
+
"toolChoice": {"auto": {}},
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
if tool_specs
|
|
204
|
+
else {}
|
|
205
|
+
),
|
|
206
|
+
**(
|
|
207
|
+
{"additionalModelRequestFields": self.config["additional_request_fields"]}
|
|
208
|
+
if self.config.get("additional_request_fields")
|
|
209
|
+
else {}
|
|
210
|
+
),
|
|
211
|
+
**(
|
|
212
|
+
{"additionalModelResponseFieldPaths": self.config["additional_response_field_paths"]}
|
|
213
|
+
if self.config.get("additional_response_field_paths")
|
|
214
|
+
else {}
|
|
215
|
+
),
|
|
216
|
+
**(
|
|
217
|
+
{
|
|
218
|
+
"guardrailConfig": {
|
|
219
|
+
"guardrailIdentifier": self.config["guardrail_id"],
|
|
220
|
+
"guardrailVersion": self.config["guardrail_version"],
|
|
221
|
+
"trace": self.config.get("guardrail_trace", "enabled"),
|
|
222
|
+
**(
|
|
223
|
+
{"streamProcessingMode": self.config.get("guardrail_stream_processing_mode")}
|
|
224
|
+
if self.config.get("guardrail_stream_processing_mode")
|
|
225
|
+
else {}
|
|
226
|
+
),
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
if self.config.get("guardrail_id") and self.config.get("guardrail_version")
|
|
230
|
+
else {}
|
|
231
|
+
),
|
|
232
|
+
"inferenceConfig": {
|
|
233
|
+
key: value
|
|
234
|
+
for key, value in [
|
|
235
|
+
("maxTokens", self.config.get("max_tokens")),
|
|
236
|
+
("temperature", self.config.get("temperature")),
|
|
237
|
+
("topP", self.config.get("top_p")),
|
|
238
|
+
("stopSequences", self.config.get("stop_sequences")),
|
|
239
|
+
]
|
|
240
|
+
if value is not None
|
|
241
|
+
},
|
|
242
|
+
**(
|
|
243
|
+
self.config["additional_args"]
|
|
244
|
+
if "additional_args" in self.config and self.config["additional_args"] is not None
|
|
245
|
+
else {}
|
|
246
|
+
),
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
def _format_bedrock_messages(self, messages: Messages) -> Messages:
|
|
250
|
+
"""Format messages for Bedrock API compatibility.
|
|
251
|
+
|
|
252
|
+
This function ensures messages conform to Bedrock's expected format by:
|
|
253
|
+
- Cleaning tool result content blocks by removing additional fields that may be
|
|
254
|
+
useful for retaining information in hooks but would cause Bedrock validation
|
|
255
|
+
exceptions when presented with unexpected fields
|
|
256
|
+
- Ensuring all message content blocks are properly formatted for the Bedrock API
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
messages: List of messages to format
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Messages formatted for Bedrock API compatibility
|
|
263
|
+
|
|
264
|
+
Note:
|
|
265
|
+
Bedrock will throw validation exceptions when presented with additional
|
|
266
|
+
unexpected fields in tool result blocks.
|
|
267
|
+
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
|
|
268
|
+
"""
|
|
269
|
+
cleaned_messages = []
|
|
270
|
+
|
|
271
|
+
for message in messages:
|
|
272
|
+
cleaned_content: list[ContentBlock] = []
|
|
273
|
+
|
|
274
|
+
for content_block in message["content"]:
|
|
275
|
+
if "toolResult" in content_block:
|
|
276
|
+
# Create a new content block with only the cleaned toolResult
|
|
277
|
+
tool_result: ToolResult = content_block["toolResult"]
|
|
278
|
+
|
|
279
|
+
# Keep only the required fields for Bedrock
|
|
280
|
+
cleaned_tool_result = ToolResult(
|
|
281
|
+
content=tool_result["content"], toolUseId=tool_result["toolUseId"], status=tool_result["status"]
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
cleaned_block: ContentBlock = {"toolResult": cleaned_tool_result}
|
|
285
|
+
cleaned_content.append(cleaned_block)
|
|
286
|
+
else:
|
|
287
|
+
# Keep other content blocks as-is
|
|
288
|
+
cleaned_content.append(content_block)
|
|
289
|
+
|
|
290
|
+
# Create new message with cleaned content
|
|
291
|
+
cleaned_message: Message = Message(content=cleaned_content, role=message["role"])
|
|
292
|
+
cleaned_messages.append(cleaned_message)
|
|
293
|
+
|
|
294
|
+
return cleaned_messages
|
|
295
|
+
|
|
296
|
+
def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
|
|
297
|
+
"""Check if guardrail data contains any blocked policies.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
guardrail_data: Guardrail data from trace information.
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
True if any blocked guardrail is detected, False otherwise.
|
|
304
|
+
"""
|
|
305
|
+
input_assessment = guardrail_data.get("inputAssessment", {})
|
|
306
|
+
output_assessments = guardrail_data.get("outputAssessments", {})
|
|
307
|
+
|
|
308
|
+
# Check input assessments
|
|
309
|
+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
|
|
310
|
+
return True
|
|
311
|
+
|
|
312
|
+
# Check output assessments
|
|
313
|
+
if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
|
|
314
|
+
return True
|
|
315
|
+
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
def _generate_redaction_events(self) -> list[StreamEvent]:
|
|
319
|
+
"""Generate redaction events based on configuration.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
List of redaction events to yield.
|
|
323
|
+
"""
|
|
324
|
+
events: list[StreamEvent] = []
|
|
325
|
+
|
|
326
|
+
if self.config.get("guardrail_redact_input", True):
|
|
327
|
+
logger.debug("Redacting user input due to guardrail.")
|
|
328
|
+
events.append(
|
|
329
|
+
{
|
|
330
|
+
"redactContent": {
|
|
331
|
+
"redactUserContentMessage": self.config.get(
|
|
332
|
+
"guardrail_redact_input_message", "[User input redacted.]"
|
|
333
|
+
)
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if self.config.get("guardrail_redact_output", False):
|
|
339
|
+
logger.debug("Redacting assistant output due to guardrail.")
|
|
340
|
+
events.append(
|
|
341
|
+
{
|
|
342
|
+
"redactContent": {
|
|
343
|
+
"redactAssistantContentMessage": self.config.get(
|
|
344
|
+
"guardrail_redact_output_message", "[Assistant output redacted.]"
|
|
345
|
+
)
|
|
346
|
+
}
|
|
347
|
+
}
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return events
|
|
351
|
+
|
|
352
|
+
@override
|
|
353
|
+
async def stream(
|
|
354
|
+
self,
|
|
355
|
+
messages: Messages,
|
|
356
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
357
|
+
system_prompt: Optional[str] = None,
|
|
358
|
+
**kwargs: Any,
|
|
359
|
+
) -> AsyncGenerator[StreamEvent, None]:
|
|
360
|
+
"""Stream conversation with the Bedrock model.
|
|
361
|
+
|
|
362
|
+
This method calls either the Bedrock converse_stream API or the converse API
|
|
363
|
+
based on the streaming parameter in the configuration.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
messages: List of message objects to be processed by the model.
|
|
367
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
368
|
+
system_prompt: System prompt to provide context to the model.
|
|
369
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
370
|
+
|
|
371
|
+
Yields:
|
|
372
|
+
Model events.
|
|
373
|
+
|
|
374
|
+
Raises:
|
|
375
|
+
ContextWindowOverflowException: If the input exceeds the model's context window.
|
|
376
|
+
ModelThrottledException: If the model service is throttling requests.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def callback(event: Optional[StreamEvent] = None) -> None:
|
|
380
|
+
loop.call_soon_threadsafe(queue.put_nowait, event)
|
|
381
|
+
if event is None:
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
loop = asyncio.get_event_loop()
|
|
385
|
+
queue: asyncio.Queue[Optional[StreamEvent]] = asyncio.Queue()
|
|
386
|
+
|
|
387
|
+
thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt)
|
|
388
|
+
task = asyncio.create_task(thread)
|
|
389
|
+
|
|
390
|
+
while True:
|
|
391
|
+
event = await queue.get()
|
|
392
|
+
if event is None:
|
|
393
|
+
break
|
|
394
|
+
|
|
395
|
+
yield event
|
|
396
|
+
|
|
397
|
+
await task
|
|
398
|
+
|
|
399
|
+
def _stream(
|
|
400
|
+
self,
|
|
401
|
+
callback: Callable[..., None],
|
|
402
|
+
messages: Messages,
|
|
403
|
+
tool_specs: Optional[list[ToolSpec]] = None,
|
|
404
|
+
system_prompt: Optional[str] = None,
|
|
405
|
+
) -> None:
|
|
406
|
+
"""Stream conversation with the Bedrock model.
|
|
407
|
+
|
|
408
|
+
This method operates in a separate thread to avoid blocking the async event loop with the call to
|
|
409
|
+
Bedrock's converse_stream.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
callback: Function to send events to the main thread.
|
|
413
|
+
messages: List of message objects to be processed by the model.
|
|
414
|
+
tool_specs: List of tool specifications to make available to the model.
|
|
415
|
+
system_prompt: System prompt to provide context to the model.
|
|
416
|
+
|
|
417
|
+
Raises:
|
|
418
|
+
ContextWindowOverflowException: If the input exceeds the model's context window.
|
|
419
|
+
ModelThrottledException: If the model service is throttling requests.
|
|
420
|
+
"""
|
|
421
|
+
logger.debug("formatting request")
|
|
422
|
+
request = self.format_request(messages, tool_specs, system_prompt)
|
|
423
|
+
logger.debug("request=<%s>", request)
|
|
424
|
+
|
|
425
|
+
logger.debug("invoking model")
|
|
426
|
+
streaming = self.config.get("streaming", True)
|
|
427
|
+
|
|
428
|
+
try:
|
|
429
|
+
logger.debug("got response from model")
|
|
430
|
+
if streaming:
|
|
431
|
+
response = self.client.converse_stream(**request)
|
|
432
|
+
for chunk in response["stream"]:
|
|
433
|
+
if (
|
|
434
|
+
"metadata" in chunk
|
|
435
|
+
and "trace" in chunk["metadata"]
|
|
436
|
+
and "guardrail" in chunk["metadata"]["trace"]
|
|
437
|
+
):
|
|
438
|
+
guardrail_data = chunk["metadata"]["trace"]["guardrail"]
|
|
439
|
+
if self._has_blocked_guardrail(guardrail_data):
|
|
440
|
+
for event in self._generate_redaction_events():
|
|
441
|
+
callback(event)
|
|
442
|
+
|
|
443
|
+
callback(chunk)
|
|
444
|
+
|
|
445
|
+
else:
|
|
446
|
+
response = self.client.converse(**request)
|
|
447
|
+
for event in self._convert_non_streaming_to_streaming(response):
|
|
448
|
+
callback(event)
|
|
449
|
+
|
|
450
|
+
if (
|
|
451
|
+
"trace" in response
|
|
452
|
+
and "guardrail" in response["trace"]
|
|
453
|
+
and self._has_blocked_guardrail(response["trace"]["guardrail"])
|
|
454
|
+
):
|
|
455
|
+
for event in self._generate_redaction_events():
|
|
456
|
+
callback(event)
|
|
457
|
+
|
|
458
|
+
except ClientError as e:
|
|
459
|
+
error_message = str(e)
|
|
460
|
+
|
|
461
|
+
if e.response["Error"]["Code"] == "ThrottlingException":
|
|
462
|
+
raise ModelThrottledException(error_message) from e
|
|
463
|
+
|
|
464
|
+
if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
|
|
465
|
+
logger.warning("bedrock threw context window overflow error")
|
|
466
|
+
raise ContextWindowOverflowException(e) from e
|
|
467
|
+
|
|
468
|
+
region = self.client.meta.region_name
|
|
469
|
+
|
|
470
|
+
# add_note added in Python 3.11
|
|
471
|
+
if hasattr(e, "add_note"):
|
|
472
|
+
# Aid in debugging by adding more information
|
|
473
|
+
e.add_note(f"└ Bedrock region: {region}")
|
|
474
|
+
e.add_note(f"└ Model id: {self.config.get('model_id')}")
|
|
475
|
+
|
|
476
|
+
if (
|
|
477
|
+
e.response["Error"]["Code"] == "AccessDeniedException"
|
|
478
|
+
and "You don't have access to the model" in error_message
|
|
479
|
+
):
|
|
480
|
+
e.add_note(
|
|
481
|
+
"└ For more information see "
|
|
482
|
+
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#model-access-issue"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if (
|
|
486
|
+
e.response["Error"]["Code"] == "ValidationException"
|
|
487
|
+
and "with on-demand throughput isn’t supported" in error_message
|
|
488
|
+
):
|
|
489
|
+
e.add_note(
|
|
490
|
+
"└ For more information see "
|
|
491
|
+
"https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/#on-demand-throughput-isnt-supported"
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
raise e
|
|
495
|
+
|
|
496
|
+
finally:
|
|
497
|
+
callback()
|
|
498
|
+
logger.debug("finished streaming response from model")
|
|
499
|
+
|
|
500
|
+
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
|
|
501
|
+
"""Convert a non-streaming response to the streaming format.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
response: The non-streaming response from the Bedrock model.
|
|
505
|
+
|
|
506
|
+
Returns:
|
|
507
|
+
An iterable of response events in the streaming format.
|
|
508
|
+
"""
|
|
509
|
+
# Yield messageStart event
|
|
510
|
+
yield {"messageStart": {"role": response["output"]["message"]["role"]}}
|
|
511
|
+
|
|
512
|
+
# Process content blocks
|
|
513
|
+
for content in response["output"]["message"]["content"]:
|
|
514
|
+
# Yield contentBlockStart event if needed
|
|
515
|
+
if "toolUse" in content:
|
|
516
|
+
yield {
|
|
517
|
+
"contentBlockStart": {
|
|
518
|
+
"start": {
|
|
519
|
+
"toolUse": {
|
|
520
|
+
"toolUseId": content["toolUse"]["toolUseId"],
|
|
521
|
+
"name": content["toolUse"]["name"],
|
|
522
|
+
}
|
|
523
|
+
},
|
|
524
|
+
}
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
# For tool use, we need to yield the input as a delta
|
|
528
|
+
input_value = json.dumps(content["toolUse"]["input"])
|
|
529
|
+
|
|
530
|
+
yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
|
|
531
|
+
elif "text" in content:
|
|
532
|
+
# Then yield the text as a delta
|
|
533
|
+
yield {
|
|
534
|
+
"contentBlockDelta": {
|
|
535
|
+
"delta": {"text": content["text"]},
|
|
536
|
+
}
|
|
537
|
+
}
|
|
538
|
+
elif "reasoningContent" in content:
|
|
539
|
+
# Then yield the reasoning content as a delta
|
|
540
|
+
yield {
|
|
541
|
+
"contentBlockDelta": {
|
|
542
|
+
"delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}}
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
if "signature" in content["reasoningContent"]["reasoningText"]:
|
|
547
|
+
yield {
|
|
548
|
+
"contentBlockDelta": {
|
|
549
|
+
"delta": {
|
|
550
|
+
"reasoningContent": {
|
|
551
|
+
"signature": content["reasoningContent"]["reasoningText"]["signature"]
|
|
552
|
+
}
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
# Yield contentBlockStop event
|
|
558
|
+
yield {"contentBlockStop": {}}
|
|
559
|
+
|
|
560
|
+
# Yield messageStop event
|
|
561
|
+
yield {
|
|
562
|
+
"messageStop": {
|
|
563
|
+
"stopReason": response["stopReason"],
|
|
564
|
+
"additionalModelResponseFields": response.get("additionalModelResponseFields"),
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
# Yield metadata event
|
|
569
|
+
if "usage" in response or "metrics" in response or "trace" in response:
|
|
570
|
+
metadata: StreamEvent = {"metadata": {}}
|
|
571
|
+
if "usage" in response:
|
|
572
|
+
metadata["metadata"]["usage"] = response["usage"]
|
|
573
|
+
if "metrics" in response:
|
|
574
|
+
metadata["metadata"]["metrics"] = response["metrics"]
|
|
575
|
+
if "trace" in response:
|
|
576
|
+
metadata["metadata"]["trace"] = response["trace"]
|
|
577
|
+
yield metadata
|
|
578
|
+
|
|
579
|
+
def _find_detected_and_blocked_policy(self, input: Any) -> bool:
|
|
580
|
+
"""Recursively checks if the assessment contains a detected and blocked guardrail.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
input: The assessment to check.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
True if the input contains a detected and blocked guardrail, False otherwise.
|
|
587
|
+
|
|
588
|
+
"""
|
|
589
|
+
# Check if input is a dictionary
|
|
590
|
+
if isinstance(input, dict):
|
|
591
|
+
# Check if current dictionary has action: BLOCKED and detected: true
|
|
592
|
+
if input.get("action") == "BLOCKED" and input.get("detected") and isinstance(input.get("detected"), bool):
|
|
593
|
+
return True
|
|
594
|
+
|
|
595
|
+
# Recursively check all values in the dictionary
|
|
596
|
+
for value in input.values():
|
|
597
|
+
if isinstance(value, dict):
|
|
598
|
+
return self._find_detected_and_blocked_policy(value)
|
|
599
|
+
# Handle case where value is a list of dictionaries
|
|
600
|
+
elif isinstance(value, list):
|
|
601
|
+
for item in value:
|
|
602
|
+
return self._find_detected_and_blocked_policy(item)
|
|
603
|
+
elif isinstance(input, list):
|
|
604
|
+
# Handle case where input is a list of dictionaries
|
|
605
|
+
for item in input:
|
|
606
|
+
return self._find_detected_and_blocked_policy(item)
|
|
607
|
+
# Otherwise return False
|
|
608
|
+
return False
|
|
609
|
+
|
|
610
|
+
@override
|
|
611
|
+
async def structured_output(
|
|
612
|
+
self, output_model: Type[T], prompt: Messages, system_prompt: Optional[str] = None, **kwargs: Any
|
|
613
|
+
) -> AsyncGenerator[dict[str, Union[T, Any]], None]:
|
|
614
|
+
"""Get structured output from the model.
|
|
615
|
+
|
|
616
|
+
Args:
|
|
617
|
+
output_model: The output model to use for the agent.
|
|
618
|
+
prompt: The prompt messages to use for the agent.
|
|
619
|
+
system_prompt: System prompt to provide context to the model.
|
|
620
|
+
**kwargs: Additional keyword arguments for future extensibility.
|
|
621
|
+
|
|
622
|
+
Yields:
|
|
623
|
+
Model events with the last being the structured output.
|
|
624
|
+
"""
|
|
625
|
+
tool_spec = convert_pydantic_to_tool_spec(output_model)
|
|
626
|
+
|
|
627
|
+
response = self.stream(messages=prompt, tool_specs=[tool_spec], system_prompt=system_prompt, **kwargs)
|
|
628
|
+
async for event in streaming.process_stream(response):
|
|
629
|
+
yield event
|
|
630
|
+
|
|
631
|
+
stop_reason, messages, _, _ = event["stop"]
|
|
632
|
+
|
|
633
|
+
if stop_reason != "tool_use":
|
|
634
|
+
raise ValueError(f'Model returned stop_reason: {stop_reason} instead of "tool_use".')
|
|
635
|
+
|
|
636
|
+
content = messages["content"]
|
|
637
|
+
output_response: dict[str, Any] | None = None
|
|
638
|
+
for block in content:
|
|
639
|
+
# if the tool use name doesn't match the tool spec name, skip, and if the block is not a tool use, skip.
|
|
640
|
+
# if the tool use name never matches, raise an error.
|
|
641
|
+
if block.get("toolUse") and block["toolUse"]["name"] == tool_spec["name"]:
|
|
642
|
+
output_response = block["toolUse"]["input"]
|
|
643
|
+
else:
|
|
644
|
+
continue
|
|
645
|
+
|
|
646
|
+
if output_response is None:
|
|
647
|
+
raise ValueError("No valid tool use or tool use input was found in the Bedrock response.")
|
|
648
|
+
|
|
649
|
+
yield {"output": output_model(**output_response)}
|