swarms 7.6.4__py3-none-any.whl → 7.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,554 +1,392 @@
1
- from contextlib import AsyncExitStack
2
- from types import TracebackType
3
- from typing import (
4
- Any,
5
- Callable,
6
- Coroutine,
7
- List,
8
- Literal,
9
- Optional,
10
- TypedDict,
11
- cast,
12
- )
1
+ from __future__ import annotations
13
2
 
14
- from mcp import ClientSession, StdioServerParameters
15
- from mcp.client.sse import sse_client
16
- from mcp.client.stdio import stdio_client
17
- from mcp.types import (
18
- CallToolResult,
19
- EmbeddedResource,
20
- ImageContent,
21
- PromptMessage,
22
- TextContent,
3
+ from typing import Any, List
4
+
5
+
6
+ from loguru import logger
7
+
8
+ import abc
9
+ import asyncio
10
+ from contextlib import AbstractAsyncContextManager, AsyncExitStack
11
+ from pathlib import Path
12
+ from typing import Literal
13
+
14
+ from anyio.streams.memory import (
15
+ MemoryObjectReceiveStream,
16
+ MemoryObjectSendStream,
23
17
  )
24
- from mcp.types import (
18
+ from mcp import (
19
+ ClientSession,
20
+ StdioServerParameters,
25
21
  Tool as MCPTool,
22
+ stdio_client,
26
23
  )
24
+ from mcp.client.sse import sse_client
25
+ from mcp.types import CallToolResult, JSONRPCMessage
26
+ from typing_extensions import NotRequired, TypedDict
27
27
 
28
+ from swarms.utils.any_to_str import any_to_str
28
29
 
29
- def convert_mcp_prompt_message_to_message(
30
- message: PromptMessage,
31
- ) -> str:
32
- """Convert an MCP prompt message to a string message.
33
30
 
34
- Args:
35
- message: MCP prompt message to convert
31
+ class MCPServer(abc.ABC):
32
+ """Base class for Model Context Protocol servers."""
36
33
 
37
- Returns:
38
- a string message
39
- """
40
- if message.content.type == "text":
41
- if message.role == "user":
42
- return str(message.content.text)
43
- elif message.role == "assistant":
44
- return str(
45
- message.content.text
46
- ) # Fixed attribute name from str to text
47
- else:
48
- raise ValueError(
49
- f"Unsupported prompt message role: {message.role}"
50
- )
34
+ @abc.abstractmethod
35
+ async def connect(self):
36
+ """Connect to the server. For example, this might mean spawning a subprocess or
37
+ opening a network connection. The server is expected to remain connected until
38
+ `cleanup()` is called.
39
+ """
40
+ pass
41
+
42
+ @property
43
+ @abc.abstractmethod
44
+ def name(self) -> str:
45
+ """A readable name for the server."""
46
+ pass
47
+
48
+ @abc.abstractmethod
49
+ async def cleanup(self):
50
+ """Cleanup the server. For example, this might mean closing a subprocess or
51
+ closing a network connection.
52
+ """
53
+ pass
51
54
 
52
- raise ValueError(
53
- f"Unsupported prompt message content type: {message.content.type}"
54
- )
55
+ @abc.abstractmethod
56
+ async def list_tools(self) -> list[MCPTool]:
57
+ """List the tools available on the server."""
58
+ pass
55
59
 
60
+ @abc.abstractmethod
61
+ async def call_tool(
62
+ self, tool_name: str, arguments: dict[str, Any] | None
63
+ ) -> CallToolResult:
64
+ """Invoke a tool on the server."""
65
+ pass
56
66
 
57
- async def load_mcp_prompt(
58
- session: ClientSession,
59
- name: str,
60
- arguments: Optional[dict[str, Any]] = None,
61
- ) -> List[str]:
62
- """Load MCP prompt and convert to messages."""
63
- response = await session.get_prompt(name, arguments)
64
67
 
65
- return [
66
- convert_mcp_prompt_message_to_message(message)
67
- for message in response.messages
68
- ]
68
+ class _MCPServerWithClientSession(MCPServer, abc.ABC):
69
+ """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
69
70
 
71
+ def __init__(self, cache_tools_list: bool):
72
+ """
73
+ Args:
74
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
75
+ cached and only fetched from the server once. If `False`, the tools list will be
76
+ fetched from the server on each call to `list_tools()`. The cache can be invalidated
77
+ by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
78
+ server will not change its tools list, because it can drastically improve latency
79
+ (by avoiding a round-trip to the server every time).
80
+ """
81
+ self.session: ClientSession | None = None
82
+ self.exit_stack: AsyncExitStack = AsyncExitStack()
83
+ self._cleanup_lock: asyncio.Lock = asyncio.Lock()
84
+ self.cache_tools_list = cache_tools_list
70
85
 
71
- DEFAULT_ENCODING = "utf-8"
72
- DEFAULT_ENCODING_ERROR_HANDLER = "strict"
86
+ # The cache is always dirty at startup, so that we fetch tools at least once
87
+ self._cache_dirty = True
88
+ self._tools_list: list[MCPTool] | None = None
73
89
 
74
- DEFAULT_HTTP_TIMEOUT = 5
75
- DEFAULT_SSE_READ_TIMEOUT = 60 * 5
90
+ @abc.abstractmethod
91
+ def create_streams(
92
+ self,
93
+ ) -> AbstractAsyncContextManager[
94
+ tuple[
95
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
96
+ MemoryObjectSendStream[JSONRPCMessage],
97
+ ]
98
+ ]:
99
+ """Create the streams for the server."""
100
+ pass
101
+
102
+ async def __aenter__(self):
103
+ await self.connect()
104
+ return self
105
+
106
+ async def __aexit__(self, exc_type, exc_value, traceback):
107
+ await self.cleanup()
108
+
109
+ def invalidate_tools_cache(self):
110
+ """Invalidate the tools cache."""
111
+ self._cache_dirty = True
112
+
113
+ async def connect(self):
114
+ """Connect to the server."""
115
+ try:
116
+ transport = await self.exit_stack.enter_async_context(
117
+ self.create_streams()
118
+ )
119
+ read, write = transport
120
+ session = await self.exit_stack.enter_async_context(
121
+ ClientSession(read, write)
122
+ )
123
+ await session.initialize()
124
+ self.session = session
125
+ except Exception as e:
126
+ logger.error(f"Error initializing MCP server: {e}")
127
+ await self.cleanup()
128
+ raise
76
129
 
130
+ async def list_tools(self) -> list[MCPTool]:
131
+ """List the tools available on the server."""
132
+ if not self.session:
133
+ raise Exception(
134
+ "Server not initialized. Make sure you call `connect()` first."
135
+ )
77
136
 
78
- class StdioConnection(TypedDict):
79
- transport: Literal["stdio"]
137
+ # Return from cache if caching is enabled, we have tools, and the cache is not dirty
138
+ if (
139
+ self.cache_tools_list
140
+ and not self._cache_dirty
141
+ and self._tools_list
142
+ ):
143
+ return self._tools_list
80
144
 
81
- command: str
82
- """The executable to run to start the server."""
145
+ # Reset the cache dirty to False
146
+ self._cache_dirty = False
83
147
 
84
- args: list[str]
85
- """Command line arguments to pass to the executable."""
148
+ # Fetch the tools from the server
149
+ self._tools_list = (await self.session.list_tools()).tools
150
+ return self._tools_list
86
151
 
87
- env: dict[str, str] | None
88
- """The environment to use when spawning the process."""
152
+ async def call_tool(
153
+ self, arguments: dict[str, Any] | None
154
+ ) -> CallToolResult:
155
+ """Invoke a tool on the server."""
156
+ tool_name = arguments.get("tool_name") or arguments.get(
157
+ "name"
158
+ )
89
159
 
90
- encoding: str
91
- """The text encoding used when sending/receiving messages to the server."""
160
+ if not tool_name:
161
+ raise Exception("No tool name found in arguments")
92
162
 
93
- encoding_error_handler: Literal["strict", "ignore", "replace"]
94
- """
95
- The text encoding error handler.
163
+ if not self.session:
164
+ raise Exception(
165
+ "Server not initialized. Make sure you call `connect()` first."
166
+ )
96
167
 
97
- See https://docs.python.org/3/library/codecs.html#codec-base-classes for
98
- explanations of possible values
99
- """
168
+ return await self.session.call_tool(tool_name, arguments)
100
169
 
170
+ async def cleanup(self):
171
+ """Cleanup the server."""
172
+ async with self._cleanup_lock:
173
+ try:
174
+ await self.exit_stack.aclose()
175
+ self.session = None
176
+ except Exception as e:
177
+ logger.error(f"Error cleaning up server: {e}")
101
178
 
102
- class SSEConnection(TypedDict):
103
- transport: Literal["sse"]
104
179
 
105
- url: str
106
- """The URL of the SSE endpoint to connect to."""
180
+ class MCPServerStdioParams(TypedDict):
181
+ """Mirrors `mcp.client.stdio.StdioServerParameters`, but lets you pass params without another
182
+ import.
183
+ """
107
184
 
108
- headers: dict[str, Any] | None
109
- """HTTP headers to send to the SSE endpoint"""
185
+ command: str
186
+ """The executable to run to start the server. For example, `python` or `node`."""
110
187
 
111
- timeout: float
112
- """HTTP timeout"""
188
+ args: NotRequired[list[str]]
189
+ """Command line args to pass to the `command` executable. For example, `['foo.py']` or
190
+ `['server.js', '--port', '8080']`."""
113
191
 
114
- sse_read_timeout: float
115
- """SSE read timeout"""
192
+ env: NotRequired[dict[str, str]]
193
+ """The environment variables to set for the server. ."""
116
194
 
195
+ cwd: NotRequired[str | Path]
196
+ """The working directory to use when spawning the process."""
117
197
 
118
- NonTextContent = ImageContent | EmbeddedResource
198
+ encoding: NotRequired[str]
199
+ """The text encoding used when sending/receiving messages to the server. Defaults to `utf-8`."""
119
200
 
201
+ encoding_error_handler: NotRequired[
202
+ Literal["strict", "ignore", "replace"]
203
+ ]
204
+ """The text encoding error handler. Defaults to `strict`.
120
205
 
121
- def _convert_call_tool_result(
122
- call_tool_result: CallToolResult,
123
- ) -> tuple[str | list[str], list[NonTextContent] | None]:
124
- text_contents: list[TextContent] = []
125
- non_text_contents = []
126
- for content in call_tool_result.content:
127
- if isinstance(content, TextContent):
128
- text_contents.append(content)
129
- else:
130
- non_text_contents.append(content)
206
+ See https://docs.python.org/3/library/codecs.html#codec-base-classes for
207
+ explanations of possible values.
208
+ """
131
209
 
132
- tool_content: str | list[str] = [
133
- content.text for content in text_contents
134
- ]
135
- if len(text_contents) == 1:
136
- tool_content = tool_content[0]
137
210
 
138
- if call_tool_result.isError:
139
- raise ValueError("Error calling tool")
211
+ class MCPServerStdio(_MCPServerWithClientSession):
212
+ """MCP server implementation that uses the stdio transport. See the [spec]
213
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) for
214
+ details.
215
+ """
140
216
 
141
- return tool_content, non_text_contents or None
217
+ def __init__(
218
+ self,
219
+ params: MCPServerStdioParams,
220
+ cache_tools_list: bool = False,
221
+ name: str | None = None,
222
+ ):
223
+ """Create a new MCP server based on the stdio transport.
142
224
 
225
+ Args:
226
+ params: The params that configure the server. This includes the command to run to
227
+ start the server, the args to pass to the command, the environment variables to
228
+ set for the server, the working directory to use when spawning the process, and
229
+ the text encoding used when sending/receiving messages to the server.
230
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
231
+ cached and only fetched from the server once. If `False`, the tools list will be
232
+ fetched from the server on each call to `list_tools()`. The cache can be
233
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
234
+ if you know the server will not change its tools list, because it can drastically
235
+ improve latency (by avoiding a round-trip to the server every time).
236
+ name: A readable name for the server. If not provided, we'll create one from the
237
+ command.
238
+ """
239
+ super().__init__(cache_tools_list)
240
+
241
+ self.params = StdioServerParameters(
242
+ command=params["command"],
243
+ args=params.get("args", []),
244
+ env=params.get("env"),
245
+ cwd=params.get("cwd"),
246
+ encoding=params.get("encoding", "utf-8"),
247
+ encoding_error_handler=params.get(
248
+ "encoding_error_handler", "strict"
249
+ ),
250
+ )
143
251
 
144
- def convert_mcp_tool_to_function(
145
- session: ClientSession,
146
- tool: MCPTool,
147
- ) -> Callable[
148
- ...,
149
- Coroutine[
150
- Any, Any, tuple[str | list[str], list[NonTextContent] | None]
151
- ],
152
- ]:
153
- """Convert an MCP tool to a callable function.
252
+ self._name = name or f"stdio: {self.params.command}"
154
253
 
155
- NOTE: this tool can be executed only in a context of an active MCP client session.
254
+ def create_streams(
255
+ self,
256
+ ) -> AbstractAsyncContextManager[
257
+ tuple[
258
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
259
+ MemoryObjectSendStream[JSONRPCMessage],
260
+ ]
261
+ ]:
262
+ """Create the streams for the server."""
263
+ return stdio_client(self.params)
156
264
 
157
- Args:
158
- session: MCP client session
159
- tool: MCP tool to convert
265
+ @property
266
+ def name(self) -> str:
267
+ """A readable name for the server."""
268
+ return self._name
160
269
 
161
- Returns:
162
- a callable function
163
- """
164
270
 
165
- async def call_tool(
166
- **arguments: dict[str, Any],
167
- ) -> tuple[str | list[str], list[NonTextContent] | None]:
168
- """Execute the tool with the given arguments."""
169
- call_tool_result = await session.call_tool(
170
- tool.name, arguments
171
- )
172
- return _convert_call_tool_result(call_tool_result)
271
+ class MCPServerSseParams(TypedDict):
272
+ """Mirrors the params in`mcp.client.sse.sse_client`."""
173
273
 
174
- # Add metadata as attributes to the function
175
- call_tool.__name__ = tool.name
176
- call_tool.__doc__ = tool.description or ""
177
- call_tool.schema = tool.inputSchema
274
+ url: str
275
+ """The URL of the server."""
178
276
 
179
- return call_tool
277
+ headers: NotRequired[dict[str, str]]
278
+ """The headers to send to the server."""
180
279
 
280
+ timeout: NotRequired[float]
281
+ """The timeout for the HTTP request. Defaults to 5 seconds."""
181
282
 
182
- async def load_mcp_tools(session: ClientSession) -> list[Callable]:
183
- """Load all available MCP tools and convert them to callable functions."""
184
- tools = await session.list_tools()
185
- return [
186
- convert_mcp_tool_to_function(session, tool)
187
- for tool in tools.tools
188
- ]
283
+ sse_read_timeout: NotRequired[float]
284
+ """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
189
285
 
190
286
 
191
- class MultiServerMCPClient:
192
- """Client for connecting to multiple MCP servers and loading tools from them."""
287
+ class MCPServerSse(_MCPServerWithClientSession):
288
+ """MCP server implementation that uses the HTTP with SSE transport. See the [spec]
289
+ (https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse)
290
+ for details.
291
+ """
193
292
 
194
293
  def __init__(
195
294
  self,
196
- connections: dict[
197
- str, StdioConnection | SSEConnection
198
- ] = None,
199
- ) -> None:
200
- """Initialize a MultiServerMCPClient with MCP servers connections.
295
+ params: MCPServerSseParams,
296
+ cache_tools_list: bool = False,
297
+ name: str | None = None,
298
+ ):
299
+ """Create a new MCP server based on the HTTP with SSE transport.
201
300
 
202
301
  Args:
203
- connections: A dictionary mapping server names to connection configurations.
204
- Each configuration can be either a StdioConnection or SSEConnection.
205
- If None, no initial connections are established.
206
-
207
- Example:
208
-
209
- ```python
210
- async with MultiServerMCPClient(
211
- {
212
- "math": {
213
- "command": "python",
214
- # Make sure to update to the full absolute path to your math_server.py file
215
- "args": ["/path/to/math_server.py"],
216
- "transport": "stdio",
217
- },
218
- "weather": {
219
- # make sure you start your weather server on port 8000
220
- "url": "http://localhost:8000/sse",
221
- "transport": "sse",
222
- }
223
- }
224
- ) as client:
225
- all_tools = client.get_tools()
226
- ...
227
- ```
302
+ params: The params that configure the server. This includes the URL of the server,
303
+ the headers to send to the server, the timeout for the HTTP request, and the
304
+ timeout for the SSE connection.
305
+
306
+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
307
+ cached and only fetched from the server once. If `False`, the tools list will be
308
+ fetched from the server on each call to `list_tools()`. The cache can be
309
+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
310
+ if you know the server will not change its tools list, because it can drastically
311
+ improve latency (by avoiding a round-trip to the server every time).
312
+
313
+ name: A readable name for the server. If not provided, we'll create one from the
314
+ URL.
228
315
  """
229
- self.connections = connections
230
- self.exit_stack = AsyncExitStack()
231
- self.sessions: dict[str, ClientSession] = {}
232
- self.server_name_to_tools: dict[str, list[Callable]] = {}
316
+ super().__init__(cache_tools_list)
233
317
 
234
- async def _initialize_session_and_load_tools(
235
- self, server_name: str, session: ClientSession
236
- ) -> None:
237
- """Initialize a session and load tools from it.
318
+ self.params = params
319
+ self._name = name or f"sse: {self.params['url']}"
238
320
 
239
- Args:
240
- server_name: Name to identify this server connection
241
- session: The ClientSession to initialize
242
- """
243
- # Initialize the session
244
- await session.initialize()
245
- self.sessions[server_name] = session
321
+ def create_streams(
322
+ self,
323
+ ) -> AbstractAsyncContextManager[
324
+ tuple[
325
+ MemoryObjectReceiveStream[JSONRPCMessage | Exception],
326
+ MemoryObjectSendStream[JSONRPCMessage],
327
+ ]
328
+ ]:
329
+ """Create the streams for the server."""
330
+ return sse_client(
331
+ url=self.params["url"],
332
+ headers=self.params.get("headers", None),
333
+ timeout=self.params.get("timeout", 5),
334
+ sse_read_timeout=self.params.get(
335
+ "sse_read_timeout", 60 * 5
336
+ ),
337
+ )
246
338
 
247
- # Load tools from this server
248
- server_tools = await load_mcp_tools(session)
249
- self.server_name_to_tools[server_name] = server_tools
339
+ @property
340
+ def name(self) -> str:
341
+ """A readable name for the server."""
342
+ return self._name
250
343
 
251
- async def connect_to_server(
252
- self,
253
- server_name: str,
254
- *,
255
- transport: Literal["stdio", "sse"] = "stdio",
256
- **kwargs,
257
- ) -> None:
258
- """Connect to an MCP server using either stdio or SSE.
259
344
 
260
- This is a generic method that calls either connect_to_server_via_stdio or connect_to_server_via_sse
261
- based on the provided transport parameter.
345
+ def mcp_flow_get_tool_schema(
346
+ params: MCPServerSseParams,
347
+ ) -> MCPServer:
348
+ server = MCPServerSse(params, cache_tools_list=True)
262
349
 
263
- Args:
264
- server_name: Name to identify this server connection
265
- transport: Type of transport to use ("stdio" or "sse"), defaults to "stdio"
266
- **kwargs: Additional arguments to pass to the specific connection method
350
+ # Connect the server
351
+ asyncio.run(server.connect())
267
352
 
268
- Raises:
269
- ValueError: If transport is not recognized
270
- ValueError: If required parameters for the specified transport are missing
271
- """
272
- if transport == "sse":
273
- if "url" not in kwargs:
274
- raise ValueError(
275
- "'url' parameter is required for SSE connection"
276
- )
277
- await self.connect_to_server_via_sse(
278
- server_name,
279
- url=kwargs["url"],
280
- headers=kwargs.get("headers"),
281
- timeout=kwargs.get("timeout", DEFAULT_HTTP_TIMEOUT),
282
- sse_read_timeout=kwargs.get(
283
- "sse_read_timeout", DEFAULT_SSE_READ_TIMEOUT
284
- ),
285
- )
286
- elif transport == "stdio":
287
- if "command" not in kwargs:
288
- raise ValueError(
289
- "'command' parameter is required for stdio connection"
290
- )
291
- if "args" not in kwargs:
292
- raise ValueError(
293
- "'args' parameter is required for stdio connection"
294
- )
295
- await self.connect_to_server_via_stdio(
296
- server_name,
297
- command=kwargs["command"],
298
- args=kwargs["args"],
299
- env=kwargs.get("env"),
300
- encoding=kwargs.get("encoding", DEFAULT_ENCODING),
301
- encoding_error_handler=kwargs.get(
302
- "encoding_error_handler",
303
- DEFAULT_ENCODING_ERROR_HANDLER,
304
- ),
305
- )
306
- else:
307
- raise ValueError(
308
- f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
309
- )
353
+ # Return the server
354
+ output = asyncio.run(server.list_tools())
310
355
 
311
- async def connect_to_server_via_stdio(
312
- self,
313
- server_name: str,
314
- *,
315
- command: str,
316
- args: list[str],
317
- env: dict[str, str] | None = None,
318
- encoding: str = DEFAULT_ENCODING,
319
- encoding_error_handler: Literal[
320
- "strict", "ignore", "replace"
321
- ] = DEFAULT_ENCODING_ERROR_HANDLER,
322
- ) -> None:
323
- """Connect to a specific MCP server using stdio
356
+ # Cleanup the server
357
+ asyncio.run(server.cleanup())
324
358
 
325
- Args:
326
- server_name: Name to identify this server connection
327
- command: Command to execute
328
- args: Arguments for the command
329
- env: Environment variables for the command
330
- encoding: Character encoding
331
- encoding_error_handler: How to handle encoding errors
332
- """
333
- server_params = StdioServerParameters(
334
- command=command,
335
- args=args,
336
- env=env,
337
- encoding=encoding,
338
- encoding_error_handler=encoding_error_handler,
339
- )
359
+ return output.model_dump()
340
360
 
341
- # Create and store the connection
342
- stdio_transport = await self.exit_stack.enter_async_context(
343
- stdio_client(server_params)
344
- )
345
- read, write = stdio_transport
346
- session = cast(
347
- ClientSession,
348
- await self.exit_stack.enter_async_context(
349
- ClientSession(read, write)
350
- ),
351
- )
352
361
 
353
- await self._initialize_session_and_load_tools(
354
- server_name, session
355
- )
362
+ def mcp_flow(
363
+ params: MCPServerSseParams,
364
+ function_call: dict[str, Any],
365
+ ) -> MCPServer:
366
+ server = MCPServerSse(params, cache_tools_list=True)
356
367
 
357
- async def connect_to_server_via_sse(
358
- self,
359
- server_name: str,
360
- *,
361
- url: str,
362
- headers: dict[str, Any] | None = None,
363
- timeout: float = DEFAULT_HTTP_TIMEOUT,
364
- sse_read_timeout: float = DEFAULT_SSE_READ_TIMEOUT,
365
- ) -> None:
366
- """Connect to a specific MCP server using SSE
368
+ # Connect the server
369
+ asyncio.run(server.connect())
367
370
 
368
- Args:
369
- server_name: Name to identify this server connection
370
- url: URL of the SSE server
371
- headers: HTTP headers to send to the SSE endpoint
372
- timeout: HTTP timeout
373
- sse_read_timeout: SSE read timeout
374
- """
375
- # Create and store the connection
376
- sse_transport = await self.exit_stack.enter_async_context(
377
- sse_client(url, headers, timeout, sse_read_timeout)
378
- )
379
- read, write = sse_transport
380
- session = cast(
381
- ClientSession,
382
- await self.exit_stack.enter_async_context(
383
- ClientSession(read, write)
384
- ),
385
- )
371
+ # Return the server
372
+ output = asyncio.run(server.call_tool(function_call))
386
373
 
387
- await self._initialize_session_and_load_tools(
388
- server_name, session
389
- )
374
+ output = output.model_dump()
390
375
 
391
- def get_tools(self) -> list[Callable]:
392
- """Get a list of all tools from all connected servers."""
393
- all_tools: list[Callable] = []
394
- for server_tools in self.server_name_to_tools.values():
395
- all_tools.extend(server_tools)
396
- return all_tools
376
+ # Cleanup the server
377
+ asyncio.run(server.cleanup())
397
378
 
398
- async def get_prompt(
399
- self,
400
- server_name: str,
401
- prompt_name: str,
402
- arguments: Optional[dict[str, Any]] = None,
403
- ) -> List[str]:
404
- """Get a prompt from a given MCP server."""
405
- session = self.sessions[server_name]
406
- return await load_mcp_prompt(session, prompt_name, arguments)
407
-
408
- async def __aenter__(self) -> "MultiServerMCPClient":
409
- try:
410
- connections = self.connections or {}
411
- for server_name, connection in connections.items():
412
- connection_dict = connection.copy()
413
- transport = connection_dict.pop("transport")
414
- if transport == "stdio":
415
- await self.connect_to_server_via_stdio(
416
- server_name, **connection_dict
417
- )
418
- elif transport == "sse":
419
- await self.connect_to_server_via_sse(
420
- server_name, **connection_dict
421
- )
422
- else:
423
- raise ValueError(
424
- f"Unsupported transport: {transport}. Must be 'stdio' or 'sse'"
425
- )
426
- return self
427
- except Exception:
428
- await self.exit_stack.aclose()
429
- raise
379
+ return any_to_str(output)
430
380
 
431
- async def __aexit__(
432
- self,
433
- exc_type: type[BaseException] | None,
434
- exc_val: BaseException | None,
435
- exc_tb: TracebackType | None,
436
- ) -> None:
437
- await self.exit_stack.aclose()
438
-
439
-
440
- # #!/usr/bin/env python3
441
- # import asyncio
442
- # import os
443
- # import json
444
- # from typing import List, Any, Callable
445
-
446
- # # # Import our MCP client module
447
- # # from mcp_client import MultiServerMCPClient
448
-
449
-
450
- # async def main():
451
- # """Test script for demonstrating MCP client usage."""
452
- # print("Starting MCP Client test...")
453
-
454
- # # Create a connection to multiple MCP servers
455
- # # You'll need to update these paths to match your setup
456
- # async with MultiServerMCPClient(
457
- # {
458
- # "math": {
459
- # "transport": "stdio",
460
- # "command": "python",
461
- # "args": ["/path/to/math_server.py"],
462
- # "env": {"DEBUG": "1"},
463
- # },
464
- # "search": {
465
- # "transport": "sse",
466
- # "url": "http://localhost:8000/sse",
467
- # "headers": {
468
- # "Authorization": f"Bearer {os.environ.get('API_KEY', '')}"
469
- # },
470
- # },
471
- # }
472
- # ) as client:
473
- # # Get all available tools
474
- # tools = client.get_tools()
475
- # print(f"Found {len(tools)} tools across all servers")
476
-
477
- # # Print tool information
478
- # for i, tool in enumerate(tools):
479
- # print(f"\nTool {i+1}: {tool.__name__}")
480
- # print(f" Description: {tool.__doc__}")
481
- # if hasattr(tool, "schema") and tool.schema:
482
- # print(
483
- # f" Schema: {json.dumps(tool.schema, indent=2)[:100]}..."
484
- # )
485
-
486
- # # Example: Use a specific tool if available
487
- # calculator_tool = next(
488
- # (t for t in tools if t.__name__ == "calculator"), None
489
- # )
490
- # if calculator_tool:
491
- # print("\n\nTesting calculator tool:")
492
- # try:
493
- # # Call the tool as an async function
494
- # result, artifacts = await calculator_tool(
495
- # expression="2 + 2 * 3"
496
- # )
497
- # print(f" Calculator result: {result}")
498
- # if artifacts:
499
- # print(
500
- # f" With {len(artifacts)} additional artifacts"
501
- # )
502
- # except Exception as e:
503
- # print(f" Error using calculator: {e}")
504
-
505
- # # Example: Load a prompt from a server
506
- # try:
507
- # print("\n\nTesting prompt loading:")
508
- # prompt_messages = await client.get_prompt(
509
- # "math",
510
- # "calculation_introduction",
511
- # {"user_name": "Test User"},
512
- # )
513
- # print(
514
- # f" Loaded prompt with {len(prompt_messages)} messages:"
515
- # )
516
- # for i, msg in enumerate(prompt_messages):
517
- # print(f" Message {i+1}: {msg[:50]}...")
518
- # except Exception as e:
519
- # print(f" Error loading prompt: {e}")
520
-
521
-
522
- # async def create_custom_tool():
523
- # """Example of creating a custom tool function."""
524
-
525
- # # Define a tool function with metadata
526
- # async def add_numbers(a: float, b: float) -> tuple[str, None]:
527
- # """Add two numbers together."""
528
- # result = a + b
529
- # return f"The sum of {a} and {b} is {result}", None
530
-
531
- # # Add metadata to the function
532
- # add_numbers.__name__ = "add_numbers"
533
- # add_numbers.__doc__ = (
534
- # "Add two numbers together and return the result."
535
- # )
536
- # add_numbers.schema = {
537
- # "type": "object",
538
- # "properties": {
539
- # "a": {"type": "number", "description": "First number"},
540
- # "b": {"type": "number", "description": "Second number"},
541
- # },
542
- # "required": ["a", "b"],
543
- # }
544
-
545
- # # Use the tool
546
- # result, _ = await add_numbers(a=5, b=7)
547
- # print(f"\nCustom tool result: {result}")
548
-
549
-
550
- # if __name__ == "__main__":
551
- # # Run both examples
552
- # loop = asyncio.get_event_loop()
553
- # loop.run_until_complete(main())
554
- # loop.run_until_complete(create_custom_tool())
381
+
382
+ def batch_mcp_flow(
383
+ params: List[MCPServerSseParams],
384
+ function_call: List[dict[str, Any]] = [],
385
+ ) -> MCPServer:
386
+ output_list = []
387
+
388
+ for param in params:
389
+ output = mcp_flow(param, function_call)
390
+ output_list.append(output)
391
+
392
+ return output_list