fast-agent-mcp 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.28.dist-info}/METADATA +3 -1
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.28.dist-info}/RECORD +24 -19
- mcp_agent/agents/agent.py +1 -17
- mcp_agent/agents/base_agent.py +2 -0
- mcp_agent/config.py +3 -0
- mcp_agent/context.py +2 -0
- mcp_agent/core/agent_app.py +7 -2
- mcp_agent/core/interactive_prompt.py +58 -51
- mcp_agent/llm/augmented_llm_slow.py +42 -0
- mcp_agent/llm/model_factory.py +74 -37
- mcp_agent/llm/provider_types.py +4 -3
- mcp_agent/llm/providers/augmented_llm_google_native.py +459 -0
- mcp_agent/llm/providers/{augmented_llm_google.py → augmented_llm_google_oai.py} +2 -2
- mcp_agent/llm/providers/google_converter.py +361 -0
- mcp_agent/mcp/helpers/server_config_helpers.py +23 -0
- mcp_agent/mcp/mcp_agent_client_session.py +51 -24
- mcp_agent/mcp/mcp_aggregator.py +18 -3
- mcp_agent/mcp/mcp_connection_manager.py +6 -5
- mcp_agent/mcp/sampling.py +40 -10
- mcp_agent/mcp_server_registry.py +15 -4
- mcp_agent/tools/tool_definition.py +14 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.28.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.28.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.28.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,361 @@
|
|
1
|
+
import base64
|
2
|
+
from typing import Any, Dict, List, Tuple
|
3
|
+
|
4
|
+
# Import necessary types from google.genai
|
5
|
+
from google.genai import types
|
6
|
+
from mcp.types import (
|
7
|
+
BlobResourceContents,
|
8
|
+
CallToolRequest,
|
9
|
+
CallToolRequestParams,
|
10
|
+
CallToolResult,
|
11
|
+
EmbeddedResource,
|
12
|
+
ImageContent,
|
13
|
+
TextContent,
|
14
|
+
)
|
15
|
+
|
16
|
+
from mcp_agent.core.request_params import RequestParams
|
17
|
+
from mcp_agent.mcp.helpers.content_helpers import (
|
18
|
+
get_image_data,
|
19
|
+
get_text,
|
20
|
+
is_image_content,
|
21
|
+
is_resource_content,
|
22
|
+
is_text_content,
|
23
|
+
)
|
24
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
25
|
+
from mcp_agent.tools.tool_definition import ToolDefinition
|
26
|
+
|
27
|
+
|
28
|
+
class GoogleConverter:
|
29
|
+
"""
|
30
|
+
Converts between fast-agent and google.genai data structures.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def _clean_schema_for_google(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
34
|
+
"""
|
35
|
+
Recursively removes unsupported JSON schema keywords for google.genai.types.Schema.
|
36
|
+
Specifically removes 'additionalProperties', '$schema', 'exclusiveMaximum', and 'exclusiveMinimum'.
|
37
|
+
"""
|
38
|
+
cleaned_schema = {}
|
39
|
+
unsupported_keys = {
|
40
|
+
"additionalProperties",
|
41
|
+
"$schema",
|
42
|
+
"exclusiveMaximum",
|
43
|
+
"exclusiveMinimum",
|
44
|
+
}
|
45
|
+
supported_string_formats = {"enum", "date-time"}
|
46
|
+
|
47
|
+
for key, value in schema.items():
|
48
|
+
if key in unsupported_keys:
|
49
|
+
continue # Skip this key
|
50
|
+
|
51
|
+
if (
|
52
|
+
key == "format"
|
53
|
+
and schema.get("type") == "string"
|
54
|
+
and value not in supported_string_formats
|
55
|
+
):
|
56
|
+
continue # Remove unsupported string formats
|
57
|
+
|
58
|
+
if isinstance(value, dict):
|
59
|
+
cleaned_schema[key] = self._clean_schema_for_google(value)
|
60
|
+
elif isinstance(value, list):
|
61
|
+
cleaned_schema[key] = [
|
62
|
+
self._clean_schema_for_google(item) if isinstance(item, dict) else item
|
63
|
+
for item in value
|
64
|
+
]
|
65
|
+
else:
|
66
|
+
cleaned_schema[key] = value
|
67
|
+
return cleaned_schema
|
68
|
+
|
69
|
+
def convert_to_google_content(
|
70
|
+
self, messages: List[PromptMessageMultipart]
|
71
|
+
) -> List[types.Content]:
|
72
|
+
"""
|
73
|
+
Converts a list of fast-agent PromptMessageMultipart to google.genai types.Content.
|
74
|
+
Handles different roles and content types (text, images, etc.).
|
75
|
+
"""
|
76
|
+
google_contents: List[types.Content] = []
|
77
|
+
for message in messages:
|
78
|
+
parts: List[types.Part] = []
|
79
|
+
for part_content in message.content: # renamed part to part_content to avoid conflict
|
80
|
+
if is_text_content(part_content):
|
81
|
+
parts.append(types.Part.from_text(text=get_text(part_content) or ""))
|
82
|
+
elif is_image_content(part_content):
|
83
|
+
assert isinstance(part_content, ImageContent)
|
84
|
+
image_bytes = base64.b64decode(get_image_data(part_content) or "")
|
85
|
+
parts.append(
|
86
|
+
types.Part.from_bytes(mime_type=part_content.mimeType, data=image_bytes)
|
87
|
+
)
|
88
|
+
elif is_resource_content(part_content):
|
89
|
+
assert isinstance(part_content, EmbeddedResource)
|
90
|
+
if (
|
91
|
+
"application/pdf" == part_content.resource.mimeType
|
92
|
+
and hasattr(part_content.resource, "blob")
|
93
|
+
and isinstance(part_content.resource, BlobResourceContents)
|
94
|
+
):
|
95
|
+
pdf_bytes = base64.b64decode(part_content.resource.blob)
|
96
|
+
parts.append(
|
97
|
+
types.Part.from_bytes(
|
98
|
+
mime_type=part_content.resource.mimeType or "application/pdf",
|
99
|
+
data=pdf_bytes,
|
100
|
+
)
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
# Check if the resource itself has text content
|
104
|
+
resource_text = None
|
105
|
+
if hasattr(part_content.resource, "text"): # Direct text attribute
|
106
|
+
resource_text = part_content.resource.text
|
107
|
+
# Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
|
108
|
+
elif (
|
109
|
+
hasattr(part_content.resource, "type")
|
110
|
+
and part_content.resource.type == "text"
|
111
|
+
and hasattr(part_content.resource, "text")
|
112
|
+
):
|
113
|
+
resource_text = get_text(part_content.resource)
|
114
|
+
|
115
|
+
if resource_text is not None:
|
116
|
+
parts.append(types.Part.from_text(text=resource_text))
|
117
|
+
else:
|
118
|
+
# Fallback for other binary types or types without direct text
|
119
|
+
uri_str = (
|
120
|
+
part_content.resource.uri
|
121
|
+
if hasattr(part_content.resource, "uri")
|
122
|
+
else "unknown_uri"
|
123
|
+
)
|
124
|
+
mime_str = (
|
125
|
+
part_content.resource.mimeType
|
126
|
+
if hasattr(part_content.resource, "mimeType")
|
127
|
+
else "unknown_mime"
|
128
|
+
)
|
129
|
+
parts.append(
|
130
|
+
types.Part.from_text(
|
131
|
+
text=f"[Resource: {uri_str}, MIME: {mime_str}]"
|
132
|
+
)
|
133
|
+
)
|
134
|
+
|
135
|
+
if parts:
|
136
|
+
google_role = (
|
137
|
+
"user"
|
138
|
+
if message.role == "user"
|
139
|
+
else ("model" if message.role == "assistant" else "tool")
|
140
|
+
)
|
141
|
+
google_contents.append(types.Content(role=google_role, parts=parts))
|
142
|
+
return google_contents
|
143
|
+
|
144
|
+
def convert_to_google_tools(self, tools: List[ToolDefinition]) -> List[types.Tool]:
|
145
|
+
"""
|
146
|
+
Converts a list of fast-agent ToolDefinition to google.genai types.Tool.
|
147
|
+
"""
|
148
|
+
google_tools: List[types.Tool] = []
|
149
|
+
for tool in tools:
|
150
|
+
cleaned_input_schema = self._clean_schema_for_google(tool.inputSchema)
|
151
|
+
function_declaration = types.FunctionDeclaration(
|
152
|
+
name=tool.name,
|
153
|
+
description=tool.description if tool.description else "",
|
154
|
+
parameters=types.Schema(**cleaned_input_schema),
|
155
|
+
)
|
156
|
+
google_tools.append(types.Tool(function_declarations=[function_declaration]))
|
157
|
+
return google_tools
|
158
|
+
|
159
|
+
def convert_from_google_content(
|
160
|
+
self, content: types.Content
|
161
|
+
) -> List[TextContent | ImageContent | EmbeddedResource | CallToolRequestParams]:
|
162
|
+
"""
|
163
|
+
Converts google.genai types.Content from a model response to a list of
|
164
|
+
fast-agent content types or tool call requests.
|
165
|
+
"""
|
166
|
+
fast_agent_parts: List[
|
167
|
+
TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
|
168
|
+
] = []
|
169
|
+
for part in content.parts:
|
170
|
+
if part.text:
|
171
|
+
fast_agent_parts.append(TextContent(type="text", text=part.text))
|
172
|
+
elif part.function_call:
|
173
|
+
fast_agent_parts.append(
|
174
|
+
CallToolRequestParams(
|
175
|
+
name=part.function_call.name,
|
176
|
+
arguments=part.function_call.args,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
return fast_agent_parts
|
180
|
+
|
181
|
+
def convert_from_google_function_call(
|
182
|
+
self, function_call: types.FunctionCall
|
183
|
+
) -> CallToolRequest:
|
184
|
+
"""
|
185
|
+
Converts a single google.genai types.FunctionCall to a fast-agent CallToolRequest.
|
186
|
+
"""
|
187
|
+
return CallToolRequest(
|
188
|
+
method="tools/call",
|
189
|
+
params=CallToolRequestParams(
|
190
|
+
name=function_call.name,
|
191
|
+
arguments=function_call.args,
|
192
|
+
),
|
193
|
+
)
|
194
|
+
|
195
|
+
def convert_function_results_to_google(
|
196
|
+
self, tool_results: List[Tuple[str, CallToolResult]]
|
197
|
+
) -> List[types.Content]:
|
198
|
+
"""
|
199
|
+
Converts a list of fast-agent tool results to google.genai types.Content
|
200
|
+
with role 'tool'. Handles multimodal content in tool results.
|
201
|
+
"""
|
202
|
+
google_tool_response_contents: List[types.Content] = []
|
203
|
+
for tool_name, tool_result in tool_results:
|
204
|
+
current_content_parts: List[types.Part] = []
|
205
|
+
textual_outputs: List[str] = []
|
206
|
+
media_parts: List[types.Part] = []
|
207
|
+
|
208
|
+
for item in tool_result.content:
|
209
|
+
if is_text_content(item):
|
210
|
+
textual_outputs.append(get_text(item) or "") # Ensure no None is added
|
211
|
+
elif is_image_content(item):
|
212
|
+
assert isinstance(item, ImageContent)
|
213
|
+
try:
|
214
|
+
image_bytes = base64.b64decode(get_image_data(item) or "")
|
215
|
+
media_parts.append(
|
216
|
+
types.Part.from_bytes(data=image_bytes, mime_type=item.mimeType)
|
217
|
+
)
|
218
|
+
except Exception as e:
|
219
|
+
textual_outputs.append(f"[Error processing image from tool result: {e}]")
|
220
|
+
elif is_resource_content(item):
|
221
|
+
assert isinstance(item, EmbeddedResource)
|
222
|
+
if (
|
223
|
+
"application/pdf" == item.resource.mimeType
|
224
|
+
and hasattr(item.resource, "blob")
|
225
|
+
and isinstance(item.resource, BlobResourceContents)
|
226
|
+
):
|
227
|
+
try:
|
228
|
+
pdf_bytes = base64.b64decode(item.resource.blob)
|
229
|
+
media_parts.append(
|
230
|
+
types.Part.from_bytes(
|
231
|
+
data=pdf_bytes,
|
232
|
+
mime_type=item.resource.mimeType or "application/pdf",
|
233
|
+
)
|
234
|
+
)
|
235
|
+
except Exception as e:
|
236
|
+
textual_outputs.append(f"[Error processing PDF from tool result: {e}]")
|
237
|
+
else:
|
238
|
+
# Check if the resource itself has text content
|
239
|
+
resource_text = None
|
240
|
+
if hasattr(item.resource, "text"): # Direct text attribute
|
241
|
+
resource_text = item.resource.text
|
242
|
+
# Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
|
243
|
+
elif (
|
244
|
+
hasattr(item.resource, "type")
|
245
|
+
and item.resource.type == "text"
|
246
|
+
and hasattr(item.resource, "text")
|
247
|
+
):
|
248
|
+
resource_text = get_text(item.resource)
|
249
|
+
|
250
|
+
if resource_text is not None:
|
251
|
+
textual_outputs.append(resource_text)
|
252
|
+
else:
|
253
|
+
uri_str = (
|
254
|
+
item.resource.uri
|
255
|
+
if hasattr(item.resource, "uri")
|
256
|
+
else "unknown_uri"
|
257
|
+
)
|
258
|
+
mime_str = (
|
259
|
+
item.resource.mimeType
|
260
|
+
if hasattr(item.resource, "mimeType")
|
261
|
+
else "unknown_mime"
|
262
|
+
)
|
263
|
+
textual_outputs.append(
|
264
|
+
f"[Unhandled Resource in Tool: {uri_str}, MIME: {mime_str}]"
|
265
|
+
)
|
266
|
+
# Add handling for other content types if needed, for now they are skipped or become unhandled resource text
|
267
|
+
|
268
|
+
function_response_payload: Dict[str, Any] = {"tool_name": tool_name}
|
269
|
+
if textual_outputs:
|
270
|
+
function_response_payload["text_content"] = "\n".join(textual_outputs)
|
271
|
+
|
272
|
+
# Only add media_parts if there are some, otherwise Gemini might error on empty parts for function response
|
273
|
+
if media_parts:
|
274
|
+
# Create the main FunctionResponse part
|
275
|
+
fn_response_part = types.Part.from_function_response(
|
276
|
+
name=tool_name, response=function_response_payload
|
277
|
+
)
|
278
|
+
current_content_parts.append(fn_response_part)
|
279
|
+
current_content_parts.extend(
|
280
|
+
media_parts
|
281
|
+
) # Add media parts after the main response part
|
282
|
+
else: # If no media parts, the textual output (if any) is the sole content of the function response
|
283
|
+
fn_response_part = types.Part.from_function_response(
|
284
|
+
name=tool_name, response=function_response_payload
|
285
|
+
)
|
286
|
+
current_content_parts.append(fn_response_part)
|
287
|
+
|
288
|
+
google_tool_response_contents.append(
|
289
|
+
types.Content(role="tool", parts=current_content_parts)
|
290
|
+
)
|
291
|
+
return google_tool_response_contents
|
292
|
+
|
293
|
+
def convert_request_params_to_google_config(
|
294
|
+
self, request_params: RequestParams
|
295
|
+
) -> types.GenerateContentConfig:
|
296
|
+
"""
|
297
|
+
Converts fast-agent RequestParams to google.genai types.GenerateContentConfig.
|
298
|
+
"""
|
299
|
+
config_args: Dict[str, Any] = {}
|
300
|
+
if request_params.temperature is not None:
|
301
|
+
config_args["temperature"] = request_params.temperature
|
302
|
+
if request_params.maxTokens is not None:
|
303
|
+
config_args["max_output_tokens"] = request_params.maxTokens
|
304
|
+
if hasattr(request_params, "topK") and request_params.topK is not None:
|
305
|
+
config_args["top_k"] = request_params.topK
|
306
|
+
if hasattr(request_params, "topP") and request_params.topP is not None:
|
307
|
+
config_args["top_p"] = request_params.topP
|
308
|
+
if hasattr(request_params, "stopSequences") and request_params.stopSequences is not None:
|
309
|
+
config_args["stop_sequences"] = request_params.stopSequences
|
310
|
+
if (
|
311
|
+
hasattr(request_params, "presencePenalty")
|
312
|
+
and request_params.presencePenalty is not None
|
313
|
+
):
|
314
|
+
config_args["presence_penalty"] = request_params.presencePenalty
|
315
|
+
if (
|
316
|
+
hasattr(request_params, "frequencyPenalty")
|
317
|
+
and request_params.frequencyPenalty is not None
|
318
|
+
):
|
319
|
+
config_args["frequency_penalty"] = request_params.frequencyPenalty
|
320
|
+
if request_params.systemPrompt is not None:
|
321
|
+
config_args["system_instruction"] = request_params.systemPrompt
|
322
|
+
return types.GenerateContentConfig(**config_args)
|
323
|
+
|
324
|
+
def convert_from_google_content_list(
|
325
|
+
self, contents: List[types.Content]
|
326
|
+
) -> List[PromptMessageMultipart]:
|
327
|
+
"""
|
328
|
+
Converts a list of google.genai types.Content to a list of fast-agent PromptMessageMultipart.
|
329
|
+
"""
|
330
|
+
return [self._convert_from_google_content(content) for content in contents]
|
331
|
+
|
332
|
+
def _convert_from_google_content(self, content: types.Content) -> PromptMessageMultipart:
|
333
|
+
"""
|
334
|
+
Converts a single google.genai types.Content to a fast-agent PromptMessageMultipart.
|
335
|
+
"""
|
336
|
+
if content.role == "model" and any(part.function_call for part in content.parts):
|
337
|
+
return PromptMessageMultipart(role="assistant", content=[])
|
338
|
+
|
339
|
+
fast_agent_parts: List[
|
340
|
+
TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
|
341
|
+
] = []
|
342
|
+
for part in content.parts:
|
343
|
+
if part.text:
|
344
|
+
fast_agent_parts.append(TextContent(type="text", text=part.text))
|
345
|
+
elif part.function_response:
|
346
|
+
response_text = str(part.function_response.response)
|
347
|
+
fast_agent_parts.append(TextContent(type="text", text=response_text))
|
348
|
+
elif part.file_data:
|
349
|
+
fast_agent_parts.append(
|
350
|
+
EmbeddedResource(
|
351
|
+
type="resource",
|
352
|
+
resource=TextContent(
|
353
|
+
uri=part.file_data.file_uri,
|
354
|
+
mimeType=part.file_data.mime_type,
|
355
|
+
text=f"[Resource: {part.file_data.file_uri}, MIME: {part.file_data.mime_type}]",
|
356
|
+
),
|
357
|
+
)
|
358
|
+
)
|
359
|
+
|
360
|
+
fast_agent_role = "user" if content.role == "user" else "assistant"
|
361
|
+
return PromptMessageMultipart(role=fast_agent_role, content=fast_agent_parts)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Helper functions for type-safe server config access."""
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
from mcp import ClientSession
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from mcp_agent.config import MCPServerSettings
|
9
|
+
|
10
|
+
|
11
|
+
def get_server_config(ctx: ClientSession) -> Optional["MCPServerSettings"]:
|
12
|
+
"""Extract server config from context if available.
|
13
|
+
|
14
|
+
Type guard helper that safely accesses server_config with proper type checking.
|
15
|
+
"""
|
16
|
+
# Import here to avoid circular import
|
17
|
+
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
18
|
+
|
19
|
+
if (hasattr(ctx, "session") and
|
20
|
+
isinstance(ctx.session, MCPAgentClientSession) and
|
21
|
+
ctx.session.server_config):
|
22
|
+
return ctx.session.server_config
|
23
|
+
return None
|
@@ -4,7 +4,7 @@ It adds logging and supports sampling requests.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from datetime import timedelta
|
7
|
-
from typing import TYPE_CHECKING
|
7
|
+
from typing import TYPE_CHECKING
|
8
8
|
|
9
9
|
from mcp import ClientSession, ServerNotification
|
10
10
|
from mcp.shared.session import (
|
@@ -20,6 +20,7 @@ from pydantic import FileUrl
|
|
20
20
|
|
21
21
|
from mcp_agent.context_dependent import ContextDependent
|
22
22
|
from mcp_agent.logging.logger import get_logger
|
23
|
+
from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
|
23
24
|
from mcp_agent.mcp.sampling import sample
|
24
25
|
|
25
26
|
if TYPE_CHECKING:
|
@@ -31,24 +32,20 @@ logger = get_logger(__name__)
|
|
31
32
|
async def list_roots(ctx: ClientSession) -> ListRootsResult:
|
32
33
|
"""List roots callback that will be called by the MCP library."""
|
33
34
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
for root in ctx.session.server_config.roots
|
50
|
-
]
|
51
|
-
return ListRootsResult(roots=roots or [])
|
35
|
+
if server_config := get_server_config(ctx):
|
36
|
+
if server_config.roots:
|
37
|
+
roots = [
|
38
|
+
Root(
|
39
|
+
uri=FileUrl(
|
40
|
+
root.server_uri_alias or root.uri,
|
41
|
+
),
|
42
|
+
name=root.name,
|
43
|
+
)
|
44
|
+
for root in server_config.roots
|
45
|
+
]
|
46
|
+
return ListRootsResult(roots=roots)
|
47
|
+
|
48
|
+
return ListRootsResult(roots=[])
|
52
49
|
|
53
50
|
|
54
51
|
class MCPAgentClientSession(ClientSession, ContextDependent):
|
@@ -72,14 +69,43 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
72
69
|
self.session_server_name = kwargs.pop("server_name", None)
|
73
70
|
# Extract the notification callbacks if provided
|
74
71
|
self._tool_list_changed_callback = kwargs.pop("tool_list_changed_callback", None)
|
72
|
+
# Extract server_config if provided
|
73
|
+
self.server_config: MCPServerSettings | None = kwargs.pop("server_config", None)
|
74
|
+
# Extract agent_model if provided (for auto_sampling fallback)
|
75
|
+
self.agent_model: str | None = kwargs.pop("agent_model", None)
|
76
|
+
|
77
|
+
# Only register callbacks if the server_config has the relevant settings
|
78
|
+
list_roots_cb = list_roots if (self.server_config and self.server_config.roots) else None
|
79
|
+
|
80
|
+
# Register sampling callback if either:
|
81
|
+
# 1. Sampling is explicitly configured, OR
|
82
|
+
# 2. Application-level auto_sampling is enabled
|
83
|
+
sampling_cb = None
|
84
|
+
if self.server_config and hasattr(self.server_config, "sampling") and self.server_config.sampling:
|
85
|
+
# Explicit sampling configuration
|
86
|
+
sampling_cb = sample
|
87
|
+
elif self._should_enable_auto_sampling():
|
88
|
+
# Auto-sampling enabled at application level
|
89
|
+
sampling_cb = sample
|
90
|
+
|
75
91
|
super().__init__(
|
76
92
|
*args,
|
77
93
|
**kwargs,
|
78
|
-
list_roots_callback=
|
79
|
-
sampling_callback=
|
94
|
+
list_roots_callback=list_roots_cb,
|
95
|
+
sampling_callback=sampling_cb,
|
80
96
|
client_info=fast_agent,
|
81
97
|
)
|
82
|
-
|
98
|
+
|
99
|
+
def _should_enable_auto_sampling(self) -> bool:
|
100
|
+
"""Check if auto_sampling is enabled at the application level."""
|
101
|
+
try:
|
102
|
+
from mcp_agent.context import get_current_context
|
103
|
+
context = get_current_context()
|
104
|
+
if context and context.config:
|
105
|
+
return getattr(context.config, 'auto_sampling', True)
|
106
|
+
except Exception:
|
107
|
+
pass
|
108
|
+
return True # Default to True if can't access config
|
83
109
|
|
84
110
|
async def send_request(
|
85
111
|
self,
|
@@ -91,10 +117,11 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
91
117
|
logger.debug("send_request: request=", data=request.model_dump())
|
92
118
|
try:
|
93
119
|
result = await super().send_request(
|
94
|
-
request,
|
95
|
-
result_type,
|
120
|
+
request=request,
|
121
|
+
result_type=result_type,
|
96
122
|
request_read_timeout_seconds=request_read_timeout_seconds,
|
97
123
|
progress_callback=progress_callback,
|
124
|
+
metadata=None,
|
98
125
|
)
|
99
126
|
logger.debug("send_request: response=", data=result.model_dump())
|
100
127
|
return result
|
mcp_agent/mcp/mcp_aggregator.py
CHANGED
@@ -215,13 +215,20 @@ class MCPAggregator(ContextDependent):
|
|
215
215
|
)
|
216
216
|
|
217
217
|
# Create a wrapper to capture the parameters for the client session
|
218
|
-
def session_factory(read_stream, write_stream, read_timeout):
|
218
|
+
def session_factory(read_stream, write_stream, read_timeout, **kwargs):
|
219
|
+
# Get agent's model if this aggregator is part of an agent
|
220
|
+
agent_model = None
|
221
|
+
if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
|
222
|
+
agent_model = self.config.model
|
223
|
+
|
219
224
|
return MCPAgentClientSession(
|
220
225
|
read_stream,
|
221
226
|
write_stream,
|
222
227
|
read_timeout,
|
223
228
|
server_name=server_name,
|
229
|
+
agent_model=agent_model,
|
224
230
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
231
|
+
**kwargs # Pass through any additional kwargs like server_config
|
225
232
|
)
|
226
233
|
|
227
234
|
await self._persistent_connection_manager.get_server(
|
@@ -269,13 +276,20 @@ class MCPAggregator(ContextDependent):
|
|
269
276
|
prompts = await fetch_prompts(server_connection.session, server_name)
|
270
277
|
else:
|
271
278
|
# Create a factory function for the client session
|
272
|
-
def create_session(read_stream, write_stream, read_timeout):
|
279
|
+
def create_session(read_stream, write_stream, read_timeout, **kwargs):
|
280
|
+
# Get agent's model if this aggregator is part of an agent
|
281
|
+
agent_model = None
|
282
|
+
if hasattr(self, 'config') and self.config and hasattr(self.config, 'model'):
|
283
|
+
agent_model = self.config.model
|
284
|
+
|
273
285
|
return MCPAgentClientSession(
|
274
286
|
read_stream,
|
275
287
|
write_stream,
|
276
288
|
read_timeout,
|
277
289
|
server_name=server_name,
|
290
|
+
agent_model=agent_model,
|
278
291
|
tool_list_changed_callback=self._handle_tool_list_changed,
|
292
|
+
**kwargs # Pass through any additional kwargs like server_config
|
279
293
|
)
|
280
294
|
|
281
295
|
async with gen_client(
|
@@ -797,12 +811,13 @@ class MCPAggregator(ContextDependent):
|
|
797
811
|
messages=[],
|
798
812
|
)
|
799
813
|
|
800
|
-
async def list_prompts(self, server_name: str | None = None) -> Mapping[str, List[Prompt]]:
|
814
|
+
async def list_prompts(self, server_name: str | None = None, agent_name: str | None = None) -> Mapping[str, List[Prompt]]:
|
801
815
|
"""
|
802
816
|
List available prompts from one or all servers.
|
803
817
|
|
804
818
|
:param server_name: Optional server name to list prompts from. If not provided,
|
805
819
|
lists prompts from all servers.
|
820
|
+
:param agent_name: Optional agent name (ignored at this level, used by multi-agent apps)
|
806
821
|
:return: Dictionary mapping server names to lists of Prompt objects
|
807
822
|
"""
|
808
823
|
if not self.initialized:
|
@@ -165,11 +165,12 @@ class ServerConnection:
|
|
165
165
|
else None
|
166
166
|
)
|
167
167
|
|
168
|
-
session = self._client_session_factory(
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
168
|
+
session = self._client_session_factory(
|
169
|
+
read_stream,
|
170
|
+
send_stream,
|
171
|
+
read_timeout,
|
172
|
+
server_config=self.server_config
|
173
|
+
)
|
173
174
|
|
174
175
|
self.session = session
|
175
176
|
|
mcp_agent/mcp/sampling.py
CHANGED
@@ -10,6 +10,7 @@ from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextConte
|
|
10
10
|
from mcp_agent.core.agent_types import AgentConfig
|
11
11
|
from mcp_agent.llm.sampling_converter import SamplingConverter
|
12
12
|
from mcp_agent.logging.logger import get_logger
|
13
|
+
from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
|
13
14
|
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
@@ -78,18 +79,47 @@ async def sample(mcp_ctx: ClientSession, params: CreateMessageRequestParams) ->
|
|
78
79
|
"""
|
79
80
|
model = None
|
80
81
|
try:
|
81
|
-
# Extract model from server config
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
)
|
89
|
-
|
82
|
+
# Extract model from server config using type-safe helper
|
83
|
+
server_config = get_server_config(mcp_ctx)
|
84
|
+
|
85
|
+
# First priority: explicitly configured sampling model
|
86
|
+
if server_config and hasattr(server_config, "sampling") and server_config.sampling:
|
87
|
+
model = server_config.sampling.model
|
88
|
+
|
89
|
+
# Second priority: auto_sampling fallback (if enabled at application level)
|
90
|
+
if model is None:
|
91
|
+
# Check if auto_sampling is enabled
|
92
|
+
auto_sampling_enabled = False
|
93
|
+
try:
|
94
|
+
from mcp_agent.context import get_current_context
|
95
|
+
app_context = get_current_context()
|
96
|
+
if app_context and app_context.config:
|
97
|
+
auto_sampling_enabled = getattr(app_context.config, 'auto_sampling', True)
|
98
|
+
except Exception as e:
|
99
|
+
logger.debug(f"Could not get application config: {e}")
|
100
|
+
auto_sampling_enabled = True # Default to enabled
|
101
|
+
|
102
|
+
if auto_sampling_enabled:
|
103
|
+
# Import here to avoid circular import
|
104
|
+
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
105
|
+
|
106
|
+
# Try agent's model first (from the session)
|
107
|
+
if (hasattr(mcp_ctx, 'session') and
|
108
|
+
isinstance(mcp_ctx.session, MCPAgentClientSession) and
|
109
|
+
mcp_ctx.session.agent_model):
|
110
|
+
model = mcp_ctx.session.agent_model
|
111
|
+
logger.debug(f"Using agent's model for sampling: {model}")
|
112
|
+
else:
|
113
|
+
# Fall back to system default model
|
114
|
+
try:
|
115
|
+
if app_context and app_context.config and app_context.config.default_model:
|
116
|
+
model = app_context.config.default_model
|
117
|
+
logger.debug(f"Using system default model for sampling: {model}")
|
118
|
+
except Exception as e:
|
119
|
+
logger.debug(f"Could not get system default model: {e}")
|
90
120
|
|
91
121
|
if model is None:
|
92
|
-
raise ValueError("No model configured")
|
122
|
+
raise ValueError("No model configured for sampling (server config, agent model, or system default)")
|
93
123
|
|
94
124
|
# Create an LLM instance
|
95
125
|
llm = create_sampling_llm(params, model)
|