fast-agent-mcp 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.3.dist-info → fast_agent_mcp-0.1.5.dist-info}/METADATA +5 -1
- {fast_agent_mcp-0.1.3.dist-info → fast_agent_mcp-0.1.5.dist-info}/RECORD +28 -17
- mcp_agent/agents/agent.py +46 -0
- mcp_agent/core/agent_app.py +373 -9
- mcp_agent/core/decorators.py +455 -0
- mcp_agent/core/enhanced_prompt.py +70 -4
- mcp_agent/core/factory.py +501 -0
- mcp_agent/core/fastagent.py +140 -1059
- mcp_agent/core/proxies.py +83 -47
- mcp_agent/core/validation.py +221 -0
- mcp_agent/human_input/handler.py +5 -2
- mcp_agent/mcp/mcp_aggregator.py +537 -47
- mcp_agent/mcp/mcp_connection_manager.py +13 -2
- mcp_agent/mcp_server/__init__.py +4 -0
- mcp_agent/mcp_server/agent_server.py +121 -0
- mcp_agent/resources/examples/internal/fastagent.config.yaml +52 -0
- mcp_agent/resources/examples/internal/prompt_category.py +21 -0
- mcp_agent/resources/examples/internal/prompt_sizing.py +53 -0
- mcp_agent/resources/examples/internal/sizer.py +24 -0
- mcp_agent/resources/examples/researcher/fastagent.config.yaml +14 -1
- mcp_agent/resources/examples/workflows/sse.py +23 -0
- mcp_agent/ui/console_display.py +278 -0
- mcp_agent/workflows/llm/augmented_llm.py +245 -179
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +49 -3
- mcp_agent/workflows/llm/augmented_llm_openai.py +52 -4
- {fast_agent_mcp-0.1.3.dist-info → fast_agent_mcp-0.1.5.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.3.dist-info → fast_agent_mcp-0.1.5.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.3.dist-info → fast_agent_mcp-0.1.5.dist-info}/licenses/LICENSE +0 -0
mcp_agent/mcp/mcp_aggregator.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1
1
|
from asyncio import Lock, gather
|
2
|
-
from typing import
|
3
|
-
|
2
|
+
from typing import (
|
3
|
+
List,
|
4
|
+
Dict,
|
5
|
+
Optional,
|
6
|
+
TYPE_CHECKING,
|
7
|
+
Any,
|
8
|
+
Callable,
|
9
|
+
TypeVar,
|
10
|
+
)
|
11
|
+
from mcp import GetPromptResult
|
4
12
|
from pydantic import BaseModel, ConfigDict
|
5
13
|
from mcp.client.session import ClientSession
|
6
14
|
from mcp.server.lowlevel.server import Server
|
@@ -9,6 +17,7 @@ from mcp.types import (
|
|
9
17
|
CallToolResult,
|
10
18
|
ListToolsResult,
|
11
19
|
Tool,
|
20
|
+
Prompt,
|
12
21
|
)
|
13
22
|
|
14
23
|
from mcp_agent.event_progress import ProgressAction
|
@@ -29,6 +38,10 @@ logger = get_logger(
|
|
29
38
|
|
30
39
|
SEP = "-"
|
31
40
|
|
41
|
+
# Define type variables for the generalized method
|
42
|
+
T = TypeVar("T")
|
43
|
+
R = TypeVar("R")
|
44
|
+
|
32
45
|
|
33
46
|
class NamespacedTool(BaseModel):
|
34
47
|
"""
|
@@ -112,7 +125,9 @@ class MCPAggregator(ContextDependent):
|
|
112
125
|
self._server_to_tool_map: Dict[str, List[NamespacedTool]] = {}
|
113
126
|
self._tool_map_lock = Lock()
|
114
127
|
|
115
|
-
#
|
128
|
+
# Cache for prompt objects, maps server_name -> list of prompt objects
|
129
|
+
self._prompt_cache: Dict[str, List[Prompt]] = {}
|
130
|
+
self._prompt_cache_lock = Lock()
|
116
131
|
|
117
132
|
async def close(self):
|
118
133
|
"""
|
@@ -172,6 +187,7 @@ class MCPAggregator(ContextDependent):
|
|
172
187
|
async def load_servers(self):
|
173
188
|
"""
|
174
189
|
Discover tools from each server in parallel and build an index of namespaced tool names.
|
190
|
+
Also populate the prompt cache.
|
175
191
|
"""
|
176
192
|
if self.initialized:
|
177
193
|
logger.debug("MCPAggregator already initialized.")
|
@@ -181,6 +197,9 @@ class MCPAggregator(ContextDependent):
|
|
181
197
|
self._namespaced_tool_map.clear()
|
182
198
|
self._server_to_tool_map.clear()
|
183
199
|
|
200
|
+
async with self._prompt_cache_lock:
|
201
|
+
self._prompt_cache.clear()
|
202
|
+
|
184
203
|
for server_name in self.server_names:
|
185
204
|
if self.connection_persistence:
|
186
205
|
logger.info(
|
@@ -211,8 +230,26 @@ class MCPAggregator(ContextDependent):
|
|
211
230
|
logger.error(f"Error loading tools from server '{server_name}'", data=e)
|
212
231
|
return []
|
213
232
|
|
214
|
-
async def
|
233
|
+
async def fetch_prompts(
|
234
|
+
client: ClientSession, server_name: str
|
235
|
+
) -> List[Prompt]:
|
236
|
+
# Only fetch prompts if the server supports them
|
237
|
+
capabilities = await self.get_capabilities(server_name)
|
238
|
+
if not capabilities or not capabilities.prompts:
|
239
|
+
logger.debug(f"Server '{server_name}' does not support prompts")
|
240
|
+
return []
|
241
|
+
|
242
|
+
try:
|
243
|
+
result = await client.list_prompts()
|
244
|
+
return getattr(result, "prompts", [])
|
245
|
+
except Exception as e:
|
246
|
+
logger.debug(f"Error loading prompts from server '{server_name}': {e}")
|
247
|
+
return []
|
248
|
+
|
249
|
+
async def load_server_data(server_name: str):
|
215
250
|
tools: List[Tool] = []
|
251
|
+
prompts: List[Prompt] = []
|
252
|
+
|
216
253
|
if self.connection_persistence:
|
217
254
|
server_connection = (
|
218
255
|
await self._persistent_connection_manager.get_server(
|
@@ -220,25 +257,30 @@ class MCPAggregator(ContextDependent):
|
|
220
257
|
)
|
221
258
|
)
|
222
259
|
tools = await fetch_tools(server_connection.session)
|
260
|
+
prompts = await fetch_prompts(server_connection.session, server_name)
|
223
261
|
else:
|
224
262
|
async with gen_client(
|
225
263
|
server_name, server_registry=self.context.server_registry
|
226
264
|
) as client:
|
227
265
|
tools = await fetch_tools(client)
|
266
|
+
prompts = await fetch_prompts(client, server_name)
|
228
267
|
|
229
|
-
return server_name, tools
|
268
|
+
return server_name, tools, prompts
|
230
269
|
|
231
|
-
# Gather
|
270
|
+
# Gather data from all servers concurrently
|
232
271
|
results = await gather(
|
233
|
-
*(
|
272
|
+
*(load_server_data(server_name) for server_name in self.server_names),
|
234
273
|
return_exceptions=True,
|
235
274
|
)
|
236
275
|
|
237
276
|
for result in results:
|
238
277
|
if isinstance(result, BaseException):
|
278
|
+
logger.error(f"Error loading server data: {result}")
|
239
279
|
continue
|
240
|
-
server_name, tools = result
|
241
280
|
|
281
|
+
server_name, tools, prompts = result
|
282
|
+
|
283
|
+
# Process tools
|
242
284
|
self._server_to_tool_map[server_name] = []
|
243
285
|
for tool in tools:
|
244
286
|
namespaced_tool_name = f"{server_name}{SEP}{tool.name}"
|
@@ -250,16 +292,40 @@ class MCPAggregator(ContextDependent):
|
|
250
292
|
|
251
293
|
self._namespaced_tool_map[namespaced_tool_name] = namespaced_tool
|
252
294
|
self._server_to_tool_map[server_name].append(namespaced_tool)
|
295
|
+
|
296
|
+
# Process prompts
|
297
|
+
async with self._prompt_cache_lock:
|
298
|
+
self._prompt_cache[server_name] = prompts
|
299
|
+
|
253
300
|
logger.debug(
|
254
|
-
"MCP Aggregator initialized",
|
301
|
+
f"MCP Aggregator initialized for server '{server_name}'",
|
255
302
|
data={
|
256
303
|
"progress_action": ProgressAction.INITIALIZED,
|
257
304
|
"server_name": server_name,
|
258
305
|
"agent_name": self.agent_name,
|
306
|
+
"tool_count": len(tools),
|
307
|
+
"prompt_count": len(prompts),
|
259
308
|
},
|
260
309
|
)
|
310
|
+
|
261
311
|
self.initialized = True
|
262
312
|
|
313
|
+
async def get_capabilities(self, server_name: str):
|
314
|
+
"""Get server capabilities if available."""
|
315
|
+
if not self.connection_persistence:
|
316
|
+
# For non-persistent connections, we can't easily check capabilities
|
317
|
+
return None
|
318
|
+
|
319
|
+
try:
|
320
|
+
server_conn = await self._persistent_connection_manager.get_server(
|
321
|
+
server_name, client_session_factory=MCPAgentClientSession
|
322
|
+
)
|
323
|
+
# server_capabilities is a property, not a coroutine
|
324
|
+
return server_conn.server_capabilities
|
325
|
+
except Exception as e:
|
326
|
+
logger.debug(f"Error getting capabilities for server '{server_name}': {e}")
|
327
|
+
return None
|
328
|
+
|
263
329
|
async def list_servers(self) -> List[str]:
|
264
330
|
"""Return the list of server names aggregated by this agent."""
|
265
331
|
if not self.initialized:
|
@@ -281,57 +347,53 @@ class MCPAggregator(ContextDependent):
|
|
281
347
|
]
|
282
348
|
)
|
283
349
|
|
284
|
-
async def
|
285
|
-
self,
|
286
|
-
|
350
|
+
async def _execute_on_server(
|
351
|
+
self,
|
352
|
+
server_name: str,
|
353
|
+
operation_type: str,
|
354
|
+
operation_name: str,
|
355
|
+
method_name: str,
|
356
|
+
method_args: Dict[str, Any] = None,
|
357
|
+
error_factory: Callable[[str], R] = None,
|
358
|
+
) -> R:
|
287
359
|
"""
|
288
|
-
|
360
|
+
Generic method to execute operations on a specific server.
|
361
|
+
|
362
|
+
Args:
|
363
|
+
server_name: Name of the server to execute the operation on
|
364
|
+
operation_type: Type of operation (for logging) e.g., "tool", "prompt"
|
365
|
+
operation_name: Name of the specific operation being called (for logging)
|
366
|
+
method_name: Name of the method to call on the client session
|
367
|
+
method_args: Arguments to pass to the method
|
368
|
+
error_factory: Function to create an error return value if the operation fails
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
Result from the operation or an error result
|
289
372
|
"""
|
290
|
-
if not self.initialized:
|
291
|
-
await self.load_servers()
|
292
|
-
|
293
|
-
server_name: str = None
|
294
|
-
local_tool_name: str = None
|
295
|
-
|
296
|
-
if SEP in name: # Namespaced tool name
|
297
|
-
server_name, local_tool_name = name.split(SEP, 1)
|
298
|
-
else:
|
299
|
-
# Assume un-namespaced, loop through all servers to find the tool. First match wins.
|
300
|
-
for _, tools in self._server_to_tool_map.items():
|
301
|
-
for namespaced_tool in tools:
|
302
|
-
if namespaced_tool.tool.name == name:
|
303
|
-
server_name = namespaced_tool.server_name
|
304
|
-
local_tool_name = name
|
305
|
-
break
|
306
|
-
|
307
|
-
if server_name is None or local_tool_name is None:
|
308
|
-
logger.error(f"Error: Tool '{name}' not found")
|
309
|
-
return CallToolResult(isError=True, message=f"Tool '{name}' not found")
|
310
|
-
|
311
373
|
logger.info(
|
312
|
-
"Requesting
|
374
|
+
f"Requesting {operation_type}",
|
313
375
|
data={
|
314
|
-
"progress_action": ProgressAction.
|
315
|
-
"
|
376
|
+
"progress_action": ProgressAction.STARTING,
|
377
|
+
f"{operation_type}_name": operation_name,
|
316
378
|
"server_name": server_name,
|
317
379
|
"agent_name": self.agent_name,
|
318
380
|
},
|
319
381
|
)
|
320
382
|
|
321
|
-
async def
|
383
|
+
async def try_execute(client: ClientSession):
|
322
384
|
try:
|
323
|
-
|
385
|
+
method = getattr(client, method_name)
|
386
|
+
return await method(**method_args) if method_args else await method()
|
324
387
|
except Exception as e:
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
)
|
388
|
+
error_msg = f"Failed to {method_name} '{operation_name}' on server '{server_name}': {e}"
|
389
|
+
logger.error(error_msg)
|
390
|
+
return error_factory(error_msg) if error_factory else None
|
329
391
|
|
330
392
|
if self.connection_persistence:
|
331
393
|
server_connection = await self._persistent_connection_manager.get_server(
|
332
394
|
server_name, client_session_factory=MCPAgentClientSession
|
333
395
|
)
|
334
|
-
return await
|
396
|
+
return await try_execute(server_connection.session)
|
335
397
|
else:
|
336
398
|
logger.debug(
|
337
399
|
f"Creating temporary connection to server: {server_name}",
|
@@ -344,7 +406,7 @@ class MCPAggregator(ContextDependent):
|
|
344
406
|
async with gen_client(
|
345
407
|
server_name, server_registry=self.context.server_registry
|
346
408
|
) as client:
|
347
|
-
result = await
|
409
|
+
result = await try_execute(client)
|
348
410
|
logger.debug(
|
349
411
|
f"Closing temporary connection to server: {server_name}",
|
350
412
|
data={
|
@@ -355,6 +417,405 @@ class MCPAggregator(ContextDependent):
|
|
355
417
|
)
|
356
418
|
return result
|
357
419
|
|
420
|
+
async def _parse_resource_name(
|
421
|
+
self, name: str, resource_type: str
|
422
|
+
) -> tuple[str, str]:
|
423
|
+
"""
|
424
|
+
Parse a possibly namespaced resource name into server name and local resource name.
|
425
|
+
|
426
|
+
Args:
|
427
|
+
name: The resource name, possibly namespaced
|
428
|
+
resource_type: Type of resource (for error messages), e.g. "tool", "prompt"
|
429
|
+
|
430
|
+
Returns:
|
431
|
+
Tuple of (server_name, local_resource_name)
|
432
|
+
"""
|
433
|
+
server_name = None
|
434
|
+
local_name = None
|
435
|
+
|
436
|
+
if SEP in name: # Namespaced resource name
|
437
|
+
server_name, local_name = name.split(SEP, 1)
|
438
|
+
else:
|
439
|
+
# For tools, search all servers for the tool
|
440
|
+
if resource_type == "tool":
|
441
|
+
for _, tools in self._server_to_tool_map.items():
|
442
|
+
for namespaced_tool in tools:
|
443
|
+
if namespaced_tool.tool.name == name:
|
444
|
+
server_name = namespaced_tool.server_name
|
445
|
+
local_name = name
|
446
|
+
break
|
447
|
+
if server_name:
|
448
|
+
break
|
449
|
+
# For all other resource types, use the first server
|
450
|
+
# (prompt resource type is specially handled in get_prompt)
|
451
|
+
else:
|
452
|
+
local_name = name
|
453
|
+
server_name = self.server_names[0] if self.server_names else None
|
454
|
+
|
455
|
+
return server_name, local_name
|
456
|
+
|
457
|
+
async def call_tool(
|
458
|
+
self, name: str, arguments: dict | None = None
|
459
|
+
) -> CallToolResult:
|
460
|
+
"""
|
461
|
+
Call a namespaced tool, e.g., 'server_name.tool_name'.
|
462
|
+
"""
|
463
|
+
if not self.initialized:
|
464
|
+
await self.load_servers()
|
465
|
+
|
466
|
+
server_name, local_tool_name = await self._parse_resource_name(name, "tool")
|
467
|
+
|
468
|
+
if server_name is None or local_tool_name is None:
|
469
|
+
logger.error(f"Error: Tool '{name}' not found")
|
470
|
+
return CallToolResult(isError=True, message=f"Tool '{name}' not found")
|
471
|
+
|
472
|
+
return await self._execute_on_server(
|
473
|
+
server_name=server_name,
|
474
|
+
operation_type="tool",
|
475
|
+
operation_name=local_tool_name,
|
476
|
+
method_name="call_tool",
|
477
|
+
method_args={"name": local_tool_name, "arguments": arguments},
|
478
|
+
error_factory=lambda msg: CallToolResult(isError=True, message=msg),
|
479
|
+
)
|
480
|
+
|
481
|
+
async def get_prompt(
|
482
|
+
self, prompt_name: str = None, arguments: dict[str, str] = None
|
483
|
+
) -> GetPromptResult:
|
484
|
+
"""
|
485
|
+
Get a prompt from a server.
|
486
|
+
|
487
|
+
:param prompt_name: Name of the prompt, optionally namespaced with server name
|
488
|
+
using the format 'server_name-prompt_name'
|
489
|
+
:param arguments: Optional dictionary of string arguments to pass to the prompt template
|
490
|
+
for templating
|
491
|
+
:return: GetPromptResult containing the prompt description and messages
|
492
|
+
with a namespaced_name property for display purposes
|
493
|
+
"""
|
494
|
+
if not self.initialized:
|
495
|
+
await self.load_servers()
|
496
|
+
|
497
|
+
# Handle the case where prompt_name is None
|
498
|
+
if not prompt_name:
|
499
|
+
server_name = self.server_names[0] if self.server_names else None
|
500
|
+
local_prompt_name = None
|
501
|
+
namespaced_name = None
|
502
|
+
# Handle namespaced prompt name
|
503
|
+
elif SEP in prompt_name:
|
504
|
+
server_name, local_prompt_name = prompt_name.split(SEP, 1)
|
505
|
+
namespaced_name = prompt_name # Already namespaced
|
506
|
+
# Plain prompt name - will use cache to find the server
|
507
|
+
else:
|
508
|
+
local_prompt_name = prompt_name
|
509
|
+
server_name = None
|
510
|
+
namespaced_name = None # Will be set when server is found
|
511
|
+
|
512
|
+
# If we have a specific server to check
|
513
|
+
if server_name:
|
514
|
+
if server_name not in self.server_names:
|
515
|
+
logger.error(f"Error: Server '{server_name}' not found")
|
516
|
+
return GetPromptResult(
|
517
|
+
description=f"Error: Server '{server_name}' not found",
|
518
|
+
messages=[],
|
519
|
+
)
|
520
|
+
|
521
|
+
# Check if server supports prompts
|
522
|
+
capabilities = await self.get_capabilities(server_name)
|
523
|
+
if not capabilities or not capabilities.prompts:
|
524
|
+
logger.debug(f"Server '{server_name}' does not support prompts")
|
525
|
+
return GetPromptResult(
|
526
|
+
description=f"Server '{server_name}' does not support prompts",
|
527
|
+
messages=[],
|
528
|
+
)
|
529
|
+
|
530
|
+
# Check the prompt cache to avoid unnecessary errors
|
531
|
+
if local_prompt_name:
|
532
|
+
async with self._prompt_cache_lock:
|
533
|
+
if server_name in self._prompt_cache:
|
534
|
+
# Check if any prompt in the cache has this name
|
535
|
+
prompt_names = [
|
536
|
+
prompt.name for prompt in self._prompt_cache[server_name]
|
537
|
+
]
|
538
|
+
if local_prompt_name not in prompt_names:
|
539
|
+
logger.debug(
|
540
|
+
f"Prompt '{local_prompt_name}' not found in cache for server '{server_name}'"
|
541
|
+
)
|
542
|
+
return GetPromptResult(
|
543
|
+
description=f"Prompt '{local_prompt_name}' not found on server '{server_name}'",
|
544
|
+
messages=[],
|
545
|
+
)
|
546
|
+
|
547
|
+
# Try to get the prompt from the specified server
|
548
|
+
method_args = {"name": local_prompt_name} if local_prompt_name else {}
|
549
|
+
if arguments:
|
550
|
+
method_args["arguments"] = arguments
|
551
|
+
|
552
|
+
result = await self._execute_on_server(
|
553
|
+
server_name=server_name,
|
554
|
+
operation_type="prompt",
|
555
|
+
operation_name=local_prompt_name or "default",
|
556
|
+
method_name="get_prompt",
|
557
|
+
method_args=method_args,
|
558
|
+
error_factory=lambda msg: GetPromptResult(description=msg, messages=[]),
|
559
|
+
)
|
560
|
+
|
561
|
+
# Add namespaced name and source server to the result
|
562
|
+
if result and result.messages:
|
563
|
+
result.namespaced_name = (
|
564
|
+
namespaced_name or f"{server_name}{SEP}{local_prompt_name}"
|
565
|
+
)
|
566
|
+
|
567
|
+
# Store the arguments in the result for display purposes
|
568
|
+
if arguments:
|
569
|
+
result.arguments = arguments
|
570
|
+
|
571
|
+
return result
|
572
|
+
|
573
|
+
# No specific server - use the cache to find servers that have this prompt
|
574
|
+
logger.debug(f"Searching for prompt '{local_prompt_name}' using cache")
|
575
|
+
|
576
|
+
# Find potential servers from the cache
|
577
|
+
potential_servers = []
|
578
|
+
async with self._prompt_cache_lock:
|
579
|
+
for s_name, prompt_list in self._prompt_cache.items():
|
580
|
+
prompt_names = [prompt.name for prompt in prompt_list]
|
581
|
+
if local_prompt_name in prompt_names:
|
582
|
+
potential_servers.append(s_name)
|
583
|
+
|
584
|
+
if potential_servers:
|
585
|
+
logger.debug(
|
586
|
+
f"Found prompt '{local_prompt_name}' in cache for servers: {potential_servers}"
|
587
|
+
)
|
588
|
+
|
589
|
+
# Try each server from the cache
|
590
|
+
for s_name in potential_servers:
|
591
|
+
# Check if this server supports prompts
|
592
|
+
capabilities = await self.get_capabilities(s_name)
|
593
|
+
if not capabilities or not capabilities.prompts:
|
594
|
+
logger.debug(
|
595
|
+
f"Server '{s_name}' does not support prompts, skipping"
|
596
|
+
)
|
597
|
+
continue
|
598
|
+
|
599
|
+
try:
|
600
|
+
method_args = {"name": local_prompt_name}
|
601
|
+
if arguments:
|
602
|
+
method_args["arguments"] = arguments
|
603
|
+
|
604
|
+
result = await self._execute_on_server(
|
605
|
+
server_name=s_name,
|
606
|
+
operation_type="prompt",
|
607
|
+
operation_name=local_prompt_name,
|
608
|
+
method_name="get_prompt",
|
609
|
+
method_args=method_args,
|
610
|
+
error_factory=lambda _: None, # Return None instead of an error
|
611
|
+
)
|
612
|
+
|
613
|
+
# If we got a successful result with messages, return it
|
614
|
+
if result and result.messages:
|
615
|
+
logger.debug(
|
616
|
+
f"Successfully retrieved prompt '{local_prompt_name}' from server '{s_name}'"
|
617
|
+
)
|
618
|
+
# Add namespaced name using the actual server where found
|
619
|
+
result.namespaced_name = f"{s_name}{SEP}{local_prompt_name}"
|
620
|
+
|
621
|
+
# Store the arguments in the result for display purposes
|
622
|
+
if arguments:
|
623
|
+
result.arguments = arguments
|
624
|
+
|
625
|
+
return result
|
626
|
+
|
627
|
+
except Exception as e:
|
628
|
+
logger.debug(f"Error retrieving prompt from server '{s_name}': {e}")
|
629
|
+
else:
|
630
|
+
logger.debug(
|
631
|
+
f"Prompt '{local_prompt_name}' not found in any server's cache"
|
632
|
+
)
|
633
|
+
|
634
|
+
# If not in cache, perform a full search as fallback (cache might be outdated)
|
635
|
+
# First identify servers that support prompts
|
636
|
+
supported_servers = []
|
637
|
+
for s_name in self.server_names:
|
638
|
+
capabilities = await self.get_capabilities(s_name)
|
639
|
+
if capabilities and capabilities.prompts:
|
640
|
+
supported_servers.append(s_name)
|
641
|
+
else:
|
642
|
+
logger.debug(
|
643
|
+
f"Server '{s_name}' does not support prompts, skipping from fallback search"
|
644
|
+
)
|
645
|
+
|
646
|
+
# Try all supported servers in order
|
647
|
+
for s_name in supported_servers:
|
648
|
+
try:
|
649
|
+
# Use a quiet approach - don't log errors if not found
|
650
|
+
method_args = {"name": local_prompt_name}
|
651
|
+
if arguments:
|
652
|
+
method_args["arguments"] = arguments
|
653
|
+
|
654
|
+
result = await self._execute_on_server(
|
655
|
+
server_name=s_name,
|
656
|
+
operation_type="prompt",
|
657
|
+
operation_name=local_prompt_name,
|
658
|
+
method_name="get_prompt",
|
659
|
+
method_args=method_args,
|
660
|
+
error_factory=lambda _: None, # Return None instead of an error
|
661
|
+
)
|
662
|
+
|
663
|
+
# If we got a successful result with messages, return it
|
664
|
+
if result and result.messages:
|
665
|
+
logger.debug(
|
666
|
+
f"Found prompt '{local_prompt_name}' on server '{s_name}' (not in cache)"
|
667
|
+
)
|
668
|
+
# Add namespaced name using the actual server where found
|
669
|
+
result.namespaced_name = f"{s_name}{SEP}{local_prompt_name}"
|
670
|
+
|
671
|
+
# Store the arguments in the result for display purposes
|
672
|
+
if arguments:
|
673
|
+
result.arguments = arguments
|
674
|
+
|
675
|
+
# Update the cache - need to fetch the prompt object to store in cache
|
676
|
+
try:
|
677
|
+
prompt_list_result = await self._execute_on_server(
|
678
|
+
server_name=s_name,
|
679
|
+
operation_type="prompts-list",
|
680
|
+
operation_name="",
|
681
|
+
method_name="list_prompts",
|
682
|
+
error_factory=lambda _: None,
|
683
|
+
)
|
684
|
+
|
685
|
+
prompts = getattr(prompt_list_result, "prompts", [])
|
686
|
+
matching_prompts = [
|
687
|
+
p for p in prompts if p.name == local_prompt_name
|
688
|
+
]
|
689
|
+
if matching_prompts:
|
690
|
+
async with self._prompt_cache_lock:
|
691
|
+
if s_name not in self._prompt_cache:
|
692
|
+
self._prompt_cache[s_name] = []
|
693
|
+
# Add if not already in the cache
|
694
|
+
prompt_names_in_cache = [
|
695
|
+
p.name for p in self._prompt_cache[s_name]
|
696
|
+
]
|
697
|
+
if local_prompt_name not in prompt_names_in_cache:
|
698
|
+
self._prompt_cache[s_name].append(
|
699
|
+
matching_prompts[0]
|
700
|
+
)
|
701
|
+
except Exception:
|
702
|
+
# Ignore errors when updating cache
|
703
|
+
pass
|
704
|
+
|
705
|
+
return result
|
706
|
+
|
707
|
+
except Exception:
|
708
|
+
# Don't log errors during fallback search
|
709
|
+
pass
|
710
|
+
|
711
|
+
# If we get here, we couldn't find the prompt on any server
|
712
|
+
logger.info(f"Prompt '{local_prompt_name}' not found on any server")
|
713
|
+
return GetPromptResult(
|
714
|
+
description=f"Prompt '{local_prompt_name}' not found on any server",
|
715
|
+
messages=[],
|
716
|
+
)
|
717
|
+
|
718
|
+
async def list_prompts(self, server_name: str = None):
|
719
|
+
"""
|
720
|
+
List available prompts from one or all servers.
|
721
|
+
|
722
|
+
:param server_name: Optional server name to list prompts from. If not provided,
|
723
|
+
lists prompts from all servers.
|
724
|
+
:return: Dictionary mapping server names to lists of available prompts
|
725
|
+
"""
|
726
|
+
if not self.initialized:
|
727
|
+
await self.load_servers()
|
728
|
+
|
729
|
+
results = {}
|
730
|
+
|
731
|
+
# If we already have the data in cache and not requesting a specific server,
|
732
|
+
# we can use the cache directly
|
733
|
+
if not server_name:
|
734
|
+
async with self._prompt_cache_lock:
|
735
|
+
if all(s_name in self._prompt_cache for s_name in self.server_names):
|
736
|
+
# Return the cached prompt objects
|
737
|
+
for s_name, prompt_list in self._prompt_cache.items():
|
738
|
+
results[s_name] = prompt_list
|
739
|
+
logger.debug("Returning cached prompts for all servers")
|
740
|
+
return results
|
741
|
+
|
742
|
+
# If server_name is provided, only list prompts from that server
|
743
|
+
if server_name:
|
744
|
+
if server_name in self.server_names:
|
745
|
+
# Check if we can use the cache
|
746
|
+
async with self._prompt_cache_lock:
|
747
|
+
if server_name in self._prompt_cache:
|
748
|
+
results[server_name] = self._prompt_cache[server_name]
|
749
|
+
logger.debug(
|
750
|
+
f"Returning cached prompts for server '{server_name}'"
|
751
|
+
)
|
752
|
+
return results
|
753
|
+
|
754
|
+
# Check if server supports prompts
|
755
|
+
capabilities = await self.get_capabilities(server_name)
|
756
|
+
if not capabilities or not capabilities.prompts:
|
757
|
+
logger.debug(f"Server '{server_name}' does not support prompts")
|
758
|
+
results[server_name] = []
|
759
|
+
return results
|
760
|
+
|
761
|
+
# If not in cache and server supports prompts, fetch from server
|
762
|
+
result = await self._execute_on_server(
|
763
|
+
server_name=server_name,
|
764
|
+
operation_type="prompts-list",
|
765
|
+
operation_name="",
|
766
|
+
method_name="list_prompts",
|
767
|
+
error_factory=lambda _: [],
|
768
|
+
)
|
769
|
+
|
770
|
+
# Update cache with the result
|
771
|
+
async with self._prompt_cache_lock:
|
772
|
+
self._prompt_cache[server_name] = getattr(result, "prompts", [])
|
773
|
+
|
774
|
+
results[server_name] = result
|
775
|
+
else:
|
776
|
+
logger.error(f"Server '{server_name}' not found")
|
777
|
+
else:
|
778
|
+
# We need to filter the servers that support prompts
|
779
|
+
supported_servers = []
|
780
|
+
for s_name in self.server_names:
|
781
|
+
capabilities = await self.get_capabilities(s_name)
|
782
|
+
if capabilities and capabilities.prompts:
|
783
|
+
supported_servers.append(s_name)
|
784
|
+
else:
|
785
|
+
logger.debug(
|
786
|
+
f"Server '{s_name}' does not support prompts, skipping"
|
787
|
+
)
|
788
|
+
# Add empty list to results for this server
|
789
|
+
results[s_name] = []
|
790
|
+
|
791
|
+
# Gather prompts from supported servers concurrently
|
792
|
+
if supported_servers:
|
793
|
+
tasks = [
|
794
|
+
self._execute_on_server(
|
795
|
+
server_name=s_name,
|
796
|
+
operation_type="prompts-list",
|
797
|
+
operation_name="",
|
798
|
+
method_name="list_prompts",
|
799
|
+
error_factory=lambda _: [],
|
800
|
+
)
|
801
|
+
for s_name in supported_servers
|
802
|
+
]
|
803
|
+
server_results = await gather(*tasks, return_exceptions=True)
|
804
|
+
|
805
|
+
for i, result in enumerate(server_results):
|
806
|
+
if isinstance(result, BaseException):
|
807
|
+
continue
|
808
|
+
|
809
|
+
s_name = supported_servers[i]
|
810
|
+
results[s_name] = result
|
811
|
+
|
812
|
+
# Update cache with the result
|
813
|
+
async with self._prompt_cache_lock:
|
814
|
+
self._prompt_cache[s_name] = getattr(result, "prompts", [])
|
815
|
+
|
816
|
+
logger.debug(f"Available prompts across servers: {results}")
|
817
|
+
return results
|
818
|
+
|
358
819
|
|
359
820
|
class MCPCompoundServer(Server):
|
360
821
|
"""
|
@@ -365,10 +826,11 @@ class MCPCompoundServer(Server):
|
|
365
826
|
super().__init__(name)
|
366
827
|
self.aggregator = MCPAggregator(server_names)
|
367
828
|
|
368
|
-
# Register handlers
|
369
|
-
# TODO: saqadri - once we support resources and prompts, add handlers for those as well
|
829
|
+
# Register handlers for tools, prompts, and resources
|
370
830
|
self.list_tools()(self._list_tools)
|
371
831
|
self.call_tool()(self._call_tool)
|
832
|
+
self.get_prompt()(self._get_prompt)
|
833
|
+
self.list_prompts()(self._list_prompts)
|
372
834
|
|
373
835
|
async def _list_tools(self) -> List[Tool]:
|
374
836
|
"""List all tools aggregated from connected MCP servers."""
|
@@ -385,6 +847,34 @@ class MCPCompoundServer(Server):
|
|
385
847
|
except Exception as e:
|
386
848
|
return CallToolResult(isError=True, message=f"Error calling tool: {e}")
|
387
849
|
|
850
|
+
async def _get_prompt(
|
851
|
+
self, name: str = None, arguments: dict[str, str] = None
|
852
|
+
) -> GetPromptResult:
|
853
|
+
"""
|
854
|
+
Get a prompt from the aggregated servers.
|
855
|
+
|
856
|
+
Args:
|
857
|
+
name: Name of the prompt to get (optionally namespaced)
|
858
|
+
arguments: Optional dictionary of string arguments for prompt templating
|
859
|
+
"""
|
860
|
+
try:
|
861
|
+
result = await self.aggregator.get_prompt(
|
862
|
+
prompt_name=name, arguments=arguments
|
863
|
+
)
|
864
|
+
return result
|
865
|
+
except Exception as e:
|
866
|
+
return GetPromptResult(
|
867
|
+
description=f"Error getting prompt: {e}", messages=[]
|
868
|
+
)
|
869
|
+
|
870
|
+
async def _list_prompts(self, server_name: str = None) -> Dict[str, List[str]]:
|
871
|
+
"""List available prompts from the aggregated servers."""
|
872
|
+
try:
|
873
|
+
return await self.aggregator.list_prompts(server_name=server_name)
|
874
|
+
except Exception as e:
|
875
|
+
logger.error(f"Error listing prompts: {e}")
|
876
|
+
return {}
|
877
|
+
|
388
878
|
async def run_stdio_async(self) -> None:
|
389
879
|
"""Run the server using stdio transport."""
|
390
880
|
async with stdio_server() as (read_stream, write_stream):
|