ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/tools.py
ADDED
|
@@ -0,0 +1,1576 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool system for the ATA Coder.
|
|
3
|
+
|
|
4
|
+
Provides a set of tools the agent can use:
|
|
5
|
+
- read_file: Read file contents
|
|
6
|
+
- write_file: Create or overwrite a file
|
|
7
|
+
- edit_file: Precise string replacement in a file
|
|
8
|
+
- run_shell: Execute a shell command
|
|
9
|
+
- grep: Search file contents with regex
|
|
10
|
+
- glob: Find files matching a pattern
|
|
11
|
+
- list_dir: List directory contents
|
|
12
|
+
- web_search: Search the web (optional)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import html
|
|
17
|
+
import html.parser
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import re
|
|
21
|
+
import shutil
|
|
22
|
+
import subprocess
|
|
23
|
+
import fnmatch
|
|
24
|
+
import time
|
|
25
|
+
import urllib.parse
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
from typing import Any, Callable
|
|
28
|
+
|
|
29
|
+
import httpx
|
|
30
|
+
|
|
31
|
+
from .config import AgentConfig
|
|
32
|
+
from .clawd_integration import get_clawd
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ── Tool result type ─────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
class ToolResult:
|
|
40
|
+
"""Result of executing a tool."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, success: bool, output: str, error: str = ""):
|
|
43
|
+
self.success = success
|
|
44
|
+
self.output = output
|
|
45
|
+
self.error = error
|
|
46
|
+
|
|
47
|
+
def to_message(self) -> str:
|
|
48
|
+
"""Format as a message to the LLM."""
|
|
49
|
+
if self.success:
|
|
50
|
+
return self.output
|
|
51
|
+
return f"Error: {self.error}\n\n{self.output}".strip()
|
|
52
|
+
|
|
53
|
+
def to_tool_result(self, tool_call_id: str) -> dict:
|
|
54
|
+
"""Format as an OpenAI tool result message."""
|
|
55
|
+
return {
|
|
56
|
+
"role": "tool",
|
|
57
|
+
"tool_call_id": tool_call_id,
|
|
58
|
+
"content": self.to_message(),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# ── Tool definitions (OpenAI function format) ────────────────────────────────
|
|
63
|
+
# Extracted to tool_defs.py to keep this file under 1600 lines.
|
|
64
|
+
|
|
65
|
+
from .tool_defs import TOOL_DEFINITIONS # noqa: E402
|
|
66
|
+
|
|
67
|
+
# TOOL_DEFINITIONS are now in tool_defs.py (extracted to keep this file focused
|
|
68
|
+
# on implementations). The import above re-exports them transparently.
|
|
69
|
+
|
|
70
|
+
# ── Tool implementations ─────────────────────────────────────────────────────
|
|
71
|
+
|
|
72
|
+
class ToolExecutor:
|
|
73
|
+
"""Executes tool calls and manages workspace context."""
|
|
74
|
+
|
|
75
|
+
# Result limits
|
|
76
|
+
MAX_GREP_RESULTS = 100
|
|
77
|
+
MAX_GREP_PER_FILE = 20
|
|
78
|
+
MAX_GLOB_RESULTS = 200
|
|
79
|
+
MAX_DIR_ENTRIES = 500
|
|
80
|
+
|
|
81
|
+
# Directories to skip during recursive operations
|
|
82
|
+
SKIP_DIRS = {
|
|
83
|
+
"node_modules", "__pycache__", ".git", "venv", ".venv",
|
|
84
|
+
"dist", "build", "target", ".next", ".pytest_cache",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
def __init__(self, config: AgentConfig | None = None):
|
|
88
|
+
self.config = config or AgentConfig()
|
|
89
|
+
self.workspace = Path(self.config.workspace_dir).resolve()
|
|
90
|
+
self._mcp = None # set via set_mcp_client()
|
|
91
|
+
self._edit_callback: Callable[[str, str], None] | None = None
|
|
92
|
+
# File read cache: path -> (mtime, content). "只读一遍" — files
|
|
93
|
+
# are read from disk once per session and served from memory after.
|
|
94
|
+
self._file_cache: dict[str, tuple[float, str]] = {}
|
|
95
|
+
self._file_cache_max_entries = 50 # LRU eviction threshold
|
|
96
|
+
self._cache_dir: Path | None = None
|
|
97
|
+
# Sub-agent manager (set by agent)
|
|
98
|
+
self._sub_agent_mgr: Any = None
|
|
99
|
+
# Pre-built handler map (validated at init time)
|
|
100
|
+
self._handlers: dict[str, Callable] = self._build_handlers()
|
|
101
|
+
|
|
102
|
+
def _build_handlers(self) -> dict[str, Callable]:
|
|
103
|
+
"""Build a dispatch table: tool_name → handler callable."""
|
|
104
|
+
handlers: dict[str, Callable] = {}
|
|
105
|
+
for name in dir(self):
|
|
106
|
+
if name.startswith("_tool_"):
|
|
107
|
+
tool_name = name[len("_tool_"):]
|
|
108
|
+
handlers[tool_name] = getattr(self, name)
|
|
109
|
+
return handlers
|
|
110
|
+
|
|
111
|
+
def set_sub_agent_manager(self, mgr: Any) -> None:
|
|
112
|
+
"""Set the SubAgentManager for spawn/collect sub-agent tool support."""
|
|
113
|
+
self._sub_agent_mgr = mgr
|
|
114
|
+
|
|
115
|
+
def set_mcp_client(self, mcp: Any) -> None:
|
|
116
|
+
"""Set the MCPClient for mcp_search tool support."""
|
|
117
|
+
self._mcp = mcp
|
|
118
|
+
|
|
119
|
+
def setup_file_cache(self, cache_dir: str | Path) -> None:
|
|
120
|
+
"""Create session cache directory. Call once before running the agent."""
|
|
121
|
+
self._cache_dir = Path(cache_dir)
|
|
122
|
+
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
|
|
124
|
+
def clear_file_cache(self) -> None:
|
|
125
|
+
"""Remove all cached files and the cache directory."""
|
|
126
|
+
self._file_cache.clear()
|
|
127
|
+
if self._cache_dir and self._cache_dir.exists():
|
|
128
|
+
shutil.rmtree(self._cache_dir, ignore_errors=True)
|
|
129
|
+
self._cache_dir = None
|
|
130
|
+
|
|
131
|
+
def close(self) -> None:
|
|
132
|
+
"""Release all held resources (httpx clients, file caches)."""
|
|
133
|
+
self.clear_file_cache()
|
|
134
|
+
if hasattr(self, "_http") and self._http is not None:
|
|
135
|
+
try:
|
|
136
|
+
self._http.close()
|
|
137
|
+
except Exception:
|
|
138
|
+
pass
|
|
139
|
+
self._http = None
|
|
140
|
+
|
|
141
|
+
def __del__(self) -> None:
|
|
142
|
+
"""Safety net: ensure httpx client is closed on GC."""
|
|
143
|
+
try:
|
|
144
|
+
self.close()
|
|
145
|
+
except Exception:
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
def on_edit(self, callback: Callable[[str, str], None]) -> None:
|
|
149
|
+
"""Register callback for edit notifications: callback(file_path, old_content)."""
|
|
150
|
+
self._edit_callback = callback
|
|
151
|
+
|
|
152
|
+
def _notify_edit(self, file_path: str, old_content: str) -> None:
|
|
153
|
+
"""Notify the UI of a file edit for diff display."""
|
|
154
|
+
if self._edit_callback:
|
|
155
|
+
try:
|
|
156
|
+
self._edit_callback(file_path, old_content)
|
|
157
|
+
except Exception:
|
|
158
|
+
logger.exception("Edit callback failed for %s", file_path)
|
|
159
|
+
|
|
160
|
+
def _resolve_path(self, file_path: str) -> Path:
|
|
161
|
+
"""Resolve a file path relative to workspace.
|
|
162
|
+
|
|
163
|
+
Raises ValueError if the resolved path escapes the workspace via
|
|
164
|
+
path traversal (e.g., ``../../etc/passwd``).
|
|
165
|
+
"""
|
|
166
|
+
p = Path(file_path)
|
|
167
|
+
if not p.is_absolute():
|
|
168
|
+
p = self.workspace / p
|
|
169
|
+
resolved = p.resolve()
|
|
170
|
+
try:
|
|
171
|
+
resolved.relative_to(self.workspace.resolve())
|
|
172
|
+
except ValueError:
|
|
173
|
+
raise ValueError(
|
|
174
|
+
f"Path traversal blocked: {file_path} → {resolved} "
|
|
175
|
+
f"is outside workspace {self.workspace}"
|
|
176
|
+
)
|
|
177
|
+
return resolved
|
|
178
|
+
|
|
179
|
+
def _ensure_in_workspace(self, path: Path) -> Path:
|
|
180
|
+
"""
|
|
181
|
+
Ensure path is within workspace for safety.
|
|
182
|
+
For files outside workspace, allow with a warning.
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
path.resolve().relative_to(self.workspace)
|
|
186
|
+
except ValueError:
|
|
187
|
+
logger.warning("Path outside workspace: %s", path)
|
|
188
|
+
return path
|
|
189
|
+
|
|
190
|
+
# Global output cap — every tool result is trimmed to this many chars.
|
|
191
|
+
# ~25k tokens, still generous for legitimate work but prevents a single
|
|
192
|
+
# result from eating the whole context window.
|
|
193
|
+
MAX_OUTPUT_CHARS = 100_000
|
|
194
|
+
|
|
195
|
+
async def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
|
|
196
|
+
"""Dispatch a tool call to the appropriate handler.
|
|
197
|
+
|
|
198
|
+
All results are capped at MAX_OUTPUT_CHARS to prevent any single
|
|
199
|
+
tool response from blowing up the conversation context.
|
|
200
|
+
"""
|
|
201
|
+
handler = self._handlers.get(tool_name)
|
|
202
|
+
if handler is None:
|
|
203
|
+
return ToolResult(
|
|
204
|
+
success=False,
|
|
205
|
+
output="",
|
|
206
|
+
error=f"Unknown tool: {tool_name}",
|
|
207
|
+
)
|
|
208
|
+
try:
|
|
209
|
+
result = await handler(**arguments)
|
|
210
|
+
except Exception as e:
|
|
211
|
+
logger.exception("Tool %s failed", tool_name)
|
|
212
|
+
return ToolResult(
|
|
213
|
+
success=False, output="", error=f"{type(e).__name__}: {e}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# ── Global size cap ──────────────────────────────────────────
|
|
217
|
+
if len(result.output) > self.MAX_OUTPUT_CHARS:
|
|
218
|
+
original_len = len(result.output)
|
|
219
|
+
cut = result.output[:self.MAX_OUTPUT_CHARS]
|
|
220
|
+
result.output = (
|
|
221
|
+
cut + f"\n\n... [truncated {original_len - self.MAX_OUTPUT_CHARS:,} "
|
|
222
|
+
f"chars — result was {original_len:,} chars total]"
|
|
223
|
+
)
|
|
224
|
+
logger.warning(
|
|
225
|
+
"Tool %s output truncated: %d → %d chars",
|
|
226
|
+
tool_name, original_len, self.MAX_OUTPUT_CHARS,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
# ── File tools ───────────────────────────────────────────────────────────
|
|
232
|
+
|
|
233
|
+
# ── Output limits (prevent token bloat from large file reads) ──────
|
|
234
|
+
MAX_READ_LINES = 2000 # auto-truncate reads without an explicit limit
|
|
235
|
+
MAX_READ_CHARS = 80_000 # hard cap on output chars (~20k tokens)
|
|
236
|
+
|
|
237
|
+
async def _tool_read_file(
|
|
238
|
+
self,
|
|
239
|
+
file_path: str,
|
|
240
|
+
offset: int | None = None,
|
|
241
|
+
limit: int | None = None,
|
|
242
|
+
) -> ToolResult:
|
|
243
|
+
"""Read a file with optional line range.
|
|
244
|
+
|
|
245
|
+
When *limit* is not set, output is capped at MAX_READ_LINES to
|
|
246
|
+
prevent a single file read from eating tens of thousands of context
|
|
247
|
+
tokens. Use *offset* + *limit* to page through larger files.
|
|
248
|
+
"""
|
|
249
|
+
path = self._resolve_path(file_path)
|
|
250
|
+
if not path.exists():
|
|
251
|
+
return ToolResult(
|
|
252
|
+
success=False,
|
|
253
|
+
output="",
|
|
254
|
+
error=f"File not found: {path}",
|
|
255
|
+
)
|
|
256
|
+
if path.is_dir():
|
|
257
|
+
return ToolResult(
|
|
258
|
+
success=False,
|
|
259
|
+
output="",
|
|
260
|
+
error=f"Path is a directory, not a file: {path}",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# ── File cache: "只读一遍" — return short note on re-read ────
|
|
264
|
+
cache_key = str(path.resolve())
|
|
265
|
+
current_mtime = path.stat().st_mtime
|
|
266
|
+
needs_disk_read = True
|
|
267
|
+
|
|
268
|
+
if cache_key in self._file_cache:
|
|
269
|
+
cached_mtime, cached_content = self._file_cache[cache_key]
|
|
270
|
+
if cached_mtime == current_mtime:
|
|
271
|
+
if offset is not None or limit is not None:
|
|
272
|
+
# Specific section — serve from memory, skip disk
|
|
273
|
+
needs_disk_read = False
|
|
274
|
+
lines = cached_content.splitlines(keepends=True)
|
|
275
|
+
if lines and not lines[-1].endswith("\n"):
|
|
276
|
+
lines[-1] += "\n"
|
|
277
|
+
else:
|
|
278
|
+
# Whole file re-read — DON'T send content again
|
|
279
|
+
total = cached_content.count("\n") + 1
|
|
280
|
+
chars = len(cached_content)
|
|
281
|
+
return ToolResult(
|
|
282
|
+
success=True,
|
|
283
|
+
output=(
|
|
284
|
+
f"[cached] {file_path} — {total} lines, {chars:,} chars.\n"
|
|
285
|
+
f"Already in conversation context from earlier read. "
|
|
286
|
+
f"Use offset/limit to read specific sections if needed."
|
|
287
|
+
),
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
if needs_disk_read:
|
|
291
|
+
try:
|
|
292
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
293
|
+
raw = f.read()
|
|
294
|
+
except Exception as e:
|
|
295
|
+
return ToolResult(
|
|
296
|
+
success=False, output="", error=f"Cannot read file: {e}"
|
|
297
|
+
)
|
|
298
|
+
self._file_cache[cache_key] = (current_mtime, raw)
|
|
299
|
+
# LRU eviction: drop oldest entries when cache exceeds limit
|
|
300
|
+
if len(self._file_cache) > self._file_cache_max_entries:
|
|
301
|
+
overflow = len(self._file_cache) - self._file_cache_max_entries
|
|
302
|
+
oldest = list(self._file_cache.keys())[:overflow]
|
|
303
|
+
for k in oldest:
|
|
304
|
+
del self._file_cache[k]
|
|
305
|
+
lines = raw.splitlines(keepends=True)
|
|
306
|
+
if lines and not lines[-1].endswith("\n"):
|
|
307
|
+
lines[-1] += "\n"
|
|
308
|
+
# Mirror to disk cache dir if configured
|
|
309
|
+
if self._cache_dir:
|
|
310
|
+
# Cross-platform safe filename: strip drive letter on Windows,
|
|
311
|
+
# replace path separators with underscores
|
|
312
|
+
resolved = str(path.resolve())
|
|
313
|
+
if len(resolved) >= 2 and resolved[1] == ":":
|
|
314
|
+
resolved = resolved[2:] # strip "C:"
|
|
315
|
+
safe_name = resolved.lstrip("\\/").replace("\\", "_").replace("/", "_")
|
|
316
|
+
try:
|
|
317
|
+
(self._cache_dir / safe_name).write_text(raw, encoding="utf-8")
|
|
318
|
+
except Exception:
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
total_lines = len(lines)
|
|
322
|
+
user_specified_range = limit is not None
|
|
323
|
+
|
|
324
|
+
# Default cap: when the caller didn't ask for a specific range,
|
|
325
|
+
# truncate to MAX_READ_LINES so one file doesn't blow the context.
|
|
326
|
+
effective_limit = limit if user_specified_range else min(
|
|
327
|
+
len(lines), self.MAX_READ_LINES
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
start = (offset or 1) - 1
|
|
331
|
+
end = start + effective_limit
|
|
332
|
+
|
|
333
|
+
# Clamp
|
|
334
|
+
start = max(0, start)
|
|
335
|
+
end = min(len(lines), end)
|
|
336
|
+
|
|
337
|
+
selected = lines[start:end]
|
|
338
|
+
was_truncated = (end - start) < (total_lines - start) if not user_specified_range else False
|
|
339
|
+
# Also truncate if the user explicitly asked for a range but total
|
|
340
|
+
# output still exceeds the hard char cap (safety net).
|
|
341
|
+
char_truncated = False
|
|
342
|
+
|
|
343
|
+
# Format with line numbers
|
|
344
|
+
output_lines: list[str] = []
|
|
345
|
+
chars = 0
|
|
346
|
+
for i, line in enumerate(selected, start=start + 1):
|
|
347
|
+
formatted = f"{i:6d}\t{line.rstrip()}"
|
|
348
|
+
chars += len(formatted) + 1 # +1 for newline
|
|
349
|
+
if chars > self.MAX_READ_CHARS:
|
|
350
|
+
output_lines.append(f"... (output truncated at {self.MAX_READ_CHARS} chars, use offset/limit to read more)")
|
|
351
|
+
char_truncated = True
|
|
352
|
+
break
|
|
353
|
+
output_lines.append(formatted)
|
|
354
|
+
|
|
355
|
+
truncated = was_truncated or char_truncated
|
|
356
|
+
shown_lines = len(output_lines) - (1 if char_truncated else 0)
|
|
357
|
+
header = f"File: {path} (lines {start+1}-{start+shown_lines} of {total_lines}"
|
|
358
|
+
if truncated:
|
|
359
|
+
header += f", truncated — use offset/limit to page"
|
|
360
|
+
header += ")\n"
|
|
361
|
+
return ToolResult(success=True, output=header + "\n".join(output_lines))
|
|
362
|
+
|
|
363
|
+
async def _tool_write_file(
|
|
364
|
+
self, file_path: str, content: str
|
|
365
|
+
) -> ToolResult:
|
|
366
|
+
"""Create or overwrite a file. Captures old content for diff display."""
|
|
367
|
+
path = self._resolve_path(file_path)
|
|
368
|
+
self._ensure_in_workspace(path)
|
|
369
|
+
|
|
370
|
+
# Capture old content for diff (if file exists)
|
|
371
|
+
old_content = ""
|
|
372
|
+
if path.exists():
|
|
373
|
+
try:
|
|
374
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
375
|
+
old_content = f.read()
|
|
376
|
+
except Exception:
|
|
377
|
+
pass
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
381
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
382
|
+
f.write(content)
|
|
383
|
+
size = path.stat().st_size
|
|
384
|
+
|
|
385
|
+
# Notify UI for diff display if overwriting
|
|
386
|
+
if old_content:
|
|
387
|
+
self._notify_edit(str(path), old_content)
|
|
388
|
+
|
|
389
|
+
return ToolResult(
|
|
390
|
+
success=True,
|
|
391
|
+
output=f"File written: {path} ({size} bytes, {content.count(chr(10))} lines)",
|
|
392
|
+
)
|
|
393
|
+
except Exception as e:
|
|
394
|
+
return ToolResult(
|
|
395
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
async def _tool_edit_file(
|
|
399
|
+
self, file_path: str, old_string: str, new_string: str
|
|
400
|
+
) -> ToolResult:
|
|
401
|
+
"""Replace text in a file. Uses CST for Python (preserves formatting),
|
|
402
|
+
falls back to text replacement for other languages."""
|
|
403
|
+
if old_string == new_string:
|
|
404
|
+
return ToolResult(
|
|
405
|
+
success=False,
|
|
406
|
+
output="",
|
|
407
|
+
error="old_string and new_string are identical",
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
path = self._resolve_path(file_path)
|
|
411
|
+
if not path.exists():
|
|
412
|
+
return ToolResult(
|
|
413
|
+
success=False,
|
|
414
|
+
output="",
|
|
415
|
+
error=f"File not found: {path}",
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
420
|
+
content = f.read()
|
|
421
|
+
except Exception as e:
|
|
422
|
+
return ToolResult(
|
|
423
|
+
success=False, output="", error=f"Cannot read file: {e}"
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Store old content for diff display (via UI callback)
|
|
427
|
+
old_content = content
|
|
428
|
+
|
|
429
|
+
# ── CST-based edit for Python files ──────────────────────────────
|
|
430
|
+
if path.suffix == ".py" or path.suffix == ".pyi":
|
|
431
|
+
result = self._cst_edit(content, old_string, new_string, path)
|
|
432
|
+
if result is not None:
|
|
433
|
+
new_content = result
|
|
434
|
+
try:
|
|
435
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
436
|
+
f.write(new_content)
|
|
437
|
+
self._notify_edit(str(path), old_content)
|
|
438
|
+
return ToolResult(
|
|
439
|
+
success=True,
|
|
440
|
+
output=f"File edited (AST): {path} (1 replacement)",
|
|
441
|
+
)
|
|
442
|
+
except Exception as e:
|
|
443
|
+
return ToolResult(
|
|
444
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
445
|
+
)
|
|
446
|
+
# CST edit failed — fall through to text-based replacement
|
|
447
|
+
|
|
448
|
+
# ── Text-based fallback ──────────────────────────────────────────
|
|
449
|
+
count = content.count(old_string)
|
|
450
|
+
if count == 0:
|
|
451
|
+
return ToolResult(
|
|
452
|
+
success=False,
|
|
453
|
+
output="",
|
|
454
|
+
error="old_string not found in file. Check whitespace/indentation.",
|
|
455
|
+
)
|
|
456
|
+
if count > 1:
|
|
457
|
+
return ToolResult(
|
|
458
|
+
success=False,
|
|
459
|
+
output="",
|
|
460
|
+
error=f"old_string found {count} times in file. Must be unique. Use a larger string with more surrounding context.",
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
new_content = content.replace(old_string, new_string, 1)
|
|
464
|
+
try:
|
|
465
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
466
|
+
f.write(new_content)
|
|
467
|
+
|
|
468
|
+
self._notify_edit(str(path), old_content)
|
|
469
|
+
|
|
470
|
+
return ToolResult(
|
|
471
|
+
success=True,
|
|
472
|
+
output=f"File edited: {path} (1 replacement)",
|
|
473
|
+
)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
return ToolResult(
|
|
476
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
@staticmethod
|
|
480
|
+
def _cst_edit(content: str, old_str: str, new_str: str, path: Any) -> str | None:
|
|
481
|
+
"""Attempt a CST-based edit for Python files using libcst.
|
|
482
|
+
|
|
483
|
+
Parses the file as a Concrete Syntax Tree, finds the node matching
|
|
484
|
+
old_str, replaces it with new_str parsed as the same node type,
|
|
485
|
+
and returns the formatted code. Preserves ALL formatting.
|
|
486
|
+
|
|
487
|
+
Returns None if the edit cannot be performed via CST (fallback to text).
|
|
488
|
+
"""
|
|
489
|
+
try:
|
|
490
|
+
import libcst as cst
|
|
491
|
+
except ImportError:
|
|
492
|
+
return None # libcst not installed — use text fallback
|
|
493
|
+
|
|
494
|
+
try:
|
|
495
|
+
tree = cst.parse_module(content)
|
|
496
|
+
except Exception:
|
|
497
|
+
return None # Can't parse — fall back to text
|
|
498
|
+
|
|
499
|
+
# Strategy: parse old_str as a statement, find it in the tree, replace
|
|
500
|
+
try:
|
|
501
|
+
# Try parsing old_str as a full module body (the typical case for
|
|
502
|
+
# function/method/class-level edits)
|
|
503
|
+
old_module = cst.parse_module(old_str + "\n")
|
|
504
|
+
if len(old_module.body) == 1:
|
|
505
|
+
old_node = old_module.body[0]
|
|
506
|
+
# Parse new_str as the same node type
|
|
507
|
+
new_module = cst.parse_module(new_str + "\n")
|
|
508
|
+
if len(new_module.body) == 1:
|
|
509
|
+
new_node = new_module.body[0]
|
|
510
|
+
transformer = _NodeReplacer(old_node, new_node)
|
|
511
|
+
new_tree = tree.visit(transformer)
|
|
512
|
+
if transformer.found:
|
|
513
|
+
return new_tree.code
|
|
514
|
+
except Exception:
|
|
515
|
+
pass
|
|
516
|
+
|
|
517
|
+
# Strategy 2: try parsing old_str as a simple statement line
|
|
518
|
+
try:
|
|
519
|
+
old_module = cst.parse_module(old_str + "\n")
|
|
520
|
+
if len(old_module.body) == 1:
|
|
521
|
+
old_stmt = old_module.body[0]
|
|
522
|
+
# For simple statements, use a body-statement-level replacer
|
|
523
|
+
new_module = cst.parse_module(new_str + "\n")
|
|
524
|
+
if len(new_module.body) == 1:
|
|
525
|
+
new_stmt = new_module.body[0]
|
|
526
|
+
transformer = _StatementReplacer(old_stmt, new_stmt)
|
|
527
|
+
new_tree = tree.visit(transformer)
|
|
528
|
+
if transformer.found:
|
|
529
|
+
return new_tree.code
|
|
530
|
+
except Exception:
|
|
531
|
+
pass
|
|
532
|
+
|
|
533
|
+
return None # All CST strategies failed — fall back to text
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
# ── CST Transformers (libcst optional — guarded at module level) ──────────
|
|
537
|
+
|
|
538
|
+
try:
|
|
539
|
+
import libcst as _cst_lib
|
|
540
|
+
|
|
541
|
+
class _NodeReplacer(_cst_lib.CSTTransformer):
|
|
542
|
+
"""Replace a specific CST node with another, preserving all else."""
|
|
543
|
+
|
|
544
|
+
def __init__(self, old_node: _cst_lib.CSTNode, new_node: _cst_lib.CSTNode):
|
|
545
|
+
self.old_node = old_node
|
|
546
|
+
self.new_node = new_node
|
|
547
|
+
self.found = False
|
|
548
|
+
|
|
549
|
+
def on_visit(self, node: _cst_lib.CSTNode) -> bool:
|
|
550
|
+
if node.deep_equals(self.old_node) and not self.found:
|
|
551
|
+
self.found = True
|
|
552
|
+
return False
|
|
553
|
+
return True
|
|
554
|
+
|
|
555
|
+
def on_leave(self, original_node: _cst_lib.CSTNode, updated_node: _cst_lib.CSTNode) -> _cst_lib.CSTNode:
|
|
556
|
+
if original_node.deep_equals(self.old_node) and self.found:
|
|
557
|
+
return self.new_node
|
|
558
|
+
return updated_node
|
|
559
|
+
|
|
560
|
+
class _StatementReplacer(_cst_lib.CSTTransformer):
|
|
561
|
+
"""Replace a statement within a body, matching by deep equality."""
|
|
562
|
+
|
|
563
|
+
def __init__(self, old_stmt: _cst_lib.CSTNode, new_stmt: _cst_lib.CSTNode):
|
|
564
|
+
self.old_stmt = old_stmt
|
|
565
|
+
self.new_stmt = new_stmt
|
|
566
|
+
self.found = False
|
|
567
|
+
|
|
568
|
+
def leave_SimpleStatementLine(
|
|
569
|
+
self, original_node: _cst_lib.SimpleStatementLine, updated_node: _cst_lib.SimpleStatementLine
|
|
570
|
+
):
|
|
571
|
+
if not self.found and len(updated_node.body) == 1:
|
|
572
|
+
if updated_node.body[0].deep_equals(self.old_stmt):
|
|
573
|
+
self.found = True
|
|
574
|
+
return updated_node.with_changes(body=[self.new_stmt])
|
|
575
|
+
return updated_node
|
|
576
|
+
|
|
577
|
+
except ImportError:
|
|
578
|
+
# libcst not installed — AST editing will fall back to text replacement
|
|
579
|
+
_NodeReplacer = None # type: ignore
|
|
580
|
+
_StatementReplacer = None # type: ignore
|
|
581
|
+
|
|
582
|
+
# ── Rename symbol (AST-aware) ────────────────────────────────────────────
|
|
583
|
+
|
|
584
|
+
async def _tool_rename_symbol(
|
|
585
|
+
self, file_path: str, old_name: str, new_name: str,
|
|
586
|
+
symbol_type: str = "variable"
|
|
587
|
+
) -> ToolResult:
|
|
588
|
+
"""Safely rename a Python symbol using libcst AST matching.
|
|
589
|
+
Never touches strings, comments, or imports of the same name.
|
|
590
|
+
"""
|
|
591
|
+
if old_name == new_name:
|
|
592
|
+
return ToolResult(success=False, output="", error="Names are identical.")
|
|
593
|
+
if not old_name.isidentifier() or not new_name.isidentifier():
|
|
594
|
+
return ToolResult(success=False, output="", error="Names must be valid Python identifiers.")
|
|
595
|
+
|
|
596
|
+
path = self._resolve_path(file_path)
|
|
597
|
+
if not path.exists():
|
|
598
|
+
return ToolResult(success=False, output="", error=f"File not found: {path}")
|
|
599
|
+
if path.suffix not in (".py", ".pyi"):
|
|
600
|
+
return ToolResult(success=False, output="", error="rename_symbol only works with Python files.")
|
|
601
|
+
|
|
602
|
+
try:
|
|
603
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
604
|
+
content = f.read()
|
|
605
|
+
except Exception as e:
|
|
606
|
+
return ToolResult(success=False, output="", error=f"Cannot read file: {e}")
|
|
607
|
+
|
|
608
|
+
try:
|
|
609
|
+
import libcst as cst
|
|
610
|
+
except ImportError:
|
|
611
|
+
return ToolResult(success=False, output="", error="libcst not installed. Run: pip install libcst")
|
|
612
|
+
|
|
613
|
+
try:
|
|
614
|
+
tree = cst.parse_module(content)
|
|
615
|
+
except Exception as e:
|
|
616
|
+
return ToolResult(success=False, output="", error=f"Cannot parse Python file: {e}")
|
|
617
|
+
|
|
618
|
+
# Choose the right renamer for the symbol type
|
|
619
|
+
if symbol_type == "function":
|
|
620
|
+
renamer = _FunctionRenamer(old_name, new_name)
|
|
621
|
+
elif symbol_type == "class":
|
|
622
|
+
renamer = _ClassRenamer(old_name, new_name)
|
|
623
|
+
else:
|
|
624
|
+
renamer = _VariableRenamer(old_name, new_name)
|
|
625
|
+
|
|
626
|
+
new_tree = tree.visit(renamer)
|
|
627
|
+
if not renamer.changes:
|
|
628
|
+
return ToolResult(success=False, output="", error=f"Symbol '{old_name}' not found in file.")
|
|
629
|
+
|
|
630
|
+
new_content = new_tree.code
|
|
631
|
+
old_content = content
|
|
632
|
+
|
|
633
|
+
try:
|
|
634
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
635
|
+
f.write(new_content)
|
|
636
|
+
self._notify_edit(str(path), old_content)
|
|
637
|
+
return ToolResult(
|
|
638
|
+
success=True,
|
|
639
|
+
output=f"Renamed {renamer.changes} occurrence(s) of '{old_name}' → '{new_name}' in {path}",
|
|
640
|
+
)
|
|
641
|
+
except Exception as e:
|
|
642
|
+
return ToolResult(success=False, output="", error=f"Cannot write file: {e}")
|
|
643
|
+
|
|
644
|
+
# ── Shell tool ───────────────────────────────────────────────────────────
|
|
645
|
+
|
|
646
|
+
async def _tool_run_shell(
|
|
647
|
+
self, command: str, timeout: int = 120
|
|
648
|
+
) -> ToolResult:
|
|
649
|
+
"""Execute a shell command."""
|
|
650
|
+
# Safety checks
|
|
651
|
+
cmd_lower = command.lower().strip()
|
|
652
|
+
|
|
653
|
+
# Check blocked patterns (hard block — unforgivable operations)
|
|
654
|
+
for blocked in self.config.blocked_commands:
|
|
655
|
+
if blocked.lower() in cmd_lower:
|
|
656
|
+
return ToolResult(
|
|
657
|
+
success=False,
|
|
658
|
+
output="",
|
|
659
|
+
error=f"Blocked command pattern detected: {blocked}",
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Check if first word is allowed (soft warning — safety_guard is the real gate)
|
|
663
|
+
first_word = command.strip().split()[0] if command.strip() else ""
|
|
664
|
+
if first_word and first_word not in self.config.allowed_commands:
|
|
665
|
+
logger.debug("Command '%s' not in allowed_commands whitelist, proceeding", first_word)
|
|
666
|
+
|
|
667
|
+
try:
|
|
668
|
+
proc = await asyncio.create_subprocess_shell(
|
|
669
|
+
command,
|
|
670
|
+
stdout=asyncio.subprocess.PIPE,
|
|
671
|
+
stderr=asyncio.subprocess.PIPE,
|
|
672
|
+
cwd=str(self.workspace),
|
|
673
|
+
)
|
|
674
|
+
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
|
675
|
+
proc.communicate(), timeout=timeout
|
|
676
|
+
)
|
|
677
|
+
output = stdout_bytes.decode("utf-8", errors="replace")
|
|
678
|
+
if stderr_bytes:
|
|
679
|
+
output += f"\n[stderr]\n{stderr_bytes.decode('utf-8', errors='replace')}"
|
|
680
|
+
returncode = proc.returncode or 0
|
|
681
|
+
if returncode != 0:
|
|
682
|
+
output += f"\n[exit code: {returncode}]"
|
|
683
|
+
return ToolResult(
|
|
684
|
+
success=returncode == 0,
|
|
685
|
+
output=output.strip() or "(no output)",
|
|
686
|
+
)
|
|
687
|
+
except asyncio.TimeoutError:
|
|
688
|
+
return ToolResult(
|
|
689
|
+
success=False,
|
|
690
|
+
output="",
|
|
691
|
+
error=f"Command timed out after {timeout}s",
|
|
692
|
+
)
|
|
693
|
+
except Exception as e:
|
|
694
|
+
return ToolResult(
|
|
695
|
+
success=False, output="", error=f"Command failed: {e}"
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# ── Search tools ─────────────────────────────────────────────────────────
|
|
699
|
+
|
|
700
|
+
async def _tool_grep(
|
|
701
|
+
self,
|
|
702
|
+
pattern: str,
|
|
703
|
+
path: str | None = None,
|
|
704
|
+
glob: str | None = None,
|
|
705
|
+
case_sensitive: bool = False,
|
|
706
|
+
) -> ToolResult:
|
|
707
|
+
"""Search file contents with regex."""
|
|
708
|
+
search_dir = self._resolve_path(path or ".")
|
|
709
|
+
if not search_dir.exists():
|
|
710
|
+
return ToolResult(
|
|
711
|
+
success=False,
|
|
712
|
+
output="",
|
|
713
|
+
error=f"Directory not found: {search_dir}",
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
flags = 0 if case_sensitive else re.IGNORECASE
|
|
717
|
+
try:
|
|
718
|
+
regex = re.compile(pattern, flags)
|
|
719
|
+
except re.error as e:
|
|
720
|
+
return ToolResult(
|
|
721
|
+
success=False, output="", error=f"Invalid regex: {e}"
|
|
722
|
+
)
|
|
723
|
+
|
|
724
|
+
# Run filesystem walk + file reads in thread pool to avoid
|
|
725
|
+
# blocking the asyncio event loop on large codebases.
|
|
726
|
+
def _do_grep():
|
|
727
|
+
results: list[str] = []
|
|
728
|
+
total_matches = 0
|
|
729
|
+
for root, dirs, files in os.walk(search_dir):
|
|
730
|
+
dirs[:] = [
|
|
731
|
+
d for d in dirs
|
|
732
|
+
if not d.startswith(".")
|
|
733
|
+
and d not in self.SKIP_DIRS
|
|
734
|
+
]
|
|
735
|
+
for fname in files:
|
|
736
|
+
if glob and not fnmatch.fnmatch(fname, glob):
|
|
737
|
+
continue
|
|
738
|
+
|
|
739
|
+
full_path = os.path.join(root, fname)
|
|
740
|
+
try:
|
|
741
|
+
rel_path = os.path.relpath(full_path, self.workspace)
|
|
742
|
+
except ValueError:
|
|
743
|
+
rel_path = full_path
|
|
744
|
+
|
|
745
|
+
try:
|
|
746
|
+
with open(full_path, "r", encoding="utf-8", errors="replace") as fh:
|
|
747
|
+
file_lines = fh.readlines()
|
|
748
|
+
except Exception:
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
matches_in_file = []
|
|
752
|
+
for line_no, line_text in enumerate(file_lines, 1):
|
|
753
|
+
if regex.search(line_text):
|
|
754
|
+
matches_in_file.append(
|
|
755
|
+
f" {line_no}: {line_text.rstrip()[:200]}"
|
|
756
|
+
)
|
|
757
|
+
total_matches += 1
|
|
758
|
+
if len(matches_in_file) >= self.MAX_GREP_PER_FILE:
|
|
759
|
+
matches_in_file.append(" ... (truncated)")
|
|
760
|
+
break
|
|
761
|
+
|
|
762
|
+
if matches_in_file:
|
|
763
|
+
results.append(f"{rel_path}:")
|
|
764
|
+
results.extend(matches_in_file)
|
|
765
|
+
|
|
766
|
+
if len(results) >= self.MAX_GREP_RESULTS:
|
|
767
|
+
results.append("... (result limit reached)")
|
|
768
|
+
return results, total_matches
|
|
769
|
+
|
|
770
|
+
if len(results) >= self.MAX_GREP_RESULTS:
|
|
771
|
+
break
|
|
772
|
+
return results, total_matches
|
|
773
|
+
|
|
774
|
+
results, total_matches = await self._run_in_thread(_do_grep)
|
|
775
|
+
|
|
776
|
+
if not results:
|
|
777
|
+
return ToolResult(
|
|
778
|
+
success=True,
|
|
779
|
+
output=f"No matches found for pattern: {pattern}",
|
|
780
|
+
)
|
|
781
|
+
return ToolResult(
|
|
782
|
+
success=True,
|
|
783
|
+
output=f"Found {total_matches} matches:\n\n" + "\n".join(results),
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
async def _tool_glob(
|
|
787
|
+
self,
|
|
788
|
+
pattern: str,
|
|
789
|
+
path: str | None = None,
|
|
790
|
+
) -> ToolResult:
|
|
791
|
+
"""Find files by glob pattern."""
|
|
792
|
+
search_dir = self._resolve_path(path or ".")
|
|
793
|
+
if not search_dir.exists():
|
|
794
|
+
return ToolResult(
|
|
795
|
+
success=False,
|
|
796
|
+
output="",
|
|
797
|
+
error=f"Directory not found: {search_dir}",
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
import glob as glob_mod
|
|
801
|
+
|
|
802
|
+
# Auto-add **/ prefix for recursive matching if not already present
|
|
803
|
+
if "**" not in pattern:
|
|
804
|
+
pattern = f"**/{pattern}"
|
|
805
|
+
search_pattern = str(search_dir / pattern)
|
|
806
|
+
matches = glob_mod.glob(search_pattern, recursive=True)
|
|
807
|
+
|
|
808
|
+
if not matches:
|
|
809
|
+
return ToolResult(
|
|
810
|
+
success=True, output=f"No files matching: {pattern}"
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Sort and format
|
|
814
|
+
matches.sort()
|
|
815
|
+
output_lines = []
|
|
816
|
+
for m in matches[:200]:
|
|
817
|
+
try:
|
|
818
|
+
rel = os.path.relpath(m, self.workspace)
|
|
819
|
+
except ValueError:
|
|
820
|
+
rel = m
|
|
821
|
+
size = os.path.getsize(m) if os.path.isfile(m) else 0
|
|
822
|
+
output_lines.append(f" {rel} ({size:,} bytes)")
|
|
823
|
+
|
|
824
|
+
if len(matches) > 200:
|
|
825
|
+
output_lines.append(
|
|
826
|
+
f" ... and {len(matches) - 200} more files"
|
|
827
|
+
)
|
|
828
|
+
|
|
829
|
+
return ToolResult(
|
|
830
|
+
success=True,
|
|
831
|
+
output=f"Found {len(matches)} files matching '{pattern}':\n"
|
|
832
|
+
+ "\n".join(output_lines),
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
async def _tool_list_dir(
|
|
836
|
+
self,
|
|
837
|
+
path: str | None = None,
|
|
838
|
+
recursive: bool = False,
|
|
839
|
+
) -> ToolResult:
|
|
840
|
+
"""List directory contents."""
|
|
841
|
+
target = self._resolve_path(path or ".")
|
|
842
|
+
if not target.exists():
|
|
843
|
+
return ToolResult(
|
|
844
|
+
success=False,
|
|
845
|
+
output="",
|
|
846
|
+
error=f"Directory not found: {target}",
|
|
847
|
+
)
|
|
848
|
+
if not target.is_dir():
|
|
849
|
+
return ToolResult(
|
|
850
|
+
success=False,
|
|
851
|
+
output="",
|
|
852
|
+
error=f"Not a directory: {target}",
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
output_lines = [f"Directory: {target}"]
|
|
856
|
+
entries: list[str] = []
|
|
857
|
+
|
|
858
|
+
if recursive:
|
|
859
|
+
for root, dirs, files in os.walk(target):
|
|
860
|
+
dirs[:] = [
|
|
861
|
+
d for d in dirs
|
|
862
|
+
if not d.startswith(".") and d not in self.SKIP_DIRS
|
|
863
|
+
]
|
|
864
|
+
level = root.replace(str(target), "").count(os.sep)
|
|
865
|
+
indent = " " * level
|
|
866
|
+
if level > 0:
|
|
867
|
+
entries.append(f"{indent}{os.path.basename(root)}/")
|
|
868
|
+
for f in sorted(files):
|
|
869
|
+
fp = os.path.join(root, f)
|
|
870
|
+
size = os.path.getsize(fp)
|
|
871
|
+
entries.append(f"{indent} {f} ({size:,}B)")
|
|
872
|
+
else:
|
|
873
|
+
items = sorted(target.iterdir(), key=lambda x: (not x.is_dir(), x.name))
|
|
874
|
+
for item in items:
|
|
875
|
+
suffix = "/" if item.is_dir() else ""
|
|
876
|
+
size = ""
|
|
877
|
+
if item.is_file():
|
|
878
|
+
size = f" ({item.stat().st_size:,}B)"
|
|
879
|
+
entries.append(f" {item.name}{suffix}{size}")
|
|
880
|
+
|
|
881
|
+
output_lines.extend(entries[:500])
|
|
882
|
+
if len(entries) > 500:
|
|
883
|
+
output_lines.append(f" ... and {len(entries) - 500} more entries")
|
|
884
|
+
|
|
885
|
+
return ToolResult(success=True, output="\n".join(output_lines))
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
# ── Web tools ──────────────────────────────────────────────────────────
|
|
889
|
+
|
|
890
|
+
# Internal HTTP client (lazy-init, shared across web tools)
|
|
891
|
+
_http: httpx.Client | None = None
|
|
892
|
+
|
|
893
|
+
async def _run_in_thread(self, func, *args, **kwargs):
|
|
894
|
+
"""Run a sync function in a thread pool to avoid blocking the event loop."""
|
|
895
|
+
from functools import partial
|
|
896
|
+
loop = asyncio.get_running_loop()
|
|
897
|
+
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
|
|
898
|
+
|
|
899
|
+
@property
|
|
900
|
+
def http(self) -> httpx.Client:
|
|
901
|
+
if self._http is None:
|
|
902
|
+
self._http = httpx.Client(
|
|
903
|
+
timeout=httpx.Timeout(30.0),
|
|
904
|
+
follow_redirects=True,
|
|
905
|
+
headers={
|
|
906
|
+
"User-Agent": (
|
|
907
|
+
"ATA-Coder/2.0 (AI Coding Assistant; "
|
|
908
|
+
"+https://github.com/ata-coder/ata-coder)"
|
|
909
|
+
),
|
|
910
|
+
"Accept": "text/html,application/xhtml+xml,*/*",
|
|
911
|
+
"Accept-Language": "en-US,zh-CN;q=0.9",
|
|
912
|
+
},
|
|
913
|
+
)
|
|
914
|
+
return self._http
|
|
915
|
+
|
|
916
|
+
async def _tool_web_search(
|
|
917
|
+
self,
|
|
918
|
+
query: str,
|
|
919
|
+
max_results: int = 10,
|
|
920
|
+
) -> ToolResult:
|
|
921
|
+
"""Search the web with tiered fallback: Bing → Baidu → Google.
|
|
922
|
+
|
|
923
|
+
All three use web scraping (no API key required).
|
|
924
|
+
Set ATA_CODER_SEARCH_BACKEND to force a single backend:
|
|
925
|
+
"bing" / "baidu" / "google" / "duckduckgo"
|
|
926
|
+
"""
|
|
927
|
+
import os
|
|
928
|
+
max_results = min(max(max_results, 1), 20)
|
|
929
|
+
forced = os.environ.get("ATA_CODER_SEARCH_BACKEND", "")
|
|
930
|
+
|
|
931
|
+
errors: list[str] = []
|
|
932
|
+
|
|
933
|
+
# Build fallback chain: respect forced backend, otherwise tiered
|
|
934
|
+
if forced:
|
|
935
|
+
chain = [(forced, getattr(self, f"_search_{forced}", None))]
|
|
936
|
+
else:
|
|
937
|
+
chain = [
|
|
938
|
+
("Bing", self._search_bing),
|
|
939
|
+
("Baidu", self._search_baidu),
|
|
940
|
+
("Google", self._search_google),
|
|
941
|
+
]
|
|
942
|
+
|
|
943
|
+
for name, searcher in chain:
|
|
944
|
+
if searcher is None:
|
|
945
|
+
errors.append(f"{name}: unsupported backend")
|
|
946
|
+
continue
|
|
947
|
+
try:
|
|
948
|
+
# Run sync search in thread pool to avoid blocking event loop
|
|
949
|
+
results = await self._run_in_thread(searcher, query)
|
|
950
|
+
if results:
|
|
951
|
+
return self._format_search_results(query, results, max_results, name)
|
|
952
|
+
errors.append(f"{name} returned no results")
|
|
953
|
+
except httpx.TimeoutException:
|
|
954
|
+
errors.append(f"{name} timed out")
|
|
955
|
+
except httpx.HTTPStatusError as e:
|
|
956
|
+
errors.append(f"{name} HTTP {e.response.status_code}")
|
|
957
|
+
except Exception as e:
|
|
958
|
+
errors.append(f"{name}: {e}")
|
|
959
|
+
|
|
960
|
+
return ToolResult(
|
|
961
|
+
success=False, output="",
|
|
962
|
+
error=f"Search failed: {'; '.join(errors)}"
|
|
963
|
+
)
|
|
964
|
+
|
|
965
|
+
def _search_bing(self, query: str) -> list[dict[str, str]]:
|
|
966
|
+
"""Search Bing (web scraping, no API key)."""
|
|
967
|
+
import urllib.parse
|
|
968
|
+
url = f"https://www.bing.com/search?q={urllib.parse.quote(query)}&setlang=en"
|
|
969
|
+
headers = {
|
|
970
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
|
971
|
+
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
|
972
|
+
"Accept-Language": "en-US,en;q=0.9",
|
|
973
|
+
}
|
|
974
|
+
resp = self.http.get(url, headers=headers)
|
|
975
|
+
resp.raise_for_status()
|
|
976
|
+
|
|
977
|
+
results: list[dict[str, str]] = []
|
|
978
|
+
# Bing results are in <li class="b_algo"> blocks
|
|
979
|
+
blocks = re.findall(
|
|
980
|
+
r'<li[^>]*class="[^"]*b_algo[^"]*"[^>]*>(.*?)</li>',
|
|
981
|
+
resp.text, re.DOTALL | re.IGNORECASE,
|
|
982
|
+
)
|
|
983
|
+
for block in blocks:
|
|
984
|
+
# Title + link in <h2><a href="...">title</a></h2>
|
|
985
|
+
m = re.search(r'<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
|
|
986
|
+
if not m:
|
|
987
|
+
continue
|
|
988
|
+
href = html.unescape(m.group(1).strip())
|
|
989
|
+
title = re.sub(r'<[^>]+>', '', m.group(2)).strip()
|
|
990
|
+
if not title or not href.startswith("http"):
|
|
991
|
+
continue
|
|
992
|
+
# Snippet in <p> or <div class="b_caption">
|
|
993
|
+
snippet = ""
|
|
994
|
+
sm = re.search(
|
|
995
|
+
r'<(?:p|div)[^>]*class="[^"]*(?:b_caption|b_lineclamp)[^"]*"[^>]*>(.*?)</(?:p|div)>',
|
|
996
|
+
block, re.DOTALL | re.IGNORECASE,
|
|
997
|
+
)
|
|
998
|
+
if sm:
|
|
999
|
+
snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
|
|
1000
|
+
snippet = html.unescape(snippet)
|
|
1001
|
+
results.append({"title": title, "url": href, "snippet": snippet})
|
|
1002
|
+
|
|
1003
|
+
return results
|
|
1004
|
+
|
|
1005
|
+
def _search_baidu(self, query: str) -> list[dict[str, str]]:
|
|
1006
|
+
"""Search Baidu (web scraping, no API key)."""
|
|
1007
|
+
import urllib.parse
|
|
1008
|
+
url = f"https://www.baidu.com/s?wd={urllib.parse.quote(query)}&ie=utf-8"
|
|
1009
|
+
headers = {
|
|
1010
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
|
1011
|
+
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
|
1012
|
+
"Accept-Language": "zh-CN,zh;q=0.9",
|
|
1013
|
+
}
|
|
1014
|
+
resp = self.http.get(url, headers=headers)
|
|
1015
|
+
resp.raise_for_status()
|
|
1016
|
+
|
|
1017
|
+
results: list[dict[str, str]] = []
|
|
1018
|
+
# Baidu results: <div class="result c-container"> or <div class="c-container">
|
|
1019
|
+
blocks = re.findall(
|
|
1020
|
+
r'<div[^>]*class="[^"]*(?:result|c-container)[^"]*"[^>]*>(.*?)</div>\s*(?=<div[^>]*class="[^"]*(?:result|c-container)|$)',
|
|
1021
|
+
resp.text, re.DOTALL | re.IGNORECASE,
|
|
1022
|
+
)
|
|
1023
|
+
if not blocks:
|
|
1024
|
+
# Fallback: match h3 titles with links
|
|
1025
|
+
blocks = re.findall(
|
|
1026
|
+
r'<div[^>]*class="[^"]*c-container[^"]*"[^>]*>(.*?)</div>',
|
|
1027
|
+
resp.text, re.DOTALL | re.IGNORECASE,
|
|
1028
|
+
)
|
|
1029
|
+
|
|
1030
|
+
for block in blocks:
|
|
1031
|
+
m = re.search(r'<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
|
|
1032
|
+
if not m:
|
|
1033
|
+
continue
|
|
1034
|
+
href = html.unescape(m.group(1).strip())
|
|
1035
|
+
title = re.sub(r'<[^>]+>', '', m.group(2)).strip()
|
|
1036
|
+
if not title or not href.startswith("http"):
|
|
1037
|
+
continue
|
|
1038
|
+
snippet = ""
|
|
1039
|
+
sm = re.search(
|
|
1040
|
+
r'<(?:span|div|p)[^>]*class="[^"]*(?:content-right_[^"]*|c-abstract|content)[^"]*"[^>]*>(.*?)</(?:span|div|p)>',
|
|
1041
|
+
block, re.DOTALL | re.IGNORECASE,
|
|
1042
|
+
)
|
|
1043
|
+
if sm:
|
|
1044
|
+
snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
|
|
1045
|
+
snippet = html.unescape(snippet)
|
|
1046
|
+
results.append({"title": title, "url": href, "snippet": snippet})
|
|
1047
|
+
|
|
1048
|
+
return results
|
|
1049
|
+
|
|
1050
|
+
def _search_google(self, query: str) -> list[dict[str, str]]:
|
|
1051
|
+
"""Search Google (web scraping, no API key)."""
|
|
1052
|
+
import urllib.parse
|
|
1053
|
+
url = f"https://www.google.com/search?q={urllib.parse.quote(query)}&hl=en"
|
|
1054
|
+
headers = {
|
|
1055
|
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
|
|
1056
|
+
"(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
|
1057
|
+
"Accept-Language": "en-US,en;q=0.9",
|
|
1058
|
+
}
|
|
1059
|
+
resp = self.http.get(url, headers=headers)
|
|
1060
|
+
resp.raise_for_status()
|
|
1061
|
+
|
|
1062
|
+
results: list[dict[str, str]] = []
|
|
1063
|
+
# Google results are in <div class="g"> or <div data-sokoban-container>
|
|
1064
|
+
blocks = re.findall(
|
|
1065
|
+
r'<(?:div|li)[^>]*\b(?:class="g\b|data-sokoban-container)[^>]*>(.*?)</(?:div|li)>',
|
|
1066
|
+
resp.text, re.DOTALL | re.IGNORECASE,
|
|
1067
|
+
)
|
|
1068
|
+
for block in blocks:
|
|
1069
|
+
# Title + link: <h3>...<a href="...">title</a></h3>
|
|
1070
|
+
m = re.search(r'<a[^>]*href="(/url\?q=|)([^"&]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
|
|
1071
|
+
if not m:
|
|
1072
|
+
continue
|
|
1073
|
+
href = html.unescape(m.group(2).strip())
|
|
1074
|
+
if not href.startswith("http"):
|
|
1075
|
+
href = "https://www.google.com" + m.group(1) + m.group(2)
|
|
1076
|
+
title = re.sub(r'<[^>]+>', '', m.group(3)).strip()
|
|
1077
|
+
if not title:
|
|
1078
|
+
continue
|
|
1079
|
+
# Snippet: <span class="aCOpRe"> or various other classes
|
|
1080
|
+
snippet = ""
|
|
1081
|
+
sm = re.search(
|
|
1082
|
+
r'<(?:span|div)[^>]*\b(?:class="[^"]*(?:\baCOpRe\b|st\b)[^"]*")[^>]*>(.*?)</(?:span|div)>',
|
|
1083
|
+
block, re.DOTALL | re.IGNORECASE,
|
|
1084
|
+
)
|
|
1085
|
+
if sm:
|
|
1086
|
+
snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
|
|
1087
|
+
snippet = html.unescape(snippet)
|
|
1088
|
+
results.append({"title": title, "url": href, "snippet": snippet})
|
|
1089
|
+
|
|
1090
|
+
return results
|
|
1091
|
+
|
|
1092
|
+
@staticmethod
|
|
1093
|
+
def _format_search_results(
|
|
1094
|
+
query: str, results: list[dict[str, str]], max_results: int, source: str
|
|
1095
|
+
) -> ToolResult:
|
|
1096
|
+
out = [f"Search results for: {query} (via {source})\n"]
|
|
1097
|
+
for i, r in enumerate(results[:max_results], 1):
|
|
1098
|
+
out.append(f"{i}. **{html.unescape(r['title'])}**")
|
|
1099
|
+
out.append(f" {r['url']}")
|
|
1100
|
+
if r.get("snippet"):
|
|
1101
|
+
out.append(f" {html.unescape(r['snippet'])}")
|
|
1102
|
+
out.append("")
|
|
1103
|
+
return ToolResult(success=True, output="\n".join(out))
|
|
1104
|
+
|
|
1105
|
+
@staticmethod
|
|
1106
|
+
def _parse_ddg_lite(html_text: str) -> list[dict[str, str]]:
|
|
1107
|
+
"""Extract search results from DuckDuckGo Lite HTML."""
|
|
1108
|
+
results: list[dict[str, str]] = []
|
|
1109
|
+
|
|
1110
|
+
# DDG Lite: results are in <a> tags with class="result-link"
|
|
1111
|
+
# and snippets in <td class="result-snippet">
|
|
1112
|
+
link_pattern = re.compile(
|
|
1113
|
+
r'<a[^>]*href="([^"]*)"[^>]*class="[^"]*result-link[^"]*"[^>]*>(.*?)</a>',
|
|
1114
|
+
re.DOTALL | re.IGNORECASE,
|
|
1115
|
+
)
|
|
1116
|
+
snippet_pattern = re.compile(
|
|
1117
|
+
r'<td[^>]*class="[^"]*result-snippet[^"]*"[^>]*>(.*?)</td>',
|
|
1118
|
+
re.DOTALL | re.IGNORECASE,
|
|
1119
|
+
)
|
|
1120
|
+
|
|
1121
|
+
links = link_pattern.findall(html_text)
|
|
1122
|
+
snippets = snippet_pattern.findall(html_text)
|
|
1123
|
+
|
|
1124
|
+
for i, (href, title) in enumerate(links):
|
|
1125
|
+
href = html.unescape(href.strip())
|
|
1126
|
+
title = re.sub(r'<[^>]+>', '', title).strip()
|
|
1127
|
+
if not title:
|
|
1128
|
+
continue
|
|
1129
|
+
|
|
1130
|
+
# Pick corresponding snippet
|
|
1131
|
+
snippet = ""
|
|
1132
|
+
if i < len(snippets):
|
|
1133
|
+
snippet = re.sub(r'<[^>]+>', '', snippets[i])
|
|
1134
|
+
snippet = html.unescape(snippet.strip())
|
|
1135
|
+
|
|
1136
|
+
results.append({
|
|
1137
|
+
"title": title,
|
|
1138
|
+
"url": href,
|
|
1139
|
+
"snippet": snippet[:300],
|
|
1140
|
+
})
|
|
1141
|
+
|
|
1142
|
+
return results
|
|
1143
|
+
|
|
1144
|
+
async def _tool_web_fetch(self, url: str) -> ToolResult:
|
|
1145
|
+
"""Fetch a URL and extract its text content."""
|
|
1146
|
+
if not url.startswith(("http://", "https://")):
|
|
1147
|
+
return ToolResult(
|
|
1148
|
+
success=False, output="",
|
|
1149
|
+
error=f"Invalid URL: must start with http:// or https://"
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
def _do_fetch():
|
|
1153
|
+
return self.http.get(url)
|
|
1154
|
+
|
|
1155
|
+
try:
|
|
1156
|
+
resp = await self._run_in_thread(_do_fetch)
|
|
1157
|
+
resp.raise_for_status()
|
|
1158
|
+
except httpx.TimeoutException:
|
|
1159
|
+
return ToolResult(
|
|
1160
|
+
success=False, output="",
|
|
1161
|
+
error=f"Request timed out: {url}"
|
|
1162
|
+
)
|
|
1163
|
+
except httpx.HTTPStatusError as e:
|
|
1164
|
+
return ToolResult(
|
|
1165
|
+
success=False, output="",
|
|
1166
|
+
error=f"HTTP {e.response.status_code} for {url}"
|
|
1167
|
+
)
|
|
1168
|
+
except Exception as e:
|
|
1169
|
+
return ToolResult(
|
|
1170
|
+
success=False, output="",
|
|
1171
|
+
error=f"Fetch failed: {e}"
|
|
1172
|
+
)
|
|
1173
|
+
|
|
1174
|
+
content_type = resp.headers.get("content-type", "")
|
|
1175
|
+
if "text/html" not in content_type and "text/plain" not in content_type:
|
|
1176
|
+
return ToolResult(
|
|
1177
|
+
success=False, output="",
|
|
1178
|
+
error=f"Cannot process content type: {content_type}. Only text/html and text/plain are supported."
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
text = self._extract_text(resp.text, url)
|
|
1182
|
+
|
|
1183
|
+
# Truncate
|
|
1184
|
+
MAX_CHARS = 15_000
|
|
1185
|
+
if len(text) > MAX_CHARS:
|
|
1186
|
+
text = text[:MAX_CHARS] + (
|
|
1187
|
+
f"\n\n... [truncated {len(text) - MAX_CHARS:,} "
|
|
1188
|
+
f"chars from {url}]"
|
|
1189
|
+
)
|
|
1190
|
+
|
|
1191
|
+
return ToolResult(
|
|
1192
|
+
success=True,
|
|
1193
|
+
output=f"Content from: {url}\n\n{text}",
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
# ── Sub-agent tools ──────────────────────────────────────────────────
|
|
1197
|
+
|
|
1198
|
+
async def _tool_spawn_subagent(self, task: str, skill: str = "",
|
|
1199
|
+
model: str = "") -> ToolResult:
|
|
1200
|
+
"""Spawn a sub-agent to work on a task in parallel."""
|
|
1201
|
+
if not self._sub_agent_mgr:
|
|
1202
|
+
return ToolResult(
|
|
1203
|
+
success=False, output="",
|
|
1204
|
+
error="SubAgentManager not available. "
|
|
1205
|
+
"Ensure agent_controller is used.",
|
|
1206
|
+
)
|
|
1207
|
+
try:
|
|
1208
|
+
# Clawd: SubagentStart
|
|
1209
|
+
get_clawd().subagent_start()
|
|
1210
|
+
|
|
1211
|
+
agent_id = self._sub_agent_mgr.spawn(
|
|
1212
|
+
task=task,
|
|
1213
|
+
skill_prompt=skill,
|
|
1214
|
+
model=model or None,
|
|
1215
|
+
)
|
|
1216
|
+
return ToolResult(
|
|
1217
|
+
success=True,
|
|
1218
|
+
output=(
|
|
1219
|
+
f"Sub-agent spawned: {agent_id}\n"
|
|
1220
|
+
f"Status: running\n"
|
|
1221
|
+
f"Active sub-agents: {self._sub_agent_mgr.active_count}\n\n"
|
|
1222
|
+
f"Use collect_subagent('{agent_id}') to retrieve results, "
|
|
1223
|
+
f"or list_subagents() to check all statuses."
|
|
1224
|
+
),
|
|
1225
|
+
)
|
|
1226
|
+
except RuntimeError as e:
|
|
1227
|
+
return ToolResult(
|
|
1228
|
+
success=False, output="",
|
|
1229
|
+
error=f"Cannot spawn sub-agent: {e}",
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
async def _tool_collect_subagent(self, agent_id: str,
|
|
1233
|
+
timeout: float = 300.0) -> ToolResult:
|
|
1234
|
+
"""Collect results from a spawned sub-agent."""
|
|
1235
|
+
if not self._sub_agent_mgr:
|
|
1236
|
+
return ToolResult(
|
|
1237
|
+
success=False, output="",
|
|
1238
|
+
error="SubAgentManager not available.",
|
|
1239
|
+
)
|
|
1240
|
+
result = self._sub_agent_mgr.collect(agent_id, timeout=timeout)
|
|
1241
|
+
|
|
1242
|
+
# Clawd: SubagentStop
|
|
1243
|
+
get_clawd().subagent_stop()
|
|
1244
|
+
|
|
1245
|
+
if result.success:
|
|
1246
|
+
lines = [
|
|
1247
|
+
f"Sub-agent {agent_id} completed successfully.",
|
|
1248
|
+
f"Tool calls: {result.tool_call_count}",
|
|
1249
|
+
f"",
|
|
1250
|
+
f"Result:",
|
|
1251
|
+
result.result or "(empty)",
|
|
1252
|
+
]
|
|
1253
|
+
return ToolResult(success=True, output="\n".join(lines))
|
|
1254
|
+
else:
|
|
1255
|
+
return ToolResult(
|
|
1256
|
+
success=False,
|
|
1257
|
+
output=f"Sub-agent {agent_id} failed: {result.error}",
|
|
1258
|
+
error=result.error,
|
|
1259
|
+
)
|
|
1260
|
+
|
|
1261
|
+
async def _tool_list_subagents(self) -> ToolResult:
|
|
1262
|
+
"""List all sub-agents and their statuses."""
|
|
1263
|
+
if not self._sub_agent_mgr:
|
|
1264
|
+
return ToolResult(
|
|
1265
|
+
success=False, output="",
|
|
1266
|
+
error="SubAgentManager not available.",
|
|
1267
|
+
)
|
|
1268
|
+
agents = self._sub_agent_mgr.list_all()
|
|
1269
|
+
if not agents:
|
|
1270
|
+
return ToolResult(success=True, output="No sub-agents.")
|
|
1271
|
+
|
|
1272
|
+
lines = [f"Sub-agents ({len(agents)} total):", ""]
|
|
1273
|
+
for a in agents:
|
|
1274
|
+
status_icon = {"running": "🔄", "done": "✅",
|
|
1275
|
+
"failed": "❌", "cancelled": "⏹️"}.get(a.status, "❓")
|
|
1276
|
+
lines.append(
|
|
1277
|
+
f" {status_icon} {a.id} — {a.status} "
|
|
1278
|
+
f"(tool_calls={a.tool_call_count})"
|
|
1279
|
+
)
|
|
1280
|
+
return ToolResult(success=True, output="\n".join(lines))
|
|
1281
|
+
|
|
1282
|
+
async def _tool_mcp_search(self, query: str, type: str = "all") -> ToolResult:
|
|
1283
|
+
"""Search MCP tools and resources across all connected servers."""
|
|
1284
|
+
if not self._mcp:
|
|
1285
|
+
return ToolResult(
|
|
1286
|
+
success=False, output="",
|
|
1287
|
+
error="MCP not configured. Add MCP servers via --mcp-config.",
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
servers = self._mcp.connected_servers
|
|
1291
|
+
if not servers:
|
|
1292
|
+
return ToolResult(success=True, output="No MCP servers connected.")
|
|
1293
|
+
|
|
1294
|
+
lines = [f"MCP search results for '{query}' across {len(servers)} server(s):", ""]
|
|
1295
|
+
found = 0
|
|
1296
|
+
|
|
1297
|
+
if type in ("tools", "all"):
|
|
1298
|
+
tools = self._mcp.search_tools(query, limit=20)
|
|
1299
|
+
if tools:
|
|
1300
|
+
lines.append(f" Tools ({len(tools)}):")
|
|
1301
|
+
for t in tools:
|
|
1302
|
+
name = t.get("name", "?")
|
|
1303
|
+
desc = (t.get("description") or "")[:100]
|
|
1304
|
+
server = t.get("_mcp_server", "?")
|
|
1305
|
+
lines.append(f" ● {name} @{server}")
|
|
1306
|
+
if desc:
|
|
1307
|
+
lines.append(f" {desc}")
|
|
1308
|
+
found += len(tools)
|
|
1309
|
+
else:
|
|
1310
|
+
lines.append(" Tools: none found")
|
|
1311
|
+
|
|
1312
|
+
if type in ("resources", "all"):
|
|
1313
|
+
resources = self._mcp.search_resources(query, limit=20)
|
|
1314
|
+
if resources:
|
|
1315
|
+
lines.append(f"\n Resources ({len(resources)}):")
|
|
1316
|
+
for r in resources:
|
|
1317
|
+
uri = r.get("uri", "?")
|
|
1318
|
+
name = r.get("name", "")
|
|
1319
|
+
desc = (r.get("description") or "")[:80]
|
|
1320
|
+
server = r.get("_mcp_server", "?")
|
|
1321
|
+
label = name or uri
|
|
1322
|
+
lines.append(f" ● {label} @{server}")
|
|
1323
|
+
if desc:
|
|
1324
|
+
lines.append(f" {desc}")
|
|
1325
|
+
found += len(resources)
|
|
1326
|
+
else:
|
|
1327
|
+
lines.append("\n Resources: none found")
|
|
1328
|
+
|
|
1329
|
+
if found == 0:
|
|
1330
|
+
return ToolResult(
|
|
1331
|
+
success=True,
|
|
1332
|
+
output=f"No MCP tools or resources found matching '{query}'.\n"
|
|
1333
|
+
f"Connected servers: {', '.join(servers)}.",
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
return ToolResult(success=True, output="\n".join(lines))
|
|
1337
|
+
|
|
1338
|
+
async def _tool_analyze_image(self, image_path: str, prompt: str = "Describe this image in detail.") -> ToolResult:
|
|
1339
|
+
"""Analyze an image using a multimodal vision model.
|
|
1340
|
+
|
|
1341
|
+
Uses the configured vision model, falling back to the main LLM config.
|
|
1342
|
+
Configure via ~/.ata_coder/settings.json:
|
|
1343
|
+
{"vision": {"model": "...", "api_base": "...", "api_key": "..."}}
|
|
1344
|
+
Or env vars: VISION_MODEL, VISION_API_BASE, VISION_API_KEY.
|
|
1345
|
+
"""
|
|
1346
|
+
import base64
|
|
1347
|
+
from pathlib import Path
|
|
1348
|
+
|
|
1349
|
+
img_path = Path(image_path)
|
|
1350
|
+
if not img_path.exists():
|
|
1351
|
+
return ToolResult(
|
|
1352
|
+
success=False, output="",
|
|
1353
|
+
error=f"Image not found: {image_path}",
|
|
1354
|
+
)
|
|
1355
|
+
|
|
1356
|
+
ext = img_path.suffix.lower()
|
|
1357
|
+
supported = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
|
1358
|
+
if ext not in supported:
|
|
1359
|
+
return ToolResult(
|
|
1360
|
+
success=False, output="",
|
|
1361
|
+
error=f"Unsupported image format: {ext}. Supported: {', '.join(sorted(supported))}",
|
|
1362
|
+
)
|
|
1363
|
+
|
|
1364
|
+
try:
|
|
1365
|
+
with open(img_path, "rb") as f:
|
|
1366
|
+
img_b64 = base64.standard_b64encode(f.read()).decode("ascii")
|
|
1367
|
+
except Exception as e:
|
|
1368
|
+
return ToolResult(success=False, output="", error=f"Failed to read image: {e}")
|
|
1369
|
+
|
|
1370
|
+
# ── Resolve vision config ──
|
|
1371
|
+
# Priority: env var > settings.json > main api config
|
|
1372
|
+
from .settings import get_settings
|
|
1373
|
+
settings = get_settings()
|
|
1374
|
+
|
|
1375
|
+
# API key: VISION_API_KEY env > settings.json vision.api_key > main api key
|
|
1376
|
+
api_key = (
|
|
1377
|
+
os.environ.get("VISION_API_KEY", "")
|
|
1378
|
+
or settings.vision_api_key
|
|
1379
|
+
or os.environ.get("ATA_CODER_API_KEY", "")
|
|
1380
|
+
or os.environ.get("OPENAI_API_KEY", "")
|
|
1381
|
+
or settings.api_key
|
|
1382
|
+
)
|
|
1383
|
+
if not api_key:
|
|
1384
|
+
return ToolResult(
|
|
1385
|
+
success=False, output="",
|
|
1386
|
+
error="No API key configured. Set ATA_CODER_API_KEY or add vision.api_key in ~/.ata_coder/settings.json.",
|
|
1387
|
+
)
|
|
1388
|
+
|
|
1389
|
+
# API base: VISION_API_BASE env > settings.json vision.api_base > main base_url
|
|
1390
|
+
api_base = (
|
|
1391
|
+
os.environ.get("VISION_API_BASE", "")
|
|
1392
|
+
or settings.vision_api_base
|
|
1393
|
+
or os.environ.get("ATA_CODER_BASE_URL", "")
|
|
1394
|
+
or os.environ.get("OPENAI_BASE_URL", "")
|
|
1395
|
+
or settings.api_base_url
|
|
1396
|
+
)
|
|
1397
|
+
|
|
1398
|
+
# Model: VISION_MODEL env > settings.json vision.model > main model
|
|
1399
|
+
model = (
|
|
1400
|
+
os.environ.get("VISION_MODEL", "")
|
|
1401
|
+
or settings.vision_model
|
|
1402
|
+
or os.environ.get("ATA_CODER_DEFAULT_MODEL", "")
|
|
1403
|
+
or os.environ.get("OPENAI_MODEL", "")
|
|
1404
|
+
or settings.default_model
|
|
1405
|
+
)
|
|
1406
|
+
|
|
1407
|
+
mime = ext.replace("jpg", "jpeg").replace(".", "image/")
|
|
1408
|
+
body = {
|
|
1409
|
+
"model": model,
|
|
1410
|
+
"messages": [{
|
|
1411
|
+
"role": "user",
|
|
1412
|
+
"content": [
|
|
1413
|
+
{"type": "text", "text": prompt},
|
|
1414
|
+
{"type": "image_url", "image_url": {
|
|
1415
|
+
"url": f"data:{mime};base64,{img_b64}",
|
|
1416
|
+
"detail": "auto"
|
|
1417
|
+
}},
|
|
1418
|
+
]
|
|
1419
|
+
}],
|
|
1420
|
+
"max_tokens": 2048,
|
|
1421
|
+
"temperature": 0.3,
|
|
1422
|
+
}
|
|
1423
|
+
|
|
1424
|
+
try:
|
|
1425
|
+
import json as _json
|
|
1426
|
+
from urllib.request import Request, urlopen
|
|
1427
|
+
from urllib.error import HTTPError
|
|
1428
|
+
|
|
1429
|
+
data = _json.dumps(body).encode("utf-8")
|
|
1430
|
+
req = Request(
|
|
1431
|
+
f"{api_base.rstrip('/')}/chat/completions",
|
|
1432
|
+
data=data,
|
|
1433
|
+
headers={
|
|
1434
|
+
"Content-Type": "application/json",
|
|
1435
|
+
"Authorization": f"Bearer {api_key}",
|
|
1436
|
+
},
|
|
1437
|
+
)
|
|
1438
|
+
with urlopen(req, timeout=120) as resp:
|
|
1439
|
+
result = _json.loads(resp.read().decode("utf-8"))
|
|
1440
|
+
content = (
|
|
1441
|
+
result.get("choices", [{}])[0]
|
|
1442
|
+
.get("message", {})
|
|
1443
|
+
.get("content", "(no response)")
|
|
1444
|
+
)
|
|
1445
|
+
usage = result.get("usage", {})
|
|
1446
|
+
tokens = usage.get("total_tokens", "?")
|
|
1447
|
+
return ToolResult(
|
|
1448
|
+
success=True,
|
|
1449
|
+
output=f"[Vision: {model} | {tokens} tokens]\n\n{content}",
|
|
1450
|
+
)
|
|
1451
|
+
except HTTPError as e:
|
|
1452
|
+
error_body = e.read().decode("utf-8", errors="replace")[:300]
|
|
1453
|
+
return ToolResult(
|
|
1454
|
+
success=False, output="",
|
|
1455
|
+
error=f"Vision API error {e.code}: {error_body}",
|
|
1456
|
+
)
|
|
1457
|
+
except Exception as e:
|
|
1458
|
+
return ToolResult(
|
|
1459
|
+
success=False, output="",
|
|
1460
|
+
error=f"Vision API call failed: {e}",
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
@staticmethod
|
|
1464
|
+
def _extract_text(html_text: str, url: str = "") -> str:
|
|
1465
|
+
"""Strip HTML down to readable text."""
|
|
1466
|
+
|
|
1467
|
+
class _TextExtractor(html.parser.HTMLParser):
|
|
1468
|
+
def __init__(self):
|
|
1469
|
+
super().__init__()
|
|
1470
|
+
self.parts: list[str] = []
|
|
1471
|
+
self._skip_count = 0 # counter for nested skip-tags
|
|
1472
|
+
self._skip_tags = {"script", "style", "noscript", "iframe",
|
|
1473
|
+
"nav", "footer", "header", "aside"}
|
|
1474
|
+
self._block_tags = {"div", "p", "h1", "h2", "h3", "h4", "h5",
|
|
1475
|
+
"h6", "li", "tr", "section", "article",
|
|
1476
|
+
"pre", "blockquote", "table", "ul", "ol",
|
|
1477
|
+
"dl", "br", "hr"}
|
|
1478
|
+
|
|
1479
|
+
def handle_starttag(self, tag, attrs):
|
|
1480
|
+
tag = tag.lower()
|
|
1481
|
+
if tag in self._skip_tags:
|
|
1482
|
+
self._skip_count += 1
|
|
1483
|
+
elif tag in self._block_tags:
|
|
1484
|
+
self.parts.append("\n")
|
|
1485
|
+
|
|
1486
|
+
def handle_endtag(self, tag):
|
|
1487
|
+
tag = tag.lower()
|
|
1488
|
+
if tag in self._skip_tags and self._skip_count > 0:
|
|
1489
|
+
self._skip_count -= 1
|
|
1490
|
+
elif tag in self._block_tags:
|
|
1491
|
+
self.parts.append("\n")
|
|
1492
|
+
|
|
1493
|
+
def handle_data(self, data):
|
|
1494
|
+
if self._skip_count == 0:
|
|
1495
|
+
text = data.strip()
|
|
1496
|
+
if text:
|
|
1497
|
+
self.parts.append(text + " ")
|
|
1498
|
+
|
|
1499
|
+
try:
|
|
1500
|
+
extractor = _TextExtractor()
|
|
1501
|
+
extractor.feed(html_text)
|
|
1502
|
+
raw = "".join(extractor.parts)
|
|
1503
|
+
except Exception:
|
|
1504
|
+
# Fallback: regex strip
|
|
1505
|
+
raw = re.sub(r'<script[^>]*>.*?</script>', '', html_text, flags=re.DOTALL | re.IGNORECASE)
|
|
1506
|
+
raw = re.sub(r'<style[^>]*>.*?</style>', '', raw, flags=re.DOTALL | re.IGNORECASE)
|
|
1507
|
+
raw = re.sub(r'<[^>]+>', ' ', raw)
|
|
1508
|
+
raw = html.unescape(raw)
|
|
1509
|
+
|
|
1510
|
+
# Collapse whitespace
|
|
1511
|
+
raw = re.sub(r'[ \t]+', ' ', raw)
|
|
1512
|
+
raw = re.sub(r'\n{3,}', '\n\n', raw)
|
|
1513
|
+
return raw.strip()
|
|
1514
|
+
|
|
1515
|
+
|
|
1516
|
+
# ── Factory ──────────────────────────────────────────────────────────────────
|
|
1517
|
+
|
|
1518
|
+
def create_tool_executor(workspace_dir: str | None = None) -> ToolExecutor:
|
|
1519
|
+
"""Create a tool executor with the given workspace."""
|
|
1520
|
+
from .config import AgentConfig
|
|
1521
|
+
if workspace_dir:
|
|
1522
|
+
cfg = AgentConfig(workspace_dir=workspace_dir)
|
|
1523
|
+
return ToolExecutor(cfg)
|
|
1524
|
+
return ToolExecutor()
|
|
1525
|
+
|
|
1526
|
+
|
|
1527
|
+
# ── CST Renamers (for rename_symbol tool) ───────────────────────────────────
|
|
1528
|
+
# These use libcst to safely rename symbols without touching strings or comments.
|
|
1529
|
+
|
|
1530
|
+
class _SymbolRenamer:
|
|
1531
|
+
"""Base class for CST symbol renamers — shared leave_Name/leave_Call logic."""
|
|
1532
|
+
|
|
1533
|
+
def __init__(self, old_name: str, new_name: str):
|
|
1534
|
+
self.old = old_name
|
|
1535
|
+
self.new = new_name
|
|
1536
|
+
self.changes = 0
|
|
1537
|
+
|
|
1538
|
+
def leave_Name(self, original_node, updated_node):
|
|
1539
|
+
import libcst as cst
|
|
1540
|
+
if isinstance(updated_node, cst.Name) and updated_node.value == self.old:
|
|
1541
|
+
self.changes += 1
|
|
1542
|
+
return updated_node.with_changes(value=self.new)
|
|
1543
|
+
return updated_node
|
|
1544
|
+
|
|
1545
|
+
def leave_Call(self, original_node, updated_node):
|
|
1546
|
+
import libcst as cst
|
|
1547
|
+
if isinstance(updated_node.func, cst.Name) and updated_node.func.value == self.old:
|
|
1548
|
+
self.changes += 1
|
|
1549
|
+
return updated_node.with_changes(func=cst.Name(value=self.new))
|
|
1550
|
+
return updated_node
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
class _VariableRenamer(_SymbolRenamer):
|
|
1554
|
+
"""Rename variable references — names only, not touching strings/comments."""
|
|
1555
|
+
|
|
1556
|
+
|
|
1557
|
+
class _FunctionRenamer(_SymbolRenamer):
|
|
1558
|
+
"""Rename function/method definitions and calls."""
|
|
1559
|
+
|
|
1560
|
+
def leave_FunctionDef(self, original_node, updated_node):
|
|
1561
|
+
import libcst as cst
|
|
1562
|
+
if updated_node.name.value == self.old:
|
|
1563
|
+
self.changes += 1
|
|
1564
|
+
return updated_node.with_changes(name=cst.Name(value=self.new))
|
|
1565
|
+
return updated_node
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
class _ClassRenamer(_SymbolRenamer):
|
|
1569
|
+
"""Rename class definitions and constructor calls."""
|
|
1570
|
+
|
|
1571
|
+
def leave_ClassDef(self, original_node, updated_node):
|
|
1572
|
+
import libcst as cst
|
|
1573
|
+
if updated_node.name.value == self.old:
|
|
1574
|
+
self.changes += 1
|
|
1575
|
+
return updated_node.with_changes(name=cst.Name(value=self.new))
|
|
1576
|
+
return updated_node
|