zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/cancellation.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Cooperative cancellation shared by providers, tools, and integrations."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import threading
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from contextvars import ContextVar
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import Iterator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OperationCancelled(RuntimeError):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CancellationToken:
|
|
17
|
+
_event: threading.Event = field(default_factory=threading.Event)
|
|
18
|
+
reason: str = "operation cancelled"
|
|
19
|
+
|
|
20
|
+
def cancel(self, reason: str = "operation cancelled") -> None:
|
|
21
|
+
self.reason = reason
|
|
22
|
+
self._event.set()
|
|
23
|
+
|
|
24
|
+
@property
|
|
25
|
+
def cancelled(self) -> bool:
|
|
26
|
+
return self._event.is_set()
|
|
27
|
+
|
|
28
|
+
def raise_if_cancelled(self) -> None:
|
|
29
|
+
if self.cancelled:
|
|
30
|
+
raise OperationCancelled(self.reason)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
_current: ContextVar[CancellationToken | None] = ContextVar(
|
|
34
|
+
"zai_cancellation_token",
|
|
35
|
+
default=None,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def current_token() -> CancellationToken | None:
|
|
40
|
+
return _current.get()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def raise_if_cancelled() -> None:
|
|
44
|
+
token = current_token()
|
|
45
|
+
if token:
|
|
46
|
+
token.raise_if_cancelled()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def cancel_current(reason: str = "operation cancelled") -> bool:
|
|
50
|
+
token = current_token()
|
|
51
|
+
if not token:
|
|
52
|
+
return False
|
|
53
|
+
token.cancel(reason)
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@contextmanager
|
|
58
|
+
def operation(token: CancellationToken | None = None) -> Iterator[CancellationToken]:
|
|
59
|
+
active = token or CancellationToken()
|
|
60
|
+
marker = _current.set(active)
|
|
61
|
+
try:
|
|
62
|
+
yield active
|
|
63
|
+
except KeyboardInterrupt:
|
|
64
|
+
active.cancel("cancelled by user")
|
|
65
|
+
raise
|
|
66
|
+
finally:
|
|
67
|
+
_current.reset(marker)
|
zai/core/commands.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Slash commands system — inspired by Claude Code.
|
|
3
|
+
Commands are .md files with YAML frontmatter in ~/.zai/commands/
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from .process import run_direct
|
|
9
|
+
|
|
10
|
+
COMMANDS_DIR = Path.home() / ".zai" / "commands"
|
|
11
|
+
BUILTIN_DIR = Path(__file__).parent.parent / "commands"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _parse_frontmatter(text: str) -> tuple[dict, str]:
|
|
15
|
+
"""Parse YAML frontmatter from markdown file. Returns (meta, body)."""
|
|
16
|
+
if not text.startswith("---"):
|
|
17
|
+
return {}, text
|
|
18
|
+
end = text.find("\n---", 3)
|
|
19
|
+
if end == -1:
|
|
20
|
+
return {}, text
|
|
21
|
+
frontmatter = text[3:end].strip()
|
|
22
|
+
body = text[end + 4:].strip()
|
|
23
|
+
meta = {}
|
|
24
|
+
for line in frontmatter.splitlines():
|
|
25
|
+
if ":" in line:
|
|
26
|
+
k, v = line.split(":", 1)
|
|
27
|
+
meta[k.strip()] = v.strip()
|
|
28
|
+
return meta, body
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _substitute_shell(prompt: str) -> str:
|
|
32
|
+
"""Replace !`command` with command output (like Claude Code does)."""
|
|
33
|
+
pattern = r'!`([^`]+)`'
|
|
34
|
+
def run(m):
|
|
35
|
+
cmd = m.group(1)
|
|
36
|
+
result = run_direct(cmd, timeout=10)
|
|
37
|
+
if result.blocked_reason:
|
|
38
|
+
return f"(blocked command substitution: {result.blocked_reason})"
|
|
39
|
+
if result.cancelled:
|
|
40
|
+
return "(blocked command substitution: approval required)"
|
|
41
|
+
if result.returncode != 0:
|
|
42
|
+
return f"(failed: {cmd}: {result.output})"
|
|
43
|
+
return result.output
|
|
44
|
+
return re.sub(pattern, run, prompt)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def load_commands() -> dict:
|
|
48
|
+
"""Load all available slash commands. Returns {name: {meta, body, path}}."""
|
|
49
|
+
commands = {}
|
|
50
|
+
|
|
51
|
+
# Load built-in commands first
|
|
52
|
+
for search_dir in [BUILTIN_DIR, COMMANDS_DIR]:
|
|
53
|
+
if not search_dir.exists():
|
|
54
|
+
continue
|
|
55
|
+
for md_file in sorted(search_dir.glob("*.md")):
|
|
56
|
+
name = md_file.stem
|
|
57
|
+
text = md_file.read_text(encoding="utf-8", errors="ignore")
|
|
58
|
+
meta, body = _parse_frontmatter(text)
|
|
59
|
+
commands[name] = {
|
|
60
|
+
"meta": meta,
|
|
61
|
+
"body": body,
|
|
62
|
+
"path": str(md_file),
|
|
63
|
+
"description": meta.get("description", ""),
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return commands
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_command_prompt(name: str, arguments: str = "") -> str | None:
|
|
70
|
+
"""Get the prompt for a slash command, with substitutions applied."""
|
|
71
|
+
commands = load_commands()
|
|
72
|
+
if name not in commands:
|
|
73
|
+
return None
|
|
74
|
+
body = commands[name]["body"]
|
|
75
|
+
body = body.replace("$ARGUMENTS", arguments)
|
|
76
|
+
body = _substitute_shell(body)
|
|
77
|
+
return body
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def list_commands() -> list[dict]:
|
|
81
|
+
"""List all commands with name and description."""
|
|
82
|
+
return [
|
|
83
|
+
{"name": name, "description": cmd["description"]}
|
|
84
|
+
for name, cmd in load_commands().items()
|
|
85
|
+
]
|
zai/core/context.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""Shared, model-aware conversation context management."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import json
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
|
|
8
|
+
from ..providers.base import Message
|
|
9
|
+
|
|
10
|
+
DEFAULT_CONTEXT_WINDOW = 128_000
|
|
11
|
+
DEFAULT_OUTPUT_RESERVE = 8_192
|
|
12
|
+
MAX_TOOL_RESULT_CHARS = 16_000
|
|
13
|
+
SUMMARY_MAX_CHARS = 12_000
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def estimate_text_tokens(text: str) -> int:
|
|
17
|
+
"""Conservative provider-neutral token estimate."""
|
|
18
|
+
if not text:
|
|
19
|
+
return 0
|
|
20
|
+
words = len(re.findall(r"\S+", text))
|
|
21
|
+
return max(1, (len(text) + 2) // 3, int(words * 1.35))
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def estimate_message_tokens(message: Message) -> int:
|
|
25
|
+
overhead = 8
|
|
26
|
+
tool_payload = ""
|
|
27
|
+
if message.tool_calls:
|
|
28
|
+
tool_payload = json.dumps([
|
|
29
|
+
{
|
|
30
|
+
"id": call.id,
|
|
31
|
+
"name": call.name,
|
|
32
|
+
"arguments": call.arguments,
|
|
33
|
+
}
|
|
34
|
+
for call in message.tool_calls
|
|
35
|
+
], ensure_ascii=False)
|
|
36
|
+
return overhead + estimate_text_tokens(message.content) + estimate_text_tokens(
|
|
37
|
+
tool_payload
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def bound_tool_result(content: str, max_chars: int = MAX_TOOL_RESULT_CHARS) -> str:
|
|
42
|
+
if max_chars <= 0:
|
|
43
|
+
return ""
|
|
44
|
+
if len(content) <= max_chars:
|
|
45
|
+
return content
|
|
46
|
+
head = max_chars * 2 // 3
|
|
47
|
+
tail = max_chars - head
|
|
48
|
+
removed = len(content) - max_chars
|
|
49
|
+
return (
|
|
50
|
+
content[:head]
|
|
51
|
+
+ f"\n\n[... {removed} characters omitted from tool result ...]\n\n"
|
|
52
|
+
+ content[-tail:]
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _fit_message(message: Message, token_budget: int) -> Message | None:
|
|
57
|
+
if token_budget <= 8:
|
|
58
|
+
return None
|
|
59
|
+
if estimate_message_tokens(message) <= token_budget:
|
|
60
|
+
return message
|
|
61
|
+
low, high = 0, len(message.content)
|
|
62
|
+
best = ""
|
|
63
|
+
while low <= high:
|
|
64
|
+
middle = (low + high) // 2
|
|
65
|
+
candidate_content = bound_tool_result(message.content, middle)
|
|
66
|
+
candidate = replace(message, content=candidate_content)
|
|
67
|
+
if estimate_message_tokens(candidate) <= token_budget:
|
|
68
|
+
best = candidate_content
|
|
69
|
+
low = middle + 1
|
|
70
|
+
else:
|
|
71
|
+
high = middle - 1
|
|
72
|
+
return replace(message, content=best) if best or not message.content else None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _context_units(messages: list[Message]) -> list[list[Message]]:
|
|
76
|
+
"""Keep assistant tool calls and their tool results indivisible."""
|
|
77
|
+
units: list[list[Message]] = []
|
|
78
|
+
index = 0
|
|
79
|
+
while index < len(messages):
|
|
80
|
+
message = messages[index]
|
|
81
|
+
unit = [message]
|
|
82
|
+
index += 1
|
|
83
|
+
if message.role == "assistant" and message.tool_calls:
|
|
84
|
+
call_ids = {call.id for call in message.tool_calls}
|
|
85
|
+
while (
|
|
86
|
+
index < len(messages)
|
|
87
|
+
and messages[index].role == "tool"
|
|
88
|
+
and messages[index].tool_call_id in call_ids
|
|
89
|
+
):
|
|
90
|
+
unit.append(messages[index])
|
|
91
|
+
index += 1
|
|
92
|
+
units.append(unit)
|
|
93
|
+
return units
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _unit_tokens(unit: list[Message]) -> int:
|
|
97
|
+
return sum(estimate_message_tokens(message) for message in unit)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _summary(messages: list[Message]) -> Message:
|
|
101
|
+
lines = [
|
|
102
|
+
"[Compacted earlier context]",
|
|
103
|
+
"The following facts were retained from older conversation turns:",
|
|
104
|
+
]
|
|
105
|
+
remaining = SUMMARY_MAX_CHARS
|
|
106
|
+
for message in messages:
|
|
107
|
+
if message.role == "tool":
|
|
108
|
+
label = f"TOOL {message.tool_name or 'result'}"
|
|
109
|
+
else:
|
|
110
|
+
label = message.role.upper()
|
|
111
|
+
content = " ".join(message.content.split())
|
|
112
|
+
if not content and message.tool_calls:
|
|
113
|
+
content = "Requested tools: " + ", ".join(
|
|
114
|
+
call.name for call in message.tool_calls
|
|
115
|
+
)
|
|
116
|
+
if not content:
|
|
117
|
+
continue
|
|
118
|
+
excerpt = content[:800]
|
|
119
|
+
line = f"- {label}: {excerpt}"
|
|
120
|
+
if len(line) > remaining:
|
|
121
|
+
break
|
|
122
|
+
lines.append(line)
|
|
123
|
+
remaining -= len(line)
|
|
124
|
+
return Message(role="user", content="\n".join(lines), pinned=True)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def compact_messages(
|
|
128
|
+
messages: list[Message],
|
|
129
|
+
max_tokens: int,
|
|
130
|
+
*,
|
|
131
|
+
reserve_tokens: int = DEFAULT_OUTPUT_RESERVE,
|
|
132
|
+
) -> list[Message]:
|
|
133
|
+
"""Fit messages to a context budget while preserving pinned and recent units."""
|
|
134
|
+
input_budget = max(128, max_tokens - reserve_tokens)
|
|
135
|
+
normalized = [
|
|
136
|
+
replace(
|
|
137
|
+
message,
|
|
138
|
+
content=(
|
|
139
|
+
bound_tool_result(message.content)
|
|
140
|
+
if message.role == "tool"
|
|
141
|
+
else message.content
|
|
142
|
+
),
|
|
143
|
+
)
|
|
144
|
+
for message in messages
|
|
145
|
+
]
|
|
146
|
+
if sum(estimate_message_tokens(message) for message in normalized) <= input_budget:
|
|
147
|
+
return normalized
|
|
148
|
+
|
|
149
|
+
units = _context_units(normalized)
|
|
150
|
+
pinned_units = [
|
|
151
|
+
unit for unit in units if any(message.pinned for message in unit)
|
|
152
|
+
]
|
|
153
|
+
regular_units = [
|
|
154
|
+
unit for unit in units if not any(message.pinned for message in unit)
|
|
155
|
+
]
|
|
156
|
+
summary_reserve = min(512, max(64, input_budget // 4))
|
|
157
|
+
selection_budget = max(64, input_budget - summary_reserve)
|
|
158
|
+
selected: list[list[Message]] = []
|
|
159
|
+
used = sum(_unit_tokens(unit) for unit in pinned_units)
|
|
160
|
+
|
|
161
|
+
for position, unit in enumerate(reversed(regular_units)):
|
|
162
|
+
cost = _unit_tokens(unit)
|
|
163
|
+
if position == 0:
|
|
164
|
+
# The latest user/tool turn must never disappear entirely.
|
|
165
|
+
selected.append(unit)
|
|
166
|
+
used += cost
|
|
167
|
+
continue
|
|
168
|
+
if used + cost > selection_budget:
|
|
169
|
+
continue
|
|
170
|
+
selected.append(unit)
|
|
171
|
+
used += cost
|
|
172
|
+
selected.reverse()
|
|
173
|
+
|
|
174
|
+
selected_ids = {id(message) for unit in selected for message in unit}
|
|
175
|
+
omitted = [
|
|
176
|
+
message
|
|
177
|
+
for unit in regular_units
|
|
178
|
+
for message in unit
|
|
179
|
+
if id(message) not in selected_ids
|
|
180
|
+
]
|
|
181
|
+
result = [message for unit in pinned_units for message in unit]
|
|
182
|
+
if omitted and used < input_budget:
|
|
183
|
+
summary = _summary(omitted)
|
|
184
|
+
if estimate_message_tokens(summary) > summary_reserve:
|
|
185
|
+
summary = replace(
|
|
186
|
+
summary,
|
|
187
|
+
content=bound_tool_result(
|
|
188
|
+
summary.content,
|
|
189
|
+
max(128, summary_reserve * 3 - 32),
|
|
190
|
+
),
|
|
191
|
+
)
|
|
192
|
+
if used + estimate_message_tokens(summary) <= input_budget:
|
|
193
|
+
result.append(summary)
|
|
194
|
+
result.extend(message for unit in selected for message in unit)
|
|
195
|
+
|
|
196
|
+
# A single pinned or recent message may itself exceed the model limit.
|
|
197
|
+
final: list[Message] = []
|
|
198
|
+
remaining = input_budget
|
|
199
|
+
for message in result:
|
|
200
|
+
cost = estimate_message_tokens(message)
|
|
201
|
+
if cost <= remaining:
|
|
202
|
+
final.append(message)
|
|
203
|
+
remaining -= cost
|
|
204
|
+
continue
|
|
205
|
+
fitted = _fit_message(message, remaining)
|
|
206
|
+
if fitted is not None:
|
|
207
|
+
final.append(fitted)
|
|
208
|
+
break
|
|
209
|
+
return final
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
class ContextManager:
|
|
213
|
+
def __init__(
|
|
214
|
+
self,
|
|
215
|
+
max_tokens: int = DEFAULT_CONTEXT_WINDOW,
|
|
216
|
+
model: str | None = None,
|
|
217
|
+
):
|
|
218
|
+
self.messages: list[Message] = []
|
|
219
|
+
self.max_tokens = max_tokens
|
|
220
|
+
self.model = model
|
|
221
|
+
if model:
|
|
222
|
+
self.set_model(model)
|
|
223
|
+
|
|
224
|
+
def set_model(self, model: str | None) -> None:
|
|
225
|
+
self.model = model
|
|
226
|
+
if model:
|
|
227
|
+
try:
|
|
228
|
+
from ..config import get_model_config
|
|
229
|
+
|
|
230
|
+
self.max_tokens = get_model_config(model)["context_window"]
|
|
231
|
+
except KeyError:
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def add(
|
|
235
|
+
self,
|
|
236
|
+
role: str,
|
|
237
|
+
content: str,
|
|
238
|
+
*,
|
|
239
|
+
pinned: bool = False,
|
|
240
|
+
**message_fields,
|
|
241
|
+
) -> Message:
|
|
242
|
+
message = Message(
|
|
243
|
+
role=role,
|
|
244
|
+
content=content,
|
|
245
|
+
pinned=pinned,
|
|
246
|
+
**message_fields,
|
|
247
|
+
)
|
|
248
|
+
self.messages.append(message)
|
|
249
|
+
return message
|
|
250
|
+
|
|
251
|
+
def add_message(self, message: Message, *, pinned: bool | None = None) -> None:
|
|
252
|
+
self.messages.append(
|
|
253
|
+
replace(message, pinned=pinned)
|
|
254
|
+
if pinned is not None
|
|
255
|
+
else message
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
def replace_messages(self, messages: list[Message]) -> None:
|
|
259
|
+
self.messages[:] = messages
|
|
260
|
+
|
|
261
|
+
def is_near_limit(self, threshold: float = 0.85) -> bool:
|
|
262
|
+
return self.token_count() >= self.max_tokens * threshold
|
|
263
|
+
|
|
264
|
+
def compress(self, preferred: str | None = None) -> None:
|
|
265
|
+
if preferred:
|
|
266
|
+
self.set_model(preferred)
|
|
267
|
+
self.messages[:] = compact_messages(
|
|
268
|
+
self.messages,
|
|
269
|
+
self.max_tokens,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def prepared(
|
|
273
|
+
self,
|
|
274
|
+
*,
|
|
275
|
+
model: str | None = None,
|
|
276
|
+
reserve_tokens: int = DEFAULT_OUTPUT_RESERVE,
|
|
277
|
+
) -> list[Message]:
|
|
278
|
+
if model:
|
|
279
|
+
self.set_model(model)
|
|
280
|
+
return compact_messages(
|
|
281
|
+
self.messages,
|
|
282
|
+
self.max_tokens,
|
|
283
|
+
reserve_tokens=reserve_tokens,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def update_limit(self, new_max: int) -> None:
|
|
287
|
+
self.max_tokens = new_max
|
|
288
|
+
|
|
289
|
+
def get_messages(self) -> list[Message]:
|
|
290
|
+
return self.messages
|
|
291
|
+
|
|
292
|
+
def token_count(self) -> int:
|
|
293
|
+
return sum(estimate_message_tokens(message) for message in self.messages)
|
|
294
|
+
|
|
295
|
+
def usage_bar(self) -> str:
|
|
296
|
+
ratio = self.token_count() / max(1, self.max_tokens)
|
|
297
|
+
filled = min(int(ratio * 30), 30)
|
|
298
|
+
bar = "#" * filled + "-" * (30 - filled)
|
|
299
|
+
return f"[{bar}] {int(ratio * 100)}%"
|
zai/core/errors.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
from rich.console import Console
|
|
2
|
+
from .runtime import print_exception
|
|
3
|
+
|
|
4
|
+
console = Console()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ZaiError(Exception):
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class NoAPIKeyError(ZaiError):
|
|
12
|
+
def __init__(self, provider: str):
|
|
13
|
+
self.provider = provider
|
|
14
|
+
super().__init__(f"No API key for {provider}. Run: zai setup")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ProviderError(ZaiError):
|
|
18
|
+
category = "provider"
|
|
19
|
+
|
|
20
|
+
def __init__(self, provider: str, message: str):
|
|
21
|
+
self.provider = provider
|
|
22
|
+
self.detail = message
|
|
23
|
+
super().__init__(f"{provider} {self.category} error: {message}")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AuthenticationError(ProviderError):
|
|
27
|
+
category = "authentication"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class QuotaError(ProviderError):
|
|
31
|
+
category = "quota"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class ModelNotFoundError(ProviderError):
|
|
35
|
+
category = "model-not-found"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ProviderTimeoutError(ProviderError):
|
|
39
|
+
category = "timeout"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class MalformedResponseError(ProviderError):
|
|
43
|
+
category = "malformed-response"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class AllModelsFailedError(ZaiError):
|
|
47
|
+
def __init__(self, last_error: str = ""):
|
|
48
|
+
message = "All models failed. Check API keys with `zai setup` or verify Ollama."
|
|
49
|
+
if last_error:
|
|
50
|
+
message += f" Last error: {last_error}"
|
|
51
|
+
super().__init__(message)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RateLimitError(QuotaError):
|
|
55
|
+
def __init__(self, provider: str):
|
|
56
|
+
super().__init__(provider, "rate limit hit")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class NetworkError(ZaiError):
|
|
60
|
+
def __init__(self, msg: str = ""):
|
|
61
|
+
super().__init__(f"Network error: {msg or 'Check your internet connection.'}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def classify_provider_error(provider: str, error: Exception) -> ZaiError:
|
|
65
|
+
"""Convert SDK/HTTP failures into stable user-facing categories."""
|
|
66
|
+
if isinstance(error, ZaiError):
|
|
67
|
+
return error
|
|
68
|
+
status = getattr(getattr(error, "response", None), "status_code", None)
|
|
69
|
+
text = str(error)
|
|
70
|
+
lowered = text.lower()
|
|
71
|
+
if status in {401, 403} or any(
|
|
72
|
+
marker in lowered for marker in ("api key", "unauthorized", "authentication")
|
|
73
|
+
):
|
|
74
|
+
return AuthenticationError(provider, text)
|
|
75
|
+
if status == 429 or any(
|
|
76
|
+
marker in lowered for marker in ("quota", "rate limit", "resource exhausted")
|
|
77
|
+
):
|
|
78
|
+
return QuotaError(provider, text)
|
|
79
|
+
if status == 404 or any(
|
|
80
|
+
marker in lowered for marker in ("model not found", "unknown model", "does not exist")
|
|
81
|
+
):
|
|
82
|
+
return ModelNotFoundError(provider, text)
|
|
83
|
+
if any(marker in lowered for marker in ("timeout", "timed out", "deadline exceeded")):
|
|
84
|
+
return ProviderTimeoutError(provider, text)
|
|
85
|
+
if any(marker in lowered for marker in ("connect", "network", "dns", "connection")):
|
|
86
|
+
return NetworkError(f"{provider}: {text}")
|
|
87
|
+
if isinstance(error, (KeyError, TypeError, ValueError)):
|
|
88
|
+
return MalformedResponseError(provider, text)
|
|
89
|
+
return ProviderError(provider, text)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class FileError(ZaiError):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def show_error(msg: str):
|
|
97
|
+
console.print(f"[red]Error:[/red] {msg}")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def show_warning(msg: str):
|
|
101
|
+
console.print(f"[yellow]Warning:[/yellow] {msg}")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def handle(func):
|
|
105
|
+
"""Decorator — catches ZaiError and unexpected errors for CLI commands."""
|
|
106
|
+
def wrapper(*args, **kwargs):
|
|
107
|
+
try:
|
|
108
|
+
return func(*args, **kwargs)
|
|
109
|
+
except NoAPIKeyError as e:
|
|
110
|
+
show_error(str(e))
|
|
111
|
+
except AllModelsFailedError as e:
|
|
112
|
+
show_error(str(e))
|
|
113
|
+
except NetworkError as e:
|
|
114
|
+
show_error(str(e))
|
|
115
|
+
except FileError as e:
|
|
116
|
+
show_error(str(e))
|
|
117
|
+
except ZaiError as e:
|
|
118
|
+
show_error(str(e))
|
|
119
|
+
except KeyboardInterrupt:
|
|
120
|
+
console.print("\n[dim]Cancelled.[/dim]")
|
|
121
|
+
except Exception as e:
|
|
122
|
+
print_exception(console, e)
|
|
123
|
+
wrapper.__name__ = func.__name__
|
|
124
|
+
wrapper.__doc__ = func.__doc__
|
|
125
|
+
return wrapper
|