ata-coder 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. ata_coder/__init__.py +1 -0
  2. ata_coder/agent.py +874 -0
  3. ata_coder/agent_compact.py +190 -0
  4. ata_coder/agent_controller.py +218 -0
  5. ata_coder/agent_extension.py +69 -0
  6. ata_coder/agent_routing.py +105 -0
  7. ata_coder/agent_subsystems.py +72 -0
  8. ata_coder/agent_tools.py +318 -0
  9. ata_coder/agent_undo.py +63 -0
  10. ata_coder/anthropic_client.py +465 -0
  11. ata_coder/change_tracker.py +368 -0
  12. ata_coder/clawd_integration.py +574 -0
  13. ata_coder/commands/__init__.py +128 -0
  14. ata_coder/commands/_core.py +184 -0
  15. ata_coder/commands/_safety.py +95 -0
  16. ata_coder/commands/_settings.py +241 -0
  17. ata_coder/commands/_workflow.py +451 -0
  18. ata_coder/commands.py +974 -0
  19. ata_coder/config.py +257 -0
  20. ata_coder/core/__init__.py +35 -0
  21. ata_coder/core/events.py +73 -0
  22. ata_coder/core/queue.py +85 -0
  23. ata_coder/core/state.py +17 -0
  24. ata_coder/event_queue.py +5 -0
  25. ata_coder/extension.py +654 -0
  26. ata_coder/extensions/__init__.py +1 -0
  27. ata_coder/extensions/hello_skill.py +47 -0
  28. ata_coder/fool_proof.py +295 -0
  29. ata_coder/git_workflow.py +371 -0
  30. ata_coder/gui.py +511 -0
  31. ata_coder/llm_client.py +543 -0
  32. ata_coder/main.py +814 -0
  33. ata_coder/mcp_client.py +1095 -0
  34. ata_coder/memory.py +539 -0
  35. ata_coder/model_registry.py +134 -0
  36. ata_coder/model_router.py +105 -0
  37. ata_coder/permissions.py +274 -0
  38. ata_coder/privilege.py +464 -0
  39. ata_coder/project.py +273 -0
  40. ata_coder/prompt_template.py +423 -0
  41. ata_coder/prompts/auto-mode.md +7 -0
  42. ata_coder/prompts/coding-rules.md +40 -0
  43. ata_coder/prompts/execution-guardrails.md +14 -0
  44. ata_coder/prompts/memory-system.md +24 -0
  45. ata_coder/prompts/output-style.md +23 -0
  46. ata_coder/prompts/safety.md +17 -0
  47. ata_coder/prompts/slash-commands.md +24 -0
  48. ata_coder/prompts/sub-agents.md +38 -0
  49. ata_coder/prompts/system-reminders.md +17 -0
  50. ata_coder/prompts/system.md +105 -0
  51. ata_coder/prompts/tool-policy.md +46 -0
  52. ata_coder/repl_theme.py +99 -0
  53. ata_coder/repl_tracker.py +89 -0
  54. ata_coder/repl_ui.py +1214 -0
  55. ata_coder/safety_guard.py +434 -0
  56. ata_coder/self_correct.py +346 -0
  57. ata_coder/server.py +882 -0
  58. ata_coder/server_session.py +159 -0
  59. ata_coder/server_shell.py +129 -0
  60. ata_coder/session.py +431 -0
  61. ata_coder/settings.py +439 -0
  62. ata_coder/setup_wizard.py +136 -0
  63. ata_coder/skill_extension.py +92 -0
  64. ata_coder/skills/architect/SKILL.md +42 -0
  65. ata_coder/skills/code-reviewer/SKILL.md +37 -0
  66. ata_coder/skills/codecraft/SKILL.md +452 -0
  67. ata_coder/skills/debugger/SKILL.md +45 -0
  68. ata_coder/skills/doc-writer/SKILL.md +36 -0
  69. ata_coder/skills/general-coder/SKILL.md +76 -0
  70. ata_coder/skills/math-calculator/README.md +40 -0
  71. ata_coder/skills/math-calculator/SKILL.md +59 -0
  72. ata_coder/skills/math-calculator/handler.py +103 -0
  73. ata_coder/skills/math-calculator/prompts/system.md +8 -0
  74. ata_coder/skills/math-calculator/requirements.txt +2 -0
  75. ata_coder/skills/math-calculator/resources/constants.json +8 -0
  76. ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
  77. ata_coder/skills/security-auditor/SKILL.md +40 -0
  78. ata_coder/skills/test-writer/SKILL.md +36 -0
  79. ata_coder/skills/weather-skill/README.md +45 -0
  80. ata_coder/skills/weather-skill/handler.py +76 -0
  81. ata_coder/skills/weather-skill/manifest.json +48 -0
  82. ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
  83. ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
  84. ata_coder/skills/weather-skill/requirements.txt +1 -0
  85. ata_coder/skills/weather-skill/resources/city_list.json +17 -0
  86. ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
  87. ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
  88. ata_coder/skills/weather-skill/weather_utils.py +50 -0
  89. ata_coder/skills.py +1014 -0
  90. ata_coder/sub_agent.py +273 -0
  91. ata_coder/sub_agent_manager.py +203 -0
  92. ata_coder/system_prompt_builder.py +146 -0
  93. ata_coder/task_planner.py +391 -0
  94. ata_coder/terminal.py +318 -0
  95. ata_coder/test_runner.py +219 -0
  96. ata_coder/thread_supervisor.py +195 -0
  97. ata_coder/tool_defs.py +335 -0
  98. ata_coder/tools/__init__.py +11 -0
  99. ata_coder/tools/definitions.py +335 -0
  100. ata_coder/tools/executor.py +1036 -0
  101. ata_coder/tools/result.py +26 -0
  102. ata_coder/tools/subagent.py +332 -0
  103. ata_coder/tools/web.py +361 -0
  104. ata_coder/tools.py +1576 -0
  105. ata_coder/types.py +92 -0
  106. ata_coder/utils.py +113 -0
  107. ata_coder/web/css/style.css +180 -0
  108. ata_coder/web/index.html +84 -0
  109. ata_coder/web/js/app.js +489 -0
  110. ata_coder/web/package-lock.json +25 -0
  111. ata_coder/web/package.json +10 -0
  112. ata_coder/web/tsconfig.json +13 -0
  113. ata_coder-2.4.2.dist-info/METADATA +799 -0
  114. ata_coder-2.4.2.dist-info/RECORD +118 -0
  115. ata_coder-2.4.2.dist-info/WHEEL +5 -0
  116. ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
  117. ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
  118. ata_coder-2.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1095 @@
1
+ """
2
+ MCP (Model Context Protocol) client — full-spec implementation.
3
+
4
+ Supports MCP servers over:
5
+ - stdio (subprocess): spawns the server as a child process
6
+ - HTTP/SSE: connects to a remote MCP server
7
+
8
+ Implements: capability negotiation, tools, resources, prompts, ping,
9
+ progress notifications, cancellation, resource templates, logging,
10
+ completion, roots.
11
+
12
+ Spec: https://spec.modelcontextprotocol.io/
13
+ """
14
+
15
+ import asyncio
16
+ import json
17
+ import logging
18
+ import time
19
+ from collections import OrderedDict
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Any, Callable
23
+
24
+ import httpx
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ # ═══════════════════════════════════════════════════════════════════════════════
30
+ # JSON-RPC 2.0 standard error codes
31
+ # ═══════════════════════════════════════════════════════════════════════════════
32
+
33
+ class JsonRpcError(Exception):
34
+ """A JSON-RPC error with standard code and message."""
35
+ def __init__(self, code: int, message: str, data: Any = None):
36
+ self.code = code
37
+ self.message = message
38
+ self.data = data
39
+ super().__init__(message)
40
+
41
+ # Standard JSON-RPC error codes
42
+ PARSE_ERROR = -32700
43
+ INVALID_REQUEST = -32600
44
+ METHOD_NOT_FOUND = -32601
45
+ INVALID_PARAMS = -32602
46
+ INTERNAL_ERROR = -32603
47
+ # MCP-specific (server error range: -32000 to -32099)
48
+ SERVER_NOT_INITIALIZED = -32002
49
+ REQUEST_CANCELLED = -32800
50
+
51
+
52
+ # ═══════════════════════════════════════════════════════════════════════════════
53
+ # JSON-RPC types
54
+ # ═══════════════════════════════════════════════════════════════════════════════
55
+
56
+ JsonRpcId = str | int
57
+
58
+
59
+ @dataclass
60
+ class JsonRpcRequest:
61
+ jsonrpc: str = "2.0"
62
+ method: str = ""
63
+ params: dict[str, Any] = field(default_factory=dict)
64
+ id: JsonRpcId = ""
65
+
66
+
67
+ @dataclass
68
+ class JsonRpcResponse:
69
+ jsonrpc: str = "2.0"
70
+ result: Any = None
71
+ error: dict[str, Any] | None = None
72
+ id: JsonRpcId = ""
73
+
74
+
75
+ # ═══════════════════════════════════════════════════════════════════════════════
76
+ # MCP Server connection — base class
77
+ # ═══════════════════════════════════════════════════════════════════════════════
78
+
79
+ class MCPServerConnection:
80
+ """
81
+ A connection to a single MCP server.
82
+
83
+ Handles JSON-RPC communication, capability negotiation,
84
+ tool/resource/prompt discovery, ping, progress, and cancellation.
85
+ """
86
+
87
+ PROTOCOL_VERSION = "2025-03-26"
88
+
89
+ def __init__(self, name: str):
90
+ self.name = name
91
+ self._tools: list[dict[str, Any]] = []
92
+ self._resources: list[dict[str, Any]] = []
93
+ self._resource_templates: list[dict[str, Any]] = []
94
+ self._prompts: list[dict[str, Any]] = []
95
+ self._initialized: bool = False
96
+ self._server_info: dict[str, Any] = {}
97
+ self._capabilities: dict[str, Any] = {}
98
+ self._server_capabilities: dict[str, Any] = {}
99
+ self._roots: list[dict[str, Any]] = []
100
+ self._pong_received: bool = False
101
+
102
+ @property
103
+ def tools(self) -> list[dict[str, Any]]:
104
+ return self._tools
105
+
106
+ @property
107
+ def resources(self) -> list[dict[str, Any]]:
108
+ return self._resources
109
+
110
+ @property
111
+ def resource_templates(self) -> list[dict[str, Any]]:
112
+ return self._resource_templates
113
+
114
+ @property
115
+ def prompts(self) -> list[dict[str, Any]]:
116
+ return self._prompts
117
+
118
+ @property
119
+ def initialized(self) -> bool:
120
+ return self._initialized
121
+
122
+ @property
123
+ def server_info(self) -> dict[str, Any]:
124
+ return self._server_info
125
+
126
+ @property
127
+ def capabilities(self) -> dict[str, Any]:
128
+ return self._capabilities
129
+
130
+ @property
131
+ def server_capabilities(self) -> dict[str, Any]:
132
+ return self._server_capabilities
133
+
134
+ # ── Abstract methods ──
135
+
136
+ async def start(self) -> None:
137
+ raise NotImplementedError
138
+
139
+ async def stop(self) -> None:
140
+ raise NotImplementedError
141
+
142
+ async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
143
+ raise NotImplementedError
144
+
145
+ async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
146
+ raise NotImplementedError
147
+
148
+ # ── Capability checks ──
149
+
150
+ def has_capability(self, cap: str) -> bool:
151
+ """Check if the server supports a given capability namespace."""
152
+ return cap in self._server_capabilities
153
+
154
+ def has_subcapability(self, cap: str, sub: str) -> bool:
155
+ """Check if the server supports a sub-capability (e.g. tools→listChanged)."""
156
+ caps = self._server_capabilities.get(cap, {})
157
+ return isinstance(caps, dict) and sub in caps
158
+
159
+ # ── Lifecycle ──
160
+
161
+ async def initialize(self, client_capabilities: dict[str, Any] | None = None) -> None:
162
+ """Send initialize request and negotiate capabilities."""
163
+ caps = {
164
+ "tools": {},
165
+ "resources": {"subscribe": True},
166
+ "prompts": {},
167
+ "logging": {},
168
+ }
169
+ if client_capabilities:
170
+ caps.update(client_capabilities)
171
+
172
+ init_result = await self.send_request("initialize", {
173
+ "protocolVersion": self.PROTOCOL_VERSION,
174
+ "capabilities": caps,
175
+ "clientInfo": {
176
+ "name": "ata-coder",
177
+ "version": "2.3.0",
178
+ },
179
+ })
180
+ self._server_info = init_result.get("serverInfo", {})
181
+ self._server_capabilities = init_result.get("capabilities", {})
182
+ self._initialized = True
183
+
184
+ # Send initialized notification
185
+ await self.send_notification("notifications/initialized", {})
186
+
187
+ logger.info(
188
+ "[%s] Initialized: %s v%s (caps: %s)",
189
+ self.name,
190
+ self._server_info.get("name", "unknown"),
191
+ self._server_info.get("version", "?"),
192
+ ", ".join(self._server_capabilities) or "none",
193
+ )
194
+
195
+ async def discover(self) -> None:
196
+ """Discover tools, resources, resource templates, and prompts."""
197
+ if not self._initialized:
198
+ raise JsonRpcError(SERVER_NOT_INITIALIZED, "Server not initialized")
199
+
200
+ # Tools
201
+ if self.has_capability("tools"):
202
+ result = await self.send_request("tools/list", {})
203
+ self._tools = result.get("tools", [])
204
+ logger.info("[%s] Discovered %d tools", self.name, len(self._tools))
205
+
206
+ # Resources
207
+ if self.has_capability("resources"):
208
+ try:
209
+ result = await self.send_request("resources/list", {})
210
+ self._resources = result.get("resources", [])
211
+ logger.info("[%s] Discovered %d resources", self.name, len(self._resources))
212
+ except Exception:
213
+ pass
214
+
215
+ # Resource templates
216
+ if self.has_capability("resources"):
217
+ try:
218
+ result = await self.send_request("resources/templates/list", {})
219
+ self._resource_templates = result.get("resourceTemplates", [])
220
+ logger.info("[%s] Discovered %d resource templates", self.name, len(self._resource_templates))
221
+ except Exception:
222
+ pass
223
+
224
+ # Prompts
225
+ if self.has_capability("prompts"):
226
+ try:
227
+ result = await self.send_request("prompts/list", {})
228
+ self._prompts = result.get("prompts", [])
229
+ logger.info("[%s] Discovered %d prompts", self.name, len(self._prompts))
230
+ except Exception:
231
+ pass
232
+
233
+ # ── Tool calling ──
234
+
235
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
236
+ """Call a tool on this MCP server."""
237
+ if not self.has_capability("tools"):
238
+ raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support tools")
239
+ return await self.send_request("tools/call", {
240
+ "name": tool_name,
241
+ "arguments": arguments,
242
+ })
243
+
244
+ # ── Resource reading ──
245
+
246
+ async def read_resource(self, uri: str) -> Any:
247
+ """Read a resource by URI."""
248
+ if not self.has_capability("resources"):
249
+ raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support resources")
250
+ return await self.send_request("resources/read", {"uri": uri})
251
+
252
+ async def subscribe_resource(self, uri: str) -> None:
253
+ """Subscribe to resource updates."""
254
+ if not self.has_subcapability("resources", "subscribe"):
255
+ return
256
+ await self.send_notification("resources/subscribe", {"uri": uri})
257
+ logger.info("[%s] Subscribed to: %s", self.name, uri)
258
+
259
+ # ── Prompts ──
260
+
261
+ async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> Any:
262
+ """Get a prompt by name with optional arguments."""
263
+ if not self.has_capability("prompts"):
264
+ raise JsonRpcError(METHOD_NOT_FOUND, "Server does not support prompts")
265
+ params: dict[str, Any] = {"name": name}
266
+ if arguments:
267
+ params["arguments"] = arguments
268
+ return await self.send_request("prompts/get", params)
269
+
270
+ # ── Ping ──
271
+
272
+ async def ping(self, timeout: float = 10.0) -> bool:
273
+ """Ping the server. Returns True if alive."""
274
+ try:
275
+ await asyncio.wait_for(self.send_request("ping", {}), timeout=timeout)
276
+ return True
277
+ except Exception:
278
+ return False
279
+
280
+ # ── Completion ──
281
+
282
+ async def complete(self, ref: dict[str, Any], argument: dict[str, Any]) -> Any:
283
+ """Request auto-completion for a prompt or resource template argument."""
284
+ return await self.send_request("completion/complete", {
285
+ "ref": ref,
286
+ "argument": argument,
287
+ })
288
+
289
+ # ── Roots ──
290
+
291
+ async def set_roots(self, roots: list[dict[str, Any]]) -> None:
292
+ """Inform the server about root directories."""
293
+ self._roots = roots
294
+ await self.send_notification("notifications/roots/list_changed", {"roots": roots})
295
+ logger.info("[%s] Updated roots: %d", self.name, len(roots))
296
+
297
+ # ── Logging ──
298
+
299
+ async def set_log_level(self, level: str) -> None:
300
+ """Set the log level on the server (debug/info/notice/warning/error/critical)."""
301
+ await self.send_notification("logging/setLevel", {"level": level})
302
+
303
+
304
+ # ═══════════════════════════════════════════════════════════════════════════════
305
+ # Stdio connection
306
+ # ═══════════════════════════════════════════════════════════════════════════════
307
+
308
+ class StdioMCPConnection(MCPServerConnection):
309
+ """MCP connection over stdio (subprocess)."""
310
+
311
+ _next_req_id = 0
312
+
313
+ def __init__(self, name: str, command: str, args: list[str] | None = None,
314
+ env: dict[str, str] | None = None, cwd: str | None = None):
315
+ super().__init__(name)
316
+ self.command = command
317
+ self.args = args or []
318
+ self.env = env
319
+ self.cwd = cwd
320
+ self._process: asyncio.subprocess.Process | None = None
321
+ self._pending: dict[JsonRpcId, asyncio.Event] = {}
322
+ self._results: dict[JsonRpcId, JsonRpcResponse] = {}
323
+ self._reader_task: asyncio.Task | None = None
324
+ self._running = False
325
+ self._on_progress: Callable[[int, int, str | None], None] | None = None
326
+
327
+ @classmethod
328
+ def _next_id(cls) -> str:
329
+ cls._next_req_id += 1
330
+ return str(cls._next_req_id)
331
+
332
+ def on_progress(self, callback: Callable[[int, int, str | None], None]) -> None:
333
+ """Register a callback for progress notifications."""
334
+ self._on_progress = callback
335
+
336
+ # ── Start / Stop ──
337
+
338
+ async def start(self) -> None:
339
+ """Start the MCP server process."""
340
+ logger.info("[%s] Starting: %s %s", self.name, self.command, " ".join(self.args))
341
+
342
+ try:
343
+ self._process = await asyncio.create_subprocess_exec(
344
+ self.command, *self.args,
345
+ stdin=asyncio.subprocess.PIPE,
346
+ stdout=asyncio.subprocess.PIPE,
347
+ stderr=asyncio.subprocess.DEVNULL,
348
+ env=self.env,
349
+ cwd=self.cwd,
350
+ )
351
+ except FileNotFoundError:
352
+ raise RuntimeError(
353
+ f"MCP server command not found: {self.command}. "
354
+ f"Install it or check the path."
355
+ )
356
+ except Exception as e:
357
+ raise RuntimeError(f"Failed to start MCP server: {e}")
358
+
359
+ self._running = True
360
+ self._reader_task = asyncio.create_task(self._read_loop())
361
+
362
+ try:
363
+ await self.initialize()
364
+ await self.discover()
365
+ except Exception:
366
+ await self.stop()
367
+ raise
368
+
369
+ async def stop(self) -> None:
370
+ """Stop the MCP server process."""
371
+ self._running = False
372
+
373
+ # Cancel and await reader task FIRST — it holds the stdout pipe open.
374
+ if self._reader_task and not self._reader_task.done():
375
+ self._reader_task.cancel()
376
+ try:
377
+ await self._reader_task
378
+ except (asyncio.CancelledError, Exception):
379
+ pass
380
+ self._reader_task = None
381
+
382
+ # Terminate/kill the process
383
+ proc = self._process
384
+ self._process = None
385
+ if proc is not None:
386
+ try:
387
+ proc.terminate()
388
+ try:
389
+ await asyncio.wait_for(proc.wait(), timeout=5)
390
+ except asyncio.TimeoutError:
391
+ try:
392
+ proc.kill()
393
+ await asyncio.wait_for(proc.wait(), timeout=3)
394
+ except Exception:
395
+ pass
396
+ except Exception:
397
+ try:
398
+ proc.kill()
399
+ except Exception:
400
+ pass
401
+ # Explicitly close pipes to prevent "I/O operation on closed pipe"
402
+ # during BaseSubprocessTransport.__del__ at GC time.
403
+ for pipe in (proc.stdin, proc.stdout, proc.stderr):
404
+ if pipe is not None:
405
+ try:
406
+ pipe.close()
407
+ except Exception:
408
+ pass
409
+
410
+ # Release all pending requests
411
+ for evt in self._pending.values():
412
+ evt.set()
413
+ self._pending.clear()
414
+ self._results.clear()
415
+
416
+ logger.info("[%s] Stopped", self.name)
417
+
418
+ # ── Message I/O ──
419
+
420
+ async def _send_raw(self, msg: dict[str, Any]) -> None:
421
+ """Send a raw JSON-RPC message to the server."""
422
+ if not self._process or not self._process.stdin:
423
+ raise RuntimeError("MCP server not running")
424
+ line = json.dumps(msg, ensure_ascii=False) + "\n"
425
+ try:
426
+ self._process.stdin.write(line.encode("utf-8"))
427
+ await self._process.stdin.drain()
428
+ except Exception as e:
429
+ raise RuntimeError(f"Failed to send to MCP server: {e}")
430
+
431
+ async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
432
+ """Send a JSON-RPC request and wait for the response."""
433
+ req_id = self._next_id()
434
+ msg = {
435
+ "jsonrpc": "2.0",
436
+ "method": method,
437
+ "params": params or {},
438
+ "id": req_id,
439
+ }
440
+
441
+ event = asyncio.Event()
442
+ self._pending[req_id] = event
443
+
444
+ await self._send_raw(msg)
445
+
446
+ # Wait with timeout
447
+ timeout = 120 if method == "initialize" else 60
448
+ try:
449
+ await asyncio.wait_for(event.wait(), timeout=timeout)
450
+ except asyncio.TimeoutError:
451
+ self._pending.pop(req_id, None)
452
+ raise JsonRpcError(INTERNAL_ERROR, f"MCP request timeout: {method}")
453
+
454
+ response = self._results.pop(req_id, None)
455
+
456
+ if response is None:
457
+ raise JsonRpcError(INTERNAL_ERROR, f"No response for request: {method}")
458
+
459
+ if response.error:
460
+ raise JsonRpcError(
461
+ response.error.get("code", INTERNAL_ERROR),
462
+ response.error.get("message", "unknown"),
463
+ response.error.get("data"),
464
+ )
465
+
466
+ return response.result
467
+
468
+ async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
469
+ """Send a JSON-RPC notification (no response expected)."""
470
+ msg = {
471
+ "jsonrpc": "2.0",
472
+ "method": method,
473
+ "params": params or {},
474
+ }
475
+ await self._send_raw(msg)
476
+
477
+ # ── Cancellation ──
478
+
479
+ async def cancel_request(self, req_id: JsonRpcId) -> None:
480
+ """Cancel an in-flight request."""
481
+ await self.send_notification("notifications/cancelled", {
482
+ "requestId": req_id,
483
+ "reason": "User cancelled",
484
+ })
485
+ self._pending.pop(req_id, None)
486
+ self._results.pop(req_id, None)
487
+
488
+ # ── Read loop ──
489
+
490
+ async def _read_loop(self) -> None:
491
+ """Background task: reads JSON-RPC messages from the server's stdout."""
492
+ while self._running and self._process and self._process.stdout:
493
+ try:
494
+ line_bytes = await self._process.stdout.readline()
495
+ if not line_bytes:
496
+ break
497
+ line = line_bytes.decode("utf-8", errors="replace")
498
+
499
+ try:
500
+ msg = json.loads(line.strip())
501
+ except json.JSONDecodeError:
502
+ continue
503
+
504
+ msg_id = msg.get("id")
505
+ method = msg.get("method")
506
+
507
+ # ── Response to our request ──
508
+ if msg_id is not None and method is None:
509
+ if msg_id in self._pending:
510
+ response = JsonRpcResponse(
511
+ jsonrpc=msg.get("jsonrpc", "2.0"),
512
+ result=msg.get("result"),
513
+ error=msg.get("error"),
514
+ id=msg_id,
515
+ )
516
+ self._results[msg_id] = response
517
+ self._pending[msg_id].set()
518
+
519
+ # ── Server → client request ──
520
+ elif method and "id" in msg and msg.get("id"):
521
+ await self._handle_server_request(msg)
522
+
523
+ # ── Notification from server ──
524
+ elif method and "id" not in msg:
525
+ await self._handle_notification(method, msg.get("params", {}))
526
+
527
+ except asyncio.CancelledError:
528
+ raise
529
+ except Exception:
530
+ if self._running:
531
+ logger.exception("[%s] Read error", self.name)
532
+ break
533
+
534
+ async def _handle_server_request(self, msg: dict[str, Any]) -> None:
535
+ """Handle a request from the server (e.g. sampling/createMessage)."""
536
+ method = msg.get("method", "")
537
+ req_id = msg.get("id")
538
+
539
+ # For now, return method-not-found for all server requests.
540
+ # Full sampling support would require an LLM callback from the agent.
541
+ error_response = {
542
+ "jsonrpc": "2.0",
543
+ "id": req_id,
544
+ "error": {
545
+ "code": METHOD_NOT_FOUND,
546
+ "message": f"Method not supported by this client: {method}",
547
+ },
548
+ }
549
+ await self._send_raw(error_response)
550
+
551
+ async def _handle_notification(self, method: str, params: dict[str, Any]) -> None:
552
+ """Handle a notification from the server."""
553
+ if method == "notifications/progress":
554
+ # Progress token + progress + total
555
+ progress_token = params.get("progressToken")
556
+ progress = params.get("progress", 0)
557
+ total = params.get("total", 0)
558
+ if self._on_progress:
559
+ self._on_progress(progress, total, progress_token)
560
+ logger.debug("[%s] Progress: %d/%d", self.name, progress, total)
561
+
562
+ elif method == "notifications/resources/updated":
563
+ uri = params.get("uri", "?")
564
+ logger.info("[%s] Resource updated: %s", self.name, uri)
565
+
566
+ elif method == "notifications/resources/list_changed":
567
+ logger.info("[%s] Resource list changed — re-discovering", self.name)
568
+ try:
569
+ await self.discover()
570
+ except Exception:
571
+ pass
572
+
573
+ elif method == "notifications/tools/list_changed":
574
+ logger.info("[%s] Tool list changed — re-discovering", self.name)
575
+ try:
576
+ await self.discover()
577
+ except Exception:
578
+ pass
579
+
580
+ elif method == "notifications/prompts/list_changed":
581
+ logger.info("[%s] Prompt list changed — re-discovering", self.name)
582
+ try:
583
+ await self.discover()
584
+ except Exception:
585
+ pass
586
+
587
+ elif method == "notifications/message":
588
+ # Server→client log message
589
+ level = params.get("level", "info")
590
+ data = params.get("data", "")
591
+ log_func = getattr(logger, level, logger.info)
592
+ log_func("[%s] %s", self.name, data)
593
+
594
+
595
+ # ═══════════════════════════════════════════════════════════════════════════════
596
+ # HTTP / SSE connection
597
+ # ═══════════════════════════════════════════════════════════════════════════════
598
+
599
+ class HTTPMCPConnection(MCPServerConnection):
600
+ """MCP connection over HTTP (Streamable HTTP transport)."""
601
+
602
+ def __init__(self, name: str, url: str, headers: dict[str, str] | None = None):
603
+ super().__init__(name)
604
+ self.url = url.rstrip("/")
605
+ self._headers = headers or {}
606
+ self._client: httpx.Client | None = None
607
+ self._id_counter = 0
608
+
609
+ def _next_id(self) -> str:
610
+ self._id_counter += 1
611
+ return str(self._id_counter)
612
+
613
+ async def start(self) -> None:
614
+ """Initialize HTTP connection."""
615
+ self._client = httpx.Client(
616
+ timeout=httpx.Timeout(120.0, connect=30.0),
617
+ headers={
618
+ "Content-Type": "application/json",
619
+ **self._headers,
620
+ },
621
+ )
622
+ logger.info("[%s] Connecting to %s", self.name, self.url)
623
+
624
+ try:
625
+ await self.initialize()
626
+ await self.discover()
627
+ except Exception:
628
+ await self.stop()
629
+ raise
630
+
631
+ async def stop(self) -> None:
632
+ if self._client:
633
+ self._client.close()
634
+ self._client = None
635
+ logger.info("[%s] Disconnected", self.name)
636
+
637
+ def _post(self, msg: dict[str, Any]) -> httpx.Response:
638
+ if not self._client:
639
+ raise RuntimeError("MCP HTTP client not connected")
640
+ response = self._client.post(self.url, json=msg)
641
+ response.raise_for_status()
642
+ return response
643
+
644
+ async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
645
+ msg = {
646
+ "jsonrpc": "2.0",
647
+ "method": method,
648
+ "params": params or {},
649
+ "id": self._next_id(),
650
+ }
651
+ response = await asyncio.to_thread(self._post, msg)
652
+
653
+ # Handle SSE stream for streaming responses
654
+ ct = response.headers.get("content-type", "")
655
+ if "text/event-stream" in ct:
656
+ return self._read_sse(response)
657
+
658
+ data = response.json()
659
+ if data.get("error"):
660
+ raise JsonRpcError(
661
+ data["error"].get("code", INTERNAL_ERROR),
662
+ data["error"].get("message", "unknown"),
663
+ data["error"].get("data"),
664
+ )
665
+ return data.get("result")
666
+
667
+ async def send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
668
+ msg = {
669
+ "jsonrpc": "2.0",
670
+ "method": method,
671
+ "params": params or {},
672
+ }
673
+ await asyncio.to_thread(self._post, msg)
674
+
675
+ @staticmethod
676
+ def _read_sse(response: httpx.Response) -> Any:
677
+ """Read SSE stream, collect the final result."""
678
+ result = None
679
+ for line in response.iter_lines():
680
+ if line.startswith("data: "):
681
+ try:
682
+ data = json.loads(line[6:])
683
+ if data.get("result") is not None:
684
+ result = data["result"]
685
+ if data.get("error"):
686
+ raise JsonRpcError(
687
+ data["error"].get("code", INTERNAL_ERROR),
688
+ data["error"].get("message", "unknown"),
689
+ )
690
+ except json.JSONDecodeError:
691
+ continue
692
+ return result
693
+
694
+
695
+ # ═══════════════════════════════════════════════════════════════════════════════
696
+ # MCP Client — manages multiple connections
697
+ # ═══════════════════════════════════════════════════════════════════════════════
698
+
699
+ @dataclass
700
+ class MCPServerConfig:
701
+ """Configuration for a single MCP server."""
702
+ name: str
703
+ transport: str = "stdio"
704
+ # stdio config
705
+ command: str = ""
706
+ args: list[str] = field(default_factory=list)
707
+ env: dict[str, str] = field(default_factory=dict)
708
+ cwd: str = ""
709
+ # http config
710
+ url: str = ""
711
+ headers: dict[str, str] = field(default_factory=dict)
712
+
713
+
714
+ class MCPClient:
715
+ """
716
+ MCP client managing multiple MCP server connections.
717
+
718
+ Discovers tools, resources, prompts from all servers.
719
+ Provides unified search, caching, and health monitoring.
720
+ """
721
+
722
+ def __init__(self, servers: list[MCPServerConfig] | None = None):
723
+ self._connections: dict[str, MCPServerConnection] = {}
724
+ self._tool_to_server: dict[str, str] = {}
725
+ self._all_tools: list[dict[str, Any]] = []
726
+ self._resource_cache: OrderedDict[str, tuple[Any, float]] = OrderedDict()
727
+ self._resource_cache_max = 64
728
+ self._resource_cache_ttl = 300.0
729
+ self._health_task: asyncio.Task | None = None
730
+ self._health_interval = 60.0
731
+ self._health_running = False
732
+ self._on_health_fail: Callable[[str], None] | None = None
733
+
734
+ if servers:
735
+ for cfg in servers:
736
+ self.add_server(cfg)
737
+
738
+ # ── Server lifecycle ───────────────────────────────────────────────────
739
+
740
+ async def add_server(self, config: MCPServerConfig) -> None:
741
+ """Add and connect to an MCP server."""
742
+ if config.transport == "stdio":
743
+ conn = StdioMCPConnection(
744
+ name=config.name,
745
+ command=config.command,
746
+ args=config.args,
747
+ env=config.env or None,
748
+ cwd=config.cwd or None,
749
+ )
750
+ elif config.transport == "http":
751
+ conn = HTTPMCPConnection(
752
+ name=config.name,
753
+ url=config.url,
754
+ headers=config.headers or None,
755
+ )
756
+ else:
757
+ raise ValueError(f"Unknown transport: {config.transport}")
758
+
759
+ try:
760
+ await conn.start()
761
+ self._connections[config.name] = conn
762
+ self._register_server_tools(config.name, conn)
763
+ logger.info(
764
+ "Added MCP server '%s': %d tools, %d resources, %d prompts",
765
+ config.name, len(conn.tools), len(conn.resources), len(conn.prompts),
766
+ )
767
+ except Exception:
768
+ try:
769
+ await conn.stop()
770
+ except Exception:
771
+ pass
772
+ raise
773
+
774
+ async def remove_server(self, name: str) -> None:
775
+ """Disconnect and remove an MCP server."""
776
+ conn = self._connections.pop(name, None)
777
+ if conn:
778
+ await conn.stop()
779
+ self._all_tools = [t for t in self._all_tools if t.get("_mcp_server") != name]
780
+ self._tool_to_server = {k: v for k, v in self._tool_to_server.items() if v != name}
781
+ # Purge cache entries from this server
782
+ self._resource_cache = OrderedDict(
783
+ (k, v) for k, v in self._resource_cache.items()
784
+ if not k.startswith(f"{name}:")
785
+ )
786
+ logger.info("Removed MCP server '%s'", name)
787
+
788
+ async def stop_all(self) -> None:
789
+ """Stop all MCP server connections."""
790
+ await self._stop_health_monitor()
791
+ for name, conn in list(self._connections.items()):
792
+ try:
793
+ await conn.stop()
794
+ except Exception:
795
+ pass
796
+ self._connections.clear()
797
+ self._all_tools.clear()
798
+ self._tool_to_server.clear()
799
+ self._resource_cache.clear()
800
+ logger.info("All MCP servers stopped")
801
+
802
+ # ── Tool registration ───────────────────────────────────────────────────
803
+
804
+ def _register_server_tools(self, server_name: str, conn: MCPServerConnection) -> None:
805
+ """Register tools from a server connection."""
806
+ for tool in conn.tools:
807
+ tool_name = tool["name"]
808
+ prefixed = f"mcp__{server_name}__{tool_name}"
809
+ if len(prefixed) > 64:
810
+ suffix = tool_name[-30:] if len(tool_name) > 30 else tool_name
811
+ prefixed = f"mcp__{server_name[:20]}__{suffix}"
812
+ logger.warning("MCP tool name truncated: %s", prefixed)
813
+ self._tool_to_server[prefixed] = server_name
814
+ tool["_mcp_server"] = server_name
815
+ tool["_mcp_original_name"] = tool_name
816
+ self._all_tools.append(tool)
817
+
818
+ def refresh_tools(self, server_name: str | None = None) -> None:
819
+ """Re-discover and re-register tools from one or all servers."""
820
+ names = [server_name] if server_name else list(self._connections)
821
+ for name in names:
822
+ conn = self._connections.get(name)
823
+ if not conn:
824
+ continue
825
+ # Remove old tools for this server
826
+ self._all_tools = [t for t in self._all_tools if t.get("_mcp_server") != name]
827
+ self._tool_to_server = {k: v for k, v in self._tool_to_server.items() if v != name}
828
+ # Re-discover and register
829
+ conn.discover()
830
+ self._register_server_tools(name, conn)
831
+
832
+ # ── Tool access ─────────────────────────────────────────────────────────
833
+
834
+ def get_tools(self) -> list[dict[str, Any]]:
835
+ """Get all tools as OpenAI function tool definitions."""
836
+ openai_tools = []
837
+ for tool in self._all_tools:
838
+ server = tool.get("_mcp_server", "?")
839
+ original = tool.get("_mcp_original_name", tool.get("name", "?"))
840
+ openai_tools.append({
841
+ "type": "function",
842
+ "function": {
843
+ "name": f"mcp__{server}__{original}",
844
+ "description": tool.get("description", f"MCP tool: {tool['name']}"),
845
+ "parameters": tool.get("inputSchema", {
846
+ "type": "object", "properties": {},
847
+ }),
848
+ },
849
+ })
850
+ return openai_tools
851
+
852
+ async def call_tool(self, prefixed_name: str, arguments: dict[str, Any]) -> Any:
853
+ """Call an MCP tool by its prefixed name."""
854
+ server_name = self._tool_to_server.get(prefixed_name)
855
+ if not server_name:
856
+ raise ValueError(f"Unknown MCP tool: {prefixed_name}")
857
+
858
+ conn = self._connections.get(server_name)
859
+ if not conn:
860
+ raise RuntimeError(f"MCP server not connected: {server_name}")
861
+
862
+ for tool in self._all_tools:
863
+ srv = tool.get("_mcp_server")
864
+ original = tool.get("_mcp_original_name")
865
+ if srv == server_name and f"mcp__{srv}__{original}" == prefixed_name:
866
+ return await conn.call_tool(tool["_mcp_original_name"], arguments)
867
+
868
+ raise ValueError(f"Tool not found: {prefixed_name}")
869
+
870
+ def is_mcp_tool(self, tool_name: str) -> bool:
871
+ return tool_name.startswith("mcp__") and tool_name in self._tool_to_server
872
+
873
+ # ── Prompts ─────────────────────────────────────────────────────────────
874
+
875
+ def list_prompts(self) -> list[dict[str, Any]]:
876
+ """List all prompts from all servers."""
877
+ result: list[dict[str, Any]] = []
878
+ for name, conn in self._connections.items():
879
+ for p in conn.prompts:
880
+ result.append({**p, "_mcp_server": name})
881
+ return result
882
+
883
+ async def get_prompt(self, server: str, prompt_name: str,
884
+ arguments: dict[str, str] | None = None) -> Any:
885
+ """Get a prompt from a specific server."""
886
+ conn = self._connections.get(server)
887
+ if not conn:
888
+ raise ValueError(f"Server not found: {server}")
889
+ return await conn.get_prompt(prompt_name, arguments)
890
+
891
+ # ── Search ──────────────────────────────────────────────────────────────
892
+
893
+ def search_tools(self, query: str, limit: int = 20) -> list[dict[str, Any]]:
894
+ """Fuzzy search MCP tools across all servers."""
895
+ q = query.lower().strip()
896
+ if not q:
897
+ return []
898
+
899
+ scored: list[tuple[int, dict[str, Any]]] = []
900
+ for tool in self._all_tools:
901
+ name = tool.get("name", "")
902
+ desc = tool.get("description", "")
903
+ name_l = name.lower()
904
+ score = 0
905
+ if name_l == q:
906
+ score = 3
907
+ elif name_l.startswith(q):
908
+ score = 2
909
+ elif q in name_l:
910
+ score = 1
911
+ elif q in desc.lower():
912
+ score = 0
913
+
914
+ if q in name_l or q in desc.lower():
915
+ scored.append((score, tool))
916
+
917
+ scored.sort(key=lambda x: (-x[0], x[1].get("name", "")))
918
+ return [t for _, t in scored[:limit]]
919
+
920
+ def search_resources(self, query: str, limit: int = 20) -> list[dict[str, Any]]:
921
+ """Search MCP resources by URI across all servers."""
922
+ q = query.lower().strip()
923
+ if not q:
924
+ return []
925
+
926
+ results: list[dict[str, Any]] = []
927
+ for conn in self._connections.values():
928
+ for res in conn.resources:
929
+ uri = res.get("uri", "").lower()
930
+ name = res.get("name", "").lower()
931
+ desc = res.get("description", "").lower()
932
+ if q in uri or q in name or q in desc:
933
+ results.append({**res, "_mcp_server": conn.name})
934
+ results.sort(key=lambda r: r.get("name", r.get("uri", "")))
935
+ return results[:limit]
936
+
937
+ def get_all_resources(self) -> list[dict[str, Any]]:
938
+ """Return all discovered resources from all servers."""
939
+ results: list[dict[str, Any]] = []
940
+ for conn in self._connections.values():
941
+ for res in conn.resources:
942
+ results.append({**res, "_mcp_server": conn.name})
943
+ results.sort(key=lambda r: r.get("name", r.get("uri", "")))
944
+ return results
945
+
946
+ # ── Resource cache ──────────────────────────────────────────────────────
947
+
948
+ def cached_read_resource(self, uri: str) -> dict[str, Any]:
949
+ """Read a resource with LRU+TTL caching."""
950
+ now = time.time()
951
+ if uri in self._resource_cache:
952
+ content, ts = self._resource_cache[uri]
953
+ if now - ts < self._resource_cache_ttl:
954
+ self._resource_cache.move_to_end(uri)
955
+ return {"content": content, "cached": True, "server": ""}
956
+ del self._resource_cache[uri]
957
+
958
+ # Find the owning server
959
+ for conn in self._connections.values():
960
+ for res in conn.resources:
961
+ if res.get("uri") == uri:
962
+ result = conn.read_resource(uri)
963
+ content = result.get("contents", result)
964
+ if len(self._resource_cache) >= self._resource_cache_max:
965
+ self._resource_cache.popitem(last=False)
966
+ self._resource_cache[uri] = (content, now)
967
+ return {"content": content, "cached": False, "server": conn.name}
968
+
969
+ # Try resource templates
970
+ for conn in self._connections.values():
971
+ for tmpl in conn.resource_templates:
972
+ tmpl_uri = tmpl.get("uriTemplate", "")
973
+ # Simple match: if URI starts with the template prefix
974
+ prefix = tmpl_uri.split("{")[0] if "{" in tmpl_uri else tmpl_uri
975
+ if uri.startswith(prefix):
976
+ result = conn.read_resource(uri)
977
+ content = result.get("contents", result)
978
+ if len(self._resource_cache) >= self._resource_cache_max:
979
+ self._resource_cache.popitem(last=False)
980
+ self._resource_cache[uri] = (content, now)
981
+ return {"content": content, "cached": False, "server": conn.name}
982
+
983
+ raise ValueError(f"Resource not found on any server: {uri}")
984
+
985
+ def invalidate_resource_cache(self, uri: str | None = None) -> None:
986
+ """Invalidate cached resources."""
987
+ if uri:
988
+ self._resource_cache.pop(uri, None)
989
+ else:
990
+ self._resource_cache.clear()
991
+
992
+ # ── Health monitoring ───────────────────────────────────────────────────
993
+
994
+ def on_health_fail(self, callback: Callable[[str], None]) -> None:
995
+ """Register a callback for health check failures."""
996
+ self._on_health_fail = callback
997
+
998
+ def start_health_monitor(self, interval: float = 60.0) -> None:
999
+ """Start periodic health checks (ping every N seconds)."""
1000
+ if self._health_running:
1001
+ return
1002
+ self._health_interval = interval
1003
+ self._health_running = True
1004
+ self._health_task = asyncio.create_task(self._health_loop())
1005
+ logger.info("MCP health monitor started (interval=%.0fs)", interval)
1006
+
1007
+ async def _stop_health_monitor(self) -> None:
1008
+ self._health_running = False
1009
+ if self._health_task and not self._health_task.done():
1010
+ self._health_task.cancel()
1011
+ try:
1012
+ await self._health_task
1013
+ except asyncio.CancelledError:
1014
+ pass
1015
+ self._health_task = None
1016
+
1017
+ async def _health_loop(self) -> None:
1018
+ while self._health_running:
1019
+ await asyncio.sleep(self._health_interval)
1020
+ if not self._health_running:
1021
+ break
1022
+ for name, conn in list(self._connections.items()):
1023
+ try:
1024
+ if not await conn.ping(timeout=10):
1025
+ logger.warning("[%s] Health check failed: no response", name)
1026
+ if self._on_health_fail:
1027
+ self._on_health_fail(name)
1028
+ except asyncio.CancelledError:
1029
+ raise
1030
+ except Exception as e:
1031
+ logger.warning("[%s] Health check error: %s", name, e)
1032
+ if self._on_health_fail:
1033
+ self._on_health_fail(name)
1034
+
1035
+ # ── Properties ──────────────────────────────────────────────────────────
1036
+
1037
+ @property
1038
+ def connected_servers(self) -> list[str]:
1039
+ return list(self._connections.keys())
1040
+
1041
+ @property
1042
+ def tool_count(self) -> int:
1043
+ return len(self._all_tools)
1044
+
1045
+ @property
1046
+ def resource_count(self) -> int:
1047
+ return sum(len(c.resources) for c in self._connections.values())
1048
+
1049
+
1050
+ # ═══════════════════════════════════════════════════════════════════════════════
1051
+ # MCP config file support
1052
+ # ═══════════════════════════════════════════════════════════════════════════════
1053
+
1054
+ def load_mcp_config(config_path: str | Path) -> list[MCPServerConfig]:
1055
+ """
1056
+ Load MCP server configurations from a JSON file.
1057
+
1058
+ Example config.json:
1059
+ {
1060
+ "mcpServers": {
1061
+ "filesystem": {
1062
+ "transport": "stdio",
1063
+ "command": "npx",
1064
+ "args": ["-y", "@anthropic/mcp-filesystem", "/path/to/allowed"]
1065
+ },
1066
+ "github": {
1067
+ "transport": "stdio",
1068
+ "command": "npx",
1069
+ "args": ["-y", "@anthropic/mcp-github"],
1070
+ "env": {"GITHUB_TOKEN": "ghp_xxx"}
1071
+ },
1072
+ "remote-api": {
1073
+ "transport": "http",
1074
+ "url": "https://mcp.example.com/mcp",
1075
+ "headers": {"Authorization": "Bearer xxx"}
1076
+ }
1077
+ }
1078
+ }
1079
+ """
1080
+ with open(config_path, "r", encoding="utf-8") as f:
1081
+ data = json.load(f)
1082
+
1083
+ servers = []
1084
+ for name, cfg in data.get("mcpServers", {}).items():
1085
+ servers.append(MCPServerConfig(
1086
+ name=name,
1087
+ transport=cfg.get("transport", "stdio"),
1088
+ command=cfg.get("command", ""),
1089
+ args=cfg.get("args", []),
1090
+ env=cfg.get("env", {}),
1091
+ cwd=cfg.get("cwd", ""),
1092
+ url=cfg.get("url", ""),
1093
+ headers=cfg.get("headers", {}),
1094
+ ))
1095
+ return servers