fast-agent-mcp 0.2.40__py3-none-any.whl → 0.2.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

Files changed (45) hide show
  1. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/METADATA +2 -1
  2. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/RECORD +45 -40
  3. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/entry_points.txt +2 -2
  4. mcp_agent/agents/base_agent.py +111 -1
  5. mcp_agent/cli/__main__.py +29 -3
  6. mcp_agent/cli/commands/check_config.py +140 -81
  7. mcp_agent/cli/commands/go.py +151 -38
  8. mcp_agent/cli/commands/quickstart.py +6 -2
  9. mcp_agent/cli/commands/server_helpers.py +106 -0
  10. mcp_agent/cli/constants.py +25 -0
  11. mcp_agent/cli/main.py +1 -1
  12. mcp_agent/config.py +111 -44
  13. mcp_agent/core/agent_app.py +104 -15
  14. mcp_agent/core/agent_types.py +5 -1
  15. mcp_agent/core/direct_decorators.py +38 -0
  16. mcp_agent/core/direct_factory.py +18 -4
  17. mcp_agent/core/enhanced_prompt.py +173 -13
  18. mcp_agent/core/fastagent.py +4 -0
  19. mcp_agent/core/interactive_prompt.py +37 -37
  20. mcp_agent/core/usage_display.py +11 -1
  21. mcp_agent/core/validation.py +21 -2
  22. mcp_agent/human_input/elicitation_form.py +53 -21
  23. mcp_agent/llm/augmented_llm.py +28 -9
  24. mcp_agent/llm/augmented_llm_silent.py +48 -0
  25. mcp_agent/llm/model_database.py +20 -0
  26. mcp_agent/llm/model_factory.py +21 -0
  27. mcp_agent/llm/provider_key_manager.py +22 -8
  28. mcp_agent/llm/provider_types.py +20 -12
  29. mcp_agent/llm/providers/augmented_llm_anthropic.py +7 -2
  30. mcp_agent/llm/providers/augmented_llm_azure.py +7 -1
  31. mcp_agent/llm/providers/augmented_llm_bedrock.py +1787 -0
  32. mcp_agent/llm/providers/augmented_llm_google_native.py +4 -1
  33. mcp_agent/llm/providers/augmented_llm_openai.py +12 -3
  34. mcp_agent/llm/providers/augmented_llm_xai.py +38 -0
  35. mcp_agent/llm/usage_tracking.py +28 -3
  36. mcp_agent/logging/logger.py +7 -0
  37. mcp_agent/mcp/hf_auth.py +32 -4
  38. mcp_agent/mcp/mcp_agent_client_session.py +2 -0
  39. mcp_agent/mcp/mcp_aggregator.py +38 -44
  40. mcp_agent/mcp/sampling.py +15 -11
  41. mcp_agent/resources/examples/mcp/elicitations/forms_demo.py +0 -6
  42. mcp_agent/resources/examples/workflows/router.py +9 -0
  43. mcp_agent/ui/console_display.py +125 -13
  44. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/WHEEL +0 -0
  45. {fast_agent_mcp-0.2.40.dist-info → fast_agent_mcp-0.2.42.dist-info}/licenses/LICENSE +0 -0
@@ -295,7 +295,7 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
295
295
  turn_usage = TurnUsage.from_google(
296
296
  api_response.usage_metadata, request_params.model
297
297
  )
298
- self.usage_accumulator.add_turn(turn_usage)
298
+ self._finalize_turn_usage(turn_usage)
299
299
 
300
300
  except Exception as e:
301
301
  self.logger.warning(f"Failed to track usage: {e}")
@@ -439,6 +439,9 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
439
439
  """
440
440
  Applies the prompt messages and potentially calls the LLM for completion.
441
441
  """
442
+ # Reset tool call counter for new turn
443
+ self._reset_turn_tool_calls()
444
+
442
445
  request_params = self.get_request_params(
443
446
  request_params=request_params
444
447
  ) # Get request params
@@ -84,7 +84,9 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
84
84
  # TODO -- move this to model capabilities, add o4.
85
85
  chosen_model = self.default_request_params.model if self.default_request_params else None
86
86
  self._reasoning = chosen_model and (
87
- chosen_model.startswith("o3") or chosen_model.startswith("o1")
87
+ chosen_model.startswith("o3")
88
+ or chosen_model.startswith("o1")
89
+ or chosen_model.startswith("o4")
88
90
  )
89
91
  if self._reasoning:
90
92
  self.logger.info(
@@ -108,6 +110,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
108
110
  def _openai_client(self) -> AsyncOpenAI:
109
111
  try:
110
112
  return AsyncOpenAI(api_key=self._api_key(), base_url=self._base_url())
113
+
111
114
  except AuthenticationError as e:
112
115
  raise ProviderKeyError(
113
116
  "Invalid OpenAI API key",
@@ -355,7 +358,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
355
358
  try:
356
359
  model_name = self.default_request_params.model or DEFAULT_OPENAI_MODEL
357
360
  turn_usage = TurnUsage.from_openai(response.usage, model_name)
358
- self.usage_accumulator.add_turn(turn_usage)
361
+ self._finalize_turn_usage(turn_usage)
359
362
  except Exception as e:
360
363
  self.logger.warning(f"Failed to track usage: {e}")
361
364
 
@@ -389,7 +392,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
389
392
  messages.append(message)
390
393
 
391
394
  message_text = message.content
392
- if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls:
395
+ if await self._is_tool_stop_reason(choice.finish_reason) and message.tool_calls:
393
396
  if message_text:
394
397
  await self.show_assistant_message(
395
398
  message_text,
@@ -477,12 +480,18 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
477
480
 
478
481
  return responses
479
482
 
483
+ async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
484
+ return True
485
+
480
486
  async def _apply_prompt_provider_specific(
481
487
  self,
482
488
  multipart_messages: List["PromptMessageMultipart"],
483
489
  request_params: RequestParams | None = None,
484
490
  is_template: bool = False,
485
491
  ) -> PromptMessageMultipart:
492
+ # Reset tool call counter for new turn
493
+ self._reset_turn_tool_calls()
494
+
486
495
  last_message = multipart_messages[-1]
487
496
 
488
497
  # Add all previous messages to history (or all messages if last is from assistant)
@@ -0,0 +1,38 @@
1
+ import os
2
+
3
+ from mcp_agent.core.request_params import RequestParams
4
+ from mcp_agent.llm.provider_types import Provider
5
+ from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
6
+
7
+ XAI_BASE_URL = "https://api.x.ai/v1"
8
+ DEFAULT_XAI_MODEL = "grok-3"
9
+
10
+
11
+ class XAIAugmentedLLM(OpenAIAugmentedLLM):
12
+ def __init__(self, *args, **kwargs) -> None:
13
+ super().__init__(
14
+ *args, provider=Provider.XAI, **kwargs
15
+ ) # Properly pass args and kwargs to parent
16
+
17
+ def _initialize_default_params(self, kwargs: dict) -> RequestParams:
18
+ """Initialize xAI parameters"""
19
+ chosen_model = kwargs.get("model", DEFAULT_XAI_MODEL)
20
+
21
+ return RequestParams(
22
+ model=chosen_model,
23
+ systemPrompt=self.instruction,
24
+ parallel_tool_calls=False,
25
+ max_iterations=10,
26
+ use_history=True,
27
+ )
28
+
29
+ def _base_url(self) -> str:
30
+ base_url = os.getenv("XAI_BASE_URL", XAI_BASE_URL)
31
+ if self.context.config and self.context.config.xai:
32
+ base_url = self.context.config.xai.base_url
33
+
34
+ return base_url
35
+
36
+ async def _is_tool_stop_reason(self, finish_reason: str) -> bool:
37
+ # grok uses Null as the finish reason for tool calls?
38
+ return await super()._is_tool_stop_reason(finish_reason) or finish_reason is None
@@ -78,6 +78,9 @@ class TurnUsage(BaseModel):
78
78
  tool_use_tokens: int = Field(default=0, description="Tokens used for tool calling prompts")
79
79
  reasoning_tokens: int = Field(default=0, description="Tokens used for reasoning/thinking")
80
80
 
81
+ # Tool call count for this turn
82
+ tool_calls: int = Field(default=0, description="Number of tool calls made in this turn")
83
+
81
84
  # Raw usage data from provider (preserves all original data)
82
85
  raw_usage: ProviderUsage
83
86
 
@@ -86,7 +89,11 @@ class TurnUsage(BaseModel):
86
89
  def current_context_tokens(self) -> int:
87
90
  """Current context size after this turn (total input including cache + output)"""
88
91
  # For Anthropic: input_tokens + cache_read_tokens represents total input context
89
- total_input = self.input_tokens + self.cache_usage.cache_read_tokens + self.cache_usage.cache_write_tokens
92
+ total_input = (
93
+ self.input_tokens
94
+ + self.cache_usage.cache_read_tokens
95
+ + self.cache_usage.cache_write_tokens
96
+ )
90
97
  return total_input + self.output_tokens
91
98
 
92
99
  @computed_field
@@ -106,11 +113,20 @@ class TurnUsage(BaseModel):
106
113
  """Input tokens to display for 'Last turn' (total submitted tokens)"""
107
114
  # For Anthropic: input_tokens excludes cache, so add cache tokens
108
115
  if self.provider == Provider.ANTHROPIC:
109
- return self.input_tokens + self.cache_usage.cache_read_tokens + self.cache_usage.cache_write_tokens
116
+ return (
117
+ self.input_tokens
118
+ + self.cache_usage.cache_read_tokens
119
+ + self.cache_usage.cache_write_tokens
120
+ )
110
121
  else:
111
122
  # For OpenAI/Google: input_tokens already includes cached tokens
112
123
  return self.input_tokens
113
124
 
125
+ def set_tool_calls(self, count: int) -> None:
126
+ """Set the number of tool calls made in this turn"""
127
+ # Use object.__setattr__ since this is a Pydantic model
128
+ object.__setattr__(self, "tool_calls", count)
129
+
114
130
  @classmethod
115
131
  def from_anthropic(cls, usage: AnthropicUsage, model: str) -> "TurnUsage":
116
132
  # Extract cache tokens with proper null handling
@@ -219,7 +235,9 @@ class UsageAccumulator(BaseModel):
219
235
  def cumulative_input_tokens(self) -> int:
220
236
  """Total input tokens charged across all turns (including cache tokens)"""
221
237
  return sum(
222
- turn.input_tokens + turn.cache_usage.cache_read_tokens + turn.cache_usage.cache_write_tokens
238
+ turn.input_tokens
239
+ + turn.cache_usage.cache_read_tokens
240
+ + turn.cache_usage.cache_write_tokens
223
241
  for turn in self.turns
224
242
  )
225
243
 
@@ -247,6 +265,12 @@ class UsageAccumulator(BaseModel):
247
265
  """Total tokens written to cache across all turns"""
248
266
  return sum(turn.cache_usage.cache_write_tokens for turn in self.turns)
249
267
 
268
+ @computed_field
269
+ @property
270
+ def cumulative_tool_calls(self) -> int:
271
+ """Total tool calls made across all turns"""
272
+ return sum(turn.tool_calls for turn in self.turns)
273
+
250
274
  @computed_field
251
275
  @property
252
276
  def cumulative_cache_hit_tokens(self) -> int:
@@ -333,6 +357,7 @@ class UsageAccumulator(BaseModel):
333
357
  "cumulative_billing_tokens": self.cumulative_billing_tokens,
334
358
  "cumulative_tool_use_tokens": self.cumulative_tool_use_tokens,
335
359
  "cumulative_reasoning_tokens": self.cumulative_reasoning_tokens,
360
+ "cumulative_tool_calls": self.cumulative_tool_calls,
336
361
  "current_context_tokens": self.current_context_tokens,
337
362
  "context_window_size": self.context_window_size,
338
363
  "context_usage_percentage": self.context_usage_percentage,
@@ -8,6 +8,7 @@ Logger module for the MCP Agent, which provides:
8
8
  """
9
9
 
10
10
  import asyncio
11
+ import logging
11
12
  import threading
12
13
  import time
13
14
  from contextlib import asynccontextmanager, contextmanager
@@ -206,6 +207,12 @@ class LoggingConfig:
206
207
  if cls._initialized:
207
208
  return
208
209
 
210
+ # Suppress boto3/botocore logging to prevent flooding
211
+ logging.getLogger('boto3').setLevel(logging.WARNING)
212
+ logging.getLogger('botocore').setLevel(logging.WARNING)
213
+ logging.getLogger('urllib3').setLevel(logging.WARNING)
214
+ logging.getLogger('s3transfer').setLevel(logging.WARNING)
215
+
209
216
  bus = AsyncEventBus.get(transport=transport)
210
217
 
211
218
  # Add standard listeners
mcp_agent/mcp/hf_auth.py CHANGED
@@ -69,14 +69,29 @@ def should_add_hf_auth(url: str, existing_headers: Optional[Dict[str, str]]) ->
69
69
  """
70
70
  # Only add HF auth if:
71
71
  # 1. URL is a HuggingFace URL
72
- # 2. No existing Authorization header is set
72
+ # 2. No existing Authorization/X-HF-Authorization header is set
73
73
  # 3. HF_TOKEN environment variable is available
74
74
 
75
75
  if not is_huggingface_url(url):
76
76
  return False
77
77
 
78
- if existing_headers and "Authorization" in existing_headers:
79
- return False
78
+ if existing_headers:
79
+ # Check if this is a .hf.space domain
80
+ try:
81
+ parsed = urlparse(url)
82
+ hostname = parsed.hostname
83
+ if hostname and hostname.endswith(".hf.space"):
84
+ # For .hf.space, check for X-HF-Authorization header
85
+ if "X-HF-Authorization" in existing_headers:
86
+ return False
87
+ else:
88
+ # For other HF domains, check for Authorization header
89
+ if "Authorization" in existing_headers:
90
+ return False
91
+ except Exception:
92
+ # Fallback to checking Authorization header
93
+ if "Authorization" in existing_headers:
94
+ return False
80
95
 
81
96
  return get_hf_token_from_env() is not None
82
97
 
@@ -101,6 +116,19 @@ def add_hf_auth_header(url: str, headers: Optional[Dict[str, str]]) -> Optional[
101
116
 
102
117
  # Create new headers dict or copy existing one
103
118
  result_headers = dict(headers) if headers else {}
104
- result_headers["Authorization"] = f"Bearer {hf_token}"
119
+
120
+ # Check if this is a .hf.space domain
121
+ try:
122
+ parsed = urlparse(url)
123
+ hostname = parsed.hostname
124
+ if hostname and hostname.endswith(".hf.space"):
125
+ # Use X-HF-Authorization for .hf.space domains
126
+ result_headers["X-HF-Authorization"] = f"Bearer {hf_token}"
127
+ else:
128
+ # Use standard Authorization header for other HF domains
129
+ result_headers["Authorization"] = f"Bearer {hf_token}"
130
+ except Exception:
131
+ # Fallback to standard Authorization header
132
+ result_headers["Authorization"] = f"Bearer {hf_token}"
105
133
 
106
134
  return result_headers
@@ -78,6 +78,8 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
78
78
  self.agent_model: str | None = kwargs.pop("agent_model", None)
79
79
  # Extract agent_name if provided
80
80
  self.agent_name: str | None = kwargs.pop("agent_name", None)
81
+ # Extract api_key if provided
82
+ self.api_key: str | None = kwargs.pop("api_key", None)
81
83
  # Extract custom elicitation handler if provided
82
84
  custom_elicitation_handler = kwargs.pop("elicitation_handler", None)
83
85
 
@@ -221,6 +221,7 @@ class MCPAggregator(ContextDependent):
221
221
  agent_model: str | None = None
222
222
  agent_name: str | None = None
223
223
  elicitation_handler = None
224
+ api_key: str | None = None
224
225
 
225
226
  # Check if this aggregator is part of an Agent (which has config)
226
227
  # Import here to avoid circular dependency
@@ -230,6 +231,7 @@ class MCPAggregator(ContextDependent):
230
231
  agent_model = self.config.model
231
232
  agent_name = self.config.name
232
233
  elicitation_handler = self.config.elicitation_handler
234
+ api_key = self.config.api_key
233
235
 
234
236
  return MCPAgentClientSession(
235
237
  read_stream,
@@ -238,6 +240,7 @@ class MCPAggregator(ContextDependent):
238
240
  server_name=server_name,
239
241
  agent_model=agent_model,
240
242
  agent_name=agent_name,
243
+ api_key=api_key,
241
244
  elicitation_handler=elicitation_handler,
242
245
  tool_list_changed_callback=self._handle_tool_list_changed,
243
246
  **kwargs, # Pass through any additional kwargs like server_config
@@ -292,6 +295,8 @@ class MCPAggregator(ContextDependent):
292
295
  # Get agent's model and name if this aggregator is part of an agent
293
296
  agent_model: str | None = None
294
297
  agent_name: str | None = None
298
+ elicitation_handler = None
299
+ api_key: str | None = None
295
300
 
296
301
  # Check if this aggregator is part of an Agent (which has config)
297
302
  # Import here to avoid circular dependency
@@ -300,6 +305,8 @@ class MCPAggregator(ContextDependent):
300
305
  if isinstance(self, BaseAgent):
301
306
  agent_model = self.config.model
302
307
  agent_name = self.config.name
308
+ elicitation_handler = self.config.elicitation_handler
309
+ api_key = self.config.api_key
303
310
 
304
311
  return MCPAgentClientSession(
305
312
  read_stream,
@@ -308,6 +315,8 @@ class MCPAggregator(ContextDependent):
308
315
  server_name=server_name,
309
316
  agent_model=agent_model,
310
317
  agent_name=agent_name,
318
+ api_key=api_key,
319
+ elicitation_handler=elicitation_handler,
311
320
  tool_list_changed_callback=self._handle_tool_list_changed,
312
321
  **kwargs, # Pass through any additional kwargs like server_config
313
322
  )
@@ -957,58 +966,43 @@ class MCPAggregator(ContextDependent):
957
966
 
958
967
  async with self._refresh_lock:
959
968
  try:
969
+ # Create a factory function that will include our parameters
970
+ def create_session(read_stream, write_stream, read_timeout):
971
+ # Get agent name if available
972
+ agent_model: str | None = None
973
+ agent_name: str | None = None
974
+ elicitation_handler = None
975
+ api_key: str | None = None
976
+
977
+ # Import here to avoid circular dependency
978
+ from mcp_agent.agents.base_agent import BaseAgent
979
+
980
+ if isinstance(self, BaseAgent):
981
+ agent_model = self.config.model
982
+ agent_name = self.config.name
983
+ elicitation_handler = self.config.elicitation_handler
984
+ api_key = self.config.api_key
985
+
986
+ return MCPAgentClientSession(
987
+ read_stream,
988
+ write_stream,
989
+ read_timeout,
990
+ server_name=server_name,
991
+ agent_model=agent_model,
992
+ agent_name=agent_name,
993
+ api_key=api_key,
994
+ elicitation_handler=elicitation_handler,
995
+ tool_list_changed_callback=self._handle_tool_list_changed,
996
+ )
997
+
960
998
  # Fetch new tools from the server
961
999
  if self.connection_persistence:
962
- # Create a factory function that will include our parameters
963
- def create_session(read_stream, write_stream, read_timeout):
964
- # Get agent name if available
965
- agent_name: str | None = None
966
-
967
- # Import here to avoid circular dependency
968
- from mcp_agent.agents.base_agent import BaseAgent
969
-
970
- if isinstance(self, BaseAgent):
971
- agent_name = self.config.name
972
- elicitation_handler = self.config.elicitation_handler
973
-
974
- return MCPAgentClientSession(
975
- read_stream,
976
- write_stream,
977
- read_timeout,
978
- server_name=server_name,
979
- agent_name=agent_name,
980
- elicitation_handler=elicitation_handler,
981
- tool_list_changed_callback=self._handle_tool_list_changed,
982
- )
983
-
984
1000
  server_connection = await self._persistent_connection_manager.get_server(
985
1001
  server_name, client_session_factory=create_session
986
1002
  )
987
1003
  tools_result = await server_connection.session.list_tools()
988
1004
  new_tools = tools_result.tools or []
989
1005
  else:
990
- # Create a factory function for the client session
991
- def create_session(read_stream, write_stream, read_timeout):
992
- # Get agent name if available
993
- agent_name: str | None = None
994
-
995
- # Import here to avoid circular dependency
996
- from mcp_agent.agents.base_agent import BaseAgent
997
-
998
- if isinstance(self, BaseAgent):
999
- agent_name = self.config.name
1000
- elicitation_handler = self.config.elicitation_handler
1001
-
1002
- return MCPAgentClientSession(
1003
- read_stream,
1004
- write_stream,
1005
- read_timeout,
1006
- server_name=server_name,
1007
- agent_name=agent_name,
1008
- elicitation_handler=elicitation_handler,
1009
- tool_list_changed_callback=self._handle_tool_list_changed,
1010
- )
1011
-
1012
1006
  async with gen_client(
1013
1007
  server_name,
1014
1008
  server_registry=self.context.server_registry,
mcp_agent/mcp/sampling.py CHANGED
@@ -20,7 +20,7 @@ logger = get_logger(__name__)
20
20
 
21
21
 
22
22
  def create_sampling_llm(
23
- params: CreateMessageRequestParams, model_string: str
23
+ params: CreateMessageRequestParams, model_string: str, api_key: str | None
24
24
  ) -> AugmentedLLMProtocol:
25
25
  """
26
26
  Create an LLM instance for sampling without tools support.
@@ -52,7 +52,7 @@ def create_sampling_llm(
52
52
 
53
53
  # Create the LLM using the factory
54
54
  factory = ModelFactory.create_factory(model_string)
55
- llm = factory(agent=agent)
55
+ llm = factory(agent=agent, api_key=api_key)
56
56
 
57
57
  # Attach the LLM to the agent
58
58
  agent._llm = llm
@@ -77,7 +77,8 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
77
77
  Returns:
78
78
  A CreateMessageResult containing the LLM's response
79
79
  """
80
- model = None
80
+ model: str | None = None
81
+ api_key: str | None = None
81
82
  try:
82
83
  # Extract model from server config using type-safe helper
83
84
  server_config = get_server_config(mcp_ctx)
@@ -104,13 +105,16 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
104
105
  from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
105
106
 
106
107
  # Try agent's model first (from the session)
107
- if (hasattr(mcp_ctx, 'session') and
108
- isinstance(mcp_ctx.session, MCPAgentClientSession) and
109
- mcp_ctx.session.agent_model):
110
- model = mcp_ctx.session.agent_model
111
- logger.debug(f"Using agent's model for sampling: {model}")
112
- else:
113
- # Fall back to system default model
108
+ if hasattr(mcp_ctx, "session") and isinstance(mcp_ctx.session, MCPAgentClientSession):
109
+ if mcp_ctx.session.agent_model:
110
+ model = mcp_ctx.session.agent_model
111
+ logger.debug(f"Using agent's model for sampling: {model}")
112
+ if mcp_ctx.session.api_key:
113
+ api_key = mcp_ctx.session.api_key
114
+ logger.debug(f"Using agent's API KEY for sampling: {api_key}")
115
+
116
+ # Fall back to system default model
117
+ if model is None:
114
118
  try:
115
119
  if app_context and app_context.config and app_context.config.default_model:
116
120
  model = app_context.config.default_model
@@ -122,7 +126,7 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
122
126
  raise ValueError("No model configured for sampling (server config, agent model, or system default)")
123
127
 
124
128
  # Create an LLM instance
125
- llm = create_sampling_llm(params, model)
129
+ llm = create_sampling_llm(params, model, api_key)
126
130
 
127
131
  # Extract all messages from the request params
128
132
  if not params.messages:
@@ -51,8 +51,6 @@ async def main():
51
51
  else:
52
52
  console.print("[red]No registration data received[/red]")
53
53
 
54
- console.print("\n" + "─" * 50 + "\n")
55
-
56
54
  # Example 2: Product Review
57
55
  console.print("[bold yellow]Example 2: Product Review Form[/bold yellow]")
58
56
  console.print(
@@ -66,8 +64,6 @@ async def main():
66
64
  )
67
65
  console.print(review_panel)
68
66
 
69
- console.print("\n" + "─" * 50 + "\n")
70
-
71
67
  # Example 3: Account Settings
72
68
  console.print("[bold yellow]Example 3: Account Settings Form[/bold yellow]")
73
69
  console.print(
@@ -81,8 +77,6 @@ async def main():
81
77
  )
82
78
  console.print(settings_panel)
83
79
 
84
- console.print("\n" + "─" * 50 + "\n")
85
-
86
80
  # Example 4: Service Appointment
87
81
  console.print("[bold yellow]Example 4: Service Appointment Booking[/bold yellow]")
88
82
  console.print(
@@ -7,6 +7,8 @@ Demonstrates router's ability to either:
7
7
 
8
8
  import asyncio
9
9
 
10
+ from rich.console import Console
11
+
10
12
  from mcp_agent.core.fastagent import FastAgent
11
13
 
12
14
  # Create the application
@@ -45,7 +47,14 @@ SAMPLE_REQUESTS = [
45
47
  agents=["code_expert", "general_assistant", "fetcher"],
46
48
  )
47
49
  async def main() -> None:
50
+ console = Console()
51
+ console.print(
52
+ "\n[bright_red]Router Workflow Demo[/bright_red]\n\n"
53
+ "Enter a request to route it to the appropriate agent.\nEnter [bright_red]STOP[/bright_red] to run the demo, [bright_red]EXIT[/bright_red] to leave"
54
+ )
55
+
48
56
  async with fast.run() as agent:
57
+ await agent.interactive(agent_name="route")
49
58
  for request in SAMPLE_REQUESTS:
50
59
  await agent.route(request)
51
60