zai-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zai/__init__.py +1 -0
- zai/__main__.py +4 -0
- zai/cli/__init__.py +1 -0
- zai/cli/common.py +16 -0
- zai/cli/integrations.py +319 -0
- zai/cli/interactive.py +518 -0
- zai/cli/settings.py +436 -0
- zai/cli/utilities.py +227 -0
- zai/cli/workflows.py +137 -0
- zai/commands/commit.md +24 -0
- zai/commands/explain.md +17 -0
- zai/commands/feature.md +34 -0
- zai/commands/fix.md +14 -0
- zai/commands/review.md +22 -0
- zai/config.py +307 -0
- zai/core/__init__.py +0 -0
- zai/core/agent.py +701 -0
- zai/core/cancellation.py +67 -0
- zai/core/commands.py +85 -0
- zai/core/context.py +299 -0
- zai/core/errors.py +125 -0
- zai/core/fallback.py +171 -0
- zai/core/hooks.py +115 -0
- zai/core/memory.py +57 -0
- zai/core/process.py +204 -0
- zai/core/repomap.py +381 -0
- zai/core/runtime.py +29 -0
- zai/core/security.py +33 -0
- zai/core/session.py +425 -0
- zai/core/storage.py +193 -0
- zai/core/streaming.py +157 -0
- zai/core/tool_schema.py +133 -0
- zai/core/undo.py +443 -0
- zai/core/watch.py +80 -0
- zai/main.py +210 -0
- zai/mcp/__init__.py +0 -0
- zai/mcp/client.py +431 -0
- zai/mcp/manager.py +118 -0
- zai/plugins/__init__.py +2 -0
- zai/plugins/base.py +49 -0
- zai/plugins/loader.py +404 -0
- zai/providers/__init__.py +22 -0
- zai/providers/anthropic.py +131 -0
- zai/providers/base.py +67 -0
- zai/providers/cerebras.py +57 -0
- zai/providers/gemini.py +119 -0
- zai/providers/groq.py +116 -0
- zai/providers/ollama.py +62 -0
- zai/providers/openai.py +124 -0
- zai/providers/openrouter.py +63 -0
- zai/providers/qwen.py +47 -0
- zai/skills/__init__.py +0 -0
- zai/skills/registry.py +52 -0
- zai/tools/__init__.py +0 -0
- zai/tools/browser.py +224 -0
- zai/tools/code_runner.py +49 -0
- zai/tools/files.py +53 -0
- zai/tools/git.py +38 -0
- zai/tools/search.py +157 -0
- zai/tools/vision.py +128 -0
- zai/ui/__init__.py +0 -0
- zai/ui/input.py +199 -0
- zai_cli-0.1.0.dist-info/METADATA +722 -0
- zai_cli-0.1.0.dist-info/RECORD +68 -0
- zai_cli-0.1.0.dist-info/WHEEL +5 -0
- zai_cli-0.1.0.dist-info/entry_points.txt +2 -0
- zai_cli-0.1.0.dist-info/licenses/LICENSE +21 -0
- zai_cli-0.1.0.dist-info/top_level.txt +1 -0
zai/core/fallback.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
from rich.console import Console
|
|
2
|
+
from ..providers import PROVIDERS
|
|
3
|
+
from ..providers.base import Message, Response
|
|
4
|
+
from ..config import get_model_config, load_config
|
|
5
|
+
from .errors import (
|
|
6
|
+
AllModelsFailedError,
|
|
7
|
+
NetworkError,
|
|
8
|
+
NoAPIKeyError,
|
|
9
|
+
ProviderError,
|
|
10
|
+
RateLimitError,
|
|
11
|
+
classify_provider_error,
|
|
12
|
+
)
|
|
13
|
+
from .runtime import plain_enabled
|
|
14
|
+
from .cancellation import OperationCancelled, raise_if_cancelled
|
|
15
|
+
|
|
16
|
+
console = Console()
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def format_model_selection(name: str) -> str:
|
|
20
|
+
try:
|
|
21
|
+
model = get_model_config(name)
|
|
22
|
+
except KeyError:
|
|
23
|
+
return name
|
|
24
|
+
return f"{name} -> {model['provider']}/{model['model_id']}"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_provider(name: str):
|
|
28
|
+
try:
|
|
29
|
+
model = get_model_config(name)
|
|
30
|
+
except KeyError:
|
|
31
|
+
return None
|
|
32
|
+
cls = PROVIDERS.get(model["provider"])
|
|
33
|
+
if not cls:
|
|
34
|
+
return None
|
|
35
|
+
provider = cls()
|
|
36
|
+
provider.model_id = model["model_id"]
|
|
37
|
+
provider.context_window = model["context_window"]
|
|
38
|
+
provider.timeout = model.get("timeout", 60)
|
|
39
|
+
provider.retries = model.get("retries", 2)
|
|
40
|
+
return provider
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _model_order(config: dict, preferred: str = None) -> list[str]:
|
|
44
|
+
first = preferred or config["default_model"]
|
|
45
|
+
if not config.get("auto_fallback", True):
|
|
46
|
+
return [first]
|
|
47
|
+
return [first] + [
|
|
48
|
+
model for model in config["fallback_order"]
|
|
49
|
+
if model != first
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def has_available_provider() -> bool:
|
|
54
|
+
"""Return True when any configured API provider or local provider is usable."""
|
|
55
|
+
config = load_config()
|
|
56
|
+
for model_name in _model_order(config):
|
|
57
|
+
provider = get_provider(model_name)
|
|
58
|
+
if provider and provider.is_available():
|
|
59
|
+
return True
|
|
60
|
+
return False
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def stream_with_fallback(messages: list[Message], system: str = "", preferred: str = None) -> tuple[str, str]:
|
|
64
|
+
"""Stream response with auto fallback. Returns (content, model_name)."""
|
|
65
|
+
config = load_config()
|
|
66
|
+
order = _model_order(config, preferred)
|
|
67
|
+
|
|
68
|
+
last_error = None
|
|
69
|
+
for model_name in order:
|
|
70
|
+
raise_if_cancelled()
|
|
71
|
+
provider = get_provider(model_name)
|
|
72
|
+
if not provider or not provider.is_available():
|
|
73
|
+
continue
|
|
74
|
+
try:
|
|
75
|
+
from .context import compact_messages, estimate_text_tokens
|
|
76
|
+
|
|
77
|
+
prepared_messages = compact_messages(
|
|
78
|
+
messages,
|
|
79
|
+
getattr(provider, "context_window", 128_000),
|
|
80
|
+
reserve_tokens=max(4096, estimate_text_tokens(system) + 4096),
|
|
81
|
+
)
|
|
82
|
+
if provider.supports_streaming and not plain_enabled():
|
|
83
|
+
content = provider.stream_chat(prepared_messages, system=system)
|
|
84
|
+
else:
|
|
85
|
+
response = provider.chat(prepared_messages, system=system)
|
|
86
|
+
content = response.content
|
|
87
|
+
console.print(content, markup=False)
|
|
88
|
+
raise_if_cancelled()
|
|
89
|
+
return content, model_name
|
|
90
|
+
except OperationCancelled:
|
|
91
|
+
raise
|
|
92
|
+
except RateLimitError as error:
|
|
93
|
+
last_error = str(error)
|
|
94
|
+
console.print(f"[yellow]⚡ {model_name} limit hit — switching...[/yellow]")
|
|
95
|
+
continue
|
|
96
|
+
except NoAPIKeyError:
|
|
97
|
+
continue
|
|
98
|
+
except (NetworkError, ProviderError) as error:
|
|
99
|
+
last_error = str(error)
|
|
100
|
+
console.print(f"[yellow]Network issue with {model_name} — switching...[/yellow]")
|
|
101
|
+
continue
|
|
102
|
+
except Exception as error:
|
|
103
|
+
normalized = classify_provider_error(provider.name, error)
|
|
104
|
+
console.print(f"[dim]Error with {model_name}: {normalized}[/dim]")
|
|
105
|
+
last_error = str(normalized)
|
|
106
|
+
continue
|
|
107
|
+
|
|
108
|
+
raise AllModelsFailedError(last_error or "")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def chat_with_fallback(
|
|
112
|
+
messages: list[Message],
|
|
113
|
+
system: str = "",
|
|
114
|
+
preferred: str = None,
|
|
115
|
+
tools: list[dict] | None = None,
|
|
116
|
+
) -> tuple[Response, str]:
|
|
117
|
+
config = load_config()
|
|
118
|
+
order = _model_order(config, preferred)
|
|
119
|
+
|
|
120
|
+
last_error = None
|
|
121
|
+
for model_name in order:
|
|
122
|
+
raise_if_cancelled()
|
|
123
|
+
provider = get_provider(model_name)
|
|
124
|
+
if not provider or not provider.is_available():
|
|
125
|
+
continue
|
|
126
|
+
try:
|
|
127
|
+
from .context import compact_messages, estimate_text_tokens
|
|
128
|
+
|
|
129
|
+
provider_system = system
|
|
130
|
+
provider_tools = tools if provider.supports_native_tools else None
|
|
131
|
+
if tools and not provider.supports_native_tools:
|
|
132
|
+
from .tool_schema import legacy_tool_instructions
|
|
133
|
+
|
|
134
|
+
provider_system = (
|
|
135
|
+
f"{system}\n\n{legacy_tool_instructions()}"
|
|
136
|
+
)
|
|
137
|
+
tool_schema_tokens = estimate_text_tokens(str(provider_tools or ""))
|
|
138
|
+
prepared_messages = compact_messages(
|
|
139
|
+
messages,
|
|
140
|
+
getattr(provider, "context_window", 128_000),
|
|
141
|
+
reserve_tokens=max(
|
|
142
|
+
4096,
|
|
143
|
+
estimate_text_tokens(provider_system) + tool_schema_tokens + 4096,
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
response = provider.chat(
|
|
147
|
+
prepared_messages,
|
|
148
|
+
system=provider_system,
|
|
149
|
+
tools=provider_tools,
|
|
150
|
+
)
|
|
151
|
+
raise_if_cancelled()
|
|
152
|
+
return response, model_name
|
|
153
|
+
except OperationCancelled:
|
|
154
|
+
raise
|
|
155
|
+
except RateLimitError as error:
|
|
156
|
+
console.print(f"[yellow]⚡ {model_name} limit hit — switching...[/yellow]")
|
|
157
|
+
last_error = str(error)
|
|
158
|
+
continue
|
|
159
|
+
except NoAPIKeyError:
|
|
160
|
+
continue
|
|
161
|
+
except (NetworkError, ProviderError) as error:
|
|
162
|
+
console.print(f"[yellow]Network issue with {model_name} — switching...[/yellow]")
|
|
163
|
+
last_error = str(error)
|
|
164
|
+
continue
|
|
165
|
+
except Exception as error:
|
|
166
|
+
normalized = classify_provider_error(provider.name, error)
|
|
167
|
+
console.print(f"[dim]Error with {model_name}: {normalized}[/dim]")
|
|
168
|
+
last_error = str(normalized)
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
raise AllModelsFailedError(last_error or "")
|
zai/core/hooks.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from .process import run_direct
|
|
5
|
+
from .storage import atomic_write_json, read_json, update_json
|
|
6
|
+
|
|
7
|
+
ZAI_DIR = Path.home() / ".zai"
|
|
8
|
+
HOOKS_FILE = ZAI_DIR / "hooks.json"
|
|
9
|
+
|
|
10
|
+
VALID_EVENTS = [
|
|
11
|
+
"PreToolUse",
|
|
12
|
+
"PostToolUse",
|
|
13
|
+
"SessionStart",
|
|
14
|
+
"SessionEnd",
|
|
15
|
+
"UserPromptSubmit",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_hooks() -> list:
|
|
20
|
+
return read_json(HOOKS_FILE, [], expected_type=list)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def save_hooks(hooks: list):
|
|
24
|
+
atomic_write_json(HOOKS_FILE, hooks)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def add_hook(event: str, command: str) -> dict:
|
|
28
|
+
created = {}
|
|
29
|
+
|
|
30
|
+
def update(hooks):
|
|
31
|
+
nonlocal created
|
|
32
|
+
new_id = max((h["id"] for h in hooks), default=0) + 1
|
|
33
|
+
created = {"id": new_id, "event": event, "command": command}
|
|
34
|
+
return [*hooks, created]
|
|
35
|
+
|
|
36
|
+
update_json(HOOKS_FILE, [], update, expected_type=list)
|
|
37
|
+
hook = created
|
|
38
|
+
return hook
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def remove_hook(hook_id: int) -> bool:
|
|
42
|
+
removed = False
|
|
43
|
+
|
|
44
|
+
def update(hooks):
|
|
45
|
+
nonlocal removed
|
|
46
|
+
new_hooks = [hook for hook in hooks if hook["id"] != hook_id]
|
|
47
|
+
removed = len(new_hooks) != len(hooks)
|
|
48
|
+
return new_hooks
|
|
49
|
+
|
|
50
|
+
update_json(HOOKS_FILE, [], update, expected_type=list)
|
|
51
|
+
return removed
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def fire(event: str, data: dict = None) -> bool:
|
|
55
|
+
"""
|
|
56
|
+
Fire an event. Returns False if any hook blocks (exit code 2).
|
|
57
|
+
Hook scripts receive JSON on stdin.
|
|
58
|
+
"""
|
|
59
|
+
hooks = load_hooks()
|
|
60
|
+
payload = json.dumps(data or {})
|
|
61
|
+
|
|
62
|
+
for hook in hooks:
|
|
63
|
+
if hook.get("event") != event:
|
|
64
|
+
continue
|
|
65
|
+
cmd = hook.get("command", "")
|
|
66
|
+
if not cmd:
|
|
67
|
+
continue
|
|
68
|
+
try:
|
|
69
|
+
result = run_direct(
|
|
70
|
+
cmd,
|
|
71
|
+
input_text=payload,
|
|
72
|
+
timeout=10,
|
|
73
|
+
enforce_policy=False,
|
|
74
|
+
)
|
|
75
|
+
if result.returncode == 2:
|
|
76
|
+
if result.output:
|
|
77
|
+
from rich.console import Console
|
|
78
|
+
Console().print(f"[yellow]Hook blocked:[/yellow] {result.output}")
|
|
79
|
+
return False
|
|
80
|
+
elif result.returncode == 126 and result.blocked_reason:
|
|
81
|
+
from rich.console import Console
|
|
82
|
+
Console().print(
|
|
83
|
+
f"[yellow]Hook blocked:[/yellow] {result.blocked_reason}"
|
|
84
|
+
)
|
|
85
|
+
return False
|
|
86
|
+
elif result.returncode == 1 and result.output:
|
|
87
|
+
from rich.console import Console
|
|
88
|
+
Console().print(f"[dim]Hook note: {result.output}[/dim]")
|
|
89
|
+
except Exception:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
# Built-in security checks for write_file tool
|
|
96
|
+
SECURITY_PATTERNS = {
|
|
97
|
+
"eval(": "eval() executes arbitrary code — security risk",
|
|
98
|
+
"exec(": "exec() can run arbitrary code",
|
|
99
|
+
"os.system(": "os.system() can allow command injection",
|
|
100
|
+
"pickle.loads(": "pickle.loads() with untrusted data is dangerous",
|
|
101
|
+
".innerHTML =": "innerHTML with untrusted content causes XSS",
|
|
102
|
+
"dangerouslySetInnerHTML": "dangerouslySetInnerHTML causes XSS — sanitize first",
|
|
103
|
+
"document.write(": "document.write() is an XSS risk",
|
|
104
|
+
"new Function(": "new Function() evaluates arbitrary code",
|
|
105
|
+
"child_process.exec(": "child_process.exec() allows shell injection",
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def check_security(content: str) -> list[str]:
|
|
110
|
+
"""Return list of security warnings found in content."""
|
|
111
|
+
warnings = []
|
|
112
|
+
for pattern, msg in SECURITY_PATTERNS.items():
|
|
113
|
+
if pattern in content:
|
|
114
|
+
warnings.append(msg)
|
|
115
|
+
return warnings
|
zai/core/memory.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from ..config import MEMORY_FILE
|
|
4
|
+
from .storage import atomic_write_json, read_json, update_json
|
|
5
|
+
|
|
6
|
+
DEFAULT_MEMORY = {"version": 1, "projects": [], "last_session": None, "preferences": {}}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def load_memory() -> dict:
|
|
10
|
+
data = read_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), expected_type=dict)
|
|
11
|
+
return {**DEFAULT_MEMORY, **data}
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def save_memory(memory: dict):
|
|
15
|
+
atomic_write_json(MEMORY_FILE, {**DEFAULT_MEMORY, **memory})
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def save_session(task: str, model: str, files: list[str] | None = None):
|
|
19
|
+
def update(memory):
|
|
20
|
+
memory = {**DEFAULT_MEMORY, **memory}
|
|
21
|
+
memory["last_session"] = {
|
|
22
|
+
"task": task,
|
|
23
|
+
"model": model,
|
|
24
|
+
"files_edited": files or [],
|
|
25
|
+
"date": datetime.now().strftime("%Y-%m-%d %H:%M"),
|
|
26
|
+
}
|
|
27
|
+
return memory
|
|
28
|
+
|
|
29
|
+
update_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), update, expected_type=dict)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_last_session() -> dict | None:
|
|
33
|
+
return load_memory().get("last_session")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def add_project(name: str, path: str, summary: str = ""):
|
|
37
|
+
def update(memory):
|
|
38
|
+
memory = {**DEFAULT_MEMORY, **memory}
|
|
39
|
+
projects = memory.setdefault("projects", [])
|
|
40
|
+
for project in projects:
|
|
41
|
+
if project["path"] == path:
|
|
42
|
+
project["last_worked"] = datetime.now().strftime("%Y-%m-%d")
|
|
43
|
+
project["summary"] = summary or project.get("summary", "")
|
|
44
|
+
return memory
|
|
45
|
+
projects.append({
|
|
46
|
+
"name": name,
|
|
47
|
+
"path": path,
|
|
48
|
+
"summary": summary,
|
|
49
|
+
"last_worked": datetime.now().strftime("%Y-%m-%d"),
|
|
50
|
+
})
|
|
51
|
+
return memory
|
|
52
|
+
|
|
53
|
+
update_json(MEMORY_FILE, DEFAULT_MEMORY.copy(), update, expected_type=dict)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_projects() -> list:
|
|
57
|
+
return load_memory().get("projects", [])
|
zai/core/process.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import shlex
|
|
4
|
+
import subprocess
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Callable
|
|
9
|
+
from .cancellation import CancellationToken, current_token
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
SHELL_META = re.compile(r"(?:&&|\|\||[|;&<>`]|[\r\n]|\$\(|%\w+%)")
|
|
13
|
+
SHELL_INTERPRETERS = {
|
|
14
|
+
"bash", "sh", "zsh", "fish", "cmd", "cmd.exe",
|
|
15
|
+
"powershell", "powershell.exe", "pwsh", "wsl",
|
|
16
|
+
}
|
|
17
|
+
PYTHON_NAMES = {"python", "python.exe", "python3", "python3.exe", "py", "py.exe"}
|
|
18
|
+
PACKAGE_MANAGERS = {"pip", "pip3", "uv", "poetry", "npm", "pnpm", "yarn"}
|
|
19
|
+
DELETE_TOOLS = {"rm", "del", "erase", "rmdir", "remove-item"}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class CommandResult:
|
|
24
|
+
returncode: int
|
|
25
|
+
stdout: str = ""
|
|
26
|
+
stderr: str = ""
|
|
27
|
+
cancelled: bool = False
|
|
28
|
+
blocked_reason: str = ""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def output(self) -> str:
|
|
32
|
+
return (self.stdout + self.stderr).strip()
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def split_command(command: str) -> list[str]:
|
|
36
|
+
"""Split a direct executable command without invoking a shell."""
|
|
37
|
+
command = command.strip()
|
|
38
|
+
if not command:
|
|
39
|
+
raise ValueError("empty command")
|
|
40
|
+
if SHELL_META.search(command):
|
|
41
|
+
raise ValueError("shell syntax is not allowed")
|
|
42
|
+
try:
|
|
43
|
+
parts = shlex.split(command, posix=os.name != "nt")
|
|
44
|
+
except ValueError as exc:
|
|
45
|
+
raise ValueError(f"invalid command quoting: {exc}") from exc
|
|
46
|
+
if os.name == "nt":
|
|
47
|
+
parts = [
|
|
48
|
+
part[1:-1] if len(part) >= 2 and part[0] == part[-1] == '"' else part
|
|
49
|
+
for part in parts
|
|
50
|
+
]
|
|
51
|
+
if not parts:
|
|
52
|
+
raise ValueError("empty command")
|
|
53
|
+
return parts
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _exe_name(argv: list[str]) -> str:
|
|
57
|
+
return Path(argv[0]).name.lower()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def classify_argv(argv: list[str]) -> tuple[str, str]:
|
|
61
|
+
"""Classify direct argv as safe, approval, or blocked."""
|
|
62
|
+
executable = _exe_name(argv)
|
|
63
|
+
lowered = [part.lower() for part in argv[1:]]
|
|
64
|
+
|
|
65
|
+
if executable in SHELL_INTERPRETERS:
|
|
66
|
+
return "blocked", "shell interpreters are not allowed"
|
|
67
|
+
if executable in DELETE_TOOLS:
|
|
68
|
+
return "approval", "file deletion"
|
|
69
|
+
if executable in {"format", "mkfs", "mkfs.ext4", "shutdown", "reboot", "halt"}:
|
|
70
|
+
return "blocked", "destructive system operation"
|
|
71
|
+
|
|
72
|
+
if executable == "git":
|
|
73
|
+
if not lowered:
|
|
74
|
+
return "safe", ""
|
|
75
|
+
action = lowered[0]
|
|
76
|
+
if action == "reset" and "--hard" in lowered:
|
|
77
|
+
return "blocked", "destructive git operation"
|
|
78
|
+
if action == "clean" and any(arg.startswith("-") and "f" in arg for arg in lowered):
|
|
79
|
+
return "blocked", "destructive git operation"
|
|
80
|
+
if action in {"status", "diff", "log", "branch", "show", "rev-parse", "--version"}:
|
|
81
|
+
return "safe", ""
|
|
82
|
+
return "approval", "git state change"
|
|
83
|
+
|
|
84
|
+
if executable in PYTHON_NAMES:
|
|
85
|
+
if lowered in (["--version"], ["-v"]):
|
|
86
|
+
return "safe", ""
|
|
87
|
+
if len(lowered) >= 2 and lowered[0] == "-m":
|
|
88
|
+
module = lowered[1]
|
|
89
|
+
if module in {"pytest", "py_compile", "compileall"}:
|
|
90
|
+
return "safe", ""
|
|
91
|
+
if module in {"pip", "uv"}:
|
|
92
|
+
return "approval", "dependency change"
|
|
93
|
+
if "-c" in lowered:
|
|
94
|
+
return "blocked", "inline interpreter execution"
|
|
95
|
+
return "approval", "Python program execution"
|
|
96
|
+
|
|
97
|
+
if executable in {"pytest", "pytest.exe"}:
|
|
98
|
+
return "safe", ""
|
|
99
|
+
if executable in PACKAGE_MANAGERS:
|
|
100
|
+
return "approval", "dependency change"
|
|
101
|
+
if executable in {"curl", "wget", "npx"}:
|
|
102
|
+
return "approval", "network or package execution"
|
|
103
|
+
|
|
104
|
+
return "blocked", f"executable is not allowlisted: {executable}"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def run_direct(
|
|
108
|
+
command: str | list[str],
|
|
109
|
+
*,
|
|
110
|
+
cwd: str | None = None,
|
|
111
|
+
timeout: int = 30,
|
|
112
|
+
input_text: str | None = None,
|
|
113
|
+
approval: Callable[[str], bool] | None = None,
|
|
114
|
+
enforce_policy: bool = True,
|
|
115
|
+
cancellation_token: CancellationToken | None = None,
|
|
116
|
+
) -> CommandResult:
|
|
117
|
+
"""Run an executable directly with shell=False."""
|
|
118
|
+
try:
|
|
119
|
+
argv = split_command(command) if isinstance(command, str) else list(command)
|
|
120
|
+
except ValueError as exc:
|
|
121
|
+
return CommandResult(126, blocked_reason=str(exc))
|
|
122
|
+
|
|
123
|
+
if enforce_policy:
|
|
124
|
+
risk, reason = classify_argv(argv)
|
|
125
|
+
if risk == "blocked":
|
|
126
|
+
return CommandResult(126, blocked_reason=reason)
|
|
127
|
+
if risk == "approval" and (approval is None or not approval(reason)):
|
|
128
|
+
return CommandResult(125, cancelled=True)
|
|
129
|
+
|
|
130
|
+
token = cancellation_token or current_token()
|
|
131
|
+
if token and token.cancelled:
|
|
132
|
+
return CommandResult(125, cancelled=True)
|
|
133
|
+
if token is None:
|
|
134
|
+
try:
|
|
135
|
+
completed = subprocess.run(
|
|
136
|
+
argv,
|
|
137
|
+
shell=False,
|
|
138
|
+
capture_output=True,
|
|
139
|
+
text=True,
|
|
140
|
+
cwd=cwd,
|
|
141
|
+
timeout=timeout,
|
|
142
|
+
input=input_text,
|
|
143
|
+
)
|
|
144
|
+
return CommandResult(
|
|
145
|
+
completed.returncode,
|
|
146
|
+
completed.stdout or "",
|
|
147
|
+
completed.stderr or "",
|
|
148
|
+
)
|
|
149
|
+
except subprocess.TimeoutExpired:
|
|
150
|
+
return CommandResult(124, stderr=f"command timed out after {timeout}s")
|
|
151
|
+
except FileNotFoundError:
|
|
152
|
+
return CommandResult(127, stderr=f"executable not found: {argv[0]}")
|
|
153
|
+
except Exception as exc:
|
|
154
|
+
return CommandResult(1, stderr=str(exc))
|
|
155
|
+
|
|
156
|
+
process = None
|
|
157
|
+
try:
|
|
158
|
+
process = subprocess.Popen(
|
|
159
|
+
argv,
|
|
160
|
+
shell=False,
|
|
161
|
+
stdin=subprocess.PIPE if input_text is not None else None,
|
|
162
|
+
stdout=subprocess.PIPE,
|
|
163
|
+
stderr=subprocess.PIPE,
|
|
164
|
+
text=True,
|
|
165
|
+
cwd=cwd,
|
|
166
|
+
)
|
|
167
|
+
started = time.monotonic()
|
|
168
|
+
pending_input = input_text
|
|
169
|
+
while True:
|
|
170
|
+
if token and token.cancelled:
|
|
171
|
+
process.terminate()
|
|
172
|
+
try:
|
|
173
|
+
process.communicate(timeout=2)
|
|
174
|
+
except subprocess.TimeoutExpired:
|
|
175
|
+
process.kill()
|
|
176
|
+
process.communicate()
|
|
177
|
+
return CommandResult(125, cancelled=True)
|
|
178
|
+
remaining = timeout - (time.monotonic() - started)
|
|
179
|
+
if remaining <= 0:
|
|
180
|
+
process.kill()
|
|
181
|
+
process.communicate()
|
|
182
|
+
return CommandResult(124, stderr=f"command timed out after {timeout}s")
|
|
183
|
+
try:
|
|
184
|
+
stdout, stderr = process.communicate(
|
|
185
|
+
input=pending_input,
|
|
186
|
+
timeout=min(0.1, remaining),
|
|
187
|
+
)
|
|
188
|
+
return CommandResult(
|
|
189
|
+
process.returncode,
|
|
190
|
+
stdout or "",
|
|
191
|
+
stderr or "",
|
|
192
|
+
)
|
|
193
|
+
except subprocess.TimeoutExpired:
|
|
194
|
+
pending_input = None
|
|
195
|
+
except FileNotFoundError:
|
|
196
|
+
return CommandResult(127, stderr=f"executable not found: {argv[0]}")
|
|
197
|
+
except KeyboardInterrupt:
|
|
198
|
+
if token:
|
|
199
|
+
token.cancel("cancelled by user")
|
|
200
|
+
if process and process.poll() is None:
|
|
201
|
+
process.terminate()
|
|
202
|
+
raise
|
|
203
|
+
except Exception as exc:
|
|
204
|
+
return CommandResult(1, stderr=str(exc))
|