fast-agent-mcp 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,361 @@
1
+ import base64
2
+ from typing import Any, Dict, List, Tuple
3
+
4
+ # Import necessary types from google.genai
5
+ from google.genai import types
6
+ from mcp.types import (
7
+ BlobResourceContents,
8
+ CallToolRequest,
9
+ CallToolRequestParams,
10
+ CallToolResult,
11
+ EmbeddedResource,
12
+ ImageContent,
13
+ TextContent,
14
+ )
15
+
16
+ from mcp_agent.core.request_params import RequestParams
17
+ from mcp_agent.mcp.helpers.content_helpers import (
18
+ get_image_data,
19
+ get_text,
20
+ is_image_content,
21
+ is_resource_content,
22
+ is_text_content,
23
+ )
24
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
25
+ from mcp_agent.tools.tool_definition import ToolDefinition
26
+
27
+
28
+ class GoogleConverter:
29
+ """
30
+ Converts between fast-agent and google.genai data structures.
31
+ """
32
+
33
+ def _clean_schema_for_google(self, schema: Dict[str, Any]) -> Dict[str, Any]:
34
+ """
35
+ Recursively removes unsupported JSON schema keywords for google.genai.types.Schema.
36
+ Specifically removes 'additionalProperties', '$schema', 'exclusiveMaximum', and 'exclusiveMinimum'.
37
+ """
38
+ cleaned_schema = {}
39
+ unsupported_keys = {
40
+ "additionalProperties",
41
+ "$schema",
42
+ "exclusiveMaximum",
43
+ "exclusiveMinimum",
44
+ }
45
+ supported_string_formats = {"enum", "date-time"}
46
+
47
+ for key, value in schema.items():
48
+ if key in unsupported_keys:
49
+ continue # Skip this key
50
+
51
+ if (
52
+ key == "format"
53
+ and schema.get("type") == "string"
54
+ and value not in supported_string_formats
55
+ ):
56
+ continue # Remove unsupported string formats
57
+
58
+ if isinstance(value, dict):
59
+ cleaned_schema[key] = self._clean_schema_for_google(value)
60
+ elif isinstance(value, list):
61
+ cleaned_schema[key] = [
62
+ self._clean_schema_for_google(item) if isinstance(item, dict) else item
63
+ for item in value
64
+ ]
65
+ else:
66
+ cleaned_schema[key] = value
67
+ return cleaned_schema
68
+
69
+ def convert_to_google_content(
70
+ self, messages: List[PromptMessageMultipart]
71
+ ) -> List[types.Content]:
72
+ """
73
+ Converts a list of fast-agent PromptMessageMultipart to google.genai types.Content.
74
+ Handles different roles and content types (text, images, etc.).
75
+ """
76
+ google_contents: List[types.Content] = []
77
+ for message in messages:
78
+ parts: List[types.Part] = []
79
+ for part_content in message.content: # renamed part to part_content to avoid conflict
80
+ if is_text_content(part_content):
81
+ parts.append(types.Part.from_text(text=get_text(part_content) or ""))
82
+ elif is_image_content(part_content):
83
+ assert isinstance(part_content, ImageContent)
84
+ image_bytes = base64.b64decode(get_image_data(part_content) or "")
85
+ parts.append(
86
+ types.Part.from_bytes(mime_type=part_content.mimeType, data=image_bytes)
87
+ )
88
+ elif is_resource_content(part_content):
89
+ assert isinstance(part_content, EmbeddedResource)
90
+ if (
91
+ "application/pdf" == part_content.resource.mimeType
92
+ and hasattr(part_content.resource, "blob")
93
+ and isinstance(part_content.resource, BlobResourceContents)
94
+ ):
95
+ pdf_bytes = base64.b64decode(part_content.resource.blob)
96
+ parts.append(
97
+ types.Part.from_bytes(
98
+ mime_type=part_content.resource.mimeType or "application/pdf",
99
+ data=pdf_bytes,
100
+ )
101
+ )
102
+ else:
103
+ # Check if the resource itself has text content
104
+ resource_text = None
105
+ if hasattr(part_content.resource, "text"): # Direct text attribute
106
+ resource_text = part_content.resource.text
107
+ # Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
108
+ elif (
109
+ hasattr(part_content.resource, "type")
110
+ and part_content.resource.type == "text"
111
+ and hasattr(part_content.resource, "text")
112
+ ):
113
+ resource_text = get_text(part_content.resource)
114
+
115
+ if resource_text is not None:
116
+ parts.append(types.Part.from_text(text=resource_text))
117
+ else:
118
+ # Fallback for other binary types or types without direct text
119
+ uri_str = (
120
+ part_content.resource.uri
121
+ if hasattr(part_content.resource, "uri")
122
+ else "unknown_uri"
123
+ )
124
+ mime_str = (
125
+ part_content.resource.mimeType
126
+ if hasattr(part_content.resource, "mimeType")
127
+ else "unknown_mime"
128
+ )
129
+ parts.append(
130
+ types.Part.from_text(
131
+ text=f"[Resource: {uri_str}, MIME: {mime_str}]"
132
+ )
133
+ )
134
+
135
+ if parts:
136
+ google_role = (
137
+ "user"
138
+ if message.role == "user"
139
+ else ("model" if message.role == "assistant" else "tool")
140
+ )
141
+ google_contents.append(types.Content(role=google_role, parts=parts))
142
+ return google_contents
143
+
144
+ def convert_to_google_tools(self, tools: List[ToolDefinition]) -> List[types.Tool]:
145
+ """
146
+ Converts a list of fast-agent ToolDefinition to google.genai types.Tool.
147
+ """
148
+ google_tools: List[types.Tool] = []
149
+ for tool in tools:
150
+ cleaned_input_schema = self._clean_schema_for_google(tool.inputSchema)
151
+ function_declaration = types.FunctionDeclaration(
152
+ name=tool.name,
153
+ description=tool.description if tool.description else "",
154
+ parameters=types.Schema(**cleaned_input_schema),
155
+ )
156
+ google_tools.append(types.Tool(function_declarations=[function_declaration]))
157
+ return google_tools
158
+
159
+ def convert_from_google_content(
160
+ self, content: types.Content
161
+ ) -> List[TextContent | ImageContent | EmbeddedResource | CallToolRequestParams]:
162
+ """
163
+ Converts google.genai types.Content from a model response to a list of
164
+ fast-agent content types or tool call requests.
165
+ """
166
+ fast_agent_parts: List[
167
+ TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
168
+ ] = []
169
+ for part in content.parts:
170
+ if part.text:
171
+ fast_agent_parts.append(TextContent(type="text", text=part.text))
172
+ elif part.function_call:
173
+ fast_agent_parts.append(
174
+ CallToolRequestParams(
175
+ name=part.function_call.name,
176
+ arguments=part.function_call.args,
177
+ )
178
+ )
179
+ return fast_agent_parts
180
+
181
+ def convert_from_google_function_call(
182
+ self, function_call: types.FunctionCall
183
+ ) -> CallToolRequest:
184
+ """
185
+ Converts a single google.genai types.FunctionCall to a fast-agent CallToolRequest.
186
+ """
187
+ return CallToolRequest(
188
+ method="tools/call",
189
+ params=CallToolRequestParams(
190
+ name=function_call.name,
191
+ arguments=function_call.args,
192
+ ),
193
+ )
194
+
195
+ def convert_function_results_to_google(
196
+ self, tool_results: List[Tuple[str, CallToolResult]]
197
+ ) -> List[types.Content]:
198
+ """
199
+ Converts a list of fast-agent tool results to google.genai types.Content
200
+ with role 'tool'. Handles multimodal content in tool results.
201
+ """
202
+ google_tool_response_contents: List[types.Content] = []
203
+ for tool_name, tool_result in tool_results:
204
+ current_content_parts: List[types.Part] = []
205
+ textual_outputs: List[str] = []
206
+ media_parts: List[types.Part] = []
207
+
208
+ for item in tool_result.content:
209
+ if is_text_content(item):
210
+ textual_outputs.append(get_text(item) or "") # Ensure no None is added
211
+ elif is_image_content(item):
212
+ assert isinstance(item, ImageContent)
213
+ try:
214
+ image_bytes = base64.b64decode(get_image_data(item) or "")
215
+ media_parts.append(
216
+ types.Part.from_bytes(data=image_bytes, mime_type=item.mimeType)
217
+ )
218
+ except Exception as e:
219
+ textual_outputs.append(f"[Error processing image from tool result: {e}]")
220
+ elif is_resource_content(item):
221
+ assert isinstance(item, EmbeddedResource)
222
+ if (
223
+ "application/pdf" == item.resource.mimeType
224
+ and hasattr(item.resource, "blob")
225
+ and isinstance(item.resource, BlobResourceContents)
226
+ ):
227
+ try:
228
+ pdf_bytes = base64.b64decode(item.resource.blob)
229
+ media_parts.append(
230
+ types.Part.from_bytes(
231
+ data=pdf_bytes,
232
+ mime_type=item.resource.mimeType or "application/pdf",
233
+ )
234
+ )
235
+ except Exception as e:
236
+ textual_outputs.append(f"[Error processing PDF from tool result: {e}]")
237
+ else:
238
+ # Check if the resource itself has text content
239
+ resource_text = None
240
+ if hasattr(item.resource, "text"): # Direct text attribute
241
+ resource_text = item.resource.text
242
+ # Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
243
+ elif (
244
+ hasattr(item.resource, "type")
245
+ and item.resource.type == "text"
246
+ and hasattr(item.resource, "text")
247
+ ):
248
+ resource_text = get_text(item.resource)
249
+
250
+ if resource_text is not None:
251
+ textual_outputs.append(resource_text)
252
+ else:
253
+ uri_str = (
254
+ item.resource.uri
255
+ if hasattr(item.resource, "uri")
256
+ else "unknown_uri"
257
+ )
258
+ mime_str = (
259
+ item.resource.mimeType
260
+ if hasattr(item.resource, "mimeType")
261
+ else "unknown_mime"
262
+ )
263
+ textual_outputs.append(
264
+ f"[Unhandled Resource in Tool: {uri_str}, MIME: {mime_str}]"
265
+ )
266
+ # Add handling for other content types if needed, for now they are skipped or become unhandled resource text
267
+
268
+ function_response_payload: Dict[str, Any] = {"tool_name": tool_name}
269
+ if textual_outputs:
270
+ function_response_payload["text_content"] = "\n".join(textual_outputs)
271
+
272
+ # Only add media_parts if there are some, otherwise Gemini might error on empty parts for function response
273
+ if media_parts:
274
+ # Create the main FunctionResponse part
275
+ fn_response_part = types.Part.from_function_response(
276
+ name=tool_name, response=function_response_payload
277
+ )
278
+ current_content_parts.append(fn_response_part)
279
+ current_content_parts.extend(
280
+ media_parts
281
+ ) # Add media parts after the main response part
282
+ else: # If no media parts, the textual output (if any) is the sole content of the function response
283
+ fn_response_part = types.Part.from_function_response(
284
+ name=tool_name, response=function_response_payload
285
+ )
286
+ current_content_parts.append(fn_response_part)
287
+
288
+ google_tool_response_contents.append(
289
+ types.Content(role="tool", parts=current_content_parts)
290
+ )
291
+ return google_tool_response_contents
292
+
293
+ def convert_request_params_to_google_config(
294
+ self, request_params: RequestParams
295
+ ) -> types.GenerateContentConfig:
296
+ """
297
+ Converts fast-agent RequestParams to google.genai types.GenerateContentConfig.
298
+ """
299
+ config_args: Dict[str, Any] = {}
300
+ if request_params.temperature is not None:
301
+ config_args["temperature"] = request_params.temperature
302
+ if request_params.maxTokens is not None:
303
+ config_args["max_output_tokens"] = request_params.maxTokens
304
+ if hasattr(request_params, "topK") and request_params.topK is not None:
305
+ config_args["top_k"] = request_params.topK
306
+ if hasattr(request_params, "topP") and request_params.topP is not None:
307
+ config_args["top_p"] = request_params.topP
308
+ if hasattr(request_params, "stopSequences") and request_params.stopSequences is not None:
309
+ config_args["stop_sequences"] = request_params.stopSequences
310
+ if (
311
+ hasattr(request_params, "presencePenalty")
312
+ and request_params.presencePenalty is not None
313
+ ):
314
+ config_args["presence_penalty"] = request_params.presencePenalty
315
+ if (
316
+ hasattr(request_params, "frequencyPenalty")
317
+ and request_params.frequencyPenalty is not None
318
+ ):
319
+ config_args["frequency_penalty"] = request_params.frequencyPenalty
320
+ if request_params.systemPrompt is not None:
321
+ config_args["system_instruction"] = request_params.systemPrompt
322
+ return types.GenerateContentConfig(**config_args)
323
+
324
+ def convert_from_google_content_list(
325
+ self, contents: List[types.Content]
326
+ ) -> List[PromptMessageMultipart]:
327
+ """
328
+ Converts a list of google.genai types.Content to a list of fast-agent PromptMessageMultipart.
329
+ """
330
+ return [self._convert_from_google_content(content) for content in contents]
331
+
332
+ def _convert_from_google_content(self, content: types.Content) -> PromptMessageMultipart:
333
+ """
334
+ Converts a single google.genai types.Content to a fast-agent PromptMessageMultipart.
335
+ """
336
+ if content.role == "model" and any(part.function_call for part in content.parts):
337
+ return PromptMessageMultipart(role="assistant", content=[])
338
+
339
+ fast_agent_parts: List[
340
+ TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
341
+ ] = []
342
+ for part in content.parts:
343
+ if part.text:
344
+ fast_agent_parts.append(TextContent(type="text", text=part.text))
345
+ elif part.function_response:
346
+ response_text = str(part.function_response.response)
347
+ fast_agent_parts.append(TextContent(type="text", text=response_text))
348
+ elif part.file_data:
349
+ fast_agent_parts.append(
350
+ EmbeddedResource(
351
+ type="resource",
352
+ resource=TextContent(
353
+ uri=part.file_data.file_uri,
354
+ mimeType=part.file_data.mime_type,
355
+ text=f"[Resource: {part.file_data.file_uri}, MIME: {part.file_data.mime_type}]",
356
+ ),
357
+ )
358
+ )
359
+
360
+ fast_agent_role = "user" if content.role == "user" else "assistant"
361
+ return PromptMessageMultipart(role=fast_agent_role, content=fast_agent_parts)
@@ -0,0 +1,23 @@
1
+ """Helper functions for type-safe server config access."""
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ from mcp import ClientSession
6
+
7
+ if TYPE_CHECKING:
8
+ from mcp_agent.config import MCPServerSettings
9
+
10
+
11
+ def get_server_config(ctx: ClientSession) -> Optional["MCPServerSettings"]:
12
+ """Extract server config from context if available.
13
+
14
+ Type guard helper that safely accesses server_config with proper type checking.
15
+ """
16
+ # Import here to avoid circular import
17
+ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
18
+
19
+ if (hasattr(ctx, "session") and
20
+ isinstance(ctx.session, MCPAgentClientSession) and
21
+ ctx.session.server_config):
22
+ return ctx.session.server_config
23
+ return None
@@ -4,7 +4,7 @@ It adds logging and supports sampling requests.
4
4
  """
5
5
 
6
6
  from datetime import timedelta
7
- from typing import TYPE_CHECKING, Optional
7
+ from typing import TYPE_CHECKING
8
8
 
9
9
  from mcp import ClientSession, ServerNotification
10
10
  from mcp.shared.session import (
@@ -20,6 +20,7 @@ from pydantic import FileUrl
20
20
 
21
21
  from mcp_agent.context_dependent import ContextDependent
22
22
  from mcp_agent.logging.logger import get_logger
23
+ from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
23
24
  from mcp_agent.mcp.sampling import sample
24
25
 
25
26
  if TYPE_CHECKING:
@@ -31,24 +32,20 @@ logger = get_logger(__name__)
31
32
  async def list_roots(ctx: ClientSession) -> ListRootsResult:
32
33
  """List roots callback that will be called by the MCP library."""
33
34
 
34
- roots = []
35
- if (
36
- hasattr(ctx, "session")
37
- and hasattr(ctx.session, "server_config")
38
- and ctx.session.server_config
39
- and hasattr(ctx.session.server_config, "roots")
40
- and ctx.session.server_config.roots
41
- ):
42
- roots = [
43
- Root(
44
- uri=FileUrl(
45
- root.server_uri_alias or root.uri,
46
- ),
47
- name=root.name,
48
- )
49
- for root in ctx.session.server_config.roots
50
- ]
51
- return ListRootsResult(roots=roots or [])
35
+ if server_config := get_server_config(ctx):
36
+ if server_config.roots:
37
+ roots = [
38
+ Root(
39
+ uri=FileUrl(
40
+ root.server_uri_alias or root.uri,
41
+ ),
42
+ name=root.name,
43
+ )
44
+ for root in server_config.roots
45
+ ]
46
+ return ListRootsResult(roots=roots)
47
+
48
+ return ListRootsResult(roots=[])
52
49
 
53
50
 
54
51
  class MCPAgentClientSession(ClientSession, ContextDependent):
@@ -72,14 +69,43 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
72
69
  self.session_server_name = kwargs.pop("server_name", None)
73
70
  # Extract the notification callbacks if provided
74
71
  self._tool_list_changed_callback = kwargs.pop("tool_list_changed_callback", None)
72
+ # Extract server_config if provided
73
+ self.server_config: MCPServerSettings | None = kwargs.pop("server_config", None)
74
+ # Extract agent_model if provided (for auto_sampling fallback)
75
+ self.agent_model: str | None = kwargs.pop("agent_model", None)
76
+
77
+ # Only register callbacks if the server_config has the relevant settings
78
+ list_roots_cb = list_roots if (self.server_config and self.server_config.roots) else None
79
+
80
+ # Register sampling callback if either:
81
+ # 1. Sampling is explicitly configured, OR
82
+ # 2. Application-level auto_sampling is enabled
83
+ sampling_cb = None
84
+ if self.server_config and hasattr(self.server_config, "sampling") and self.server_config.sampling:
85
+ # Explicit sampling configuration
86
+ sampling_cb = sample
87
+ elif self._should_enable_auto_sampling():
88
+ # Auto-sampling enabled at application level
89
+ sampling_cb = sample
90
+
75
91
  super().__init__(
76
92
  *args,
77
93
  **kwargs,
78
- list_roots_callback=list_roots,
79
- sampling_callback=sample,
94
+ list_roots_callback=list_roots_cb,
95
+ sampling_callback=sampling_cb,
80
96
  client_info=fast_agent,
81
97
  )
82
- self.server_config: Optional[MCPServerSettings] = None
98
+
99
+ def _should_enable_auto_sampling(self) -> bool:
100
+ """Check if auto_sampling is enabled at the application level."""
101
+ try:
102
+ from mcp_agent.context import get_current_context
103
+ context = get_current_context()
104
+ if context and context.config:
105
+ return getattr(context.config, 'auto_sampling', True)
106
+ except Exception:
107
+ pass
108
+ return True # Default to True if can't access config
83
109
 
84
110
  async def send_request(
85
111
  self,
@@ -91,10 +117,11 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
91
117
  logger.debug("send_request: request=", data=request.model_dump())
92
118
  try:
93
119
  result = await super().send_request(
94
- request,
95
- result_type,
120
+ request=request,
121
+ result_type=result_type,
96
122
  request_read_timeout_seconds=request_read_timeout_seconds,
97
123
  progress_callback=progress_callback,
124
+ metadata=None,
98
125
  )
99
126
  logger.debug("send_request: response=", data=result.model_dump())
100
127
  return result
@@ -215,13 +215,20 @@ class MCPAggregator(ContextDependent):
215
215
  )
216
216
 
217
217
  # Create a wrapper to capture the parameters for the client session
218
- def session_factory(read_stream, write_stream, read_timeout):
218
+ def session_factory(read_stream, write_stream, read_timeout, **kwargs):
219
+ # Get agent's model if this aggregator is part of an agent
220
+ agent_model = None
221
+ if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
222
+ agent_model = self.config.model
223
+
219
224
  return MCPAgentClientSession(
220
225
  read_stream,
221
226
  write_stream,
222
227
  read_timeout,
223
228
  server_name=server_name,
229
+ agent_model=agent_model,
224
230
  tool_list_changed_callback=self._handle_tool_list_changed,
231
+ **kwargs # Pass through any additional kwargs like server_config
225
232
  )
226
233
 
227
234
  await self._persistent_connection_manager.get_server(
@@ -269,13 +276,20 @@ class MCPAggregator(ContextDependent):
269
276
  prompts = await fetch_prompts(server_connection.session, server_name)
270
277
  else:
271
278
  # Create a factory function for the client session
272
- def create_session(read_stream, write_stream, read_timeout):
279
+ def create_session(read_stream, write_stream, read_timeout, **kwargs):
280
+ # Get agent's model if this aggregator is part of an agent
281
+ agent_model = None
282
+ if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
283
+ agent_model = self.config.model
284
+
273
285
  return MCPAgentClientSession(
274
286
  read_stream,
275
287
  write_stream,
276
288
  read_timeout,
277
289
  server_name=server_name,
290
+ agent_model=agent_model,
278
291
  tool_list_changed_callback=self._handle_tool_list_changed,
292
+ **kwargs # Pass through any additional kwargs like server_config
279
293
  )
280
294
 
281
295
  async with gen_client(
@@ -797,12 +811,13 @@ class MCPAggregator(ContextDependent):
797
811
  messages=[],
798
812
  )
799
813
 
800
- async def list_prompts(self, server_name: str | None = None) -> Mapping[str, List[Prompt]]:
814
+ async def list_prompts(self, server_name: str | None = None, agent_name: str | None = None) -> Mapping[str, List[Prompt]]:
801
815
  """
802
816
  List available prompts from one or all servers.
803
817
 
804
818
  :param server_name: Optional server name to list prompts from. If not provided,
805
819
  lists prompts from all servers.
820
+ :param agent_name: Optional agent name (ignored at this level, used by multi-agent apps)
806
821
  :return: Dictionary mapping server names to lists of Prompt objects
807
822
  """
808
823
  if not self.initialized:
@@ -165,11 +165,12 @@ class ServerConnection:
165
165
  else None
166
166
  )
167
167
 
168
- session = self._client_session_factory(read_stream, send_stream, read_timeout)
169
-
170
- # Make the server config available to the session for initialization
171
- if hasattr(session, "server_config"):
172
- session.server_config = self.server_config
168
+ session = self._client_session_factory(
169
+ read_stream,
170
+ send_stream,
171
+ read_timeout,
172
+ server_config=self.server_config
173
+ )
173
174
 
174
175
  self.session = session
175
176
 
mcp_agent/mcp/sampling.py CHANGED
@@ -10,6 +10,7 @@ from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextConte
10
10
  from mcp_agent.core.agent_types import AgentConfig
11
11
  from mcp_agent.llm.sampling_converter import SamplingConverter
12
12
  from mcp_agent.logging.logger import get_logger
13
+ from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
13
14
  from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
14
15
 
15
16
  if TYPE_CHECKING:
@@ -78,18 +79,47 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
78
79
  """
79
80
  model = None
80
81
  try:
81
- # Extract model from server config
82
- if (
83
- hasattr(mcp_ctx, "session")
84
- and hasattr(mcp_ctx.session, "server_config")
85
- and mcp_ctx.session.server_config
86
- and hasattr(mcp_ctx.session.server_config, "sampling")
87
- and mcp_ctx.session.server_config.sampling.model
88
- ):
89
- model = mcp_ctx.session.server_config.sampling.model
82
+ # Extract model from server config using type-safe helper
83
+ server_config = get_server_config(mcp_ctx)
84
+
85
+ # First priority: explicitly configured sampling model
86
+ if server_config and hasattr(server_config, "sampling") and server_config.sampling:
87
+ model = server_config.sampling.model
88
+
89
+ # Second priority: auto_sampling fallback (if enabled at application level)
90
+ if model is None:
91
+ # Check if auto_sampling is enabled
92
+ auto_sampling_enabled = False
93
+ try:
94
+ from mcp_agent.context import get_current_context
95
+ app_context = get_current_context()
96
+ if app_context and app_context.config:
97
+ auto_sampling_enabled = getattr(app_context.config, 'auto_sampling', True)
98
+ except Exception as e:
99
+ logger.debug(f"Could not get application config: {e}")
100
+ auto_sampling_enabled = True # Default to enabled
101
+
102
+ if auto_sampling_enabled:
103
+ # Import here to avoid circular import
104
+ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
105
+
106
+ # Try agent's model first (from the session)
107
+ if (hasattr(mcp_ctx, 'session') and
108
+ isinstance(mcp_ctx.session, MCPAgentClientSession) and
109
+ mcp_ctx.session.agent_model):
110
+ model = mcp_ctx.session.agent_model
111
+ logger.debug(f"Using agent's model for sampling: {model}")
112
+ else:
113
+ # Fall back to system default model
114
+ try:
115
+ if app_context and app_context.config and app_context.config.default_model:
116
+ model = app_context.config.default_model
117
+ logger.debug(f"Using system default model for sampling: {model}")
118
+ except Exception as e:
119
+ logger.debug(f"Could not get system default model: {e}")
90
120
 
91
121
  if model is None:
92
- raise ValueError("No model configured")
122
+ raise ValueError("No model configured for sampling (server config, agent model, or system default)")
93
123
 
94
124
  # Create an LLM instance
95
125
  llm = create_sampling_llm(params, model)