swarms 7.6.4__py3-none-any.whl → 7.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- swarms/structs/__init__.py +1 -3
- swarms/structs/agent.py +77 -0
- swarms/tools/mcp_integration.py +321 -483
- swarms/utils/vllm_wrapper.py +146 -0
- {swarms-7.6.4.dist-info → swarms-7.6.5.dist-info}/METADATA +1 -1
- {swarms-7.6.4.dist-info → swarms-7.6.5.dist-info}/RECORD +9 -9
- swarms/structs/auto_swarm.py +0 -229
- {swarms-7.6.4.dist-info → swarms-7.6.5.dist-info}/LICENSE +0 -0
- {swarms-7.6.4.dist-info → swarms-7.6.5.dist-info}/WHEEL +0 -0
- {swarms-7.6.4.dist-info → swarms-7.6.5.dist-info}/entry_points.txt +0 -0
swarms/tools/mcp_integration.py
CHANGED
@@ -1,554 +1,392 @@
|
|
1
|
-
from
|
2
|
-
from types import TracebackType
|
3
|
-
from typing import (
|
4
|
-
Any,
|
5
|
-
Callable,
|
6
|
-
Coroutine,
|
7
|
-
List,
|
8
|
-
Literal,
|
9
|
-
Optional,
|
10
|
-
TypedDict,
|
11
|
-
cast,
|
12
|
-
)
|
1
|
+
from __future__ import annotations
|
13
2
|
|
14
|
-
from
|
15
|
-
|
16
|
-
|
17
|
-
from
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
3
|
+
from typing import Any, List
|
4
|
+
|
5
|
+
|
6
|
+
from loguru import logger
|
7
|
+
|
8
|
+
import abc
|
9
|
+
import asyncio
|
10
|
+
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
11
|
+
from pathlib import Path
|
12
|
+
from typing import Literal
|
13
|
+
|
14
|
+
from anyio.streams.memory import (
|
15
|
+
MemoryObjectReceiveStream,
|
16
|
+
MemoryObjectSendStream,
|
23
17
|
)
|
24
|
-
from mcp
|
18
|
+
from mcp import (
|
19
|
+
ClientSession,
|
20
|
+
StdioServerParameters,
|
25
21
|
Tool as MCPTool,
|
22
|
+
stdio_client,
|
26
23
|
)
|
24
|
+
from mcp.client.sse import sse_client
|
25
|
+
from mcp.types import CallToolResult, JSONRPCMessage
|
26
|
+
from typing_extensions import NotRequired, TypedDict
|
27
27
|
|
28
|
+
from swarms.utils.any_to_str import any_to_str
|
28
29
|
|
29
|
-
def convert_mcp_prompt_message_to_message(
|
30
|
-
message: PromptMessage,
|
31
|
-
) -> str:
|
32
|
-
"""Convert an MCP prompt message to a string message.
|
33
30
|
|
34
|
-
|
35
|
-
|
31
|
+
class MCPServer(abc.ABC):
|
32
|
+
"""Base class for Model Context Protocol servers."""
|
36
33
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
34
|
+
@abc.abstractmethod
|
35
|
+
async def connect(self):
|
36
|
+
"""Connect to the server. For example, this might mean spawning a subprocess or
|
37
|
+
opening a network connection. The server is expected to remain connected until
|
38
|
+
`cleanup()` is called.
|
39
|
+
"""
|
40
|
+
pass
|
41
|
+
|
42
|
+
@property
|
43
|
+
@abc.abstractmethod
|
44
|
+
def name(self) -> str:
|
45
|
+
"""A readable name for the server."""
|
46
|
+
pass
|
47
|
+
|
48
|
+
@abc.abstractmethod
|
49
|
+
async def cleanup(self):
|
50
|
+
"""Cleanup the server. For example, this might mean closing a subprocess or
|
51
|
+
closing a network connection.
|
52
|
+
"""
|
53
|
+
pass
|
51
54
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
+
@abc.abstractmethod
|
56
|
+
async def list_tools(self) -> list[MCPTool]:
|
57
|
+
"""List the tools available on the server."""
|
58
|
+
pass
|
55
59
|
|
60
|
+
@abc.abstractmethod
|
61
|
+
async def call_tool(
|
62
|
+
self, tool_name: str, arguments: dict[str, Any] | None
|
63
|
+
) -> CallToolResult:
|
64
|
+
"""Invoke a tool on the server."""
|
65
|
+
pass
|
56
66
|
|
57
|
-
async def load_mcp_prompt(
|
58
|
-
session: ClientSession,
|
59
|
-
name: str,
|
60
|
-
arguments: Optional[dict[str, Any]] = None,
|
61
|
-
) -> List[str]:
|
62
|
-
"""Load MCP prompt and convert to messages."""
|
63
|
-
response = await session.get_prompt(name, arguments)
|
64
67
|
|
65
|
-
|
66
|
-
|
67
|
-
for message in response.messages
|
68
|
-
]
|
68
|
+
class _MCPServerWithClientSession(MCPServer, abc.ABC):
|
69
|
+
"""Base class for MCP servers that use a `ClientSession` to communicate with the server."""
|
69
70
|
|
71
|
+
def __init__(self, cache_tools_list: bool):
|
72
|
+
"""
|
73
|
+
Args:
|
74
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
75
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
76
|
+
fetched from the server on each call to `list_tools()`. The cache can be invalidated
|
77
|
+
by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
|
78
|
+
server will not change its tools list, because it can drastically improve latency
|
79
|
+
(by avoiding a round-trip to the server every time).
|
80
|
+
"""
|
81
|
+
self.session: ClientSession | None = None
|
82
|
+
self.exit_stack: AsyncExitStack = AsyncExitStack()
|
83
|
+
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
|
84
|
+
self.cache_tools_list = cache_tools_list
|
70
85
|
|
71
|
-
|
72
|
-
|
86
|
+
# The cache is always dirty at startup, so that we fetch tools at least once
|
87
|
+
self._cache_dirty = True
|
88
|
+
self._tools_list: list[MCPTool] | None = None
|
73
89
|
|
74
|
-
|
75
|
-
|
90
|
+
@abc.abstractmethod
|
91
|
+
def create_streams(
|
92
|
+
self,
|
93
|
+
) -> AbstractAsyncContextManager[
|
94
|
+
tuple[
|
95
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
96
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
97
|
+
]
|
98
|
+
]:
|
99
|
+
"""Create the streams for the server."""
|
100
|
+
pass
|
101
|
+
|
102
|
+
async def __aenter__(self):
|
103
|
+
await self.connect()
|
104
|
+
return self
|
105
|
+
|
106
|
+
async def __aexit__(self, exc_type, exc_value, traceback):
|
107
|
+
await self.cleanup()
|
108
|
+
|
109
|
+
def invalidate_tools_cache(self):
|
110
|
+
"""Invalidate the tools cache."""
|
111
|
+
self._cache_dirty = True
|
112
|
+
|
113
|
+
async def connect(self):
|
114
|
+
"""Connect to the server."""
|
115
|
+
try:
|
116
|
+
transport = await self.exit_stack.enter_async_context(
|
117
|
+
self.create_streams()
|
118
|
+
)
|
119
|
+
read, write = transport
|
120
|
+
session = await self.exit_stack.enter_async_context(
|
121
|
+
ClientSession(read, write)
|
122
|
+
)
|
123
|
+
await session.initialize()
|
124
|
+
self.session = session
|
125
|
+
except Exception as e:
|
126
|
+
logger.error(f"Error initializing MCP server: {e}")
|
127
|
+
await self.cleanup()
|
128
|
+
raise
|
76
129
|
|
130
|
+
async def list_tools(self) -> list[MCPTool]:
|
131
|
+
"""List the tools available on the server."""
|
132
|
+
if not self.session:
|
133
|
+
raise Exception(
|
134
|
+
"Server not initialized. Make sure you call `connect()` first."
|
135
|
+
)
|
77
136
|
|
78
|
-
|
79
|
-
|
137
|
+
# Return from cache if caching is enabled, we have tools, and the cache is not dirty
|
138
|
+
if (
|
139
|
+
self.cache_tools_list
|
140
|
+
and not self._cache_dirty
|
141
|
+
and self._tools_list
|
142
|
+
):
|
143
|
+
return self._tools_list
|
80
144
|
|
81
|
-
|
82
|
-
|
145
|
+
# Reset the cache dirty to False
|
146
|
+
self._cache_dirty = False
|
83
147
|
|
84
|
-
|
85
|
-
|
148
|
+
# Fetch the tools from the server
|
149
|
+
self._tools_list = (await self.session.list_tools()).tools
|
150
|
+
return self._tools_list
|
86
151
|
|
87
|
-
|
88
|
-
|
152
|
+
async def call_tool(
|
153
|
+
self, arguments: dict[str, Any] | None
|
154
|
+
) -> CallToolResult:
|
155
|
+
"""Invoke a tool on the server."""
|
156
|
+
tool_name = arguments.get("tool_name") or arguments.get(
|
157
|
+
"name"
|
158
|
+
)
|
89
159
|
|
90
|
-
|
91
|
-
|
160
|
+
if not tool_name:
|
161
|
+
raise Exception("No tool name found in arguments")
|
92
162
|
|
93
|
-
|
94
|
-
|
95
|
-
|
163
|
+
if not self.session:
|
164
|
+
raise Exception(
|
165
|
+
"Server not initialized. Make sure you call `connect()` first."
|
166
|
+
)
|
96
167
|
|
97
|
-
|
98
|
-
explanations of possible values
|
99
|
-
"""
|
168
|
+
return await self.session.call_tool(tool_name, arguments)
|
100
169
|
|
170
|
+
async def cleanup(self):
|
171
|
+
"""Cleanup the server."""
|
172
|
+
async with self._cleanup_lock:
|
173
|
+
try:
|
174
|
+
await self.exit_stack.aclose()
|
175
|
+
self.session = None
|
176
|
+
except Exception as e:
|
177
|
+
logger.error(f"Error cleaning up server: {e}")
|
101
178
|
|
102
|
-
class SSEConnection(TypedDict):
|
103
|
-
transport: Literal["sse"]
|
104
179
|
|
105
|
-
|
106
|
-
"""
|
180
|
+
class MCPServerStdioParams(TypedDict):
|
181
|
+
"""Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
|
182
|
+
import.
|
183
|
+
"""
|
107
184
|
|
108
|
-
|
109
|
-
"""
|
185
|
+
command: str
|
186
|
+
"""The executable to run to start the server. For example, `python` or `node`."""
|
110
187
|
|
111
|
-
|
112
|
-
"""
|
188
|
+
args: NotRequired[list[str]]
|
189
|
+
"""Command line args to pass to the `command` executable. For example, `['foo.py']` or
|
190
|
+
`['server.js', '--port', '8080']`."""
|
113
191
|
|
114
|
-
|
115
|
-
"""
|
192
|
+
env: NotRequired[dict[str, str]]
|
193
|
+
"""The environment variables to set for the server. ."""
|
116
194
|
|
195
|
+
cwd: NotRequired[str | Path]
|
196
|
+
"""The working directory to use when spawning the process."""
|
117
197
|
|
118
|
-
|
198
|
+
encoding: NotRequired[str]
|
199
|
+
"""The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
|
119
200
|
|
201
|
+
encoding_error_handler: NotRequired[
|
202
|
+
Literal["strict", "ignore", "replace"]
|
203
|
+
]
|
204
|
+
"""The text encoding error handler. Defaults to `strict`.
|
120
205
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
text_contents: list[TextContent] = []
|
125
|
-
non_text_contents = []
|
126
|
-
for content in call_tool_result.content:
|
127
|
-
if isinstance(content, TextContent):
|
128
|
-
text_contents.append(content)
|
129
|
-
else:
|
130
|
-
non_text_contents.append(content)
|
206
|
+
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
207
|
+
explanations of possible values.
|
208
|
+
"""
|
131
209
|
|
132
|
-
tool_content: str | list[str] = [
|
133
|
-
content.text for content in text_contents
|
134
|
-
]
|
135
|
-
if len(text_contents) == 1:
|
136
|
-
tool_content = tool_content[0]
|
137
210
|
|
138
|
-
|
139
|
-
|
211
|
+
class MCPServerStdio(_MCPServerWithClientSession):
|
212
|
+
"""MCP server implementation that uses the stdio transport. See the [spec]
|
213
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
|
214
|
+
details.
|
215
|
+
"""
|
140
216
|
|
141
|
-
|
217
|
+
def __init__(
|
218
|
+
self,
|
219
|
+
params: MCPServerStdioParams,
|
220
|
+
cache_tools_list: bool = False,
|
221
|
+
name: str | None = None,
|
222
|
+
):
|
223
|
+
"""Create a new MCP server based on the stdio transport.
|
142
224
|
|
225
|
+
Args:
|
226
|
+
params: The params that configure the server. This includes the command to run to
|
227
|
+
start the server, the args to pass to the command, the environment variables to
|
228
|
+
set for the server, the working directory to use when spawning the process, and
|
229
|
+
the text encoding used when sending/receiving messages to the server.
|
230
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
231
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
232
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
233
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
234
|
+
if you know the server will not change its tools list, because it can drastically
|
235
|
+
improve latency (by avoiding a round-trip to the server every time).
|
236
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
237
|
+
command.
|
238
|
+
"""
|
239
|
+
super().__init__(cache_tools_list)
|
240
|
+
|
241
|
+
self.params = StdioServerParameters(
|
242
|
+
command=params["command"],
|
243
|
+
args=params.get("args", []),
|
244
|
+
env=params.get("env"),
|
245
|
+
cwd=params.get("cwd"),
|
246
|
+
encoding=params.get("encoding", "utf-8"),
|
247
|
+
encoding_error_handler=params.get(
|
248
|
+
"encoding_error_handler", "strict"
|
249
|
+
),
|
250
|
+
)
|
143
251
|
|
144
|
-
|
145
|
-
session: ClientSession,
|
146
|
-
tool: MCPTool,
|
147
|
-
) -> Callable[
|
148
|
-
...,
|
149
|
-
Coroutine[
|
150
|
-
Any, Any, tuple[str | list[str], list[NonTextContent] | None]
|
151
|
-
],
|
152
|
-
]:
|
153
|
-
"""Convert an MCP tool to a callable function.
|
252
|
+
self._name = name or f"stdio: {self.params.command}"
|
154
253
|
|
155
|
-
|
254
|
+
def create_streams(
|
255
|
+
self,
|
256
|
+
) -> AbstractAsyncContextManager[
|
257
|
+
tuple[
|
258
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
259
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
260
|
+
]
|
261
|
+
]:
|
262
|
+
"""Create the streams for the server."""
|
263
|
+
return stdio_client(self.params)
|
156
264
|
|
157
|
-
|
158
|
-
|
159
|
-
|
265
|
+
@property
|
266
|
+
def name(self) -> str:
|
267
|
+
"""A readable name for the server."""
|
268
|
+
return self._name
|
160
269
|
|
161
|
-
Returns:
|
162
|
-
a callable function
|
163
|
-
"""
|
164
270
|
|
165
|
-
|
166
|
-
|
167
|
-
) -> tuple[str | list[str], list[NonTextContent] | None]:
|
168
|
-
"""Execute the tool with the given arguments."""
|
169
|
-
call_tool_result = await session.call_tool(
|
170
|
-
tool.name, arguments
|
171
|
-
)
|
172
|
-
return _convert_call_tool_result(call_tool_result)
|
271
|
+
class MCPServerSseParams(TypedDict):
|
272
|
+
"""Mirrors the params in`mcp.client.sse.sse_client`."""
|
173
273
|
|
174
|
-
|
175
|
-
|
176
|
-
call_tool.__doc__ = tool.description or ""
|
177
|
-
call_tool.schema = tool.inputSchema
|
274
|
+
url: str
|
275
|
+
"""The URL of the server."""
|
178
276
|
|
179
|
-
|
277
|
+
headers: NotRequired[dict[str, str]]
|
278
|
+
"""The headers to send to the server."""
|
180
279
|
|
280
|
+
timeout: NotRequired[float]
|
281
|
+
"""The timeout for the HTTP request. Defaults to 5 seconds."""
|
181
282
|
|
182
|
-
|
183
|
-
"""
|
184
|
-
tools = await session.list_tools()
|
185
|
-
return [
|
186
|
-
convert_mcp_tool_to_function(session, tool)
|
187
|
-
for tool in tools.tools
|
188
|
-
]
|
283
|
+
sse_read_timeout: NotRequired[float]
|
284
|
+
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
|
189
285
|
|
190
286
|
|
191
|
-
class
|
192
|
-
"""
|
287
|
+
class MCPServerSse(_MCPServerWithClientSession):
|
288
|
+
"""MCP server implementation that uses the HTTP with SSE transport. See the [spec]
|
289
|
+
(https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
|
290
|
+
for details.
|
291
|
+
"""
|
193
292
|
|
194
293
|
def __init__(
|
195
294
|
self,
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
)
|
200
|
-
"""
|
295
|
+
params: MCPServerSseParams,
|
296
|
+
cache_tools_list: bool = False,
|
297
|
+
name: str | None = None,
|
298
|
+
):
|
299
|
+
"""Create a new MCP server based on the HTTP with SSE transport.
|
201
300
|
|
202
301
|
Args:
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
"transport": "stdio",
|
217
|
-
},
|
218
|
-
"weather": {
|
219
|
-
# make sure you start your weather server on port 8000
|
220
|
-
"url": "http://localhost:8000/sse",
|
221
|
-
"transport": "sse",
|
222
|
-
}
|
223
|
-
}
|
224
|
-
) as client:
|
225
|
-
all_tools = client.get_tools()
|
226
|
-
...
|
227
|
-
```
|
302
|
+
params: The params that configure the server. This includes the URL of the server,
|
303
|
+
the headers to send to the server, the timeout for the HTTP request, and the
|
304
|
+
timeout for the SSE connection.
|
305
|
+
|
306
|
+
cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
|
307
|
+
cached and only fetched from the server once. If `False`, the tools list will be
|
308
|
+
fetched from the server on each call to `list_tools()`. The cache can be
|
309
|
+
invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
|
310
|
+
if you know the server will not change its tools list, because it can drastically
|
311
|
+
improve latency (by avoiding a round-trip to the server every time).
|
312
|
+
|
313
|
+
name: A readable name for the server. If not provided, we'll create one from the
|
314
|
+
URL.
|
228
315
|
"""
|
229
|
-
|
230
|
-
self.exit_stack = AsyncExitStack()
|
231
|
-
self.sessions: dict[str, ClientSession] = {}
|
232
|
-
self.server_name_to_tools: dict[str, list[Callable]] = {}
|
316
|
+
super().__init__(cache_tools_list)
|
233
317
|
|
234
|
-
|
235
|
-
self
|
236
|
-
) -> None:
|
237
|
-
"""Initialize a session and load tools from it.
|
318
|
+
self.params = params
|
319
|
+
self._name = name or f"sse: {self.params['url']}"
|
238
320
|
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
321
|
+
def create_streams(
|
322
|
+
self,
|
323
|
+
) -> AbstractAsyncContextManager[
|
324
|
+
tuple[
|
325
|
+
MemoryObjectReceiveStream[JSONRPCMessage | Exception],
|
326
|
+
MemoryObjectSendStream[JSONRPCMessage],
|
327
|
+
]
|
328
|
+
]:
|
329
|
+
"""Create the streams for the server."""
|
330
|
+
return sse_client(
|
331
|
+
url=self.params["url"],
|
332
|
+
headers=self.params.get("headers", None),
|
333
|
+
timeout=self.params.get("timeout", 5),
|
334
|
+
sse_read_timeout=self.params.get(
|
335
|
+
"sse_read_timeout", 60 * 5
|
336
|
+
),
|
337
|
+
)
|
246
338
|
|
247
|
-
|
248
|
-
|
249
|
-
|
339
|
+
@property
|
340
|
+
def name(self) -> str:
|
341
|
+
"""A readable name for the server."""
|
342
|
+
return self._name
|
250
343
|
|
251
|
-
async def connect_to_server(
|
252
|
-
self,
|
253
|
-
server_name: str,
|
254
|
-
*,
|
255
|
-
transport: Literal["stdio", "sse"] = "stdio",
|
256
|
-
**kwargs,
|
257
|
-
) -> None:
|
258
|
-
"""Connect to an MCP server using either stdio or SSE.
|
259
344
|
|
260
|
-
|
261
|
-
|
345
|
+
def mcp_flow_get_tool_schema(
|
346
|
+
params: MCPServerSseParams,
|
347
|
+
) -> MCPServer:
|
348
|
+
server = MCPServerSse(params, cache_tools_list=True)
|
262
349
|
|
263
|
-
|
264
|
-
|
265
|
-
transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
|
266
|
-
**kwargs: Additional arguments to pass to the specific connection method
|
350
|
+
# Connect the server
|
351
|
+
asyncio.run(server.connect())
|
267
352
|
|
268
|
-
|
269
|
-
|
270
|
-
ValueError: If required parameters for the specified transport are missing
|
271
|
-
"""
|
272
|
-
if transport == "sse":
|
273
|
-
if "url" not in kwargs:
|
274
|
-
raise ValueError(
|
275
|
-
"'url' parameter is required for SSE connection"
|
276
|
-
)
|
277
|
-
await self.connect_to_server_via_sse(
|
278
|
-
server_name,
|
279
|
-
url=kwargs["url"],
|
280
|
-
headers=kwargs.get("headers"),
|
281
|
-
timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
|
282
|
-
sse_read_timeout=kwargs.get(
|
283
|
-
"sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
|
284
|
-
),
|
285
|
-
)
|
286
|
-
elif transport == "stdio":
|
287
|
-
if "command" not in kwargs:
|
288
|
-
raise ValueError(
|
289
|
-
"'command' parameter is required for stdio connection"
|
290
|
-
)
|
291
|
-
if "args" not in kwargs:
|
292
|
-
raise ValueError(
|
293
|
-
"'args' parameter is required for stdio connection"
|
294
|
-
)
|
295
|
-
await self.connect_to_server_via_stdio(
|
296
|
-
server_name,
|
297
|
-
command=kwargs["command"],
|
298
|
-
args=kwargs["args"],
|
299
|
-
env=kwargs.get("env"),
|
300
|
-
encoding=kwargs.get("encoding", DEFAULT_ENCODING),
|
301
|
-
encoding_error_handler=kwargs.get(
|
302
|
-
"encoding_error_handler",
|
303
|
-
DEFAULT_ENCODING_ERROR_HANDLER,
|
304
|
-
),
|
305
|
-
)
|
306
|
-
else:
|
307
|
-
raise ValueError(
|
308
|
-
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
309
|
-
)
|
353
|
+
# Return the server
|
354
|
+
output = asyncio.run(server.list_tools())
|
310
355
|
|
311
|
-
|
312
|
-
|
313
|
-
server_name: str,
|
314
|
-
*,
|
315
|
-
command: str,
|
316
|
-
args: list[str],
|
317
|
-
env: dict[str, str] | None = None,
|
318
|
-
encoding: str = DEFAULT_ENCODING,
|
319
|
-
encoding_error_handler: Literal[
|
320
|
-
"strict", "ignore", "replace"
|
321
|
-
] = DEFAULT_ENCODING_ERROR_HANDLER,
|
322
|
-
) -> None:
|
323
|
-
"""Connect to a specific MCP server using stdio
|
356
|
+
# Cleanup the server
|
357
|
+
asyncio.run(server.cleanup())
|
324
358
|
|
325
|
-
|
326
|
-
server_name: Name to identify this server connection
|
327
|
-
command: Command to execute
|
328
|
-
args: Arguments for the command
|
329
|
-
env: Environment variables for the command
|
330
|
-
encoding: Character encoding
|
331
|
-
encoding_error_handler: How to handle encoding errors
|
332
|
-
"""
|
333
|
-
server_params = StdioServerParameters(
|
334
|
-
command=command,
|
335
|
-
args=args,
|
336
|
-
env=env,
|
337
|
-
encoding=encoding,
|
338
|
-
encoding_error_handler=encoding_error_handler,
|
339
|
-
)
|
359
|
+
return output.model_dump()
|
340
360
|
|
341
|
-
# Create and store the connection
|
342
|
-
stdio_transport = await self.exit_stack.enter_async_context(
|
343
|
-
stdio_client(server_params)
|
344
|
-
)
|
345
|
-
read, write = stdio_transport
|
346
|
-
session = cast(
|
347
|
-
ClientSession,
|
348
|
-
await self.exit_stack.enter_async_context(
|
349
|
-
ClientSession(read, write)
|
350
|
-
),
|
351
|
-
)
|
352
361
|
|
353
|
-
|
354
|
-
|
355
|
-
|
362
|
+
def mcp_flow(
|
363
|
+
params: MCPServerSseParams,
|
364
|
+
function_call: dict[str, Any],
|
365
|
+
) -> MCPServer:
|
366
|
+
server = MCPServerSse(params, cache_tools_list=True)
|
356
367
|
|
357
|
-
|
358
|
-
|
359
|
-
server_name: str,
|
360
|
-
*,
|
361
|
-
url: str,
|
362
|
-
headers: dict[str, Any] | None = None,
|
363
|
-
timeout: float = DEFAULT_HTTP_TIMEOUT,
|
364
|
-
sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
|
365
|
-
) -> None:
|
366
|
-
"""Connect to a specific MCP server using SSE
|
368
|
+
# Connect the server
|
369
|
+
asyncio.run(server.connect())
|
367
370
|
|
368
|
-
|
369
|
-
|
370
|
-
url: URL of the SSE server
|
371
|
-
headers: HTTP headers to send to the SSE endpoint
|
372
|
-
timeout: HTTP timeout
|
373
|
-
sse_read_timeout: SSE read timeout
|
374
|
-
"""
|
375
|
-
# Create and store the connection
|
376
|
-
sse_transport = await self.exit_stack.enter_async_context(
|
377
|
-
sse_client(url, headers, timeout, sse_read_timeout)
|
378
|
-
)
|
379
|
-
read, write = sse_transport
|
380
|
-
session = cast(
|
381
|
-
ClientSession,
|
382
|
-
await self.exit_stack.enter_async_context(
|
383
|
-
ClientSession(read, write)
|
384
|
-
),
|
385
|
-
)
|
371
|
+
# Return the server
|
372
|
+
output = asyncio.run(server.call_tool(function_call))
|
386
373
|
|
387
|
-
|
388
|
-
server_name, session
|
389
|
-
)
|
374
|
+
output = output.model_dump()
|
390
375
|
|
391
|
-
|
392
|
-
|
393
|
-
all_tools: list[Callable] = []
|
394
|
-
for server_tools in self.server_name_to_tools.values():
|
395
|
-
all_tools.extend(server_tools)
|
396
|
-
return all_tools
|
376
|
+
# Cleanup the server
|
377
|
+
asyncio.run(server.cleanup())
|
397
378
|
|
398
|
-
|
399
|
-
self,
|
400
|
-
server_name: str,
|
401
|
-
prompt_name: str,
|
402
|
-
arguments: Optional[dict[str, Any]] = None,
|
403
|
-
) -> List[str]:
|
404
|
-
"""Get a prompt from a given MCP server."""
|
405
|
-
session = self.sessions[server_name]
|
406
|
-
return await load_mcp_prompt(session, prompt_name, arguments)
|
407
|
-
|
408
|
-
async def __aenter__(self) -> "MultiServerMCPClient":
|
409
|
-
try:
|
410
|
-
connections = self.connections or {}
|
411
|
-
for server_name, connection in connections.items():
|
412
|
-
connection_dict = connection.copy()
|
413
|
-
transport = connection_dict.pop("transport")
|
414
|
-
if transport == "stdio":
|
415
|
-
await self.connect_to_server_via_stdio(
|
416
|
-
server_name, **connection_dict
|
417
|
-
)
|
418
|
-
elif transport == "sse":
|
419
|
-
await self.connect_to_server_via_sse(
|
420
|
-
server_name, **connection_dict
|
421
|
-
)
|
422
|
-
else:
|
423
|
-
raise ValueError(
|
424
|
-
f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
|
425
|
-
)
|
426
|
-
return self
|
427
|
-
except Exception:
|
428
|
-
await self.exit_stack.aclose()
|
429
|
-
raise
|
379
|
+
return any_to_str(output)
|
430
380
|
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
# import json
|
444
|
-
# from typing import List, Any, Callable
|
445
|
-
|
446
|
-
# # # Import our MCP client module
|
447
|
-
# # from mcp_client import MultiServerMCPClient
|
448
|
-
|
449
|
-
|
450
|
-
# async def main():
|
451
|
-
# """Test script for demonstrating MCP client usage."""
|
452
|
-
# print("Starting MCP Client test...")
|
453
|
-
|
454
|
-
# # Create a connection to multiple MCP servers
|
455
|
-
# # You'll need to update these paths to match your setup
|
456
|
-
# async with MultiServerMCPClient(
|
457
|
-
# {
|
458
|
-
# "math": {
|
459
|
-
# "transport": "stdio",
|
460
|
-
# "command": "python",
|
461
|
-
# "args": ["/path/to/math_server.py"],
|
462
|
-
# "env": {"DEBUG": "1"},
|
463
|
-
# },
|
464
|
-
# "search": {
|
465
|
-
# "transport": "sse",
|
466
|
-
# "url": "http://localhost:8000/sse",
|
467
|
-
# "headers": {
|
468
|
-
# "Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
|
469
|
-
# },
|
470
|
-
# },
|
471
|
-
# }
|
472
|
-
# ) as client:
|
473
|
-
# # Get all available tools
|
474
|
-
# tools = client.get_tools()
|
475
|
-
# print(f"Found {len(tools)} tools across all servers")
|
476
|
-
|
477
|
-
# # Print tool information
|
478
|
-
# for i, tool in enumerate(tools):
|
479
|
-
# print(f"\nTool {i+1}: {tool.__name__}")
|
480
|
-
# print(f" Description: {tool.__doc__}")
|
481
|
-
# if hasattr(tool, "schema") and tool.schema:
|
482
|
-
# print(
|
483
|
-
# f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
|
484
|
-
# )
|
485
|
-
|
486
|
-
# # Example: Use a specific tool if available
|
487
|
-
# calculator_tool = next(
|
488
|
-
# (t for t in tools if t.__name__ == "calculator"), None
|
489
|
-
# )
|
490
|
-
# if calculator_tool:
|
491
|
-
# print("\n\nTesting calculator tool:")
|
492
|
-
# try:
|
493
|
-
# # Call the tool as an async function
|
494
|
-
# result, artifacts = await calculator_tool(
|
495
|
-
# expression="2 + 2 * 3"
|
496
|
-
# )
|
497
|
-
# print(f" Calculator result: {result}")
|
498
|
-
# if artifacts:
|
499
|
-
# print(
|
500
|
-
# f" With {len(artifacts)} additional artifacts"
|
501
|
-
# )
|
502
|
-
# except Exception as e:
|
503
|
-
# print(f" Error using calculator: {e}")
|
504
|
-
|
505
|
-
# # Example: Load a prompt from a server
|
506
|
-
# try:
|
507
|
-
# print("\n\nTesting prompt loading:")
|
508
|
-
# prompt_messages = await client.get_prompt(
|
509
|
-
# "math",
|
510
|
-
# "calculation_introduction",
|
511
|
-
# {"user_name": "Test User"},
|
512
|
-
# )
|
513
|
-
# print(
|
514
|
-
# f" Loaded prompt with {len(prompt_messages)} messages:"
|
515
|
-
# )
|
516
|
-
# for i, msg in enumerate(prompt_messages):
|
517
|
-
# print(f" Message {i+1}: {msg[:50]}...")
|
518
|
-
# except Exception as e:
|
519
|
-
# print(f" Error loading prompt: {e}")
|
520
|
-
|
521
|
-
|
522
|
-
# async def create_custom_tool():
|
523
|
-
# """Example of creating a custom tool function."""
|
524
|
-
|
525
|
-
# # Define a tool function with metadata
|
526
|
-
# async def add_numbers(a: float, b: float) -> tuple[str, None]:
|
527
|
-
# """Add two numbers together."""
|
528
|
-
# result = a + b
|
529
|
-
# return f"The sum of {a} and {b} is {result}", None
|
530
|
-
|
531
|
-
# # Add metadata to the function
|
532
|
-
# add_numbers.__name__ = "add_numbers"
|
533
|
-
# add_numbers.__doc__ = (
|
534
|
-
# "Add two numbers together and return the result."
|
535
|
-
# )
|
536
|
-
# add_numbers.schema = {
|
537
|
-
# "type": "object",
|
538
|
-
# "properties": {
|
539
|
-
# "a": {"type": "number", "description": "First number"},
|
540
|
-
# "b": {"type": "number", "description": "Second number"},
|
541
|
-
# },
|
542
|
-
# "required": ["a", "b"],
|
543
|
-
# }
|
544
|
-
|
545
|
-
# # Use the tool
|
546
|
-
# result, _ = await add_numbers(a=5, b=7)
|
547
|
-
# print(f"\nCustom tool result: {result}")
|
548
|
-
|
549
|
-
|
550
|
-
# if __name__ == "__main__":
|
551
|
-
# # Run both examples
|
552
|
-
# loop = asyncio.get_event_loop()
|
553
|
-
# loop.run_until_complete(main())
|
554
|
-
# loop.run_until_complete(create_custom_tool())
|
381
|
+
|
382
|
+
def batch_mcp_flow(
|
383
|
+
params: List[MCPServerSseParams],
|
384
|
+
function_call: List[dict[str, Any]] = [],
|
385
|
+
) -> MCPServer:
|
386
|
+
output_list = []
|
387
|
+
|
388
|
+
for param in params:
|
389
|
+
output = mcp_flow(param, function_call)
|
390
|
+
output_list.append(output)
|
391
|
+
|
392
|
+
return output_list
|