zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/main.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import typer
|
|
2
|
+
from typing import Optional
|
|
3
|
+
from typer.core import TyperGroup
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from rich.panel import Panel
|
|
6
|
+
from rich.prompt import Prompt, Confirm
|
|
7
|
+
|
|
8
|
+
from .config import load_config
|
|
9
|
+
from .core.fallback import (
|
|
10
|
+
format_model_selection,
|
|
11
|
+
has_available_provider,
|
|
12
|
+
stream_with_fallback,
|
|
13
|
+
)
|
|
14
|
+
from .core.context import ContextManager
|
|
15
|
+
from .core.memory import save_session, get_last_session
|
|
16
|
+
from .core.errors import show_error, AllModelsFailedError
|
|
17
|
+
from .core.runtime import configure as configure_runtime
|
|
18
|
+
from .core.runtime import plain_enabled, print_exception
|
|
19
|
+
from .cli.settings import register_settings_commands
|
|
20
|
+
from .cli.integrations import register_integration_commands
|
|
21
|
+
from .cli.utilities import register_utility_commands
|
|
22
|
+
from .cli.workflows import register_workflow_commands
|
|
23
|
+
from .cli.common import closest_command as _closest_command
|
|
24
|
+
from .cli.interactive import run_interactive
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class FuzzyTyperGroup(TyperGroup):
|
|
28
|
+
"""Accept safe, unambiguous typos in top-level command names."""
|
|
29
|
+
|
|
30
|
+
def get_command(self, ctx, cmd_name):
|
|
31
|
+
command = super().get_command(ctx, cmd_name)
|
|
32
|
+
if command is not None:
|
|
33
|
+
return command
|
|
34
|
+
|
|
35
|
+
corrected = _closest_command(cmd_name, self.list_commands(ctx))
|
|
36
|
+
if corrected:
|
|
37
|
+
console.print(f"[yellow]Command '{cmd_name}' not found; using '{corrected}'.[/yellow]")
|
|
38
|
+
return super().get_command(ctx, corrected)
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
app = typer.Typer(
|
|
43
|
+
help="zai - Your personal AI CLI. Free. Fast. Smart.",
|
|
44
|
+
cls=FuzzyTyperGroup,
|
|
45
|
+
)
|
|
46
|
+
register_settings_commands(app)
|
|
47
|
+
register_integration_commands(app)
|
|
48
|
+
import io, sys
|
|
49
|
+
if sys.stdout.encoding and sys.stdout.encoding.lower() != 'utf-8':
|
|
50
|
+
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace')
|
|
51
|
+
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace')
|
|
52
|
+
console = Console()
|
|
53
|
+
|
|
54
|
+
context = ContextManager()
|
|
55
|
+
SYSTEM_PROMPT = (
|
|
56
|
+
"You are zai, a smart AI assistant running in the terminal. Be concise, helpful, and direct. "
|
|
57
|
+
"IMPORTANT: When creating files, ALWAYS use this exact format for each file:\n"
|
|
58
|
+
"**File:** filename.ext\n"
|
|
59
|
+
"```lang\n"
|
|
60
|
+
"...code...\n"
|
|
61
|
+
"```\n"
|
|
62
|
+
"Use this format every time so files can be automatically saved to disk."
|
|
63
|
+
)
|
|
64
|
+
register_utility_commands(app, context, SYSTEM_PROMPT)
|
|
65
|
+
register_workflow_commands(app, context, SYSTEM_PROMPT)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _parse_files_from_response(content: str) -> dict:
|
|
69
|
+
"""Extract filename→code pairs from AI response."""
|
|
70
|
+
import re
|
|
71
|
+
files = {}
|
|
72
|
+
|
|
73
|
+
# Pattern 1: **Anything File:** filename.ext (e.g. **HTML File:** index.html)
|
|
74
|
+
pattern1 = r'\*\*[\w\s]*[Ff]ile[\w\s]*:\*\*\s*([\w./\\-]+\.\w+)[^\n]*\n```[\w]*\n(.*?)```'
|
|
75
|
+
for m in re.finditer(pattern1, content, re.DOTALL):
|
|
76
|
+
files[m.group(1).strip()] = m.group(2).strip()
|
|
77
|
+
|
|
78
|
+
# Pattern 2: **File Name:** or **Filename:** variant
|
|
79
|
+
pattern2 = r'\*\*[\w\s]*:\*\*\s*([\w./\\-]+\.\w+)\s*\n```[\w]*\n(.*?)```'
|
|
80
|
+
for m in re.finditer(pattern2, content, re.DOTALL):
|
|
81
|
+
fname = m.group(1).strip()
|
|
82
|
+
if fname not in files:
|
|
83
|
+
files[fname] = m.group(2).strip()
|
|
84
|
+
|
|
85
|
+
# Pattern 3: ```lang\n# filename or // filename or <!-- filename -->
|
|
86
|
+
pattern3 = r'```(\w+)\n(?:#|//|<!--)\s*([\w./\\-]+\.\w+)[^\n]*\n(.*?)```'
|
|
87
|
+
for m in re.finditer(pattern3, content, re.DOTALL):
|
|
88
|
+
fname = m.group(2).strip()
|
|
89
|
+
if fname not in files:
|
|
90
|
+
files[fname] = m.group(3).strip()
|
|
91
|
+
|
|
92
|
+
# Pattern 4: filename.ext on its own line (or backtick-wrapped) before code block
|
|
93
|
+
pattern4 = r'(?:^|\n)`?([\w./\\-]+\.\w+)`?\s*\n```[\w]*\n(.*?)```'
|
|
94
|
+
for m in re.finditer(pattern4, content, re.DOTALL):
|
|
95
|
+
fname = m.group(1).strip()
|
|
96
|
+
if fname not in files and '/' not in fname.split('.')[-1]:
|
|
97
|
+
files[fname] = m.group(2).strip()
|
|
98
|
+
|
|
99
|
+
return files
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _chat(message: str, model: str = None) -> str:
|
|
103
|
+
config = load_config()
|
|
104
|
+
preferred = model or config["default_model"]
|
|
105
|
+
context.set_model(preferred)
|
|
106
|
+
|
|
107
|
+
if not has_available_provider():
|
|
108
|
+
show_error("No AI provider available. Run: zai setup, or start Ollama.")
|
|
109
|
+
return ""
|
|
110
|
+
|
|
111
|
+
context.add("user", message)
|
|
112
|
+
|
|
113
|
+
if context.is_near_limit():
|
|
114
|
+
console.print("[yellow]Context compressing...[/yellow]")
|
|
115
|
+
context.compress(preferred)
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
content, used_model = stream_with_fallback(
|
|
119
|
+
context.get_messages(),
|
|
120
|
+
system=SYSTEM_PROMPT,
|
|
121
|
+
preferred=preferred,
|
|
122
|
+
)
|
|
123
|
+
except AllModelsFailedError as e:
|
|
124
|
+
show_error(str(e))
|
|
125
|
+
return ""
|
|
126
|
+
except Exception as e:
|
|
127
|
+
print_exception(console, e)
|
|
128
|
+
return ""
|
|
129
|
+
|
|
130
|
+
context.add("assistant", content)
|
|
131
|
+
|
|
132
|
+
if used_model != preferred and not plain_enabled():
|
|
133
|
+
console.print(f"[dim]Switched to: {format_model_selection(used_model)}[/dim]")
|
|
134
|
+
|
|
135
|
+
session_model = used_model
|
|
136
|
+
if plain_enabled():
|
|
137
|
+
save_session(task=message[:80], model=session_model)
|
|
138
|
+
return content
|
|
139
|
+
|
|
140
|
+
used_model = format_model_selection(used_model)
|
|
141
|
+
|
|
142
|
+
console.print(f"[dim]── {used_model} ──[/dim]")
|
|
143
|
+
|
|
144
|
+
if config.get("show_token_count") and not plain_enabled():
|
|
145
|
+
console.print(f"[dim]Tokens: {context.usage_bar()}[/dim]")
|
|
146
|
+
|
|
147
|
+
save_session(task=message[:80], model=session_model)
|
|
148
|
+
return content
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@app.command()
|
|
152
|
+
def chat(
|
|
153
|
+
message: Optional[str] = typer.Argument(None, help="Message to send"),
|
|
154
|
+
model: Optional[str] = typer.Option(None, "--model", "-m", help="Model to use"),
|
|
155
|
+
):
|
|
156
|
+
"""Chat with AI. Supports auto model fallback."""
|
|
157
|
+
if message:
|
|
158
|
+
_chat(message, model)
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
last = get_last_session()
|
|
162
|
+
if last:
|
|
163
|
+
console.print(f"[dim]Last session: {last['task']} ({last['date']})[/dim]")
|
|
164
|
+
if Confirm.ask("Continue last session?", default=False):
|
|
165
|
+
context.add("assistant", f"Continuing from: {last['task']}")
|
|
166
|
+
|
|
167
|
+
console.print(Panel("[bold cyan]zai[/bold cyan] — type your message. [dim]Ctrl+C to exit.[/dim]"))
|
|
168
|
+
while True:
|
|
169
|
+
try:
|
|
170
|
+
msg = Prompt.ask("[cyan]you[/cyan]")
|
|
171
|
+
if msg.strip():
|
|
172
|
+
_chat(msg, model)
|
|
173
|
+
except (KeyboardInterrupt, EOFError):
|
|
174
|
+
console.print("\n[dim]Goodbye![/dim]")
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@app.command("ask")
|
|
179
|
+
def ask(
|
|
180
|
+
message: str = typer.Argument(..., help="Message to send to AI"),
|
|
181
|
+
model: Optional[str] = typer.Option(None, "--model", "-m"),
|
|
182
|
+
):
|
|
183
|
+
"""Send a quick message to AI (shortcut for chat)."""
|
|
184
|
+
_chat(message, model)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@app.callback(invoke_without_command=True)
|
|
188
|
+
def main(
|
|
189
|
+
ctx: typer.Context,
|
|
190
|
+
model: Optional[str] = typer.Option(None, "--model", "-m"),
|
|
191
|
+
version: bool = typer.Option(False, "--version", "-v"),
|
|
192
|
+
debug: bool = typer.Option(False, "--debug", help="Show full tracebacks"),
|
|
193
|
+
plain: bool = typer.Option(
|
|
194
|
+
False,
|
|
195
|
+
"--plain",
|
|
196
|
+
help="Emit plain one-shot output without streaming UI or metadata",
|
|
197
|
+
),
|
|
198
|
+
):
|
|
199
|
+
"""zai - Your personal AI CLI."""
|
|
200
|
+
configure_runtime(debug=debug, plain=plain)
|
|
201
|
+
if version:
|
|
202
|
+
from . import __version__
|
|
203
|
+
console.print(f"zai v{__version__}")
|
|
204
|
+
return
|
|
205
|
+
if ctx.invoked_subcommand is None:
|
|
206
|
+
run_interactive(model)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
if __name__ == "__main__":
|
|
210
|
+
app()
|
zai/mcp/__init__.py
ADDED
|
File without changes
|
zai/mcp/client.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
"""MCP JSON-RPC 2.0 client over the standard stdio transport."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import atexit
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import shlex
|
|
8
|
+
import subprocess
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import Callable, Optional
|
|
13
|
+
|
|
14
|
+
from rich.console import Console
|
|
15
|
+
from ..core.cancellation import current_token
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
|
|
19
|
+
PROTOCOL_VERSION = "2025-06-18"
|
|
20
|
+
SUPPORTED_PROTOCOL_VERSIONS = {PROTOCOL_VERSION, "2024-11-05"}
|
|
21
|
+
DEFAULT_REQUEST_TIMEOUT = 30.0
|
|
22
|
+
|
|
23
|
+
_active_clients: dict[str, "MCPClient"] = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class MCPError(RuntimeError):
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MCPTimeoutError(MCPError):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class _PendingRequest:
|
|
36
|
+
event: threading.Event = field(default_factory=threading.Event)
|
|
37
|
+
response: Optional[dict] = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class MCPClient:
|
|
41
|
+
"""Connect to one MCP server process using newline-delimited JSON-RPC."""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
name: str,
|
|
46
|
+
command: str | list[str],
|
|
47
|
+
env: dict | None = None,
|
|
48
|
+
request_timeout: float = DEFAULT_REQUEST_TIMEOUT,
|
|
49
|
+
progress_callback: Callable[[dict], None] | None = None,
|
|
50
|
+
):
|
|
51
|
+
self.name = name
|
|
52
|
+
self.command = command
|
|
53
|
+
self.env = env or {}
|
|
54
|
+
self.request_timeout = request_timeout
|
|
55
|
+
self.progress_callback = progress_callback
|
|
56
|
+
self.process: Optional[subprocess.Popen] = None
|
|
57
|
+
self._id = 0
|
|
58
|
+
self._tools: list[dict] = []
|
|
59
|
+
self._pending: dict[int, _PendingRequest] = {}
|
|
60
|
+
self._pending_lock = threading.Lock()
|
|
61
|
+
self._write_lock = threading.Lock()
|
|
62
|
+
self._reader_thread: threading.Thread | None = None
|
|
63
|
+
self._stderr_thread: threading.Thread | None = None
|
|
64
|
+
self._reader_error: str | None = None
|
|
65
|
+
self.server_info: dict = {}
|
|
66
|
+
self.server_capabilities: dict = {}
|
|
67
|
+
self.protocol_version: str | None = None
|
|
68
|
+
self.ready = False
|
|
69
|
+
|
|
70
|
+
def _argv(self) -> list[str]:
|
|
71
|
+
if isinstance(self.command, list):
|
|
72
|
+
argv = list(self.command)
|
|
73
|
+
else:
|
|
74
|
+
argv = shlex.split(self.command, posix=os.name != "nt")
|
|
75
|
+
if (
|
|
76
|
+
os.name == "nt"
|
|
77
|
+
and argv
|
|
78
|
+
and argv[0].lower() in {"npx", "npx.cmd", "npm", "npm.cmd"}
|
|
79
|
+
):
|
|
80
|
+
return ["cmd", "/c", *argv]
|
|
81
|
+
return argv
|
|
82
|
+
|
|
83
|
+
def start(self) -> bool:
|
|
84
|
+
env = {**os.environ, **self.env}
|
|
85
|
+
try:
|
|
86
|
+
self.process = subprocess.Popen(
|
|
87
|
+
self._argv(),
|
|
88
|
+
shell=False,
|
|
89
|
+
stdin=subprocess.PIPE,
|
|
90
|
+
stdout=subprocess.PIPE,
|
|
91
|
+
stderr=subprocess.PIPE,
|
|
92
|
+
text=True,
|
|
93
|
+
encoding="utf-8",
|
|
94
|
+
errors="replace",
|
|
95
|
+
bufsize=1,
|
|
96
|
+
env=env,
|
|
97
|
+
)
|
|
98
|
+
self._reader_thread = threading.Thread(
|
|
99
|
+
target=self._read_loop,
|
|
100
|
+
name=f"zai-mcp-{self.name}",
|
|
101
|
+
daemon=True,
|
|
102
|
+
)
|
|
103
|
+
self._reader_thread.start()
|
|
104
|
+
self._stderr_thread = threading.Thread(
|
|
105
|
+
target=self._drain_stderr,
|
|
106
|
+
name=f"zai-mcp-stderr-{self.name}",
|
|
107
|
+
daemon=True,
|
|
108
|
+
)
|
|
109
|
+
self._stderr_thread.start()
|
|
110
|
+
|
|
111
|
+
response = self._request("initialize", {
|
|
112
|
+
"protocolVersion": PROTOCOL_VERSION,
|
|
113
|
+
"capabilities": {},
|
|
114
|
+
"clientInfo": {"name": "zai", "version": "0.1.0"},
|
|
115
|
+
})
|
|
116
|
+
result = response.get("result", {})
|
|
117
|
+
selected = result.get("protocolVersion")
|
|
118
|
+
if selected not in SUPPORTED_PROTOCOL_VERSIONS:
|
|
119
|
+
raise MCPError(
|
|
120
|
+
f"unsupported protocol version: {selected or 'missing'}"
|
|
121
|
+
)
|
|
122
|
+
self.protocol_version = selected
|
|
123
|
+
self.server_capabilities = result.get("capabilities", {})
|
|
124
|
+
self.server_info = result.get("serverInfo", {})
|
|
125
|
+
self._notify("notifications/initialized", {})
|
|
126
|
+
|
|
127
|
+
if "tools" in self.server_capabilities:
|
|
128
|
+
self._tools = self._fetch_tools()
|
|
129
|
+
self.ready = True
|
|
130
|
+
console.print(
|
|
131
|
+
f"[green]MCP connected:[/green] {self.name} "
|
|
132
|
+
f"({len(self._tools)} tools, {self.protocol_version})"
|
|
133
|
+
)
|
|
134
|
+
return True
|
|
135
|
+
except Exception as error:
|
|
136
|
+
console.print(f"[red]MCP error ({self.name}): {error}[/red]")
|
|
137
|
+
self.stop()
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
def _next_id(self) -> int:
|
|
141
|
+
with self._pending_lock:
|
|
142
|
+
self._id += 1
|
|
143
|
+
return self._id
|
|
144
|
+
|
|
145
|
+
def _write(self, message: dict) -> None:
|
|
146
|
+
if not self.process or self.process.poll() is not None or not self.process.stdin:
|
|
147
|
+
raise MCPError("server process is not running")
|
|
148
|
+
payload = json.dumps(message, separators=(",", ":"))
|
|
149
|
+
with self._write_lock:
|
|
150
|
+
self.process.stdin.write(payload + "\n")
|
|
151
|
+
self.process.stdin.flush()
|
|
152
|
+
|
|
153
|
+
def _notify(self, method: str, params: dict | None = None) -> None:
|
|
154
|
+
self._write({
|
|
155
|
+
"jsonrpc": "2.0",
|
|
156
|
+
"method": method,
|
|
157
|
+
"params": params or {},
|
|
158
|
+
})
|
|
159
|
+
|
|
160
|
+
def _request(
|
|
161
|
+
self,
|
|
162
|
+
method: str,
|
|
163
|
+
params: dict | None = None,
|
|
164
|
+
timeout: float | None = None,
|
|
165
|
+
) -> dict:
|
|
166
|
+
request_id = self._next_id()
|
|
167
|
+
pending = _PendingRequest()
|
|
168
|
+
with self._pending_lock:
|
|
169
|
+
self._pending[request_id] = pending
|
|
170
|
+
try:
|
|
171
|
+
self._write({
|
|
172
|
+
"jsonrpc": "2.0",
|
|
173
|
+
"id": request_id,
|
|
174
|
+
"method": method,
|
|
175
|
+
"params": params or {},
|
|
176
|
+
})
|
|
177
|
+
wait_for = self.request_timeout if timeout is None else timeout
|
|
178
|
+
deadline = time.monotonic() + max(wait_for, 0.0)
|
|
179
|
+
while not pending.event.is_set():
|
|
180
|
+
token = current_token()
|
|
181
|
+
if token and token.cancelled:
|
|
182
|
+
try:
|
|
183
|
+
self._notify("notifications/cancelled", {
|
|
184
|
+
"requestId": request_id,
|
|
185
|
+
"reason": token.reason,
|
|
186
|
+
})
|
|
187
|
+
except MCPError:
|
|
188
|
+
pass
|
|
189
|
+
raise MCPError(token.reason)
|
|
190
|
+
remaining = deadline - time.monotonic()
|
|
191
|
+
if remaining <= 0:
|
|
192
|
+
break
|
|
193
|
+
pending.event.wait(min(0.05, remaining))
|
|
194
|
+
if not pending.event.is_set():
|
|
195
|
+
try:
|
|
196
|
+
self._notify("notifications/cancelled", {
|
|
197
|
+
"requestId": request_id,
|
|
198
|
+
"reason": f"zai timed out after {wait_for:g}s",
|
|
199
|
+
})
|
|
200
|
+
except MCPError:
|
|
201
|
+
pass
|
|
202
|
+
raise MCPTimeoutError(
|
|
203
|
+
f"{method} timed out after {wait_for:g}s"
|
|
204
|
+
)
|
|
205
|
+
if pending.response is None:
|
|
206
|
+
raise MCPError(self._reader_error or "server connection closed")
|
|
207
|
+
return pending.response
|
|
208
|
+
finally:
|
|
209
|
+
with self._pending_lock:
|
|
210
|
+
self._pending.pop(request_id, None)
|
|
211
|
+
|
|
212
|
+
# Compatibility for callers/tests that used the old method name.
|
|
213
|
+
def _send(
|
|
214
|
+
self,
|
|
215
|
+
method: str,
|
|
216
|
+
params: dict | None = None,
|
|
217
|
+
timeout: float | None = None,
|
|
218
|
+
) -> dict:
|
|
219
|
+
return self._request(method, params, timeout)
|
|
220
|
+
|
|
221
|
+
def _read_loop(self) -> None:
|
|
222
|
+
try:
|
|
223
|
+
if not self.process or not self.process.stdout:
|
|
224
|
+
return
|
|
225
|
+
for line in self.process.stdout:
|
|
226
|
+
if not line.strip():
|
|
227
|
+
continue
|
|
228
|
+
try:
|
|
229
|
+
message = json.loads(line)
|
|
230
|
+
except json.JSONDecodeError:
|
|
231
|
+
continue
|
|
232
|
+
if "id" in message and ("result" in message or "error" in message):
|
|
233
|
+
with self._pending_lock:
|
|
234
|
+
pending = self._pending.get(message["id"])
|
|
235
|
+
if pending:
|
|
236
|
+
pending.response = message
|
|
237
|
+
pending.event.set()
|
|
238
|
+
continue
|
|
239
|
+
self._handle_server_message(message)
|
|
240
|
+
except Exception as error:
|
|
241
|
+
self._reader_error = str(error)
|
|
242
|
+
finally:
|
|
243
|
+
self._fail_pending()
|
|
244
|
+
|
|
245
|
+
def _drain_stderr(self) -> None:
|
|
246
|
+
"""Drain server logs so a full stderr pipe cannot block the process."""
|
|
247
|
+
try:
|
|
248
|
+
if not self.process or not self.process.stderr:
|
|
249
|
+
return
|
|
250
|
+
for _line in self.process.stderr:
|
|
251
|
+
pass
|
|
252
|
+
except Exception:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
def _handle_server_message(self, message: dict) -> None:
|
|
256
|
+
method = message.get("method")
|
|
257
|
+
params = message.get("params", {})
|
|
258
|
+
if method == "notifications/progress":
|
|
259
|
+
if self.progress_callback:
|
|
260
|
+
self.progress_callback(params)
|
|
261
|
+
return
|
|
262
|
+
if method == "notifications/tools/list_changed" and self.ready:
|
|
263
|
+
# Avoid making a blocking request on the reader thread.
|
|
264
|
+
threading.Thread(
|
|
265
|
+
target=self._refresh_tools,
|
|
266
|
+
name=f"zai-mcp-refresh-{self.name}",
|
|
267
|
+
daemon=True,
|
|
268
|
+
).start()
|
|
269
|
+
return
|
|
270
|
+
if "id" in message and method:
|
|
271
|
+
try:
|
|
272
|
+
self._write({
|
|
273
|
+
"jsonrpc": "2.0",
|
|
274
|
+
"id": message["id"],
|
|
275
|
+
"error": {
|
|
276
|
+
"code": -32601,
|
|
277
|
+
"message": f"Client method not supported: {method}",
|
|
278
|
+
},
|
|
279
|
+
})
|
|
280
|
+
except MCPError:
|
|
281
|
+
pass
|
|
282
|
+
|
|
283
|
+
def _fail_pending(self) -> None:
|
|
284
|
+
with self._pending_lock:
|
|
285
|
+
pending_requests = list(self._pending.values())
|
|
286
|
+
for pending in pending_requests:
|
|
287
|
+
pending.event.set()
|
|
288
|
+
|
|
289
|
+
def _refresh_tools(self) -> None:
|
|
290
|
+
try:
|
|
291
|
+
self._tools = self._fetch_tools()
|
|
292
|
+
except MCPError:
|
|
293
|
+
pass
|
|
294
|
+
|
|
295
|
+
def _fetch_tools(self) -> list[dict]:
|
|
296
|
+
tools: list[dict] = []
|
|
297
|
+
cursor = None
|
|
298
|
+
while True:
|
|
299
|
+
params = {"cursor": cursor} if cursor else {}
|
|
300
|
+
response = self._request("tools/list", params)
|
|
301
|
+
if "error" in response:
|
|
302
|
+
raise MCPError(
|
|
303
|
+
response["error"].get("message", "tools/list failed")
|
|
304
|
+
)
|
|
305
|
+
result = response.get("result", {})
|
|
306
|
+
tools.extend(result.get("tools", []))
|
|
307
|
+
cursor = result.get("nextCursor")
|
|
308
|
+
if not cursor:
|
|
309
|
+
return tools
|
|
310
|
+
|
|
311
|
+
def call_tool(
|
|
312
|
+
self,
|
|
313
|
+
tool_name: str,
|
|
314
|
+
arguments: dict,
|
|
315
|
+
timeout: float | None = None,
|
|
316
|
+
) -> str:
|
|
317
|
+
if not self.ready:
|
|
318
|
+
return f"MCP server '{self.name}' not ready"
|
|
319
|
+
if "tools" not in self.server_capabilities:
|
|
320
|
+
return f"MCP server '{self.name}' does not advertise tools"
|
|
321
|
+
try:
|
|
322
|
+
response = self._request(
|
|
323
|
+
"tools/call",
|
|
324
|
+
{"name": tool_name, "arguments": arguments},
|
|
325
|
+
timeout=timeout,
|
|
326
|
+
)
|
|
327
|
+
except MCPTimeoutError as error:
|
|
328
|
+
return f"MCP timeout: {error}"
|
|
329
|
+
except MCPError as error:
|
|
330
|
+
return f"MCP error: {error}"
|
|
331
|
+
if "error" in response:
|
|
332
|
+
return f"MCP error: {response['error'].get('message', 'Unknown')}"
|
|
333
|
+
result = response.get("result", {})
|
|
334
|
+
content = result.get("content", [])
|
|
335
|
+
parts = [
|
|
336
|
+
item.get("text", "")
|
|
337
|
+
for item in content
|
|
338
|
+
if item.get("type") == "text"
|
|
339
|
+
]
|
|
340
|
+
if result.get("isError"):
|
|
341
|
+
details = "\n".join(parts) or "Unknown error"
|
|
342
|
+
return f"MCP tool error: {details}"
|
|
343
|
+
if parts:
|
|
344
|
+
return "\n".join(parts)
|
|
345
|
+
structured = result.get("structuredContent")
|
|
346
|
+
return json.dumps(structured) if structured is not None else "Done"
|
|
347
|
+
|
|
348
|
+
def get_tools(self) -> list[dict]:
|
|
349
|
+
return list(self._tools)
|
|
350
|
+
|
|
351
|
+
def stop(self) -> None:
|
|
352
|
+
process = self.process
|
|
353
|
+
self.ready = False
|
|
354
|
+
if not process:
|
|
355
|
+
return
|
|
356
|
+
try:
|
|
357
|
+
if process.stdin and not process.stdin.closed:
|
|
358
|
+
process.stdin.close()
|
|
359
|
+
process.wait(timeout=2)
|
|
360
|
+
except subprocess.TimeoutExpired:
|
|
361
|
+
process.terminate()
|
|
362
|
+
try:
|
|
363
|
+
process.wait(timeout=2)
|
|
364
|
+
except subprocess.TimeoutExpired:
|
|
365
|
+
process.kill()
|
|
366
|
+
try:
|
|
367
|
+
process.wait(timeout=2)
|
|
368
|
+
except subprocess.TimeoutExpired:
|
|
369
|
+
pass
|
|
370
|
+
except Exception:
|
|
371
|
+
try:
|
|
372
|
+
process.kill()
|
|
373
|
+
except Exception:
|
|
374
|
+
pass
|
|
375
|
+
finally:
|
|
376
|
+
self.process = None
|
|
377
|
+
self._fail_pending()
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def connect(
|
|
381
|
+
name: str,
|
|
382
|
+
command: str | list[str],
|
|
383
|
+
env: dict | None = None,
|
|
384
|
+
request_timeout: float = DEFAULT_REQUEST_TIMEOUT,
|
|
385
|
+
) -> bool:
|
|
386
|
+
"""Start and register an MCP server connection."""
|
|
387
|
+
disconnect(name)
|
|
388
|
+
client = MCPClient(name, command, env, request_timeout=request_timeout)
|
|
389
|
+
if client.start():
|
|
390
|
+
_active_clients[name] = client
|
|
391
|
+
return True
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def disconnect(name: str) -> None:
|
|
396
|
+
client = _active_clients.pop(name, None)
|
|
397
|
+
if client:
|
|
398
|
+
client.stop()
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def get_all_tools() -> dict[str, list[dict]]:
|
|
402
|
+
return {
|
|
403
|
+
name: client.get_tools()
|
|
404
|
+
for name, client in _active_clients.items()
|
|
405
|
+
if client.ready
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def call_mcp_tool(
|
|
410
|
+
server: str,
|
|
411
|
+
tool: str,
|
|
412
|
+
arguments: dict,
|
|
413
|
+
timeout: float | None = None,
|
|
414
|
+
) -> str:
|
|
415
|
+
client = _active_clients.get(server)
|
|
416
|
+
if not client:
|
|
417
|
+
return f"MCP server '{server}' not connected. Run: zai mcp connect {server}"
|
|
418
|
+
return client.call_tool(tool, arguments, timeout=timeout)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def stop_all() -> None:
|
|
422
|
+
for client in list(_active_clients.values()):
|
|
423
|
+
client.stop()
|
|
424
|
+
_active_clients.clear()
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def active_count() -> int:
|
|
428
|
+
return sum(1 for client in _active_clients.values() if client.ready)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
atexit.register(stop_all)
|