zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/fallback.py ADDED
@@ -0,0 +1,171 @@
1
+ from rich.console import Console
2
+ from ..providers import PROVIDERS
3
+ from ..providers.base import Message, Response
4
+ from ..config import get_model_config, load_config
5
+ from .errors import (
6
+ AllModelsFailedError,
7
+ NetworkError,
8
+ NoAPIKeyError,
9
+ ProviderError,
10
+ RateLimitError,
11
+ classify_provider_error,
12
+ )
13
+ from .runtime import plain_enabled
14
+ from .cancellation import OperationCancelled, raise_if_cancelled
15
+
16
+ console = Console()
17
+
18
+
19
+ def format_model_selection(name: str) -> str:
20
+ try:
21
+ model = get_model_config(name)
22
+ except KeyError:
23
+ return name
24
+ return f"{name} -> {model['provider']}/{model['model_id']}"
25
+
26
+
27
+ def get_provider(name: str):
28
+ try:
29
+ model = get_model_config(name)
30
+ except KeyError:
31
+ return None
32
+ cls = PROVIDERS.get(model["provider"])
33
+ if not cls:
34
+ return None
35
+ provider = cls()
36
+ provider.model_id = model["model_id"]
37
+ provider.context_window = model["context_window"]
38
+ provider.timeout = model.get("timeout", 60)
39
+ provider.retries = model.get("retries", 2)
40
+ return provider
41
+
42
+
43
+ def _model_order(config: dict, preferred: str = None) -> list[str]:
44
+ first = preferred or config["default_model"]
45
+ if not config.get("auto_fallback", True):
46
+ return [first]
47
+ return [first] + [
48
+ model for model in config["fallback_order"]
49
+ if model != first
50
+ ]
51
+
52
+
53
+ def has_available_provider() -> bool:
54
+ """Return True when any configured API provider or local provider is usable."""
55
+ config = load_config()
56
+ for model_name in _model_order(config):
57
+ provider = get_provider(model_name)
58
+ if provider and provider.is_available():
59
+ return True
60
+ return False
61
+
62
+
63
+ def stream_with_fallback(messages: list[Message], system: str = "", preferred: str = None) -> tuple[str, str]:
64
+ """Stream response with auto fallback. Returns (content, model_name)."""
65
+ config = load_config()
66
+ order = _model_order(config, preferred)
67
+
68
+ last_error = None
69
+ for model_name in order:
70
+ raise_if_cancelled()
71
+ provider = get_provider(model_name)
72
+ if not provider or not provider.is_available():
73
+ continue
74
+ try:
75
+ from .context import compact_messages, estimate_text_tokens
76
+
77
+ prepared_messages = compact_messages(
78
+ messages,
79
+ getattr(provider, "context_window", 128_000),
80
+ reserve_tokens=max(4096, estimate_text_tokens(system) + 4096),
81
+ )
82
+ if provider.supports_streaming and not plain_enabled():
83
+ content = provider.stream_chat(prepared_messages, system=system)
84
+ else:
85
+ response = provider.chat(prepared_messages, system=system)
86
+ content = response.content
87
+ console.print(content, markup=False)
88
+ raise_if_cancelled()
89
+ return content, model_name
90
+ except OperationCancelled:
91
+ raise
92
+ except RateLimitError as error:
93
+ last_error = str(error)
94
+ console.print(f"[yellow]⚡ {model_name} limit hit — switching...[/yellow]")
95
+ continue
96
+ except NoAPIKeyError:
97
+ continue
98
+ except (NetworkError, ProviderError) as error:
99
+ last_error = str(error)
100
+ console.print(f"[yellow]Network issue with {model_name} — switching...[/yellow]")
101
+ continue
102
+ except Exception as error:
103
+ normalized = classify_provider_error(provider.name, error)
104
+ console.print(f"[dim]Error with {model_name}: {normalized}[/dim]")
105
+ last_error = str(normalized)
106
+ continue
107
+
108
+ raise AllModelsFailedError(last_error or "")
109
+
110
+
111
+ def chat_with_fallback(
112
+ messages: list[Message],
113
+ system: str = "",
114
+ preferred: str = None,
115
+ tools: list[dict] | None = None,
116
+ ) -> tuple[Response, str]:
117
+ config = load_config()
118
+ order = _model_order(config, preferred)
119
+
120
+ last_error = None
121
+ for model_name in order:
122
+ raise_if_cancelled()
123
+ provider = get_provider(model_name)
124
+ if not provider or not provider.is_available():
125
+ continue
126
+ try:
127
+ from .context import compact_messages, estimate_text_tokens
128
+
129
+ provider_system = system
130
+ provider_tools = tools if provider.supports_native_tools else None
131
+ if tools and not provider.supports_native_tools:
132
+ from .tool_schema import legacy_tool_instructions
133
+
134
+ provider_system = (
135
+ f"{system}\n\n{legacy_tool_instructions()}"
136
+ )
137
+ tool_schema_tokens = estimate_text_tokens(str(provider_tools or ""))
138
+ prepared_messages = compact_messages(
139
+ messages,
140
+ getattr(provider, "context_window", 128_000),
141
+ reserve_tokens=max(
142
+ 4096,
143
+ estimate_text_tokens(provider_system) + tool_schema_tokens + 4096,
144
+ ),
145
+ )
146
+ response = provider.chat(
147
+ prepared_messages,
148
+ system=provider_system,
149
+ tools=provider_tools,
150
+ )
151
+ raise_if_cancelled()
152
+ return response, model_name
153
+ except OperationCancelled:
154
+ raise
155
+ except RateLimitError as error:
156
+ console.print(f"[yellow]⚡ {model_name} limit hit — switching...[/yellow]")
157
+ last_error = str(error)
158
+ continue
159
+ except NoAPIKeyError:
160
+ continue
161
+ except (NetworkError, ProviderError) as error:
162
+ console.print(f"[yellow]Network issue with {model_name} — switching...[/yellow]")
163
+ last_error = str(error)
164
+ continue
165
+ except Exception as error:
166
+ normalized = classify_provider_error(provider.name, error)
167
+ console.print(f"[dim]Error with {model_name}: {normalized}[/dim]")
168
+ last_error = str(normalized)
169
+ continue
170
+
171
+ raise AllModelsFailedError(last_error or "")
zai/core/hooks.py ADDED
@@ -0,0 +1,115 @@
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from .process import run_direct
5
+ from .storage import atomic_write_json, read_json, update_json
6
+
7
+ ZAI_DIR = Path.home() / ".zai"
8
+ HOOKS_FILE = ZAI_DIR / "hooks.json"
9
+
10
+ VALID_EVENTS = [
11
+ "PreToolUse",
12
+ "PostToolUse",
13
+ "SessionStart",
14
+ "SessionEnd",
15
+ "UserPromptSubmit",
16
+ ]
17
+
18
+
19
+ def load_hooks() -> list:
20
+ return read_json(HOOKS_FILE, [], expected_type=list)
21
+
22
+
23
+ def save_hooks(hooks: list):
24
+ atomic_write_json(HOOKS_FILE, hooks)
25
+
26
+
27
+ def add_hook(event: str, command: str) -> dict:
28
+ created = {}
29
+
30
+ def update(hooks):
31
+ nonlocal created
32
+ new_id = max((h["id"] for h in hooks), default=0) + 1
33
+ created = {"id": new_id, "event": event, "command": command}
34
+ return [*hooks, created]
35
+
36
+ update_json(HOOKS_FILE, [], update, expected_type=list)
37
+ hook = created
38
+ return hook
39
+
40
+
41
+ def remove_hook(hook_id: int) -> bool:
42
+ removed = False
43
+
44
+ def update(hooks):
45
+ nonlocal removed
46
+ new_hooks = [hook for hook in hooks if hook["id"] != hook_id]
47
+ removed = len(new_hooks) != len(hooks)
48
+ return new_hooks
49
+
50
+ update_json(HOOKS_FILE, [], update, expected_type=list)
51
+ return removed
52
+
53
+
54
+ def fire(event: str, data: dict = None) -> bool:
55
+ """
56
+ Fire an event. Returns False if any hook blocks (exit code 2).
57
+ Hook scripts receive JSON on stdin.
58
+ """
59
+ hooks = load_hooks()
60
+ payload = json.dumps(data or {})
61
+
62
+ for hook in hooks:
63
+ if hook.get("event") != event:
64
+ continue
65
+ cmd = hook.get("command", "")
66
+ if not cmd:
67
+ continue
68
+ try:
69
+ result = run_direct(
70
+ cmd,
71
+ input_text=payload,
72
+ timeout=10,
73
+ enforce_policy=False,
74
+ )
75
+ if result.returncode == 2:
76
+ if result.output:
77
+ from rich.console import Console
78
+ Console().print(f"[yellow]Hook blocked:[/yellow] {result.output}")
79
+ return False
80
+ elif result.returncode == 126 and result.blocked_reason:
81
+ from rich.console import Console
82
+ Console().print(
83
+ f"[yellow]Hook blocked:[/yellow] {result.blocked_reason}"
84
+ )
85
+ return False
86
+ elif result.returncode == 1 and result.output:
87
+ from rich.console import Console
88
+ Console().print(f"[dim]Hook note: {result.output}[/dim]")
89
+ except Exception:
90
+ pass
91
+
92
+ return True
93
+
94
+
95
+ # Built-in security checks for write_file tool
96
+ SECURITY_PATTERNS = {
97
+ "eval(": "eval() executes arbitrary code — security risk",
98
+ "exec(": "exec() can run arbitrary code",
99
+ "os.system(": "os.system() can allow command injection",
100
+ "pickle.loads(": "pickle.loads() with untrusted data is dangerous",
101
+ ".innerHTML =": "innerHTML with untrusted content causes XSS",
102
+ "dangerouslySetInnerHTML": "dangerouslySetInnerHTML causes XSS — sanitize first",
103
+ "document.write(": "document.write() is an XSS risk",
104
+ "new Function(": "new Function() evaluates arbitrary code",
105
+ "child_process.exec(": "child_process.exec() allows shell injection",
106
+ }
107
+
108
+
109
+ def check_security(content: str) -> list[str]:
110
+ """Return list of security warnings found in content."""
111
+ warnings = []
112
+ for pattern, msg in SECURITY_PATTERNS.items():
113
+ if pattern in content:
114
+ warnings.append(msg)
115
+ return warnings
zai/core/memory.py ADDED
@@ -0,0 +1,57 @@
1
+ from datetime import datetime
2
+ from pathlib import Path
3
+ from ..config import MEMORY_FILE
4
+ from .storage import atomic_write_json, read_json, update_json
5
+
6
+ DEFAULT_MEMORY = {"version": 1, "projects": [], "last_session": None, "preferences": {}}
7
+
8
+
9
+ def load_memory() -> dict:
10
+ data = read_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), expected_type=dict)
11
+ return {**DEFAULT_MEMORY, **data}
12
+
13
+
14
+ def save_memory(memory: dict):
15
+ atomic_write_json(MEMORY_FILE, {**DEFAULT_MEMORY, **memory})
16
+
17
+
18
+ def save_session(task: str, model: str, files: list[str] | None = None):
19
+ def update(memory):
20
+ memory = {**DEFAULT_MEMORY, **memory}
21
+ memory["last_session"] = {
22
+ "task": task,
23
+ "model": model,
24
+ "files_edited": files or [],
25
+ "date": datetime.now().strftime("%Y-%m-%d %H:%M"),
26
+ }
27
+ return memory
28
+
29
+ update_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), update, expected_type=dict)
30
+
31
+
32
+ def get_last_session() -> dict | None:
33
+ return load_memory().get("last_session")
34
+
35
+
36
+ def add_project(name: str, path: str, summary: str = ""):
37
+ def update(memory):
38
+ memory = {**DEFAULT_MEMORY, **memory}
39
+ projects = memory.setdefault("projects", [])
40
+ for project in projects:
41
+ if project["path"] == path:
42
+ project["last_worked"] = datetime.now().strftime("%Y-%m-%d")
43
+ project["summary"] = summary or project.get("summary", "")
44
+ return memory
45
+ projects.append({
46
+ "name": name,
47
+ "path": path,
48
+ "summary": summary,
49
+ "last_worked": datetime.now().strftime("%Y-%m-%d"),
50
+ })
51
+ return memory
52
+
53
+ update_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), update, expected_type=dict)
54
+
55
+
56
+ def get_projects() -> list:
57
+ return load_memory().get("projects", [])
zai/core/process.py ADDED
@@ -0,0 +1,204 @@
1
+ import os
2
+ import re
3
+ import shlex
4
+ import subprocess
5
+ import time
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Callable
9
+ from .cancellation import CancellationToken, current_token
10
+
11
+
12
+ SHELL_META = re.compile(r"(?:&&|\|\||[|;&<>`]|[\r\n]|\$\(|%\w+%)")
13
+ SHELL_INTERPRETERS = {
14
+ "bash", "sh", "zsh", "fish", "cmd", "cmd.exe",
15
+ "powershell", "powershell.exe", "pwsh", "wsl",
16
+ }
17
+ PYTHON_NAMES = {"python", "python.exe", "python3", "python3.exe", "py", "py.exe"}
18
+ PACKAGE_MANAGERS = {"pip", "pip3", "uv", "poetry", "npm", "pnpm", "yarn"}
19
+ DELETE_TOOLS = {"rm", "del", "erase", "rmdir", "remove-item"}
20
+
21
+
22
+ @dataclass
23
+ class CommandResult:
24
+ returncode: int
25
+ stdout: str = ""
26
+ stderr: str = ""
27
+ cancelled: bool = False
28
+ blocked_reason: str = ""
29
+
30
+ @property
31
+ def output(self) -> str:
32
+ return (self.stdout + self.stderr).strip()
33
+
34
+
35
+ def split_command(command: str) -> list[str]:
36
+ """Split a direct executable command without invoking a shell."""
37
+ command = command.strip()
38
+ if not command:
39
+ raise ValueError("empty command")
40
+ if SHELL_META.search(command):
41
+ raise ValueError("shell syntax is not allowed")
42
+ try:
43
+ parts = shlex.split(command, posix=os.name != "nt")
44
+ except ValueError as exc:
45
+ raise ValueError(f"invalid command quoting: {exc}") from exc
46
+ if os.name == "nt":
47
+ parts = [
48
+ part[1:-1] if len(part) >= 2 and part[0] == part[-1] == '"' else part
49
+ for part in parts
50
+ ]
51
+ if not parts:
52
+ raise ValueError("empty command")
53
+ return parts
54
+
55
+
56
+ def _exe_name(argv: list[str]) -> str:
57
+ return Path(argv[0]).name.lower()
58
+
59
+
60
+ def classify_argv(argv: list[str]) -> tuple[str, str]:
61
+ """Classify direct argv as safe, approval, or blocked."""
62
+ executable = _exe_name(argv)
63
+ lowered = [part.lower() for part in argv[1:]]
64
+
65
+ if executable in SHELL_INTERPRETERS:
66
+ return "blocked", "shell interpreters are not allowed"
67
+ if executable in DELETE_TOOLS:
68
+ return "approval", "file deletion"
69
+ if executable in {"format", "mkfs", "mkfs.ext4", "shutdown", "reboot", "halt"}:
70
+ return "blocked", "destructive system operation"
71
+
72
+ if executable == "git":
73
+ if not lowered:
74
+ return "safe", ""
75
+ action = lowered[0]
76
+ if action == "reset" and "--hard" in lowered:
77
+ return "blocked", "destructive git operation"
78
+ if action == "clean" and any(arg.startswith("-") and "f" in arg for arg in lowered):
79
+ return "blocked", "destructive git operation"
80
+ if action in {"status", "diff", "log", "branch", "show", "rev-parse", "--version"}:
81
+ return "safe", ""
82
+ return "approval", "git state change"
83
+
84
+ if executable in PYTHON_NAMES:
85
+ if lowered in (["--version"], ["-v"]):
86
+ return "safe", ""
87
+ if len(lowered) >= 2 and lowered[0] == "-m":
88
+ module = lowered[1]
89
+ if module in {"pytest", "py_compile", "compileall"}:
90
+ return "safe", ""
91
+ if module in {"pip", "uv"}:
92
+ return "approval", "dependency change"
93
+ if "-c" in lowered:
94
+ return "blocked", "inline interpreter execution"
95
+ return "approval", "Python program execution"
96
+
97
+ if executable in {"pytest", "pytest.exe"}:
98
+ return "safe", ""
99
+ if executable in PACKAGE_MANAGERS:
100
+ return "approval", "dependency change"
101
+ if executable in {"curl", "wget", "npx"}:
102
+ return "approval", "network or package execution"
103
+
104
+ return "blocked", f"executable is not allowlisted: {executable}"
105
+
106
+
107
+ def run_direct(
108
+ command: str | list[str],
109
+ *,
110
+ cwd: str | None = None,
111
+ timeout: int = 30,
112
+ input_text: str | None = None,
113
+ approval: Callable[[str], bool] | None = None,
114
+ enforce_policy: bool = True,
115
+ cancellation_token: CancellationToken | None = None,
116
+ ) -> CommandResult:
117
+ """Run an executable directly with shell=False."""
118
+ try:
119
+ argv = split_command(command) if isinstance(command, str) else list(command)
120
+ except ValueError as exc:
121
+ return CommandResult(126, blocked_reason=str(exc))
122
+
123
+ if enforce_policy:
124
+ risk, reason = classify_argv(argv)
125
+ if risk == "blocked":
126
+ return CommandResult(126, blocked_reason=reason)
127
+ if risk == "approval" and (approval is None or not approval(reason)):
128
+ return CommandResult(125, cancelled=True)
129
+
130
+ token = cancellation_token or current_token()
131
+ if token and token.cancelled:
132
+ return CommandResult(125, cancelled=True)
133
+ if token is None:
134
+ try:
135
+ completed = subprocess.run(
136
+ argv,
137
+ shell=False,
138
+ capture_output=True,
139
+ text=True,
140
+ cwd=cwd,
141
+ timeout=timeout,
142
+ input=input_text,
143
+ )
144
+ return CommandResult(
145
+ completed.returncode,
146
+ completed.stdout or "",
147
+ completed.stderr or "",
148
+ )
149
+ except subprocess.TimeoutExpired:
150
+ return CommandResult(124, stderr=f"command timed out after {timeout}s")
151
+ except FileNotFoundError:
152
+ return CommandResult(127, stderr=f"executable not found: {argv[0]}")
153
+ except Exception as exc:
154
+ return CommandResult(1, stderr=str(exc))
155
+
156
+ process = None
157
+ try:
158
+ process = subprocess.Popen(
159
+ argv,
160
+ shell=False,
161
+ stdin=subprocess.PIPE if input_text is not None else None,
162
+ stdout=subprocess.PIPE,
163
+ stderr=subprocess.PIPE,
164
+ text=True,
165
+ cwd=cwd,
166
+ )
167
+ started = time.monotonic()
168
+ pending_input = input_text
169
+ while True:
170
+ if token and token.cancelled:
171
+ process.terminate()
172
+ try:
173
+ process.communicate(timeout=2)
174
+ except subprocess.TimeoutExpired:
175
+ process.kill()
176
+ process.communicate()
177
+ return CommandResult(125, cancelled=True)
178
+ remaining = timeout - (time.monotonic() - started)
179
+ if remaining <= 0:
180
+ process.kill()
181
+ process.communicate()
182
+ return CommandResult(124, stderr=f"command timed out after {timeout}s")
183
+ try:
184
+ stdout, stderr = process.communicate(
185
+ input=pending_input,
186
+ timeout=min(0.1, remaining),
187
+ )
188
+ return CommandResult(
189
+ process.returncode,
190
+ stdout or "",
191
+ stderr or "",
192
+ )
193
+ except subprocess.TimeoutExpired:
194
+ pending_input = None
195
+ except FileNotFoundError:
196
+ return CommandResult(127, stderr=f"executable not found: {argv[0]}")
197
+ except KeyboardInterrupt:
198
+ if token:
199
+ token.cancel("cancelled by user")
200
+ if process and process.poll() is None:
201
+ process.terminate()
202
+ raise
203
+ except Exception as exc:
204
+ return CommandResult(1, stderr=str(exc))