fast-agent-mcp 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/RECORD +33 -33
- mcp_agent/agents/agent.py +2 -2
- mcp_agent/agents/base_agent.py +3 -3
- mcp_agent/agents/workflow/chain_agent.py +2 -2
- mcp_agent/agents/workflow/evaluator_optimizer.py +3 -3
- mcp_agent/agents/workflow/orchestrator_agent.py +3 -3
- mcp_agent/agents/workflow/parallel_agent.py +2 -2
- mcp_agent/agents/workflow/router_agent.py +2 -2
- mcp_agent/cli/commands/check_config.py +450 -0
- mcp_agent/cli/commands/setup.py +1 -1
- mcp_agent/cli/main.py +8 -15
- mcp_agent/core/agent_types.py +8 -8
- mcp_agent/core/direct_decorators.py +10 -8
- mcp_agent/core/direct_factory.py +4 -1
- mcp_agent/core/validation.py +6 -4
- mcp_agent/event_progress.py +6 -6
- mcp_agent/llm/augmented_llm.py +10 -2
- mcp_agent/llm/augmented_llm_passthrough.py +5 -3
- mcp_agent/llm/augmented_llm_playback.py +2 -1
- mcp_agent/llm/model_factory.py +7 -27
- mcp_agent/llm/provider_key_manager.py +83 -0
- mcp_agent/llm/provider_types.py +16 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +5 -26
- mcp_agent/llm/providers/augmented_llm_deepseek.py +5 -24
- mcp_agent/llm/providers/augmented_llm_generic.py +2 -16
- mcp_agent/llm/providers/augmented_llm_openai.py +4 -26
- mcp_agent/llm/providers/augmented_llm_openrouter.py +17 -45
- mcp_agent/mcp/interfaces.py +2 -1
- mcp_agent/mcp_server/agent_server.py +335 -14
- mcp_agent/cli/commands/config.py +0 -11
- mcp_agent/executor/temporal.py +0 -383
- mcp_agent/executor/workflow.py +0 -195
- {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.12.dist-info → fast_agent_mcp-0.2.14.dist-info}/licenses/LICENSE +0 -0
mcp_agent/llm/augmented_llm.py
CHANGED
@@ -27,6 +27,7 @@ from mcp_agent.core.prompt import Prompt
|
|
27
27
|
from mcp_agent.core.request_params import RequestParams
|
28
28
|
from mcp_agent.event_progress import ProgressAction
|
29
29
|
from mcp_agent.llm.memory import Memory, SimpleMemory
|
30
|
+
from mcp_agent.llm.provider_types import Provider
|
30
31
|
from mcp_agent.llm.sampling_format_converter import (
|
31
32
|
BasicFormatConverter,
|
32
33
|
ProviderFormatConverter,
|
@@ -64,10 +65,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
64
65
|
selecting appropriate tools, and determining what information to retain.
|
65
66
|
"""
|
66
67
|
|
67
|
-
provider:
|
68
|
+
provider: Provider | None = None
|
68
69
|
|
69
70
|
def __init__(
|
70
71
|
self,
|
72
|
+
provider: Provider,
|
71
73
|
agent: Optional["Agent"] = None,
|
72
74
|
server_names: List[str] | None = None,
|
73
75
|
instruction: str | None = None,
|
@@ -104,7 +106,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
104
106
|
self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
|
105
107
|
self.name = agent.name if agent else name
|
106
108
|
self.instruction = agent.instruction if agent else instruction
|
107
|
-
|
109
|
+
self.provider = provider
|
108
110
|
# memory contains provider specific API types.
|
109
111
|
self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
|
110
112
|
|
@@ -480,3 +482,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
480
482
|
List of PromptMessageMultipart objects representing the conversation history
|
481
483
|
"""
|
482
484
|
return self._message_history
|
485
|
+
|
486
|
+
def _api_key(self):
|
487
|
+
from mcp_agent.llm.provider_key_manager import ProviderKeyManager
|
488
|
+
|
489
|
+
assert self.provider
|
490
|
+
return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
|
@@ -9,6 +9,7 @@ from mcp_agent.llm.augmented_llm import (
|
|
9
9
|
MessageParamT,
|
10
10
|
RequestParams,
|
11
11
|
)
|
12
|
+
from mcp_agent.llm.provider_types import Provider
|
12
13
|
from mcp_agent.logging.logger import get_logger
|
13
14
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
14
15
|
|
@@ -25,9 +26,10 @@ class PassthroughLLM(AugmentedLLM):
|
|
25
26
|
parallel workflow where no fan-in aggregation is needed.
|
26
27
|
"""
|
27
28
|
|
28
|
-
def __init__(
|
29
|
-
|
30
|
-
|
29
|
+
def __init__(
|
30
|
+
self, provider=Provider.FAST_AGENT, name: str = "Passthrough", **kwargs: dict[str, Any]
|
31
|
+
) -> None:
|
32
|
+
super().__init__(name=name, provider=provider, **kwargs)
|
31
33
|
self.logger = get_logger(__name__)
|
32
34
|
self._messages = [PromptMessage]
|
33
35
|
self._fixed_response: str | None = None
|
@@ -3,6 +3,7 @@ from typing import Any, List
|
|
3
3
|
from mcp_agent.core.prompt import Prompt
|
4
4
|
from mcp_agent.llm.augmented_llm import RequestParams
|
5
5
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
6
|
+
from mcp_agent.llm.provider_types import Provider
|
6
7
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
7
8
|
from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
|
8
9
|
|
@@ -21,7 +22,7 @@ class PlaybackLLM(PassthroughLLM):
|
|
21
22
|
"""
|
22
23
|
|
23
24
|
def __init__(self, name: str = "Playback", **kwargs: dict[str, Any]) -> None:
|
24
|
-
super().__init__(name=name, **kwargs)
|
25
|
+
super().__init__(name=name, provider=Provider.FAST_AGENT, **kwargs)
|
25
26
|
self._messages: List[PromptMessageMultipart] = []
|
26
27
|
self._current_index = -1
|
27
28
|
self._overage = -1
|
mcp_agent/llm/model_factory.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1
|
-
from
|
2
|
-
from enum import Enum, auto
|
1
|
+
from enum import Enum
|
3
2
|
from typing import Callable, Dict, Optional, Type, Union
|
4
3
|
|
4
|
+
from pydantic import BaseModel
|
5
|
+
|
5
6
|
from mcp_agent.agents.agent import Agent
|
6
7
|
from mcp_agent.core.exceptions import ModelConfigError
|
7
8
|
from mcp_agent.core.request_params import RequestParams
|
8
9
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
9
10
|
from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
|
11
|
+
from mcp_agent.llm.provider_types import Provider
|
10
12
|
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
|
11
13
|
from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM
|
12
14
|
from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
|
@@ -28,17 +30,6 @@ LLMClass = Union[
|
|
28
30
|
]
|
29
31
|
|
30
32
|
|
31
|
-
class Provider(Enum):
|
32
|
-
"""Supported LLM providers"""
|
33
|
-
|
34
|
-
ANTHROPIC = auto()
|
35
|
-
OPENAI = auto()
|
36
|
-
FAST_AGENT = auto()
|
37
|
-
DEEPSEEK = auto()
|
38
|
-
GENERIC = auto()
|
39
|
-
OPENROUTER = auto()
|
40
|
-
|
41
|
-
|
42
33
|
class ReasoningEffort(Enum):
|
43
34
|
"""Optional reasoning effort levels"""
|
44
35
|
|
@@ -47,8 +38,7 @@ class ReasoningEffort(Enum):
|
|
47
38
|
HIGH = "high"
|
48
39
|
|
49
40
|
|
50
|
-
|
51
|
-
class ModelConfig:
|
41
|
+
class ModelConfig(BaseModel):
|
52
42
|
"""Configuration for a specific model"""
|
53
43
|
|
54
44
|
provider: Provider
|
@@ -59,16 +49,6 @@ class ModelConfig:
|
|
59
49
|
class ModelFactory:
|
60
50
|
"""Factory for creating LLM instances based on model specifications"""
|
61
51
|
|
62
|
-
# Mapping of provider strings to enum values
|
63
|
-
PROVIDER_MAP = {
|
64
|
-
"anthropic": Provider.ANTHROPIC,
|
65
|
-
"openai": Provider.OPENAI,
|
66
|
-
"fast-agent": Provider.FAST_AGENT,
|
67
|
-
"deepseek": Provider.DEEPSEEK,
|
68
|
-
"generic": Provider.GENERIC,
|
69
|
-
"openrouter": Provider.OPENROUTER,
|
70
|
-
}
|
71
|
-
|
72
52
|
# Mapping of effort strings to enum values
|
73
53
|
EFFORT_MAP = {
|
74
54
|
"low": ReasoningEffort.LOW,
|
@@ -156,8 +136,8 @@ class ModelFactory:
|
|
156
136
|
# Check first part for provider
|
157
137
|
if len(model_parts) > 1:
|
158
138
|
potential_provider = model_parts[0]
|
159
|
-
if potential_provider in
|
160
|
-
provider =
|
139
|
+
if any(provider.value == potential_provider for provider in Provider):
|
140
|
+
provider = Provider(potential_provider)
|
161
141
|
model_parts = model_parts[1:]
|
162
142
|
|
163
143
|
# Join remaining parts as model name
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""
|
2
|
+
Provider API key management for various LLM providers.
|
3
|
+
Centralizes API key handling logic to make provider implementations more generic.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import os
|
7
|
+
from typing import Any, Dict
|
8
|
+
|
9
|
+
from pydantic import BaseModel
|
10
|
+
|
11
|
+
from mcp_agent.core.exceptions import ProviderKeyError
|
12
|
+
|
13
|
+
PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
|
14
|
+
"anthropic": "ANTHROPIC_API_KEY",
|
15
|
+
"openai": "OPENAI_API_KEY",
|
16
|
+
"deepseek": "DEEPSEEK_API_KEY",
|
17
|
+
"openrouter": "OPENROUTER_API_KEY",
|
18
|
+
"generic": "GENERIC_API_KEY",
|
19
|
+
}
|
20
|
+
API_KEY_HINT_TEXT = "<your-api-key-here>"
|
21
|
+
|
22
|
+
|
23
|
+
class ProviderKeyManager:
|
24
|
+
"""
|
25
|
+
Manages API keys for different providers centrally.
|
26
|
+
This class abstracts away the provider-specific key access logic,
|
27
|
+
making the provider implementations more generic.
|
28
|
+
"""
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def get_env_var(provider_name: str) -> str | None:
|
32
|
+
return os.getenv(ProviderKeyManager.get_env_key_name(provider_name))
|
33
|
+
|
34
|
+
@staticmethod
|
35
|
+
def get_env_key_name(provider_name: str) -> str:
|
36
|
+
return PROVIDER_ENVIRONMENT_MAP.get(provider_name, f"{provider_name.upper()}_API_KEY")
|
37
|
+
|
38
|
+
@staticmethod
|
39
|
+
def get_config_file_key(provider_name: str, config: Any) -> str | None:
|
40
|
+
api_key = None
|
41
|
+
if isinstance(config, BaseModel):
|
42
|
+
config = config.model_dump()
|
43
|
+
provider_settings = config.get(provider_name)
|
44
|
+
if provider_settings:
|
45
|
+
api_key = provider_settings.get("api_key", API_KEY_HINT_TEXT)
|
46
|
+
if api_key == API_KEY_HINT_TEXT:
|
47
|
+
api_key = None
|
48
|
+
|
49
|
+
return api_key
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def get_api_key(provider_name: str, config: Any) -> str:
|
53
|
+
"""
|
54
|
+
Gets the API key for the specified provider.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
provider_name: Name of the provider (e.g., "anthropic", "openai")
|
58
|
+
config: The application configuration object
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
The API key as a string
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
ProviderKeyError: If the API key is not found or is invalid
|
65
|
+
"""
|
66
|
+
|
67
|
+
provider_name = provider_name.lower()
|
68
|
+
api_key = ProviderKeyManager.get_config_file_key(provider_name, config)
|
69
|
+
if not api_key:
|
70
|
+
api_key = ProviderKeyManager.get_env_var(provider_name)
|
71
|
+
|
72
|
+
if not api_key and provider_name == "generic":
|
73
|
+
api_key = "ollama" # Default for generic provider
|
74
|
+
|
75
|
+
if not api_key:
|
76
|
+
raise ProviderKeyError(
|
77
|
+
f"{provider_name.title()} API key not configured",
|
78
|
+
f"The {provider_name.title()} API key is required but not set.\n"
|
79
|
+
f"Add it to your configuration file under {provider_name}.api_key "
|
80
|
+
f"or set the {ProviderKeyManager.get_env_key_name(provider_name)} environment variable.",
|
81
|
+
)
|
82
|
+
|
83
|
+
return api_key
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Type definitions for LLM providers.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from enum import Enum
|
6
|
+
|
7
|
+
|
8
|
+
class Provider(Enum):
|
9
|
+
"""Supported LLM providers"""
|
10
|
+
|
11
|
+
ANTHROPIC = "anthropic"
|
12
|
+
OPENAI = "openai"
|
13
|
+
FAST_AGENT = "fast-agent"
|
14
|
+
DEEPSEEK = "deepseek"
|
15
|
+
GENERIC = "generic"
|
16
|
+
OPENROUTER = "openrouter"
|
@@ -1,9 +1,9 @@
|
|
1
|
-
import os
|
2
1
|
from typing import TYPE_CHECKING, List
|
3
2
|
|
4
3
|
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
5
4
|
|
6
5
|
from mcp_agent.core.prompt import Prompt
|
6
|
+
from mcp_agent.llm.provider_types import Provider
|
7
7
|
from mcp_agent.llm.providers.multipart_converter_anthropic import (
|
8
8
|
AnthropicConverter,
|
9
9
|
)
|
@@ -51,12 +51,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
51
51
|
"""
|
52
52
|
|
53
53
|
def __init__(self, *args, **kwargs) -> None:
|
54
|
-
self.provider = "Anthropic"
|
55
54
|
# Initialize logger - keep it simple without name reference
|
56
55
|
self.logger = get_logger(__name__)
|
57
56
|
|
58
|
-
|
59
|
-
|
57
|
+
super().__init__(
|
58
|
+
*args, provider=Provider.ANTHROPIC, type_converter=AnthropicSamplingConverter, **kwargs
|
59
|
+
)
|
60
60
|
|
61
61
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
62
62
|
"""Initialize Anthropic-specific default parameters"""
|
@@ -83,7 +83,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
83
83
|
Override this method to use a different LLM.
|
84
84
|
"""
|
85
85
|
|
86
|
-
api_key = self._api_key(
|
86
|
+
api_key = self._api_key()
|
87
87
|
base_url = self._base_url()
|
88
88
|
if base_url and base_url.endswith("/v1"):
|
89
89
|
base_url = base_url.rstrip("/v1")
|
@@ -277,27 +277,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
277
277
|
|
278
278
|
return responses
|
279
279
|
|
280
|
-
def _api_key(self, config):
|
281
|
-
api_key = None
|
282
|
-
|
283
|
-
if hasattr(config, "anthropic") and config.anthropic:
|
284
|
-
api_key = config.anthropic.api_key
|
285
|
-
if api_key == "<your-api-key-here>":
|
286
|
-
api_key = None
|
287
|
-
|
288
|
-
if api_key is None:
|
289
|
-
api_key = os.getenv("ANTHROPIC_API_KEY")
|
290
|
-
|
291
|
-
if not api_key:
|
292
|
-
raise ProviderKeyError(
|
293
|
-
"Anthropic API key not configured",
|
294
|
-
"The Anthropic API key is required but not set.\n"
|
295
|
-
"Add it to your configuration file under anthropic.api_key "
|
296
|
-
"or set the ANTHROPIC_API_KEY environment variable.",
|
297
|
-
)
|
298
|
-
|
299
|
-
return api_key
|
300
|
-
|
301
280
|
async def generate_messages(
|
302
281
|
self,
|
303
282
|
message_param,
|
@@ -1,7 +1,6 @@
|
|
1
|
-
import os
|
2
1
|
|
3
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
4
2
|
from mcp_agent.core.request_params import RequestParams
|
3
|
+
from mcp_agent.llm.provider_types import Provider
|
5
4
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
6
5
|
|
7
6
|
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
@@ -11,7 +10,9 @@ DEFAULT_DEEPSEEK_MODEL = "deepseekchat" # current Deepseek only has two type mo
|
|
11
10
|
class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
|
12
11
|
def __init__(self, *args, **kwargs) -> None:
|
13
12
|
kwargs["provider_name"] = "Deepseek" # Set provider name in kwargs
|
14
|
-
super().__init__(
|
13
|
+
super().__init__(
|
14
|
+
*args, provider=Provider.DEEPSEEK, **kwargs
|
15
|
+
) # Properly pass args and kwargs to parent
|
15
16
|
|
16
17
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
17
18
|
"""Initialize Deepseek-specific default parameters"""
|
@@ -25,28 +26,8 @@ class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
|
|
25
26
|
use_history=True,
|
26
27
|
)
|
27
28
|
|
28
|
-
def _api_key(self) -> str:
|
29
|
-
config = self.context.config
|
30
|
-
api_key = None
|
31
|
-
|
32
|
-
if config and config.deepseek:
|
33
|
-
api_key = config.deepseek.api_key
|
34
|
-
if api_key == "<your-api-key-here>":
|
35
|
-
api_key = None
|
36
|
-
|
37
|
-
if api_key is None:
|
38
|
-
api_key = os.getenv("DEEPSEEK_API_KEY")
|
39
|
-
|
40
|
-
if not api_key:
|
41
|
-
raise ProviderKeyError(
|
42
|
-
"DEEPSEEK API key not configured",
|
43
|
-
"The DEEKSEEK API key is required but not set.\n"
|
44
|
-
"Add it to your configuration file under deepseek.api_key\n"
|
45
|
-
"Or set the DEEPSEEK_API_KEY environment variable",
|
46
|
-
)
|
47
|
-
return api_key
|
48
|
-
|
49
29
|
def _base_url(self) -> str:
|
30
|
+
base_url = None
|
50
31
|
if self.context.config and self.context.config.deepseek:
|
51
32
|
base_url = self.context.config.deepseek.base_url
|
52
33
|
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
|
3
3
|
from mcp_agent.core.request_params import RequestParams
|
4
|
+
from mcp_agent.llm.provider_types import Provider
|
4
5
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
5
6
|
|
6
7
|
DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434/v1"
|
@@ -9,8 +10,7 @@ DEFAULT_OLLAMA_API_KEY = "ollama"
|
|
9
10
|
|
10
11
|
|
11
12
|
class GenericAugmentedLLM(OpenAIAugmentedLLM):
|
12
|
-
def __init__(self, *args, **kwargs) -> None:
|
13
|
-
kwargs["provider_name"] = "GenericOpenAI"
|
13
|
+
def __init__(self, *args, provider=Provider.GENERIC, **kwargs) -> None:
|
14
14
|
super().__init__(*args, **kwargs) # Properly pass args and kwargs to parent
|
15
15
|
|
16
16
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
@@ -25,20 +25,6 @@ class GenericAugmentedLLM(OpenAIAugmentedLLM):
|
|
25
25
|
use_history=True,
|
26
26
|
)
|
27
27
|
|
28
|
-
def _api_key(self) -> str:
|
29
|
-
config = self.context.config
|
30
|
-
api_key = None
|
31
|
-
|
32
|
-
if config and config.generic:
|
33
|
-
api_key = config.generic.api_key
|
34
|
-
if api_key == "<your-api-key-here>":
|
35
|
-
api_key = None
|
36
|
-
|
37
|
-
if api_key is None:
|
38
|
-
api_key = os.getenv("GENERIC_API_KEY")
|
39
|
-
|
40
|
-
return api_key or "ollama"
|
41
|
-
|
42
28
|
def _base_url(self) -> str:
|
43
29
|
base_url = os.getenv("GENERIC_BASE_URL", DEFAULT_OLLAMA_BASE_URL)
|
44
30
|
if self.context.config and self.context.config.generic:
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import os
|
2
1
|
from typing import List, Tuple, Type
|
3
2
|
|
4
3
|
from mcp.types import (
|
@@ -29,6 +28,7 @@ from mcp_agent.llm.augmented_llm import (
|
|
29
28
|
ModelT,
|
30
29
|
RequestParams,
|
31
30
|
)
|
31
|
+
from mcp_agent.llm.provider_types import Provider
|
32
32
|
from mcp_agent.llm.providers.multipart_converter_openai import OpenAIConverter
|
33
33
|
from mcp_agent.llm.providers.sampling_converter_openai import (
|
34
34
|
OpenAISamplingConverter,
|
@@ -49,14 +49,13 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
49
49
|
This implementation uses OpenAI's ChatCompletion as the LLM.
|
50
50
|
"""
|
51
51
|
|
52
|
-
def __init__(self,
|
52
|
+
def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None:
|
53
53
|
# Set type_converter before calling super().__init__
|
54
54
|
if "type_converter" not in kwargs:
|
55
55
|
kwargs["type_converter"] = OpenAISamplingConverter
|
56
56
|
|
57
|
-
super().__init__(*args, **kwargs)
|
57
|
+
super().__init__(*args, provider=provider, **kwargs)
|
58
58
|
|
59
|
-
self.provider = provider_name
|
60
59
|
# Initialize logger with name if available
|
61
60
|
self.logger = get_logger(f"{__name__}.{self.name}" if self.name else __name__)
|
62
61
|
|
@@ -90,27 +89,6 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
90
89
|
use_history=True,
|
91
90
|
)
|
92
91
|
|
93
|
-
def _api_key(self) -> str:
|
94
|
-
config = self.context.config
|
95
|
-
api_key = None
|
96
|
-
|
97
|
-
if hasattr(config, "openai") and config.openai:
|
98
|
-
api_key = config.openai.api_key
|
99
|
-
if api_key == "<your-api-key-here>":
|
100
|
-
api_key = None
|
101
|
-
|
102
|
-
if api_key is None:
|
103
|
-
api_key = os.getenv("OPENAI_API_KEY")
|
104
|
-
|
105
|
-
if not api_key:
|
106
|
-
raise ProviderKeyError(
|
107
|
-
"OpenAI API key not configured",
|
108
|
-
"The OpenAI API key is required but not set.\n"
|
109
|
-
"Add it to your configuration file under openai.api_key\n"
|
110
|
-
"Or set the OPENAI_API_KEY environment variable",
|
111
|
-
)
|
112
|
-
return api_key
|
113
|
-
|
114
92
|
def _base_url(self) -> str:
|
115
93
|
return self.context.config.openai.base_url if self.context.config.openai else None
|
116
94
|
|
@@ -371,7 +349,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
371
349
|
The parsed response as a Pydantic model, or None if parsing fails
|
372
350
|
"""
|
373
351
|
|
374
|
-
if not
|
352
|
+
if not Provider.OPENAI == self.provider:
|
375
353
|
return await super().structured(prompt, model, request_params)
|
376
354
|
|
377
355
|
logger = get_logger(__name__)
|
@@ -1,19 +1,19 @@
|
|
1
1
|
import os
|
2
2
|
|
3
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
4
3
|
from mcp_agent.core.request_params import RequestParams
|
4
|
+
from mcp_agent.llm.provider_types import Provider
|
5
5
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
6
6
|
|
7
7
|
DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
8
8
|
# No single default model for OpenRouter, users must specify full path
|
9
|
-
DEFAULT_OPENROUTER_MODEL = None
|
9
|
+
DEFAULT_OPENROUTER_MODEL = None
|
10
10
|
|
11
11
|
|
12
12
|
class OpenRouterAugmentedLLM(OpenAIAugmentedLLM):
|
13
13
|
"""Augmented LLM provider for OpenRouter, using an OpenAI-compatible API."""
|
14
|
+
|
14
15
|
def __init__(self, *args, **kwargs) -> None:
|
15
|
-
|
16
|
-
super().__init__(*args, **kwargs)
|
16
|
+
super().__init__(*args, provider=Provider.OPENROUTER, **kwargs)
|
17
17
|
|
18
18
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
19
19
|
"""Initialize OpenRouter-specific default parameters."""
|
@@ -21,58 +21,30 @@ class OpenRouterAugmentedLLM(OpenAIAugmentedLLM):
|
|
21
21
|
# The model should be passed in the 'model' kwarg during factory creation.
|
22
22
|
chosen_model = kwargs.get("model", DEFAULT_OPENROUTER_MODEL)
|
23
23
|
if not chosen_model:
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
24
|
+
# Unlike Deepseek, OpenRouter *requires* a model path in the identifier.
|
25
|
+
# The factory should extract this before calling the constructor.
|
26
|
+
# We rely on the model being passed correctly via kwargs.
|
27
|
+
# If it's still None here, it indicates an issue upstream (factory or user input).
|
28
|
+
# However, the base class _get_model handles the error if model is None.
|
29
|
+
pass
|
31
30
|
|
32
31
|
return RequestParams(
|
33
|
-
model=chosen_model,
|
32
|
+
model=chosen_model, # Will be validated by base class
|
34
33
|
systemPrompt=self.instruction,
|
35
|
-
parallel_tool_calls=True,
|
36
|
-
max_iterations=10,
|
37
|
-
use_history=True,
|
34
|
+
parallel_tool_calls=True, # Default based on OpenAI provider
|
35
|
+
max_iterations=10, # Default based on OpenAI provider
|
36
|
+
use_history=True, # Default based on OpenAI provider
|
38
37
|
)
|
39
38
|
|
40
|
-
def _api_key(self) -> str:
|
41
|
-
"""Retrieve the OpenRouter API key from config or environment variables."""
|
42
|
-
config = self.context.config
|
43
|
-
api_key = None
|
44
|
-
|
45
|
-
# Check config file first
|
46
|
-
if config and hasattr(config, 'openrouter') and config.openrouter:
|
47
|
-
api_key = getattr(config.openrouter, 'api_key', None)
|
48
|
-
if api_key == "<your-openrouter-api-key-here>" or not api_key:
|
49
|
-
api_key = None
|
50
|
-
|
51
|
-
# Fallback to environment variable
|
52
|
-
if api_key is None:
|
53
|
-
api_key = os.getenv("OPENROUTER_API_KEY")
|
54
|
-
|
55
|
-
if not api_key:
|
56
|
-
raise ProviderKeyError(
|
57
|
-
"OpenRouter API key not configured",
|
58
|
-
"The OpenRouter API key is required but not set.\n"
|
59
|
-
"Add it to your configuration file under openrouter.api_key\n"
|
60
|
-
"Or set the OPENROUTER_API_KEY environment variable.",
|
61
|
-
)
|
62
|
-
return api_key
|
63
|
-
|
64
39
|
def _base_url(self) -> str:
|
65
40
|
"""Retrieve the OpenRouter base URL from config or use the default."""
|
66
41
|
base_url = os.getenv("OPENROUTER_BASE_URL", DEFAULT_OPENROUTER_BASE_URL) # Default
|
67
42
|
config = self.context.config
|
68
|
-
|
43
|
+
|
69
44
|
# Check config file for override
|
70
|
-
if config and hasattr(config,
|
71
|
-
config_base_url = getattr(config.openrouter,
|
45
|
+
if config and hasattr(config, "openrouter") and config.openrouter:
|
46
|
+
config_base_url = getattr(config.openrouter, "base_url", None)
|
72
47
|
if config_base_url:
|
73
48
|
base_url = config_base_url
|
74
49
|
|
75
50
|
return base_url
|
76
|
-
|
77
|
-
# Other methods like _get_model, _send_request etc., are inherited from OpenAIAugmentedLLM
|
78
|
-
# We may override them later if OpenRouter deviates significantly or offers unique features.
|
mcp_agent/mcp/interfaces.py
CHANGED
@@ -26,6 +26,7 @@ from mcp import ClientSession
|
|
26
26
|
from mcp.types import GetPromptResult, Prompt, PromptMessage, ReadResourceResult
|
27
27
|
from pydantic import BaseModel
|
28
28
|
|
29
|
+
from mcp_agent.core.agent_types import AgentType
|
29
30
|
from mcp_agent.core.request_params import RequestParams
|
30
31
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
31
32
|
|
@@ -137,7 +138,7 @@ class AgentProtocol(AugmentedLLMProtocol, Protocol):
|
|
137
138
|
name: str
|
138
139
|
|
139
140
|
@property
|
140
|
-
def agent_type(self) ->
|
141
|
+
def agent_type(self) -> AgentType:
|
141
142
|
"""Return the type of this agent"""
|
142
143
|
...
|
143
144
|
|