zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,67 @@
1
+ """Cooperative cancellation shared by providers, tools, and integrations."""
2
+ from __future__ import annotations
3
+
4
+ import threading
5
+ from contextlib import contextmanager
6
+ from contextvars import ContextVar
7
+ from dataclasses import dataclass, field
8
+ from typing import Iterator
9
+
10
+
11
+ class OperationCancelled(RuntimeError):
12
+ pass
13
+
14
+
15
+ @dataclass
16
+ class CancellationToken:
17
+ _event: threading.Event = field(default_factory=threading.Event)
18
+ reason: str = "operation cancelled"
19
+
20
+ def cancel(self, reason: str = "operation cancelled") -> None:
21
+ self.reason = reason
22
+ self._event.set()
23
+
24
+ @property
25
+ def cancelled(self) -> bool:
26
+ return self._event.is_set()
27
+
28
+ def raise_if_cancelled(self) -> None:
29
+ if self.cancelled:
30
+ raise OperationCancelled(self.reason)
31
+
32
+
33
+ _current: ContextVar[CancellationToken | None] = ContextVar(
34
+ "zai_cancellation_token",
35
+ default=None,
36
+ )
37
+
38
+
39
+ def current_token() -> CancellationToken | None:
40
+ return _current.get()
41
+
42
+
43
+ def raise_if_cancelled() -> None:
44
+ token = current_token()
45
+ if token:
46
+ token.raise_if_cancelled()
47
+
48
+
49
+ def cancel_current(reason: str = "operation cancelled") -> bool:
50
+ token = current_token()
51
+ if not token:
52
+ return False
53
+ token.cancel(reason)
54
+ return True
55
+
56
+
57
+ @contextmanager
58
+ def operation(token: CancellationToken | None = None) -> Iterator[CancellationToken]:
59
+ active = token or CancellationToken()
60
+ marker = _current.set(active)
61
+ try:
62
+ yield active
63
+ except KeyboardInterrupt:
64
+ active.cancel("cancelled by user")
65
+ raise
66
+ finally:
67
+ _current.reset(marker)
zai/core/commands.py ADDED
@@ -0,0 +1,85 @@
1
+ """
2
+ Slash commands system — inspired by Claude Code.
3
+ Commands are .md files with YAML frontmatter in ~/.zai/commands/
4
+ """
5
+ import os
6
+ import re
7
+ from pathlib import Path
8
+ from .process import run_direct
9
+
10
+ COMMANDS_DIR = Path.home() / ".zai" / "commands"
11
+ BUILTIN_DIR = Path(__file__).parent.parent / "commands"
12
+
13
+
14
+ def _parse_frontmatter(text: str) -> tuple[dict, str]:
15
+ """Parse YAML frontmatter from markdown file. Returns (meta, body)."""
16
+ if not text.startswith("---"):
17
+ return {}, text
18
+ end = text.find("\n---", 3)
19
+ if end == -1:
20
+ return {}, text
21
+ frontmatter = text[3:end].strip()
22
+ body = text[end + 4:].strip()
23
+ meta = {}
24
+ for line in frontmatter.splitlines():
25
+ if ":" in line:
26
+ k, v = line.split(":", 1)
27
+ meta[k.strip()] = v.strip()
28
+ return meta, body
29
+
30
+
31
+ def _substitute_shell(prompt: str) -> str:
32
+ """Replace !`command` with command output (like Claude Code does)."""
33
+ pattern = r'!`([^`]+)`'
34
+ def run(m):
35
+ cmd = m.group(1)
36
+ result = run_direct(cmd, timeout=10)
37
+ if result.blocked_reason:
38
+ return f"(blocked command substitution: {result.blocked_reason})"
39
+ if result.cancelled:
40
+ return "(blocked command substitution: approval required)"
41
+ if result.returncode != 0:
42
+ return f"(failed: {cmd}: {result.output})"
43
+ return result.output
44
+ return re.sub(pattern, run, prompt)
45
+
46
+
47
+ def load_commands() -> dict:
48
+ """Load all available slash commands. Returns {name: {meta, body, path}}."""
49
+ commands = {}
50
+
51
+ # Load built-in commands first
52
+ for search_dir in [BUILTIN_DIR, COMMANDS_DIR]:
53
+ if not search_dir.exists():
54
+ continue
55
+ for md_file in sorted(search_dir.glob("*.md")):
56
+ name = md_file.stem
57
+ text = md_file.read_text(encoding="utf-8", errors="ignore")
58
+ meta, body = _parse_frontmatter(text)
59
+ commands[name] = {
60
+ "meta": meta,
61
+ "body": body,
62
+ "path": str(md_file),
63
+ "description": meta.get("description", ""),
64
+ }
65
+
66
+ return commands
67
+
68
+
69
+ def get_command_prompt(name: str, arguments: str = "") -> str | None:
70
+ """Get the prompt for a slash command, with substitutions applied."""
71
+ commands = load_commands()
72
+ if name not in commands:
73
+ return None
74
+ body = commands[name]["body"]
75
+ body = body.replace("$ARGUMENTS", arguments)
76
+ body = _substitute_shell(body)
77
+ return body
78
+
79
+
80
+ def list_commands() -> list[dict]:
81
+ """List all commands with name and description."""
82
+ return [
83
+ {"name": name, "description": cmd["description"]}
84
+ for name, cmd in load_commands().items()
85
+ ]
zai/core/context.py ADDED
@@ -0,0 +1,299 @@
1
+ """Shared, model-aware conversation context management."""
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ import re
6
+ from dataclasses import replace
7
+
8
+ from ..providers.base import Message
9
+
10
+ DEFAULT_CONTEXT_WINDOW = 128_000
11
+ DEFAULT_OUTPUT_RESERVE = 8_192
12
+ MAX_TOOL_RESULT_CHARS = 16_000
13
+ SUMMARY_MAX_CHARS = 12_000
14
+
15
+
16
+ def estimate_text_tokens(text: str) -> int:
17
+ """Conservative provider-neutral token estimate."""
18
+ if not text:
19
+ return 0
20
+ words = len(re.findall(r"\S+", text))
21
+ return max(1, (len(text) + 2) // 3, int(words * 1.35))
22
+
23
+
24
+ def estimate_message_tokens(message: Message) -> int:
25
+ overhead = 8
26
+ tool_payload = ""
27
+ if message.tool_calls:
28
+ tool_payload = json.dumps([
29
+ {
30
+ "id": call.id,
31
+ "name": call.name,
32
+ "arguments": call.arguments,
33
+ }
34
+ for call in message.tool_calls
35
+ ], ensure_ascii=False)
36
+ return overhead + estimate_text_tokens(message.content) + estimate_text_tokens(
37
+ tool_payload
38
+ )
39
+
40
+
41
+ def bound_tool_result(content: str, max_chars: int = MAX_TOOL_RESULT_CHARS) -> str:
42
+ if max_chars <= 0:
43
+ return ""
44
+ if len(content) <= max_chars:
45
+ return content
46
+ head = max_chars * 2 // 3
47
+ tail = max_chars - head
48
+ removed = len(content) - max_chars
49
+ return (
50
+ content[:head]
51
+ + f"\n\n[... {removed} characters omitted from tool result ...]\n\n"
52
+ + content[-tail:]
53
+ )
54
+
55
+
56
+ def _fit_message(message: Message, token_budget: int) -> Message | None:
57
+ if token_budget <= 8:
58
+ return None
59
+ if estimate_message_tokens(message) <= token_budget:
60
+ return message
61
+ low, high = 0, len(message.content)
62
+ best = ""
63
+ while low <= high:
64
+ middle = (low + high) // 2
65
+ candidate_content = bound_tool_result(message.content, middle)
66
+ candidate = replace(message, content=candidate_content)
67
+ if estimate_message_tokens(candidate) <= token_budget:
68
+ best = candidate_content
69
+ low = middle + 1
70
+ else:
71
+ high = middle - 1
72
+ return replace(message, content=best) if best or not message.content else None
73
+
74
+
75
+ def _context_units(messages: list[Message]) -> list[list[Message]]:
76
+ """Keep assistant tool calls and their tool results indivisible."""
77
+ units: list[list[Message]] = []
78
+ index = 0
79
+ while index < len(messages):
80
+ message = messages[index]
81
+ unit = [message]
82
+ index += 1
83
+ if message.role == "assistant" and message.tool_calls:
84
+ call_ids = {call.id for call in message.tool_calls}
85
+ while (
86
+ index < len(messages)
87
+ and messages[index].role == "tool"
88
+ and messages[index].tool_call_id in call_ids
89
+ ):
90
+ unit.append(messages[index])
91
+ index += 1
92
+ units.append(unit)
93
+ return units
94
+
95
+
96
+ def _unit_tokens(unit: list[Message]) -> int:
97
+ return sum(estimate_message_tokens(message) for message in unit)
98
+
99
+
100
+ def _summary(messages: list[Message]) -> Message:
101
+ lines = [
102
+ "[Compacted earlier context]",
103
+ "The following facts were retained from older conversation turns:",
104
+ ]
105
+ remaining = SUMMARY_MAX_CHARS
106
+ for message in messages:
107
+ if message.role == "tool":
108
+ label = f"TOOL {message.tool_name or 'result'}"
109
+ else:
110
+ label = message.role.upper()
111
+ content = " ".join(message.content.split())
112
+ if not content and message.tool_calls:
113
+ content = "Requested tools: " + ", ".join(
114
+ call.name for call in message.tool_calls
115
+ )
116
+ if not content:
117
+ continue
118
+ excerpt = content[:800]
119
+ line = f"- {label}: {excerpt}"
120
+ if len(line) > remaining:
121
+ break
122
+ lines.append(line)
123
+ remaining -= len(line)
124
+ return Message(role="user", content="\n".join(lines), pinned=True)
125
+
126
+
127
+ def compact_messages(
128
+ messages: list[Message],
129
+ max_tokens: int,
130
+ *,
131
+ reserve_tokens: int = DEFAULT_OUTPUT_RESERVE,
132
+ ) -> list[Message]:
133
+ """Fit messages to a context budget while preserving pinned and recent units."""
134
+ input_budget = max(128, max_tokens - reserve_tokens)
135
+ normalized = [
136
+ replace(
137
+ message,
138
+ content=(
139
+ bound_tool_result(message.content)
140
+ if message.role == "tool"
141
+ else message.content
142
+ ),
143
+ )
144
+ for message in messages
145
+ ]
146
+ if sum(estimate_message_tokens(message) for message in normalized) <= input_budget:
147
+ return normalized
148
+
149
+ units = _context_units(normalized)
150
+ pinned_units = [
151
+ unit for unit in units if any(message.pinned for message in unit)
152
+ ]
153
+ regular_units = [
154
+ unit for unit in units if not any(message.pinned for message in unit)
155
+ ]
156
+ summary_reserve = min(512, max(64, input_budget // 4))
157
+ selection_budget = max(64, input_budget - summary_reserve)
158
+ selected: list[list[Message]] = []
159
+ used = sum(_unit_tokens(unit) for unit in pinned_units)
160
+
161
+ for position, unit in enumerate(reversed(regular_units)):
162
+ cost = _unit_tokens(unit)
163
+ if position == 0:
164
+ # The latest user/tool turn must never disappear entirely.
165
+ selected.append(unit)
166
+ used += cost
167
+ continue
168
+ if used + cost > selection_budget:
169
+ continue
170
+ selected.append(unit)
171
+ used += cost
172
+ selected.reverse()
173
+
174
+ selected_ids = {id(message) for unit in selected for message in unit}
175
+ omitted = [
176
+ message
177
+ for unit in regular_units
178
+ for message in unit
179
+ if id(message) not in selected_ids
180
+ ]
181
+ result = [message for unit in pinned_units for message in unit]
182
+ if omitted and used < input_budget:
183
+ summary = _summary(omitted)
184
+ if estimate_message_tokens(summary) > summary_reserve:
185
+ summary = replace(
186
+ summary,
187
+ content=bound_tool_result(
188
+ summary.content,
189
+ max(128, summary_reserve * 3 - 32),
190
+ ),
191
+ )
192
+ if used + estimate_message_tokens(summary) <= input_budget:
193
+ result.append(summary)
194
+ result.extend(message for unit in selected for message in unit)
195
+
196
+ # A single pinned or recent message may itself exceed the model limit.
197
+ final: list[Message] = []
198
+ remaining = input_budget
199
+ for message in result:
200
+ cost = estimate_message_tokens(message)
201
+ if cost <= remaining:
202
+ final.append(message)
203
+ remaining -= cost
204
+ continue
205
+ fitted = _fit_message(message, remaining)
206
+ if fitted is not None:
207
+ final.append(fitted)
208
+ break
209
+ return final
210
+
211
+
212
+ class ContextManager:
213
+ def __init__(
214
+ self,
215
+ max_tokens: int = DEFAULT_CONTEXT_WINDOW,
216
+ model: str | None = None,
217
+ ):
218
+ self.messages: list[Message] = []
219
+ self.max_tokens = max_tokens
220
+ self.model = model
221
+ if model:
222
+ self.set_model(model)
223
+
224
+ def set_model(self, model: str | None) -> None:
225
+ self.model = model
226
+ if model:
227
+ try:
228
+ from ..config import get_model_config
229
+
230
+ self.max_tokens = get_model_config(model)["context_window"]
231
+ except KeyError:
232
+ pass
233
+
234
+ def add(
235
+ self,
236
+ role: str,
237
+ content: str,
238
+ *,
239
+ pinned: bool = False,
240
+ **message_fields,
241
+ ) -> Message:
242
+ message = Message(
243
+ role=role,
244
+ content=content,
245
+ pinned=pinned,
246
+ **message_fields,
247
+ )
248
+ self.messages.append(message)
249
+ return message
250
+
251
+ def add_message(self, message: Message, *, pinned: bool | None = None) -> None:
252
+ self.messages.append(
253
+ replace(message, pinned=pinned)
254
+ if pinned is not None
255
+ else message
256
+ )
257
+
258
+ def replace_messages(self, messages: list[Message]) -> None:
259
+ self.messages[:] = messages
260
+
261
+ def is_near_limit(self, threshold: float = 0.85) -> bool:
262
+ return self.token_count() >= self.max_tokens * threshold
263
+
264
+ def compress(self, preferred: str | None = None) -> None:
265
+ if preferred:
266
+ self.set_model(preferred)
267
+ self.messages[:] = compact_messages(
268
+ self.messages,
269
+ self.max_tokens,
270
+ )
271
+
272
+ def prepared(
273
+ self,
274
+ *,
275
+ model: str | None = None,
276
+ reserve_tokens: int = DEFAULT_OUTPUT_RESERVE,
277
+ ) -> list[Message]:
278
+ if model:
279
+ self.set_model(model)
280
+ return compact_messages(
281
+ self.messages,
282
+ self.max_tokens,
283
+ reserve_tokens=reserve_tokens,
284
+ )
285
+
286
+ def update_limit(self, new_max: int) -> None:
287
+ self.max_tokens = new_max
288
+
289
+ def get_messages(self) -> list[Message]:
290
+ return self.messages
291
+
292
+ def token_count(self) -> int:
293
+ return sum(estimate_message_tokens(message) for message in self.messages)
294
+
295
+ def usage_bar(self) -> str:
296
+ ratio = self.token_count() / max(1, self.max_tokens)
297
+ filled = min(int(ratio * 30), 30)
298
+ bar = "#" * filled + "-" * (30 - filled)
299
+ return f"[{bar}] {int(ratio * 100)}%"
zai/core/errors.py ADDED
@@ -0,0 +1,125 @@
1
+ from rich.console import Console
2
+ from .runtime import print_exception
3
+
4
+ console = Console()
5
+
6
+
7
+ class ZaiError(Exception):
8
+ pass
9
+
10
+
11
+ class NoAPIKeyError(ZaiError):
12
+ def __init__(self, provider: str):
13
+ self.provider = provider
14
+ super().__init__(f"No API key for {provider}. Run: zai setup")
15
+
16
+
17
+ class ProviderError(ZaiError):
18
+ category = "provider"
19
+
20
+ def __init__(self, provider: str, message: str):
21
+ self.provider = provider
22
+ self.detail = message
23
+ super().__init__(f"{provider} {self.category} error: {message}")
24
+
25
+
26
+ class AuthenticationError(ProviderError):
27
+ category = "authentication"
28
+
29
+
30
+ class QuotaError(ProviderError):
31
+ category = "quota"
32
+
33
+
34
+ class ModelNotFoundError(ProviderError):
35
+ category = "model-not-found"
36
+
37
+
38
+ class ProviderTimeoutError(ProviderError):
39
+ category = "timeout"
40
+
41
+
42
+ class MalformedResponseError(ProviderError):
43
+ category = "malformed-response"
44
+
45
+
46
+ class AllModelsFailedError(ZaiError):
47
+ def __init__(self, last_error: str = ""):
48
+ message = "All models failed. Check API keys with `zai setup` or verify Ollama."
49
+ if last_error:
50
+ message += f" Last error: {last_error}"
51
+ super().__init__(message)
52
+
53
+
54
+ class RateLimitError(QuotaError):
55
+ def __init__(self, provider: str):
56
+ super().__init__(provider, "rate limit hit")
57
+
58
+
59
+ class NetworkError(ZaiError):
60
+ def __init__(self, msg: str = ""):
61
+ super().__init__(f"Network error: {msg or 'Check your internet connection.'}")
62
+
63
+
64
+ def classify_provider_error(provider: str, error: Exception) -> ZaiError:
65
+ """Convert SDK/HTTP failures into stable user-facing categories."""
66
+ if isinstance(error, ZaiError):
67
+ return error
68
+ status = getattr(getattr(error, "response", None), "status_code", None)
69
+ text = str(error)
70
+ lowered = text.lower()
71
+ if status in {401, 403} or any(
72
+ marker in lowered for marker in ("api key", "unauthorized", "authentication")
73
+ ):
74
+ return AuthenticationError(provider, text)
75
+ if status == 429 or any(
76
+ marker in lowered for marker in ("quota", "rate limit", "resource exhausted")
77
+ ):
78
+ return QuotaError(provider, text)
79
+ if status == 404 or any(
80
+ marker in lowered for marker in ("model not found", "unknown model", "does not exist")
81
+ ):
82
+ return ModelNotFoundError(provider, text)
83
+ if any(marker in lowered for marker in ("timeout", "timed out", "deadline exceeded")):
84
+ return ProviderTimeoutError(provider, text)
85
+ if any(marker in lowered for marker in ("connect", "network", "dns", "connection")):
86
+ return NetworkError(f"{provider}: {text}")
87
+ if isinstance(error, (KeyError, TypeError, ValueError)):
88
+ return MalformedResponseError(provider, text)
89
+ return ProviderError(provider, text)
90
+
91
+
92
+ class FileError(ZaiError):
93
+ pass
94
+
95
+
96
+ def show_error(msg: str):
97
+ console.print(f"[red]Error:[/red] {msg}")
98
+
99
+
100
+ def show_warning(msg: str):
101
+ console.print(f"[yellow]Warning:[/yellow] {msg}")
102
+
103
+
104
+ def handle(func):
105
+ """Decorator — catches ZaiError and unexpected errors for CLI commands."""
106
+ def wrapper(*args, **kwargs):
107
+ try:
108
+ return func(*args, **kwargs)
109
+ except NoAPIKeyError as e:
110
+ show_error(str(e))
111
+ except AllModelsFailedError as e:
112
+ show_error(str(e))
113
+ except NetworkError as e:
114
+ show_error(str(e))
115
+ except FileError as e:
116
+ show_error(str(e))
117
+ except ZaiError as e:
118
+ show_error(str(e))
119
+ except KeyboardInterrupt:
120
+ console.print("\n[dim]Cancelled.[/dim]")
121
+ except Exception as e:
122
+ print_exception(console, e)
123
+ wrapper.__name__ = func.__name__
124
+ wrapper.__doc__ = func.__doc__
125
+ return wrapper