gnosys-strata 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gnosys_strata-1.1.4.dist-info/METADATA +140 -0
- gnosys_strata-1.1.4.dist-info/RECORD +28 -0
- gnosys_strata-1.1.4.dist-info/WHEEL +4 -0
- gnosys_strata-1.1.4.dist-info/entry_points.txt +2 -0
- strata/__init__.py +6 -0
- strata/__main__.py +6 -0
- strata/cli.py +364 -0
- strata/config.py +310 -0
- strata/logging_config.py +109 -0
- strata/main.py +6 -0
- strata/mcp_client_manager.py +282 -0
- strata/mcp_proxy/__init__.py +7 -0
- strata/mcp_proxy/auth_provider.py +200 -0
- strata/mcp_proxy/client.py +162 -0
- strata/mcp_proxy/transport/__init__.py +7 -0
- strata/mcp_proxy/transport/base.py +104 -0
- strata/mcp_proxy/transport/http.py +80 -0
- strata/mcp_proxy/transport/stdio.py +69 -0
- strata/server.py +216 -0
- strata/tools.py +714 -0
- strata/treeshell_functions.py +397 -0
- strata/utils/__init__.py +0 -0
- strata/utils/bm25_search.py +181 -0
- strata/utils/catalog.py +82 -0
- strata/utils/dict_utils.py +29 -0
- strata/utils/field_search.py +233 -0
- strata/utils/shared_search.py +202 -0
- strata/utils/tool_integration.py +269 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
"""MCP Client Manager for managing multiple MCP server connections."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from strata.config import MCPServerConfig, MCPServerList
|
|
9
|
+
from strata.mcp_proxy.client import MCPClient
|
|
10
|
+
from strata.mcp_proxy.transport.http import HTTPTransport
|
|
11
|
+
from strata.mcp_proxy.transport.stdio import StdioTransport
|
|
12
|
+
from strata.utils.catalog import ToolCatalog
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MCPClientManager:
|
|
18
|
+
"""Manages multiple MCP client connections based on configuration."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, config_path: Optional[Path] = None, server_names: Optional[List[str]] = None):
|
|
21
|
+
"""Initialize the MCP client manager.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
config_path: Optional path to configuration file.
|
|
25
|
+
If None, uses default from MCPServerList.
|
|
26
|
+
server_names: Optional list of specific server names to initialize.
|
|
27
|
+
If None, all enabled servers will be initialized.
|
|
28
|
+
"""
|
|
29
|
+
self.server_list = MCPServerList(config_path)
|
|
30
|
+
self.server_names = server_names # Specific servers to manage
|
|
31
|
+
self.active_clients: Dict[str, MCPClient] = {}
|
|
32
|
+
self.active_transports: Dict[str, HTTPTransport | StdioTransport] = {}
|
|
33
|
+
# Cache of current server configs for comparison during sync
|
|
34
|
+
self.cached_configs: List[MCPServerConfig] = []
|
|
35
|
+
self.catalog = ToolCatalog()
|
|
36
|
+
# Mutex to prevent concurrent sync operations
|
|
37
|
+
self._sync_lock = asyncio.Lock()
|
|
38
|
+
|
|
39
|
+
async def initialize_from_config(self) -> Dict[str, bool]:
|
|
40
|
+
"""Initialize MCP clients from configuration.
|
|
41
|
+
|
|
42
|
+
Only initializes servers that are enabled in the configuration.
|
|
43
|
+
If server_names was specified in __init__, only those servers will be initialized.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Dict mapping server names to success status (True if connected)
|
|
47
|
+
"""
|
|
48
|
+
results = {}
|
|
49
|
+
enabled_servers = self.server_list.list_servers(enabled_only=True)
|
|
50
|
+
|
|
51
|
+
# Filter servers if specific names were provided
|
|
52
|
+
if self.server_names:
|
|
53
|
+
enabled_servers = [s for s in enabled_servers if s.name in self.server_names]
|
|
54
|
+
|
|
55
|
+
for server in enabled_servers:
|
|
56
|
+
try:
|
|
57
|
+
await self._connect_server(server)
|
|
58
|
+
results[server.name] = True
|
|
59
|
+
logger.info(f"Successfully connected to MCP server: {server.name}")
|
|
60
|
+
except Exception as e:
|
|
61
|
+
results[server.name] = False
|
|
62
|
+
logger.error(f"Failed to connect to MCP server {server.name}: {e}")
|
|
63
|
+
|
|
64
|
+
# Cache all server configs (both enabled and disabled) for future comparisons
|
|
65
|
+
self.cached_configs = self.server_list.list_servers()
|
|
66
|
+
|
|
67
|
+
return results
|
|
68
|
+
|
|
69
|
+
async def authenticate_server(self, server_name: str) -> None:
|
|
70
|
+
"""Authenticate a single MCP server.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
server_name: Name of the server to authenticate
|
|
74
|
+
"""
|
|
75
|
+
if server_name in self.active_clients:
|
|
76
|
+
client = self.active_clients[server_name]
|
|
77
|
+
try:
|
|
78
|
+
await client.initialize()
|
|
79
|
+
await client.disconnect()
|
|
80
|
+
logger.info(f"Initialized MCP server: {server_name}")
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.error(f"Error initializing {server_name}: {e}")
|
|
83
|
+
|
|
84
|
+
async def _connect_server(self, server: MCPServerConfig) -> None:
|
|
85
|
+
"""Connect to a single MCP server.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
server: Server configuration
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
Exception: If connection fails
|
|
92
|
+
"""
|
|
93
|
+
# Create transport based on type
|
|
94
|
+
if server.type in ["sse", "http"]:
|
|
95
|
+
if not server.url:
|
|
96
|
+
raise ValueError(f"Server {server.name} has no URL configured")
|
|
97
|
+
|
|
98
|
+
transport = HTTPTransport(
|
|
99
|
+
server_name=server.name,
|
|
100
|
+
url=server.url,
|
|
101
|
+
mode=server.type, # "http" or "sse" # type: ignore
|
|
102
|
+
headers=server.headers,
|
|
103
|
+
auth=server.auth
|
|
104
|
+
)
|
|
105
|
+
else: # stdio/command
|
|
106
|
+
if not server.command:
|
|
107
|
+
raise ValueError(f"Server {server.name} has no command configured")
|
|
108
|
+
transport = StdioTransport(
|
|
109
|
+
command=server.command, args=server.args, env=server.env
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Create client
|
|
113
|
+
client = MCPClient(transport)
|
|
114
|
+
|
|
115
|
+
# Connect
|
|
116
|
+
await client.connect()
|
|
117
|
+
|
|
118
|
+
# Store active client and transport
|
|
119
|
+
self.active_clients[server.name] = client
|
|
120
|
+
self.active_transports[server.name] = transport
|
|
121
|
+
|
|
122
|
+
# Auto-populate catalog on first connect (background task, non-blocking)
|
|
123
|
+
if not self.catalog.get_tools(server.name):
|
|
124
|
+
async def populate_catalog():
|
|
125
|
+
try:
|
|
126
|
+
tools = await client.list_tools()
|
|
127
|
+
self.catalog.update_server(server.name, tools)
|
|
128
|
+
logger.info(f"Added {server.name} to catalog (first connection)")
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.warning(f"Could not update catalog for {server.name}: {e}")
|
|
131
|
+
asyncio.create_task(populate_catalog())
|
|
132
|
+
|
|
133
|
+
async def _disconnect_server(self, server_name: str) -> None:
|
|
134
|
+
"""Disconnect from a single MCP server.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
server_name: Name of the server to disconnect
|
|
138
|
+
"""
|
|
139
|
+
if server_name in self.active_clients:
|
|
140
|
+
client = self.active_clients[server_name]
|
|
141
|
+
try:
|
|
142
|
+
await client.disconnect()
|
|
143
|
+
logger.info(f"Disconnected from MCP server: {server_name}")
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"Error disconnecting from {server_name}: {e}")
|
|
146
|
+
finally:
|
|
147
|
+
# Remove from active clients and transports
|
|
148
|
+
if server_name in self.active_clients:
|
|
149
|
+
del self.active_clients[server_name]
|
|
150
|
+
if server_name in self.active_transports:
|
|
151
|
+
del self.active_transports[server_name]
|
|
152
|
+
|
|
153
|
+
async def sync_with_config(self, new_servers: Dict[str, MCPServerConfig]) -> None:
|
|
154
|
+
"""Sync the manager state with new configuration.
|
|
155
|
+
|
|
156
|
+
This method handles all changes: add, remove, enable, disable, and config updates.
|
|
157
|
+
Uses a mutex lock to prevent concurrent sync operations.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
new_servers: New server configurations from config file
|
|
161
|
+
"""
|
|
162
|
+
async with self._sync_lock:
|
|
163
|
+
# Create lookup for current cached configs by name
|
|
164
|
+
cached_by_name = {config.name: config for config in self.cached_configs}
|
|
165
|
+
|
|
166
|
+
# Find servers to remove (in active clients but not in new config)
|
|
167
|
+
servers_to_remove = set(self.active_clients.keys()) - set(
|
|
168
|
+
new_servers.keys()
|
|
169
|
+
)
|
|
170
|
+
for server_name in servers_to_remove:
|
|
171
|
+
await self._disconnect_server(server_name)
|
|
172
|
+
logger.info(f"Removed MCP server: {server_name}")
|
|
173
|
+
|
|
174
|
+
# Process each server in new config
|
|
175
|
+
for server_name, new_config in new_servers.items():
|
|
176
|
+
try:
|
|
177
|
+
is_active = server_name in self.active_clients
|
|
178
|
+
cached_config = cached_by_name.get(server_name)
|
|
179
|
+
config_changed = cached_config != new_config
|
|
180
|
+
|
|
181
|
+
if new_config.enabled:
|
|
182
|
+
if not is_active:
|
|
183
|
+
# Server is enabled but not connected, connect it
|
|
184
|
+
await self._connect_server(new_config)
|
|
185
|
+
logger.info(f"Connected to MCP server: {server_name}")
|
|
186
|
+
elif config_changed:
|
|
187
|
+
# Server is active but config changed, reconnect
|
|
188
|
+
await self._disconnect_server(server_name)
|
|
189
|
+
await self._connect_server(new_config)
|
|
190
|
+
logger.info(
|
|
191
|
+
f"Reconnected MCP server with new config: {server_name}"
|
|
192
|
+
)
|
|
193
|
+
# If server is active and config unchanged, do nothing
|
|
194
|
+
else:
|
|
195
|
+
if is_active:
|
|
196
|
+
# Server is disabled but still connected, disconnect it
|
|
197
|
+
await self._disconnect_server(server_name)
|
|
198
|
+
logger.info(f"Disabled MCP server: {server_name}")
|
|
199
|
+
except Exception as e:
|
|
200
|
+
logger.error(f"Failed to sync server {server_name}: {e}")
|
|
201
|
+
|
|
202
|
+
# Update cached configs with new config
|
|
203
|
+
self.cached_configs = list(new_servers.values())
|
|
204
|
+
|
|
205
|
+
def get_client(self, server_name: str) -> MCPClient:
|
|
206
|
+
"""Get an active MCP client by server name.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
server_name: Name of the server
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
MCPClient instance if active, None otherwise
|
|
213
|
+
"""
|
|
214
|
+
return self.active_clients[server_name]
|
|
215
|
+
|
|
216
|
+
def list_active_servers(self) -> list[str]:
|
|
217
|
+
"""List names of all active (connected) servers.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
List of active server names
|
|
221
|
+
"""
|
|
222
|
+
return list(self.active_clients.keys())
|
|
223
|
+
|
|
224
|
+
def is_connected(self, server_name: str) -> bool:
|
|
225
|
+
"""Check if a server is currently connected.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
server_name: Name of the server
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
True if connected, False otherwise
|
|
232
|
+
"""
|
|
233
|
+
client = self.active_clients.get(server_name)
|
|
234
|
+
return client is not None and client.is_connected()
|
|
235
|
+
|
|
236
|
+
async def disconnect_all(self) -> None:
|
|
237
|
+
"""Disconnect from all active MCP servers."""
|
|
238
|
+
server_names = list(self.active_clients.keys())
|
|
239
|
+
for server_name in server_names:
|
|
240
|
+
await self._disconnect_server(server_name)
|
|
241
|
+
logger.info("Disconnected from all MCP servers")
|
|
242
|
+
|
|
243
|
+
async def reconnect_server(self, server_name: str) -> bool:
|
|
244
|
+
"""Reconnect to a server (disconnect if connected, then connect).
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
server_name: Name of the server to reconnect
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
True if successfully reconnected, False otherwise
|
|
251
|
+
"""
|
|
252
|
+
server = self.server_list.get_server(server_name)
|
|
253
|
+
if server is None:
|
|
254
|
+
logger.error(f"Server not found: {server_name}")
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
if not server.enabled:
|
|
258
|
+
logger.error(f"Cannot reconnect disabled server: {server_name}")
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Disconnect if active
|
|
263
|
+
if server_name in self.active_clients:
|
|
264
|
+
await self._disconnect_server(server_name)
|
|
265
|
+
|
|
266
|
+
# Reconnect
|
|
267
|
+
await self._connect_server(server)
|
|
268
|
+
logger.info(f"Reconnected to MCP server: {server_name}")
|
|
269
|
+
return True
|
|
270
|
+
|
|
271
|
+
except Exception as e:
|
|
272
|
+
logger.error(f"Failed to reconnect to server {server_name}: {e}")
|
|
273
|
+
return False
|
|
274
|
+
|
|
275
|
+
async def __aenter__(self):
|
|
276
|
+
"""Enter async context manager."""
|
|
277
|
+
await self.initialize_from_config()
|
|
278
|
+
return self
|
|
279
|
+
|
|
280
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
281
|
+
"""Exit async context manager."""
|
|
282
|
+
await self.disconnect_all()
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
"""MCP Proxy module for connecting to and interacting with MCP servers."""
|
|
2
|
+
|
|
3
|
+
from .client import MCPClient
|
|
4
|
+
from .transport import HTTPTransport, StdioTransport, Transport
|
|
5
|
+
from .auth_provider import create_oauth_provider
|
|
6
|
+
|
|
7
|
+
__all__ = ["MCPClient", "StdioTransport", "HTTPTransport", "Transport", "create_oauth_provider"]
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
import time
|
|
3
|
+
import webbrowser
|
|
4
|
+
import os
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
9
|
+
from urllib.parse import parse_qs, urlparse
|
|
10
|
+
|
|
11
|
+
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
|
12
|
+
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Set up logging
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
TOKEN_DIRECTORY = '.tokens'
|
|
19
|
+
|
|
20
|
+
class LocalTokenStorage(TokenStorage):
|
|
21
|
+
"""Simple in-memory token storage implementation."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, server_name: str = "default"):
|
|
24
|
+
self.server_name: str = server_name
|
|
25
|
+
self._tokens: OAuthToken | None = None
|
|
26
|
+
self._client_info: OAuthClientInformationFull | None = None
|
|
27
|
+
self.token_lock = threading.Lock()
|
|
28
|
+
self.info_lock = threading.Lock()
|
|
29
|
+
self.TOKEN_PATH = os.path.join(TOKEN_DIRECTORY, self.server_name,"tokens.json")
|
|
30
|
+
|
|
31
|
+
async def get_tokens(self) -> OAuthToken | None:
|
|
32
|
+
if os.path.exists(self.TOKEN_PATH) and self._tokens is None:
|
|
33
|
+
with self.token_lock:
|
|
34
|
+
with open(self.TOKEN_PATH, "r") as f:
|
|
35
|
+
try:
|
|
36
|
+
data = json.load(f)
|
|
37
|
+
self._tokens = OAuthToken.model_validate(data)
|
|
38
|
+
except Exception as e:
|
|
39
|
+
logger.info("Error loading tokens:", e)
|
|
40
|
+
return self._tokens
|
|
41
|
+
|
|
42
|
+
async def set_tokens(self, tokens: OAuthToken) -> None:
|
|
43
|
+
self._tokens = tokens
|
|
44
|
+
os.makedirs(os.path.dirname(self.TOKEN_PATH), exist_ok=True)
|
|
45
|
+
with self.token_lock:
|
|
46
|
+
with open(self.TOKEN_PATH, "w") as f:
|
|
47
|
+
dump = tokens.model_dump(exclude_none=True, mode='json')
|
|
48
|
+
json.dump(dump, f)
|
|
49
|
+
|
|
50
|
+
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
|
51
|
+
return self._client_info
|
|
52
|
+
|
|
53
|
+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
|
54
|
+
self._client_info = client_info
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CallbackHandler(BaseHTTPRequestHandler):
|
|
58
|
+
"""Simple HTTP handler to capture OAuth callback."""
|
|
59
|
+
|
|
60
|
+
def __init__(self, request, client_address, server, callback_data):
|
|
61
|
+
"""Initialize with callback data storage."""
|
|
62
|
+
self.callback_data = callback_data
|
|
63
|
+
super().__init__(request, client_address, server)
|
|
64
|
+
|
|
65
|
+
def do_GET(self):
|
|
66
|
+
"""Handle GET request from OAuth redirect."""
|
|
67
|
+
parsed = urlparse(self.path)
|
|
68
|
+
query_params = parse_qs(parsed.query)
|
|
69
|
+
|
|
70
|
+
if "code" in query_params:
|
|
71
|
+
self.callback_data["authorization_code"] = query_params["code"][0]
|
|
72
|
+
self.callback_data["state"] = query_params.get("state", [None])[0]
|
|
73
|
+
self.send_response(200)
|
|
74
|
+
self.send_header("Content-type", "text/html")
|
|
75
|
+
self.end_headers()
|
|
76
|
+
self.wfile.write(b"""
|
|
77
|
+
<html>
|
|
78
|
+
<body>
|
|
79
|
+
<h1>Authorization Successful!</h1>
|
|
80
|
+
<p>You can close this window and return to the terminal.</p>
|
|
81
|
+
<script>setTimeout(() => window.close(), 2000);</script>
|
|
82
|
+
</body>
|
|
83
|
+
</html>
|
|
84
|
+
""")
|
|
85
|
+
elif "error" in query_params:
|
|
86
|
+
self.callback_data["error"] = query_params["error"][0]
|
|
87
|
+
self.send_response(400)
|
|
88
|
+
self.send_header("Content-type", "text/html")
|
|
89
|
+
self.end_headers()
|
|
90
|
+
self.wfile.write(
|
|
91
|
+
f"""
|
|
92
|
+
<html>
|
|
93
|
+
<body>
|
|
94
|
+
<h1>Authorization Failed</h1>
|
|
95
|
+
<p>Error: {query_params["error"][0]}</p>
|
|
96
|
+
<p>You can close this window and return to the terminal.</p>
|
|
97
|
+
</body>
|
|
98
|
+
</html>
|
|
99
|
+
""".encode()
|
|
100
|
+
)
|
|
101
|
+
else:
|
|
102
|
+
self.send_response(404)
|
|
103
|
+
self.end_headers()
|
|
104
|
+
|
|
105
|
+
def log_message(self, format, *args):
|
|
106
|
+
"""Suppress default logging."""
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class CallbackServer:
|
|
111
|
+
"""Simple server to handle OAuth callbacks."""
|
|
112
|
+
|
|
113
|
+
def __init__(self, port=3000):
|
|
114
|
+
self.port = port
|
|
115
|
+
self.server = None
|
|
116
|
+
self.thread = None
|
|
117
|
+
self.callback_data = {"authorization_code": None, "state": None, "error": None}
|
|
118
|
+
|
|
119
|
+
def _create_handler_with_data(self):
|
|
120
|
+
"""Create a handler class with access to callback data."""
|
|
121
|
+
callback_data = self.callback_data
|
|
122
|
+
|
|
123
|
+
class DataCallbackHandler(CallbackHandler):
|
|
124
|
+
def __init__(self, request, client_address, server):
|
|
125
|
+
super().__init__(request, client_address, server, callback_data)
|
|
126
|
+
|
|
127
|
+
return DataCallbackHandler
|
|
128
|
+
|
|
129
|
+
def start(self):
|
|
130
|
+
"""Start the callback server in a background thread."""
|
|
131
|
+
handler_class = self._create_handler_with_data()
|
|
132
|
+
self.server = HTTPServer(("localhost", self.port), handler_class)
|
|
133
|
+
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
|
134
|
+
self.thread.start()
|
|
135
|
+
logger.info(f"🖥️ Started callback server on http://localhost:{self.port}")
|
|
136
|
+
|
|
137
|
+
def stop(self):
|
|
138
|
+
"""Stop the callback server."""
|
|
139
|
+
if self.server:
|
|
140
|
+
self.server.shutdown()
|
|
141
|
+
self.server.server_close()
|
|
142
|
+
if self.thread:
|
|
143
|
+
self.thread.join(timeout=1)
|
|
144
|
+
|
|
145
|
+
def wait_for_callback(self, timeout=300):
|
|
146
|
+
"""Wait for OAuth callback with timeout."""
|
|
147
|
+
start_time = time.time()
|
|
148
|
+
while time.time() - start_time < timeout:
|
|
149
|
+
if self.callback_data["authorization_code"]:
|
|
150
|
+
return self.callback_data["authorization_code"]
|
|
151
|
+
elif self.callback_data["error"]:
|
|
152
|
+
raise Exception(f"OAuth error: {self.callback_data['error']}")
|
|
153
|
+
time.sleep(0.1)
|
|
154
|
+
raise Exception("Timeout waiting for OAuth callback")
|
|
155
|
+
|
|
156
|
+
def get_state(self):
|
|
157
|
+
"""Get the received state parameter."""
|
|
158
|
+
return self.callback_data["state"]
|
|
159
|
+
|
|
160
|
+
def create_oauth_provider(server_name: str, url: str) -> OAuthClientProvider:
|
|
161
|
+
"""Create OAuth authentication provider."""
|
|
162
|
+
client_metadata_dict = {
|
|
163
|
+
"client_name": "KLAVIS Strata MCP Router",
|
|
164
|
+
"redirect_uris": ["http://localhost:3030/callback"],
|
|
165
|
+
"grant_types": ["authorization_code", "refresh_token"],
|
|
166
|
+
"response_types": ["code"],
|
|
167
|
+
"token_endpoint_auth_method": "client_secret_post",
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
async def callback_handler() -> tuple[str, str | None]:
|
|
171
|
+
"""Wait for OAuth callback and return auth code and state."""
|
|
172
|
+
callback_server = CallbackServer(port=3030)
|
|
173
|
+
callback_server.start()
|
|
174
|
+
logger.info("⏳ Waiting for authorization callback...")
|
|
175
|
+
try:
|
|
176
|
+
auth_code = callback_server.wait_for_callback(timeout=300)
|
|
177
|
+
return auth_code, callback_server.get_state()
|
|
178
|
+
finally:
|
|
179
|
+
callback_server.stop()
|
|
180
|
+
|
|
181
|
+
async def _default_redirect_handler(authorization_url: str) -> None:
|
|
182
|
+
"""Default redirect handler that opens the URL in a browser."""
|
|
183
|
+
logger.info(f"Opening browser for authorization: {authorization_url}")
|
|
184
|
+
webbrowser.open(authorization_url)
|
|
185
|
+
|
|
186
|
+
auth_url = ""
|
|
187
|
+
if url.endswith("/mcp"):
|
|
188
|
+
auth_url = url[:-4]
|
|
189
|
+
elif url.endswith("/sse"):
|
|
190
|
+
auth_url = url[:-4]
|
|
191
|
+
else:
|
|
192
|
+
auth_url = url
|
|
193
|
+
|
|
194
|
+
return OAuthClientProvider(
|
|
195
|
+
server_url=auth_url,
|
|
196
|
+
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
|
|
197
|
+
storage=LocalTokenStorage(server_name),
|
|
198
|
+
redirect_handler=_default_redirect_handler,
|
|
199
|
+
callback_handler=callback_handler,
|
|
200
|
+
)
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""MCP Client for connecting to and interacting with MCP servers."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from mcp import types
|
|
7
|
+
|
|
8
|
+
from .transport import Transport
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MCPClient:
|
|
14
|
+
"""Client for connecting to MCP servers using various transports.
|
|
15
|
+
|
|
16
|
+
Usage:
|
|
17
|
+
# With stdio transport
|
|
18
|
+
transport = StdioTransport("docker", ["run", "-i", "my-server"])
|
|
19
|
+
client = MCPClient(transport)
|
|
20
|
+
await client.connect()
|
|
21
|
+
|
|
22
|
+
# With HTTP/SSE transport
|
|
23
|
+
transport = HTTPTransport("http://localhost:8080", mode="sse")
|
|
24
|
+
client = MCPClient(transport)
|
|
25
|
+
await client.connect()
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, transport: Transport):
|
|
29
|
+
"""Initialize the MCP client with a transport.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
transport: Transport instance (StdioTransport or HTTPTransport)
|
|
33
|
+
"""
|
|
34
|
+
self.transport = transport
|
|
35
|
+
self._tools_cache: Optional[List[Dict[str, Any]]] = None
|
|
36
|
+
|
|
37
|
+
async def initialize(self) -> None:
|
|
38
|
+
"""Initialize the MCP client by connecting the transport."""
|
|
39
|
+
await self.transport.initialize()
|
|
40
|
+
|
|
41
|
+
async def connect(self) -> None:
|
|
42
|
+
"""Connect to the MCP server."""
|
|
43
|
+
await self.transport.connect()
|
|
44
|
+
logger.info(
|
|
45
|
+
f"Connected to MCP server using {self.transport.__class__.__name__}"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
async def disconnect(self) -> None:
|
|
49
|
+
"""Disconnect from the MCP server."""
|
|
50
|
+
await self.transport.disconnect()
|
|
51
|
+
self._tools_cache = None
|
|
52
|
+
logger.info("Disconnected from MCP server")
|
|
53
|
+
|
|
54
|
+
def is_connected(self) -> bool:
|
|
55
|
+
"""Check if connected to an MCP server."""
|
|
56
|
+
return self.transport.is_connected()
|
|
57
|
+
|
|
58
|
+
async def list_tools(self, use_cache: bool = True) -> List[Dict[str, Any]]:
|
|
59
|
+
"""List available tools from the MCP server.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
use_cache: Whether to use cached tools if available
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
List of tool definitions with name, description, and inputSchema
|
|
66
|
+
"""
|
|
67
|
+
if not self.transport.is_connected():
|
|
68
|
+
raise RuntimeError("Not connected to any MCP server")
|
|
69
|
+
|
|
70
|
+
if use_cache and self._tools_cache is not None:
|
|
71
|
+
return self._tools_cache
|
|
72
|
+
|
|
73
|
+
# Get tools from server
|
|
74
|
+
session = self.transport.get_session()
|
|
75
|
+
response = await session.list_tools()
|
|
76
|
+
|
|
77
|
+
# Convert to dict format
|
|
78
|
+
tools = []
|
|
79
|
+
for tool in response.tools:
|
|
80
|
+
tool_dict = {
|
|
81
|
+
"name": tool.name,
|
|
82
|
+
"description": tool.description,
|
|
83
|
+
"inputSchema": tool.inputSchema,
|
|
84
|
+
}
|
|
85
|
+
# Add optional fields only if they have values
|
|
86
|
+
if hasattr(tool, "title") and tool.title:
|
|
87
|
+
tool_dict["title"] = tool.title
|
|
88
|
+
if hasattr(tool, "outputSchema") and tool.outputSchema:
|
|
89
|
+
tool_dict["outputSchema"] = tool.outputSchema
|
|
90
|
+
tools.append(tool_dict)
|
|
91
|
+
|
|
92
|
+
self._tools_cache = tools
|
|
93
|
+
logger.info(f"Retrieved {len(tools)} tools from MCP server")
|
|
94
|
+
|
|
95
|
+
return tools
|
|
96
|
+
|
|
97
|
+
async def call_tool(
|
|
98
|
+
self, tool_name: str, arguments: Dict[str, Any]
|
|
99
|
+
) -> List[types.ContentBlock]:
|
|
100
|
+
"""Call a tool on the MCP server.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
tool_name: Name of the tool to call
|
|
104
|
+
arguments: Arguments to pass to the tool
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
Tool execution result from MCP server
|
|
108
|
+
"""
|
|
109
|
+
if not self.transport.is_connected():
|
|
110
|
+
raise RuntimeError("Not connected to any MCP server")
|
|
111
|
+
|
|
112
|
+
logger.info(f"Calling tool '{tool_name}' with arguments: {arguments}")
|
|
113
|
+
|
|
114
|
+
# Call the tool and return result directly
|
|
115
|
+
session = self.transport.get_session()
|
|
116
|
+
result = await session.call_tool(tool_name, arguments)
|
|
117
|
+
if result.isError:
|
|
118
|
+
# Extract error text from content blocks (stderr is captured here)
|
|
119
|
+
logger.debug(f"Error result content: {result.content}")
|
|
120
|
+
logger.debug(f"Error result content type: {type(result.content)}")
|
|
121
|
+
logger.debug(f"Error result content length: {len(result.content) if result.content else 0}")
|
|
122
|
+
|
|
123
|
+
error_messages = []
|
|
124
|
+
for i, content_block in enumerate(result.content or []):
|
|
125
|
+
logger.debug(f"Content block {i}: {content_block}, type: {type(content_block)}")
|
|
126
|
+
if hasattr(content_block, 'text'):
|
|
127
|
+
logger.debug(f"Content block {i} has text: {content_block.text}")
|
|
128
|
+
error_messages.append(content_block.text)
|
|
129
|
+
else:
|
|
130
|
+
logger.debug(f"Content block {i} attributes: {dir(content_block)}")
|
|
131
|
+
|
|
132
|
+
error_text = "\n".join(error_messages) if error_messages else f"Unknown error (no error text in {len(result.content or [])} content blocks)"
|
|
133
|
+
|
|
134
|
+
logger.error(f"Tool '{tool_name}' returned error: {error_text}")
|
|
135
|
+
raise RuntimeError(f"Tool '{tool_name}' error: {error_text}")
|
|
136
|
+
return result.content
|
|
137
|
+
|
|
138
|
+
async def get_tool_schema(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
|
139
|
+
"""Get the schema for a specific tool.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
tool_name: Name of the tool
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Tool schema or None if not found
|
|
146
|
+
"""
|
|
147
|
+
tools = await self.list_tools()
|
|
148
|
+
|
|
149
|
+
for tool in tools:
|
|
150
|
+
if tool["name"] == tool_name:
|
|
151
|
+
return tool
|
|
152
|
+
|
|
153
|
+
return None
|
|
154
|
+
|
|
155
|
+
async def __aenter__(self):
|
|
156
|
+
"""Enter async context manager."""
|
|
157
|
+
await self.connect()
|
|
158
|
+
return self
|
|
159
|
+
|
|
160
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
161
|
+
"""Exit async context manager."""
|
|
162
|
+
await self.disconnect()
|