fast-agent-mcp 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/METADATA +1 -1
  2. {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/RECORD +33 -33
  3. mcp_agent/agents/agent.py +2 -2
  4. mcp_agent/agents/base_agent.py +3 -3
  5. mcp_agent/agents/workflow/chain_agent.py +2 -2
  6. mcp_agent/agents/workflow/evaluator_optimizer.py +3 -3
  7. mcp_agent/agents/workflow/orchestrator_agent.py +3 -3
  8. mcp_agent/agents/workflow/parallel_agent.py +2 -2
  9. mcp_agent/agents/workflow/router_agent.py +2 -2
  10. mcp_agent/cli/commands/check_config.py +450 -0
  11. mcp_agent/cli/commands/setup.py +1 -1
  12. mcp_agent/cli/main.py +8 -15
  13. mcp_agent/core/agent_types.py +8 -8
  14. mcp_agent/core/direct_decorators.py +10 -8
  15. mcp_agent/core/direct_factory.py +4 -1
  16. mcp_agent/core/validation.py +6 -4
  17. mcp_agent/event_progress.py +6 -6
  18. mcp_agent/llm/augmented_llm.py +10 -2
  19. mcp_agent/llm/augmented_llm_passthrough.py +5 -3
  20. mcp_agent/llm/augmented_llm_playback.py +2 -1
  21. mcp_agent/llm/model_factory.py +7 -27
  22. mcp_agent/llm/provider_key_manager.py +83 -0
  23. mcp_agent/llm/provider_types.py +16 -0
  24. mcp_agent/llm/providers/augmented_llm_anthropic.py +5 -26
  25. mcp_agent/llm/providers/augmented_llm_deepseek.py +5 -24
  26. mcp_agent/llm/providers/augmented_llm_generic.py +2 -16
  27. mcp_agent/llm/providers/augmented_llm_openai.py +4 -26
  28. mcp_agent/llm/providers/augmented_llm_openrouter.py +17 -45
  29. mcp_agent/mcp/interfaces.py +2 -1
  30. mcp_agent/mcp_server/agent_server.py +335 -14
  31. mcp_agent/cli/commands/config.py +0 -11
  32. mcp_agent/executor/temporal.py +0 -383
  33. mcp_agent/executor/workflow.py +0 -195
  34. {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/WHEEL +0 -0
  35. {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/entry_points.txt +0 -0
  36. {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/licenses/LICENSE +0 -0
@@ -27,6 +27,7 @@ from mcp_agent.core.prompt import Prompt
27
27
  from mcp_agent.core.request_params import RequestParams
28
28
  from mcp_agent.event_progress import ProgressAction
29
29
  from mcp_agent.llm.memory import Memory, SimpleMemory
30
+ from mcp_agent.llm.provider_types import Provider
30
31
  from mcp_agent.llm.sampling_format_converter import (
31
32
  BasicFormatConverter,
32
33
  ProviderFormatConverter,
@@ -64,10 +65,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
64
65
  selecting appropriate tools, and determining what information to retain.
65
66
  """
66
67
 
67
- provider: str | None = None
68
+ provider: Provider | None = None
68
69
 
69
70
  def __init__(
70
71
  self,
72
+ provider: Provider,
71
73
  agent: Optional["Agent"] = None,
72
74
  server_names: List[str] | None = None,
73
75
  instruction: str | None = None,
@@ -104,7 +106,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
104
106
  self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
105
107
  self.name = agent.name if agent else name
106
108
  self.instruction = agent.instruction if agent else instruction
107
-
109
+ self.provider = provider
108
110
  # memory contains provider specific API types.
109
111
  self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
110
112
 
@@ -480,3 +482,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
480
482
  List of PromptMessageMultipart objects representing the conversation history
481
483
  """
482
484
  return self._message_history
485
+
486
+ def _api_key(self):
487
+ from mcp_agent.llm.provider_key_manager import ProviderKeyManager
488
+
489
+ assert self.provider
490
+ return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
@@ -9,6 +9,7 @@ from mcp_agent.llm.augmented_llm import (
9
9
  MessageParamT,
10
10
  RequestParams,
11
11
  )
12
+ from mcp_agent.llm.provider_types import Provider
12
13
  from mcp_agent.logging.logger import get_logger
13
14
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
14
15
 
@@ -25,9 +26,10 @@ class PassthroughLLM(AugmentedLLM):
25
26
  parallel workflow where no fan-in aggregation is needed.
26
27
  """
27
28
 
28
- def __init__(self, name: str = "Passthrough", **kwargs: dict[str, Any]) -> None:
29
- super().__init__(name=name, **kwargs)
30
- self.provider = "fast-agent"
29
+ def __init__(
30
+ self, provider=Provider.FAST_AGENT, name: str = "Passthrough", **kwargs: dict[str, Any]
31
+ ) -> None:
32
+ super().__init__(name=name, provider=provider, **kwargs)
31
33
  self.logger = get_logger(__name__)
32
34
  self._messages = [PromptMessage]
33
35
  self._fixed_response: str | None = None
@@ -3,6 +3,7 @@ from typing import Any, List
3
3
  from mcp_agent.core.prompt import Prompt
4
4
  from mcp_agent.llm.augmented_llm import RequestParams
5
5
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
6
+ from mcp_agent.llm.provider_types import Provider
6
7
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
7
8
  from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
8
9
 
@@ -21,7 +22,7 @@ class PlaybackLLM(PassthroughLLM):
21
22
  """
22
23
 
23
24
  def __init__(self, name: str = "Playback", **kwargs: dict[str, Any]) -> None:
24
- super().__init__(name=name, **kwargs)
25
+ super().__init__(name=name, provider=Provider.FAST_AGENT, **kwargs)
25
26
  self._messages: List[PromptMessageMultipart] = []
26
27
  self._current_index = -1
27
28
  self._overage = -1
@@ -1,12 +1,14 @@
1
- from dataclasses import dataclass
2
- from enum import Enum, auto
1
+ from enum import Enum
3
2
  from typing import Callable, Dict, Optional, Type, Union
4
3
 
4
+ from pydantic import BaseModel
5
+
5
6
  from mcp_agent.agents.agent import Agent
6
7
  from mcp_agent.core.exceptions import ModelConfigError
7
8
  from mcp_agent.core.request_params import RequestParams
8
9
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
9
10
  from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
11
+ from mcp_agent.llm.provider_types import Provider
10
12
  from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
11
13
  from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM
12
14
  from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
@@ -28,17 +30,6 @@ LLMClass = Union[
28
30
  ]
29
31
 
30
32
 
31
- class Provider(Enum):
32
- """Supported LLM providers"""
33
-
34
- ANTHROPIC = auto()
35
- OPENAI = auto()
36
- FAST_AGENT = auto()
37
- DEEPSEEK = auto()
38
- GENERIC = auto()
39
- OPENROUTER = auto()
40
-
41
-
42
33
  class ReasoningEffort(Enum):
43
34
  """Optional reasoning effort levels"""
44
35
 
@@ -47,8 +38,7 @@ class ReasoningEffort(Enum):
47
38
  HIGH = "high"
48
39
 
49
40
 
50
- @dataclass
51
- class ModelConfig:
41
+ class ModelConfig(BaseModel):
52
42
  """Configuration for a specific model"""
53
43
 
54
44
  provider: Provider
@@ -59,16 +49,6 @@ class ModelConfig:
59
49
  class ModelFactory:
60
50
  """Factory for creating LLM instances based on model specifications"""
61
51
 
62
- # Mapping of provider strings to enum values
63
- PROVIDER_MAP = {
64
- "anthropic": Provider.ANTHROPIC,
65
- "openai": Provider.OPENAI,
66
- "fast-agent": Provider.FAST_AGENT,
67
- "deepseek": Provider.DEEPSEEK,
68
- "generic": Provider.GENERIC,
69
- "openrouter": Provider.OPENROUTER,
70
- }
71
-
72
52
  # Mapping of effort strings to enum values
73
53
  EFFORT_MAP = {
74
54
  "low": ReasoningEffort.LOW,
@@ -156,8 +136,8 @@ class ModelFactory:
156
136
  # Check first part for provider
157
137
  if len(model_parts) > 1:
158
138
  potential_provider = model_parts[0]
159
- if potential_provider in cls.PROVIDER_MAP:
160
- provider = cls.PROVIDER_MAP[potential_provider]
139
+ if any(provider.value == potential_provider for provider in Provider):
140
+ provider = Provider(potential_provider)
161
141
  model_parts = model_parts[1:]
162
142
 
163
143
  # Join remaining parts as model name
@@ -0,0 +1,83 @@
1
+ """
2
+ Provider API key management for various LLM providers.
3
+ Centralizes API key handling logic to make provider implementations more generic.
4
+ """
5
+
6
+ import os
7
+ from typing import Any, Dict
8
+
9
+ from pydantic import BaseModel
10
+
11
+ from mcp_agent.core.exceptions import ProviderKeyError
12
+
13
+ PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
14
+ "anthropic": "ANTHROPIC_API_KEY",
15
+ "openai": "OPENAI_API_KEY",
16
+ "deepseek": "DEEPSEEK_API_KEY",
17
+ "openrouter": "OPENROUTER_API_KEY",
18
+ "generic": "GENERIC_API_KEY",
19
+ }
20
+ API_KEY_HINT_TEXT = "<your-api-key-here>"
21
+
22
+
23
+ class ProviderKeyManager:
24
+ """
25
+ Manages API keys for different providers centrally.
26
+ This class abstracts away the provider-specific key access logic,
27
+ making the provider implementations more generic.
28
+ """
29
+
30
+ @staticmethod
31
+ def get_env_var(provider_name: str) -> str | None:
32
+ return os.getenv(ProviderKeyManager.get_env_key_name(provider_name))
33
+
34
+ @staticmethod
35
+ def get_env_key_name(provider_name: str) -> str:
36
+ return PROVIDER_ENVIRONMENT_MAP.get(provider_name, f"{provider_name.upper()}_API_KEY")
37
+
38
+ @staticmethod
39
+ def get_config_file_key(provider_name: str, config: Any) -> str | None:
40
+ api_key = None
41
+ if isinstance(config, BaseModel):
42
+ config = config.model_dump()
43
+ provider_settings = config.get(provider_name)
44
+ if provider_settings:
45
+ api_key = provider_settings.get("api_key", API_KEY_HINT_TEXT)
46
+ if api_key == API_KEY_HINT_TEXT:
47
+ api_key = None
48
+
49
+ return api_key
50
+
51
+ @staticmethod
52
+ def get_api_key(provider_name: str, config: Any) -> str:
53
+ """
54
+ Gets the API key for the specified provider.
55
+
56
+ Args:
57
+ provider_name: Name of the provider (e.g., "anthropic", "openai")
58
+ config: The application configuration object
59
+
60
+ Returns:
61
+ The API key as a string
62
+
63
+ Raises:
64
+ ProviderKeyError: If the API key is not found or is invalid
65
+ """
66
+
67
+ provider_name = provider_name.lower()
68
+ api_key = ProviderKeyManager.get_config_file_key(provider_name, config)
69
+ if not api_key:
70
+ api_key = ProviderKeyManager.get_env_var(provider_name)
71
+
72
+ if not api_key and provider_name == "generic":
73
+ api_key = "ollama" # Default for generic provider
74
+
75
+ if not api_key:
76
+ raise ProviderKeyError(
77
+ f"{provider_name.title()} API key not configured",
78
+ f"The {provider_name.title()} API key is required but not set.\n"
79
+ f"Add it to your configuration file under {provider_name}.api_key "
80
+ f"or set the {ProviderKeyManager.get_env_key_name(provider_name)} environment variable.",
81
+ )
82
+
83
+ return api_key
@@ -0,0 +1,16 @@
1
+ """
2
+ Type definitions for LLM providers.
3
+ """
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class Provider(Enum):
9
+ """Supported LLM providers"""
10
+
11
+ ANTHROPIC = "anthropic"
12
+ OPENAI = "openai"
13
+ FAST_AGENT = "fast-agent"
14
+ DEEPSEEK = "deepseek"
15
+ GENERIC = "generic"
16
+ OPENROUTER = "openrouter"
@@ -1,9 +1,9 @@
1
- import os
2
1
  from typing import TYPE_CHECKING, List
3
2
 
4
3
  from mcp.types import EmbeddedResource, ImageContent, TextContent
5
4
 
6
5
  from mcp_agent.core.prompt import Prompt
6
+ from mcp_agent.llm.provider_types import Provider
7
7
  from mcp_agent.llm.providers.multipart_converter_anthropic import (
8
8
  AnthropicConverter,
9
9
  )
@@ -51,12 +51,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
51
51
  """
52
52
 
53
53
  def __init__(self, *args, **kwargs) -> None:
54
- self.provider = "Anthropic"
55
54
  # Initialize logger - keep it simple without name reference
56
55
  self.logger = get_logger(__name__)
57
56
 
58
- # Now call super().__init__
59
- super().__init__(*args, type_converter=AnthropicSamplingConverter, **kwargs)
57
+ super().__init__(
58
+ *args, provider=Provider.ANTHROPIC, type_converter=AnthropicSamplingConverter, **kwargs
59
+ )
60
60
 
61
61
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
62
62
  """Initialize Anthropic-specific default parameters"""
@@ -83,7 +83,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
83
83
  Override this method to use a different LLM.
84
84
  """
85
85
 
86
- api_key = self._api_key(self.context.config)
86
+ api_key = self._api_key()
87
87
  base_url = self._base_url()
88
88
  if base_url and base_url.endswith("/v1"):
89
89
  base_url = base_url.rstrip("/v1")
@@ -277,27 +277,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
277
277
 
278
278
  return responses
279
279
 
280
- def _api_key(self, config):
281
- api_key = None
282
-
283
- if hasattr(config, "anthropic") and config.anthropic:
284
- api_key = config.anthropic.api_key
285
- if api_key == "<your-api-key-here>":
286
- api_key = None
287
-
288
- if api_key is None:
289
- api_key = os.getenv("ANTHROPIC_API_KEY")
290
-
291
- if not api_key:
292
- raise ProviderKeyError(
293
- "Anthropic API key not configured",
294
- "The Anthropic API key is required but not set.\n"
295
- "Add it to your configuration file under anthropic.api_key "
296
- "or set the ANTHROPIC_API_KEY environment variable.",
297
- )
298
-
299
- return api_key
300
-
301
280
  async def generate_messages(
302
281
  self,
303
282
  message_param,
@@ -1,7 +1,6 @@
1
- import os
2
1
 
3
- from mcp_agent.core.exceptions import ProviderKeyError
4
2
  from mcp_agent.core.request_params import RequestParams
3
+ from mcp_agent.llm.provider_types import Provider
5
4
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
6
5
 
7
6
  DEEPSEEK_BASE_URL = "https://api.deepseek.com"
@@ -11,7 +10,9 @@ DEFAULT_DEEPSEEK_MODEL = "deepseekchat" # current Deepseek only has two type mo
11
10
  class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
12
11
  def __init__(self, *args, **kwargs) -> None:
13
12
  kwargs["provider_name"] = "Deepseek" # Set provider name in kwargs
14
- super().__init__(*args, **kwargs) # Properly pass args and kwargs to parent
13
+ super().__init__(
14
+ *args, provider=Provider.DEEPSEEK, **kwargs
15
+ ) # Properly pass args and kwargs to parent
15
16
 
16
17
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
17
18
  """Initialize Deepseek-specific default parameters"""
@@ -25,28 +26,8 @@ class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
25
26
  use_history=True,
26
27
  )
27
28
 
28
- def _api_key(self) -> str:
29
- config = self.context.config
30
- api_key = None
31
-
32
- if config and config.deepseek:
33
- api_key = config.deepseek.api_key
34
- if api_key == "<your-api-key-here>":
35
- api_key = None
36
-
37
- if api_key is None:
38
- api_key = os.getenv("DEEPSEEK_API_KEY")
39
-
40
- if not api_key:
41
- raise ProviderKeyError(
42
- "DEEPSEEK API key not configured",
43
- "The DEEKSEEK API key is required but not set.\n"
44
- "Add it to your configuration file under deepseek.api_key\n"
45
- "Or set the DEEPSEEK_API_KEY environment variable",
46
- )
47
- return api_key
48
-
49
29
  def _base_url(self) -> str:
30
+ base_url = None
50
31
  if self.context.config and self.context.config.deepseek:
51
32
  base_url = self.context.config.deepseek.base_url
52
33
 
@@ -1,6 +1,7 @@
1
1
  import os
2
2
 
3
3
  from mcp_agent.core.request_params import RequestParams
4
+ from mcp_agent.llm.provider_types import Provider
4
5
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
5
6
 
6
7
  DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434/v1"
@@ -9,8 +10,7 @@ DEFAULT_OLLAMA_API_KEY = "ollama"
9
10
 
10
11
 
11
12
  class GenericAugmentedLLM(OpenAIAugmentedLLM):
12
- def __init__(self, *args, **kwargs) -> None:
13
- kwargs["provider_name"] = "GenericOpenAI"
13
+ def __init__(self, *args, provider=Provider.GENERIC, **kwargs) -> None:
14
14
  super().__init__(*args, **kwargs) # Properly pass args and kwargs to parent
15
15
 
16
16
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
@@ -25,20 +25,6 @@ class GenericAugmentedLLM(OpenAIAugmentedLLM):
25
25
  use_history=True,
26
26
  )
27
27
 
28
- def _api_key(self) -> str:
29
- config = self.context.config
30
- api_key = None
31
-
32
- if config and config.generic:
33
- api_key = config.generic.api_key
34
- if api_key == "<your-api-key-here>":
35
- api_key = None
36
-
37
- if api_key is None:
38
- api_key = os.getenv("GENERIC_API_KEY")
39
-
40
- return api_key or "ollama"
41
-
42
28
  def _base_url(self) -> str:
43
29
  base_url = os.getenv("GENERIC_BASE_URL", DEFAULT_OLLAMA_BASE_URL)
44
30
  if self.context.config and self.context.config.generic:
@@ -1,4 +1,3 @@
1
- import os
2
1
  from typing import List, Tuple, Type
3
2
 
4
3
  from mcp.types import (
@@ -29,6 +28,7 @@ from mcp_agent.llm.augmented_llm import (
29
28
  ModelT,
30
29
  RequestParams,
31
30
  )
31
+ from mcp_agent.llm.provider_types import Provider
32
32
  from mcp_agent.llm.providers.multipart_converter_openai import OpenAIConverter
33
33
  from mcp_agent.llm.providers.sampling_converter_openai import (
34
34
  OpenAISamplingConverter,
@@ -49,14 +49,13 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
49
49
  This implementation uses OpenAI's ChatCompletion as the LLM.
50
50
  """
51
51
 
52
- def __init__(self, provider_name: str = "OpenAI", *args, **kwargs) -> None:
52
+ def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None:
53
53
  # Set type_converter before calling super().__init__
54
54
  if "type_converter" not in kwargs:
55
55
  kwargs["type_converter"] = OpenAISamplingConverter
56
56
 
57
- super().__init__(*args, **kwargs)
57
+ super().__init__(*args, provider=provider, **kwargs)
58
58
 
59
- self.provider = provider_name
60
59
  # Initialize logger with name if available
61
60
  self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__)
62
61
 
@@ -90,27 +89,6 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
90
89
  use_history=True,
91
90
  )
92
91
 
93
- def _api_key(self) -> str:
94
- config = self.context.config
95
- api_key = None
96
-
97
- if hasattr(config, "openai") and config.openai:
98
- api_key = config.openai.api_key
99
- if api_key == "<your-api-key-here>":
100
- api_key = None
101
-
102
- if api_key is None:
103
- api_key = os.getenv("OPENAI_API_KEY")
104
-
105
- if not api_key:
106
- raise ProviderKeyError(
107
- "OpenAI API key not configured",
108
- "The OpenAI API key is required but not set.\n"
109
- "Add it to your configuration file under openai.api_key\n"
110
- "Or set the OPENAI_API_KEY environment variable",
111
- )
112
- return api_key
113
-
114
92
  def _base_url(self) -> str:
115
93
  return self.context.config.openai.base_url if self.context.config.openai else None
116
94
 
@@ -371,7 +349,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
371
349
  The parsed response as a Pydantic model, or None if parsing fails
372
350
  """
373
351
 
374
- if not "OpenAI" == self.provider:
352
+ if not Provider.OPENAI == self.provider:
375
353
  return await super().structured(prompt, model, request_params)
376
354
 
377
355
  logger = get_logger(__name__)
@@ -1,19 +1,19 @@
1
1
  import os
2
2
 
3
- from mcp_agent.core.exceptions import ProviderKeyError
4
3
  from mcp_agent.core.request_params import RequestParams
4
+ from mcp_agent.llm.provider_types import Provider
5
5
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
6
6
 
7
7
  DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
8
8
  # No single default model for OpenRouter, users must specify full path
9
- DEFAULT_OPENROUTER_MODEL = None
9
+ DEFAULT_OPENROUTER_MODEL = None
10
10
 
11
11
 
12
12
  class OpenRouterAugmentedLLM(OpenAIAugmentedLLM):
13
13
  """Augmented LLM provider for OpenRouter, using an OpenAI-compatible API."""
14
+
14
15
  def __init__(self, *args, **kwargs) -> None:
15
- kwargs["provider_name"] = "OpenRouter" # Set provider name
16
- super().__init__(*args, **kwargs)
16
+ super().__init__(*args, provider=Provider.OPENROUTER, **kwargs)
17
17
 
18
18
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
19
19
  """Initialize OpenRouter-specific default parameters."""
@@ -21,58 +21,30 @@ class OpenRouterAugmentedLLM(OpenAIAugmentedLLM):
21
21
  # The model should be passed in the 'model' kwarg during factory creation.
22
22
  chosen_model = kwargs.get("model", DEFAULT_OPENROUTER_MODEL)
23
23
  if not chosen_model:
24
- # Unlike Deepseek, OpenRouter *requires* a model path in the identifier.
25
- # The factory should extract this before calling the constructor.
26
- # We rely on the model being passed correctly via kwargs.
27
- # If it's still None here, it indicates an issue upstream (factory or user input).
28
- # However, the base class _get_model handles the error if model is None.
29
- pass
30
-
24
+ # Unlike Deepseek, OpenRouter *requires* a model path in the identifier.
25
+ # The factory should extract this before calling the constructor.
26
+ # We rely on the model being passed correctly via kwargs.
27
+ # If it's still None here, it indicates an issue upstream (factory or user input).
28
+ # However, the base class _get_model handles the error if model is None.
29
+ pass
31
30
 
32
31
  return RequestParams(
33
- model=chosen_model, # Will be validated by base class
32
+ model=chosen_model, # Will be validated by base class
34
33
  systemPrompt=self.instruction,
35
- parallel_tool_calls=True, # Default based on OpenAI provider
36
- max_iterations=10, # Default based on OpenAI provider
37
- use_history=True, # Default based on OpenAI provider
34
+ parallel_tool_calls=True, # Default based on OpenAI provider
35
+ max_iterations=10, # Default based on OpenAI provider
36
+ use_history=True, # Default based on OpenAI provider
38
37
  )
39
38
 
40
- def _api_key(self) -> str:
41
- """Retrieve the OpenRouter API key from config or environment variables."""
42
- config = self.context.config
43
- api_key = None
44
-
45
- # Check config file first
46
- if config and hasattr(config, 'openrouter') and config.openrouter:
47
- api_key = getattr(config.openrouter, 'api_key', None)
48
- if api_key == "<your-openrouter-api-key-here>" or not api_key:
49
- api_key = None
50
-
51
- # Fallback to environment variable
52
- if api_key is None:
53
- api_key = os.getenv("OPENROUTER_API_KEY")
54
-
55
- if not api_key:
56
- raise ProviderKeyError(
57
- "OpenRouter API key not configured",
58
- "The OpenRouter API key is required but not set.\n"
59
- "Add it to your configuration file under openrouter.api_key\n"
60
- "Or set the OPENROUTER_API_KEY environment variable.",
61
- )
62
- return api_key
63
-
64
39
  def _base_url(self) -> str:
65
40
  """Retrieve the OpenRouter base URL from config or use the default."""
66
41
  base_url = os.getenv("OPENROUTER_BASE_URL", DEFAULT_OPENROUTER_BASE_URL) # Default
67
42
  config = self.context.config
68
-
43
+
69
44
  # Check config file for override
70
- if config and hasattr(config, 'openrouter') and config.openrouter:
71
- config_base_url = getattr(config.openrouter, 'base_url', None)
45
+ if config and hasattr(config, "openrouter") and config.openrouter:
46
+ config_base_url = getattr(config.openrouter, "base_url", None)
72
47
  if config_base_url:
73
48
  base_url = config_base_url
74
49
 
75
50
  return base_url
76
-
77
- # Other methods like _get_model, _send_request etc., are inherited from OpenAIAugmentedLLM
78
- # We may override them later if OpenRouter deviates significantly or offers unique features.
@@ -26,6 +26,7 @@ from mcp import ClientSession
26
26
  from mcp.types import GetPromptResult, Prompt, PromptMessage, ReadResourceResult
27
27
  from pydantic import BaseModel
28
28
 
29
+ from mcp_agent.core.agent_types import AgentType
29
30
  from mcp_agent.core.request_params import RequestParams
30
31
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
31
32
 
@@ -137,7 +138,7 @@ class AgentProtocol(AugmentedLLMProtocol, Protocol):
137
138
  name: str
138
139
 
139
140
  @property
140
- def agent_type(self) -> str:
141
+ def agent_type(self) -> AgentType:
141
142
  """Return the type of this agent"""
142
143
  ...
143
144