fast-agent-mcp 0.2.40__py3-none-any.whl → 0.2.42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/METADATA +2 -1
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/RECORD +45 -40
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/entry_points.txt +2 -2
- mcp_agent/agents/base_agent.py +111 -1
- mcp_agent/cli/__main__.py +29 -3
- mcp_agent/cli/commands/check_config.py +140 -81
- mcp_agent/cli/commands/go.py +151 -38
- mcp_agent/cli/commands/quickstart.py +6 -2
- mcp_agent/cli/commands/server_helpers.py +106 -0
- mcp_agent/cli/constants.py +25 -0
- mcp_agent/cli/main.py +1 -1
- mcp_agent/config.py +111 -44
- mcp_agent/core/agent_app.py +104 -15
- mcp_agent/core/agent_types.py +5 -1
- mcp_agent/core/direct_decorators.py +38 -0
- mcp_agent/core/direct_factory.py +18 -4
- mcp_agent/core/enhanced_prompt.py +173 -13
- mcp_agent/core/fastagent.py +4 -0
- mcp_agent/core/interactive_prompt.py +37 -37
- mcp_agent/core/usage_display.py +11 -1
- mcp_agent/core/validation.py +21 -2
- mcp_agent/human_input/elicitation_form.py +53 -21
- mcp_agent/llm/augmented_llm.py +28 -9
- mcp_agent/llm/augmented_llm_silent.py +48 -0
- mcp_agent/llm/model_database.py +20 -0
- mcp_agent/llm/model_factory.py +21 -0
- mcp_agent/llm/provider_key_manager.py +22 -8
- mcp_agent/llm/provider_types.py +20 -12
- mcp_agent/llm/providers/augmented_llm_anthropic.py +7 -2
- mcp_agent/llm/providers/augmented_llm_azure.py +7 -1
- mcp_agent/llm/providers/augmented_llm_bedrock.py +1787 -0
- mcp_agent/llm/providers/augmented_llm_google_native.py +4 -1
- mcp_agent/llm/providers/augmented_llm_openai.py +12 -3
- mcp_agent/llm/providers/augmented_llm_xai.py +38 -0
- mcp_agent/llm/usage_tracking.py +28 -3
- mcp_agent/logging/logger.py +7 -0
- mcp_agent/mcp/hf_auth.py +32 -4
- mcp_agent/mcp/mcp_agent_client_session.py +2 -0
- mcp_agent/mcp/mcp_aggregator.py +38 -44
- mcp_agent/mcp/sampling.py +15 -11
- mcp_agent/resources/examples/mcp/elicitations/forms_demo.py +0 -6
- mcp_agent/resources/examples/workflows/router.py +9 -0
- mcp_agent/ui/console_display.py +125 -13
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/licenses/LICENSE +0 -0
|
@@ -295,7 +295,7 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
|
|
|
295
295
|
turn_usage = TurnUsage.from_google(
|
|
296
296
|
api_response.usage_metadata, request_params.model
|
|
297
297
|
)
|
|
298
|
-
self.
|
|
298
|
+
self._finalize_turn_usage(turn_usage)
|
|
299
299
|
|
|
300
300
|
except Exception as e:
|
|
301
301
|
self.logger.warning(f"Failed to track usage: {e}")
|
|
@@ -439,6 +439,9 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
|
|
|
439
439
|
"""
|
|
440
440
|
Applies the prompt messages and potentially calls the LLM for completion.
|
|
441
441
|
"""
|
|
442
|
+
# Reset tool call counter for new turn
|
|
443
|
+
self._reset_turn_tool_calls()
|
|
444
|
+
|
|
442
445
|
request_params = self.get_request_params(
|
|
443
446
|
request_params=request_params
|
|
444
447
|
) # Get request params
|
|
@@ -84,7 +84,9 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
84
84
|
# TODO -- move this to model capabilities, add o4.
|
|
85
85
|
chosen_model = self.default_request_params.model if self.default_request_params else None
|
|
86
86
|
self._reasoning = chosen_model and (
|
|
87
|
-
chosen_model.startswith("o3")
|
|
87
|
+
chosen_model.startswith("o3")
|
|
88
|
+
or chosen_model.startswith("o1")
|
|
89
|
+
or chosen_model.startswith("o4")
|
|
88
90
|
)
|
|
89
91
|
if self._reasoning:
|
|
90
92
|
self.logger.info(
|
|
@@ -108,6 +110,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
108
110
|
def _openai_client(self) -> AsyncOpenAI:
|
|
109
111
|
try:
|
|
110
112
|
return AsyncOpenAI(api_key=self._api_key(), base_url=self._base_url())
|
|
113
|
+
|
|
111
114
|
except AuthenticationError as e:
|
|
112
115
|
raise ProviderKeyError(
|
|
113
116
|
"Invalid OpenAI API key",
|
|
@@ -355,7 +358,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
355
358
|
try:
|
|
356
359
|
model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
|
|
357
360
|
turn_usage = TurnUsage.from_openai(response.usage, model_name)
|
|
358
|
-
self.
|
|
361
|
+
self._finalize_turn_usage(turn_usage)
|
|
359
362
|
except Exception as e:
|
|
360
363
|
self.logger.warning(f"Failed to track usage: {e}")
|
|
361
364
|
|
|
@@ -389,7 +392,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
389
392
|
messages.append(message)
|
|
390
393
|
|
|
391
394
|
message_text = message.content
|
|
392
|
-
if choice.finish_reason
|
|
395
|
+
if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls:
|
|
393
396
|
if message_text:
|
|
394
397
|
await self.show_assistant_message(
|
|
395
398
|
message_text,
|
|
@@ -477,12 +480,18 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
477
480
|
|
|
478
481
|
return responses
|
|
479
482
|
|
|
483
|
+
async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
|
|
484
|
+
return True
|
|
485
|
+
|
|
480
486
|
async def _apply_prompt_provider_specific(
|
|
481
487
|
self,
|
|
482
488
|
multipart_messages: List["PromptMessageMultipart"],
|
|
483
489
|
request_params: RequestParams | None = None,
|
|
484
490
|
is_template: bool = False,
|
|
485
491
|
) -> PromptMessageMultipart:
|
|
492
|
+
# Reset tool call counter for new turn
|
|
493
|
+
self._reset_turn_tool_calls()
|
|
494
|
+
|
|
486
495
|
last_message = multipart_messages[-1]
|
|
487
496
|
|
|
488
497
|
# Add all previous messages to history (or all messages if last is from assistant)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from mcp_agent.core.request_params import RequestParams
|
|
4
|
+
from mcp_agent.llm.provider_types import Provider
|
|
5
|
+
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
|
6
|
+
|
|
7
|
+
XAI_BASE_URL = "https://api.x.ai/v1"
|
|
8
|
+
DEFAULT_XAI_MODEL = "grok-3"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class XAIAugmentedLLM(OpenAIAugmentedLLM):
|
|
12
|
+
def __init__(self, *args, **kwargs) -> None:
|
|
13
|
+
super().__init__(
|
|
14
|
+
*args, provider=Provider.XAI, **kwargs
|
|
15
|
+
) # Properly pass args and kwargs to parent
|
|
16
|
+
|
|
17
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
|
18
|
+
"""Initialize xAI parameters"""
|
|
19
|
+
chosen_model = kwargs.get("model", DEFAULT_XAI_MODEL)
|
|
20
|
+
|
|
21
|
+
return RequestParams(
|
|
22
|
+
model=chosen_model,
|
|
23
|
+
systemPrompt=self.instruction,
|
|
24
|
+
parallel_tool_calls=False,
|
|
25
|
+
max_iterations=10,
|
|
26
|
+
use_history=True,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def _base_url(self) -> str:
|
|
30
|
+
base_url = os.getenv("XAI_BASE_URL", XAI_BASE_URL)
|
|
31
|
+
if self.context.config and self.context.config.xai:
|
|
32
|
+
base_url = self.context.config.xai.base_url
|
|
33
|
+
|
|
34
|
+
return base_url
|
|
35
|
+
|
|
36
|
+
async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
|
|
37
|
+
# grok uses Null as the finish reason for tool calls?
|
|
38
|
+
return await super()._is_tool_stop_reason(finish_reason) or finish_reason is None
|
mcp_agent/llm/usage_tracking.py
CHANGED
|
@@ -78,6 +78,9 @@ class TurnUsage(BaseModel):
|
|
|
78
78
|
tool_use_tokens: int = Field(default=0, description="Tokens used for tool calling prompts")
|
|
79
79
|
reasoning_tokens: int = Field(default=0, description="Tokens used for reasoning/thinking")
|
|
80
80
|
|
|
81
|
+
# Tool call count for this turn
|
|
82
|
+
tool_calls: int = Field(default=0, description="Number of tool calls made in this turn")
|
|
83
|
+
|
|
81
84
|
# Raw usage data from provider (preserves all original data)
|
|
82
85
|
raw_usage: ProviderUsage
|
|
83
86
|
|
|
@@ -86,7 +89,11 @@ class TurnUsage(BaseModel):
|
|
|
86
89
|
def current_context_tokens(self) -> int:
|
|
87
90
|
"""Current context size after this turn (total input including cache + output)"""
|
|
88
91
|
# For Anthropic: input_tokens + cache_read_tokens represents total input context
|
|
89
|
-
total_input =
|
|
92
|
+
total_input = (
|
|
93
|
+
self.input_tokens
|
|
94
|
+
+ self.cache_usage.cache_read_tokens
|
|
95
|
+
+ self.cache_usage.cache_write_tokens
|
|
96
|
+
)
|
|
90
97
|
return total_input + self.output_tokens
|
|
91
98
|
|
|
92
99
|
@computed_field
|
|
@@ -106,11 +113,20 @@ class TurnUsage(BaseModel):
|
|
|
106
113
|
"""Input tokens to display for 'Last turn' (total submitted tokens)"""
|
|
107
114
|
# For Anthropic: input_tokens excludes cache, so add cache tokens
|
|
108
115
|
if self.provider == Provider.ANTHROPIC:
|
|
109
|
-
return
|
|
116
|
+
return (
|
|
117
|
+
self.input_tokens
|
|
118
|
+
+ self.cache_usage.cache_read_tokens
|
|
119
|
+
+ self.cache_usage.cache_write_tokens
|
|
120
|
+
)
|
|
110
121
|
else:
|
|
111
122
|
# For OpenAI/Google: input_tokens already includes cached tokens
|
|
112
123
|
return self.input_tokens
|
|
113
124
|
|
|
125
|
+
def set_tool_calls(self, count: int) -> None:
|
|
126
|
+
"""Set the number of tool calls made in this turn"""
|
|
127
|
+
# Use object.__setattr__ since this is a Pydantic model
|
|
128
|
+
object.__setattr__(self, "tool_calls", count)
|
|
129
|
+
|
|
114
130
|
@classmethod
|
|
115
131
|
def from_anthropic(cls, usage: AnthropicUsage, model: str) -> "TurnUsage":
|
|
116
132
|
# Extract cache tokens with proper null handling
|
|
@@ -219,7 +235,9 @@ class UsageAccumulator(BaseModel):
|
|
|
219
235
|
def cumulative_input_tokens(self) -> int:
|
|
220
236
|
"""Total input tokens charged across all turns (including cache tokens)"""
|
|
221
237
|
return sum(
|
|
222
|
-
turn.input_tokens
|
|
238
|
+
turn.input_tokens
|
|
239
|
+
+ turn.cache_usage.cache_read_tokens
|
|
240
|
+
+ turn.cache_usage.cache_write_tokens
|
|
223
241
|
for turn in self.turns
|
|
224
242
|
)
|
|
225
243
|
|
|
@@ -247,6 +265,12 @@ class UsageAccumulator(BaseModel):
|
|
|
247
265
|
"""Total tokens written to cache across all turns"""
|
|
248
266
|
return sum(turn.cache_usage.cache_write_tokens for turn in self.turns)
|
|
249
267
|
|
|
268
|
+
@computed_field
|
|
269
|
+
@property
|
|
270
|
+
def cumulative_tool_calls(self) -> int:
|
|
271
|
+
"""Total tool calls made across all turns"""
|
|
272
|
+
return sum(turn.tool_calls for turn in self.turns)
|
|
273
|
+
|
|
250
274
|
@computed_field
|
|
251
275
|
@property
|
|
252
276
|
def cumulative_cache_hit_tokens(self) -> int:
|
|
@@ -333,6 +357,7 @@ class UsageAccumulator(BaseModel):
|
|
|
333
357
|
"cumulative_billing_tokens": self.cumulative_billing_tokens,
|
|
334
358
|
"cumulative_tool_use_tokens": self.cumulative_tool_use_tokens,
|
|
335
359
|
"cumulative_reasoning_tokens": self.cumulative_reasoning_tokens,
|
|
360
|
+
"cumulative_tool_calls": self.cumulative_tool_calls,
|
|
336
361
|
"current_context_tokens": self.current_context_tokens,
|
|
337
362
|
"context_window_size": self.context_window_size,
|
|
338
363
|
"context_usage_percentage": self.context_usage_percentage,
|
mcp_agent/logging/logger.py
CHANGED
|
@@ -8,6 +8,7 @@ Logger module for the MCP Agent, which provides:
|
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
import asyncio
|
|
11
|
+
import logging
|
|
11
12
|
import threading
|
|
12
13
|
import time
|
|
13
14
|
from contextlib import asynccontextmanager, contextmanager
|
|
@@ -206,6 +207,12 @@ class LoggingConfig:
|
|
|
206
207
|
if cls._initialized:
|
|
207
208
|
return
|
|
208
209
|
|
|
210
|
+
# Suppress boto3/botocore logging to prevent flooding
|
|
211
|
+
logging.getLogger('boto3').setLevel(logging.WARNING)
|
|
212
|
+
logging.getLogger('botocore').setLevel(logging.WARNING)
|
|
213
|
+
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
|
214
|
+
logging.getLogger('s3transfer').setLevel(logging.WARNING)
|
|
215
|
+
|
|
209
216
|
bus = AsyncEventBus.get(transport=transport)
|
|
210
217
|
|
|
211
218
|
# Add standard listeners
|
mcp_agent/mcp/hf_auth.py
CHANGED
|
@@ -69,14 +69,29 @@ def should_add_hf_auth(url: str, existing_headers: Optional[Dict[str, str]]) ->
|
|
|
69
69
|
"""
|
|
70
70
|
# Only add HF auth if:
|
|
71
71
|
# 1. URL is a HuggingFace URL
|
|
72
|
-
# 2. No existing Authorization header is set
|
|
72
|
+
# 2. No existing Authorization/X-HF-Authorization header is set
|
|
73
73
|
# 3. HF_TOKEN environment variable is available
|
|
74
74
|
|
|
75
75
|
if not is_huggingface_url(url):
|
|
76
76
|
return False
|
|
77
77
|
|
|
78
|
-
if existing_headers
|
|
79
|
-
|
|
78
|
+
if existing_headers:
|
|
79
|
+
# Check if this is a .hf.space domain
|
|
80
|
+
try:
|
|
81
|
+
parsed = urlparse(url)
|
|
82
|
+
hostname = parsed.hostname
|
|
83
|
+
if hostname and hostname.endswith(".hf.space"):
|
|
84
|
+
# For .hf.space, check for X-HF-Authorization header
|
|
85
|
+
if "X-HF-Authorization" in existing_headers:
|
|
86
|
+
return False
|
|
87
|
+
else:
|
|
88
|
+
# For other HF domains, check for Authorization header
|
|
89
|
+
if "Authorization" in existing_headers:
|
|
90
|
+
return False
|
|
91
|
+
except Exception:
|
|
92
|
+
# Fallback to checking Authorization header
|
|
93
|
+
if "Authorization" in existing_headers:
|
|
94
|
+
return False
|
|
80
95
|
|
|
81
96
|
return get_hf_token_from_env() is not None
|
|
82
97
|
|
|
@@ -101,6 +116,19 @@ def add_hf_auth_header(url: str, headers: Optional[Dict[str, str]]) -> Optional[
|
|
|
101
116
|
|
|
102
117
|
# Create new headers dict or copy existing one
|
|
103
118
|
result_headers = dict(headers) if headers else {}
|
|
104
|
-
|
|
119
|
+
|
|
120
|
+
# Check if this is a .hf.space domain
|
|
121
|
+
try:
|
|
122
|
+
parsed = urlparse(url)
|
|
123
|
+
hostname = parsed.hostname
|
|
124
|
+
if hostname and hostname.endswith(".hf.space"):
|
|
125
|
+
# Use X-HF-Authorization for .hf.space domains
|
|
126
|
+
result_headers["X-HF-Authorization"] = f"Bearer {hf_token}"
|
|
127
|
+
else:
|
|
128
|
+
# Use standard Authorization header for other HF domains
|
|
129
|
+
result_headers["Authorization"] = f"Bearer {hf_token}"
|
|
130
|
+
except Exception:
|
|
131
|
+
# Fallback to standard Authorization header
|
|
132
|
+
result_headers["Authorization"] = f"Bearer {hf_token}"
|
|
105
133
|
|
|
106
134
|
return result_headers
|
|
@@ -78,6 +78,8 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
|
78
78
|
self.agent_model: str | None = kwargs.pop("agent_model", None)
|
|
79
79
|
# Extract agent_name if provided
|
|
80
80
|
self.agent_name: str | None = kwargs.pop("agent_name", None)
|
|
81
|
+
# Extract api_key if provided
|
|
82
|
+
self.api_key: str | None = kwargs.pop("api_key", None)
|
|
81
83
|
# Extract custom elicitation handler if provided
|
|
82
84
|
custom_elicitation_handler = kwargs.pop("elicitation_handler", None)
|
|
83
85
|
|
mcp_agent/mcp/mcp_aggregator.py
CHANGED
|
@@ -221,6 +221,7 @@ class MCPAggregator(ContextDependent):
|
|
|
221
221
|
agent_model: str | None = None
|
|
222
222
|
agent_name: str | None = None
|
|
223
223
|
elicitation_handler = None
|
|
224
|
+
api_key: str | None = None
|
|
224
225
|
|
|
225
226
|
# Check if this aggregator is part of an Agent (which has config)
|
|
226
227
|
# Import here to avoid circular dependency
|
|
@@ -230,6 +231,7 @@ class MCPAggregator(ContextDependent):
|
|
|
230
231
|
agent_model = self.config.model
|
|
231
232
|
agent_name = self.config.name
|
|
232
233
|
elicitation_handler = self.config.elicitation_handler
|
|
234
|
+
api_key = self.config.api_key
|
|
233
235
|
|
|
234
236
|
return MCPAgentClientSession(
|
|
235
237
|
read_stream,
|
|
@@ -238,6 +240,7 @@ class MCPAggregator(ContextDependent):
|
|
|
238
240
|
server_name=server_name,
|
|
239
241
|
agent_model=agent_model,
|
|
240
242
|
agent_name=agent_name,
|
|
243
|
+
api_key=api_key,
|
|
241
244
|
elicitation_handler=elicitation_handler,
|
|
242
245
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
243
246
|
**kwargs, # Pass through any additional kwargs like server_config
|
|
@@ -292,6 +295,8 @@ class MCPAggregator(ContextDependent):
|
|
|
292
295
|
# Get agent's model and name if this aggregator is part of an agent
|
|
293
296
|
agent_model: str | None = None
|
|
294
297
|
agent_name: str | None = None
|
|
298
|
+
elicitation_handler = None
|
|
299
|
+
api_key: str | None = None
|
|
295
300
|
|
|
296
301
|
# Check if this aggregator is part of an Agent (which has config)
|
|
297
302
|
# Import here to avoid circular dependency
|
|
@@ -300,6 +305,8 @@ class MCPAggregator(ContextDependent):
|
|
|
300
305
|
if isinstance(self, BaseAgent):
|
|
301
306
|
agent_model = self.config.model
|
|
302
307
|
agent_name = self.config.name
|
|
308
|
+
elicitation_handler = self.config.elicitation_handler
|
|
309
|
+
api_key = self.config.api_key
|
|
303
310
|
|
|
304
311
|
return MCPAgentClientSession(
|
|
305
312
|
read_stream,
|
|
@@ -308,6 +315,8 @@ class MCPAggregator(ContextDependent):
|
|
|
308
315
|
server_name=server_name,
|
|
309
316
|
agent_model=agent_model,
|
|
310
317
|
agent_name=agent_name,
|
|
318
|
+
api_key=api_key,
|
|
319
|
+
elicitation_handler=elicitation_handler,
|
|
311
320
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
312
321
|
**kwargs, # Pass through any additional kwargs like server_config
|
|
313
322
|
)
|
|
@@ -957,58 +966,43 @@ class MCPAggregator(ContextDependent):
|
|
|
957
966
|
|
|
958
967
|
async with self._refresh_lock:
|
|
959
968
|
try:
|
|
969
|
+
# Create a factory function that will include our parameters
|
|
970
|
+
def create_session(read_stream, write_stream, read_timeout):
|
|
971
|
+
# Get agent name if available
|
|
972
|
+
agent_model: str | None = None
|
|
973
|
+
agent_name: str | None = None
|
|
974
|
+
elicitation_handler = None
|
|
975
|
+
api_key: str | None = None
|
|
976
|
+
|
|
977
|
+
# Import here to avoid circular dependency
|
|
978
|
+
from mcp_agent.agents.base_agent import BaseAgent
|
|
979
|
+
|
|
980
|
+
if isinstance(self, BaseAgent):
|
|
981
|
+
agent_model = self.config.model
|
|
982
|
+
agent_name = self.config.name
|
|
983
|
+
elicitation_handler = self.config.elicitation_handler
|
|
984
|
+
api_key = self.config.api_key
|
|
985
|
+
|
|
986
|
+
return MCPAgentClientSession(
|
|
987
|
+
read_stream,
|
|
988
|
+
write_stream,
|
|
989
|
+
read_timeout,
|
|
990
|
+
server_name=server_name,
|
|
991
|
+
agent_model=agent_model,
|
|
992
|
+
agent_name=agent_name,
|
|
993
|
+
api_key=api_key,
|
|
994
|
+
elicitation_handler=elicitation_handler,
|
|
995
|
+
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
996
|
+
)
|
|
997
|
+
|
|
960
998
|
# Fetch new tools from the server
|
|
961
999
|
if self.connection_persistence:
|
|
962
|
-
# Create a factory function that will include our parameters
|
|
963
|
-
def create_session(read_stream, write_stream, read_timeout):
|
|
964
|
-
# Get agent name if available
|
|
965
|
-
agent_name: str | None = None
|
|
966
|
-
|
|
967
|
-
# Import here to avoid circular dependency
|
|
968
|
-
from mcp_agent.agents.base_agent import BaseAgent
|
|
969
|
-
|
|
970
|
-
if isinstance(self, BaseAgent):
|
|
971
|
-
agent_name = self.config.name
|
|
972
|
-
elicitation_handler = self.config.elicitation_handler
|
|
973
|
-
|
|
974
|
-
return MCPAgentClientSession(
|
|
975
|
-
read_stream,
|
|
976
|
-
write_stream,
|
|
977
|
-
read_timeout,
|
|
978
|
-
server_name=server_name,
|
|
979
|
-
agent_name=agent_name,
|
|
980
|
-
elicitation_handler=elicitation_handler,
|
|
981
|
-
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
982
|
-
)
|
|
983
|
-
|
|
984
1000
|
server_connection = await self._persistent_connection_manager.get_server(
|
|
985
1001
|
server_name, client_session_factory=create_session
|
|
986
1002
|
)
|
|
987
1003
|
tools_result = await server_connection.session.list_tools()
|
|
988
1004
|
new_tools = tools_result.tools or []
|
|
989
1005
|
else:
|
|
990
|
-
# Create a factory function for the client session
|
|
991
|
-
def create_session(read_stream, write_stream, read_timeout):
|
|
992
|
-
# Get agent name if available
|
|
993
|
-
agent_name: str | None = None
|
|
994
|
-
|
|
995
|
-
# Import here to avoid circular dependency
|
|
996
|
-
from mcp_agent.agents.base_agent import BaseAgent
|
|
997
|
-
|
|
998
|
-
if isinstance(self, BaseAgent):
|
|
999
|
-
agent_name = self.config.name
|
|
1000
|
-
elicitation_handler = self.config.elicitation_handler
|
|
1001
|
-
|
|
1002
|
-
return MCPAgentClientSession(
|
|
1003
|
-
read_stream,
|
|
1004
|
-
write_stream,
|
|
1005
|
-
read_timeout,
|
|
1006
|
-
server_name=server_name,
|
|
1007
|
-
agent_name=agent_name,
|
|
1008
|
-
elicitation_handler=elicitation_handler,
|
|
1009
|
-
tool_list_changed_callback=self._handle_tool_list_changed,
|
|
1010
|
-
)
|
|
1011
|
-
|
|
1012
1006
|
async with gen_client(
|
|
1013
1007
|
server_name,
|
|
1014
1008
|
server_registry=self.context.server_registry,
|
mcp_agent/mcp/sampling.py
CHANGED
|
@@ -20,7 +20,7 @@ logger = get_logger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def create_sampling_llm(
|
|
23
|
-
params: CreateMessageRequestParams, model_string: str
|
|
23
|
+
params: CreateMessageRequestParams, model_string: str, api_key: str | None
|
|
24
24
|
) -> AugmentedLLMProtocol:
|
|
25
25
|
"""
|
|
26
26
|
Create an LLM instance for sampling without tools support.
|
|
@@ -52,7 +52,7 @@ def create_sampling_llm(
|
|
|
52
52
|
|
|
53
53
|
# Create the LLM using the factory
|
|
54
54
|
factory = ModelFactory.create_factory(model_string)
|
|
55
|
-
llm = factory(agent=agent)
|
|
55
|
+
llm = factory(agent=agent, api_key=api_key)
|
|
56
56
|
|
|
57
57
|
# Attach the LLM to the agent
|
|
58
58
|
agent._llm = llm
|
|
@@ -77,7 +77,8 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
77
77
|
Returns:
|
|
78
78
|
A CreateMessageResult containing the LLM's response
|
|
79
79
|
"""
|
|
80
|
-
model = None
|
|
80
|
+
model: str | None = None
|
|
81
|
+
api_key: str | None = None
|
|
81
82
|
try:
|
|
82
83
|
# Extract model from server config using type-safe helper
|
|
83
84
|
server_config = get_server_config(mcp_ctx)
|
|
@@ -104,13 +105,16 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
104
105
|
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
105
106
|
|
|
106
107
|
# Try agent's model first (from the session)
|
|
107
|
-
if
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
108
|
+
if hasattr(mcp_ctx, "session") and isinstance(mcp_ctx.session, MCPAgentClientSession):
|
|
109
|
+
if mcp_ctx.session.agent_model:
|
|
110
|
+
model = mcp_ctx.session.agent_model
|
|
111
|
+
logger.debug(f"Using agent's model for sampling: {model}")
|
|
112
|
+
if mcp_ctx.session.api_key:
|
|
113
|
+
api_key = mcp_ctx.session.api_key
|
|
114
|
+
logger.debug(f"Using agent's API KEY for sampling: {api_key}")
|
|
115
|
+
|
|
116
|
+
# Fall back to system default model
|
|
117
|
+
if model is None:
|
|
114
118
|
try:
|
|
115
119
|
if app_context and app_context.config and app_context.config.default_model:
|
|
116
120
|
model = app_context.config.default_model
|
|
@@ -122,7 +126,7 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
|
122
126
|
raise ValueError("No model configured for sampling (server config, agent model, or system default)")
|
|
123
127
|
|
|
124
128
|
# Create an LLM instance
|
|
125
|
-
llm = create_sampling_llm(params, model)
|
|
129
|
+
llm = create_sampling_llm(params, model, api_key)
|
|
126
130
|
|
|
127
131
|
# Extract all messages from the request params
|
|
128
132
|
if not params.messages:
|
|
@@ -51,8 +51,6 @@ async def main():
|
|
|
51
51
|
else:
|
|
52
52
|
console.print("[red]No registration data received[/red]")
|
|
53
53
|
|
|
54
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
55
|
-
|
|
56
54
|
# Example 2: Product Review
|
|
57
55
|
console.print("[bold yellow]Example 2: Product Review Form[/bold yellow]")
|
|
58
56
|
console.print(
|
|
@@ -66,8 +64,6 @@ async def main():
|
|
|
66
64
|
)
|
|
67
65
|
console.print(review_panel)
|
|
68
66
|
|
|
69
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
70
|
-
|
|
71
67
|
# Example 3: Account Settings
|
|
72
68
|
console.print("[bold yellow]Example 3: Account Settings Form[/bold yellow]")
|
|
73
69
|
console.print(
|
|
@@ -81,8 +77,6 @@ async def main():
|
|
|
81
77
|
)
|
|
82
78
|
console.print(settings_panel)
|
|
83
79
|
|
|
84
|
-
console.print("\n" + "─" * 50 + "\n")
|
|
85
|
-
|
|
86
80
|
# Example 4: Service Appointment
|
|
87
81
|
console.print("[bold yellow]Example 4: Service Appointment Booking[/bold yellow]")
|
|
88
82
|
console.print(
|
|
@@ -7,6 +7,8 @@ Demonstrates router's ability to either:
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
|
|
10
12
|
from mcp_agent.core.fastagent import FastAgent
|
|
11
13
|
|
|
12
14
|
# Create the application
|
|
@@ -45,7 +47,14 @@ SAMPLE_REQUESTS = [
|
|
|
45
47
|
agents=["code_expert", "general_assistant", "fetcher"],
|
|
46
48
|
)
|
|
47
49
|
async def main() -> None:
|
|
50
|
+
console = Console()
|
|
51
|
+
console.print(
|
|
52
|
+
"\n[bright_red]Router Workflow Demo[/bright_red]\n\n"
|
|
53
|
+
"Enter a request to route it to the appropriate agent.\nEnter [bright_red]STOP[/bright_red] to run the demo, [bright_red]EXIT[/bright_red] to leave"
|
|
54
|
+
)
|
|
55
|
+
|
|
48
56
|
async with fast.run() as agent:
|
|
57
|
+
await agent.interactive(agent_name="route")
|
|
49
58
|
for request in SAMPLE_REQUESTS:
|
|
50
59
|
await agent.route(request)
|
|
51
60
|
|