zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/main.py ADDED
@@ -0,0 +1,210 @@
1
+ import typer
2
+ from typing import Optional
3
+ from typer.core import TyperGroup
4
+ from rich.console import Console
5
+ from rich.panel import Panel
6
+ from rich.prompt import Prompt, Confirm
7
+
8
+ from .config import load_config
9
+ from .core.fallback import (
10
+ format_model_selection,
11
+ has_available_provider,
12
+ stream_with_fallback,
13
+ )
14
+ from .core.context import ContextManager
15
+ from .core.memory import save_session, get_last_session
16
+ from .core.errors import show_error, AllModelsFailedError
17
+ from .core.runtime import configure as configure_runtime
18
+ from .core.runtime import plain_enabled, print_exception
19
+ from .cli.settings import register_settings_commands
20
+ from .cli.integrations import register_integration_commands
21
+ from .cli.utilities import register_utility_commands
22
+ from .cli.workflows import register_workflow_commands
23
+ from .cli.common import closest_command as _closest_command
24
+ from .cli.interactive import run_interactive
25
+
26
+
27
+ class FuzzyTyperGroup(TyperGroup):
28
+ """Accept safe, unambiguous typos in top-level command names."""
29
+
30
+ def get_command(self, ctx, cmd_name):
31
+ command = super().get_command(ctx, cmd_name)
32
+ if command is not None:
33
+ return command
34
+
35
+ corrected = _closest_command(cmd_name, self.list_commands(ctx))
36
+ if corrected:
37
+ console.print(f"[yellow]Command '{cmd_name}' not found; using '{corrected}'.[/yellow]")
38
+ return super().get_command(ctx, corrected)
39
+ return None
40
+
41
+
42
+ app = typer.Typer(
43
+ help="zai - Your personal AI CLI. Free. Fast. Smart.",
44
+ cls=FuzzyTyperGroup,
45
+ )
46
+ register_settings_commands(app)
47
+ register_integration_commands(app)
48
+ import io, sys
49
+ if sys.stdout.encoding and sys.stdout.encoding.lower() != 'utf-8':
50
+ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
51
+ sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
52
+ console = Console()
53
+
54
+ context = ContextManager()
55
+ SYSTEM_PROMPT = (
56
+ "You are zai, a smart AI assistant running in the terminal. Be concise, helpful, and direct. "
57
+ "IMPORTANT: When creating files, ALWAYS use this exact format for each file:\n"
58
+ "**File:** filename.ext\n"
59
+ "```lang\n"
60
+ "...code...\n"
61
+ "```\n"
62
+ "Use this format every time so files can be automatically saved to disk."
63
+ )
64
+ register_utility_commands(app, context, SYSTEM_PROMPT)
65
+ register_workflow_commands(app, context, SYSTEM_PROMPT)
66
+
67
+
68
+ def _parse_files_from_response(content: str) -> dict:
69
+ """Extract filename→code pairs from AI response."""
70
+ import re
71
+ files = {}
72
+
73
+ # Pattern 1: **Anything File:** filename.ext (e.g. **HTML File:** index.html)
74
+ pattern1 = r'\*\*[\w\s]*[Ff]ile[\w\s]*:\*\*\s*([\w./\\-]+\.\w+)[^\n]*\n```[\w]*\n(.*?)```'
75
+ for m in re.finditer(pattern1, content, re.DOTALL):
76
+ files[m.group(1).strip()] = m.group(2).strip()
77
+
78
+ # Pattern 2: **File Name:** or **Filename:** variant
79
+ pattern2 = r'\*\*[\w\s]*:\*\*\s*([\w./\\-]+\.\w+)\s*\n```[\w]*\n(.*?)```'
80
+ for m in re.finditer(pattern2, content, re.DOTALL):
81
+ fname = m.group(1).strip()
82
+ if fname not in files:
83
+ files[fname] = m.group(2).strip()
84
+
85
+ # Pattern 3: ```lang\n# filename or // filename or <!-- filename -->
86
+ pattern3 = r'```(\w+)\n(?:#|//|<!--)\s*([\w./\\-]+\.\w+)[^\n]*\n(.*?)```'
87
+ for m in re.finditer(pattern3, content, re.DOTALL):
88
+ fname = m.group(2).strip()
89
+ if fname not in files:
90
+ files[fname] = m.group(3).strip()
91
+
92
+ # Pattern 4: filename.ext on its own line (or backtick-wrapped) before code block
93
+ pattern4 = r'(?:^|\n)`?([\w./\\-]+\.\w+)`?\s*\n```[\w]*\n(.*?)```'
94
+ for m in re.finditer(pattern4, content, re.DOTALL):
95
+ fname = m.group(1).strip()
96
+ if fname not in files and '/' not in fname.split('.')[-1]:
97
+ files[fname] = m.group(2).strip()
98
+
99
+ return files
100
+
101
+
102
+ def _chat(message: str, model: str = None) -> str:
103
+ config = load_config()
104
+ preferred = model or config["default_model"]
105
+ context.set_model(preferred)
106
+
107
+ if not has_available_provider():
108
+ show_error("No AI provider available. Run: zai setup, or start Ollama.")
109
+ return ""
110
+
111
+ context.add("user", message)
112
+
113
+ if context.is_near_limit():
114
+ console.print("[yellow]Context compressing...[/yellow]")
115
+ context.compress(preferred)
116
+
117
+ try:
118
+ content, used_model = stream_with_fallback(
119
+ context.get_messages(),
120
+ system=SYSTEM_PROMPT,
121
+ preferred=preferred,
122
+ )
123
+ except AllModelsFailedError as e:
124
+ show_error(str(e))
125
+ return ""
126
+ except Exception as e:
127
+ print_exception(console, e)
128
+ return ""
129
+
130
+ context.add("assistant", content)
131
+
132
+ if used_model != preferred and not plain_enabled():
133
+ console.print(f"[dim]Switched to: {format_model_selection(used_model)}[/dim]")
134
+
135
+ session_model = used_model
136
+ if plain_enabled():
137
+ save_session(task=message[:80], model=session_model)
138
+ return content
139
+
140
+ used_model = format_model_selection(used_model)
141
+
142
+ console.print(f"[dim]── {used_model} ──[/dim]")
143
+
144
+ if config.get("show_token_count") and not plain_enabled():
145
+ console.print(f"[dim]Tokens: {context.usage_bar()}[/dim]")
146
+
147
+ save_session(task=message[:80], model=session_model)
148
+ return content
149
+
150
+
151
+ @app.command()
152
+ def chat(
153
+ message: Optional[str] = typer.Argument(None, help="Message to send"),
154
+ model: Optional[str] = typer.Option(None, "--model", "-m", help="Model to use"),
155
+ ):
156
+ """Chat with AI. Supports auto model fallback."""
157
+ if message:
158
+ _chat(message, model)
159
+ return
160
+
161
+ last = get_last_session()
162
+ if last:
163
+ console.print(f"[dim]Last session: {last['task']} ({last['date']})[/dim]")
164
+ if Confirm.ask("Continue last session?", default=False):
165
+ context.add("assistant", f"Continuing from: {last['task']}")
166
+
167
+ console.print(Panel("[bold cyan]zai[/bold cyan] — type your message. [dim]Ctrl+C to exit.[/dim]"))
168
+ while True:
169
+ try:
170
+ msg = Prompt.ask("[cyan]you[/cyan]")
171
+ if msg.strip():
172
+ _chat(msg, model)
173
+ except (KeyboardInterrupt, EOFError):
174
+ console.print("\n[dim]Goodbye![/dim]")
175
+ break
176
+
177
+
178
+ @app.command("ask")
179
+ def ask(
180
+ message: str = typer.Argument(..., help="Message to send to AI"),
181
+ model: Optional[str] = typer.Option(None, "--model", "-m"),
182
+ ):
183
+ """Send a quick message to AI (shortcut for chat)."""
184
+ _chat(message, model)
185
+
186
+
187
+ @app.callback(invoke_without_command=True)
188
+ def main(
189
+ ctx: typer.Context,
190
+ model: Optional[str] = typer.Option(None, "--model", "-m"),
191
+ version: bool = typer.Option(False, "--version", "-v"),
192
+ debug: bool = typer.Option(False, "--debug", help="Show full tracebacks"),
193
+ plain: bool = typer.Option(
194
+ False,
195
+ "--plain",
196
+ help="Emit plain one-shot output without streaming UI or metadata",
197
+ ),
198
+ ):
199
+ """zai - Your personal AI CLI."""
200
+ configure_runtime(debug=debug, plain=plain)
201
+ if version:
202
+ from . import __version__
203
+ console.print(f"zai v{__version__}")
204
+ return
205
+ if ctx.invoked_subcommand is None:
206
+ run_interactive(model)
207
+
208
+
209
+ if __name__ == "__main__":
210
+ app()
zai/mcp/__init__.py ADDED
File without changes
zai/mcp/client.py ADDED
@@ -0,0 +1,431 @@
1
+ """MCP JSON-RPC 2.0 client over the standard stdio transport."""
2
+ from __future__ import annotations
3
+
4
+ import atexit
5
+ import json
6
+ import os
7
+ import shlex
8
+ import subprocess
9
+ import threading
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from typing import Callable, Optional
13
+
14
+ from rich.console import Console
15
+ from ..core.cancellation import current_token
16
+
17
+ console = Console()
18
+
19
+ PROTOCOL_VERSION = "2025-06-18"
20
+ SUPPORTED_PROTOCOL_VERSIONS = {PROTOCOL_VERSION, "2024-11-05"}
21
+ DEFAULT_REQUEST_TIMEOUT = 30.0
22
+
23
+ _active_clients: dict[str, "MCPClient"] = {}
24
+
25
+
26
+ class MCPError(RuntimeError):
27
+ pass
28
+
29
+
30
+ class MCPTimeoutError(MCPError):
31
+ pass
32
+
33
+
34
+ @dataclass
35
+ class _PendingRequest:
36
+ event: threading.Event = field(default_factory=threading.Event)
37
+ response: Optional[dict] = None
38
+
39
+
40
+ class MCPClient:
41
+ """Connect to one MCP server process using newline-delimited JSON-RPC."""
42
+
43
+ def __init__(
44
+ self,
45
+ name: str,
46
+ command: str | list[str],
47
+ env: dict | None = None,
48
+ request_timeout: float = DEFAULT_REQUEST_TIMEOUT,
49
+ progress_callback: Callable[[dict], None] | None = None,
50
+ ):
51
+ self.name = name
52
+ self.command = command
53
+ self.env = env or {}
54
+ self.request_timeout = request_timeout
55
+ self.progress_callback = progress_callback
56
+ self.process: Optional[subprocess.Popen] = None
57
+ self._id = 0
58
+ self._tools: list[dict] = []
59
+ self._pending: dict[int, _PendingRequest] = {}
60
+ self._pending_lock = threading.Lock()
61
+ self._write_lock = threading.Lock()
62
+ self._reader_thread: threading.Thread | None = None
63
+ self._stderr_thread: threading.Thread | None = None
64
+ self._reader_error: str | None = None
65
+ self.server_info: dict = {}
66
+ self.server_capabilities: dict = {}
67
+ self.protocol_version: str | None = None
68
+ self.ready = False
69
+
70
+ def _argv(self) -> list[str]:
71
+ if isinstance(self.command, list):
72
+ argv = list(self.command)
73
+ else:
74
+ argv = shlex.split(self.command, posix=os.name != "nt")
75
+ if (
76
+ os.name == "nt"
77
+ and argv
78
+ and argv[0].lower() in {"npx", "npx.cmd", "npm", "npm.cmd"}
79
+ ):
80
+ return ["cmd", "/c", *argv]
81
+ return argv
82
+
83
+ def start(self) -> bool:
84
+ env = {**os.environ, **self.env}
85
+ try:
86
+ self.process = subprocess.Popen(
87
+ self._argv(),
88
+ shell=False,
89
+ stdin=subprocess.PIPE,
90
+ stdout=subprocess.PIPE,
91
+ stderr=subprocess.PIPE,
92
+ text=True,
93
+ encoding="utf-8",
94
+ errors="replace",
95
+ bufsize=1,
96
+ env=env,
97
+ )
98
+ self._reader_thread = threading.Thread(
99
+ target=self._read_loop,
100
+ name=f"zai-mcp-{self.name}",
101
+ daemon=True,
102
+ )
103
+ self._reader_thread.start()
104
+ self._stderr_thread = threading.Thread(
105
+ target=self._drain_stderr,
106
+ name=f"zai-mcp-stderr-{self.name}",
107
+ daemon=True,
108
+ )
109
+ self._stderr_thread.start()
110
+
111
+ response = self._request("initialize", {
112
+ "protocolVersion": PROTOCOL_VERSION,
113
+ "capabilities": {},
114
+ "clientInfo": {"name": "zai", "version": "0.1.0"},
115
+ })
116
+ result = response.get("result", {})
117
+ selected = result.get("protocolVersion")
118
+ if selected not in SUPPORTED_PROTOCOL_VERSIONS:
119
+ raise MCPError(
120
+ f"unsupported protocol version: {selected or 'missing'}"
121
+ )
122
+ self.protocol_version = selected
123
+ self.server_capabilities = result.get("capabilities", {})
124
+ self.server_info = result.get("serverInfo", {})
125
+ self._notify("notifications/initialized", {})
126
+
127
+ if "tools" in self.server_capabilities:
128
+ self._tools = self._fetch_tools()
129
+ self.ready = True
130
+ console.print(
131
+ f"[green]MCP connected:[/green] {self.name} "
132
+ f"({len(self._tools)} tools, {self.protocol_version})"
133
+ )
134
+ return True
135
+ except Exception as error:
136
+ console.print(f"[red]MCP error ({self.name}): {error}[/red]")
137
+ self.stop()
138
+ return False
139
+
140
+ def _next_id(self) -> int:
141
+ with self._pending_lock:
142
+ self._id += 1
143
+ return self._id
144
+
145
+ def _write(self, message: dict) -> None:
146
+ if not self.process or self.process.poll() is not None or not self.process.stdin:
147
+ raise MCPError("server process is not running")
148
+ payload = json.dumps(message, separators=(",", ":"))
149
+ with self._write_lock:
150
+ self.process.stdin.write(payload + "\n")
151
+ self.process.stdin.flush()
152
+
153
+ def _notify(self, method: str, params: dict | None = None) -> None:
154
+ self._write({
155
+ "jsonrpc": "2.0",
156
+ "method": method,
157
+ "params": params or {},
158
+ })
159
+
160
+ def _request(
161
+ self,
162
+ method: str,
163
+ params: dict | None = None,
164
+ timeout: float | None = None,
165
+ ) -> dict:
166
+ request_id = self._next_id()
167
+ pending = _PendingRequest()
168
+ with self._pending_lock:
169
+ self._pending[request_id] = pending
170
+ try:
171
+ self._write({
172
+ "jsonrpc": "2.0",
173
+ "id": request_id,
174
+ "method": method,
175
+ "params": params or {},
176
+ })
177
+ wait_for = self.request_timeout if timeout is None else timeout
178
+ deadline = time.monotonic() + max(wait_for, 0.0)
179
+ while not pending.event.is_set():
180
+ token = current_token()
181
+ if token and token.cancelled:
182
+ try:
183
+ self._notify("notifications/cancelled", {
184
+ "requestId": request_id,
185
+ "reason": token.reason,
186
+ })
187
+ except MCPError:
188
+ pass
189
+ raise MCPError(token.reason)
190
+ remaining = deadline - time.monotonic()
191
+ if remaining <= 0:
192
+ break
193
+ pending.event.wait(min(0.05, remaining))
194
+ if not pending.event.is_set():
195
+ try:
196
+ self._notify("notifications/cancelled", {
197
+ "requestId": request_id,
198
+ "reason": f"zai timed out after {wait_for:g}s",
199
+ })
200
+ except MCPError:
201
+ pass
202
+ raise MCPTimeoutError(
203
+ f"{method} timed out after {wait_for:g}s"
204
+ )
205
+ if pending.response is None:
206
+ raise MCPError(self._reader_error or "server connection closed")
207
+ return pending.response
208
+ finally:
209
+ with self._pending_lock:
210
+ self._pending.pop(request_id, None)
211
+
212
+ # Compatibility for callers/tests that used the old method name.
213
+ def _send(
214
+ self,
215
+ method: str,
216
+ params: dict | None = None,
217
+ timeout: float | None = None,
218
+ ) -> dict:
219
+ return self._request(method, params, timeout)
220
+
221
+ def _read_loop(self) -> None:
222
+ try:
223
+ if not self.process or not self.process.stdout:
224
+ return
225
+ for line in self.process.stdout:
226
+ if not line.strip():
227
+ continue
228
+ try:
229
+ message = json.loads(line)
230
+ except json.JSONDecodeError:
231
+ continue
232
+ if "id" in message and ("result" in message or "error" in message):
233
+ with self._pending_lock:
234
+ pending = self._pending.get(message["id"])
235
+ if pending:
236
+ pending.response = message
237
+ pending.event.set()
238
+ continue
239
+ self._handle_server_message(message)
240
+ except Exception as error:
241
+ self._reader_error = str(error)
242
+ finally:
243
+ self._fail_pending()
244
+
245
+ def _drain_stderr(self) -> None:
246
+ """Drain server logs so a full stderr pipe cannot block the process."""
247
+ try:
248
+ if not self.process or not self.process.stderr:
249
+ return
250
+ for _line in self.process.stderr:
251
+ pass
252
+ except Exception:
253
+ pass
254
+
255
+ def _handle_server_message(self, message: dict) -> None:
256
+ method = message.get("method")
257
+ params = message.get("params", {})
258
+ if method == "notifications/progress":
259
+ if self.progress_callback:
260
+ self.progress_callback(params)
261
+ return
262
+ if method == "notifications/tools/list_changed" and self.ready:
263
+ # Avoid making a blocking request on the reader thread.
264
+ threading.Thread(
265
+ target=self._refresh_tools,
266
+ name=f"zai-mcp-refresh-{self.name}",
267
+ daemon=True,
268
+ ).start()
269
+ return
270
+ if "id" in message and method:
271
+ try:
272
+ self._write({
273
+ "jsonrpc": "2.0",
274
+ "id": message["id"],
275
+ "error": {
276
+ "code": -32601,
277
+ "message": f"Client method not supported: {method}",
278
+ },
279
+ })
280
+ except MCPError:
281
+ pass
282
+
283
+ def _fail_pending(self) -> None:
284
+ with self._pending_lock:
285
+ pending_requests = list(self._pending.values())
286
+ for pending in pending_requests:
287
+ pending.event.set()
288
+
289
+ def _refresh_tools(self) -> None:
290
+ try:
291
+ self._tools = self._fetch_tools()
292
+ except MCPError:
293
+ pass
294
+
295
+ def _fetch_tools(self) -> list[dict]:
296
+ tools: list[dict] = []
297
+ cursor = None
298
+ while True:
299
+ params = {"cursor": cursor} if cursor else {}
300
+ response = self._request("tools/list", params)
301
+ if "error" in response:
302
+ raise MCPError(
303
+ response["error"].get("message", "tools/list failed")
304
+ )
305
+ result = response.get("result", {})
306
+ tools.extend(result.get("tools", []))
307
+ cursor = result.get("nextCursor")
308
+ if not cursor:
309
+ return tools
310
+
311
+ def call_tool(
312
+ self,
313
+ tool_name: str,
314
+ arguments: dict,
315
+ timeout: float | None = None,
316
+ ) -> str:
317
+ if not self.ready:
318
+ return f"MCP server '{self.name}' not ready"
319
+ if "tools" not in self.server_capabilities:
320
+ return f"MCP server '{self.name}' does not advertise tools"
321
+ try:
322
+ response = self._request(
323
+ "tools/call",
324
+ {"name": tool_name, "arguments": arguments},
325
+ timeout=timeout,
326
+ )
327
+ except MCPTimeoutError as error:
328
+ return f"MCP timeout: {error}"
329
+ except MCPError as error:
330
+ return f"MCP error: {error}"
331
+ if "error" in response:
332
+ return f"MCP error: {response['error'].get('message', 'Unknown')}"
333
+ result = response.get("result", {})
334
+ content = result.get("content", [])
335
+ parts = [
336
+ item.get("text", "")
337
+ for item in content
338
+ if item.get("type") == "text"
339
+ ]
340
+ if result.get("isError"):
341
+ details = "\n".join(parts) or "Unknown error"
342
+ return f"MCP tool error: {details}"
343
+ if parts:
344
+ return "\n".join(parts)
345
+ structured = result.get("structuredContent")
346
+ return json.dumps(structured) if structured is not None else "Done"
347
+
348
+ def get_tools(self) -> list[dict]:
349
+ return list(self._tools)
350
+
351
+ def stop(self) -> None:
352
+ process = self.process
353
+ self.ready = False
354
+ if not process:
355
+ return
356
+ try:
357
+ if process.stdin and not process.stdin.closed:
358
+ process.stdin.close()
359
+ process.wait(timeout=2)
360
+ except subprocess.TimeoutExpired:
361
+ process.terminate()
362
+ try:
363
+ process.wait(timeout=2)
364
+ except subprocess.TimeoutExpired:
365
+ process.kill()
366
+ try:
367
+ process.wait(timeout=2)
368
+ except subprocess.TimeoutExpired:
369
+ pass
370
+ except Exception:
371
+ try:
372
+ process.kill()
373
+ except Exception:
374
+ pass
375
+ finally:
376
+ self.process = None
377
+ self._fail_pending()
378
+
379
+
380
+ def connect(
381
+ name: str,
382
+ command: str | list[str],
383
+ env: dict | None = None,
384
+ request_timeout: float = DEFAULT_REQUEST_TIMEOUT,
385
+ ) -> bool:
386
+ """Start and register an MCP server connection."""
387
+ disconnect(name)
388
+ client = MCPClient(name, command, env, request_timeout=request_timeout)
389
+ if client.start():
390
+ _active_clients[name] = client
391
+ return True
392
+ return False
393
+
394
+
395
+ def disconnect(name: str) -> None:
396
+ client = _active_clients.pop(name, None)
397
+ if client:
398
+ client.stop()
399
+
400
+
401
+ def get_all_tools() -> dict[str, list[dict]]:
402
+ return {
403
+ name: client.get_tools()
404
+ for name, client in _active_clients.items()
405
+ if client.ready
406
+ }
407
+
408
+
409
+ def call_mcp_tool(
410
+ server: str,
411
+ tool: str,
412
+ arguments: dict,
413
+ timeout: float | None = None,
414
+ ) -> str:
415
+ client = _active_clients.get(server)
416
+ if not client:
417
+ return f"MCP server '{server}' not connected. Run: zai mcp connect {server}"
418
+ return client.call_tool(tool, arguments, timeout=timeout)
419
+
420
+
421
+ def stop_all() -> None:
422
+ for client in list(_active_clients.values()):
423
+ client.stop()
424
+ _active_clients.clear()
425
+
426
+
427
+ def active_count() -> int:
428
+ return sum(1 for client in _active_clients.values() if client.ready)
429
+
430
+
431
+ atexit.register(stop_all)