fast-agent-mcp 0.2.56__py3-none-any.whl → 0.2.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- {fast_agent_mcp-0.2.56.dist-info → fast_agent_mcp-0.2.58.dist-info}/METADATA +2 -2
- {fast_agent_mcp-0.2.56.dist-info → fast_agent_mcp-0.2.58.dist-info}/RECORD +25 -24
- mcp_agent/agents/agent.py +10 -3
- mcp_agent/agents/base_agent.py +7 -2
- mcp_agent/config.py +3 -0
- mcp_agent/core/agent_app.py +18 -6
- mcp_agent/core/enhanced_prompt.py +10 -2
- mcp_agent/core/fastagent.py +2 -0
- mcp_agent/core/request_params.py +5 -0
- mcp_agent/event_progress.py +3 -0
- mcp_agent/human_input/elicitation_form.py +45 -33
- mcp_agent/llm/augmented_llm.py +16 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +1 -0
- mcp_agent/llm/providers/augmented_llm_bedrock.py +890 -602
- mcp_agent/llm/providers/augmented_llm_google_native.py +1 -0
- mcp_agent/llm/providers/augmented_llm_openai.py +1 -0
- mcp_agent/llm/providers/bedrock_utils.py +216 -0
- mcp_agent/mcp/mcp_agent_client_session.py +105 -2
- mcp_agent/mcp/mcp_aggregator.py +92 -29
- mcp_agent/mcp/mcp_connection_manager.py +19 -0
- mcp_agent/resources/examples/mcp/elicitations/elicitation_forms_server.py +25 -3
- mcp_agent/ui/console_display.py +105 -15
- {fast_agent_mcp-0.2.56.dist-info → fast_agent_mcp-0.2.58.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.56.dist-info → fast_agent_mcp-0.2.58.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.56.dist-info → fast_agent_mcp-0.2.58.dist-info}/licenses/LICENSE +0 -0
|
@@ -160,6 +160,7 @@ class GoogleNativeAugmentedLLM(AugmentedLLM[types.Content, types.Content]):
|
|
|
160
160
|
AugmentedLLM.PARAM_USE_HISTORY, # Handled by AugmentedLLM base / this class's logic
|
|
161
161
|
AugmentedLLM.PARAM_MAX_ITERATIONS, # Handled by this class's loop
|
|
162
162
|
# Add any other OpenAI-specific params not applicable to google.genai
|
|
163
|
+
AugmentedLLM.PARAM_MCP_METADATA,
|
|
163
164
|
}.union(AugmentedLLM.BASE_EXCLUDE_FIELDS)
|
|
164
165
|
|
|
165
166
|
def __init__(self, *args, **kwargs) -> None:
|
|
@@ -59,6 +59,7 @@ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletion
|
|
|
59
59
|
AugmentedLLM.PARAM_USE_HISTORY,
|
|
60
60
|
AugmentedLLM.PARAM_MAX_ITERATIONS,
|
|
61
61
|
AugmentedLLM.PARAM_TEMPLATE_VARS,
|
|
62
|
+
AugmentedLLM.PARAM_MCP_METADATA,
|
|
62
63
|
}
|
|
63
64
|
|
|
64
65
|
def __init__(self, provider: Provider = Provider.OPENAI, *args, **kwargs) -> None:
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Collection, Dict, List, Literal, Optional, Set, TypedDict, cast
|
|
4
|
+
|
|
5
|
+
# Lightweight, runtime-only loader for AWS Bedrock models.
|
|
6
|
+
# - Fetches once per process via boto3 (region from session; env override supported)
|
|
7
|
+
# - Memory cache only; no disk persistence
|
|
8
|
+
# - Provides filtering and optional prefixing (default 'bedrock.') for model IDs
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import boto3
|
|
12
|
+
except Exception: # pragma: no cover - import error path
|
|
13
|
+
boto3 = None # type: ignore[assignment]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
Modality = Literal["TEXT", "IMAGE", "VIDEO", "SPEECH", "EMBEDDING"]
|
|
17
|
+
Lifecycle = Literal["ACTIVE", "LEGACY"]
|
|
18
|
+
InferenceType = Literal["ON_DEMAND", "PROVISIONED", "INFERENCE_PROFILE"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelSummary(TypedDict, total=False):
|
|
22
|
+
modelId: str
|
|
23
|
+
modelName: str
|
|
24
|
+
providerName: str
|
|
25
|
+
inputModalities: List[Modality]
|
|
26
|
+
outputModalities: List[Modality]
|
|
27
|
+
responseStreamingSupported: bool
|
|
28
|
+
customizationsSupported: List[str]
|
|
29
|
+
inferenceTypesSupported: List[InferenceType]
|
|
30
|
+
modelLifecycle: Dict[str, Lifecycle]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_MODELS_CACHE_BY_REGION: Dict[str, Dict[str, ModelSummary]] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _resolve_region(region: Optional[str]) -> str:
|
|
37
|
+
if region:
|
|
38
|
+
return region
|
|
39
|
+
import os
|
|
40
|
+
|
|
41
|
+
env_region = os.getenv("BEDROCK_REGION")
|
|
42
|
+
if env_region:
|
|
43
|
+
return env_region
|
|
44
|
+
if boto3 is None:
|
|
45
|
+
raise RuntimeError(
|
|
46
|
+
"boto3 is required to load Bedrock models. Install boto3 or provide a static list."
|
|
47
|
+
)
|
|
48
|
+
session = boto3.Session()
|
|
49
|
+
if not session.region_name:
|
|
50
|
+
raise RuntimeError(
|
|
51
|
+
"AWS region could not be resolved. Configure your AWS SSO/profile or set BEDROCK_REGION."
|
|
52
|
+
)
|
|
53
|
+
return session.region_name
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _strip_prefix(model_id: str, prefix: str) -> str:
|
|
57
|
+
return model_id[len(prefix) :] if prefix and model_id.startswith(prefix) else model_id
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _ensure_loaded(region: Optional[str] = None) -> Dict[str, ModelSummary]:
|
|
61
|
+
resolved_region = _resolve_region(region)
|
|
62
|
+
cache = _MODELS_CACHE_BY_REGION.get(resolved_region)
|
|
63
|
+
if cache is not None:
|
|
64
|
+
return cache
|
|
65
|
+
|
|
66
|
+
if boto3 is None:
|
|
67
|
+
raise RuntimeError("boto3 is required to load Bedrock models. Install boto3.")
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
client = boto3.client("bedrock", region_name=resolved_region)
|
|
71
|
+
resp = client.list_foundation_models()
|
|
72
|
+
summaries: List[ModelSummary] = resp.get("modelSummaries", []) # type: ignore[assignment]
|
|
73
|
+
except Exception as exc: # keep error simple and actionable
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
f"Failed to list Bedrock foundation models in region '{resolved_region}'. "
|
|
76
|
+
f"Ensure AWS credentials (SSO) and permissions (bedrock:ListFoundationModels) are configured. "
|
|
77
|
+
f"Original error: {exc}"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
cache = {s.get("modelId", ""): s for s in summaries if s.get("modelId")}
|
|
81
|
+
_MODELS_CACHE_BY_REGION[resolved_region] = cache
|
|
82
|
+
return cache
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def refresh_bedrock_models(region: Optional[str] = None) -> None:
|
|
86
|
+
resolved_region = _resolve_region(region)
|
|
87
|
+
# drop and reload on next access
|
|
88
|
+
_MODELS_CACHE_BY_REGION.pop(resolved_region, None)
|
|
89
|
+
_ensure_loaded(resolved_region)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _matches_modalities(model_modalities: List[Modality], requested: Collection[Modality]) -> bool:
|
|
93
|
+
# include if all requested are present in the model's modalities
|
|
94
|
+
return set(requested).issubset(set(model_modalities))
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def all_model_summaries(
|
|
98
|
+
input_modalities: Optional[Collection[Modality]] = None,
|
|
99
|
+
output_modalities: Optional[Collection[Modality]] = None,
|
|
100
|
+
include_legacy: bool = False,
|
|
101
|
+
providers: Optional[Collection[str]] = None,
|
|
102
|
+
inference_types: Optional[Collection[InferenceType]] = None,
|
|
103
|
+
direct_invocation_only: bool = True,
|
|
104
|
+
region: Optional[str] = None,
|
|
105
|
+
) -> List[ModelSummary]:
|
|
106
|
+
"""Return filtered Bedrock model summaries.
|
|
107
|
+
|
|
108
|
+
Defaults: input_modalities={"TEXT"}, output_modalities={"TEXT"}, include_legacy=False,
|
|
109
|
+
inference_types={"ON_DEMAND"}, direct_invocation_only=True.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
cache = _ensure_loaded(region)
|
|
113
|
+
results: List[ModelSummary] = []
|
|
114
|
+
|
|
115
|
+
effective_output: Set[Modality] = (
|
|
116
|
+
set(output_modalities) if output_modalities is not None else {cast("Modality", "TEXT")}
|
|
117
|
+
)
|
|
118
|
+
effective_input: Optional[Set[Modality]] = (
|
|
119
|
+
set(input_modalities) if input_modalities is not None else {cast("Modality", "TEXT")}
|
|
120
|
+
)
|
|
121
|
+
provider_filter: Optional[Set[str]] = set(providers) if providers is not None else None
|
|
122
|
+
effective_inference: Set[InferenceType] = (
|
|
123
|
+
set(inference_types) if inference_types is not None else {cast("InferenceType", "ON_DEMAND")}
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
for summary in cache.values():
|
|
127
|
+
lifecycle = (summary.get("modelLifecycle") or {}).get("status")
|
|
128
|
+
if not include_legacy and lifecycle == "LEGACY":
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
if provider_filter is not None and summary.get("providerName") not in provider_filter:
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
# direct invocation only: exclude profile variants like :0:24k or :mm
|
|
135
|
+
if direct_invocation_only:
|
|
136
|
+
mid = summary.get("modelId") or ""
|
|
137
|
+
if mid.count(":") > 1:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# modalities
|
|
141
|
+
model_inputs: List[Modality] = summary.get("inputModalities", []) # type: ignore[assignment]
|
|
142
|
+
model_outputs: List[Modality] = summary.get("outputModalities", []) # type: ignore[assignment]
|
|
143
|
+
|
|
144
|
+
if effective_input is not None and not _matches_modalities(model_inputs, effective_input):
|
|
145
|
+
continue
|
|
146
|
+
if effective_output and not _matches_modalities(model_outputs, effective_output):
|
|
147
|
+
continue
|
|
148
|
+
|
|
149
|
+
# inference types
|
|
150
|
+
model_inference: List[InferenceType] = summary.get("inferenceTypesSupported", []) # type: ignore[assignment]
|
|
151
|
+
if effective_inference and not set(effective_inference).issubset(set(model_inference)):
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
results.append(summary)
|
|
155
|
+
|
|
156
|
+
return results
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def all_bedrock_models(
|
|
160
|
+
input_modalities: Optional[Collection[Modality]] = None,
|
|
161
|
+
output_modalities: Optional[Collection[Modality]] = None,
|
|
162
|
+
include_legacy: bool = False,
|
|
163
|
+
providers: Optional[Collection[str]] = None,
|
|
164
|
+
prefix: str = "bedrock.",
|
|
165
|
+
inference_types: Optional[Collection[InferenceType]] = None,
|
|
166
|
+
direct_invocation_only: bool = True,
|
|
167
|
+
region: Optional[str] = None,
|
|
168
|
+
) -> List[str]:
|
|
169
|
+
"""Return model IDs (optionally prefixed) filtered by the given criteria.
|
|
170
|
+
|
|
171
|
+
Defaults: output_modalities={"TEXT"}, exclude LEGACY,
|
|
172
|
+
inference_types={"ON_DEMAND"}, direct_invocation_only=True.
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
summaries = all_model_summaries(
|
|
176
|
+
input_modalities=input_modalities,
|
|
177
|
+
output_modalities=output_modalities,
|
|
178
|
+
include_legacy=include_legacy,
|
|
179
|
+
providers=providers,
|
|
180
|
+
inference_types=inference_types,
|
|
181
|
+
direct_invocation_only=direct_invocation_only,
|
|
182
|
+
region=region,
|
|
183
|
+
)
|
|
184
|
+
ids: List[str] = []
|
|
185
|
+
for s in summaries:
|
|
186
|
+
mid = s.get("modelId")
|
|
187
|
+
if mid:
|
|
188
|
+
ids.append(mid)
|
|
189
|
+
if prefix:
|
|
190
|
+
return [f"{prefix}{mid}" for mid in ids]
|
|
191
|
+
return ids
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_model_metadata(model_id: str, region: Optional[str] = None) -> Optional[ModelSummary]:
|
|
195
|
+
cache = _ensure_loaded(region)
|
|
196
|
+
# Accept either prefixed or plain model IDs
|
|
197
|
+
plain_id = _strip_prefix(model_id, "bedrock.")
|
|
198
|
+
return cache.get(plain_id)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def list_providers(region: Optional[str] = None) -> List[str]:
|
|
202
|
+
cache = _ensure_loaded(region)
|
|
203
|
+
providers = {s.get("providerName") for s in cache.values() if s.get("providerName")}
|
|
204
|
+
return sorted(providers) # type: ignore[arg-type]
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
__all__ = [
|
|
208
|
+
"Modality",
|
|
209
|
+
"Lifecycle",
|
|
210
|
+
"ModelSummary",
|
|
211
|
+
"all_bedrock_models",
|
|
212
|
+
"all_model_summaries",
|
|
213
|
+
"get_model_metadata",
|
|
214
|
+
"list_providers",
|
|
215
|
+
"refresh_bedrock_models",
|
|
216
|
+
]
|
|
@@ -14,8 +14,17 @@ from mcp.shared.session import (
|
|
|
14
14
|
SendRequestT,
|
|
15
15
|
)
|
|
16
16
|
from mcp.types import (
|
|
17
|
+
CallToolRequest,
|
|
18
|
+
CallToolRequestParams,
|
|
19
|
+
CallToolResult,
|
|
20
|
+
GetPromptRequest,
|
|
21
|
+
GetPromptRequestParams,
|
|
22
|
+
GetPromptResult,
|
|
17
23
|
Implementation,
|
|
18
24
|
ListRootsResult,
|
|
25
|
+
ReadResourceRequest,
|
|
26
|
+
ReadResourceRequestParams,
|
|
27
|
+
ReadResourceResult,
|
|
19
28
|
Root,
|
|
20
29
|
ToolListChangedNotification,
|
|
21
30
|
)
|
|
@@ -180,8 +189,17 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
|
180
189
|
)
|
|
181
190
|
return result
|
|
182
191
|
except Exception as e:
|
|
183
|
-
|
|
184
|
-
|
|
192
|
+
# Handle connection errors cleanly
|
|
193
|
+
from anyio import ClosedResourceError
|
|
194
|
+
|
|
195
|
+
if isinstance(e, ClosedResourceError):
|
|
196
|
+
# Show clean offline message and convert to ConnectionError
|
|
197
|
+
from mcp_agent import console
|
|
198
|
+
console.console.print(f"[dim red]MCP server {self.session_server_name} offline[/dim red]")
|
|
199
|
+
raise ConnectionError(f"MCP server {self.session_server_name} offline") from e
|
|
200
|
+
else:
|
|
201
|
+
logger.error(f"send_request failed: {str(e)}")
|
|
202
|
+
raise
|
|
185
203
|
|
|
186
204
|
async def _received_notification(self, notification: ServerNotification) -> None:
|
|
187
205
|
"""
|
|
@@ -226,3 +244,88 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
|
226
244
|
await self._tool_list_changed_callback(server_name)
|
|
227
245
|
except Exception as e:
|
|
228
246
|
logger.error(f"Error in tool list changed callback: {e}")
|
|
247
|
+
|
|
248
|
+
async def call_tool(
|
|
249
|
+
self,
|
|
250
|
+
name: str,
|
|
251
|
+
arguments: dict | None = None,
|
|
252
|
+
_meta: dict | None = None,
|
|
253
|
+
**kwargs
|
|
254
|
+
) -> CallToolResult:
|
|
255
|
+
"""Call a tool with optional metadata support."""
|
|
256
|
+
if _meta:
|
|
257
|
+
from mcp.types import RequestParams
|
|
258
|
+
|
|
259
|
+
# Safe merge - preserve existing meta fields like progressToken
|
|
260
|
+
existing_meta = kwargs.get('meta')
|
|
261
|
+
if existing_meta:
|
|
262
|
+
meta_dict = existing_meta.model_dump() if hasattr(existing_meta, 'model_dump') else {}
|
|
263
|
+
meta_dict.update(_meta)
|
|
264
|
+
meta_obj = RequestParams.Meta(**meta_dict)
|
|
265
|
+
else:
|
|
266
|
+
meta_obj = RequestParams.Meta(**_meta)
|
|
267
|
+
|
|
268
|
+
# Create CallToolRequestParams without meta, then add _meta via model_dump
|
|
269
|
+
params = CallToolRequestParams(name=name, arguments=arguments)
|
|
270
|
+
params_dict = params.model_dump(by_alias=True)
|
|
271
|
+
params_dict["_meta"] = meta_obj.model_dump()
|
|
272
|
+
|
|
273
|
+
# Create request with proper types
|
|
274
|
+
request = CallToolRequest(
|
|
275
|
+
method="tools/call",
|
|
276
|
+
params=CallToolRequestParams.model_validate(params_dict)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return await self.send_request(request, CallToolResult)
|
|
280
|
+
else:
|
|
281
|
+
return await super().call_tool(name, arguments, **kwargs)
|
|
282
|
+
|
|
283
|
+
async def read_resource(self, uri: str, _meta: dict | None = None, **kwargs) -> ReadResourceResult:
|
|
284
|
+
"""Read a resource with optional metadata support."""
|
|
285
|
+
if _meta:
|
|
286
|
+
from mcp.types import RequestParams
|
|
287
|
+
|
|
288
|
+
# Safe merge - preserve existing meta fields like progressToken
|
|
289
|
+
existing_meta = kwargs.get('meta')
|
|
290
|
+
if existing_meta:
|
|
291
|
+
meta_dict = existing_meta.model_dump() if hasattr(existing_meta, 'model_dump') else {}
|
|
292
|
+
meta_dict.update(_meta)
|
|
293
|
+
meta_obj = RequestParams.Meta(**meta_dict)
|
|
294
|
+
else:
|
|
295
|
+
meta_obj = RequestParams.Meta(**_meta)
|
|
296
|
+
|
|
297
|
+
request = ReadResourceRequest(
|
|
298
|
+
method="resources/read",
|
|
299
|
+
params=ReadResourceRequestParams(uri=uri, meta=meta_obj)
|
|
300
|
+
)
|
|
301
|
+
return await self.send_request(request, ReadResourceResult)
|
|
302
|
+
else:
|
|
303
|
+
return await super().read_resource(uri, **kwargs)
|
|
304
|
+
|
|
305
|
+
async def get_prompt(
|
|
306
|
+
self,
|
|
307
|
+
name: str,
|
|
308
|
+
arguments: dict | None = None,
|
|
309
|
+
_meta: dict | None = None,
|
|
310
|
+
**kwargs
|
|
311
|
+
) -> GetPromptResult:
|
|
312
|
+
"""Get a prompt with optional metadata support."""
|
|
313
|
+
if _meta:
|
|
314
|
+
from mcp.types import RequestParams
|
|
315
|
+
|
|
316
|
+
# Safe merge - preserve existing meta fields like progressToken
|
|
317
|
+
existing_meta = kwargs.get('meta')
|
|
318
|
+
if existing_meta:
|
|
319
|
+
meta_dict = existing_meta.model_dump() if hasattr(existing_meta, 'model_dump') else {}
|
|
320
|
+
meta_dict.update(_meta)
|
|
321
|
+
meta_obj = RequestParams.Meta(**meta_dict)
|
|
322
|
+
else:
|
|
323
|
+
meta_obj = RequestParams.Meta(**_meta)
|
|
324
|
+
|
|
325
|
+
request = GetPromptRequest(
|
|
326
|
+
method="prompts/get",
|
|
327
|
+
params=GetPromptRequestParams(name=name, arguments=arguments, meta=meta_obj)
|
|
328
|
+
)
|
|
329
|
+
return await self.send_request(request, GetPromptResult)
|
|
330
|
+
else:
|
|
331
|
+
return await super().get_prompt(name, arguments, **kwargs)
|
mcp_agent/mcp/mcp_aggregator.py
CHANGED
|
@@ -139,7 +139,10 @@ class MCPAggregator(ContextDependent):
|
|
|
139
139
|
|
|
140
140
|
def _create_progress_callback(self, server_name: str, tool_name: str) -> "ProgressFnT":
|
|
141
141
|
"""Create a progress callback function for tool execution."""
|
|
142
|
-
|
|
142
|
+
|
|
143
|
+
async def progress_callback(
|
|
144
|
+
progress: float, total: float | None, message: str | None
|
|
145
|
+
) -> None:
|
|
143
146
|
"""Handle progress notifications from MCP tool execution."""
|
|
144
147
|
logger.info(
|
|
145
148
|
"Tool progress update",
|
|
@@ -153,6 +156,7 @@ class MCPAggregator(ContextDependent):
|
|
|
153
156
|
"details": message or "", # Put the message in details column
|
|
154
157
|
},
|
|
155
158
|
)
|
|
159
|
+
|
|
156
160
|
return progress_callback
|
|
157
161
|
|
|
158
162
|
async def close(self) -> None:
|
|
@@ -508,12 +512,28 @@ class MCPAggregator(ContextDependent):
|
|
|
508
512
|
async def try_execute(client: ClientSession):
|
|
509
513
|
try:
|
|
510
514
|
method = getattr(client, method_name)
|
|
515
|
+
|
|
516
|
+
# Get metadata from context for tool, resource, and prompt calls
|
|
517
|
+
metadata = None
|
|
518
|
+
if method_name in ["call_tool", "read_resource", "get_prompt"]:
|
|
519
|
+
from mcp_agent.llm.augmented_llm import _mcp_metadata_var
|
|
520
|
+
|
|
521
|
+
metadata = _mcp_metadata_var.get()
|
|
522
|
+
|
|
523
|
+
# Prepare kwargs
|
|
524
|
+
kwargs = method_args or {}
|
|
525
|
+
if metadata:
|
|
526
|
+
kwargs["_meta"] = metadata
|
|
527
|
+
|
|
511
528
|
# For call_tool method, check if we need to add progress_callback
|
|
512
529
|
if method_name == "call_tool" and progress_callback:
|
|
513
530
|
# The call_tool method signature includes progress_callback parameter
|
|
514
|
-
return await method(
|
|
531
|
+
return await method(progress_callback=progress_callback, **kwargs)
|
|
515
532
|
else:
|
|
516
|
-
return await method(**
|
|
533
|
+
return await method(**(kwargs or {}))
|
|
534
|
+
except ConnectionError:
|
|
535
|
+
# Let ConnectionError pass through for reconnection logic
|
|
536
|
+
raise
|
|
517
537
|
except Exception as e:
|
|
518
538
|
error_msg = (
|
|
519
539
|
f"Failed to {method_name} '{operation_name}' on server '{server_name}': {e}"
|
|
@@ -525,34 +545,77 @@ class MCPAggregator(ContextDependent):
|
|
|
525
545
|
# Re-raise the original exception to propagate it
|
|
526
546
|
raise e
|
|
527
547
|
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
data={
|
|
537
|
-
"progress_action": ProgressAction.STARTING,
|
|
538
|
-
"server_name": server_name,
|
|
539
|
-
"agent_name": self.agent_name,
|
|
540
|
-
},
|
|
541
|
-
)
|
|
542
|
-
async with gen_client(
|
|
543
|
-
server_name, server_registry=self.context.server_registry
|
|
544
|
-
) as client:
|
|
545
|
-
result = await try_execute(client)
|
|
548
|
+
# Try initial execution
|
|
549
|
+
try:
|
|
550
|
+
if self.connection_persistence:
|
|
551
|
+
server_connection = await self._persistent_connection_manager.get_server(
|
|
552
|
+
server_name, client_session_factory=MCPAgentClientSession
|
|
553
|
+
)
|
|
554
|
+
return await try_execute(server_connection.session)
|
|
555
|
+
else:
|
|
546
556
|
logger.debug(
|
|
547
|
-
f"
|
|
557
|
+
f"Creating temporary connection to server: {server_name}",
|
|
548
558
|
data={
|
|
549
|
-
"progress_action": ProgressAction.
|
|
559
|
+
"progress_action": ProgressAction.STARTING,
|
|
550
560
|
"server_name": server_name,
|
|
551
561
|
"agent_name": self.agent_name,
|
|
552
562
|
},
|
|
553
563
|
)
|
|
564
|
+
async with gen_client(
|
|
565
|
+
server_name, server_registry=self.context.server_registry
|
|
566
|
+
) as client:
|
|
567
|
+
result = await try_execute(client)
|
|
568
|
+
logger.debug(
|
|
569
|
+
f"Closing temporary connection to server: {server_name}",
|
|
570
|
+
data={
|
|
571
|
+
"progress_action": ProgressAction.SHUTDOWN,
|
|
572
|
+
"server_name": server_name,
|
|
573
|
+
"agent_name": self.agent_name,
|
|
574
|
+
},
|
|
575
|
+
)
|
|
576
|
+
return result
|
|
577
|
+
except ConnectionError:
|
|
578
|
+
# Server offline - attempt reconnection
|
|
579
|
+
from mcp_agent import console
|
|
580
|
+
|
|
581
|
+
console.console.print(
|
|
582
|
+
f"[dim yellow]MCP server {server_name} reconnecting...[/dim yellow]"
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
if self.connection_persistence:
|
|
587
|
+
# Force disconnect and create fresh connection
|
|
588
|
+
await self._persistent_connection_manager.disconnect_server(server_name)
|
|
589
|
+
import asyncio
|
|
590
|
+
|
|
591
|
+
await asyncio.sleep(0.1)
|
|
592
|
+
|
|
593
|
+
server_connection = await self._persistent_connection_manager.get_server(
|
|
594
|
+
server_name, client_session_factory=MCPAgentClientSession
|
|
595
|
+
)
|
|
596
|
+
result = await try_execute(server_connection.session)
|
|
597
|
+
else:
|
|
598
|
+
# For non-persistent connections, just try again
|
|
599
|
+
async with gen_client(
|
|
600
|
+
server_name, server_registry=self.context.server_registry
|
|
601
|
+
) as client:
|
|
602
|
+
result = await try_execute(client)
|
|
603
|
+
|
|
604
|
+
# Success!
|
|
605
|
+
console.console.print(f"[dim green]MCP server {server_name} online[/dim green]")
|
|
554
606
|
return result
|
|
555
607
|
|
|
608
|
+
except Exception:
|
|
609
|
+
# Reconnection failed
|
|
610
|
+
console.console.print(
|
|
611
|
+
f"[dim red]MCP server {server_name} offline - failed to reconnect[/dim red]"
|
|
612
|
+
)
|
|
613
|
+
error_msg = f"MCP server {server_name} offline - failed to reconnect"
|
|
614
|
+
if error_factory:
|
|
615
|
+
return error_factory(error_msg)
|
|
616
|
+
else:
|
|
617
|
+
raise Exception(error_msg)
|
|
618
|
+
|
|
556
619
|
async def _parse_resource_name(self, name: str, resource_type: str) -> tuple[str, str]:
|
|
557
620
|
"""
|
|
558
621
|
Parse a possibly namespaced resource name into server name and local resource name.
|
|
@@ -575,7 +638,7 @@ class MCPAggregator(ContextDependent):
|
|
|
575
638
|
# Try to match against known server names, handling server names with hyphens
|
|
576
639
|
for server_name in self.server_names:
|
|
577
640
|
if name.startswith(f"{server_name}{SEP}"):
|
|
578
|
-
local_name = name[len(server_name) + len(SEP):]
|
|
641
|
+
local_name = name[len(server_name) + len(SEP) :]
|
|
579
642
|
return server_name, local_name
|
|
580
643
|
|
|
581
644
|
# If no server name matched, it might be a tool with a hyphen in its name
|
|
@@ -622,10 +685,10 @@ class MCPAggregator(ContextDependent):
|
|
|
622
685
|
with tracer.start_as_current_span(f"MCP Tool: {server_name}/{local_tool_name}"):
|
|
623
686
|
trace.get_current_span().set_attribute("tool_name", local_tool_name)
|
|
624
687
|
trace.get_current_span().set_attribute("server_name", server_name)
|
|
625
|
-
|
|
688
|
+
|
|
626
689
|
# Create progress callback for this tool execution
|
|
627
690
|
progress_callback = self._create_progress_callback(server_name, local_tool_name)
|
|
628
|
-
|
|
691
|
+
|
|
629
692
|
return await self._execute_on_server(
|
|
630
693
|
server_name=server_name,
|
|
631
694
|
operation_type="tool",
|
|
@@ -1235,11 +1298,11 @@ class MCPAggregator(ContextDependent):
|
|
|
1235
1298
|
async def list_mcp_tools(self, server_name: str | None = None) -> Dict[str, List[Tool]]:
|
|
1236
1299
|
"""
|
|
1237
1300
|
List available tools from one or all servers, grouped by server name.
|
|
1238
|
-
|
|
1301
|
+
|
|
1239
1302
|
Args:
|
|
1240
1303
|
server_name: Optional server name to list tools from. If not provided,
|
|
1241
1304
|
lists tools from all servers.
|
|
1242
|
-
|
|
1305
|
+
|
|
1243
1306
|
Returns:
|
|
1244
1307
|
Dictionary mapping server names to lists of Tool objects (with original names, not namespaced)
|
|
1245
1308
|
"""
|
|
@@ -1248,7 +1311,7 @@ class MCPAggregator(ContextDependent):
|
|
|
1248
1311
|
|
|
1249
1312
|
results: Dict[str, List[Tool]] = {}
|
|
1250
1313
|
|
|
1251
|
-
# Get the list of servers to check
|
|
1314
|
+
# Get the list of servers to check
|
|
1252
1315
|
servers_to_check = [server_name] if server_name else self.server_names
|
|
1253
1316
|
|
|
1254
1317
|
# For each server, try to list its tools
|
|
@@ -272,6 +272,7 @@ class MCPConnectionManager(ContextDependent):
|
|
|
272
272
|
# Manage our own task group - independent of task context
|
|
273
273
|
self._task_group = None
|
|
274
274
|
self._task_group_active = False
|
|
275
|
+
self._mcp_sse_filter_added = False
|
|
275
276
|
|
|
276
277
|
async def __aenter__(self):
|
|
277
278
|
# Create a task group that isn't tied to a specific task
|
|
@@ -300,6 +301,21 @@ class MCPConnectionManager(ContextDependent):
|
|
|
300
301
|
except Exception as e:
|
|
301
302
|
logger.error(f"Error during connection manager shutdown: {e}")
|
|
302
303
|
|
|
304
|
+
def _suppress_mcp_sse_errors(self) -> None:
|
|
305
|
+
"""Suppress MCP library's 'Error in sse_reader' messages."""
|
|
306
|
+
if self._mcp_sse_filter_added:
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
import logging
|
|
310
|
+
|
|
311
|
+
class MCPSSEErrorFilter(logging.Filter):
|
|
312
|
+
def filter(self, record):
|
|
313
|
+
return not (record.name == "mcp.client.sse" and "Error in sse_reader" in record.getMessage())
|
|
314
|
+
|
|
315
|
+
mcp_sse_logger = logging.getLogger("mcp.client.sse")
|
|
316
|
+
mcp_sse_logger.addFilter(MCPSSEErrorFilter())
|
|
317
|
+
self._mcp_sse_filter_added = True
|
|
318
|
+
|
|
303
319
|
async def launch_server(
|
|
304
320
|
self,
|
|
305
321
|
server_name: str,
|
|
@@ -341,6 +357,9 @@ class MCPConnectionManager(ContextDependent):
|
|
|
341
357
|
logger.debug(f"{server_name}: Creating stdio client with custom error handler")
|
|
342
358
|
return _add_none_to_context(stdio_client(server_params, errlog=error_handler))
|
|
343
359
|
elif config.transport == "sse":
|
|
360
|
+
# Suppress MCP library error spam
|
|
361
|
+
self._suppress_mcp_sse_errors()
|
|
362
|
+
|
|
344
363
|
return _add_none_to_context(
|
|
345
364
|
sse_client(
|
|
346
365
|
config.url,
|
|
@@ -100,11 +100,25 @@ async def product_review() -> ReadResourceResult:
|
|
|
100
100
|
},
|
|
101
101
|
)
|
|
102
102
|
review_text: str = Field(
|
|
103
|
-
description="Tell us about your experience",
|
|
103
|
+
description="Tell us about your experience",
|
|
104
|
+
default="""Great product!
|
|
105
|
+
Here's what I loved:
|
|
106
|
+
|
|
107
|
+
- Excellent build quality
|
|
108
|
+
- Fast shipping
|
|
109
|
+
- Works as advertised
|
|
110
|
+
|
|
111
|
+
One minor issue:
|
|
112
|
+
- Instructions could be clearer
|
|
113
|
+
|
|
114
|
+
Overall, highly recommended!""",
|
|
115
|
+
min_length=10,
|
|
116
|
+
max_length=1000
|
|
104
117
|
)
|
|
105
118
|
|
|
106
119
|
result = await mcp.get_context().elicit(
|
|
107
|
-
"Share your product review - Help others make informed decisions!",
|
|
120
|
+
"Share your product review - Help others make informed decisions!",
|
|
121
|
+
schema=ProductReview,
|
|
108
122
|
)
|
|
109
123
|
|
|
110
124
|
match result:
|
|
@@ -140,6 +154,7 @@ async def account_settings() -> ReadResourceResult:
|
|
|
140
154
|
email_notifications: bool = Field(True, description="Receive email notifications?")
|
|
141
155
|
marketing_emails: bool = Field(False, description="Subscribe to marketing emails?")
|
|
142
156
|
theme: str = Field(
|
|
157
|
+
"dark",
|
|
143
158
|
description="Choose your preferred theme",
|
|
144
159
|
json_schema_extra={
|
|
145
160
|
"enum": ["light", "dark", "auto"],
|
|
@@ -147,7 +162,9 @@ async def account_settings() -> ReadResourceResult:
|
|
|
147
162
|
},
|
|
148
163
|
)
|
|
149
164
|
privacy_public: bool = Field(False, description="Make your profile public?")
|
|
150
|
-
items_per_page: int = Field(
|
|
165
|
+
items_per_page: int = Field(
|
|
166
|
+
25, description="Items to show per page (10-100)", ge=10, le=100
|
|
167
|
+
)
|
|
151
168
|
|
|
152
169
|
result = await mcp.get_context().elicit("Update your account settings", schema=AccountSettings)
|
|
153
170
|
|
|
@@ -182,7 +199,11 @@ async def service_appointment() -> ReadResourceResult:
|
|
|
182
199
|
|
|
183
200
|
class ServiceAppointment(BaseModel):
|
|
184
201
|
customer_name: str = Field(description="Your full name", min_length=2, max_length=50)
|
|
202
|
+
phone_number: str = Field(
|
|
203
|
+
"555-", description="Contact phone number", min_length=10, max_length=20
|
|
204
|
+
)
|
|
185
205
|
vehicle_type: str = Field(
|
|
206
|
+
default="sedan",
|
|
186
207
|
description="What type of vehicle do you have?",
|
|
187
208
|
json_schema_extra={
|
|
188
209
|
"enum": ["sedan", "suv", "truck", "motorcycle", "other"],
|
|
@@ -205,6 +226,7 @@ async def service_appointment() -> ReadResourceResult:
|
|
|
205
226
|
lines = [
|
|
206
227
|
"🔧 Service Appointment Scheduled!",
|
|
207
228
|
f"👤 Customer: {data.customer_name}",
|
|
229
|
+
f"📞 Phone: {data.phone_number}",
|
|
208
230
|
f"🚗 Vehicle: {data.vehicle_type.title()}",
|
|
209
231
|
f"🚙 Loaner needed: {'Yes' if data.needs_loaner else 'No'}",
|
|
210
232
|
f"📅 Appointment: {data.appointment_time}",
|