unitsauce 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- unitsauce/__init__.py +0 -0
- unitsauce/analysis.py +179 -0
- unitsauce/fixer.py +219 -0
- unitsauce/github.py +92 -0
- unitsauce/main.py +82 -0
- unitsauce/models.py +58 -0
- unitsauce/output.py +114 -0
- unitsauce/prompts.py +53 -0
- unitsauce/utils.py +25 -0
- unitsauce-0.1.0.dist-info/METADATA +48 -0
- unitsauce-0.1.0.dist-info/RECORD +14 -0
- unitsauce-0.1.0.dist-info/WHEEL +5 -0
- unitsauce-0.1.0.dist-info/entry_points.txt +2 -0
- unitsauce-0.1.0.dist-info/top_level.txt +1 -0
unitsauce/__init__.py
ADDED
|
File without changes
|
unitsauce/analysis.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import difflib
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import re
|
|
7
|
+
import subprocess
|
|
8
|
+
|
|
9
|
+
from rich.syntax import Syntax
|
|
10
|
+
from rich.panel import Panel
|
|
11
|
+
from rich.spinner import Spinner
|
|
12
|
+
from rich.live import Live
|
|
13
|
+
from .utils import console
|
|
14
|
+
|
|
15
|
+
def show_diff(original, new, file_name):
|
|
16
|
+
diff = difflib.unified_diff(
|
|
17
|
+
original.splitlines(keepends=True),
|
|
18
|
+
new.splitlines(keepends=True),
|
|
19
|
+
fromfile=f"before/{file_name}",
|
|
20
|
+
tofile=f"after/{file_name}"
|
|
21
|
+
)
|
|
22
|
+
diff_text = ''.join(diff)
|
|
23
|
+
if diff_text:
|
|
24
|
+
syntax = Syntax(diff_text, "diff", theme="monokai", line_numbers=True)
|
|
25
|
+
console.print(Panel(syntax, title="Changes", border_style="green"))
|
|
26
|
+
|
|
27
|
+
return diff_text
|
|
28
|
+
|
|
29
|
+
def changed_lines(diff):
|
|
30
|
+
lines = []
|
|
31
|
+
new_ln = None
|
|
32
|
+
|
|
33
|
+
for line in diff.splitlines():
|
|
34
|
+
if line.startswith("@@"):
|
|
35
|
+
m = re.search(r"\+(\d+)", line)
|
|
36
|
+
new_ln = int(m.group(1)) - 1
|
|
37
|
+
continue
|
|
38
|
+
|
|
39
|
+
if new_ln is None:
|
|
40
|
+
continue
|
|
41
|
+
if line.startswith("+") and not line.startswith("+++"):
|
|
42
|
+
new_ln += 1
|
|
43
|
+
lines.append(new_ln)
|
|
44
|
+
elif line.startswith("-") and not line.startswith("---"):
|
|
45
|
+
pass
|
|
46
|
+
else:
|
|
47
|
+
new_ln += 1
|
|
48
|
+
|
|
49
|
+
return lines
|
|
50
|
+
|
|
51
|
+
def get_failing_tests(path):
|
|
52
|
+
with open(path + "/report.json") as f:
|
|
53
|
+
report = json.load(f)
|
|
54
|
+
|
|
55
|
+
failures = []
|
|
56
|
+
for test in report["tests"]:
|
|
57
|
+
if test["outcome"] == "failed":
|
|
58
|
+
failures.append({
|
|
59
|
+
"file": test["nodeid"].split("::")[0],
|
|
60
|
+
"function": test["nodeid"].split("::")[-1],
|
|
61
|
+
"error": test["call"]["crash"]["message"],
|
|
62
|
+
})
|
|
63
|
+
|
|
64
|
+
return failures
|
|
65
|
+
|
|
66
|
+
def get_git_diff(path):
|
|
67
|
+
changed_files = subprocess.run(
|
|
68
|
+
["git", "diff", "--name-only", "HEAD~1", "--", ".", ":(exclude)tests"],
|
|
69
|
+
cwd=path,
|
|
70
|
+
capture_output=True,
|
|
71
|
+
text=True,
|
|
72
|
+
check=True
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return changed_files.stdout.splitlines()
|
|
76
|
+
|
|
77
|
+
def get_single_file_diff(path, changed_file_path):
|
|
78
|
+
changed_file_diff = subprocess.run(
|
|
79
|
+
["git", "diff", "HEAD~1", "--", changed_file_path],
|
|
80
|
+
cwd=path,
|
|
81
|
+
capture_output=True,
|
|
82
|
+
text=True,
|
|
83
|
+
check=True
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
return changed_file_diff.stdout
|
|
87
|
+
|
|
88
|
+
def index_file_functions(source):
|
|
89
|
+
tree = ast.parse(source)
|
|
90
|
+
funcs = []
|
|
91
|
+
|
|
92
|
+
for node in ast.walk(tree):
|
|
93
|
+
if isinstance(node, ast.FunctionDef):
|
|
94
|
+
funcs.append({
|
|
95
|
+
"name": node.name,
|
|
96
|
+
"start": node.lineno,
|
|
97
|
+
"end": node.end_lineno,
|
|
98
|
+
"node": node,
|
|
99
|
+
})
|
|
100
|
+
|
|
101
|
+
return funcs
|
|
102
|
+
|
|
103
|
+
def extract_function_source(code: str, func):
|
|
104
|
+
lines = code.splitlines()
|
|
105
|
+
return "\n".join(lines[func["start"] - 1 : func["end"]])
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def split_functions_raw(code):
|
|
109
|
+
"""Split code into functions, keeping raw text with comments."""
|
|
110
|
+
lines = code.splitlines()
|
|
111
|
+
functions = {}
|
|
112
|
+
current_name = None
|
|
113
|
+
current_lines = []
|
|
114
|
+
|
|
115
|
+
for line in lines:
|
|
116
|
+
if line.startswith('def '):
|
|
117
|
+
if current_name:
|
|
118
|
+
functions[current_name] = '\n'.join(current_lines)
|
|
119
|
+
current_name = line.split('(')[0].replace('def ', '').strip()
|
|
120
|
+
current_lines = [line]
|
|
121
|
+
elif current_name:
|
|
122
|
+
current_lines.append(line)
|
|
123
|
+
|
|
124
|
+
if current_name:
|
|
125
|
+
functions[current_name] = '\n'.join(current_lines)
|
|
126
|
+
|
|
127
|
+
return functions
|
|
128
|
+
|
|
129
|
+
def read_file_content(file, path, is_file_path=False):
|
|
130
|
+
if is_file_path:
|
|
131
|
+
with open(file, 'r', encoding='utf-8', errors='ignore') as f:
|
|
132
|
+
file_content = f.read()
|
|
133
|
+
return file, file_content
|
|
134
|
+
|
|
135
|
+
file_path = next(Path(path).rglob(file), None)
|
|
136
|
+
if not file_path:
|
|
137
|
+
return None, None
|
|
138
|
+
|
|
139
|
+
with open(file_path, encoding='utf-8', errors='ignore') as open_file:
|
|
140
|
+
file_content = open_file.read()
|
|
141
|
+
|
|
142
|
+
return file_path, file_content
|
|
143
|
+
|
|
144
|
+
def gather_context(diff, function_code):
|
|
145
|
+
lines = changed_lines(diff)
|
|
146
|
+
funcs = index_file_functions(function_code)
|
|
147
|
+
|
|
148
|
+
affected = []
|
|
149
|
+
|
|
150
|
+
for f in funcs:
|
|
151
|
+
if any(f["start"] <= ln <= f["end"] for ln in lines):
|
|
152
|
+
affected.append(extract_function_source(function_code, f))
|
|
153
|
+
|
|
154
|
+
return affected
|
|
155
|
+
|
|
156
|
+
def run_tests(path):
|
|
157
|
+
if os.path.exists(path):
|
|
158
|
+
with Live(Spinner("dots", text="Running tests..."), console=console):
|
|
159
|
+
result = subprocess.run(
|
|
160
|
+
["python", "-m", "pytest", "--tb=short", "--json-report", "--json-report-file=report.json"],
|
|
161
|
+
cwd=path,
|
|
162
|
+
capture_output=True,
|
|
163
|
+
text=True
|
|
164
|
+
)
|
|
165
|
+
console.print()
|
|
166
|
+
return result
|
|
167
|
+
|
|
168
|
+
def run_single_test(path, test_file, test_function):
|
|
169
|
+
test_id = f"{test_file}::{test_function}"
|
|
170
|
+
result = subprocess.run(
|
|
171
|
+
["python", "-m", "pytest", test_id, "-v"],
|
|
172
|
+
cwd=path,
|
|
173
|
+
capture_output=True,
|
|
174
|
+
text=True
|
|
175
|
+
)
|
|
176
|
+
if result.returncode == 0:
|
|
177
|
+
return True, ""
|
|
178
|
+
else:
|
|
179
|
+
return False, result.stderr
|
unitsauce/fixer.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
from .analysis import gather_context, get_single_file_diff, index_file_functions, read_file_content, run_single_test, run_tests, show_diff, split_functions_raw
|
|
5
|
+
from anthropic import Anthropic
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
from .models import FixContext, FixResult, VerifyContext
|
|
9
|
+
from .prompts import fix_code_prompt, fix_test_prompt
|
|
10
|
+
from rich.spinner import Spinner
|
|
11
|
+
from rich.live import Live
|
|
12
|
+
from .utils import backup_file, console
|
|
13
|
+
|
|
14
|
+
load_dotenv()
|
|
15
|
+
|
|
16
|
+
client = Anthropic(
|
|
17
|
+
api_key=os.getenv("ANTHROPIC_API_KEY")
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def call_llm(fix_prompt, functions, test_code, error_message, diff):
|
|
21
|
+
with Live(Spinner("dots", text="Generating solution..."), console=console):
|
|
22
|
+
response = client.messages.create(
|
|
23
|
+
max_tokens=8192*2,
|
|
24
|
+
messages=[
|
|
25
|
+
{
|
|
26
|
+
"role": "user",
|
|
27
|
+
"content": fix_prompt.format(function_code=functions, test_code=test_code, error_message=error_message, diff=diff),
|
|
28
|
+
}
|
|
29
|
+
],
|
|
30
|
+
model="claude-opus-4-5-20251101",
|
|
31
|
+
)
|
|
32
|
+
console.print()
|
|
33
|
+
code = response.content[0].text
|
|
34
|
+
match = re.search(r'```python(.*?)```', code, re.DOTALL)
|
|
35
|
+
|
|
36
|
+
if match:
|
|
37
|
+
return match.group(1)
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def apply_fix(file_path, generated_code):
|
|
42
|
+
backup = backup_file(file_path)
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
source = file_path.read_text()
|
|
46
|
+
file_funcs_list = index_file_functions(source)
|
|
47
|
+
file_funcs = {f["name"]: f for f in file_funcs_list}
|
|
48
|
+
|
|
49
|
+
lines = source.splitlines()
|
|
50
|
+
|
|
51
|
+
raw_funcs = split_functions_raw(generated_code)
|
|
52
|
+
|
|
53
|
+
for name, raw_text in sorted(raw_funcs.items(), key=lambda x: file_funcs.get(x[0], {}).get("start", 0), reverse=True):
|
|
54
|
+
if name in file_funcs:
|
|
55
|
+
old = file_funcs[name]
|
|
56
|
+
start = old["start"] - 1
|
|
57
|
+
end = old["end"]
|
|
58
|
+
lines[start:end] = raw_text.splitlines()
|
|
59
|
+
|
|
60
|
+
file_path.write_text("\n".join(lines))
|
|
61
|
+
|
|
62
|
+
return {"success": True, "backup": backup}
|
|
63
|
+
|
|
64
|
+
except SyntaxError as e:
|
|
65
|
+
console.print(f"[red]Claude returned invalid code: {e}[/red]")
|
|
66
|
+
shutil.copy2(backup, file_path)
|
|
67
|
+
backup.unlink()
|
|
68
|
+
return {"success": False, "backup": None}
|
|
69
|
+
|
|
70
|
+
def verify_fix(ctx: VerifyContext):
|
|
71
|
+
test_passed, new_changes_result = run_single_test(ctx.repo_path, ctx.test_file, ctx.test_function)
|
|
72
|
+
if test_passed:
|
|
73
|
+
if ctx.fix_type == "code":
|
|
74
|
+
diff = show_diff(ctx.original_function_code, ctx.generated_code, ctx.test_function)
|
|
75
|
+
else:
|
|
76
|
+
diff = show_diff(ctx.test_code, ctx.generated_code, ctx.test_function)
|
|
77
|
+
|
|
78
|
+
result = run_tests(ctx.repo_path)
|
|
79
|
+
if result.returncode == 0:
|
|
80
|
+
ctx.backup_path.unlink()
|
|
81
|
+
return {"fixed": True, "diff": diff}
|
|
82
|
+
else:
|
|
83
|
+
return {"fixed": False, "diff": ""}
|
|
84
|
+
else:
|
|
85
|
+
if new_changes_result == ctx.original_error_message:
|
|
86
|
+
shutil.copy2(ctx.backup_path, ctx.file_path)
|
|
87
|
+
ctx.backup_path.unlink()
|
|
88
|
+
console.print("[red]Fix didn't work, restored original[/red]")
|
|
89
|
+
|
|
90
|
+
else:
|
|
91
|
+
ctx.backup_path.unlink()
|
|
92
|
+
console.print("[yellow]Different error now - keeping changes[/yellow]")
|
|
93
|
+
return {"fixed": False, "diff": ""}
|
|
94
|
+
|
|
95
|
+
def fix(ctx: FixContext):
|
|
96
|
+
diff = get_single_file_diff(ctx.repo_path, ctx.function_name)
|
|
97
|
+
|
|
98
|
+
affected = gather_context(diff, ctx.function_code)
|
|
99
|
+
|
|
100
|
+
generated_code = call_llm(ctx.prompt, affected, ctx.test_code, ctx.error_message, diff)
|
|
101
|
+
|
|
102
|
+
if generated_code is None:
|
|
103
|
+
console.print("[red]LLM returned no code block[/red]")
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
successful_fix = apply_fix(ctx.file_path, generated_code)
|
|
107
|
+
if not successful_fix["success"]:
|
|
108
|
+
return False
|
|
109
|
+
else:
|
|
110
|
+
verify_ctx = VerifyContext(
|
|
111
|
+
repo_path=ctx.repo_path,
|
|
112
|
+
file_path=ctx.file_path,
|
|
113
|
+
test_file=ctx.test_file,
|
|
114
|
+
test_code=ctx.test_code,
|
|
115
|
+
test_function=ctx.test_function,
|
|
116
|
+
original_function_code=ctx.function_code,
|
|
117
|
+
generated_code=generated_code,
|
|
118
|
+
backup_path=successful_fix["backup"],
|
|
119
|
+
original_error_message=ctx.error_message,
|
|
120
|
+
fix_type=ctx.fix_type
|
|
121
|
+
)
|
|
122
|
+
return verify_fix(verify_ctx)
|
|
123
|
+
|
|
124
|
+
def try_fix_test(failure, test_file_path, test_code, source_file, source_code, path, fix_type):
|
|
125
|
+
"""Attempt to fix the test file."""
|
|
126
|
+
context = FixContext(
|
|
127
|
+
prompt=fix_test_prompt,
|
|
128
|
+
function_name=failure['function'],
|
|
129
|
+
file_path=test_file_path,
|
|
130
|
+
function_code=source_code,
|
|
131
|
+
test_code=test_code,
|
|
132
|
+
error_message=failure['error'],
|
|
133
|
+
repo_path=path,
|
|
134
|
+
test_file=failure['file'],
|
|
135
|
+
test_function=failure['function'],
|
|
136
|
+
fix_type=fix_type
|
|
137
|
+
)
|
|
138
|
+
return fix(context)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def try_fix_code(failure, test_code, source_file, source_code, path, fix_type):
|
|
142
|
+
"""Attempt to fix the source code."""
|
|
143
|
+
context = FixContext(
|
|
144
|
+
prompt=fix_code_prompt,
|
|
145
|
+
function_name=source_file,
|
|
146
|
+
file_path=source_file,
|
|
147
|
+
function_code=source_code,
|
|
148
|
+
test_code=test_code,
|
|
149
|
+
error_message=failure['error'],
|
|
150
|
+
repo_path=path,
|
|
151
|
+
test_file=failure['file'],
|
|
152
|
+
test_function=failure['function'],
|
|
153
|
+
fix_type=fix_type
|
|
154
|
+
)
|
|
155
|
+
return fix(context)
|
|
156
|
+
|
|
157
|
+
def attempt_fix(failure, changed_files, path, mode):
|
|
158
|
+
test_file_path, test_code = read_file_content(failure['file'], path)
|
|
159
|
+
|
|
160
|
+
guessed_name = failure['file'].split("/")[-1].replace("test_", "")
|
|
161
|
+
|
|
162
|
+
if guessed_name in changed_files:
|
|
163
|
+
files_to_try = [guessed_name] + [f for f in changed_files if f != guessed_name]
|
|
164
|
+
else:
|
|
165
|
+
files_to_try = changed_files
|
|
166
|
+
|
|
167
|
+
for source_file in files_to_try:
|
|
168
|
+
source_path, source_code = read_file_content(source_file, path)
|
|
169
|
+
if not source_path:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
if mode == 'test':
|
|
173
|
+
result = try_fix_test(failure, test_file_path, test_code, source_path, source_code, path, mode)
|
|
174
|
+
if result["fixed"]:
|
|
175
|
+
return FixResult(
|
|
176
|
+
test_file=failure['file'],
|
|
177
|
+
test_function=failure['function'],
|
|
178
|
+
error_message=failure['error'],
|
|
179
|
+
fixed=result["fixed"],
|
|
180
|
+
fix_type='test',
|
|
181
|
+
file_changed=str(test_file_path),
|
|
182
|
+
diff=result["diff"]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
elif mode == 'code':
|
|
186
|
+
result = try_fix_code(failure, test_file_path, test_code, source_path, source_code, path, mode)
|
|
187
|
+
if result["fixed"]:
|
|
188
|
+
return FixResult(
|
|
189
|
+
test_file=failure['file'],
|
|
190
|
+
test_function=failure['function'],
|
|
191
|
+
error_message=failure['error'],
|
|
192
|
+
fixed=result["fixed"],
|
|
193
|
+
fix_type='test',
|
|
194
|
+
file_changed=str(test_file_path),
|
|
195
|
+
diff=result["diff"]
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
elif mode == 'auto':
|
|
199
|
+
result = try_fix_test(failure, test_file_path, test_code, source_path, source_code, path, mode)
|
|
200
|
+
if result["fixed"]:
|
|
201
|
+
return FixResult(
|
|
202
|
+
test_file=failure['file'],
|
|
203
|
+
test_function=failure['function'],
|
|
204
|
+
error_message=failure['error'],
|
|
205
|
+
fixed=result["fixed"],
|
|
206
|
+
fix_type='test',
|
|
207
|
+
file_changed=str(test_file_path),
|
|
208
|
+
diff=result["diff"]
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return FixResult(
|
|
212
|
+
test_file=failure['file'],
|
|
213
|
+
test_function=failure['function'],
|
|
214
|
+
error_message=failure['error'],
|
|
215
|
+
fixed=False,
|
|
216
|
+
fix_type='auto',
|
|
217
|
+
file_changed=str(test_file_path),
|
|
218
|
+
diff=""
|
|
219
|
+
)
|
unitsauce/github.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import httpx
|
|
5
|
+
|
|
6
|
+
def check_if_pull_request():
|
|
7
|
+
event_name = os.getenv("GITHUB_EVENT_NAME")
|
|
8
|
+
event_path = os.getenv("GITHUB_EVENT_PATH")
|
|
9
|
+
repo = os.getenv("GITHUB_REPOSITORY")
|
|
10
|
+
|
|
11
|
+
if event_name != "pull_request":
|
|
12
|
+
return None
|
|
13
|
+
|
|
14
|
+
if not event_path or not os.path.exists(event_path):
|
|
15
|
+
return None
|
|
16
|
+
|
|
17
|
+
with open(event_path) as f:
|
|
18
|
+
event_details = json.load(f)
|
|
19
|
+
|
|
20
|
+
pr = event_details.get("pull_request", {})
|
|
21
|
+
|
|
22
|
+
return {
|
|
23
|
+
"number": pr.get("number"),
|
|
24
|
+
"repo": repo,
|
|
25
|
+
"sha": pr.get("head", {}).get("sha")
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def post_pr_comment(repo, pr_number, body):
|
|
30
|
+
"""Post a comment to a PR. Returns True if successful."""
|
|
31
|
+
|
|
32
|
+
token = os.getenv("GITHUB_TOKEN")
|
|
33
|
+
|
|
34
|
+
if not token:
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
url = f"https://api.github.com/repos/{repo}/issues/{pr_number}/comments"
|
|
38
|
+
headers = {
|
|
39
|
+
"Authorization": f"Bearer {token}",
|
|
40
|
+
"Accept": "application/vnd.github+json"
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Use requests or httpx
|
|
44
|
+
response = httpx.post(url, json={"body": body}, headers=headers)
|
|
45
|
+
return response.status_code == 201
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def format_pr_comment(result):
|
|
49
|
+
"""Format fix result as PR comment markdown."""
|
|
50
|
+
|
|
51
|
+
if result.fixed:
|
|
52
|
+
status = "✅ Fixed"
|
|
53
|
+
else:
|
|
54
|
+
status = "❌ Could not fix"
|
|
55
|
+
|
|
56
|
+
comment = f"## {status}: `{result.test_file}::{result.test_function}`\n\n"
|
|
57
|
+
comment += f"**Error:** `{result.error_message[:100]}`\n\n"
|
|
58
|
+
|
|
59
|
+
if result.fixed:
|
|
60
|
+
comment += f"**Fixed by:** Updating `{result.fix_type}` in `{result.file_changed}`\n\n"
|
|
61
|
+
comment += f"**Apply this change:**\n\n"
|
|
62
|
+
comment += f"```python\n{result.generated_code}\n```\n"
|
|
63
|
+
|
|
64
|
+
return comment
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def format_pr_comment_summary(results):
|
|
68
|
+
"""Format all fix results as a single PR comment."""
|
|
69
|
+
|
|
70
|
+
total = len(results)
|
|
71
|
+
fixed = sum(1 for r in results if r.fixed)
|
|
72
|
+
|
|
73
|
+
# Header
|
|
74
|
+
comment = "## 🔧 UnitSauce Analysis\n\n"
|
|
75
|
+
comment += f"Found **{total}** failing test(s), fixed **{fixed}**.\n\n"
|
|
76
|
+
comment += "---\n\n"
|
|
77
|
+
|
|
78
|
+
# Each result
|
|
79
|
+
for result in results:
|
|
80
|
+
if result.fixed:
|
|
81
|
+
comment += f"### ✅ `{result.test_file}::{result.test_function}`\n\n"
|
|
82
|
+
comment += f"**Error:** `{result.error_message[:100]}`\n\n"
|
|
83
|
+
comment += f"**Fixed by:** Updating `{result.fix_type}` in `{result.file_changed}`\n\n"
|
|
84
|
+
comment += f"```diff\n{result.diff}\n```\n\n"
|
|
85
|
+
else:
|
|
86
|
+
comment += f"### ❌ `{result.test_file}::{result.test_function}`\n\n"
|
|
87
|
+
comment += f"**Error:** `{result.error_message[:100]}`\n\n"
|
|
88
|
+
comment += "Could not auto-fix this failure.\n\n"
|
|
89
|
+
|
|
90
|
+
comment += "---\n\n"
|
|
91
|
+
|
|
92
|
+
return comment
|
unitsauce/main.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
|
|
2
|
+
import argparse
|
|
3
|
+
import sys
|
|
4
|
+
from dotenv import load_dotenv
|
|
5
|
+
from unitsauce.github import check_if_pull_request, format_pr_comment, format_pr_comment_summary, post_pr_comment
|
|
6
|
+
from unitsauce.output import format_result, format_summary
|
|
7
|
+
|
|
8
|
+
from .fixer import attempt_fix
|
|
9
|
+
from .analysis import get_failing_tests, get_git_diff, read_file_content, run_tests
|
|
10
|
+
from .utils import print_header, console
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
load_dotenv()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main():
|
|
17
|
+
|
|
18
|
+
parser = argparse.ArgumentParser(
|
|
19
|
+
prog='unitsauce',
|
|
20
|
+
description='AI-powered test failure analysis and fix suggestions. Analyzes git diffs, identifies bugs causing test failures, and generates fixes.',
|
|
21
|
+
epilog='Examples:\n unitsauce ./my-project\n unitsauce ./my-project --mode code\n unitsauce ./my-project --mode test --output markdown',
|
|
22
|
+
formatter_class=argparse.RawDescriptionHelpFormatter # Preserves newlines in epilog
|
|
23
|
+
)
|
|
24
|
+
parser.add_argument("path")
|
|
25
|
+
parser.add_argument('--mode', choices=['auto', 'code', 'test'], default='auto', help='Fix mode (default: auto)')
|
|
26
|
+
parser.add_argument('--output', choices=['console', 'markdown', 'json'], default='console', help='Output format (default: console)')
|
|
27
|
+
|
|
28
|
+
args = parser.parse_args()
|
|
29
|
+
|
|
30
|
+
if len(sys.argv) < 2:
|
|
31
|
+
console.print("[yellow]Usage: python test_fixer.py <project_path>[/yellow]")
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
print_header()
|
|
35
|
+
|
|
36
|
+
path = args.path
|
|
37
|
+
run_tests(path)
|
|
38
|
+
failures = get_failing_tests(path)
|
|
39
|
+
|
|
40
|
+
results = []
|
|
41
|
+
markdown_output = ""
|
|
42
|
+
|
|
43
|
+
if not failures:
|
|
44
|
+
console.print("[green]All tests pass![/green]")
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
console.print(f"[yellow]Found {len(failures)} failing test(s)[/yellow]\n")
|
|
48
|
+
changed_files = get_git_diff(args.path)
|
|
49
|
+
changed_files = [f for f in changed_files if f.endswith('.py')]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
for failure in failures:
|
|
53
|
+
console.print(f"[red]FAILING:[/red] {failure['file']}::{failure['function']}")
|
|
54
|
+
console.print(f"[red]ERROR:[/red] {failure['error']}\n")
|
|
55
|
+
|
|
56
|
+
result = attempt_fix(failure, changed_files, args.path, args.mode)
|
|
57
|
+
results.append(result)
|
|
58
|
+
|
|
59
|
+
if args.output == 'console':
|
|
60
|
+
format_result(result, 'console')
|
|
61
|
+
elif args.output == 'markdown':
|
|
62
|
+
markdown_output += format_result(result, 'markdown') + "\n"
|
|
63
|
+
|
|
64
|
+
pr = check_if_pull_request()
|
|
65
|
+
|
|
66
|
+
if pr:
|
|
67
|
+
comment = format_pr_comment_summary(results)
|
|
68
|
+
post_pr_comment(pr['repo'], pr['number'], comment)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
if args.output == 'console':
|
|
72
|
+
format_summary(results, 'console')
|
|
73
|
+
elif args.output == 'markdown':
|
|
74
|
+
markdown_output += format_summary(results, 'markdown')
|
|
75
|
+
print(markdown_output)
|
|
76
|
+
elif args.output == 'json':
|
|
77
|
+
print(format_summary(results, 'json'))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
main()
|
unitsauce/models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FixContext:
|
|
9
|
+
# --- LLM / reasoning ---
|
|
10
|
+
prompt: str
|
|
11
|
+
error_message: str
|
|
12
|
+
|
|
13
|
+
# --- Code under fix ---
|
|
14
|
+
function_name: str
|
|
15
|
+
function_code: str
|
|
16
|
+
file_path: Path
|
|
17
|
+
|
|
18
|
+
# --- Test context ---
|
|
19
|
+
test_code: str
|
|
20
|
+
test_file: Path
|
|
21
|
+
test_function: str
|
|
22
|
+
|
|
23
|
+
# --- Repo / execution ---
|
|
24
|
+
repo_path: Path
|
|
25
|
+
fix_type: str
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True)
|
|
29
|
+
class VerifyContext:
|
|
30
|
+
# --- Repo / execution ---
|
|
31
|
+
repo_path: Path
|
|
32
|
+
file_path: Path
|
|
33
|
+
fix_type: str
|
|
34
|
+
|
|
35
|
+
# --- Test ---
|
|
36
|
+
test_file: Path
|
|
37
|
+
test_function: str
|
|
38
|
+
test_code: str
|
|
39
|
+
|
|
40
|
+
# --- Code ---
|
|
41
|
+
original_function_code: str
|
|
42
|
+
generated_code: str
|
|
43
|
+
backup_path: Path
|
|
44
|
+
|
|
45
|
+
# --- Failure ---
|
|
46
|
+
original_error_message: Optional[str] = None
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class FixResult:
|
|
50
|
+
test_file: str
|
|
51
|
+
test_function: str
|
|
52
|
+
|
|
53
|
+
error_message: str
|
|
54
|
+
fixed: bool
|
|
55
|
+
fix_type: str
|
|
56
|
+
|
|
57
|
+
diff: str
|
|
58
|
+
file_changed: str
|
unitsauce/output.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from rich.panel import Panel
|
|
3
|
+
from rich.table import Table
|
|
4
|
+
from rich.syntax import Syntax
|
|
5
|
+
from .utils import console
|
|
6
|
+
|
|
7
|
+
def format_result(result, format_type):
|
|
8
|
+
if format_type == 'console':
|
|
9
|
+
return _format_console(result)
|
|
10
|
+
elif format_type == 'markdown':
|
|
11
|
+
return _format_markdown(result)
|
|
12
|
+
elif format_type == 'json':
|
|
13
|
+
return _format_json(result)
|
|
14
|
+
|
|
15
|
+
def format_summary(results, format_type):
|
|
16
|
+
"""Format summary of all results."""
|
|
17
|
+
if format_type == 'console':
|
|
18
|
+
return _format_console_summary(results)
|
|
19
|
+
elif format_type == 'markdown':
|
|
20
|
+
return _format_markdown_summary(results)
|
|
21
|
+
elif format_type == 'json':
|
|
22
|
+
return _format_json_summary(results)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _format_console(result):
|
|
26
|
+
"""Format a single fix result for console output."""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if result.fixed:
|
|
30
|
+
status = "[green]✓ FIXED[/green]"
|
|
31
|
+
else:
|
|
32
|
+
status = "[red]✗ NOT FIXED[/red]"
|
|
33
|
+
|
|
34
|
+
console.print(f"\n{status} [bold]{result.test_file}::{result.test_function}[/bold]")
|
|
35
|
+
|
|
36
|
+
console.print(f"[dim]Error:[/dim] {result.error_message[:100]}...")
|
|
37
|
+
|
|
38
|
+
if result.fixed:
|
|
39
|
+
console.print(f"[dim]Fixed by:[/dim] Updating [cyan]{result.fix_type}[/cyan] in [cyan]{result.file_changed}[/cyan]")
|
|
40
|
+
|
|
41
|
+
if result.diff:
|
|
42
|
+
syntax = Syntax(result.diff, "diff", theme="monokai", line_numbers=True)
|
|
43
|
+
console.print(Panel(syntax, title="Changes", border_style="green"))
|
|
44
|
+
|
|
45
|
+
def _format_console_summary(results):
|
|
46
|
+
total = len(results)
|
|
47
|
+
fixed = sum(1 for r in results if r.fixed)
|
|
48
|
+
failed = total - fixed
|
|
49
|
+
|
|
50
|
+
table = Table(title="Summary")
|
|
51
|
+
table.add_column("Status", style="bold")
|
|
52
|
+
table.add_column("Count", justify="right")
|
|
53
|
+
table.add_row("[green]Fixed[/green]", str(fixed))
|
|
54
|
+
table.add_row("[red]Failed[/red]", str(failed))
|
|
55
|
+
table.add_row("Total", str(total))
|
|
56
|
+
|
|
57
|
+
console.print("\n")
|
|
58
|
+
console.print(table)
|
|
59
|
+
|
|
60
|
+
def _format_markdown(result):
|
|
61
|
+
if result.fixed:
|
|
62
|
+
status = "✅ FIXED"
|
|
63
|
+
else:
|
|
64
|
+
status = "❌ NOT FIXED"
|
|
65
|
+
|
|
66
|
+
md = f"## {status}: `{result.test_file}::{result.test_function}`\n\n"
|
|
67
|
+
md += f"**Error:** `{result.error_message[:100]}...`\n\n"
|
|
68
|
+
|
|
69
|
+
if result.fixed:
|
|
70
|
+
md += f"**Fixed by:** Updating `{result.fix_type}` in `{result.file_changed}`\n\n"
|
|
71
|
+
if result.diff:
|
|
72
|
+
md += f"```diff\n{result.diff}\n```\n"
|
|
73
|
+
|
|
74
|
+
return md
|
|
75
|
+
|
|
76
|
+
def _format_markdown_summary(results):
|
|
77
|
+
total = len(results)
|
|
78
|
+
fixed = sum(1 for r in results if r.fixed)
|
|
79
|
+
failed = total - fixed
|
|
80
|
+
|
|
81
|
+
md = "## Summary\n\n"
|
|
82
|
+
md += f"| Status | Count |\n"
|
|
83
|
+
md += f"|--------|-------|\n"
|
|
84
|
+
md += f"| ✅ Fixed | {fixed} |\n"
|
|
85
|
+
md += f"| ❌ Failed | {failed} |\n"
|
|
86
|
+
md += f"| Total | {total} |\n"
|
|
87
|
+
|
|
88
|
+
return md
|
|
89
|
+
|
|
90
|
+
def _format_json(result):
|
|
91
|
+
return {
|
|
92
|
+
"test_file": result.test_file,
|
|
93
|
+
"test_function": result.test_function,
|
|
94
|
+
"error_message": result.error_message,
|
|
95
|
+
"fixed": result.fixed,
|
|
96
|
+
"fix_type": result.fix_type,
|
|
97
|
+
"file_changed": result.file_changed,
|
|
98
|
+
"diff": result.diff
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def _format_json_summary(results):
|
|
102
|
+
total = len(results)
|
|
103
|
+
fixed = sum(1 for r in results if r.fixed)
|
|
104
|
+
|
|
105
|
+
output = {
|
|
106
|
+
"failures": [_format_json(r) for r in results],
|
|
107
|
+
"summary": {
|
|
108
|
+
"total": total,
|
|
109
|
+
"fixed": fixed,
|
|
110
|
+
"failed": total - fixed
|
|
111
|
+
}
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
return json.dumps(output, indent=2)
|
unitsauce/prompts.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
fix_code_prompt = """
|
|
2
|
+
You are an expert Python developer fixing a bug in code.
|
|
3
|
+
|
|
4
|
+
IMPORTANT RULES:
|
|
5
|
+
- Return ONLY the fixed function(s), nothing else
|
|
6
|
+
- Preserve ALL comments - they explain important logic
|
|
7
|
+
- Keep the same code structure and style
|
|
8
|
+
- Make the MINIMAL change needed to fix the issue
|
|
9
|
+
- Do NOT reformat or rewrite working code
|
|
10
|
+
- Do NOT change variable order or logic flow unless that's the bug
|
|
11
|
+
- Return all fixed functions in a SINGLE code block, not separate blocks.
|
|
12
|
+
|
|
13
|
+
Here are the functions that were modified and may contain the bug:
|
|
14
|
+
{function_code}
|
|
15
|
+
|
|
16
|
+
Here is the git diff showing what changed:
|
|
17
|
+
{diff}
|
|
18
|
+
|
|
19
|
+
Here is the failing test:
|
|
20
|
+
{test_code}
|
|
21
|
+
|
|
22
|
+
Here is the error:
|
|
23
|
+
{error_message}
|
|
24
|
+
|
|
25
|
+
Fix the bug with minimal changes. Preserve all comments.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
fix_test_prompt = """
|
|
29
|
+
You are an expert Python developer updating a test to match intentional code changes.
|
|
30
|
+
|
|
31
|
+
IMPORTANT RULES:
|
|
32
|
+
- Return ONLY the fixed test function(s), nothing else
|
|
33
|
+
- Preserve ALL comments - they explain important logic
|
|
34
|
+
- Keep the same test structure and style
|
|
35
|
+
- Make the MINIMAL change needed to match new behavior
|
|
36
|
+
- Do NOT reformat or rewrite working test code
|
|
37
|
+
- Update assertions to match the new expected values
|
|
38
|
+
- Return all fixed functions in a SINGLE code block, not separate blocks.
|
|
39
|
+
|
|
40
|
+
Here are the source functions that were intentionally changed:
|
|
41
|
+
{function_code}
|
|
42
|
+
|
|
43
|
+
Here is the git diff showing what changed:
|
|
44
|
+
{diff}
|
|
45
|
+
|
|
46
|
+
Here is the failing test that needs updating:
|
|
47
|
+
{test_code}
|
|
48
|
+
|
|
49
|
+
Here is the error:
|
|
50
|
+
{error_message}
|
|
51
|
+
|
|
52
|
+
Update the test to match the new code behavior. Preserve all comments.
|
|
53
|
+
"""
|
unitsauce/utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
import shutil
|
|
3
|
+
import art
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
|
|
6
|
+
console = Console()
|
|
7
|
+
|
|
8
|
+
def print_header():
|
|
9
|
+
text = art.text2art("UnitSauce")
|
|
10
|
+
|
|
11
|
+
print("```")
|
|
12
|
+
print(text.rstrip())
|
|
13
|
+
print("```")
|
|
14
|
+
print("*AI-powered test fixer*")
|
|
15
|
+
|
|
16
|
+
def backup_file(file_path):
|
|
17
|
+
src = Path(file_path)
|
|
18
|
+
|
|
19
|
+
if not src.exists():
|
|
20
|
+
raise FileNotFoundError(src)
|
|
21
|
+
|
|
22
|
+
backup = src.with_suffix(src.suffix + ".bak")
|
|
23
|
+
shutil.copy2(src, backup)
|
|
24
|
+
|
|
25
|
+
return backup
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: unitsauce
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: AI-powered test failure analysis and fix suggestions
|
|
5
|
+
Author: Zan Starasinic
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/zanstarasinic/unitsauce
|
|
8
|
+
Requires-Python: >=3.10
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: anthropic
|
|
11
|
+
Requires-Dist: python-dotenv
|
|
12
|
+
Requires-Dist: art
|
|
13
|
+
Requires-Dist: rich
|
|
14
|
+
Requires-Dist: pytest
|
|
15
|
+
Requires-Dist: pytest-json-report
|
|
16
|
+
|
|
17
|
+
# UnitSauce
|
|
18
|
+
|
|
19
|
+
AI-powered test fixer for Python projects.
|
|
20
|
+
|
|
21
|
+
## What it does
|
|
22
|
+
|
|
23
|
+
Analyzes failing tests, identifies bugs in recent code changes, and suggests fixes using Claude.
|
|
24
|
+
|
|
25
|
+
## Install
|
|
26
|
+
```bash
|
|
27
|
+
pip install -r requirements.txt
|
|
28
|
+
cp .env.example .env
|
|
29
|
+
# Add your Anthropic API key to .env
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## Usage
|
|
33
|
+
```bash
|
|
34
|
+
python main.py /path/to/your/project
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
The tool will:
|
|
38
|
+
1. Run pytest and detect failures
|
|
39
|
+
2. Analyze git diff to find recent changes
|
|
40
|
+
3. Ask if you want to fix the code or update the test
|
|
41
|
+
4. Generate and apply a fix
|
|
42
|
+
5. Verify the fix works
|
|
43
|
+
|
|
44
|
+
## Requirements
|
|
45
|
+
|
|
46
|
+
- Python 3.10+
|
|
47
|
+
- Anthropic API key
|
|
48
|
+
- Project must be a git repository
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
unitsauce/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
unitsauce/analysis.py,sha256=Zz0d-aOr2yI2YG9OoOz7SRNl7Wz-_9meJfwyUKdS_i0,5158
|
|
3
|
+
unitsauce/fixer.py,sha256=3SHjqLbIEPnvzmt_6JwnrfXtVGtwiUJK-zt-e_pe_78,8023
|
|
4
|
+
unitsauce/github.py,sha256=9JfO-a1v9H96_dZUXA1YacVoM_UZZm4NV0ebjU8_1Dg,2746
|
|
5
|
+
unitsauce/main.py,sha256=QLOVhxkmrjYW4IoglR3TChncyXpdoD8_wehAMtAjkOw,2760
|
|
6
|
+
unitsauce/models.py,sha256=l40OcAAoFrysma83A1vpbUi_5uD2Yh5TmGfGzd_AgWc,1001
|
|
7
|
+
unitsauce/output.py,sha256=wLnFEBq0o2Qms0FMOF6Fu7MLWTTBCw-zqFzdMupc7Aw,3447
|
|
8
|
+
unitsauce/prompts.py,sha256=2xaOEkYrXB-XawdJGoJOryaZjGX9YHQPF0bapnfsS8U,1708
|
|
9
|
+
unitsauce/utils.py,sha256=-XnwUGgsydBalq4GtSV0GJ63PeiPkf6jZBCs2BykxWw,476
|
|
10
|
+
unitsauce-0.1.0.dist-info/METADATA,sha256=NFBhV4L4CSi2qwJXoNJwb4AhXq0pRUioPykhzxkVgoc,1062
|
|
11
|
+
unitsauce-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
12
|
+
unitsauce-0.1.0.dist-info/entry_points.txt,sha256=QFdjrO3Rj2jkNqfwe0aY8JT5gYbHOWPPkvwxAsVynu8,50
|
|
13
|
+
unitsauce-0.1.0.dist-info/top_level.txt,sha256=44KP-G6eKDbMYc6daSbp1vcEDez-Q99MEWNqGL3Tx5E,10
|
|
14
|
+
unitsauce-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
unitsauce
|