fast-agent-mcp 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -215,13 +215,20 @@ class MCPAggregator(ContextDependent):
215
215
  )
216
216
 
217
217
  # Create a wrapper to capture the parameters for the client session
218
- def session_factory(read_stream, write_stream, read_timeout):
218
+ def session_factory(read_stream, write_stream, read_timeout, **kwargs):
219
+ # Get agent's model if this aggregator is part of an agent
220
+ agent_model = None
221
+ if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
222
+ agent_model = self.config.model
223
+
219
224
  return MCPAgentClientSession(
220
225
  read_stream,
221
226
  write_stream,
222
227
  read_timeout,
223
228
  server_name=server_name,
229
+ agent_model=agent_model,
224
230
  tool_list_changed_callback=self._handle_tool_list_changed,
231
+ **kwargs # Pass through any additional kwargs like server_config
225
232
  )
226
233
 
227
234
  await self._persistent_connection_manager.get_server(
@@ -269,13 +276,20 @@ class MCPAggregator(ContextDependent):
269
276
  prompts = await fetch_prompts(server_connection.session, server_name)
270
277
  else:
271
278
  # Create a factory function for the client session
272
- def create_session(read_stream, write_stream, read_timeout):
279
+ def create_session(read_stream, write_stream, read_timeout, **kwargs):
280
+ # Get agent's model if this aggregator is part of an agent
281
+ agent_model = None
282
+ if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
283
+ agent_model = self.config.model
284
+
273
285
  return MCPAgentClientSession(
274
286
  read_stream,
275
287
  write_stream,
276
288
  read_timeout,
277
289
  server_name=server_name,
290
+ agent_model=agent_model,
278
291
  tool_list_changed_callback=self._handle_tool_list_changed,
292
+ **kwargs # Pass through any additional kwargs like server_config
279
293
  )
280
294
 
281
295
  async with gen_client(
@@ -797,12 +811,13 @@ class MCPAggregator(ContextDependent):
797
811
  messages=[],
798
812
  )
799
813
 
800
- async def list_prompts(self, server_name: str | None = None) -> Mapping[str, List[Prompt]]:
814
+ async def list_prompts(self, server_name: str | None = None, agent_name: str | None = None) -> Mapping[str, List[Prompt]]:
801
815
  """
802
816
  List available prompts from one or all servers.
803
817
 
804
818
  :param server_name: Optional server name to list prompts from. If not provided,
805
819
  lists prompts from all servers.
820
+ :param agent_name: Optional agent name (ignored at this level, used by multi-agent apps)
806
821
  :return: Dictionary mapping server names to lists of Prompt objects
807
822
  """
808
823
  if not self.initialized:
@@ -165,11 +165,12 @@ class ServerConnection:
165
165
  else None
166
166
  )
167
167
 
168
- session = self._client_session_factory(read_stream, send_stream, read_timeout)
169
-
170
- # Make the server config available to the session for initialization
171
- if hasattr(session, "server_config"):
172
- session.server_config = self.server_config
168
+ session = self._client_session_factory(
169
+ read_stream,
170
+ send_stream,
171
+ read_timeout,
172
+ server_config=self.server_config
173
+ )
173
174
 
174
175
  self.session = session
175
176
 
mcp_agent/mcp/sampling.py CHANGED
@@ -10,6 +10,7 @@ from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextConte
10
10
  from mcp_agent.core.agent_types import AgentConfig
11
11
  from mcp_agent.llm.sampling_converter import SamplingConverter
12
12
  from mcp_agent.logging.logger import get_logger
13
+ from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
13
14
  from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
14
15
 
15
16
  if TYPE_CHECKING:
@@ -78,18 +79,47 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
78
79
  """
79
80
  model = None
80
81
  try:
81
- # Extract model from server config
82
- if (
83
- hasattr(mcp_ctx, "session")
84
- and hasattr(mcp_ctx.session, "server_config")
85
- and mcp_ctx.session.server_config
86
- and hasattr(mcp_ctx.session.server_config, "sampling")
87
- and mcp_ctx.session.server_config.sampling.model
88
- ):
89
- model = mcp_ctx.session.server_config.sampling.model
82
+ # Extract model from server config using type-safe helper
83
+ server_config = get_server_config(mcp_ctx)
84
+
85
+ # First priority: explicitly configured sampling model
86
+ if server_config and hasattr(server_config, "sampling") and server_config.sampling:
87
+ model = server_config.sampling.model
88
+
89
+ # Second priority: auto_sampling fallback (if enabled at application level)
90
+ if model is None:
91
+ # Check if auto_sampling is enabled
92
+ auto_sampling_enabled = False
93
+ try:
94
+ from mcp_agent.context import get_current_context
95
+ app_context = get_current_context()
96
+ if app_context and app_context.config:
97
+ auto_sampling_enabled = getattr(app_context.config, 'auto_sampling', True)
98
+ except Exception as e:
99
+ logger.debug(f"Could not get application config: {e}")
100
+ auto_sampling_enabled = True # Default to enabled
101
+
102
+ if auto_sampling_enabled:
103
+ # Import here to avoid circular import
104
+ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
105
+
106
+ # Try agent's model first (from the session)
107
+ if (hasattr(mcp_ctx, 'session') and
108
+ isinstance(mcp_ctx.session, MCPAgentClientSession) and
109
+ mcp_ctx.session.agent_model):
110
+ model = mcp_ctx.session.agent_model
111
+ logger.debug(f"Using agent's model for sampling: {model}")
112
+ else:
113
+ # Fall back to system default model
114
+ try:
115
+ if app_context and app_context.config and app_context.config.default_model:
116
+ model = app_context.config.default_model
117
+ logger.debug(f"Using system default model for sampling: {model}")
118
+ except Exception as e:
119
+ logger.debug(f"Could not get system default model: {e}")
90
120
 
91
121
  if model is None:
92
- raise ValueError("No model configured")
122
+ raise ValueError("No model configured for sampling (server config, agent model, or system default)")
93
123
 
94
124
  # Create an LLM instance
95
125
  llm = create_sampling_llm(params, model)
@@ -27,6 +27,7 @@ from mcp_agent.config import (
27
27
  get_settings,
28
28
  )
29
29
  from mcp_agent.logging.logger import get_logger
30
+ from mcp_agent.mcp.hf_auth import add_hf_auth_header
30
31
  from mcp_agent.mcp.logger_textio import get_stderr_handler
31
32
  from mcp_agent.mcp.mcp_connection_manager import (
32
33
  MCPConnectionManager,
@@ -70,9 +71,15 @@ class ServerRegistry:
70
71
  config (Settings): The Settings object containing the server configurations.
71
72
  config_path (str): Path to the YAML configuration file.
72
73
  """
73
- self.registry = (
74
- self.load_registry_from_file(config_path) if config is None else config.mcp.servers
75
- )
74
+ if config is None:
75
+ self.registry = self.load_registry_from_file(config_path)
76
+ elif config.mcp is not None and hasattr(config.mcp, 'servers') and config.mcp.servers is not None:
77
+ # Ensure config.mcp exists, has a 'servers' attribute, and it's not None
78
+ self.registry = config.mcp.servers
79
+ else:
80
+ # Default to an empty dictionary if config.mcp is None or has no 'servers'
81
+ self.registry = {}
82
+
76
83
  self.init_hooks: Dict[str, InitHookCallable] = {}
77
84
  self.connection_manager = MCPConnectionManager(self)
78
85
 
@@ -88,8 +95,13 @@ class ServerRegistry:
88
95
  Raises:
89
96
  ValueError: If the configuration is invalid.
90
97
  """
98
+ servers = {}
91
99
 
92
- servers = get_settings(config_path).mcp.servers or {}
100
+ settings = get_settings(config_path)
101
+
102
+ if settings.mcp is not None and hasattr(settings.mcp, 'servers') and settings.mcp.servers is not None:
103
+ return settings.mcp.servers
104
+
93
105
  return servers
94
106
 
95
107
  @asynccontextmanager
@@ -165,11 +177,14 @@ class ServerRegistry:
165
177
  if not config.url:
166
178
  raise ValueError(f"URL is required for SSE transport: {server_name}")
167
179
 
180
+ # Apply HuggingFace authentication if appropriate
181
+ headers = add_hf_auth_header(config.url, config.headers)
182
+
168
183
  # Use sse_client to get the read and write streams
169
184
  async with _add_none_to_context(
170
185
  sse_client(
171
186
  config.url,
172
- config.headers,
187
+ headers,
173
188
  sse_read_timeout=config.read_transport_sse_timeout_seconds,
174
189
  )
175
190
  ) as (read_stream, write_stream, _):
@@ -187,9 +202,12 @@ class ServerRegistry:
187
202
  logger.debug(f"{server_name}: Closed session to server")
188
203
  elif config.transport == "http":
189
204
  if not config.url:
190
- raise ValueError(f"URL is required for SSE transport: {server_name}")
205
+ raise ValueError(f"URL is required for HTTP transport: {server_name}")
206
+
207
+ # Apply HuggingFace authentication if appropriate
208
+ headers = add_hf_auth_header(config.url, config.headers)
191
209
 
192
- async with streamablehttp_client(config.url, config.headers) as (
210
+ async with streamablehttp_client(config.url, headers) as (
193
211
  read_stream,
194
212
  write_stream,
195
213
  _,
@@ -0,0 +1,14 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, Optional
3
+
4
+
5
+ @dataclass
6
+ class ToolDefinition:
7
+ """
8
+ Represents a definition of a tool available to the agent.
9
+ """
10
+
11
+ name: str
12
+ description: Optional[str] = None
13
+ inputSchema: Dict[str, Any] = field(default_factory=dict)
14
+ # Add other relevant fields if necessary based on how tools are defined in fast-agent