zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/repomap.py ADDED
@@ -0,0 +1,381 @@
1
+ """Fast, bounded, gitignore-aware repository mapping."""
2
+ from __future__ import annotations
3
+
4
+ import ast
5
+ import hashlib
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+
11
+ from .context import estimate_text_tokens
12
+ from .storage import atomic_write_json, read_json
13
+
14
+ SKIP_DIRS = {
15
+ ".git", ".zai", "__pycache__", "node_modules", ".venv", "venv",
16
+ "dist", "build", ".next", ".nuxt", "coverage", ".pytest_cache",
17
+ ".mypy_cache", ".ruff_cache", "target", "vendor",
18
+ }
19
+ SKIP_EXTS = {
20
+ ".pyc", ".pyo", ".lock", ".sum", ".png", ".jpg", ".jpeg", ".gif",
21
+ ".ico", ".svg", ".woff", ".woff2", ".ttf", ".eot", ".mp4", ".zip",
22
+ ".tar", ".gz", ".7z", ".pdf", ".exe", ".dll", ".so", ".dylib",
23
+ }
24
+ CODE_EXTS = {
25
+ ".py", ".js", ".mjs", ".cjs", ".ts", ".tsx", ".jsx", ".go", ".rs",
26
+ ".java", ".cpp", ".cc", ".c", ".h", ".hpp", ".cs", ".rb", ".php",
27
+ ".swift", ".kt", ".kts", ".vue", ".svelte", ".html", ".css", ".scss",
28
+ ".sql", ".sh", ".ps1", ".toml", ".yaml", ".yml", ".json", ".md",
29
+ }
30
+ HIGH_PRIORITY_NAMES = {
31
+ "readme.md", "pyproject.toml", "package.json", "cargo.toml", "go.mod",
32
+ "requirements.txt", "dockerfile", "compose.yml", "compose.yaml",
33
+ "makefile", "justfile",
34
+ }
35
+ SECRET_NAMES = {
36
+ ".env", ".env.local", ".env.production", "credentials.json",
37
+ "secrets.json", "id_rsa", "id_ed25519",
38
+ }
39
+ SECRET_SUFFIXES = {".pem", ".key", ".p12", ".pfx"}
40
+
41
+ DEFAULT_MAX_FILES = 500
42
+ DEFAULT_MAX_SCAN_FILES = 10_000
43
+ DEFAULT_MAX_FILE_BYTES = 512 * 1024
44
+ DEFAULT_MAX_REPO_BYTES = 100 * 1024 * 1024
45
+ DEFAULT_PROMPT_TOKENS = 12_000
46
+ CACHE_VERSION = 1
47
+
48
+
49
+ @dataclass
50
+ class RepoMapStats:
51
+ scanned: int = 0
52
+ indexed: int = 0
53
+ ignored: int = 0
54
+ unsupported: int = 0
55
+ oversized: int = 0
56
+ binary: int = 0
57
+ symlinks: int = 0
58
+ errors: int = 0
59
+ cache_hits: int = 0
60
+ cache_misses: int = 0
61
+ truncated: bool = False
62
+ bytes_considered: int = 0
63
+
64
+
65
+ @dataclass
66
+ class RepoMapResult:
67
+ text: str
68
+ stats: RepoMapStats
69
+ files: list[str]
70
+
71
+
72
+ def _extract_python_symbols(path: str) -> list[str]:
73
+ try:
74
+ tree = ast.parse(Path(path).read_text(encoding="utf-8", errors="ignore"))
75
+ symbols = []
76
+ for node in tree.body:
77
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
78
+ symbols.append(f"def {node.name}()")
79
+ elif isinstance(node, ast.ClassDef):
80
+ methods = [
81
+ f".{item.name}()"
82
+ for item in node.body
83
+ if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef))
84
+ ]
85
+ suffix = f": {', '.join(methods[:3])}" if methods else ""
86
+ symbols.append(f"class {node.name}{suffix}")
87
+ return symbols[:12]
88
+ except (OSError, SyntaxError, UnicodeError):
89
+ return []
90
+
91
+
92
+ def _extract_regex_symbols(content: str, suffix: str) -> list[str]:
93
+ patterns = {
94
+ ".go": r"^\s*(?:func|type)\s+([A-Za-z_]\w*)",
95
+ ".rs": r"^\s*(?:pub\s+)?(?:fn|struct|enum|trait|impl)\s+([A-Za-z_]\w*)",
96
+ ".java": r"^\s*(?:public|private|protected)?\s*(?:class|interface|enum)\s+([A-Za-z_]\w*)",
97
+ ".cs": r"^\s*(?:public|private|internal|protected)?\s*(?:class|interface|enum|record)\s+([A-Za-z_]\w*)",
98
+ ".rb": r"^\s*(?:def|class|module)\s+([A-Za-z_]\w*[!?=]?)",
99
+ ".php": r"^\s*(?:public\s+|private\s+|protected\s+)?(?:function|class|interface|trait)\s+([A-Za-z_]\w*)",
100
+ ".swift": r"^\s*(?:func|class|struct|enum|protocol)\s+([A-Za-z_]\w*)",
101
+ ".kt": r"^\s*(?:fun|class|object|interface|data\s+class)\s+([A-Za-z_]\w*)",
102
+ ".kts": r"^\s*(?:fun|class|object|interface|data\s+class)\s+([A-Za-z_]\w*)",
103
+ }
104
+ pattern = patterns.get(suffix)
105
+ if not pattern:
106
+ return []
107
+ return [
108
+ match.group(0).strip()[:100]
109
+ for match in re.finditer(pattern, content, re.MULTILINE)
110
+ ][:12]
111
+
112
+
113
+ def _extract_js_symbols(content: str) -> list[str]:
114
+ pattern = re.compile(
115
+ r"^\s*(?:export\s+(?:default\s+)?)?"
116
+ r"(?:async\s+)?(?:function|class|const|let)\s+([A-Za-z_$][\w$]*)",
117
+ re.MULTILINE,
118
+ )
119
+ return [match.group(0).strip()[:100] for match in pattern.finditer(content)][:12]
120
+
121
+
122
+ def _extract_symbols(path: Path, content: str | None = None) -> list[str]:
123
+ suffix = path.suffix.lower()
124
+ if suffix == ".py":
125
+ return _extract_python_symbols(str(path))
126
+ if content is None:
127
+ try:
128
+ content = path.read_text(encoding="utf-8", errors="ignore")
129
+ except OSError:
130
+ return []
131
+ if suffix in {".js", ".mjs", ".cjs", ".ts", ".tsx", ".jsx", ".vue", ".svelte"}:
132
+ return _extract_js_symbols(content)
133
+ return _extract_regex_symbols(content, suffix)
134
+
135
+
136
+ def _is_secret(path: Path) -> bool:
137
+ name = path.name.lower()
138
+ return (
139
+ name in SECRET_NAMES
140
+ or name.startswith(".env.")
141
+ or path.suffix.lower() in SECRET_SUFFIXES
142
+ or any(part.lower() in {".ssh", ".aws", ".gnupg"} for part in path.parts)
143
+ )
144
+
145
+
146
+ def _looks_binary(path: Path) -> bool:
147
+ try:
148
+ with path.open("rb") as handle:
149
+ chunk = handle.read(4096)
150
+ return b"\0" in chunk
151
+ except OSError:
152
+ return True
153
+
154
+
155
+ def _load_gitignore(root: Path, extra_patterns: list[str] | None = None):
156
+ try:
157
+ import pathspec
158
+ except ImportError:
159
+ return None
160
+ patterns = []
161
+ for filename in (".gitignore", ".ignore"):
162
+ path = root / filename
163
+ if path.is_file():
164
+ try:
165
+ patterns.extend(path.read_text(encoding="utf-8", errors="ignore").splitlines())
166
+ except OSError:
167
+ pass
168
+ patterns.extend(extra_patterns or [])
169
+ return pathspec.GitIgnoreSpec.from_lines(patterns)
170
+
171
+
172
+ def _ignored(spec, relative: str, is_dir: bool = False) -> bool:
173
+ if spec is None:
174
+ return False
175
+ candidate = relative.replace("\\", "/") + ("/" if is_dir else "")
176
+ return spec.match_file(candidate)
177
+
178
+
179
+ def _priority(relative: Path) -> tuple:
180
+ suffix = relative.suffix.lower()
181
+ name = relative.name.lower()
182
+ depth = len(relative.parts)
183
+ if name in HIGH_PRIORITY_NAMES:
184
+ group = 0
185
+ elif relative.parts and relative.parts[0].lower() in {"test", "tests", "spec"}:
186
+ group = 3
187
+ elif suffix in {".py", ".ts", ".tsx", ".js", ".jsx", ".go", ".rs", ".java", ".cs"}:
188
+ group = 1
189
+ elif suffix in CODE_EXTS:
190
+ group = 2
191
+ else:
192
+ group = 3
193
+ return group, depth, relative.as_posix().lower()
194
+
195
+
196
+ def _cache_path(root: Path) -> Path:
197
+ return root / ".zai" / "cache" / "repomap.json"
198
+
199
+
200
+ def _load_cache(root: Path) -> dict:
201
+ data = read_json(_cache_path(root), {}, expected_type=dict, quarantine=True)
202
+ if data.get("version") != CACHE_VERSION:
203
+ return {"version": CACHE_VERSION, "files": {}}
204
+ return data
205
+
206
+
207
+ def _cache_key(path: Path) -> str:
208
+ stat = path.stat()
209
+ return f"{stat.st_mtime_ns}:{stat.st_size}"
210
+
211
+
212
+ def _content_hash(path: Path) -> str:
213
+ digest = hashlib.sha256()
214
+ with path.open("rb") as handle:
215
+ for chunk in iter(lambda: handle.read(1024 * 1024), b""):
216
+ digest.update(chunk)
217
+ return digest.hexdigest()
218
+
219
+
220
+ def _scan_candidates(
221
+ root: Path,
222
+ stats: RepoMapStats,
223
+ max_scan_files: int,
224
+ ignore_patterns: list[str] | None = None,
225
+ progress_callback=None,
226
+ ) -> list[tuple[Path, Path, int]]:
227
+ spec = _load_gitignore(root, ignore_patterns)
228
+ candidates = []
229
+ for current, dirs, files in os.walk(root, topdown=True, followlinks=False):
230
+ current_path = Path(current)
231
+ kept_dirs = []
232
+ for dirname in sorted(dirs):
233
+ path = current_path / dirname
234
+ relative = path.relative_to(root)
235
+ if path.is_symlink():
236
+ stats.symlinks += 1
237
+ elif dirname in SKIP_DIRS or _ignored(spec, relative.as_posix(), True):
238
+ stats.ignored += 1
239
+ else:
240
+ kept_dirs.append(dirname)
241
+ dirs[:] = kept_dirs
242
+
243
+ for filename in sorted(files):
244
+ if stats.scanned >= max_scan_files:
245
+ stats.truncated = True
246
+ return candidates
247
+ stats.scanned += 1
248
+ if progress_callback and stats.scanned % 250 == 0:
249
+ progress_callback(stats)
250
+ path = current_path / filename
251
+ relative = path.relative_to(root)
252
+ if path.is_symlink():
253
+ stats.symlinks += 1
254
+ continue
255
+ if _ignored(spec, relative.as_posix()) or _is_secret(relative):
256
+ stats.ignored += 1
257
+ continue
258
+ suffix = path.suffix.lower()
259
+ if suffix in SKIP_EXTS or suffix not in CODE_EXTS:
260
+ stats.unsupported += 1
261
+ continue
262
+ try:
263
+ size = path.stat().st_size
264
+ except OSError:
265
+ stats.errors += 1
266
+ continue
267
+ candidates.append((path, relative, size))
268
+ return candidates
269
+
270
+
271
+ def build_repomap_result(
272
+ directory: str = ".",
273
+ max_files: int = DEFAULT_MAX_FILES,
274
+ *,
275
+ max_scan_files: int = DEFAULT_MAX_SCAN_FILES,
276
+ max_file_bytes: int = DEFAULT_MAX_FILE_BYTES,
277
+ max_repo_bytes: int = DEFAULT_MAX_REPO_BYTES,
278
+ use_cache: bool = True,
279
+ ignore_patterns: list[str] | None = None,
280
+ progress_callback=None,
281
+ ) -> RepoMapResult:
282
+ root = Path(directory).resolve()
283
+ if not root.is_dir():
284
+ raise ValueError(f"Repository directory not found: {directory}")
285
+ stats = RepoMapStats()
286
+ candidates = _scan_candidates(
287
+ root,
288
+ stats,
289
+ max_scan_files,
290
+ ignore_patterns,
291
+ progress_callback,
292
+ )
293
+ candidates.sort(key=lambda item: _priority(item[1]))
294
+ cache = _load_cache(root) if use_cache else {"version": CACHE_VERSION, "files": {}}
295
+ old_files = cache.get("files", {})
296
+ new_files = {}
297
+ entries = []
298
+ indexed_paths = []
299
+
300
+ for path, relative, size in candidates:
301
+ if len(entries) >= max_files:
302
+ stats.truncated = True
303
+ break
304
+ if size > max_file_bytes:
305
+ stats.oversized += 1
306
+ continue
307
+ if stats.bytes_considered + size > max_repo_bytes:
308
+ stats.truncated = True
309
+ break
310
+ if _looks_binary(path):
311
+ stats.binary += 1
312
+ continue
313
+ stats.bytes_considered += size
314
+ relative_text = relative.as_posix()
315
+ try:
316
+ key = _cache_key(path)
317
+ except OSError:
318
+ stats.errors += 1
319
+ continue
320
+ cached = old_files.get(relative_text)
321
+ if cached and cached.get("key") == key:
322
+ symbols = cached.get("symbols", [])
323
+ stats.cache_hits += 1
324
+ else:
325
+ symbols = _extract_symbols(path)
326
+ stats.cache_misses += 1
327
+ new_files[relative_text] = {
328
+ "key": key,
329
+ "sha256": (
330
+ cached.get("sha256")
331
+ if cached and cached.get("key") == key
332
+ else _content_hash(path)
333
+ ),
334
+ "symbols": symbols,
335
+ "size": size,
336
+ }
337
+ entries.append((relative_text, size, symbols))
338
+ indexed_paths.append(relative_text)
339
+
340
+ stats.indexed = len(entries)
341
+ if use_cache:
342
+ atomic_write_json(_cache_path(root), {
343
+ "version": CACHE_VERSION,
344
+ "root": str(root),
345
+ "files": new_files,
346
+ })
347
+
348
+ lines = [f"Repo: {root.name}", ""]
349
+ for relative, size, symbols in entries:
350
+ lines.append(f"{relative} ({size // 1024}kb)")
351
+ lines.extend(f" {symbol}" for symbol in symbols)
352
+ lines.extend([
353
+ "",
354
+ f"Indexed: {stats.indexed} | Scanned: {stats.scanned} | "
355
+ f"Cache: {stats.cache_hits} hits/{stats.cache_misses} misses",
356
+ f"Skipped: ignored {stats.ignored}, unsupported {stats.unsupported}, "
357
+ f"oversized {stats.oversized}, binary {stats.binary}, "
358
+ f"symlinks {stats.symlinks}, errors {stats.errors}",
359
+ ])
360
+ if stats.truncated:
361
+ lines.append("Note: repository map was truncated by configured limits.")
362
+ return RepoMapResult("\n".join(lines), stats, indexed_paths)
363
+
364
+
365
+ def build_repomap(directory: str = ".", max_files: int = DEFAULT_MAX_FILES) -> str:
366
+ return build_repomap_result(directory, max_files=max_files).text
367
+
368
+
369
+ def get_repomap_prompt(
370
+ directory: str = ".",
371
+ *,
372
+ max_tokens: int = DEFAULT_PROMPT_TOKENS,
373
+ ) -> str:
374
+ result = build_repomap_result(directory)
375
+ prefix = "Here is the repository structure:\n\n"
376
+ suffix = "\n\nUse this map to understand the codebase when answering questions."
377
+ available_chars = max(512, (max_tokens - estimate_text_tokens(prefix + suffix)) * 3)
378
+ text = result.text
379
+ if len(text) > available_chars:
380
+ text = text[:available_chars] + "\n\n[Repository map truncated to token budget]"
381
+ return f"{prefix}{text}{suffix}"
zai/core/runtime.py ADDED
@@ -0,0 +1,29 @@
1
+ """Process-wide CLI presentation and diagnostics settings."""
2
+ from __future__ import annotations
3
+
4
+ from rich.console import Console
5
+
6
+ _debug = False
7
+ _plain = False
8
+
9
+
10
+ def configure(*, debug: bool = False, plain: bool = False) -> None:
11
+ global _debug, _plain
12
+ _debug = debug
13
+ _plain = plain
14
+
15
+
16
+ def debug_enabled() -> bool:
17
+ return _debug
18
+
19
+
20
+ def plain_enabled() -> bool:
21
+ return _plain
22
+
23
+
24
+ def print_exception(console: Console, error: Exception) -> None:
25
+ if _debug:
26
+ console.print_exception(show_locals=False)
27
+ else:
28
+ console.print(f"[red]Error:[/red] Unexpected error: {error}")
29
+ console.print("[dim]Run with --debug for full traceback.[/dim]")
zai/core/security.py ADDED
@@ -0,0 +1,33 @@
1
+ from pathlib import Path
2
+
3
+ from .errors import FileError
4
+ from .process import classify_argv, split_command
5
+
6
+
7
+ def resolve_project_path(cwd: str | Path, user_path: str, *, allow_root: bool = False) -> Path:
8
+ """Resolve a user-controlled path and require it to stay inside the project."""
9
+ if not user_path or "\x00" in user_path:
10
+ raise FileError("Invalid empty path")
11
+
12
+ root = Path(cwd).resolve()
13
+ candidate = Path(user_path).expanduser()
14
+ if not candidate.is_absolute():
15
+ candidate = root / candidate
16
+ candidate = candidate.resolve()
17
+
18
+ try:
19
+ candidate.relative_to(root)
20
+ except ValueError as exc:
21
+ raise FileError(f"Path must stay inside the current project: {user_path}") from exc
22
+
23
+ if candidate == root and not allow_root:
24
+ raise FileError("Operation on the project root is not allowed")
25
+ return candidate
26
+
27
+
28
+ def classify_command(command: str) -> tuple[str, str]:
29
+ """Return (safe|approval|blocked, reason) using direct argv policy."""
30
+ try:
31
+ return classify_argv(split_command(command))
32
+ except ValueError as exc:
33
+ return "blocked", str(exc)