fast-agent-mcp 0.2.13__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.16.dist-info}/METADATA +1 -1
  2. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.16.dist-info}/RECORD +36 -36
  3. mcp_agent/agents/agent.py +2 -2
  4. mcp_agent/agents/base_agent.py +3 -3
  5. mcp_agent/agents/workflow/chain_agent.py +2 -2
  6. mcp_agent/agents/workflow/evaluator_optimizer.py +3 -3
  7. mcp_agent/agents/workflow/orchestrator_agent.py +3 -3
  8. mcp_agent/agents/workflow/parallel_agent.py +2 -2
  9. mcp_agent/agents/workflow/router_agent.py +2 -2
  10. mcp_agent/cli/commands/check_config.py +450 -0
  11. mcp_agent/cli/commands/setup.py +1 -1
  12. mcp_agent/cli/main.py +8 -15
  13. mcp_agent/config.py +4 -7
  14. mcp_agent/core/agent_types.py +8 -8
  15. mcp_agent/core/direct_decorators.py +10 -8
  16. mcp_agent/core/direct_factory.py +4 -1
  17. mcp_agent/core/enhanced_prompt.py +6 -5
  18. mcp_agent/core/interactive_prompt.py +70 -50
  19. mcp_agent/core/validation.py +6 -4
  20. mcp_agent/event_progress.py +6 -6
  21. mcp_agent/llm/augmented_llm.py +10 -2
  22. mcp_agent/llm/augmented_llm_passthrough.py +5 -3
  23. mcp_agent/llm/augmented_llm_playback.py +2 -1
  24. mcp_agent/llm/model_factory.py +7 -27
  25. mcp_agent/llm/provider_key_manager.py +83 -0
  26. mcp_agent/llm/provider_types.py +16 -0
  27. mcp_agent/llm/providers/augmented_llm_anthropic.py +5 -26
  28. mcp_agent/llm/providers/augmented_llm_deepseek.py +5 -24
  29. mcp_agent/llm/providers/augmented_llm_generic.py +4 -16
  30. mcp_agent/llm/providers/augmented_llm_openai.py +4 -26
  31. mcp_agent/llm/providers/augmented_llm_openrouter.py +17 -45
  32. mcp_agent/mcp/interfaces.py +2 -1
  33. mcp_agent/mcp_server/agent_server.py +120 -38
  34. mcp_agent/cli/commands/config.py +0 -11
  35. mcp_agent/executor/temporal.py +0 -383
  36. mcp_agent/executor/workflow.py +0 -195
  37. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.16.dist-info}/WHEEL +0 -0
  38. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.16.dist-info}/entry_points.txt +0 -0
  39. {fast_agent_mcp-0.2.13.dist-info → fast_agent_mcp-0.2.16.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@ from prompt_toolkit.key_binding import KeyBindings
14
14
  from prompt_toolkit.styles import Style
15
15
  from rich import print as rich_print
16
16
 
17
+ from mcp_agent.core.agent_types import AgentType
17
18
  from mcp_agent.core.exceptions import PromptExitError
18
19
 
19
20
  # Get the application version
@@ -86,7 +87,7 @@ class AgentCompleter(Completer):
86
87
  for agent in self.agents:
87
88
  if agent.lower().startswith(agent_name.lower()):
88
89
  # Get agent type or default to "Agent"
89
- agent_type = self.agent_types.get(agent, "Agent")
90
+ agent_type = self.agent_types.get(agent, AgentType.BASIC).value
90
91
  yield Completion(
91
92
  agent,
92
93
  start_position=-len(agent_name),
@@ -149,7 +150,7 @@ async def get_enhanced_input(
149
150
  show_stop_hint: bool = False,
150
151
  multiline: bool = False,
151
152
  available_agent_names: List[str] = None,
152
- agent_types: dict = None,
153
+ agent_types: dict[str, AgentType] = None,
153
154
  is_human_input: bool = False,
154
155
  toolbar_color: str = "ansiblue",
155
156
  ) -> str:
@@ -430,18 +431,18 @@ async def get_argument_input(
430
431
  async def handle_special_commands(command, agent_app=None):
431
432
  """
432
433
  Handle special input commands.
433
-
434
+
434
435
  Args:
435
436
  command: The command to handle, can be string or dictionary
436
437
  agent_app: Optional agent app reference
437
-
438
+
438
439
  Returns:
439
440
  True if command was handled, False if not, or a dict with action info
440
441
  """
441
442
  # Quick guard for empty or None commands
442
443
  if not command:
443
444
  return False
444
-
445
+
445
446
  # If command is already a dictionary, it has been pre-processed
446
447
  # Just return it directly (like when /prompts converts to select_prompt dict)
447
448
  if isinstance(command, dict):
@@ -20,6 +20,7 @@ from rich import print as rich_print
20
20
  from rich.console import Console
21
21
  from rich.table import Table
22
22
 
23
+ from mcp_agent.core.agent_types import AgentType
23
24
  from mcp_agent.core.enhanced_prompt import (
24
25
  get_argument_input,
25
26
  get_enhanced_input,
@@ -36,14 +37,14 @@ class InteractivePrompt:
36
37
  This is extracted from the original AgentApp implementation to support DirectAgentApp.
37
38
  """
38
39
 
39
- def __init__(self, agent_types: Optional[Dict[str, str]] = None) -> None:
40
+ def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None:
40
41
  """
41
42
  Initialize the interactive prompt.
42
43
 
43
44
  Args:
44
45
  agent_types: Dictionary mapping agent names to their types for display
45
46
  """
46
- self.agent_types = agent_types or {}
47
+ self.agent_types: Dict[str, AgentType] = agent_types or {}
47
48
 
48
49
  async def prompt_loop(
49
50
  self,
@@ -97,7 +98,7 @@ class InteractivePrompt:
97
98
 
98
99
  # Handle special commands - pass "True" to enable agent switching
99
100
  command_result = await handle_special_commands(user_input, True)
100
-
101
+
101
102
  # Check if we should switch agents
102
103
  if isinstance(command_result, dict):
103
104
  if "switch_agent" in command_result:
@@ -113,11 +114,13 @@ class InteractivePrompt:
113
114
  # Use the list_prompts_func directly
114
115
  await self._list_prompts(list_prompts_func, agent)
115
116
  continue
116
- elif "select_prompt" in command_result and (list_prompts_func and apply_prompt_func):
117
+ elif "select_prompt" in command_result and (
118
+ list_prompts_func and apply_prompt_func
119
+ ):
117
120
  # Handle prompt selection, using both list_prompts and apply_prompt
118
121
  prompt_name = command_result.get("prompt_name")
119
122
  prompt_index = command_result.get("prompt_index")
120
-
123
+
121
124
  # If a specific index was provided (from /prompt <number>)
122
125
  if prompt_index is not None:
123
126
  # First get a list of all prompts to look up the index
@@ -125,20 +128,29 @@ class InteractivePrompt:
125
128
  if not all_prompts:
126
129
  rich_print("[yellow]No prompts available[/yellow]")
127
130
  continue
128
-
131
+
129
132
  # Check if the index is valid
130
133
  if 1 <= prompt_index <= len(all_prompts):
131
134
  # Get the prompt at the specified index (1-based to 0-based)
132
135
  selected_prompt = all_prompts[prompt_index - 1]
133
136
  # Use the already created namespaced_name to ensure consistency
134
- await self._select_prompt(list_prompts_func, apply_prompt_func, agent, selected_prompt["namespaced_name"])
137
+ await self._select_prompt(
138
+ list_prompts_func,
139
+ apply_prompt_func,
140
+ agent,
141
+ selected_prompt["namespaced_name"],
142
+ )
135
143
  else:
136
- rich_print(f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]")
144
+ rich_print(
145
+ f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]"
146
+ )
137
147
  # Show the prompt list for convenience
138
148
  await self._list_prompts(list_prompts_func, agent)
139
149
  else:
140
150
  # Use the name-based selection
141
- await self._select_prompt(list_prompts_func, apply_prompt_func, agent, prompt_name)
151
+ await self._select_prompt(
152
+ list_prompts_func, apply_prompt_func, agent, prompt_name
153
+ )
142
154
  continue
143
155
 
144
156
  # Skip further processing if command was handled
@@ -158,11 +170,11 @@ class InteractivePrompt:
158
170
  async def _get_all_prompts(self, list_prompts_func, agent_name):
159
171
  """
160
172
  Get a list of all available prompts.
161
-
173
+
162
174
  Args:
163
175
  list_prompts_func: Function to get available prompts
164
176
  agent_name: Name of the agent
165
-
177
+
166
178
  Returns:
167
179
  List of prompt info dictionaries, sorted by server and name
168
180
  """
@@ -171,7 +183,7 @@ class InteractivePrompt:
171
183
  # the agent_name parameter should never be used as a server name
172
184
  prompt_servers = await list_prompts_func(None)
173
185
  all_prompts = []
174
-
186
+
175
187
  # Process the returned prompt servers
176
188
  if prompt_servers:
177
189
  # First collect all prompts
@@ -179,44 +191,51 @@ class InteractivePrompt:
179
191
  if prompts_info and hasattr(prompts_info, "prompts") and prompts_info.prompts:
180
192
  for prompt in prompts_info.prompts:
181
193
  # Use the SEP constant for proper namespacing
182
- all_prompts.append({
183
- "server": server_name,
184
- "name": prompt.name,
185
- "namespaced_name": f"{server_name}{SEP}{prompt.name}",
186
- "description": getattr(prompt, "description", "No description"),
187
- "arg_count": len(getattr(prompt, "arguments", [])),
188
- "arguments": getattr(prompt, "arguments", [])
189
- })
194
+ all_prompts.append(
195
+ {
196
+ "server": server_name,
197
+ "name": prompt.name,
198
+ "namespaced_name": f"{server_name}{SEP}{prompt.name}",
199
+ "description": getattr(prompt, "description", "No description"),
200
+ "arg_count": len(getattr(prompt, "arguments", [])),
201
+ "arguments": getattr(prompt, "arguments", []),
202
+ }
203
+ )
190
204
  elif isinstance(prompts_info, list) and prompts_info:
191
205
  for prompt in prompts_info:
192
206
  if isinstance(prompt, dict) and "name" in prompt:
193
- all_prompts.append({
194
- "server": server_name,
195
- "name": prompt["name"],
196
- "namespaced_name": f"{server_name}{SEP}{prompt['name']}",
197
- "description": prompt.get("description", "No description"),
198
- "arg_count": len(prompt.get("arguments", [])),
199
- "arguments": prompt.get("arguments", [])
200
- })
207
+ all_prompts.append(
208
+ {
209
+ "server": server_name,
210
+ "name": prompt["name"],
211
+ "namespaced_name": f"{server_name}{SEP}{prompt['name']}",
212
+ "description": prompt.get("description", "No description"),
213
+ "arg_count": len(prompt.get("arguments", [])),
214
+ "arguments": prompt.get("arguments", []),
215
+ }
216
+ )
201
217
  else:
202
- all_prompts.append({
203
- "server": server_name,
204
- "name": str(prompt),
205
- "namespaced_name": f"{server_name}{SEP}{str(prompt)}",
206
- "description": "No description",
207
- "arg_count": 0,
208
- "arguments": []
209
- })
210
-
218
+ all_prompts.append(
219
+ {
220
+ "server": server_name,
221
+ "name": str(prompt),
222
+ "namespaced_name": f"{server_name}{SEP}{str(prompt)}",
223
+ "description": "No description",
224
+ "arg_count": 0,
225
+ "arguments": [],
226
+ }
227
+ )
228
+
211
229
  # Sort prompts by server and name for consistent ordering
212
230
  all_prompts.sort(key=lambda p: (p["server"], p["name"]))
213
-
231
+
214
232
  return all_prompts
215
-
233
+
216
234
  except Exception as e:
217
235
  import traceback
218
236
 
219
237
  from rich import print as rich_print
238
+
220
239
  rich_print(f"[red]Error getting prompts: {e}[/red]")
221
240
  rich_print(f"[dim]{traceback.format_exc()}[/dim]")
222
241
  return []
@@ -238,11 +257,11 @@ class InteractivePrompt:
238
257
  try:
239
258
  # Directly call the list_prompts function for this agent
240
259
  rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
241
-
260
+
242
261
  # Get all prompts using the helper function - pass None as server name
243
262
  # to get prompts from all available servers
244
263
  all_prompts = await self._get_all_prompts(list_prompts_func, None)
245
-
264
+
246
265
  if all_prompts:
247
266
  # Create a table for better display
248
267
  table = Table(title="Available MCP Prompts")
@@ -251,7 +270,7 @@ class InteractivePrompt:
251
270
  table.add_column("Prompt Name", style="bright_blue")
252
271
  table.add_column("Description")
253
272
  table.add_column("Args", justify="center")
254
-
273
+
255
274
  # Add prompts to table
256
275
  for i, prompt in enumerate(all_prompts):
257
276
  table.add_row(
@@ -259,11 +278,11 @@ class InteractivePrompt:
259
278
  prompt["server"],
260
279
  prompt["name"],
261
280
  prompt["description"],
262
- str(prompt["arg_count"])
281
+ str(prompt["arg_count"]),
263
282
  )
264
-
283
+
265
284
  console.print(table)
266
-
285
+
267
286
  # Add usage instructions
268
287
  rich_print("\n[bold]Usage:[/bold]")
269
288
  rich_print(" • Use [cyan]/prompt <number>[/cyan] to select a prompt by number")
@@ -272,10 +291,13 @@ class InteractivePrompt:
272
291
  rich_print("[yellow]No prompts available[/yellow]")
273
292
  except Exception as e:
274
293
  import traceback
294
+
275
295
  rich_print(f"[red]Error listing prompts: {e}[/red]")
276
296
  rich_print(f"[dim]{traceback.format_exc()}[/dim]")
277
297
 
278
- async def _select_prompt(self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None) -> None:
298
+ async def _select_prompt(
299
+ self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None
300
+ ) -> None:
279
301
  """
280
302
  Select and apply a prompt.
281
303
 
@@ -293,7 +315,7 @@ class InteractivePrompt:
293
315
  try:
294
316
  # Get all available prompts directly from the list_prompts function
295
317
  rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
296
- # IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
318
+ # IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
297
319
  # So we pass None to get prompts from all servers, not using agent_name as server name
298
320
  prompt_servers = await list_prompts_func(None)
299
321
 
@@ -514,9 +536,7 @@ class InteractivePrompt:
514
536
 
515
537
  # Apply the prompt
516
538
  namespaced_name = selected_prompt["namespaced_name"]
517
- rich_print(
518
- f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]"
519
- )
539
+ rich_print(f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]")
520
540
 
521
541
  # Call apply_prompt function with the prompt name and arguments
522
542
  await apply_prompt_func(namespaced_name, arg_values, agent_name)
@@ -51,8 +51,9 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
51
51
  available_components = set(agents.keys())
52
52
 
53
53
  for name, agent_data in agents.items():
54
- agent_type = agent_data["type"]
55
-
54
+ agent_type = agent_data["type"] # This is a string from config
55
+
56
+ # Note: Compare string values from config with the Enum's string value
56
57
  if agent_type == AgentType.PARALLEL.value:
57
58
  # Check fan_in exists
58
59
  fan_in = agent_data["fan_in"]
@@ -224,8 +225,9 @@ def get_dependencies_groups(
224
225
 
225
226
  # Build the dependency graph
226
227
  for name, agent_data in agents_dict.items():
227
- agent_type = agent_data["type"]
228
-
228
+ agent_type = agent_data["type"] # This is a string from config
229
+
230
+ # Note: Compare string values from config with the Enum's string value
229
231
  if agent_type == AgentType.PARALLEL.value:
230
232
  # Parallel agents depend on their fan-out and fan-in agents
231
233
  dependencies[name].update(agent_data.get("parallel_agents", []))
@@ -1,9 +1,10 @@
1
1
  """Module for converting log events to progress events."""
2
2
 
3
- from dataclasses import dataclass
4
3
  from enum import Enum
5
4
  from typing import Optional
6
5
 
6
+ from pydantic import BaseModel
7
+
7
8
  from mcp_agent.logging.events import Event
8
9
 
9
10
 
@@ -24,8 +25,7 @@ class ProgressAction(str, Enum):
24
25
  FATAL_ERROR = "Error"
25
26
 
26
27
 
27
- @dataclass
28
- class ProgressEvent:
28
+ class ProgressEvent(BaseModel):
29
29
  """Represents a progress event converted from a log event."""
30
30
 
31
31
  action: ProgressAction
@@ -87,8 +87,8 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
87
87
  target = event_data.get("target", "unknown")
88
88
 
89
89
  return ProgressEvent(
90
- ProgressAction(progress_action),
91
- target,
92
- details,
90
+ action=ProgressAction(progress_action),
91
+ target=target,
92
+ details=details,
93
93
  agent_name=event_data.get("agent_name"),
94
94
  )
@@ -27,6 +27,7 @@ from mcp_agent.core.prompt import Prompt
27
27
  from mcp_agent.core.request_params import RequestParams
28
28
  from mcp_agent.event_progress import ProgressAction
29
29
  from mcp_agent.llm.memory import Memory, SimpleMemory
30
+ from mcp_agent.llm.provider_types import Provider
30
31
  from mcp_agent.llm.sampling_format_converter import (
31
32
  BasicFormatConverter,
32
33
  ProviderFormatConverter,
@@ -64,10 +65,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
64
65
  selecting appropriate tools, and determining what information to retain.
65
66
  """
66
67
 
67
- provider: str | None = None
68
+ provider: Provider | None = None
68
69
 
69
70
  def __init__(
70
71
  self,
72
+ provider: Provider,
71
73
  agent: Optional["Agent"] = None,
72
74
  server_names: List[str] | None = None,
73
75
  instruction: str | None = None,
@@ -104,7 +106,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
104
106
  self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
105
107
  self.name = agent.name if agent else name
106
108
  self.instruction = agent.instruction if agent else instruction
107
-
109
+ self.provider = provider
108
110
  # memory contains provider specific API types.
109
111
  self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
110
112
 
@@ -480,3 +482,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
480
482
  List of PromptMessageMultipart objects representing the conversation history
481
483
  """
482
484
  return self._message_history
485
+
486
+ def _api_key(self):
487
+ from mcp_agent.llm.provider_key_manager import ProviderKeyManager
488
+
489
+ assert self.provider
490
+ return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
@@ -9,6 +9,7 @@ from mcp_agent.llm.augmented_llm import (
9
9
  MessageParamT,
10
10
  RequestParams,
11
11
  )
12
+ from mcp_agent.llm.provider_types import Provider
12
13
  from mcp_agent.logging.logger import get_logger
13
14
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
14
15
 
@@ -25,9 +26,10 @@ class PassthroughLLM(AugmentedLLM):
25
26
  parallel workflow where no fan-in aggregation is needed.
26
27
  """
27
28
 
28
- def __init__(self, name: str = "Passthrough", **kwargs: dict[str, Any]) -> None:
29
- super().__init__(name=name, **kwargs)
30
- self.provider = "fast-agent"
29
+ def __init__(
30
+ self, provider=Provider.FAST_AGENT, name: str = "Passthrough", **kwargs: dict[str, Any]
31
+ ) -> None:
32
+ super().__init__(name=name, provider=provider, **kwargs)
31
33
  self.logger = get_logger(__name__)
32
34
  self._messages = [PromptMessage]
33
35
  self._fixed_response: str | None = None
@@ -3,6 +3,7 @@ from typing import Any, List
3
3
  from mcp_agent.core.prompt import Prompt
4
4
  from mcp_agent.llm.augmented_llm import RequestParams
5
5
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
6
+ from mcp_agent.llm.provider_types import Provider
6
7
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
7
8
  from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
8
9
 
@@ -21,7 +22,7 @@ class PlaybackLLM(PassthroughLLM):
21
22
  """
22
23
 
23
24
  def __init__(self, name: str = "Playback", **kwargs: dict[str, Any]) -> None:
24
- super().__init__(name=name, **kwargs)
25
+ super().__init__(name=name, provider=Provider.FAST_AGENT, **kwargs)
25
26
  self._messages: List[PromptMessageMultipart] = []
26
27
  self._current_index = -1
27
28
  self._overage = -1
@@ -1,12 +1,14 @@
1
- from dataclasses import dataclass
2
- from enum import Enum, auto
1
+ from enum import Enum
3
2
  from typing import Callable, Dict, Optional, Type, Union
4
3
 
4
+ from pydantic import BaseModel
5
+
5
6
  from mcp_agent.agents.agent import Agent
6
7
  from mcp_agent.core.exceptions import ModelConfigError
7
8
  from mcp_agent.core.request_params import RequestParams
8
9
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
9
10
  from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
11
+ from mcp_agent.llm.provider_types import Provider
10
12
  from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
11
13
  from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM
12
14
  from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
@@ -28,17 +30,6 @@ LLMClass = Union[
28
30
  ]
29
31
 
30
32
 
31
- class Provider(Enum):
32
- """Supported LLM providers"""
33
-
34
- ANTHROPIC = auto()
35
- OPENAI = auto()
36
- FAST_AGENT = auto()
37
- DEEPSEEK = auto()
38
- GENERIC = auto()
39
- OPENROUTER = auto()
40
-
41
-
42
33
  class ReasoningEffort(Enum):
43
34
  """Optional reasoning effort levels"""
44
35
 
@@ -47,8 +38,7 @@ class ReasoningEffort(Enum):
47
38
  HIGH = "high"
48
39
 
49
40
 
50
- @dataclass
51
- class ModelConfig:
41
+ class ModelConfig(BaseModel):
52
42
  """Configuration for a specific model"""
53
43
 
54
44
  provider: Provider
@@ -59,16 +49,6 @@ class ModelConfig:
59
49
  class ModelFactory:
60
50
  """Factory for creating LLM instances based on model specifications"""
61
51
 
62
- # Mapping of provider strings to enum values
63
- PROVIDER_MAP = {
64
- "anthropic": Provider.ANTHROPIC,
65
- "openai": Provider.OPENAI,
66
- "fast-agent": Provider.FAST_AGENT,
67
- "deepseek": Provider.DEEPSEEK,
68
- "generic": Provider.GENERIC,
69
- "openrouter": Provider.OPENROUTER,
70
- }
71
-
72
52
  # Mapping of effort strings to enum values
73
53
  EFFORT_MAP = {
74
54
  "low": ReasoningEffort.LOW,
@@ -156,8 +136,8 @@ class ModelFactory:
156
136
  # Check first part for provider
157
137
  if len(model_parts) > 1:
158
138
  potential_provider = model_parts[0]
159
- if potential_provider in cls.PROVIDER_MAP:
160
- provider = cls.PROVIDER_MAP[potential_provider]
139
+ if any(provider.value == potential_provider for provider in Provider):
140
+ provider = Provider(potential_provider)
161
141
  model_parts = model_parts[1:]
162
142
 
163
143
  # Join remaining parts as model name
@@ -0,0 +1,83 @@
1
+ """
2
+ Provider API key management for various LLM providers.
3
+ Centralizes API key handling logic to make provider implementations more generic.
4
+ """
5
+
6
+ import os
7
+ from typing import Any, Dict
8
+
9
+ from pydantic import BaseModel
10
+
11
+ from mcp_agent.core.exceptions import ProviderKeyError
12
+
13
+ PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
14
+ "anthropic": "ANTHROPIC_API_KEY",
15
+ "openai": "OPENAI_API_KEY",
16
+ "deepseek": "DEEPSEEK_API_KEY",
17
+ "openrouter": "OPENROUTER_API_KEY",
18
+ "generic": "GENERIC_API_KEY",
19
+ }
20
+ API_KEY_HINT_TEXT = "<your-api-key-here>"
21
+
22
+
23
+ class ProviderKeyManager:
24
+ """
25
+ Manages API keys for different providers centrally.
26
+ This class abstracts away the provider-specific key access logic,
27
+ making the provider implementations more generic.
28
+ """
29
+
30
+ @staticmethod
31
+ def get_env_var(provider_name: str) -> str | None:
32
+ return os.getenv(ProviderKeyManager.get_env_key_name(provider_name))
33
+
34
+ @staticmethod
35
+ def get_env_key_name(provider_name: str) -> str:
36
+ return PROVIDER_ENVIRONMENT_MAP.get(provider_name, f"{provider_name.upper()}_API_KEY")
37
+
38
+ @staticmethod
39
+ def get_config_file_key(provider_name: str, config: Any) -> str | None:
40
+ api_key = None
41
+ if isinstance(config, BaseModel):
42
+ config = config.model_dump()
43
+ provider_settings = config.get(provider_name)
44
+ if provider_settings:
45
+ api_key = provider_settings.get("api_key", API_KEY_HINT_TEXT)
46
+ if api_key == API_KEY_HINT_TEXT:
47
+ api_key = None
48
+
49
+ return api_key
50
+
51
+ @staticmethod
52
+ def get_api_key(provider_name: str, config: Any) -> str:
53
+ """
54
+ Gets the API key for the specified provider.
55
+
56
+ Args:
57
+ provider_name: Name of the provider (e.g., "anthropic", "openai")
58
+ config: The application configuration object
59
+
60
+ Returns:
61
+ The API key as a string
62
+
63
+ Raises:
64
+ ProviderKeyError: If the API key is not found or is invalid
65
+ """
66
+
67
+ provider_name = provider_name.lower()
68
+ api_key = ProviderKeyManager.get_config_file_key(provider_name, config)
69
+ if not api_key:
70
+ api_key = ProviderKeyManager.get_env_var(provider_name)
71
+
72
+ if not api_key and provider_name == "generic":
73
+ api_key = "ollama" # Default for generic provider
74
+
75
+ if not api_key:
76
+ raise ProviderKeyError(
77
+ f"{provider_name.title()} API key not configured",
78
+ f"The {provider_name.title()} API key is required but not set.\n"
79
+ f"Add it to your configuration file under {provider_name}.api_key "
80
+ f"or set the {ProviderKeyManager.get_env_key_name(provider_name)} environment variable.",
81
+ )
82
+
83
+ return api_key
@@ -0,0 +1,16 @@
1
+ """
2
+ Type definitions for LLM providers.
3
+ """
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class Provider(Enum):
9
+ """Supported LLM providers"""
10
+
11
+ ANTHROPIC = "anthropic"
12
+ OPENAI = "openai"
13
+ FAST_AGENT = "fast-agent"
14
+ DEEPSEEK = "deepseek"
15
+ GENERIC = "generic"
16
+ OPENROUTER = "openrouter"