fast-agent-mcp 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_agent/__init__.py +183 -0
- fast_agent/acp/__init__.py +19 -0
- fast_agent/acp/acp_aware_mixin.py +304 -0
- fast_agent/acp/acp_context.py +437 -0
- fast_agent/acp/content_conversion.py +136 -0
- fast_agent/acp/filesystem_runtime.py +427 -0
- fast_agent/acp/permission_store.py +269 -0
- fast_agent/acp/server/__init__.py +5 -0
- fast_agent/acp/server/agent_acp_server.py +1472 -0
- fast_agent/acp/slash_commands.py +1050 -0
- fast_agent/acp/terminal_runtime.py +408 -0
- fast_agent/acp/tool_permission_adapter.py +125 -0
- fast_agent/acp/tool_permissions.py +474 -0
- fast_agent/acp/tool_progress.py +814 -0
- fast_agent/agents/__init__.py +85 -0
- fast_agent/agents/agent_types.py +64 -0
- fast_agent/agents/llm_agent.py +350 -0
- fast_agent/agents/llm_decorator.py +1139 -0
- fast_agent/agents/mcp_agent.py +1337 -0
- fast_agent/agents/tool_agent.py +271 -0
- fast_agent/agents/workflow/agents_as_tools_agent.py +849 -0
- fast_agent/agents/workflow/chain_agent.py +212 -0
- fast_agent/agents/workflow/evaluator_optimizer.py +380 -0
- fast_agent/agents/workflow/iterative_planner.py +652 -0
- fast_agent/agents/workflow/maker_agent.py +379 -0
- fast_agent/agents/workflow/orchestrator_models.py +218 -0
- fast_agent/agents/workflow/orchestrator_prompts.py +248 -0
- fast_agent/agents/workflow/parallel_agent.py +250 -0
- fast_agent/agents/workflow/router_agent.py +353 -0
- fast_agent/cli/__init__.py +0 -0
- fast_agent/cli/__main__.py +73 -0
- fast_agent/cli/commands/acp.py +159 -0
- fast_agent/cli/commands/auth.py +404 -0
- fast_agent/cli/commands/check_config.py +783 -0
- fast_agent/cli/commands/go.py +514 -0
- fast_agent/cli/commands/quickstart.py +557 -0
- fast_agent/cli/commands/serve.py +143 -0
- fast_agent/cli/commands/server_helpers.py +114 -0
- fast_agent/cli/commands/setup.py +174 -0
- fast_agent/cli/commands/url_parser.py +190 -0
- fast_agent/cli/constants.py +40 -0
- fast_agent/cli/main.py +115 -0
- fast_agent/cli/terminal.py +24 -0
- fast_agent/config.py +798 -0
- fast_agent/constants.py +41 -0
- fast_agent/context.py +279 -0
- fast_agent/context_dependent.py +50 -0
- fast_agent/core/__init__.py +92 -0
- fast_agent/core/agent_app.py +448 -0
- fast_agent/core/core_app.py +137 -0
- fast_agent/core/direct_decorators.py +784 -0
- fast_agent/core/direct_factory.py +620 -0
- fast_agent/core/error_handling.py +27 -0
- fast_agent/core/exceptions.py +90 -0
- fast_agent/core/executor/__init__.py +0 -0
- fast_agent/core/executor/executor.py +280 -0
- fast_agent/core/executor/task_registry.py +32 -0
- fast_agent/core/executor/workflow_signal.py +324 -0
- fast_agent/core/fastagent.py +1186 -0
- fast_agent/core/logging/__init__.py +5 -0
- fast_agent/core/logging/events.py +138 -0
- fast_agent/core/logging/json_serializer.py +164 -0
- fast_agent/core/logging/listeners.py +309 -0
- fast_agent/core/logging/logger.py +278 -0
- fast_agent/core/logging/transport.py +481 -0
- fast_agent/core/prompt.py +9 -0
- fast_agent/core/prompt_templates.py +183 -0
- fast_agent/core/validation.py +326 -0
- fast_agent/event_progress.py +62 -0
- fast_agent/history/history_exporter.py +49 -0
- fast_agent/human_input/__init__.py +47 -0
- fast_agent/human_input/elicitation_handler.py +123 -0
- fast_agent/human_input/elicitation_state.py +33 -0
- fast_agent/human_input/form_elements.py +59 -0
- fast_agent/human_input/form_fields.py +256 -0
- fast_agent/human_input/simple_form.py +113 -0
- fast_agent/human_input/types.py +40 -0
- fast_agent/interfaces.py +310 -0
- fast_agent/llm/__init__.py +9 -0
- fast_agent/llm/cancellation.py +22 -0
- fast_agent/llm/fastagent_llm.py +931 -0
- fast_agent/llm/internal/passthrough.py +161 -0
- fast_agent/llm/internal/playback.py +129 -0
- fast_agent/llm/internal/silent.py +41 -0
- fast_agent/llm/internal/slow.py +38 -0
- fast_agent/llm/memory.py +275 -0
- fast_agent/llm/model_database.py +490 -0
- fast_agent/llm/model_factory.py +388 -0
- fast_agent/llm/model_info.py +102 -0
- fast_agent/llm/prompt_utils.py +155 -0
- fast_agent/llm/provider/anthropic/anthropic_utils.py +84 -0
- fast_agent/llm/provider/anthropic/cache_planner.py +56 -0
- fast_agent/llm/provider/anthropic/llm_anthropic.py +796 -0
- fast_agent/llm/provider/anthropic/multipart_converter_anthropic.py +462 -0
- fast_agent/llm/provider/bedrock/bedrock_utils.py +218 -0
- fast_agent/llm/provider/bedrock/llm_bedrock.py +2207 -0
- fast_agent/llm/provider/bedrock/multipart_converter_bedrock.py +84 -0
- fast_agent/llm/provider/google/google_converter.py +466 -0
- fast_agent/llm/provider/google/llm_google_native.py +681 -0
- fast_agent/llm/provider/openai/llm_aliyun.py +31 -0
- fast_agent/llm/provider/openai/llm_azure.py +143 -0
- fast_agent/llm/provider/openai/llm_deepseek.py +76 -0
- fast_agent/llm/provider/openai/llm_generic.py +35 -0
- fast_agent/llm/provider/openai/llm_google_oai.py +32 -0
- fast_agent/llm/provider/openai/llm_groq.py +42 -0
- fast_agent/llm/provider/openai/llm_huggingface.py +85 -0
- fast_agent/llm/provider/openai/llm_openai.py +1195 -0
- fast_agent/llm/provider/openai/llm_openai_compatible.py +138 -0
- fast_agent/llm/provider/openai/llm_openrouter.py +45 -0
- fast_agent/llm/provider/openai/llm_tensorzero_openai.py +128 -0
- fast_agent/llm/provider/openai/llm_xai.py +38 -0
- fast_agent/llm/provider/openai/multipart_converter_openai.py +561 -0
- fast_agent/llm/provider/openai/openai_multipart.py +169 -0
- fast_agent/llm/provider/openai/openai_utils.py +67 -0
- fast_agent/llm/provider/openai/responses.py +133 -0
- fast_agent/llm/provider_key_manager.py +139 -0
- fast_agent/llm/provider_types.py +34 -0
- fast_agent/llm/request_params.py +61 -0
- fast_agent/llm/sampling_converter.py +98 -0
- fast_agent/llm/stream_types.py +9 -0
- fast_agent/llm/usage_tracking.py +445 -0
- fast_agent/mcp/__init__.py +56 -0
- fast_agent/mcp/common.py +26 -0
- fast_agent/mcp/elicitation_factory.py +84 -0
- fast_agent/mcp/elicitation_handlers.py +164 -0
- fast_agent/mcp/gen_client.py +83 -0
- fast_agent/mcp/helpers/__init__.py +36 -0
- fast_agent/mcp/helpers/content_helpers.py +352 -0
- fast_agent/mcp/helpers/server_config_helpers.py +25 -0
- fast_agent/mcp/hf_auth.py +147 -0
- fast_agent/mcp/interfaces.py +92 -0
- fast_agent/mcp/logger_textio.py +108 -0
- fast_agent/mcp/mcp_agent_client_session.py +411 -0
- fast_agent/mcp/mcp_aggregator.py +2175 -0
- fast_agent/mcp/mcp_connection_manager.py +723 -0
- fast_agent/mcp/mcp_content.py +262 -0
- fast_agent/mcp/mime_utils.py +108 -0
- fast_agent/mcp/oauth_client.py +509 -0
- fast_agent/mcp/prompt.py +159 -0
- fast_agent/mcp/prompt_message_extended.py +155 -0
- fast_agent/mcp/prompt_render.py +84 -0
- fast_agent/mcp/prompt_serialization.py +580 -0
- fast_agent/mcp/prompts/__init__.py +0 -0
- fast_agent/mcp/prompts/__main__.py +7 -0
- fast_agent/mcp/prompts/prompt_constants.py +18 -0
- fast_agent/mcp/prompts/prompt_helpers.py +238 -0
- fast_agent/mcp/prompts/prompt_load.py +186 -0
- fast_agent/mcp/prompts/prompt_server.py +552 -0
- fast_agent/mcp/prompts/prompt_template.py +438 -0
- fast_agent/mcp/resource_utils.py +215 -0
- fast_agent/mcp/sampling.py +200 -0
- fast_agent/mcp/server/__init__.py +4 -0
- fast_agent/mcp/server/agent_server.py +613 -0
- fast_agent/mcp/skybridge.py +44 -0
- fast_agent/mcp/sse_tracking.py +287 -0
- fast_agent/mcp/stdio_tracking_simple.py +59 -0
- fast_agent/mcp/streamable_http_tracking.py +309 -0
- fast_agent/mcp/tool_execution_handler.py +137 -0
- fast_agent/mcp/tool_permission_handler.py +88 -0
- fast_agent/mcp/transport_tracking.py +634 -0
- fast_agent/mcp/types.py +24 -0
- fast_agent/mcp/ui_agent.py +48 -0
- fast_agent/mcp/ui_mixin.py +209 -0
- fast_agent/mcp_server_registry.py +89 -0
- fast_agent/py.typed +0 -0
- fast_agent/resources/examples/data-analysis/analysis-campaign.py +189 -0
- fast_agent/resources/examples/data-analysis/analysis.py +68 -0
- fast_agent/resources/examples/data-analysis/fastagent.config.yaml +41 -0
- fast_agent/resources/examples/data-analysis/mount-point/WA_Fn-UseC_-HR-Employee-Attrition.csv +1471 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_account_server.py +88 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_forms_server.py +297 -0
- fast_agent/resources/examples/mcp/elicitations/elicitation_game_server.py +164 -0
- fast_agent/resources/examples/mcp/elicitations/fastagent.config.yaml +35 -0
- fast_agent/resources/examples/mcp/elicitations/fastagent.secrets.yaml.example +17 -0
- fast_agent/resources/examples/mcp/elicitations/forms_demo.py +107 -0
- fast_agent/resources/examples/mcp/elicitations/game_character.py +65 -0
- fast_agent/resources/examples/mcp/elicitations/game_character_handler.py +256 -0
- fast_agent/resources/examples/mcp/elicitations/tool_call.py +21 -0
- fast_agent/resources/examples/mcp/state-transfer/agent_one.py +18 -0
- fast_agent/resources/examples/mcp/state-transfer/agent_two.py +18 -0
- fast_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +27 -0
- fast_agent/resources/examples/mcp/state-transfer/fastagent.secrets.yaml.example +15 -0
- fast_agent/resources/examples/researcher/fastagent.config.yaml +61 -0
- fast_agent/resources/examples/researcher/researcher-eval.py +53 -0
- fast_agent/resources/examples/researcher/researcher-imp.py +189 -0
- fast_agent/resources/examples/researcher/researcher.py +36 -0
- fast_agent/resources/examples/tensorzero/.env.sample +2 -0
- fast_agent/resources/examples/tensorzero/Makefile +31 -0
- fast_agent/resources/examples/tensorzero/README.md +56 -0
- fast_agent/resources/examples/tensorzero/agent.py +35 -0
- fast_agent/resources/examples/tensorzero/demo_images/clam.jpg +0 -0
- fast_agent/resources/examples/tensorzero/demo_images/crab.png +0 -0
- fast_agent/resources/examples/tensorzero/demo_images/shrimp.png +0 -0
- fast_agent/resources/examples/tensorzero/docker-compose.yml +105 -0
- fast_agent/resources/examples/tensorzero/fastagent.config.yaml +19 -0
- fast_agent/resources/examples/tensorzero/image_demo.py +67 -0
- fast_agent/resources/examples/tensorzero/mcp_server/Dockerfile +25 -0
- fast_agent/resources/examples/tensorzero/mcp_server/entrypoint.sh +35 -0
- fast_agent/resources/examples/tensorzero/mcp_server/mcp_server.py +31 -0
- fast_agent/resources/examples/tensorzero/mcp_server/pyproject.toml +11 -0
- fast_agent/resources/examples/tensorzero/simple_agent.py +25 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/system_schema.json +29 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/system_template.minijinja +11 -0
- fast_agent/resources/examples/tensorzero/tensorzero_config/tensorzero.toml +35 -0
- fast_agent/resources/examples/workflows/agents_as_tools_extended.py +73 -0
- fast_agent/resources/examples/workflows/agents_as_tools_simple.py +50 -0
- fast_agent/resources/examples/workflows/chaining.py +37 -0
- fast_agent/resources/examples/workflows/evaluator.py +77 -0
- fast_agent/resources/examples/workflows/fastagent.config.yaml +26 -0
- fast_agent/resources/examples/workflows/graded_report.md +89 -0
- fast_agent/resources/examples/workflows/human_input.py +28 -0
- fast_agent/resources/examples/workflows/maker.py +156 -0
- fast_agent/resources/examples/workflows/orchestrator.py +70 -0
- fast_agent/resources/examples/workflows/parallel.py +56 -0
- fast_agent/resources/examples/workflows/router.py +69 -0
- fast_agent/resources/examples/workflows/short_story.md +13 -0
- fast_agent/resources/examples/workflows/short_story.txt +19 -0
- fast_agent/resources/setup/.gitignore +30 -0
- fast_agent/resources/setup/agent.py +28 -0
- fast_agent/resources/setup/fastagent.config.yaml +65 -0
- fast_agent/resources/setup/fastagent.secrets.yaml.example +38 -0
- fast_agent/resources/setup/pyproject.toml.tmpl +23 -0
- fast_agent/skills/__init__.py +9 -0
- fast_agent/skills/registry.py +235 -0
- fast_agent/tools/elicitation.py +369 -0
- fast_agent/tools/shell_runtime.py +402 -0
- fast_agent/types/__init__.py +59 -0
- fast_agent/types/conversation_summary.py +294 -0
- fast_agent/types/llm_stop_reason.py +78 -0
- fast_agent/types/message_search.py +249 -0
- fast_agent/ui/__init__.py +38 -0
- fast_agent/ui/console.py +59 -0
- fast_agent/ui/console_display.py +1080 -0
- fast_agent/ui/elicitation_form.py +946 -0
- fast_agent/ui/elicitation_style.py +59 -0
- fast_agent/ui/enhanced_prompt.py +1400 -0
- fast_agent/ui/history_display.py +734 -0
- fast_agent/ui/interactive_prompt.py +1199 -0
- fast_agent/ui/markdown_helpers.py +104 -0
- fast_agent/ui/markdown_truncator.py +1004 -0
- fast_agent/ui/mcp_display.py +857 -0
- fast_agent/ui/mcp_ui_utils.py +235 -0
- fast_agent/ui/mermaid_utils.py +169 -0
- fast_agent/ui/message_primitives.py +50 -0
- fast_agent/ui/notification_tracker.py +205 -0
- fast_agent/ui/plain_text_truncator.py +68 -0
- fast_agent/ui/progress_display.py +10 -0
- fast_agent/ui/rich_progress.py +195 -0
- fast_agent/ui/streaming.py +774 -0
- fast_agent/ui/streaming_buffer.py +449 -0
- fast_agent/ui/tool_display.py +422 -0
- fast_agent/ui/usage_display.py +204 -0
- fast_agent/utils/__init__.py +5 -0
- fast_agent/utils/reasoning_stream_parser.py +77 -0
- fast_agent/utils/time.py +22 -0
- fast_agent/workflow_telemetry.py +261 -0
- fast_agent_mcp-0.4.7.dist-info/METADATA +788 -0
- fast_agent_mcp-0.4.7.dist-info/RECORD +261 -0
- fast_agent_mcp-0.4.7.dist-info/WHEEL +4 -0
- fast_agent_mcp-0.4.7.dist-info/entry_points.txt +7 -0
- fast_agent_mcp-0.4.7.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,2207 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import sys
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from enum import Enum, auto
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Type, Union
|
|
8
|
+
|
|
9
|
+
from mcp import Tool
|
|
10
|
+
from mcp.types import (
|
|
11
|
+
CallToolRequest,
|
|
12
|
+
CallToolRequestParams,
|
|
13
|
+
ContentBlock,
|
|
14
|
+
TextContent,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from fast_agent.core.exceptions import ProviderKeyError
|
|
18
|
+
from fast_agent.core.logging.logger import get_logger
|
|
19
|
+
from fast_agent.event_progress import ProgressAction
|
|
20
|
+
from fast_agent.interfaces import ModelT
|
|
21
|
+
from fast_agent.llm.fastagent_llm import FastAgentLLM
|
|
22
|
+
from fast_agent.llm.provider.bedrock.multipart_converter_bedrock import BedrockConverter
|
|
23
|
+
from fast_agent.llm.provider_types import Provider
|
|
24
|
+
from fast_agent.llm.usage_tracking import TurnUsage
|
|
25
|
+
from fast_agent.types import PromptMessageExtended, RequestParams
|
|
26
|
+
from fast_agent.types.llm_stop_reason import LlmStopReason
|
|
27
|
+
|
|
28
|
+
# Mapping from Bedrock's snake_case stop reasons to MCP's camelCase
|
|
29
|
+
BEDROCK_TO_MCP_STOP_REASON = {
|
|
30
|
+
"end_turn": LlmStopReason.END_TURN.value,
|
|
31
|
+
"stop_sequence": LlmStopReason.STOP_SEQUENCE.value,
|
|
32
|
+
"max_tokens": LlmStopReason.MAX_TOKENS.value,
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
if TYPE_CHECKING:
|
|
36
|
+
from mcp import ListToolsResult
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
import boto3
|
|
40
|
+
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
|
41
|
+
except ImportError:
|
|
42
|
+
boto3 = None
|
|
43
|
+
BotoCoreError = Exception
|
|
44
|
+
ClientError = Exception
|
|
45
|
+
NoCredentialsError = Exception
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
DEFAULT_BEDROCK_MODEL = "amazon.nova-lite-v1:0"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Local ReasoningEffort enum to avoid circular imports
|
|
52
|
+
class ReasoningEffort(Enum):
|
|
53
|
+
"""Reasoning effort levels for Bedrock models"""
|
|
54
|
+
|
|
55
|
+
MINIMAL = "minimal"
|
|
56
|
+
LOW = "low"
|
|
57
|
+
MEDIUM = "medium"
|
|
58
|
+
HIGH = "high"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Reasoning effort to token budget mapping
|
|
62
|
+
# Based on AWS recommendations: start with 1024 minimum, increment reasonably
|
|
63
|
+
REASONING_EFFORT_BUDGETS = {
|
|
64
|
+
ReasoningEffort.MINIMAL: 0, # Disabled
|
|
65
|
+
ReasoningEffort.LOW: 512, # Light reasoning
|
|
66
|
+
ReasoningEffort.MEDIUM: 1024, # AWS minimum recommendation
|
|
67
|
+
ReasoningEffort.HIGH: 2048, # Higher reasoning
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
# Bedrock message format types
|
|
71
|
+
BedrockMessage = dict[str, Any] # Bedrock message format
|
|
72
|
+
BedrockMessageParam = dict[str, Any] # Bedrock message parameter format
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ToolSchemaType(Enum):
|
|
76
|
+
"""Enum for different tool schema formats used by different model families."""
|
|
77
|
+
|
|
78
|
+
DEFAULT = auto() # Default toolSpec format used by most models (formerly Nova)
|
|
79
|
+
SYSTEM_PROMPT = auto() # System prompt-based tool calling format
|
|
80
|
+
ANTHROPIC = auto() # Native Anthropic tool calling format
|
|
81
|
+
NONE = auto() # Schema fallback failed, avoid retries
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SystemMode(Enum):
|
|
85
|
+
"""System message handling modes."""
|
|
86
|
+
|
|
87
|
+
SYSTEM = auto() # Use native system parameter
|
|
88
|
+
INJECT = auto() # Inject into user message
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class StreamPreference(Enum):
|
|
92
|
+
"""Streaming preference with tools."""
|
|
93
|
+
|
|
94
|
+
STREAM_OK = auto() # Model can stream with tools
|
|
95
|
+
NON_STREAM = auto() # Model requires non-streaming for tools
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ToolNamePolicy(Enum):
|
|
99
|
+
"""Tool name transformation policy."""
|
|
100
|
+
|
|
101
|
+
PRESERVE = auto() # Keep original tool names
|
|
102
|
+
UNDERSCORES = auto() # Convert to underscore format
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class StructuredStrategy(Enum):
|
|
106
|
+
"""Structured output generation strategy."""
|
|
107
|
+
|
|
108
|
+
STRICT_SCHEMA = auto() # Use full JSON schema
|
|
109
|
+
SIMPLIFIED_SCHEMA = auto() # Use simplified schema
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class ModelCapabilities:
|
|
114
|
+
"""Unified per-model capability cache to avoid scattered caches.
|
|
115
|
+
|
|
116
|
+
Uses proper enums and types to prevent typos and improve type safety.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
schema: ToolSchemaType | None = None
|
|
120
|
+
system_mode: SystemMode | None = None
|
|
121
|
+
stream_with_tools: StreamPreference | None = None
|
|
122
|
+
tool_name_policy: ToolNamePolicy | None = None
|
|
123
|
+
structured_strategy: StructuredStrategy | None = None
|
|
124
|
+
reasoning_support: bool | None = None # True=supported, False=unsupported, None=unknown
|
|
125
|
+
supports_tools: bool | None = None # True=yes, False=no, None=unknown
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class BedrockLLM(FastAgentLLM[BedrockMessageParam, BedrockMessage]):
|
|
129
|
+
"""
|
|
130
|
+
AWS Bedrock implementation of FastAgentLLM using the Converse API.
|
|
131
|
+
Supports all Bedrock models including Nova, Claude, Meta, etc.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
# Class-level capabilities cache shared across all instances
|
|
135
|
+
capabilities: dict[str, ModelCapabilities] = {}
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def debug_cache(cls) -> None:
|
|
139
|
+
"""Print human-readable JSON representation of the capabilities cache.
|
|
140
|
+
|
|
141
|
+
Useful for debugging and understanding what capabilities have been
|
|
142
|
+
discovered and cached for each model. Uses sys.stdout to bypass
|
|
143
|
+
any logging hijacking.
|
|
144
|
+
"""
|
|
145
|
+
if not cls.capabilities:
|
|
146
|
+
sys.stdout.write("{}\n")
|
|
147
|
+
sys.stdout.flush()
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
cache_dict = {}
|
|
151
|
+
for model, caps in cls.capabilities.items():
|
|
152
|
+
cache_dict[model] = {
|
|
153
|
+
"schema": caps.schema.name if caps.schema else None,
|
|
154
|
+
"system_mode": caps.system_mode.name if caps.system_mode else None,
|
|
155
|
+
"stream_with_tools": caps.stream_with_tools.name
|
|
156
|
+
if caps.stream_with_tools
|
|
157
|
+
else None,
|
|
158
|
+
"tool_name_policy": caps.tool_name_policy.name if caps.tool_name_policy else None,
|
|
159
|
+
"structured_strategy": caps.structured_strategy.name
|
|
160
|
+
if caps.structured_strategy
|
|
161
|
+
else None,
|
|
162
|
+
"reasoning_support": caps.reasoning_support,
|
|
163
|
+
"supports_tools": caps.supports_tools,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
output = json.dumps(cache_dict, indent=2, sort_keys=True)
|
|
167
|
+
sys.stdout.write(f"{output}\n")
|
|
168
|
+
sys.stdout.flush()
|
|
169
|
+
|
|
170
|
+
@classmethod
|
|
171
|
+
def matches_model_pattern(cls, model_name: str) -> bool:
|
|
172
|
+
"""Return True if model_name exists in the Bedrock model list loaded at init.
|
|
173
|
+
|
|
174
|
+
Uses the centralized discovery in bedrock_utils; no regex, no fallbacks.
|
|
175
|
+
Gracefully handles environments without AWS access by returning False.
|
|
176
|
+
"""
|
|
177
|
+
from fast_agent.llm.provider.bedrock.bedrock_utils import all_bedrock_models
|
|
178
|
+
|
|
179
|
+
try:
|
|
180
|
+
available = set(all_bedrock_models(prefix=""))
|
|
181
|
+
return model_name in available
|
|
182
|
+
except Exception:
|
|
183
|
+
# If AWS calls fail (no credentials, region not configured, etc.),
|
|
184
|
+
# assume this is not a Bedrock model
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
188
|
+
"""Initialize the Bedrock LLM with AWS credentials and region."""
|
|
189
|
+
if boto3 is None:
|
|
190
|
+
raise ImportError(
|
|
191
|
+
"boto3 is required for Bedrock support. Install with: pip install boto3"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Initialize logger
|
|
195
|
+
self.logger = get_logger(__name__)
|
|
196
|
+
|
|
197
|
+
# Extract AWS configuration from kwargs first
|
|
198
|
+
self.aws_region = kwargs.pop("region", None)
|
|
199
|
+
self.aws_profile = kwargs.pop("profile", None)
|
|
200
|
+
|
|
201
|
+
super().__init__(*args, provider=Provider.BEDROCK, **kwargs)
|
|
202
|
+
|
|
203
|
+
# Use config values if not provided in kwargs (after super().__init__)
|
|
204
|
+
if self.context.config and self.context.config.bedrock:
|
|
205
|
+
if not self.aws_region:
|
|
206
|
+
self.aws_region = self.context.config.bedrock.region
|
|
207
|
+
if not self.aws_profile:
|
|
208
|
+
self.aws_profile = self.context.config.bedrock.profile
|
|
209
|
+
|
|
210
|
+
# Final fallback to environment variables
|
|
211
|
+
if not self.aws_region:
|
|
212
|
+
# Support both AWS_REGION and AWS_DEFAULT_REGION
|
|
213
|
+
self.aws_region = os.environ.get("AWS_REGION") or os.environ.get(
|
|
214
|
+
"AWS_DEFAULT_REGION", "us-east-1"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
if not self.aws_profile:
|
|
218
|
+
# Support AWS_PROFILE environment variable
|
|
219
|
+
self.aws_profile = os.environ.get("AWS_PROFILE")
|
|
220
|
+
|
|
221
|
+
# Initialize AWS clients
|
|
222
|
+
self._bedrock_client = None
|
|
223
|
+
self._bedrock_runtime_client = None
|
|
224
|
+
|
|
225
|
+
# One-shot hint to force non-streaming on next completion (used by structured outputs)
|
|
226
|
+
self._force_non_streaming_once: bool = False
|
|
227
|
+
|
|
228
|
+
# Set up reasoning-related attributes
|
|
229
|
+
self._reasoning_effort = kwargs.get("reasoning_effort", None)
|
|
230
|
+
if (
|
|
231
|
+
self._reasoning_effort is None
|
|
232
|
+
and self.context
|
|
233
|
+
and self.context.config
|
|
234
|
+
and self.context.config.bedrock
|
|
235
|
+
):
|
|
236
|
+
if hasattr(self.context.config.bedrock, "reasoning_effort"):
|
|
237
|
+
self._reasoning_effort = self.context.config.bedrock.reasoning_effort
|
|
238
|
+
|
|
239
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
|
240
|
+
"""Initialize Bedrock-specific default parameters"""
|
|
241
|
+
# Get base defaults from parent (includes ModelDatabase lookup)
|
|
242
|
+
base_params = super()._initialize_default_params(kwargs)
|
|
243
|
+
|
|
244
|
+
# Override with Bedrock-specific settings - ensure we always have a model
|
|
245
|
+
chosen_model = kwargs.get("model", DEFAULT_BEDROCK_MODEL)
|
|
246
|
+
base_params.model = chosen_model
|
|
247
|
+
|
|
248
|
+
return base_params
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def model(self) -> str:
|
|
252
|
+
"""Get the model name, guaranteed to be set."""
|
|
253
|
+
return self.default_request_params.model
|
|
254
|
+
|
|
255
|
+
def _get_bedrock_client(self):
|
|
256
|
+
"""Get or create Bedrock client."""
|
|
257
|
+
if self._bedrock_client is None:
|
|
258
|
+
try:
|
|
259
|
+
session = boto3.Session(profile_name=self.aws_profile) # type: ignore[union-attr]
|
|
260
|
+
self._bedrock_client = session.client("bedrock", region_name=self.aws_region)
|
|
261
|
+
except NoCredentialsError as e:
|
|
262
|
+
raise ProviderKeyError(
|
|
263
|
+
"AWS credentials not found",
|
|
264
|
+
"Please configure AWS credentials using AWS CLI, environment variables, or IAM roles.",
|
|
265
|
+
) from e
|
|
266
|
+
return self._bedrock_client
|
|
267
|
+
|
|
268
|
+
def _get_bedrock_runtime_client(self):
|
|
269
|
+
"""Get or create Bedrock Runtime client."""
|
|
270
|
+
if self._bedrock_runtime_client is None:
|
|
271
|
+
try:
|
|
272
|
+
session = boto3.Session(profile_name=self.aws_profile) # type: ignore[union-attr]
|
|
273
|
+
self._bedrock_runtime_client = session.client(
|
|
274
|
+
"bedrock-runtime", region_name=self.aws_region
|
|
275
|
+
)
|
|
276
|
+
except NoCredentialsError as e:
|
|
277
|
+
raise ProviderKeyError(
|
|
278
|
+
"AWS credentials not found",
|
|
279
|
+
"Please configure AWS credentials using AWS CLI, environment variables, or IAM roles.",
|
|
280
|
+
) from e
|
|
281
|
+
return self._bedrock_runtime_client
|
|
282
|
+
|
|
283
|
+
def _convert_extended_messages_to_provider(
|
|
284
|
+
self, messages: list[PromptMessageExtended]
|
|
285
|
+
) -> list[BedrockMessageParam]:
|
|
286
|
+
"""
|
|
287
|
+
Convert PromptMessageExtended list to Bedrock BedrockMessageParam format.
|
|
288
|
+
This is called fresh on every API call from _convert_to_provider_format().
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
messages: List of PromptMessageExtended objects
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
List of Bedrock BedrockMessageParam objects
|
|
295
|
+
"""
|
|
296
|
+
converted: list[BedrockMessageParam] = []
|
|
297
|
+
for msg in messages:
|
|
298
|
+
bedrock_msg = BedrockConverter.convert_to_bedrock(msg)
|
|
299
|
+
converted.append(bedrock_msg)
|
|
300
|
+
return converted
|
|
301
|
+
|
|
302
|
+
def _build_tool_name_mapping(
|
|
303
|
+
self, tools: "ListToolsResult", name_policy: ToolNamePolicy
|
|
304
|
+
) -> dict[str, str]:
|
|
305
|
+
"""Build tool name mapping based on schema type and name policy.
|
|
306
|
+
|
|
307
|
+
Returns dict mapping from converted_name -> original_name for tool execution.
|
|
308
|
+
"""
|
|
309
|
+
mapping = {}
|
|
310
|
+
|
|
311
|
+
if name_policy == ToolNamePolicy.PRESERVE:
|
|
312
|
+
# Identity mapping for preserve policy
|
|
313
|
+
for tool in tools.tools:
|
|
314
|
+
mapping[tool.name] = tool.name
|
|
315
|
+
else:
|
|
316
|
+
# Nova-style cleaning for underscores policy
|
|
317
|
+
for tool in tools.tools:
|
|
318
|
+
clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", tool.name)
|
|
319
|
+
clean_name = re.sub(r"_+", "_", clean_name).strip("_")
|
|
320
|
+
if not clean_name:
|
|
321
|
+
clean_name = f"tool_{hash(tool.name) % 10000}"
|
|
322
|
+
mapping[clean_name] = tool.name
|
|
323
|
+
|
|
324
|
+
return mapping
|
|
325
|
+
|
|
326
|
+
def _convert_tools_nova_format(
|
|
327
|
+
self, tools: "ListToolsResult", tool_name_mapping: dict[str, str]
|
|
328
|
+
) -> list[dict[str, Any]]:
|
|
329
|
+
"""Convert MCP tools to Nova-specific toolSpec format.
|
|
330
|
+
|
|
331
|
+
Note: Nova models have VERY strict JSON schema requirements:
|
|
332
|
+
- Top level schema must be of type Object
|
|
333
|
+
- ONLY three fields are supported: type, properties, required
|
|
334
|
+
- NO other fields like $schema, description, title, additionalProperties
|
|
335
|
+
- Properties can only have type and description
|
|
336
|
+
- Tools with no parameters should have empty properties object
|
|
337
|
+
"""
|
|
338
|
+
bedrock_tools = []
|
|
339
|
+
|
|
340
|
+
self.logger.debug(f"Converting {len(tools.tools)} MCP tools to Nova format")
|
|
341
|
+
|
|
342
|
+
for tool in tools.tools:
|
|
343
|
+
self.logger.debug(f"Converting MCP tool: {tool.name}")
|
|
344
|
+
|
|
345
|
+
# Extract and validate the input schema
|
|
346
|
+
input_schema = tool.inputSchema or {}
|
|
347
|
+
|
|
348
|
+
# Create Nova-compliant schema with ONLY the three allowed fields
|
|
349
|
+
# Always include type and properties (even if empty)
|
|
350
|
+
nova_schema: dict[str, Any] = {"type": "object", "properties": {}}
|
|
351
|
+
|
|
352
|
+
# Properties - clean them strictly
|
|
353
|
+
properties: dict[str, Any] = {}
|
|
354
|
+
if "properties" in input_schema and isinstance(input_schema["properties"], dict):
|
|
355
|
+
for prop_name, prop_def in input_schema["properties"].items():
|
|
356
|
+
# Only include type and description for each property
|
|
357
|
+
clean_prop: dict[str, Any] = {}
|
|
358
|
+
|
|
359
|
+
if isinstance(prop_def, dict):
|
|
360
|
+
# Only include type (required) and description (optional)
|
|
361
|
+
clean_prop["type"] = prop_def.get("type", "string")
|
|
362
|
+
# Nova allows description in properties
|
|
363
|
+
if "description" in prop_def:
|
|
364
|
+
clean_prop["description"] = prop_def["description"]
|
|
365
|
+
else:
|
|
366
|
+
# Handle simple property definitions
|
|
367
|
+
clean_prop["type"] = "string"
|
|
368
|
+
|
|
369
|
+
properties[prop_name] = clean_prop
|
|
370
|
+
|
|
371
|
+
# Always set properties (even if empty for parameterless tools)
|
|
372
|
+
nova_schema["properties"] = properties
|
|
373
|
+
|
|
374
|
+
# Required fields - only add if present and not empty
|
|
375
|
+
if (
|
|
376
|
+
"required" in input_schema
|
|
377
|
+
and isinstance(input_schema["required"], list)
|
|
378
|
+
and input_schema["required"]
|
|
379
|
+
):
|
|
380
|
+
nova_schema["required"] = input_schema["required"]
|
|
381
|
+
|
|
382
|
+
# Use the tool name mapping that was already built in _bedrock_completion
|
|
383
|
+
# This ensures consistent transformation logic across the codebase
|
|
384
|
+
clean_name = None
|
|
385
|
+
for mapped_name, original_name in tool_name_mapping.items():
|
|
386
|
+
if original_name == tool.name:
|
|
387
|
+
clean_name = mapped_name
|
|
388
|
+
break
|
|
389
|
+
|
|
390
|
+
if clean_name is None:
|
|
391
|
+
# Fallback if mapping not found (shouldn't happen)
|
|
392
|
+
clean_name = tool.name
|
|
393
|
+
self.logger.warning(
|
|
394
|
+
f"Tool name mapping not found for {tool.name}, using original name"
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
bedrock_tool = {
|
|
398
|
+
"toolSpec": {
|
|
399
|
+
"name": clean_name,
|
|
400
|
+
"description": tool.description or f"Tool: {tool.name}",
|
|
401
|
+
"inputSchema": {"json": nova_schema},
|
|
402
|
+
}
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
bedrock_tools.append(bedrock_tool)
|
|
406
|
+
|
|
407
|
+
self.logger.debug(f"Converted {len(bedrock_tools)} tools for Nova format")
|
|
408
|
+
return bedrock_tools
|
|
409
|
+
|
|
410
|
+
def _convert_tools_system_prompt_format(
|
|
411
|
+
self, tools: "ListToolsResult", tool_name_mapping: dict[str, str]
|
|
412
|
+
) -> str:
|
|
413
|
+
"""Convert MCP tools to system prompt format."""
|
|
414
|
+
if not tools.tools:
|
|
415
|
+
return ""
|
|
416
|
+
|
|
417
|
+
self.logger.debug(f"Converting {len(tools.tools)} MCP tools to system prompt format")
|
|
418
|
+
|
|
419
|
+
prompt_parts = [
|
|
420
|
+
"You have the following tools available to help answer the user's request. You can call one or more functions at a time. The functions are described here in JSON-schema format:",
|
|
421
|
+
"",
|
|
422
|
+
]
|
|
423
|
+
|
|
424
|
+
# Add each tool definition in JSON format
|
|
425
|
+
for tool in tools.tools:
|
|
426
|
+
self.logger.debug(f"Converting MCP tool: {tool.name}")
|
|
427
|
+
|
|
428
|
+
# Use original tool name (no hyphen replacement)
|
|
429
|
+
tool_name = tool.name
|
|
430
|
+
|
|
431
|
+
# Create tool definition
|
|
432
|
+
tool_def = {
|
|
433
|
+
"type": "function",
|
|
434
|
+
"function": {
|
|
435
|
+
"name": tool_name,
|
|
436
|
+
"description": tool.description or f"Tool: {tool.name}",
|
|
437
|
+
"parameters": tool.inputSchema or {"type": "object", "properties": {}},
|
|
438
|
+
},
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
prompt_parts.append(json.dumps(tool_def))
|
|
442
|
+
|
|
443
|
+
# Add the response format instructions
|
|
444
|
+
prompt_parts.extend(
|
|
445
|
+
[
|
|
446
|
+
"",
|
|
447
|
+
"To call one or more tools, provide the tool calls on a new line as a JSON-formatted array. Explain your steps in a neutral tone. Then, only call the tools you can for the first step, then end your turn. If you previously received an error, you can try to call the tool again. Give up after 3 errors.",
|
|
448
|
+
"",
|
|
449
|
+
"Conform precisely to the single-line format of this example:",
|
|
450
|
+
"Tool Call:",
|
|
451
|
+
'[{"name": "SampleTool", "arguments": {"foo": "bar"}},{"name": "SampleTool", "arguments": {"foo": "other"}}]',
|
|
452
|
+
"",
|
|
453
|
+
"When calling a tool you must supply valid JSON with both 'name' and 'arguments' keys with the function name and function arguments respectively. Do not add any preamble, labels or extra text, just the single JSON string in one of the specified formats",
|
|
454
|
+
]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
system_prompt = "\n".join(prompt_parts)
|
|
458
|
+
self.logger.debug(f"Generated Llama native system prompt: {system_prompt}")
|
|
459
|
+
|
|
460
|
+
return system_prompt
|
|
461
|
+
|
|
462
|
+
def _convert_tools_anthropic_format(
|
|
463
|
+
self, tools: "ListToolsResult", tool_name_mapping: dict[str, str]
|
|
464
|
+
) -> list[dict[str, Any]]:
|
|
465
|
+
"""Convert MCP tools to Anthropic format wrapped in Bedrock toolSpec - preserves raw schema."""
|
|
466
|
+
|
|
467
|
+
self.logger.debug(
|
|
468
|
+
f"Converting {len(tools.tools)} MCP tools to Anthropic format with toolSpec wrapper"
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
bedrock_tools = []
|
|
472
|
+
for tool in tools.tools:
|
|
473
|
+
self.logger.debug(f"Converting MCP tool: {tool.name}")
|
|
474
|
+
|
|
475
|
+
# Use raw MCP schema (like native Anthropic provider) - no cleaning
|
|
476
|
+
input_schema = tool.inputSchema or {"type": "object", "properties": {}}
|
|
477
|
+
|
|
478
|
+
# Wrap in Bedrock toolSpec format but preserve raw Anthropic schema
|
|
479
|
+
bedrock_tool = {
|
|
480
|
+
"toolSpec": {
|
|
481
|
+
"name": tool.name, # Original name, no cleaning
|
|
482
|
+
"description": tool.description or f"Tool: {tool.name}",
|
|
483
|
+
"inputSchema": {
|
|
484
|
+
"json": input_schema # Raw MCP schema, not cleaned
|
|
485
|
+
},
|
|
486
|
+
}
|
|
487
|
+
}
|
|
488
|
+
bedrock_tools.append(bedrock_tool)
|
|
489
|
+
|
|
490
|
+
self.logger.debug(
|
|
491
|
+
f"Converted {len(bedrock_tools)} tools to Anthropic format with toolSpec wrapper"
|
|
492
|
+
)
|
|
493
|
+
return bedrock_tools
|
|
494
|
+
|
|
495
|
+
def _parse_system_prompt_tool_response(
|
|
496
|
+
self, processed_response: dict[str, Any], model: str
|
|
497
|
+
) -> list[dict[str, Any]]:
|
|
498
|
+
"""Parse system prompt tool response format: function calls in text."""
|
|
499
|
+
# Extract text content from the response
|
|
500
|
+
text_content = ""
|
|
501
|
+
for content_item in processed_response.get("content", []):
|
|
502
|
+
if isinstance(content_item, dict) and "text" in content_item:
|
|
503
|
+
text_content += content_item["text"]
|
|
504
|
+
|
|
505
|
+
if not text_content:
|
|
506
|
+
return []
|
|
507
|
+
|
|
508
|
+
# Look for different tool call formats
|
|
509
|
+
tool_calls = []
|
|
510
|
+
|
|
511
|
+
# First try Scout format: [function_name(arguments)]
|
|
512
|
+
scout_pattern = r"\[([^(]+)\(([^)]*)\)\]"
|
|
513
|
+
scout_matches = re.findall(scout_pattern, text_content)
|
|
514
|
+
if scout_matches:
|
|
515
|
+
for i, (func_name, args_str) in enumerate(scout_matches):
|
|
516
|
+
func_name = func_name.strip()
|
|
517
|
+
args_str = args_str.strip()
|
|
518
|
+
|
|
519
|
+
# Parse arguments - could be empty, JSON object, or simple values
|
|
520
|
+
arguments = {}
|
|
521
|
+
if args_str:
|
|
522
|
+
try:
|
|
523
|
+
# Try to parse as JSON object first
|
|
524
|
+
if args_str.startswith("{") and args_str.endswith("}"):
|
|
525
|
+
arguments = json.loads(args_str)
|
|
526
|
+
else:
|
|
527
|
+
# For simple values, create a basic structure
|
|
528
|
+
arguments = {"value": args_str}
|
|
529
|
+
except json.JSONDecodeError:
|
|
530
|
+
# If JSON parsing fails, treat as string
|
|
531
|
+
arguments = {"value": args_str}
|
|
532
|
+
|
|
533
|
+
tool_calls.append(
|
|
534
|
+
{
|
|
535
|
+
"type": "system_prompt_tool",
|
|
536
|
+
"name": func_name,
|
|
537
|
+
"arguments": arguments,
|
|
538
|
+
"id": f"system_prompt_{func_name}_{i}",
|
|
539
|
+
}
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
if tool_calls:
|
|
543
|
+
return tool_calls
|
|
544
|
+
|
|
545
|
+
# Second try: find the "Action:" format (commonly used by Nova models)
|
|
546
|
+
action_pattern = r"Action:\s*([^(]+)\(([^)]*)\)"
|
|
547
|
+
action_matches = re.findall(action_pattern, text_content)
|
|
548
|
+
if action_matches:
|
|
549
|
+
for i, (func_name, args_str) in enumerate(action_matches):
|
|
550
|
+
func_name = func_name.strip()
|
|
551
|
+
args_str = args_str.strip()
|
|
552
|
+
|
|
553
|
+
# Parse arguments - handle quoted strings and key=value pairs
|
|
554
|
+
arguments = {}
|
|
555
|
+
if args_str:
|
|
556
|
+
try:
|
|
557
|
+
# Handle key=value format like location="London"
|
|
558
|
+
if "=" in args_str:
|
|
559
|
+
# Split by comma, then by = for each part
|
|
560
|
+
for arg_part in args_str.split(","):
|
|
561
|
+
if "=" in arg_part:
|
|
562
|
+
key, value = arg_part.split("=", 1)
|
|
563
|
+
key = key.strip()
|
|
564
|
+
value = value.strip().strip("\"'") # Remove quotes
|
|
565
|
+
arguments[key] = value
|
|
566
|
+
else:
|
|
567
|
+
# Single value argument - try to map to appropriate parameter name
|
|
568
|
+
value = args_str.strip("\"'") if args_str else ""
|
|
569
|
+
# Handle common single-parameter functions
|
|
570
|
+
if func_name == "check_weather":
|
|
571
|
+
arguments = {"location": value}
|
|
572
|
+
else:
|
|
573
|
+
# Generic fallback
|
|
574
|
+
arguments = {"value": value}
|
|
575
|
+
except Exception as e:
|
|
576
|
+
self.logger.warning(f"Failed to parse Action arguments: {args_str} - {e}")
|
|
577
|
+
arguments = {"value": args_str}
|
|
578
|
+
|
|
579
|
+
tool_calls.append(
|
|
580
|
+
{
|
|
581
|
+
"type": "system_prompt_tool",
|
|
582
|
+
"name": func_name,
|
|
583
|
+
"arguments": arguments,
|
|
584
|
+
"id": f"system_prompt_{func_name}_{i}",
|
|
585
|
+
}
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if tool_calls:
|
|
589
|
+
return tool_calls
|
|
590
|
+
|
|
591
|
+
# Third try: find the "Tool Call:" format
|
|
592
|
+
tool_call_match = re.search(r"Tool Call:\s*(\[.*?\])", text_content, re.DOTALL)
|
|
593
|
+
if tool_call_match:
|
|
594
|
+
json_str = tool_call_match.group(1)
|
|
595
|
+
try:
|
|
596
|
+
parsed_calls = json.loads(json_str)
|
|
597
|
+
if isinstance(parsed_calls, list):
|
|
598
|
+
for i, call in enumerate(parsed_calls):
|
|
599
|
+
if isinstance(call, dict) and "name" in call:
|
|
600
|
+
tool_calls.append(
|
|
601
|
+
{
|
|
602
|
+
"type": "system_prompt_tool",
|
|
603
|
+
"name": call["name"],
|
|
604
|
+
"arguments": call.get("arguments", {}),
|
|
605
|
+
"id": f"system_prompt_{call['name']}_{i}",
|
|
606
|
+
}
|
|
607
|
+
)
|
|
608
|
+
return tool_calls
|
|
609
|
+
except json.JSONDecodeError as e:
|
|
610
|
+
self.logger.warning(f"Failed to parse Tool Call JSON array: {json_str} - {e}")
|
|
611
|
+
|
|
612
|
+
# Fallback: try to parse JSON arrays that look like tool calls
|
|
613
|
+
# Look for arrays containing objects with "name" fields - avoid simple citations
|
|
614
|
+
array_match = re.search(r'\[.*?\{.*?"name".*?\}.*?\]', text_content, re.DOTALL)
|
|
615
|
+
if array_match:
|
|
616
|
+
json_str = array_match.group(0)
|
|
617
|
+
try:
|
|
618
|
+
parsed_calls = json.loads(json_str)
|
|
619
|
+
if isinstance(parsed_calls, list):
|
|
620
|
+
for i, call in enumerate(parsed_calls):
|
|
621
|
+
if isinstance(call, dict) and "name" in call:
|
|
622
|
+
tool_calls.append(
|
|
623
|
+
{
|
|
624
|
+
"type": "system_prompt_tool",
|
|
625
|
+
"name": call["name"],
|
|
626
|
+
"arguments": call.get("arguments", {}),
|
|
627
|
+
"id": f"system_prompt_{call['name']}_{i}",
|
|
628
|
+
}
|
|
629
|
+
)
|
|
630
|
+
return tool_calls
|
|
631
|
+
except json.JSONDecodeError as e:
|
|
632
|
+
self.logger.debug(f"Failed to parse JSON array: {json_str} - {e}")
|
|
633
|
+
|
|
634
|
+
# Fallback: try to parse as single JSON object (backward compatibility)
|
|
635
|
+
try:
|
|
636
|
+
json_match = re.search(r'\{[^}]*"name"[^}]*"arguments"[^}]*\}', text_content, re.DOTALL)
|
|
637
|
+
if json_match:
|
|
638
|
+
json_str = json_match.group(0)
|
|
639
|
+
function_call = json.loads(json_str)
|
|
640
|
+
|
|
641
|
+
if "name" in function_call:
|
|
642
|
+
return [
|
|
643
|
+
{
|
|
644
|
+
"type": "system_prompt_tool",
|
|
645
|
+
"name": function_call["name"],
|
|
646
|
+
"arguments": function_call.get("arguments", {}),
|
|
647
|
+
"id": f"system_prompt_{function_call['name']}",
|
|
648
|
+
}
|
|
649
|
+
]
|
|
650
|
+
|
|
651
|
+
except json.JSONDecodeError as e:
|
|
652
|
+
self.logger.warning(
|
|
653
|
+
f"Failed to parse system prompt tool response as JSON: {text_content} - {e}"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Fallback to old custom tag format in case some models still use it
|
|
657
|
+
function_regex = r"<function=([^>]+)>(.*?)</function>"
|
|
658
|
+
match = re.search(function_regex, text_content)
|
|
659
|
+
|
|
660
|
+
if match:
|
|
661
|
+
function_name = match.group(1)
|
|
662
|
+
function_args_json = match.group(2)
|
|
663
|
+
|
|
664
|
+
try:
|
|
665
|
+
function_args = json.loads(function_args_json)
|
|
666
|
+
return [
|
|
667
|
+
{
|
|
668
|
+
"type": "system_prompt_tool",
|
|
669
|
+
"name": function_name,
|
|
670
|
+
"arguments": function_args,
|
|
671
|
+
"id": f"system_prompt_{function_name}",
|
|
672
|
+
}
|
|
673
|
+
]
|
|
674
|
+
except json.JSONDecodeError:
|
|
675
|
+
self.logger.warning(
|
|
676
|
+
f"Failed to parse fallback custom tag format: {function_args_json}"
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Third try: find direct function call format like "function_name(args)"
|
|
680
|
+
direct_call_pattern = r"^([a-zA-Z_][a-zA-Z0-9_]*)\(([^)]*)\)$"
|
|
681
|
+
direct_call_match = re.search(direct_call_pattern, text_content.strip())
|
|
682
|
+
if direct_call_match:
|
|
683
|
+
func_name, args_str = direct_call_match.groups()
|
|
684
|
+
func_name = func_name.strip()
|
|
685
|
+
args_str = args_str.strip()
|
|
686
|
+
|
|
687
|
+
# Parse arguments
|
|
688
|
+
arguments = {}
|
|
689
|
+
if args_str:
|
|
690
|
+
try:
|
|
691
|
+
# Handle key=value format like location="London"
|
|
692
|
+
if "=" in args_str:
|
|
693
|
+
# Split by comma, then by = for each part
|
|
694
|
+
for arg_part in args_str.split(","):
|
|
695
|
+
if "=" in arg_part:
|
|
696
|
+
key, value = arg_part.split("=", 1)
|
|
697
|
+
key = key.strip()
|
|
698
|
+
value = value.strip().strip("\"'") # Remove quotes
|
|
699
|
+
arguments[key] = value
|
|
700
|
+
else:
|
|
701
|
+
# Single value argument - try to map to appropriate parameter name
|
|
702
|
+
value = args_str.strip("\"'") if args_str else ""
|
|
703
|
+
# Handle common single-parameter functions
|
|
704
|
+
if func_name == "check_weather":
|
|
705
|
+
arguments = {"location": value}
|
|
706
|
+
else:
|
|
707
|
+
# Generic fallback
|
|
708
|
+
arguments = {"value": value}
|
|
709
|
+
except Exception as e:
|
|
710
|
+
self.logger.warning(f"Failed to parse direct call arguments: {args_str} - {e}")
|
|
711
|
+
arguments = {"value": args_str}
|
|
712
|
+
|
|
713
|
+
return [
|
|
714
|
+
{
|
|
715
|
+
"type": "system_prompt_tool",
|
|
716
|
+
"name": func_name,
|
|
717
|
+
"arguments": arguments,
|
|
718
|
+
"id": f"system_prompt_{func_name}_0",
|
|
719
|
+
}
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
return []
|
|
723
|
+
|
|
724
|
+
def _parse_anthropic_tool_response(
|
|
725
|
+
self, processed_response: dict[str, Any]
|
|
726
|
+
) -> list[dict[str, Any]]:
|
|
727
|
+
"""Parse Anthropic tool response format (same as native provider)."""
|
|
728
|
+
tool_uses = []
|
|
729
|
+
|
|
730
|
+
# Look for toolUse in content items (Bedrock format for Anthropic models)
|
|
731
|
+
for content_item in processed_response.get("content", []):
|
|
732
|
+
if "toolUse" in content_item:
|
|
733
|
+
tool_use = content_item["toolUse"]
|
|
734
|
+
tool_uses.append(
|
|
735
|
+
{
|
|
736
|
+
"type": "anthropic_tool",
|
|
737
|
+
"name": tool_use["name"],
|
|
738
|
+
"arguments": tool_use["input"],
|
|
739
|
+
"id": tool_use["toolUseId"],
|
|
740
|
+
}
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
return tool_uses
|
|
744
|
+
|
|
745
|
+
def _parse_tool_response(
|
|
746
|
+
self, processed_response: dict[str, Any], model: str
|
|
747
|
+
) -> list[dict[str, Any]]:
|
|
748
|
+
"""Parse tool responses using cached schema, without model/family heuristics."""
|
|
749
|
+
caps = self.capabilities.get(model) or ModelCapabilities()
|
|
750
|
+
schema = caps.schema
|
|
751
|
+
|
|
752
|
+
# Choose parser strictly by cached schema
|
|
753
|
+
if schema == ToolSchemaType.SYSTEM_PROMPT:
|
|
754
|
+
return self._parse_system_prompt_tool_response(processed_response, model)
|
|
755
|
+
if schema == ToolSchemaType.ANTHROPIC:
|
|
756
|
+
return self._parse_anthropic_tool_response(processed_response)
|
|
757
|
+
|
|
758
|
+
# Default/Nova: detect toolUse objects
|
|
759
|
+
tool_uses = [
|
|
760
|
+
c
|
|
761
|
+
for c in processed_response.get("content", [])
|
|
762
|
+
if isinstance(c, dict) and "toolUse" in c
|
|
763
|
+
]
|
|
764
|
+
if tool_uses:
|
|
765
|
+
parsed_tools: list[dict[str, Any]] = []
|
|
766
|
+
for item in tool_uses:
|
|
767
|
+
tu = item.get("toolUse", {})
|
|
768
|
+
if not isinstance(tu, dict):
|
|
769
|
+
continue
|
|
770
|
+
parsed_tools.append(
|
|
771
|
+
{
|
|
772
|
+
"type": "nova_tool",
|
|
773
|
+
"name": tu.get("name"),
|
|
774
|
+
"arguments": tu.get("input", {}),
|
|
775
|
+
"id": tu.get("toolUseId"),
|
|
776
|
+
}
|
|
777
|
+
)
|
|
778
|
+
if parsed_tools:
|
|
779
|
+
return parsed_tools
|
|
780
|
+
|
|
781
|
+
# Family-agnostic fallback: parse JSON array embedded in text
|
|
782
|
+
try:
|
|
783
|
+
text_content = ""
|
|
784
|
+
for content_item in processed_response.get("content", []):
|
|
785
|
+
if isinstance(content_item, dict) and "text" in content_item:
|
|
786
|
+
text_content += content_item["text"]
|
|
787
|
+
if text_content:
|
|
788
|
+
import json as _json
|
|
789
|
+
import re as _re
|
|
790
|
+
|
|
791
|
+
match = _re.search(r"\[(?:.|\n)*?\]", text_content)
|
|
792
|
+
if match:
|
|
793
|
+
arr = _json.loads(match.group(0))
|
|
794
|
+
if isinstance(arr, list) and arr and isinstance(arr[0], dict):
|
|
795
|
+
parsed_calls = []
|
|
796
|
+
for i, call in enumerate(arr):
|
|
797
|
+
name = call.get("name")
|
|
798
|
+
args = call.get("arguments", {})
|
|
799
|
+
if name:
|
|
800
|
+
parsed_calls.append(
|
|
801
|
+
{
|
|
802
|
+
"type": "system_prompt_tool",
|
|
803
|
+
"name": name,
|
|
804
|
+
"arguments": args,
|
|
805
|
+
"id": f"system_prompt_{name}_{i}",
|
|
806
|
+
}
|
|
807
|
+
)
|
|
808
|
+
if parsed_calls:
|
|
809
|
+
return parsed_calls
|
|
810
|
+
except Exception:
|
|
811
|
+
pass
|
|
812
|
+
|
|
813
|
+
# Final fallback: try system prompt parsing regardless of cached schema
|
|
814
|
+
# This handles cases where native tool calling failed but model generated system prompt format
|
|
815
|
+
try:
|
|
816
|
+
return self._parse_system_prompt_tool_response(processed_response, model)
|
|
817
|
+
except Exception:
|
|
818
|
+
pass
|
|
819
|
+
|
|
820
|
+
return []
|
|
821
|
+
|
|
822
|
+
def _build_tool_calls_dict(
|
|
823
|
+
self, parsed_tools: list[dict[str, Any]]
|
|
824
|
+
) -> dict[str, CallToolRequest]:
|
|
825
|
+
"""
|
|
826
|
+
Convert parsed tools to CallToolRequest dict for external execution.
|
|
827
|
+
|
|
828
|
+
Args:
|
|
829
|
+
parsed_tools: List of parsed tool dictionaries from _parse_tool_response()
|
|
830
|
+
|
|
831
|
+
Returns:
|
|
832
|
+
Dictionary mapping tool_use_id to CallToolRequest objects
|
|
833
|
+
"""
|
|
834
|
+
tool_calls = {}
|
|
835
|
+
for parsed_tool in parsed_tools:
|
|
836
|
+
# Use tool name directly, but map back to original if a mapping is available
|
|
837
|
+
tool_name = parsed_tool["name"]
|
|
838
|
+
try:
|
|
839
|
+
mapping = getattr(self, "tool_name_mapping", None)
|
|
840
|
+
if isinstance(mapping, dict):
|
|
841
|
+
tool_name = mapping.get(tool_name, tool_name)
|
|
842
|
+
except Exception:
|
|
843
|
+
pass
|
|
844
|
+
|
|
845
|
+
# Create CallToolRequest
|
|
846
|
+
tool_call = CallToolRequest(
|
|
847
|
+
method="tools/call",
|
|
848
|
+
params=CallToolRequestParams(
|
|
849
|
+
name=tool_name, arguments=parsed_tool.get("arguments", {})
|
|
850
|
+
),
|
|
851
|
+
)
|
|
852
|
+
tool_calls[parsed_tool["id"]] = tool_call
|
|
853
|
+
return tool_calls
|
|
854
|
+
|
|
855
|
+
def _map_bedrock_stop_reason(self, bedrock_stop_reason: str) -> LlmStopReason:
|
|
856
|
+
"""
|
|
857
|
+
Map Bedrock stop reasons to LlmStopReason enum.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
bedrock_stop_reason: Stop reason from Bedrock API
|
|
861
|
+
|
|
862
|
+
Returns:
|
|
863
|
+
Corresponding LlmStopReason enum value
|
|
864
|
+
"""
|
|
865
|
+
if bedrock_stop_reason == "tool_use":
|
|
866
|
+
return LlmStopReason.TOOL_USE
|
|
867
|
+
elif bedrock_stop_reason == "end_turn":
|
|
868
|
+
return LlmStopReason.END_TURN
|
|
869
|
+
elif bedrock_stop_reason == "stop_sequence":
|
|
870
|
+
return LlmStopReason.STOP_SEQUENCE
|
|
871
|
+
elif bedrock_stop_reason == "max_tokens":
|
|
872
|
+
return LlmStopReason.MAX_TOKENS
|
|
873
|
+
else:
|
|
874
|
+
# Default to END_TURN for unknown stop reasons, but log for debugging
|
|
875
|
+
self.logger.warning(
|
|
876
|
+
f"Unknown Bedrock stop reason: {bedrock_stop_reason}, defaulting to END_TURN"
|
|
877
|
+
)
|
|
878
|
+
return LlmStopReason.END_TURN
|
|
879
|
+
|
|
880
|
+
def _convert_multipart_to_bedrock_message(
|
|
881
|
+
self, msg: PromptMessageExtended
|
|
882
|
+
) -> BedrockMessageParam:
|
|
883
|
+
"""
|
|
884
|
+
Convert a PromptMessageExtended to Bedrock message parameter format.
|
|
885
|
+
Handles tool results and regular content.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
msg: PromptMessageExtended message to convert
|
|
889
|
+
|
|
890
|
+
Returns:
|
|
891
|
+
Bedrock message parameter dictionary
|
|
892
|
+
"""
|
|
893
|
+
bedrock_msg = {"role": msg.role, "content": []}
|
|
894
|
+
|
|
895
|
+
# Handle tool results first (if present)
|
|
896
|
+
if msg.tool_results:
|
|
897
|
+
# Get the cached schema type to determine result formatting
|
|
898
|
+
caps = self.capabilities.get(self.model) or ModelCapabilities()
|
|
899
|
+
# Check if any tool ID indicates system prompt format
|
|
900
|
+
has_system_prompt_tools = any(
|
|
901
|
+
tool_id.startswith("system_prompt_") for tool_id in msg.tool_results.keys()
|
|
902
|
+
)
|
|
903
|
+
is_system_prompt_schema = (
|
|
904
|
+
caps.schema == ToolSchemaType.SYSTEM_PROMPT or has_system_prompt_tools
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
if is_system_prompt_schema:
|
|
908
|
+
# For system prompt models: format as human-readable text
|
|
909
|
+
tool_result_parts = []
|
|
910
|
+
for tool_id, tool_result in msg.tool_results.items():
|
|
911
|
+
result_text = "".join(
|
|
912
|
+
part.text for part in tool_result.content if isinstance(part, TextContent)
|
|
913
|
+
)
|
|
914
|
+
result_payload = {
|
|
915
|
+
"tool_name": tool_id, # Use tool_id as name for system prompt
|
|
916
|
+
"status": "error" if tool_result.isError else "success",
|
|
917
|
+
"result": result_text,
|
|
918
|
+
}
|
|
919
|
+
tool_result_parts.append(json.dumps(result_payload))
|
|
920
|
+
|
|
921
|
+
if tool_result_parts:
|
|
922
|
+
full_result_text = f"Tool Results:\n{', '.join(tool_result_parts)}"
|
|
923
|
+
bedrock_msg["content"].append({"type": "text", "text": full_result_text})
|
|
924
|
+
else:
|
|
925
|
+
# For Nova/Anthropic models: use structured tool_result format
|
|
926
|
+
for tool_id, tool_result in msg.tool_results.items():
|
|
927
|
+
result_content_blocks = []
|
|
928
|
+
if tool_result.content:
|
|
929
|
+
for part in tool_result.content:
|
|
930
|
+
if isinstance(part, TextContent):
|
|
931
|
+
result_content_blocks.append({"text": part.text})
|
|
932
|
+
|
|
933
|
+
if not result_content_blocks:
|
|
934
|
+
result_content_blocks.append({"text": "[No content in tool result]"})
|
|
935
|
+
|
|
936
|
+
bedrock_msg["content"].append(
|
|
937
|
+
{
|
|
938
|
+
"type": "tool_result",
|
|
939
|
+
"tool_use_id": tool_id,
|
|
940
|
+
"content": result_content_blocks,
|
|
941
|
+
"status": "error" if tool_result.isError else "success",
|
|
942
|
+
}
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
# Handle regular content
|
|
946
|
+
for content_item in msg.content:
|
|
947
|
+
if isinstance(content_item, TextContent):
|
|
948
|
+
bedrock_msg["content"].append({"type": "text", "text": content_item.text})
|
|
949
|
+
|
|
950
|
+
return bedrock_msg
|
|
951
|
+
|
|
952
|
+
def _convert_messages_to_bedrock(
|
|
953
|
+
self, messages: list[BedrockMessageParam]
|
|
954
|
+
) -> list[dict[str, Any]]:
|
|
955
|
+
"""Convert message parameters to Bedrock format."""
|
|
956
|
+
bedrock_messages = []
|
|
957
|
+
for message in messages:
|
|
958
|
+
bedrock_message = {"role": message.get("role", "user"), "content": []}
|
|
959
|
+
|
|
960
|
+
content = message.get("content", [])
|
|
961
|
+
|
|
962
|
+
if isinstance(content, str):
|
|
963
|
+
bedrock_message["content"].append({"text": content})
|
|
964
|
+
elif isinstance(content, list):
|
|
965
|
+
for item in content:
|
|
966
|
+
item_type = item.get("type")
|
|
967
|
+
if item_type == "text":
|
|
968
|
+
bedrock_message["content"].append({"text": item.get("text", "")})
|
|
969
|
+
elif item_type == "tool_use":
|
|
970
|
+
bedrock_message["content"].append(
|
|
971
|
+
{
|
|
972
|
+
"toolUse": {
|
|
973
|
+
"toolUseId": item.get("id", ""),
|
|
974
|
+
"name": item.get("name", ""),
|
|
975
|
+
"input": item.get("input", {}),
|
|
976
|
+
}
|
|
977
|
+
}
|
|
978
|
+
)
|
|
979
|
+
elif item_type == "tool_result":
|
|
980
|
+
tool_use_id = item.get("tool_use_id")
|
|
981
|
+
raw_content = item.get("content", [])
|
|
982
|
+
status = item.get("status", "success")
|
|
983
|
+
|
|
984
|
+
bedrock_content_list = []
|
|
985
|
+
if raw_content:
|
|
986
|
+
for part in raw_content:
|
|
987
|
+
# FIX: The content parts are dicts, not TextContent objects.
|
|
988
|
+
if isinstance(part, dict) and "text" in part:
|
|
989
|
+
bedrock_content_list.append({"text": part.get("text", "")})
|
|
990
|
+
|
|
991
|
+
# Bedrock requires content for error statuses.
|
|
992
|
+
if not bedrock_content_list and status == "error":
|
|
993
|
+
bedrock_content_list.append({"text": "Tool call failed with an error."})
|
|
994
|
+
|
|
995
|
+
bedrock_message["content"].append(
|
|
996
|
+
{
|
|
997
|
+
"toolResult": {
|
|
998
|
+
"toolUseId": tool_use_id,
|
|
999
|
+
"content": bedrock_content_list,
|
|
1000
|
+
"status": status,
|
|
1001
|
+
}
|
|
1002
|
+
}
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
# Only add the message if it has content
|
|
1006
|
+
if bedrock_message["content"]:
|
|
1007
|
+
bedrock_messages.append(bedrock_message)
|
|
1008
|
+
|
|
1009
|
+
return bedrock_messages
|
|
1010
|
+
|
|
1011
|
+
async def _process_stream(
|
|
1012
|
+
self,
|
|
1013
|
+
stream_response,
|
|
1014
|
+
model: str,
|
|
1015
|
+
) -> BedrockMessage:
|
|
1016
|
+
"""Process streaming response from Bedrock."""
|
|
1017
|
+
estimated_tokens = 0
|
|
1018
|
+
response_content = []
|
|
1019
|
+
tool_uses = []
|
|
1020
|
+
stop_reason = None
|
|
1021
|
+
usage = {"input_tokens": 0, "output_tokens": 0}
|
|
1022
|
+
|
|
1023
|
+
try:
|
|
1024
|
+
# Cancellation is handled via asyncio.Task.cancel() which raises CancelledError
|
|
1025
|
+
for event in stream_response["stream"]:
|
|
1026
|
+
|
|
1027
|
+
if "messageStart" in event:
|
|
1028
|
+
# Message started
|
|
1029
|
+
continue
|
|
1030
|
+
elif "contentBlockStart" in event:
|
|
1031
|
+
# Content block started
|
|
1032
|
+
content_block = event["contentBlockStart"]
|
|
1033
|
+
if "start" in content_block and "toolUse" in content_block["start"]:
|
|
1034
|
+
# Tool use block started
|
|
1035
|
+
tool_use_start = content_block["start"]["toolUse"]
|
|
1036
|
+
self.logger.debug(f"Tool use block started: {tool_use_start}")
|
|
1037
|
+
tool_uses.append(
|
|
1038
|
+
{
|
|
1039
|
+
"toolUse": {
|
|
1040
|
+
"toolUseId": tool_use_start.get("toolUseId"),
|
|
1041
|
+
"name": tool_use_start.get("name"),
|
|
1042
|
+
"input": tool_use_start.get("input", {}),
|
|
1043
|
+
"_input_accumulator": "", # For accumulating streamed input
|
|
1044
|
+
}
|
|
1045
|
+
}
|
|
1046
|
+
)
|
|
1047
|
+
elif "contentBlockDelta" in event:
|
|
1048
|
+
# Content delta received
|
|
1049
|
+
delta = event["contentBlockDelta"]["delta"]
|
|
1050
|
+
if "text" in delta:
|
|
1051
|
+
text = delta["text"]
|
|
1052
|
+
response_content.append(text)
|
|
1053
|
+
# Update streaming progress
|
|
1054
|
+
estimated_tokens = self._update_streaming_progress(
|
|
1055
|
+
text, model, estimated_tokens
|
|
1056
|
+
)
|
|
1057
|
+
elif "toolUse" in delta:
|
|
1058
|
+
# Tool use delta - handle tool call
|
|
1059
|
+
tool_use = delta["toolUse"]
|
|
1060
|
+
self.logger.debug(f"Tool use delta: {tool_use}")
|
|
1061
|
+
if tool_use and tool_uses:
|
|
1062
|
+
# Handle input accumulation for streaming tool arguments
|
|
1063
|
+
if "input" in tool_use:
|
|
1064
|
+
input_data = tool_use["input"]
|
|
1065
|
+
|
|
1066
|
+
# If input is a dict, merge it directly
|
|
1067
|
+
if isinstance(input_data, dict):
|
|
1068
|
+
tool_uses[-1]["toolUse"]["input"].update(input_data)
|
|
1069
|
+
# If input is a string, accumulate it for later JSON parsing
|
|
1070
|
+
elif isinstance(input_data, str):
|
|
1071
|
+
tool_uses[-1]["toolUse"]["_input_accumulator"] += input_data
|
|
1072
|
+
self.logger.debug(
|
|
1073
|
+
f"Accumulated input: {tool_uses[-1]['toolUse']['_input_accumulator']}"
|
|
1074
|
+
)
|
|
1075
|
+
else:
|
|
1076
|
+
self.logger.debug(
|
|
1077
|
+
f"Tool use input is unexpected type: {type(input_data)}: {input_data}"
|
|
1078
|
+
)
|
|
1079
|
+
# Set the input directly if it's not a dict or string
|
|
1080
|
+
tool_uses[-1]["toolUse"]["input"] = input_data
|
|
1081
|
+
elif "contentBlockStop" in event:
|
|
1082
|
+
# Content block stopped - finalize any accumulated tool input
|
|
1083
|
+
if tool_uses:
|
|
1084
|
+
for tool_use in tool_uses:
|
|
1085
|
+
if "_input_accumulator" in tool_use["toolUse"]:
|
|
1086
|
+
accumulated_input = tool_use["toolUse"]["_input_accumulator"]
|
|
1087
|
+
if accumulated_input:
|
|
1088
|
+
self.logger.debug(
|
|
1089
|
+
f"Processing accumulated input: {accumulated_input}"
|
|
1090
|
+
)
|
|
1091
|
+
try:
|
|
1092
|
+
# Try to parse the accumulated input as JSON
|
|
1093
|
+
parsed_input = json.loads(accumulated_input)
|
|
1094
|
+
if isinstance(parsed_input, dict):
|
|
1095
|
+
tool_use["toolUse"]["input"].update(parsed_input)
|
|
1096
|
+
else:
|
|
1097
|
+
tool_use["toolUse"]["input"] = parsed_input
|
|
1098
|
+
self.logger.debug(
|
|
1099
|
+
f"Successfully parsed accumulated input: {parsed_input}"
|
|
1100
|
+
)
|
|
1101
|
+
except json.JSONDecodeError as e:
|
|
1102
|
+
self.logger.warning(
|
|
1103
|
+
f"Failed to parse accumulated input as JSON: {accumulated_input} - {e}"
|
|
1104
|
+
)
|
|
1105
|
+
# If it's not valid JSON, wrap it as a dict to avoid downstream errors
|
|
1106
|
+
tool_use["toolUse"]["input"] = {"value": accumulated_input}
|
|
1107
|
+
# Clean up the accumulator
|
|
1108
|
+
del tool_use["toolUse"]["_input_accumulator"]
|
|
1109
|
+
continue
|
|
1110
|
+
elif "messageStop" in event:
|
|
1111
|
+
# Message stopped
|
|
1112
|
+
if "stopReason" in event["messageStop"]:
|
|
1113
|
+
stop_reason = event["messageStop"]["stopReason"]
|
|
1114
|
+
elif "metadata" in event:
|
|
1115
|
+
# Usage metadata
|
|
1116
|
+
metadata = event["metadata"]
|
|
1117
|
+
if "usage" in metadata:
|
|
1118
|
+
usage = metadata["usage"]
|
|
1119
|
+
actual_tokens = usage.get("outputTokens", 0)
|
|
1120
|
+
if actual_tokens > 0:
|
|
1121
|
+
# Emit final progress with actual token count
|
|
1122
|
+
token_str = str(actual_tokens).rjust(5)
|
|
1123
|
+
data = {
|
|
1124
|
+
"progress_action": ProgressAction.STREAMING,
|
|
1125
|
+
"model": model,
|
|
1126
|
+
"agent_name": self.name,
|
|
1127
|
+
"chat_turn": self.chat_turn(),
|
|
1128
|
+
"details": token_str.strip(),
|
|
1129
|
+
}
|
|
1130
|
+
self.logger.info("Streaming progress", data=data)
|
|
1131
|
+
except Exception as e:
|
|
1132
|
+
self.logger.error(f"Error processing stream: {e}")
|
|
1133
|
+
raise
|
|
1134
|
+
|
|
1135
|
+
# Construct the response message
|
|
1136
|
+
full_text = "".join(response_content)
|
|
1137
|
+
response = {
|
|
1138
|
+
"content": [{"text": full_text}] if full_text else [],
|
|
1139
|
+
"stop_reason": stop_reason or "end_turn",
|
|
1140
|
+
"usage": {
|
|
1141
|
+
"input_tokens": usage.get("inputTokens", 0),
|
|
1142
|
+
"output_tokens": usage.get("outputTokens", 0),
|
|
1143
|
+
},
|
|
1144
|
+
"model": model,
|
|
1145
|
+
"role": "assistant",
|
|
1146
|
+
}
|
|
1147
|
+
|
|
1148
|
+
# Add tool uses if any
|
|
1149
|
+
if tool_uses:
|
|
1150
|
+
# Clean up any remaining accumulators before adding to response
|
|
1151
|
+
for tool_use in tool_uses:
|
|
1152
|
+
if "_input_accumulator" in tool_use["toolUse"]:
|
|
1153
|
+
accumulated_input = tool_use["toolUse"]["_input_accumulator"]
|
|
1154
|
+
if accumulated_input:
|
|
1155
|
+
self.logger.debug(
|
|
1156
|
+
f"Final processing of accumulated input: {accumulated_input}"
|
|
1157
|
+
)
|
|
1158
|
+
try:
|
|
1159
|
+
# Try to parse the accumulated input as JSON
|
|
1160
|
+
parsed_input = json.loads(accumulated_input)
|
|
1161
|
+
if isinstance(parsed_input, dict):
|
|
1162
|
+
tool_use["toolUse"]["input"].update(parsed_input)
|
|
1163
|
+
else:
|
|
1164
|
+
tool_use["toolUse"]["input"] = parsed_input
|
|
1165
|
+
self.logger.debug(
|
|
1166
|
+
f"Successfully parsed final accumulated input: {parsed_input}"
|
|
1167
|
+
)
|
|
1168
|
+
except json.JSONDecodeError as e:
|
|
1169
|
+
self.logger.warning(
|
|
1170
|
+
f"Failed to parse final accumulated input as JSON: {accumulated_input} - {e}"
|
|
1171
|
+
)
|
|
1172
|
+
# If it's not valid JSON, wrap it as a dict to avoid downstream errors
|
|
1173
|
+
tool_use["toolUse"]["input"] = {"value": accumulated_input}
|
|
1174
|
+
# Clean up the accumulator
|
|
1175
|
+
del tool_use["toolUse"]["_input_accumulator"]
|
|
1176
|
+
|
|
1177
|
+
response["content"].extend(tool_uses)
|
|
1178
|
+
|
|
1179
|
+
return response
|
|
1180
|
+
|
|
1181
|
+
def _process_non_streaming_response(self, response, model: str) -> BedrockMessage:
|
|
1182
|
+
"""Process non-streaming response from Bedrock."""
|
|
1183
|
+
self.logger.debug(f"Processing non-streaming response: {response}")
|
|
1184
|
+
|
|
1185
|
+
# Extract response content
|
|
1186
|
+
content = response.get("output", {}).get("message", {}).get("content", [])
|
|
1187
|
+
usage = response.get("usage", {})
|
|
1188
|
+
stop_reason = response.get("stopReason", "end_turn")
|
|
1189
|
+
|
|
1190
|
+
# Show progress for non-streaming (single update)
|
|
1191
|
+
if usage.get("outputTokens", 0) > 0:
|
|
1192
|
+
token_str = str(usage.get("outputTokens", 0)).rjust(5)
|
|
1193
|
+
data = {
|
|
1194
|
+
"progress_action": ProgressAction.STREAMING,
|
|
1195
|
+
"model": model,
|
|
1196
|
+
"agent_name": self.name,
|
|
1197
|
+
"chat_turn": self.chat_turn(),
|
|
1198
|
+
"details": token_str.strip(),
|
|
1199
|
+
}
|
|
1200
|
+
self.logger.info("Non-streaming progress", data=data)
|
|
1201
|
+
|
|
1202
|
+
# Convert to the same format as streaming response
|
|
1203
|
+
processed_response = {
|
|
1204
|
+
"content": content,
|
|
1205
|
+
"stop_reason": stop_reason,
|
|
1206
|
+
"usage": {
|
|
1207
|
+
"input_tokens": usage.get("inputTokens", 0),
|
|
1208
|
+
"output_tokens": usage.get("outputTokens", 0),
|
|
1209
|
+
},
|
|
1210
|
+
"model": model,
|
|
1211
|
+
"role": "assistant",
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
return processed_response
|
|
1215
|
+
|
|
1216
|
+
async def _bedrock_completion(
|
|
1217
|
+
self,
|
|
1218
|
+
message_param: BedrockMessageParam,
|
|
1219
|
+
request_params: RequestParams | None = None,
|
|
1220
|
+
tools: list[Tool] | None = None,
|
|
1221
|
+
pre_messages: list[BedrockMessageParam] | None = None,
|
|
1222
|
+
history: list[PromptMessageExtended] | None = None,
|
|
1223
|
+
) -> PromptMessageExtended:
|
|
1224
|
+
"""
|
|
1225
|
+
Process a query using Bedrock and available tools.
|
|
1226
|
+
Returns PromptMessageExtended with tool calls for external execution.
|
|
1227
|
+
"""
|
|
1228
|
+
client = self._get_bedrock_runtime_client()
|
|
1229
|
+
|
|
1230
|
+
try:
|
|
1231
|
+
messages: list[BedrockMessageParam] = list(pre_messages) if pre_messages else []
|
|
1232
|
+
params = self.get_request_params(request_params)
|
|
1233
|
+
except (ClientError, BotoCoreError) as e:
|
|
1234
|
+
error_msg = str(e)
|
|
1235
|
+
if "UnauthorizedOperation" in error_msg or "AccessDenied" in error_msg:
|
|
1236
|
+
raise ProviderKeyError(
|
|
1237
|
+
"AWS Bedrock access denied",
|
|
1238
|
+
"Please check your AWS credentials and IAM permissions for Bedrock.",
|
|
1239
|
+
) from e
|
|
1240
|
+
else:
|
|
1241
|
+
raise ProviderKeyError(
|
|
1242
|
+
"AWS Bedrock error",
|
|
1243
|
+
f"Error accessing Bedrock: {error_msg}",
|
|
1244
|
+
) from e
|
|
1245
|
+
|
|
1246
|
+
# Convert supplied history/messages directly
|
|
1247
|
+
if history:
|
|
1248
|
+
messages.extend(self._convert_to_provider_format(history))
|
|
1249
|
+
else:
|
|
1250
|
+
messages.append(message_param)
|
|
1251
|
+
|
|
1252
|
+
# Get available tools (no resolver gating; fallback logic will decide wiring)
|
|
1253
|
+
tool_list = None
|
|
1254
|
+
|
|
1255
|
+
try:
|
|
1256
|
+
tool_list = await self.aggregator.list_tools()
|
|
1257
|
+
self.logger.debug(f"Found {len(tool_list.tools)} MCP tools")
|
|
1258
|
+
except Exception as e:
|
|
1259
|
+
self.logger.error(f"Error fetching MCP tools: {e}")
|
|
1260
|
+
import traceback
|
|
1261
|
+
|
|
1262
|
+
self.logger.debug(f"Traceback: {traceback.format_exc()}")
|
|
1263
|
+
tool_list = None
|
|
1264
|
+
|
|
1265
|
+
# Use tools parameter if provided, otherwise get from aggregator
|
|
1266
|
+
if tools is None:
|
|
1267
|
+
tools = tool_list.tools if tool_list else []
|
|
1268
|
+
elif tool_list is None and tools:
|
|
1269
|
+
# Create a ListToolsResult from the provided tools for conversion
|
|
1270
|
+
from mcp.types import ListToolsResult
|
|
1271
|
+
|
|
1272
|
+
tool_list = ListToolsResult(tools=tools)
|
|
1273
|
+
|
|
1274
|
+
response_content_blocks: list[ContentBlock] = []
|
|
1275
|
+
model = self.default_request_params.model
|
|
1276
|
+
|
|
1277
|
+
# Single API call - no tool execution loop
|
|
1278
|
+
self._log_chat_progress(self.chat_turn(), model=model)
|
|
1279
|
+
|
|
1280
|
+
# Convert messages to Bedrock format
|
|
1281
|
+
bedrock_messages = self._convert_messages_to_bedrock(messages)
|
|
1282
|
+
|
|
1283
|
+
# Base system text
|
|
1284
|
+
base_system_text = self.instruction or params.systemPrompt
|
|
1285
|
+
|
|
1286
|
+
# Determine tool schema fallback order and caches
|
|
1287
|
+
caps = self.capabilities.get(model) or ModelCapabilities()
|
|
1288
|
+
if caps.schema and caps.schema != ToolSchemaType.NONE:
|
|
1289
|
+
# Special case: Force Mistral 7B to try SYSTEM_PROMPT instead of cached DEFAULT
|
|
1290
|
+
if (
|
|
1291
|
+
model == "mistral.mistral-7b-instruct-v0:2"
|
|
1292
|
+
and caps.schema == ToolSchemaType.DEFAULT
|
|
1293
|
+
):
|
|
1294
|
+
print(
|
|
1295
|
+
f"🔧 FORCING SYSTEM_PROMPT for {model} (was cached as DEFAULT)",
|
|
1296
|
+
file=sys.stderr,
|
|
1297
|
+
flush=True,
|
|
1298
|
+
)
|
|
1299
|
+
schema_order = [ToolSchemaType.SYSTEM_PROMPT, ToolSchemaType.DEFAULT]
|
|
1300
|
+
else:
|
|
1301
|
+
schema_order = [caps.schema]
|
|
1302
|
+
else:
|
|
1303
|
+
# Restore original fallback order: Anthropic models try anthropic first, others skip it
|
|
1304
|
+
if model.startswith("anthropic."):
|
|
1305
|
+
schema_order = [
|
|
1306
|
+
ToolSchemaType.ANTHROPIC,
|
|
1307
|
+
ToolSchemaType.DEFAULT,
|
|
1308
|
+
ToolSchemaType.SYSTEM_PROMPT,
|
|
1309
|
+
]
|
|
1310
|
+
elif model == "mistral.mistral-7b-instruct-v0:2":
|
|
1311
|
+
# Force Mistral 7B to try SYSTEM_PROMPT first (it doesn't work well with DEFAULT)
|
|
1312
|
+
schema_order = [
|
|
1313
|
+
ToolSchemaType.SYSTEM_PROMPT,
|
|
1314
|
+
ToolSchemaType.DEFAULT,
|
|
1315
|
+
]
|
|
1316
|
+
else:
|
|
1317
|
+
schema_order = [
|
|
1318
|
+
ToolSchemaType.DEFAULT,
|
|
1319
|
+
ToolSchemaType.SYSTEM_PROMPT,
|
|
1320
|
+
]
|
|
1321
|
+
|
|
1322
|
+
# Track whether we changed system mode cache this turn
|
|
1323
|
+
tried_system_fallback = False
|
|
1324
|
+
|
|
1325
|
+
processed_response = None # type: ignore[assignment]
|
|
1326
|
+
last_error_msg = None
|
|
1327
|
+
|
|
1328
|
+
for schema_choice in schema_order:
|
|
1329
|
+
# Fresh messages per attempt
|
|
1330
|
+
converse_args = {"modelId": model, "messages": [dict(m) for m in bedrock_messages]}
|
|
1331
|
+
|
|
1332
|
+
# Build tools representation for this schema
|
|
1333
|
+
tools_payload: Union[list[dict[str, Any]], str, None] = None
|
|
1334
|
+
# Get tool name policy (needed even when no tools for cache logic)
|
|
1335
|
+
name_policy = (
|
|
1336
|
+
self.capabilities.get(model) or ModelCapabilities()
|
|
1337
|
+
).tool_name_policy or ToolNamePolicy.PRESERVE
|
|
1338
|
+
|
|
1339
|
+
if tool_list and tool_list.tools:
|
|
1340
|
+
# Build tool name mapping once per schema attempt
|
|
1341
|
+
tool_name_mapping = self._build_tool_name_mapping(tool_list, name_policy)
|
|
1342
|
+
|
|
1343
|
+
# Store mapping for tool execution
|
|
1344
|
+
self.tool_name_mapping = tool_name_mapping
|
|
1345
|
+
|
|
1346
|
+
if schema_choice == ToolSchemaType.ANTHROPIC:
|
|
1347
|
+
tools_payload = self._convert_tools_anthropic_format(
|
|
1348
|
+
tool_list, tool_name_mapping
|
|
1349
|
+
)
|
|
1350
|
+
elif schema_choice == ToolSchemaType.DEFAULT:
|
|
1351
|
+
tools_payload = self._convert_tools_nova_format(tool_list, tool_name_mapping)
|
|
1352
|
+
elif schema_choice == ToolSchemaType.SYSTEM_PROMPT:
|
|
1353
|
+
tools_payload = self._convert_tools_system_prompt_format(
|
|
1354
|
+
tool_list, tool_name_mapping
|
|
1355
|
+
)
|
|
1356
|
+
|
|
1357
|
+
# System prompt handling with cache
|
|
1358
|
+
system_mode = (
|
|
1359
|
+
self.capabilities.get(model) or ModelCapabilities()
|
|
1360
|
+
).system_mode or SystemMode.SYSTEM
|
|
1361
|
+
system_text = base_system_text
|
|
1362
|
+
|
|
1363
|
+
if (
|
|
1364
|
+
schema_choice == ToolSchemaType.SYSTEM_PROMPT
|
|
1365
|
+
and isinstance(tools_payload, str)
|
|
1366
|
+
and tools_payload
|
|
1367
|
+
):
|
|
1368
|
+
system_text = f"{system_text}\n\n{tools_payload}" if system_text else tools_payload
|
|
1369
|
+
|
|
1370
|
+
# Cohere-specific nudge: force exact echo of tool result text on final answer
|
|
1371
|
+
if (
|
|
1372
|
+
schema_choice == ToolSchemaType.SYSTEM_PROMPT
|
|
1373
|
+
and isinstance(model, str)
|
|
1374
|
+
and model.startswith("cohere.")
|
|
1375
|
+
):
|
|
1376
|
+
cohere_nudge = (
|
|
1377
|
+
"FINAL ANSWER RULES (STRICT):\n"
|
|
1378
|
+
"- When a tool result is provided, your final answer MUST be exactly the raw tool result text.\n"
|
|
1379
|
+
"- Do not add any extra words, punctuation, qualifiers, or phrases (e.g., 'according to the tool').\n"
|
|
1380
|
+
"- Example: If tool result text is 'It"
|
|
1381
|
+
"s sunny in London', your final answer must be exactly: It"
|
|
1382
|
+
"s sunny in London\n"
|
|
1383
|
+
)
|
|
1384
|
+
system_text = f"{system_text}\n\n{cohere_nudge}" if system_text else cohere_nudge
|
|
1385
|
+
|
|
1386
|
+
# Llama3-specific nudge: prevent paraphrasing and extra tool calls
|
|
1387
|
+
if (
|
|
1388
|
+
schema_choice == ToolSchemaType.SYSTEM_PROMPT
|
|
1389
|
+
and isinstance(model, str)
|
|
1390
|
+
and model.startswith("meta.llama3")
|
|
1391
|
+
):
|
|
1392
|
+
llama_nudge = (
|
|
1393
|
+
"TOOL RESPONSE RULES:\n"
|
|
1394
|
+
"- After receiving a tool result, immediately output ONLY the exact tool result text.\n"
|
|
1395
|
+
"- Do not call additional tools or add commentary.\n"
|
|
1396
|
+
"- Do not paraphrase or modify the tool result in any way."
|
|
1397
|
+
)
|
|
1398
|
+
system_text = f"{system_text}\n\n{llama_nudge}" if system_text else llama_nudge
|
|
1399
|
+
|
|
1400
|
+
# Mistral-specific nudge: prevent tool calling loops and accept tool results
|
|
1401
|
+
if (
|
|
1402
|
+
schema_choice == ToolSchemaType.SYSTEM_PROMPT
|
|
1403
|
+
and isinstance(model, str)
|
|
1404
|
+
and model.startswith("mistral.")
|
|
1405
|
+
):
|
|
1406
|
+
mistral_nudge = (
|
|
1407
|
+
"TOOL EXECUTION RULES:\n"
|
|
1408
|
+
"- Call each tool only ONCE per conversation turn.\n"
|
|
1409
|
+
"- Accept and trust all tool results - do not question or retry them.\n"
|
|
1410
|
+
"- After receiving a tool result, provide a direct answer based on that result.\n"
|
|
1411
|
+
"- Do not call the same tool multiple times or call additional tools unless specifically requested.\n"
|
|
1412
|
+
"- Tool results are always valid - do not attempt to validate or correct them."
|
|
1413
|
+
)
|
|
1414
|
+
system_text = f"{system_text}\n\n{mistral_nudge}" if system_text else mistral_nudge
|
|
1415
|
+
|
|
1416
|
+
if system_text:
|
|
1417
|
+
if system_mode == SystemMode.SYSTEM:
|
|
1418
|
+
converse_args["system"] = [{"text": system_text}]
|
|
1419
|
+
self.logger.debug(
|
|
1420
|
+
f"Attempting with system param for {model} and schema={schema_choice}"
|
|
1421
|
+
)
|
|
1422
|
+
else:
|
|
1423
|
+
# inject
|
|
1424
|
+
if (
|
|
1425
|
+
converse_args["messages"]
|
|
1426
|
+
and converse_args["messages"][0].get("role") == "user"
|
|
1427
|
+
):
|
|
1428
|
+
first_message = converse_args["messages"][0]
|
|
1429
|
+
if first_message.get("content") and len(first_message["content"]) > 0:
|
|
1430
|
+
original_text = first_message["content"][0].get("text", "")
|
|
1431
|
+
first_message["content"][0]["text"] = (
|
|
1432
|
+
f"System: {system_text}\n\nUser: {original_text}"
|
|
1433
|
+
)
|
|
1434
|
+
self.logger.debug(
|
|
1435
|
+
"Injected system prompt into first user message (cached mode)"
|
|
1436
|
+
)
|
|
1437
|
+
|
|
1438
|
+
# Tools wiring
|
|
1439
|
+
# Always include toolConfig if we have tools OR if there are tool results in the conversation
|
|
1440
|
+
has_tool_results = False
|
|
1441
|
+
for msg in bedrock_messages:
|
|
1442
|
+
if isinstance(msg, dict) and msg.get("content"):
|
|
1443
|
+
for content in msg["content"]:
|
|
1444
|
+
if isinstance(content, dict) and "toolResult" in content:
|
|
1445
|
+
has_tool_results = True
|
|
1446
|
+
break
|
|
1447
|
+
if has_tool_results:
|
|
1448
|
+
break
|
|
1449
|
+
|
|
1450
|
+
if (
|
|
1451
|
+
schema_choice in (ToolSchemaType.ANTHROPIC, ToolSchemaType.DEFAULT)
|
|
1452
|
+
and isinstance(tools_payload, list)
|
|
1453
|
+
and tools_payload
|
|
1454
|
+
):
|
|
1455
|
+
# Include tools only when we have actual tools to provide
|
|
1456
|
+
converse_args["toolConfig"] = {"tools": tools_payload}
|
|
1457
|
+
|
|
1458
|
+
# Inference configuration and overrides
|
|
1459
|
+
inference_config: dict[str, Any] = {}
|
|
1460
|
+
if params.maxTokens is not None:
|
|
1461
|
+
inference_config["maxTokens"] = params.maxTokens
|
|
1462
|
+
if params.stopSequences:
|
|
1463
|
+
inference_config["stopSequences"] = params.stopSequences
|
|
1464
|
+
|
|
1465
|
+
# Check if reasoning should be enabled
|
|
1466
|
+
reasoning_budget = 0
|
|
1467
|
+
if self._reasoning_effort and self._reasoning_effort != ReasoningEffort.MINIMAL:
|
|
1468
|
+
# Convert string to enum if needed
|
|
1469
|
+
if isinstance(self._reasoning_effort, str):
|
|
1470
|
+
try:
|
|
1471
|
+
effort_enum = ReasoningEffort(self._reasoning_effort)
|
|
1472
|
+
except ValueError:
|
|
1473
|
+
effort_enum = ReasoningEffort.MINIMAL
|
|
1474
|
+
else:
|
|
1475
|
+
effort_enum = self._reasoning_effort
|
|
1476
|
+
|
|
1477
|
+
if effort_enum != ReasoningEffort.MINIMAL:
|
|
1478
|
+
reasoning_budget = REASONING_EFFORT_BUDGETS.get(effort_enum, 0)
|
|
1479
|
+
|
|
1480
|
+
# Handle temperature and reasoning configuration
|
|
1481
|
+
# AWS docs: "Thinking isn't compatible with temperature, top_p, or top_k modifications"
|
|
1482
|
+
reasoning_enabled = False
|
|
1483
|
+
if reasoning_budget > 0:
|
|
1484
|
+
# Check if this model supports reasoning (with caching)
|
|
1485
|
+
cached_reasoning = (
|
|
1486
|
+
self.capabilities.get(model) or ModelCapabilities()
|
|
1487
|
+
).reasoning_support
|
|
1488
|
+
if cached_reasoning == "supported":
|
|
1489
|
+
# We know this model supports reasoning
|
|
1490
|
+
converse_args["performanceConfig"] = {
|
|
1491
|
+
"reasoning": {"maxReasoningTokens": reasoning_budget}
|
|
1492
|
+
}
|
|
1493
|
+
reasoning_enabled = True
|
|
1494
|
+
elif cached_reasoning != "unsupported":
|
|
1495
|
+
# Unknown - we'll try reasoning and fallback if needed
|
|
1496
|
+
converse_args["performanceConfig"] = {
|
|
1497
|
+
"reasoning": {"maxReasoningTokens": reasoning_budget}
|
|
1498
|
+
}
|
|
1499
|
+
reasoning_enabled = True
|
|
1500
|
+
|
|
1501
|
+
if not reasoning_enabled:
|
|
1502
|
+
# No reasoning - apply temperature if provided
|
|
1503
|
+
if params.temperature is not None:
|
|
1504
|
+
inference_config["temperature"] = params.temperature
|
|
1505
|
+
|
|
1506
|
+
# Nova-specific recommendations (when not using reasoning)
|
|
1507
|
+
if model and "nova" in (model or "").lower() and reasoning_budget == 0:
|
|
1508
|
+
inference_config.setdefault("topP", 1.0)
|
|
1509
|
+
# Merge/attach additionalModelRequestFields for topK
|
|
1510
|
+
existing_amrf = converse_args.get("additionalModelRequestFields", {})
|
|
1511
|
+
merged_amrf = {**existing_amrf, **{"inferenceConfig": {"topK": 1}}}
|
|
1512
|
+
converse_args["additionalModelRequestFields"] = merged_amrf
|
|
1513
|
+
|
|
1514
|
+
if inference_config:
|
|
1515
|
+
converse_args["inferenceConfig"] = inference_config
|
|
1516
|
+
|
|
1517
|
+
# Decide streaming vs non-streaming (resolver-free with runtime detection + cache)
|
|
1518
|
+
has_tools: bool = False
|
|
1519
|
+
try:
|
|
1520
|
+
has_tools = bool(tools_payload) and bool(
|
|
1521
|
+
(isinstance(tools_payload, list) and len(tools_payload) > 0)
|
|
1522
|
+
or (isinstance(tools_payload, str) and tools_payload.strip())
|
|
1523
|
+
)
|
|
1524
|
+
|
|
1525
|
+
# Force non-streaming for structured-output flows (one-shot)
|
|
1526
|
+
force_non_streaming = False
|
|
1527
|
+
if self._force_non_streaming_once:
|
|
1528
|
+
force_non_streaming = True
|
|
1529
|
+
self._force_non_streaming_once = False
|
|
1530
|
+
|
|
1531
|
+
# Evaluate cache for streaming-with-tools
|
|
1532
|
+
cache_pref = (self.capabilities.get(model) or ModelCapabilities()).stream_with_tools
|
|
1533
|
+
use_streaming = True
|
|
1534
|
+
attempted_streaming = False
|
|
1535
|
+
|
|
1536
|
+
if force_non_streaming:
|
|
1537
|
+
use_streaming = False
|
|
1538
|
+
elif has_tools:
|
|
1539
|
+
if cache_pref == StreamPreference.NON_STREAM:
|
|
1540
|
+
use_streaming = False
|
|
1541
|
+
elif cache_pref == StreamPreference.STREAM_OK:
|
|
1542
|
+
use_streaming = True
|
|
1543
|
+
else:
|
|
1544
|
+
# Unknown: try streaming first, fallback on error
|
|
1545
|
+
use_streaming = True
|
|
1546
|
+
|
|
1547
|
+
# NEW: For Anthropic schema, when tool results are present in the conversation,
|
|
1548
|
+
# force non-streaming on this second turn to avoid empty streamed replies.
|
|
1549
|
+
if schema_choice == ToolSchemaType.ANTHROPIC and has_tool_results:
|
|
1550
|
+
use_streaming = False
|
|
1551
|
+
self.logger.debug(
|
|
1552
|
+
"Forcing non-streaming for Anthropic second turn with tool results"
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
# Try API call with reasoning fallback
|
|
1556
|
+
try:
|
|
1557
|
+
if not use_streaming:
|
|
1558
|
+
self.logger.debug(
|
|
1559
|
+
f"Using non-streaming API for {model} (schema={schema_choice})"
|
|
1560
|
+
)
|
|
1561
|
+
response = client.converse(**converse_args)
|
|
1562
|
+
processed_response = self._process_non_streaming_response(response, model)
|
|
1563
|
+
else:
|
|
1564
|
+
self.logger.debug(
|
|
1565
|
+
f"Using streaming API for {model} (schema={schema_choice})"
|
|
1566
|
+
)
|
|
1567
|
+
attempted_streaming = True
|
|
1568
|
+
response = client.converse_stream(**converse_args)
|
|
1569
|
+
processed_response = await self._process_stream(
|
|
1570
|
+
response, model
|
|
1571
|
+
)
|
|
1572
|
+
except (ClientError, BotoCoreError) as e:
|
|
1573
|
+
# Check if this is a reasoning-related error
|
|
1574
|
+
if reasoning_budget > 0 and (
|
|
1575
|
+
"reasoning" in str(e).lower() or "performance" in str(e).lower()
|
|
1576
|
+
):
|
|
1577
|
+
self.logger.debug(
|
|
1578
|
+
f"Model {model} doesn't support reasoning, retrying without: {e}"
|
|
1579
|
+
)
|
|
1580
|
+
caps.reasoning_support = False
|
|
1581
|
+
self.capabilities[model] = caps
|
|
1582
|
+
|
|
1583
|
+
# Remove reasoning and retry
|
|
1584
|
+
if "performanceConfig" in converse_args:
|
|
1585
|
+
del converse_args["performanceConfig"]
|
|
1586
|
+
|
|
1587
|
+
# Apply temperature now that reasoning is disabled
|
|
1588
|
+
if params.temperature is not None:
|
|
1589
|
+
if "inferenceConfig" not in converse_args:
|
|
1590
|
+
converse_args["inferenceConfig"] = {}
|
|
1591
|
+
converse_args["inferenceConfig"]["temperature"] = params.temperature
|
|
1592
|
+
|
|
1593
|
+
# Retry the API call
|
|
1594
|
+
if not use_streaming:
|
|
1595
|
+
response = client.converse(**converse_args)
|
|
1596
|
+
processed_response = self._process_non_streaming_response(
|
|
1597
|
+
response, model
|
|
1598
|
+
)
|
|
1599
|
+
else:
|
|
1600
|
+
response = client.converse_stream(**converse_args)
|
|
1601
|
+
processed_response = await self._process_stream(
|
|
1602
|
+
response, model
|
|
1603
|
+
)
|
|
1604
|
+
else:
|
|
1605
|
+
# Not a reasoning error, re-raise
|
|
1606
|
+
raise
|
|
1607
|
+
|
|
1608
|
+
# Success: cache the working schema choice if not already cached
|
|
1609
|
+
# Only cache schema when tools are present - no tools doesn't predict tool behavior
|
|
1610
|
+
if not caps.schema and has_tools:
|
|
1611
|
+
caps.schema = ToolSchemaType(schema_choice)
|
|
1612
|
+
|
|
1613
|
+
# Cache successful reasoning if we tried it
|
|
1614
|
+
if reasoning_budget > 0 and caps.reasoning_support is not True:
|
|
1615
|
+
caps.reasoning_support = True
|
|
1616
|
+
|
|
1617
|
+
# If Nova/default worked and we used preserve but server complains, flip cache for next time
|
|
1618
|
+
if (
|
|
1619
|
+
schema_choice == ToolSchemaType.DEFAULT
|
|
1620
|
+
and name_policy == ToolNamePolicy.PRESERVE
|
|
1621
|
+
):
|
|
1622
|
+
# Heuristic: if tool names include '-', prefer underscores next time
|
|
1623
|
+
try:
|
|
1624
|
+
if any("-" in t.name for t in (tool_list.tools if tool_list else [])):
|
|
1625
|
+
caps.tool_name_policy = ToolNamePolicy.UNDERSCORES
|
|
1626
|
+
except Exception:
|
|
1627
|
+
pass
|
|
1628
|
+
# Cache streaming-with-tools behavior on success
|
|
1629
|
+
if has_tools and attempted_streaming:
|
|
1630
|
+
caps.stream_with_tools = StreamPreference.STREAM_OK
|
|
1631
|
+
self.capabilities[model] = caps
|
|
1632
|
+
break
|
|
1633
|
+
except (ClientError, BotoCoreError) as e:
|
|
1634
|
+
error_msg = str(e)
|
|
1635
|
+
last_error_msg = error_msg
|
|
1636
|
+
self.logger.debug(f"Bedrock API error (schema={schema_choice}): {error_msg}")
|
|
1637
|
+
|
|
1638
|
+
# If streaming with tools failed and cache undecided, fallback to non-streaming and cache
|
|
1639
|
+
if has_tools and (caps.stream_with_tools is None):
|
|
1640
|
+
try:
|
|
1641
|
+
self.logger.debug(
|
|
1642
|
+
f"Falling back to non-streaming API for {model} after streaming error"
|
|
1643
|
+
)
|
|
1644
|
+
response = client.converse(**converse_args)
|
|
1645
|
+
processed_response = self._process_non_streaming_response(response, model)
|
|
1646
|
+
caps.stream_with_tools = StreamPreference.NON_STREAM
|
|
1647
|
+
if not caps.schema:
|
|
1648
|
+
caps.schema = ToolSchemaType(schema_choice)
|
|
1649
|
+
self.capabilities[model] = caps
|
|
1650
|
+
break
|
|
1651
|
+
except (ClientError, BotoCoreError) as e_fallback:
|
|
1652
|
+
last_error_msg = str(e_fallback)
|
|
1653
|
+
self.logger.debug(
|
|
1654
|
+
f"Bedrock API error after non-streaming fallback: {last_error_msg}"
|
|
1655
|
+
)
|
|
1656
|
+
# continue to other fallbacks (e.g., system inject or next schema)
|
|
1657
|
+
|
|
1658
|
+
# System parameter fallback once per call if system message unsupported
|
|
1659
|
+
if (
|
|
1660
|
+
not tried_system_fallback
|
|
1661
|
+
and system_text
|
|
1662
|
+
and system_mode == SystemMode.SYSTEM
|
|
1663
|
+
and (
|
|
1664
|
+
"system message" in error_msg.lower()
|
|
1665
|
+
or "system messages" in error_msg.lower()
|
|
1666
|
+
)
|
|
1667
|
+
):
|
|
1668
|
+
tried_system_fallback = True
|
|
1669
|
+
caps.system_mode = SystemMode.INJECT
|
|
1670
|
+
self.capabilities[model] = caps
|
|
1671
|
+
self.logger.info(
|
|
1672
|
+
f"Switching system mode to inject for {model} and retrying same schema"
|
|
1673
|
+
)
|
|
1674
|
+
# Retry the same schema immediately in inject mode
|
|
1675
|
+
try:
|
|
1676
|
+
# Rebuild messages for inject
|
|
1677
|
+
converse_args = {
|
|
1678
|
+
"modelId": model,
|
|
1679
|
+
"messages": [dict(m) for m in bedrock_messages],
|
|
1680
|
+
}
|
|
1681
|
+
# inject system into first user
|
|
1682
|
+
if (
|
|
1683
|
+
converse_args["messages"]
|
|
1684
|
+
and converse_args["messages"][0].get("role") == "user"
|
|
1685
|
+
):
|
|
1686
|
+
fm = converse_args["messages"][0]
|
|
1687
|
+
if fm.get("content") and len(fm["content"]) > 0:
|
|
1688
|
+
original_text = fm["content"][0].get("text", "")
|
|
1689
|
+
fm["content"][0]["text"] = (
|
|
1690
|
+
f"System: {system_text}\n\nUser: {original_text}"
|
|
1691
|
+
)
|
|
1692
|
+
|
|
1693
|
+
# Re-add tools
|
|
1694
|
+
if (
|
|
1695
|
+
schema_choice
|
|
1696
|
+
in (ToolSchemaType.ANTHROPIC.value, ToolSchemaType.DEFAULT.value)
|
|
1697
|
+
and isinstance(tools_payload, list)
|
|
1698
|
+
and tools_payload
|
|
1699
|
+
):
|
|
1700
|
+
converse_args["toolConfig"] = {"tools": tools_payload}
|
|
1701
|
+
|
|
1702
|
+
# Same streaming decision using cache
|
|
1703
|
+
has_tools = bool(tools_payload) and bool(
|
|
1704
|
+
(isinstance(tools_payload, list) and len(tools_payload) > 0)
|
|
1705
|
+
or (isinstance(tools_payload, str) and tools_payload.strip())
|
|
1706
|
+
)
|
|
1707
|
+
cache_pref = (
|
|
1708
|
+
self.capabilities.get(model) or ModelCapabilities()
|
|
1709
|
+
).stream_with_tools
|
|
1710
|
+
if cache_pref == StreamPreference.NON_STREAM or not has_tools:
|
|
1711
|
+
response = client.converse(**converse_args)
|
|
1712
|
+
processed_response = self._process_non_streaming_response(
|
|
1713
|
+
response, model
|
|
1714
|
+
)
|
|
1715
|
+
else:
|
|
1716
|
+
response = client.converse_stream(**converse_args)
|
|
1717
|
+
processed_response = await self._process_stream(
|
|
1718
|
+
response, model
|
|
1719
|
+
)
|
|
1720
|
+
if not caps.schema and has_tools:
|
|
1721
|
+
caps.schema = ToolSchemaType(schema_choice)
|
|
1722
|
+
self.capabilities[model] = caps
|
|
1723
|
+
break
|
|
1724
|
+
except (ClientError, BotoCoreError) as e2:
|
|
1725
|
+
last_error_msg = str(e2)
|
|
1726
|
+
self.logger.debug(
|
|
1727
|
+
f"Bedrock API error after system inject fallback: {last_error_msg}"
|
|
1728
|
+
)
|
|
1729
|
+
# Fall through to next schema
|
|
1730
|
+
continue
|
|
1731
|
+
|
|
1732
|
+
# For any other error (including tool format errors), continue to next schema
|
|
1733
|
+
self.logger.debug(
|
|
1734
|
+
f"Continuing to next schema after error with {schema_choice}: {error_msg}"
|
|
1735
|
+
)
|
|
1736
|
+
continue
|
|
1737
|
+
|
|
1738
|
+
if processed_response is None:
|
|
1739
|
+
# All attempts failed; mark schema as none to avoid repeated retries this process
|
|
1740
|
+
caps.schema = ToolSchemaType.NONE
|
|
1741
|
+
self.capabilities[model] = caps
|
|
1742
|
+
processed_response = {
|
|
1743
|
+
"content": [
|
|
1744
|
+
{"text": f"Error during generation: {last_error_msg or 'Unknown error'}"}
|
|
1745
|
+
],
|
|
1746
|
+
"stop_reason": "error",
|
|
1747
|
+
"usage": {"input_tokens": 0, "output_tokens": 0},
|
|
1748
|
+
"model": model,
|
|
1749
|
+
"role": "assistant",
|
|
1750
|
+
}
|
|
1751
|
+
|
|
1752
|
+
# Track usage
|
|
1753
|
+
if processed_response.get("usage"):
|
|
1754
|
+
try:
|
|
1755
|
+
usage = processed_response["usage"]
|
|
1756
|
+
turn_usage = TurnUsage(
|
|
1757
|
+
provider=Provider.BEDROCK.value,
|
|
1758
|
+
model=model,
|
|
1759
|
+
input_tokens=usage.get("input_tokens", 0),
|
|
1760
|
+
output_tokens=usage.get("output_tokens", 0),
|
|
1761
|
+
total_tokens=usage.get("input_tokens", 0) + usage.get("output_tokens", 0),
|
|
1762
|
+
raw_usage=usage,
|
|
1763
|
+
)
|
|
1764
|
+
self.usage_accumulator.add_turn(turn_usage)
|
|
1765
|
+
except Exception as e:
|
|
1766
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
|
1767
|
+
|
|
1768
|
+
self.logger.debug(f"{model} response:", data=processed_response)
|
|
1769
|
+
|
|
1770
|
+
# Convert response to message param and add to messages
|
|
1771
|
+
response_message_param = self.convert_message_to_message_param(processed_response)
|
|
1772
|
+
messages.append(response_message_param)
|
|
1773
|
+
|
|
1774
|
+
# Extract text content for responses
|
|
1775
|
+
if processed_response.get("content"):
|
|
1776
|
+
for content_item in processed_response["content"]:
|
|
1777
|
+
if content_item.get("text"):
|
|
1778
|
+
response_content_blocks.append(
|
|
1779
|
+
TextContent(type="text", text=content_item["text"])
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
# Fallback: if no content returned and the last input contained tool results,
|
|
1783
|
+
# synthesize the assistant reply using the tool result text to preserve behavior.
|
|
1784
|
+
if not response_content_blocks:
|
|
1785
|
+
try:
|
|
1786
|
+
# messages currently includes the appended assistant response; inspect the prior user message
|
|
1787
|
+
last_index = len(messages) - 2 if len(messages) >= 2 else (len(messages) - 1)
|
|
1788
|
+
last_input = messages[last_index] if last_index >= 0 else None
|
|
1789
|
+
if isinstance(last_input, dict):
|
|
1790
|
+
contents = last_input.get("content", []) or []
|
|
1791
|
+
for c in contents:
|
|
1792
|
+
# Handle parameter-level representation
|
|
1793
|
+
if isinstance(c, dict) and c.get("type") == "tool_result":
|
|
1794
|
+
tr_content = c.get("content", []) or []
|
|
1795
|
+
fallback_text = " ".join(
|
|
1796
|
+
part.get("text", "")
|
|
1797
|
+
for part in tr_content
|
|
1798
|
+
if isinstance(part, dict)
|
|
1799
|
+
).strip()
|
|
1800
|
+
if fallback_text:
|
|
1801
|
+
response_content_blocks.append(
|
|
1802
|
+
TextContent(type="text", text=fallback_text)
|
|
1803
|
+
)
|
|
1804
|
+
break
|
|
1805
|
+
# Handle bedrock-level representation
|
|
1806
|
+
if isinstance(c, dict) and "toolResult" in c:
|
|
1807
|
+
tr = c["toolResult"]
|
|
1808
|
+
tr_content = tr.get("content", []) or []
|
|
1809
|
+
fallback_text = " ".join(
|
|
1810
|
+
part.get("text", "")
|
|
1811
|
+
for part in tr_content
|
|
1812
|
+
if isinstance(part, dict)
|
|
1813
|
+
).strip()
|
|
1814
|
+
if fallback_text:
|
|
1815
|
+
response_content_blocks.append(
|
|
1816
|
+
TextContent(type="text", text=fallback_text)
|
|
1817
|
+
)
|
|
1818
|
+
break
|
|
1819
|
+
except Exception:
|
|
1820
|
+
pass
|
|
1821
|
+
|
|
1822
|
+
# Handle different stop reasons
|
|
1823
|
+
stop_reason = processed_response.get("stop_reason", "end_turn")
|
|
1824
|
+
|
|
1825
|
+
# Determine if we should parse for system-prompt tool calls (unified capabilities)
|
|
1826
|
+
caps_tmp = self.capabilities.get(model) or ModelCapabilities()
|
|
1827
|
+
|
|
1828
|
+
# Try to parse system prompt tool calls if we have an end_turn with tools available
|
|
1829
|
+
# This handles cases where native tool calling failed but model generates system prompt format
|
|
1830
|
+
if stop_reason == "end_turn" and tools:
|
|
1831
|
+
# Only parse for tools if text contains actual function call structure
|
|
1832
|
+
message_text = ""
|
|
1833
|
+
for content_item in processed_response.get("content", []):
|
|
1834
|
+
if isinstance(content_item, dict) and "text" in content_item:
|
|
1835
|
+
message_text += content_item.get("text", "")
|
|
1836
|
+
|
|
1837
|
+
# Check if there's a tool call in the response
|
|
1838
|
+
parsed_tools = self._parse_tool_response(processed_response, model)
|
|
1839
|
+
if parsed_tools:
|
|
1840
|
+
# Override stop_reason to handle as tool_use
|
|
1841
|
+
stop_reason = "tool_use"
|
|
1842
|
+
# Update capabilities cache to reflect successful system prompt tool calling
|
|
1843
|
+
if not caps_tmp.schema:
|
|
1844
|
+
caps_tmp.schema = ToolSchemaType.SYSTEM_PROMPT
|
|
1845
|
+
self.capabilities[model] = caps_tmp
|
|
1846
|
+
|
|
1847
|
+
# NEW: Handle tool calls without execution - return them for external handling
|
|
1848
|
+
tool_calls: dict[str, CallToolRequest] | None = None
|
|
1849
|
+
if stop_reason in ["tool_use", "tool_calls"]:
|
|
1850
|
+
parsed_tools = self._parse_tool_response(processed_response, model)
|
|
1851
|
+
if parsed_tools:
|
|
1852
|
+
tool_calls = self._build_tool_calls_dict(parsed_tools)
|
|
1853
|
+
|
|
1854
|
+
# Map stop reason to LlmStopReason
|
|
1855
|
+
mapped_stop_reason = self._map_bedrock_stop_reason(stop_reason)
|
|
1856
|
+
|
|
1857
|
+
# Update diagnostic snapshot (never read again)
|
|
1858
|
+
# This provides a snapshot of what was sent to the provider for debugging
|
|
1859
|
+
self.history.set(messages)
|
|
1860
|
+
|
|
1861
|
+
self._log_chat_finished(model=model)
|
|
1862
|
+
|
|
1863
|
+
# Return PromptMessageExtended with tool calls for external execution
|
|
1864
|
+
from fast_agent.core.prompt import Prompt
|
|
1865
|
+
|
|
1866
|
+
return Prompt.assistant(
|
|
1867
|
+
*response_content_blocks, stop_reason=mapped_stop_reason, tool_calls=tool_calls
|
|
1868
|
+
)
|
|
1869
|
+
|
|
1870
|
+
async def _apply_prompt_provider_specific(
|
|
1871
|
+
self,
|
|
1872
|
+
multipart_messages: list[PromptMessageExtended],
|
|
1873
|
+
request_params: RequestParams | None = None,
|
|
1874
|
+
tools: list[Tool] | None = None,
|
|
1875
|
+
is_template: bool = False,
|
|
1876
|
+
) -> PromptMessageExtended:
|
|
1877
|
+
"""
|
|
1878
|
+
Provider-specific prompt application.
|
|
1879
|
+
Templates are handled by the agent; messages already include them.
|
|
1880
|
+
"""
|
|
1881
|
+
if not multipart_messages:
|
|
1882
|
+
return PromptMessageExtended(role="user", content=[])
|
|
1883
|
+
|
|
1884
|
+
# Check the last message role
|
|
1885
|
+
last_message = multipart_messages[-1]
|
|
1886
|
+
|
|
1887
|
+
if last_message.role == "assistant":
|
|
1888
|
+
# For assistant messages: Return the last message (no completion needed)
|
|
1889
|
+
return last_message
|
|
1890
|
+
|
|
1891
|
+
# Convert the last user message to Bedrock message parameter format
|
|
1892
|
+
message_param = BedrockConverter.convert_to_bedrock(last_message)
|
|
1893
|
+
|
|
1894
|
+
# Call the completion method
|
|
1895
|
+
# No need to pass pre_messages - conversion happens in _bedrock_completion
|
|
1896
|
+
# via _convert_to_provider_format()
|
|
1897
|
+
return await self._bedrock_completion(
|
|
1898
|
+
message_param,
|
|
1899
|
+
request_params,
|
|
1900
|
+
tools,
|
|
1901
|
+
pre_messages=None,
|
|
1902
|
+
history=multipart_messages,
|
|
1903
|
+
)
|
|
1904
|
+
|
|
1905
|
+
def _generate_simplified_schema(self, model: Type[ModelT]) -> str:
|
|
1906
|
+
"""Generates a simplified, human-readable schema with inline enum constraints."""
|
|
1907
|
+
|
|
1908
|
+
def get_field_type_representation(field_type: Any) -> Any:
|
|
1909
|
+
"""Get a string representation for a field type."""
|
|
1910
|
+
# Handle Optional types
|
|
1911
|
+
if hasattr(field_type, "__origin__") and field_type.__origin__ is Union:
|
|
1912
|
+
non_none_types = [t for t in field_type.__args__ if t is not type(None)]
|
|
1913
|
+
if non_none_types:
|
|
1914
|
+
field_type = non_none_types[0]
|
|
1915
|
+
|
|
1916
|
+
# Handle basic types
|
|
1917
|
+
if field_type is str:
|
|
1918
|
+
return "string"
|
|
1919
|
+
elif field_type is int:
|
|
1920
|
+
return "integer"
|
|
1921
|
+
elif field_type is float:
|
|
1922
|
+
return "float"
|
|
1923
|
+
elif field_type is bool:
|
|
1924
|
+
return "boolean"
|
|
1925
|
+
|
|
1926
|
+
# Handle Enum types
|
|
1927
|
+
elif hasattr(field_type, "__bases__") and any(
|
|
1928
|
+
issubclass(base, Enum) for base in field_type.__bases__ if isinstance(base, type)
|
|
1929
|
+
):
|
|
1930
|
+
enum_values = [f'"{e.value}"' for e in field_type]
|
|
1931
|
+
return f"string (must be one of: {', '.join(enum_values)})"
|
|
1932
|
+
|
|
1933
|
+
# Handle List types
|
|
1934
|
+
elif (
|
|
1935
|
+
hasattr(field_type, "__origin__")
|
|
1936
|
+
and hasattr(field_type, "__args__")
|
|
1937
|
+
and field_type.__origin__ is list
|
|
1938
|
+
):
|
|
1939
|
+
item_type_repr = "any"
|
|
1940
|
+
if field_type.__args__:
|
|
1941
|
+
item_type_repr = get_field_type_representation(field_type.__args__[0])
|
|
1942
|
+
return [item_type_repr]
|
|
1943
|
+
|
|
1944
|
+
# Handle nested Pydantic models
|
|
1945
|
+
elif hasattr(field_type, "__bases__") and any(
|
|
1946
|
+
hasattr(base, "model_fields") for base in field_type.__bases__
|
|
1947
|
+
):
|
|
1948
|
+
nested_schema = _generate_schema_dict(field_type)
|
|
1949
|
+
return nested_schema
|
|
1950
|
+
|
|
1951
|
+
# Default fallback
|
|
1952
|
+
else:
|
|
1953
|
+
return "any"
|
|
1954
|
+
|
|
1955
|
+
def _generate_schema_dict(model_class: Type) -> dict[str, Any]:
|
|
1956
|
+
"""Recursively generate the schema as a dictionary."""
|
|
1957
|
+
schema_dict = {}
|
|
1958
|
+
if hasattr(model_class, "model_fields"):
|
|
1959
|
+
for field_name, field_info in model_class.model_fields.items():
|
|
1960
|
+
schema_dict[field_name] = get_field_type_representation(field_info.annotation)
|
|
1961
|
+
return schema_dict
|
|
1962
|
+
|
|
1963
|
+
schema = _generate_schema_dict(model)
|
|
1964
|
+
return json.dumps(schema, indent=2)
|
|
1965
|
+
|
|
1966
|
+
async def _apply_prompt_provider_specific_structured(
|
|
1967
|
+
self,
|
|
1968
|
+
multipart_messages: list[PromptMessageExtended],
|
|
1969
|
+
model: Type[ModelT],
|
|
1970
|
+
request_params: RequestParams | None = None,
|
|
1971
|
+
) -> tuple[ModelT | None, PromptMessageExtended]:
|
|
1972
|
+
"""Apply structured output for Bedrock using prompt engineering with a simplified schema."""
|
|
1973
|
+
# Short-circuit: if the last message is already an assistant JSON payload,
|
|
1974
|
+
# parse it directly without invoking the model. This restores pre-regression behavior
|
|
1975
|
+
# for tests that seed assistant JSON as the last turn.
|
|
1976
|
+
try:
|
|
1977
|
+
if multipart_messages and multipart_messages[-1].role == "assistant":
|
|
1978
|
+
parsed_model, parsed_mp = self._structured_from_multipart(
|
|
1979
|
+
multipart_messages[-1], model
|
|
1980
|
+
)
|
|
1981
|
+
if parsed_model is not None:
|
|
1982
|
+
return parsed_model, parsed_mp
|
|
1983
|
+
except Exception:
|
|
1984
|
+
# Fall through to normal generation path
|
|
1985
|
+
pass
|
|
1986
|
+
|
|
1987
|
+
request_params = self.get_request_params(request_params)
|
|
1988
|
+
|
|
1989
|
+
# For structured outputs: disable reasoning entirely and set temperature=0 for deterministic JSON
|
|
1990
|
+
# This avoids conflicts between reasoning (requires temperature=1) and structured output (wants temperature=0)
|
|
1991
|
+
original_reasoning_effort = self._reasoning_effort
|
|
1992
|
+
self._reasoning_effort = ReasoningEffort.MINIMAL # Temporarily disable reasoning
|
|
1993
|
+
|
|
1994
|
+
# Override temperature for structured outputs
|
|
1995
|
+
if request_params:
|
|
1996
|
+
request_params = request_params.model_copy(update={"temperature": 0.0})
|
|
1997
|
+
else:
|
|
1998
|
+
request_params = RequestParams(temperature=0.0)
|
|
1999
|
+
|
|
2000
|
+
# Select schema strategy, prefer runtime cache over resolver
|
|
2001
|
+
caps_struct = self.capabilities.get(self.model) or ModelCapabilities()
|
|
2002
|
+
strategy = caps_struct.structured_strategy or StructuredStrategy.STRICT_SCHEMA
|
|
2003
|
+
|
|
2004
|
+
if strategy == StructuredStrategy.SIMPLIFIED_SCHEMA:
|
|
2005
|
+
schema_text = self._generate_simplified_schema(model)
|
|
2006
|
+
else:
|
|
2007
|
+
schema_text = FastAgentLLM.model_to_schema_str(model)
|
|
2008
|
+
|
|
2009
|
+
# Build the new simplified prompt
|
|
2010
|
+
prompt_parts = [
|
|
2011
|
+
"You are a JSON generator. Respond with JSON that strictly follows the provided schema. Do not add any commentary or explanation.",
|
|
2012
|
+
"",
|
|
2013
|
+
"JSON Schema:",
|
|
2014
|
+
schema_text,
|
|
2015
|
+
"",
|
|
2016
|
+
"IMPORTANT RULES:",
|
|
2017
|
+
"- You MUST respond with only raw JSON data. No other text, commentary, or markdown is allowed.",
|
|
2018
|
+
"- All field names and enum values are case-sensitive and must match the schema exactly.",
|
|
2019
|
+
"- Do not add any extra fields to the JSON response. Only include the fields specified in the schema.",
|
|
2020
|
+
"- Do not use code fences or backticks (no ```json and no ```).",
|
|
2021
|
+
"- Your output must start with '{' and end with '}'.",
|
|
2022
|
+
"- Valid JSON requires double quotes for all field names and string values. Other types (int, float, boolean, etc.) should not be quoted.",
|
|
2023
|
+
"",
|
|
2024
|
+
"Now, generate the valid JSON response for the following request:",
|
|
2025
|
+
]
|
|
2026
|
+
|
|
2027
|
+
# IMPORTANT: Do NOT mutate the caller's messages. Create a deep copy of the last
|
|
2028
|
+
# user message, append the schema to the copy only, and pass just that copy into
|
|
2029
|
+
# the provider-specific path. This prevents contamination of routed messages.
|
|
2030
|
+
try:
|
|
2031
|
+
temp_last = multipart_messages[-1].model_copy(deep=True)
|
|
2032
|
+
except Exception:
|
|
2033
|
+
# Fallback: construct a minimal copy if model_copy is unavailable
|
|
2034
|
+
temp_last = PromptMessageExtended(
|
|
2035
|
+
role=multipart_messages[-1].role, content=list(multipart_messages[-1].content)
|
|
2036
|
+
)
|
|
2037
|
+
|
|
2038
|
+
temp_last.add_text("\n".join(prompt_parts))
|
|
2039
|
+
|
|
2040
|
+
self.logger.debug(
|
|
2041
|
+
"DEBUG: Using copied last message for structured schema; original left untouched"
|
|
2042
|
+
)
|
|
2043
|
+
|
|
2044
|
+
try:
|
|
2045
|
+
result: PromptMessageExtended = await self._apply_prompt_provider_specific(
|
|
2046
|
+
[temp_last], request_params
|
|
2047
|
+
)
|
|
2048
|
+
try:
|
|
2049
|
+
parsed_model, _ = self._structured_from_multipart(result, model)
|
|
2050
|
+
# If parsing returned None (no model instance) we should trigger the retry path
|
|
2051
|
+
if parsed_model is None:
|
|
2052
|
+
raise ValueError("structured parse returned None; triggering retry")
|
|
2053
|
+
return parsed_model, result
|
|
2054
|
+
except Exception:
|
|
2055
|
+
# One retry with stricter JSON-only guidance and simplified schema
|
|
2056
|
+
strict_parts = [
|
|
2057
|
+
"STRICT MODE:",
|
|
2058
|
+
"Return ONLY a single JSON object that matches the schema.",
|
|
2059
|
+
"Do not include any prose, explanations, code fences, or extra characters.",
|
|
2060
|
+
"Start with '{' and end with '}'.",
|
|
2061
|
+
"",
|
|
2062
|
+
"JSON Schema (simplified):",
|
|
2063
|
+
]
|
|
2064
|
+
try:
|
|
2065
|
+
simplified_schema_text = self._generate_simplified_schema(model)
|
|
2066
|
+
except Exception:
|
|
2067
|
+
simplified_schema_text = FastAgentLLM.model_to_schema_str(model)
|
|
2068
|
+
try:
|
|
2069
|
+
temp_last_retry = multipart_messages[-1].model_copy(deep=True)
|
|
2070
|
+
except Exception:
|
|
2071
|
+
temp_last_retry = PromptMessageExtended(
|
|
2072
|
+
role=multipart_messages[-1].role,
|
|
2073
|
+
content=list(multipart_messages[-1].content),
|
|
2074
|
+
)
|
|
2075
|
+
temp_last_retry.add_text("\n".join(strict_parts + [simplified_schema_text]))
|
|
2076
|
+
|
|
2077
|
+
retry_result: PromptMessageExtended = await self._apply_prompt_provider_specific(
|
|
2078
|
+
[temp_last_retry], request_params
|
|
2079
|
+
)
|
|
2080
|
+
return self._structured_from_multipart(retry_result, model)
|
|
2081
|
+
finally:
|
|
2082
|
+
# Restore original reasoning effort
|
|
2083
|
+
self._reasoning_effort = original_reasoning_effort
|
|
2084
|
+
|
|
2085
|
+
def _clean_json_response(self, text: str) -> str:
|
|
2086
|
+
"""Clean up JSON response by removing text before first { and after last }.
|
|
2087
|
+
|
|
2088
|
+
Also handles cases where models wrap the response in an extra layer like:
|
|
2089
|
+
{"FormattedResponse": {"thinking": "...", "message": "..."}}
|
|
2090
|
+
"""
|
|
2091
|
+
if not text:
|
|
2092
|
+
return text
|
|
2093
|
+
|
|
2094
|
+
# Strip common code fences (```json ... ``` or ``` ... ```), anywhere in the text
|
|
2095
|
+
try:
|
|
2096
|
+
import re as _re
|
|
2097
|
+
|
|
2098
|
+
fence_match = _re.search(r"```(?:json)?\s*([\s\S]*?)```", text)
|
|
2099
|
+
if fence_match:
|
|
2100
|
+
text = fence_match.group(1)
|
|
2101
|
+
except Exception:
|
|
2102
|
+
pass
|
|
2103
|
+
|
|
2104
|
+
# Find the first { and last }
|
|
2105
|
+
first_brace = text.find("{")
|
|
2106
|
+
last_brace = text.rfind("}")
|
|
2107
|
+
|
|
2108
|
+
# If we found both braces, extract just the JSON part
|
|
2109
|
+
if first_brace != -1 and last_brace != -1 and first_brace < last_brace:
|
|
2110
|
+
json_part = text[first_brace : last_brace + 1]
|
|
2111
|
+
|
|
2112
|
+
# Check if the JSON is wrapped in an extra layer (common model behavior)
|
|
2113
|
+
try:
|
|
2114
|
+
import json
|
|
2115
|
+
|
|
2116
|
+
parsed = json.loads(json_part)
|
|
2117
|
+
|
|
2118
|
+
# If it's a dict with a single key that matches the model class name,
|
|
2119
|
+
# unwrap it (e.g., {"FormattedResponse": {...}} -> {...})
|
|
2120
|
+
if isinstance(parsed, dict) and len(parsed) == 1:
|
|
2121
|
+
key = list(parsed.keys())[0]
|
|
2122
|
+
# Common wrapper patterns: class name, "response", "result", etc.
|
|
2123
|
+
if key in [
|
|
2124
|
+
"FormattedResponse",
|
|
2125
|
+
"WeatherResponse",
|
|
2126
|
+
"SimpleResponse",
|
|
2127
|
+
] or key.endswith("Response"):
|
|
2128
|
+
inner_value = parsed[key]
|
|
2129
|
+
if isinstance(inner_value, dict):
|
|
2130
|
+
return json.dumps(inner_value)
|
|
2131
|
+
|
|
2132
|
+
return json_part
|
|
2133
|
+
except json.JSONDecodeError:
|
|
2134
|
+
# If parsing fails, return the original JSON part
|
|
2135
|
+
return json_part
|
|
2136
|
+
|
|
2137
|
+
# Otherwise return the original text
|
|
2138
|
+
return text
|
|
2139
|
+
|
|
2140
|
+
def _structured_from_multipart(
|
|
2141
|
+
self, message: PromptMessageExtended, model: Type[ModelT]
|
|
2142
|
+
) -> tuple[ModelT | None, PromptMessageExtended]:
|
|
2143
|
+
"""Override to apply JSON cleaning before parsing."""
|
|
2144
|
+
# Get the text from the multipart message
|
|
2145
|
+
text = message.all_text()
|
|
2146
|
+
|
|
2147
|
+
# Clean the JSON response to remove extra text
|
|
2148
|
+
cleaned_text = self._clean_json_response(text)
|
|
2149
|
+
|
|
2150
|
+
# If we cleaned the text, create a new multipart with the cleaned text
|
|
2151
|
+
if cleaned_text != text:
|
|
2152
|
+
from mcp.types import TextContent
|
|
2153
|
+
|
|
2154
|
+
cleaned_multipart = PromptMessageExtended(
|
|
2155
|
+
role=message.role, content=[TextContent(type="text", text=cleaned_text)]
|
|
2156
|
+
)
|
|
2157
|
+
else:
|
|
2158
|
+
cleaned_multipart = message
|
|
2159
|
+
|
|
2160
|
+
# Parse using cleaned multipart first
|
|
2161
|
+
model_instance, parsed_multipart = super()._structured_from_multipart(
|
|
2162
|
+
cleaned_multipart, model
|
|
2163
|
+
)
|
|
2164
|
+
if model_instance is not None:
|
|
2165
|
+
return model_instance, parsed_multipart
|
|
2166
|
+
# Fallback: if parsing failed (e.g., assistant-provided JSON already valid), try original
|
|
2167
|
+
return super()._structured_from_multipart(message, model)
|
|
2168
|
+
|
|
2169
|
+
@classmethod
|
|
2170
|
+
def convert_message_to_message_param(
|
|
2171
|
+
cls, message: BedrockMessage, **kwargs
|
|
2172
|
+
) -> BedrockMessageParam:
|
|
2173
|
+
"""Convert a Bedrock message to message parameter format."""
|
|
2174
|
+
message_param = {"role": message.get("role", "assistant"), "content": []}
|
|
2175
|
+
|
|
2176
|
+
for content_item in message.get("content", []):
|
|
2177
|
+
if isinstance(content_item, dict):
|
|
2178
|
+
if "text" in content_item:
|
|
2179
|
+
message_param["content"].append({"type": "text", "text": content_item["text"]})
|
|
2180
|
+
elif "toolUse" in content_item:
|
|
2181
|
+
tool_use = content_item["toolUse"]
|
|
2182
|
+
tool_input = tool_use.get("input", {})
|
|
2183
|
+
|
|
2184
|
+
# Ensure tool_input is a dictionary
|
|
2185
|
+
if not isinstance(tool_input, dict):
|
|
2186
|
+
if isinstance(tool_input, str):
|
|
2187
|
+
try:
|
|
2188
|
+
tool_input = json.loads(tool_input) if tool_input else {}
|
|
2189
|
+
except json.JSONDecodeError:
|
|
2190
|
+
tool_input = {}
|
|
2191
|
+
else:
|
|
2192
|
+
tool_input = {}
|
|
2193
|
+
|
|
2194
|
+
message_param["content"].append(
|
|
2195
|
+
{
|
|
2196
|
+
"type": "tool_use",
|
|
2197
|
+
"id": tool_use.get("toolUseId", ""),
|
|
2198
|
+
"name": tool_use.get("name", ""),
|
|
2199
|
+
"input": tool_input,
|
|
2200
|
+
}
|
|
2201
|
+
)
|
|
2202
|
+
|
|
2203
|
+
return message_param
|
|
2204
|
+
|
|
2205
|
+
def _api_key(self) -> str:
|
|
2206
|
+
"""Bedrock doesn't use API keys, returns empty string."""
|
|
2207
|
+
return ""
|