ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1036 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool system for the ATA Coder.
|
|
3
|
+
|
|
4
|
+
Provides a set of tools the agent can use:
|
|
5
|
+
- read_file: Read file contents
|
|
6
|
+
- write_file: Create or overwrite a file
|
|
7
|
+
- edit_file: Precise string replacement in a file
|
|
8
|
+
- run_shell: Execute a shell command
|
|
9
|
+
- grep: Search file contents with regex
|
|
10
|
+
- glob: Find files matching a pattern
|
|
11
|
+
- list_dir: List directory contents
|
|
12
|
+
- web_search: Search the web (optional)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import re
|
|
19
|
+
import shutil
|
|
20
|
+
import fnmatch
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any, Callable
|
|
23
|
+
|
|
24
|
+
# Cached reference for __del__ — avoids "import asyncio" during interpreter
|
|
25
|
+
# shutdown, which fails with "ImportError: sys.meta_path is None".
|
|
26
|
+
_asyncio_get_running_loop = asyncio.get_running_loop
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
from ..config import AgentConfig
|
|
30
|
+
from .result import ToolResult
|
|
31
|
+
from .web import WebToolsMixin
|
|
32
|
+
from .subagent import SubAgentToolsMixin
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ── Tool result type ─────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
class ToolExecutor(WebToolsMixin, SubAgentToolsMixin):
|
|
40
|
+
"""Executes tool calls and manages workspace context."""
|
|
41
|
+
|
|
42
|
+
# Result limits
|
|
43
|
+
MAX_GREP_RESULTS = 100
|
|
44
|
+
MAX_GREP_PER_FILE = 20
|
|
45
|
+
MAX_GLOB_RESULTS = 200
|
|
46
|
+
MAX_DIR_ENTRIES = 500
|
|
47
|
+
|
|
48
|
+
# Directories to skip during recursive operations
|
|
49
|
+
SKIP_DIRS = {
|
|
50
|
+
"node_modules", "__pycache__", ".git", "venv", ".venv",
|
|
51
|
+
"dist", "build", "target", ".next", ".pytest_cache",
|
|
52
|
+
".mypy_cache", ".ruff_cache", ".tox", ".idea", ".vs",
|
|
53
|
+
".vscode", "bower_components", ".terraform", ".eggs",
|
|
54
|
+
"*.egg-info", "htmlcov", ".coverage", "__pypackages__",
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
def __init__(self, config: AgentConfig | None = None):
|
|
58
|
+
self.config = config or AgentConfig()
|
|
59
|
+
self.workspace = Path(self.config.workspace_dir).resolve()
|
|
60
|
+
self._mcp = None # set via set_mcp_client()
|
|
61
|
+
self._edit_callback: Callable[[str, str], None] | None = None
|
|
62
|
+
# File read cache: path → (mtime, cached_at, content).
|
|
63
|
+
# "只读一遍" — files are read from disk once per session and served
|
|
64
|
+
# from memory after. Entries older than _FILE_CACHE_TTL seconds are
|
|
65
|
+
# re-read to catch external modifications that don't update mtime.
|
|
66
|
+
# LRU eviction: hits refresh position (move-to-end), eviction drops
|
|
67
|
+
# from front (least-recently-used).
|
|
68
|
+
self._file_cache: dict[str, tuple[float, float, str]] = {}
|
|
69
|
+
self._file_cache_max_entries = 50
|
|
70
|
+
self._FILE_CACHE_TTL = 30.0 # seconds before a cache entry is revalidated
|
|
71
|
+
self._cache_dir: Path | None = None
|
|
72
|
+
# Sub-agent manager (set by agent)
|
|
73
|
+
self._sub_agent_mgr: Any = None
|
|
74
|
+
# Pre-built handler map (validated at init time)
|
|
75
|
+
self._handlers: dict[str, Callable] = self._build_handlers()
|
|
76
|
+
# Streaming callback: set by agent to receive real-time output chunks
|
|
77
|
+
self._stream_cb: Callable[[str, str], None] | None = None # (tool_name, chunk)
|
|
78
|
+
|
|
79
|
+
def set_stream_callback(self, cb: Callable[[str, str], None] | None) -> None:
|
|
80
|
+
"""Set callback for real-time tool output streaming.
|
|
81
|
+
|
|
82
|
+
``cb(tool_name, chunk)`` is called with incremental output chunks
|
|
83
|
+
during long-running tool execution (e.g. run_shell).
|
|
84
|
+
Set to None to disable streaming.
|
|
85
|
+
"""
|
|
86
|
+
self._stream_cb = cb
|
|
87
|
+
|
|
88
|
+
def _build_handlers(self) -> dict[str, Callable]:
|
|
89
|
+
"""Build a dispatch table: tool_name → handler callable.
|
|
90
|
+
|
|
91
|
+
Only registers callable methods — prevents accidental registration
|
|
92
|
+
of mixin properties, class attributes, or helper objects whose names
|
|
93
|
+
happen to match the ``_tool_`` prefix.
|
|
94
|
+
"""
|
|
95
|
+
handlers: dict[str, Callable] = {}
|
|
96
|
+
for name in dir(self):
|
|
97
|
+
if name.startswith("_tool_"):
|
|
98
|
+
attr = getattr(self, name)
|
|
99
|
+
if callable(attr):
|
|
100
|
+
tool_name = name[len("_tool_"):]
|
|
101
|
+
handlers[tool_name] = attr
|
|
102
|
+
return handlers
|
|
103
|
+
|
|
104
|
+
def set_sub_agent_manager(self, mgr: Any) -> None:
|
|
105
|
+
"""Set the SubAgentManager for spawn/collect sub-agent tool support."""
|
|
106
|
+
self._sub_agent_mgr = mgr
|
|
107
|
+
|
|
108
|
+
def set_mcp_client(self, mcp: Any) -> None:
|
|
109
|
+
"""Set the MCPClient for mcp_search tool support."""
|
|
110
|
+
self._mcp = mcp
|
|
111
|
+
|
|
112
|
+
def setup_file_cache(self, cache_dir: str | Path) -> None:
|
|
113
|
+
"""Create session cache directory. Call once before running the agent."""
|
|
114
|
+
self._cache_dir = Path(cache_dir)
|
|
115
|
+
self._cache_dir.mkdir(parents=True, exist_ok=True)
|
|
116
|
+
|
|
117
|
+
def clear_file_cache(self) -> None:
|
|
118
|
+
"""Remove all cached files and the cache directory."""
|
|
119
|
+
self._file_cache.clear()
|
|
120
|
+
if self._cache_dir and self._cache_dir.exists():
|
|
121
|
+
shutil.rmtree(self._cache_dir, ignore_errors=True)
|
|
122
|
+
self._cache_dir = None
|
|
123
|
+
|
|
124
|
+
def close(self) -> None:
|
|
125
|
+
"""Release all held resources (httpx clients, file caches)."""
|
|
126
|
+
self.clear_file_cache()
|
|
127
|
+
if hasattr(self, "_http") and self._http is not None:
|
|
128
|
+
try:
|
|
129
|
+
self._http.close()
|
|
130
|
+
except Exception:
|
|
131
|
+
pass
|
|
132
|
+
self._http = None
|
|
133
|
+
|
|
134
|
+
def __del__(self) -> None:
|
|
135
|
+
"""Safety net: ensure httpx client is closed on GC.
|
|
136
|
+
|
|
137
|
+
Only calls close() outside an active asyncio event loop — during
|
|
138
|
+
interpreter shutdown the loop may already be closed, and touching
|
|
139
|
+
it from __del__ triggers spurious warnings/errors.
|
|
140
|
+
|
|
141
|
+
Uses a cached reference to asyncio.get_running_loop (the module-level
|
|
142
|
+
import at the top of this file) to avoid ``import asyncio`` inside
|
|
143
|
+
__del__, which fails with ``ImportError: sys.meta_path is None``
|
|
144
|
+
during interpreter shutdown.
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
_asyncio_get_running_loop()
|
|
148
|
+
except RuntimeError:
|
|
149
|
+
# No running loop — safe to close synchronously
|
|
150
|
+
try:
|
|
151
|
+
self.close()
|
|
152
|
+
except Exception:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def on_edit(self, callback: Callable[[str, str], None]) -> None:
|
|
156
|
+
"""Register callback for edit notifications: callback(file_path, old_content)."""
|
|
157
|
+
self._edit_callback = callback
|
|
158
|
+
|
|
159
|
+
def _notify_edit(self, file_path: str, old_content: str) -> None:
|
|
160
|
+
"""Notify the UI of a file edit for diff display."""
|
|
161
|
+
if self._edit_callback:
|
|
162
|
+
try:
|
|
163
|
+
self._edit_callback(file_path, old_content)
|
|
164
|
+
except Exception:
|
|
165
|
+
logger.exception("Edit callback failed for %s", file_path)
|
|
166
|
+
|
|
167
|
+
def _resolve_path(self, file_path: str) -> Path:
|
|
168
|
+
"""Resolve a file path relative to workspace.
|
|
169
|
+
|
|
170
|
+
Raises ValueError if the resolved path escapes the workspace via
|
|
171
|
+
path traversal (e.g., ``../../etc/passwd``).
|
|
172
|
+
"""
|
|
173
|
+
p = Path(file_path)
|
|
174
|
+
if not p.is_absolute():
|
|
175
|
+
p = self.workspace / p
|
|
176
|
+
resolved = p.resolve()
|
|
177
|
+
try:
|
|
178
|
+
resolved.relative_to(self.workspace.resolve())
|
|
179
|
+
except ValueError:
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Path traversal blocked: {file_path} → {resolved} "
|
|
182
|
+
f"is outside workspace {self.workspace}"
|
|
183
|
+
)
|
|
184
|
+
return resolved
|
|
185
|
+
|
|
186
|
+
def _ensure_in_workspace(self, path: Path) -> Path:
|
|
187
|
+
"""
|
|
188
|
+
Ensure path is within workspace for safety.
|
|
189
|
+
For files outside workspace, allow with a warning.
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
path.resolve().relative_to(self.workspace)
|
|
193
|
+
except ValueError:
|
|
194
|
+
logger.warning("Path outside workspace: %s", path)
|
|
195
|
+
return path
|
|
196
|
+
|
|
197
|
+
# Global output cap — every tool result is trimmed to this many chars.
|
|
198
|
+
# ~25k tokens, still generous for legitimate work but prevents a single
|
|
199
|
+
# result from eating the whole context window.
|
|
200
|
+
MAX_OUTPUT_CHARS = 100_000
|
|
201
|
+
|
|
202
|
+
async def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
|
|
203
|
+
"""Dispatch a tool call to the appropriate handler.
|
|
204
|
+
|
|
205
|
+
All results are capped at MAX_OUTPUT_CHARS to prevent any single
|
|
206
|
+
tool response from blowing up the conversation context.
|
|
207
|
+
"""
|
|
208
|
+
handler = self._handlers.get(tool_name)
|
|
209
|
+
if handler is None:
|
|
210
|
+
return ToolResult(
|
|
211
|
+
success=False,
|
|
212
|
+
output="",
|
|
213
|
+
error=f"Unknown tool: {tool_name}",
|
|
214
|
+
)
|
|
215
|
+
try:
|
|
216
|
+
result = await handler(**arguments)
|
|
217
|
+
except Exception as e:
|
|
218
|
+
logger.exception("Tool %s failed", tool_name)
|
|
219
|
+
return ToolResult(
|
|
220
|
+
success=False, output="", error=f"{type(e).__name__}: {e}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# ── Global size cap ──────────────────────────────────────────
|
|
224
|
+
if len(result.output) > self.MAX_OUTPUT_CHARS:
|
|
225
|
+
original_len = len(result.output)
|
|
226
|
+
cut = result.output[:self.MAX_OUTPUT_CHARS]
|
|
227
|
+
result.output = (
|
|
228
|
+
cut + f"\n\n... [truncated {original_len - self.MAX_OUTPUT_CHARS:,} "
|
|
229
|
+
f"chars — result was {original_len:,} chars total]"
|
|
230
|
+
)
|
|
231
|
+
logger.warning(
|
|
232
|
+
"Tool %s output truncated: %d → %d chars",
|
|
233
|
+
tool_name, original_len, self.MAX_OUTPUT_CHARS,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
return result
|
|
237
|
+
|
|
238
|
+
# ── File tools ───────────────────────────────────────────────────────────
|
|
239
|
+
|
|
240
|
+
# ── Output limits (prevent token bloat from large file reads) ──────
|
|
241
|
+
MAX_READ_LINES = 2000 # auto-truncate reads without an explicit limit
|
|
242
|
+
MAX_READ_CHARS = 80_000 # hard cap on output chars (~20k tokens)
|
|
243
|
+
|
|
244
|
+
async def _tool_read_file(
|
|
245
|
+
self,
|
|
246
|
+
file_path: str,
|
|
247
|
+
offset: int | None = None,
|
|
248
|
+
limit: int | None = None,
|
|
249
|
+
) -> ToolResult:
|
|
250
|
+
"""Read a file with optional line range.
|
|
251
|
+
|
|
252
|
+
When *limit* is not set, output is capped at MAX_READ_LINES to
|
|
253
|
+
prevent a single file read from eating tens of thousands of context
|
|
254
|
+
tokens. Use *offset* + *limit* to page through larger files.
|
|
255
|
+
"""
|
|
256
|
+
path = self._resolve_path(file_path)
|
|
257
|
+
if not path.exists():
|
|
258
|
+
return ToolResult(
|
|
259
|
+
success=False,
|
|
260
|
+
output="",
|
|
261
|
+
error=f"File not found: {path}",
|
|
262
|
+
)
|
|
263
|
+
if path.is_dir():
|
|
264
|
+
return ToolResult(
|
|
265
|
+
success=False,
|
|
266
|
+
output="",
|
|
267
|
+
error=f"Path is a directory, not a file: {path}",
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# ── File cache: "只读一遍" — return short note on re-read ────
|
|
271
|
+
import time as _time
|
|
272
|
+
cache_key = str(path.resolve())
|
|
273
|
+
current_mtime = path.stat().st_mtime
|
|
274
|
+
needs_disk_read = True
|
|
275
|
+
|
|
276
|
+
if cache_key in self._file_cache:
|
|
277
|
+
cached_mtime, cached_at, cached_content = self._file_cache[cache_key]
|
|
278
|
+
age = _time.time() - cached_at
|
|
279
|
+
if age > self._FILE_CACHE_TTL:
|
|
280
|
+
# TTL expired — re-read from disk to catch external modifications
|
|
281
|
+
pass # fall through to needs_disk_read
|
|
282
|
+
elif cached_mtime == current_mtime:
|
|
283
|
+
# Refresh LRU position: pop and re-insert moves key to end
|
|
284
|
+
self._file_cache.pop(cache_key)
|
|
285
|
+
self._file_cache[cache_key] = (cached_mtime, _time.time(), cached_content)
|
|
286
|
+
if offset is not None or limit is not None:
|
|
287
|
+
# Specific section — serve from memory, skip disk
|
|
288
|
+
needs_disk_read = False
|
|
289
|
+
lines = cached_content.splitlines(keepends=True)
|
|
290
|
+
if lines and not lines[-1].endswith("\n"):
|
|
291
|
+
lines[-1] += "\n"
|
|
292
|
+
else:
|
|
293
|
+
# Whole file re-read — DON'T send content again
|
|
294
|
+
total = cached_content.count("\n") + 1
|
|
295
|
+
chars = len(cached_content)
|
|
296
|
+
return ToolResult(
|
|
297
|
+
success=True,
|
|
298
|
+
output=(
|
|
299
|
+
f"[cached] {file_path} — {total} lines, {chars:,} chars.\n"
|
|
300
|
+
f"Already in conversation context from earlier read. "
|
|
301
|
+
f"Use offset/limit to read specific sections if needed."
|
|
302
|
+
),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if needs_disk_read:
|
|
306
|
+
try:
|
|
307
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
308
|
+
raw = f.read()
|
|
309
|
+
except Exception as e:
|
|
310
|
+
return ToolResult(
|
|
311
|
+
success=False, output="", error=f"Cannot read file: {e}"
|
|
312
|
+
)
|
|
313
|
+
self._file_cache[cache_key] = (current_mtime, _time.time(), raw)
|
|
314
|
+
# LRU eviction: dict preserves insertion order (Python 3.7+).
|
|
315
|
+
# Cache hits move entries to the end, so the first keys are
|
|
316
|
+
# always the least-recently-used ones.
|
|
317
|
+
if len(self._file_cache) > self._file_cache_max_entries:
|
|
318
|
+
overflow = len(self._file_cache) - self._file_cache_max_entries
|
|
319
|
+
oldest = list(self._file_cache.keys())[:overflow]
|
|
320
|
+
for k in oldest:
|
|
321
|
+
del self._file_cache[k]
|
|
322
|
+
lines = raw.splitlines(keepends=True)
|
|
323
|
+
if lines and not lines[-1].endswith("\n"):
|
|
324
|
+
lines[-1] += "\n"
|
|
325
|
+
# Mirror to disk cache dir if configured
|
|
326
|
+
if self._cache_dir:
|
|
327
|
+
# Cross-platform safe filename: strip drive letter on Windows,
|
|
328
|
+
# replace path separators with underscores
|
|
329
|
+
resolved = str(path.resolve())
|
|
330
|
+
if len(resolved) >= 2 and resolved[1] == ":":
|
|
331
|
+
resolved = resolved[2:] # strip "C:"
|
|
332
|
+
safe_name = resolved.lstrip("\\/").replace("\\", "_").replace("/", "_")
|
|
333
|
+
try:
|
|
334
|
+
(self._cache_dir / safe_name).write_text(raw, encoding="utf-8")
|
|
335
|
+
except Exception:
|
|
336
|
+
pass
|
|
337
|
+
|
|
338
|
+
total_lines = len(lines)
|
|
339
|
+
user_specified_range = limit is not None
|
|
340
|
+
|
|
341
|
+
# Default cap: when the caller didn't ask for a specific range,
|
|
342
|
+
# truncate to MAX_READ_LINES so one file doesn't blow the context.
|
|
343
|
+
effective_limit = limit if user_specified_range else min(
|
|
344
|
+
len(lines), self.MAX_READ_LINES
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
start = (offset or 1) - 1
|
|
348
|
+
end = start + effective_limit
|
|
349
|
+
|
|
350
|
+
# Clamp
|
|
351
|
+
start = max(0, start)
|
|
352
|
+
end = min(len(lines), end)
|
|
353
|
+
|
|
354
|
+
selected = lines[start:end]
|
|
355
|
+
was_truncated = (end - start) < (total_lines - start) if not user_specified_range else False
|
|
356
|
+
# Also truncate if the user explicitly asked for a range but total
|
|
357
|
+
# output still exceeds the hard char cap (safety net).
|
|
358
|
+
char_truncated = False
|
|
359
|
+
|
|
360
|
+
# Format with line numbers
|
|
361
|
+
output_lines: list[str] = []
|
|
362
|
+
chars = 0
|
|
363
|
+
for i, line in enumerate(selected, start=start + 1):
|
|
364
|
+
formatted = f"{i:6d}\t{line.rstrip()}"
|
|
365
|
+
chars += len(formatted) + 1 # +1 for newline
|
|
366
|
+
if chars > self.MAX_READ_CHARS:
|
|
367
|
+
output_lines.append(f"... (output truncated at {self.MAX_READ_CHARS} chars, use offset/limit to read more)")
|
|
368
|
+
char_truncated = True
|
|
369
|
+
break
|
|
370
|
+
output_lines.append(formatted)
|
|
371
|
+
|
|
372
|
+
truncated = was_truncated or char_truncated
|
|
373
|
+
shown_lines = len(output_lines) - (1 if char_truncated else 0)
|
|
374
|
+
header = f"File: {path} (lines {start+1}-{start+shown_lines} of {total_lines}"
|
|
375
|
+
if truncated:
|
|
376
|
+
header += ", truncated — use offset/limit to page"
|
|
377
|
+
header += ")\n"
|
|
378
|
+
return ToolResult(success=True, output=header + "\n".join(output_lines))
|
|
379
|
+
|
|
380
|
+
async def _tool_write_file(
|
|
381
|
+
self, file_path: str, content: str
|
|
382
|
+
) -> ToolResult:
|
|
383
|
+
"""Create or overwrite a file. Captures old content for diff display."""
|
|
384
|
+
path = self._resolve_path(file_path)
|
|
385
|
+
self._ensure_in_workspace(path)
|
|
386
|
+
|
|
387
|
+
# Capture old content for diff (if file exists)
|
|
388
|
+
old_content = ""
|
|
389
|
+
if path.exists():
|
|
390
|
+
try:
|
|
391
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
392
|
+
old_content = f.read()
|
|
393
|
+
except Exception:
|
|
394
|
+
pass
|
|
395
|
+
|
|
396
|
+
try:
|
|
397
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
398
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
399
|
+
f.write(content)
|
|
400
|
+
size = path.stat().st_size
|
|
401
|
+
|
|
402
|
+
# Notify UI for diff display if overwriting
|
|
403
|
+
if old_content:
|
|
404
|
+
self._notify_edit(str(path), old_content)
|
|
405
|
+
|
|
406
|
+
return ToolResult(
|
|
407
|
+
success=True,
|
|
408
|
+
output=f"File written: {path} ({size} bytes, {content.count(chr(10))} lines)",
|
|
409
|
+
)
|
|
410
|
+
except Exception as e:
|
|
411
|
+
return ToolResult(
|
|
412
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
async def _tool_edit_file(
|
|
416
|
+
self, file_path: str, old_string: str, new_string: str
|
|
417
|
+
) -> ToolResult:
|
|
418
|
+
"""Replace text in a file. Uses CST for Python (preserves formatting),
|
|
419
|
+
falls back to text replacement for other languages."""
|
|
420
|
+
if old_string == new_string:
|
|
421
|
+
return ToolResult(
|
|
422
|
+
success=False,
|
|
423
|
+
output="",
|
|
424
|
+
error="old_string and new_string are identical",
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
path = self._resolve_path(file_path)
|
|
428
|
+
if not path.exists():
|
|
429
|
+
return ToolResult(
|
|
430
|
+
success=False,
|
|
431
|
+
output="",
|
|
432
|
+
error=f"File not found: {path}",
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
try:
|
|
436
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
437
|
+
content = f.read()
|
|
438
|
+
except Exception as e:
|
|
439
|
+
return ToolResult(
|
|
440
|
+
success=False, output="", error=f"Cannot read file: {e}"
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# Store old content for diff display (via UI callback)
|
|
444
|
+
old_content = content
|
|
445
|
+
|
|
446
|
+
# ── CST-based edit for Python files ──────────────────────────────
|
|
447
|
+
if path.suffix == ".py" or path.suffix == ".pyi":
|
|
448
|
+
result = self._cst_edit(content, old_string, new_string, path)
|
|
449
|
+
if result is not None:
|
|
450
|
+
new_content = result
|
|
451
|
+
try:
|
|
452
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
453
|
+
f.write(new_content)
|
|
454
|
+
self._notify_edit(str(path), old_content)
|
|
455
|
+
return ToolResult(
|
|
456
|
+
success=True,
|
|
457
|
+
output=f"File edited (AST): {path} (1 replacement)",
|
|
458
|
+
)
|
|
459
|
+
except Exception as e:
|
|
460
|
+
return ToolResult(
|
|
461
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
462
|
+
)
|
|
463
|
+
# CST edit failed — fall through to text-based replacement
|
|
464
|
+
|
|
465
|
+
# ── Text-based fallback ──────────────────────────────────────────
|
|
466
|
+
count = content.count(old_string)
|
|
467
|
+
if count == 0:
|
|
468
|
+
return ToolResult(
|
|
469
|
+
success=False,
|
|
470
|
+
output="",
|
|
471
|
+
error="old_string not found in file. Check whitespace/indentation.",
|
|
472
|
+
)
|
|
473
|
+
if count > 1:
|
|
474
|
+
return ToolResult(
|
|
475
|
+
success=False,
|
|
476
|
+
output="",
|
|
477
|
+
error=f"old_string found {count} times in file. Must be unique. Use a larger string with more surrounding context.",
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
new_content = content.replace(old_string, new_string, 1)
|
|
481
|
+
try:
|
|
482
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
483
|
+
f.write(new_content)
|
|
484
|
+
|
|
485
|
+
self._notify_edit(str(path), old_content)
|
|
486
|
+
|
|
487
|
+
return ToolResult(
|
|
488
|
+
success=True,
|
|
489
|
+
output=f"File edited: {path} (1 replacement)",
|
|
490
|
+
)
|
|
491
|
+
except Exception as e:
|
|
492
|
+
return ToolResult(
|
|
493
|
+
success=False, output="", error=f"Cannot write file: {e}"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
@staticmethod
|
|
497
|
+
def _cst_edit(content: str, old_str: str, new_str: str, path: Any) -> str | None:
|
|
498
|
+
"""Attempt a CST-based edit for Python files using libcst.
|
|
499
|
+
|
|
500
|
+
Parses the file as a Concrete Syntax Tree, finds the node matching
|
|
501
|
+
old_str, replaces it with new_str parsed as the same node type,
|
|
502
|
+
and returns the formatted code. Preserves ALL formatting.
|
|
503
|
+
|
|
504
|
+
Returns None if the edit cannot be performed via CST (fallback to text).
|
|
505
|
+
"""
|
|
506
|
+
try:
|
|
507
|
+
import libcst as cst
|
|
508
|
+
except ImportError:
|
|
509
|
+
return None # libcst not installed — use text fallback
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
tree = cst.parse_module(content)
|
|
513
|
+
except Exception:
|
|
514
|
+
return None # Can't parse — fall back to text
|
|
515
|
+
|
|
516
|
+
# Strategy: parse old_str as a statement, find it in the tree, replace
|
|
517
|
+
try:
|
|
518
|
+
# Try parsing old_str as a full module body (the typical case for
|
|
519
|
+
# function/method/class-level edits)
|
|
520
|
+
old_module = cst.parse_module(old_str + "\n")
|
|
521
|
+
if len(old_module.body) == 1:
|
|
522
|
+
old_node = old_module.body[0]
|
|
523
|
+
# Parse new_str as the same node type
|
|
524
|
+
new_module = cst.parse_module(new_str + "\n")
|
|
525
|
+
if len(new_module.body) == 1:
|
|
526
|
+
new_node = new_module.body[0]
|
|
527
|
+
transformer = _NodeReplacer(old_node, new_node)
|
|
528
|
+
new_tree = tree.visit(transformer)
|
|
529
|
+
if transformer.found:
|
|
530
|
+
return new_tree.code
|
|
531
|
+
except Exception:
|
|
532
|
+
pass
|
|
533
|
+
|
|
534
|
+
# Strategy 2: try parsing old_str as a simple statement line
|
|
535
|
+
try:
|
|
536
|
+
old_module = cst.parse_module(old_str + "\n")
|
|
537
|
+
if len(old_module.body) == 1:
|
|
538
|
+
old_stmt = old_module.body[0]
|
|
539
|
+
# For simple statements, use a body-statement-level replacer
|
|
540
|
+
new_module = cst.parse_module(new_str + "\n")
|
|
541
|
+
if len(new_module.body) == 1:
|
|
542
|
+
new_stmt = new_module.body[0]
|
|
543
|
+
transformer = _StatementReplacer(old_stmt, new_stmt)
|
|
544
|
+
new_tree = tree.visit(transformer)
|
|
545
|
+
if transformer.found:
|
|
546
|
+
return new_tree.code
|
|
547
|
+
except Exception:
|
|
548
|
+
pass
|
|
549
|
+
|
|
550
|
+
return None # All CST strategies failed — fall back to text
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
# ── CST Transformers (libcst optional — guarded at module level) ──────────
|
|
554
|
+
|
|
555
|
+
try:
|
|
556
|
+
import libcst as _cst_lib
|
|
557
|
+
|
|
558
|
+
class _NodeReplacer(_cst_lib.CSTTransformer):
|
|
559
|
+
"""Replace a specific CST node with another, preserving all else."""
|
|
560
|
+
|
|
561
|
+
def __init__(self, old_node: _cst_lib.CSTNode, new_node: _cst_lib.CSTNode):
|
|
562
|
+
self.old_node = old_node
|
|
563
|
+
self.new_node = new_node
|
|
564
|
+
self.found = False
|
|
565
|
+
|
|
566
|
+
def on_visit(self, node: _cst_lib.CSTNode) -> bool:
|
|
567
|
+
if node.deep_equals(self.old_node) and not self.found:
|
|
568
|
+
self.found = True
|
|
569
|
+
return False
|
|
570
|
+
return True
|
|
571
|
+
|
|
572
|
+
def on_leave(self, original_node: _cst_lib.CSTNode, updated_node: _cst_lib.CSTNode) -> _cst_lib.CSTNode:
|
|
573
|
+
if original_node.deep_equals(self.old_node) and self.found:
|
|
574
|
+
return self.new_node
|
|
575
|
+
return updated_node
|
|
576
|
+
|
|
577
|
+
class _StatementReplacer(_cst_lib.CSTTransformer):
|
|
578
|
+
"""Replace a statement within a body, matching by deep equality."""
|
|
579
|
+
|
|
580
|
+
def __init__(self, old_stmt: _cst_lib.CSTNode, new_stmt: _cst_lib.CSTNode):
|
|
581
|
+
self.old_stmt = old_stmt
|
|
582
|
+
self.new_stmt = new_stmt
|
|
583
|
+
self.found = False
|
|
584
|
+
|
|
585
|
+
def leave_SimpleStatementLine(
|
|
586
|
+
self, original_node: _cst_lib.SimpleStatementLine, updated_node: _cst_lib.SimpleStatementLine
|
|
587
|
+
):
|
|
588
|
+
if not self.found and len(updated_node.body) == 1:
|
|
589
|
+
if updated_node.body[0].deep_equals(self.old_stmt):
|
|
590
|
+
self.found = True
|
|
591
|
+
return updated_node.with_changes(body=[self.new_stmt])
|
|
592
|
+
return updated_node
|
|
593
|
+
|
|
594
|
+
except ImportError:
|
|
595
|
+
# libcst not installed — AST editing will fall back to text replacement
|
|
596
|
+
_NodeReplacer = None # type: ignore
|
|
597
|
+
_StatementReplacer = None # type: ignore
|
|
598
|
+
|
|
599
|
+
# ── Rename symbol (AST-aware) ────────────────────────────────────────────
|
|
600
|
+
|
|
601
|
+
async def _tool_rename_symbol(
|
|
602
|
+
self, file_path: str, old_name: str, new_name: str,
|
|
603
|
+
symbol_type: str = "variable"
|
|
604
|
+
) -> ToolResult:
|
|
605
|
+
"""Safely rename a Python symbol using libcst AST matching.
|
|
606
|
+
Never touches strings, comments, or imports of the same name.
|
|
607
|
+
"""
|
|
608
|
+
if old_name == new_name:
|
|
609
|
+
return ToolResult(success=False, output="", error="Names are identical.")
|
|
610
|
+
if not old_name.isidentifier() or not new_name.isidentifier():
|
|
611
|
+
return ToolResult(success=False, output="", error="Names must be valid Python identifiers.")
|
|
612
|
+
|
|
613
|
+
path = self._resolve_path(file_path)
|
|
614
|
+
if not path.exists():
|
|
615
|
+
return ToolResult(success=False, output="", error=f"File not found: {path}")
|
|
616
|
+
if path.suffix not in (".py", ".pyi"):
|
|
617
|
+
return ToolResult(success=False, output="", error="rename_symbol only works with Python files.")
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
621
|
+
content = f.read()
|
|
622
|
+
except Exception as e:
|
|
623
|
+
return ToolResult(success=False, output="", error=f"Cannot read file: {e}")
|
|
624
|
+
|
|
625
|
+
try:
|
|
626
|
+
import libcst as cst
|
|
627
|
+
except ImportError:
|
|
628
|
+
return ToolResult(success=False, output="", error="libcst not installed. Run: pip install libcst")
|
|
629
|
+
|
|
630
|
+
try:
|
|
631
|
+
tree = cst.parse_module(content)
|
|
632
|
+
except Exception as e:
|
|
633
|
+
return ToolResult(success=False, output="", error=f"Cannot parse Python file: {e}")
|
|
634
|
+
|
|
635
|
+
# Choose the right renamer for the symbol type
|
|
636
|
+
if symbol_type == "function":
|
|
637
|
+
renamer = _FunctionRenamer(old_name, new_name)
|
|
638
|
+
elif symbol_type == "class":
|
|
639
|
+
renamer = _ClassRenamer(old_name, new_name)
|
|
640
|
+
else:
|
|
641
|
+
renamer = _VariableRenamer(old_name, new_name)
|
|
642
|
+
|
|
643
|
+
new_tree = tree.visit(renamer)
|
|
644
|
+
if not renamer.changes:
|
|
645
|
+
return ToolResult(success=False, output="", error=f"Symbol '{old_name}' not found in file.")
|
|
646
|
+
|
|
647
|
+
new_content = new_tree.code
|
|
648
|
+
old_content = content
|
|
649
|
+
|
|
650
|
+
try:
|
|
651
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
652
|
+
f.write(new_content)
|
|
653
|
+
self._notify_edit(str(path), old_content)
|
|
654
|
+
return ToolResult(
|
|
655
|
+
success=True,
|
|
656
|
+
output=f"Renamed {renamer.changes} occurrence(s) of '{old_name}' → '{new_name}' in {path}",
|
|
657
|
+
)
|
|
658
|
+
except Exception as e:
|
|
659
|
+
return ToolResult(success=False, output="", error=f"Cannot write file: {e}")
|
|
660
|
+
|
|
661
|
+
# ── Shell tool ───────────────────────────────────────────────────────────
|
|
662
|
+
|
|
663
|
+
async def _tool_run_shell(
|
|
664
|
+
self, command: str, timeout: int = 120
|
|
665
|
+
) -> ToolResult:
|
|
666
|
+
"""Execute a shell command."""
|
|
667
|
+
# Safety checks
|
|
668
|
+
cmd_lower = command.lower().strip()
|
|
669
|
+
|
|
670
|
+
# Check blocked patterns (hard block — unforgivable operations)
|
|
671
|
+
for blocked in self.config.blocked_commands:
|
|
672
|
+
if blocked.lower() in cmd_lower:
|
|
673
|
+
return ToolResult(
|
|
674
|
+
success=False,
|
|
675
|
+
output="",
|
|
676
|
+
error=f"Blocked command pattern detected: {blocked}",
|
|
677
|
+
)
|
|
678
|
+
|
|
679
|
+
# Check if first word is allowed (soft warning — safety_guard is the real gate)
|
|
680
|
+
first_word = command.strip().split()[0] if command.strip() else ""
|
|
681
|
+
if first_word and first_word not in self.config.allowed_commands:
|
|
682
|
+
logger.debug("Command '%s' not in allowed_commands whitelist, proceeding", first_word)
|
|
683
|
+
|
|
684
|
+
MAX_OUTPUT_BYTES = 500_000 # cap total stdout+stderr to prevent memory exhaustion
|
|
685
|
+
proc = None
|
|
686
|
+
try:
|
|
687
|
+
# DEVNULL on stdin prevents hangs when child processes try to read
|
|
688
|
+
# from inherited stdin (common cause of "exec task freeze").
|
|
689
|
+
proc = await asyncio.create_subprocess_shell(
|
|
690
|
+
command,
|
|
691
|
+
stdin=asyncio.subprocess.DEVNULL,
|
|
692
|
+
stdout=asyncio.subprocess.PIPE,
|
|
693
|
+
stderr=asyncio.subprocess.PIPE,
|
|
694
|
+
cwd=str(self.workspace),
|
|
695
|
+
)
|
|
696
|
+
|
|
697
|
+
# Stream stdout+stderr concurrently — avoids pipe-buffer deadlock
|
|
698
|
+
# and hangs from child processes inheriting pipes.
|
|
699
|
+
# When a stream_callback is set, emit chunks in real-time so the
|
|
700
|
+
# user sees command progress instead of waiting for completion.
|
|
701
|
+
async def _read_stream(stream, chunks: list[bytes]):
|
|
702
|
+
total = 0
|
|
703
|
+
while True:
|
|
704
|
+
try:
|
|
705
|
+
chunk = await stream.read(65536)
|
|
706
|
+
except (asyncio.CancelledError, Exception):
|
|
707
|
+
return
|
|
708
|
+
if not chunk:
|
|
709
|
+
return
|
|
710
|
+
total += len(chunk)
|
|
711
|
+
if total <= MAX_OUTPUT_BYTES:
|
|
712
|
+
chunks.append(chunk)
|
|
713
|
+
# Real-time streaming to UI
|
|
714
|
+
if self._stream_cb:
|
|
715
|
+
try:
|
|
716
|
+
text = chunk.decode("utf-8", errors="replace")
|
|
717
|
+
self._stream_cb("run_shell", text)
|
|
718
|
+
except Exception:
|
|
719
|
+
pass
|
|
720
|
+
|
|
721
|
+
stdout_chunks: list[bytes] = []
|
|
722
|
+
stderr_chunks: list[bytes] = []
|
|
723
|
+
stdout_task = asyncio.create_task(_read_stream(proc.stdout, stdout_chunks))
|
|
724
|
+
stderr_task = asyncio.create_task(_read_stream(proc.stderr, stderr_chunks))
|
|
725
|
+
|
|
726
|
+
# Wait for process to exit (with timeout), NOT for pipes to close
|
|
727
|
+
try:
|
|
728
|
+
await asyncio.wait_for(proc.wait(), timeout=timeout)
|
|
729
|
+
except asyncio.TimeoutError:
|
|
730
|
+
# Kill the process FIRST — sends EOF to pipes, unblocking reader tasks
|
|
731
|
+
try:
|
|
732
|
+
proc.kill()
|
|
733
|
+
except ProcessLookupError:
|
|
734
|
+
pass
|
|
735
|
+
# Cancel reader tasks AFTER kill (kill unblocks pipe reads)
|
|
736
|
+
stdout_task.cancel()
|
|
737
|
+
stderr_task.cancel()
|
|
738
|
+
try:
|
|
739
|
+
await proc.wait()
|
|
740
|
+
except (ProcessLookupError, Exception):
|
|
741
|
+
pass
|
|
742
|
+
return ToolResult(
|
|
743
|
+
success=False, output="",
|
|
744
|
+
error=f"Command timed out after {timeout}s",
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Process exited — wait briefly for remaining pipe data
|
|
748
|
+
try:
|
|
749
|
+
await asyncio.wait_for(stdout_task, timeout=5)
|
|
750
|
+
except asyncio.TimeoutError:
|
|
751
|
+
stdout_task.cancel()
|
|
752
|
+
try:
|
|
753
|
+
await asyncio.wait_for(stderr_task, timeout=5)
|
|
754
|
+
except asyncio.TimeoutError:
|
|
755
|
+
stderr_task.cancel()
|
|
756
|
+
|
|
757
|
+
stdout_bytes = b"".join(stdout_chunks)
|
|
758
|
+
stderr_bytes = b"".join(stderr_chunks)
|
|
759
|
+
output = stdout_bytes.decode("utf-8", errors="replace")
|
|
760
|
+
if stderr_bytes:
|
|
761
|
+
output += f"\n[stderr]\n{stderr_bytes.decode('utf-8', errors='replace')}"
|
|
762
|
+
returncode = proc.returncode or 0
|
|
763
|
+
if returncode != 0:
|
|
764
|
+
output += f"\n[exit code: {returncode}]"
|
|
765
|
+
if len(stdout_bytes) + len(stderr_bytes) >= MAX_OUTPUT_BYTES:
|
|
766
|
+
output += "\n[output truncated — exceeded 500KB limit]"
|
|
767
|
+
return ToolResult(
|
|
768
|
+
success=returncode == 0,
|
|
769
|
+
output=output.strip() or "(no output)",
|
|
770
|
+
)
|
|
771
|
+
except Exception as e:
|
|
772
|
+
return ToolResult(
|
|
773
|
+
success=False, output="", error=f"Command failed: {e}"
|
|
774
|
+
)
|
|
775
|
+
finally:
|
|
776
|
+
# Explicitly close pipes to prevent "I/O operation on closed pipe"
|
|
777
|
+
# during BaseSubprocessTransport.__del__ at GC time.
|
|
778
|
+
if proc is not None:
|
|
779
|
+
for pipe in (proc.stdin, proc.stdout, proc.stderr):
|
|
780
|
+
if pipe is not None:
|
|
781
|
+
try:
|
|
782
|
+
pipe.close()
|
|
783
|
+
except Exception:
|
|
784
|
+
pass
|
|
785
|
+
|
|
786
|
+
# ── Search tools ─────────────────────────────────────────────────────────
|
|
787
|
+
|
|
788
|
+
async def _tool_grep(
|
|
789
|
+
self,
|
|
790
|
+
pattern: str,
|
|
791
|
+
path: str | None = None,
|
|
792
|
+
glob: str | None = None,
|
|
793
|
+
case_sensitive: bool = False,
|
|
794
|
+
) -> ToolResult:
|
|
795
|
+
"""Search file contents with regex."""
|
|
796
|
+
search_dir = self._resolve_path(path or ".")
|
|
797
|
+
if not search_dir.exists():
|
|
798
|
+
return ToolResult(
|
|
799
|
+
success=False,
|
|
800
|
+
output="",
|
|
801
|
+
error=f"Directory not found: {search_dir}",
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
flags = 0 if case_sensitive else re.IGNORECASE
|
|
805
|
+
try:
|
|
806
|
+
regex = re.compile(pattern, flags)
|
|
807
|
+
except re.error as e:
|
|
808
|
+
return ToolResult(
|
|
809
|
+
success=False, output="", error=f"Invalid regex: {e}"
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
# Run filesystem walk + file reads in thread pool to avoid
|
|
813
|
+
# blocking the asyncio event loop on large codebases.
|
|
814
|
+
def _do_grep():
|
|
815
|
+
results: list[str] = []
|
|
816
|
+
total_matches = 0
|
|
817
|
+
for root, dirs, files in os.walk(search_dir):
|
|
818
|
+
dirs[:] = [
|
|
819
|
+
d for d in dirs
|
|
820
|
+
if not d.startswith(".")
|
|
821
|
+
and d not in self.SKIP_DIRS
|
|
822
|
+
]
|
|
823
|
+
for fname in files:
|
|
824
|
+
if glob and not fnmatch.fnmatch(fname, glob):
|
|
825
|
+
continue
|
|
826
|
+
|
|
827
|
+
full_path = os.path.join(root, fname)
|
|
828
|
+
try:
|
|
829
|
+
rel_path = os.path.relpath(full_path, self.workspace)
|
|
830
|
+
except ValueError:
|
|
831
|
+
rel_path = full_path
|
|
832
|
+
|
|
833
|
+
try:
|
|
834
|
+
with open(full_path, "r", encoding="utf-8", errors="replace") as fh:
|
|
835
|
+
file_lines = fh.readlines()
|
|
836
|
+
except Exception:
|
|
837
|
+
continue
|
|
838
|
+
|
|
839
|
+
matches_in_file = []
|
|
840
|
+
for line_no, line_text in enumerate(file_lines, 1):
|
|
841
|
+
if regex.search(line_text):
|
|
842
|
+
matches_in_file.append(
|
|
843
|
+
f" {line_no}: {line_text.rstrip()[:200]}"
|
|
844
|
+
)
|
|
845
|
+
total_matches += 1
|
|
846
|
+
if len(matches_in_file) >= self.MAX_GREP_PER_FILE:
|
|
847
|
+
matches_in_file.append(" ... (truncated)")
|
|
848
|
+
break
|
|
849
|
+
|
|
850
|
+
if matches_in_file:
|
|
851
|
+
results.append(f"{rel_path}:")
|
|
852
|
+
results.extend(matches_in_file)
|
|
853
|
+
|
|
854
|
+
if len(results) >= self.MAX_GREP_RESULTS:
|
|
855
|
+
results.append("... (result limit reached)")
|
|
856
|
+
return results, total_matches
|
|
857
|
+
|
|
858
|
+
if len(results) >= self.MAX_GREP_RESULTS:
|
|
859
|
+
break
|
|
860
|
+
return results, total_matches
|
|
861
|
+
|
|
862
|
+
results, total_matches = await self._run_in_thread(_do_grep)
|
|
863
|
+
|
|
864
|
+
if not results:
|
|
865
|
+
return ToolResult(
|
|
866
|
+
success=True,
|
|
867
|
+
output=f"No matches found for pattern: {pattern}",
|
|
868
|
+
)
|
|
869
|
+
return ToolResult(
|
|
870
|
+
success=True,
|
|
871
|
+
output=f"Found {total_matches} matches:\n\n" + "\n".join(results),
|
|
872
|
+
)
|
|
873
|
+
|
|
874
|
+
async def _tool_glob(
|
|
875
|
+
self,
|
|
876
|
+
pattern: str,
|
|
877
|
+
path: str | None = None,
|
|
878
|
+
) -> ToolResult:
|
|
879
|
+
"""Find files by glob pattern."""
|
|
880
|
+
search_dir = self._resolve_path(path or ".")
|
|
881
|
+
if not search_dir.exists():
|
|
882
|
+
return ToolResult(
|
|
883
|
+
success=False,
|
|
884
|
+
output="",
|
|
885
|
+
error=f"Directory not found: {search_dir}",
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
import glob as glob_mod
|
|
889
|
+
|
|
890
|
+
# Auto-add **/ prefix for recursive matching if not already present
|
|
891
|
+
if "**" not in pattern:
|
|
892
|
+
pattern = f"**/{pattern}"
|
|
893
|
+
search_pattern = str(search_dir / pattern)
|
|
894
|
+
matches = glob_mod.glob(search_pattern, recursive=True)
|
|
895
|
+
|
|
896
|
+
if not matches:
|
|
897
|
+
return ToolResult(
|
|
898
|
+
success=True, output=f"No files matching: {pattern}"
|
|
899
|
+
)
|
|
900
|
+
|
|
901
|
+
# Sort and format
|
|
902
|
+
matches.sort()
|
|
903
|
+
output_lines = []
|
|
904
|
+
for m in matches[:200]:
|
|
905
|
+
try:
|
|
906
|
+
rel = os.path.relpath(m, self.workspace)
|
|
907
|
+
except ValueError:
|
|
908
|
+
rel = m
|
|
909
|
+
size = os.path.getsize(m) if os.path.isfile(m) else 0
|
|
910
|
+
output_lines.append(f" {rel} ({size:,} bytes)")
|
|
911
|
+
|
|
912
|
+
if len(matches) > 200:
|
|
913
|
+
output_lines.append(
|
|
914
|
+
f" ... and {len(matches) - 200} more files"
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
return ToolResult(
|
|
918
|
+
success=True,
|
|
919
|
+
output=f"Found {len(matches)} files matching '{pattern}':\n"
|
|
920
|
+
+ "\n".join(output_lines),
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
async def _tool_list_dir(
|
|
924
|
+
self,
|
|
925
|
+
path: str | None = None,
|
|
926
|
+
recursive: bool = False,
|
|
927
|
+
) -> ToolResult:
|
|
928
|
+
"""List directory contents."""
|
|
929
|
+
target = self._resolve_path(path or ".")
|
|
930
|
+
if not target.exists():
|
|
931
|
+
return ToolResult(
|
|
932
|
+
success=False,
|
|
933
|
+
output="",
|
|
934
|
+
error=f"Directory not found: {target}",
|
|
935
|
+
)
|
|
936
|
+
if not target.is_dir():
|
|
937
|
+
return ToolResult(
|
|
938
|
+
success=False,
|
|
939
|
+
output="",
|
|
940
|
+
error=f"Not a directory: {target}",
|
|
941
|
+
)
|
|
942
|
+
|
|
943
|
+
output_lines = [f"Directory: {target}"]
|
|
944
|
+
entries: list[str] = []
|
|
945
|
+
|
|
946
|
+
if recursive:
|
|
947
|
+
for root, dirs, files in os.walk(target):
|
|
948
|
+
dirs[:] = [
|
|
949
|
+
d for d in dirs
|
|
950
|
+
if not d.startswith(".") and d not in self.SKIP_DIRS
|
|
951
|
+
]
|
|
952
|
+
level = root.replace(str(target), "").count(os.sep)
|
|
953
|
+
indent = " " * level
|
|
954
|
+
if level > 0:
|
|
955
|
+
entries.append(f"{indent}{os.path.basename(root)}/")
|
|
956
|
+
for f in sorted(files):
|
|
957
|
+
fp = os.path.join(root, f)
|
|
958
|
+
size = os.path.getsize(fp)
|
|
959
|
+
entries.append(f"{indent} {f} ({size:,}B)")
|
|
960
|
+
else:
|
|
961
|
+
items = sorted(target.iterdir(), key=lambda x: (not x.is_dir(), x.name))
|
|
962
|
+
for item in items:
|
|
963
|
+
suffix = "/" if item.is_dir() else ""
|
|
964
|
+
size = ""
|
|
965
|
+
if item.is_file():
|
|
966
|
+
size = f" ({item.stat().st_size:,}B)"
|
|
967
|
+
entries.append(f" {item.name}{suffix}{size}")
|
|
968
|
+
|
|
969
|
+
output_lines.extend(entries[:500])
|
|
970
|
+
if len(entries) > 500:
|
|
971
|
+
output_lines.append(f" ... and {len(entries) - 500} more entries")
|
|
972
|
+
|
|
973
|
+
return ToolResult(success=True, output="\n".join(output_lines))
|
|
974
|
+
|
|
975
|
+
|
|
976
|
+
# ── Factory ──────────────────────────────────────────────────────────────────
|
|
977
|
+
|
|
978
|
+
def create_tool_executor(workspace_dir: str | None = None) -> ToolExecutor:
|
|
979
|
+
"""Create a tool executor with the given workspace."""
|
|
980
|
+
from ..config import AgentConfig
|
|
981
|
+
if workspace_dir:
|
|
982
|
+
cfg = AgentConfig(workspace_dir=workspace_dir)
|
|
983
|
+
return ToolExecutor(cfg)
|
|
984
|
+
return ToolExecutor()
|
|
985
|
+
|
|
986
|
+
|
|
987
|
+
# ── CST Renamers (for rename_symbol tool) ───────────────────────────────────
|
|
988
|
+
# These use libcst to safely rename symbols without touching strings or comments.
|
|
989
|
+
|
|
990
|
+
class _SymbolRenamer:
|
|
991
|
+
"""Base class for CST symbol renamers — shared leave_Name/leave_Call logic."""
|
|
992
|
+
|
|
993
|
+
def __init__(self, old_name: str, new_name: str):
|
|
994
|
+
self.old = old_name
|
|
995
|
+
self.new = new_name
|
|
996
|
+
self.changes = 0
|
|
997
|
+
|
|
998
|
+
def leave_Name(self, original_node, updated_node):
|
|
999
|
+
import libcst as cst
|
|
1000
|
+
if isinstance(updated_node, cst.Name) and updated_node.value == self.old:
|
|
1001
|
+
self.changes += 1
|
|
1002
|
+
return updated_node.with_changes(value=self.new)
|
|
1003
|
+
return updated_node
|
|
1004
|
+
|
|
1005
|
+
def leave_Call(self, original_node, updated_node):
|
|
1006
|
+
import libcst as cst
|
|
1007
|
+
if isinstance(updated_node.func, cst.Name) and updated_node.func.value == self.old:
|
|
1008
|
+
self.changes += 1
|
|
1009
|
+
return updated_node.with_changes(func=cst.Name(value=self.new))
|
|
1010
|
+
return updated_node
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
class _VariableRenamer(_SymbolRenamer):
|
|
1014
|
+
"""Rename variable references — names only, not touching strings/comments."""
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
class _FunctionRenamer(_SymbolRenamer):
|
|
1018
|
+
"""Rename function/method definitions and calls."""
|
|
1019
|
+
|
|
1020
|
+
def leave_FunctionDef(self, original_node, updated_node):
|
|
1021
|
+
import libcst as cst
|
|
1022
|
+
if updated_node.name.value == self.old:
|
|
1023
|
+
self.changes += 1
|
|
1024
|
+
return updated_node.with_changes(name=cst.Name(value=self.new))
|
|
1025
|
+
return updated_node
|
|
1026
|
+
|
|
1027
|
+
|
|
1028
|
+
class _ClassRenamer(_SymbolRenamer):
|
|
1029
|
+
"""Rename class definitions and constructor calls."""
|
|
1030
|
+
|
|
1031
|
+
def leave_ClassDef(self, original_node, updated_node):
|
|
1032
|
+
import libcst as cst
|
|
1033
|
+
if updated_node.name.value == self.old:
|
|
1034
|
+
self.changes += 1
|
|
1035
|
+
return updated_node.with_changes(name=cst.Name(value=self.new))
|
|
1036
|
+
return updated_node
|