ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/mcp_client.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP (Model Context Protocol) client — full-spec implementation.
|
|
3
|
+
|
|
4
|
+
Supports MCP servers over:
|
|
5
|
+
- stdio (subprocess): spawns the server as a child process
|
|
6
|
+
- HTTP/SSE: connects to a remote MCP server
|
|
7
|
+
|
|
8
|
+
Implements: capability negotiation, tools, resources, prompts, ping,
|
|
9
|
+
progress notifications, cancellation, resource templates, logging,
|
|
10
|
+
completion, roots.
|
|
11
|
+
|
|
12
|
+
Spec: https://spec.modelcontextprotocol.io/
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
import time
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Callable
|
|
23
|
+
|
|
24
|
+
import httpx
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
30
|
+
# JSON-RPC 2.0 standard error codes
|
|
31
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
32
|
+
|
|
33
|
+
class JsonRpcError(Exception):
|
|
34
|
+
"""A JSON-RPC error with standard code and message."""
|
|
35
|
+
def __init__(self, code: int, message: str, data: Any = None):
|
|
36
|
+
self.code = code
|
|
37
|
+
self.message = message
|
|
38
|
+
self.data = data
|
|
39
|
+
super().__init__(message)
|
|
40
|
+
|
|
41
|
+
# Standard JSON-RPC error codes
|
|
42
|
+
PARSE_ERROR = -32700
|
|
43
|
+
INVALID_REQUEST = -32600
|
|
44
|
+
METHOD_NOT_FOUND = -32601
|
|
45
|
+
INVALID_PARAMS = -32602
|
|
46
|
+
INTERNAL_ERROR = -32603
|
|
47
|
+
# MCP-specific (server error range: -32000 to -32099)
|
|
48
|
+
SERVER_NOT_INITIALIZED = -32002
|
|
49
|
+
REQUEST_CANCELLED = -32800
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
53
|
+
# JSON-RPC types
|
|
54
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
55
|
+
|
|
56
|
+
JsonRpcId = str | int
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class JsonRpcRequest:
|
|
61
|
+
jsonrpc: str = "2.0"
|
|
62
|
+
method: str = ""
|
|
63
|
+
params: dict[str, Any] = field(default_factory=dict)
|
|
64
|
+
id: JsonRpcId = ""
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class JsonRpcResponse:
|
|
69
|
+
jsonrpc: str = "2.0"
|
|
70
|
+
result: Any = None
|
|
71
|
+
error: dict[str, Any] | None = None
|
|
72
|
+
id: JsonRpcId = ""
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
76
|
+
# MCP Server connection — base class
|
|
77
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
78
|
+
|
|
79
|
+
class MCPServerConnection:
|
|
80
|
+
"""
|
|
81
|
+
A connection to a single MCP server.
|
|
82
|
+
|
|
83
|
+
Handles JSON-RPC communication, capability negotiation,
|
|
84
|
+
tool/resource/prompt discovery, ping, progress, and cancellation.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
PROTOCOL_VERSION = "2025-03-26"
|
|
88
|
+
|
|
89
|
+
def __init__(self, name: str):
|
|
90
|
+
self.name = name
|
|
91
|
+
self._tools: list[dict[str, Any]] = []
|
|
92
|
+
self._resources: list[dict[str, Any]] = []
|
|
93
|
+
self._resource_templates: list[dict[str, Any]] = []
|
|
94
|
+
self._prompts: list[dict[str, Any]] = []
|
|
95
|
+
self._initialized: bool = False
|
|
96
|
+
self._server_info: dict[str, Any] = {}
|
|
97
|
+
self._capabilities: dict[str, Any] = {}
|
|
98
|
+
self._server_capabilities: dict[str, Any] = {}
|
|
99
|
+
self._roots: list[dict[str, Any]] = []
|
|
100
|
+
self._pong_received: bool = False
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def tools(self) -> list[dict[str, Any]]:
|
|
104
|
+
return self._tools
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def resources(self) -> list[dict[str, Any]]:
|
|
108
|
+
return self._resources
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def resource_templates(self) -> list[dict[str, Any]]:
|
|
112
|
+
return self._resource_templates
|
|
113
|
+
|
|
114
|
+
@property
|
|
115
|
+
def prompts(self) -> list[dict[str, Any]]:
|
|
116
|
+
return self._prompts
|
|
117
|
+
|
|
118
|
+
@property
|
|
119
|
+
def initialized(self) -> bool:
|
|
120
|
+
return self._initialized
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def server_info(self) -> dict[str, Any]:
|
|
124
|
+
return self._server_info
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def capabilities(self) -> dict[str, Any]:
|
|
128
|
+
return self._capabilities
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def server_capabilities(self) -> dict[str, Any]:
|
|
132
|
+
return self._server_capabilities
|
|
133
|
+
|
|
134
|
+
# ── Abstract methods ──
|
|
135
|
+
|
|
136
|
+
async def start(self) -> None:
|
|
137
|
+
raise NotImplementedError
|
|
138
|
+
|
|
139
|
+
async def stop(self) -> None:
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
|
146
|
+
raise NotImplementedError
|
|
147
|
+
|
|
148
|
+
# ── Capability checks ──
|
|
149
|
+
|
|
150
|
+
def has_capability(self, cap: str) -> bool:
|
|
151
|
+
"""Check if the server supports a given capability namespace."""
|
|
152
|
+
return cap in self._server_capabilities
|
|
153
|
+
|
|
154
|
+
def has_subcapability(self, cap: str, sub: str) -> bool:
|
|
155
|
+
"""Check if the server supports a sub-capability (e.g. tools→listChanged)."""
|
|
156
|
+
caps = self._server_capabilities.get(cap, {})
|
|
157
|
+
return isinstance(caps, dict) and sub in caps
|
|
158
|
+
|
|
159
|
+
# ── Lifecycle ──
|
|
160
|
+
|
|
161
|
+
async def initialize(self, client_capabilities: dict[str, Any] | None = None) -> None:
|
|
162
|
+
"""Send initialize request and negotiate capabilities."""
|
|
163
|
+
caps = {
|
|
164
|
+
"tools": {},
|
|
165
|
+
"resources": {"subscribe": True},
|
|
166
|
+
"prompts": {},
|
|
167
|
+
"logging": {},
|
|
168
|
+
}
|
|
169
|
+
if client_capabilities:
|
|
170
|
+
caps.update(client_capabilities)
|
|
171
|
+
|
|
172
|
+
init_result = await self.send_request("initialize", {
|
|
173
|
+
"protocolVersion": self.PROTOCOL_VERSION,
|
|
174
|
+
"capabilities": caps,
|
|
175
|
+
"clientInfo": {
|
|
176
|
+
"name": "ata-coder",
|
|
177
|
+
"version": "2.3.0",
|
|
178
|
+
},
|
|
179
|
+
})
|
|
180
|
+
self._server_info = init_result.get("serverInfo", {})
|
|
181
|
+
self._server_capabilities = init_result.get("capabilities", {})
|
|
182
|
+
self._initialized = True
|
|
183
|
+
|
|
184
|
+
# Send initialized notification
|
|
185
|
+
await self.send_notification("notifications/initialized", {})
|
|
186
|
+
|
|
187
|
+
logger.info(
|
|
188
|
+
"[%s] Initialized: %s v%s (caps: %s)",
|
|
189
|
+
self.name,
|
|
190
|
+
self._server_info.get("name", "unknown"),
|
|
191
|
+
self._server_info.get("version", "?"),
|
|
192
|
+
", ".join(self._server_capabilities) or "none",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
async def discover(self) -> None:
|
|
196
|
+
"""Discover tools, resources, resource templates, and prompts."""
|
|
197
|
+
if not self._initialized:
|
|
198
|
+
raise JsonRpcError(SERVER_NOT_INITIALIZED, "Server not initialized")
|
|
199
|
+
|
|
200
|
+
# Tools
|
|
201
|
+
if self.has_capability("tools"):
|
|
202
|
+
result = await self.send_request("tools/list", {})
|
|
203
|
+
self._tools = result.get("tools", [])
|
|
204
|
+
logger.info("[%s] Discovered %d tools", self.name, len(self._tools))
|
|
205
|
+
|
|
206
|
+
# Resources
|
|
207
|
+
if self.has_capability("resources"):
|
|
208
|
+
try:
|
|
209
|
+
result = await self.send_request("resources/list", {})
|
|
210
|
+
self._resources = result.get("resources", [])
|
|
211
|
+
logger.info("[%s] Discovered %d resources", self.name, len(self._resources))
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
# Resource templates
|
|
216
|
+
if self.has_capability("resources"):
|
|
217
|
+
try:
|
|
218
|
+
result = await self.send_request("resources/templates/list", {})
|
|
219
|
+
self._resource_templates = result.get("resourceTemplates", [])
|
|
220
|
+
logger.info("[%s] Discovered %d resource templates", self.name, len(self._resource_templates))
|
|
221
|
+
except Exception:
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
# Prompts
|
|
225
|
+
if self.has_capability("prompts"):
|
|
226
|
+
try:
|
|
227
|
+
result = await self.send_request("prompts/list", {})
|
|
228
|
+
self._prompts = result.get("prompts", [])
|
|
229
|
+
logger.info("[%s] Discovered %d prompts", self.name, len(self._prompts))
|
|
230
|
+
except Exception:
|
|
231
|
+
pass
|
|
232
|
+
|
|
233
|
+
# ── Tool calling ──
|
|
234
|
+
|
|
235
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
|
236
|
+
"""Call a tool on this MCP server."""
|
|
237
|
+
if not self.has_capability("tools"):
|
|
238
|
+
raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support tools")
|
|
239
|
+
return await self.send_request("tools/call", {
|
|
240
|
+
"name": tool_name,
|
|
241
|
+
"arguments": arguments,
|
|
242
|
+
})
|
|
243
|
+
|
|
244
|
+
# ── Resource reading ──
|
|
245
|
+
|
|
246
|
+
async def read_resource(self, uri: str) -> Any:
|
|
247
|
+
"""Read a resource by URI."""
|
|
248
|
+
if not self.has_capability("resources"):
|
|
249
|
+
raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support resources")
|
|
250
|
+
return await self.send_request("resources/read", {"uri": uri})
|
|
251
|
+
|
|
252
|
+
async def subscribe_resource(self, uri: str) -> None:
|
|
253
|
+
"""Subscribe to resource updates."""
|
|
254
|
+
if not self.has_subcapability("resources", "subscribe"):
|
|
255
|
+
return
|
|
256
|
+
await self.send_notification("resources/subscribe", {"uri": uri})
|
|
257
|
+
logger.info("[%s] Subscribed to: %s", self.name, uri)
|
|
258
|
+
|
|
259
|
+
# ── Prompts ──
|
|
260
|
+
|
|
261
|
+
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> Any:
|
|
262
|
+
"""Get a prompt by name with optional arguments."""
|
|
263
|
+
if not self.has_capability("prompts"):
|
|
264
|
+
raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support prompts")
|
|
265
|
+
params: dict[str, Any] = {"name": name}
|
|
266
|
+
if arguments:
|
|
267
|
+
params["arguments"] = arguments
|
|
268
|
+
return await self.send_request("prompts/get", params)
|
|
269
|
+
|
|
270
|
+
# ── Ping ──
|
|
271
|
+
|
|
272
|
+
async def ping(self, timeout: float = 10.0) -> bool:
|
|
273
|
+
"""Ping the server. Returns True if alive."""
|
|
274
|
+
try:
|
|
275
|
+
await asyncio.wait_for(self.send_request("ping", {}), timeout=timeout)
|
|
276
|
+
return True
|
|
277
|
+
except Exception:
|
|
278
|
+
return False
|
|
279
|
+
|
|
280
|
+
# ── Completion ──
|
|
281
|
+
|
|
282
|
+
async def complete(self, ref: dict[str, Any], argument: dict[str, Any]) -> Any:
|
|
283
|
+
"""Request auto-completion for a prompt or resource template argument."""
|
|
284
|
+
return await self.send_request("completion/complete", {
|
|
285
|
+
"ref": ref,
|
|
286
|
+
"argument": argument,
|
|
287
|
+
})
|
|
288
|
+
|
|
289
|
+
# ── Roots ──
|
|
290
|
+
|
|
291
|
+
async def set_roots(self, roots: list[dict[str, Any]]) -> None:
|
|
292
|
+
"""Inform the server about root directories."""
|
|
293
|
+
self._roots = roots
|
|
294
|
+
await self.send_notification("notifications/roots/list_changed", {"roots": roots})
|
|
295
|
+
logger.info("[%s] Updated roots: %d", self.name, len(roots))
|
|
296
|
+
|
|
297
|
+
# ── Logging ──
|
|
298
|
+
|
|
299
|
+
async def set_log_level(self, level: str) -> None:
|
|
300
|
+
"""Set the log level on the server (debug/info/notice/warning/error/critical)."""
|
|
301
|
+
await self.send_notification("logging/setLevel", {"level": level})
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
305
|
+
# Stdio connection
|
|
306
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
307
|
+
|
|
308
|
+
class StdioMCPConnection(MCPServerConnection):
|
|
309
|
+
"""MCP connection over stdio (subprocess)."""
|
|
310
|
+
|
|
311
|
+
_next_req_id = 0
|
|
312
|
+
|
|
313
|
+
def __init__(self, name: str, command: str, args: list[str] | None = None,
|
|
314
|
+
env: dict[str, str] | None = None, cwd: str | None = None):
|
|
315
|
+
super().__init__(name)
|
|
316
|
+
self.command = command
|
|
317
|
+
self.args = args or []
|
|
318
|
+
self.env = env
|
|
319
|
+
self.cwd = cwd
|
|
320
|
+
self._process: asyncio.subprocess.Process | None = None
|
|
321
|
+
self._pending: dict[JsonRpcId, asyncio.Event] = {}
|
|
322
|
+
self._results: dict[JsonRpcId, JsonRpcResponse] = {}
|
|
323
|
+
self._reader_task: asyncio.Task | None = None
|
|
324
|
+
self._running = False
|
|
325
|
+
self._on_progress: Callable[[int, int, str | None], None] | None = None
|
|
326
|
+
|
|
327
|
+
@classmethod
|
|
328
|
+
def _next_id(cls) -> str:
|
|
329
|
+
cls._next_req_id += 1
|
|
330
|
+
return str(cls._next_req_id)
|
|
331
|
+
|
|
332
|
+
def on_progress(self, callback: Callable[[int, int, str | None], None]) -> None:
|
|
333
|
+
"""Register a callback for progress notifications."""
|
|
334
|
+
self._on_progress = callback
|
|
335
|
+
|
|
336
|
+
# ── Start / Stop ──
|
|
337
|
+
|
|
338
|
+
async def start(self) -> None:
|
|
339
|
+
"""Start the MCP server process."""
|
|
340
|
+
logger.info("[%s] Starting: %s %s", self.name, self.command, " ".join(self.args))
|
|
341
|
+
|
|
342
|
+
try:
|
|
343
|
+
self._process = await asyncio.create_subprocess_exec(
|
|
344
|
+
self.command, *self.args,
|
|
345
|
+
stdin=asyncio.subprocess.PIPE,
|
|
346
|
+
stdout=asyncio.subprocess.PIPE,
|
|
347
|
+
stderr=asyncio.subprocess.DEVNULL,
|
|
348
|
+
env=self.env,
|
|
349
|
+
cwd=self.cwd,
|
|
350
|
+
)
|
|
351
|
+
except FileNotFoundError:
|
|
352
|
+
raise RuntimeError(
|
|
353
|
+
f"MCP server command not found: {self.command}. "
|
|
354
|
+
f"Install it or check the path."
|
|
355
|
+
)
|
|
356
|
+
except Exception as e:
|
|
357
|
+
raise RuntimeError(f"Failed to start MCP server: {e}")
|
|
358
|
+
|
|
359
|
+
self._running = True
|
|
360
|
+
self._reader_task = asyncio.create_task(self._read_loop())
|
|
361
|
+
|
|
362
|
+
try:
|
|
363
|
+
await self.initialize()
|
|
364
|
+
await self.discover()
|
|
365
|
+
except Exception:
|
|
366
|
+
await self.stop()
|
|
367
|
+
raise
|
|
368
|
+
|
|
369
|
+
async def stop(self) -> None:
|
|
370
|
+
"""Stop the MCP server process."""
|
|
371
|
+
self._running = False
|
|
372
|
+
|
|
373
|
+
# Cancel and await reader task FIRST — it holds the stdout pipe open.
|
|
374
|
+
if self._reader_task and not self._reader_task.done():
|
|
375
|
+
self._reader_task.cancel()
|
|
376
|
+
try:
|
|
377
|
+
await self._reader_task
|
|
378
|
+
except (asyncio.CancelledError, Exception):
|
|
379
|
+
pass
|
|
380
|
+
self._reader_task = None
|
|
381
|
+
|
|
382
|
+
# Terminate/kill the process
|
|
383
|
+
proc = self._process
|
|
384
|
+
self._process = None
|
|
385
|
+
if proc is not None:
|
|
386
|
+
try:
|
|
387
|
+
proc.terminate()
|
|
388
|
+
try:
|
|
389
|
+
await asyncio.wait_for(proc.wait(), timeout=5)
|
|
390
|
+
except asyncio.TimeoutError:
|
|
391
|
+
try:
|
|
392
|
+
proc.kill()
|
|
393
|
+
await asyncio.wait_for(proc.wait(), timeout=3)
|
|
394
|
+
except Exception:
|
|
395
|
+
pass
|
|
396
|
+
except Exception:
|
|
397
|
+
try:
|
|
398
|
+
proc.kill()
|
|
399
|
+
except Exception:
|
|
400
|
+
pass
|
|
401
|
+
# Explicitly close pipes to prevent "I/O operation on closed pipe"
|
|
402
|
+
# during BaseSubprocessTransport.__del__ at GC time.
|
|
403
|
+
for pipe in (proc.stdin, proc.stdout, proc.stderr):
|
|
404
|
+
if pipe is not None:
|
|
405
|
+
try:
|
|
406
|
+
pipe.close()
|
|
407
|
+
except Exception:
|
|
408
|
+
pass
|
|
409
|
+
|
|
410
|
+
# Release all pending requests
|
|
411
|
+
for evt in self._pending.values():
|
|
412
|
+
evt.set()
|
|
413
|
+
self._pending.clear()
|
|
414
|
+
self._results.clear()
|
|
415
|
+
|
|
416
|
+
logger.info("[%s] Stopped", self.name)
|
|
417
|
+
|
|
418
|
+
# ── Message I/O ──
|
|
419
|
+
|
|
420
|
+
async def _send_raw(self, msg: dict[str, Any]) -> None:
|
|
421
|
+
"""Send a raw JSON-RPC message to the server."""
|
|
422
|
+
if not self._process or not self._process.stdin:
|
|
423
|
+
raise RuntimeError("MCP server not running")
|
|
424
|
+
line = json.dumps(msg, ensure_ascii=False) + "\n"
|
|
425
|
+
try:
|
|
426
|
+
self._process.stdin.write(line.encode("utf-8"))
|
|
427
|
+
await self._process.stdin.drain()
|
|
428
|
+
except Exception as e:
|
|
429
|
+
raise RuntimeError(f"Failed to send to MCP server: {e}")
|
|
430
|
+
|
|
431
|
+
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
432
|
+
"""Send a JSON-RPC request and wait for the response."""
|
|
433
|
+
req_id = self._next_id()
|
|
434
|
+
msg = {
|
|
435
|
+
"jsonrpc": "2.0",
|
|
436
|
+
"method": method,
|
|
437
|
+
"params": params or {},
|
|
438
|
+
"id": req_id,
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
event = asyncio.Event()
|
|
442
|
+
self._pending[req_id] = event
|
|
443
|
+
|
|
444
|
+
await self._send_raw(msg)
|
|
445
|
+
|
|
446
|
+
# Wait with timeout
|
|
447
|
+
timeout = 120 if method == "initialize" else 60
|
|
448
|
+
try:
|
|
449
|
+
await asyncio.wait_for(event.wait(), timeout=timeout)
|
|
450
|
+
except asyncio.TimeoutError:
|
|
451
|
+
self._pending.pop(req_id, None)
|
|
452
|
+
raise JsonRpcError(INTERNAL_ERROR, f"MCP request timeout: {method}")
|
|
453
|
+
|
|
454
|
+
response = self._results.pop(req_id, None)
|
|
455
|
+
|
|
456
|
+
if response is None:
|
|
457
|
+
raise JsonRpcError(INTERNAL_ERROR, f"No response for request: {method}")
|
|
458
|
+
|
|
459
|
+
if response.error:
|
|
460
|
+
raise JsonRpcError(
|
|
461
|
+
response.error.get("code", INTERNAL_ERROR),
|
|
462
|
+
response.error.get("message", "unknown"),
|
|
463
|
+
response.error.get("data"),
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
return response.result
|
|
467
|
+
|
|
468
|
+
async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
|
469
|
+
"""Send a JSON-RPC notification (no response expected)."""
|
|
470
|
+
msg = {
|
|
471
|
+
"jsonrpc": "2.0",
|
|
472
|
+
"method": method,
|
|
473
|
+
"params": params or {},
|
|
474
|
+
}
|
|
475
|
+
await self._send_raw(msg)
|
|
476
|
+
|
|
477
|
+
# ── Cancellation ──
|
|
478
|
+
|
|
479
|
+
async def cancel_request(self, req_id: JsonRpcId) -> None:
|
|
480
|
+
"""Cancel an in-flight request."""
|
|
481
|
+
await self.send_notification("notifications/cancelled", {
|
|
482
|
+
"requestId": req_id,
|
|
483
|
+
"reason": "User cancelled",
|
|
484
|
+
})
|
|
485
|
+
self._pending.pop(req_id, None)
|
|
486
|
+
self._results.pop(req_id, None)
|
|
487
|
+
|
|
488
|
+
# ── Read loop ──
|
|
489
|
+
|
|
490
|
+
async def _read_loop(self) -> None:
|
|
491
|
+
"""Background task: reads JSON-RPC messages from the server's stdout."""
|
|
492
|
+
while self._running and self._process and self._process.stdout:
|
|
493
|
+
try:
|
|
494
|
+
line_bytes = await self._process.stdout.readline()
|
|
495
|
+
if not line_bytes:
|
|
496
|
+
break
|
|
497
|
+
line = line_bytes.decode("utf-8", errors="replace")
|
|
498
|
+
|
|
499
|
+
try:
|
|
500
|
+
msg = json.loads(line.strip())
|
|
501
|
+
except json.JSONDecodeError:
|
|
502
|
+
continue
|
|
503
|
+
|
|
504
|
+
msg_id = msg.get("id")
|
|
505
|
+
method = msg.get("method")
|
|
506
|
+
|
|
507
|
+
# ── Response to our request ──
|
|
508
|
+
if msg_id is not None and method is None:
|
|
509
|
+
if msg_id in self._pending:
|
|
510
|
+
response = JsonRpcResponse(
|
|
511
|
+
jsonrpc=msg.get("jsonrpc", "2.0"),
|
|
512
|
+
result=msg.get("result"),
|
|
513
|
+
error=msg.get("error"),
|
|
514
|
+
id=msg_id,
|
|
515
|
+
)
|
|
516
|
+
self._results[msg_id] = response
|
|
517
|
+
self._pending[msg_id].set()
|
|
518
|
+
|
|
519
|
+
# ── Server → client request ──
|
|
520
|
+
elif method and "id" in msg and msg.get("id"):
|
|
521
|
+
await self._handle_server_request(msg)
|
|
522
|
+
|
|
523
|
+
# ── Notification from server ──
|
|
524
|
+
elif method and "id" not in msg:
|
|
525
|
+
await self._handle_notification(method, msg.get("params", {}))
|
|
526
|
+
|
|
527
|
+
except asyncio.CancelledError:
|
|
528
|
+
raise
|
|
529
|
+
except Exception:
|
|
530
|
+
if self._running:
|
|
531
|
+
logger.exception("[%s] Read error", self.name)
|
|
532
|
+
break
|
|
533
|
+
|
|
534
|
+
async def _handle_server_request(self, msg: dict[str, Any]) -> None:
|
|
535
|
+
"""Handle a request from the server (e.g. sampling/createMessage)."""
|
|
536
|
+
method = msg.get("method", "")
|
|
537
|
+
req_id = msg.get("id")
|
|
538
|
+
|
|
539
|
+
# For now, return method-not-found for all server requests.
|
|
540
|
+
# Full sampling support would require an LLM callback from the agent.
|
|
541
|
+
error_response = {
|
|
542
|
+
"jsonrpc": "2.0",
|
|
543
|
+
"id": req_id,
|
|
544
|
+
"error": {
|
|
545
|
+
"code": METHOD_NOT_FOUND,
|
|
546
|
+
"message": f"Method not supported by this client: {method}",
|
|
547
|
+
},
|
|
548
|
+
}
|
|
549
|
+
await self._send_raw(error_response)
|
|
550
|
+
|
|
551
|
+
async def _handle_notification(self, method: str, params: dict[str, Any]) -> None:
|
|
552
|
+
"""Handle a notification from the server."""
|
|
553
|
+
if method == "notifications/progress":
|
|
554
|
+
# Progress token + progress + total
|
|
555
|
+
progress_token = params.get("progressToken")
|
|
556
|
+
progress = params.get("progress", 0)
|
|
557
|
+
total = params.get("total", 0)
|
|
558
|
+
if self._on_progress:
|
|
559
|
+
self._on_progress(progress, total, progress_token)
|
|
560
|
+
logger.debug("[%s] Progress: %d/%d", self.name, progress, total)
|
|
561
|
+
|
|
562
|
+
elif method == "notifications/resources/updated":
|
|
563
|
+
uri = params.get("uri", "?")
|
|
564
|
+
logger.info("[%s] Resource updated: %s", self.name, uri)
|
|
565
|
+
|
|
566
|
+
elif method == "notifications/resources/list_changed":
|
|
567
|
+
logger.info("[%s] Resource list changed — re-discovering", self.name)
|
|
568
|
+
try:
|
|
569
|
+
await self.discover()
|
|
570
|
+
except Exception:
|
|
571
|
+
pass
|
|
572
|
+
|
|
573
|
+
elif method == "notifications/tools/list_changed":
|
|
574
|
+
logger.info("[%s] Tool list changed — re-discovering", self.name)
|
|
575
|
+
try:
|
|
576
|
+
await self.discover()
|
|
577
|
+
except Exception:
|
|
578
|
+
pass
|
|
579
|
+
|
|
580
|
+
elif method == "notifications/prompts/list_changed":
|
|
581
|
+
logger.info("[%s] Prompt list changed — re-discovering", self.name)
|
|
582
|
+
try:
|
|
583
|
+
await self.discover()
|
|
584
|
+
except Exception:
|
|
585
|
+
pass
|
|
586
|
+
|
|
587
|
+
elif method == "notifications/message":
|
|
588
|
+
# Server→client log message
|
|
589
|
+
level = params.get("level", "info")
|
|
590
|
+
data = params.get("data", "")
|
|
591
|
+
log_func = getattr(logger, level, logger.info)
|
|
592
|
+
log_func("[%s] %s", self.name, data)
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
596
|
+
# HTTP / SSE connection
|
|
597
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
598
|
+
|
|
599
|
+
class HTTPMCPConnection(MCPServerConnection):
|
|
600
|
+
"""MCP connection over HTTP (Streamable HTTP transport)."""
|
|
601
|
+
|
|
602
|
+
def __init__(self, name: str, url: str, headers: dict[str, str] | None = None):
|
|
603
|
+
super().__init__(name)
|
|
604
|
+
self.url = url.rstrip("/")
|
|
605
|
+
self._headers = headers or {}
|
|
606
|
+
self._client: httpx.Client | None = None
|
|
607
|
+
self._id_counter = 0
|
|
608
|
+
|
|
609
|
+
def _next_id(self) -> str:
|
|
610
|
+
self._id_counter += 1
|
|
611
|
+
return str(self._id_counter)
|
|
612
|
+
|
|
613
|
+
async def start(self) -> None:
|
|
614
|
+
"""Initialize HTTP connection."""
|
|
615
|
+
self._client = httpx.Client(
|
|
616
|
+
timeout=httpx.Timeout(120.0, connect=30.0),
|
|
617
|
+
headers={
|
|
618
|
+
"Content-Type": "application/json",
|
|
619
|
+
**self._headers,
|
|
620
|
+
},
|
|
621
|
+
)
|
|
622
|
+
logger.info("[%s] Connecting to %s", self.name, self.url)
|
|
623
|
+
|
|
624
|
+
try:
|
|
625
|
+
await self.initialize()
|
|
626
|
+
await self.discover()
|
|
627
|
+
except Exception:
|
|
628
|
+
await self.stop()
|
|
629
|
+
raise
|
|
630
|
+
|
|
631
|
+
async def stop(self) -> None:
|
|
632
|
+
if self._client:
|
|
633
|
+
self._client.close()
|
|
634
|
+
self._client = None
|
|
635
|
+
logger.info("[%s] Disconnected", self.name)
|
|
636
|
+
|
|
637
|
+
def _post(self, msg: dict[str, Any]) -> httpx.Response:
|
|
638
|
+
if not self._client:
|
|
639
|
+
raise RuntimeError("MCP HTTP client not connected")
|
|
640
|
+
response = self._client.post(self.url, json=msg)
|
|
641
|
+
response.raise_for_status()
|
|
642
|
+
return response
|
|
643
|
+
|
|
644
|
+
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
645
|
+
msg = {
|
|
646
|
+
"jsonrpc": "2.0",
|
|
647
|
+
"method": method,
|
|
648
|
+
"params": params or {},
|
|
649
|
+
"id": self._next_id(),
|
|
650
|
+
}
|
|
651
|
+
response = await asyncio.to_thread(self._post, msg)
|
|
652
|
+
|
|
653
|
+
# Handle SSE stream for streaming responses
|
|
654
|
+
ct = response.headers.get("content-type", "")
|
|
655
|
+
if "text/event-stream" in ct:
|
|
656
|
+
return self._read_sse(response)
|
|
657
|
+
|
|
658
|
+
data = response.json()
|
|
659
|
+
if data.get("error"):
|
|
660
|
+
raise JsonRpcError(
|
|
661
|
+
data["error"].get("code", INTERNAL_ERROR),
|
|
662
|
+
data["error"].get("message", "unknown"),
|
|
663
|
+
data["error"].get("data"),
|
|
664
|
+
)
|
|
665
|
+
return data.get("result")
|
|
666
|
+
|
|
667
|
+
async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
|
668
|
+
msg = {
|
|
669
|
+
"jsonrpc": "2.0",
|
|
670
|
+
"method": method,
|
|
671
|
+
"params": params or {},
|
|
672
|
+
}
|
|
673
|
+
await asyncio.to_thread(self._post, msg)
|
|
674
|
+
|
|
675
|
+
@staticmethod
|
|
676
|
+
def _read_sse(response: httpx.Response) -> Any:
|
|
677
|
+
"""Read SSE stream, collect the final result."""
|
|
678
|
+
result = None
|
|
679
|
+
for line in response.iter_lines():
|
|
680
|
+
if line.startswith("data: "):
|
|
681
|
+
try:
|
|
682
|
+
data = json.loads(line[6:])
|
|
683
|
+
if data.get("result") is not None:
|
|
684
|
+
result = data["result"]
|
|
685
|
+
if data.get("error"):
|
|
686
|
+
raise JsonRpcError(
|
|
687
|
+
data["error"].get("code", INTERNAL_ERROR),
|
|
688
|
+
data["error"].get("message", "unknown"),
|
|
689
|
+
)
|
|
690
|
+
except json.JSONDecodeError:
|
|
691
|
+
continue
|
|
692
|
+
return result
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
696
|
+
# MCP Client — manages multiple connections
|
|
697
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
698
|
+
|
|
699
|
+
@dataclass
|
|
700
|
+
class MCPServerConfig:
|
|
701
|
+
"""Configuration for a single MCP server."""
|
|
702
|
+
name: str
|
|
703
|
+
transport: str = "stdio"
|
|
704
|
+
# stdio config
|
|
705
|
+
command: str = ""
|
|
706
|
+
args: list[str] = field(default_factory=list)
|
|
707
|
+
env: dict[str, str] = field(default_factory=dict)
|
|
708
|
+
cwd: str = ""
|
|
709
|
+
# http config
|
|
710
|
+
url: str = ""
|
|
711
|
+
headers: dict[str, str] = field(default_factory=dict)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
class MCPClient:
|
|
715
|
+
"""
|
|
716
|
+
MCP client managing multiple MCP server connections.
|
|
717
|
+
|
|
718
|
+
Discovers tools, resources, prompts from all servers.
|
|
719
|
+
Provides unified search, caching, and health monitoring.
|
|
720
|
+
"""
|
|
721
|
+
|
|
722
|
+
def __init__(self, servers: list[MCPServerConfig] | None = None):
|
|
723
|
+
self._connections: dict[str, MCPServerConnection] = {}
|
|
724
|
+
self._tool_to_server: dict[str, str] = {}
|
|
725
|
+
self._all_tools: list[dict[str, Any]] = []
|
|
726
|
+
self._resource_cache: OrderedDict[str, tuple[Any, float]] = OrderedDict()
|
|
727
|
+
self._resource_cache_max = 64
|
|
728
|
+
self._resource_cache_ttl = 300.0
|
|
729
|
+
self._health_task: asyncio.Task | None = None
|
|
730
|
+
self._health_interval = 60.0
|
|
731
|
+
self._health_running = False
|
|
732
|
+
self._on_health_fail: Callable[[str], None] | None = None
|
|
733
|
+
|
|
734
|
+
if servers:
|
|
735
|
+
for cfg in servers:
|
|
736
|
+
self.add_server(cfg)
|
|
737
|
+
|
|
738
|
+
# ── Server lifecycle ───────────────────────────────────────────────────
|
|
739
|
+
|
|
740
|
+
async def add_server(self, config: MCPServerConfig) -> None:
|
|
741
|
+
"""Add and connect to an MCP server."""
|
|
742
|
+
if config.transport == "stdio":
|
|
743
|
+
conn = StdioMCPConnection(
|
|
744
|
+
name=config.name,
|
|
745
|
+
command=config.command,
|
|
746
|
+
args=config.args,
|
|
747
|
+
env=config.env or None,
|
|
748
|
+
cwd=config.cwd or None,
|
|
749
|
+
)
|
|
750
|
+
elif config.transport == "http":
|
|
751
|
+
conn = HTTPMCPConnection(
|
|
752
|
+
name=config.name,
|
|
753
|
+
url=config.url,
|
|
754
|
+
headers=config.headers or None,
|
|
755
|
+
)
|
|
756
|
+
else:
|
|
757
|
+
raise ValueError(f"Unknown transport: {config.transport}")
|
|
758
|
+
|
|
759
|
+
try:
|
|
760
|
+
await conn.start()
|
|
761
|
+
self._connections[config.name] = conn
|
|
762
|
+
self._register_server_tools(config.name, conn)
|
|
763
|
+
logger.info(
|
|
764
|
+
"Added MCP server '%s': %d tools, %d resources, %d prompts",
|
|
765
|
+
config.name, len(conn.tools), len(conn.resources), len(conn.prompts),
|
|
766
|
+
)
|
|
767
|
+
except Exception:
|
|
768
|
+
try:
|
|
769
|
+
await conn.stop()
|
|
770
|
+
except Exception:
|
|
771
|
+
pass
|
|
772
|
+
raise
|
|
773
|
+
|
|
774
|
+
async def remove_server(self, name: str) -> None:
|
|
775
|
+
"""Disconnect and remove an MCP server."""
|
|
776
|
+
conn = self._connections.pop(name, None)
|
|
777
|
+
if conn:
|
|
778
|
+
await conn.stop()
|
|
779
|
+
self._all_tools = [t for t in self._all_tools if t.get("_mcp_server") != name]
|
|
780
|
+
self._tool_to_server = {k: v for k, v in self._tool_to_server.items() if v != name}
|
|
781
|
+
# Purge cache entries from this server
|
|
782
|
+
self._resource_cache = OrderedDict(
|
|
783
|
+
(k, v) for k, v in self._resource_cache.items()
|
|
784
|
+
if not k.startswith(f"{name}:")
|
|
785
|
+
)
|
|
786
|
+
logger.info("Removed MCP server '%s'", name)
|
|
787
|
+
|
|
788
|
+
async def stop_all(self) -> None:
|
|
789
|
+
"""Stop all MCP server connections."""
|
|
790
|
+
await self._stop_health_monitor()
|
|
791
|
+
for name, conn in list(self._connections.items()):
|
|
792
|
+
try:
|
|
793
|
+
await conn.stop()
|
|
794
|
+
except Exception:
|
|
795
|
+
pass
|
|
796
|
+
self._connections.clear()
|
|
797
|
+
self._all_tools.clear()
|
|
798
|
+
self._tool_to_server.clear()
|
|
799
|
+
self._resource_cache.clear()
|
|
800
|
+
logger.info("All MCP servers stopped")
|
|
801
|
+
|
|
802
|
+
# ── Tool registration ───────────────────────────────────────────────────
|
|
803
|
+
|
|
804
|
+
def _register_server_tools(self, server_name: str, conn: MCPServerConnection) -> None:
|
|
805
|
+
"""Register tools from a server connection."""
|
|
806
|
+
for tool in conn.tools:
|
|
807
|
+
tool_name = tool["name"]
|
|
808
|
+
prefixed = f"mcp__{server_name}__{tool_name}"
|
|
809
|
+
if len(prefixed) > 64:
|
|
810
|
+
suffix = tool_name[-30:] if len(tool_name) > 30 else tool_name
|
|
811
|
+
prefixed = f"mcp__{server_name[:20]}__{suffix}"
|
|
812
|
+
logger.warning("MCP tool name truncated: %s", prefixed)
|
|
813
|
+
self._tool_to_server[prefixed] = server_name
|
|
814
|
+
tool["_mcp_server"] = server_name
|
|
815
|
+
tool["_mcp_original_name"] = tool_name
|
|
816
|
+
self._all_tools.append(tool)
|
|
817
|
+
|
|
818
|
+
def refresh_tools(self, server_name: str | None = None) -> None:
|
|
819
|
+
"""Re-discover and re-register tools from one or all servers."""
|
|
820
|
+
names = [server_name] if server_name else list(self._connections)
|
|
821
|
+
for name in names:
|
|
822
|
+
conn = self._connections.get(name)
|
|
823
|
+
if not conn:
|
|
824
|
+
continue
|
|
825
|
+
# Remove old tools for this server
|
|
826
|
+
self._all_tools = [t for t in self._all_tools if t.get("_mcp_server") != name]
|
|
827
|
+
self._tool_to_server = {k: v for k, v in self._tool_to_server.items() if v != name}
|
|
828
|
+
# Re-discover and register
|
|
829
|
+
conn.discover()
|
|
830
|
+
self._register_server_tools(name, conn)
|
|
831
|
+
|
|
832
|
+
# ── Tool access ─────────────────────────────────────────────────────────
|
|
833
|
+
|
|
834
|
+
def get_tools(self) -> list[dict[str, Any]]:
|
|
835
|
+
"""Get all tools as OpenAI function tool definitions."""
|
|
836
|
+
openai_tools = []
|
|
837
|
+
for tool in self._all_tools:
|
|
838
|
+
server = tool.get("_mcp_server", "?")
|
|
839
|
+
original = tool.get("_mcp_original_name", tool.get("name", "?"))
|
|
840
|
+
openai_tools.append({
|
|
841
|
+
"type": "function",
|
|
842
|
+
"function": {
|
|
843
|
+
"name": f"mcp__{server}__{original}",
|
|
844
|
+
"description": tool.get("description", f"MCP tool: {tool['name']}"),
|
|
845
|
+
"parameters": tool.get("inputSchema", {
|
|
846
|
+
"type": "object", "properties": {},
|
|
847
|
+
}),
|
|
848
|
+
},
|
|
849
|
+
})
|
|
850
|
+
return openai_tools
|
|
851
|
+
|
|
852
|
+
async def call_tool(self, prefixed_name: str, arguments: dict[str, Any]) -> Any:
|
|
853
|
+
"""Call an MCP tool by its prefixed name."""
|
|
854
|
+
server_name = self._tool_to_server.get(prefixed_name)
|
|
855
|
+
if not server_name:
|
|
856
|
+
raise ValueError(f"Unknown MCP tool: {prefixed_name}")
|
|
857
|
+
|
|
858
|
+
conn = self._connections.get(server_name)
|
|
859
|
+
if not conn:
|
|
860
|
+
raise RuntimeError(f"MCP server not connected: {server_name}")
|
|
861
|
+
|
|
862
|
+
for tool in self._all_tools:
|
|
863
|
+
srv = tool.get("_mcp_server")
|
|
864
|
+
original = tool.get("_mcp_original_name")
|
|
865
|
+
if srv == server_name and f"mcp__{srv}__{original}" == prefixed_name:
|
|
866
|
+
return await conn.call_tool(tool["_mcp_original_name"], arguments)
|
|
867
|
+
|
|
868
|
+
raise ValueError(f"Tool not found: {prefixed_name}")
|
|
869
|
+
|
|
870
|
+
def is_mcp_tool(self, tool_name: str) -> bool:
|
|
871
|
+
return tool_name.startswith("mcp__") and tool_name in self._tool_to_server
|
|
872
|
+
|
|
873
|
+
# ── Prompts ─────────────────────────────────────────────────────────────
|
|
874
|
+
|
|
875
|
+
def list_prompts(self) -> list[dict[str, Any]]:
|
|
876
|
+
"""List all prompts from all servers."""
|
|
877
|
+
result: list[dict[str, Any]] = []
|
|
878
|
+
for name, conn in self._connections.items():
|
|
879
|
+
for p in conn.prompts:
|
|
880
|
+
result.append({**p, "_mcp_server": name})
|
|
881
|
+
return result
|
|
882
|
+
|
|
883
|
+
async def get_prompt(self, server: str, prompt_name: str,
|
|
884
|
+
arguments: dict[str, str] | None = None) -> Any:
|
|
885
|
+
"""Get a prompt from a specific server."""
|
|
886
|
+
conn = self._connections.get(server)
|
|
887
|
+
if not conn:
|
|
888
|
+
raise ValueError(f"Server not found: {server}")
|
|
889
|
+
return await conn.get_prompt(prompt_name, arguments)
|
|
890
|
+
|
|
891
|
+
# ── Search ──────────────────────────────────────────────────────────────
|
|
892
|
+
|
|
893
|
+
def search_tools(self, query: str, limit: int = 20) -> list[dict[str, Any]]:
|
|
894
|
+
"""Fuzzy search MCP tools across all servers."""
|
|
895
|
+
q = query.lower().strip()
|
|
896
|
+
if not q:
|
|
897
|
+
return []
|
|
898
|
+
|
|
899
|
+
scored: list[tuple[int, dict[str, Any]]] = []
|
|
900
|
+
for tool in self._all_tools:
|
|
901
|
+
name = tool.get("name", "")
|
|
902
|
+
desc = tool.get("description", "")
|
|
903
|
+
name_l = name.lower()
|
|
904
|
+
score = 0
|
|
905
|
+
if name_l == q:
|
|
906
|
+
score = 3
|
|
907
|
+
elif name_l.startswith(q):
|
|
908
|
+
score = 2
|
|
909
|
+
elif q in name_l:
|
|
910
|
+
score = 1
|
|
911
|
+
elif q in desc.lower():
|
|
912
|
+
score = 0
|
|
913
|
+
|
|
914
|
+
if q in name_l or q in desc.lower():
|
|
915
|
+
scored.append((score, tool))
|
|
916
|
+
|
|
917
|
+
scored.sort(key=lambda x: (-x[0], x[1].get("name", "")))
|
|
918
|
+
return [t for _, t in scored[:limit]]
|
|
919
|
+
|
|
920
|
+
def search_resources(self, query: str, limit: int = 20) -> list[dict[str, Any]]:
|
|
921
|
+
"""Search MCP resources by URI across all servers."""
|
|
922
|
+
q = query.lower().strip()
|
|
923
|
+
if not q:
|
|
924
|
+
return []
|
|
925
|
+
|
|
926
|
+
results: list[dict[str, Any]] = []
|
|
927
|
+
for conn in self._connections.values():
|
|
928
|
+
for res in conn.resources:
|
|
929
|
+
uri = res.get("uri", "").lower()
|
|
930
|
+
name = res.get("name", "").lower()
|
|
931
|
+
desc = res.get("description", "").lower()
|
|
932
|
+
if q in uri or q in name or q in desc:
|
|
933
|
+
results.append({**res, "_mcp_server": conn.name})
|
|
934
|
+
results.sort(key=lambda r: r.get("name", r.get("uri", "")))
|
|
935
|
+
return results[:limit]
|
|
936
|
+
|
|
937
|
+
def get_all_resources(self) -> list[dict[str, Any]]:
|
|
938
|
+
"""Return all discovered resources from all servers."""
|
|
939
|
+
results: list[dict[str, Any]] = []
|
|
940
|
+
for conn in self._connections.values():
|
|
941
|
+
for res in conn.resources:
|
|
942
|
+
results.append({**res, "_mcp_server": conn.name})
|
|
943
|
+
results.sort(key=lambda r: r.get("name", r.get("uri", "")))
|
|
944
|
+
return results
|
|
945
|
+
|
|
946
|
+
# ── Resource cache ──────────────────────────────────────────────────────
|
|
947
|
+
|
|
948
|
+
def cached_read_resource(self, uri: str) -> dict[str, Any]:
|
|
949
|
+
"""Read a resource with LRU+TTL caching."""
|
|
950
|
+
now = time.time()
|
|
951
|
+
if uri in self._resource_cache:
|
|
952
|
+
content, ts = self._resource_cache[uri]
|
|
953
|
+
if now - ts < self._resource_cache_ttl:
|
|
954
|
+
self._resource_cache.move_to_end(uri)
|
|
955
|
+
return {"content": content, "cached": True, "server": ""}
|
|
956
|
+
del self._resource_cache[uri]
|
|
957
|
+
|
|
958
|
+
# Find the owning server
|
|
959
|
+
for conn in self._connections.values():
|
|
960
|
+
for res in conn.resources:
|
|
961
|
+
if res.get("uri") == uri:
|
|
962
|
+
result = conn.read_resource(uri)
|
|
963
|
+
content = result.get("contents", result)
|
|
964
|
+
if len(self._resource_cache) >= self._resource_cache_max:
|
|
965
|
+
self._resource_cache.popitem(last=False)
|
|
966
|
+
self._resource_cache[uri] = (content, now)
|
|
967
|
+
return {"content": content, "cached": False, "server": conn.name}
|
|
968
|
+
|
|
969
|
+
# Try resource templates
|
|
970
|
+
for conn in self._connections.values():
|
|
971
|
+
for tmpl in conn.resource_templates:
|
|
972
|
+
tmpl_uri = tmpl.get("uriTemplate", "")
|
|
973
|
+
# Simple match: if URI starts with the template prefix
|
|
974
|
+
prefix = tmpl_uri.split("{")[0] if "{" in tmpl_uri else tmpl_uri
|
|
975
|
+
if uri.startswith(prefix):
|
|
976
|
+
result = conn.read_resource(uri)
|
|
977
|
+
content = result.get("contents", result)
|
|
978
|
+
if len(self._resource_cache) >= self._resource_cache_max:
|
|
979
|
+
self._resource_cache.popitem(last=False)
|
|
980
|
+
self._resource_cache[uri] = (content, now)
|
|
981
|
+
return {"content": content, "cached": False, "server": conn.name}
|
|
982
|
+
|
|
983
|
+
raise ValueError(f"Resource not found on any server: {uri}")
|
|
984
|
+
|
|
985
|
+
def invalidate_resource_cache(self, uri: str | None = None) -> None:
|
|
986
|
+
"""Invalidate cached resources."""
|
|
987
|
+
if uri:
|
|
988
|
+
self._resource_cache.pop(uri, None)
|
|
989
|
+
else:
|
|
990
|
+
self._resource_cache.clear()
|
|
991
|
+
|
|
992
|
+
# ── Health monitoring ───────────────────────────────────────────────────
|
|
993
|
+
|
|
994
|
+
def on_health_fail(self, callback: Callable[[str], None]) -> None:
|
|
995
|
+
"""Register a callback for health check failures."""
|
|
996
|
+
self._on_health_fail = callback
|
|
997
|
+
|
|
998
|
+
def start_health_monitor(self, interval: float = 60.0) -> None:
|
|
999
|
+
"""Start periodic health checks (ping every N seconds)."""
|
|
1000
|
+
if self._health_running:
|
|
1001
|
+
return
|
|
1002
|
+
self._health_interval = interval
|
|
1003
|
+
self._health_running = True
|
|
1004
|
+
self._health_task = asyncio.create_task(self._health_loop())
|
|
1005
|
+
logger.info("MCP health monitor started (interval=%.0fs)", interval)
|
|
1006
|
+
|
|
1007
|
+
async def _stop_health_monitor(self) -> None:
|
|
1008
|
+
self._health_running = False
|
|
1009
|
+
if self._health_task and not self._health_task.done():
|
|
1010
|
+
self._health_task.cancel()
|
|
1011
|
+
try:
|
|
1012
|
+
await self._health_task
|
|
1013
|
+
except asyncio.CancelledError:
|
|
1014
|
+
pass
|
|
1015
|
+
self._health_task = None
|
|
1016
|
+
|
|
1017
|
+
async def _health_loop(self) -> None:
|
|
1018
|
+
while self._health_running:
|
|
1019
|
+
await asyncio.sleep(self._health_interval)
|
|
1020
|
+
if not self._health_running:
|
|
1021
|
+
break
|
|
1022
|
+
for name, conn in list(self._connections.items()):
|
|
1023
|
+
try:
|
|
1024
|
+
if not await conn.ping(timeout=10):
|
|
1025
|
+
logger.warning("[%s] Health check failed: no response", name)
|
|
1026
|
+
if self._on_health_fail:
|
|
1027
|
+
self._on_health_fail(name)
|
|
1028
|
+
except asyncio.CancelledError:
|
|
1029
|
+
raise
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
logger.warning("[%s] Health check error: %s", name, e)
|
|
1032
|
+
if self._on_health_fail:
|
|
1033
|
+
self._on_health_fail(name)
|
|
1034
|
+
|
|
1035
|
+
# ── Properties ──────────────────────────────────────────────────────────
|
|
1036
|
+
|
|
1037
|
+
@property
|
|
1038
|
+
def connected_servers(self) -> list[str]:
|
|
1039
|
+
return list(self._connections.keys())
|
|
1040
|
+
|
|
1041
|
+
@property
|
|
1042
|
+
def tool_count(self) -> int:
|
|
1043
|
+
return len(self._all_tools)
|
|
1044
|
+
|
|
1045
|
+
@property
|
|
1046
|
+
def resource_count(self) -> int:
|
|
1047
|
+
return sum(len(c.resources) for c in self._connections.values())
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
1051
|
+
# MCP config file support
|
|
1052
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
1053
|
+
|
|
1054
|
+
def load_mcp_config(config_path: str | Path) -> list[MCPServerConfig]:
|
|
1055
|
+
"""
|
|
1056
|
+
Load MCP server configurations from a JSON file.
|
|
1057
|
+
|
|
1058
|
+
Example config.json:
|
|
1059
|
+
{
|
|
1060
|
+
"mcpServers": {
|
|
1061
|
+
"filesystem": {
|
|
1062
|
+
"transport": "stdio",
|
|
1063
|
+
"command": "npx",
|
|
1064
|
+
"args": ["-y", "@anthropic/mcp-filesystem", "/path/to/allowed"]
|
|
1065
|
+
},
|
|
1066
|
+
"github": {
|
|
1067
|
+
"transport": "stdio",
|
|
1068
|
+
"command": "npx",
|
|
1069
|
+
"args": ["-y", "@anthropic/mcp-github"],
|
|
1070
|
+
"env": {"GITHUB_TOKEN": "ghp_xxx"}
|
|
1071
|
+
},
|
|
1072
|
+
"remote-api": {
|
|
1073
|
+
"transport": "http",
|
|
1074
|
+
"url": "https://mcp.example.com/mcp",
|
|
1075
|
+
"headers": {"Authorization": "Bearer xxx"}
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
"""
|
|
1080
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
1081
|
+
data = json.load(f)
|
|
1082
|
+
|
|
1083
|
+
servers = []
|
|
1084
|
+
for name, cfg in data.get("mcpServers", {}).items():
|
|
1085
|
+
servers.append(MCPServerConfig(
|
|
1086
|
+
name=name,
|
|
1087
|
+
transport=cfg.get("transport", "stdio"),
|
|
1088
|
+
command=cfg.get("command", ""),
|
|
1089
|
+
args=cfg.get("args", []),
|
|
1090
|
+
env=cfg.get("env", {}),
|
|
1091
|
+
cwd=cfg.get("cwd", ""),
|
|
1092
|
+
url=cfg.get("url", ""),
|
|
1093
|
+
headers=cfg.get("headers", {}),
|
|
1094
|
+
))
|
|
1095
|
+
return servers
|