zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/streaming.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from rich.console import Console
|
|
2
|
+
from rich.live import Live
|
|
3
|
+
from rich.text import Text
|
|
4
|
+
|
|
5
|
+
console = Console()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def stream_groq(
|
|
9
|
+
messages: list, system: str, model_id: str, api_key: str,
|
|
10
|
+
timeout: float = 60, retries: int = 2,
|
|
11
|
+
):
|
|
12
|
+
from groq import Groq
|
|
13
|
+
client = Groq(api_key=api_key, timeout=timeout, max_retries=retries)
|
|
14
|
+
formatted = [{"role": "system", "content": system}]
|
|
15
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
16
|
+
stream = client.chat.completions.create(
|
|
17
|
+
model=model_id,
|
|
18
|
+
messages=formatted,
|
|
19
|
+
stream=True,
|
|
20
|
+
)
|
|
21
|
+
full = ""
|
|
22
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
23
|
+
for chunk in stream:
|
|
24
|
+
delta = chunk.choices[0].delta.content or ""
|
|
25
|
+
full += delta
|
|
26
|
+
live.update(Text(full))
|
|
27
|
+
return full
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def stream_gemini(
|
|
31
|
+
messages: list, system: str, model_id: str, api_key: str,
|
|
32
|
+
timeout: float = 60,
|
|
33
|
+
):
|
|
34
|
+
from google import genai
|
|
35
|
+
from google.genai import types
|
|
36
|
+
|
|
37
|
+
client = genai.Client(
|
|
38
|
+
api_key=api_key,
|
|
39
|
+
http_options=types.HttpOptions(timeout=int(timeout * 1000)),
|
|
40
|
+
)
|
|
41
|
+
contents = [
|
|
42
|
+
types.Content(
|
|
43
|
+
role=m.role if m.role == "user" else "model",
|
|
44
|
+
parts=[types.Part.from_text(text=m.content)],
|
|
45
|
+
)
|
|
46
|
+
for m in messages
|
|
47
|
+
]
|
|
48
|
+
full = ""
|
|
49
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
50
|
+
for chunk in client.models.generate_content_stream(
|
|
51
|
+
model=model_id,
|
|
52
|
+
contents=contents,
|
|
53
|
+
config=types.GenerateContentConfig(system_instruction=system),
|
|
54
|
+
):
|
|
55
|
+
full += chunk.text or ""
|
|
56
|
+
live.update(Text(full))
|
|
57
|
+
return full
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def stream_anthropic(
|
|
61
|
+
messages: list, system: str, model_id: str, api_key: str,
|
|
62
|
+
timeout: float = 60, retries: int = 2,
|
|
63
|
+
):
|
|
64
|
+
import anthropic
|
|
65
|
+
client = anthropic.Anthropic(
|
|
66
|
+
api_key=api_key, timeout=timeout, max_retries=retries,
|
|
67
|
+
)
|
|
68
|
+
formatted = [{"role": m.role, "content": m.content} for m in messages]
|
|
69
|
+
full = ""
|
|
70
|
+
with client.messages.stream(
|
|
71
|
+
model=model_id,
|
|
72
|
+
max_tokens=8096,
|
|
73
|
+
system=system,
|
|
74
|
+
messages=formatted,
|
|
75
|
+
) as stream:
|
|
76
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
77
|
+
for delta in stream.text_stream:
|
|
78
|
+
full += delta
|
|
79
|
+
live.update(Text(full))
|
|
80
|
+
return full
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def stream_openai(
|
|
84
|
+
messages: list, system: str, model_id: str, api_key: str,
|
|
85
|
+
timeout: float = 60, retries: int = 2,
|
|
86
|
+
):
|
|
87
|
+
from openai import OpenAI
|
|
88
|
+
client = OpenAI(api_key=api_key, timeout=timeout, max_retries=retries)
|
|
89
|
+
formatted = [{"role": "system", "content": system}]
|
|
90
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
91
|
+
full = ""
|
|
92
|
+
stream = client.chat.completions.create(
|
|
93
|
+
model=model_id,
|
|
94
|
+
messages=formatted,
|
|
95
|
+
stream=True,
|
|
96
|
+
)
|
|
97
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
98
|
+
for chunk in stream:
|
|
99
|
+
delta = chunk.choices[0].delta.content or ""
|
|
100
|
+
full += delta
|
|
101
|
+
live.update(Text(full))
|
|
102
|
+
return full
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def stream_ollama(
|
|
106
|
+
messages: list, system: str, model_id: str, timeout: float = 120,
|
|
107
|
+
):
|
|
108
|
+
import httpx
|
|
109
|
+
import json
|
|
110
|
+
formatted = [{"role": "system", "content": system}]
|
|
111
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
112
|
+
full = ""
|
|
113
|
+
with httpx.stream(
|
|
114
|
+
"POST",
|
|
115
|
+
"http://localhost:11434/api/chat",
|
|
116
|
+
json={"model": model_id, "messages": formatted, "stream": True},
|
|
117
|
+
timeout=timeout,
|
|
118
|
+
) as response:
|
|
119
|
+
response.raise_for_status()
|
|
120
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
121
|
+
for line in response.iter_lines():
|
|
122
|
+
if line:
|
|
123
|
+
try:
|
|
124
|
+
data = json.loads(line)
|
|
125
|
+
delta = data.get("message", {}).get("content", "")
|
|
126
|
+
full += delta
|
|
127
|
+
live.update(Text(full))
|
|
128
|
+
except Exception:
|
|
129
|
+
pass
|
|
130
|
+
return full
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def stream_openrouter(messages: list, system: str, model_id: str, api_key: str):
|
|
134
|
+
import httpx
|
|
135
|
+
formatted = [{"role": "system", "content": system}]
|
|
136
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
137
|
+
full = ""
|
|
138
|
+
with httpx.stream(
|
|
139
|
+
"POST",
|
|
140
|
+
"https://openrouter.ai/api/v1/chat/completions",
|
|
141
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
142
|
+
json={"model": model_id, "messages": formatted, "stream": True},
|
|
143
|
+
timeout=60,
|
|
144
|
+
) as response:
|
|
145
|
+
response.raise_for_status()
|
|
146
|
+
with Live(Text(""), refresh_per_second=15, console=console) as live:
|
|
147
|
+
for line in response.iter_lines():
|
|
148
|
+
if line.startswith("data: ") and line != "data: [DONE]":
|
|
149
|
+
import json
|
|
150
|
+
try:
|
|
151
|
+
data = json.loads(line[6:])
|
|
152
|
+
delta = data["choices"][0]["delta"].get("content", "")
|
|
153
|
+
full += delta
|
|
154
|
+
live.update(Text(full))
|
|
155
|
+
except Exception:
|
|
156
|
+
pass
|
|
157
|
+
return full
|
zai/core/tool_schema.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from typing import Any, Literal
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StrictArguments(BaseModel):
|
|
7
|
+
model_config = ConfigDict(extra="forbid")
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class PathArguments(StrictArguments):
|
|
11
|
+
path: str = Field(min_length=1)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class WriteFileArguments(PathArguments):
|
|
15
|
+
content: str
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EditFileArguments(PathArguments):
|
|
19
|
+
search: str = Field(min_length=1)
|
|
20
|
+
replace: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ListFilesArguments(StrictArguments):
|
|
24
|
+
path: str = "."
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RenamePathArguments(StrictArguments):
|
|
28
|
+
source: str = Field(min_length=1)
|
|
29
|
+
destination: str = Field(min_length=1)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RunCommandArguments(StrictArguments):
|
|
33
|
+
command: str = Field(min_length=1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MCPCallArguments(StrictArguments):
|
|
37
|
+
server: str = Field(min_length=1)
|
|
38
|
+
tool: str = Field(min_length=1)
|
|
39
|
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PluginCallArguments(StrictArguments):
|
|
43
|
+
plugin: str = Field(min_length=1)
|
|
44
|
+
tool: str = Field(min_length=1)
|
|
45
|
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ToolCallEnvelope(BaseModel):
|
|
49
|
+
model_config = ConfigDict(extra="forbid")
|
|
50
|
+
|
|
51
|
+
name: Literal[
|
|
52
|
+
"write_file",
|
|
53
|
+
"read_file",
|
|
54
|
+
"edit_file",
|
|
55
|
+
"create_folder",
|
|
56
|
+
"rename_path",
|
|
57
|
+
"list_files",
|
|
58
|
+
"run_command",
|
|
59
|
+
"mcp_call",
|
|
60
|
+
"plugin_call",
|
|
61
|
+
]
|
|
62
|
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
ARGUMENT_MODELS = {
|
|
66
|
+
"write_file": WriteFileArguments,
|
|
67
|
+
"read_file": PathArguments,
|
|
68
|
+
"edit_file": EditFileArguments,
|
|
69
|
+
"create_folder": PathArguments,
|
|
70
|
+
"rename_path": RenamePathArguments,
|
|
71
|
+
"list_files": ListFilesArguments,
|
|
72
|
+
"run_command": RunCommandArguments,
|
|
73
|
+
"mcp_call": MCPCallArguments,
|
|
74
|
+
"plugin_call": PluginCallArguments,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
TOOL_DESCRIPTIONS = {
|
|
78
|
+
"write_file": "Create a new text file. Use edit_file for an existing file.",
|
|
79
|
+
"read_file": "Read a text file inside the current project.",
|
|
80
|
+
"edit_file": "Replace one exact text occurrence in an existing file.",
|
|
81
|
+
"create_folder": "Create a directory inside the current project.",
|
|
82
|
+
"rename_path": "Rename or move a file or directory inside the current project.",
|
|
83
|
+
"list_files": "List files in a project directory.",
|
|
84
|
+
"run_command": "Run an approved direct command in the current project.",
|
|
85
|
+
"mcp_call": "Call a tool exposed by a connected MCP server.",
|
|
86
|
+
"plugin_call": "Call a tool exposed by an enabled zai plugin.",
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_tool_definitions() -> list[dict[str, Any]]:
|
|
91
|
+
"""Return provider-neutral JSON Schema function definitions."""
|
|
92
|
+
definitions = []
|
|
93
|
+
for name, model in ARGUMENT_MODELS.items():
|
|
94
|
+
schema = model.model_json_schema()
|
|
95
|
+
schema.pop("title", None)
|
|
96
|
+
schema["additionalProperties"] = False
|
|
97
|
+
# Native providers are most reliable when every property is explicit.
|
|
98
|
+
schema["required"] = list(schema.get("properties", {}).keys())
|
|
99
|
+
definitions.append({
|
|
100
|
+
"name": name,
|
|
101
|
+
"description": TOOL_DESCRIPTIONS[name],
|
|
102
|
+
"parameters": schema,
|
|
103
|
+
})
|
|
104
|
+
return definitions
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def legacy_tool_instructions() -> str:
|
|
108
|
+
"""Instructions for providers that do not expose native function calling."""
|
|
109
|
+
lines = [
|
|
110
|
+
"Your API does not support native tools here. Emit each call as JSON:",
|
|
111
|
+
'<tool_call>{"name":"tool_name","arguments":{...}}</tool_call>',
|
|
112
|
+
"Valid tools:",
|
|
113
|
+
]
|
|
114
|
+
for tool in get_tool_definitions():
|
|
115
|
+
properties = ", ".join(tool["parameters"].get("properties", {}))
|
|
116
|
+
lines.append(f"- {tool['name']}: {properties}")
|
|
117
|
+
lines.append("JSON must use double quotes and must not contain trailing commas.")
|
|
118
|
+
return "\n".join(lines)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def validate_tool_call(data: Any) -> tuple[str, dict[str, Any]]:
|
|
122
|
+
"""Validate a structured tool call and return normalized arguments."""
|
|
123
|
+
envelope = ToolCallEnvelope.model_validate(data)
|
|
124
|
+
arguments = ARGUMENT_MODELS[envelope.name].model_validate(envelope.arguments)
|
|
125
|
+
return envelope.name, arguments.model_dump()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def format_validation_error(error: ValidationError) -> str:
|
|
129
|
+
issues = []
|
|
130
|
+
for item in error.errors(include_url=False):
|
|
131
|
+
location = ".".join(str(part) for part in item["loc"])
|
|
132
|
+
issues.append(f"{location}: {item['msg']}")
|
|
133
|
+
return "; ".join(issues)
|