zai-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zai/__init__.py +1 -0
  2. zai/__main__.py +4 -0
  3. zai/cli/__init__.py +1 -0
  4. zai/cli/common.py +16 -0
  5. zai/cli/integrations.py +319 -0
  6. zai/cli/interactive.py +518 -0
  7. zai/cli/settings.py +436 -0
  8. zai/cli/utilities.py +227 -0
  9. zai/cli/workflows.py +137 -0
  10. zai/commands/commit.md +24 -0
  11. zai/commands/explain.md +17 -0
  12. zai/commands/feature.md +34 -0
  13. zai/commands/fix.md +14 -0
  14. zai/commands/review.md +22 -0
  15. zai/config.py +307 -0
  16. zai/core/__init__.py +0 -0
  17. zai/core/agent.py +701 -0
  18. zai/core/cancellation.py +67 -0
  19. zai/core/commands.py +85 -0
  20. zai/core/context.py +299 -0
  21. zai/core/errors.py +125 -0
  22. zai/core/fallback.py +171 -0
  23. zai/core/hooks.py +115 -0
  24. zai/core/memory.py +57 -0
  25. zai/core/process.py +204 -0
  26. zai/core/repomap.py +381 -0
  27. zai/core/runtime.py +29 -0
  28. zai/core/security.py +33 -0
  29. zai/core/session.py +425 -0
  30. zai/core/storage.py +193 -0
  31. zai/core/streaming.py +157 -0
  32. zai/core/tool_schema.py +133 -0
  33. zai/core/undo.py +443 -0
  34. zai/core/watch.py +80 -0
  35. zai/main.py +210 -0
  36. zai/mcp/__init__.py +0 -0
  37. zai/mcp/client.py +431 -0
  38. zai/mcp/manager.py +118 -0
  39. zai/plugins/__init__.py +2 -0
  40. zai/plugins/base.py +49 -0
  41. zai/plugins/loader.py +404 -0
  42. zai/providers/__init__.py +22 -0
  43. zai/providers/anthropic.py +131 -0
  44. zai/providers/base.py +67 -0
  45. zai/providers/cerebras.py +57 -0
  46. zai/providers/gemini.py +119 -0
  47. zai/providers/groq.py +116 -0
  48. zai/providers/ollama.py +62 -0
  49. zai/providers/openai.py +124 -0
  50. zai/providers/openrouter.py +63 -0
  51. zai/providers/qwen.py +47 -0
  52. zai/skills/__init__.py +0 -0
  53. zai/skills/registry.py +52 -0
  54. zai/tools/__init__.py +0 -0
  55. zai/tools/browser.py +224 -0
  56. zai/tools/code_runner.py +49 -0
  57. zai/tools/files.py +53 -0
  58. zai/tools/git.py +38 -0
  59. zai/tools/search.py +157 -0
  60. zai/tools/vision.py +128 -0
  61. zai/ui/__init__.py +0 -0
  62. zai/ui/input.py +199 -0
  63. zai_cli-0.1.0.dist-info/METADATA +722 -0
  64. zai_cli-0.1.0.dist-info/RECORD +68 -0
  65. zai_cli-0.1.0.dist-info/WHEEL +5 -0
  66. zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
  67. zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
  68. zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/streaming.py ADDED
@@ -0,0 +1,157 @@
1
+ from rich.console import Console
2
+ from rich.live import Live
3
+ from rich.text import Text
4
+
5
+ console = Console()
6
+
7
+
8
+ def stream_groq(
9
+ messages: list, system: str, model_id: str, api_key: str,
10
+ timeout: float = 60, retries: int = 2,
11
+ ):
12
+ from groq import Groq
13
+ client = Groq(api_key=api_key, timeout=timeout, max_retries=retries)
14
+ formatted = [{"role": "system", "content": system}]
15
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
16
+ stream = client.chat.completions.create(
17
+ model=model_id,
18
+ messages=formatted,
19
+ stream=True,
20
+ )
21
+ full = ""
22
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
23
+ for chunk in stream:
24
+ delta = chunk.choices[0].delta.content or ""
25
+ full += delta
26
+ live.update(Text(full))
27
+ return full
28
+
29
+
30
+ def stream_gemini(
31
+ messages: list, system: str, model_id: str, api_key: str,
32
+ timeout: float = 60,
33
+ ):
34
+ from google import genai
35
+ from google.genai import types
36
+
37
+ client = genai.Client(
38
+ api_key=api_key,
39
+ http_options=types.HttpOptions(timeout=int(timeout * 1000)),
40
+ )
41
+ contents = [
42
+ types.Content(
43
+ role=m.role if m.role == "user" else "model",
44
+ parts=[types.Part.from_text(text=m.content)],
45
+ )
46
+ for m in messages
47
+ ]
48
+ full = ""
49
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
50
+ for chunk in client.models.generate_content_stream(
51
+ model=model_id,
52
+ contents=contents,
53
+ config=types.GenerateContentConfig(system_instruction=system),
54
+ ):
55
+ full += chunk.text or ""
56
+ live.update(Text(full))
57
+ return full
58
+
59
+
60
+ def stream_anthropic(
61
+ messages: list, system: str, model_id: str, api_key: str,
62
+ timeout: float = 60, retries: int = 2,
63
+ ):
64
+ import anthropic
65
+ client = anthropic.Anthropic(
66
+ api_key=api_key, timeout=timeout, max_retries=retries,
67
+ )
68
+ formatted = [{"role": m.role, "content": m.content} for m in messages]
69
+ full = ""
70
+ with client.messages.stream(
71
+ model=model_id,
72
+ max_tokens=8096,
73
+ system=system,
74
+ messages=formatted,
75
+ ) as stream:
76
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
77
+ for delta in stream.text_stream:
78
+ full += delta
79
+ live.update(Text(full))
80
+ return full
81
+
82
+
83
+ def stream_openai(
84
+ messages: list, system: str, model_id: str, api_key: str,
85
+ timeout: float = 60, retries: int = 2,
86
+ ):
87
+ from openai import OpenAI
88
+ client = OpenAI(api_key=api_key, timeout=timeout, max_retries=retries)
89
+ formatted = [{"role": "system", "content": system}]
90
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
91
+ full = ""
92
+ stream = client.chat.completions.create(
93
+ model=model_id,
94
+ messages=formatted,
95
+ stream=True,
96
+ )
97
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
98
+ for chunk in stream:
99
+ delta = chunk.choices[0].delta.content or ""
100
+ full += delta
101
+ live.update(Text(full))
102
+ return full
103
+
104
+
105
+ def stream_ollama(
106
+ messages: list, system: str, model_id: str, timeout: float = 120,
107
+ ):
108
+ import httpx
109
+ import json
110
+ formatted = [{"role": "system", "content": system}]
111
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
112
+ full = ""
113
+ with httpx.stream(
114
+ "POST",
115
+ "http://localhost:11434/api/chat",
116
+ json={"model": model_id, "messages": formatted, "stream": True},
117
+ timeout=timeout,
118
+ ) as response:
119
+ response.raise_for_status()
120
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
121
+ for line in response.iter_lines():
122
+ if line:
123
+ try:
124
+ data = json.loads(line)
125
+ delta = data.get("message", {}).get("content", "")
126
+ full += delta
127
+ live.update(Text(full))
128
+ except Exception:
129
+ pass
130
+ return full
131
+
132
+
133
+ def stream_openrouter(messages: list, system: str, model_id: str, api_key: str):
134
+ import httpx
135
+ formatted = [{"role": "system", "content": system}]
136
+ formatted += [{"role": m.role, "content": m.content} for m in messages]
137
+ full = ""
138
+ with httpx.stream(
139
+ "POST",
140
+ "https://openrouter.ai/api/v1/chat/completions",
141
+ headers={"Authorization": f"Bearer {api_key}"},
142
+ json={"model": model_id, "messages": formatted, "stream": True},
143
+ timeout=60,
144
+ ) as response:
145
+ response.raise_for_status()
146
+ with Live(Text(""), refresh_per_second=15, console=console) as live:
147
+ for line in response.iter_lines():
148
+ if line.startswith("data: ") and line != "data: [DONE]":
149
+ import json
150
+ try:
151
+ data = json.loads(line[6:])
152
+ delta = data["choices"][0]["delta"].get("content", "")
153
+ full += delta
154
+ live.update(Text(full))
155
+ except Exception:
156
+ pass
157
+ return full
@@ -0,0 +1,133 @@
1
+ from typing import Any, Literal
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field, ValidationError
4
+
5
+
6
+ class StrictArguments(BaseModel):
7
+ model_config = ConfigDict(extra="forbid")
8
+
9
+
10
+ class PathArguments(StrictArguments):
11
+ path: str = Field(min_length=1)
12
+
13
+
14
+ class WriteFileArguments(PathArguments):
15
+ content: str
16
+
17
+
18
+ class EditFileArguments(PathArguments):
19
+ search: str = Field(min_length=1)
20
+ replace: str
21
+
22
+
23
+ class ListFilesArguments(StrictArguments):
24
+ path: str = "."
25
+
26
+
27
+ class RenamePathArguments(StrictArguments):
28
+ source: str = Field(min_length=1)
29
+ destination: str = Field(min_length=1)
30
+
31
+
32
+ class RunCommandArguments(StrictArguments):
33
+ command: str = Field(min_length=1)
34
+
35
+
36
+ class MCPCallArguments(StrictArguments):
37
+ server: str = Field(min_length=1)
38
+ tool: str = Field(min_length=1)
39
+ arguments: dict[str, Any] = Field(default_factory=dict)
40
+
41
+
42
+ class PluginCallArguments(StrictArguments):
43
+ plugin: str = Field(min_length=1)
44
+ tool: str = Field(min_length=1)
45
+ arguments: dict[str, Any] = Field(default_factory=dict)
46
+
47
+
48
+ class ToolCallEnvelope(BaseModel):
49
+ model_config = ConfigDict(extra="forbid")
50
+
51
+ name: Literal[
52
+ "write_file",
53
+ "read_file",
54
+ "edit_file",
55
+ "create_folder",
56
+ "rename_path",
57
+ "list_files",
58
+ "run_command",
59
+ "mcp_call",
60
+ "plugin_call",
61
+ ]
62
+ arguments: dict[str, Any] = Field(default_factory=dict)
63
+
64
+
65
+ ARGUMENT_MODELS = {
66
+ "write_file": WriteFileArguments,
67
+ "read_file": PathArguments,
68
+ "edit_file": EditFileArguments,
69
+ "create_folder": PathArguments,
70
+ "rename_path": RenamePathArguments,
71
+ "list_files": ListFilesArguments,
72
+ "run_command": RunCommandArguments,
73
+ "mcp_call": MCPCallArguments,
74
+ "plugin_call": PluginCallArguments,
75
+ }
76
+
77
+ TOOL_DESCRIPTIONS = {
78
+ "write_file": "Create a new text file. Use edit_file for an existing file.",
79
+ "read_file": "Read a text file inside the current project.",
80
+ "edit_file": "Replace one exact text occurrence in an existing file.",
81
+ "create_folder": "Create a directory inside the current project.",
82
+ "rename_path": "Rename or move a file or directory inside the current project.",
83
+ "list_files": "List files in a project directory.",
84
+ "run_command": "Run an approved direct command in the current project.",
85
+ "mcp_call": "Call a tool exposed by a connected MCP server.",
86
+ "plugin_call": "Call a tool exposed by an enabled zai plugin.",
87
+ }
88
+
89
+
90
+ def get_tool_definitions() -> list[dict[str, Any]]:
91
+ """Return provider-neutral JSON Schema function definitions."""
92
+ definitions = []
93
+ for name, model in ARGUMENT_MODELS.items():
94
+ schema = model.model_json_schema()
95
+ schema.pop("title", None)
96
+ schema["additionalProperties"] = False
97
+ # Native providers are most reliable when every property is explicit.
98
+ schema["required"] = list(schema.get("properties", {}).keys())
99
+ definitions.append({
100
+ "name": name,
101
+ "description": TOOL_DESCRIPTIONS[name],
102
+ "parameters": schema,
103
+ })
104
+ return definitions
105
+
106
+
107
+ def legacy_tool_instructions() -> str:
108
+ """Instructions for providers that do not expose native function calling."""
109
+ lines = [
110
+ "Your API does not support native tools here. Emit each call as JSON:",
111
+ '<tool_call>{"name":"tool_name","arguments":{...}}</tool_call>',
112
+ "Valid tools:",
113
+ ]
114
+ for tool in get_tool_definitions():
115
+ properties = ", ".join(tool["parameters"].get("properties", {}))
116
+ lines.append(f"- {tool['name']}: {properties}")
117
+ lines.append("JSON must use double quotes and must not contain trailing commas.")
118
+ return "\n".join(lines)
119
+
120
+
121
+ def validate_tool_call(data: Any) -> tuple[str, dict[str, Any]]:
122
+ """Validate a structured tool call and return normalized arguments."""
123
+ envelope = ToolCallEnvelope.model_validate(data)
124
+ arguments = ARGUMENT_MODELS[envelope.name].model_validate(envelope.arguments)
125
+ return envelope.name, arguments.model_dump()
126
+
127
+
128
+ def format_validation_error(error: ValidationError) -> str:
129
+ issues = []
130
+ for item in error.errors(include_url=False):
131
+ location = ".".join(str(part) for part in item["loc"])
132
+ issues.append(f"{location}: {item['msg']}")
133
+ return "; ".join(issues)