fast-agent-mcp 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of fast-agent-mcp might be problematic. Click here for more details.

Files changed (100) hide show
  1. fast_agent_mcp-0.0.7.dist-info/METADATA +322 -0
  2. fast_agent_mcp-0.0.7.dist-info/RECORD +100 -0
  3. fast_agent_mcp-0.0.7.dist-info/WHEEL +4 -0
  4. fast_agent_mcp-0.0.7.dist-info/entry_points.txt +5 -0
  5. fast_agent_mcp-0.0.7.dist-info/licenses/LICENSE +201 -0
  6. mcp_agent/__init__.py +0 -0
  7. mcp_agent/agents/__init__.py +0 -0
  8. mcp_agent/agents/agent.py +277 -0
  9. mcp_agent/app.py +303 -0
  10. mcp_agent/cli/__init__.py +0 -0
  11. mcp_agent/cli/__main__.py +4 -0
  12. mcp_agent/cli/commands/bootstrap.py +221 -0
  13. mcp_agent/cli/commands/config.py +11 -0
  14. mcp_agent/cli/commands/setup.py +229 -0
  15. mcp_agent/cli/main.py +68 -0
  16. mcp_agent/cli/terminal.py +24 -0
  17. mcp_agent/config.py +334 -0
  18. mcp_agent/console.py +28 -0
  19. mcp_agent/context.py +251 -0
  20. mcp_agent/context_dependent.py +48 -0
  21. mcp_agent/core/fastagent.py +1013 -0
  22. mcp_agent/eval/__init__.py +0 -0
  23. mcp_agent/event_progress.py +88 -0
  24. mcp_agent/executor/__init__.py +0 -0
  25. mcp_agent/executor/decorator_registry.py +120 -0
  26. mcp_agent/executor/executor.py +293 -0
  27. mcp_agent/executor/task_registry.py +34 -0
  28. mcp_agent/executor/temporal.py +405 -0
  29. mcp_agent/executor/workflow.py +197 -0
  30. mcp_agent/executor/workflow_signal.py +325 -0
  31. mcp_agent/human_input/__init__.py +0 -0
  32. mcp_agent/human_input/handler.py +49 -0
  33. mcp_agent/human_input/types.py +58 -0
  34. mcp_agent/logging/__init__.py +0 -0
  35. mcp_agent/logging/events.py +123 -0
  36. mcp_agent/logging/json_serializer.py +163 -0
  37. mcp_agent/logging/listeners.py +216 -0
  38. mcp_agent/logging/logger.py +365 -0
  39. mcp_agent/logging/rich_progress.py +120 -0
  40. mcp_agent/logging/tracing.py +140 -0
  41. mcp_agent/logging/transport.py +461 -0
  42. mcp_agent/mcp/__init__.py +0 -0
  43. mcp_agent/mcp/gen_client.py +85 -0
  44. mcp_agent/mcp/mcp_activity.py +18 -0
  45. mcp_agent/mcp/mcp_agent_client_session.py +242 -0
  46. mcp_agent/mcp/mcp_agent_server.py +56 -0
  47. mcp_agent/mcp/mcp_aggregator.py +394 -0
  48. mcp_agent/mcp/mcp_connection_manager.py +330 -0
  49. mcp_agent/mcp/stdio.py +104 -0
  50. mcp_agent/mcp_server_registry.py +275 -0
  51. mcp_agent/progress_display.py +10 -0
  52. mcp_agent/resources/examples/decorator/main.py +26 -0
  53. mcp_agent/resources/examples/decorator/optimizer.py +78 -0
  54. mcp_agent/resources/examples/decorator/orchestrator.py +68 -0
  55. mcp_agent/resources/examples/decorator/parallel.py +81 -0
  56. mcp_agent/resources/examples/decorator/router.py +56 -0
  57. mcp_agent/resources/examples/decorator/tiny.py +22 -0
  58. mcp_agent/resources/examples/mcp_researcher/main-evalopt.py +53 -0
  59. mcp_agent/resources/examples/mcp_researcher/main.py +38 -0
  60. mcp_agent/telemetry/__init__.py +0 -0
  61. mcp_agent/telemetry/usage_tracking.py +18 -0
  62. mcp_agent/workflows/__init__.py +0 -0
  63. mcp_agent/workflows/embedding/__init__.py +0 -0
  64. mcp_agent/workflows/embedding/embedding_base.py +61 -0
  65. mcp_agent/workflows/embedding/embedding_cohere.py +49 -0
  66. mcp_agent/workflows/embedding/embedding_openai.py +46 -0
  67. mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
  68. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +359 -0
  69. mcp_agent/workflows/intent_classifier/__init__.py +0 -0
  70. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +120 -0
  71. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +134 -0
  72. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +45 -0
  73. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +45 -0
  74. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +161 -0
  75. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +60 -0
  76. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +60 -0
  77. mcp_agent/workflows/llm/__init__.py +0 -0
  78. mcp_agent/workflows/llm/augmented_llm.py +645 -0
  79. mcp_agent/workflows/llm/augmented_llm_anthropic.py +539 -0
  80. mcp_agent/workflows/llm/augmented_llm_openai.py +615 -0
  81. mcp_agent/workflows/llm/llm_selector.py +345 -0
  82. mcp_agent/workflows/llm/model_factory.py +175 -0
  83. mcp_agent/workflows/orchestrator/__init__.py +0 -0
  84. mcp_agent/workflows/orchestrator/orchestrator.py +407 -0
  85. mcp_agent/workflows/orchestrator/orchestrator_models.py +154 -0
  86. mcp_agent/workflows/orchestrator/orchestrator_prompts.py +113 -0
  87. mcp_agent/workflows/parallel/__init__.py +0 -0
  88. mcp_agent/workflows/parallel/fan_in.py +350 -0
  89. mcp_agent/workflows/parallel/fan_out.py +187 -0
  90. mcp_agent/workflows/parallel/parallel_llm.py +141 -0
  91. mcp_agent/workflows/router/__init__.py +0 -0
  92. mcp_agent/workflows/router/router_base.py +276 -0
  93. mcp_agent/workflows/router/router_embedding.py +240 -0
  94. mcp_agent/workflows/router/router_embedding_cohere.py +59 -0
  95. mcp_agent/workflows/router/router_embedding_openai.py +59 -0
  96. mcp_agent/workflows/router/router_llm.py +301 -0
  97. mcp_agent/workflows/swarm/__init__.py +0 -0
  98. mcp_agent/workflows/swarm/swarm.py +320 -0
  99. mcp_agent/workflows/swarm/swarm_anthropic.py +42 -0
  100. mcp_agent/workflows/swarm/swarm_openai.py +41 -0
@@ -0,0 +1,242 @@
1
+ """
2
+ A derived client session for the MCP Agent framework.
3
+ It adds logging and supports sampling requests.
4
+ """
5
+
6
+ from typing import Optional
7
+
8
+ from mcp import ClientSession
9
+ from mcp.shared.session import (
10
+ RequestResponder,
11
+ ReceiveResultT,
12
+ ReceiveNotificationT,
13
+ RequestId,
14
+ SendNotificationT,
15
+ SendRequestT,
16
+ SendResultT,
17
+ )
18
+ from mcp.types import (
19
+ ClientResult,
20
+ CreateMessageRequest,
21
+ CreateMessageResult,
22
+ ErrorData,
23
+ JSONRPCNotification,
24
+ JSONRPCRequest,
25
+ ServerRequest,
26
+ TextContent,
27
+ ListRootsRequest,
28
+ ListRootsResult,
29
+ Root,
30
+ )
31
+
32
+ from mcp_agent.config import MCPServerSettings
33
+ from mcp_agent.context_dependent import ContextDependent
34
+ from mcp_agent.logging.logger import get_logger
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ class MCPAgentClientSession(ClientSession, ContextDependent):
40
+ """
41
+ MCP Agent framework acts as a client to the servers providing tools/resources/prompts for the agent workloads.
42
+ This is a simple client session for those server connections, and supports
43
+ - handling sampling requests
44
+ - notifications
45
+ - MCP root configuration
46
+
47
+ Developers can extend this class to add more custom functionality as needed
48
+ """
49
+
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ self.server_config: Optional[MCPServerSettings] = None
53
+
54
+ async def _received_request(
55
+ self, responder: RequestResponder[ServerRequest, ClientResult]
56
+ ) -> None:
57
+ logger.debug("Received request:", data=responder.request.model_dump())
58
+ request = responder.request.root
59
+
60
+ if isinstance(request, CreateMessageRequest):
61
+ return await self.handle_sampling_request(request, responder)
62
+ elif isinstance(request, ListRootsRequest):
63
+ # Handle list_roots request by returning configured roots
64
+ if hasattr(self, "server_config") and self.server_config.roots:
65
+ roots = [
66
+ Root(
67
+ uri=root.server_uri_alias or root.uri,
68
+ name=root.name,
69
+ )
70
+ for root in self.server_config.roots
71
+ ]
72
+
73
+ await responder.respond(ListRootsResult(roots=roots))
74
+ else:
75
+ await responder.respond(ListRootsResult(roots=[]))
76
+ return
77
+
78
+ # Handle other requests as usual
79
+ await super()._received_request(responder)
80
+
81
+ async def send_request(
82
+ self,
83
+ request: SendRequestT,
84
+ result_type: type[ReceiveResultT],
85
+ ) -> ReceiveResultT:
86
+ logger.debug("send_request: request=", data=request.model_dump())
87
+ try:
88
+ result = await super().send_request(request, result_type)
89
+ logger.debug("send_request: response=", data=result.model_dump())
90
+ return result
91
+ except Exception as e:
92
+ logger.error(f"send_request failed: {e}")
93
+ raise
94
+
95
+ async def send_notification(self, notification: SendNotificationT) -> None:
96
+ logger.debug("send_notification:", data=notification.model_dump())
97
+ try:
98
+ return await super().send_notification(notification)
99
+ except Exception as e:
100
+ logger.error("send_notification failed", data=e)
101
+ raise
102
+
103
+ async def _send_response(
104
+ self, request_id: RequestId, response: SendResultT | ErrorData
105
+ ) -> None:
106
+ logger.debug(
107
+ f"send_response: request_id={request_id}, response=",
108
+ data=response.model_dump(),
109
+ )
110
+ return await super()._send_response(request_id, response)
111
+
112
+ async def _received_notification(self, notification: ReceiveNotificationT) -> None:
113
+ """
114
+ Can be overridden by subclasses to handle a notification without needing
115
+ to listen on the message stream.
116
+ """
117
+ logger.info(
118
+ "_received_notification: notification=",
119
+ data=notification.model_dump(),
120
+ )
121
+ return await super()._received_notification(notification)
122
+
123
+ async def send_progress_notification(
124
+ self, progress_token: str | int, progress: float, total: float | None = None
125
+ ) -> None:
126
+ """
127
+ Sends a progress notification for a request that is currently being
128
+ processed.
129
+ """
130
+ logger.debug(
131
+ "send_progress_notification: progress_token={progress_token}, progress={progress}, total={total}"
132
+ )
133
+ return await super().send_progress_notification(
134
+ progress_token=progress_token, progress=progress, total=total
135
+ )
136
+
137
+ async def _receive_loop(self) -> None:
138
+ async with (
139
+ self._read_stream,
140
+ self._write_stream,
141
+ self._incoming_message_stream_writer,
142
+ ):
143
+ async for message in self._read_stream:
144
+ if isinstance(message, Exception):
145
+ await self._incoming_message_stream_writer.send(message)
146
+ elif isinstance(message.root, JSONRPCRequest):
147
+ validated_request = self._receive_request_type.model_validate(
148
+ message.root.model_dump(
149
+ by_alias=True, mode="json", exclude_none=True
150
+ )
151
+ )
152
+ responder = RequestResponder(
153
+ request_id=message.root.id,
154
+ request_meta=validated_request.root.params.meta
155
+ if validated_request.root.params
156
+ else None,
157
+ request=validated_request,
158
+ session=self,
159
+ )
160
+
161
+ await self._received_request(responder)
162
+ if not responder._responded:
163
+ await self._incoming_message_stream_writer.send(responder)
164
+ elif isinstance(message.root, JSONRPCNotification):
165
+ notification = self._receive_notification_type.model_validate(
166
+ message.root.model_dump(
167
+ by_alias=True, mode="json", exclude_none=True
168
+ )
169
+ )
170
+
171
+ await self._received_notification(notification)
172
+ await self._incoming_message_stream_writer.send(notification)
173
+ else: # Response or error
174
+ stream = self._response_streams.pop(message.root.id, None)
175
+ if stream:
176
+ await stream.send(message.root)
177
+ else:
178
+ await self._incoming_message_stream_writer.send(
179
+ RuntimeError(
180
+ "Received response with an unknown "
181
+ f"request ID: {message}"
182
+ )
183
+ )
184
+
185
+ async def handle_sampling_request(
186
+ self,
187
+ request: CreateMessageRequest,
188
+ responder: RequestResponder[ServerRequest, ClientResult],
189
+ ):
190
+ logger.info("Handling sampling request: %s", request)
191
+ config = self.context.config
192
+ session = self.context.upstream_session
193
+ if session is None:
194
+ # TODO: saqadri - consider whether we should be handling the sampling request here as a client
195
+ logger.warning(
196
+ "Error: No upstream client available for sampling requests. Request:",
197
+ data=request,
198
+ )
199
+ try:
200
+ from anthropic import AsyncAnthropic
201
+
202
+ client = AsyncAnthropic(api_key=config.anthropic.api_key)
203
+
204
+ params = request.params
205
+ response = await client.messages.create(
206
+ model="claude-3-sonnet-20240229",
207
+ max_tokens=params.maxTokens,
208
+ messages=[
209
+ {
210
+ "role": m.role,
211
+ "content": m.content.text
212
+ if hasattr(m.content, "text")
213
+ else m.content.data,
214
+ }
215
+ for m in params.messages
216
+ ],
217
+ system=getattr(params, "systemPrompt", None),
218
+ temperature=getattr(params, "temperature", 0.7),
219
+ stop_sequences=getattr(params, "stopSequences", None),
220
+ )
221
+
222
+ await responder.respond(
223
+ CreateMessageResult(
224
+ model="claude-3-sonnet-20240229",
225
+ role="assistant",
226
+ content=TextContent(type="text", text=response.content[0].text),
227
+ )
228
+ )
229
+ except Exception as e:
230
+ logger.error(f"Error handling sampling request: {e}")
231
+ await responder.respond(ErrorData(code=-32603, message=str(e)))
232
+ else:
233
+ try:
234
+ # If a session is available, we'll pass-through the sampling request to the upstream client
235
+ result = await session.send_request(
236
+ request=ServerRequest(request), result_type=CreateMessageResult
237
+ )
238
+
239
+ # Pass the result from the upstream client back to the server. We just act as a pass-through client here.
240
+ await responder.send_result(result)
241
+ except Exception as e:
242
+ await responder.send_error(code=-32603, message=str(e))
@@ -0,0 +1,56 @@
1
+ import asyncio
2
+ from mcp.server import NotificationOptions
3
+ from mcp.server.fastmcp import FastMCP
4
+ from mcp.server.stdio import stdio_server
5
+ from mcp_agent.executor.temporal import get_temporal_client
6
+ from mcp_agent.telemetry.tracing import setup_tracing
7
+
8
+ app = FastMCP("mcp-agent-server")
9
+
10
+ setup_tracing("mcp-agent-server")
11
+
12
+
13
+ async def run():
14
+ async with stdio_server() as (read_stream, write_stream):
15
+ await app._mcp_server.run(
16
+ read_stream,
17
+ write_stream,
18
+ app._mcp_server.create_initialization_options(
19
+ notification_options=NotificationOptions(
20
+ tools_changed=True, resources_changed=True
21
+ )
22
+ ),
23
+ )
24
+
25
+
26
+ @app.tool
27
+ async def run_workflow(query: str):
28
+ """Run the workflow given its name or id"""
29
+ pass
30
+
31
+
32
+ @app.tool
33
+ async def pause_workflow(workflow_id: str):
34
+ """Pause a running workflow."""
35
+ temporal_client = await get_temporal_client()
36
+ handle = temporal_client.get_workflow_handle(workflow_id)
37
+ await handle.signal("pause")
38
+
39
+
40
+ @app.tool
41
+ async def resume_workflow(workflow_id: str):
42
+ """Resume a paused workflow."""
43
+ temporal_client = await get_temporal_client()
44
+ handle = temporal_client.get_workflow_handle(workflow_id)
45
+ await handle.signal("resume")
46
+
47
+
48
+ async def provide_user_input(workflow_id: str, input_data: str):
49
+ """Provide user/human input to a waiting workflow step."""
50
+ temporal_client = await get_temporal_client()
51
+ handle = temporal_client.get_workflow_handle(workflow_id)
52
+ await handle.signal("human_input", input_data)
53
+
54
+
55
+ if __name__ == "__main__":
56
+ asyncio.run(run())
@@ -0,0 +1,394 @@
1
+ from asyncio import Lock, gather
2
+ from typing import List, Dict, Optional, TYPE_CHECKING
3
+
4
+ from pydantic import BaseModel, ConfigDict
5
+ from mcp.client.session import ClientSession
6
+ from mcp.server.lowlevel.server import Server
7
+ from mcp.server.stdio import stdio_server
8
+ from mcp.types import (
9
+ CallToolResult,
10
+ ListToolsResult,
11
+ Tool,
12
+ )
13
+
14
+ from mcp_agent.event_progress import ProgressAction
15
+ from mcp_agent.logging.logger import get_logger
16
+ from mcp_agent.mcp.gen_client import gen_client
17
+
18
+ from mcp_agent.context_dependent import ContextDependent
19
+ from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
20
+ from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager
21
+
22
+ if TYPE_CHECKING:
23
+ from mcp_agent.context import Context
24
+
25
+
26
+ logger = get_logger(
27
+ __name__
28
+ ) # This will be replaced per-instance when agent_name is available
29
+
30
+ SEP = "-"
31
+
32
+
33
+ class NamespacedTool(BaseModel):
34
+ """
35
+ A tool that is namespaced by server name.
36
+ """
37
+
38
+ tool: Tool
39
+ server_name: str
40
+ namespaced_tool_name: str
41
+
42
+
43
+ class MCPAggregator(ContextDependent):
44
+ """
45
+ Aggregates multiple MCP servers. When a developer calls, e.g. call_tool(...),
46
+ the aggregator searches all servers in its list for a server that provides that tool.
47
+ """
48
+
49
+ initialized: bool = False
50
+ """Whether the aggregator has been initialized with tools and resources from all servers."""
51
+
52
+ connection_persistence: bool = False
53
+ """Whether to maintain a persistent connection to the server."""
54
+
55
+ server_names: List[str]
56
+ """A list of server names to connect to."""
57
+
58
+ model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
59
+
60
+ async def __aenter__(self):
61
+ if self.initialized:
62
+ return self
63
+
64
+ # Keep a connection manager to manage persistent connections for this aggregator
65
+ if self.connection_persistence:
66
+ # Try to get existing connection manager from context
67
+ if not hasattr(self.context, "_connection_manager"):
68
+ self.context._connection_manager = MCPConnectionManager(
69
+ self.context.server_registry
70
+ )
71
+ await self.context._connection_manager.__aenter__()
72
+ self._persistent_connection_manager = self.context._connection_manager
73
+
74
+ await self.load_servers()
75
+
76
+ return self
77
+
78
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
79
+ await self.close()
80
+
81
+ def __init__(
82
+ self,
83
+ server_names: List[str],
84
+ connection_persistence: bool = False,
85
+ context: Optional["Context"] = None,
86
+ name: str = None,
87
+ **kwargs,
88
+ ):
89
+ """
90
+ :param server_names: A list of server names to connect to.
91
+ Note: The server names must be resolvable by the gen_client function, and specified in the server registry.
92
+ """
93
+ super().__init__(
94
+ context=context,
95
+ **kwargs,
96
+ )
97
+
98
+ self.server_names = server_names
99
+ self.connection_persistence = connection_persistence
100
+ self.agent_name = name
101
+ self._persistent_connection_manager: MCPConnectionManager = None
102
+
103
+ # Set up logger with agent name in namespace if available
104
+ global logger
105
+ logger_name = f"{__name__}.{name}" if name else __name__
106
+ logger = get_logger(logger_name)
107
+
108
+ # Maps namespaced_tool_name -> namespaced tool info
109
+ self._namespaced_tool_map: Dict[str, NamespacedTool] = {}
110
+ # Maps server_name -> list of tools
111
+ self._server_to_tool_map: Dict[str, List[NamespacedTool]] = {}
112
+ self._tool_map_lock = Lock()
113
+
114
+ # TODO: saqadri - add resources and prompt maps as well
115
+
116
+ async def close(self):
117
+ """
118
+ Close all persistent connections when the aggregator is deleted.
119
+ """
120
+ if self.connection_persistence and self._persistent_connection_manager:
121
+ try:
122
+ # Only attempt cleanup if we own the connection manager
123
+ if (
124
+ hasattr(self.context, "_connection_manager")
125
+ and self.context._connection_manager
126
+ == self._persistent_connection_manager
127
+ ):
128
+ logger.info("Shutting down all persistent connections...")
129
+ await self._persistent_connection_manager.disconnect_all()
130
+ await self._persistent_connection_manager.__aexit__(
131
+ None, None, None
132
+ )
133
+ delattr(self.context, "_connection_manager")
134
+ self.initialized = False
135
+ except Exception as e:
136
+ logger.error(f"Error during connection manager cleanup: {e}")
137
+
138
+ @classmethod
139
+ async def create(
140
+ cls,
141
+ server_names: List[str],
142
+ connection_persistence: bool = False,
143
+ ) -> "MCPAggregator":
144
+ """
145
+ Factory method to create and initialize an MCPAggregator.
146
+ Use this instead of constructor since we need async initialization.
147
+ If connection_persistence is True, the aggregator will maintain a
148
+ persistent connection to the servers for as long as this aggregator is around.
149
+ By default we do not maintain a persistent connection.
150
+ """
151
+
152
+ logger.info(f"Creating MCPAggregator with servers: {server_names}")
153
+
154
+ instance = cls(
155
+ server_names=server_names,
156
+ connection_persistence=connection_persistence,
157
+ )
158
+
159
+ try:
160
+ await instance.__aenter__()
161
+
162
+ logger.debug("Loading servers...")
163
+ await instance.load_servers()
164
+
165
+ logger.debug("MCPAggregator created and initialized.")
166
+ return instance
167
+ except Exception as e:
168
+ logger.error(f"Error creating MCPAggregator: {e}")
169
+ await instance.__aexit__(None, None, None)
170
+
171
+ async def load_servers(self):
172
+ """
173
+ Discover tools from each server in parallel and build an index of namespaced tool names.
174
+ """
175
+ if self.initialized:
176
+ logger.debug("MCPAggregator already initialized.")
177
+ return
178
+
179
+ async with self._tool_map_lock:
180
+ self._namespaced_tool_map.clear()
181
+ self._server_to_tool_map.clear()
182
+
183
+ for server_name in self.server_names:
184
+ if self.connection_persistence:
185
+ logger.info(
186
+ f"Creating persistent connection to server: {server_name}",
187
+ data={
188
+ "progress_action": ProgressAction.STARTING,
189
+ "server_name": server_name,
190
+ "agent_name": self.agent_name,
191
+ },
192
+ )
193
+ await self._persistent_connection_manager.get_server(
194
+ server_name, client_session_factory=MCPAgentClientSession
195
+ )
196
+
197
+ logger.info(
198
+ f"MCP Servers initialized for agent '{self.agent_name}'",
199
+ data={
200
+ "progress_action": ProgressAction.INITIALIZED,
201
+ "agent_name": self.agent_name,
202
+ },
203
+ )
204
+
205
+ async def fetch_tools(client: ClientSession):
206
+ try:
207
+ result: ListToolsResult = await client.list_tools()
208
+ return result.tools or []
209
+ except Exception as e:
210
+ logger.error(f"Error loading tools from server '{server_name}'", data=e)
211
+ return []
212
+
213
+ async def load_server_tools(server_name: str):
214
+ tools: List[Tool] = []
215
+ if self.connection_persistence:
216
+ server_connection = (
217
+ await self._persistent_connection_manager.get_server(
218
+ server_name, client_session_factory=MCPAgentClientSession
219
+ )
220
+ )
221
+ tools = await fetch_tools(server_connection.session)
222
+ else:
223
+ async with gen_client(
224
+ server_name, server_registry=self.context.server_registry
225
+ ) as client:
226
+ tools = await fetch_tools(client)
227
+
228
+ return server_name, tools
229
+
230
+ # Gather tools from all servers concurrently
231
+ results = await gather(
232
+ *(load_server_tools(server_name) for server_name in self.server_names),
233
+ return_exceptions=True,
234
+ )
235
+
236
+ for result in results:
237
+ if isinstance(result, BaseException):
238
+ continue
239
+ server_name, tools = result
240
+
241
+ self._server_to_tool_map[server_name] = []
242
+ for tool in tools:
243
+ namespaced_tool_name = f"{server_name}{SEP}{tool.name}"
244
+ namespaced_tool = NamespacedTool(
245
+ tool=tool,
246
+ server_name=server_name,
247
+ namespaced_tool_name=namespaced_tool_name,
248
+ )
249
+
250
+ self._namespaced_tool_map[namespaced_tool_name] = namespaced_tool
251
+ self._server_to_tool_map[server_name].append(namespaced_tool)
252
+ logger.debug(
253
+ "MCP Aggregator initialized",
254
+ data={
255
+ "progress_action": ProgressAction.INITIALIZED,
256
+ "server_name": server_name,
257
+ "agent_name": self.agent_name,
258
+ },
259
+ )
260
+ self.initialized = True
261
+
262
+ async def list_servers(self) -> List[str]:
263
+ """Return the list of server names aggregated by this agent."""
264
+ if not self.initialized:
265
+ await self.load_servers()
266
+
267
+ return self.server_names
268
+
269
+ async def list_tools(self) -> ListToolsResult:
270
+ """
271
+ :return: Tools from all servers aggregated, and renamed to be dot-namespaced by server name.
272
+ """
273
+ if not self.initialized:
274
+ await self.load_servers()
275
+
276
+ return ListToolsResult(
277
+ tools=[
278
+ namespaced_tool.tool.model_copy(update={"name": namespaced_tool_name})
279
+ for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items()
280
+ ]
281
+ )
282
+
283
+ async def call_tool(
284
+ self, name: str, arguments: dict | None = None
285
+ ) -> CallToolResult:
286
+ """
287
+ Call a namespaced tool, e.g., 'server_name.tool_name'.
288
+ """
289
+ if not self.initialized:
290
+ await self.load_servers()
291
+
292
+ server_name: str = None
293
+ local_tool_name: str = None
294
+
295
+ if SEP in name: # Namespaced tool name
296
+ server_name, local_tool_name = name.split(SEP, 1)
297
+ else:
298
+ # Assume un-namespaced, loop through all servers to find the tool. First match wins.
299
+ for _, tools in self._server_to_tool_map.items():
300
+ for namespaced_tool in tools:
301
+ if namespaced_tool.tool.name == name:
302
+ server_name = namespaced_tool.server_name
303
+ local_tool_name = name
304
+ break
305
+
306
+ if server_name is None or local_tool_name is None:
307
+ logger.error(f"Error: Tool '{name}' not found")
308
+ return CallToolResult(isError=True, message=f"Tool '{name}' not found")
309
+
310
+ logger.info(
311
+ "Requesting tool call",
312
+ data={
313
+ "progress_action": ProgressAction.CALLING_TOOL,
314
+ "tool_name": local_tool_name,
315
+ "server_name": server_name,
316
+ "agent_name": self.agent_name,
317
+ },
318
+ )
319
+
320
+ async def try_call_tool(client: ClientSession):
321
+ try:
322
+ return await client.call_tool(name=local_tool_name, arguments=arguments)
323
+ except Exception as e:
324
+ return CallToolResult(
325
+ isError=True,
326
+ message=f"Failed to call tool '{local_tool_name}' on server '{server_name}': {e}",
327
+ )
328
+
329
+ if self.connection_persistence:
330
+ server_connection = await self._persistent_connection_manager.get_server(
331
+ server_name, client_session_factory=MCPAgentClientSession
332
+ )
333
+ return await try_call_tool(server_connection.session)
334
+ else:
335
+ logger.debug(
336
+ f"Creating temporary connection to server: {server_name}",
337
+ data={
338
+ "progress_action": ProgressAction.STARTING,
339
+ "server_name": server_name,
340
+ "agent_name": self.agent_name,
341
+ },
342
+ )
343
+ async with gen_client(
344
+ server_name, server_registry=self.context.server_registry
345
+ ) as client:
346
+ result = await try_call_tool(client)
347
+ logger.debug(
348
+ f"Closing temporary connection to server: {server_name}",
349
+ data={
350
+ "progress_action": ProgressAction.SHUTDOWN,
351
+ "server_name": server_name,
352
+ "agent_name": self.agent_name,
353
+ },
354
+ )
355
+ return result
356
+
357
+
358
+ class MCPCompoundServer(Server):
359
+ """
360
+ A compound server (server-of-servers) that aggregates multiple MCP servers and is itself an MCP server
361
+ """
362
+
363
+ def __init__(self, server_names: List[str], name: str = "MCPCompoundServer"):
364
+ super().__init__(name)
365
+ self.aggregator = MCPAggregator(server_names)
366
+
367
+ # Register handlers
368
+ # TODO: saqadri - once we support resources and prompts, add handlers for those as well
369
+ self.list_tools()(self._list_tools)
370
+ self.call_tool()(self._call_tool)
371
+
372
+ async def _list_tools(self) -> List[Tool]:
373
+ """List all tools aggregated from connected MCP servers."""
374
+ tools_result = await self.aggregator.list_tools()
375
+ return tools_result.tools
376
+
377
+ async def _call_tool(
378
+ self, name: str, arguments: dict | None = None
379
+ ) -> CallToolResult:
380
+ """Call a specific tool from the aggregated servers."""
381
+ try:
382
+ result = await self.aggregator.call_tool(name=name, arguments=arguments)
383
+ return result.content
384
+ except Exception as e:
385
+ return CallToolResult(isError=True, message=f"Error calling tool: {e}")
386
+
387
+ async def run_stdio_async(self) -> None:
388
+ """Run the server using stdio transport."""
389
+ async with stdio_server() as (read_stream, write_stream):
390
+ await self.run(
391
+ read_stream=read_stream,
392
+ write_stream=write_stream,
393
+ initialization_options=self.create_initialization_options(),
394
+ )