ata-coder 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ata_coder/__init__.py +1 -0
- ata_coder/agent.py +874 -0
- ata_coder/agent_compact.py +190 -0
- ata_coder/agent_controller.py +218 -0
- ata_coder/agent_extension.py +69 -0
- ata_coder/agent_routing.py +105 -0
- ata_coder/agent_subsystems.py +72 -0
- ata_coder/agent_tools.py +318 -0
- ata_coder/agent_undo.py +63 -0
- ata_coder/anthropic_client.py +465 -0
- ata_coder/change_tracker.py +368 -0
- ata_coder/clawd_integration.py +574 -0
- ata_coder/commands/__init__.py +128 -0
- ata_coder/commands/_core.py +184 -0
- ata_coder/commands/_safety.py +95 -0
- ata_coder/commands/_settings.py +241 -0
- ata_coder/commands/_workflow.py +451 -0
- ata_coder/commands.py +974 -0
- ata_coder/config.py +257 -0
- ata_coder/core/__init__.py +35 -0
- ata_coder/core/events.py +73 -0
- ata_coder/core/queue.py +85 -0
- ata_coder/core/state.py +17 -0
- ata_coder/event_queue.py +5 -0
- ata_coder/extension.py +654 -0
- ata_coder/extensions/__init__.py +1 -0
- ata_coder/extensions/hello_skill.py +47 -0
- ata_coder/fool_proof.py +295 -0
- ata_coder/git_workflow.py +371 -0
- ata_coder/gui.py +511 -0
- ata_coder/llm_client.py +543 -0
- ata_coder/main.py +814 -0
- ata_coder/mcp_client.py +1095 -0
- ata_coder/memory.py +539 -0
- ata_coder/model_registry.py +134 -0
- ata_coder/model_router.py +105 -0
- ata_coder/permissions.py +274 -0
- ata_coder/privilege.py +464 -0
- ata_coder/project.py +273 -0
- ata_coder/prompt_template.py +423 -0
- ata_coder/prompts/auto-mode.md +7 -0
- ata_coder/prompts/coding-rules.md +40 -0
- ata_coder/prompts/execution-guardrails.md +14 -0
- ata_coder/prompts/memory-system.md +24 -0
- ata_coder/prompts/output-style.md +23 -0
- ata_coder/prompts/safety.md +17 -0
- ata_coder/prompts/slash-commands.md +24 -0
- ata_coder/prompts/sub-agents.md +38 -0
- ata_coder/prompts/system-reminders.md +17 -0
- ata_coder/prompts/system.md +105 -0
- ata_coder/prompts/tool-policy.md +46 -0
- ata_coder/repl_theme.py +99 -0
- ata_coder/repl_tracker.py +89 -0
- ata_coder/repl_ui.py +1214 -0
- ata_coder/safety_guard.py +434 -0
- ata_coder/self_correct.py +346 -0
- ata_coder/server.py +882 -0
- ata_coder/server_session.py +159 -0
- ata_coder/server_shell.py +129 -0
- ata_coder/session.py +431 -0
- ata_coder/settings.py +439 -0
- ata_coder/setup_wizard.py +136 -0
- ata_coder/skill_extension.py +92 -0
- ata_coder/skills/architect/SKILL.md +42 -0
- ata_coder/skills/code-reviewer/SKILL.md +37 -0
- ata_coder/skills/codecraft/SKILL.md +452 -0
- ata_coder/skills/debugger/SKILL.md +45 -0
- ata_coder/skills/doc-writer/SKILL.md +36 -0
- ata_coder/skills/general-coder/SKILL.md +76 -0
- ata_coder/skills/math-calculator/README.md +40 -0
- ata_coder/skills/math-calculator/SKILL.md +59 -0
- ata_coder/skills/math-calculator/handler.py +103 -0
- ata_coder/skills/math-calculator/prompts/system.md +8 -0
- ata_coder/skills/math-calculator/requirements.txt +2 -0
- ata_coder/skills/math-calculator/resources/constants.json +8 -0
- ata_coder/skills/math-calculator/tests/test_handler.py +53 -0
- ata_coder/skills/security-auditor/SKILL.md +40 -0
- ata_coder/skills/test-writer/SKILL.md +36 -0
- ata_coder/skills/weather-skill/README.md +45 -0
- ata_coder/skills/weather-skill/handler.py +76 -0
- ata_coder/skills/weather-skill/manifest.json +48 -0
- ata_coder/skills/weather-skill/prompts/system_prompt.txt +9 -0
- ata_coder/skills/weather-skill/prompts/user_prompt_template.txt +3 -0
- ata_coder/skills/weather-skill/requirements.txt +1 -0
- ata_coder/skills/weather-skill/resources/city_list.json +17 -0
- ata_coder/skills/weather-skill/resources/error_messages.json +7 -0
- ata_coder/skills/weather-skill/tests/test_handler.py +28 -0
- ata_coder/skills/weather-skill/weather_utils.py +50 -0
- ata_coder/skills.py +1014 -0
- ata_coder/sub_agent.py +273 -0
- ata_coder/sub_agent_manager.py +203 -0
- ata_coder/system_prompt_builder.py +146 -0
- ata_coder/task_planner.py +391 -0
- ata_coder/terminal.py +318 -0
- ata_coder/test_runner.py +219 -0
- ata_coder/thread_supervisor.py +195 -0
- ata_coder/tool_defs.py +335 -0
- ata_coder/tools/__init__.py +11 -0
- ata_coder/tools/definitions.py +335 -0
- ata_coder/tools/executor.py +1036 -0
- ata_coder/tools/result.py +26 -0
- ata_coder/tools/subagent.py +332 -0
- ata_coder/tools/web.py +361 -0
- ata_coder/tools.py +1576 -0
- ata_coder/types.py +92 -0
- ata_coder/utils.py +113 -0
- ata_coder/web/css/style.css +180 -0
- ata_coder/web/index.html +84 -0
- ata_coder/web/js/app.js +489 -0
- ata_coder/web/package-lock.json +25 -0
- ata_coder/web/package.json +10 -0
- ata_coder/web/tsconfig.json +13 -0
- ata_coder-2.4.2.dist-info/METADATA +799 -0
- ata_coder-2.4.2.dist-info/RECORD +118 -0
- ata_coder-2.4.2.dist-info/WHEEL +5 -0
- ata_coder-2.4.2.dist-info/entry_points.txt +2 -0
- ata_coder-2.4.2.dist-info/licenses/LICENSE +21 -0
- ata_coder-2.4.2.dist-info/top_level.txt +1 -0
ata_coder/terminal.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Terminal color & formatting — unified across all outputs.
|
|
3
|
+
|
|
4
|
+
Auto-detects capabilities and provides a single API for colored output.
|
|
5
|
+
Works on Windows (via colorama), Linux, macOS (native ANSI).
|
|
6
|
+
Falls back gracefully when color is not available.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
# ── Try to import color libraries ────────────────────────────────────────────
|
|
13
|
+
|
|
14
|
+
HAS_RICH = False
|
|
15
|
+
HAS_COLORAMA = False
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from rich.console import Console
|
|
19
|
+
from rich.text import Text
|
|
20
|
+
from rich.style import Style
|
|
21
|
+
from rich.theme import Theme
|
|
22
|
+
HAS_RICH = True
|
|
23
|
+
except ImportError:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
from colorama import init, Fore, Back, Style as CStyle
|
|
28
|
+
init()
|
|
29
|
+
HAS_COLORAMA = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
35
|
+
# Color registry
|
|
36
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
37
|
+
|
|
38
|
+
class Ansi:
|
|
39
|
+
"""ANSI escape codes — always available as fallback."""
|
|
40
|
+
|
|
41
|
+
RESET = "\033[0m"
|
|
42
|
+
BOLD = "\033[1m"
|
|
43
|
+
DIM = "\033[2m"
|
|
44
|
+
ITALIC = "\033[3m"
|
|
45
|
+
UNDERLINE = "\033[4m"
|
|
46
|
+
|
|
47
|
+
# 16-color standard
|
|
48
|
+
BLACK = "\033[30m"
|
|
49
|
+
RED = "\033[31m"
|
|
50
|
+
GREEN = "\033[32m"
|
|
51
|
+
YELLOW = "\033[33m"
|
|
52
|
+
BLUE = "\033[34m"
|
|
53
|
+
MAGENTA = "\033[35m"
|
|
54
|
+
CYAN = "\033[36m"
|
|
55
|
+
WHITE = "\033[37m"
|
|
56
|
+
|
|
57
|
+
# Bright variants
|
|
58
|
+
GRAY = "\033[90m"
|
|
59
|
+
BRIGHT_RED = "\033[91m"
|
|
60
|
+
BRIGHT_GREEN = "\033[92m"
|
|
61
|
+
BRIGHT_YELLOW = "\033[93m"
|
|
62
|
+
BRIGHT_BLUE = "\033[94m"
|
|
63
|
+
BRIGHT_MAGENTA = "\033[95m"
|
|
64
|
+
BRIGHT_CYAN = "\033[96m"
|
|
65
|
+
BRIGHT_WHITE = "\033[97m"
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
69
|
+
# Semantic color tokens
|
|
70
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
71
|
+
|
|
72
|
+
# Maps semantic names → ANSI codes (can be overridden by Rich theme)
|
|
73
|
+
ANSI_THEME = {
|
|
74
|
+
# Status
|
|
75
|
+
"ok": Ansi.GREEN,
|
|
76
|
+
"fail": Ansi.RED,
|
|
77
|
+
"warn": Ansi.YELLOW,
|
|
78
|
+
"info": Ansi.CYAN,
|
|
79
|
+
"debug": Ansi.DIM,
|
|
80
|
+
|
|
81
|
+
# Severity
|
|
82
|
+
"critical": Ansi.BRIGHT_RED + Ansi.BOLD,
|
|
83
|
+
"danger": Ansi.BRIGHT_RED,
|
|
84
|
+
"caution": Ansi.BRIGHT_YELLOW,
|
|
85
|
+
"safe": Ansi.GREEN,
|
|
86
|
+
|
|
87
|
+
# Categories
|
|
88
|
+
"tool": Ansi.CYAN,
|
|
89
|
+
"file": Ansi.BLUE,
|
|
90
|
+
"cmd": Ansi.MAGENTA,
|
|
91
|
+
"model": Ansi.BRIGHT_MAGENTA,
|
|
92
|
+
"skill": Ansi.YELLOW,
|
|
93
|
+
"memory": Ansi.GREEN,
|
|
94
|
+
"git": Ansi.BRIGHT_RED,
|
|
95
|
+
|
|
96
|
+
# UI elements
|
|
97
|
+
"prompt": Ansi.BRIGHT_CYAN + Ansi.BOLD,
|
|
98
|
+
"heading": Ansi.BOLD + Ansi.BRIGHT_CYAN,
|
|
99
|
+
"border": Ansi.GRAY,
|
|
100
|
+
"dim": Ansi.DIM,
|
|
101
|
+
"bold": Ansi.BOLD,
|
|
102
|
+
"reset": Ansi.RESET,
|
|
103
|
+
|
|
104
|
+
# Diff
|
|
105
|
+
"diff_add": Ansi.GREEN,
|
|
106
|
+
"diff_del": Ansi.RED,
|
|
107
|
+
"diff_hdr": Ansi.CYAN + Ansi.BOLD,
|
|
108
|
+
"diff_ctx": Ansi.DIM,
|
|
109
|
+
|
|
110
|
+
# Tokens / cost
|
|
111
|
+
"token_low": Ansi.GREEN,
|
|
112
|
+
"token_mid": Ansi.YELLOW,
|
|
113
|
+
"token_high": Ansi.BRIGHT_RED,
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
118
|
+
# Terminal capabilities
|
|
119
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
120
|
+
|
|
121
|
+
def _detect_color_support() -> bool:
|
|
122
|
+
"""Detect if the terminal supports color."""
|
|
123
|
+
if os.environ.get("NO_COLOR"):
|
|
124
|
+
return False
|
|
125
|
+
if os.environ.get("FORCE_COLOR"):
|
|
126
|
+
return True
|
|
127
|
+
if not sys.stdout.isatty():
|
|
128
|
+
# Check specifically for common CI systems
|
|
129
|
+
if os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"):
|
|
130
|
+
return True
|
|
131
|
+
return False
|
|
132
|
+
if sys.platform == "win32":
|
|
133
|
+
# Windows Terminal, ConEmu, etc. support ANSI
|
|
134
|
+
return "WT_SESSION" in os.environ or os.environ.get("TERM") == "xterm-256color"
|
|
135
|
+
return True
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
_COLOR_ENABLED = _detect_color_support()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def color_enabled() -> bool:
|
|
142
|
+
return _COLOR_ENABLED
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def enable_color():
|
|
146
|
+
global _COLOR_ENABLED
|
|
147
|
+
_COLOR_ENABLED = True
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def disable_color():
|
|
151
|
+
global _COLOR_ENABLED
|
|
152
|
+
_COLOR_ENABLED = False
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
156
|
+
# Public API
|
|
157
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
158
|
+
|
|
159
|
+
def style(text: str, token: str = "") -> str:
|
|
160
|
+
"""Apply a semantic style to text. Returns the styled string."""
|
|
161
|
+
if not _COLOR_ENABLED:
|
|
162
|
+
return text
|
|
163
|
+
code = ANSI_THEME.get(token, "")
|
|
164
|
+
if not code:
|
|
165
|
+
return text
|
|
166
|
+
return f"{code}{text}{Ansi.RESET}"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def ok(text: str) -> str: return style(text, "ok")
|
|
170
|
+
def fail(text: str) -> str: return style(text, "fail")
|
|
171
|
+
def warn(text: str) -> str: return style(text, "warn")
|
|
172
|
+
def info(text: str) -> str: return style(text, "info")
|
|
173
|
+
def dim(text: str) -> str: return style(text, "dim")
|
|
174
|
+
def bold(text: str) -> str: return style(text, "bold")
|
|
175
|
+
def heading(text: str) -> str: return style(text, "heading")
|
|
176
|
+
def tool(text: str) -> str: return style(text, "tool")
|
|
177
|
+
def file(text: str) -> str: return style(text, "file")
|
|
178
|
+
def cmd(text: str) -> str: return style(text, "cmd")
|
|
179
|
+
|
|
180
|
+
def diff_add(text: str) -> str: return style(text, "diff_add")
|
|
181
|
+
def diff_del(text: str) -> str: return style(text, "diff_del")
|
|
182
|
+
def diff_hdr(text: str) -> str: return style(text, "diff_hdr")
|
|
183
|
+
|
|
184
|
+
def critical(text: str) -> str: return style(text, "critical")
|
|
185
|
+
def danger(text: str) -> str: return style(text, "danger")
|
|
186
|
+
def safe(text: str) -> str: return style(text, "safe")
|
|
187
|
+
|
|
188
|
+
def token_bar(pct: float, width: int = 20) -> str:
|
|
189
|
+
"""Render a colored token usage bar."""
|
|
190
|
+
if not _COLOR_ENABLED:
|
|
191
|
+
filled = int(pct / 100 * width)
|
|
192
|
+
return "█" * filled + "░" * (width - filled)
|
|
193
|
+
filled = int(pct / 100 * width)
|
|
194
|
+
if pct < 50:
|
|
195
|
+
color = Ansi.GREEN
|
|
196
|
+
elif pct < 80:
|
|
197
|
+
color = Ansi.YELLOW
|
|
198
|
+
else:
|
|
199
|
+
color = Ansi.BRIGHT_RED
|
|
200
|
+
return f"{color}{'█' * filled}{Ansi.DIM}{'░' * (width - filled)}{Ansi.RESET} {pct:.0f}%"
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
204
|
+
# Rich console (when available)
|
|
205
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
206
|
+
|
|
207
|
+
_rich_console: "Console | None" = None
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def get_rich_console() -> "Console | None":
|
|
211
|
+
"""Get or create a Rich Console instance."""
|
|
212
|
+
global _rich_console
|
|
213
|
+
if not HAS_RICH:
|
|
214
|
+
return None
|
|
215
|
+
if _rich_console is None:
|
|
216
|
+
_rich_console = Console(force_terminal=_COLOR_ENABLED)
|
|
217
|
+
return _rich_console
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def rich_print(*args, **kwargs):
|
|
221
|
+
"""Print via Rich if available, else plain print."""
|
|
222
|
+
c = get_rich_console()
|
|
223
|
+
if c:
|
|
224
|
+
c.print(*args, **kwargs)
|
|
225
|
+
else:
|
|
226
|
+
print(*args, **kwargs)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
230
|
+
# Convenience printers
|
|
231
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
232
|
+
|
|
233
|
+
def print_ok(msg: str):
|
|
234
|
+
print(f" {ok('[OK]')} {dim(msg)}")
|
|
235
|
+
|
|
236
|
+
def print_fail(msg: str):
|
|
237
|
+
print(f" {fail('[FAIL]')} {msg}")
|
|
238
|
+
|
|
239
|
+
def print_warn(msg: str):
|
|
240
|
+
print(f" {warn('[WARN]')} {msg}")
|
|
241
|
+
|
|
242
|
+
def print_info(msg: str):
|
|
243
|
+
print(f" {info('[i]')} {dim(msg)}")
|
|
244
|
+
|
|
245
|
+
def print_tool(name: str, args: str = ""):
|
|
246
|
+
print(f" {tool(name)} {dim(args)}" if args else f" {tool(name)}")
|
|
247
|
+
|
|
248
|
+
def print_file(path: str):
|
|
249
|
+
print(f" {file(path)}")
|
|
250
|
+
|
|
251
|
+
def print_heading(text: str):
|
|
252
|
+
print(f"\n{heading(text)}")
|
|
253
|
+
|
|
254
|
+
def print_separator(char: str = "─", width: int = 60):
|
|
255
|
+
print(dim(char * width))
|
|
256
|
+
|
|
257
|
+
def print_diff(old: str, new: str, filepath: str = ""):
|
|
258
|
+
"""Print a colored unified diff."""
|
|
259
|
+
import difflib
|
|
260
|
+
diff = difflib.unified_diff(
|
|
261
|
+
old.splitlines(keepends=True),
|
|
262
|
+
new.splitlines(keepends=True),
|
|
263
|
+
fromfile=f"a/{filepath}", tofile=f"b/{filepath}",
|
|
264
|
+
)
|
|
265
|
+
for line in diff:
|
|
266
|
+
line = line.rstrip("\n")
|
|
267
|
+
if line.startswith("---") or line.startswith("+++"):
|
|
268
|
+
print(dim(line))
|
|
269
|
+
elif line.startswith("@@"):
|
|
270
|
+
print(diff_hdr(line))
|
|
271
|
+
elif line.startswith("+"):
|
|
272
|
+
print(diff_add(line))
|
|
273
|
+
elif line.startswith("-"):
|
|
274
|
+
print(diff_del(line))
|
|
275
|
+
else:
|
|
276
|
+
print(dim(line))
|
|
277
|
+
|
|
278
|
+
def print_banner(title: str, width: int = 60):
|
|
279
|
+
"""Print a colored banner."""
|
|
280
|
+
print()
|
|
281
|
+
print(bold("╔" + "═" * (width - 2) + "╗"))
|
|
282
|
+
pad = (width - 2 - len(title)) // 2
|
|
283
|
+
print(bold("║") + " " * pad + title + " " * (width - 2 - pad - len(title)) + bold("║"))
|
|
284
|
+
print(bold("╚" + "═" * (width - 2) + "╝"))
|
|
285
|
+
print()
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
289
|
+
# Status line
|
|
290
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
291
|
+
|
|
292
|
+
def status_line(tokens: int = 0, max_tokens: int = 0, tools: int = 0,
|
|
293
|
+
max_tools: int = 0, cost: float = 0, elapsed: float = 0,
|
|
294
|
+
git_status: str = "", dangerous: bool = False) -> str:
|
|
295
|
+
"""Build a colored status line."""
|
|
296
|
+
parts = []
|
|
297
|
+
if dangerous:
|
|
298
|
+
parts.append(danger("[DANGEROUS]"))
|
|
299
|
+
|
|
300
|
+
if tokens:
|
|
301
|
+
pct = min(100, tokens / max_tokens * 100) if max_tokens else 0
|
|
302
|
+
color = "token_low" if pct < 50 else ("token_mid" if pct < 80 else "token_high")
|
|
303
|
+
parts.append(f"tokens: {style(f'{tokens:,}/{max_tokens:,}', color)}")
|
|
304
|
+
parts.append(token_bar(pct, 12))
|
|
305
|
+
|
|
306
|
+
if tools:
|
|
307
|
+
parts.append(f"tools: {dim(f'{tools}/{max_tools}')}")
|
|
308
|
+
|
|
309
|
+
if cost:
|
|
310
|
+
parts.append(f"cost: {ok(f'${cost:.4f}')}")
|
|
311
|
+
|
|
312
|
+
if elapsed:
|
|
313
|
+
parts.append(f"time: {dim(f'{elapsed:.0f}s')}")
|
|
314
|
+
|
|
315
|
+
if git_status:
|
|
316
|
+
parts.append(f"git: {style(git_status, 'git')}")
|
|
317
|
+
|
|
318
|
+
return " | ".join(parts)
|
ata_coder/test_runner.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test runner + auto-fix loop. Detects test framework, runs tests,
|
|
3
|
+
parses failures, feeds errors back to agent for automatic fixing.
|
|
4
|
+
|
|
5
|
+
Supports: pytest, unittest, jest, vitest, mocha, go test, cargo test, phpunit, rspec.
|
|
6
|
+
|
|
7
|
+
Commands (added to registry):
|
|
8
|
+
/test — Run tests in current project
|
|
9
|
+
/test-fix — Run tests, auto-fix failures up to 3 times
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
import re
|
|
15
|
+
import subprocess
|
|
16
|
+
import time
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
25
|
+
# Test framework detection
|
|
26
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
27
|
+
|
|
28
|
+
FRAMEWORK_DETECTORS = [
|
|
29
|
+
("pytest", ["pytest.ini", "pyproject.toml", "conftest.py", "tox.ini"],
|
|
30
|
+
"python -m pytest -v --tb=short 2>&1"),
|
|
31
|
+
("unittest", ["test_*.py", "*_test.py"],
|
|
32
|
+
"python -m unittest discover -v 2>&1"),
|
|
33
|
+
("jest", ["jest.config.js", "jest.config.ts", "jest.config.mjs"],
|
|
34
|
+
"npx jest --verbose 2>&1"),
|
|
35
|
+
("vitest", ["vitest.config.js", "vitest.config.ts"],
|
|
36
|
+
"npx vitest --reporter verbose 2>&1"),
|
|
37
|
+
("mocha", [".mocharc.js", ".mocharc.json", ".mocharc.yml"],
|
|
38
|
+
"npx mocha --reporter spec 2>&1"),
|
|
39
|
+
("go test", ["go.mod"], "go test ./... -v 2>&1"),
|
|
40
|
+
("cargo test", ["Cargo.toml"], "cargo test 2>&1"),
|
|
41
|
+
("phpunit", ["phpunit.xml", "phpunit.xml.dist"], "phpunit 2>&1"),
|
|
42
|
+
("rspec", ["spec/", ".rspec"], "bundle exec rspec 2>&1"),
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def detect_framework(workspace: str | Path) -> tuple[str, str] | None:
|
|
47
|
+
"""Detect test framework and return (name, command). Only scans 2 levels deep."""
|
|
48
|
+
root = Path(workspace)
|
|
49
|
+
all_files = set()
|
|
50
|
+
# Only scan root + 2 levels deep to avoid slow rglob in large projects
|
|
51
|
+
for depth in range(3):
|
|
52
|
+
pattern = "*/" * depth + "*"
|
|
53
|
+
for entry in root.glob(pattern):
|
|
54
|
+
if entry.is_file():
|
|
55
|
+
name = entry.name
|
|
56
|
+
rel = str(entry.relative_to(root))
|
|
57
|
+
all_files.add(name)
|
|
58
|
+
all_files.add(rel)
|
|
59
|
+
|
|
60
|
+
for name, indicators, cmd in FRAMEWORK_DETECTORS:
|
|
61
|
+
for ind in indicators:
|
|
62
|
+
if "*" in ind:
|
|
63
|
+
import fnmatch
|
|
64
|
+
if any(fnmatch.fnmatch(f, ind) for f in all_files):
|
|
65
|
+
return name, cmd
|
|
66
|
+
elif ind in all_files or any(ind in f for f in all_files):
|
|
67
|
+
return name, cmd
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
72
|
+
# Test result parsing
|
|
73
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class TestResult:
|
|
77
|
+
framework: str
|
|
78
|
+
passed: int = 0
|
|
79
|
+
failed: int = 0
|
|
80
|
+
errors: int = 0
|
|
81
|
+
skipped: int = 0
|
|
82
|
+
duration: float = 0.0
|
|
83
|
+
output: str = ""
|
|
84
|
+
failures: list[str] = field(default_factory=list)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def ok(self) -> bool:
|
|
88
|
+
return self.failed == 0 and self.errors == 0
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def parse_results(framework: str, output: str) -> TestResult:
|
|
92
|
+
"""Parse test output into structured result."""
|
|
93
|
+
result = TestResult(framework=framework, output=output)
|
|
94
|
+
|
|
95
|
+
if framework == "pytest":
|
|
96
|
+
m = re.search(r"(\d+) passed", output)
|
|
97
|
+
if m: result.passed = int(m.group(1))
|
|
98
|
+
m = re.search(r"(\d+) failed", output)
|
|
99
|
+
if m: result.failed = int(m.group(1))
|
|
100
|
+
m = re.search(r"(\d+) error", output)
|
|
101
|
+
if m: result.errors = int(m.group(1))
|
|
102
|
+
# Extract failure blocks
|
|
103
|
+
failures = re.findall(r"FAILED.*?\n(.*?)(?:\n_+|\n=+|\Z)", output, re.DOTALL)
|
|
104
|
+
result.failures = [f.strip()[:500] for f in failures]
|
|
105
|
+
|
|
106
|
+
elif framework in ("jest", "vitest", "mocha"):
|
|
107
|
+
m = re.search(r"Tests:\s+(\d+) passed.*?(\d+) failed.*?(\d+) total", output, re.DOTALL)
|
|
108
|
+
if m: result.passed, result.failed = int(m.group(1)), int(m.group(2))
|
|
109
|
+
failures = re.findall(r"●.*?\n(.*?)(?:\n\n|\Z)", output, re.DOTALL)
|
|
110
|
+
result.failures = [f.strip()[:500] for f in failures]
|
|
111
|
+
|
|
112
|
+
elif framework == "go test":
|
|
113
|
+
result.failed = output.count("FAIL")
|
|
114
|
+
result.passed = output.count("PASS") - output.count("FAIL")
|
|
115
|
+
failures = re.findall(r"--- FAIL.*?\n(.*?)(?:\n---|\Z)", output, re.DOTALL)
|
|
116
|
+
result.failures = [f.strip()[:500] for f in failures]
|
|
117
|
+
|
|
118
|
+
elif framework == "cargo test":
|
|
119
|
+
m = re.search(r"test result:.*?(\d+) passed.*?(\d+) failed", output)
|
|
120
|
+
if m: result.passed, result.failed = int(m.group(1)), int(m.group(2))
|
|
121
|
+
failures = re.findall(r"thread '.*?' panicked.*?:\n(.*)", output)
|
|
122
|
+
result.failures = [f.strip()[:500] for f in failures]
|
|
123
|
+
|
|
124
|
+
else:
|
|
125
|
+
# Generic: count "FAIL" and "PASS" lines
|
|
126
|
+
result.failed = output.count("FAIL")
|
|
127
|
+
result.passed = max(0, output.count("ok") - output.count("not ok"))
|
|
128
|
+
|
|
129
|
+
return result
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
133
|
+
# Runner
|
|
134
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
135
|
+
|
|
136
|
+
def run_tests(workspace: str | Path, command: str | None = None) -> TestResult | None:
|
|
137
|
+
"""Run tests and return parsed result."""
|
|
138
|
+
root = Path(workspace)
|
|
139
|
+
if not command:
|
|
140
|
+
detected = detect_framework(root)
|
|
141
|
+
if not detected:
|
|
142
|
+
return None
|
|
143
|
+
_, command = detected
|
|
144
|
+
|
|
145
|
+
logger.info("Running: %s", command)
|
|
146
|
+
start = time.time()
|
|
147
|
+
try:
|
|
148
|
+
proc = subprocess.run(
|
|
149
|
+
command, shell=True, capture_output=True, text=True,
|
|
150
|
+
timeout=120, cwd=str(root),
|
|
151
|
+
)
|
|
152
|
+
output = proc.stdout + "\n" + proc.stderr
|
|
153
|
+
except subprocess.TimeoutExpired:
|
|
154
|
+
output = "Test run timed out after 120s"
|
|
155
|
+
except Exception as e:
|
|
156
|
+
output = str(e)
|
|
157
|
+
|
|
158
|
+
elapsed = time.time() - start
|
|
159
|
+
|
|
160
|
+
# Detect framework from output
|
|
161
|
+
fw = "pytest" if "pytest" in command else (
|
|
162
|
+
"jest" if "jest" in command else (
|
|
163
|
+
"go test" if "go test" in command else "generic"
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
result = parse_results(fw, output)
|
|
167
|
+
result.duration = elapsed
|
|
168
|
+
if not result.framework:
|
|
169
|
+
result.framework = fw
|
|
170
|
+
return result
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
174
|
+
# Auto-fix loop
|
|
175
|
+
# ═══════════════════════════════════════════════════════════════════════════════
|
|
176
|
+
|
|
177
|
+
def auto_fix_loop(
|
|
178
|
+
workspace: str | Path,
|
|
179
|
+
agent,
|
|
180
|
+
max_retries: int = 3,
|
|
181
|
+
command: str | None = None,
|
|
182
|
+
) -> tuple[bool, str]:
|
|
183
|
+
"""
|
|
184
|
+
Run tests → feed failures to agent → fix → repeat until pass or max retries.
|
|
185
|
+
Returns (passed, summary).
|
|
186
|
+
"""
|
|
187
|
+
results = []
|
|
188
|
+
for attempt in range(max_retries):
|
|
189
|
+
print(f"\n[Test attempt {attempt + 1}/{max_retries}]")
|
|
190
|
+
result = run_tests(workspace, command)
|
|
191
|
+
if result is None:
|
|
192
|
+
return False, "No test framework detected."
|
|
193
|
+
|
|
194
|
+
results.append(result)
|
|
195
|
+
|
|
196
|
+
if result.ok:
|
|
197
|
+
return True, f"All {result.passed} tests passed in {result.duration:.1f}s"
|
|
198
|
+
|
|
199
|
+
if attempt == max_retries - 1:
|
|
200
|
+
break
|
|
201
|
+
|
|
202
|
+
# Feed failures to agent
|
|
203
|
+
failure_text = "\n\n".join(result.failures[:3])
|
|
204
|
+
if not failure_text:
|
|
205
|
+
failure_text = result.output[-1000:]
|
|
206
|
+
|
|
207
|
+
task = (
|
|
208
|
+
f"The tests failed. Here are the failures:\n\n"
|
|
209
|
+
f"```\n{failure_text}\n```\n\n"
|
|
210
|
+
f"Read the relevant source files, fix the issues, and make the tests pass. "
|
|
211
|
+
f"Be minimal — only fix what's broken."
|
|
212
|
+
)
|
|
213
|
+
print(f" Failures: {result.failed} failed, {result.errors} errors")
|
|
214
|
+
print(f" Asking agent to fix...")
|
|
215
|
+
agent.run(task, stream=True)
|
|
216
|
+
|
|
217
|
+
# All retries exhausted
|
|
218
|
+
summary = f"Failed after {max_retries} attempts. Last: {results[-1].failed} failures, {results[-1].errors} errors."
|
|
219
|
+
return False, summary
|