fast-agent-mcp 0.2.40__py3-none-any.whl → 0.2.41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

Files changed (41) hide show
  1. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/METADATA +1 -1
  2. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/RECORD +41 -37
  3. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/entry_points.txt +2 -2
  4. mcp_agent/cli/__main__.py +29 -3
  5. mcp_agent/cli/commands/check_config.py +140 -81
  6. mcp_agent/cli/commands/go.py +151 -38
  7. mcp_agent/cli/commands/quickstart.py +6 -2
  8. mcp_agent/cli/commands/server_helpers.py +106 -0
  9. mcp_agent/cli/constants.py +25 -0
  10. mcp_agent/cli/main.py +1 -1
  11. mcp_agent/config.py +94 -44
  12. mcp_agent/core/agent_app.py +104 -15
  13. mcp_agent/core/agent_types.py +1 -0
  14. mcp_agent/core/direct_decorators.py +9 -0
  15. mcp_agent/core/direct_factory.py +18 -4
  16. mcp_agent/core/enhanced_prompt.py +165 -13
  17. mcp_agent/core/fastagent.py +4 -0
  18. mcp_agent/core/interactive_prompt.py +37 -37
  19. mcp_agent/core/usage_display.py +11 -1
  20. mcp_agent/core/validation.py +21 -2
  21. mcp_agent/human_input/elicitation_form.py +53 -21
  22. mcp_agent/llm/augmented_llm.py +28 -9
  23. mcp_agent/llm/augmented_llm_silent.py +48 -0
  24. mcp_agent/llm/model_database.py +20 -0
  25. mcp_agent/llm/model_factory.py +12 -0
  26. mcp_agent/llm/provider_key_manager.py +22 -8
  27. mcp_agent/llm/provider_types.py +19 -12
  28. mcp_agent/llm/providers/augmented_llm_anthropic.py +7 -2
  29. mcp_agent/llm/providers/augmented_llm_azure.py +7 -1
  30. mcp_agent/llm/providers/augmented_llm_google_native.py +4 -1
  31. mcp_agent/llm/providers/augmented_llm_openai.py +9 -2
  32. mcp_agent/llm/providers/augmented_llm_xai.py +38 -0
  33. mcp_agent/llm/usage_tracking.py +28 -3
  34. mcp_agent/mcp/mcp_agent_client_session.py +2 -0
  35. mcp_agent/mcp/mcp_aggregator.py +38 -44
  36. mcp_agent/mcp/sampling.py +15 -11
  37. mcp_agent/resources/examples/mcp/elicitations/forms_demo.py +0 -6
  38. mcp_agent/resources/examples/workflows/router.py +9 -0
  39. mcp_agent/ui/console_display.py +125 -13
  40. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/WHEEL +0 -0
  41. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.41.dist-info}/licenses/LICENSE +0 -0
@@ -8,6 +8,7 @@ from mcp_agent.core.exceptions import ModelConfigError
8
8
  from mcp_agent.core.request_params import RequestParams
9
9
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
10
10
  from mcp_agent.llm.augmented_llm_playback import PlaybackLLM
11
+ from mcp_agent.llm.augmented_llm_silent import SilentLLM
11
12
  from mcp_agent.llm.augmented_llm_slow import SlowLLM
12
13
  from mcp_agent.llm.provider_types import Provider
13
14
  from mcp_agent.llm.providers.augmented_llm_aliyun import AliyunAugmentedLLM
@@ -20,6 +21,7 @@ from mcp_agent.llm.providers.augmented_llm_google_oai import GoogleOaiAugmentedL
20
21
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
21
22
  from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
22
23
  from mcp_agent.llm.providers.augmented_llm_tensorzero import TensorZeroAugmentedLLM
24
+ from mcp_agent.llm.providers.augmented_llm_xai import XAIAugmentedLLM
23
25
  from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
24
26
 
25
27
  # from mcp_agent.workflows.llm.augmented_llm_deepseek import DeekSeekAugmentedLLM
@@ -31,6 +33,7 @@ LLMClass = Union[
31
33
  Type[OpenAIAugmentedLLM],
32
34
  Type[PassthroughLLM],
33
35
  Type[PlaybackLLM],
36
+ Type[SilentLLM],
34
37
  Type[SlowLLM],
35
38
  Type[DeepSeekAugmentedLLM],
36
39
  Type[OpenRouterAugmentedLLM],
@@ -75,6 +78,7 @@ class ModelFactory:
75
78
  """
76
79
  DEFAULT_PROVIDERS = {
77
80
  "passthrough": Provider.FAST_AGENT,
81
+ "silent": Provider.FAST_AGENT,
78
82
  "playback": Provider.FAST_AGENT,
79
83
  "slow": Provider.FAST_AGENT,
80
84
  "gpt-4o": Provider.OPENAI,
@@ -106,6 +110,12 @@ class ModelFactory:
106
110
  "gemini-2.0-flash": Provider.GOOGLE,
107
111
  "gemini-2.5-flash-preview-05-20": Provider.GOOGLE,
108
112
  "gemini-2.5-pro-preview-05-06": Provider.GOOGLE,
113
+ "grok-4": Provider.XAI,
114
+ "grok-4-0709": Provider.XAI,
115
+ "grok-3": Provider.XAI,
116
+ "grok-3-mini": Provider.XAI,
117
+ "grok-3-fast": Provider.XAI,
118
+ "grok-3-mini-fast": Provider.XAI,
109
119
  "qwen-turbo": Provider.ALIYUN,
110
120
  "qwen-plus": Provider.ALIYUN,
111
121
  "qwen-max": Provider.ALIYUN,
@@ -140,6 +150,7 @@ class ModelFactory:
140
150
  Provider.GENERIC: GenericAugmentedLLM,
141
151
  Provider.GOOGLE_OAI: GoogleOaiAugmentedLLM,
142
152
  Provider.GOOGLE: GoogleNativeAugmentedLLM,
153
+ Provider.XAI: XAIAugmentedLLM,
143
154
  Provider.OPENROUTER: OpenRouterAugmentedLLM,
144
155
  Provider.TENSORZERO: TensorZeroAugmentedLLM,
145
156
  Provider.AZURE: AzureOpenAIAugmentedLLM,
@@ -150,6 +161,7 @@ class ModelFactory:
150
161
  # This overrides the provider-based class selection
151
162
  MODEL_SPECIFIC_CLASSES: Dict[str, LLMClass] = {
152
163
  "playback": PlaybackLLM,
164
+ "silent": SilentLLM,
153
165
  "slow": SlowLLM,
154
166
  }
155
167
 
@@ -11,12 +11,8 @@ from pydantic import BaseModel
11
11
  from mcp_agent.core.exceptions import ProviderKeyError
12
12
 
13
13
  PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
14
- "anthropic": "ANTHROPIC_API_KEY",
15
- "openai": "OPENAI_API_KEY",
16
- "deepseek": "DEEPSEEK_API_KEY",
17
- "google": "GOOGLE_API_KEY",
18
- "openrouter": "OPENROUTER_API_KEY",
19
- "generic": "GENERIC_API_KEY",
14
+ # default behaviour in _get_env_key_name is to capitalize the
15
+ # provider name and suffix "_API_KEY" - so no specific mapping needed unless overriding
20
16
  "huggingface": "HF_TOKEN",
21
17
  }
22
18
  API_KEY_HINT_TEXT = "<your-api-key-here>"
@@ -66,7 +62,14 @@ class ProviderKeyManager:
66
62
  ProviderKeyError: If the API key is not found or is invalid
67
63
  """
68
64
 
65
+ from mcp_agent.llm.provider_types import Provider
66
+
69
67
  provider_name = provider_name.lower()
68
+
69
+ # Fast-agent provider doesn't need external API keys
70
+ if provider_name == "fast-agent":
71
+ return ""
72
+
70
73
  api_key = ProviderKeyManager.get_config_file_key(provider_name, config)
71
74
  if not api_key:
72
75
  api_key = ProviderKeyManager.get_env_var(provider_name)
@@ -75,9 +78,20 @@ class ProviderKeyManager:
75
78
  api_key = "ollama" # Default for generic provider
76
79
 
77
80
  if not api_key:
81
+ # Get proper display name for error message
82
+ try:
83
+ provider_enum = Provider(provider_name)
84
+ display_name = provider_enum.display_name
85
+ except ValueError:
86
+ # Invalid provider name
87
+ raise ProviderKeyError(
88
+ f"Invalid provider: {provider_name}",
89
+ f"'{provider_name}' is not a valid provider name.",
90
+ )
91
+
78
92
  raise ProviderKeyError(
79
- f"{provider_name.title()} API key not configured",
80
- f"The {provider_name.title()} API key is required but not set.\n"
93
+ f"{display_name} API key not configured",
94
+ f"The {display_name} API key is required but not set.\n"
81
95
  f"Add it to your configuration file under {provider_name}.api_key "
82
96
  f"or set the {ProviderKeyManager.get_env_key_name(provider_name)} environment variable.",
83
97
  )
@@ -8,15 +8,22 @@ from enum import Enum
8
8
  class Provider(Enum):
9
9
  """Supported LLM providers"""
10
10
 
11
- ANTHROPIC = "anthropic"
12
- DEEPSEEK = "deepseek"
13
- FAST_AGENT = "fast-agent"
14
- GENERIC = "generic"
15
- GOOGLE_OAI = "googleoai" # For Google through OpenAI libraries
16
- GOOGLE = "google" # For Google GenAI native library
17
- OPENAI = "openai"
18
- OPENROUTER = "openrouter"
19
- TENSORZERO = "tensorzero" # For TensorZero Gateway
20
- AZURE = "azure" # Azure OpenAI Service
21
- ALIYUN = "aliyun" # Aliyun Bailian OpenAI Service
22
- HUGGINGFACE = "huggingface" # For HuggingFace MCP connections
11
+ def __new__(cls, config_name, display_name=None):
12
+ obj = object.__new__(cls)
13
+ obj._value_ = config_name
14
+ obj.display_name = display_name or config_name.title()
15
+ return obj
16
+
17
+ ANTHROPIC = ("anthropic", "Anthropic")
18
+ DEEPSEEK = ("deepseek", "Deepseek")
19
+ FAST_AGENT = ("fast-agent", "FastAgent")
20
+ GENERIC = ("generic", "Generic")
21
+ GOOGLE_OAI = ("googleoai", "GoogleOAI") # For Google through OpenAI libraries
22
+ GOOGLE = ("google", "Google") # For Google GenAI native library
23
+ OPENAI = ("openai", "OpenAI")
24
+ OPENROUTER = ("openrouter", "OpenRouter")
25
+ TENSORZERO = ("tensorzero", "TensorZero") # For TensorZero Gateway
26
+ AZURE = ("azure", "Azure") # Azure OpenAI Service
27
+ ALIYUN = ("aliyun", "Aliyun") # Aliyun Bailian OpenAI Service
28
+ HUGGINGFACE = ("huggingface", "HuggingFace") # For HuggingFace MCP connections
29
+ XAI = ("xai", "XAI") # For xAI Grok models
@@ -112,7 +112,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
112
112
  and event.delta.type == "text_delta"
113
113
  ):
114
114
  # Use base class method for token estimation and progress emission
115
- estimated_tokens = self._update_streaming_progress(event.delta.text, model, estimated_tokens)
115
+ estimated_tokens = self._update_streaming_progress(
116
+ event.delta.text, model, estimated_tokens
117
+ )
116
118
 
117
119
  # Also check for final message_delta events with actual usage info
118
120
  elif (
@@ -285,7 +287,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
285
287
  turn_usage = TurnUsage.from_anthropic(
286
288
  response.usage, model or DEFAULT_ANTHROPIC_MODEL
287
289
  )
288
- self.usage_accumulator.add_turn(turn_usage)
290
+ self._finalize_turn_usage(turn_usage)
289
291
  # self._show_usage(response.usage, turn_usage)
290
292
  except Exception as e:
291
293
  self.logger.warning(f"Failed to track usage: {e}")
@@ -435,6 +437,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
435
437
  Override this method to use a different LLM.
436
438
 
437
439
  """
440
+ # Reset tool call counter for new turn
441
+ self._reset_turn_tool_calls()
442
+
438
443
  res = await self._anthropic_completion(
439
444
  message_param=message_param,
440
445
  request_params=request_params,
@@ -69,7 +69,7 @@ class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM):
69
69
 
70
70
  self.get_azure_token = get_azure_token
71
71
  else:
72
- self.api_key = getattr(azure_cfg, "api_key", None)
72
+ self.api_key = self._api_key()
73
73
  self.resource_name = getattr(azure_cfg, "resource_name", None)
74
74
  self.base_url = getattr(azure_cfg, "base_url", None) or (
75
75
  f"https://{self.resource_name}.openai.azure.com/" if self.resource_name else None
@@ -93,6 +93,12 @@ class AzureOpenAIAugmentedLLM(OpenAIAugmentedLLM):
93
93
  if not self.resource_name and self.base_url:
94
94
  self.resource_name = _extract_resource_name(self.base_url)
95
95
 
96
+ def _api_key(self):
97
+ """Override to return 'AzureCredential' when using DefaultAzureCredential"""
98
+ if self.use_default_cred:
99
+ return "AzureCredential"
100
+ return super()._api_key()
101
+
96
102
  def _openai_client(self) -> AsyncOpenAI:
97
103
  """
98
104
  Returns an AzureOpenAI client, handling both API Key and DefaultAzureCredential.
@@ -295,7 +295,7 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
295
295
  turn_usage = TurnUsage.from_google(
296
296
  api_response.usage_metadata, request_params.model
297
297
  )
298
- self.usage_accumulator.add_turn(turn_usage)
298
+ self._finalize_turn_usage(turn_usage)
299
299
 
300
300
  except Exception as e:
301
301
  self.logger.warning(f"Failed to track usage: {e}")
@@ -439,6 +439,9 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
439
439
  """
440
440
  Applies the prompt messages and potentially calls the LLM for completion.
441
441
  """
442
+ # Reset tool call counter for new turn
443
+ self._reset_turn_tool_calls()
444
+
442
445
  request_params = self.get_request_params(
443
446
  request_params=request_params
444
447
  ) # Get request params
@@ -108,6 +108,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
108
108
  def _openai_client(self) -> AsyncOpenAI:
109
109
  try:
110
110
  return AsyncOpenAI(api_key=self._api_key(), base_url=self._base_url())
111
+
111
112
  except AuthenticationError as e:
112
113
  raise ProviderKeyError(
113
114
  "Invalid OpenAI API key",
@@ -355,7 +356,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
355
356
  try:
356
357
  model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
357
358
  turn_usage = TurnUsage.from_openai(response.usage, model_name)
358
- self.usage_accumulator.add_turn(turn_usage)
359
+ self._finalize_turn_usage(turn_usage)
359
360
  except Exception as e:
360
361
  self.logger.warning(f"Failed to track usage: {e}")
361
362
 
@@ -389,7 +390,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
389
390
  messages.append(message)
390
391
 
391
392
  message_text = message.content
392
- if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls:
393
+ if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls:
393
394
  if message_text:
394
395
  await self.show_assistant_message(
395
396
  message_text,
@@ -477,12 +478,18 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
477
478
 
478
479
  return responses
479
480
 
481
+ async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
482
+ return True
483
+
480
484
  async def _apply_prompt_provider_specific(
481
485
  self,
482
486
  multipart_messages: List["PromptMessageMultipart"],
483
487
  request_params: RequestParams | None = None,
484
488
  is_template: bool = False,
485
489
  ) -> PromptMessageMultipart:
490
+ # Reset tool call counter for new turn
491
+ self._reset_turn_tool_calls()
492
+
486
493
  last_message = multipart_messages[-1]
487
494
 
488
495
  # Add all previous messages to history (or all messages if last is from assistant)
@@ -0,0 +1,38 @@
1
+ import os
2
+
3
+ from mcp_agent.core.request_params import RequestParams
4
+ from mcp_agent.llm.provider_types import Provider
5
+ from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
6
+
7
+ XAI_BASE_URL = "https://api.x.ai/v1"
8
+ DEFAULT_XAI_MODEL = "grok-3"
9
+
10
+
11
+ class XAIAugmentedLLM(OpenAIAugmentedLLM):
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super().__init__(
14
+ *args, provider=Provider.XAI, **kwargs
15
+ ) # Properly pass args and kwargs to parent
16
+
17
+ def _initialize_default_params(self, kwargs: dict) -> RequestParams:
18
+ """Initialize xAI parameters"""
19
+ chosen_model = kwargs.get("model", DEFAULT_XAI_MODEL)
20
+
21
+ return RequestParams(
22
+ model=chosen_model,
23
+ systemPrompt=self.instruction,
24
+ parallel_tool_calls=False,
25
+ max_iterations=10,
26
+ use_history=True,
27
+ )
28
+
29
+ def _base_url(self) -> str:
30
+ base_url = os.getenv("XAI_BASE_URL", XAI_BASE_URL)
31
+ if self.context.config and self.context.config.xai:
32
+ base_url = self.context.config.xai.base_url
33
+
34
+ return base_url
35
+
36
+ async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
37
+ # grok uses Null as the finish reason for tool calls?
38
+ return await super()._is_tool_stop_reason(finish_reason) or finish_reason is None
@@ -78,6 +78,9 @@ class TurnUsage(BaseModel):
78
78
  tool_use_tokens: int = Field(default=0, description="Tokens used for tool calling prompts")
79
79
  reasoning_tokens: int = Field(default=0, description="Tokens used for reasoning/thinking")
80
80
 
81
+ # Tool call count for this turn
82
+ tool_calls: int = Field(default=0, description="Number of tool calls made in this turn")
83
+
81
84
  # Raw usage data from provider (preserves all original data)
82
85
  raw_usage: ProviderUsage
83
86
 
@@ -86,7 +89,11 @@ class TurnUsage(BaseModel):
86
89
  def current_context_tokens(self) -> int:
87
90
  """Current context size after this turn (total input including cache + output)"""
88
91
  # For Anthropic: input_tokens + cache_read_tokens represents total input context
89
- total_input = self.input_tokens + self.cache_usage.cache_read_tokens + self.cache_usage.cache_write_tokens
92
+ total_input = (
93
+ self.input_tokens
94
+ + self.cache_usage.cache_read_tokens
95
+ + self.cache_usage.cache_write_tokens
96
+ )
90
97
  return total_input + self.output_tokens
91
98
 
92
99
  @computed_field
@@ -106,11 +113,20 @@ class TurnUsage(BaseModel):
106
113
  """Input tokens to display for 'Last turn' (total submitted tokens)"""
107
114
  # For Anthropic: input_tokens excludes cache, so add cache tokens
108
115
  if self.provider == Provider.ANTHROPIC:
109
- return self.input_tokens + self.cache_usage.cache_read_tokens + self.cache_usage.cache_write_tokens
116
+ return (
117
+ self.input_tokens
118
+ + self.cache_usage.cache_read_tokens
119
+ + self.cache_usage.cache_write_tokens
120
+ )
110
121
  else:
111
122
  # For OpenAI/Google: input_tokens already includes cached tokens
112
123
  return self.input_tokens
113
124
 
125
+ def set_tool_calls(self, count: int) -> None:
126
+ """Set the number of tool calls made in this turn"""
127
+ # Use object.__setattr__ since this is a Pydantic model
128
+ object.__setattr__(self, "tool_calls", count)
129
+
114
130
  @classmethod
115
131
  def from_anthropic(cls, usage: AnthropicUsage, model: str) -> "TurnUsage":
116
132
  # Extract cache tokens with proper null handling
@@ -219,7 +235,9 @@ class UsageAccumulator(BaseModel):
219
235
  def cumulative_input_tokens(self) -> int:
220
236
  """Total input tokens charged across all turns (including cache tokens)"""
221
237
  return sum(
222
- turn.input_tokens + turn.cache_usage.cache_read_tokens + turn.cache_usage.cache_write_tokens
238
+ turn.input_tokens
239
+ + turn.cache_usage.cache_read_tokens
240
+ + turn.cache_usage.cache_write_tokens
223
241
  for turn in self.turns
224
242
  )
225
243
 
@@ -247,6 +265,12 @@ class UsageAccumulator(BaseModel):
247
265
  """Total tokens written to cache across all turns"""
248
266
  return sum(turn.cache_usage.cache_write_tokens for turn in self.turns)
249
267
 
268
+ @computed_field
269
+ @property
270
+ def cumulative_tool_calls(self) -> int:
271
+ """Total tool calls made across all turns"""
272
+ return sum(turn.tool_calls for turn in self.turns)
273
+
250
274
  @computed_field
251
275
  @property
252
276
  def cumulative_cache_hit_tokens(self) -> int:
@@ -333,6 +357,7 @@ class UsageAccumulator(BaseModel):
333
357
  "cumulative_billing_tokens": self.cumulative_billing_tokens,
334
358
  "cumulative_tool_use_tokens": self.cumulative_tool_use_tokens,
335
359
  "cumulative_reasoning_tokens": self.cumulative_reasoning_tokens,
360
+ "cumulative_tool_calls": self.cumulative_tool_calls,
336
361
  "current_context_tokens": self.current_context_tokens,
337
362
  "context_window_size": self.context_window_size,
338
363
  "context_usage_percentage": self.context_usage_percentage,
@@ -78,6 +78,8 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
78
78
  self.agent_model: str | None = kwargs.pop("agent_model", None)
79
79
  # Extract agent_name if provided
80
80
  self.agent_name: str | None = kwargs.pop("agent_name", None)
81
+ # Extract api_key if provided
82
+ self.api_key: str | None = kwargs.pop("api_key", None)
81
83
  # Extract custom elicitation handler if provided
82
84
  custom_elicitation_handler = kwargs.pop("elicitation_handler", None)
83
85
 
@@ -221,6 +221,7 @@ class MCPAggregator(ContextDependent):
221
221
  agent_model: str | None = None
222
222
  agent_name: str | None = None
223
223
  elicitation_handler = None
224
+ api_key: str | None = None
224
225
 
225
226
  # Check if this aggregator is part of an Agent (which has config)
226
227
  # Import here to avoid circular dependency
@@ -230,6 +231,7 @@ class MCPAggregator(ContextDependent):
230
231
  agent_model = self.config.model
231
232
  agent_name = self.config.name
232
233
  elicitation_handler = self.config.elicitation_handler
234
+ api_key = self.config.api_key
233
235
 
234
236
  return MCPAgentClientSession(
235
237
  read_stream,
@@ -238,6 +240,7 @@ class MCPAggregator(ContextDependent):
238
240
  server_name=server_name,
239
241
  agent_model=agent_model,
240
242
  agent_name=agent_name,
243
+ api_key=api_key,
241
244
  elicitation_handler=elicitation_handler,
242
245
  tool_list_changed_callback=self._handle_tool_list_changed,
243
246
  **kwargs, # Pass through any additional kwargs like server_config
@@ -292,6 +295,8 @@ class MCPAggregator(ContextDependent):
292
295
  # Get agent's model and name if this aggregator is part of an agent
293
296
  agent_model: str | None = None
294
297
  agent_name: str | None = None
298
+ elicitation_handler = None
299
+ api_key: str | None = None
295
300
 
296
301
  # Check if this aggregator is part of an Agent (which has config)
297
302
  # Import here to avoid circular dependency
@@ -300,6 +305,8 @@ class MCPAggregator(ContextDependent):
300
305
  if isinstance(self, BaseAgent):
301
306
  agent_model = self.config.model
302
307
  agent_name = self.config.name
308
+ elicitation_handler = self.config.elicitation_handler
309
+ api_key = self.config.api_key
303
310
 
304
311
  return MCPAgentClientSession(
305
312
  read_stream,
@@ -308,6 +315,8 @@ class MCPAggregator(ContextDependent):
308
315
  server_name=server_name,
309
316
  agent_model=agent_model,
310
317
  agent_name=agent_name,
318
+ api_key=api_key,
319
+ elicitation_handler=elicitation_handler,
311
320
  tool_list_changed_callback=self._handle_tool_list_changed,
312
321
  **kwargs, # Pass through any additional kwargs like server_config
313
322
  )
@@ -957,58 +966,43 @@ class MCPAggregator(ContextDependent):
957
966
 
958
967
  async with self._refresh_lock:
959
968
  try:
969
+ # Create a factory function that will include our parameters
970
+ def create_session(read_stream, write_stream, read_timeout):
971
+ # Get agent name if available
972
+ agent_model: str | None = None
973
+ agent_name: str | None = None
974
+ elicitation_handler = None
975
+ api_key: str | None = None
976
+
977
+ # Import here to avoid circular dependency
978
+ from mcp_agent.agents.base_agent import BaseAgent
979
+
980
+ if isinstance(self, BaseAgent):
981
+ agent_model = self.config.model
982
+ agent_name = self.config.name
983
+ elicitation_handler = self.config.elicitation_handler
984
+ api_key = self.config.api_key
985
+
986
+ return MCPAgentClientSession(
987
+ read_stream,
988
+ write_stream,
989
+ read_timeout,
990
+ server_name=server_name,
991
+ agent_model=agent_model,
992
+ agent_name=agent_name,
993
+ api_key=api_key,
994
+ elicitation_handler=elicitation_handler,
995
+ tool_list_changed_callback=self._handle_tool_list_changed,
996
+ )
997
+
960
998
  # Fetch new tools from the server
961
999
  if self.connection_persistence:
962
- # Create a factory function that will include our parameters
963
- def create_session(read_stream, write_stream, read_timeout):
964
- # Get agent name if available
965
- agent_name: str | None = None
966
-
967
- # Import here to avoid circular dependency
968
- from mcp_agent.agents.base_agent import BaseAgent
969
-
970
- if isinstance(self, BaseAgent):
971
- agent_name = self.config.name
972
- elicitation_handler = self.config.elicitation_handler
973
-
974
- return MCPAgentClientSession(
975
- read_stream,
976
- write_stream,
977
- read_timeout,
978
- server_name=server_name,
979
- agent_name=agent_name,
980
- elicitation_handler=elicitation_handler,
981
- tool_list_changed_callback=self._handle_tool_list_changed,
982
- )
983
-
984
1000
  server_connection = await self._persistent_connection_manager.get_server(
985
1001
  server_name, client_session_factory=create_session
986
1002
  )
987
1003
  tools_result = await server_connection.session.list_tools()
988
1004
  new_tools = tools_result.tools or []
989
1005
  else:
990
- # Create a factory function for the client session
991
- def create_session(read_stream, write_stream, read_timeout):
992
- # Get agent name if available
993
- agent_name: str | None = None
994
-
995
- # Import here to avoid circular dependency
996
- from mcp_agent.agents.base_agent import BaseAgent
997
-
998
- if isinstance(self, BaseAgent):
999
- agent_name = self.config.name
1000
- elicitation_handler = self.config.elicitation_handler
1001
-
1002
- return MCPAgentClientSession(
1003
- read_stream,
1004
- write_stream,
1005
- read_timeout,
1006
- server_name=server_name,
1007
- agent_name=agent_name,
1008
- elicitation_handler=elicitation_handler,
1009
- tool_list_changed_callback=self._handle_tool_list_changed,
1010
- )
1011
-
1012
1006
  async with gen_client(
1013
1007
  server_name,
1014
1008
  server_registry=self.context.server_registry,
mcp_agent/mcp/sampling.py CHANGED
@@ -20,7 +20,7 @@ logger = get_logger(__name__)
20
20
 
21
21
 
22
22
  def create_sampling_llm(
23
- params: CreateMessageRequestParams, model_string: str
23
+ params: CreateMessageRequestParams, model_string: str, api_key: str | None
24
24
  ) -> AugmentedLLMProtocol:
25
25
  """
26
26
  Create an LLM instance for sampling without tools support.
@@ -52,7 +52,7 @@ def create_sampling_llm(
52
52
 
53
53
  # Create the LLM using the factory
54
54
  factory = ModelFactory.create_factory(model_string)
55
- llm = factory(agent=agent)
55
+ llm = factory(agent=agent, api_key=api_key)
56
56
 
57
57
  # Attach the LLM to the agent
58
58
  agent._llm = llm
@@ -77,7 +77,8 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
77
77
  Returns:
78
78
  A CreateMessageResult containing the LLM's response
79
79
  """
80
- model = None
80
+ model: str | None = None
81
+ api_key: str | None = None
81
82
  try:
82
83
  # Extract model from server config using type-safe helper
83
84
  server_config = get_server_config(mcp_ctx)
@@ -104,13 +105,16 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
104
105
  from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
105
106
 
106
107
  # Try agent's model first (from the session)
107
- if (hasattr(mcp_ctx, 'session') and
108
- isinstance(mcp_ctx.session, MCPAgentClientSession) and
109
- mcp_ctx.session.agent_model):
110
- model = mcp_ctx.session.agent_model
111
- logger.debug(f"Using agent's model for sampling: {model}")
112
- else:
113
- # Fall back to system default model
108
+ if hasattr(mcp_ctx, "session") and isinstance(mcp_ctx.session, MCPAgentClientSession):
109
+ if mcp_ctx.session.agent_model:
110
+ model = mcp_ctx.session.agent_model
111
+ logger.debug(f"Using agent's model for sampling: {model}")
112
+ if mcp_ctx.session.api_key:
113
+ api_key = mcp_ctx.session.api_key
114
+ logger.debug(f"Using agent's API KEY for sampling: {api_key}")
115
+
116
+ # Fall back to system default model
117
+ if model is None:
114
118
  try:
115
119
  if app_context and app_context.config and app_context.config.default_model:
116
120
  model = app_context.config.default_model
@@ -122,7 +126,7 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
122
126
  raise ValueError("No model configured for sampling (server config, agent model, or system default)")
123
127
 
124
128
  # Create an LLM instance
125
- llm = create_sampling_llm(params, model)
129
+ llm = create_sampling_llm(params, model, api_key)
126
130
 
127
131
  # Extract all messages from the request params
128
132
  if not params.messages:
@@ -51,8 +51,6 @@ async def main():
51
51
  else:
52
52
  console.print("[red]No registration data received[/red]")
53
53
 
54
- console.print("\n" + "─" * 50 + "\n")
55
-
56
54
  # Example 2: Product Review
57
55
  console.print("[bold yellow]Example 2: Product Review Form[/bold yellow]")
58
56
  console.print(
@@ -66,8 +64,6 @@ async def main():
66
64
  )
67
65
  console.print(review_panel)
68
66
 
69
- console.print("\n" + "─" * 50 + "\n")
70
-
71
67
  # Example 3: Account Settings
72
68
  console.print("[bold yellow]Example 3: Account Settings Form[/bold yellow]")
73
69
  console.print(
@@ -81,8 +77,6 @@ async def main():
81
77
  )
82
78
  console.print(settings_panel)
83
79
 
84
- console.print("\n" + "─" * 50 + "\n")
85
-
86
80
  # Example 4: Service Appointment
87
81
  console.print("[bold yellow]Example 4: Service Appointment Booking[/bold yellow]")
88
82
  console.print(
@@ -7,6 +7,8 @@ Demonstrates router's ability to either:
7
7
 
8
8
  import asyncio
9
9
 
10
+ from rich.console import Console
11
+
10
12
  from mcp_agent.core.fastagent import FastAgent
11
13
 
12
14
  # Create the application
@@ -45,7 +47,14 @@ SAMPLE_REQUESTS = [
45
47
  agents=["code_expert", "general_assistant", "fetcher"],
46
48
  )
47
49
  async def main() -> None:
50
+ console = Console()
51
+ console.print(
52
+ "\n[bright_red]Router Workflow Demo[/bright_red]\n\n"
53
+ "Enter a request to route it to the appropriate agent.\nEnter [bright_red]STOP[/bright_red] to run the demo, [bright_red]EXIT[/bright_red] to leave"
54
+ )
55
+
48
56
  async with fast.run() as agent:
57
+ await agent.interactive(agent_name="route")
49
58
  for request in SAMPLE_REQUESTS:
50
59
  await agent.route(request)
51
60