fast-agent-mcp 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- fast_agent_mcp-0.0.7.dist-info/METADATA +322 -0
- fast_agent_mcp-0.0.7.dist-info/RECORD +100 -0
- fast_agent_mcp-0.0.7.dist-info/WHEEL +4 -0
- fast_agent_mcp-0.0.7.dist-info/entry_points.txt +5 -0
- fast_agent_mcp-0.0.7.dist-info/licenses/LICENSE +201 -0
- mcp_agent/__init__.py +0 -0
- mcp_agent/agents/__init__.py +0 -0
- mcp_agent/agents/agent.py +277 -0
- mcp_agent/app.py +303 -0
- mcp_agent/cli/__init__.py +0 -0
- mcp_agent/cli/__main__.py +4 -0
- mcp_agent/cli/commands/bootstrap.py +221 -0
- mcp_agent/cli/commands/config.py +11 -0
- mcp_agent/cli/commands/setup.py +229 -0
- mcp_agent/cli/main.py +68 -0
- mcp_agent/cli/terminal.py +24 -0
- mcp_agent/config.py +334 -0
- mcp_agent/console.py +28 -0
- mcp_agent/context.py +251 -0
- mcp_agent/context_dependent.py +48 -0
- mcp_agent/core/fastagent.py +1013 -0
- mcp_agent/eval/__init__.py +0 -0
- mcp_agent/event_progress.py +88 -0
- mcp_agent/executor/__init__.py +0 -0
- mcp_agent/executor/decorator_registry.py +120 -0
- mcp_agent/executor/executor.py +293 -0
- mcp_agent/executor/task_registry.py +34 -0
- mcp_agent/executor/temporal.py +405 -0
- mcp_agent/executor/workflow.py +197 -0
- mcp_agent/executor/workflow_signal.py +325 -0
- mcp_agent/human_input/__init__.py +0 -0
- mcp_agent/human_input/handler.py +49 -0
- mcp_agent/human_input/types.py +58 -0
- mcp_agent/logging/__init__.py +0 -0
- mcp_agent/logging/events.py +123 -0
- mcp_agent/logging/json_serializer.py +163 -0
- mcp_agent/logging/listeners.py +216 -0
- mcp_agent/logging/logger.py +365 -0
- mcp_agent/logging/rich_progress.py +120 -0
- mcp_agent/logging/tracing.py +140 -0
- mcp_agent/logging/transport.py +461 -0
- mcp_agent/mcp/__init__.py +0 -0
- mcp_agent/mcp/gen_client.py +85 -0
- mcp_agent/mcp/mcp_activity.py +18 -0
- mcp_agent/mcp/mcp_agent_client_session.py +242 -0
- mcp_agent/mcp/mcp_agent_server.py +56 -0
- mcp_agent/mcp/mcp_aggregator.py +394 -0
- mcp_agent/mcp/mcp_connection_manager.py +330 -0
- mcp_agent/mcp/stdio.py +104 -0
- mcp_agent/mcp_server_registry.py +275 -0
- mcp_agent/progress_display.py +10 -0
- mcp_agent/resources/examples/decorator/main.py +26 -0
- mcp_agent/resources/examples/decorator/optimizer.py +78 -0
- mcp_agent/resources/examples/decorator/orchestrator.py +68 -0
- mcp_agent/resources/examples/decorator/parallel.py +81 -0
- mcp_agent/resources/examples/decorator/router.py +56 -0
- mcp_agent/resources/examples/decorator/tiny.py +22 -0
- mcp_agent/resources/examples/mcp_researcher/main-evalopt.py +53 -0
- mcp_agent/resources/examples/mcp_researcher/main.py +38 -0
- mcp_agent/telemetry/__init__.py +0 -0
- mcp_agent/telemetry/usage_tracking.py +18 -0
- mcp_agent/workflows/__init__.py +0 -0
- mcp_agent/workflows/embedding/__init__.py +0 -0
- mcp_agent/workflows/embedding/embedding_base.py +61 -0
- mcp_agent/workflows/embedding/embedding_cohere.py +49 -0
- mcp_agent/workflows/embedding/embedding_openai.py +46 -0
- mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +359 -0
- mcp_agent/workflows/intent_classifier/__init__.py +0 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +120 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +134 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +45 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +45 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +161 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +60 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +60 -0
- mcp_agent/workflows/llm/__init__.py +0 -0
- mcp_agent/workflows/llm/augmented_llm.py +645 -0
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +539 -0
- mcp_agent/workflows/llm/augmented_llm_openai.py +615 -0
- mcp_agent/workflows/llm/llm_selector.py +345 -0
- mcp_agent/workflows/llm/model_factory.py +175 -0
- mcp_agent/workflows/orchestrator/__init__.py +0 -0
- mcp_agent/workflows/orchestrator/orchestrator.py +407 -0
- mcp_agent/workflows/orchestrator/orchestrator_models.py +154 -0
- mcp_agent/workflows/orchestrator/orchestrator_prompts.py +113 -0
- mcp_agent/workflows/parallel/__init__.py +0 -0
- mcp_agent/workflows/parallel/fan_in.py +350 -0
- mcp_agent/workflows/parallel/fan_out.py +187 -0
- mcp_agent/workflows/parallel/parallel_llm.py +141 -0
- mcp_agent/workflows/router/__init__.py +0 -0
- mcp_agent/workflows/router/router_base.py +276 -0
- mcp_agent/workflows/router/router_embedding.py +240 -0
- mcp_agent/workflows/router/router_embedding_cohere.py +59 -0
- mcp_agent/workflows/router/router_embedding_openai.py +59 -0
- mcp_agent/workflows/router/router_llm.py +301 -0
- mcp_agent/workflows/swarm/__init__.py +0 -0
- mcp_agent/workflows/swarm/swarm.py +320 -0
- mcp_agent/workflows/swarm/swarm_anthropic.py +42 -0
- mcp_agent/workflows/swarm/swarm_openai.py +41 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""
|
|
2
|
+
A derived client session for the MCP Agent framework.
|
|
3
|
+
It adds logging and supports sampling requests.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from mcp import ClientSession
|
|
9
|
+
from mcp.shared.session import (
|
|
10
|
+
RequestResponder,
|
|
11
|
+
ReceiveResultT,
|
|
12
|
+
ReceiveNotificationT,
|
|
13
|
+
RequestId,
|
|
14
|
+
SendNotificationT,
|
|
15
|
+
SendRequestT,
|
|
16
|
+
SendResultT,
|
|
17
|
+
)
|
|
18
|
+
from mcp.types import (
|
|
19
|
+
ClientResult,
|
|
20
|
+
CreateMessageRequest,
|
|
21
|
+
CreateMessageResult,
|
|
22
|
+
ErrorData,
|
|
23
|
+
JSONRPCNotification,
|
|
24
|
+
JSONRPCRequest,
|
|
25
|
+
ServerRequest,
|
|
26
|
+
TextContent,
|
|
27
|
+
ListRootsRequest,
|
|
28
|
+
ListRootsResult,
|
|
29
|
+
Root,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
from mcp_agent.config import MCPServerSettings
|
|
33
|
+
from mcp_agent.context_dependent import ContextDependent
|
|
34
|
+
from mcp_agent.logging.logger import get_logger
|
|
35
|
+
|
|
36
|
+
logger = get_logger(__name__)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
40
|
+
"""
|
|
41
|
+
MCP Agent framework acts as a client to the servers providing tools/resources/prompts for the agent workloads.
|
|
42
|
+
This is a simple client session for those server connections, and supports
|
|
43
|
+
- handling sampling requests
|
|
44
|
+
- notifications
|
|
45
|
+
- MCP root configuration
|
|
46
|
+
|
|
47
|
+
Developers can extend this class to add more custom functionality as needed
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, *args, **kwargs):
|
|
51
|
+
super().__init__(*args, **kwargs)
|
|
52
|
+
self.server_config: Optional[MCPServerSettings] = None
|
|
53
|
+
|
|
54
|
+
async def _received_request(
|
|
55
|
+
self, responder: RequestResponder[ServerRequest, ClientResult]
|
|
56
|
+
) -> None:
|
|
57
|
+
logger.debug("Received request:", data=responder.request.model_dump())
|
|
58
|
+
request = responder.request.root
|
|
59
|
+
|
|
60
|
+
if isinstance(request, CreateMessageRequest):
|
|
61
|
+
return await self.handle_sampling_request(request, responder)
|
|
62
|
+
elif isinstance(request, ListRootsRequest):
|
|
63
|
+
# Handle list_roots request by returning configured roots
|
|
64
|
+
if hasattr(self, "server_config") and self.server_config.roots:
|
|
65
|
+
roots = [
|
|
66
|
+
Root(
|
|
67
|
+
uri=root.server_uri_alias or root.uri,
|
|
68
|
+
name=root.name,
|
|
69
|
+
)
|
|
70
|
+
for root in self.server_config.roots
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
await responder.respond(ListRootsResult(roots=roots))
|
|
74
|
+
else:
|
|
75
|
+
await responder.respond(ListRootsResult(roots=[]))
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
# Handle other requests as usual
|
|
79
|
+
await super()._received_request(responder)
|
|
80
|
+
|
|
81
|
+
async def send_request(
|
|
82
|
+
self,
|
|
83
|
+
request: SendRequestT,
|
|
84
|
+
result_type: type[ReceiveResultT],
|
|
85
|
+
) -> ReceiveResultT:
|
|
86
|
+
logger.debug("send_request: request=", data=request.model_dump())
|
|
87
|
+
try:
|
|
88
|
+
result = await super().send_request(request, result_type)
|
|
89
|
+
logger.debug("send_request: response=", data=result.model_dump())
|
|
90
|
+
return result
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"send_request failed: {e}")
|
|
93
|
+
raise
|
|
94
|
+
|
|
95
|
+
async def send_notification(self, notification: SendNotificationT) -> None:
|
|
96
|
+
logger.debug("send_notification:", data=notification.model_dump())
|
|
97
|
+
try:
|
|
98
|
+
return await super().send_notification(notification)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error("send_notification failed", data=e)
|
|
101
|
+
raise
|
|
102
|
+
|
|
103
|
+
async def _send_response(
|
|
104
|
+
self, request_id: RequestId, response: SendResultT | ErrorData
|
|
105
|
+
) -> None:
|
|
106
|
+
logger.debug(
|
|
107
|
+
f"send_response: request_id={request_id}, response=",
|
|
108
|
+
data=response.model_dump(),
|
|
109
|
+
)
|
|
110
|
+
return await super()._send_response(request_id, response)
|
|
111
|
+
|
|
112
|
+
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Can be overridden by subclasses to handle a notification without needing
|
|
115
|
+
to listen on the message stream.
|
|
116
|
+
"""
|
|
117
|
+
logger.info(
|
|
118
|
+
"_received_notification: notification=",
|
|
119
|
+
data=notification.model_dump(),
|
|
120
|
+
)
|
|
121
|
+
return await super()._received_notification(notification)
|
|
122
|
+
|
|
123
|
+
async def send_progress_notification(
|
|
124
|
+
self, progress_token: str | int, progress: float, total: float | None = None
|
|
125
|
+
) -> None:
|
|
126
|
+
"""
|
|
127
|
+
Sends a progress notification for a request that is currently being
|
|
128
|
+
processed.
|
|
129
|
+
"""
|
|
130
|
+
logger.debug(
|
|
131
|
+
"send_progress_notification: progress_token={progress_token}, progress={progress}, total={total}"
|
|
132
|
+
)
|
|
133
|
+
return await super().send_progress_notification(
|
|
134
|
+
progress_token=progress_token, progress=progress, total=total
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
async def _receive_loop(self) -> None:
|
|
138
|
+
async with (
|
|
139
|
+
self._read_stream,
|
|
140
|
+
self._write_stream,
|
|
141
|
+
self._incoming_message_stream_writer,
|
|
142
|
+
):
|
|
143
|
+
async for message in self._read_stream:
|
|
144
|
+
if isinstance(message, Exception):
|
|
145
|
+
await self._incoming_message_stream_writer.send(message)
|
|
146
|
+
elif isinstance(message.root, JSONRPCRequest):
|
|
147
|
+
validated_request = self._receive_request_type.model_validate(
|
|
148
|
+
message.root.model_dump(
|
|
149
|
+
by_alias=True, mode="json", exclude_none=True
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
responder = RequestResponder(
|
|
153
|
+
request_id=message.root.id,
|
|
154
|
+
request_meta=validated_request.root.params.meta
|
|
155
|
+
if validated_request.root.params
|
|
156
|
+
else None,
|
|
157
|
+
request=validated_request,
|
|
158
|
+
session=self,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
await self._received_request(responder)
|
|
162
|
+
if not responder._responded:
|
|
163
|
+
await self._incoming_message_stream_writer.send(responder)
|
|
164
|
+
elif isinstance(message.root, JSONRPCNotification):
|
|
165
|
+
notification = self._receive_notification_type.model_validate(
|
|
166
|
+
message.root.model_dump(
|
|
167
|
+
by_alias=True, mode="json", exclude_none=True
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
await self._received_notification(notification)
|
|
172
|
+
await self._incoming_message_stream_writer.send(notification)
|
|
173
|
+
else: # Response or error
|
|
174
|
+
stream = self._response_streams.pop(message.root.id, None)
|
|
175
|
+
if stream:
|
|
176
|
+
await stream.send(message.root)
|
|
177
|
+
else:
|
|
178
|
+
await self._incoming_message_stream_writer.send(
|
|
179
|
+
RuntimeError(
|
|
180
|
+
"Received response with an unknown "
|
|
181
|
+
f"request ID: {message}"
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
async def handle_sampling_request(
|
|
186
|
+
self,
|
|
187
|
+
request: CreateMessageRequest,
|
|
188
|
+
responder: RequestResponder[ServerRequest, ClientResult],
|
|
189
|
+
):
|
|
190
|
+
logger.info("Handling sampling request: %s", request)
|
|
191
|
+
config = self.context.config
|
|
192
|
+
session = self.context.upstream_session
|
|
193
|
+
if session is None:
|
|
194
|
+
# TODO: saqadri - consider whether we should be handling the sampling request here as a client
|
|
195
|
+
logger.warning(
|
|
196
|
+
"Error: No upstream client available for sampling requests. Request:",
|
|
197
|
+
data=request,
|
|
198
|
+
)
|
|
199
|
+
try:
|
|
200
|
+
from anthropic import AsyncAnthropic
|
|
201
|
+
|
|
202
|
+
client = AsyncAnthropic(api_key=config.anthropic.api_key)
|
|
203
|
+
|
|
204
|
+
params = request.params
|
|
205
|
+
response = await client.messages.create(
|
|
206
|
+
model="claude-3-sonnet-20240229",
|
|
207
|
+
max_tokens=params.maxTokens,
|
|
208
|
+
messages=[
|
|
209
|
+
{
|
|
210
|
+
"role": m.role,
|
|
211
|
+
"content": m.content.text
|
|
212
|
+
if hasattr(m.content, "text")
|
|
213
|
+
else m.content.data,
|
|
214
|
+
}
|
|
215
|
+
for m in params.messages
|
|
216
|
+
],
|
|
217
|
+
system=getattr(params, "systemPrompt", None),
|
|
218
|
+
temperature=getattr(params, "temperature", 0.7),
|
|
219
|
+
stop_sequences=getattr(params, "stopSequences", None),
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
await responder.respond(
|
|
223
|
+
CreateMessageResult(
|
|
224
|
+
model="claude-3-sonnet-20240229",
|
|
225
|
+
role="assistant",
|
|
226
|
+
content=TextContent(type="text", text=response.content[0].text),
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f"Error handling sampling request: {e}")
|
|
231
|
+
await responder.respond(ErrorData(code=-32603, message=str(e)))
|
|
232
|
+
else:
|
|
233
|
+
try:
|
|
234
|
+
# If a session is available, we'll pass-through the sampling request to the upstream client
|
|
235
|
+
result = await session.send_request(
|
|
236
|
+
request=ServerRequest(request), result_type=CreateMessageResult
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Pass the result from the upstream client back to the server. We just act as a pass-through client here.
|
|
240
|
+
await responder.send_result(result)
|
|
241
|
+
except Exception as e:
|
|
242
|
+
await responder.send_error(code=-32603, message=str(e))
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from mcp.server import NotificationOptions
|
|
3
|
+
from mcp.server.fastmcp import FastMCP
|
|
4
|
+
from mcp.server.stdio import stdio_server
|
|
5
|
+
from mcp_agent.executor.temporal import get_temporal_client
|
|
6
|
+
from mcp_agent.telemetry.tracing import setup_tracing
|
|
7
|
+
|
|
8
|
+
app = FastMCP("mcp-agent-server")
|
|
9
|
+
|
|
10
|
+
setup_tracing("mcp-agent-server")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def run():
|
|
14
|
+
async with stdio_server() as (read_stream, write_stream):
|
|
15
|
+
await app._mcp_server.run(
|
|
16
|
+
read_stream,
|
|
17
|
+
write_stream,
|
|
18
|
+
app._mcp_server.create_initialization_options(
|
|
19
|
+
notification_options=NotificationOptions(
|
|
20
|
+
tools_changed=True, resources_changed=True
|
|
21
|
+
)
|
|
22
|
+
),
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@app.tool
|
|
27
|
+
async def run_workflow(query: str):
|
|
28
|
+
"""Run the workflow given its name or id"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@app.tool
|
|
33
|
+
async def pause_workflow(workflow_id: str):
|
|
34
|
+
"""Pause a running workflow."""
|
|
35
|
+
temporal_client = await get_temporal_client()
|
|
36
|
+
handle = temporal_client.get_workflow_handle(workflow_id)
|
|
37
|
+
await handle.signal("pause")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@app.tool
|
|
41
|
+
async def resume_workflow(workflow_id: str):
|
|
42
|
+
"""Resume a paused workflow."""
|
|
43
|
+
temporal_client = await get_temporal_client()
|
|
44
|
+
handle = temporal_client.get_workflow_handle(workflow_id)
|
|
45
|
+
await handle.signal("resume")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def provide_user_input(workflow_id: str, input_data: str):
|
|
49
|
+
"""Provide user/human input to a waiting workflow step."""
|
|
50
|
+
temporal_client = await get_temporal_client()
|
|
51
|
+
handle = temporal_client.get_workflow_handle(workflow_id)
|
|
52
|
+
await handle.signal("human_input", input_data)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
asyncio.run(run())
|
|
@@ -0,0 +1,394 @@
|
|
|
1
|
+
from asyncio import Lock, gather
|
|
2
|
+
from typing import List, Dict, Optional, TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, ConfigDict
|
|
5
|
+
from mcp.client.session import ClientSession
|
|
6
|
+
from mcp.server.lowlevel.server import Server
|
|
7
|
+
from mcp.server.stdio import stdio_server
|
|
8
|
+
from mcp.types import (
|
|
9
|
+
CallToolResult,
|
|
10
|
+
ListToolsResult,
|
|
11
|
+
Tool,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from mcp_agent.event_progress import ProgressAction
|
|
15
|
+
from mcp_agent.logging.logger import get_logger
|
|
16
|
+
from mcp_agent.mcp.gen_client import gen_client
|
|
17
|
+
|
|
18
|
+
from mcp_agent.context_dependent import ContextDependent
|
|
19
|
+
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
|
20
|
+
from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from mcp_agent.context import Context
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
logger = get_logger(
|
|
27
|
+
__name__
|
|
28
|
+
) # This will be replaced per-instance when agent_name is available
|
|
29
|
+
|
|
30
|
+
SEP = "-"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class NamespacedTool(BaseModel):
|
|
34
|
+
"""
|
|
35
|
+
A tool that is namespaced by server name.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
tool: Tool
|
|
39
|
+
server_name: str
|
|
40
|
+
namespaced_tool_name: str
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class MCPAggregator(ContextDependent):
|
|
44
|
+
"""
|
|
45
|
+
Aggregates multiple MCP servers. When a developer calls, e.g. call_tool(...),
|
|
46
|
+
the aggregator searches all servers in its list for a server that provides that tool.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
initialized: bool = False
|
|
50
|
+
"""Whether the aggregator has been initialized with tools and resources from all servers."""
|
|
51
|
+
|
|
52
|
+
connection_persistence: bool = False
|
|
53
|
+
"""Whether to maintain a persistent connection to the server."""
|
|
54
|
+
|
|
55
|
+
server_names: List[str]
|
|
56
|
+
"""A list of server names to connect to."""
|
|
57
|
+
|
|
58
|
+
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
|
|
59
|
+
|
|
60
|
+
async def __aenter__(self):
|
|
61
|
+
if self.initialized:
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
# Keep a connection manager to manage persistent connections for this aggregator
|
|
65
|
+
if self.connection_persistence:
|
|
66
|
+
# Try to get existing connection manager from context
|
|
67
|
+
if not hasattr(self.context, "_connection_manager"):
|
|
68
|
+
self.context._connection_manager = MCPConnectionManager(
|
|
69
|
+
self.context.server_registry
|
|
70
|
+
)
|
|
71
|
+
await self.context._connection_manager.__aenter__()
|
|
72
|
+
self._persistent_connection_manager = self.context._connection_manager
|
|
73
|
+
|
|
74
|
+
await self.load_servers()
|
|
75
|
+
|
|
76
|
+
return self
|
|
77
|
+
|
|
78
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
79
|
+
await self.close()
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
server_names: List[str],
|
|
84
|
+
connection_persistence: bool = False,
|
|
85
|
+
context: Optional["Context"] = None,
|
|
86
|
+
name: str = None,
|
|
87
|
+
**kwargs,
|
|
88
|
+
):
|
|
89
|
+
"""
|
|
90
|
+
:param server_names: A list of server names to connect to.
|
|
91
|
+
Note: The server names must be resolvable by the gen_client function, and specified in the server registry.
|
|
92
|
+
"""
|
|
93
|
+
super().__init__(
|
|
94
|
+
context=context,
|
|
95
|
+
**kwargs,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
self.server_names = server_names
|
|
99
|
+
self.connection_persistence = connection_persistence
|
|
100
|
+
self.agent_name = name
|
|
101
|
+
self._persistent_connection_manager: MCPConnectionManager = None
|
|
102
|
+
|
|
103
|
+
# Set up logger with agent name in namespace if available
|
|
104
|
+
global logger
|
|
105
|
+
logger_name = f"{__name__}.{name}" if name else __name__
|
|
106
|
+
logger = get_logger(logger_name)
|
|
107
|
+
|
|
108
|
+
# Maps namespaced_tool_name -> namespaced tool info
|
|
109
|
+
self._namespaced_tool_map: Dict[str, NamespacedTool] = {}
|
|
110
|
+
# Maps server_name -> list of tools
|
|
111
|
+
self._server_to_tool_map: Dict[str, List[NamespacedTool]] = {}
|
|
112
|
+
self._tool_map_lock = Lock()
|
|
113
|
+
|
|
114
|
+
# TODO: saqadri - add resources and prompt maps as well
|
|
115
|
+
|
|
116
|
+
async def close(self):
|
|
117
|
+
"""
|
|
118
|
+
Close all persistent connections when the aggregator is deleted.
|
|
119
|
+
"""
|
|
120
|
+
if self.connection_persistence and self._persistent_connection_manager:
|
|
121
|
+
try:
|
|
122
|
+
# Only attempt cleanup if we own the connection manager
|
|
123
|
+
if (
|
|
124
|
+
hasattr(self.context, "_connection_manager")
|
|
125
|
+
and self.context._connection_manager
|
|
126
|
+
== self._persistent_connection_manager
|
|
127
|
+
):
|
|
128
|
+
logger.info("Shutting down all persistent connections...")
|
|
129
|
+
await self._persistent_connection_manager.disconnect_all()
|
|
130
|
+
await self._persistent_connection_manager.__aexit__(
|
|
131
|
+
None, None, None
|
|
132
|
+
)
|
|
133
|
+
delattr(self.context, "_connection_manager")
|
|
134
|
+
self.initialized = False
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"Error during connection manager cleanup: {e}")
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
async def create(
|
|
140
|
+
cls,
|
|
141
|
+
server_names: List[str],
|
|
142
|
+
connection_persistence: bool = False,
|
|
143
|
+
) -> "MCPAggregator":
|
|
144
|
+
"""
|
|
145
|
+
Factory method to create and initialize an MCPAggregator.
|
|
146
|
+
Use this instead of constructor since we need async initialization.
|
|
147
|
+
If connection_persistence is True, the aggregator will maintain a
|
|
148
|
+
persistent connection to the servers for as long as this aggregator is around.
|
|
149
|
+
By default we do not maintain a persistent connection.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
logger.info(f"Creating MCPAggregator with servers: {server_names}")
|
|
153
|
+
|
|
154
|
+
instance = cls(
|
|
155
|
+
server_names=server_names,
|
|
156
|
+
connection_persistence=connection_persistence,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
await instance.__aenter__()
|
|
161
|
+
|
|
162
|
+
logger.debug("Loading servers...")
|
|
163
|
+
await instance.load_servers()
|
|
164
|
+
|
|
165
|
+
logger.debug("MCPAggregator created and initialized.")
|
|
166
|
+
return instance
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"Error creating MCPAggregator: {e}")
|
|
169
|
+
await instance.__aexit__(None, None, None)
|
|
170
|
+
|
|
171
|
+
async def load_servers(self):
|
|
172
|
+
"""
|
|
173
|
+
Discover tools from each server in parallel and build an index of namespaced tool names.
|
|
174
|
+
"""
|
|
175
|
+
if self.initialized:
|
|
176
|
+
logger.debug("MCPAggregator already initialized.")
|
|
177
|
+
return
|
|
178
|
+
|
|
179
|
+
async with self._tool_map_lock:
|
|
180
|
+
self._namespaced_tool_map.clear()
|
|
181
|
+
self._server_to_tool_map.clear()
|
|
182
|
+
|
|
183
|
+
for server_name in self.server_names:
|
|
184
|
+
if self.connection_persistence:
|
|
185
|
+
logger.info(
|
|
186
|
+
f"Creating persistent connection to server: {server_name}",
|
|
187
|
+
data={
|
|
188
|
+
"progress_action": ProgressAction.STARTING,
|
|
189
|
+
"server_name": server_name,
|
|
190
|
+
"agent_name": self.agent_name,
|
|
191
|
+
},
|
|
192
|
+
)
|
|
193
|
+
await self._persistent_connection_manager.get_server(
|
|
194
|
+
server_name, client_session_factory=MCPAgentClientSession
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
logger.info(
|
|
198
|
+
f"MCP Servers initialized for agent '{self.agent_name}'",
|
|
199
|
+
data={
|
|
200
|
+
"progress_action": ProgressAction.INITIALIZED,
|
|
201
|
+
"agent_name": self.agent_name,
|
|
202
|
+
},
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
async def fetch_tools(client: ClientSession):
|
|
206
|
+
try:
|
|
207
|
+
result: ListToolsResult = await client.list_tools()
|
|
208
|
+
return result.tools or []
|
|
209
|
+
except Exception as e:
|
|
210
|
+
logger.error(f"Error loading tools from server '{server_name}'", data=e)
|
|
211
|
+
return []
|
|
212
|
+
|
|
213
|
+
async def load_server_tools(server_name: str):
|
|
214
|
+
tools: List[Tool] = []
|
|
215
|
+
if self.connection_persistence:
|
|
216
|
+
server_connection = (
|
|
217
|
+
await self._persistent_connection_manager.get_server(
|
|
218
|
+
server_name, client_session_factory=MCPAgentClientSession
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
tools = await fetch_tools(server_connection.session)
|
|
222
|
+
else:
|
|
223
|
+
async with gen_client(
|
|
224
|
+
server_name, server_registry=self.context.server_registry
|
|
225
|
+
) as client:
|
|
226
|
+
tools = await fetch_tools(client)
|
|
227
|
+
|
|
228
|
+
return server_name, tools
|
|
229
|
+
|
|
230
|
+
# Gather tools from all servers concurrently
|
|
231
|
+
results = await gather(
|
|
232
|
+
*(load_server_tools(server_name) for server_name in self.server_names),
|
|
233
|
+
return_exceptions=True,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
for result in results:
|
|
237
|
+
if isinstance(result, BaseException):
|
|
238
|
+
continue
|
|
239
|
+
server_name, tools = result
|
|
240
|
+
|
|
241
|
+
self._server_to_tool_map[server_name] = []
|
|
242
|
+
for tool in tools:
|
|
243
|
+
namespaced_tool_name = f"{server_name}{SEP}{tool.name}"
|
|
244
|
+
namespaced_tool = NamespacedTool(
|
|
245
|
+
tool=tool,
|
|
246
|
+
server_name=server_name,
|
|
247
|
+
namespaced_tool_name=namespaced_tool_name,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
self._namespaced_tool_map[namespaced_tool_name] = namespaced_tool
|
|
251
|
+
self._server_to_tool_map[server_name].append(namespaced_tool)
|
|
252
|
+
logger.debug(
|
|
253
|
+
"MCP Aggregator initialized",
|
|
254
|
+
data={
|
|
255
|
+
"progress_action": ProgressAction.INITIALIZED,
|
|
256
|
+
"server_name": server_name,
|
|
257
|
+
"agent_name": self.agent_name,
|
|
258
|
+
},
|
|
259
|
+
)
|
|
260
|
+
self.initialized = True
|
|
261
|
+
|
|
262
|
+
async def list_servers(self) -> List[str]:
|
|
263
|
+
"""Return the list of server names aggregated by this agent."""
|
|
264
|
+
if not self.initialized:
|
|
265
|
+
await self.load_servers()
|
|
266
|
+
|
|
267
|
+
return self.server_names
|
|
268
|
+
|
|
269
|
+
async def list_tools(self) -> ListToolsResult:
|
|
270
|
+
"""
|
|
271
|
+
:return: Tools from all servers aggregated, and renamed to be dot-namespaced by server name.
|
|
272
|
+
"""
|
|
273
|
+
if not self.initialized:
|
|
274
|
+
await self.load_servers()
|
|
275
|
+
|
|
276
|
+
return ListToolsResult(
|
|
277
|
+
tools=[
|
|
278
|
+
namespaced_tool.tool.model_copy(update={"name": namespaced_tool_name})
|
|
279
|
+
for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items()
|
|
280
|
+
]
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
async def call_tool(
|
|
284
|
+
self, name: str, arguments: dict | None = None
|
|
285
|
+
) -> CallToolResult:
|
|
286
|
+
"""
|
|
287
|
+
Call a namespaced tool, e.g., 'server_name.tool_name'.
|
|
288
|
+
"""
|
|
289
|
+
if not self.initialized:
|
|
290
|
+
await self.load_servers()
|
|
291
|
+
|
|
292
|
+
server_name: str = None
|
|
293
|
+
local_tool_name: str = None
|
|
294
|
+
|
|
295
|
+
if SEP in name: # Namespaced tool name
|
|
296
|
+
server_name, local_tool_name = name.split(SEP, 1)
|
|
297
|
+
else:
|
|
298
|
+
# Assume un-namespaced, loop through all servers to find the tool. First match wins.
|
|
299
|
+
for _, tools in self._server_to_tool_map.items():
|
|
300
|
+
for namespaced_tool in tools:
|
|
301
|
+
if namespaced_tool.tool.name == name:
|
|
302
|
+
server_name = namespaced_tool.server_name
|
|
303
|
+
local_tool_name = name
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
if server_name is None or local_tool_name is None:
|
|
307
|
+
logger.error(f"Error: Tool '{name}' not found")
|
|
308
|
+
return CallToolResult(isError=True, message=f"Tool '{name}' not found")
|
|
309
|
+
|
|
310
|
+
logger.info(
|
|
311
|
+
"Requesting tool call",
|
|
312
|
+
data={
|
|
313
|
+
"progress_action": ProgressAction.CALLING_TOOL,
|
|
314
|
+
"tool_name": local_tool_name,
|
|
315
|
+
"server_name": server_name,
|
|
316
|
+
"agent_name": self.agent_name,
|
|
317
|
+
},
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
async def try_call_tool(client: ClientSession):
|
|
321
|
+
try:
|
|
322
|
+
return await client.call_tool(name=local_tool_name, arguments=arguments)
|
|
323
|
+
except Exception as e:
|
|
324
|
+
return CallToolResult(
|
|
325
|
+
isError=True,
|
|
326
|
+
message=f"Failed to call tool '{local_tool_name}' on server '{server_name}': {e}",
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if self.connection_persistence:
|
|
330
|
+
server_connection = await self._persistent_connection_manager.get_server(
|
|
331
|
+
server_name, client_session_factory=MCPAgentClientSession
|
|
332
|
+
)
|
|
333
|
+
return await try_call_tool(server_connection.session)
|
|
334
|
+
else:
|
|
335
|
+
logger.debug(
|
|
336
|
+
f"Creating temporary connection to server: {server_name}",
|
|
337
|
+
data={
|
|
338
|
+
"progress_action": ProgressAction.STARTING,
|
|
339
|
+
"server_name": server_name,
|
|
340
|
+
"agent_name": self.agent_name,
|
|
341
|
+
},
|
|
342
|
+
)
|
|
343
|
+
async with gen_client(
|
|
344
|
+
server_name, server_registry=self.context.server_registry
|
|
345
|
+
) as client:
|
|
346
|
+
result = await try_call_tool(client)
|
|
347
|
+
logger.debug(
|
|
348
|
+
f"Closing temporary connection to server: {server_name}",
|
|
349
|
+
data={
|
|
350
|
+
"progress_action": ProgressAction.SHUTDOWN,
|
|
351
|
+
"server_name": server_name,
|
|
352
|
+
"agent_name": self.agent_name,
|
|
353
|
+
},
|
|
354
|
+
)
|
|
355
|
+
return result
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
class MCPCompoundServer(Server):
|
|
359
|
+
"""
|
|
360
|
+
A compound server (server-of-servers) that aggregates multiple MCP servers and is itself an MCP server
|
|
361
|
+
"""
|
|
362
|
+
|
|
363
|
+
def __init__(self, server_names: List[str], name: str = "MCPCompoundServer"):
|
|
364
|
+
super().__init__(name)
|
|
365
|
+
self.aggregator = MCPAggregator(server_names)
|
|
366
|
+
|
|
367
|
+
# Register handlers
|
|
368
|
+
# TODO: saqadri - once we support resources and prompts, add handlers for those as well
|
|
369
|
+
self.list_tools()(self._list_tools)
|
|
370
|
+
self.call_tool()(self._call_tool)
|
|
371
|
+
|
|
372
|
+
async def _list_tools(self) -> List[Tool]:
|
|
373
|
+
"""List all tools aggregated from connected MCP servers."""
|
|
374
|
+
tools_result = await self.aggregator.list_tools()
|
|
375
|
+
return tools_result.tools
|
|
376
|
+
|
|
377
|
+
async def _call_tool(
|
|
378
|
+
self, name: str, arguments: dict | None = None
|
|
379
|
+
) -> CallToolResult:
|
|
380
|
+
"""Call a specific tool from the aggregated servers."""
|
|
381
|
+
try:
|
|
382
|
+
result = await self.aggregator.call_tool(name=name, arguments=arguments)
|
|
383
|
+
return result.content
|
|
384
|
+
except Exception as e:
|
|
385
|
+
return CallToolResult(isError=True, message=f"Error calling tool: {e}")
|
|
386
|
+
|
|
387
|
+
async def run_stdio_async(self) -> None:
|
|
388
|
+
"""Run the server using stdio transport."""
|
|
389
|
+
async with stdio_server() as (read_stream, write_stream):
|
|
390
|
+
await self.run(
|
|
391
|
+
read_stream=read_stream,
|
|
392
|
+
write_stream=write_stream,
|
|
393
|
+
initialization_options=self.create_initialization_options(),
|
|
394
|
+
)
|