ata-coder 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. ata_coder/__init__.py +1 -0
  2. ata_coder/agent.py +874 -0
  3. ata_coder/agent_compact.py +190 -0
  4. ata_coder/agent_controller.py +218 -0
  5. ata_coder/agent_extension.py +69 -0
  6. ata_coder/agent_routing.py +105 -0
  7. ata_coder/agent_subsystems.py +72 -0
  8. ata_coder/agent_tools.py +318 -0
  9. ata_coder/agent_undo.py +63 -0
  10. ata_coder/anthropic_client.py +465 -0
  11. ata_coder/change_tracker.py +368 -0
  12. ata_coder/clawd_integration.py +574 -0
  13. ata_coder/commands/__init__.py +128 -0
  14. ata_coder/commands/_core.py +184 -0
  15. ata_coder/commands/_safety.py +95 -0
  16. ata_coder/commands/_settings.py +241 -0
  17. ata_coder/commands/_workflow.py +451 -0
  18. ata_coder/commands.py +974 -0
  19. ata_coder/config.py +257 -0
  20. ata_coder/core/__init__.py +35 -0
  21. ata_coder/core/events.py +73 -0
  22. ata_coder/core/queue.py +85 -0
  23. ata_coder/core/state.py +17 -0
  24. ata_coder/event_queue.py +5 -0
  25. ata_coder/extension.py +654 -0
  26. ata_coder/extensions/__init__.py +1 -0
  27. ata_coder/extensions/hello_skill.py +47 -0
  28. ata_coder/fool_proof.py +295 -0
  29. ata_coder/git_workflow.py +371 -0
  30. ata_coder/gui.py +511 -0
  31. ata_coder/llm_client.py +543 -0
  32. ata_coder/main.py +814 -0
  33. ata_coder/mcp_client.py +1095 -0
  34. ata_coder/memory.py +539 -0
  35. ata_coder/model_registry.py +134 -0
  36. ata_coder/model_router.py +105 -0
  37. ata_coder/permissions.py +274 -0
  38. ata_coder/privilege.py +464 -0
  39. ata_coder/project.py +273 -0
  40. ata_coder/prompt_template.py +423 -0
  41. ata_coder/prompts/auto-mode.md +7 -0
  42. ata_coder/prompts/coding-rules.md +40 -0
  43. ata_coder/prompts/execution-guardrails.md +14 -0
  44. ata_coder/prompts/memory-system.md +24 -0
  45. ata_coder/prompts/output-style.md +23 -0
  46. ata_coder/prompts/safety.md +17 -0
  47. ata_coder/prompts/slash-commands.md +24 -0
  48. ata_coder/prompts/sub-agents.md +38 -0
  49. ata_coder/prompts/system-reminders.md +17 -0
  50. ata_coder/prompts/system.md +105 -0
  51. ata_coder/prompts/tool-policy.md +46 -0
  52. ata_coder/repl_theme.py +99 -0
  53. ata_coder/repl_tracker.py +89 -0
  54. ata_coder/repl_ui.py +1214 -0
  55. ata_coder/safety_guard.py +434 -0
  56. ata_coder/self_correct.py +346 -0
  57. ata_coder/server.py +882 -0
  58. ata_coder/server_session.py +159 -0
  59. ata_coder/server_shell.py +129 -0
  60. ata_coder/session.py +431 -0
  61. ata_coder/settings.py +439 -0
  62. ata_coder/setup_wizard.py +136 -0
  63. ata_coder/skill_extension.py +92 -0
  64. ata_coder/skills/architect/SKILL.md +42 -0
  65. ata_coder/skills/code-reviewer/SKILL.md +37 -0
  66. ata_coder/skills/codecraft/SKILL.md +452 -0
  67. ata_coder/skills/debugger/SKILL.md +45 -0
  68. ata_coder/skills/doc-writer/SKILL.md +36 -0
  69. ata_coder/skills/general-coder/SKILL.md +76 -0
  70. ata_coder/skills/math-calculator/README.md +40 -0
  71. ata_coder/skills/math-calculator/SKILL.md +59 -0
  72. ata_coder/skills/math-calculator/handler.py +103 -0
  73. ata_coder/skills/math-calculator/prompts/system.md +8 -0
  74. ata_coder/skills/math-calculator/requirements.txt +2 -0
  75. ata_coder/skills/math-calculator/resources/constants.json +8 -0
  76. ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
  77. ata_coder/skills/security-auditor/SKILL.md +40 -0
  78. ata_coder/skills/test-writer/SKILL.md +36 -0
  79. ata_coder/skills/weather-skill/README.md +45 -0
  80. ata_coder/skills/weather-skill/handler.py +76 -0
  81. ata_coder/skills/weather-skill/manifest.json +48 -0
  82. ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
  83. ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
  84. ata_coder/skills/weather-skill/requirements.txt +1 -0
  85. ata_coder/skills/weather-skill/resources/city_list.json +17 -0
  86. ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
  87. ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
  88. ata_coder/skills/weather-skill/weather_utils.py +50 -0
  89. ata_coder/skills.py +1014 -0
  90. ata_coder/sub_agent.py +273 -0
  91. ata_coder/sub_agent_manager.py +203 -0
  92. ata_coder/system_prompt_builder.py +146 -0
  93. ata_coder/task_planner.py +391 -0
  94. ata_coder/terminal.py +318 -0
  95. ata_coder/test_runner.py +219 -0
  96. ata_coder/thread_supervisor.py +195 -0
  97. ata_coder/tool_defs.py +335 -0
  98. ata_coder/tools/__init__.py +11 -0
  99. ata_coder/tools/definitions.py +335 -0
  100. ata_coder/tools/executor.py +1036 -0
  101. ata_coder/tools/result.py +26 -0
  102. ata_coder/tools/subagent.py +332 -0
  103. ata_coder/tools/web.py +361 -0
  104. ata_coder/tools.py +1576 -0
  105. ata_coder/types.py +92 -0
  106. ata_coder/utils.py +113 -0
  107. ata_coder/web/css/style.css +180 -0
  108. ata_coder/web/index.html +84 -0
  109. ata_coder/web/js/app.js +489 -0
  110. ata_coder/web/package-lock.json +25 -0
  111. ata_coder/web/package.json +10 -0
  112. ata_coder/web/tsconfig.json +13 -0
  113. ata_coder-2.4.2.dist-info/METADATA +799 -0
  114. ata_coder-2.4.2.dist-info/RECORD +118 -0
  115. ata_coder-2.4.2.dist-info/WHEEL +5 -0
  116. ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
  117. ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
  118. ata_coder-2.4.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1036 @@
1
+ """
2
+ Tool system for the ATA Coder.
3
+
4
+ Provides a set of tools the agent can use:
5
+ - read_file: Read file contents
6
+ - write_file: Create or overwrite a file
7
+ - edit_file: Precise string replacement in a file
8
+ - run_shell: Execute a shell command
9
+ - grep: Search file contents with regex
10
+ - glob: Find files matching a pattern
11
+ - list_dir: List directory contents
12
+ - web_search: Search the web (optional)
13
+ """
14
+
15
+ import asyncio
16
+ import logging
17
+ import os
18
+ import re
19
+ import shutil
20
+ import fnmatch
21
+ from pathlib import Path
22
+ from typing import Any, Callable
23
+
24
+ # Cached reference for __del__ — avoids "import asyncio" during interpreter
25
+ # shutdown, which fails with "ImportError: sys.meta_path is None".
26
+ _asyncio_get_running_loop = asyncio.get_running_loop
27
+
28
+
29
+ from ..config import AgentConfig
30
+ from .result import ToolResult
31
+ from .web import WebToolsMixin
32
+ from .subagent import SubAgentToolsMixin
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # ── Tool result type ─────────────────────────────────────────────────────────
38
+
39
+ class ToolExecutor(WebToolsMixin, SubAgentToolsMixin):
40
+ """Executes tool calls and manages workspace context."""
41
+
42
+ # Result limits
43
+ MAX_GREP_RESULTS = 100
44
+ MAX_GREP_PER_FILE = 20
45
+ MAX_GLOB_RESULTS = 200
46
+ MAX_DIR_ENTRIES = 500
47
+
48
+ # Directories to skip during recursive operations
49
+ SKIP_DIRS = {
50
+ "node_modules", "__pycache__", ".git", "venv", ".venv",
51
+ "dist", "build", "target", ".next", ".pytest_cache",
52
+ ".mypy_cache", ".ruff_cache", ".tox", ".idea", ".vs",
53
+ ".vscode", "bower_components", ".terraform", ".eggs",
54
+ "*.egg-info", "htmlcov", ".coverage", "__pypackages__",
55
+ }
56
+
57
+ def __init__(self, config: AgentConfig | None = None):
58
+ self.config = config or AgentConfig()
59
+ self.workspace = Path(self.config.workspace_dir).resolve()
60
+ self._mcp = None # set via set_mcp_client()
61
+ self._edit_callback: Callable[[str, str], None] | None = None
62
+ # File read cache: path → (mtime, cached_at, content).
63
+ # "只读一遍" — files are read from disk once per session and served
64
+ # from memory after. Entries older than _FILE_CACHE_TTL seconds are
65
+ # re-read to catch external modifications that don't update mtime.
66
+ # LRU eviction: hits refresh position (move-to-end), eviction drops
67
+ # from front (least-recently-used).
68
+ self._file_cache: dict[str, tuple[float, float, str]] = {}
69
+ self._file_cache_max_entries = 50
70
+ self._FILE_CACHE_TTL = 30.0 # seconds before a cache entry is revalidated
71
+ self._cache_dir: Path | None = None
72
+ # Sub-agent manager (set by agent)
73
+ self._sub_agent_mgr: Any = None
74
+ # Pre-built handler map (validated at init time)
75
+ self._handlers: dict[str, Callable] = self._build_handlers()
76
+ # Streaming callback: set by agent to receive real-time output chunks
77
+ self._stream_cb: Callable[[str, str], None] | None = None # (tool_name, chunk)
78
+
79
+ def set_stream_callback(self, cb: Callable[[str, str], None] | None) -> None:
80
+ """Set callback for real-time tool output streaming.
81
+
82
+ ``cb(tool_name, chunk)`` is called with incremental output chunks
83
+ during long-running tool execution (e.g. run_shell).
84
+ Set to None to disable streaming.
85
+ """
86
+ self._stream_cb = cb
87
+
88
+ def _build_handlers(self) -> dict[str, Callable]:
89
+ """Build a dispatch table: tool_name → handler callable.
90
+
91
+ Only registers callable methods — prevents accidental registration
92
+ of mixin properties, class attributes, or helper objects whose names
93
+ happen to match the ``_tool_`` prefix.
94
+ """
95
+ handlers: dict[str, Callable] = {}
96
+ for name in dir(self):
97
+ if name.startswith("_tool_"):
98
+ attr = getattr(self, name)
99
+ if callable(attr):
100
+ tool_name = name[len("_tool_"):]
101
+ handlers[tool_name] = attr
102
+ return handlers
103
+
104
+ def set_sub_agent_manager(self, mgr: Any) -> None:
105
+ """Set the SubAgentManager for spawn/collect sub-agent tool support."""
106
+ self._sub_agent_mgr = mgr
107
+
108
+ def set_mcp_client(self, mcp: Any) -> None:
109
+ """Set the MCPClient for mcp_search tool support."""
110
+ self._mcp = mcp
111
+
112
+ def setup_file_cache(self, cache_dir: str | Path) -> None:
113
+ """Create session cache directory. Call once before running the agent."""
114
+ self._cache_dir = Path(cache_dir)
115
+ self._cache_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ def clear_file_cache(self) -> None:
118
+ """Remove all cached files and the cache directory."""
119
+ self._file_cache.clear()
120
+ if self._cache_dir and self._cache_dir.exists():
121
+ shutil.rmtree(self._cache_dir, ignore_errors=True)
122
+ self._cache_dir = None
123
+
124
+ def close(self) -> None:
125
+ """Release all held resources (httpx clients, file caches)."""
126
+ self.clear_file_cache()
127
+ if hasattr(self, "_http") and self._http is not None:
128
+ try:
129
+ self._http.close()
130
+ except Exception:
131
+ pass
132
+ self._http = None
133
+
134
+ def __del__(self) -> None:
135
+ """Safety net: ensure httpx client is closed on GC.
136
+
137
+ Only calls close() outside an active asyncio event loop — during
138
+ interpreter shutdown the loop may already be closed, and touching
139
+ it from __del__ triggers spurious warnings/errors.
140
+
141
+ Uses a cached reference to asyncio.get_running_loop (the module-level
142
+ import at the top of this file) to avoid ``import asyncio`` inside
143
+ __del__, which fails with ``ImportError: sys.meta_path is None``
144
+ during interpreter shutdown.
145
+ """
146
+ try:
147
+ _asyncio_get_running_loop()
148
+ except RuntimeError:
149
+ # No running loop — safe to close synchronously
150
+ try:
151
+ self.close()
152
+ except Exception:
153
+ pass
154
+
155
+ def on_edit(self, callback: Callable[[str, str], None]) -> None:
156
+ """Register callback for edit notifications: callback(file_path, old_content)."""
157
+ self._edit_callback = callback
158
+
159
+ def _notify_edit(self, file_path: str, old_content: str) -> None:
160
+ """Notify the UI of a file edit for diff display."""
161
+ if self._edit_callback:
162
+ try:
163
+ self._edit_callback(file_path, old_content)
164
+ except Exception:
165
+ logger.exception("Edit callback failed for %s", file_path)
166
+
167
+ def _resolve_path(self, file_path: str) -> Path:
168
+ """Resolve a file path relative to workspace.
169
+
170
+ Raises ValueError if the resolved path escapes the workspace via
171
+ path traversal (e.g., ``../../etc/passwd``).
172
+ """
173
+ p = Path(file_path)
174
+ if not p.is_absolute():
175
+ p = self.workspace / p
176
+ resolved = p.resolve()
177
+ try:
178
+ resolved.relative_to(self.workspace.resolve())
179
+ except ValueError:
180
+ raise ValueError(
181
+ f"Path traversal blocked: {file_path} → {resolved} "
182
+ f"is outside workspace {self.workspace}"
183
+ )
184
+ return resolved
185
+
186
+ def _ensure_in_workspace(self, path: Path) -> Path:
187
+ """
188
+ Ensure path is within workspace for safety.
189
+ For files outside workspace, allow with a warning.
190
+ """
191
+ try:
192
+ path.resolve().relative_to(self.workspace)
193
+ except ValueError:
194
+ logger.warning("Path outside workspace: %s", path)
195
+ return path
196
+
197
+ # Global output cap — every tool result is trimmed to this many chars.
198
+ # ~25k tokens, still generous for legitimate work but prevents a single
199
+ # result from eating the whole context window.
200
+ MAX_OUTPUT_CHARS = 100_000
201
+
202
+ async def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
203
+ """Dispatch a tool call to the appropriate handler.
204
+
205
+ All results are capped at MAX_OUTPUT_CHARS to prevent any single
206
+ tool response from blowing up the conversation context.
207
+ """
208
+ handler = self._handlers.get(tool_name)
209
+ if handler is None:
210
+ return ToolResult(
211
+ success=False,
212
+ output="",
213
+ error=f"Unknown tool: {tool_name}",
214
+ )
215
+ try:
216
+ result = await handler(**arguments)
217
+ except Exception as e:
218
+ logger.exception("Tool %s failed", tool_name)
219
+ return ToolResult(
220
+ success=False, output="", error=f"{type(e).__name__}: {e}"
221
+ )
222
+
223
+ # ── Global size cap ──────────────────────────────────────────
224
+ if len(result.output) > self.MAX_OUTPUT_CHARS:
225
+ original_len = len(result.output)
226
+ cut = result.output[:self.MAX_OUTPUT_CHARS]
227
+ result.output = (
228
+ cut + f"\n\n... [truncated {original_len - self.MAX_OUTPUT_CHARS:,} "
229
+ f"chars — result was {original_len:,} chars total]"
230
+ )
231
+ logger.warning(
232
+ "Tool %s output truncated: %d → %d chars",
233
+ tool_name, original_len, self.MAX_OUTPUT_CHARS,
234
+ )
235
+
236
+ return result
237
+
238
+ # ── File tools ───────────────────────────────────────────────────────────
239
+
240
+ # ── Output limits (prevent token bloat from large file reads) ──────
241
+ MAX_READ_LINES = 2000 # auto-truncate reads without an explicit limit
242
+ MAX_READ_CHARS = 80_000 # hard cap on output chars (~20k tokens)
243
+
244
+ async def _tool_read_file(
245
+ self,
246
+ file_path: str,
247
+ offset: int | None = None,
248
+ limit: int | None = None,
249
+ ) -> ToolResult:
250
+ """Read a file with optional line range.
251
+
252
+ When *limit* is not set, output is capped at MAX_READ_LINES to
253
+ prevent a single file read from eating tens of thousands of context
254
+ tokens. Use *offset* + *limit* to page through larger files.
255
+ """
256
+ path = self._resolve_path(file_path)
257
+ if not path.exists():
258
+ return ToolResult(
259
+ success=False,
260
+ output="",
261
+ error=f"File not found: {path}",
262
+ )
263
+ if path.is_dir():
264
+ return ToolResult(
265
+ success=False,
266
+ output="",
267
+ error=f"Path is a directory, not a file: {path}",
268
+ )
269
+
270
+ # ── File cache: "只读一遍" — return short note on re-read ────
271
+ import time as _time
272
+ cache_key = str(path.resolve())
273
+ current_mtime = path.stat().st_mtime
274
+ needs_disk_read = True
275
+
276
+ if cache_key in self._file_cache:
277
+ cached_mtime, cached_at, cached_content = self._file_cache[cache_key]
278
+ age = _time.time() - cached_at
279
+ if age > self._FILE_CACHE_TTL:
280
+ # TTL expired — re-read from disk to catch external modifications
281
+ pass # fall through to needs_disk_read
282
+ elif cached_mtime == current_mtime:
283
+ # Refresh LRU position: pop and re-insert moves key to end
284
+ self._file_cache.pop(cache_key)
285
+ self._file_cache[cache_key] = (cached_mtime, _time.time(), cached_content)
286
+ if offset is not None or limit is not None:
287
+ # Specific section — serve from memory, skip disk
288
+ needs_disk_read = False
289
+ lines = cached_content.splitlines(keepends=True)
290
+ if lines and not lines[-1].endswith("\n"):
291
+ lines[-1] += "\n"
292
+ else:
293
+ # Whole file re-read — DON'T send content again
294
+ total = cached_content.count("\n") + 1
295
+ chars = len(cached_content)
296
+ return ToolResult(
297
+ success=True,
298
+ output=(
299
+ f"[cached] {file_path} — {total} lines, {chars:,} chars.\n"
300
+ f"Already in conversation context from earlier read. "
301
+ f"Use offset/limit to read specific sections if needed."
302
+ ),
303
+ )
304
+
305
+ if needs_disk_read:
306
+ try:
307
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
308
+ raw = f.read()
309
+ except Exception as e:
310
+ return ToolResult(
311
+ success=False, output="", error=f"Cannot read file: {e}"
312
+ )
313
+ self._file_cache[cache_key] = (current_mtime, _time.time(), raw)
314
+ # LRU eviction: dict preserves insertion order (Python 3.7+).
315
+ # Cache hits move entries to the end, so the first keys are
316
+ # always the least-recently-used ones.
317
+ if len(self._file_cache) > self._file_cache_max_entries:
318
+ overflow = len(self._file_cache) - self._file_cache_max_entries
319
+ oldest = list(self._file_cache.keys())[:overflow]
320
+ for k in oldest:
321
+ del self._file_cache[k]
322
+ lines = raw.splitlines(keepends=True)
323
+ if lines and not lines[-1].endswith("\n"):
324
+ lines[-1] += "\n"
325
+ # Mirror to disk cache dir if configured
326
+ if self._cache_dir:
327
+ # Cross-platform safe filename: strip drive letter on Windows,
328
+ # replace path separators with underscores
329
+ resolved = str(path.resolve())
330
+ if len(resolved) >= 2 and resolved[1] == ":":
331
+ resolved = resolved[2:] # strip "C:"
332
+ safe_name = resolved.lstrip("\\/").replace("\\", "_").replace("/", "_")
333
+ try:
334
+ (self._cache_dir / safe_name).write_text(raw, encoding="utf-8")
335
+ except Exception:
336
+ pass
337
+
338
+ total_lines = len(lines)
339
+ user_specified_range = limit is not None
340
+
341
+ # Default cap: when the caller didn't ask for a specific range,
342
+ # truncate to MAX_READ_LINES so one file doesn't blow the context.
343
+ effective_limit = limit if user_specified_range else min(
344
+ len(lines), self.MAX_READ_LINES
345
+ )
346
+
347
+ start = (offset or 1) - 1
348
+ end = start + effective_limit
349
+
350
+ # Clamp
351
+ start = max(0, start)
352
+ end = min(len(lines), end)
353
+
354
+ selected = lines[start:end]
355
+ was_truncated = (end - start) < (total_lines - start) if not user_specified_range else False
356
+ # Also truncate if the user explicitly asked for a range but total
357
+ # output still exceeds the hard char cap (safety net).
358
+ char_truncated = False
359
+
360
+ # Format with line numbers
361
+ output_lines: list[str] = []
362
+ chars = 0
363
+ for i, line in enumerate(selected, start=start + 1):
364
+ formatted = f"{i:6d}\t{line.rstrip()}"
365
+ chars += len(formatted) + 1 # +1 for newline
366
+ if chars > self.MAX_READ_CHARS:
367
+ output_lines.append(f"... (output truncated at {self.MAX_READ_CHARS} chars, use offset/limit to read more)")
368
+ char_truncated = True
369
+ break
370
+ output_lines.append(formatted)
371
+
372
+ truncated = was_truncated or char_truncated
373
+ shown_lines = len(output_lines) - (1 if char_truncated else 0)
374
+ header = f"File: {path} (lines {start+1}-{start+shown_lines} of {total_lines}"
375
+ if truncated:
376
+ header += ", truncated — use offset/limit to page"
377
+ header += ")\n"
378
+ return ToolResult(success=True, output=header + "\n".join(output_lines))
379
+
380
+ async def _tool_write_file(
381
+ self, file_path: str, content: str
382
+ ) -> ToolResult:
383
+ """Create or overwrite a file. Captures old content for diff display."""
384
+ path = self._resolve_path(file_path)
385
+ self._ensure_in_workspace(path)
386
+
387
+ # Capture old content for diff (if file exists)
388
+ old_content = ""
389
+ if path.exists():
390
+ try:
391
+ with open(path, "r", encoding="utf-8") as f:
392
+ old_content = f.read()
393
+ except Exception:
394
+ pass
395
+
396
+ try:
397
+ path.parent.mkdir(parents=True, exist_ok=True)
398
+ with open(path, "w", encoding="utf-8") as f:
399
+ f.write(content)
400
+ size = path.stat().st_size
401
+
402
+ # Notify UI for diff display if overwriting
403
+ if old_content:
404
+ self._notify_edit(str(path), old_content)
405
+
406
+ return ToolResult(
407
+ success=True,
408
+ output=f"File written: {path} ({size} bytes, {content.count(chr(10))} lines)",
409
+ )
410
+ except Exception as e:
411
+ return ToolResult(
412
+ success=False, output="", error=f"Cannot write file: {e}"
413
+ )
414
+
415
+ async def _tool_edit_file(
416
+ self, file_path: str, old_string: str, new_string: str
417
+ ) -> ToolResult:
418
+ """Replace text in a file. Uses CST for Python (preserves formatting),
419
+ falls back to text replacement for other languages."""
420
+ if old_string == new_string:
421
+ return ToolResult(
422
+ success=False,
423
+ output="",
424
+ error="old_string and new_string are identical",
425
+ )
426
+
427
+ path = self._resolve_path(file_path)
428
+ if not path.exists():
429
+ return ToolResult(
430
+ success=False,
431
+ output="",
432
+ error=f"File not found: {path}",
433
+ )
434
+
435
+ try:
436
+ with open(path, "r", encoding="utf-8") as f:
437
+ content = f.read()
438
+ except Exception as e:
439
+ return ToolResult(
440
+ success=False, output="", error=f"Cannot read file: {e}"
441
+ )
442
+
443
+ # Store old content for diff display (via UI callback)
444
+ old_content = content
445
+
446
+ # ── CST-based edit for Python files ──────────────────────────────
447
+ if path.suffix == ".py" or path.suffix == ".pyi":
448
+ result = self._cst_edit(content, old_string, new_string, path)
449
+ if result is not None:
450
+ new_content = result
451
+ try:
452
+ with open(path, "w", encoding="utf-8") as f:
453
+ f.write(new_content)
454
+ self._notify_edit(str(path), old_content)
455
+ return ToolResult(
456
+ success=True,
457
+ output=f"File edited (AST): {path} (1 replacement)",
458
+ )
459
+ except Exception as e:
460
+ return ToolResult(
461
+ success=False, output="", error=f"Cannot write file: {e}"
462
+ )
463
+ # CST edit failed — fall through to text-based replacement
464
+
465
+ # ── Text-based fallback ──────────────────────────────────────────
466
+ count = content.count(old_string)
467
+ if count == 0:
468
+ return ToolResult(
469
+ success=False,
470
+ output="",
471
+ error="old_string not found in file. Check whitespace/indentation.",
472
+ )
473
+ if count > 1:
474
+ return ToolResult(
475
+ success=False,
476
+ output="",
477
+ error=f"old_string found {count} times in file. Must be unique. Use a larger string with more surrounding context.",
478
+ )
479
+
480
+ new_content = content.replace(old_string, new_string, 1)
481
+ try:
482
+ with open(path, "w", encoding="utf-8") as f:
483
+ f.write(new_content)
484
+
485
+ self._notify_edit(str(path), old_content)
486
+
487
+ return ToolResult(
488
+ success=True,
489
+ output=f"File edited: {path} (1 replacement)",
490
+ )
491
+ except Exception as e:
492
+ return ToolResult(
493
+ success=False, output="", error=f"Cannot write file: {e}"
494
+ )
495
+
496
+ @staticmethod
497
+ def _cst_edit(content: str, old_str: str, new_str: str, path: Any) -> str | None:
498
+ """Attempt a CST-based edit for Python files using libcst.
499
+
500
+ Parses the file as a Concrete Syntax Tree, finds the node matching
501
+ old_str, replaces it with new_str parsed as the same node type,
502
+ and returns the formatted code. Preserves ALL formatting.
503
+
504
+ Returns None if the edit cannot be performed via CST (fallback to text).
505
+ """
506
+ try:
507
+ import libcst as cst
508
+ except ImportError:
509
+ return None # libcst not installed — use text fallback
510
+
511
+ try:
512
+ tree = cst.parse_module(content)
513
+ except Exception:
514
+ return None # Can't parse — fall back to text
515
+
516
+ # Strategy: parse old_str as a statement, find it in the tree, replace
517
+ try:
518
+ # Try parsing old_str as a full module body (the typical case for
519
+ # function/method/class-level edits)
520
+ old_module = cst.parse_module(old_str + "\n")
521
+ if len(old_module.body) == 1:
522
+ old_node = old_module.body[0]
523
+ # Parse new_str as the same node type
524
+ new_module = cst.parse_module(new_str + "\n")
525
+ if len(new_module.body) == 1:
526
+ new_node = new_module.body[0]
527
+ transformer = _NodeReplacer(old_node, new_node)
528
+ new_tree = tree.visit(transformer)
529
+ if transformer.found:
530
+ return new_tree.code
531
+ except Exception:
532
+ pass
533
+
534
+ # Strategy 2: try parsing old_str as a simple statement line
535
+ try:
536
+ old_module = cst.parse_module(old_str + "\n")
537
+ if len(old_module.body) == 1:
538
+ old_stmt = old_module.body[0]
539
+ # For simple statements, use a body-statement-level replacer
540
+ new_module = cst.parse_module(new_str + "\n")
541
+ if len(new_module.body) == 1:
542
+ new_stmt = new_module.body[0]
543
+ transformer = _StatementReplacer(old_stmt, new_stmt)
544
+ new_tree = tree.visit(transformer)
545
+ if transformer.found:
546
+ return new_tree.code
547
+ except Exception:
548
+ pass
549
+
550
+ return None # All CST strategies failed — fall back to text
551
+
552
+
553
+ # ── CST Transformers (libcst optional — guarded at module level) ──────────
554
+
555
+ try:
556
+ import libcst as _cst_lib
557
+
558
+ class _NodeReplacer(_cst_lib.CSTTransformer):
559
+ """Replace a specific CST node with another, preserving all else."""
560
+
561
+ def __init__(self, old_node: _cst_lib.CSTNode, new_node: _cst_lib.CSTNode):
562
+ self.old_node = old_node
563
+ self.new_node = new_node
564
+ self.found = False
565
+
566
+ def on_visit(self, node: _cst_lib.CSTNode) -> bool:
567
+ if node.deep_equals(self.old_node) and not self.found:
568
+ self.found = True
569
+ return False
570
+ return True
571
+
572
+ def on_leave(self, original_node: _cst_lib.CSTNode, updated_node: _cst_lib.CSTNode) -> _cst_lib.CSTNode:
573
+ if original_node.deep_equals(self.old_node) and self.found:
574
+ return self.new_node
575
+ return updated_node
576
+
577
+ class _StatementReplacer(_cst_lib.CSTTransformer):
578
+ """Replace a statement within a body, matching by deep equality."""
579
+
580
+ def __init__(self, old_stmt: _cst_lib.CSTNode, new_stmt: _cst_lib.CSTNode):
581
+ self.old_stmt = old_stmt
582
+ self.new_stmt = new_stmt
583
+ self.found = False
584
+
585
+ def leave_SimpleStatementLine(
586
+ self, original_node: _cst_lib.SimpleStatementLine, updated_node: _cst_lib.SimpleStatementLine
587
+ ):
588
+ if not self.found and len(updated_node.body) == 1:
589
+ if updated_node.body[0].deep_equals(self.old_stmt):
590
+ self.found = True
591
+ return updated_node.with_changes(body=[self.new_stmt])
592
+ return updated_node
593
+
594
+ except ImportError:
595
+ # libcst not installed — AST editing will fall back to text replacement
596
+ _NodeReplacer = None # type: ignore
597
+ _StatementReplacer = None # type: ignore
598
+
599
+ # ── Rename symbol (AST-aware) ────────────────────────────────────────────
600
+
601
+ async def _tool_rename_symbol(
602
+ self, file_path: str, old_name: str, new_name: str,
603
+ symbol_type: str = "variable"
604
+ ) -> ToolResult:
605
+ """Safely rename a Python symbol using libcst AST matching.
606
+ Never touches strings, comments, or imports of the same name.
607
+ """
608
+ if old_name == new_name:
609
+ return ToolResult(success=False, output="", error="Names are identical.")
610
+ if not old_name.isidentifier() or not new_name.isidentifier():
611
+ return ToolResult(success=False, output="", error="Names must be valid Python identifiers.")
612
+
613
+ path = self._resolve_path(file_path)
614
+ if not path.exists():
615
+ return ToolResult(success=False, output="", error=f"File not found: {path}")
616
+ if path.suffix not in (".py", ".pyi"):
617
+ return ToolResult(success=False, output="", error="rename_symbol only works with Python files.")
618
+
619
+ try:
620
+ with open(path, "r", encoding="utf-8") as f:
621
+ content = f.read()
622
+ except Exception as e:
623
+ return ToolResult(success=False, output="", error=f"Cannot read file: {e}")
624
+
625
+ try:
626
+ import libcst as cst
627
+ except ImportError:
628
+ return ToolResult(success=False, output="", error="libcst not installed. Run: pip install libcst")
629
+
630
+ try:
631
+ tree = cst.parse_module(content)
632
+ except Exception as e:
633
+ return ToolResult(success=False, output="", error=f"Cannot parse Python file: {e}")
634
+
635
+ # Choose the right renamer for the symbol type
636
+ if symbol_type == "function":
637
+ renamer = _FunctionRenamer(old_name, new_name)
638
+ elif symbol_type == "class":
639
+ renamer = _ClassRenamer(old_name, new_name)
640
+ else:
641
+ renamer = _VariableRenamer(old_name, new_name)
642
+
643
+ new_tree = tree.visit(renamer)
644
+ if not renamer.changes:
645
+ return ToolResult(success=False, output="", error=f"Symbol '{old_name}' not found in file.")
646
+
647
+ new_content = new_tree.code
648
+ old_content = content
649
+
650
+ try:
651
+ with open(path, "w", encoding="utf-8") as f:
652
+ f.write(new_content)
653
+ self._notify_edit(str(path), old_content)
654
+ return ToolResult(
655
+ success=True,
656
+ output=f"Renamed {renamer.changes} occurrence(s) of '{old_name}' → '{new_name}' in {path}",
657
+ )
658
+ except Exception as e:
659
+ return ToolResult(success=False, output="", error=f"Cannot write file: {e}")
660
+
661
+ # ── Shell tool ───────────────────────────────────────────────────────────
662
+
663
+ async def _tool_run_shell(
664
+ self, command: str, timeout: int = 120
665
+ ) -> ToolResult:
666
+ """Execute a shell command."""
667
+ # Safety checks
668
+ cmd_lower = command.lower().strip()
669
+
670
+ # Check blocked patterns (hard block — unforgivable operations)
671
+ for blocked in self.config.blocked_commands:
672
+ if blocked.lower() in cmd_lower:
673
+ return ToolResult(
674
+ success=False,
675
+ output="",
676
+ error=f"Blocked command pattern detected: {blocked}",
677
+ )
678
+
679
+ # Check if first word is allowed (soft warning — safety_guard is the real gate)
680
+ first_word = command.strip().split()[0] if command.strip() else ""
681
+ if first_word and first_word not in self.config.allowed_commands:
682
+ logger.debug("Command '%s' not in allowed_commands whitelist, proceeding", first_word)
683
+
684
+ MAX_OUTPUT_BYTES = 500_000 # cap total stdout+stderr to prevent memory exhaustion
685
+ proc = None
686
+ try:
687
+ # DEVNULL on stdin prevents hangs when child processes try to read
688
+ # from inherited stdin (common cause of "exec task freeze").
689
+ proc = await asyncio.create_subprocess_shell(
690
+ command,
691
+ stdin=asyncio.subprocess.DEVNULL,
692
+ stdout=asyncio.subprocess.PIPE,
693
+ stderr=asyncio.subprocess.PIPE,
694
+ cwd=str(self.workspace),
695
+ )
696
+
697
+ # Stream stdout+stderr concurrently — avoids pipe-buffer deadlock
698
+ # and hangs from child processes inheriting pipes.
699
+ # When a stream_callback is set, emit chunks in real-time so the
700
+ # user sees command progress instead of waiting for completion.
701
+ async def _read_stream(stream, chunks: list[bytes]):
702
+ total = 0
703
+ while True:
704
+ try:
705
+ chunk = await stream.read(65536)
706
+ except (asyncio.CancelledError, Exception):
707
+ return
708
+ if not chunk:
709
+ return
710
+ total += len(chunk)
711
+ if total <= MAX_OUTPUT_BYTES:
712
+ chunks.append(chunk)
713
+ # Real-time streaming to UI
714
+ if self._stream_cb:
715
+ try:
716
+ text = chunk.decode("utf-8", errors="replace")
717
+ self._stream_cb("run_shell", text)
718
+ except Exception:
719
+ pass
720
+
721
+ stdout_chunks: list[bytes] = []
722
+ stderr_chunks: list[bytes] = []
723
+ stdout_task = asyncio.create_task(_read_stream(proc.stdout, stdout_chunks))
724
+ stderr_task = asyncio.create_task(_read_stream(proc.stderr, stderr_chunks))
725
+
726
+ # Wait for process to exit (with timeout), NOT for pipes to close
727
+ try:
728
+ await asyncio.wait_for(proc.wait(), timeout=timeout)
729
+ except asyncio.TimeoutError:
730
+ # Kill the process FIRST — sends EOF to pipes, unblocking reader tasks
731
+ try:
732
+ proc.kill()
733
+ except ProcessLookupError:
734
+ pass
735
+ # Cancel reader tasks AFTER kill (kill unblocks pipe reads)
736
+ stdout_task.cancel()
737
+ stderr_task.cancel()
738
+ try:
739
+ await proc.wait()
740
+ except (ProcessLookupError, Exception):
741
+ pass
742
+ return ToolResult(
743
+ success=False, output="",
744
+ error=f"Command timed out after {timeout}s",
745
+ )
746
+
747
+ # Process exited — wait briefly for remaining pipe data
748
+ try:
749
+ await asyncio.wait_for(stdout_task, timeout=5)
750
+ except asyncio.TimeoutError:
751
+ stdout_task.cancel()
752
+ try:
753
+ await asyncio.wait_for(stderr_task, timeout=5)
754
+ except asyncio.TimeoutError:
755
+ stderr_task.cancel()
756
+
757
+ stdout_bytes = b"".join(stdout_chunks)
758
+ stderr_bytes = b"".join(stderr_chunks)
759
+ output = stdout_bytes.decode("utf-8", errors="replace")
760
+ if stderr_bytes:
761
+ output += f"\n[stderr]\n{stderr_bytes.decode('utf-8', errors='replace')}"
762
+ returncode = proc.returncode or 0
763
+ if returncode != 0:
764
+ output += f"\n[exit code: {returncode}]"
765
+ if len(stdout_bytes) + len(stderr_bytes) >= MAX_OUTPUT_BYTES:
766
+ output += "\n[output truncated — exceeded 500KB limit]"
767
+ return ToolResult(
768
+ success=returncode == 0,
769
+ output=output.strip() or "(no output)",
770
+ )
771
+ except Exception as e:
772
+ return ToolResult(
773
+ success=False, output="", error=f"Command failed: {e}"
774
+ )
775
+ finally:
776
+ # Explicitly close pipes to prevent "I/O operation on closed pipe"
777
+ # during BaseSubprocessTransport.__del__ at GC time.
778
+ if proc is not None:
779
+ for pipe in (proc.stdin, proc.stdout, proc.stderr):
780
+ if pipe is not None:
781
+ try:
782
+ pipe.close()
783
+ except Exception:
784
+ pass
785
+
786
+ # ── Search tools ─────────────────────────────────────────────────────────
787
+
788
+ async def _tool_grep(
789
+ self,
790
+ pattern: str,
791
+ path: str | None = None,
792
+ glob: str | None = None,
793
+ case_sensitive: bool = False,
794
+ ) -> ToolResult:
795
+ """Search file contents with regex."""
796
+ search_dir = self._resolve_path(path or ".")
797
+ if not search_dir.exists():
798
+ return ToolResult(
799
+ success=False,
800
+ output="",
801
+ error=f"Directory not found: {search_dir}",
802
+ )
803
+
804
+ flags = 0 if case_sensitive else re.IGNORECASE
805
+ try:
806
+ regex = re.compile(pattern, flags)
807
+ except re.error as e:
808
+ return ToolResult(
809
+ success=False, output="", error=f"Invalid regex: {e}"
810
+ )
811
+
812
+ # Run filesystem walk + file reads in thread pool to avoid
813
+ # blocking the asyncio event loop on large codebases.
814
+ def _do_grep():
815
+ results: list[str] = []
816
+ total_matches = 0
817
+ for root, dirs, files in os.walk(search_dir):
818
+ dirs[:] = [
819
+ d for d in dirs
820
+ if not d.startswith(".")
821
+ and d not in self.SKIP_DIRS
822
+ ]
823
+ for fname in files:
824
+ if glob and not fnmatch.fnmatch(fname, glob):
825
+ continue
826
+
827
+ full_path = os.path.join(root, fname)
828
+ try:
829
+ rel_path = os.path.relpath(full_path, self.workspace)
830
+ except ValueError:
831
+ rel_path = full_path
832
+
833
+ try:
834
+ with open(full_path, "r", encoding="utf-8", errors="replace") as fh:
835
+ file_lines = fh.readlines()
836
+ except Exception:
837
+ continue
838
+
839
+ matches_in_file = []
840
+ for line_no, line_text in enumerate(file_lines, 1):
841
+ if regex.search(line_text):
842
+ matches_in_file.append(
843
+ f" {line_no}: {line_text.rstrip()[:200]}"
844
+ )
845
+ total_matches += 1
846
+ if len(matches_in_file) >= self.MAX_GREP_PER_FILE:
847
+ matches_in_file.append(" ... (truncated)")
848
+ break
849
+
850
+ if matches_in_file:
851
+ results.append(f"{rel_path}:")
852
+ results.extend(matches_in_file)
853
+
854
+ if len(results) >= self.MAX_GREP_RESULTS:
855
+ results.append("... (result limit reached)")
856
+ return results, total_matches
857
+
858
+ if len(results) >= self.MAX_GREP_RESULTS:
859
+ break
860
+ return results, total_matches
861
+
862
+ results, total_matches = await self._run_in_thread(_do_grep)
863
+
864
+ if not results:
865
+ return ToolResult(
866
+ success=True,
867
+ output=f"No matches found for pattern: {pattern}",
868
+ )
869
+ return ToolResult(
870
+ success=True,
871
+ output=f"Found {total_matches} matches:\n\n" + "\n".join(results),
872
+ )
873
+
874
+ async def _tool_glob(
875
+ self,
876
+ pattern: str,
877
+ path: str | None = None,
878
+ ) -> ToolResult:
879
+ """Find files by glob pattern."""
880
+ search_dir = self._resolve_path(path or ".")
881
+ if not search_dir.exists():
882
+ return ToolResult(
883
+ success=False,
884
+ output="",
885
+ error=f"Directory not found: {search_dir}",
886
+ )
887
+
888
+ import glob as glob_mod
889
+
890
+ # Auto-add **/ prefix for recursive matching if not already present
891
+ if "**" not in pattern:
892
+ pattern = f"**/{pattern}"
893
+ search_pattern = str(search_dir / pattern)
894
+ matches = glob_mod.glob(search_pattern, recursive=True)
895
+
896
+ if not matches:
897
+ return ToolResult(
898
+ success=True, output=f"No files matching: {pattern}"
899
+ )
900
+
901
+ # Sort and format
902
+ matches.sort()
903
+ output_lines = []
904
+ for m in matches[:200]:
905
+ try:
906
+ rel = os.path.relpath(m, self.workspace)
907
+ except ValueError:
908
+ rel = m
909
+ size = os.path.getsize(m) if os.path.isfile(m) else 0
910
+ output_lines.append(f" {rel} ({size:,} bytes)")
911
+
912
+ if len(matches) > 200:
913
+ output_lines.append(
914
+ f" ... and {len(matches) - 200} more files"
915
+ )
916
+
917
+ return ToolResult(
918
+ success=True,
919
+ output=f"Found {len(matches)} files matching '{pattern}':\n"
920
+ + "\n".join(output_lines),
921
+ )
922
+
923
+ async def _tool_list_dir(
924
+ self,
925
+ path: str | None = None,
926
+ recursive: bool = False,
927
+ ) -> ToolResult:
928
+ """List directory contents."""
929
+ target = self._resolve_path(path or ".")
930
+ if not target.exists():
931
+ return ToolResult(
932
+ success=False,
933
+ output="",
934
+ error=f"Directory not found: {target}",
935
+ )
936
+ if not target.is_dir():
937
+ return ToolResult(
938
+ success=False,
939
+ output="",
940
+ error=f"Not a directory: {target}",
941
+ )
942
+
943
+ output_lines = [f"Directory: {target}"]
944
+ entries: list[str] = []
945
+
946
+ if recursive:
947
+ for root, dirs, files in os.walk(target):
948
+ dirs[:] = [
949
+ d for d in dirs
950
+ if not d.startswith(".") and d not in self.SKIP_DIRS
951
+ ]
952
+ level = root.replace(str(target), "").count(os.sep)
953
+ indent = " " * level
954
+ if level > 0:
955
+ entries.append(f"{indent}{os.path.basename(root)}/")
956
+ for f in sorted(files):
957
+ fp = os.path.join(root, f)
958
+ size = os.path.getsize(fp)
959
+ entries.append(f"{indent} {f} ({size:,}B)")
960
+ else:
961
+ items = sorted(target.iterdir(), key=lambda x: (not x.is_dir(), x.name))
962
+ for item in items:
963
+ suffix = "/" if item.is_dir() else ""
964
+ size = ""
965
+ if item.is_file():
966
+ size = f" ({item.stat().st_size:,}B)"
967
+ entries.append(f" {item.name}{suffix}{size}")
968
+
969
+ output_lines.extend(entries[:500])
970
+ if len(entries) > 500:
971
+ output_lines.append(f" ... and {len(entries) - 500} more entries")
972
+
973
+ return ToolResult(success=True, output="\n".join(output_lines))
974
+
975
+
976
+ # ── Factory ──────────────────────────────────────────────────────────────────
977
+
978
+ def create_tool_executor(workspace_dir: str | None = None) -> ToolExecutor:
979
+ """Create a tool executor with the given workspace."""
980
+ from ..config import AgentConfig
981
+ if workspace_dir:
982
+ cfg = AgentConfig(workspace_dir=workspace_dir)
983
+ return ToolExecutor(cfg)
984
+ return ToolExecutor()
985
+
986
+
987
+ # ── CST Renamers (for rename_symbol tool) ───────────────────────────────────
988
+ # These use libcst to safely rename symbols without touching strings or comments.
989
+
990
+ class _SymbolRenamer:
991
+ """Base class for CST symbol renamers — shared leave_Name/leave_Call logic."""
992
+
993
+ def __init__(self, old_name: str, new_name: str):
994
+ self.old = old_name
995
+ self.new = new_name
996
+ self.changes = 0
997
+
998
+ def leave_Name(self, original_node, updated_node):
999
+ import libcst as cst
1000
+ if isinstance(updated_node, cst.Name) and updated_node.value == self.old:
1001
+ self.changes += 1
1002
+ return updated_node.with_changes(value=self.new)
1003
+ return updated_node
1004
+
1005
+ def leave_Call(self, original_node, updated_node):
1006
+ import libcst as cst
1007
+ if isinstance(updated_node.func, cst.Name) and updated_node.func.value == self.old:
1008
+ self.changes += 1
1009
+ return updated_node.with_changes(func=cst.Name(value=self.new))
1010
+ return updated_node
1011
+
1012
+
1013
+ class _VariableRenamer(_SymbolRenamer):
1014
+ """Rename variable references — names only, not touching strings/comments."""
1015
+
1016
+
1017
+ class _FunctionRenamer(_SymbolRenamer):
1018
+ """Rename function/method definitions and calls."""
1019
+
1020
+ def leave_FunctionDef(self, original_node, updated_node):
1021
+ import libcst as cst
1022
+ if updated_node.name.value == self.old:
1023
+ self.changes += 1
1024
+ return updated_node.with_changes(name=cst.Name(value=self.new))
1025
+ return updated_node
1026
+
1027
+
1028
+ class _ClassRenamer(_SymbolRenamer):
1029
+ """Rename class definitions and constructor calls."""
1030
+
1031
+ def leave_ClassDef(self, original_node, updated_node):
1032
+ import libcst as cst
1033
+ if updated_node.name.value == self.old:
1034
+ self.changes += 1
1035
+ return updated_node.with_changes(name=cst.Name(value=self.new))
1036
+ return updated_node