fast-agent-mcp 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.32.dist-info → fast_agent_mcp-0.2.34.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.2.32.dist-info → fast_agent_mcp-0.2.34.dist-info}/RECORD +23 -20
- mcp_agent/agents/base_agent.py +13 -0
- mcp_agent/config.py +40 -4
- mcp_agent/core/agent_app.py +41 -1
- mcp_agent/core/enhanced_prompt.py +9 -0
- mcp_agent/core/fastagent.py +14 -2
- mcp_agent/core/interactive_prompt.py +59 -13
- mcp_agent/core/usage_display.py +193 -0
- mcp_agent/llm/augmented_llm.py +26 -6
- mcp_agent/llm/augmented_llm_passthrough.py +66 -4
- mcp_agent/llm/augmented_llm_playback.py +19 -0
- mcp_agent/llm/augmented_llm_slow.py +12 -1
- mcp_agent/llm/model_database.py +236 -0
- mcp_agent/llm/model_factory.py +1 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +44 -8
- mcp_agent/llm/providers/augmented_llm_google_native.py +18 -1
- mcp_agent/llm/providers/augmented_llm_openai.py +20 -7
- mcp_agent/llm/usage_tracking.py +385 -0
- mcp_agent/mcp/interfaces.py +6 -0
- {fast_agent_mcp-0.2.32.dist-info → fast_agent_mcp-0.2.34.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.32.dist-info → fast_agent_mcp-0.2.34.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.32.dist-info → fast_agent_mcp-0.2.34.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
|
|
1
|
+
"""
|
2
|
+
Utility module for displaying usage statistics in a consistent format.
|
3
|
+
Consolidates the usage display logic that was duplicated between fastagent.py and interactive_prompt.py.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, Optional
|
7
|
+
|
8
|
+
from rich.console import Console
|
9
|
+
|
10
|
+
|
11
|
+
def display_usage_report(
|
12
|
+
agents: Dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False
|
13
|
+
) -> None:
|
14
|
+
"""
|
15
|
+
Display a formatted table of token usage for all agents.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
agents: Dictionary of agent name -> agent object
|
19
|
+
show_if_progress_disabled: If True, show even when progress display is disabled
|
20
|
+
subdued_colors: If True, use dim styling for a more subdued appearance
|
21
|
+
"""
|
22
|
+
# Check if progress display is enabled (only relevant for fastagent context)
|
23
|
+
if not show_if_progress_disabled:
|
24
|
+
try:
|
25
|
+
from mcp_agent import config
|
26
|
+
|
27
|
+
settings = config.get_settings()
|
28
|
+
if not settings.logger.progress_display:
|
29
|
+
return
|
30
|
+
except (ImportError, AttributeError):
|
31
|
+
# If we can't check settings, assume we should display
|
32
|
+
pass
|
33
|
+
|
34
|
+
# Collect usage data from all agents
|
35
|
+
usage_data = []
|
36
|
+
total_input = 0
|
37
|
+
total_output = 0
|
38
|
+
total_tokens = 0
|
39
|
+
|
40
|
+
for agent_name, agent in agents.items():
|
41
|
+
if agent.usage_accumulator:
|
42
|
+
summary = agent.usage_accumulator.get_summary()
|
43
|
+
if summary["turn_count"] > 0:
|
44
|
+
input_tokens = summary["cumulative_input_tokens"]
|
45
|
+
output_tokens = summary["cumulative_output_tokens"]
|
46
|
+
billing_tokens = summary["cumulative_billing_tokens"]
|
47
|
+
turns = summary["turn_count"]
|
48
|
+
|
49
|
+
# Get context percentage for this agent
|
50
|
+
context_percentage = agent.usage_accumulator.context_usage_percentage
|
51
|
+
|
52
|
+
# Get model name from LLM's default_request_params
|
53
|
+
model = "unknown"
|
54
|
+
if hasattr(agent, "_llm") and agent._llm:
|
55
|
+
llm = agent._llm
|
56
|
+
if (
|
57
|
+
hasattr(llm, "default_request_params")
|
58
|
+
and llm.default_request_params
|
59
|
+
and hasattr(llm.default_request_params, "model")
|
60
|
+
):
|
61
|
+
model = llm.default_request_params.model or "unknown"
|
62
|
+
|
63
|
+
# Standardize model name truncation - use consistent 25 char width with 22+... truncation
|
64
|
+
if len(model) > 25:
|
65
|
+
model = model[:22] + "..."
|
66
|
+
|
67
|
+
usage_data.append(
|
68
|
+
{
|
69
|
+
"name": agent_name,
|
70
|
+
"model": model,
|
71
|
+
"input": input_tokens,
|
72
|
+
"output": output_tokens,
|
73
|
+
"total": billing_tokens,
|
74
|
+
"turns": turns,
|
75
|
+
"context": context_percentage,
|
76
|
+
}
|
77
|
+
)
|
78
|
+
|
79
|
+
total_input += input_tokens
|
80
|
+
total_output += output_tokens
|
81
|
+
total_tokens += billing_tokens
|
82
|
+
|
83
|
+
if not usage_data:
|
84
|
+
return
|
85
|
+
|
86
|
+
# Calculate dynamic agent column width (max 15)
|
87
|
+
max_agent_width = min(15, max(len(data["name"]) for data in usage_data) if usage_data else 8)
|
88
|
+
agent_width = max(max_agent_width, 5) # Minimum of 5 for "Agent" header
|
89
|
+
|
90
|
+
# Display the table
|
91
|
+
console = Console()
|
92
|
+
console.print()
|
93
|
+
console.print("[dim]Usage Summary (Cumulative)[/dim]")
|
94
|
+
|
95
|
+
# Print header with proper spacing
|
96
|
+
console.print(
|
97
|
+
f"[dim]{'Agent':<{agent_width}} {'Input':>9} {'Output':>9} {'Total':>9} {'Turns':>6} {'Context%':>9} {'Model':<25}[/dim]"
|
98
|
+
)
|
99
|
+
|
100
|
+
# Print agent rows - use styling based on subdued_colors flag
|
101
|
+
for data in usage_data:
|
102
|
+
input_str = f"{data['input']:,}"
|
103
|
+
output_str = f"{data['output']:,}"
|
104
|
+
total_str = f"{data['total']:,}"
|
105
|
+
turns_str = str(data["turns"])
|
106
|
+
context_str = f"{data['context']:.1f}%" if data["context"] is not None else "-"
|
107
|
+
|
108
|
+
# Truncate agent name if needed
|
109
|
+
agent_name = data["name"]
|
110
|
+
if len(agent_name) > agent_width:
|
111
|
+
agent_name = agent_name[: agent_width - 3] + "..."
|
112
|
+
|
113
|
+
if subdued_colors:
|
114
|
+
# Original fastagent.py style with dim wrapper
|
115
|
+
console.print(
|
116
|
+
f"[dim]{agent_name:<{agent_width}} "
|
117
|
+
f"{input_str:>9} "
|
118
|
+
f"{output_str:>9} "
|
119
|
+
f"[bold]{total_str:>9}[/bold] "
|
120
|
+
f"{turns_str:>6} "
|
121
|
+
f"{context_str:>9} "
|
122
|
+
f"{data['model']:<25}[/dim]"
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
# Original interactive_prompt.py style
|
126
|
+
console.print(
|
127
|
+
f"{agent_name:<{agent_width}} "
|
128
|
+
f"{input_str:>9} "
|
129
|
+
f"{output_str:>9} "
|
130
|
+
f"[bold]{total_str:>9}[/bold] "
|
131
|
+
f"{turns_str:>6} "
|
132
|
+
f"{context_str:>9} "
|
133
|
+
f"[dim]{data['model']:<25}[/dim]"
|
134
|
+
)
|
135
|
+
|
136
|
+
# Add total row if multiple agents
|
137
|
+
if len(usage_data) > 1:
|
138
|
+
console.print()
|
139
|
+
total_input_str = f"{total_input:,}"
|
140
|
+
total_output_str = f"{total_output:,}"
|
141
|
+
total_tokens_str = f"{total_tokens:,}"
|
142
|
+
|
143
|
+
if subdued_colors:
|
144
|
+
# Original fastagent.py style with dim wrapper on bold
|
145
|
+
console.print(
|
146
|
+
f"[bold dim]{'TOTAL':<{agent_width}} "
|
147
|
+
f"{total_input_str:>9} "
|
148
|
+
f"{total_output_str:>9} "
|
149
|
+
f"[bold]{total_tokens_str:>9}[/bold] "
|
150
|
+
f"{'':<6} "
|
151
|
+
f"{'':<9} "
|
152
|
+
f"{'':<25}[/bold dim]"
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# Original interactive_prompt.py style
|
156
|
+
console.print(
|
157
|
+
f"[bold]{'TOTAL':<{agent_width}}[/bold] "
|
158
|
+
f"[bold]{total_input_str:>9}[/bold] "
|
159
|
+
f"[bold]{total_output_str:>9}[/bold] "
|
160
|
+
f"[bold]{total_tokens_str:>9}[/bold] "
|
161
|
+
f"{'':<6} "
|
162
|
+
f"{'':<9} "
|
163
|
+
f"{'':<25}"
|
164
|
+
)
|
165
|
+
|
166
|
+
console.print()
|
167
|
+
|
168
|
+
|
169
|
+
def collect_agents_from_provider(
|
170
|
+
prompt_provider: Any, agent_name: Optional[str] = None
|
171
|
+
) -> Dict[str, Any]:
|
172
|
+
"""
|
173
|
+
Collect agents from a prompt provider for usage display.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
prompt_provider: Provider that has access to agents
|
177
|
+
agent_name: Name of the current agent (for context)
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Dictionary of agent name -> agent object
|
181
|
+
"""
|
182
|
+
agents_to_show = {}
|
183
|
+
|
184
|
+
if hasattr(prompt_provider, "_agents"):
|
185
|
+
# Multi-agent app - show all agents
|
186
|
+
agents_to_show = prompt_provider._agents
|
187
|
+
elif hasattr(prompt_provider, "agent"):
|
188
|
+
# Single agent
|
189
|
+
agent = prompt_provider.agent
|
190
|
+
if hasattr(agent, "name"):
|
191
|
+
agents_to_show = {agent.name: agent}
|
192
|
+
|
193
|
+
return agents_to_show
|
mcp_agent/llm/augmented_llm.py
CHANGED
@@ -30,11 +30,13 @@ from mcp_agent.core.prompt import Prompt
|
|
30
30
|
from mcp_agent.core.request_params import RequestParams
|
31
31
|
from mcp_agent.event_progress import ProgressAction
|
32
32
|
from mcp_agent.llm.memory import Memory, SimpleMemory
|
33
|
+
from mcp_agent.llm.model_database import ModelDatabase
|
33
34
|
from mcp_agent.llm.provider_types import Provider
|
34
35
|
from mcp_agent.llm.sampling_format_converter import (
|
35
36
|
BasicFormatConverter,
|
36
37
|
ProviderFormatConverter,
|
37
38
|
)
|
39
|
+
from mcp_agent.llm.usage_tracking import UsageAccumulator
|
38
40
|
from mcp_agent.logging.logger import get_logger
|
39
41
|
from mcp_agent.mcp.helpers.content_helpers import get_text
|
40
42
|
from mcp_agent.mcp.interfaces import (
|
@@ -155,12 +157,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
155
157
|
# Initialize the display component
|
156
158
|
self.display = ConsoleDisplay(config=self.context.config)
|
157
159
|
|
158
|
-
# Initialize default parameters
|
159
|
-
|
160
|
-
|
161
|
-
# Apply model override if provided
|
160
|
+
# Initialize default parameters, passing model info
|
161
|
+
model_kwargs = kwargs.copy()
|
162
162
|
if model:
|
163
|
-
|
163
|
+
model_kwargs["model"] = model
|
164
|
+
self.default_request_params = self._initialize_default_params(model_kwargs)
|
164
165
|
|
165
166
|
# Merge with provided params if any
|
166
167
|
if self._init_request_params:
|
@@ -171,13 +172,22 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
171
172
|
self.type_converter = type_converter
|
172
173
|
self.verb = kwargs.get("verb")
|
173
174
|
|
175
|
+
# Initialize usage tracking
|
176
|
+
self.usage_accumulator = UsageAccumulator()
|
177
|
+
|
174
178
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
175
179
|
"""Initialize default parameters for the LLM.
|
176
180
|
Should be overridden by provider implementations to set provider-specific defaults."""
|
181
|
+
# Get model-aware default max tokens
|
182
|
+
model = kwargs.get("model")
|
183
|
+
max_tokens = ModelDatabase.get_default_max_tokens(model)
|
184
|
+
|
177
185
|
return RequestParams(
|
186
|
+
model=model,
|
187
|
+
maxTokens=max_tokens,
|
178
188
|
systemPrompt=self.instruction,
|
179
189
|
parallel_tool_calls=True,
|
180
|
-
max_iterations=
|
190
|
+
max_iterations=20,
|
181
191
|
use_history=True,
|
182
192
|
)
|
183
193
|
|
@@ -642,3 +652,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
642
652
|
|
643
653
|
assert self.provider
|
644
654
|
return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
|
655
|
+
|
656
|
+
def get_usage_summary(self) -> dict:
|
657
|
+
"""
|
658
|
+
Get a summary of usage statistics for this LLM instance.
|
659
|
+
|
660
|
+
Returns:
|
661
|
+
Dictionary containing usage statistics including tokens, cache metrics,
|
662
|
+
and context window utilization.
|
663
|
+
"""
|
664
|
+
return self.usage_accumulator.get_summary()
|
@@ -10,6 +10,7 @@ from mcp_agent.llm.augmented_llm import (
|
|
10
10
|
RequestParams,
|
11
11
|
)
|
12
12
|
from mcp_agent.llm.provider_types import Provider
|
13
|
+
from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
|
13
14
|
from mcp_agent.logging.logger import get_logger
|
14
15
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
15
16
|
|
@@ -48,13 +49,34 @@ class PassthroughLLM(AugmentedLLM):
|
|
48
49
|
await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
|
49
50
|
|
50
51
|
# Handle PromptMessage by concatenating all parts
|
52
|
+
result = ""
|
51
53
|
if isinstance(message, PromptMessage):
|
52
54
|
parts_text = []
|
53
55
|
for part in message.content:
|
54
56
|
parts_text.append(str(part))
|
55
|
-
|
57
|
+
result = "\n".join(parts_text)
|
58
|
+
else:
|
59
|
+
result = str(message)
|
56
60
|
|
57
|
-
|
61
|
+
# Track usage for this passthrough "turn"
|
62
|
+
try:
|
63
|
+
input_content = str(message)
|
64
|
+
output_content = result
|
65
|
+
tool_calls = 1 if input_content.startswith("***CALL_TOOL") else 0
|
66
|
+
|
67
|
+
turn_usage = create_turn_usage_from_messages(
|
68
|
+
input_content=input_content,
|
69
|
+
output_content=output_content,
|
70
|
+
model="passthrough",
|
71
|
+
model_type="passthrough",
|
72
|
+
tool_calls=tool_calls,
|
73
|
+
delay_seconds=0.0,
|
74
|
+
)
|
75
|
+
self.usage_accumulator.add_turn(turn_usage)
|
76
|
+
except Exception as e:
|
77
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
78
|
+
|
79
|
+
return result
|
58
80
|
|
59
81
|
async def initialize(self) -> None:
|
60
82
|
pass
|
@@ -146,6 +168,25 @@ class PassthroughLLM(AugmentedLLM):
|
|
146
168
|
if self.is_tool_call(last_message):
|
147
169
|
result = Prompt.assistant(await self.generate_str(last_message.first_text()))
|
148
170
|
await self.show_assistant_message(result.first_text())
|
171
|
+
|
172
|
+
# Track usage for this tool call "turn"
|
173
|
+
try:
|
174
|
+
input_content = "\n".join(message.all_text() for message in multipart_messages)
|
175
|
+
output_content = result.first_text()
|
176
|
+
|
177
|
+
turn_usage = create_turn_usage_from_messages(
|
178
|
+
input_content=input_content,
|
179
|
+
output_content=output_content,
|
180
|
+
model="passthrough",
|
181
|
+
model_type="passthrough",
|
182
|
+
tool_calls=1, # This is definitely a tool call
|
183
|
+
delay_seconds=0.0,
|
184
|
+
)
|
185
|
+
self.usage_accumulator.add_turn(turn_usage)
|
186
|
+
|
187
|
+
except Exception as e:
|
188
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
189
|
+
|
149
190
|
return result
|
150
191
|
|
151
192
|
if last_message.first_text().startswith(FIXED_RESPONSE_INDICATOR):
|
@@ -155,12 +196,33 @@ class PassthroughLLM(AugmentedLLM):
|
|
155
196
|
|
156
197
|
if self._fixed_response:
|
157
198
|
await self.show_assistant_message(self._fixed_response)
|
158
|
-
|
199
|
+
result = Prompt.assistant(self._fixed_response)
|
159
200
|
else:
|
160
201
|
# TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
|
161
202
|
concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
|
162
203
|
await self.show_assistant_message(concatenated)
|
163
|
-
|
204
|
+
result = Prompt.assistant(concatenated)
|
205
|
+
|
206
|
+
# Track usage for this passthrough "turn"
|
207
|
+
try:
|
208
|
+
input_content = "\n".join(message.all_text() for message in multipart_messages)
|
209
|
+
output_content = result.first_text()
|
210
|
+
tool_calls = 1 if self.is_tool_call(last_message) else 0
|
211
|
+
|
212
|
+
turn_usage = create_turn_usage_from_messages(
|
213
|
+
input_content=input_content,
|
214
|
+
output_content=output_content,
|
215
|
+
model="passthrough",
|
216
|
+
model_type="passthrough",
|
217
|
+
tool_calls=tool_calls,
|
218
|
+
delay_seconds=0.0,
|
219
|
+
)
|
220
|
+
self.usage_accumulator.add_turn(turn_usage)
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
224
|
+
|
225
|
+
return result
|
164
226
|
|
165
227
|
def is_tool_call(self, message: PromptMessageMultipart) -> bool:
|
166
228
|
return message.first_text().startswith(CALL_TOOL_INDICATOR)
|
@@ -5,6 +5,7 @@ from mcp_agent.core.prompt import Prompt
|
|
5
5
|
from mcp_agent.llm.augmented_llm import RequestParams
|
6
6
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
7
7
|
from mcp_agent.llm.provider_types import Provider
|
8
|
+
from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
|
8
9
|
from mcp_agent.mcp.interfaces import ModelT
|
9
10
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
10
11
|
from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
|
@@ -83,6 +84,24 @@ class PlaybackLLM(PassthroughLLM):
|
|
83
84
|
message_text=MessageContent.get_first_text(response), title="ASSISTANT/PLAYBACK"
|
84
85
|
)
|
85
86
|
|
87
|
+
# Track usage for this playback "turn"
|
88
|
+
try:
|
89
|
+
input_content = str(multipart_messages) if multipart_messages else ""
|
90
|
+
output_content = MessageContent.get_first_text(response)
|
91
|
+
|
92
|
+
turn_usage = create_turn_usage_from_messages(
|
93
|
+
input_content=input_content,
|
94
|
+
output_content=output_content,
|
95
|
+
model="playback",
|
96
|
+
model_type="playback",
|
97
|
+
tool_calls=0,
|
98
|
+
delay_seconds=0.0,
|
99
|
+
)
|
100
|
+
self.usage_accumulator.add_turn(turn_usage)
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
104
|
+
|
86
105
|
return response
|
87
106
|
|
88
107
|
async def structured(
|
@@ -30,7 +30,18 @@ class SlowLLM(PassthroughLLM):
|
|
30
30
|
) -> str:
|
31
31
|
"""Sleep for 3 seconds then return the input message as a string."""
|
32
32
|
await asyncio.sleep(3)
|
33
|
-
|
33
|
+
result = await super().generate_str(message, request_params)
|
34
|
+
|
35
|
+
# Override the last turn to include the 3-second delay
|
36
|
+
if self.usage_accumulator.turns:
|
37
|
+
last_turn = self.usage_accumulator.turns[-1]
|
38
|
+
# Update the raw usage to include delay
|
39
|
+
if hasattr(last_turn.raw_usage, 'delay_seconds'):
|
40
|
+
last_turn.raw_usage.delay_seconds = 3.0
|
41
|
+
# Print updated debug info
|
42
|
+
print("SlowLLM: Added 3.0s delay to turn usage")
|
43
|
+
|
44
|
+
return result
|
34
45
|
|
35
46
|
async def _apply_prompt_provider_specific(
|
36
47
|
self,
|
@@ -0,0 +1,236 @@
|
|
1
|
+
"""
|
2
|
+
Model database for LLM parameters.
|
3
|
+
|
4
|
+
This module provides a centralized lookup for model parameters including
|
5
|
+
context windows, max output tokens, and supported tokenization types.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Dict, List, Optional
|
9
|
+
|
10
|
+
from pydantic import BaseModel
|
11
|
+
|
12
|
+
|
13
|
+
class ModelParameters(BaseModel):
|
14
|
+
"""Configuration parameters for a specific model"""
|
15
|
+
|
16
|
+
context_window: int
|
17
|
+
"""Maximum context window size in tokens"""
|
18
|
+
|
19
|
+
max_output_tokens: int
|
20
|
+
"""Maximum output tokens the model can generate"""
|
21
|
+
|
22
|
+
tokenizes: List[str]
|
23
|
+
"""List of supported content types for tokenization"""
|
24
|
+
|
25
|
+
|
26
|
+
class ModelDatabase:
|
27
|
+
"""Centralized model configuration database"""
|
28
|
+
|
29
|
+
# Common parameter sets
|
30
|
+
OPENAI_MULTIMODAL = ["text/plain", "image/jpeg", "image/png", "image/webp", "application/pdf"]
|
31
|
+
OPENAI_VISION = ["text/plain", "image/jpeg", "image/png", "image/webp"]
|
32
|
+
ANTHROPIC_MULTIMODAL = [
|
33
|
+
"text/plain",
|
34
|
+
"image/jpeg",
|
35
|
+
"image/png",
|
36
|
+
"image/webp",
|
37
|
+
"application/pdf",
|
38
|
+
]
|
39
|
+
GOOGLE_MULTIMODAL = [
|
40
|
+
"text/plain",
|
41
|
+
"image/jpeg",
|
42
|
+
"image/png",
|
43
|
+
"image/webp",
|
44
|
+
"application/pdf",
|
45
|
+
"audio/wav",
|
46
|
+
"audio/mp3",
|
47
|
+
"video/mp4",
|
48
|
+
]
|
49
|
+
QWEN_MULTIMODAL = ["text/plain", "image/jpeg", "image/png", "image/webp"]
|
50
|
+
TEXT_ONLY = ["text/plain"]
|
51
|
+
|
52
|
+
# Common parameter configurations
|
53
|
+
OPENAI_STANDARD = ModelParameters(
|
54
|
+
context_window=128000, max_output_tokens=16384, tokenizes=OPENAI_MULTIMODAL
|
55
|
+
)
|
56
|
+
|
57
|
+
OPENAI_4_1_STANDARD = ModelParameters(
|
58
|
+
context_window=1047576, max_output_tokens=32768, tokenizes=OPENAI_MULTIMODAL
|
59
|
+
)
|
60
|
+
|
61
|
+
OPENAI_O_SERIES = ModelParameters(
|
62
|
+
context_window=200000, max_output_tokens=100000, tokenizes=OPENAI_VISION
|
63
|
+
)
|
64
|
+
|
65
|
+
ANTHROPIC_LEGACY = ModelParameters(
|
66
|
+
context_window=200000, max_output_tokens=4096, tokenizes=ANTHROPIC_MULTIMODAL
|
67
|
+
)
|
68
|
+
|
69
|
+
ANTHROPIC_35_SERIES = ModelParameters(
|
70
|
+
context_window=200000, max_output_tokens=8192, tokenizes=ANTHROPIC_MULTIMODAL
|
71
|
+
)
|
72
|
+
|
73
|
+
# TODO--- TO USE 64,000 NEED TO SUPPORT STREAMING
|
74
|
+
ANTHROPIC_37_SERIES = ModelParameters(
|
75
|
+
context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
|
76
|
+
)
|
77
|
+
|
78
|
+
GEMINI_FLASH = ModelParameters(
|
79
|
+
context_window=1048576, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
|
80
|
+
)
|
81
|
+
|
82
|
+
GEMINI_PRO = ModelParameters(
|
83
|
+
context_window=2097152, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
|
84
|
+
)
|
85
|
+
|
86
|
+
QWEN_STANDARD = ModelParameters(
|
87
|
+
context_window=32000, max_output_tokens=8192, tokenizes=QWEN_MULTIMODAL
|
88
|
+
)
|
89
|
+
|
90
|
+
FAST_AGENT_STANDARD = ModelParameters(
|
91
|
+
context_window=1000000, max_output_tokens=100000, tokenizes=TEXT_ONLY
|
92
|
+
)
|
93
|
+
|
94
|
+
OPENAI_4_1_SERIES = ModelParameters(
|
95
|
+
context_window=1047576, max_output_tokens=32768, tokenizes=OPENAI_MULTIMODAL
|
96
|
+
)
|
97
|
+
|
98
|
+
OPENAI_4O_SERIES = ModelParameters(
|
99
|
+
context_window=128000, max_output_tokens=16384, tokenizes=OPENAI_VISION
|
100
|
+
)
|
101
|
+
|
102
|
+
OPENAI_O3_SERIES = ModelParameters(
|
103
|
+
context_window=200000, max_output_tokens=100000, tokenizes=OPENAI_MULTIMODAL
|
104
|
+
)
|
105
|
+
|
106
|
+
OPENAI_O3_MINI_SERIES = ModelParameters(
|
107
|
+
context_window=200000, max_output_tokens=100000, tokenizes=TEXT_ONLY
|
108
|
+
)
|
109
|
+
|
110
|
+
# TODO update to 32000
|
111
|
+
ANTHROPIC_OPUS_4_VERSIONED = ModelParameters(
|
112
|
+
context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
|
113
|
+
)
|
114
|
+
# TODO update to 64000
|
115
|
+
ANTHROPIC_SONNET_4_VERSIONED = ModelParameters(
|
116
|
+
context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
|
117
|
+
)
|
118
|
+
|
119
|
+
DEEPSEEK_CHAT_STANDARD = ModelParameters(
|
120
|
+
context_window=65536, max_output_tokens=8192, tokenizes=TEXT_ONLY
|
121
|
+
)
|
122
|
+
|
123
|
+
DEEPSEEK_REASONER = ModelParameters(
|
124
|
+
context_window=65536, max_output_tokens=32768, tokenizes=TEXT_ONLY
|
125
|
+
)
|
126
|
+
|
127
|
+
GEMINI_2_5_PRO = ModelParameters(
|
128
|
+
context_window=2097152, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
|
129
|
+
)
|
130
|
+
|
131
|
+
# Model configuration database
|
132
|
+
MODELS: Dict[str, ModelParameters] = {
|
133
|
+
# internal models
|
134
|
+
"passthrough": FAST_AGENT_STANDARD,
|
135
|
+
"playback": FAST_AGENT_STANDARD,
|
136
|
+
"slow": FAST_AGENT_STANDARD,
|
137
|
+
# aliyun models
|
138
|
+
"qwen-turbo": QWEN_STANDARD,
|
139
|
+
"qwen-plus": QWEN_STANDARD,
|
140
|
+
"qwen-max": QWEN_STANDARD,
|
141
|
+
"qwen-long": ModelParameters(
|
142
|
+
context_window=10000000, max_output_tokens=8192, tokenizes=TEXT_ONLY
|
143
|
+
),
|
144
|
+
# OpenAI Models (vanilla aliases and versioned)
|
145
|
+
"gpt-4.1": OPENAI_4_1_SERIES,
|
146
|
+
"gpt-4.1-mini": OPENAI_4_1_SERIES,
|
147
|
+
"gpt-4.1-nano": OPENAI_4_1_SERIES,
|
148
|
+
"gpt-4.1-2025-04-14": OPENAI_4_1_SERIES,
|
149
|
+
"gpt-4.1-mini-2025-04-14": OPENAI_4_1_SERIES,
|
150
|
+
"gpt-4.1-nano-2025-04-14": OPENAI_4_1_SERIES,
|
151
|
+
"gpt-4o": OPENAI_4O_SERIES,
|
152
|
+
"gpt-4o-2024-11-20": OPENAI_4O_SERIES,
|
153
|
+
"gpt-4o-mini-2024-07-18": OPENAI_4O_SERIES,
|
154
|
+
"o1": OPENAI_O_SERIES,
|
155
|
+
"o1-2024-12-17": OPENAI_O_SERIES,
|
156
|
+
"o3": OPENAI_O3_SERIES,
|
157
|
+
"o3-pro": ModelParameters(
|
158
|
+
context_window=200_000, max_output_tokens=100_000, tokenizes=TEXT_ONLY
|
159
|
+
),
|
160
|
+
"o3-mini": OPENAI_O3_MINI_SERIES,
|
161
|
+
"o4-mini": OPENAI_O3_SERIES,
|
162
|
+
"o3-2025-04-16": OPENAI_O3_SERIES,
|
163
|
+
"o3-mini-2025-01-31": OPENAI_O3_MINI_SERIES,
|
164
|
+
"o4-mini-2025-04-16": OPENAI_O3_SERIES,
|
165
|
+
# Anthropic Models
|
166
|
+
"claude-3-haiku": ANTHROPIC_35_SERIES,
|
167
|
+
"claude-3-haiku-20240307": ANTHROPIC_LEGACY,
|
168
|
+
"claude-3-sonnet": ANTHROPIC_LEGACY,
|
169
|
+
"claude-3-opus": ANTHROPIC_LEGACY,
|
170
|
+
"claude-3-opus-20240229": ANTHROPIC_LEGACY,
|
171
|
+
"claude-3-opus-latest": ANTHROPIC_LEGACY,
|
172
|
+
"claude-3-5-haiku": ANTHROPIC_35_SERIES,
|
173
|
+
"claude-3-5-haiku-20241022": ANTHROPIC_35_SERIES,
|
174
|
+
"claude-3-5-haiku-latest": ANTHROPIC_35_SERIES,
|
175
|
+
"claude-3-sonnet-20240229": ANTHROPIC_LEGACY,
|
176
|
+
"claude-3-5-sonnet": ANTHROPIC_35_SERIES,
|
177
|
+
"claude-3-5-sonnet-20240620": ANTHROPIC_35_SERIES,
|
178
|
+
"claude-3-5-sonnet-20241022": ANTHROPIC_35_SERIES,
|
179
|
+
"claude-3-5-sonnet-latest": ANTHROPIC_35_SERIES,
|
180
|
+
"claude-3-7-sonnet": ANTHROPIC_37_SERIES,
|
181
|
+
"claude-3-7-sonnet-20250219": ANTHROPIC_37_SERIES,
|
182
|
+
"claude-3-7-sonnet-latest": ANTHROPIC_37_SERIES,
|
183
|
+
"claude-sonnet-4": ANTHROPIC_SONNET_4_VERSIONED,
|
184
|
+
"claude-sonnet-4-0": ANTHROPIC_SONNET_4_VERSIONED,
|
185
|
+
"claude-sonnet-4-20250514": ANTHROPIC_SONNET_4_VERSIONED,
|
186
|
+
"claude-opus-4": ANTHROPIC_OPUS_4_VERSIONED,
|
187
|
+
"claude-opus-4-0": ANTHROPIC_OPUS_4_VERSIONED,
|
188
|
+
"claude-opus-4-20250514": ANTHROPIC_OPUS_4_VERSIONED,
|
189
|
+
# DeepSeek Models
|
190
|
+
"deepseek-chat": DEEPSEEK_CHAT_STANDARD,
|
191
|
+
# Google Gemini Models (vanilla aliases and versioned)
|
192
|
+
"gemini-2.0-flash": GEMINI_FLASH,
|
193
|
+
"gemini-2.5-flash-preview": GEMINI_FLASH,
|
194
|
+
"gemini-2.5-pro-preview": GEMINI_2_5_PRO,
|
195
|
+
"gemini-2.5-flash-preview-05-20": GEMINI_FLASH,
|
196
|
+
"gemini-2.5-pro-preview-05-06": GEMINI_PRO,
|
197
|
+
}
|
198
|
+
|
199
|
+
@classmethod
|
200
|
+
def get_model_params(cls, model: str) -> Optional[ModelParameters]:
|
201
|
+
"""Get model parameters for a given model name"""
|
202
|
+
return cls.MODELS.get(model)
|
203
|
+
|
204
|
+
@classmethod
|
205
|
+
def get_context_window(cls, model: str) -> Optional[int]:
|
206
|
+
"""Get context window size for a model"""
|
207
|
+
params = cls.get_model_params(model)
|
208
|
+
return params.context_window if params else None
|
209
|
+
|
210
|
+
@classmethod
|
211
|
+
def get_max_output_tokens(cls, model: str) -> Optional[int]:
|
212
|
+
"""Get maximum output tokens for a model"""
|
213
|
+
params = cls.get_model_params(model)
|
214
|
+
return params.max_output_tokens if params else None
|
215
|
+
|
216
|
+
@classmethod
|
217
|
+
def get_tokenizes(cls, model: str) -> Optional[List[str]]:
|
218
|
+
"""Get supported tokenization types for a model"""
|
219
|
+
params = cls.get_model_params(model)
|
220
|
+
return params.tokenizes if params else None
|
221
|
+
|
222
|
+
@classmethod
|
223
|
+
def get_default_max_tokens(cls, model: str) -> int:
|
224
|
+
"""Get default max_tokens for RequestParams based on model"""
|
225
|
+
if not model:
|
226
|
+
return 2048 # Fallback when no model specified
|
227
|
+
|
228
|
+
params = cls.get_model_params(model)
|
229
|
+
if params:
|
230
|
+
return params.max_output_tokens
|
231
|
+
return 2048 # Fallback for unknown models
|
232
|
+
|
233
|
+
@classmethod
|
234
|
+
def list_models(cls) -> List[str]:
|
235
|
+
"""List all available model names"""
|
236
|
+
return list(cls.MODELS.keys())
|
mcp_agent/llm/model_factory.py
CHANGED
@@ -87,6 +87,7 @@ class ModelFactory:
|
|
87
87
|
"o1-preview": Provider.OPENAI,
|
88
88
|
"o3": Provider.OPENAI,
|
89
89
|
"o3-mini": Provider.OPENAI,
|
90
|
+
"o4-mini": Provider.OPENAI,
|
90
91
|
"claude-3-haiku-20240307": Provider.ANTHROPIC,
|
91
92
|
"claude-3-5-haiku-20241022": Provider.ANTHROPIC,
|
92
93
|
"claude-3-5-haiku-latest": Provider.ANTHROPIC,
|