pdd-cli 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pdd-cli might be problematic. Click here for more details.
- pdd/__init__.py +0 -0
- pdd/auto_deps_main.py +98 -0
- pdd/auto_include.py +175 -0
- pdd/auto_update.py +73 -0
- pdd/bug_main.py +99 -0
- pdd/bug_to_unit_test.py +159 -0
- pdd/change.py +141 -0
- pdd/change_main.py +240 -0
- pdd/cli.py +607 -0
- pdd/cmd_test_main.py +155 -0
- pdd/code_generator.py +117 -0
- pdd/code_generator_main.py +66 -0
- pdd/comment_line.py +35 -0
- pdd/conflicts_in_prompts.py +143 -0
- pdd/conflicts_main.py +90 -0
- pdd/construct_paths.py +251 -0
- pdd/context_generator.py +133 -0
- pdd/context_generator_main.py +73 -0
- pdd/continue_generation.py +140 -0
- pdd/crash_main.py +127 -0
- pdd/data/language_format.csv +61 -0
- pdd/data/llm_model.csv +15 -0
- pdd/detect_change.py +142 -0
- pdd/detect_change_main.py +100 -0
- pdd/find_section.py +28 -0
- pdd/fix_code_loop.py +212 -0
- pdd/fix_code_module_errors.py +143 -0
- pdd/fix_error_loop.py +216 -0
- pdd/fix_errors_from_unit_tests.py +240 -0
- pdd/fix_main.py +138 -0
- pdd/generate_output_paths.py +194 -0
- pdd/generate_test.py +140 -0
- pdd/get_comment.py +55 -0
- pdd/get_extension.py +52 -0
- pdd/get_language.py +41 -0
- pdd/git_update.py +84 -0
- pdd/increase_tests.py +93 -0
- pdd/insert_includes.py +150 -0
- pdd/llm_invoke.py +304 -0
- pdd/load_prompt_template.py +59 -0
- pdd/pdd_completion.fish +72 -0
- pdd/pdd_completion.sh +141 -0
- pdd/pdd_completion.zsh +418 -0
- pdd/postprocess.py +121 -0
- pdd/postprocess_0.py +52 -0
- pdd/preprocess.py +199 -0
- pdd/preprocess_main.py +72 -0
- pdd/process_csv_change.py +182 -0
- pdd/prompts/auto_include_LLM.prompt +230 -0
- pdd/prompts/bug_to_unit_test_LLM.prompt +17 -0
- pdd/prompts/change_LLM.prompt +34 -0
- pdd/prompts/conflict_LLM.prompt +23 -0
- pdd/prompts/continue_generation_LLM.prompt +3 -0
- pdd/prompts/detect_change_LLM.prompt +65 -0
- pdd/prompts/example_generator_LLM.prompt +10 -0
- pdd/prompts/extract_auto_include_LLM.prompt +6 -0
- pdd/prompts/extract_code_LLM.prompt +22 -0
- pdd/prompts/extract_conflict_LLM.prompt +19 -0
- pdd/prompts/extract_detect_change_LLM.prompt +19 -0
- pdd/prompts/extract_program_code_fix_LLM.prompt +16 -0
- pdd/prompts/extract_prompt_change_LLM.prompt +7 -0
- pdd/prompts/extract_prompt_split_LLM.prompt +9 -0
- pdd/prompts/extract_prompt_update_LLM.prompt +8 -0
- pdd/prompts/extract_promptline_LLM.prompt +11 -0
- pdd/prompts/extract_unit_code_fix_LLM.prompt +332 -0
- pdd/prompts/extract_xml_LLM.prompt +7 -0
- pdd/prompts/fix_code_module_errors_LLM.prompt +17 -0
- pdd/prompts/fix_errors_from_unit_tests_LLM.prompt +62 -0
- pdd/prompts/generate_test_LLM.prompt +12 -0
- pdd/prompts/increase_tests_LLM.prompt +16 -0
- pdd/prompts/insert_includes_LLM.prompt +30 -0
- pdd/prompts/split_LLM.prompt +94 -0
- pdd/prompts/summarize_file_LLM.prompt +11 -0
- pdd/prompts/trace_LLM.prompt +30 -0
- pdd/prompts/trim_results_LLM.prompt +83 -0
- pdd/prompts/trim_results_start_LLM.prompt +45 -0
- pdd/prompts/unfinished_prompt_LLM.prompt +18 -0
- pdd/prompts/update_prompt_LLM.prompt +19 -0
- pdd/prompts/xml_convertor_LLM.prompt +54 -0
- pdd/split.py +119 -0
- pdd/split_main.py +103 -0
- pdd/summarize_directory.py +212 -0
- pdd/trace.py +135 -0
- pdd/trace_main.py +108 -0
- pdd/track_cost.py +102 -0
- pdd/unfinished_prompt.py +114 -0
- pdd/update_main.py +96 -0
- pdd/update_prompt.py +115 -0
- pdd/xml_tagger.py +122 -0
- pdd_cli-0.0.2.dist-info/LICENSE +7 -0
- pdd_cli-0.0.2.dist-info/METADATA +225 -0
- pdd_cli-0.0.2.dist-info/RECORD +95 -0
- pdd_cli-0.0.2.dist-info/WHEEL +5 -0
- pdd_cli-0.0.2.dist-info/entry_points.txt +2 -0
- pdd_cli-0.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
3
|
+
from rich import print
|
|
4
|
+
from rich.markdown import Markdown
|
|
5
|
+
from .load_prompt_template import load_prompt_template
|
|
6
|
+
from .llm_invoke import llm_invoke
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
class CodeFix(BaseModel):
|
|
10
|
+
update_program: bool = Field(description="Indicates if the program needs updating")
|
|
11
|
+
update_code: bool = Field(description="Indicates if the code module needs updating")
|
|
12
|
+
fixed_program: str = Field(description="The fixed program code")
|
|
13
|
+
fixed_code: str = Field(description="The fixed code module")
|
|
14
|
+
|
|
15
|
+
def validate_inputs(
|
|
16
|
+
program: str,
|
|
17
|
+
prompt: str,
|
|
18
|
+
code: str,
|
|
19
|
+
errors: str,
|
|
20
|
+
strength: float
|
|
21
|
+
) -> None:
|
|
22
|
+
"""Validate input parameters."""
|
|
23
|
+
if not all([program, prompt, code, errors]):
|
|
24
|
+
raise ValueError("All string inputs (program, prompt, code, errors) must be non-empty")
|
|
25
|
+
|
|
26
|
+
if not isinstance(strength, (int, float)):
|
|
27
|
+
raise ValueError("Strength must be a number")
|
|
28
|
+
|
|
29
|
+
if not 0 <= strength <= 1:
|
|
30
|
+
raise ValueError("Strength must be between 0 and 1")
|
|
31
|
+
|
|
32
|
+
def fix_code_module_errors(
|
|
33
|
+
program: str,
|
|
34
|
+
prompt: str,
|
|
35
|
+
code: str,
|
|
36
|
+
errors: str,
|
|
37
|
+
strength: float,
|
|
38
|
+
temperature: float = 0,
|
|
39
|
+
verbose: bool = False
|
|
40
|
+
) -> Tuple[bool, bool, str, str, float, str]:
|
|
41
|
+
"""
|
|
42
|
+
Fix errors in a code module that caused a program to crash and/or have errors.
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
# Validate inputs
|
|
46
|
+
validate_inputs(program, prompt, code, errors, strength)
|
|
47
|
+
|
|
48
|
+
# Step 1: Load prompt templates
|
|
49
|
+
fix_prompt = load_prompt_template("fix_code_module_errors_LLM")
|
|
50
|
+
extract_prompt = load_prompt_template("extract_program_code_fix_LLM")
|
|
51
|
+
|
|
52
|
+
if not all([fix_prompt, extract_prompt]):
|
|
53
|
+
raise ValueError("Failed to load one or more prompt templates")
|
|
54
|
+
|
|
55
|
+
total_cost = 0
|
|
56
|
+
model_name = ""
|
|
57
|
+
|
|
58
|
+
# Step 2: First LLM invoke for error analysis
|
|
59
|
+
input_json = {
|
|
60
|
+
"program": program,
|
|
61
|
+
"prompt": prompt,
|
|
62
|
+
"code": code,
|
|
63
|
+
"errors": errors
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
if verbose:
|
|
67
|
+
print("[blue]Running initial error analysis...[/blue]")
|
|
68
|
+
|
|
69
|
+
first_response = llm_invoke(
|
|
70
|
+
prompt=fix_prompt,
|
|
71
|
+
input_json=input_json,
|
|
72
|
+
strength=strength,
|
|
73
|
+
temperature=temperature,
|
|
74
|
+
verbose=verbose
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
total_cost += first_response.get('cost', 0)
|
|
78
|
+
model_name = first_response.get('model_name', '')
|
|
79
|
+
|
|
80
|
+
if verbose:
|
|
81
|
+
print("[green]Error analysis complete[/green]")
|
|
82
|
+
print(Markdown(first_response['result']))
|
|
83
|
+
print(f"[yellow]Current cost: ${total_cost:.6f}[/yellow]")
|
|
84
|
+
|
|
85
|
+
# Step 4: Second LLM invoke for code extraction
|
|
86
|
+
extract_input = {
|
|
87
|
+
"program_code_fix": first_response['result'],
|
|
88
|
+
"program": program,
|
|
89
|
+
"code": code
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
if verbose:
|
|
93
|
+
print("[blue]Extracting code fixes...[/blue]")
|
|
94
|
+
|
|
95
|
+
second_response = llm_invoke(
|
|
96
|
+
prompt=extract_prompt,
|
|
97
|
+
input_json=extract_input,
|
|
98
|
+
strength=0.89, # Fixed strength as specified
|
|
99
|
+
temperature=temperature,
|
|
100
|
+
verbose=verbose,
|
|
101
|
+
output_pydantic=CodeFix
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
total_cost += second_response.get('cost', 0)
|
|
105
|
+
|
|
106
|
+
# Step 5: Extract values from Pydantic result
|
|
107
|
+
result = second_response['result']
|
|
108
|
+
|
|
109
|
+
if isinstance(result, str):
|
|
110
|
+
try:
|
|
111
|
+
result_dict = json.loads(result)
|
|
112
|
+
except json.JSONDecodeError:
|
|
113
|
+
result_dict = {"result": result}
|
|
114
|
+
result = CodeFix.model_validate(result_dict)
|
|
115
|
+
elif isinstance(result, dict):
|
|
116
|
+
result = CodeFix.model_validate(result)
|
|
117
|
+
elif not isinstance(result, CodeFix):
|
|
118
|
+
result = CodeFix.model_validate({"result": str(result)})
|
|
119
|
+
|
|
120
|
+
if verbose:
|
|
121
|
+
print("[green]Code extraction complete[/green]")
|
|
122
|
+
print(f"[yellow]Total cost: ${total_cost:.6f}[/yellow]")
|
|
123
|
+
print(f"[blue]Model used: {model_name}[/blue]")
|
|
124
|
+
|
|
125
|
+
# Step 7: Return results
|
|
126
|
+
return (
|
|
127
|
+
result.update_program,
|
|
128
|
+
result.update_code,
|
|
129
|
+
result.fixed_program,
|
|
130
|
+
result.fixed_code,
|
|
131
|
+
total_cost,
|
|
132
|
+
model_name
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
except ValueError as ve:
|
|
136
|
+
print(f"[red]Value Error: {str(ve)}[/red]")
|
|
137
|
+
raise
|
|
138
|
+
except ValidationError:
|
|
139
|
+
print("[red]Validation Error: Invalid result format[/red]")
|
|
140
|
+
raise
|
|
141
|
+
except Exception as e:
|
|
142
|
+
print(f"[red]Unexpected error: {str(e)}[/red]")
|
|
143
|
+
raise
|
pdd/fix_error_loop.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import shutil
|
|
4
|
+
import subprocess
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Tuple, Optional
|
|
7
|
+
from rich import print as rprint
|
|
8
|
+
|
|
9
|
+
from .fix_errors_from_unit_tests import fix_errors_from_unit_tests
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class IterationResult:
|
|
13
|
+
fails: int
|
|
14
|
+
errors: int
|
|
15
|
+
iteration: int
|
|
16
|
+
total_fails_and_errors: int
|
|
17
|
+
|
|
18
|
+
def is_better_than(self, other: Optional['IterationResult']) -> bool:
|
|
19
|
+
if other is None:
|
|
20
|
+
return True
|
|
21
|
+
if self.total_fails_and_errors < other.total_fails_and_errors:
|
|
22
|
+
return True
|
|
23
|
+
if self.total_fails_and_errors == other.total_fails_and_errors:
|
|
24
|
+
return self.errors < other.errors # Prioritize fewer errors
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
def extract_test_results(pytest_output: str) -> Tuple[int, int]:
|
|
28
|
+
"""Extract the number of fails and errors from pytest output.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
pytest_output (str): The complete pytest output text
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tuple[int, int]: Number of fails and errors respectively
|
|
35
|
+
"""
|
|
36
|
+
fails = errors = 0
|
|
37
|
+
|
|
38
|
+
# First try to match the summary line
|
|
39
|
+
summary_match = re.search(r'=+ (\d+) failed[\,\s]', pytest_output)
|
|
40
|
+
if summary_match:
|
|
41
|
+
fails = int(summary_match.group(1))
|
|
42
|
+
else:
|
|
43
|
+
# Fallback to looking for any "X failed" pattern
|
|
44
|
+
fail_match = re.search(r'(\d+)\s+failed', pytest_output)
|
|
45
|
+
if fail_match:
|
|
46
|
+
fails = int(fail_match.group(1))
|
|
47
|
+
|
|
48
|
+
# Look for error patterns
|
|
49
|
+
error_match = re.search(r'(\d+)\s+error', pytest_output)
|
|
50
|
+
if error_match:
|
|
51
|
+
errors = int(error_match.group(1))
|
|
52
|
+
|
|
53
|
+
return fails, errors
|
|
54
|
+
|
|
55
|
+
def create_backup_files(unit_test_file: str, code_file: str, fails: int,
|
|
56
|
+
errors: int, iteration: int) -> Tuple[str, str]:
|
|
57
|
+
"""Create backup files with iteration information in the filename."""
|
|
58
|
+
unit_test_backup = f"{os.path.splitext(unit_test_file)[0]}_{fails}_{errors}_{iteration}.py"
|
|
59
|
+
code_backup = f"{os.path.splitext(code_file)[0]}_{fails}_{errors}_{iteration}.py"
|
|
60
|
+
|
|
61
|
+
shutil.copy2(unit_test_file, unit_test_backup)
|
|
62
|
+
shutil.copy2(code_file, code_backup)
|
|
63
|
+
|
|
64
|
+
return unit_test_backup, code_backup
|
|
65
|
+
|
|
66
|
+
def fix_error_loop(
|
|
67
|
+
unit_test_file: str,
|
|
68
|
+
code_file: str,
|
|
69
|
+
prompt: str,
|
|
70
|
+
verification_program: str,
|
|
71
|
+
strength: float,
|
|
72
|
+
temperature: float,
|
|
73
|
+
max_attempts: int,
|
|
74
|
+
budget: float,
|
|
75
|
+
error_log_file: str = "error_log.txt",
|
|
76
|
+
verbose: bool = False
|
|
77
|
+
) -> Tuple[bool, str, str, int, float, str]:
|
|
78
|
+
"""
|
|
79
|
+
Attempt to fix errors in a unit test and its corresponding code file through multiple iterations.
|
|
80
|
+
"""
|
|
81
|
+
# Input validation
|
|
82
|
+
if not all([os.path.exists(f) for f in [unit_test_file, code_file, verification_program]]):
|
|
83
|
+
raise FileNotFoundError("One or more input files do not exist")
|
|
84
|
+
if not (0 <= strength <= 1 and 0 <= temperature <= 1):
|
|
85
|
+
raise ValueError("Strength and temperature must be between 0 and 1")
|
|
86
|
+
|
|
87
|
+
# Step 1: Remove existing error log file if it exists
|
|
88
|
+
try:
|
|
89
|
+
if os.path.exists(error_log_file):
|
|
90
|
+
os.remove(error_log_file)
|
|
91
|
+
except FileNotFoundError:
|
|
92
|
+
pass # File doesn't exist, which is fine
|
|
93
|
+
|
|
94
|
+
# Step 2: Initialize variables
|
|
95
|
+
attempt_count = 0
|
|
96
|
+
total_cost = 0.0
|
|
97
|
+
best_iteration: Optional[IterationResult] = None
|
|
98
|
+
model_name = ""
|
|
99
|
+
|
|
100
|
+
while attempt_count < max_attempts:
|
|
101
|
+
rprint(f"[bold yellow]Attempt {attempt_count + 1}[/bold yellow]")
|
|
102
|
+
|
|
103
|
+
# Increment attempt counter first
|
|
104
|
+
attempt_count += 1
|
|
105
|
+
|
|
106
|
+
# Step 3a: Run pytest
|
|
107
|
+
with open(error_log_file, 'a') as f:
|
|
108
|
+
result = subprocess.run(['python', '-m', 'pytest', '-vv', '--no-cov', unit_test_file],
|
|
109
|
+
capture_output=True, text=True)
|
|
110
|
+
f.write("\n****************************************************************************************************\n")
|
|
111
|
+
f.write("\nAttempt " + str(attempt_count) + ":\n")
|
|
112
|
+
f.write("\n****************************************************************************************************\n")
|
|
113
|
+
f.write(result.stdout + result.stderr)
|
|
114
|
+
|
|
115
|
+
# Extract test results
|
|
116
|
+
fails, errors = extract_test_results(result.stdout)
|
|
117
|
+
current_iteration = IterationResult(fails, errors, attempt_count, fails + errors)
|
|
118
|
+
|
|
119
|
+
# Step 3b: Check if tests pass
|
|
120
|
+
if fails == 0 and errors == 0:
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
# Step 3c: Handle test failures
|
|
124
|
+
with open(error_log_file, 'r') as f:
|
|
125
|
+
error_content = f.read()
|
|
126
|
+
rprint(f"[bold red]Test output (attempt {attempt_count}):[/bold red]")
|
|
127
|
+
rprint(error_content.replace('[', '\\[').replace(']', '\\]'))
|
|
128
|
+
|
|
129
|
+
# Create backups
|
|
130
|
+
backup_unit_test, backup_code = create_backup_files(
|
|
131
|
+
unit_test_file, code_file, fails, errors, attempt_count
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Read current files
|
|
135
|
+
with open(unit_test_file, 'r') as f:
|
|
136
|
+
current_unit_test = f.read()
|
|
137
|
+
with open(code_file, 'r') as f:
|
|
138
|
+
current_code = f.read()
|
|
139
|
+
|
|
140
|
+
# Try to fix errors
|
|
141
|
+
update_unit_test, update_code, fixed_unit_test, fixed_code, iteration_cost, model_name = (
|
|
142
|
+
fix_errors_from_unit_tests(
|
|
143
|
+
current_unit_test, current_code, prompt, error_content,
|
|
144
|
+
error_log_file, strength, temperature,
|
|
145
|
+
verbose=verbose
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
total_cost += iteration_cost
|
|
150
|
+
if total_cost > budget:
|
|
151
|
+
rprint("[bold red]Budget exceeded![/bold red]")
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
if not (update_unit_test or update_code):
|
|
155
|
+
rprint("[bold yellow]No changes needed or possible.[/bold yellow]")
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
# Update files if needed
|
|
159
|
+
if update_unit_test:
|
|
160
|
+
with open(unit_test_file, 'w') as f:
|
|
161
|
+
f.write(fixed_unit_test)
|
|
162
|
+
|
|
163
|
+
if update_code:
|
|
164
|
+
with open(code_file, 'w') as f:
|
|
165
|
+
f.write(fixed_code)
|
|
166
|
+
|
|
167
|
+
# Run verification
|
|
168
|
+
rprint("[bold yellow]Running Verification.[/bold yellow]")
|
|
169
|
+
verification_result = subprocess.run(['python', verification_program],
|
|
170
|
+
capture_output=True, text=True)
|
|
171
|
+
|
|
172
|
+
if verification_result.returncode != 0:
|
|
173
|
+
rprint("[bold red]Verification failed! Restoring previous code.[/bold red]")
|
|
174
|
+
shutil.copy2(backup_code, code_file)
|
|
175
|
+
with open(error_log_file, 'a') as f:
|
|
176
|
+
f.write("****************************************************************************************************\n")
|
|
177
|
+
f.write("\nVerification program failed! Here is the output and errors from the verification program that was running the code under test:\n" + verification_result.stdout + verification_result.stderr)
|
|
178
|
+
f.write("****************************************************************************************************\n")
|
|
179
|
+
f.write(f"\nRestoring previous working code.\n")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Update best iteration if current is better
|
|
183
|
+
if current_iteration.is_better_than(best_iteration):
|
|
184
|
+
best_iteration = current_iteration
|
|
185
|
+
|
|
186
|
+
# Check budget after increment
|
|
187
|
+
if total_cost > budget:
|
|
188
|
+
rprint("[bold red]Budget exceeded![/bold red]")
|
|
189
|
+
break
|
|
190
|
+
|
|
191
|
+
# Step 4: Final test run
|
|
192
|
+
with open(error_log_file, 'a') as f:
|
|
193
|
+
final_result = subprocess.run(['python', '-m', 'pytest', '-vv', unit_test_file],
|
|
194
|
+
capture_output=True, text=True)
|
|
195
|
+
f.write("\nFinal test run:\n" + final_result.stdout + final_result.stderr)
|
|
196
|
+
rprint("[bold]Final test output:[/bold]")
|
|
197
|
+
rprint(final_result.stdout.replace('[', '\\[').replace(']', '\\]'))
|
|
198
|
+
|
|
199
|
+
# Step 5: Restore best iteration if needed
|
|
200
|
+
final_fails, final_errors = extract_test_results(final_result.stdout)
|
|
201
|
+
if best_iteration and (final_fails + final_errors) > best_iteration.total_fails_and_errors:
|
|
202
|
+
rprint(f"[bold yellow]Restoring best iteration: {best_iteration.iteration} [/bold yellow]")
|
|
203
|
+
best_unit_test = f"{os.path.splitext(unit_test_file)[0]}_{best_iteration.fails}_{best_iteration.errors}_{best_iteration.iteration}.py"
|
|
204
|
+
best_code = f"{os.path.splitext(code_file)[0]}_{best_iteration.fails}_{best_iteration.errors}_{best_iteration.iteration}.py"
|
|
205
|
+
shutil.copy2(best_unit_test, unit_test_file)
|
|
206
|
+
shutil.copy2(best_code, code_file)
|
|
207
|
+
|
|
208
|
+
# Step 6: Return results
|
|
209
|
+
with open(unit_test_file, 'r') as f:
|
|
210
|
+
final_unit_test = f.read()
|
|
211
|
+
with open(code_file, 'r') as f:
|
|
212
|
+
final_code = f.read()
|
|
213
|
+
|
|
214
|
+
success = final_fails == 0 and final_errors == 0
|
|
215
|
+
|
|
216
|
+
return success, final_unit_test, final_code, attempt_count, total_cost, model_name
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile # Added missing import
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Tuple, Optional
|
|
5
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
6
|
+
from rich import print as rprint
|
|
7
|
+
from rich.markdown import Markdown
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.panel import Panel
|
|
10
|
+
from tempfile import NamedTemporaryFile
|
|
11
|
+
|
|
12
|
+
from .preprocess import preprocess
|
|
13
|
+
from .load_prompt_template import load_prompt_template
|
|
14
|
+
from .llm_invoke import llm_invoke
|
|
15
|
+
|
|
16
|
+
console = Console()
|
|
17
|
+
|
|
18
|
+
class CodeFix(BaseModel):
|
|
19
|
+
update_unit_test: bool = Field(description="Whether the unit test needs to be updated")
|
|
20
|
+
update_code: bool = Field(description="Whether the code needs to be updated")
|
|
21
|
+
fixed_unit_test: str = Field(description="The fixed unit test code")
|
|
22
|
+
fixed_code: str = Field(description="The fixed code under test")
|
|
23
|
+
|
|
24
|
+
def validate_inputs(strength: float, temperature: float) -> None:
|
|
25
|
+
"""Validate strength and temperature parameters."""
|
|
26
|
+
if not 0 <= strength <= 1:
|
|
27
|
+
raise ValueError("Strength must be between 0 and 1")
|
|
28
|
+
if not 0 <= temperature <= 1:
|
|
29
|
+
raise ValueError("Temperature must be between 0 and 1")
|
|
30
|
+
|
|
31
|
+
def write_to_error_file(file_path: str, content: str) -> None:
|
|
32
|
+
"""Write content to error file with timestamp and separator."""
|
|
33
|
+
try:
|
|
34
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
35
|
+
separator = f"\n{'='*80}\n{timestamp}\n{'='*80}\n"
|
|
36
|
+
|
|
37
|
+
# Ensure parent directory exists
|
|
38
|
+
parent_dir = os.path.dirname(file_path)
|
|
39
|
+
use_fallback = False
|
|
40
|
+
|
|
41
|
+
if parent_dir:
|
|
42
|
+
try:
|
|
43
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
44
|
+
except Exception as e:
|
|
45
|
+
console.print(f"[yellow]Warning: Could not create directory {parent_dir}: {str(e)}[/yellow]")
|
|
46
|
+
# Fallback to system temp directory
|
|
47
|
+
use_fallback = True
|
|
48
|
+
parent_dir = None
|
|
49
|
+
|
|
50
|
+
# Use atomic write with temporary file
|
|
51
|
+
try:
|
|
52
|
+
# First read existing content if file exists
|
|
53
|
+
existing_content = ""
|
|
54
|
+
if os.path.exists(file_path):
|
|
55
|
+
try:
|
|
56
|
+
with open(file_path, 'r') as f:
|
|
57
|
+
existing_content = f.read()
|
|
58
|
+
except Exception as e:
|
|
59
|
+
console.print(f"[yellow]Warning: Could not read existing file {file_path}: {str(e)}[/yellow]")
|
|
60
|
+
|
|
61
|
+
# Write both existing and new content to temp file
|
|
62
|
+
with NamedTemporaryFile(mode='w', dir=parent_dir, delete=False) as tmp_file:
|
|
63
|
+
if existing_content:
|
|
64
|
+
tmp_file.write(existing_content)
|
|
65
|
+
tmp_file.write(f"{separator}{content}\n")
|
|
66
|
+
tmp_path = tmp_file.name
|
|
67
|
+
|
|
68
|
+
# Only attempt atomic move if not using fallback
|
|
69
|
+
if not use_fallback:
|
|
70
|
+
try:
|
|
71
|
+
os.replace(tmp_path, file_path)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
console.print(f"[yellow]Warning: Could not move file to {file_path}: {str(e)}[/yellow]")
|
|
74
|
+
use_fallback = True
|
|
75
|
+
|
|
76
|
+
if use_fallback:
|
|
77
|
+
# Write to fallback location in system temp directory
|
|
78
|
+
fallback_path = os.path.join(tempfile.gettempdir(), os.path.basename(file_path))
|
|
79
|
+
try:
|
|
80
|
+
os.replace(tmp_path, fallback_path)
|
|
81
|
+
console.print(f"[yellow]Warning: Using fallback location: {fallback_path}[/yellow]")
|
|
82
|
+
except Exception as e:
|
|
83
|
+
console.print(f"[red]Error writing to fallback location: {str(e)}[/red]")
|
|
84
|
+
try:
|
|
85
|
+
os.unlink(tmp_path)
|
|
86
|
+
except:
|
|
87
|
+
pass
|
|
88
|
+
raise
|
|
89
|
+
except Exception as e:
|
|
90
|
+
console.print(f"[red]Error writing to error file: {str(e)}[/red]")
|
|
91
|
+
try:
|
|
92
|
+
os.unlink(tmp_path)
|
|
93
|
+
except:
|
|
94
|
+
pass
|
|
95
|
+
raise
|
|
96
|
+
except Exception as e:
|
|
97
|
+
console.print(f"[red]Error in write_to_error_file: {str(e)}[/red]")
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
def fix_errors_from_unit_tests(
|
|
101
|
+
unit_test: str,
|
|
102
|
+
code: str,
|
|
103
|
+
prompt: str,
|
|
104
|
+
error: str,
|
|
105
|
+
error_file: str,
|
|
106
|
+
strength: float,
|
|
107
|
+
temperature: float,
|
|
108
|
+
verbose: bool = False
|
|
109
|
+
) -> Tuple[bool, bool, str, str, float, str]:
|
|
110
|
+
"""
|
|
111
|
+
Fix errors in unit tests using LLM models and log the process.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
unit_test (str): The unit test code
|
|
115
|
+
code (str): The code under test
|
|
116
|
+
prompt (str): The prompt that generated the code
|
|
117
|
+
error (str): The error message
|
|
118
|
+
error_file (str): Path to error log file
|
|
119
|
+
strength (float): LLM model strength (0-1)
|
|
120
|
+
temperature (float): LLM temperature (0-1)
|
|
121
|
+
verbose (bool): Whether to print detailed output
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Tuple containing update flags, fixed code/tests, total cost, and model name
|
|
125
|
+
"""
|
|
126
|
+
# Input validation
|
|
127
|
+
if not all([unit_test, code, prompt, error, error_file]):
|
|
128
|
+
raise ValueError("All input parameters must be non-empty")
|
|
129
|
+
|
|
130
|
+
validate_inputs(strength, temperature)
|
|
131
|
+
|
|
132
|
+
total_cost = 0.0
|
|
133
|
+
model_name = ""
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
# Step 1: Load prompt templates
|
|
137
|
+
fix_errors_prompt = load_prompt_template("fix_errors_from_unit_tests_LLM")
|
|
138
|
+
extract_fix_prompt = load_prompt_template("extract_unit_code_fix_LLM")
|
|
139
|
+
|
|
140
|
+
if not fix_errors_prompt or not extract_fix_prompt:
|
|
141
|
+
raise ValueError("Failed to load prompt templates")
|
|
142
|
+
|
|
143
|
+
# Step 2: Read error file content
|
|
144
|
+
existing_errors = ""
|
|
145
|
+
try:
|
|
146
|
+
if os.path.exists(error_file):
|
|
147
|
+
with open(error_file, 'r', encoding='utf-8') as f:
|
|
148
|
+
existing_errors = f.read()
|
|
149
|
+
except Exception as e:
|
|
150
|
+
if verbose:
|
|
151
|
+
console.print(f"[yellow]Warning: Could not read error file: {str(e)}[/yellow]")
|
|
152
|
+
|
|
153
|
+
# Step 3: Run first prompt through llm_invoke
|
|
154
|
+
processed_prompt = preprocess(
|
|
155
|
+
prompt,
|
|
156
|
+
recursive=False,
|
|
157
|
+
double_curly_brackets=True,
|
|
158
|
+
exclude_keys=['unit_test', 'code', 'unit_test_fix']
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
if verbose:
|
|
162
|
+
console.print(Panel("[bold green]Running fix_errors_from_unit_tests...[/bold green]"))
|
|
163
|
+
|
|
164
|
+
response1 = llm_invoke(
|
|
165
|
+
prompt=fix_errors_prompt,
|
|
166
|
+
input_json={
|
|
167
|
+
"unit_test": unit_test,
|
|
168
|
+
"code": code,
|
|
169
|
+
"prompt": processed_prompt,
|
|
170
|
+
"errors": error
|
|
171
|
+
},
|
|
172
|
+
strength=strength,
|
|
173
|
+
temperature=temperature,
|
|
174
|
+
verbose=verbose
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
total_cost += response1['cost']
|
|
178
|
+
model_name = response1['model_name']
|
|
179
|
+
result1 = response1['result']
|
|
180
|
+
|
|
181
|
+
# Step 4: Pretty print results and log to error file
|
|
182
|
+
if verbose:
|
|
183
|
+
console.print(Markdown(result1))
|
|
184
|
+
console.print(f"Cost of first run: ${response1['cost']:.6f}")
|
|
185
|
+
|
|
186
|
+
write_to_error_file(error_file, f"Model: {model_name}\nResult:\n{result1}")
|
|
187
|
+
|
|
188
|
+
# Step 5: Preprocess extract_fix prompt
|
|
189
|
+
processed_extract_prompt = preprocess(
|
|
190
|
+
extract_fix_prompt,
|
|
191
|
+
recursive=False,
|
|
192
|
+
double_curly_brackets=True,
|
|
193
|
+
exclude_keys=['unit_test', 'code', 'unit_test_fix']
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Step 6: Run second prompt through llm_invoke with fixed strength
|
|
197
|
+
if verbose:
|
|
198
|
+
console.print(Panel("[bold green]Running extract_unit_code_fix...[/bold green]"))
|
|
199
|
+
|
|
200
|
+
response2 = llm_invoke(
|
|
201
|
+
prompt=processed_extract_prompt,
|
|
202
|
+
input_json={
|
|
203
|
+
"unit_test_fix": result1,
|
|
204
|
+
"unit_test": unit_test,
|
|
205
|
+
"code": code
|
|
206
|
+
},
|
|
207
|
+
strength=0.895, # Fixed strength as per requirements
|
|
208
|
+
temperature=temperature,
|
|
209
|
+
output_pydantic=CodeFix,
|
|
210
|
+
verbose=verbose
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
total_cost += response2['cost']
|
|
214
|
+
result2: CodeFix = response2['result']
|
|
215
|
+
|
|
216
|
+
if verbose:
|
|
217
|
+
console.print(f"Total cost: ${total_cost:.6f}")
|
|
218
|
+
console.print(f"Model used: {model_name}")
|
|
219
|
+
|
|
220
|
+
return (
|
|
221
|
+
result2.update_unit_test,
|
|
222
|
+
result2.update_code,
|
|
223
|
+
result2.fixed_unit_test,
|
|
224
|
+
result2.fixed_code,
|
|
225
|
+
total_cost,
|
|
226
|
+
model_name
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
except ValidationError as e:
|
|
230
|
+
error_msg = f"Validation error in fix_errors_from_unit_tests: {str(e)}"
|
|
231
|
+
if verbose:
|
|
232
|
+
console.print(f"[bold red]{error_msg}[/bold red]")
|
|
233
|
+
write_to_error_file(error_file, error_msg)
|
|
234
|
+
return False, False, "", "", 0.0, ""
|
|
235
|
+
except Exception as e:
|
|
236
|
+
error_msg = f"Error in fix_errors_from_unit_tests: {str(e)}"
|
|
237
|
+
if verbose:
|
|
238
|
+
console.print(f"[bold red]{error_msg}[/bold red]")
|
|
239
|
+
write_to_error_file(error_file, error_msg)
|
|
240
|
+
return False, False, "", "", 0.0, ""
|