zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,63 @@
1
+ import httpx
2
+ from .base import BaseProvider, Message, Response
3
+ from ..config import get_api_key
4
+ from ..core.errors import (
5
+ NoAPIKeyError,
6
+ AuthenticationError,
7
+ RateLimitError,
8
+ NetworkError,
9
+ classify_provider_error,
10
+ )
11
+
12
+
13
+ class OpenRouterProvider(BaseProvider):
14
+ name = "openrouter"
15
+ model_id = "meta-llama/llama-3.3-70b-instruct:free"
16
+ context_window = 128000
17
+
18
+ def is_available(self) -> bool:
19
+ return bool(get_api_key("openrouter"))
20
+
21
+ def chat(
22
+ self,
23
+ messages: list[Message],
24
+ system: str = "",
25
+ tools: list[dict] | None = None,
26
+ ) -> Response:
27
+ key = get_api_key("openrouter")
28
+ if not key:
29
+ raise NoAPIKeyError("openrouter")
30
+
31
+ formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
32
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
33
+
34
+ for attempt in range(self.retries + 1):
35
+ try:
36
+ response = httpx.post(
37
+ "https://openrouter.ai/api/v1/chat/completions",
38
+ headers={"Authorization": f"Bearer {key}"},
39
+ json={"model": self.model_id, "messages": formatted},
40
+ timeout=self.timeout,
41
+ )
42
+ if response.status_code == 429:
43
+ raise RateLimitError("openrouter")
44
+ if response.status_code == 401:
45
+ raise AuthenticationError(
46
+ "openrouter", "invalid API credentials"
47
+ )
48
+ try:
49
+ response.raise_for_status()
50
+ data = response.json()
51
+ except Exception as error:
52
+ raise classify_provider_error("openrouter", error) from error
53
+ return Response(
54
+ content=data["choices"][0]["message"]["content"],
55
+ model=self.model_id,
56
+ tokens_used=data.get("usage", {}).get("total_tokens", 0),
57
+ )
58
+ except (httpx.ConnectError, httpx.TimeoutException) as error:
59
+ if attempt == self.retries:
60
+ raise NetworkError(str(error))
61
+ except (RateLimitError, NoAPIKeyError):
62
+ raise
63
+ raise NetworkError("OpenRouter request failed")
zai/providers/qwen.py ADDED
@@ -0,0 +1,47 @@
1
+ import httpx
2
+ from .base import BaseProvider, Message, Response
3
+ from ..config import get_api_key
4
+ from ..core.errors import NetworkError, NoAPIKeyError, classify_provider_error
5
+
6
+
7
+ class QwenProvider(BaseProvider):
8
+ name = "qwen"
9
+ model_id = "qwen-turbo"
10
+ context_window = 1000000
11
+
12
+ def is_available(self) -> bool:
13
+ return bool(get_api_key("qwen"))
14
+
15
+ def chat(
16
+ self,
17
+ messages: list[Message],
18
+ system: str = "",
19
+ tools: list[dict] | None = None,
20
+ ) -> Response:
21
+ key = get_api_key("qwen")
22
+ if not key:
23
+ raise NoAPIKeyError("qwen")
24
+ formatted = [{"role": "system", "content": system or "You are zai, a helpful AI assistant."}]
25
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
26
+ for attempt in range(self.retries + 1):
27
+ try:
28
+ response = httpx.post(
29
+ "https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
30
+ headers={"Authorization": f"Bearer {key}"},
31
+ json={"model": self.model_id, "messages": formatted},
32
+ timeout=self.timeout,
33
+ )
34
+ try:
35
+ response.raise_for_status()
36
+ data = response.json()
37
+ except Exception as error:
38
+ raise classify_provider_error("qwen", error) from error
39
+ return Response(
40
+ content=data["choices"][0]["message"]["content"],
41
+ model=self.model_id,
42
+ tokens_used=data.get("usage", {}).get("total_tokens", 0),
43
+ )
44
+ except (httpx.ConnectError, httpx.TimeoutException) as error:
45
+ if attempt == self.retries:
46
+ raise NetworkError(str(error))
47
+ raise NetworkError("Qwen request failed")
zai/skills/__init__.py ADDED
File without changes
zai/skills/registry.py ADDED
@@ -0,0 +1,52 @@
1
+ from ..tools.files import read_file
2
+
3
+ SKILLS = {
4
+ "review": {
5
+ "description": "Review code and suggest improvements",
6
+ "prompt": lambda ctx: f"Do a thorough code review of the following. Point out bugs, security issues, and improvements:\n\n{ctx}",
7
+ },
8
+ "test": {
9
+ "description": "Generate unit tests",
10
+ "prompt": lambda ctx: f"Write comprehensive unit tests for this code using pytest:\n\n{ctx}",
11
+ },
12
+ "commit": {
13
+ "description": "Generate git commit message",
14
+ "prompt": lambda ctx: f"Write a clear, concise git commit message for these changes (conventional commits format):\n\n{ctx}",
15
+ },
16
+ "docs": {
17
+ "description": "Generate documentation",
18
+ "prompt": lambda ctx: f"Write clear documentation (docstrings + README section) for this code:\n\n{ctx}",
19
+ },
20
+ "fix": {
21
+ "description": "Find and fix bugs",
22
+ "prompt": lambda ctx: f"Find all bugs in this code and provide the fixed version with explanation:\n\n{ctx}",
23
+ },
24
+ "explain": {
25
+ "description": "Explain code clearly",
26
+ "prompt": lambda ctx: f"Explain this code clearly, step by step, for someone learning:\n\n{ctx}",
27
+ },
28
+ "refactor": {
29
+ "description": "Refactor code for quality",
30
+ "prompt": lambda ctx: f"Refactor this code for better readability, performance, and best practices:\n\n{ctx}",
31
+ },
32
+ "summarize": {
33
+ "description": "Summarize text or file",
34
+ "prompt": lambda ctx: f"Summarize the following in clear bullet points:\n\n{ctx}",
35
+ },
36
+ "translate": {
37
+ "description": "Translate text",
38
+ "prompt": lambda ctx: f"Translate the following text to English (or specify target language):\n\n{ctx}",
39
+ },
40
+ }
41
+
42
+
43
+ def get_skill_prompt(skill_name: str, file_path: str = None, text: str = None) -> str | None:
44
+ skill = SKILLS.get(skill_name)
45
+ if not skill:
46
+ return None
47
+ if file_path:
48
+ content = read_file(file_path)
49
+ return skill["prompt"](content)
50
+ if text:
51
+ return skill["prompt"](text)
52
+ return skill["prompt"]("")
zai/tools/__init__.py ADDED
File without changes
zai/tools/browser.py ADDED
@@ -0,0 +1,224 @@
1
+ import ipaddress
2
+ import socket
3
+ from pathlib import Path
4
+ from urllib.parse import urlsplit
5
+
6
+ from ..core.errors import FileError, ZaiError
7
+ from ..core.security import resolve_project_path
8
+
9
+ MAX_DOCUMENT_BYTES = 2_000_000
10
+
11
+
12
+ class BrowserError(ZaiError):
13
+ pass
14
+
15
+
16
+ def _get_playwright():
17
+ try:
18
+ from playwright.sync_api import sync_playwright
19
+
20
+ return sync_playwright
21
+ except ImportError:
22
+ raise BrowserError(
23
+ "Playwright not installed. Run: pip install playwright && "
24
+ "playwright install chromium"
25
+ )
26
+
27
+
28
+ def _validate_public_url(url: str) -> str:
29
+ """Allow only credential-free public HTTP(S) destinations."""
30
+ try:
31
+ parsed = urlsplit(url)
32
+ port = parsed.port
33
+ except ValueError as error:
34
+ raise BrowserError(f"Invalid URL: {error}")
35
+ if parsed.scheme.lower() not in {"http", "https"}:
36
+ raise BrowserError("Only http:// and https:// URLs are allowed")
37
+ if not parsed.hostname:
38
+ raise BrowserError("URL must include a hostname")
39
+ if parsed.username is not None or parsed.password is not None:
40
+ raise BrowserError("URLs containing credentials are not allowed")
41
+
42
+ hostname = parsed.hostname.rstrip(".").lower()
43
+ if hostname == "localhost" or hostname.endswith(".localhost"):
44
+ raise BrowserError("Local and private network URLs are blocked")
45
+ try:
46
+ addresses = {
47
+ item[4][0]
48
+ for item in socket.getaddrinfo(
49
+ hostname,
50
+ port or (443 if parsed.scheme.lower() == "https" else 80),
51
+ type=socket.SOCK_STREAM,
52
+ )
53
+ }
54
+ except socket.gaierror as error:
55
+ raise BrowserError(f"Cannot resolve URL hostname: {error}")
56
+ if not addresses:
57
+ raise BrowserError("URL hostname did not resolve")
58
+ for address in addresses:
59
+ try:
60
+ if not ipaddress.ip_address(address).is_global:
61
+ raise BrowserError("Local and private network URLs are blocked")
62
+ except ValueError:
63
+ raise BrowserError("URL resolved to an invalid network address")
64
+ return url
65
+
66
+
67
+ def _safe_output_path(path: str, project_dir: str = ".") -> Path:
68
+ if Path(path).suffix.lower() not in {".png", ".jpg", ".jpeg"}:
69
+ raise BrowserError("Screenshot output must use .png, .jpg, or .jpeg")
70
+ try:
71
+ output = resolve_project_path(project_dir, path)
72
+ except FileError as error:
73
+ raise BrowserError(str(error))
74
+ output.parent.mkdir(parents=True, exist_ok=True)
75
+ return output
76
+
77
+
78
+ def _prepare_page(browser, initial_url: str):
79
+ context = browser.new_context(
80
+ service_workers="block",
81
+ accept_downloads=False,
82
+ )
83
+ page = context.new_page()
84
+
85
+ def guard_request(route):
86
+ try:
87
+ _validate_public_url(route.request.url)
88
+ except BrowserError:
89
+ route.abort("blockedbyclient")
90
+ return
91
+ route.continue_()
92
+
93
+ page.route("http://**/*", guard_request)
94
+ page.route("https://**/*", guard_request)
95
+ response = page.goto(
96
+ _validate_public_url(initial_url),
97
+ timeout=30000,
98
+ wait_until="domcontentloaded",
99
+ )
100
+ if response is not None:
101
+ content_length = response.headers.get("content-length")
102
+ if content_length:
103
+ try:
104
+ if int(content_length) > MAX_DOCUMENT_BYTES:
105
+ raise BrowserError("Webpage exceeded the 2 MB document limit")
106
+ except ValueError:
107
+ pass
108
+ _validate_public_url(page.url)
109
+ return context, page
110
+
111
+
112
+ def scrape(url: str) -> str:
113
+ """Open a public URL and return bounded visible text content."""
114
+ sync_playwright = _get_playwright()
115
+ try:
116
+ with sync_playwright() as playwright:
117
+ browser = playwright.chromium.launch(headless=True)
118
+ try:
119
+ context, page = _prepare_page(browser, url)
120
+ try:
121
+ text = page.evaluate("() => document.body.innerText")
122
+ return text[:8000] if text else "No content found."
123
+ finally:
124
+ context.close()
125
+ finally:
126
+ browser.close()
127
+ except BrowserError:
128
+ raise
129
+ except Exception as error:
130
+ raise BrowserError(f"Cannot scrape {url}: {error}")
131
+
132
+
133
+ def screenshot(url: str, path: str = "screenshot.png", project_dir: str = ".") -> str:
134
+ """Take a screenshot of a public URL inside the current project."""
135
+ output = _safe_output_path(path, project_dir)
136
+ sync_playwright = _get_playwright()
137
+ try:
138
+ with sync_playwright() as playwright:
139
+ browser = playwright.chromium.launch(headless=True)
140
+ try:
141
+ context, page = _prepare_page(browser, url)
142
+ try:
143
+ page.screenshot(path=str(output), full_page=True)
144
+ return str(output)
145
+ finally:
146
+ context.close()
147
+ finally:
148
+ browser.close()
149
+ except BrowserError:
150
+ raise
151
+ except Exception as error:
152
+ raise BrowserError(f"Screenshot failed: {error}")
153
+
154
+
155
+ def click_and_get(url: str, selector: str) -> str:
156
+ """Click an element on a public page and return bounded content."""
157
+ sync_playwright = _get_playwright()
158
+ try:
159
+ with sync_playwright() as playwright:
160
+ browser = playwright.chromium.launch(headless=True)
161
+ try:
162
+ context, page = _prepare_page(browser, url)
163
+ try:
164
+ page.click(selector, timeout=5000)
165
+ page.wait_for_load_state("networkidle", timeout=10000)
166
+ _validate_public_url(page.url)
167
+ text = page.evaluate("() => document.body.innerText")
168
+ return text[:8000] if text else "No content after click."
169
+ finally:
170
+ context.close()
171
+ finally:
172
+ browser.close()
173
+ except BrowserError:
174
+ raise
175
+ except Exception as error:
176
+ raise BrowserError(f"Click failed: {error}")
177
+
178
+
179
+ def fill_form(url: str, fields: dict[str, str], submit_selector: str = None) -> str:
180
+ """Fill a form on a public page and optionally submit it."""
181
+ sync_playwright = _get_playwright()
182
+ try:
183
+ with sync_playwright() as playwright:
184
+ browser = playwright.chromium.launch(headless=True)
185
+ try:
186
+ context, page = _prepare_page(browser, url)
187
+ try:
188
+ for selector, value in fields.items():
189
+ page.fill(selector, value)
190
+ if submit_selector:
191
+ page.click(submit_selector)
192
+ page.wait_for_load_state("networkidle", timeout=10000)
193
+ _validate_public_url(page.url)
194
+ text = page.evaluate("() => document.body.innerText")
195
+ return text[:5000] if text else "Done."
196
+ finally:
197
+ context.close()
198
+ finally:
199
+ browser.close()
200
+ except BrowserError:
201
+ raise
202
+ except Exception as error:
203
+ raise BrowserError(f"Form fill failed: {error}")
204
+
205
+
206
+ def run_js(url: str, script: str) -> str:
207
+ """Run JavaScript on a public page and return a bounded result."""
208
+ sync_playwright = _get_playwright()
209
+ try:
210
+ with sync_playwright() as playwright:
211
+ browser = playwright.chromium.launch(headless=True)
212
+ try:
213
+ context, page = _prepare_page(browser, url)
214
+ try:
215
+ result = page.evaluate(script)
216
+ return str(result)[:8000]
217
+ finally:
218
+ context.close()
219
+ finally:
220
+ browser.close()
221
+ except BrowserError:
222
+ raise
223
+ except Exception as error:
224
+ raise BrowserError(f"JS execution failed: {error}")
@@ -0,0 +1,49 @@
1
+ import subprocess
2
+ import sys
3
+ import tempfile
4
+ from pathlib import Path
5
+
6
+
7
+ def run_python(code: str, timeout: int = 30) -> str:
8
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
9
+ f.write(code)
10
+ tmp = f.name
11
+ try:
12
+ result = subprocess.run(
13
+ [sys.executable, tmp],
14
+ capture_output=True, text=True, timeout=timeout
15
+ )
16
+ out = result.stdout.strip()
17
+ err = result.stderr.strip()
18
+ if err:
19
+ return f"Output:\n{out}\n\nErrors:\n{err}" if out else f"Error:\n{err}"
20
+ return out or "No output."
21
+ except subprocess.TimeoutExpired:
22
+ return f"Error: Code timed out after {timeout} seconds."
23
+ except Exception as e:
24
+ return f"Error: {e}"
25
+ finally:
26
+ Path(tmp).unlink(missing_ok=True)
27
+
28
+
29
+ def run_file(path: str, timeout: int = 30) -> str:
30
+ from ..core.errors import FileError
31
+ target = Path(path)
32
+ if not target.exists():
33
+ raise FileError(f"File not found: {path}")
34
+ if not target.is_file():
35
+ raise FileError(f"Not a file: {path}")
36
+ try:
37
+ result = subprocess.run(
38
+ [sys.executable, str(target)],
39
+ capture_output=True, text=True, timeout=timeout
40
+ )
41
+ out = result.stdout.strip()
42
+ err = result.stderr.strip()
43
+ if err:
44
+ return f"Output:\n{out}\n\nErrors:\n{err}" if out else f"Error:\n{err}"
45
+ return out or "No output."
46
+ except subprocess.TimeoutExpired:
47
+ return f"Error: Timed out after {timeout}s."
48
+ except Exception as e:
49
+ return f"Error: {e}"
zai/tools/files.py ADDED
@@ -0,0 +1,53 @@
1
+ from pathlib import Path
2
+ from ..core.errors import FileError
3
+ from ..core.storage import atomic_write_text
4
+
5
+
6
+ def read_file(path: str) -> str:
7
+ p = Path(path)
8
+ if not p.exists():
9
+ raise FileError(f"File not found: {path}")
10
+ if not p.is_file():
11
+ raise FileError(f"Not a file: {path}")
12
+ try:
13
+ return p.read_text(encoding="utf-8", errors="ignore")
14
+ except PermissionError:
15
+ raise FileError(f"Permission denied: {path}")
16
+ except Exception as e:
17
+ raise FileError(f"Cannot read {path}: {e}")
18
+
19
+
20
+ def write_file(path: str, content: str):
21
+ p = Path(path)
22
+ try:
23
+ p.parent.mkdir(parents=True, exist_ok=True)
24
+ atomic_write_text(p, content, mode=0o644, lock=False)
25
+ except PermissionError:
26
+ raise FileError(f"Permission denied: {path}")
27
+ except Exception as e:
28
+ raise FileError(f"Cannot write {path}: {e}")
29
+
30
+
31
+ def edit_file(path: str, old: str, new: str) -> bool:
32
+ p = Path(path)
33
+ if not p.exists():
34
+ raise FileError(f"File not found: {path}")
35
+ content = p.read_text(encoding="utf-8")
36
+ if old not in content:
37
+ return False
38
+ atomic_write_text(
39
+ p,
40
+ content.replace(old, new, 1),
41
+ mode=0o644,
42
+ lock=False,
43
+ )
44
+ return True
45
+
46
+
47
+ def list_files(directory: str = ".", pattern: str = "*") -> list[str]:
48
+ d = Path(directory)
49
+ if not d.exists():
50
+ raise FileError(f"Directory not found: {directory}")
51
+ if not d.is_dir():
52
+ raise FileError(f"Not a directory: {directory}")
53
+ return [str(p) for p in d.rglob(pattern) if p.is_file()]
zai/tools/git.py ADDED
@@ -0,0 +1,38 @@
1
+ import subprocess
2
+
3
+
4
+ def run_git(args: list[str]) -> str:
5
+ try:
6
+ result = subprocess.run(
7
+ ["git"] + args,
8
+ capture_output=True, text=True, timeout=30
9
+ )
10
+ return result.stdout.strip() or result.stderr.strip()
11
+ except FileNotFoundError:
12
+ return "Error: git not found. Please install git."
13
+ except Exception as e:
14
+ return f"Error: {e}"
15
+
16
+
17
+ def get_diff() -> str:
18
+ return run_git(["diff", "HEAD"])
19
+
20
+
21
+ def get_staged_diff() -> str:
22
+ return run_git(["diff", "--staged"])
23
+
24
+
25
+ def get_status() -> str:
26
+ return run_git(["status", "--short"])
27
+
28
+
29
+ def get_log(n: int = 10) -> str:
30
+ return run_git(["log", f"-{n}", "--oneline"])
31
+
32
+
33
+ def commit(message: str) -> str:
34
+ return run_git(["commit", "-m", message])
35
+
36
+
37
+ def get_branch() -> str:
38
+ return run_git(["branch", "--show-current"])
zai/tools/search.py ADDED
@@ -0,0 +1,157 @@
1
+ """Bounded web-search provider adapters."""
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from html.parser import HTMLParser
6
+ from urllib.parse import parse_qs, unquote, urlsplit
7
+
8
+ import httpx
9
+
10
+ from ..core.errors import NetworkError
11
+
12
+ MAX_SEARCH_RESPONSE_BYTES = 1_000_000
13
+ MAX_RESULTS = 8
14
+ DUCKDUCKGO_HTML_URL = "https://html.duckduckgo.com/html/"
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class SearchResult:
19
+ title: str
20
+ url: str
21
+ snippet: str = ""
22
+
23
+
24
+ class SearchProvider:
25
+ """Interface implemented by general web-search backends."""
26
+
27
+ name = "search"
28
+
29
+ def search(self, query: str, *, limit: int = MAX_RESULTS) -> list[SearchResult]:
30
+ raise NotImplementedError
31
+
32
+
33
+ class _DuckDuckGoHTMLParser(HTMLParser):
34
+ def __init__(self) -> None:
35
+ super().__init__(convert_charrefs=True)
36
+ self.results: list[SearchResult] = []
37
+ self._title_parts: list[str] | None = None
38
+ self._snippet_parts: list[str] | None = None
39
+ self._url = ""
40
+
41
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
42
+ attributes = dict(attrs)
43
+ classes = set((attributes.get("class") or "").split())
44
+ if tag == "a" and "result__a" in classes:
45
+ self._url = _unwrap_duckduckgo_url(attributes.get("href") or "")
46
+ self._title_parts = []
47
+ elif tag in {"a", "div"} and "result__snippet" in classes:
48
+ self._snippet_parts = []
49
+
50
+ def handle_data(self, data: str) -> None:
51
+ if self._title_parts is not None:
52
+ self._title_parts.append(data)
53
+ if self._snippet_parts is not None:
54
+ self._snippet_parts.append(data)
55
+
56
+ def handle_endtag(self, tag: str) -> None:
57
+ if tag == "a" and self._title_parts is not None:
58
+ title = " ".join("".join(self._title_parts).split())
59
+ if title and self._url:
60
+ self.results.append(SearchResult(title=title, url=self._url))
61
+ self._title_parts = None
62
+ self._url = ""
63
+ if tag in {"a", "div"} and self._snippet_parts is not None:
64
+ snippet = " ".join("".join(self._snippet_parts).split())
65
+ if snippet and self.results:
66
+ previous = self.results[-1]
67
+ self.results[-1] = SearchResult(
68
+ title=previous.title,
69
+ url=previous.url,
70
+ snippet=snippet,
71
+ )
72
+ self._snippet_parts = None
73
+
74
+
75
+ def _unwrap_duckduckgo_url(url: str) -> str:
76
+ parsed = urlsplit(url)
77
+ redirected = parse_qs(parsed.query).get("uddg")
78
+ return unquote(redirected[0]) if redirected else url
79
+
80
+
81
+ def _read_bounded_response(
82
+ client: httpx.Client,
83
+ url: str,
84
+ *,
85
+ data: dict[str, str],
86
+ ) -> str:
87
+ with client.stream("POST", url, data=data) as response:
88
+ response.raise_for_status()
89
+ content_length = response.headers.get("content-length")
90
+ if content_length and int(content_length) > MAX_SEARCH_RESPONSE_BYTES:
91
+ raise NetworkError("Search response exceeded the 1 MB limit.")
92
+ body = bytearray()
93
+ for chunk in response.iter_bytes():
94
+ body.extend(chunk)
95
+ if len(body) > MAX_SEARCH_RESPONSE_BYTES:
96
+ raise NetworkError("Search response exceeded the 1 MB limit.")
97
+ return body.decode(response.encoding or "utf-8", errors="replace")
98
+
99
+
100
+ class DuckDuckGoHTMLSearchProvider(SearchProvider):
101
+ """General web results from DuckDuckGo's HTML search endpoint."""
102
+
103
+ name = "duckduckgo"
104
+
105
+ def search(self, query: str, *, limit: int = MAX_RESULTS) -> list[SearchResult]:
106
+ query = " ".join(query.split())
107
+ if not query:
108
+ raise NetworkError("Search query cannot be empty.")
109
+ headers = {
110
+ "User-Agent": "zai-cli/0.1 (+https://github.com/zai-cli/zai)",
111
+ "Accept": "text/html,application/xhtml+xml",
112
+ }
113
+ try:
114
+ with httpx.Client(
115
+ headers=headers,
116
+ timeout=httpx.Timeout(10.0),
117
+ follow_redirects=True,
118
+ ) as client:
119
+ html = _read_bounded_response(
120
+ client,
121
+ DUCKDUCKGO_HTML_URL,
122
+ data={"q": query},
123
+ )
124
+ except NetworkError:
125
+ raise
126
+ except httpx.ConnectError as error:
127
+ raise NetworkError("Cannot connect. Check your internet.") from error
128
+ except httpx.TimeoutException as error:
129
+ raise NetworkError("Search timed out.") from error
130
+ except httpx.HTTPError as error:
131
+ raise NetworkError(f"Search provider error: {error}") from error
132
+
133
+ parser = _DuckDuckGoHTMLParser()
134
+ parser.feed(html)
135
+ return parser.results[: max(1, min(limit, MAX_RESULTS))]
136
+
137
+
138
+ def _format_results(query: str, results: list[SearchResult]) -> str:
139
+ if not results:
140
+ return f"No results found for: {query}"
141
+ lines = []
142
+ for index, result in enumerate(results, 1):
143
+ lines.append(f"{index}. {result.title}\n {result.url}")
144
+ if result.snippet:
145
+ lines.append(f" {result.snippet}")
146
+ return "\n".join(lines)
147
+
148
+
149
+ def web_search(query: str, provider: SearchProvider | None = None) -> str:
150
+ """Search through an injectable provider and return bounded plain text."""
151
+ backend = provider or DuckDuckGoHTMLSearchProvider()
152
+ try:
153
+ return _format_results(query, backend.search(query))
154
+ except NetworkError:
155
+ raise
156
+ except Exception as error:
157
+ raise NetworkError(f"Search failed: {error}") from error