zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
from .base import BaseProvider, Message, Response
|
|
3
|
+
from ..config import get_api_key
|
|
4
|
+
from ..core.errors import (
|
|
5
|
+
NoAPIKeyError,
|
|
6
|
+
AuthenticationError,
|
|
7
|
+
RateLimitError,
|
|
8
|
+
NetworkError,
|
|
9
|
+
classify_provider_error,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OpenRouterProvider(BaseProvider):
|
|
14
|
+
name = "openrouter"
|
|
15
|
+
model_id = "meta-llama/llama-3.3-70b-instruct:free"
|
|
16
|
+
context_window = 128000
|
|
17
|
+
|
|
18
|
+
def is_available(self) -> bool:
|
|
19
|
+
return bool(get_api_key("openrouter"))
|
|
20
|
+
|
|
21
|
+
def chat(
|
|
22
|
+
self,
|
|
23
|
+
messages: list[Message],
|
|
24
|
+
system: str = "",
|
|
25
|
+
tools: list[dict] | None = None,
|
|
26
|
+
) -> Response:
|
|
27
|
+
key = get_api_key("openrouter")
|
|
28
|
+
if not key:
|
|
29
|
+
raise NoAPIKeyError("openrouter")
|
|
30
|
+
|
|
31
|
+
formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
|
|
32
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
33
|
+
|
|
34
|
+
for attempt in range(self.retries + 1):
|
|
35
|
+
try:
|
|
36
|
+
response = httpx.post(
|
|
37
|
+
"https://openrouter.ai/api/v1/chat/completions",
|
|
38
|
+
headers={"Authorization": f"Bearer {key}"},
|
|
39
|
+
json={"model": self.model_id, "messages": formatted},
|
|
40
|
+
timeout=self.timeout,
|
|
41
|
+
)
|
|
42
|
+
if response.status_code == 429:
|
|
43
|
+
raise RateLimitError("openrouter")
|
|
44
|
+
if response.status_code == 401:
|
|
45
|
+
raise AuthenticationError(
|
|
46
|
+
"openrouter", "invalid API credentials"
|
|
47
|
+
)
|
|
48
|
+
try:
|
|
49
|
+
response.raise_for_status()
|
|
50
|
+
data = response.json()
|
|
51
|
+
except Exception as error:
|
|
52
|
+
raise classify_provider_error("openrouter", error) from error
|
|
53
|
+
return Response(
|
|
54
|
+
content=data["choices"][0]["message"]["content"],
|
|
55
|
+
model=self.model_id,
|
|
56
|
+
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
|
57
|
+
)
|
|
58
|
+
except (httpx.ConnectError, httpx.TimeoutException) as error:
|
|
59
|
+
if attempt == self.retries:
|
|
60
|
+
raise NetworkError(str(error))
|
|
61
|
+
except (RateLimitError, NoAPIKeyError):
|
|
62
|
+
raise
|
|
63
|
+
raise NetworkError("OpenRouter request failed")
|
zai/providers/qwen.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
from .base import BaseProvider, Message, Response
|
|
3
|
+
from ..config import get_api_key
|
|
4
|
+
from ..core.errors import NetworkError, NoAPIKeyError, classify_provider_error
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class QwenProvider(BaseProvider):
|
|
8
|
+
name = "qwen"
|
|
9
|
+
model_id = "qwen-turbo"
|
|
10
|
+
context_window = 1000000
|
|
11
|
+
|
|
12
|
+
def is_available(self) -> bool:
|
|
13
|
+
return bool(get_api_key("qwen"))
|
|
14
|
+
|
|
15
|
+
def chat(
|
|
16
|
+
self,
|
|
17
|
+
messages: list[Message],
|
|
18
|
+
system: str = "",
|
|
19
|
+
tools: list[dict] | None = None,
|
|
20
|
+
) -> Response:
|
|
21
|
+
key = get_api_key("qwen")
|
|
22
|
+
if not key:
|
|
23
|
+
raise NoAPIKeyError("qwen")
|
|
24
|
+
formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
|
|
25
|
+
formatted += [{"role": m.role, "content": m.content} for m in messages]
|
|
26
|
+
for attempt in range(self.retries + 1):
|
|
27
|
+
try:
|
|
28
|
+
response = httpx.post(
|
|
29
|
+
"https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
|
30
|
+
headers={"Authorization": f"Bearer {key}"},
|
|
31
|
+
json={"model": self.model_id, "messages": formatted},
|
|
32
|
+
timeout=self.timeout,
|
|
33
|
+
)
|
|
34
|
+
try:
|
|
35
|
+
response.raise_for_status()
|
|
36
|
+
data = response.json()
|
|
37
|
+
except Exception as error:
|
|
38
|
+
raise classify_provider_error("qwen", error) from error
|
|
39
|
+
return Response(
|
|
40
|
+
content=data["choices"][0]["message"]["content"],
|
|
41
|
+
model=self.model_id,
|
|
42
|
+
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
|
43
|
+
)
|
|
44
|
+
except (httpx.ConnectError, httpx.TimeoutException) as error:
|
|
45
|
+
if attempt == self.retries:
|
|
46
|
+
raise NetworkError(str(error))
|
|
47
|
+
raise NetworkError("Qwen request failed")
|
zai/skills/__init__.py
ADDED
|
File without changes
|
zai/skills/registry.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from ..tools.files import read_file
|
|
2
|
+
|
|
3
|
+
SKILLS = {
|
|
4
|
+
"review": {
|
|
5
|
+
"description": "Review code and suggest improvements",
|
|
6
|
+
"prompt": lambda ctx: f"Do a thorough code review of the following. Point out bugs, security issues, and improvements:\n\n{ctx}",
|
|
7
|
+
},
|
|
8
|
+
"test": {
|
|
9
|
+
"description": "Generate unit tests",
|
|
10
|
+
"prompt": lambda ctx: f"Write comprehensive unit tests for this code using pytest:\n\n{ctx}",
|
|
11
|
+
},
|
|
12
|
+
"commit": {
|
|
13
|
+
"description": "Generate git commit message",
|
|
14
|
+
"prompt": lambda ctx: f"Write a clear, concise git commit message for these changes (conventional commits format):\n\n{ctx}",
|
|
15
|
+
},
|
|
16
|
+
"docs": {
|
|
17
|
+
"description": "Generate documentation",
|
|
18
|
+
"prompt": lambda ctx: f"Write clear documentation (docstrings + README section) for this code:\n\n{ctx}",
|
|
19
|
+
},
|
|
20
|
+
"fix": {
|
|
21
|
+
"description": "Find and fix bugs",
|
|
22
|
+
"prompt": lambda ctx: f"Find all bugs in this code and provide the fixed version with explanation:\n\n{ctx}",
|
|
23
|
+
},
|
|
24
|
+
"explain": {
|
|
25
|
+
"description": "Explain code clearly",
|
|
26
|
+
"prompt": lambda ctx: f"Explain this code clearly, step by step, for someone learning:\n\n{ctx}",
|
|
27
|
+
},
|
|
28
|
+
"refactor": {
|
|
29
|
+
"description": "Refactor code for quality",
|
|
30
|
+
"prompt": lambda ctx: f"Refactor this code for better readability, performance, and best practices:\n\n{ctx}",
|
|
31
|
+
},
|
|
32
|
+
"summarize": {
|
|
33
|
+
"description": "Summarize text or file",
|
|
34
|
+
"prompt": lambda ctx: f"Summarize the following in clear bullet points:\n\n{ctx}",
|
|
35
|
+
},
|
|
36
|
+
"translate": {
|
|
37
|
+
"description": "Translate text",
|
|
38
|
+
"prompt": lambda ctx: f"Translate the following text to English (or specify target language):\n\n{ctx}",
|
|
39
|
+
},
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_skill_prompt(skill_name: str, file_path: str = None, text: str = None) -> str | None:
|
|
44
|
+
skill = SKILLS.get(skill_name)
|
|
45
|
+
if not skill:
|
|
46
|
+
return None
|
|
47
|
+
if file_path:
|
|
48
|
+
content = read_file(file_path)
|
|
49
|
+
return skill["prompt"](content)
|
|
50
|
+
if text:
|
|
51
|
+
return skill["prompt"](text)
|
|
52
|
+
return skill["prompt"]("")
|
zai/tools/__init__.py
ADDED
|
File without changes
|
zai/tools/browser.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
import ipaddress
|
|
2
|
+
import socket
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from urllib.parse import urlsplit
|
|
5
|
+
|
|
6
|
+
from ..core.errors import FileError, ZaiError
|
|
7
|
+
from ..core.security import resolve_project_path
|
|
8
|
+
|
|
9
|
+
MAX_DOCUMENT_BYTES = 2_000_000
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BrowserError(ZaiError):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_playwright():
|
|
17
|
+
try:
|
|
18
|
+
from playwright.sync_api import sync_playwright
|
|
19
|
+
|
|
20
|
+
return sync_playwright
|
|
21
|
+
except ImportError:
|
|
22
|
+
raise BrowserError(
|
|
23
|
+
"Playwright not installed. Run: pip install playwright && "
|
|
24
|
+
"playwright install chromium"
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _validate_public_url(url: str) -> str:
|
|
29
|
+
"""Allow only credential-free public HTTP(S) destinations."""
|
|
30
|
+
try:
|
|
31
|
+
parsed = urlsplit(url)
|
|
32
|
+
port = parsed.port
|
|
33
|
+
except ValueError as error:
|
|
34
|
+
raise BrowserError(f"Invalid URL: {error}")
|
|
35
|
+
if parsed.scheme.lower() not in {"http", "https"}:
|
|
36
|
+
raise BrowserError("Only http:// and https:// URLs are allowed")
|
|
37
|
+
if not parsed.hostname:
|
|
38
|
+
raise BrowserError("URL must include a hostname")
|
|
39
|
+
if parsed.username is not None or parsed.password is not None:
|
|
40
|
+
raise BrowserError("URLs containing credentials are not allowed")
|
|
41
|
+
|
|
42
|
+
hostname = parsed.hostname.rstrip(".").lower()
|
|
43
|
+
if hostname == "localhost" or hostname.endswith(".localhost"):
|
|
44
|
+
raise BrowserError("Local and private network URLs are blocked")
|
|
45
|
+
try:
|
|
46
|
+
addresses = {
|
|
47
|
+
item[4][0]
|
|
48
|
+
for item in socket.getaddrinfo(
|
|
49
|
+
hostname,
|
|
50
|
+
port or (443 if parsed.scheme.lower() == "https" else 80),
|
|
51
|
+
type=socket.SOCK_STREAM,
|
|
52
|
+
)
|
|
53
|
+
}
|
|
54
|
+
except socket.gaierror as error:
|
|
55
|
+
raise BrowserError(f"Cannot resolve URL hostname: {error}")
|
|
56
|
+
if not addresses:
|
|
57
|
+
raise BrowserError("URL hostname did not resolve")
|
|
58
|
+
for address in addresses:
|
|
59
|
+
try:
|
|
60
|
+
if not ipaddress.ip_address(address).is_global:
|
|
61
|
+
raise BrowserError("Local and private network URLs are blocked")
|
|
62
|
+
except ValueError:
|
|
63
|
+
raise BrowserError("URL resolved to an invalid network address")
|
|
64
|
+
return url
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _safe_output_path(path: str, project_dir: str = ".") -> Path:
|
|
68
|
+
if Path(path).suffix.lower() not in {".png", ".jpg", ".jpeg"}:
|
|
69
|
+
raise BrowserError("Screenshot output must use .png, .jpg, or .jpeg")
|
|
70
|
+
try:
|
|
71
|
+
output = resolve_project_path(project_dir, path)
|
|
72
|
+
except FileError as error:
|
|
73
|
+
raise BrowserError(str(error))
|
|
74
|
+
output.parent.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
return output
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _prepare_page(browser, initial_url: str):
|
|
79
|
+
context = browser.new_context(
|
|
80
|
+
service_workers="block",
|
|
81
|
+
accept_downloads=False,
|
|
82
|
+
)
|
|
83
|
+
page = context.new_page()
|
|
84
|
+
|
|
85
|
+
def guard_request(route):
|
|
86
|
+
try:
|
|
87
|
+
_validate_public_url(route.request.url)
|
|
88
|
+
except BrowserError:
|
|
89
|
+
route.abort("blockedbyclient")
|
|
90
|
+
return
|
|
91
|
+
route.continue_()
|
|
92
|
+
|
|
93
|
+
page.route("http://**/*", guard_request)
|
|
94
|
+
page.route("https://**/*", guard_request)
|
|
95
|
+
response = page.goto(
|
|
96
|
+
_validate_public_url(initial_url),
|
|
97
|
+
timeout=30000,
|
|
98
|
+
wait_until="domcontentloaded",
|
|
99
|
+
)
|
|
100
|
+
if response is not None:
|
|
101
|
+
content_length = response.headers.get("content-length")
|
|
102
|
+
if content_length:
|
|
103
|
+
try:
|
|
104
|
+
if int(content_length) > MAX_DOCUMENT_BYTES:
|
|
105
|
+
raise BrowserError("Webpage exceeded the 2 MB document limit")
|
|
106
|
+
except ValueError:
|
|
107
|
+
pass
|
|
108
|
+
_validate_public_url(page.url)
|
|
109
|
+
return context, page
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def scrape(url: str) -> str:
|
|
113
|
+
"""Open a public URL and return bounded visible text content."""
|
|
114
|
+
sync_playwright = _get_playwright()
|
|
115
|
+
try:
|
|
116
|
+
with sync_playwright() as playwright:
|
|
117
|
+
browser = playwright.chromium.launch(headless=True)
|
|
118
|
+
try:
|
|
119
|
+
context, page = _prepare_page(browser, url)
|
|
120
|
+
try:
|
|
121
|
+
text = page.evaluate("() => document.body.innerText")
|
|
122
|
+
return text[:8000] if text else "No content found."
|
|
123
|
+
finally:
|
|
124
|
+
context.close()
|
|
125
|
+
finally:
|
|
126
|
+
browser.close()
|
|
127
|
+
except BrowserError:
|
|
128
|
+
raise
|
|
129
|
+
except Exception as error:
|
|
130
|
+
raise BrowserError(f"Cannot scrape {url}: {error}")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def screenshot(url: str, path: str = "screenshot.png", project_dir: str = ".") -> str:
|
|
134
|
+
"""Take a screenshot of a public URL inside the current project."""
|
|
135
|
+
output = _safe_output_path(path, project_dir)
|
|
136
|
+
sync_playwright = _get_playwright()
|
|
137
|
+
try:
|
|
138
|
+
with sync_playwright() as playwright:
|
|
139
|
+
browser = playwright.chromium.launch(headless=True)
|
|
140
|
+
try:
|
|
141
|
+
context, page = _prepare_page(browser, url)
|
|
142
|
+
try:
|
|
143
|
+
page.screenshot(path=str(output), full_page=True)
|
|
144
|
+
return str(output)
|
|
145
|
+
finally:
|
|
146
|
+
context.close()
|
|
147
|
+
finally:
|
|
148
|
+
browser.close()
|
|
149
|
+
except BrowserError:
|
|
150
|
+
raise
|
|
151
|
+
except Exception as error:
|
|
152
|
+
raise BrowserError(f"Screenshot failed: {error}")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def click_and_get(url: str, selector: str) -> str:
|
|
156
|
+
"""Click an element on a public page and return bounded content."""
|
|
157
|
+
sync_playwright = _get_playwright()
|
|
158
|
+
try:
|
|
159
|
+
with sync_playwright() as playwright:
|
|
160
|
+
browser = playwright.chromium.launch(headless=True)
|
|
161
|
+
try:
|
|
162
|
+
context, page = _prepare_page(browser, url)
|
|
163
|
+
try:
|
|
164
|
+
page.click(selector, timeout=5000)
|
|
165
|
+
page.wait_for_load_state("networkidle", timeout=10000)
|
|
166
|
+
_validate_public_url(page.url)
|
|
167
|
+
text = page.evaluate("() => document.body.innerText")
|
|
168
|
+
return text[:8000] if text else "No content after click."
|
|
169
|
+
finally:
|
|
170
|
+
context.close()
|
|
171
|
+
finally:
|
|
172
|
+
browser.close()
|
|
173
|
+
except BrowserError:
|
|
174
|
+
raise
|
|
175
|
+
except Exception as error:
|
|
176
|
+
raise BrowserError(f"Click failed: {error}")
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def fill_form(url: str, fields: dict[str, str], submit_selector: str = None) -> str:
|
|
180
|
+
"""Fill a form on a public page and optionally submit it."""
|
|
181
|
+
sync_playwright = _get_playwright()
|
|
182
|
+
try:
|
|
183
|
+
with sync_playwright() as playwright:
|
|
184
|
+
browser = playwright.chromium.launch(headless=True)
|
|
185
|
+
try:
|
|
186
|
+
context, page = _prepare_page(browser, url)
|
|
187
|
+
try:
|
|
188
|
+
for selector, value in fields.items():
|
|
189
|
+
page.fill(selector, value)
|
|
190
|
+
if submit_selector:
|
|
191
|
+
page.click(submit_selector)
|
|
192
|
+
page.wait_for_load_state("networkidle", timeout=10000)
|
|
193
|
+
_validate_public_url(page.url)
|
|
194
|
+
text = page.evaluate("() => document.body.innerText")
|
|
195
|
+
return text[:5000] if text else "Done."
|
|
196
|
+
finally:
|
|
197
|
+
context.close()
|
|
198
|
+
finally:
|
|
199
|
+
browser.close()
|
|
200
|
+
except BrowserError:
|
|
201
|
+
raise
|
|
202
|
+
except Exception as error:
|
|
203
|
+
raise BrowserError(f"Form fill failed: {error}")
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def run_js(url: str, script: str) -> str:
|
|
207
|
+
"""Run JavaScript on a public page and return a bounded result."""
|
|
208
|
+
sync_playwright = _get_playwright()
|
|
209
|
+
try:
|
|
210
|
+
with sync_playwright() as playwright:
|
|
211
|
+
browser = playwright.chromium.launch(headless=True)
|
|
212
|
+
try:
|
|
213
|
+
context, page = _prepare_page(browser, url)
|
|
214
|
+
try:
|
|
215
|
+
result = page.evaluate(script)
|
|
216
|
+
return str(result)[:8000]
|
|
217
|
+
finally:
|
|
218
|
+
context.close()
|
|
219
|
+
finally:
|
|
220
|
+
browser.close()
|
|
221
|
+
except BrowserError:
|
|
222
|
+
raise
|
|
223
|
+
except Exception as error:
|
|
224
|
+
raise BrowserError(f"JS execution failed: {error}")
|
zai/tools/code_runner.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
import tempfile
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def run_python(code: str, timeout: int = 30) -> str:
|
|
8
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
9
|
+
f.write(code)
|
|
10
|
+
tmp = f.name
|
|
11
|
+
try:
|
|
12
|
+
result = subprocess.run(
|
|
13
|
+
[sys.executable, tmp],
|
|
14
|
+
capture_output=True, text=True, timeout=timeout
|
|
15
|
+
)
|
|
16
|
+
out = result.stdout.strip()
|
|
17
|
+
err = result.stderr.strip()
|
|
18
|
+
if err:
|
|
19
|
+
return f"Output:\n{out}\n\nErrors:\n{err}" if out else f"Error:\n{err}"
|
|
20
|
+
return out or "No output."
|
|
21
|
+
except subprocess.TimeoutExpired:
|
|
22
|
+
return f"Error: Code timed out after {timeout} seconds."
|
|
23
|
+
except Exception as e:
|
|
24
|
+
return f"Error: {e}"
|
|
25
|
+
finally:
|
|
26
|
+
Path(tmp).unlink(missing_ok=True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def run_file(path: str, timeout: int = 30) -> str:
|
|
30
|
+
from ..core.errors import FileError
|
|
31
|
+
target = Path(path)
|
|
32
|
+
if not target.exists():
|
|
33
|
+
raise FileError(f"File not found: {path}")
|
|
34
|
+
if not target.is_file():
|
|
35
|
+
raise FileError(f"Not a file: {path}")
|
|
36
|
+
try:
|
|
37
|
+
result = subprocess.run(
|
|
38
|
+
[sys.executable, str(target)],
|
|
39
|
+
capture_output=True, text=True, timeout=timeout
|
|
40
|
+
)
|
|
41
|
+
out = result.stdout.strip()
|
|
42
|
+
err = result.stderr.strip()
|
|
43
|
+
if err:
|
|
44
|
+
return f"Output:\n{out}\n\nErrors:\n{err}" if out else f"Error:\n{err}"
|
|
45
|
+
return out or "No output."
|
|
46
|
+
except subprocess.TimeoutExpired:
|
|
47
|
+
return f"Error: Timed out after {timeout}s."
|
|
48
|
+
except Exception as e:
|
|
49
|
+
return f"Error: {e}"
|
zai/tools/files.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from ..core.errors import FileError
|
|
3
|
+
from ..core.storage import atomic_write_text
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def read_file(path: str) -> str:
|
|
7
|
+
p = Path(path)
|
|
8
|
+
if not p.exists():
|
|
9
|
+
raise FileError(f"File not found: {path}")
|
|
10
|
+
if not p.is_file():
|
|
11
|
+
raise FileError(f"Not a file: {path}")
|
|
12
|
+
try:
|
|
13
|
+
return p.read_text(encoding="utf-8", errors="ignore")
|
|
14
|
+
except PermissionError:
|
|
15
|
+
raise FileError(f"Permission denied: {path}")
|
|
16
|
+
except Exception as e:
|
|
17
|
+
raise FileError(f"Cannot read {path}: {e}")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def write_file(path: str, content: str):
|
|
21
|
+
p = Path(path)
|
|
22
|
+
try:
|
|
23
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
atomic_write_text(p, content, mode=0o644, lock=False)
|
|
25
|
+
except PermissionError:
|
|
26
|
+
raise FileError(f"Permission denied: {path}")
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise FileError(f"Cannot write {path}: {e}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def edit_file(path: str, old: str, new: str) -> bool:
|
|
32
|
+
p = Path(path)
|
|
33
|
+
if not p.exists():
|
|
34
|
+
raise FileError(f"File not found: {path}")
|
|
35
|
+
content = p.read_text(encoding="utf-8")
|
|
36
|
+
if old not in content:
|
|
37
|
+
return False
|
|
38
|
+
atomic_write_text(
|
|
39
|
+
p,
|
|
40
|
+
content.replace(old, new, 1),
|
|
41
|
+
mode=0o644,
|
|
42
|
+
lock=False,
|
|
43
|
+
)
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def list_files(directory: str = ".", pattern: str = "*") -> list[str]:
|
|
48
|
+
d = Path(directory)
|
|
49
|
+
if not d.exists():
|
|
50
|
+
raise FileError(f"Directory not found: {directory}")
|
|
51
|
+
if not d.is_dir():
|
|
52
|
+
raise FileError(f"Not a directory: {directory}")
|
|
53
|
+
return [str(p) for p in d.rglob(pattern) if p.is_file()]
|
zai/tools/git.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def run_git(args: list[str]) -> str:
|
|
5
|
+
try:
|
|
6
|
+
result = subprocess.run(
|
|
7
|
+
["git"] + args,
|
|
8
|
+
capture_output=True, text=True, timeout=30
|
|
9
|
+
)
|
|
10
|
+
return result.stdout.strip() or result.stderr.strip()
|
|
11
|
+
except FileNotFoundError:
|
|
12
|
+
return "Error: git not found. Please install git."
|
|
13
|
+
except Exception as e:
|
|
14
|
+
return f"Error: {e}"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_diff() -> str:
|
|
18
|
+
return run_git(["diff", "HEAD"])
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_staged_diff() -> str:
|
|
22
|
+
return run_git(["diff", "--staged"])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_status() -> str:
|
|
26
|
+
return run_git(["status", "--short"])
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_log(n: int = 10) -> str:
|
|
30
|
+
return run_git(["log", f"-{n}", "--oneline"])
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def commit(message: str) -> str:
|
|
34
|
+
return run_git(["commit", "-m", message])
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_branch() -> str:
|
|
38
|
+
return run_git(["branch", "--show-current"])
|
zai/tools/search.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Bounded web-search provider adapters."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from html.parser import HTMLParser
|
|
6
|
+
from urllib.parse import parse_qs, unquote, urlsplit
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from ..core.errors import NetworkError
|
|
11
|
+
|
|
12
|
+
MAX_SEARCH_RESPONSE_BYTES = 1_000_000
|
|
13
|
+
MAX_RESULTS = 8
|
|
14
|
+
DUCKDUCKGO_HTML_URL = "https://html.duckduckgo.com/html/"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class SearchResult:
|
|
19
|
+
title: str
|
|
20
|
+
url: str
|
|
21
|
+
snippet: str = ""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SearchProvider:
|
|
25
|
+
"""Interface implemented by general web-search backends."""
|
|
26
|
+
|
|
27
|
+
name = "search"
|
|
28
|
+
|
|
29
|
+
def search(self, query: str, *, limit: int = MAX_RESULTS) -> list[SearchResult]:
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class _DuckDuckGoHTMLParser(HTMLParser):
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
super().__init__(convert_charrefs=True)
|
|
36
|
+
self.results: list[SearchResult] = []
|
|
37
|
+
self._title_parts: list[str] | None = None
|
|
38
|
+
self._snippet_parts: list[str] | None = None
|
|
39
|
+
self._url = ""
|
|
40
|
+
|
|
41
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
42
|
+
attributes = dict(attrs)
|
|
43
|
+
classes = set((attributes.get("class") or "").split())
|
|
44
|
+
if tag == "a" and "result__a" in classes:
|
|
45
|
+
self._url = _unwrap_duckduckgo_url(attributes.get("href") or "")
|
|
46
|
+
self._title_parts = []
|
|
47
|
+
elif tag in {"a", "div"} and "result__snippet" in classes:
|
|
48
|
+
self._snippet_parts = []
|
|
49
|
+
|
|
50
|
+
def handle_data(self, data: str) -> None:
|
|
51
|
+
if self._title_parts is not None:
|
|
52
|
+
self._title_parts.append(data)
|
|
53
|
+
if self._snippet_parts is not None:
|
|
54
|
+
self._snippet_parts.append(data)
|
|
55
|
+
|
|
56
|
+
def handle_endtag(self, tag: str) -> None:
|
|
57
|
+
if tag == "a" and self._title_parts is not None:
|
|
58
|
+
title = " ".join("".join(self._title_parts).split())
|
|
59
|
+
if title and self._url:
|
|
60
|
+
self.results.append(SearchResult(title=title, url=self._url))
|
|
61
|
+
self._title_parts = None
|
|
62
|
+
self._url = ""
|
|
63
|
+
if tag in {"a", "div"} and self._snippet_parts is not None:
|
|
64
|
+
snippet = " ".join("".join(self._snippet_parts).split())
|
|
65
|
+
if snippet and self.results:
|
|
66
|
+
previous = self.results[-1]
|
|
67
|
+
self.results[-1] = SearchResult(
|
|
68
|
+
title=previous.title,
|
|
69
|
+
url=previous.url,
|
|
70
|
+
snippet=snippet,
|
|
71
|
+
)
|
|
72
|
+
self._snippet_parts = None
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _unwrap_duckduckgo_url(url: str) -> str:
|
|
76
|
+
parsed = urlsplit(url)
|
|
77
|
+
redirected = parse_qs(parsed.query).get("uddg")
|
|
78
|
+
return unquote(redirected[0]) if redirected else url
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _read_bounded_response(
|
|
82
|
+
client: httpx.Client,
|
|
83
|
+
url: str,
|
|
84
|
+
*,
|
|
85
|
+
data: dict[str, str],
|
|
86
|
+
) -> str:
|
|
87
|
+
with client.stream("POST", url, data=data) as response:
|
|
88
|
+
response.raise_for_status()
|
|
89
|
+
content_length = response.headers.get("content-length")
|
|
90
|
+
if content_length and int(content_length) > MAX_SEARCH_RESPONSE_BYTES:
|
|
91
|
+
raise NetworkError("Search response exceeded the 1 MB limit.")
|
|
92
|
+
body = bytearray()
|
|
93
|
+
for chunk in response.iter_bytes():
|
|
94
|
+
body.extend(chunk)
|
|
95
|
+
if len(body) > MAX_SEARCH_RESPONSE_BYTES:
|
|
96
|
+
raise NetworkError("Search response exceeded the 1 MB limit.")
|
|
97
|
+
return body.decode(response.encoding or "utf-8", errors="replace")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class DuckDuckGoHTMLSearchProvider(SearchProvider):
|
|
101
|
+
"""General web results from DuckDuckGo's HTML search endpoint."""
|
|
102
|
+
|
|
103
|
+
name = "duckduckgo"
|
|
104
|
+
|
|
105
|
+
def search(self, query: str, *, limit: int = MAX_RESULTS) -> list[SearchResult]:
|
|
106
|
+
query = " ".join(query.split())
|
|
107
|
+
if not query:
|
|
108
|
+
raise NetworkError("Search query cannot be empty.")
|
|
109
|
+
headers = {
|
|
110
|
+
"User-Agent": "zai-cli/0.1 (+https://github.com/zai-cli/zai)",
|
|
111
|
+
"Accept": "text/html,application/xhtml+xml",
|
|
112
|
+
}
|
|
113
|
+
try:
|
|
114
|
+
with httpx.Client(
|
|
115
|
+
headers=headers,
|
|
116
|
+
timeout=httpx.Timeout(10.0),
|
|
117
|
+
follow_redirects=True,
|
|
118
|
+
) as client:
|
|
119
|
+
html = _read_bounded_response(
|
|
120
|
+
client,
|
|
121
|
+
DUCKDUCKGO_HTML_URL,
|
|
122
|
+
data={"q": query},
|
|
123
|
+
)
|
|
124
|
+
except NetworkError:
|
|
125
|
+
raise
|
|
126
|
+
except httpx.ConnectError as error:
|
|
127
|
+
raise NetworkError("Cannot connect. Check your internet.") from error
|
|
128
|
+
except httpx.TimeoutException as error:
|
|
129
|
+
raise NetworkError("Search timed out.") from error
|
|
130
|
+
except httpx.HTTPError as error:
|
|
131
|
+
raise NetworkError(f"Search provider error: {error}") from error
|
|
132
|
+
|
|
133
|
+
parser = _DuckDuckGoHTMLParser()
|
|
134
|
+
parser.feed(html)
|
|
135
|
+
return parser.results[: max(1, min(limit, MAX_RESULTS))]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _format_results(query: str, results: list[SearchResult]) -> str:
|
|
139
|
+
if not results:
|
|
140
|
+
return f"No results found for: {query}"
|
|
141
|
+
lines = []
|
|
142
|
+
for index, result in enumerate(results, 1):
|
|
143
|
+
lines.append(f"{index}. {result.title}\n {result.url}")
|
|
144
|
+
if result.snippet:
|
|
145
|
+
lines.append(f" {result.snippet}")
|
|
146
|
+
return "\n".join(lines)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def web_search(query: str, provider: SearchProvider | None = None) -> str:
|
|
150
|
+
"""Search through an injectable provider and return bounded plain text."""
|
|
151
|
+
backend = provider or DuckDuckGoHTMLSearchProvider()
|
|
152
|
+
try:
|
|
153
|
+
return _format_results(query, backend.search(query))
|
|
154
|
+
except NetworkError:
|
|
155
|
+
raise
|
|
156
|
+
except Exception as error:
|
|
157
|
+
raise NetworkError(f"Search failed: {error}") from error
|