fast-agent-mcp 0.2.32__py3-none-any.whl → 0.2.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ """
2
+ Utility module for displaying usage statistics in a consistent format.
3
+ Consolidates the usage display logic that was duplicated between fastagent.py and interactive_prompt.py.
4
+ """
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ from rich.console import Console
9
+
10
+
11
+ def display_usage_report(
12
+ agents: Dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False
13
+ ) -> None:
14
+ """
15
+ Display a formatted table of token usage for all agents.
16
+
17
+ Args:
18
+ agents: Dictionary of agent name -> agent object
19
+ show_if_progress_disabled: If True, show even when progress display is disabled
20
+ subdued_colors: If True, use dim styling for a more subdued appearance
21
+ """
22
+ # Check if progress display is enabled (only relevant for fastagent context)
23
+ if not show_if_progress_disabled:
24
+ try:
25
+ from mcp_agent import config
26
+
27
+ settings = config.get_settings()
28
+ if not settings.logger.progress_display:
29
+ return
30
+ except (ImportError, AttributeError):
31
+ # If we can't check settings, assume we should display
32
+ pass
33
+
34
+ # Collect usage data from all agents
35
+ usage_data = []
36
+ total_input = 0
37
+ total_output = 0
38
+ total_tokens = 0
39
+
40
+ for agent_name, agent in agents.items():
41
+ if agent.usage_accumulator:
42
+ summary = agent.usage_accumulator.get_summary()
43
+ if summary["turn_count"] > 0:
44
+ input_tokens = summary["cumulative_input_tokens"]
45
+ output_tokens = summary["cumulative_output_tokens"]
46
+ billing_tokens = summary["cumulative_billing_tokens"]
47
+ turns = summary["turn_count"]
48
+
49
+ # Get context percentage for this agent
50
+ context_percentage = agent.usage_accumulator.context_usage_percentage
51
+
52
+ # Get model name from LLM's default_request_params
53
+ model = "unknown"
54
+ if hasattr(agent, "_llm") and agent._llm:
55
+ llm = agent._llm
56
+ if (
57
+ hasattr(llm, "default_request_params")
58
+ and llm.default_request_params
59
+ and hasattr(llm.default_request_params, "model")
60
+ ):
61
+ model = llm.default_request_params.model or "unknown"
62
+
63
+ # Standardize model name truncation - use consistent 25 char width with 22+... truncation
64
+ if len(model) > 25:
65
+ model = model[:22] + "..."
66
+
67
+ usage_data.append(
68
+ {
69
+ "name": agent_name,
70
+ "model": model,
71
+ "input": input_tokens,
72
+ "output": output_tokens,
73
+ "total": billing_tokens,
74
+ "turns": turns,
75
+ "context": context_percentage,
76
+ }
77
+ )
78
+
79
+ total_input += input_tokens
80
+ total_output += output_tokens
81
+ total_tokens += billing_tokens
82
+
83
+ if not usage_data:
84
+ return
85
+
86
+ # Calculate dynamic agent column width (max 15)
87
+ max_agent_width = min(15, max(len(data["name"]) for data in usage_data) if usage_data else 8)
88
+ agent_width = max(max_agent_width, 5) # Minimum of 5 for "Agent" header
89
+
90
+ # Display the table
91
+ console = Console()
92
+ console.print()
93
+ console.print("[dim]Usage Summary (Cumulative)[/dim]")
94
+
95
+ # Print header with proper spacing
96
+ console.print(
97
+ f"[dim]{'Agent':<{agent_width}} {'Input':>9} {'Output':>9} {'Total':>9} {'Turns':>6} {'Context%':>9} {'Model':<25}[/dim]"
98
+ )
99
+
100
+ # Print agent rows - use styling based on subdued_colors flag
101
+ for data in usage_data:
102
+ input_str = f"{data['input']:,}"
103
+ output_str = f"{data['output']:,}"
104
+ total_str = f"{data['total']:,}"
105
+ turns_str = str(data["turns"])
106
+ context_str = f"{data['context']:.1f}%" if data["context"] is not None else "-"
107
+
108
+ # Truncate agent name if needed
109
+ agent_name = data["name"]
110
+ if len(agent_name) > agent_width:
111
+ agent_name = agent_name[: agent_width - 3] + "..."
112
+
113
+ if subdued_colors:
114
+ # Original fastagent.py style with dim wrapper
115
+ console.print(
116
+ f"[dim]{agent_name:<{agent_width}} "
117
+ f"{input_str:>9} "
118
+ f"{output_str:>9} "
119
+ f"[bold]{total_str:>9}[/bold] "
120
+ f"{turns_str:>6} "
121
+ f"{context_str:>9} "
122
+ f"{data['model']:<25}[/dim]"
123
+ )
124
+ else:
125
+ # Original interactive_prompt.py style
126
+ console.print(
127
+ f"{agent_name:<{agent_width}} "
128
+ f"{input_str:>9} "
129
+ f"{output_str:>9} "
130
+ f"[bold]{total_str:>9}[/bold] "
131
+ f"{turns_str:>6} "
132
+ f"{context_str:>9} "
133
+ f"[dim]{data['model']:<25}[/dim]"
134
+ )
135
+
136
+ # Add total row if multiple agents
137
+ if len(usage_data) > 1:
138
+ console.print()
139
+ total_input_str = f"{total_input:,}"
140
+ total_output_str = f"{total_output:,}"
141
+ total_tokens_str = f"{total_tokens:,}"
142
+
143
+ if subdued_colors:
144
+ # Original fastagent.py style with dim wrapper on bold
145
+ console.print(
146
+ f"[bold dim]{'TOTAL':<{agent_width}} "
147
+ f"{total_input_str:>9} "
148
+ f"{total_output_str:>9} "
149
+ f"[bold]{total_tokens_str:>9}[/bold] "
150
+ f"{'':<6} "
151
+ f"{'':<9} "
152
+ f"{'':<25}[/bold dim]"
153
+ )
154
+ else:
155
+ # Original interactive_prompt.py style
156
+ console.print(
157
+ f"[bold]{'TOTAL':<{agent_width}}[/bold] "
158
+ f"[bold]{total_input_str:>9}[/bold] "
159
+ f"[bold]{total_output_str:>9}[/bold] "
160
+ f"[bold]{total_tokens_str:>9}[/bold] "
161
+ f"{'':<6} "
162
+ f"{'':<9} "
163
+ f"{'':<25}"
164
+ )
165
+
166
+ console.print()
167
+
168
+
169
+ def collect_agents_from_provider(
170
+ prompt_provider: Any, agent_name: Optional[str] = None
171
+ ) -> Dict[str, Any]:
172
+ """
173
+ Collect agents from a prompt provider for usage display.
174
+
175
+ Args:
176
+ prompt_provider: Provider that has access to agents
177
+ agent_name: Name of the current agent (for context)
178
+
179
+ Returns:
180
+ Dictionary of agent name -> agent object
181
+ """
182
+ agents_to_show = {}
183
+
184
+ if hasattr(prompt_provider, "_agents"):
185
+ # Multi-agent app - show all agents
186
+ agents_to_show = prompt_provider._agents
187
+ elif hasattr(prompt_provider, "agent"):
188
+ # Single agent
189
+ agent = prompt_provider.agent
190
+ if hasattr(agent, "name"):
191
+ agents_to_show = {agent.name: agent}
192
+
193
+ return agents_to_show
@@ -30,11 +30,13 @@ from mcp_agent.core.prompt import Prompt
30
30
  from mcp_agent.core.request_params import RequestParams
31
31
  from mcp_agent.event_progress import ProgressAction
32
32
  from mcp_agent.llm.memory import Memory, SimpleMemory
33
+ from mcp_agent.llm.model_database import ModelDatabase
33
34
  from mcp_agent.llm.provider_types import Provider
34
35
  from mcp_agent.llm.sampling_format_converter import (
35
36
  BasicFormatConverter,
36
37
  ProviderFormatConverter,
37
38
  )
39
+ from mcp_agent.llm.usage_tracking import UsageAccumulator
38
40
  from mcp_agent.logging.logger import get_logger
39
41
  from mcp_agent.mcp.helpers.content_helpers import get_text
40
42
  from mcp_agent.mcp.interfaces import (
@@ -155,12 +157,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
155
157
  # Initialize the display component
156
158
  self.display = ConsoleDisplay(config=self.context.config)
157
159
 
158
- # Initialize default parameters
159
- self.default_request_params = self._initialize_default_params(kwargs)
160
-
161
- # Apply model override if provided
160
+ # Initialize default parameters, passing model info
161
+ model_kwargs = kwargs.copy()
162
162
  if model:
163
- self.default_request_params.model = model
163
+ model_kwargs["model"] = model
164
+ self.default_request_params = self._initialize_default_params(model_kwargs)
164
165
 
165
166
  # Merge with provided params if any
166
167
  if self._init_request_params:
@@ -171,13 +172,22 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
171
172
  self.type_converter = type_converter
172
173
  self.verb = kwargs.get("verb")
173
174
 
175
+ # Initialize usage tracking
176
+ self.usage_accumulator = UsageAccumulator()
177
+
174
178
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
175
179
  """Initialize default parameters for the LLM.
176
180
  Should be overridden by provider implementations to set provider-specific defaults."""
181
+ # Get model-aware default max tokens
182
+ model = kwargs.get("model")
183
+ max_tokens = ModelDatabase.get_default_max_tokens(model)
184
+
177
185
  return RequestParams(
186
+ model=model,
187
+ maxTokens=max_tokens,
178
188
  systemPrompt=self.instruction,
179
189
  parallel_tool_calls=True,
180
- max_iterations=10,
190
+ max_iterations=20,
181
191
  use_history=True,
182
192
  )
183
193
 
@@ -642,3 +652,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
642
652
 
643
653
  assert self.provider
644
654
  return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
655
+
656
+ def get_usage_summary(self) -> dict:
657
+ """
658
+ Get a summary of usage statistics for this LLM instance.
659
+
660
+ Returns:
661
+ Dictionary containing usage statistics including tokens, cache metrics,
662
+ and context window utilization.
663
+ """
664
+ return self.usage_accumulator.get_summary()
@@ -10,6 +10,7 @@ from mcp_agent.llm.augmented_llm import (
10
10
  RequestParams,
11
11
  )
12
12
  from mcp_agent.llm.provider_types import Provider
13
+ from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
13
14
  from mcp_agent.logging.logger import get_logger
14
15
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
15
16
 
@@ -48,13 +49,34 @@ class PassthroughLLM(AugmentedLLM):
48
49
  await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
49
50
 
50
51
  # Handle PromptMessage by concatenating all parts
52
+ result = ""
51
53
  if isinstance(message, PromptMessage):
52
54
  parts_text = []
53
55
  for part in message.content:
54
56
  parts_text.append(str(part))
55
- return "\n".join(parts_text)
57
+ result = "\n".join(parts_text)
58
+ else:
59
+ result = str(message)
56
60
 
57
- return str(message)
61
+ # Track usage for this passthrough "turn"
62
+ try:
63
+ input_content = str(message)
64
+ output_content = result
65
+ tool_calls = 1 if input_content.startswith("***CALL_TOOL") else 0
66
+
67
+ turn_usage = create_turn_usage_from_messages(
68
+ input_content=input_content,
69
+ output_content=output_content,
70
+ model="passthrough",
71
+ model_type="passthrough",
72
+ tool_calls=tool_calls,
73
+ delay_seconds=0.0,
74
+ )
75
+ self.usage_accumulator.add_turn(turn_usage)
76
+ except Exception as e:
77
+ self.logger.warning(f"Failed to track usage: {e}")
78
+
79
+ return result
58
80
 
59
81
  async def initialize(self) -> None:
60
82
  pass
@@ -146,6 +168,25 @@ class PassthroughLLM(AugmentedLLM):
146
168
  if self.is_tool_call(last_message):
147
169
  result = Prompt.assistant(await self.generate_str(last_message.first_text()))
148
170
  await self.show_assistant_message(result.first_text())
171
+
172
+ # Track usage for this tool call "turn"
173
+ try:
174
+ input_content = "\n".join(message.all_text() for message in multipart_messages)
175
+ output_content = result.first_text()
176
+
177
+ turn_usage = create_turn_usage_from_messages(
178
+ input_content=input_content,
179
+ output_content=output_content,
180
+ model="passthrough",
181
+ model_type="passthrough",
182
+ tool_calls=1, # This is definitely a tool call
183
+ delay_seconds=0.0,
184
+ )
185
+ self.usage_accumulator.add_turn(turn_usage)
186
+
187
+ except Exception as e:
188
+ self.logger.warning(f"Failed to track usage: {e}")
189
+
149
190
  return result
150
191
 
151
192
  if last_message.first_text().startswith(FIXED_RESPONSE_INDICATOR):
@@ -155,12 +196,33 @@ class PassthroughLLM(AugmentedLLM):
155
196
 
156
197
  if self._fixed_response:
157
198
  await self.show_assistant_message(self._fixed_response)
158
- return Prompt.assistant(self._fixed_response)
199
+ result = Prompt.assistant(self._fixed_response)
159
200
  else:
160
201
  # TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
161
202
  concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
162
203
  await self.show_assistant_message(concatenated)
163
- return Prompt.assistant(concatenated)
204
+ result = Prompt.assistant(concatenated)
205
+
206
+ # Track usage for this passthrough "turn"
207
+ try:
208
+ input_content = "\n".join(message.all_text() for message in multipart_messages)
209
+ output_content = result.first_text()
210
+ tool_calls = 1 if self.is_tool_call(last_message) else 0
211
+
212
+ turn_usage = create_turn_usage_from_messages(
213
+ input_content=input_content,
214
+ output_content=output_content,
215
+ model="passthrough",
216
+ model_type="passthrough",
217
+ tool_calls=tool_calls,
218
+ delay_seconds=0.0,
219
+ )
220
+ self.usage_accumulator.add_turn(turn_usage)
221
+
222
+ except Exception as e:
223
+ self.logger.warning(f"Failed to track usage: {e}")
224
+
225
+ return result
164
226
 
165
227
  def is_tool_call(self, message: PromptMessageMultipart) -> bool:
166
228
  return message.first_text().startswith(CALL_TOOL_INDICATOR)
@@ -5,6 +5,7 @@ from mcp_agent.core.prompt import Prompt
5
5
  from mcp_agent.llm.augmented_llm import RequestParams
6
6
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
7
7
  from mcp_agent.llm.provider_types import Provider
8
+ from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
8
9
  from mcp_agent.mcp.interfaces import ModelT
9
10
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
10
11
  from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
@@ -83,6 +84,24 @@ class PlaybackLLM(PassthroughLLM):
83
84
  message_text=MessageContent.get_first_text(response), title="ASSISTANT/PLAYBACK"
84
85
  )
85
86
 
87
+ # Track usage for this playback "turn"
88
+ try:
89
+ input_content = str(multipart_messages) if multipart_messages else ""
90
+ output_content = MessageContent.get_first_text(response)
91
+
92
+ turn_usage = create_turn_usage_from_messages(
93
+ input_content=input_content,
94
+ output_content=output_content,
95
+ model="playback",
96
+ model_type="playback",
97
+ tool_calls=0,
98
+ delay_seconds=0.0,
99
+ )
100
+ self.usage_accumulator.add_turn(turn_usage)
101
+
102
+ except Exception as e:
103
+ self.logger.warning(f"Failed to track usage: {e}")
104
+
86
105
  return response
87
106
 
88
107
  async def structured(
@@ -30,7 +30,18 @@ class SlowLLM(PassthroughLLM):
30
30
  ) -> str:
31
31
  """Sleep for 3 seconds then return the input message as a string."""
32
32
  await asyncio.sleep(3)
33
- return await super().generate_str(message, request_params)
33
+ result = await super().generate_str(message, request_params)
34
+
35
+ # Override the last turn to include the 3-second delay
36
+ if self.usage_accumulator.turns:
37
+ last_turn = self.usage_accumulator.turns[-1]
38
+ # Update the raw usage to include delay
39
+ if hasattr(last_turn.raw_usage, 'delay_seconds'):
40
+ last_turn.raw_usage.delay_seconds = 3.0
41
+ # Print updated debug info
42
+ print("SlowLLM: Added 3.0s delay to turn usage")
43
+
44
+ return result
34
45
 
35
46
  async def _apply_prompt_provider_specific(
36
47
  self,
@@ -0,0 +1,236 @@
1
+ """
2
+ Model database for LLM parameters.
3
+
4
+ This module provides a centralized lookup for model parameters including
5
+ context windows, max output tokens, and supported tokenization types.
6
+ """
7
+
8
+ from typing import Dict, List, Optional
9
+
10
+ from pydantic import BaseModel
11
+
12
+
13
+ class ModelParameters(BaseModel):
14
+ """Configuration parameters for a specific model"""
15
+
16
+ context_window: int
17
+ """Maximum context window size in tokens"""
18
+
19
+ max_output_tokens: int
20
+ """Maximum output tokens the model can generate"""
21
+
22
+ tokenizes: List[str]
23
+ """List of supported content types for tokenization"""
24
+
25
+
26
+ class ModelDatabase:
27
+ """Centralized model configuration database"""
28
+
29
+ # Common parameter sets
30
+ OPENAI_MULTIMODAL = ["text/plain", "image/jpeg", "image/png", "image/webp", "application/pdf"]
31
+ OPENAI_VISION = ["text/plain", "image/jpeg", "image/png", "image/webp"]
32
+ ANTHROPIC_MULTIMODAL = [
33
+ "text/plain",
34
+ "image/jpeg",
35
+ "image/png",
36
+ "image/webp",
37
+ "application/pdf",
38
+ ]
39
+ GOOGLE_MULTIMODAL = [
40
+ "text/plain",
41
+ "image/jpeg",
42
+ "image/png",
43
+ "image/webp",
44
+ "application/pdf",
45
+ "audio/wav",
46
+ "audio/mp3",
47
+ "video/mp4",
48
+ ]
49
+ QWEN_MULTIMODAL = ["text/plain", "image/jpeg", "image/png", "image/webp"]
50
+ TEXT_ONLY = ["text/plain"]
51
+
52
+ # Common parameter configurations
53
+ OPENAI_STANDARD = ModelParameters(
54
+ context_window=128000, max_output_tokens=16384, tokenizes=OPENAI_MULTIMODAL
55
+ )
56
+
57
+ OPENAI_4_1_STANDARD = ModelParameters(
58
+ context_window=1047576, max_output_tokens=32768, tokenizes=OPENAI_MULTIMODAL
59
+ )
60
+
61
+ OPENAI_O_SERIES = ModelParameters(
62
+ context_window=200000, max_output_tokens=100000, tokenizes=OPENAI_VISION
63
+ )
64
+
65
+ ANTHROPIC_LEGACY = ModelParameters(
66
+ context_window=200000, max_output_tokens=4096, tokenizes=ANTHROPIC_MULTIMODAL
67
+ )
68
+
69
+ ANTHROPIC_35_SERIES = ModelParameters(
70
+ context_window=200000, max_output_tokens=8192, tokenizes=ANTHROPIC_MULTIMODAL
71
+ )
72
+
73
+ # TODO--- TO USE 64,000 NEED TO SUPPORT STREAMING
74
+ ANTHROPIC_37_SERIES = ModelParameters(
75
+ context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
76
+ )
77
+
78
+ GEMINI_FLASH = ModelParameters(
79
+ context_window=1048576, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
80
+ )
81
+
82
+ GEMINI_PRO = ModelParameters(
83
+ context_window=2097152, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
84
+ )
85
+
86
+ QWEN_STANDARD = ModelParameters(
87
+ context_window=32000, max_output_tokens=8192, tokenizes=QWEN_MULTIMODAL
88
+ )
89
+
90
+ FAST_AGENT_STANDARD = ModelParameters(
91
+ context_window=1000000, max_output_tokens=100000, tokenizes=TEXT_ONLY
92
+ )
93
+
94
+ OPENAI_4_1_SERIES = ModelParameters(
95
+ context_window=1047576, max_output_tokens=32768, tokenizes=OPENAI_MULTIMODAL
96
+ )
97
+
98
+ OPENAI_4O_SERIES = ModelParameters(
99
+ context_window=128000, max_output_tokens=16384, tokenizes=OPENAI_VISION
100
+ )
101
+
102
+ OPENAI_O3_SERIES = ModelParameters(
103
+ context_window=200000, max_output_tokens=100000, tokenizes=OPENAI_MULTIMODAL
104
+ )
105
+
106
+ OPENAI_O3_MINI_SERIES = ModelParameters(
107
+ context_window=200000, max_output_tokens=100000, tokenizes=TEXT_ONLY
108
+ )
109
+
110
+ # TODO update to 32000
111
+ ANTHROPIC_OPUS_4_VERSIONED = ModelParameters(
112
+ context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
113
+ )
114
+ # TODO update to 64000
115
+ ANTHROPIC_SONNET_4_VERSIONED = ModelParameters(
116
+ context_window=200000, max_output_tokens=16384, tokenizes=ANTHROPIC_MULTIMODAL
117
+ )
118
+
119
+ DEEPSEEK_CHAT_STANDARD = ModelParameters(
120
+ context_window=65536, max_output_tokens=8192, tokenizes=TEXT_ONLY
121
+ )
122
+
123
+ DEEPSEEK_REASONER = ModelParameters(
124
+ context_window=65536, max_output_tokens=32768, tokenizes=TEXT_ONLY
125
+ )
126
+
127
+ GEMINI_2_5_PRO = ModelParameters(
128
+ context_window=2097152, max_output_tokens=8192, tokenizes=GOOGLE_MULTIMODAL
129
+ )
130
+
131
+ # Model configuration database
132
+ MODELS: Dict[str, ModelParameters] = {
133
+ # internal models
134
+ "passthrough": FAST_AGENT_STANDARD,
135
+ "playback": FAST_AGENT_STANDARD,
136
+ "slow": FAST_AGENT_STANDARD,
137
+ # aliyun models
138
+ "qwen-turbo": QWEN_STANDARD,
139
+ "qwen-plus": QWEN_STANDARD,
140
+ "qwen-max": QWEN_STANDARD,
141
+ "qwen-long": ModelParameters(
142
+ context_window=10000000, max_output_tokens=8192, tokenizes=TEXT_ONLY
143
+ ),
144
+ # OpenAI Models (vanilla aliases and versioned)
145
+ "gpt-4.1": OPENAI_4_1_SERIES,
146
+ "gpt-4.1-mini": OPENAI_4_1_SERIES,
147
+ "gpt-4.1-nano": OPENAI_4_1_SERIES,
148
+ "gpt-4.1-2025-04-14": OPENAI_4_1_SERIES,
149
+ "gpt-4.1-mini-2025-04-14": OPENAI_4_1_SERIES,
150
+ "gpt-4.1-nano-2025-04-14": OPENAI_4_1_SERIES,
151
+ "gpt-4o": OPENAI_4O_SERIES,
152
+ "gpt-4o-2024-11-20": OPENAI_4O_SERIES,
153
+ "gpt-4o-mini-2024-07-18": OPENAI_4O_SERIES,
154
+ "o1": OPENAI_O_SERIES,
155
+ "o1-2024-12-17": OPENAI_O_SERIES,
156
+ "o3": OPENAI_O3_SERIES,
157
+ "o3-pro": ModelParameters(
158
+ context_window=200_000, max_output_tokens=100_000, tokenizes=TEXT_ONLY
159
+ ),
160
+ "o3-mini": OPENAI_O3_MINI_SERIES,
161
+ "o4-mini": OPENAI_O3_SERIES,
162
+ "o3-2025-04-16": OPENAI_O3_SERIES,
163
+ "o3-mini-2025-01-31": OPENAI_O3_MINI_SERIES,
164
+ "o4-mini-2025-04-16": OPENAI_O3_SERIES,
165
+ # Anthropic Models
166
+ "claude-3-haiku": ANTHROPIC_35_SERIES,
167
+ "claude-3-haiku-20240307": ANTHROPIC_LEGACY,
168
+ "claude-3-sonnet": ANTHROPIC_LEGACY,
169
+ "claude-3-opus": ANTHROPIC_LEGACY,
170
+ "claude-3-opus-20240229": ANTHROPIC_LEGACY,
171
+ "claude-3-opus-latest": ANTHROPIC_LEGACY,
172
+ "claude-3-5-haiku": ANTHROPIC_35_SERIES,
173
+ "claude-3-5-haiku-20241022": ANTHROPIC_35_SERIES,
174
+ "claude-3-5-haiku-latest": ANTHROPIC_35_SERIES,
175
+ "claude-3-sonnet-20240229": ANTHROPIC_LEGACY,
176
+ "claude-3-5-sonnet": ANTHROPIC_35_SERIES,
177
+ "claude-3-5-sonnet-20240620": ANTHROPIC_35_SERIES,
178
+ "claude-3-5-sonnet-20241022": ANTHROPIC_35_SERIES,
179
+ "claude-3-5-sonnet-latest": ANTHROPIC_35_SERIES,
180
+ "claude-3-7-sonnet": ANTHROPIC_37_SERIES,
181
+ "claude-3-7-sonnet-20250219": ANTHROPIC_37_SERIES,
182
+ "claude-3-7-sonnet-latest": ANTHROPIC_37_SERIES,
183
+ "claude-sonnet-4": ANTHROPIC_SONNET_4_VERSIONED,
184
+ "claude-sonnet-4-0": ANTHROPIC_SONNET_4_VERSIONED,
185
+ "claude-sonnet-4-20250514": ANTHROPIC_SONNET_4_VERSIONED,
186
+ "claude-opus-4": ANTHROPIC_OPUS_4_VERSIONED,
187
+ "claude-opus-4-0": ANTHROPIC_OPUS_4_VERSIONED,
188
+ "claude-opus-4-20250514": ANTHROPIC_OPUS_4_VERSIONED,
189
+ # DeepSeek Models
190
+ "deepseek-chat": DEEPSEEK_CHAT_STANDARD,
191
+ # Google Gemini Models (vanilla aliases and versioned)
192
+ "gemini-2.0-flash": GEMINI_FLASH,
193
+ "gemini-2.5-flash-preview": GEMINI_FLASH,
194
+ "gemini-2.5-pro-preview": GEMINI_2_5_PRO,
195
+ "gemini-2.5-flash-preview-05-20": GEMINI_FLASH,
196
+ "gemini-2.5-pro-preview-05-06": GEMINI_PRO,
197
+ }
198
+
199
+ @classmethod
200
+ def get_model_params(cls, model: str) -> Optional[ModelParameters]:
201
+ """Get model parameters for a given model name"""
202
+ return cls.MODELS.get(model)
203
+
204
+ @classmethod
205
+ def get_context_window(cls, model: str) -> Optional[int]:
206
+ """Get context window size for a model"""
207
+ params = cls.get_model_params(model)
208
+ return params.context_window if params else None
209
+
210
+ @classmethod
211
+ def get_max_output_tokens(cls, model: str) -> Optional[int]:
212
+ """Get maximum output tokens for a model"""
213
+ params = cls.get_model_params(model)
214
+ return params.max_output_tokens if params else None
215
+
216
+ @classmethod
217
+ def get_tokenizes(cls, model: str) -> Optional[List[str]]:
218
+ """Get supported tokenization types for a model"""
219
+ params = cls.get_model_params(model)
220
+ return params.tokenizes if params else None
221
+
222
+ @classmethod
223
+ def get_default_max_tokens(cls, model: str) -> int:
224
+ """Get default max_tokens for RequestParams based on model"""
225
+ if not model:
226
+ return 2048 # Fallback when no model specified
227
+
228
+ params = cls.get_model_params(model)
229
+ if params:
230
+ return params.max_output_tokens
231
+ return 2048 # Fallback for unknown models
232
+
233
+ @classmethod
234
+ def list_models(cls) -> List[str]:
235
+ """List all available model names"""
236
+ return list(cls.MODELS.keys())
@@ -87,6 +87,7 @@ class ModelFactory:
87
87
  "o1-preview": Provider.OPENAI,
88
88
  "o3": Provider.OPENAI,
89
89
  "o3-mini": Provider.OPENAI,
90
+ "o4-mini": Provider.OPENAI,
90
91
  "claude-3-haiku-20240307": Provider.ANTHROPIC,
91
92
  "claude-3-5-haiku-20241022": Provider.ANTHROPIC,
92
93
  "claude-3-5-haiku-latest": Provider.ANTHROPIC,