ata-coder 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. ata_coder/__init__.py +1 -0
  2. ata_coder/agent.py +874 -0
  3. ata_coder/agent_compact.py +190 -0
  4. ata_coder/agent_controller.py +218 -0
  5. ata_coder/agent_extension.py +69 -0
  6. ata_coder/agent_routing.py +105 -0
  7. ata_coder/agent_subsystems.py +72 -0
  8. ata_coder/agent_tools.py +318 -0
  9. ata_coder/agent_undo.py +63 -0
  10. ata_coder/anthropic_client.py +465 -0
  11. ata_coder/change_tracker.py +368 -0
  12. ata_coder/clawd_integration.py +574 -0
  13. ata_coder/commands/__init__.py +128 -0
  14. ata_coder/commands/_core.py +184 -0
  15. ata_coder/commands/_safety.py +95 -0
  16. ata_coder/commands/_settings.py +241 -0
  17. ata_coder/commands/_workflow.py +451 -0
  18. ata_coder/commands.py +974 -0
  19. ata_coder/config.py +257 -0
  20. ata_coder/core/__init__.py +35 -0
  21. ata_coder/core/events.py +73 -0
  22. ata_coder/core/queue.py +85 -0
  23. ata_coder/core/state.py +17 -0
  24. ata_coder/event_queue.py +5 -0
  25. ata_coder/extension.py +654 -0
  26. ata_coder/extensions/__init__.py +1 -0
  27. ata_coder/extensions/hello_skill.py +47 -0
  28. ata_coder/fool_proof.py +295 -0
  29. ata_coder/git_workflow.py +371 -0
  30. ata_coder/gui.py +511 -0
  31. ata_coder/llm_client.py +543 -0
  32. ata_coder/main.py +814 -0
  33. ata_coder/mcp_client.py +1095 -0
  34. ata_coder/memory.py +539 -0
  35. ata_coder/model_registry.py +134 -0
  36. ata_coder/model_router.py +105 -0
  37. ata_coder/permissions.py +274 -0
  38. ata_coder/privilege.py +464 -0
  39. ata_coder/project.py +273 -0
  40. ata_coder/prompt_template.py +423 -0
  41. ata_coder/prompts/auto-mode.md +7 -0
  42. ata_coder/prompts/coding-rules.md +40 -0
  43. ata_coder/prompts/execution-guardrails.md +14 -0
  44. ata_coder/prompts/memory-system.md +24 -0
  45. ata_coder/prompts/output-style.md +23 -0
  46. ata_coder/prompts/safety.md +17 -0
  47. ata_coder/prompts/slash-commands.md +24 -0
  48. ata_coder/prompts/sub-agents.md +38 -0
  49. ata_coder/prompts/system-reminders.md +17 -0
  50. ata_coder/prompts/system.md +105 -0
  51. ata_coder/prompts/tool-policy.md +46 -0
  52. ata_coder/repl_theme.py +99 -0
  53. ata_coder/repl_tracker.py +89 -0
  54. ata_coder/repl_ui.py +1214 -0
  55. ata_coder/safety_guard.py +434 -0
  56. ata_coder/self_correct.py +346 -0
  57. ata_coder/server.py +882 -0
  58. ata_coder/server_session.py +159 -0
  59. ata_coder/server_shell.py +129 -0
  60. ata_coder/session.py +431 -0
  61. ata_coder/settings.py +439 -0
  62. ata_coder/setup_wizard.py +136 -0
  63. ata_coder/skill_extension.py +92 -0
  64. ata_coder/skills/architect/SKILL.md +42 -0
  65. ata_coder/skills/code-reviewer/SKILL.md +37 -0
  66. ata_coder/skills/codecraft/SKILL.md +452 -0
  67. ata_coder/skills/debugger/SKILL.md +45 -0
  68. ata_coder/skills/doc-writer/SKILL.md +36 -0
  69. ata_coder/skills/general-coder/SKILL.md +76 -0
  70. ata_coder/skills/math-calculator/README.md +40 -0
  71. ata_coder/skills/math-calculator/SKILL.md +59 -0
  72. ata_coder/skills/math-calculator/handler.py +103 -0
  73. ata_coder/skills/math-calculator/prompts/system.md +8 -0
  74. ata_coder/skills/math-calculator/requirements.txt +2 -0
  75. ata_coder/skills/math-calculator/resources/constants.json +8 -0
  76. ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
  77. ata_coder/skills/security-auditor/SKILL.md +40 -0
  78. ata_coder/skills/test-writer/SKILL.md +36 -0
  79. ata_coder/skills/weather-skill/README.md +45 -0
  80. ata_coder/skills/weather-skill/handler.py +76 -0
  81. ata_coder/skills/weather-skill/manifest.json +48 -0
  82. ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
  83. ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
  84. ata_coder/skills/weather-skill/requirements.txt +1 -0
  85. ata_coder/skills/weather-skill/resources/city_list.json +17 -0
  86. ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
  87. ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
  88. ata_coder/skills/weather-skill/weather_utils.py +50 -0
  89. ata_coder/skills.py +1014 -0
  90. ata_coder/sub_agent.py +273 -0
  91. ata_coder/sub_agent_manager.py +203 -0
  92. ata_coder/system_prompt_builder.py +146 -0
  93. ata_coder/task_planner.py +391 -0
  94. ata_coder/terminal.py +318 -0
  95. ata_coder/test_runner.py +219 -0
  96. ata_coder/thread_supervisor.py +195 -0
  97. ata_coder/tool_defs.py +335 -0
  98. ata_coder/tools/__init__.py +11 -0
  99. ata_coder/tools/definitions.py +335 -0
  100. ata_coder/tools/executor.py +1036 -0
  101. ata_coder/tools/result.py +26 -0
  102. ata_coder/tools/subagent.py +332 -0
  103. ata_coder/tools/web.py +361 -0
  104. ata_coder/tools.py +1576 -0
  105. ata_coder/types.py +92 -0
  106. ata_coder/utils.py +113 -0
  107. ata_coder/web/css/style.css +180 -0
  108. ata_coder/web/index.html +84 -0
  109. ata_coder/web/js/app.js +489 -0
  110. ata_coder/web/package-lock.json +25 -0
  111. ata_coder/web/package.json +10 -0
  112. ata_coder/web/tsconfig.json +13 -0
  113. ata_coder-2.4.2.dist-info/METADATA +799 -0
  114. ata_coder-2.4.2.dist-info/RECORD +118 -0
  115. ata_coder-2.4.2.dist-info/WHEEL +5 -0
  116. ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
  117. ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
  118. ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/tools.py ADDED
@@ -0,0 +1,1576 @@
1
+ """
2
+ Tool system for the ATA Coder.
3
+
4
+ Provides a set of tools the agent can use:
5
+ - read_file: Read file contents
6
+ - write_file: Create or overwrite a file
7
+ - edit_file: Precise string replacement in a file
8
+ - run_shell: Execute a shell command
9
+ - grep: Search file contents with regex
10
+ - glob: Find files matching a pattern
11
+ - list_dir: List directory contents
12
+ - web_search: Search the web (optional)
13
+ """
14
+
15
+ import asyncio
16
+ import html
17
+ import html.parser
18
+ import logging
19
+ import os
20
+ import re
21
+ import shutil
22
+ import subprocess
23
+ import fnmatch
24
+ import time
25
+ import urllib.parse
26
+ from pathlib import Path
27
+ from typing import Any, Callable
28
+
29
+ import httpx
30
+
31
+ from .config import AgentConfig
32
+ from .clawd_integration import get_clawd
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # ── Tool result type ─────────────────────────────────────────────────────────
38
+
39
+ class ToolResult:
40
+ """Result of executing a tool."""
41
+
42
+ def __init__(self, success: bool, output: str, error: str = ""):
43
+ self.success = success
44
+ self.output = output
45
+ self.error = error
46
+
47
+ def to_message(self) -> str:
48
+ """Format as a message to the LLM."""
49
+ if self.success:
50
+ return self.output
51
+ return f"Error: {self.error}\n\n{self.output}".strip()
52
+
53
+ def to_tool_result(self, tool_call_id: str) -> dict:
54
+ """Format as an OpenAI tool result message."""
55
+ return {
56
+ "role": "tool",
57
+ "tool_call_id": tool_call_id,
58
+ "content": self.to_message(),
59
+ }
60
+
61
+
62
+ # ── Tool definitions (OpenAI function format) ────────────────────────────────
63
+ # Extracted to tool_defs.py to keep this file under 1600 lines.
64
+
65
+ from .tool_defs import TOOL_DEFINITIONS # noqa: E402
66
+
67
+ # TOOL_DEFINITIONS are now in tool_defs.py (extracted to keep this file focused
68
+ # on implementations). The import above re-exports them transparently.
69
+
70
+ # ── Tool implementations ─────────────────────────────────────────────────────
71
+
72
+ class ToolExecutor:
73
+ """Executes tool calls and manages workspace context."""
74
+
75
+ # Result limits
76
+ MAX_GREP_RESULTS = 100
77
+ MAX_GREP_PER_FILE = 20
78
+ MAX_GLOB_RESULTS = 200
79
+ MAX_DIR_ENTRIES = 500
80
+
81
+ # Directories to skip during recursive operations
82
+ SKIP_DIRS = {
83
+ "node_modules", "__pycache__", ".git", "venv", ".venv",
84
+ "dist", "build", "target", ".next", ".pytest_cache",
85
+ }
86
+
87
+ def __init__(self, config: AgentConfig | None = None):
88
+ self.config = config or AgentConfig()
89
+ self.workspace = Path(self.config.workspace_dir).resolve()
90
+ self._mcp = None # set via set_mcp_client()
91
+ self._edit_callback: Callable[[str, str], None] | None = None
92
+ # File read cache: path -> (mtime, content). "只读一遍" — files
93
+ # are read from disk once per session and served from memory after.
94
+ self._file_cache: dict[str, tuple[float, str]] = {}
95
+ self._file_cache_max_entries = 50 # LRU eviction threshold
96
+ self._cache_dir: Path | None = None
97
+ # Sub-agent manager (set by agent)
98
+ self._sub_agent_mgr: Any = None
99
+ # Pre-built handler map (validated at init time)
100
+ self._handlers: dict[str, Callable] = self._build_handlers()
101
+
102
+ def _build_handlers(self) -> dict[str, Callable]:
103
+ """Build a dispatch table: tool_name → handler callable."""
104
+ handlers: dict[str, Callable] = {}
105
+ for name in dir(self):
106
+ if name.startswith("_tool_"):
107
+ tool_name = name[len("_tool_"):]
108
+ handlers[tool_name] = getattr(self, name)
109
+ return handlers
110
+
111
+ def set_sub_agent_manager(self, mgr: Any) -> None:
112
+ """Set the SubAgentManager for spawn/collect sub-agent tool support."""
113
+ self._sub_agent_mgr = mgr
114
+
115
+ def set_mcp_client(self, mcp: Any) -> None:
116
+ """Set the MCPClient for mcp_search tool support."""
117
+ self._mcp = mcp
118
+
119
+ def setup_file_cache(self, cache_dir: str | Path) -> None:
120
+ """Create session cache directory. Call once before running the agent."""
121
+ self._cache_dir = Path(cache_dir)
122
+ self._cache_dir.mkdir(parents=True, exist_ok=True)
123
+
124
+ def clear_file_cache(self) -> None:
125
+ """Remove all cached files and the cache directory."""
126
+ self._file_cache.clear()
127
+ if self._cache_dir and self._cache_dir.exists():
128
+ shutil.rmtree(self._cache_dir, ignore_errors=True)
129
+ self._cache_dir = None
130
+
131
+ def close(self) -> None:
132
+ """Release all held resources (httpx clients, file caches)."""
133
+ self.clear_file_cache()
134
+ if hasattr(self, "_http") and self._http is not None:
135
+ try:
136
+ self._http.close()
137
+ except Exception:
138
+ pass
139
+ self._http = None
140
+
141
+ def __del__(self) -> None:
142
+ """Safety net: ensure httpx client is closed on GC."""
143
+ try:
144
+ self.close()
145
+ except Exception:
146
+ pass
147
+
148
+ def on_edit(self, callback: Callable[[str, str], None]) -> None:
149
+ """Register callback for edit notifications: callback(file_path, old_content)."""
150
+ self._edit_callback = callback
151
+
152
+ def _notify_edit(self, file_path: str, old_content: str) -> None:
153
+ """Notify the UI of a file edit for diff display."""
154
+ if self._edit_callback:
155
+ try:
156
+ self._edit_callback(file_path, old_content)
157
+ except Exception:
158
+ logger.exception("Edit callback failed for %s", file_path)
159
+
160
+ def _resolve_path(self, file_path: str) -> Path:
161
+ """Resolve a file path relative to workspace.
162
+
163
+ Raises ValueError if the resolved path escapes the workspace via
164
+ path traversal (e.g., ``../../etc/passwd``).
165
+ """
166
+ p = Path(file_path)
167
+ if not p.is_absolute():
168
+ p = self.workspace / p
169
+ resolved = p.resolve()
170
+ try:
171
+ resolved.relative_to(self.workspace.resolve())
172
+ except ValueError:
173
+ raise ValueError(
174
+ f"Path traversal blocked: {file_path} → {resolved} "
175
+ f"is outside workspace {self.workspace}"
176
+ )
177
+ return resolved
178
+
179
+ def _ensure_in_workspace(self, path: Path) -> Path:
180
+ """
181
+ Ensure path is within workspace for safety.
182
+ For files outside workspace, allow with a warning.
183
+ """
184
+ try:
185
+ path.resolve().relative_to(self.workspace)
186
+ except ValueError:
187
+ logger.warning("Path outside workspace: %s", path)
188
+ return path
189
+
190
+ # Global output cap — every tool result is trimmed to this many chars.
191
+ # ~25k tokens, still generous for legitimate work but prevents a single
192
+ # result from eating the whole context window.
193
+ MAX_OUTPUT_CHARS = 100_000
194
+
195
+ async def execute(self, tool_name: str, arguments: dict[str, Any]) -> ToolResult:
196
+ """Dispatch a tool call to the appropriate handler.
197
+
198
+ All results are capped at MAX_OUTPUT_CHARS to prevent any single
199
+ tool response from blowing up the conversation context.
200
+ """
201
+ handler = self._handlers.get(tool_name)
202
+ if handler is None:
203
+ return ToolResult(
204
+ success=False,
205
+ output="",
206
+ error=f"Unknown tool: {tool_name}",
207
+ )
208
+ try:
209
+ result = await handler(**arguments)
210
+ except Exception as e:
211
+ logger.exception("Tool %s failed", tool_name)
212
+ return ToolResult(
213
+ success=False, output="", error=f"{type(e).__name__}: {e}"
214
+ )
215
+
216
+ # ── Global size cap ──────────────────────────────────────────
217
+ if len(result.output) > self.MAX_OUTPUT_CHARS:
218
+ original_len = len(result.output)
219
+ cut = result.output[:self.MAX_OUTPUT_CHARS]
220
+ result.output = (
221
+ cut + f"\n\n... [truncated {original_len - self.MAX_OUTPUT_CHARS:,} "
222
+ f"chars — result was {original_len:,} chars total]"
223
+ )
224
+ logger.warning(
225
+ "Tool %s output truncated: %d → %d chars",
226
+ tool_name, original_len, self.MAX_OUTPUT_CHARS,
227
+ )
228
+
229
+ return result
230
+
231
+ # ── File tools ───────────────────────────────────────────────────────────
232
+
233
+ # ── Output limits (prevent token bloat from large file reads) ──────
234
+ MAX_READ_LINES = 2000 # auto-truncate reads without an explicit limit
235
+ MAX_READ_CHARS = 80_000 # hard cap on output chars (~20k tokens)
236
+
237
+ async def _tool_read_file(
238
+ self,
239
+ file_path: str,
240
+ offset: int | None = None,
241
+ limit: int | None = None,
242
+ ) -> ToolResult:
243
+ """Read a file with optional line range.
244
+
245
+ When *limit* is not set, output is capped at MAX_READ_LINES to
246
+ prevent a single file read from eating tens of thousands of context
247
+ tokens. Use *offset* + *limit* to page through larger files.
248
+ """
249
+ path = self._resolve_path(file_path)
250
+ if not path.exists():
251
+ return ToolResult(
252
+ success=False,
253
+ output="",
254
+ error=f"File not found: {path}",
255
+ )
256
+ if path.is_dir():
257
+ return ToolResult(
258
+ success=False,
259
+ output="",
260
+ error=f"Path is a directory, not a file: {path}",
261
+ )
262
+
263
+ # ── File cache: "只读一遍" — return short note on re-read ────
264
+ cache_key = str(path.resolve())
265
+ current_mtime = path.stat().st_mtime
266
+ needs_disk_read = True
267
+
268
+ if cache_key in self._file_cache:
269
+ cached_mtime, cached_content = self._file_cache[cache_key]
270
+ if cached_mtime == current_mtime:
271
+ if offset is not None or limit is not None:
272
+ # Specific section — serve from memory, skip disk
273
+ needs_disk_read = False
274
+ lines = cached_content.splitlines(keepends=True)
275
+ if lines and not lines[-1].endswith("\n"):
276
+ lines[-1] += "\n"
277
+ else:
278
+ # Whole file re-read — DON'T send content again
279
+ total = cached_content.count("\n") + 1
280
+ chars = len(cached_content)
281
+ return ToolResult(
282
+ success=True,
283
+ output=(
284
+ f"[cached] {file_path} — {total} lines, {chars:,} chars.\n"
285
+ f"Already in conversation context from earlier read. "
286
+ f"Use offset/limit to read specific sections if needed."
287
+ ),
288
+ )
289
+
290
+ if needs_disk_read:
291
+ try:
292
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
293
+ raw = f.read()
294
+ except Exception as e:
295
+ return ToolResult(
296
+ success=False, output="", error=f"Cannot read file: {e}"
297
+ )
298
+ self._file_cache[cache_key] = (current_mtime, raw)
299
+ # LRU eviction: drop oldest entries when cache exceeds limit
300
+ if len(self._file_cache) > self._file_cache_max_entries:
301
+ overflow = len(self._file_cache) - self._file_cache_max_entries
302
+ oldest = list(self._file_cache.keys())[:overflow]
303
+ for k in oldest:
304
+ del self._file_cache[k]
305
+ lines = raw.splitlines(keepends=True)
306
+ if lines and not lines[-1].endswith("\n"):
307
+ lines[-1] += "\n"
308
+ # Mirror to disk cache dir if configured
309
+ if self._cache_dir:
310
+ # Cross-platform safe filename: strip drive letter on Windows,
311
+ # replace path separators with underscores
312
+ resolved = str(path.resolve())
313
+ if len(resolved) >= 2 and resolved[1] == ":":
314
+ resolved = resolved[2:] # strip "C:"
315
+ safe_name = resolved.lstrip("\\/").replace("\\", "_").replace("/", "_")
316
+ try:
317
+ (self._cache_dir / safe_name).write_text(raw, encoding="utf-8")
318
+ except Exception:
319
+ pass
320
+
321
+ total_lines = len(lines)
322
+ user_specified_range = limit is not None
323
+
324
+ # Default cap: when the caller didn't ask for a specific range,
325
+ # truncate to MAX_READ_LINES so one file doesn't blow the context.
326
+ effective_limit = limit if user_specified_range else min(
327
+ len(lines), self.MAX_READ_LINES
328
+ )
329
+
330
+ start = (offset or 1) - 1
331
+ end = start + effective_limit
332
+
333
+ # Clamp
334
+ start = max(0, start)
335
+ end = min(len(lines), end)
336
+
337
+ selected = lines[start:end]
338
+ was_truncated = (end - start) < (total_lines - start) if not user_specified_range else False
339
+ # Also truncate if the user explicitly asked for a range but total
340
+ # output still exceeds the hard char cap (safety net).
341
+ char_truncated = False
342
+
343
+ # Format with line numbers
344
+ output_lines: list[str] = []
345
+ chars = 0
346
+ for i, line in enumerate(selected, start=start + 1):
347
+ formatted = f"{i:6d}\t{line.rstrip()}"
348
+ chars += len(formatted) + 1 # +1 for newline
349
+ if chars > self.MAX_READ_CHARS:
350
+ output_lines.append(f"... (output truncated at {self.MAX_READ_CHARS} chars, use offset/limit to read more)")
351
+ char_truncated = True
352
+ break
353
+ output_lines.append(formatted)
354
+
355
+ truncated = was_truncated or char_truncated
356
+ shown_lines = len(output_lines) - (1 if char_truncated else 0)
357
+ header = f"File: {path} (lines {start+1}-{start+shown_lines} of {total_lines}"
358
+ if truncated:
359
+ header += f", truncated — use offset/limit to page"
360
+ header += ")\n"
361
+ return ToolResult(success=True, output=header + "\n".join(output_lines))
362
+
363
+ async def _tool_write_file(
364
+ self, file_path: str, content: str
365
+ ) -> ToolResult:
366
+ """Create or overwrite a file. Captures old content for diff display."""
367
+ path = self._resolve_path(file_path)
368
+ self._ensure_in_workspace(path)
369
+
370
+ # Capture old content for diff (if file exists)
371
+ old_content = ""
372
+ if path.exists():
373
+ try:
374
+ with open(path, "r", encoding="utf-8") as f:
375
+ old_content = f.read()
376
+ except Exception:
377
+ pass
378
+
379
+ try:
380
+ path.parent.mkdir(parents=True, exist_ok=True)
381
+ with open(path, "w", encoding="utf-8") as f:
382
+ f.write(content)
383
+ size = path.stat().st_size
384
+
385
+ # Notify UI for diff display if overwriting
386
+ if old_content:
387
+ self._notify_edit(str(path), old_content)
388
+
389
+ return ToolResult(
390
+ success=True,
391
+ output=f"File written: {path} ({size} bytes, {content.count(chr(10))} lines)",
392
+ )
393
+ except Exception as e:
394
+ return ToolResult(
395
+ success=False, output="", error=f"Cannot write file: {e}"
396
+ )
397
+
398
+ async def _tool_edit_file(
399
+ self, file_path: str, old_string: str, new_string: str
400
+ ) -> ToolResult:
401
+ """Replace text in a file. Uses CST for Python (preserves formatting),
402
+ falls back to text replacement for other languages."""
403
+ if old_string == new_string:
404
+ return ToolResult(
405
+ success=False,
406
+ output="",
407
+ error="old_string and new_string are identical",
408
+ )
409
+
410
+ path = self._resolve_path(file_path)
411
+ if not path.exists():
412
+ return ToolResult(
413
+ success=False,
414
+ output="",
415
+ error=f"File not found: {path}",
416
+ )
417
+
418
+ try:
419
+ with open(path, "r", encoding="utf-8") as f:
420
+ content = f.read()
421
+ except Exception as e:
422
+ return ToolResult(
423
+ success=False, output="", error=f"Cannot read file: {e}"
424
+ )
425
+
426
+ # Store old content for diff display (via UI callback)
427
+ old_content = content
428
+
429
+ # ── CST-based edit for Python files ──────────────────────────────
430
+ if path.suffix == ".py" or path.suffix == ".pyi":
431
+ result = self._cst_edit(content, old_string, new_string, path)
432
+ if result is not None:
433
+ new_content = result
434
+ try:
435
+ with open(path, "w", encoding="utf-8") as f:
436
+ f.write(new_content)
437
+ self._notify_edit(str(path), old_content)
438
+ return ToolResult(
439
+ success=True,
440
+ output=f"File edited (AST): {path} (1 replacement)",
441
+ )
442
+ except Exception as e:
443
+ return ToolResult(
444
+ success=False, output="", error=f"Cannot write file: {e}"
445
+ )
446
+ # CST edit failed — fall through to text-based replacement
447
+
448
+ # ── Text-based fallback ──────────────────────────────────────────
449
+ count = content.count(old_string)
450
+ if count == 0:
451
+ return ToolResult(
452
+ success=False,
453
+ output="",
454
+ error="old_string not found in file. Check whitespace/indentation.",
455
+ )
456
+ if count > 1:
457
+ return ToolResult(
458
+ success=False,
459
+ output="",
460
+ error=f"old_string found {count} times in file. Must be unique. Use a larger string with more surrounding context.",
461
+ )
462
+
463
+ new_content = content.replace(old_string, new_string, 1)
464
+ try:
465
+ with open(path, "w", encoding="utf-8") as f:
466
+ f.write(new_content)
467
+
468
+ self._notify_edit(str(path), old_content)
469
+
470
+ return ToolResult(
471
+ success=True,
472
+ output=f"File edited: {path} (1 replacement)",
473
+ )
474
+ except Exception as e:
475
+ return ToolResult(
476
+ success=False, output="", error=f"Cannot write file: {e}"
477
+ )
478
+
479
+ @staticmethod
480
+ def _cst_edit(content: str, old_str: str, new_str: str, path: Any) -> str | None:
481
+ """Attempt a CST-based edit for Python files using libcst.
482
+
483
+ Parses the file as a Concrete Syntax Tree, finds the node matching
484
+ old_str, replaces it with new_str parsed as the same node type,
485
+ and returns the formatted code. Preserves ALL formatting.
486
+
487
+ Returns None if the edit cannot be performed via CST (fallback to text).
488
+ """
489
+ try:
490
+ import libcst as cst
491
+ except ImportError:
492
+ return None # libcst not installed — use text fallback
493
+
494
+ try:
495
+ tree = cst.parse_module(content)
496
+ except Exception:
497
+ return None # Can't parse — fall back to text
498
+
499
+ # Strategy: parse old_str as a statement, find it in the tree, replace
500
+ try:
501
+ # Try parsing old_str as a full module body (the typical case for
502
+ # function/method/class-level edits)
503
+ old_module = cst.parse_module(old_str + "\n")
504
+ if len(old_module.body) == 1:
505
+ old_node = old_module.body[0]
506
+ # Parse new_str as the same node type
507
+ new_module = cst.parse_module(new_str + "\n")
508
+ if len(new_module.body) == 1:
509
+ new_node = new_module.body[0]
510
+ transformer = _NodeReplacer(old_node, new_node)
511
+ new_tree = tree.visit(transformer)
512
+ if transformer.found:
513
+ return new_tree.code
514
+ except Exception:
515
+ pass
516
+
517
+ # Strategy 2: try parsing old_str as a simple statement line
518
+ try:
519
+ old_module = cst.parse_module(old_str + "\n")
520
+ if len(old_module.body) == 1:
521
+ old_stmt = old_module.body[0]
522
+ # For simple statements, use a body-statement-level replacer
523
+ new_module = cst.parse_module(new_str + "\n")
524
+ if len(new_module.body) == 1:
525
+ new_stmt = new_module.body[0]
526
+ transformer = _StatementReplacer(old_stmt, new_stmt)
527
+ new_tree = tree.visit(transformer)
528
+ if transformer.found:
529
+ return new_tree.code
530
+ except Exception:
531
+ pass
532
+
533
+ return None # All CST strategies failed — fall back to text
534
+
535
+
536
+ # ── CST Transformers (libcst optional — guarded at module level) ──────────
537
+
538
+ try:
539
+ import libcst as _cst_lib
540
+
541
+ class _NodeReplacer(_cst_lib.CSTTransformer):
542
+ """Replace a specific CST node with another, preserving all else."""
543
+
544
+ def __init__(self, old_node: _cst_lib.CSTNode, new_node: _cst_lib.CSTNode):
545
+ self.old_node = old_node
546
+ self.new_node = new_node
547
+ self.found = False
548
+
549
+ def on_visit(self, node: _cst_lib.CSTNode) -> bool:
550
+ if node.deep_equals(self.old_node) and not self.found:
551
+ self.found = True
552
+ return False
553
+ return True
554
+
555
+ def on_leave(self, original_node: _cst_lib.CSTNode, updated_node: _cst_lib.CSTNode) -> _cst_lib.CSTNode:
556
+ if original_node.deep_equals(self.old_node) and self.found:
557
+ return self.new_node
558
+ return updated_node
559
+
560
+ class _StatementReplacer(_cst_lib.CSTTransformer):
561
+ """Replace a statement within a body, matching by deep equality."""
562
+
563
+ def __init__(self, old_stmt: _cst_lib.CSTNode, new_stmt: _cst_lib.CSTNode):
564
+ self.old_stmt = old_stmt
565
+ self.new_stmt = new_stmt
566
+ self.found = False
567
+
568
+ def leave_SimpleStatementLine(
569
+ self, original_node: _cst_lib.SimpleStatementLine, updated_node: _cst_lib.SimpleStatementLine
570
+ ):
571
+ if not self.found and len(updated_node.body) == 1:
572
+ if updated_node.body[0].deep_equals(self.old_stmt):
573
+ self.found = True
574
+ return updated_node.with_changes(body=[self.new_stmt])
575
+ return updated_node
576
+
577
+ except ImportError:
578
+ # libcst not installed — AST editing will fall back to text replacement
579
+ _NodeReplacer = None # type: ignore
580
+ _StatementReplacer = None # type: ignore
581
+
582
+ # ── Rename symbol (AST-aware) ────────────────────────────────────────────
583
+
584
+ async def _tool_rename_symbol(
585
+ self, file_path: str, old_name: str, new_name: str,
586
+ symbol_type: str = "variable"
587
+ ) -> ToolResult:
588
+ """Safely rename a Python symbol using libcst AST matching.
589
+ Never touches strings, comments, or imports of the same name.
590
+ """
591
+ if old_name == new_name:
592
+ return ToolResult(success=False, output="", error="Names are identical.")
593
+ if not old_name.isidentifier() or not new_name.isidentifier():
594
+ return ToolResult(success=False, output="", error="Names must be valid Python identifiers.")
595
+
596
+ path = self._resolve_path(file_path)
597
+ if not path.exists():
598
+ return ToolResult(success=False, output="", error=f"File not found: {path}")
599
+ if path.suffix not in (".py", ".pyi"):
600
+ return ToolResult(success=False, output="", error="rename_symbol only works with Python files.")
601
+
602
+ try:
603
+ with open(path, "r", encoding="utf-8") as f:
604
+ content = f.read()
605
+ except Exception as e:
606
+ return ToolResult(success=False, output="", error=f"Cannot read file: {e}")
607
+
608
+ try:
609
+ import libcst as cst
610
+ except ImportError:
611
+ return ToolResult(success=False, output="", error="libcst not installed. Run: pip install libcst")
612
+
613
+ try:
614
+ tree = cst.parse_module(content)
615
+ except Exception as e:
616
+ return ToolResult(success=False, output="", error=f"Cannot parse Python file: {e}")
617
+
618
+ # Choose the right renamer for the symbol type
619
+ if symbol_type == "function":
620
+ renamer = _FunctionRenamer(old_name, new_name)
621
+ elif symbol_type == "class":
622
+ renamer = _ClassRenamer(old_name, new_name)
623
+ else:
624
+ renamer = _VariableRenamer(old_name, new_name)
625
+
626
+ new_tree = tree.visit(renamer)
627
+ if not renamer.changes:
628
+ return ToolResult(success=False, output="", error=f"Symbol '{old_name}' not found in file.")
629
+
630
+ new_content = new_tree.code
631
+ old_content = content
632
+
633
+ try:
634
+ with open(path, "w", encoding="utf-8") as f:
635
+ f.write(new_content)
636
+ self._notify_edit(str(path), old_content)
637
+ return ToolResult(
638
+ success=True,
639
+ output=f"Renamed {renamer.changes} occurrence(s) of '{old_name}' → '{new_name}' in {path}",
640
+ )
641
+ except Exception as e:
642
+ return ToolResult(success=False, output="", error=f"Cannot write file: {e}")
643
+
644
+ # ── Shell tool ───────────────────────────────────────────────────────────
645
+
646
+ async def _tool_run_shell(
647
+ self, command: str, timeout: int = 120
648
+ ) -> ToolResult:
649
+ """Execute a shell command."""
650
+ # Safety checks
651
+ cmd_lower = command.lower().strip()
652
+
653
+ # Check blocked patterns (hard block — unforgivable operations)
654
+ for blocked in self.config.blocked_commands:
655
+ if blocked.lower() in cmd_lower:
656
+ return ToolResult(
657
+ success=False,
658
+ output="",
659
+ error=f"Blocked command pattern detected: {blocked}",
660
+ )
661
+
662
+ # Check if first word is allowed (soft warning — safety_guard is the real gate)
663
+ first_word = command.strip().split()[0] if command.strip() else ""
664
+ if first_word and first_word not in self.config.allowed_commands:
665
+ logger.debug("Command '%s' not in allowed_commands whitelist, proceeding", first_word)
666
+
667
+ try:
668
+ proc = await asyncio.create_subprocess_shell(
669
+ command,
670
+ stdout=asyncio.subprocess.PIPE,
671
+ stderr=asyncio.subprocess.PIPE,
672
+ cwd=str(self.workspace),
673
+ )
674
+ stdout_bytes, stderr_bytes = await asyncio.wait_for(
675
+ proc.communicate(), timeout=timeout
676
+ )
677
+ output = stdout_bytes.decode("utf-8", errors="replace")
678
+ if stderr_bytes:
679
+ output += f"\n[stderr]\n{stderr_bytes.decode('utf-8', errors='replace')}"
680
+ returncode = proc.returncode or 0
681
+ if returncode != 0:
682
+ output += f"\n[exit code: {returncode}]"
683
+ return ToolResult(
684
+ success=returncode == 0,
685
+ output=output.strip() or "(no output)",
686
+ )
687
+ except asyncio.TimeoutError:
688
+ return ToolResult(
689
+ success=False,
690
+ output="",
691
+ error=f"Command timed out after {timeout}s",
692
+ )
693
+ except Exception as e:
694
+ return ToolResult(
695
+ success=False, output="", error=f"Command failed: {e}"
696
+ )
697
+
698
+ # ── Search tools ─────────────────────────────────────────────────────────
699
+
700
+ async def _tool_grep(
701
+ self,
702
+ pattern: str,
703
+ path: str | None = None,
704
+ glob: str | None = None,
705
+ case_sensitive: bool = False,
706
+ ) -> ToolResult:
707
+ """Search file contents with regex."""
708
+ search_dir = self._resolve_path(path or ".")
709
+ if not search_dir.exists():
710
+ return ToolResult(
711
+ success=False,
712
+ output="",
713
+ error=f"Directory not found: {search_dir}",
714
+ )
715
+
716
+ flags = 0 if case_sensitive else re.IGNORECASE
717
+ try:
718
+ regex = re.compile(pattern, flags)
719
+ except re.error as e:
720
+ return ToolResult(
721
+ success=False, output="", error=f"Invalid regex: {e}"
722
+ )
723
+
724
+ # Run filesystem walk + file reads in thread pool to avoid
725
+ # blocking the asyncio event loop on large codebases.
726
+ def _do_grep():
727
+ results: list[str] = []
728
+ total_matches = 0
729
+ for root, dirs, files in os.walk(search_dir):
730
+ dirs[:] = [
731
+ d for d in dirs
732
+ if not d.startswith(".")
733
+ and d not in self.SKIP_DIRS
734
+ ]
735
+ for fname in files:
736
+ if glob and not fnmatch.fnmatch(fname, glob):
737
+ continue
738
+
739
+ full_path = os.path.join(root, fname)
740
+ try:
741
+ rel_path = os.path.relpath(full_path, self.workspace)
742
+ except ValueError:
743
+ rel_path = full_path
744
+
745
+ try:
746
+ with open(full_path, "r", encoding="utf-8", errors="replace") as fh:
747
+ file_lines = fh.readlines()
748
+ except Exception:
749
+ continue
750
+
751
+ matches_in_file = []
752
+ for line_no, line_text in enumerate(file_lines, 1):
753
+ if regex.search(line_text):
754
+ matches_in_file.append(
755
+ f" {line_no}: {line_text.rstrip()[:200]}"
756
+ )
757
+ total_matches += 1
758
+ if len(matches_in_file) >= self.MAX_GREP_PER_FILE:
759
+ matches_in_file.append(" ... (truncated)")
760
+ break
761
+
762
+ if matches_in_file:
763
+ results.append(f"{rel_path}:")
764
+ results.extend(matches_in_file)
765
+
766
+ if len(results) >= self.MAX_GREP_RESULTS:
767
+ results.append("... (result limit reached)")
768
+ return results, total_matches
769
+
770
+ if len(results) >= self.MAX_GREP_RESULTS:
771
+ break
772
+ return results, total_matches
773
+
774
+ results, total_matches = await self._run_in_thread(_do_grep)
775
+
776
+ if not results:
777
+ return ToolResult(
778
+ success=True,
779
+ output=f"No matches found for pattern: {pattern}",
780
+ )
781
+ return ToolResult(
782
+ success=True,
783
+ output=f"Found {total_matches} matches:\n\n" + "\n".join(results),
784
+ )
785
+
786
+ async def _tool_glob(
787
+ self,
788
+ pattern: str,
789
+ path: str | None = None,
790
+ ) -> ToolResult:
791
+ """Find files by glob pattern."""
792
+ search_dir = self._resolve_path(path or ".")
793
+ if not search_dir.exists():
794
+ return ToolResult(
795
+ success=False,
796
+ output="",
797
+ error=f"Directory not found: {search_dir}",
798
+ )
799
+
800
+ import glob as glob_mod
801
+
802
+ # Auto-add **/ prefix for recursive matching if not already present
803
+ if "**" not in pattern:
804
+ pattern = f"**/{pattern}"
805
+ search_pattern = str(search_dir / pattern)
806
+ matches = glob_mod.glob(search_pattern, recursive=True)
807
+
808
+ if not matches:
809
+ return ToolResult(
810
+ success=True, output=f"No files matching: {pattern}"
811
+ )
812
+
813
+ # Sort and format
814
+ matches.sort()
815
+ output_lines = []
816
+ for m in matches[:200]:
817
+ try:
818
+ rel = os.path.relpath(m, self.workspace)
819
+ except ValueError:
820
+ rel = m
821
+ size = os.path.getsize(m) if os.path.isfile(m) else 0
822
+ output_lines.append(f" {rel} ({size:,} bytes)")
823
+
824
+ if len(matches) > 200:
825
+ output_lines.append(
826
+ f" ... and {len(matches) - 200} more files"
827
+ )
828
+
829
+ return ToolResult(
830
+ success=True,
831
+ output=f"Found {len(matches)} files matching '{pattern}':\n"
832
+ + "\n".join(output_lines),
833
+ )
834
+
835
+ async def _tool_list_dir(
836
+ self,
837
+ path: str | None = None,
838
+ recursive: bool = False,
839
+ ) -> ToolResult:
840
+ """List directory contents."""
841
+ target = self._resolve_path(path or ".")
842
+ if not target.exists():
843
+ return ToolResult(
844
+ success=False,
845
+ output="",
846
+ error=f"Directory not found: {target}",
847
+ )
848
+ if not target.is_dir():
849
+ return ToolResult(
850
+ success=False,
851
+ output="",
852
+ error=f"Not a directory: {target}",
853
+ )
854
+
855
+ output_lines = [f"Directory: {target}"]
856
+ entries: list[str] = []
857
+
858
+ if recursive:
859
+ for root, dirs, files in os.walk(target):
860
+ dirs[:] = [
861
+ d for d in dirs
862
+ if not d.startswith(".") and d not in self.SKIP_DIRS
863
+ ]
864
+ level = root.replace(str(target), "").count(os.sep)
865
+ indent = " " * level
866
+ if level > 0:
867
+ entries.append(f"{indent}{os.path.basename(root)}/")
868
+ for f in sorted(files):
869
+ fp = os.path.join(root, f)
870
+ size = os.path.getsize(fp)
871
+ entries.append(f"{indent} {f} ({size:,}B)")
872
+ else:
873
+ items = sorted(target.iterdir(), key=lambda x: (not x.is_dir(), x.name))
874
+ for item in items:
875
+ suffix = "/" if item.is_dir() else ""
876
+ size = ""
877
+ if item.is_file():
878
+ size = f" ({item.stat().st_size:,}B)"
879
+ entries.append(f" {item.name}{suffix}{size}")
880
+
881
+ output_lines.extend(entries[:500])
882
+ if len(entries) > 500:
883
+ output_lines.append(f" ... and {len(entries) - 500} more entries")
884
+
885
+ return ToolResult(success=True, output="\n".join(output_lines))
886
+
887
+
888
+ # ── Web tools ──────────────────────────────────────────────────────────
889
+
890
+ # Internal HTTP client (lazy-init, shared across web tools)
891
+ _http: httpx.Client | None = None
892
+
893
+ async def _run_in_thread(self, func, *args, **kwargs):
894
+ """Run a sync function in a thread pool to avoid blocking the event loop."""
895
+ from functools import partial
896
+ loop = asyncio.get_running_loop()
897
+ return await loop.run_in_executor(None, partial(func, *args, **kwargs))
898
+
899
+ @property
900
+ def http(self) -> httpx.Client:
901
+ if self._http is None:
902
+ self._http = httpx.Client(
903
+ timeout=httpx.Timeout(30.0),
904
+ follow_redirects=True,
905
+ headers={
906
+ "User-Agent": (
907
+ "ATA-Coder/2.0 (AI Coding Assistant; "
908
+ "+https://github.com/ata-coder/ata-coder)"
909
+ ),
910
+ "Accept": "text/html,application/xhtml+xml,*/*",
911
+ "Accept-Language": "en-US,zh-CN;q=0.9",
912
+ },
913
+ )
914
+ return self._http
915
+
916
+ async def _tool_web_search(
917
+ self,
918
+ query: str,
919
+ max_results: int = 10,
920
+ ) -> ToolResult:
921
+ """Search the web with tiered fallback: Bing → Baidu → Google.
922
+
923
+ All three use web scraping (no API key required).
924
+ Set ATA_CODER_SEARCH_BACKEND to force a single backend:
925
+ "bing" / "baidu" / "google" / "duckduckgo"
926
+ """
927
+ import os
928
+ max_results = min(max(max_results, 1), 20)
929
+ forced = os.environ.get("ATA_CODER_SEARCH_BACKEND", "")
930
+
931
+ errors: list[str] = []
932
+
933
+ # Build fallback chain: respect forced backend, otherwise tiered
934
+ if forced:
935
+ chain = [(forced, getattr(self, f"_search_{forced}", None))]
936
+ else:
937
+ chain = [
938
+ ("Bing", self._search_bing),
939
+ ("Baidu", self._search_baidu),
940
+ ("Google", self._search_google),
941
+ ]
942
+
943
+ for name, searcher in chain:
944
+ if searcher is None:
945
+ errors.append(f"{name}: unsupported backend")
946
+ continue
947
+ try:
948
+ # Run sync search in thread pool to avoid blocking event loop
949
+ results = await self._run_in_thread(searcher, query)
950
+ if results:
951
+ return self._format_search_results(query, results, max_results, name)
952
+ errors.append(f"{name} returned no results")
953
+ except httpx.TimeoutException:
954
+ errors.append(f"{name} timed out")
955
+ except httpx.HTTPStatusError as e:
956
+ errors.append(f"{name} HTTP {e.response.status_code}")
957
+ except Exception as e:
958
+ errors.append(f"{name}: {e}")
959
+
960
+ return ToolResult(
961
+ success=False, output="",
962
+ error=f"Search failed: {'; '.join(errors)}"
963
+ )
964
+
965
+ def _search_bing(self, query: str) -> list[dict[str, str]]:
966
+ """Search Bing (web scraping, no API key)."""
967
+ import urllib.parse
968
+ url = f"https://www.bing.com/search?q={urllib.parse.quote(query)}&setlang=en"
969
+ headers = {
970
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
971
+ "(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
972
+ "Accept-Language": "en-US,en;q=0.9",
973
+ }
974
+ resp = self.http.get(url, headers=headers)
975
+ resp.raise_for_status()
976
+
977
+ results: list[dict[str, str]] = []
978
+ # Bing results are in <li class="b_algo"> blocks
979
+ blocks = re.findall(
980
+ r'<li[^>]*class="[^"]*b_algo[^"]*"[^>]*>(.*?)</li>',
981
+ resp.text, re.DOTALL | re.IGNORECASE,
982
+ )
983
+ for block in blocks:
984
+ # Title + link in <h2><a href="...">title</a></h2>
985
+ m = re.search(r'<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
986
+ if not m:
987
+ continue
988
+ href = html.unescape(m.group(1).strip())
989
+ title = re.sub(r'<[^>]+>', '', m.group(2)).strip()
990
+ if not title or not href.startswith("http"):
991
+ continue
992
+ # Snippet in <p> or <div class="b_caption">
993
+ snippet = ""
994
+ sm = re.search(
995
+ r'<(?:p|div)[^>]*class="[^"]*(?:b_caption|b_lineclamp)[^"]*"[^>]*>(.*?)</(?:p|div)>',
996
+ block, re.DOTALL | re.IGNORECASE,
997
+ )
998
+ if sm:
999
+ snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
1000
+ snippet = html.unescape(snippet)
1001
+ results.append({"title": title, "url": href, "snippet": snippet})
1002
+
1003
+ return results
1004
+
1005
+ def _search_baidu(self, query: str) -> list[dict[str, str]]:
1006
+ """Search Baidu (web scraping, no API key)."""
1007
+ import urllib.parse
1008
+ url = f"https://www.baidu.com/s?wd={urllib.parse.quote(query)}&ie=utf-8"
1009
+ headers = {
1010
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
1011
+ "(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
1012
+ "Accept-Language": "zh-CN,zh;q=0.9",
1013
+ }
1014
+ resp = self.http.get(url, headers=headers)
1015
+ resp.raise_for_status()
1016
+
1017
+ results: list[dict[str, str]] = []
1018
+ # Baidu results: <div class="result c-container"> or <div class="c-container">
1019
+ blocks = re.findall(
1020
+ r'<div[^>]*class="[^"]*(?:result|c-container)[^"]*"[^>]*>(.*?)</div>\s*(?=<div[^>]*class="[^"]*(?:result|c-container)|$)',
1021
+ resp.text, re.DOTALL | re.IGNORECASE,
1022
+ )
1023
+ if not blocks:
1024
+ # Fallback: match h3 titles with links
1025
+ blocks = re.findall(
1026
+ r'<div[^>]*class="[^"]*c-container[^"]*"[^>]*>(.*?)</div>',
1027
+ resp.text, re.DOTALL | re.IGNORECASE,
1028
+ )
1029
+
1030
+ for block in blocks:
1031
+ m = re.search(r'<a[^>]*href="([^"]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
1032
+ if not m:
1033
+ continue
1034
+ href = html.unescape(m.group(1).strip())
1035
+ title = re.sub(r'<[^>]+>', '', m.group(2)).strip()
1036
+ if not title or not href.startswith("http"):
1037
+ continue
1038
+ snippet = ""
1039
+ sm = re.search(
1040
+ r'<(?:span|div|p)[^>]*class="[^"]*(?:content-right_[^"]*|c-abstract|content)[^"]*"[^>]*>(.*?)</(?:span|div|p)>',
1041
+ block, re.DOTALL | re.IGNORECASE,
1042
+ )
1043
+ if sm:
1044
+ snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
1045
+ snippet = html.unescape(snippet)
1046
+ results.append({"title": title, "url": href, "snippet": snippet})
1047
+
1048
+ return results
1049
+
1050
+ def _search_google(self, query: str) -> list[dict[str, str]]:
1051
+ """Search Google (web scraping, no API key)."""
1052
+ import urllib.parse
1053
+ url = f"https://www.google.com/search?q={urllib.parse.quote(query)}&hl=en"
1054
+ headers = {
1055
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
1056
+ "(KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
1057
+ "Accept-Language": "en-US,en;q=0.9",
1058
+ }
1059
+ resp = self.http.get(url, headers=headers)
1060
+ resp.raise_for_status()
1061
+
1062
+ results: list[dict[str, str]] = []
1063
+ # Google results are in <div class="g"> or <div data-sokoban-container>
1064
+ blocks = re.findall(
1065
+ r'<(?:div|li)[^>]*\b(?:class="g\b|data-sokoban-container)[^>]*>(.*?)</(?:div|li)>',
1066
+ resp.text, re.DOTALL | re.IGNORECASE,
1067
+ )
1068
+ for block in blocks:
1069
+ # Title + link: <h3>...<a href="...">title</a></h3>
1070
+ m = re.search(r'<a[^>]*href="(/url\?q=|)([^"&]*)"[^>]*>(.*?)</a>', block, re.DOTALL)
1071
+ if not m:
1072
+ continue
1073
+ href = html.unescape(m.group(2).strip())
1074
+ if not href.startswith("http"):
1075
+ href = "https://www.google.com" + m.group(1) + m.group(2)
1076
+ title = re.sub(r'<[^>]+>', '', m.group(3)).strip()
1077
+ if not title:
1078
+ continue
1079
+ # Snippet: <span class="aCOpRe"> or various other classes
1080
+ snippet = ""
1081
+ sm = re.search(
1082
+ r'<(?:span|div)[^>]*\b(?:class="[^"]*(?:\baCOpRe\b|st\b)[^"]*")[^>]*>(.*?)</(?:span|div)>',
1083
+ block, re.DOTALL | re.IGNORECASE,
1084
+ )
1085
+ if sm:
1086
+ snippet = re.sub(r'<[^>]+>', '', sm.group(1)).strip()
1087
+ snippet = html.unescape(snippet)
1088
+ results.append({"title": title, "url": href, "snippet": snippet})
1089
+
1090
+ return results
1091
+
1092
+ @staticmethod
1093
+ def _format_search_results(
1094
+ query: str, results: list[dict[str, str]], max_results: int, source: str
1095
+ ) -> ToolResult:
1096
+ out = [f"Search results for: {query} (via {source})\n"]
1097
+ for i, r in enumerate(results[:max_results], 1):
1098
+ out.append(f"{i}. **{html.unescape(r['title'])}**")
1099
+ out.append(f" {r['url']}")
1100
+ if r.get("snippet"):
1101
+ out.append(f" {html.unescape(r['snippet'])}")
1102
+ out.append("")
1103
+ return ToolResult(success=True, output="\n".join(out))
1104
+
1105
+ @staticmethod
1106
+ def _parse_ddg_lite(html_text: str) -> list[dict[str, str]]:
1107
+ """Extract search results from DuckDuckGo Lite HTML."""
1108
+ results: list[dict[str, str]] = []
1109
+
1110
+ # DDG Lite: results are in <a> tags with class="result-link"
1111
+ # and snippets in <td class="result-snippet">
1112
+ link_pattern = re.compile(
1113
+ r'<a[^>]*href="([^"]*)"[^>]*class="[^"]*result-link[^"]*"[^>]*>(.*?)</a>',
1114
+ re.DOTALL | re.IGNORECASE,
1115
+ )
1116
+ snippet_pattern = re.compile(
1117
+ r'<td[^>]*class="[^"]*result-snippet[^"]*"[^>]*>(.*?)</td>',
1118
+ re.DOTALL | re.IGNORECASE,
1119
+ )
1120
+
1121
+ links = link_pattern.findall(html_text)
1122
+ snippets = snippet_pattern.findall(html_text)
1123
+
1124
+ for i, (href, title) in enumerate(links):
1125
+ href = html.unescape(href.strip())
1126
+ title = re.sub(r'<[^>]+>', '', title).strip()
1127
+ if not title:
1128
+ continue
1129
+
1130
+ # Pick corresponding snippet
1131
+ snippet = ""
1132
+ if i < len(snippets):
1133
+ snippet = re.sub(r'<[^>]+>', '', snippets[i])
1134
+ snippet = html.unescape(snippet.strip())
1135
+
1136
+ results.append({
1137
+ "title": title,
1138
+ "url": href,
1139
+ "snippet": snippet[:300],
1140
+ })
1141
+
1142
+ return results
1143
+
1144
+ async def _tool_web_fetch(self, url: str) -> ToolResult:
1145
+ """Fetch a URL and extract its text content."""
1146
+ if not url.startswith(("http://", "https://")):
1147
+ return ToolResult(
1148
+ success=False, output="",
1149
+ error=f"Invalid URL: must start with http:// or https://"
1150
+ )
1151
+
1152
+ def _do_fetch():
1153
+ return self.http.get(url)
1154
+
1155
+ try:
1156
+ resp = await self._run_in_thread(_do_fetch)
1157
+ resp.raise_for_status()
1158
+ except httpx.TimeoutException:
1159
+ return ToolResult(
1160
+ success=False, output="",
1161
+ error=f"Request timed out: {url}"
1162
+ )
1163
+ except httpx.HTTPStatusError as e:
1164
+ return ToolResult(
1165
+ success=False, output="",
1166
+ error=f"HTTP {e.response.status_code} for {url}"
1167
+ )
1168
+ except Exception as e:
1169
+ return ToolResult(
1170
+ success=False, output="",
1171
+ error=f"Fetch failed: {e}"
1172
+ )
1173
+
1174
+ content_type = resp.headers.get("content-type", "")
1175
+ if "text/html" not in content_type and "text/plain" not in content_type:
1176
+ return ToolResult(
1177
+ success=False, output="",
1178
+ error=f"Cannot process content type: {content_type}. Only text/html and text/plain are supported."
1179
+ )
1180
+
1181
+ text = self._extract_text(resp.text, url)
1182
+
1183
+ # Truncate
1184
+ MAX_CHARS = 15_000
1185
+ if len(text) > MAX_CHARS:
1186
+ text = text[:MAX_CHARS] + (
1187
+ f"\n\n... [truncated {len(text) - MAX_CHARS:,} "
1188
+ f"chars from {url}]"
1189
+ )
1190
+
1191
+ return ToolResult(
1192
+ success=True,
1193
+ output=f"Content from: {url}\n\n{text}",
1194
+ )
1195
+
1196
+ # ── Sub-agent tools ──────────────────────────────────────────────────
1197
+
1198
+ async def _tool_spawn_subagent(self, task: str, skill: str = "",
1199
+ model: str = "") -> ToolResult:
1200
+ """Spawn a sub-agent to work on a task in parallel."""
1201
+ if not self._sub_agent_mgr:
1202
+ return ToolResult(
1203
+ success=False, output="",
1204
+ error="SubAgentManager not available. "
1205
+ "Ensure agent_controller is used.",
1206
+ )
1207
+ try:
1208
+ # Clawd: SubagentStart
1209
+ get_clawd().subagent_start()
1210
+
1211
+ agent_id = self._sub_agent_mgr.spawn(
1212
+ task=task,
1213
+ skill_prompt=skill,
1214
+ model=model or None,
1215
+ )
1216
+ return ToolResult(
1217
+ success=True,
1218
+ output=(
1219
+ f"Sub-agent spawned: {agent_id}\n"
1220
+ f"Status: running\n"
1221
+ f"Active sub-agents: {self._sub_agent_mgr.active_count}\n\n"
1222
+ f"Use collect_subagent('{agent_id}') to retrieve results, "
1223
+ f"or list_subagents() to check all statuses."
1224
+ ),
1225
+ )
1226
+ except RuntimeError as e:
1227
+ return ToolResult(
1228
+ success=False, output="",
1229
+ error=f"Cannot spawn sub-agent: {e}",
1230
+ )
1231
+
1232
+ async def _tool_collect_subagent(self, agent_id: str,
1233
+ timeout: float = 300.0) -> ToolResult:
1234
+ """Collect results from a spawned sub-agent."""
1235
+ if not self._sub_agent_mgr:
1236
+ return ToolResult(
1237
+ success=False, output="",
1238
+ error="SubAgentManager not available.",
1239
+ )
1240
+ result = self._sub_agent_mgr.collect(agent_id, timeout=timeout)
1241
+
1242
+ # Clawd: SubagentStop
1243
+ get_clawd().subagent_stop()
1244
+
1245
+ if result.success:
1246
+ lines = [
1247
+ f"Sub-agent {agent_id} completed successfully.",
1248
+ f"Tool calls: {result.tool_call_count}",
1249
+ f"",
1250
+ f"Result:",
1251
+ result.result or "(empty)",
1252
+ ]
1253
+ return ToolResult(success=True, output="\n".join(lines))
1254
+ else:
1255
+ return ToolResult(
1256
+ success=False,
1257
+ output=f"Sub-agent {agent_id} failed: {result.error}",
1258
+ error=result.error,
1259
+ )
1260
+
1261
+ async def _tool_list_subagents(self) -> ToolResult:
1262
+ """List all sub-agents and their statuses."""
1263
+ if not self._sub_agent_mgr:
1264
+ return ToolResult(
1265
+ success=False, output="",
1266
+ error="SubAgentManager not available.",
1267
+ )
1268
+ agents = self._sub_agent_mgr.list_all()
1269
+ if not agents:
1270
+ return ToolResult(success=True, output="No sub-agents.")
1271
+
1272
+ lines = [f"Sub-agents ({len(agents)} total):", ""]
1273
+ for a in agents:
1274
+ status_icon = {"running": "🔄", "done": "✅",
1275
+ "failed": "❌", "cancelled": "⏹️"}.get(a.status, "❓")
1276
+ lines.append(
1277
+ f" {status_icon} {a.id} — {a.status} "
1278
+ f"(tool_calls={a.tool_call_count})"
1279
+ )
1280
+ return ToolResult(success=True, output="\n".join(lines))
1281
+
1282
+ async def _tool_mcp_search(self, query: str, type: str = "all") -> ToolResult:
1283
+ """Search MCP tools and resources across all connected servers."""
1284
+ if not self._mcp:
1285
+ return ToolResult(
1286
+ success=False, output="",
1287
+ error="MCP not configured. Add MCP servers via --mcp-config.",
1288
+ )
1289
+
1290
+ servers = self._mcp.connected_servers
1291
+ if not servers:
1292
+ return ToolResult(success=True, output="No MCP servers connected.")
1293
+
1294
+ lines = [f"MCP search results for '{query}' across {len(servers)} server(s):", ""]
1295
+ found = 0
1296
+
1297
+ if type in ("tools", "all"):
1298
+ tools = self._mcp.search_tools(query, limit=20)
1299
+ if tools:
1300
+ lines.append(f" Tools ({len(tools)}):")
1301
+ for t in tools:
1302
+ name = t.get("name", "?")
1303
+ desc = (t.get("description") or "")[:100]
1304
+ server = t.get("_mcp_server", "?")
1305
+ lines.append(f" ● {name} @{server}")
1306
+ if desc:
1307
+ lines.append(f" {desc}")
1308
+ found += len(tools)
1309
+ else:
1310
+ lines.append(" Tools: none found")
1311
+
1312
+ if type in ("resources", "all"):
1313
+ resources = self._mcp.search_resources(query, limit=20)
1314
+ if resources:
1315
+ lines.append(f"\n Resources ({len(resources)}):")
1316
+ for r in resources:
1317
+ uri = r.get("uri", "?")
1318
+ name = r.get("name", "")
1319
+ desc = (r.get("description") or "")[:80]
1320
+ server = r.get("_mcp_server", "?")
1321
+ label = name or uri
1322
+ lines.append(f" ● {label} @{server}")
1323
+ if desc:
1324
+ lines.append(f" {desc}")
1325
+ found += len(resources)
1326
+ else:
1327
+ lines.append("\n Resources: none found")
1328
+
1329
+ if found == 0:
1330
+ return ToolResult(
1331
+ success=True,
1332
+ output=f"No MCP tools or resources found matching '{query}'.\n"
1333
+ f"Connected servers: {', '.join(servers)}.",
1334
+ )
1335
+
1336
+ return ToolResult(success=True, output="\n".join(lines))
1337
+
1338
+ async def _tool_analyze_image(self, image_path: str, prompt: str = "Describe this image in detail.") -> ToolResult:
1339
+ """Analyze an image using a multimodal vision model.
1340
+
1341
+ Uses the configured vision model, falling back to the main LLM config.
1342
+ Configure via ~/.ata_coder/settings.json:
1343
+ {"vision": {"model": "...", "api_base": "...", "api_key": "..."}}
1344
+ Or env vars: VISION_MODEL, VISION_API_BASE, VISION_API_KEY.
1345
+ """
1346
+ import base64
1347
+ from pathlib import Path
1348
+
1349
+ img_path = Path(image_path)
1350
+ if not img_path.exists():
1351
+ return ToolResult(
1352
+ success=False, output="",
1353
+ error=f"Image not found: {image_path}",
1354
+ )
1355
+
1356
+ ext = img_path.suffix.lower()
1357
+ supported = {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
1358
+ if ext not in supported:
1359
+ return ToolResult(
1360
+ success=False, output="",
1361
+ error=f"Unsupported image format: {ext}. Supported: {', '.join(sorted(supported))}",
1362
+ )
1363
+
1364
+ try:
1365
+ with open(img_path, "rb") as f:
1366
+ img_b64 = base64.standard_b64encode(f.read()).decode("ascii")
1367
+ except Exception as e:
1368
+ return ToolResult(success=False, output="", error=f"Failed to read image: {e}")
1369
+
1370
+ # ── Resolve vision config ──
1371
+ # Priority: env var > settings.json > main api config
1372
+ from .settings import get_settings
1373
+ settings = get_settings()
1374
+
1375
+ # API key: VISION_API_KEY env > settings.json vision.api_key > main api key
1376
+ api_key = (
1377
+ os.environ.get("VISION_API_KEY", "")
1378
+ or settings.vision_api_key
1379
+ or os.environ.get("ATA_CODER_API_KEY", "")
1380
+ or os.environ.get("OPENAI_API_KEY", "")
1381
+ or settings.api_key
1382
+ )
1383
+ if not api_key:
1384
+ return ToolResult(
1385
+ success=False, output="",
1386
+ error="No API key configured. Set ATA_CODER_API_KEY or add vision.api_key in ~/.ata_coder/settings.json.",
1387
+ )
1388
+
1389
+ # API base: VISION_API_BASE env > settings.json vision.api_base > main base_url
1390
+ api_base = (
1391
+ os.environ.get("VISION_API_BASE", "")
1392
+ or settings.vision_api_base
1393
+ or os.environ.get("ATA_CODER_BASE_URL", "")
1394
+ or os.environ.get("OPENAI_BASE_URL", "")
1395
+ or settings.api_base_url
1396
+ )
1397
+
1398
+ # Model: VISION_MODEL env > settings.json vision.model > main model
1399
+ model = (
1400
+ os.environ.get("VISION_MODEL", "")
1401
+ or settings.vision_model
1402
+ or os.environ.get("ATA_CODER_DEFAULT_MODEL", "")
1403
+ or os.environ.get("OPENAI_MODEL", "")
1404
+ or settings.default_model
1405
+ )
1406
+
1407
+ mime = ext.replace("jpg", "jpeg").replace(".", "image/")
1408
+ body = {
1409
+ "model": model,
1410
+ "messages": [{
1411
+ "role": "user",
1412
+ "content": [
1413
+ {"type": "text", "text": prompt},
1414
+ {"type": "image_url", "image_url": {
1415
+ "url": f"data:{mime};base64,{img_b64}",
1416
+ "detail": "auto"
1417
+ }},
1418
+ ]
1419
+ }],
1420
+ "max_tokens": 2048,
1421
+ "temperature": 0.3,
1422
+ }
1423
+
1424
+ try:
1425
+ import json as _json
1426
+ from urllib.request import Request, urlopen
1427
+ from urllib.error import HTTPError
1428
+
1429
+ data = _json.dumps(body).encode("utf-8")
1430
+ req = Request(
1431
+ f"{api_base.rstrip('/')}/chat/completions",
1432
+ data=data,
1433
+ headers={
1434
+ "Content-Type": "application/json",
1435
+ "Authorization": f"Bearer {api_key}",
1436
+ },
1437
+ )
1438
+ with urlopen(req, timeout=120) as resp:
1439
+ result = _json.loads(resp.read().decode("utf-8"))
1440
+ content = (
1441
+ result.get("choices", [{}])[0]
1442
+ .get("message", {})
1443
+ .get("content", "(no response)")
1444
+ )
1445
+ usage = result.get("usage", {})
1446
+ tokens = usage.get("total_tokens", "?")
1447
+ return ToolResult(
1448
+ success=True,
1449
+ output=f"[Vision: {model} | {tokens} tokens]\n\n{content}",
1450
+ )
1451
+ except HTTPError as e:
1452
+ error_body = e.read().decode("utf-8", errors="replace")[:300]
1453
+ return ToolResult(
1454
+ success=False, output="",
1455
+ error=f"Vision API error {e.code}: {error_body}",
1456
+ )
1457
+ except Exception as e:
1458
+ return ToolResult(
1459
+ success=False, output="",
1460
+ error=f"Vision API call failed: {e}",
1461
+ )
1462
+
1463
+ @staticmethod
1464
+ def _extract_text(html_text: str, url: str = "") -> str:
1465
+ """Strip HTML down to readable text."""
1466
+
1467
+ class _TextExtractor(html.parser.HTMLParser):
1468
+ def __init__(self):
1469
+ super().__init__()
1470
+ self.parts: list[str] = []
1471
+ self._skip_count = 0 # counter for nested skip-tags
1472
+ self._skip_tags = {"script", "style", "noscript", "iframe",
1473
+ "nav", "footer", "header", "aside"}
1474
+ self._block_tags = {"div", "p", "h1", "h2", "h3", "h4", "h5",
1475
+ "h6", "li", "tr", "section", "article",
1476
+ "pre", "blockquote", "table", "ul", "ol",
1477
+ "dl", "br", "hr"}
1478
+
1479
+ def handle_starttag(self, tag, attrs):
1480
+ tag = tag.lower()
1481
+ if tag in self._skip_tags:
1482
+ self._skip_count += 1
1483
+ elif tag in self._block_tags:
1484
+ self.parts.append("\n")
1485
+
1486
+ def handle_endtag(self, tag):
1487
+ tag = tag.lower()
1488
+ if tag in self._skip_tags and self._skip_count > 0:
1489
+ self._skip_count -= 1
1490
+ elif tag in self._block_tags:
1491
+ self.parts.append("\n")
1492
+
1493
+ def handle_data(self, data):
1494
+ if self._skip_count == 0:
1495
+ text = data.strip()
1496
+ if text:
1497
+ self.parts.append(text + " ")
1498
+
1499
+ try:
1500
+ extractor = _TextExtractor()
1501
+ extractor.feed(html_text)
1502
+ raw = "".join(extractor.parts)
1503
+ except Exception:
1504
+ # Fallback: regex strip
1505
+ raw = re.sub(r'<script[^>]*>.*?</script>', '', html_text, flags=re.DOTALL | re.IGNORECASE)
1506
+ raw = re.sub(r'<style[^>]*>.*?</style>', '', raw, flags=re.DOTALL | re.IGNORECASE)
1507
+ raw = re.sub(r'<[^>]+>', ' ', raw)
1508
+ raw = html.unescape(raw)
1509
+
1510
+ # Collapse whitespace
1511
+ raw = re.sub(r'[ \t]+', ' ', raw)
1512
+ raw = re.sub(r'\n{3,}', '\n\n', raw)
1513
+ return raw.strip()
1514
+
1515
+
1516
+ # ── Factory ──────────────────────────────────────────────────────────────────
1517
+
1518
+ def create_tool_executor(workspace_dir: str | None = None) -> ToolExecutor:
1519
+ """Create a tool executor with the given workspace."""
1520
+ from .config import AgentConfig
1521
+ if workspace_dir:
1522
+ cfg = AgentConfig(workspace_dir=workspace_dir)
1523
+ return ToolExecutor(cfg)
1524
+ return ToolExecutor()
1525
+
1526
+
1527
+ # ── CST Renamers (for rename_symbol tool) ───────────────────────────────────
1528
+ # These use libcst to safely rename symbols without touching strings or comments.
1529
+
1530
+ class _SymbolRenamer:
1531
+ """Base class for CST symbol renamers — shared leave_Name/leave_Call logic."""
1532
+
1533
+ def __init__(self, old_name: str, new_name: str):
1534
+ self.old = old_name
1535
+ self.new = new_name
1536
+ self.changes = 0
1537
+
1538
+ def leave_Name(self, original_node, updated_node):
1539
+ import libcst as cst
1540
+ if isinstance(updated_node, cst.Name) and updated_node.value == self.old:
1541
+ self.changes += 1
1542
+ return updated_node.with_changes(value=self.new)
1543
+ return updated_node
1544
+
1545
+ def leave_Call(self, original_node, updated_node):
1546
+ import libcst as cst
1547
+ if isinstance(updated_node.func, cst.Name) and updated_node.func.value == self.old:
1548
+ self.changes += 1
1549
+ return updated_node.with_changes(func=cst.Name(value=self.new))
1550
+ return updated_node
1551
+
1552
+
1553
+ class _VariableRenamer(_SymbolRenamer):
1554
+ """Rename variable references — names only, not touching strings/comments."""
1555
+
1556
+
1557
+ class _FunctionRenamer(_SymbolRenamer):
1558
+ """Rename function/method definitions and calls."""
1559
+
1560
+ def leave_FunctionDef(self, original_node, updated_node):
1561
+ import libcst as cst
1562
+ if updated_node.name.value == self.old:
1563
+ self.changes += 1
1564
+ return updated_node.with_changes(name=cst.Name(value=self.new))
1565
+ return updated_node
1566
+
1567
+
1568
+ class _ClassRenamer(_SymbolRenamer):
1569
+ """Rename class definitions and constructor calls."""
1570
+
1571
+ def leave_ClassDef(self, original_node, updated_node):
1572
+ import libcst as cst
1573
+ if updated_node.name.value == self.old:
1574
+ self.changes += 1
1575
+ return updated_node.with_changes(name=cst.Name(value=self.new))
1576
+ return updated_node