universal-mcp 0.1.23rc2__py3-none-any.whl → 0.1.24rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- universal_mcp/agentr/__init__.py +6 -0
- universal_mcp/agentr/agentr.py +30 -0
- universal_mcp/{utils/agentr.py → agentr/client.py} +22 -7
- universal_mcp/agentr/integration.py +104 -0
- universal_mcp/agentr/registry.py +91 -0
- universal_mcp/agentr/server.py +51 -0
- universal_mcp/agents/__init__.py +6 -0
- universal_mcp/agents/auto.py +576 -0
- universal_mcp/agents/base.py +88 -0
- universal_mcp/agents/cli.py +27 -0
- universal_mcp/agents/codeact/__init__.py +243 -0
- universal_mcp/agents/codeact/sandbox.py +27 -0
- universal_mcp/agents/codeact/test.py +15 -0
- universal_mcp/agents/codeact/utils.py +61 -0
- universal_mcp/agents/hil.py +104 -0
- universal_mcp/agents/llm.py +10 -0
- universal_mcp/agents/react.py +58 -0
- universal_mcp/agents/simple.py +40 -0
- universal_mcp/agents/utils.py +111 -0
- universal_mcp/analytics.py +44 -14
- universal_mcp/applications/__init__.py +42 -75
- universal_mcp/applications/application.py +187 -133
- universal_mcp/applications/sample/app.py +245 -0
- universal_mcp/cli.py +14 -231
- universal_mcp/client/oauth.py +122 -18
- universal_mcp/client/token_store.py +62 -3
- universal_mcp/client/{client.py → transport.py} +127 -48
- universal_mcp/config.py +189 -49
- universal_mcp/exceptions.py +54 -6
- universal_mcp/integrations/__init__.py +0 -18
- universal_mcp/integrations/integration.py +185 -168
- universal_mcp/servers/__init__.py +2 -14
- universal_mcp/servers/server.py +84 -258
- universal_mcp/stores/store.py +126 -93
- universal_mcp/tools/__init__.py +3 -0
- universal_mcp/tools/adapters.py +20 -11
- universal_mcp/tools/func_metadata.py +1 -1
- universal_mcp/tools/manager.py +38 -53
- universal_mcp/tools/registry.py +41 -0
- universal_mcp/tools/tools.py +24 -3
- universal_mcp/types.py +10 -0
- universal_mcp/utils/common.py +245 -0
- universal_mcp/utils/installation.py +3 -4
- universal_mcp/utils/openapi/api_generator.py +71 -17
- universal_mcp/utils/openapi/api_splitter.py +0 -1
- universal_mcp/utils/openapi/cli.py +669 -0
- universal_mcp/utils/openapi/filters.py +114 -0
- universal_mcp/utils/openapi/openapi.py +315 -23
- universal_mcp/utils/openapi/postprocessor.py +275 -0
- universal_mcp/utils/openapi/preprocessor.py +63 -8
- universal_mcp/utils/openapi/test_generator.py +287 -0
- universal_mcp/utils/prompts.py +634 -0
- universal_mcp/utils/singleton.py +4 -1
- universal_mcp/utils/testing.py +196 -8
- universal_mcp-0.1.24rc3.dist-info/METADATA +68 -0
- universal_mcp-0.1.24rc3.dist-info/RECORD +70 -0
- universal_mcp/applications/README.md +0 -122
- universal_mcp/client/__main__.py +0 -30
- universal_mcp/client/agent.py +0 -96
- universal_mcp/integrations/README.md +0 -25
- universal_mcp/servers/README.md +0 -79
- universal_mcp/stores/README.md +0 -74
- universal_mcp/tools/README.md +0 -86
- universal_mcp-0.1.23rc2.dist-info/METADATA +0 -283
- universal_mcp-0.1.23rc2.dist-info/RECORD +0 -51
- /universal_mcp/{utils → tools}/docstring_parser.py +0 -0
- {universal_mcp-0.1.23rc2.dist-info → universal_mcp-0.1.24rc3.dist-info}/WHEEL +0 -0
- {universal_mcp-0.1.23rc2.dist-info → universal_mcp-0.1.24rc3.dist-info}/entry_points.txt +0 -0
- {universal_mcp-0.1.23rc2.dist-info → universal_mcp-0.1.24rc3.dist-info}/licenses/LICENSE +0 -0
universal_mcp/client/oauth.py
CHANGED
@@ -1,21 +1,45 @@
|
|
1
1
|
import threading
|
2
2
|
import time
|
3
3
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
4
|
+
from typing import Any
|
4
5
|
from urllib.parse import parse_qs, urlparse
|
5
6
|
|
6
7
|
from universal_mcp.utils.singleton import Singleton
|
7
8
|
|
8
9
|
|
9
10
|
class CallbackHandler(BaseHTTPRequestHandler):
|
10
|
-
"""
|
11
|
-
|
12
|
-
|
13
|
-
|
11
|
+
"""Handles the HTTP GET request for an OAuth 2.0 callback.
|
12
|
+
|
13
|
+
This handler is designed to capture the authorization code and state
|
14
|
+
(or an error) returned by an OAuth 2.0 authorization server as query
|
15
|
+
parameters in the redirect URI. It stores these values in a shared
|
16
|
+
`callback_data` dictionary.
|
17
|
+
|
18
|
+
It sends a simple HTML response to the user's browser indicating
|
19
|
+
success or failure of the authorization attempt.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, request, client_address, server, callback_data: dict):
|
23
|
+
"""Initializes the CallbackHandler.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
request: The HTTP request.
|
27
|
+
client_address: The client's address.
|
28
|
+
server: The server instance.
|
29
|
+
callback_data (dict): A dictionary shared with the `CallbackServer`
|
30
|
+
to store the captured OAuth parameters (e.g.,
|
31
|
+
`authorization_code`, `state`, `error`).
|
32
|
+
"""
|
14
33
|
self.callback_data = callback_data
|
15
34
|
super().__init__(request, client_address, server)
|
16
35
|
|
17
36
|
def do_GET(self):
|
18
|
-
"""
|
37
|
+
"""Handles the GET request from the OAuth authorization server's redirect.
|
38
|
+
|
39
|
+
Parses the URL query parameters to find 'code' and 'state', or 'error'.
|
40
|
+
Stores these values into the `self.callback_data` dictionary.
|
41
|
+
Responds to the browser with a success or failure HTML page.
|
42
|
+
"""
|
19
43
|
parsed = urlparse(self.path)
|
20
44
|
query_params = parse_qs(parsed.query)
|
21
45
|
|
@@ -44,7 +68,7 @@ class CallbackHandler(BaseHTTPRequestHandler):
|
|
44
68
|
<html>
|
45
69
|
<body>
|
46
70
|
<h1>Authorization Failed</h1>
|
47
|
-
<p>Error: {query_params[
|
71
|
+
<p>Error: {query_params["error"][0]}</p>
|
48
72
|
<p>You can close this window and return to the terminal.</p>
|
49
73
|
</body>
|
50
74
|
</html>
|
@@ -54,23 +78,68 @@ class CallbackHandler(BaseHTTPRequestHandler):
|
|
54
78
|
self.send_response(404)
|
55
79
|
self.end_headers()
|
56
80
|
|
57
|
-
def log_message(self, format, *args):
|
58
|
-
"""
|
81
|
+
def log_message(self, format: str, *args: Any):
|
82
|
+
"""Suppresses the default logging of HTTP requests.
|
83
|
+
|
84
|
+
Overrides the base class method to prevent request logs from being
|
85
|
+
printed to stderr, keeping the console cleaner during the OAuth flow.
|
86
|
+
"""
|
59
87
|
pass
|
60
88
|
|
61
89
|
|
62
90
|
class CallbackServer(metaclass=Singleton):
|
63
|
-
"""
|
64
|
-
|
65
|
-
|
91
|
+
"""A singleton HTTP server to manage OAuth 2.0 redirect callbacks.
|
92
|
+
|
93
|
+
This server runs in a background thread, listening on a specified
|
94
|
+
localhost port. It uses the `CallbackHandler` to capture the
|
95
|
+
authorization code or error returned by an OAuth 2.0 provider
|
96
|
+
after user authentication.
|
97
|
+
|
98
|
+
Being a Singleton, only one instance of this server will run per
|
99
|
+
application, even if instantiated multiple times.
|
100
|
+
|
101
|
+
Attributes:
|
102
|
+
port (int): The port number on localhost where the server listens.
|
103
|
+
server (HTTPServer | None): The underlying `HTTPServer` instance.
|
104
|
+
None if the server is not running.
|
105
|
+
thread (threading.Thread | None): The background thread in which
|
106
|
+
the server runs. None if the server is not running.
|
107
|
+
callback_data (dict): A dictionary to store data received from the
|
108
|
+
OAuth callback (e.g., `authorization_code`, `state`, `error`).
|
109
|
+
This is shared with the `CallbackHandler`.
|
110
|
+
_running (bool): A flag indicating whether the server is currently
|
111
|
+
started and listening.
|
112
|
+
"""
|
113
|
+
|
114
|
+
def __init__(self, port: int = 3000):
|
115
|
+
"""Initializes the CallbackServer.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
port (int, optional): The port number on localhost for the server
|
119
|
+
to listen on. Defaults to 3000.
|
120
|
+
"""
|
66
121
|
self.port = port
|
67
122
|
self.server = None
|
68
123
|
self.thread = None
|
69
124
|
self.callback_data = {"authorization_code": None, "state": None, "error": None}
|
70
125
|
self._running = False
|
71
126
|
|
127
|
+
@property
|
128
|
+
def is_running(self) -> bool:
|
129
|
+
return self._running
|
130
|
+
|
72
131
|
def _create_handler_with_data(self):
|
73
|
-
"""
|
132
|
+
"""Creates a `CallbackHandler` subclass with shared `callback_data`.
|
133
|
+
|
134
|
+
This method dynamically defines a new handler class that inherits from
|
135
|
+
`CallbackHandler`. The purpose is to allow the handler instances
|
136
|
+
to access and modify the `self.callback_data` dictionary of this
|
137
|
+
`CallbackServer` instance, enabling communication of OAuth parameters
|
138
|
+
from the handler back to the server logic.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
type: A new class, subclass of `CallbackHandler`.
|
142
|
+
"""
|
74
143
|
callback_data = self.callback_data
|
75
144
|
|
76
145
|
class DataCallbackHandler(CallbackHandler):
|
@@ -80,7 +149,13 @@ class CallbackServer(metaclass=Singleton):
|
|
80
149
|
return DataCallbackHandler
|
81
150
|
|
82
151
|
def start(self):
|
83
|
-
"""
|
152
|
+
"""Starts the HTTP callback server in a background daemon thread.
|
153
|
+
|
154
|
+
If the server is not already running, it initializes an `HTTPServer`
|
155
|
+
with a specialized `CallbackHandler` and starts it in a new
|
156
|
+
daemon thread. This allows the main application flow to continue
|
157
|
+
while waiting for the OAuth callback.
|
158
|
+
"""
|
84
159
|
if self._running:
|
85
160
|
return
|
86
161
|
handler_class = self._create_handler_with_data()
|
@@ -91,15 +166,37 @@ class CallbackServer(metaclass=Singleton):
|
|
91
166
|
self._running = True
|
92
167
|
|
93
168
|
def stop(self):
|
94
|
-
"""
|
169
|
+
"""Stops the HTTP callback server and cleans up resources.
|
170
|
+
|
171
|
+
Shuts down the `HTTPServer` and waits for its background thread
|
172
|
+
to complete.
|
173
|
+
"""
|
95
174
|
if self.server:
|
96
175
|
self.server.shutdown()
|
97
176
|
self.server.server_close()
|
98
177
|
if self.thread:
|
99
178
|
self.thread.join(timeout=1)
|
100
179
|
|
101
|
-
def wait_for_callback(self, timeout=300):
|
102
|
-
"""
|
180
|
+
def wait_for_callback(self, timeout: int = 300) -> str:
|
181
|
+
"""Waits for the OAuth callback to provide an authorization code.
|
182
|
+
|
183
|
+
This method polls the `self.callback_data` dictionary until an
|
184
|
+
authorization code is received or an error is reported by the
|
185
|
+
`CallbackHandler`, or until the timeout is reached.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
timeout (int, optional): The maximum time in seconds to wait
|
189
|
+
for the callback. Defaults to 300 seconds (5 minutes).
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
str: The received authorization code.
|
193
|
+
|
194
|
+
Raises:
|
195
|
+
Exception: If an error is reported in the callback
|
196
|
+
(e.g., "OAuth error: <error_message>").
|
197
|
+
Exception: If the timeout is reached before a code or error
|
198
|
+
is received (e.g., "Timeout waiting for OAuth callback").
|
199
|
+
"""
|
103
200
|
start_time = time.time()
|
104
201
|
while time.time() - start_time < timeout:
|
105
202
|
if self.callback_data["authorization_code"]:
|
@@ -109,6 +206,13 @@ class CallbackServer(metaclass=Singleton):
|
|
109
206
|
time.sleep(0.1)
|
110
207
|
raise Exception("Timeout waiting for OAuth callback")
|
111
208
|
|
112
|
-
def get_state(self):
|
113
|
-
"""
|
209
|
+
def get_state(self) -> str | None:
|
210
|
+
"""Retrieves the 'state' parameter received during the OAuth callback.
|
211
|
+
|
212
|
+
The state parameter is often used to prevent cross-site request forgery (CSRF)
|
213
|
+
attacks by matching its value with one sent in the initial authorization request.
|
214
|
+
|
215
|
+
Returns:
|
216
|
+
str | None: The 'state' parameter value if received, otherwise None.
|
217
|
+
"""
|
114
218
|
return self.callback_data["state"]
|
@@ -6,27 +6,86 @@ from universal_mcp.stores.store import KeyringStore
|
|
6
6
|
|
7
7
|
|
8
8
|
class TokenStore(MCPTokenStorage):
|
9
|
-
"""
|
9
|
+
"""Persistent storage for OAuth tokens and client information using KeyringStore.
|
10
|
+
|
11
|
+
This class implements the `mcp.client.auth.TokenStorage` interface,
|
12
|
+
providing a mechanism to securely store and retrieve OAuth 2.0 tokens
|
13
|
+
(as `OAuthToken` objects) and OAuth client registration details
|
14
|
+
(as `OAuthClientInformationFull` objects).
|
15
|
+
|
16
|
+
It utilizes an underlying `KeyringStore` instance, which typically
|
17
|
+
delegates to the operating system's secure credential management
|
18
|
+
system (e.g., macOS Keychain, Windows Credential Manager, Linux KWallet).
|
19
|
+
This ensures that sensitive token data is stored securely and persistently.
|
20
|
+
|
21
|
+
Attributes:
|
22
|
+
store (KeyringStore): The `KeyringStore` instance used for actually
|
23
|
+
storing and retrieving the serialized token and client info data.
|
24
|
+
"""
|
10
25
|
|
11
26
|
def __init__(self, store: KeyringStore):
|
27
|
+
"""Initializes the TokenStore.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
store (KeyringStore): An instance of `KeyringStore` that will be
|
31
|
+
used for the actual persistence of tokens and client information.
|
32
|
+
"""
|
12
33
|
self.store = store
|
13
|
-
|
14
|
-
self.
|
34
|
+
# These are not meant to be persistent caches in this implementation
|
35
|
+
# self._tokens: OAuthToken | None = None
|
36
|
+
# self._client_info: OAuthClientInformationFull | None = None
|
15
37
|
|
16
38
|
async def get_tokens(self) -> OAuthToken | None:
|
39
|
+
"""Retrieves OAuth tokens from the persistent KeyringStore.
|
40
|
+
|
41
|
+
Fetches the JSON string representation of tokens from the store using
|
42
|
+
the key "tokens" and deserializes it into an `OAuthToken` object.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
OAuthToken | None: The deserialized `OAuthToken` object if found
|
46
|
+
and successfully parsed, otherwise None.
|
47
|
+
"""
|
17
48
|
try:
|
18
49
|
return OAuthToken.model_validate_json(self.store.get("tokens"))
|
19
50
|
except KeyNotFoundError:
|
20
51
|
return None
|
21
52
|
|
22
53
|
async def set_tokens(self, tokens: OAuthToken) -> None:
|
54
|
+
"""Serializes OAuth tokens to JSON and saves them to the KeyringStore.
|
55
|
+
|
56
|
+
The provided `OAuthToken` object is converted to its JSON string
|
57
|
+
representation and stored in the `KeyringStore` under the key "tokens".
|
58
|
+
|
59
|
+
Args:
|
60
|
+
tokens (OAuthToken): The `OAuthToken` object to store.
|
61
|
+
"""
|
23
62
|
self.store.set("tokens", tokens.model_dump_json())
|
24
63
|
|
25
64
|
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
65
|
+
"""Retrieves OAuth client information from the persistent KeyringStore.
|
66
|
+
|
67
|
+
Fetches the JSON string representation of client information from the
|
68
|
+
store using the key "client_info" and deserializes it into an
|
69
|
+
`OAuthClientInformationFull` object.
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
OAuthClientInformationFull | None: The deserialized object if found
|
73
|
+
and successfully parsed, otherwise None.
|
74
|
+
"""
|
26
75
|
try:
|
27
76
|
return OAuthClientInformationFull.model_validate_json(self.store.get("client_info"))
|
28
77
|
except KeyNotFoundError:
|
29
78
|
return None
|
30
79
|
|
31
80
|
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
81
|
+
"""Serializes OAuth client information to JSON and saves it to KeyringStore.
|
82
|
+
|
83
|
+
The provided `OAuthClientInformationFull` object is converted to its
|
84
|
+
JSON string representation and stored in the `KeyringStore` under the
|
85
|
+
key "client_info".
|
86
|
+
|
87
|
+
Args:
|
88
|
+
client_info (OAuthClientInformationFull): The client information object
|
89
|
+
to store.
|
90
|
+
"""
|
32
91
|
self.store.set("client_info", client_info.model_dump_json())
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import os
|
2
2
|
import webbrowser
|
3
3
|
from contextlib import AsyncExitStack
|
4
|
-
from typing import Any, Literal
|
4
|
+
from typing import Any, Literal, Self
|
5
5
|
|
6
6
|
from loguru import logger
|
7
7
|
from mcp import ClientSession, StdioServerParameters
|
@@ -9,7 +9,6 @@ from mcp.client.auth import OAuthClientProvider
|
|
9
9
|
from mcp.client.sse import sse_client
|
10
10
|
from mcp.client.stdio import stdio_client
|
11
11
|
from mcp.client.streamable_http import streamablehttp_client
|
12
|
-
from mcp.server import Server
|
13
12
|
from mcp.shared.auth import OAuthClientMetadata
|
14
13
|
from mcp.types import (
|
15
14
|
CallToolResult as MCPCallToolResult,
|
@@ -21,13 +20,20 @@ from openai.types.chat import ChatCompletionToolParam
|
|
21
20
|
|
22
21
|
from universal_mcp.client.oauth import CallbackServer
|
23
22
|
from universal_mcp.client.token_store import TokenStore
|
24
|
-
from universal_mcp.config import ClientTransportConfig
|
23
|
+
from universal_mcp.config import ClientConfig, ClientTransportConfig
|
25
24
|
from universal_mcp.stores.store import KeyringStore
|
26
25
|
from universal_mcp.tools.adapters import transform_mcp_tool_to_openai_tool
|
27
26
|
|
28
27
|
|
29
|
-
class
|
30
|
-
"""
|
28
|
+
class ClientTransport:
|
29
|
+
"""
|
30
|
+
Client for connecting to and interacting with a single MCP server.
|
31
|
+
|
32
|
+
Manages the lifecycle of a connection to an MCP server, handles various
|
33
|
+
transport mechanisms (stdio, sse, streamable_http), and facilitates
|
34
|
+
authentication, including OAuth 2.0 client flows. Allows listing tools
|
35
|
+
available on the server and calling them.
|
36
|
+
"""
|
31
37
|
|
32
38
|
def __init__(self, name: str, config: ClientTransportConfig) -> None:
|
33
39
|
self.name: str = name
|
@@ -35,14 +41,12 @@ class MCPClient:
|
|
35
41
|
self.session: ClientSession | None = None
|
36
42
|
self.server_url: str = config.url
|
37
43
|
|
38
|
-
#
|
39
|
-
self.
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
self.store = KeyringStore(self.name)
|
45
|
-
self.auth = OAuthClientProvider(
|
44
|
+
# Create OAuth authentication handler if needed
|
45
|
+
if self.server_url and not getattr(self.config, "headers", None):
|
46
|
+
# Set up callback server
|
47
|
+
self._callback_server = CallbackServer(port=3000)
|
48
|
+
self.store: KeyringStore | None = KeyringStore(self.name)
|
49
|
+
self.auth: OAuthClientProvider | None = OAuthClientProvider(
|
46
50
|
server_url="/".join(self.server_url.split("/")[:-1]),
|
47
51
|
client_metadata=OAuthClientMetadata.model_validate(self.client_metadata_dict),
|
48
52
|
storage=TokenStore(self.store),
|
@@ -50,10 +54,18 @@ class MCPClient:
|
|
50
54
|
callback_handler=self._callback_handler,
|
51
55
|
)
|
52
56
|
else:
|
57
|
+
self._callback_server = None
|
58
|
+
self.store = None
|
53
59
|
self.auth = None
|
54
60
|
|
61
|
+
@property
|
62
|
+
def callback_server(self) -> CallbackServer:
|
63
|
+
if self._callback_server and not self._callback_server.is_running:
|
64
|
+
self._callback_server.start()
|
65
|
+
return self._callback_server
|
66
|
+
|
55
67
|
async def _callback_handler(self) -> tuple[str, str | None]:
|
56
|
-
"""
|
68
|
+
"""Handles the OAuth callback by waiting for and returning auth details."""
|
57
69
|
print("⏳ Waiting for authorization callback...")
|
58
70
|
try:
|
59
71
|
auth_code = self.callback_server.wait_for_callback(timeout=300)
|
@@ -63,38 +75,45 @@ class MCPClient:
|
|
63
75
|
|
64
76
|
@property
|
65
77
|
def client_metadata_dict(self) -> dict[str, Any]:
|
78
|
+
"""Provides OAuth 2.0 client metadata for registration or authentication."""
|
66
79
|
return {
|
67
|
-
"client_name":
|
68
|
-
"redirect_uris": [
|
80
|
+
"client_name": self.name,
|
81
|
+
"redirect_uris": [self.callback_server.redirect_uri], # type: ignore
|
69
82
|
"grant_types": ["authorization_code", "refresh_token"],
|
70
83
|
"response_types": ["code"],
|
71
84
|
"token_endpoint_auth_method": "client_secret_post",
|
72
85
|
}
|
73
86
|
|
74
87
|
async def _default_redirect_handler(self, authorization_url: str) -> None:
|
75
|
-
"""Default
|
88
|
+
"""Default handler for OAuth redirects; opens URL in a web browser."""
|
76
89
|
print(f"Opening browser for authorization: {authorization_url}")
|
77
90
|
webbrowser.open(authorization_url)
|
78
91
|
|
79
|
-
async def initialize(self, exit_stack: AsyncExitStack):
|
80
|
-
"""
|
81
|
-
|
92
|
+
async def initialize(self, exit_stack: AsyncExitStack) -> None:
|
93
|
+
"""
|
94
|
+
Establishes and initializes the connection to the MCP server.
|
95
|
+
|
96
|
+
Raises:
|
97
|
+
ValueError: If the transport type is unknown or if required
|
98
|
+
configuration for a transport is missing.
|
99
|
+
"""
|
100
|
+
transport = getattr(self.config, "transport", None)
|
101
|
+
session = None
|
82
102
|
try:
|
83
103
|
if transport == "stdio":
|
84
|
-
command = self.config
|
85
|
-
if command
|
104
|
+
command = self.config.get("command")
|
105
|
+
if not command:
|
86
106
|
raise ValueError("The command must be a valid string and cannot be None.")
|
87
107
|
|
88
108
|
server_params = StdioServerParameters(
|
89
109
|
command=command,
|
90
|
-
args=self.config
|
91
|
-
env={**os.environ, **self.config
|
110
|
+
args=self.config.get("args", []),
|
111
|
+
env={**os.environ, **self.config.get("env", {})} if self.config.get("env") else None,
|
92
112
|
)
|
93
113
|
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
|
94
114
|
read, write = stdio_transport
|
95
115
|
session = await exit_stack.enter_async_context(ClientSession(read, write))
|
96
116
|
await session.initialize()
|
97
|
-
self.session = session
|
98
117
|
elif transport == "streamable_http":
|
99
118
|
url = self.config.get("url")
|
100
119
|
headers = self.config.get("headers", {})
|
@@ -106,10 +125,9 @@ class MCPClient:
|
|
106
125
|
read, write, _ = streamable_http_transport
|
107
126
|
session = await exit_stack.enter_async_context(ClientSession(read, write))
|
108
127
|
await session.initialize()
|
109
|
-
self.session = session
|
110
128
|
elif transport == "sse":
|
111
|
-
url = self.config.url
|
112
|
-
headers = self.config.headers
|
129
|
+
url = self.config.get("url")
|
130
|
+
headers = self.config.get("headers", {})
|
113
131
|
if not url:
|
114
132
|
raise ValueError("'url' must be provided for sse transport.")
|
115
133
|
sse_transport = await exit_stack.enter_async_context(
|
@@ -118,73 +136,126 @@ class MCPClient:
|
|
118
136
|
read, write = sse_transport
|
119
137
|
session = await exit_stack.enter_async_context(ClientSession(read, write))
|
120
138
|
await session.initialize()
|
121
|
-
self.session = session
|
122
139
|
else:
|
123
140
|
raise ValueError(f"Unknown transport: {transport}")
|
141
|
+
self.session = session
|
124
142
|
except Exception as e:
|
143
|
+
if session:
|
144
|
+
await session.aclose()
|
125
145
|
logger.error(f"Error initializing server {self.name}: {e}")
|
126
146
|
raise
|
127
147
|
|
128
148
|
async def list_tools(self) -> list[MCPTool]:
|
129
|
-
"""
|
149
|
+
"""Lists all tools available on the connected MCP server."""
|
130
150
|
if self.session:
|
131
|
-
|
132
|
-
|
151
|
+
try:
|
152
|
+
tools = await self.session.list_tools()
|
153
|
+
return list(tools.tools)
|
154
|
+
except Exception as e:
|
155
|
+
logger.warning(f"Failed to list tools for client {self.name}: {e}")
|
133
156
|
return []
|
134
157
|
|
135
158
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> MCPCallToolResult:
|
136
|
-
"""
|
159
|
+
"""Calls a specified tool on the connected MCP server with given arguments."""
|
137
160
|
if self.session:
|
138
|
-
|
161
|
+
try:
|
162
|
+
return await self.session.call_tool(tool_name, arguments)
|
163
|
+
except Exception as e:
|
164
|
+
logger.error(f"Error calling tool '{tool_name}' on client {self.name}: {e}")
|
139
165
|
return MCPCallToolResult(
|
140
166
|
content=[],
|
141
167
|
isError=True,
|
142
168
|
)
|
143
169
|
|
144
170
|
|
145
|
-
class
|
171
|
+
class MultiClientTransport:
|
146
172
|
"""
|
147
|
-
|
173
|
+
Aggregates multiple ClientTransport instances to act as a single MCP Server.
|
174
|
+
|
175
|
+
Provides a unified Server interface for a collection of ClientTransport
|
176
|
+
instances, each potentially connected to a different MCP server.
|
177
|
+
Maintains a mapping of tool names to the specific ClientTransport that
|
178
|
+
provides that tool.
|
148
179
|
"""
|
149
180
|
|
150
181
|
def __init__(self, clients: dict[str, ClientTransportConfig]):
|
151
|
-
self.clients: list[
|
152
|
-
self.tool_to_client: dict[str,
|
182
|
+
self.clients: list[ClientTransport] = [ClientTransport(name, config) for name, config in clients.items()]
|
183
|
+
self.tool_to_client: dict[str, ClientTransport] = {}
|
153
184
|
self._mcp_tools: list[MCPTool] = []
|
154
185
|
self._exit_stack: AsyncExitStack = AsyncExitStack()
|
155
186
|
|
187
|
+
@classmethod
|
188
|
+
def from_file(cls, path: str) -> Self:
|
189
|
+
mcp_config = ClientConfig.load_json_config(path)
|
190
|
+
return cls(mcp_config.mcpServers)
|
191
|
+
|
192
|
+
def save_to_file(self, path: str) -> None:
|
193
|
+
mcp_config = ClientConfig(mcpServers={name: config.model_dump() for name, config in self.clients.items()})
|
194
|
+
mcp_config.save_json_config(path)
|
195
|
+
|
196
|
+
async def add_client(self, name: str, config: ClientTransportConfig) -> None:
|
197
|
+
if name in self.tool_to_client:
|
198
|
+
logger.warning(f"Client {name} already exists. Skipping.")
|
199
|
+
return
|
200
|
+
self.clients.append(ClientTransport(name, config))
|
201
|
+
self.tool_to_client[name] = self.clients[-1]
|
202
|
+
logger.info(f"Added client: {name}")
|
203
|
+
await self._populate_tool_mapping()
|
204
|
+
|
205
|
+
async def remove_client(self, name: str) -> None:
|
206
|
+
if name not in self.tool_to_client:
|
207
|
+
logger.warning(f"Client {name} not found. Skipping.")
|
208
|
+
return
|
209
|
+
self.clients.remove(self.tool_to_client[name])
|
210
|
+
del self.tool_to_client[name]
|
211
|
+
logger.info(f"Removed client: {name}")
|
212
|
+
await self._populate_tool_mapping()
|
213
|
+
|
156
214
|
async def __aenter__(self):
|
157
|
-
"""Initialize the server connection."""
|
158
215
|
for client in self.clients:
|
159
216
|
await client.initialize(self._exit_stack)
|
160
217
|
await self._populate_tool_mapping()
|
161
218
|
return self
|
162
219
|
|
163
220
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
164
|
-
"""Clean up the server connection."""
|
165
221
|
self.clients.clear()
|
166
222
|
self.tool_to_client.clear()
|
167
223
|
self._mcp_tools.clear()
|
168
224
|
await self._exit_stack.aclose()
|
169
225
|
|
170
226
|
async def _populate_tool_mapping(self):
|
171
|
-
"""Populate the mapping from tool name to server."""
|
172
227
|
self.tool_to_client.clear()
|
173
228
|
self._mcp_tools.clear()
|
174
229
|
for client in self.clients:
|
175
230
|
try:
|
176
231
|
tools = await client.list_tools()
|
177
232
|
for tool in tools:
|
178
|
-
|
179
|
-
tool_name = tool.name
|
180
|
-
logger.info(f"Found tool: {tool_name} from client: {client.name}")
|
233
|
+
tool_name = getattr(tool, "name", None)
|
181
234
|
if tool_name:
|
182
|
-
self.tool_to_client
|
235
|
+
if tool_name not in self.tool_to_client:
|
236
|
+
self._mcp_tools.append(tool)
|
237
|
+
self.tool_to_client[tool_name] = client
|
238
|
+
logger.info(f"Found tool: {tool_name} from client: {client.name}")
|
239
|
+
else:
|
240
|
+
logger.warning(
|
241
|
+
f"Duplicate tool name '{tool_name}' found in client '{client.name}'. Skipping."
|
242
|
+
)
|
183
243
|
except Exception as e:
|
184
244
|
logger.warning(f"Failed to list tools for client {client.name}: {e}")
|
185
245
|
|
186
246
|
async def list_tools(self, format: Literal["mcp", "openai"] = "mcp") -> list[MCPTool | ChatCompletionToolParam]:
|
187
|
-
"""
|
247
|
+
"""
|
248
|
+
Lists all unique tools available from all managed clients.
|
249
|
+
|
250
|
+
Args:
|
251
|
+
format: The desired format for the returned tools.
|
252
|
+
|
253
|
+
Returns:
|
254
|
+
List of tools in the specified format.
|
255
|
+
|
256
|
+
Raises:
|
257
|
+
ValueError: If an unsupported format is requested.
|
258
|
+
"""
|
188
259
|
if format == "mcp":
|
189
260
|
return self._mcp_tools
|
190
261
|
elif format == "openai":
|
@@ -193,6 +264,14 @@ class MultiClientServer(Server):
|
|
193
264
|
raise ValueError(f"Invalid format: {format}")
|
194
265
|
|
195
266
|
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> MCPCallToolResult:
|
196
|
-
"""
|
197
|
-
|
267
|
+
"""
|
268
|
+
Calls a tool by routing the request to the appropriate ClientTransport.
|
269
|
+
|
270
|
+
Raises:
|
271
|
+
KeyError: If the tool_name is not found.
|
272
|
+
"""
|
273
|
+
client = self.tool_to_client.get(tool_name)
|
274
|
+
if not client:
|
275
|
+
logger.error(f"Tool '{tool_name}' not found in any client.")
|
276
|
+
return MCPCallToolResult(content=[], isError=True)
|
198
277
|
return await client.call_tool(tool_name, arguments)
|