fast-agent-mcp 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.29.dist-info}/METADATA +13 -8
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.29.dist-info}/RECORD +29 -23
- mcp_agent/agents/agent.py +1 -17
- mcp_agent/agents/base_agent.py +2 -0
- mcp_agent/app.py +1 -1
- mcp_agent/cli/commands/url_parser.py +7 -1
- mcp_agent/config.py +3 -0
- mcp_agent/context.py +7 -3
- mcp_agent/core/agent_app.py +7 -2
- mcp_agent/core/interactive_prompt.py +60 -53
- mcp_agent/llm/augmented_llm_slow.py +42 -0
- mcp_agent/llm/model_factory.py +74 -37
- mcp_agent/llm/provider_types.py +4 -3
- mcp_agent/llm/providers/augmented_llm_deepseek.py +49 -0
- mcp_agent/llm/providers/augmented_llm_google_native.py +459 -0
- mcp_agent/llm/providers/{augmented_llm_google.py → augmented_llm_google_oai.py} +2 -2
- mcp_agent/llm/providers/google_converter.py +365 -0
- mcp_agent/mcp/common.py +2 -2
- mcp_agent/mcp/helpers/server_config_helpers.py +23 -0
- mcp_agent/mcp/hf_auth.py +87 -0
- mcp_agent/mcp/mcp_agent_client_session.py +63 -60
- mcp_agent/mcp/mcp_aggregator.py +18 -3
- mcp_agent/mcp/mcp_connection_manager.py +6 -5
- mcp_agent/mcp/sampling.py +40 -10
- mcp_agent/mcp_server_registry.py +25 -7
- mcp_agent/tools/tool_definition.py +14 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.29.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.29.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.27.dist-info → fast_agent_mcp-0.2.29.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,365 @@
|
|
1
|
+
import base64
|
2
|
+
from typing import Any, Dict, List, Tuple
|
3
|
+
|
4
|
+
# Import necessary types from google.genai
|
5
|
+
from google.genai import types
|
6
|
+
from mcp.types import (
|
7
|
+
BlobResourceContents,
|
8
|
+
CallToolRequest,
|
9
|
+
CallToolRequestParams,
|
10
|
+
CallToolResult,
|
11
|
+
EmbeddedResource,
|
12
|
+
ImageContent,
|
13
|
+
TextContent,
|
14
|
+
)
|
15
|
+
|
16
|
+
from mcp_agent.core.request_params import RequestParams
|
17
|
+
from mcp_agent.mcp.helpers.content_helpers import (
|
18
|
+
get_image_data,
|
19
|
+
get_text,
|
20
|
+
is_image_content,
|
21
|
+
is_resource_content,
|
22
|
+
is_text_content,
|
23
|
+
)
|
24
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
25
|
+
from mcp_agent.tools.tool_definition import ToolDefinition
|
26
|
+
|
27
|
+
|
28
|
+
class GoogleConverter:
|
29
|
+
"""
|
30
|
+
Converts between fast-agent and google.genai data structures.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def _clean_schema_for_google(self, schema: Dict[str, Any]) -> Dict[str, Any]:
|
34
|
+
"""
|
35
|
+
Recursively removes unsupported JSON schema keywords for google.genai.types.Schema.
|
36
|
+
Specifically removes 'additionalProperties', '$schema', 'exclusiveMaximum', and 'exclusiveMinimum'.
|
37
|
+
"""
|
38
|
+
cleaned_schema = {}
|
39
|
+
unsupported_keys = {
|
40
|
+
"additionalProperties",
|
41
|
+
"$schema",
|
42
|
+
"exclusiveMaximum",
|
43
|
+
"exclusiveMinimum",
|
44
|
+
}
|
45
|
+
supported_string_formats = {"enum", "date-time"}
|
46
|
+
|
47
|
+
for key, value in schema.items():
|
48
|
+
if key in unsupported_keys:
|
49
|
+
continue # Skip this key
|
50
|
+
|
51
|
+
if (
|
52
|
+
key == "format"
|
53
|
+
and schema.get("type") == "string"
|
54
|
+
and value not in supported_string_formats
|
55
|
+
):
|
56
|
+
continue # Remove unsupported string formats
|
57
|
+
|
58
|
+
if isinstance(value, dict):
|
59
|
+
cleaned_schema[key] = self._clean_schema_for_google(value)
|
60
|
+
elif isinstance(value, list):
|
61
|
+
cleaned_schema[key] = [
|
62
|
+
self._clean_schema_for_google(item) if isinstance(item, dict) else item
|
63
|
+
for item in value
|
64
|
+
]
|
65
|
+
else:
|
66
|
+
cleaned_schema[key] = value
|
67
|
+
return cleaned_schema
|
68
|
+
|
69
|
+
def convert_to_google_content(
|
70
|
+
self, messages: List[PromptMessageMultipart]
|
71
|
+
) -> List[types.Content]:
|
72
|
+
"""
|
73
|
+
Converts a list of fast-agent PromptMessageMultipart to google.genai types.Content.
|
74
|
+
Handles different roles and content types (text, images, etc.).
|
75
|
+
"""
|
76
|
+
google_contents: List[types.Content] = []
|
77
|
+
for message in messages:
|
78
|
+
parts: List[types.Part] = []
|
79
|
+
for part_content in message.content: # renamed part to part_content to avoid conflict
|
80
|
+
if is_text_content(part_content):
|
81
|
+
parts.append(types.Part.from_text(text=get_text(part_content) or ""))
|
82
|
+
elif is_image_content(part_content):
|
83
|
+
assert isinstance(part_content, ImageContent)
|
84
|
+
image_bytes = base64.b64decode(get_image_data(part_content) or "")
|
85
|
+
parts.append(
|
86
|
+
types.Part.from_bytes(mime_type=part_content.mimeType, data=image_bytes)
|
87
|
+
)
|
88
|
+
elif is_resource_content(part_content):
|
89
|
+
assert isinstance(part_content, EmbeddedResource)
|
90
|
+
if (
|
91
|
+
"application/pdf" == part_content.resource.mimeType
|
92
|
+
and hasattr(part_content.resource, "blob")
|
93
|
+
and isinstance(part_content.resource, BlobResourceContents)
|
94
|
+
):
|
95
|
+
pdf_bytes = base64.b64decode(part_content.resource.blob)
|
96
|
+
parts.append(
|
97
|
+
types.Part.from_bytes(
|
98
|
+
mime_type=part_content.resource.mimeType or "application/pdf",
|
99
|
+
data=pdf_bytes,
|
100
|
+
)
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
# Check if the resource itself has text content
|
104
|
+
resource_text = None
|
105
|
+
if hasattr(part_content.resource, "text"): # Direct text attribute
|
106
|
+
resource_text = part_content.resource.text
|
107
|
+
# Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
|
108
|
+
elif (
|
109
|
+
hasattr(part_content.resource, "type")
|
110
|
+
and part_content.resource.type == "text"
|
111
|
+
and hasattr(part_content.resource, "text")
|
112
|
+
):
|
113
|
+
resource_text = get_text(part_content.resource)
|
114
|
+
|
115
|
+
if resource_text is not None:
|
116
|
+
parts.append(types.Part.from_text(text=resource_text))
|
117
|
+
else:
|
118
|
+
# Fallback for other binary types or types without direct text
|
119
|
+
uri_str = (
|
120
|
+
part_content.resource.uri
|
121
|
+
if hasattr(part_content.resource, "uri")
|
122
|
+
else "unknown_uri"
|
123
|
+
)
|
124
|
+
mime_str = (
|
125
|
+
part_content.resource.mimeType
|
126
|
+
if hasattr(part_content.resource, "mimeType")
|
127
|
+
else "unknown_mime"
|
128
|
+
)
|
129
|
+
parts.append(
|
130
|
+
types.Part.from_text(
|
131
|
+
text=f"[Resource: {uri_str}, MIME: {mime_str}]"
|
132
|
+
)
|
133
|
+
)
|
134
|
+
|
135
|
+
if parts:
|
136
|
+
google_role = (
|
137
|
+
"user"
|
138
|
+
if message.role == "user"
|
139
|
+
else ("model" if message.role == "assistant" else "tool")
|
140
|
+
)
|
141
|
+
google_contents.append(types.Content(role=google_role, parts=parts))
|
142
|
+
return google_contents
|
143
|
+
|
144
|
+
def convert_to_google_tools(self, tools: List[ToolDefinition]) -> List[types.Tool]:
|
145
|
+
"""
|
146
|
+
Converts a list of fast-agent ToolDefinition to google.genai types.Tool.
|
147
|
+
"""
|
148
|
+
google_tools: List[types.Tool] = []
|
149
|
+
for tool in tools:
|
150
|
+
cleaned_input_schema = self._clean_schema_for_google(tool.inputSchema)
|
151
|
+
function_declaration = types.FunctionDeclaration(
|
152
|
+
name=tool.name,
|
153
|
+
description=tool.description if tool.description else "",
|
154
|
+
parameters=types.Schema(**cleaned_input_schema),
|
155
|
+
)
|
156
|
+
google_tools.append(types.Tool(function_declarations=[function_declaration]))
|
157
|
+
return google_tools
|
158
|
+
|
159
|
+
def convert_from_google_content(
|
160
|
+
self, content: types.Content
|
161
|
+
) -> List[TextContent | ImageContent | EmbeddedResource | CallToolRequestParams]:
|
162
|
+
"""
|
163
|
+
Converts google.genai types.Content from a model response to a list of
|
164
|
+
fast-agent content types or tool call requests.
|
165
|
+
"""
|
166
|
+
fast_agent_parts: List[
|
167
|
+
TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
|
168
|
+
] = []
|
169
|
+
|
170
|
+
if content is None or not hasattr(content, 'parts') or content.parts is None:
|
171
|
+
return [] # Google API response 'content' object is None. Cannot extract parts.
|
172
|
+
|
173
|
+
for part in content.parts:
|
174
|
+
if part.text:
|
175
|
+
fast_agent_parts.append(TextContent(type="text", text=part.text))
|
176
|
+
elif part.function_call:
|
177
|
+
fast_agent_parts.append(
|
178
|
+
CallToolRequestParams(
|
179
|
+
name=part.function_call.name,
|
180
|
+
arguments=part.function_call.args,
|
181
|
+
)
|
182
|
+
)
|
183
|
+
return fast_agent_parts
|
184
|
+
|
185
|
+
def convert_from_google_function_call(
|
186
|
+
self, function_call: types.FunctionCall
|
187
|
+
) -> CallToolRequest:
|
188
|
+
"""
|
189
|
+
Converts a single google.genai types.FunctionCall to a fast-agent CallToolRequest.
|
190
|
+
"""
|
191
|
+
return CallToolRequest(
|
192
|
+
method="tools/call",
|
193
|
+
params=CallToolRequestParams(
|
194
|
+
name=function_call.name,
|
195
|
+
arguments=function_call.args,
|
196
|
+
),
|
197
|
+
)
|
198
|
+
|
199
|
+
def convert_function_results_to_google(
|
200
|
+
self, tool_results: List[Tuple[str, CallToolResult]]
|
201
|
+
) -> List[types.Content]:
|
202
|
+
"""
|
203
|
+
Converts a list of fast-agent tool results to google.genai types.Content
|
204
|
+
with role 'tool'. Handles multimodal content in tool results.
|
205
|
+
"""
|
206
|
+
google_tool_response_contents: List[types.Content] = []
|
207
|
+
for tool_name, tool_result in tool_results:
|
208
|
+
current_content_parts: List[types.Part] = []
|
209
|
+
textual_outputs: List[str] = []
|
210
|
+
media_parts: List[types.Part] = []
|
211
|
+
|
212
|
+
for item in tool_result.content:
|
213
|
+
if is_text_content(item):
|
214
|
+
textual_outputs.append(get_text(item) or "") # Ensure no None is added
|
215
|
+
elif is_image_content(item):
|
216
|
+
assert isinstance(item, ImageContent)
|
217
|
+
try:
|
218
|
+
image_bytes = base64.b64decode(get_image_data(item) or "")
|
219
|
+
media_parts.append(
|
220
|
+
types.Part.from_bytes(data=image_bytes, mime_type=item.mimeType)
|
221
|
+
)
|
222
|
+
except Exception as e:
|
223
|
+
textual_outputs.append(f"[Error processing image from tool result: {e}]")
|
224
|
+
elif is_resource_content(item):
|
225
|
+
assert isinstance(item, EmbeddedResource)
|
226
|
+
if (
|
227
|
+
"application/pdf" == item.resource.mimeType
|
228
|
+
and hasattr(item.resource, "blob")
|
229
|
+
and isinstance(item.resource, BlobResourceContents)
|
230
|
+
):
|
231
|
+
try:
|
232
|
+
pdf_bytes = base64.b64decode(item.resource.blob)
|
233
|
+
media_parts.append(
|
234
|
+
types.Part.from_bytes(
|
235
|
+
data=pdf_bytes,
|
236
|
+
mime_type=item.resource.mimeType or "application/pdf",
|
237
|
+
)
|
238
|
+
)
|
239
|
+
except Exception as e:
|
240
|
+
textual_outputs.append(f"[Error processing PDF from tool result: {e}]")
|
241
|
+
else:
|
242
|
+
# Check if the resource itself has text content
|
243
|
+
resource_text = None
|
244
|
+
if hasattr(item.resource, "text"): # Direct text attribute
|
245
|
+
resource_text = item.resource.text
|
246
|
+
# Example: if EmbeddedResource wraps a TextContent-like object in its 'resource' field
|
247
|
+
elif (
|
248
|
+
hasattr(item.resource, "type")
|
249
|
+
and item.resource.type == "text"
|
250
|
+
and hasattr(item.resource, "text")
|
251
|
+
):
|
252
|
+
resource_text = get_text(item.resource)
|
253
|
+
|
254
|
+
if resource_text is not None:
|
255
|
+
textual_outputs.append(resource_text)
|
256
|
+
else:
|
257
|
+
uri_str = (
|
258
|
+
item.resource.uri
|
259
|
+
if hasattr(item.resource, "uri")
|
260
|
+
else "unknown_uri"
|
261
|
+
)
|
262
|
+
mime_str = (
|
263
|
+
item.resource.mimeType
|
264
|
+
if hasattr(item.resource, "mimeType")
|
265
|
+
else "unknown_mime"
|
266
|
+
)
|
267
|
+
textual_outputs.append(
|
268
|
+
f"[Unhandled Resource in Tool: {uri_str}, MIME: {mime_str}]"
|
269
|
+
)
|
270
|
+
# Add handling for other content types if needed, for now they are skipped or become unhandled resource text
|
271
|
+
|
272
|
+
function_response_payload: Dict[str, Any] = {"tool_name": tool_name}
|
273
|
+
if textual_outputs:
|
274
|
+
function_response_payload["text_content"] = "\n".join(textual_outputs)
|
275
|
+
|
276
|
+
# Only add media_parts if there are some, otherwise Gemini might error on empty parts for function response
|
277
|
+
if media_parts:
|
278
|
+
# Create the main FunctionResponse part
|
279
|
+
fn_response_part = types.Part.from_function_response(
|
280
|
+
name=tool_name, response=function_response_payload
|
281
|
+
)
|
282
|
+
current_content_parts.append(fn_response_part)
|
283
|
+
current_content_parts.extend(
|
284
|
+
media_parts
|
285
|
+
) # Add media parts after the main response part
|
286
|
+
else: # If no media parts, the textual output (if any) is the sole content of the function response
|
287
|
+
fn_response_part = types.Part.from_function_response(
|
288
|
+
name=tool_name, response=function_response_payload
|
289
|
+
)
|
290
|
+
current_content_parts.append(fn_response_part)
|
291
|
+
|
292
|
+
google_tool_response_contents.append(
|
293
|
+
types.Content(role="tool", parts=current_content_parts)
|
294
|
+
)
|
295
|
+
return google_tool_response_contents
|
296
|
+
|
297
|
+
def convert_request_params_to_google_config(
|
298
|
+
self, request_params: RequestParams
|
299
|
+
) -> types.GenerateContentConfig:
|
300
|
+
"""
|
301
|
+
Converts fast-agent RequestParams to google.genai types.GenerateContentConfig.
|
302
|
+
"""
|
303
|
+
config_args: Dict[str, Any] = {}
|
304
|
+
if request_params.temperature is not None:
|
305
|
+
config_args["temperature"] = request_params.temperature
|
306
|
+
if request_params.maxTokens is not None:
|
307
|
+
config_args["max_output_tokens"] = request_params.maxTokens
|
308
|
+
if hasattr(request_params, "topK") and request_params.topK is not None:
|
309
|
+
config_args["top_k"] = request_params.topK
|
310
|
+
if hasattr(request_params, "topP") and request_params.topP is not None:
|
311
|
+
config_args["top_p"] = request_params.topP
|
312
|
+
if hasattr(request_params, "stopSequences") and request_params.stopSequences is not None:
|
313
|
+
config_args["stop_sequences"] = request_params.stopSequences
|
314
|
+
if (
|
315
|
+
hasattr(request_params, "presencePenalty")
|
316
|
+
and request_params.presencePenalty is not None
|
317
|
+
):
|
318
|
+
config_args["presence_penalty"] = request_params.presencePenalty
|
319
|
+
if (
|
320
|
+
hasattr(request_params, "frequencyPenalty")
|
321
|
+
and request_params.frequencyPenalty is not None
|
322
|
+
):
|
323
|
+
config_args["frequency_penalty"] = request_params.frequencyPenalty
|
324
|
+
if request_params.systemPrompt is not None:
|
325
|
+
config_args["system_instruction"] = request_params.systemPrompt
|
326
|
+
return types.GenerateContentConfig(**config_args)
|
327
|
+
|
328
|
+
def convert_from_google_content_list(
|
329
|
+
self, contents: List[types.Content]
|
330
|
+
) -> List[PromptMessageMultipart]:
|
331
|
+
"""
|
332
|
+
Converts a list of google.genai types.Content to a list of fast-agent PromptMessageMultipart.
|
333
|
+
"""
|
334
|
+
return [self._convert_from_google_content(content) for content in contents]
|
335
|
+
|
336
|
+
def _convert_from_google_content(self, content: types.Content) -> PromptMessageMultipart:
|
337
|
+
"""
|
338
|
+
Converts a single google.genai types.Content to a fast-agent PromptMessageMultipart.
|
339
|
+
"""
|
340
|
+
if content.role == "model" and any(part.function_call for part in content.parts):
|
341
|
+
return PromptMessageMultipart(role="assistant", content=[])
|
342
|
+
|
343
|
+
fast_agent_parts: List[
|
344
|
+
TextContent | ImageContent | EmbeddedResource | CallToolRequestParams
|
345
|
+
] = []
|
346
|
+
for part in content.parts:
|
347
|
+
if part.text:
|
348
|
+
fast_agent_parts.append(TextContent(type="text", text=part.text))
|
349
|
+
elif part.function_response:
|
350
|
+
response_text = str(part.function_response.response)
|
351
|
+
fast_agent_parts.append(TextContent(type="text", text=response_text))
|
352
|
+
elif part.file_data:
|
353
|
+
fast_agent_parts.append(
|
354
|
+
EmbeddedResource(
|
355
|
+
type="resource",
|
356
|
+
resource=TextContent(
|
357
|
+
uri=part.file_data.file_uri,
|
358
|
+
mimeType=part.file_data.mime_type,
|
359
|
+
text=f"[Resource: {part.file_data.file_uri}, MIME: {part.file_data.mime_type}]",
|
360
|
+
),
|
361
|
+
)
|
362
|
+
)
|
363
|
+
|
364
|
+
fast_agent_role = "user" if content.role == "user" else "assistant"
|
365
|
+
return PromptMessageMultipart(role=fast_agent_role, content=fast_agent_parts)
|
mcp_agent/mcp/common.py
CHANGED
@@ -8,9 +8,9 @@ SEP = "-"
|
|
8
8
|
|
9
9
|
def create_namespaced_name(server_name: str, resource_name: str) -> str:
|
10
10
|
"""Create a namespaced resource name from server and resource names"""
|
11
|
-
return f"{server_name}{SEP}{resource_name}"
|
11
|
+
return f"{server_name}{SEP}{resource_name}"[:64]
|
12
12
|
|
13
13
|
|
14
14
|
def is_namespaced_name(name: str) -> bool:
|
15
15
|
"""Check if a name is already namespaced"""
|
16
|
-
return SEP in name
|
16
|
+
return SEP in name
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Helper functions for type-safe server config access."""
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
from mcp import ClientSession
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from mcp_agent.config import MCPServerSettings
|
9
|
+
|
10
|
+
|
11
|
+
def get_server_config(ctx: ClientSession) -> Optional["MCPServerSettings"]:
|
12
|
+
"""Extract server config from context if available.
|
13
|
+
|
14
|
+
Type guard helper that safely accesses server_config with proper type checking.
|
15
|
+
"""
|
16
|
+
# Import here to avoid circular import
|
17
|
+
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
18
|
+
|
19
|
+
if (hasattr(ctx, "session") and
|
20
|
+
isinstance(ctx.session, MCPAgentClientSession) and
|
21
|
+
ctx.session.server_config):
|
22
|
+
return ctx.session.server_config
|
23
|
+
return None
|
mcp_agent/mcp/hf_auth.py
ADDED
@@ -0,0 +1,87 @@
|
|
1
|
+
"""HuggingFace authentication utilities for MCP connections."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import Dict, Optional
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
|
7
|
+
|
8
|
+
def is_huggingface_url(url: str) -> bool:
|
9
|
+
"""
|
10
|
+
Check if a URL is a HuggingFace URL that should receive HF_TOKEN authentication.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
url: The URL to check
|
14
|
+
|
15
|
+
Returns:
|
16
|
+
True if the URL is a HuggingFace URL, False otherwise
|
17
|
+
"""
|
18
|
+
try:
|
19
|
+
parsed = urlparse(url)
|
20
|
+
hostname = parsed.hostname
|
21
|
+
if hostname is None:
|
22
|
+
return False
|
23
|
+
|
24
|
+
# Check for HuggingFace domains
|
25
|
+
return hostname in {"hf.co", "huggingface.co"}
|
26
|
+
except Exception:
|
27
|
+
return False
|
28
|
+
|
29
|
+
|
30
|
+
def get_hf_token_from_env() -> Optional[str]:
|
31
|
+
"""
|
32
|
+
Get the HuggingFace token from the HF_TOKEN environment variable.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The HF_TOKEN value if set, None otherwise
|
36
|
+
"""
|
37
|
+
return os.environ.get("HF_TOKEN")
|
38
|
+
|
39
|
+
|
40
|
+
def should_add_hf_auth(url: str, existing_headers: Optional[Dict[str, str]]) -> bool:
|
41
|
+
"""
|
42
|
+
Determine if HuggingFace authentication should be added to the headers.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
url: The URL to check
|
46
|
+
existing_headers: Existing headers dictionary (may be None)
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
True if HF auth should be added, False otherwise
|
50
|
+
"""
|
51
|
+
# Only add HF auth if:
|
52
|
+
# 1. URL is a HuggingFace URL
|
53
|
+
# 2. No existing Authorization header is set
|
54
|
+
# 3. HF_TOKEN environment variable is available
|
55
|
+
|
56
|
+
if not is_huggingface_url(url):
|
57
|
+
return False
|
58
|
+
|
59
|
+
if existing_headers and "Authorization" in existing_headers:
|
60
|
+
return False
|
61
|
+
|
62
|
+
return get_hf_token_from_env() is not None
|
63
|
+
|
64
|
+
|
65
|
+
def add_hf_auth_header(url: str, headers: Optional[Dict[str, str]]) -> Optional[Dict[str, str]]:
|
66
|
+
"""
|
67
|
+
Add HuggingFace authentication header if appropriate.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
url: The URL to check
|
71
|
+
headers: Existing headers dictionary (may be None)
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Updated headers dictionary with HF auth if appropriate, or original headers
|
75
|
+
"""
|
76
|
+
if not should_add_hf_auth(url, headers):
|
77
|
+
return headers
|
78
|
+
|
79
|
+
hf_token = get_hf_token_from_env()
|
80
|
+
if hf_token is None:
|
81
|
+
return headers
|
82
|
+
|
83
|
+
# Create new headers dict or copy existing one
|
84
|
+
result_headers = dict(headers) if headers else {}
|
85
|
+
result_headers["Authorization"] = f"Bearer {hf_token}"
|
86
|
+
|
87
|
+
return result_headers
|
@@ -4,22 +4,21 @@ It adds logging and supports sampling requests.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
from datetime import timedelta
|
7
|
-
from typing import TYPE_CHECKING
|
7
|
+
from typing import TYPE_CHECKING
|
8
8
|
|
9
9
|
from mcp import ClientSession, ServerNotification
|
10
|
+
from mcp.shared.message import MessageMetadata
|
10
11
|
from mcp.shared.session import (
|
11
12
|
ProgressFnT,
|
12
13
|
ReceiveResultT,
|
13
|
-
RequestId,
|
14
|
-
SendNotificationT,
|
15
14
|
SendRequestT,
|
16
|
-
SendResultT,
|
17
15
|
)
|
18
|
-
from mcp.types import
|
16
|
+
from mcp.types import Implementation, ListRootsResult, Root, ToolListChangedNotification
|
19
17
|
from pydantic import FileUrl
|
20
18
|
|
21
19
|
from mcp_agent.context_dependent import ContextDependent
|
22
20
|
from mcp_agent.logging.logger import get_logger
|
21
|
+
from mcp_agent.mcp.helpers.server_config_helpers import get_server_config
|
23
22
|
from mcp_agent.mcp.sampling import sample
|
24
23
|
|
25
24
|
if TYPE_CHECKING:
|
@@ -31,24 +30,20 @@ logger = get_logger(__name__)
|
|
31
30
|
async def list_roots(ctx: ClientSession) -> ListRootsResult:
|
32
31
|
"""List roots callback that will be called by the MCP library."""
|
33
32
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
)
|
49
|
-
for root in ctx.session.server_config.roots
|
50
|
-
]
|
51
|
-
return ListRootsResult(roots=roots or [])
|
33
|
+
if server_config := get_server_config(ctx):
|
34
|
+
if server_config.roots:
|
35
|
+
roots = [
|
36
|
+
Root(
|
37
|
+
uri=FileUrl(
|
38
|
+
root.server_uri_alias or root.uri,
|
39
|
+
),
|
40
|
+
name=root.name,
|
41
|
+
)
|
42
|
+
for root in server_config.roots
|
43
|
+
]
|
44
|
+
return ListRootsResult(roots=roots)
|
45
|
+
|
46
|
+
return ListRootsResult(roots=[])
|
52
47
|
|
53
48
|
|
54
49
|
class MCPAgentClientSession(ClientSession, ContextDependent):
|
@@ -72,53 +67,75 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
72
67
|
self.session_server_name = kwargs.pop("server_name", None)
|
73
68
|
# Extract the notification callbacks if provided
|
74
69
|
self._tool_list_changed_callback = kwargs.pop("tool_list_changed_callback", None)
|
70
|
+
# Extract server_config if provided
|
71
|
+
self.server_config: MCPServerSettings | None = kwargs.pop("server_config", None)
|
72
|
+
# Extract agent_model if provided (for auto_sampling fallback)
|
73
|
+
self.agent_model: str | None = kwargs.pop("agent_model", None)
|
74
|
+
|
75
|
+
# Only register callbacks if the server_config has the relevant settings
|
76
|
+
list_roots_cb = list_roots if (self.server_config and self.server_config.roots) else None
|
77
|
+
|
78
|
+
# Register sampling callback if either:
|
79
|
+
# 1. Sampling is explicitly configured, OR
|
80
|
+
# 2. Application-level auto_sampling is enabled
|
81
|
+
sampling_cb = None
|
82
|
+
if (
|
83
|
+
self.server_config
|
84
|
+
and hasattr(self.server_config, "sampling")
|
85
|
+
and self.server_config.sampling
|
86
|
+
):
|
87
|
+
# Explicit sampling configuration
|
88
|
+
sampling_cb = sample
|
89
|
+
elif self._should_enable_auto_sampling():
|
90
|
+
# Auto-sampling enabled at application level
|
91
|
+
sampling_cb = sample
|
92
|
+
|
75
93
|
super().__init__(
|
76
94
|
*args,
|
77
95
|
**kwargs,
|
78
|
-
list_roots_callback=
|
79
|
-
sampling_callback=
|
96
|
+
list_roots_callback=list_roots_cb,
|
97
|
+
sampling_callback=sampling_cb,
|
80
98
|
client_info=fast_agent,
|
81
99
|
)
|
82
|
-
|
100
|
+
|
101
|
+
def _should_enable_auto_sampling(self) -> bool:
|
102
|
+
"""Check if auto_sampling is enabled at the application level."""
|
103
|
+
try:
|
104
|
+
from mcp_agent.context import get_current_context
|
105
|
+
|
106
|
+
context = get_current_context()
|
107
|
+
if context and context.config:
|
108
|
+
return getattr(context.config, "auto_sampling", True)
|
109
|
+
except Exception:
|
110
|
+
pass
|
111
|
+
return True # Default to True if can't access config
|
83
112
|
|
84
113
|
async def send_request(
|
85
114
|
self,
|
86
115
|
request: SendRequestT,
|
87
116
|
result_type: type[ReceiveResultT],
|
88
117
|
request_read_timeout_seconds: timedelta | None = None,
|
118
|
+
metadata: MessageMetadata | None = None,
|
89
119
|
progress_callback: ProgressFnT | None = None,
|
90
120
|
) -> ReceiveResultT:
|
91
121
|
logger.debug("send_request: request=", data=request.model_dump())
|
92
122
|
try:
|
93
123
|
result = await super().send_request(
|
94
|
-
request,
|
95
|
-
result_type,
|
124
|
+
request=request,
|
125
|
+
result_type=result_type,
|
96
126
|
request_read_timeout_seconds=request_read_timeout_seconds,
|
127
|
+
metadata=metadata,
|
97
128
|
progress_callback=progress_callback,
|
98
129
|
)
|
99
|
-
logger.debug(
|
130
|
+
logger.debug(
|
131
|
+
"send_request: response=",
|
132
|
+
data=result.model_dump() if result is not None else "no response returned",
|
133
|
+
)
|
100
134
|
return result
|
101
135
|
except Exception as e:
|
102
136
|
logger.error(f"send_request failed: {str(e)}")
|
103
137
|
raise
|
104
138
|
|
105
|
-
async def send_notification(self, notification: SendNotificationT) -> None:
|
106
|
-
logger.debug("send_notification:", data=notification.model_dump())
|
107
|
-
try:
|
108
|
-
return await super().send_notification(notification)
|
109
|
-
except Exception as e:
|
110
|
-
logger.error("send_notification failed", data=e)
|
111
|
-
raise
|
112
|
-
|
113
|
-
async def _send_response(
|
114
|
-
self, request_id: RequestId, response: SendResultT | ErrorData
|
115
|
-
) -> None:
|
116
|
-
logger.debug(
|
117
|
-
f"send_response: request_id={request_id}, response=",
|
118
|
-
data=response.model_dump(),
|
119
|
-
)
|
120
|
-
return await super()._send_response(request_id, response)
|
121
|
-
|
122
139
|
async def _received_notification(self, notification: ServerNotification) -> None:
|
123
140
|
"""
|
124
141
|
Can be overridden by subclasses to handle a notification without needing
|
@@ -162,17 +179,3 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
162
179
|
await self._tool_list_changed_callback(server_name)
|
163
180
|
except Exception as e:
|
164
181
|
logger.error(f"Error in tool list changed callback: {e}")
|
165
|
-
|
166
|
-
async def send_progress_notification(
|
167
|
-
self, progress_token: str | int, progress: float, total: float | None = None
|
168
|
-
) -> None:
|
169
|
-
"""
|
170
|
-
Sends a progress notification for a request that is currently being
|
171
|
-
processed.
|
172
|
-
"""
|
173
|
-
logger.debug(
|
174
|
-
"send_progress_notification: progress_token={progress_token}, progress={progress}, total={total}"
|
175
|
-
)
|
176
|
-
return await super().send_progress_notification(
|
177
|
-
progress_token=progress_token, progress=progress, total=total
|
178
|
-
)
|