zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/mcp/manager.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from ..config import ZAI_DIR
|
|
5
|
+
from ..core.storage import atomic_write_json, read_json, update_json
|
|
6
|
+
|
|
7
|
+
MCP_FILE = ZAI_DIR / "mcp.json"
|
|
8
|
+
|
|
9
|
+
KNOWN_SERVERS = {
|
|
10
|
+
"filesystem": {
|
|
11
|
+
"description": "Secure file operations within the current project",
|
|
12
|
+
"command": "npx",
|
|
13
|
+
"args": ["-y", "@modelcontextprotocol/server-filesystem", "{cwd}"],
|
|
14
|
+
"env_key": None,
|
|
15
|
+
},
|
|
16
|
+
"memory": {
|
|
17
|
+
"description": "Knowledge-graph persistent memory",
|
|
18
|
+
"command": "npx",
|
|
19
|
+
"args": ["-y", "@modelcontextprotocol/server-memory"],
|
|
20
|
+
"env_key": None,
|
|
21
|
+
},
|
|
22
|
+
"github": {
|
|
23
|
+
"description": "GitHub repositories, issues, and pull requests (archived reference server)",
|
|
24
|
+
"command": "npx",
|
|
25
|
+
"args": ["-y", "@modelcontextprotocol/server-github"],
|
|
26
|
+
"env_key": "GITHUB_PERSONAL_ACCESS_TOKEN",
|
|
27
|
+
},
|
|
28
|
+
"postgres": {
|
|
29
|
+
"description": "Read-only PostgreSQL schema and query access (archived reference server)",
|
|
30
|
+
"command": "npx",
|
|
31
|
+
"args": ["-y", "@modelcontextprotocol/server-postgres", "{DATABASE_URL}"],
|
|
32
|
+
"env_key": "DATABASE_URL",
|
|
33
|
+
},
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def load_mcp() -> list[dict]:
|
|
38
|
+
return read_json(MCP_FILE, [], expected_type=list)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save_mcp(servers: list[dict]) -> None:
|
|
42
|
+
atomic_write_json(MCP_FILE, servers)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def add_server(name: str) -> dict | None:
|
|
46
|
+
if name not in KNOWN_SERVERS:
|
|
47
|
+
return None
|
|
48
|
+
server = {"name": name, **KNOWN_SERVERS[name]}
|
|
49
|
+
already = False
|
|
50
|
+
|
|
51
|
+
def update(servers):
|
|
52
|
+
nonlocal already
|
|
53
|
+
if any(item.get("name") == name for item in servers):
|
|
54
|
+
already = True
|
|
55
|
+
return servers
|
|
56
|
+
return [*servers, server]
|
|
57
|
+
|
|
58
|
+
update_json(MCP_FILE, [], update, expected_type=list)
|
|
59
|
+
return {"already": True, **KNOWN_SERVERS[name]} if already else server
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def remove_server(name: str) -> bool:
|
|
63
|
+
removed = False
|
|
64
|
+
|
|
65
|
+
def update(servers):
|
|
66
|
+
nonlocal removed
|
|
67
|
+
remaining = [server for server in servers if server.get("name") != name]
|
|
68
|
+
removed = len(remaining) != len(servers)
|
|
69
|
+
return remaining
|
|
70
|
+
|
|
71
|
+
update_json(MCP_FILE, [], update, expected_type=list)
|
|
72
|
+
return removed
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def list_servers() -> list[dict]:
|
|
76
|
+
return load_mcp()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def build_server_command(
|
|
80
|
+
server: dict,
|
|
81
|
+
cwd: str | Path | None = None,
|
|
82
|
+
) -> tuple[list[str], dict[str, str]]:
|
|
83
|
+
"""Build an argv/env pair from persisted MCP server configuration."""
|
|
84
|
+
current_dir = str(Path(cwd or os.getcwd()).resolve())
|
|
85
|
+
known = KNOWN_SERVERS.get(server.get("name", ""), {})
|
|
86
|
+
env_key = known.get("env_key") if known else server.get("env_key")
|
|
87
|
+
replacements = {"cwd": current_dir}
|
|
88
|
+
if env_key:
|
|
89
|
+
value = os.getenv(env_key)
|
|
90
|
+
if not value and env_key == "GITHUB_PERSONAL_ACCESS_TOKEN":
|
|
91
|
+
value = os.getenv("GITHUB_TOKEN")
|
|
92
|
+
if not value:
|
|
93
|
+
raise ValueError(f"set {env_key} to enable {server.get('name', 'server')}")
|
|
94
|
+
replacements[env_key] = value
|
|
95
|
+
|
|
96
|
+
command = server.get("command")
|
|
97
|
+
args = server.get("args")
|
|
98
|
+
if not command or not isinstance(args, list):
|
|
99
|
+
# Migrate older saved entries to the current known command first.
|
|
100
|
+
if known:
|
|
101
|
+
command = known.get("command")
|
|
102
|
+
args = known.get("args", [])
|
|
103
|
+
else:
|
|
104
|
+
package = server.get("package")
|
|
105
|
+
if not package:
|
|
106
|
+
raise ValueError("MCP server has no command configuration")
|
|
107
|
+
command = "npx"
|
|
108
|
+
args = ["-y", package]
|
|
109
|
+
|
|
110
|
+
rendered_args = [
|
|
111
|
+
str(argument).format_map(replacements)
|
|
112
|
+
for argument in args
|
|
113
|
+
]
|
|
114
|
+
argv = [str(command), *rendered_args]
|
|
115
|
+
if os.name == "nt" and command in {"npx", "npm"}:
|
|
116
|
+
argv = ["cmd", "/c", *argv]
|
|
117
|
+
environment = {env_key: replacements[env_key]} if env_key else {}
|
|
118
|
+
return argv, environment
|
zai/plugins/__init__.py
ADDED
zai/plugins/base.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BasePlugin(ABC):
|
|
5
|
+
name: str = ""
|
|
6
|
+
description: str = ""
|
|
7
|
+
version: str = "0.1.0"
|
|
8
|
+
author: str = ""
|
|
9
|
+
permissions: tuple[str, ...] = ()
|
|
10
|
+
|
|
11
|
+
def setup(self) -> None:
|
|
12
|
+
"""Called once when plugin loads. Override for initialization."""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
def get_tools(self) -> dict:
|
|
16
|
+
"""
|
|
17
|
+
Return {tool_name: callable} dict.
|
|
18
|
+
Each callable receives **kwargs from the agent and returns a str result.
|
|
19
|
+
"""
|
|
20
|
+
return {}
|
|
21
|
+
|
|
22
|
+
def get_skills(self) -> dict:
|
|
23
|
+
"""
|
|
24
|
+
Return {skill_name: {"description": str, "prompt": callable(ctx) -> str}} dict.
|
|
25
|
+
Skills become available via: zai skill <name> <file>
|
|
26
|
+
"""
|
|
27
|
+
return {}
|
|
28
|
+
|
|
29
|
+
def get_commands(self) -> dict:
|
|
30
|
+
"""
|
|
31
|
+
Return {cmd_name: {"description": str, "body": str}} dict.
|
|
32
|
+
Commands become available as /<cmd_name> in interactive mode.
|
|
33
|
+
"""
|
|
34
|
+
return {}
|
|
35
|
+
|
|
36
|
+
def get_agent_description(self) -> str:
|
|
37
|
+
"""System prompt snippet describing this plugin's tools."""
|
|
38
|
+
tools = self.get_tools()
|
|
39
|
+
if not tools:
|
|
40
|
+
return ""
|
|
41
|
+
lines = [f"\nPlugin '{self.name}' tools (use structured plugin_call):"]
|
|
42
|
+
for tool_name, func in tools.items():
|
|
43
|
+
doc = (func.__doc__ or "").strip().split("\n")[0][:80]
|
|
44
|
+
lines.append(
|
|
45
|
+
f'<tool_call>{{"name":"plugin_call","arguments":'
|
|
46
|
+
f'{{"plugin":"{self.name}","tool":"{tool_name}",'
|
|
47
|
+
f'"arguments":{{"arg":"value"}}}}}}</tool_call> ({doc})'
|
|
48
|
+
)
|
|
49
|
+
return "\n".join(lines)
|
zai/plugins/loader.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plugin loader — finds and loads plugins from:
|
|
3
|
+
1. ~/.zai/plugins/*.py (user-installed)
|
|
4
|
+
2. pip packages with entry_point group 'zai.plugins'
|
|
5
|
+
"""
|
|
6
|
+
import importlib.util
|
|
7
|
+
import hashlib
|
|
8
|
+
import json
|
|
9
|
+
import re
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from .base import BasePlugin
|
|
13
|
+
from ..core.storage import atomic_write_json, atomic_write_text, read_json
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
PLUGIN_DIR = Path.home() / ".zai" / "plugins"
|
|
18
|
+
DISABLED_FILE = Path.home() / ".zai" / "plugins_disabled.json"
|
|
19
|
+
TRUST_FILE = Path.home() / ".zai" / "plugins_trusted.json"
|
|
20
|
+
|
|
21
|
+
_loaded: dict[str, BasePlugin] = {}
|
|
22
|
+
_load_errors: dict[str, str] = {}
|
|
23
|
+
PLUGIN_NAME_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9_.-]{0,127}$")
|
|
24
|
+
MANIFEST_SUFFIX = ".plugin.json"
|
|
25
|
+
MANIFEST_VERSION = 1
|
|
26
|
+
ALLOWED_PERMISSIONS = {
|
|
27
|
+
"project_read",
|
|
28
|
+
"project_write",
|
|
29
|
+
"network",
|
|
30
|
+
"subprocess",
|
|
31
|
+
"secrets",
|
|
32
|
+
}
|
|
33
|
+
_manifests: dict[str, dict] = {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _load_disabled() -> set:
|
|
37
|
+
return set(read_json(DISABLED_FILE, [], expected_type=list))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _save_disabled(names: set):
|
|
41
|
+
atomic_write_json(DISABLED_FILE, sorted(names))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _load_trusted() -> dict[str, str]:
|
|
45
|
+
return read_json(TRUST_FILE, {}, expected_type=dict)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _save_trusted(entries: dict[str, str]) -> None:
|
|
49
|
+
atomic_write_json(TRUST_FILE, entries)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def local_manifest_path(name: str) -> Path | None:
|
|
53
|
+
plugin_path = local_plugin_path(name)
|
|
54
|
+
if plugin_path is None:
|
|
55
|
+
return None
|
|
56
|
+
return PLUGIN_DIR / f"{name}{MANIFEST_SUFFIX}"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _validate_manifest(data: object, expected_name: str) -> dict:
|
|
60
|
+
if not isinstance(data, dict):
|
|
61
|
+
raise ValueError("plugin manifest must be a JSON object")
|
|
62
|
+
if data.get("manifest_version") != MANIFEST_VERSION:
|
|
63
|
+
raise ValueError(f"manifest_version must be {MANIFEST_VERSION}")
|
|
64
|
+
if data.get("name") != expected_name:
|
|
65
|
+
raise ValueError("plugin manifest name does not match plugin name")
|
|
66
|
+
permissions = data.get("permissions")
|
|
67
|
+
if not isinstance(permissions, list) or not all(
|
|
68
|
+
isinstance(item, str) for item in permissions
|
|
69
|
+
):
|
|
70
|
+
raise ValueError("plugin permissions must be a list of strings")
|
|
71
|
+
unknown = sorted(set(permissions) - ALLOWED_PERMISSIONS)
|
|
72
|
+
if unknown:
|
|
73
|
+
raise ValueError("unknown plugin permissions: " + ", ".join(unknown))
|
|
74
|
+
normalized = dict(data)
|
|
75
|
+
normalized["permissions"] = sorted(set(permissions))
|
|
76
|
+
normalized.setdefault("description", "")
|
|
77
|
+
normalized.setdefault("version", "0.1.0")
|
|
78
|
+
normalized.setdefault("source", "local")
|
|
79
|
+
return normalized
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _read_local_manifest(name: str) -> dict:
|
|
83
|
+
path = local_manifest_path(name)
|
|
84
|
+
if path is None or not path.is_file():
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"Missing {name}{MANIFEST_SUFFIX}. "
|
|
87
|
+
"Plugins require a reviewed manifest."
|
|
88
|
+
)
|
|
89
|
+
try:
|
|
90
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
91
|
+
except (OSError, UnicodeError, json.JSONDecodeError) as error:
|
|
92
|
+
raise ValueError(f"Invalid plugin manifest: {error}") from error
|
|
93
|
+
return _validate_manifest(data, name)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _file_fingerprint(path: Path, manifest: dict) -> str:
|
|
97
|
+
manifest_bytes = json.dumps(
|
|
98
|
+
manifest,
|
|
99
|
+
sort_keys=True,
|
|
100
|
+
separators=(",", ":"),
|
|
101
|
+
).encode("utf-8")
|
|
102
|
+
return hashlib.sha256(path.read_bytes() + b"\0" + manifest_bytes).hexdigest()
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _entrypoint_manifest(entry_point) -> dict:
|
|
106
|
+
distribution = getattr(entry_point, "dist", None)
|
|
107
|
+
if distribution is None:
|
|
108
|
+
raise ValueError("Plugin entry point has no distribution metadata")
|
|
109
|
+
for item in getattr(distribution, "files", None) or []:
|
|
110
|
+
if Path(str(item)).name == "zai-plugin.json":
|
|
111
|
+
path = distribution.locate_file(item)
|
|
112
|
+
try:
|
|
113
|
+
data = json.loads(Path(path).read_text(encoding="utf-8"))
|
|
114
|
+
except (OSError, UnicodeError, json.JSONDecodeError) as error:
|
|
115
|
+
raise ValueError(f"Invalid packaged plugin manifest: {error}") from error
|
|
116
|
+
return _validate_manifest(data, entry_point.name)
|
|
117
|
+
raise ValueError("Packaged plugin is missing zai-plugin.json")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _entrypoint_fingerprint(entry_point, manifest: dict) -> str:
|
|
121
|
+
distribution = getattr(entry_point, "dist", None)
|
|
122
|
+
dist_name = getattr(distribution, "name", "") or ""
|
|
123
|
+
version = getattr(distribution, "version", "") or ""
|
|
124
|
+
file_records = []
|
|
125
|
+
for item in getattr(distribution, "files", None) or []:
|
|
126
|
+
recorded_hash = getattr(item, "hash", None)
|
|
127
|
+
file_records.append(f"{item}:{recorded_hash or ''}")
|
|
128
|
+
return hashlib.sha256(
|
|
129
|
+
(
|
|
130
|
+
f"{entry_point.value}|{dist_name}|{version}|"
|
|
131
|
+
+ "|".join(sorted(file_records))
|
|
132
|
+
+ "|"
|
|
133
|
+
+ json.dumps(manifest, sort_keys=True, separators=(",", ":"))
|
|
134
|
+
).encode("utf-8")
|
|
135
|
+
).hexdigest()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _entry_points():
|
|
139
|
+
from importlib.metadata import entry_points
|
|
140
|
+
|
|
141
|
+
return entry_points(group="zai.plugins")
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def local_plugin_path(name: str) -> Path | None:
|
|
145
|
+
if (
|
|
146
|
+
not PLUGIN_NAME_PATTERN.fullmatch(name)
|
|
147
|
+
or ".." in name
|
|
148
|
+
or "/" in name
|
|
149
|
+
or "\\" in name
|
|
150
|
+
):
|
|
151
|
+
return None
|
|
152
|
+
return PLUGIN_DIR / f"{name}.py"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def trust_plugin(name: str) -> bool:
|
|
156
|
+
"""Trust the current code fingerprint without importing or executing it."""
|
|
157
|
+
trusted = _load_trusted()
|
|
158
|
+
local_path = local_plugin_path(name)
|
|
159
|
+
if local_path and local_path.is_file():
|
|
160
|
+
try:
|
|
161
|
+
manifest = _read_local_manifest(name)
|
|
162
|
+
except ValueError:
|
|
163
|
+
return False
|
|
164
|
+
trusted[f"file:{name}"] = _file_fingerprint(local_path, manifest)
|
|
165
|
+
_save_trusted(trusted)
|
|
166
|
+
_load_errors.pop(name, None)
|
|
167
|
+
return True
|
|
168
|
+
try:
|
|
169
|
+
for entry_point in _entry_points():
|
|
170
|
+
if entry_point.name == name:
|
|
171
|
+
manifest = _entrypoint_manifest(entry_point)
|
|
172
|
+
trusted[f"entrypoint:{name}"] = _entrypoint_fingerprint(
|
|
173
|
+
entry_point,
|
|
174
|
+
manifest,
|
|
175
|
+
)
|
|
176
|
+
_save_trusted(trusted)
|
|
177
|
+
_load_errors.pop(name, None)
|
|
178
|
+
return True
|
|
179
|
+
except Exception:
|
|
180
|
+
return False
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def revoke_plugin_trust(name: str) -> bool:
|
|
185
|
+
trusted = _load_trusted()
|
|
186
|
+
removed = False
|
|
187
|
+
for key in (f"file:{name}", f"entrypoint:{name}"):
|
|
188
|
+
removed = trusted.pop(key, None) is not None or removed
|
|
189
|
+
if removed:
|
|
190
|
+
_save_trusted(trusted)
|
|
191
|
+
_loaded.pop(name, None)
|
|
192
|
+
return removed
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def load_all() -> dict[str, BasePlugin]:
|
|
196
|
+
"""Load all plugins. Returns {name: plugin_instance}."""
|
|
197
|
+
global _loaded, _load_errors, _manifests
|
|
198
|
+
_loaded = {}
|
|
199
|
+
_load_errors = {}
|
|
200
|
+
_manifests = {}
|
|
201
|
+
disabled = _load_disabled()
|
|
202
|
+
trusted = _load_trusted()
|
|
203
|
+
|
|
204
|
+
# 1. User plugins from ~/.zai/plugins/*.py
|
|
205
|
+
PLUGIN_DIR.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
for py_file in sorted(PLUGIN_DIR.glob("*.py")):
|
|
207
|
+
try:
|
|
208
|
+
manifest = _read_local_manifest(py_file.stem)
|
|
209
|
+
except ValueError as error:
|
|
210
|
+
_load_errors[py_file.stem] = str(error)
|
|
211
|
+
continue
|
|
212
|
+
fingerprint = _file_fingerprint(py_file, manifest)
|
|
213
|
+
if trusted.get(f"file:{py_file.stem}") != fingerprint:
|
|
214
|
+
_load_errors[py_file.stem] = (
|
|
215
|
+
"Plugin is untrusted or changed. Review it, then run: "
|
|
216
|
+
f"zai plugin trust {py_file.stem}"
|
|
217
|
+
)
|
|
218
|
+
continue
|
|
219
|
+
_load_file(py_file, disabled, manifest)
|
|
220
|
+
|
|
221
|
+
# 2. Pip-installed entry points
|
|
222
|
+
try:
|
|
223
|
+
for ep in _entry_points():
|
|
224
|
+
try:
|
|
225
|
+
manifest = _entrypoint_manifest(ep)
|
|
226
|
+
if trusted.get(f"entrypoint:{ep.name}") != _entrypoint_fingerprint(
|
|
227
|
+
ep,
|
|
228
|
+
manifest,
|
|
229
|
+
):
|
|
230
|
+
_load_errors[ep.name] = (
|
|
231
|
+
"Plugin package is untrusted or changed. Review it, then run: "
|
|
232
|
+
f"zai plugin trust {ep.name}"
|
|
233
|
+
)
|
|
234
|
+
continue
|
|
235
|
+
plugin_cls = ep.load()
|
|
236
|
+
instance = plugin_cls()
|
|
237
|
+
if instance.name in disabled:
|
|
238
|
+
continue
|
|
239
|
+
if tuple(manifest["permissions"]) != tuple(
|
|
240
|
+
sorted(set(instance.permissions))
|
|
241
|
+
):
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"plugin class permissions do not match its manifest"
|
|
244
|
+
)
|
|
245
|
+
instance.setup()
|
|
246
|
+
_loaded[instance.name] = instance
|
|
247
|
+
_manifests[instance.name] = manifest
|
|
248
|
+
except Exception as e:
|
|
249
|
+
_load_errors[ep.name] = str(e)
|
|
250
|
+
except Exception:
|
|
251
|
+
pass
|
|
252
|
+
|
|
253
|
+
return _loaded
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _load_file(py_file: Path, disabled: set, manifest: dict):
|
|
257
|
+
try:
|
|
258
|
+
spec = importlib.util.spec_from_file_location(py_file.stem, py_file)
|
|
259
|
+
if spec is None or spec.loader is None:
|
|
260
|
+
raise ValueError("Cannot create plugin module specification")
|
|
261
|
+
mod = importlib.util.module_from_spec(spec)
|
|
262
|
+
spec.loader.exec_module(mod)
|
|
263
|
+
instance = getattr(mod, "plugin", None)
|
|
264
|
+
if not isinstance(instance, BasePlugin):
|
|
265
|
+
_load_errors[py_file.stem] = "No `plugin = MyPlugin()` found at module level"
|
|
266
|
+
return
|
|
267
|
+
if not instance.name:
|
|
268
|
+
instance.name = py_file.stem
|
|
269
|
+
if instance.name != manifest["name"]:
|
|
270
|
+
raise ValueError("loaded plugin name does not match its manifest")
|
|
271
|
+
if tuple(manifest["permissions"]) != tuple(
|
|
272
|
+
sorted(set(instance.permissions))
|
|
273
|
+
):
|
|
274
|
+
raise ValueError("plugin class permissions do not match its manifest")
|
|
275
|
+
if instance.name in disabled:
|
|
276
|
+
return
|
|
277
|
+
instance.setup()
|
|
278
|
+
_loaded[instance.name] = instance
|
|
279
|
+
_manifests[instance.name] = manifest
|
|
280
|
+
except Exception as e:
|
|
281
|
+
_load_errors[py_file.stem] = str(e)
|
|
282
|
+
console.print(f"[yellow]Plugin error ({py_file.name}):[/yellow] {e}")
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def get_loaded() -> dict[str, BasePlugin]:
|
|
286
|
+
return _loaded
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def get_errors() -> dict[str, str]:
|
|
290
|
+
return _load_errors
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def get_manifest(name: str) -> dict | None:
|
|
294
|
+
manifest = _manifests.get(name)
|
|
295
|
+
return dict(manifest) if manifest else None
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def disable_plugin(name: str) -> bool:
|
|
299
|
+
disabled = _load_disabled()
|
|
300
|
+
disabled.add(name)
|
|
301
|
+
_save_disabled(disabled)
|
|
302
|
+
_loaded.pop(name, None)
|
|
303
|
+
return True
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def enable_plugin(name: str) -> bool:
|
|
307
|
+
disabled = _load_disabled()
|
|
308
|
+
if name not in disabled:
|
|
309
|
+
return False
|
|
310
|
+
disabled.discard(name)
|
|
311
|
+
_save_disabled(disabled)
|
|
312
|
+
return True
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def get_all_tools() -> dict[str, tuple]:
|
|
316
|
+
"""Returns {tool_name: (plugin_name, callable)}."""
|
|
317
|
+
tools = {}
|
|
318
|
+
for plugin_name, plugin in _loaded.items():
|
|
319
|
+
for tool_name, func in plugin.get_tools().items():
|
|
320
|
+
tools[f"{plugin_name}.{tool_name}"] = (plugin_name, tool_name, func)
|
|
321
|
+
return tools
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def get_all_skills() -> dict:
|
|
325
|
+
"""Merge all plugin skills into one dict."""
|
|
326
|
+
skills = {}
|
|
327
|
+
for plugin in _loaded.values():
|
|
328
|
+
skills.update(plugin.get_skills())
|
|
329
|
+
return skills
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def get_all_commands() -> dict:
|
|
333
|
+
"""Merge all plugin commands into one dict."""
|
|
334
|
+
cmds = {}
|
|
335
|
+
for plugin in _loaded.values():
|
|
336
|
+
cmds.update(plugin.get_commands())
|
|
337
|
+
return cmds
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def get_agent_descriptions() -> str:
|
|
341
|
+
"""Full system prompt block for all plugin tools."""
|
|
342
|
+
parts = [p.get_agent_description() for p in _loaded.values()]
|
|
343
|
+
return "\n".join(p for p in parts if p)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def scaffold_plugin(name: str) -> str:
|
|
347
|
+
"""Create a starter plugin file in ~/.zai/plugins/. Returns file path."""
|
|
348
|
+
PLUGIN_DIR.mkdir(parents=True, exist_ok=True)
|
|
349
|
+
safe = re.sub(r"[^a-z0-9_]", "_", name.lower().replace("-", "_")).strip("_")
|
|
350
|
+
if not safe:
|
|
351
|
+
raise ValueError("Plugin name must contain a letter or number")
|
|
352
|
+
path = PLUGIN_DIR / f"{safe}.py"
|
|
353
|
+
manifest_path = PLUGIN_DIR / f"{safe}{MANIFEST_SUFFIX}"
|
|
354
|
+
atomic_write_text(
|
|
355
|
+
path,
|
|
356
|
+
f'"""zai plugin: {name}"""\n'
|
|
357
|
+
f"from zai.plugins.base import BasePlugin\n\n\n"
|
|
358
|
+
f"class {safe.title().replace('_', '')}Plugin(BasePlugin):\n"
|
|
359
|
+
f' name = "{safe}"\n'
|
|
360
|
+
f' description = "My {name} plugin"\n'
|
|
361
|
+
f' version = "0.1.0"\n'
|
|
362
|
+
f' author = ""\n\n'
|
|
363
|
+
f" permissions = ()\n\n"
|
|
364
|
+
f" def setup(self):\n"
|
|
365
|
+
f" pass # runs once on load\n\n"
|
|
366
|
+
f" def get_tools(self):\n"
|
|
367
|
+
f" return {{\n"
|
|
368
|
+
f' "hello": self.hello,\n'
|
|
369
|
+
f" }}\n\n"
|
|
370
|
+
f" def hello(self, name=\"world\", **kwargs):\n"
|
|
371
|
+
f' """Say hello from the plugin."""\n'
|
|
372
|
+
f' return f"Hello {{name}} from {name} plugin!"\n\n'
|
|
373
|
+
f" def get_skills(self):\n"
|
|
374
|
+
f" return {{\n"
|
|
375
|
+
f' "{safe}_skill": {{\n'
|
|
376
|
+
f' "description": "Example skill",\n'
|
|
377
|
+
f' "prompt": lambda ctx: f"Process this:\\n\\n{{ctx}}",\n'
|
|
378
|
+
f" }}\n"
|
|
379
|
+
f" }}\n\n"
|
|
380
|
+
f" def get_commands(self):\n"
|
|
381
|
+
f" return {{\n"
|
|
382
|
+
f' "{safe}": {{\n'
|
|
383
|
+
f' "description": "Run {name} command",\n'
|
|
384
|
+
f' "body": "You are helping with {name}. $ARGUMENTS",\n'
|
|
385
|
+
f" }}\n"
|
|
386
|
+
f" }}\n\n\n"
|
|
387
|
+
f"plugin = {safe.title().replace('_', '')}Plugin()\n",
|
|
388
|
+
mode=0o644,
|
|
389
|
+
lock=False,
|
|
390
|
+
)
|
|
391
|
+
atomic_write_json(
|
|
392
|
+
manifest_path,
|
|
393
|
+
{
|
|
394
|
+
"manifest_version": MANIFEST_VERSION,
|
|
395
|
+
"name": safe,
|
|
396
|
+
"description": f"My {name} plugin",
|
|
397
|
+
"version": "0.1.0",
|
|
398
|
+
"source": "local",
|
|
399
|
+
"permissions": [],
|
|
400
|
+
},
|
|
401
|
+
lock=False,
|
|
402
|
+
)
|
|
403
|
+
trust_plugin(safe)
|
|
404
|
+
return str(path)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .gemini import GeminiProvider
|
|
2
|
+
from .groq import GroqProvider
|
|
3
|
+
from .cerebras import CerebrasProvider
|
|
4
|
+
from .openrouter import OpenRouterProvider
|
|
5
|
+
from .qwen import QwenProvider
|
|
6
|
+
from .anthropic import AnthropicProvider
|
|
7
|
+
from .openai import OpenAIProvider
|
|
8
|
+
from .ollama import OllamaProvider
|
|
9
|
+
|
|
10
|
+
PROVIDERS = {
|
|
11
|
+
"gemini": GeminiProvider,
|
|
12
|
+
"groq": GroqProvider,
|
|
13
|
+
"cerebras": CerebrasProvider,
|
|
14
|
+
"openrouter": OpenRouterProvider,
|
|
15
|
+
"qwen": QwenProvider,
|
|
16
|
+
"anthropic": AnthropicProvider,
|
|
17
|
+
"openai": OpenAIProvider,
|
|
18
|
+
"ollama": OllamaProvider,
|
|
19
|
+
# aliases matching MODELS keys in config.py
|
|
20
|
+
"claude": AnthropicProvider,
|
|
21
|
+
"gpt4o": OpenAIProvider,
|
|
22
|
+
}
|