fast-agent-mcp 0.2.40__py3-none-any.whl → 0.2.41__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/RECORD +41 -37
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/entry_points.txt +2 -2
- mcp_agent/cli/__main__.py +29 -3
- mcp_agent/cli/commands/check_config.py +140 -81
- mcp_agent/cli/commands/go.py +151 -38
- mcp_agent/cli/commands/quickstart.py +6 -2
- mcp_agent/cli/commands/server_helpers.py +106 -0
- mcp_agent/cli/constants.py +25 -0
- mcp_agent/cli/main.py +1 -1
- mcp_agent/config.py +94 -44
- mcp_agent/core/agent_app.py +104 -15
- mcp_agent/core/agent_types.py +1 -0
- mcp_agent/core/direct_decorators.py +9 -0
- mcp_agent/core/direct_factory.py +18 -4
- mcp_agent/core/enhanced_prompt.py +165 -13
- mcp_agent/core/fastagent.py +4 -0
- mcp_agent/core/interactive_prompt.py +37 -37
- mcp_agent/core/usage_display.py +11 -1
- mcp_agent/core/validation.py +21 -2
- mcp_agent/human_input/elicitation_form.py +53 -21
- mcp_agent/llm/augmented_llm.py +28 -9
- mcp_agent/llm/augmented_llm_silent.py +48 -0
- mcp_agent/llm/model_database.py +20 -0
- mcp_agent/llm/model_factory.py +12 -0
- mcp_agent/llm/provider_key_manager.py +22 -8
- mcp_agent/llm/provider_types.py +19 -12
- mcp_agent/llm/providers/augmented_llm_anthropic.py +7 -2
- mcp_agent/llm/providers/augmented_llm_azure.py +7 -1
- mcp_agent/llm/providers/augmented_llm_google_native.py +4 -1
- mcp_agent/llm/providers/augmented_llm_openai.py +9 -2
- mcp_agent/llm/providers/augmented_llm_xai.py +38 -0
- mcp_agent/llm/usage_tracking.py +28 -3
- mcp_agent/mcp/mcp_agent_client_session.py +2 -0
- mcp_agent/mcp/mcp_aggregator.py +38 -44
- mcp_agent/mcp/sampling.py +15 -11
- mcp_agent/resources/examples/mcp/elicitations/forms_demo.py +0 -6
- mcp_agent/resources/examples/workflows/router.py +9 -0
- mcp_agent/ui/console_display.py +125 -13
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/licenses/LICENSE +0 -0
mcp_agent/llm/model_factory.py
CHANGED
|
@@ -8,6 +8,7 @@ from mcp_agent.core.exceptions import ModelConfigError
|
|
|
8
8
|
from mcp_agent.core.request_params import RequestParams
|
|
9
9
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
|
10
10
|
from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
|
|
11
|
+
from mcp_agent.llm.augmented_llm_silent import SilentLLM
|
|
11
12
|
from mcp_agent.llm.augmented_llm_slow import SlowLLM
|
|
12
13
|
from mcp_agent.llm.provider_types import Provider
|
|
13
14
|
from mcp_agent.llm.providers.augmented_llm_aliyun import AliyunAugmentedLLM
|
|
@@ -20,6 +21,7 @@ from mcp_agent.llm.providers.augmented_llm_google_oai import GoogleOaiAugmentedL
|
|
|
20
21
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
|
21
22
|
from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
|
|
22
23
|
from mcp_agent.llm.providers.augmented_llm_tensorzero import TensorZeroAugmentedLLM
|
|
24
|
+
from mcp_agent.llm.providers.augmented_llm_xai import XAIAugmentedLLM
|
|
23
25
|
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
|
|
24
26
|
|
|
25
27
|
# from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM
|
|
@@ -31,6 +33,7 @@ LLMClass = Union[
|
|
|
31
33
|
Type[OpenAIAugmentedLLM],
|
|
32
34
|
Type[PassthroughLLM],
|
|
33
35
|
Type[PlaybackLLM],
|
|
36
|
+
Type[SilentLLM],
|
|
34
37
|
Type[SlowLLM],
|
|
35
38
|
Type[DeepSeekAugmentedLLM],
|
|
36
39
|
Type[OpenRouterAugmentedLLM],
|
|
@@ -75,6 +78,7 @@ class ModelFactory:
|
|
|
75
78
|
"""
|
|
76
79
|
DEFAULT_PROVIDERS = {
|
|
77
80
|
"passthrough": Provider.FAST_AGENT,
|
|
81
|
+
"silent": Provider.FAST_AGENT,
|
|
78
82
|
"playback": Provider.FAST_AGENT,
|
|
79
83
|
"slow": Provider.FAST_AGENT,
|
|
80
84
|
"gpt-4o": Provider.OPENAI,
|
|
@@ -106,6 +110,12 @@ class ModelFactory:
|
|
|
106
110
|
"gemini-2.0-flash": Provider.GOOGLE,
|
|
107
111
|
"gemini-2.5-flash-preview-05-20": Provider.GOOGLE,
|
|
108
112
|
"gemini-2.5-pro-preview-05-06": Provider.GOOGLE,
|
|
113
|
+
"grok-4": Provider.XAI,
|
|
114
|
+
"grok-4-0709": Provider.XAI,
|
|
115
|
+
"grok-3": Provider.XAI,
|
|
116
|
+
"grok-3-mini": Provider.XAI,
|
|
117
|
+
"grok-3-fast": Provider.XAI,
|
|
118
|
+
"grok-3-mini-fast": Provider.XAI,
|
|
109
119
|
"qwen-turbo": Provider.ALIYUN,
|
|
110
120
|
"qwen-plus": Provider.ALIYUN,
|
|
111
121
|
"qwen-max": Provider.ALIYUN,
|
|
@@ -140,6 +150,7 @@ class ModelFactory:
|
|
|
140
150
|
Provider.GENERIC: GenericAugmentedLLM,
|
|
141
151
|
Provider.GOOGLE_OAI: GoogleOaiAugmentedLLM,
|
|
142
152
|
Provider.GOOGLE: GoogleNativeAugmentedLLM,
|
|
153
|
+
Provider.XAI: XAIAugmentedLLM,
|
|
143
154
|
Provider.OPENROUTER: OpenRouterAugmentedLLM,
|
|
144
155
|
Provider.TENSORZERO: TensorZeroAugmentedLLM,
|
|
145
156
|
Provider.AZURE: AzureOpenAIAugmentedLLM,
|
|
@@ -150,6 +161,7 @@ class ModelFactory:
|
|
|
150
161
|
# This overrides the provider-based class selection
|
|
151
162
|
MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = {
|
|
152
163
|
"playback": PlaybackLLM,
|
|
164
|
+
"silent": SilentLLM,
|
|
153
165
|
"slow": SlowLLM,
|
|
154
166
|
}
|
|
155
167
|
|
|
@@ -11,12 +11,8 @@ from pydantic import BaseModel
|
|
|
11
11
|
from mcp_agent.core.exceptions import ProviderKeyError
|
|
12
12
|
|
|
13
13
|
PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
|
|
14
|
-
|
|
15
|
-
"
|
|
16
|
-
"deepseek": "DEEPSEEK_API_KEY",
|
|
17
|
-
"google": "GOOGLE_API_KEY",
|
|
18
|
-
"openrouter": "OPENROUTER_API_KEY",
|
|
19
|
-
"generic": "GENERIC_API_KEY",
|
|
14
|
+
# default behaviour in _get_env_key_name is to capitalize the
|
|
15
|
+
# provider name and suffix "_API_KEY" - so no specific mapping needed unless overriding
|
|
20
16
|
"huggingface": "HF_TOKEN",
|
|
21
17
|
}
|
|
22
18
|
API_KEY_HINT_TEXT = "<your-api-key-here>"
|
|
@@ -66,7 +62,14 @@ class ProviderKeyManager:
|
|
|
66
62
|
ProviderKeyError: If the API key is not found or is invalid
|
|
67
63
|
"""
|
|
68
64
|
|
|
65
|
+
from mcp_agent.llm.provider_types import Provider
|
|
66
|
+
|
|
69
67
|
provider_name = provider_name.lower()
|
|
68
|
+
|
|
69
|
+
# Fast-agent provider doesn't need external API keys
|
|
70
|
+
if provider_name == "fast-agent":
|
|
71
|
+
return ""
|
|
72
|
+
|
|
70
73
|
api_key = ProviderKeyManager.get_config_file_key(provider_name, config)
|
|
71
74
|
if not api_key:
|
|
72
75
|
api_key = ProviderKeyManager.get_env_var(provider_name)
|
|
@@ -75,9 +78,20 @@ class ProviderKeyManager:
|
|
|
75
78
|
api_key = "ollama" # Default for generic provider
|
|
76
79
|
|
|
77
80
|
if not api_key:
|
|
81
|
+
# Get proper display name for error message
|
|
82
|
+
try:
|
|
83
|
+
provider_enum = Provider(provider_name)
|
|
84
|
+
display_name = provider_enum.display_name
|
|
85
|
+
except ValueError:
|
|
86
|
+
# Invalid provider name
|
|
87
|
+
raise ProviderKeyError(
|
|
88
|
+
f"Invalid provider: {provider_name}",
|
|
89
|
+
f"'{provider_name}' is not a valid provider name.",
|
|
90
|
+
)
|
|
91
|
+
|
|
78
92
|
raise ProviderKeyError(
|
|
79
|
-
f"{
|
|
80
|
-
f"The {
|
|
93
|
+
f"{display_name} API key not configured",
|
|
94
|
+
f"The {display_name} API key is required but not set.\n"
|
|
81
95
|
f"Add it to your configuration file under {provider_name}.api_key "
|
|
82
96
|
f"or set the {ProviderKeyManager.get_env_key_name(provider_name)} environment variable.",
|
|
83
97
|
)
|
mcp_agent/llm/provider_types.py
CHANGED
|
@@ -8,15 +8,22 @@ from enum import Enum
|
|
|
8
8
|
class Provider(Enum):
|
|
9
9
|
"""Supported LLM providers"""
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
11
|
+
def __new__(cls, config_name, display_name=None):
|
|
12
|
+
obj = object.__new__(cls)
|
|
13
|
+
obj._value_ = config_name
|
|
14
|
+
obj.display_name = display_name or config_name.title()
|
|
15
|
+
return obj
|
|
16
|
+
|
|
17
|
+
ANTHROPIC = ("anthropic", "Anthropic")
|
|
18
|
+
DEEPSEEK = ("deepseek", "Deepseek")
|
|
19
|
+
FAST_AGENT = ("fast-agent", "FastAgent")
|
|
20
|
+
GENERIC = ("generic", "Generic")
|
|
21
|
+
GOOGLE_OAI = ("googleoai", "GoogleOAI") # For Google through OpenAI libraries
|
|
22
|
+
GOOGLE = ("google", "Google") # For Google GenAI native library
|
|
23
|
+
OPENAI = ("openai", "OpenAI")
|
|
24
|
+
OPENROUTER = ("openrouter", "OpenRouter")
|
|
25
|
+
TENSORZERO = ("tensorzero", "TensorZero") # For TensorZero Gateway
|
|
26
|
+
AZURE = ("azure", "Azure") # Azure OpenAI Service
|
|
27
|
+
ALIYUN = ("aliyun", "Aliyun") # Aliyun Bailian OpenAI Service
|
|
28
|
+
HUGGINGFACE = ("huggingface", "HuggingFace") # For HuggingFace MCP connections
|
|
29
|
+
XAI = ("xai", "XAI") # For xAI Grok models
|
|
@@ -112,7 +112,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
|
112
112
|
and event.delta.type == "text_delta"
|
|
113
113
|
):
|
|
114
114
|
# Use base class method for token estimation and progress emission
|
|
115
|
-
estimated_tokens = self._update_streaming_progress(
|
|
115
|
+
estimated_tokens = self._update_streaming_progress(
|
|
116
|
+
event.delta.text, model, estimated_tokens
|
|
117
|
+
)
|
|
116
118
|
|
|
117
119
|
# Also check for final message_delta events with actual usage info
|
|
118
120
|
elif (
|
|
@@ -285,7 +287,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
|
285
287
|
turn_usage = TurnUsage.from_anthropic(
|
|
286
288
|
response.usage, model or DEFAULT_ANTHROPIC_MODEL
|
|
287
289
|
)
|
|
288
|
-
self.
|
|
290
|
+
self._finalize_turn_usage(turn_usage)
|
|
289
291
|
# self._show_usage(response.usage, turn_usage)
|
|
290
292
|
except Exception as e:
|
|
291
293
|
self.logger.warning(f"Failed to track usage: {e}")
|
|
@@ -435,6 +437,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
|
435
437
|
Override this method to use a different LLM.
|
|
436
438
|
|
|
437
439
|
"""
|
|
440
|
+
# Reset tool call counter for new turn
|
|
441
|
+
self._reset_turn_tool_calls()
|
|
442
|
+
|
|
438
443
|
res = await self._anthropic_completion(
|
|
439
444
|
message_param=message_param,
|
|
440
445
|
request_params=request_params,
|
|
@@ -69,7 +69,7 @@ class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM):
|
|
|
69
69
|
|
|
70
70
|
self.get_azure_token = get_azure_token
|
|
71
71
|
else:
|
|
72
|
-
self.api_key =
|
|
72
|
+
self.api_key = self._api_key()
|
|
73
73
|
self.resource_name = getattr(azure_cfg, "resource_name", None)
|
|
74
74
|
self.base_url = getattr(azure_cfg, "base_url", None) or (
|
|
75
75
|
f"https://{self.resource_name}.openai.azure.com/" if self.resource_name else None
|
|
@@ -93,6 +93,12 @@ class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM):
|
|
|
93
93
|
if not self.resource_name and self.base_url:
|
|
94
94
|
self.resource_name = _extract_resource_name(self.base_url)
|
|
95
95
|
|
|
96
|
+
def _api_key(self):
|
|
97
|
+
"""Override to return 'AzureCredential' when using DefaultAzureCredential"""
|
|
98
|
+
if self.use_default_cred:
|
|
99
|
+
return "AzureCredential"
|
|
100
|
+
return super()._api_key()
|
|
101
|
+
|
|
96
102
|
def _openai_client(self) -> AsyncOpenAI:
|
|
97
103
|
"""
|
|
98
104
|
Returns an AzureOpenAI client, handling both API Key and DefaultAzureCredential.
|
|
@@ -295,7 +295,7 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
|
|
|
295
295
|
turn_usage = TurnUsage.from_google(
|
|
296
296
|
api_response.usage_metadata, request_params.model
|
|
297
297
|
)
|
|
298
|
-
self.
|
|
298
|
+
self._finalize_turn_usage(turn_usage)
|
|
299
299
|
|
|
300
300
|
except Exception as e:
|
|
301
301
|
self.logger.warning(f"Failed to track usage: {e}")
|
|
@@ -439,6 +439,9 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
|
|
|
439
439
|
"""
|
|
440
440
|
Applies the prompt messages and potentially calls the LLM for completion.
|
|
441
441
|
"""
|
|
442
|
+
# Reset tool call counter for new turn
|
|
443
|
+
self._reset_turn_tool_calls()
|
|
444
|
+
|
|
442
445
|
request_params = self.get_request_params(
|
|
443
446
|
request_params=request_params
|
|
444
447
|
) # Get request params
|
|
@@ -108,6 +108,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
108
108
|
def _openai_client(self) -> AsyncOpenAI:
|
|
109
109
|
try:
|
|
110
110
|
return AsyncOpenAI(api_key=self._api_key(), base_url=self._base_url())
|
|
111
|
+
|
|
111
112
|
except AuthenticationError as e:
|
|
112
113
|
raise ProviderKeyError(
|
|
113
114
|
"Invalid OpenAI API key",
|
|
@@ -355,7 +356,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
355
356
|
try:
|
|
356
357
|
model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
|
|
357
358
|
turn_usage = TurnUsage.from_openai(response.usage, model_name)
|
|
358
|
-
self.
|
|
359
|
+
self._finalize_turn_usage(turn_usage)
|
|
359
360
|
except Exception as e:
|
|
360
361
|
self.logger.warning(f"Failed to track usage: {e}")
|
|
361
362
|
|
|
@@ -389,7 +390,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
389
390
|
messages.append(message)
|
|
390
391
|
|
|
391
392
|
message_text = message.content
|
|
392
|
-
if choice.finish_reason
|
|
393
|
+
if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls:
|
|
393
394
|
if message_text:
|
|
394
395
|
await self.show_assistant_message(
|
|
395
396
|
message_text,
|
|
@@ -477,12 +478,18 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
477
478
|
|
|
478
479
|
return responses
|
|
479
480
|
|
|
481
|
+
async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
|
|
482
|
+
return True
|
|
483
|
+
|
|
480
484
|
async def _apply_prompt_provider_specific(
|
|
481
485
|
self,
|
|
482
486
|
multipart_messages: List["PromptMessageMultipart"],
|
|
483
487
|
request_params: RequestParams | None = None,
|
|
484
488
|
is_template: bool = False,
|
|
485
489
|
) -> PromptMessageMultipart:
|
|
490
|
+
# Reset tool call counter for new turn
|
|
491
|
+
self._reset_turn_tool_calls()
|
|
492
|
+
|
|
486
493
|
last_message = multipart_messages[-1]
|
|
487
494
|
|
|
488
495
|
# Add all previous messages to history (or all messages if last is from assistant)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from mcp_agent.core.request_params import RequestParams
|
|
4
|
+
from mcp_agent.llm.provider_types import Provider
|
|
5
|
+
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
|
6
|
+
|
|
7
|
+
XAI_BASE_URL = "https://api.x.ai/v1"
|
|
8
|
+
DEFAULT_XAI_MODEL = "grok-3"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class XAIAugmentedLLM(OpenAIAugmentedLLM):
|
|
12
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
13
|
+
super().__init__(
|
|
14
|
+
*args, provider=Provider.XAI, **kwargs
|
|
15
|
+
) # Properly pass args and kwargs to parent
|
|
16
|
+
|
|
17
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
|
18
|
+
"""Initialize xAI parameters"""
|
|
19
|
+
chosen_model = kwargs.get("model", DEFAULT_XAI_MODEL)
|
|
20
|
+
|
|
21
|
+
return RequestParams(
|
|
22
|
+
model=chosen_model,
|
|
23
|
+
systemPrompt=self.instruction,
|
|
24
|
+
parallel_tool_calls=False,
|
|
25
|
+
max_iterations=10,
|
|
26
|
+
use_history=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def _base_url(self) -> str:
|
|
30
|
+
base_url = os.getenv("XAI_BASE_URL", XAI_BASE_URL)
|
|
31
|
+
if self.context.config and self.context.config.xai:
|
|
32
|
+
base_url = self.context.config.xai.base_url
|
|
33
|
+
|
|
34
|
+
return base_url
|
|
35
|
+
|
|
36
|
+
async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
|
|
37
|
+
# grok uses Null as the finish reason for tool calls?
|
|
38
|
+
return await super()._is_tool_stop_reason(finish_reason) or finish_reason is None
|
mcp_agent/llm/usage_tracking.py
CHANGED
|
@@ -78,6 +78,9 @@ class TurnUsage(BaseModel):
|
|
|
78
78
|
tool_use_tokens: int = Field(default=0, description="Tokens used for tool calling prompts")
|
|
79
79
|
reasoning_tokens: int = Field(default=0, description="Tokens used for reasoning/thinking")
|
|
80
80
|
|
|
81
|
+
# Tool call count for this turn
|
|
82
|
+
tool_calls: int = Field(default=0, description="Number of tool calls made in this turn")
|
|
83
|
+
|
|
81
84
|
# Raw usage data from provider (preserves all original data)
|
|
82
85
|
raw_usage: ProviderUsage
|
|
83
86
|
|
|
@@ -86,7 +89,11 @@ class TurnUsage(BaseModel):
|
|
|
86
89
|
def current_context_tokens(self) -> int:
|
|
87
90
|
"""Current context size after this turn (total input including cache + output)"""
|
|
88
91
|
# For Anthropic: input_tokens + cache_read_tokens represents total input context
|
|
89
|
-
total_input =
|
|
92
|
+
total_input = (
|
|
93
|
+
self.input_tokens
|
|
94
|
+
+ self.cache_usage.cache_read_tokens
|
|
95
|
+
+ self.cache_usage.cache_write_tokens
|
|
96
|
+
)
|
|
90
97
|
return total_input + self.output_tokens
|
|
91
98
|
|
|
92
99
|
@computed_field
|
|
@@ -106,11 +113,20 @@ class TurnUsage(BaseModel):
|
|
|
106
113
|
"""Input tokens to display for 'Last turn' (total submitted tokens)"""
|
|
107
114
|
# For Anthropic: input_tokens excludes cache, so add cache tokens
|
|
108
115
|
if self.provider == Provider.ANTHROPIC:
|
|
109
|
-
return
|
|
116
|
+
return (
|
|
117
|
+
self.input_tokens
|
|
118
|
+
+ self.cache_usage.cache_read_tokens
|
|
119
|
+
+ self.cache_usage.cache_write_tokens
|
|
120
|
+
)
|
|
110
121
|
else:
|
|
111
122
|
# For OpenAI/Google: input_tokens already includes cached tokens
|
|
112
123
|
return self.input_tokens
|
|
113
124
|
|
|
125
|
+
def set_tool_calls(self, count: int) -> None:
|
|
126
|
+
"""Set the number of tool calls made in this turn"""
|
|
127
|
+
# Use object.__setattr__ since this is a Pydantic model
|
|
128
|
+
object.__setattr__(self, "tool_calls", count)
|
|
129
|
+
|
|
114
130
|
@classmethod
|
|
115
131
|
def from_anthropic(cls, usage: AnthropicUsage, model: str) -> "TurnUsage":
|
|
116
132
|
# Extract cache tokens with proper null handling
|
|
@@ -219,7 +235,9 @@ class UsageAccumulator(BaseModel):
|
|
|
219
235
|
def cumulative_input_tokens(self) -> int:
|
|
220
236
|
"""Total input tokens charged across all turns (including cache tokens)"""
|
|
221
237
|
return sum(
|
|
222
|
-
turn.input_tokens
|
|
238
|
+
turn.input_tokens
|
|
239
|
+
+ turn.cache_usage.cache_read_tokens
|
|
240
|
+
+ turn.cache_usage.cache_write_tokens
|
|
223
241
|
for turn in self.turns
|
|
224
242
|
)
|
|
225
243
|
|
|
@@ -247,6 +265,12 @@ class UsageAccumulator(BaseModel):
|
|
|
247
265
|
"""Total tokens written to cache across all turns"""
|
|
248
266
|
return sum(turn.cache_usage.cache_write_tokens for turn in self.turns)
|
|
249
267
|
|
|
268
|
+
@computed_field
|
|
269
|
+
@property
|
|
270
|
+
def cumulative_tool_calls(self) -> int:
|
|
271
|
+
"""Total tool calls made across all turns"""
|
|
272
|
+
return sum(turn.tool_calls for turn in self.turns)
|
|
273
|
+
|
|
250
274
|
@computed_field
|
|
251
275
|
@property
|
|
252
276
|
def cumulative_cache_hit_tokens(self) -> int:
|
|
@@ -333,6 +357,7 @@ class UsageAccumulator(BaseModel):
|
|
|
333
357
|
"cumulative_billing_tokens": self.cumulative_billing_tokens,
|
|
334
358
|
"cumulative_tool_use_tokens": self.cumulative_tool_use_tokens,
|
|
335
359
|
"cumulative_reasoning_tokens": self.cumulative_reasoning_tokens,
|
|
360
|
+
"cumulative_tool_calls": self.cumulative_tool_calls,
|
|
336
361
|
"current_context_tokens": self.current_context_tokens,
|
|
337
362
|
"context_window_size": self.context_window_size,
|
|
338
363
|
"context_usage_percentage": self.context_usage_percentage,
|
|
@@ -78,6 +78,8 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
|
78
78
|
self.agent_model: str | None = kwargs.pop("agent_model", None)
|
|
79
79
|
# Extract agent_name if provided
|
|
80
80
|
self.agent_name: str | None = kwargs.pop("agent_name", None)
|
|
81
|
+
# Extract api_key if provided
|
|
82
|
+
self.api_key: str | None = kwargs.pop("api_key", None)
|
|
81
83
|
# Extract custom elicitation handler if provided
|
|
82
84
|
custom_elicitation_handler = kwargs.pop("elicitation_handler", None)
|
|
83
85
|
|
mcp_agent/mcp/mcp_aggregator.py
CHANGED
|
@@ -221,6 +221,7 @@ class MCPAggregator(ContextDependent):
|
|
|
221
221
|
agent_model: str | None = None
|
|
222
222
|
agent_name: str | None = None
|
|
223
223
|
elicitation_handler = None
|
|
224
|
+
api_key: str | None = None
|
|
224
225
|
|
|
225
226
|
# Check if this aggregator is part of an Agent (which has config)
|
|
226
227
|
# Import here to avoid circular dependency
|
|
@@ -230,6 +231,7 @@ class MCPAggregator(ContextDependent):
|
|
|
230
231
|
agent_model = self.config.model
|
|
231
232
|
agent_name = self.config.name
|
|
232
233
|
elicitation_handler = self.config.elicitation_handler
|
|
234
|
+
api_key = self.config.api_key
|
|
233
235
|
|
|
234
236
|
return MCPAgentClientSession(
|
|
235
237
|
read_stream,
|
|
@@ -238,6 +240,7 @@ class MCPAggregator(ContextDependent):
|
|
|
238
240
|
server_name=server_name,
|
|
239
241
|
agent_model=agent_model,
|
|
240
242
|
agent_name=agent_name,
|
|
243
|
+
api_key=api_key,
|
|
241
244
|
elicitation_handler=elicitation_handler,
|
|
242
245
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
243
246
|
**kwargs, # Pass through any additional kwargs like server_config
|
|
@@ -292,6 +295,8 @@ class MCPAggregator(ContextDependent):
|
|
|
292
295
|
# Get agent's model and name if this aggregator is part of an agent
|
|
293
296
|
agent_model: str | None = None
|
|
294
297
|
agent_name: str | None = None
|
|
298
|
+
elicitation_handler = None
|
|
299
|
+
api_key: str | None = None
|
|
295
300
|
|
|
296
301
|
# Check if this aggregator is part of an Agent (which has config)
|
|
297
302
|
# Import here to avoid circular dependency
|
|
@@ -300,6 +305,8 @@ class MCPAggregator(ContextDependent):
|
|
|
300
305
|
if isinstance(self, BaseAgent):
|
|
301
306
|
agent_model = self.config.model
|
|
302
307
|
agent_name = self.config.name
|
|
308
|
+
elicitation_handler = self.config.elicitation_handler
|
|
309
|
+
api_key = self.config.api_key
|
|
303
310
|
|
|
304
311
|
return MCPAgentClientSession(
|
|
305
312
|
read_stream,
|
|
@@ -308,6 +315,8 @@ class MCPAggregator(ContextDependent):
|
|
|
308
315
|
server_name=server_name,
|
|
309
316
|
agent_model=agent_model,
|
|
310
317
|
agent_name=agent_name,
|
|
318
|
+
api_key=api_key,
|
|
319
|
+
elicitation_handler=elicitation_handler,
|
|
311
320
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
312
321
|
**kwargs, # Pass through any additional kwargs like server_config
|
|
313
322
|
)
|
|
@@ -957,58 +966,43 @@ class MCPAggregator(ContextDependent):
|
|
|
957
966
|
|
|
958
967
|
async with self._refresh_lock:
|
|
959
968
|
try:
|
|
969
|
+
# Create a factory function that will include our parameters
|
|
970
|
+
def create_session(read_stream, write_stream, read_timeout):
|
|
971
|
+
# Get agent name if available
|
|
972
|
+
agent_model: str | None = None
|
|
973
|
+
agent_name: str | None = None
|
|
974
|
+
elicitation_handler = None
|
|
975
|
+
api_key: str | None = None
|
|
976
|
+
|
|
977
|
+
# Import here to avoid circular dependency
|
|
978
|
+
from mcp_agent.agents.base_agent import BaseAgent
|
|
979
|
+
|
|
980
|
+
if isinstance(self, BaseAgent):
|
|
981
|
+
agent_model = self.config.model
|
|
982
|
+
agent_name = self.config.name
|
|
983
|
+
elicitation_handler = self.config.elicitation_handler
|
|
984
|
+
api_key = self.config.api_key
|
|
985
|
+
|
|
986
|
+
return MCPAgentClientSession(
|
|
987
|
+
read_stream,
|
|
988
|
+
write_stream,
|
|
989
|
+
read_timeout,
|
|
990
|
+
server_name=server_name,
|
|
991
|
+
agent_model=agent_model,
|
|
992
|
+
agent_name=agent_name,
|
|
993
|
+
api_key=api_key,
|
|
994
|
+
elicitation_handler=elicitation_handler,
|
|
995
|
+
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
996
|
+
)
|
|
997
|
+
|
|
960
998
|
# Fetch new tools from the server
|
|
961
999
|
if self.connection_persistence:
|
|
962
|
-
# Create a factory function that will include our parameters
|
|
963
|
-
def create_session(read_stream, write_stream, read_timeout):
|
|
964
|
-
# Get agent name if available
|
|
965
|
-
agent_name: str | None = None
|
|
966
|
-
|
|
967
|
-
# Import here to avoid circular dependency
|
|
968
|
-
from mcp_agent.agents.base_agent import BaseAgent
|
|
969
|
-
|
|
970
|
-
if isinstance(self, BaseAgent):
|
|
971
|
-
agent_name = self.config.name
|
|
972
|
-
elicitation_handler = self.config.elicitation_handler
|
|
973
|
-
|
|
974
|
-
return MCPAgentClientSession(
|
|
975
|
-
read_stream,
|
|
976
|
-
write_stream,
|
|
977
|
-
read_timeout,
|
|
978
|
-
server_name=server_name,
|
|
979
|
-
agent_name=agent_name,
|
|
980
|
-
elicitation_handler=elicitation_handler,
|
|
981
|
-
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
982
|
-
)
|
|
983
|
-
|
|
984
1000
|
server_connection = await self._persistent_connection_manager.get_server(
|
|
985
1001
|
server_name, client_session_factory=create_session
|
|
986
1002
|
)
|
|
987
1003
|
tools_result = await server_connection.session.list_tools()
|
|
988
1004
|
new_tools = tools_result.tools or []
|
|
989
1005
|
else:
|
|
990
|
-
# Create a factory function for the client session
|
|
991
|
-
def create_session(read_stream, write_stream, read_timeout):
|
|
992
|
-
# Get agent name if available
|
|
993
|
-
agent_name: str | None = None
|
|
994
|
-
|
|
995
|
-
# Import here to avoid circular dependency
|
|
996
|
-
from mcp_agent.agents.base_agent import BaseAgent
|
|
997
|
-
|
|
998
|
-
if isinstance(self, BaseAgent):
|
|
999
|
-
agent_name = self.config.name
|
|
1000
|
-
elicitation_handler = self.config.elicitation_handler
|
|
1001
|
-
|
|
1002
|
-
return MCPAgentClientSession(
|
|
1003
|
-
read_stream,
|
|
1004
|
-
write_stream,
|
|
1005
|
-
read_timeout,
|
|
1006
|
-
server_name=server_name,
|
|
1007
|
-
agent_name=agent_name,
|
|
1008
|
-
elicitation_handler=elicitation_handler,
|
|
1009
|
-
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
1010
|
-
)
|
|
1011
|
-
|
|
1012
1006
|
async with gen_client(
|
|
1013
1007
|
server_name,
|
|
1014
1008
|
server_registry=self.context.server_registry,
|
mcp_agent/mcp/sampling.py
CHANGED
|
@@ -20,7 +20,7 @@ logger = get_logger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def create_sampling_llm(
|
|
23
|
-
params: CreateMessageRequestParams, model_string: str
|
|
23
|
+
params: CreateMessageRequestParams, model_string: str, api_key: str | None
|
|
24
24
|
) -> AugmentedLLMProtocol:
|
|
25
25
|
"""
|
|
26
26
|
Create an LLM instance for sampling without tools support.
|
|
@@ -52,7 +52,7 @@ def create_sampling_llm(
|
|
|
52
52
|
|
|
53
53
|
# Create the LLM using the factory
|
|
54
54
|
factory = ModelFactory.create_factory(model_string)
|
|
55
|
-
llm = factory(agent=agent)
|
|
55
|
+
llm = factory(agent=agent, api_key=api_key)
|
|
56
56
|
|
|
57
57
|
# Attach the LLM to the agent
|
|
58
58
|
agent._llm = llm
|
|
@@ -77,7 +77,8 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
77
77
|
Returns:
|
|
78
78
|
A CreateMessageResult containing the LLM's response
|
|
79
79
|
"""
|
|
80
|
-
model = None
|
|
80
|
+
model: str | None = None
|
|
81
|
+
api_key: str | None = None
|
|
81
82
|
try:
|
|
82
83
|
# Extract model from server config using type-safe helper
|
|
83
84
|
server_config = get_server_config(mcp_ctx)
|
|
@@ -104,13 +105,16 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
104
105
|
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
105
106
|
|
|
106
107
|
# Try agent's model first (from the session)
|
|
107
|
-
if
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
108
|
+
if hasattr(mcp_ctx, "session") and isinstance(mcp_ctx.session, MCPAgentClientSession):
|
|
109
|
+
if mcp_ctx.session.agent_model:
|
|
110
|
+
model = mcp_ctx.session.agent_model
|
|
111
|
+
logger.debug(f"Using agent's model for sampling: {model}")
|
|
112
|
+
if mcp_ctx.session.api_key:
|
|
113
|
+
api_key = mcp_ctx.session.api_key
|
|
114
|
+
logger.debug(f"Using agent's API KEY for sampling: {api_key}")
|
|
115
|
+
|
|
116
|
+
# Fall back to system default model
|
|
117
|
+
if model is None:
|
|
114
118
|
try:
|
|
115
119
|
if app_context and app_context.config and app_context.config.default_model:
|
|
116
120
|
model = app_context.config.default_model
|
|
@@ -122,7 +126,7 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
122
126
|
raise ValueError("No model configured for sampling (server config, agent model, or system default)")
|
|
123
127
|
|
|
124
128
|
# Create an LLM instance
|
|
125
|
-
llm = create_sampling_llm(params, model)
|
|
129
|
+
llm = create_sampling_llm(params, model, api_key)
|
|
126
130
|
|
|
127
131
|
# Extract all messages from the request params
|
|
128
132
|
if not params.messages:
|
|
@@ -51,8 +51,6 @@ async def main():
|
|
|
51
51
|
else:
|
|
52
52
|
console.print("[red]No registration data received[/red]")
|
|
53
53
|
|
|
54
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
55
|
-
|
|
56
54
|
# Example 2: Product Review
|
|
57
55
|
console.print("[bold yellow]Example 2: Product Review Form[/bold yellow]")
|
|
58
56
|
console.print(
|
|
@@ -66,8 +64,6 @@ async def main():
|
|
|
66
64
|
)
|
|
67
65
|
console.print(review_panel)
|
|
68
66
|
|
|
69
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
70
|
-
|
|
71
67
|
# Example 3: Account Settings
|
|
72
68
|
console.print("[bold yellow]Example 3: Account Settings Form[/bold yellow]")
|
|
73
69
|
console.print(
|
|
@@ -81,8 +77,6 @@ async def main():
|
|
|
81
77
|
)
|
|
82
78
|
console.print(settings_panel)
|
|
83
79
|
|
|
84
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
85
|
-
|
|
86
80
|
# Example 4: Service Appointment
|
|
87
81
|
console.print("[bold yellow]Example 4: Service Appointment Booking[/bold yellow]")
|
|
88
82
|
console.print(
|
|
@@ -7,6 +7,8 @@ Demonstrates router's ability to either:
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
|
|
10
12
|
from mcp_agent.core.fastagent import FastAgent
|
|
11
13
|
|
|
12
14
|
# Create the application
|
|
@@ -45,7 +47,14 @@ SAMPLE_REQUESTS = [
|
|
|
45
47
|
agents=["code_expert", "general_assistant", "fetcher"],
|
|
46
48
|
)
|
|
47
49
|
async def main() -> None:
|
|
50
|
+
console = Console()
|
|
51
|
+
console.print(
|
|
52
|
+
"\n[bright_red]Router Workflow Demo[/bright_red]\n\n"
|
|
53
|
+
"Enter a request to route it to the appropriate agent.\nEnter [bright_red]STOP[/bright_red] to run the demo, [bright_red]EXIT[/bright_red] to leave"
|
|
54
|
+
)
|
|
55
|
+
|
|
48
56
|
async with fast.run() as agent:
|
|
57
|
+
await agent.interactive(agent_name="route")
|
|
49
58
|
for request in SAMPLE_REQUESTS:
|
|
50
59
|
await agent.route(request)
|
|
51
60
|
|