flyteplugins-codegen 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/codegen/__init__.py +18 -0
- flyteplugins/codegen/auto_coder_agent.py +1088 -0
- flyteplugins/codegen/core/__init__.py +19 -0
- flyteplugins/codegen/core/types.py +337 -0
- flyteplugins/codegen/data/__init__.py +27 -0
- flyteplugins/codegen/data/extraction.py +281 -0
- flyteplugins/codegen/data/schema.py +270 -0
- flyteplugins/codegen/execution/__init__.py +7 -0
- flyteplugins/codegen/execution/agent.py +671 -0
- flyteplugins/codegen/execution/docker.py +206 -0
- flyteplugins/codegen/generation/__init__.py +41 -0
- flyteplugins/codegen/generation/llm.py +1269 -0
- flyteplugins/codegen/generation/prompts.py +136 -0
- flyteplugins_codegen-2.0.6.dist-info/METADATA +441 -0
- flyteplugins_codegen-2.0.6.dist-info/RECORD +17 -0
- flyteplugins_codegen-2.0.6.dist-info/WHEEL +5 -0
- flyteplugins_codegen-2.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1269 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import flyte
|
|
9
|
+
import flyte.errors
|
|
10
|
+
import litellm
|
|
11
|
+
|
|
12
|
+
from flyteplugins.codegen.core.types import (
|
|
13
|
+
CodePlan,
|
|
14
|
+
CodeSolution,
|
|
15
|
+
ErrorDiagnosis,
|
|
16
|
+
FixVerification,
|
|
17
|
+
TestFixResponse,
|
|
18
|
+
TestFunctionPatch,
|
|
19
|
+
_PackageDetectionResponse,
|
|
20
|
+
_PackageReplacementResponse,
|
|
21
|
+
_TestCodeResponse,
|
|
22
|
+
)
|
|
23
|
+
from flyteplugins.codegen.data.schema import extract_token_usage
|
|
24
|
+
from flyteplugins.codegen.generation.prompts import (
|
|
25
|
+
PACKAGE_MANAGER_MAP,
|
|
26
|
+
TEST_FRAMEWORKS,
|
|
27
|
+
build_enhanced_prompt,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
_PYTHON_STDLIB = sys.stdlib_module_names
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def strip_code_fences(code: str) -> str:
|
|
36
|
+
"""Strip markdown code fences from LLM output."""
|
|
37
|
+
stripped = code.strip()
|
|
38
|
+
# Match ```python ... ``` or ``` ... ```
|
|
39
|
+
match = re.match(r"^```(?:\w+)?\s*\n(.*?)```\s*$", stripped, re.DOTALL)
|
|
40
|
+
if match:
|
|
41
|
+
return match.group(1).strip()
|
|
42
|
+
return stripped
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def filter_stdlib(packages: list[str]) -> list[str]:
|
|
46
|
+
"""Remove Python standard library modules from a package list."""
|
|
47
|
+
return [p for p in packages if p.split(".")[0].lower() not in _PYTHON_STDLIB]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@flyte.trace
|
|
51
|
+
async def generate_plan(
|
|
52
|
+
model: str,
|
|
53
|
+
prompt: str,
|
|
54
|
+
language: str,
|
|
55
|
+
schema: Optional[str],
|
|
56
|
+
constraints: Optional[list[str]],
|
|
57
|
+
data_samples: Optional[str],
|
|
58
|
+
inputs: Optional[dict[str, type]],
|
|
59
|
+
outputs: Optional[dict[str, type]],
|
|
60
|
+
litellm_params: Optional[dict],
|
|
61
|
+
) -> tuple[CodePlan, int, int]:
|
|
62
|
+
"""Generate a structured plan for the code solution before writing code.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Tuple of (plan, input_tokens, output_tokens)
|
|
66
|
+
"""
|
|
67
|
+
base_prompt = build_enhanced_prompt(prompt, language, schema, constraints, data_samples, inputs, outputs)
|
|
68
|
+
|
|
69
|
+
# Add planning-specific instructions
|
|
70
|
+
planning_prompt = f"""You are planning a {language} solution for the following task:
|
|
71
|
+
|
|
72
|
+
{base_prompt}
|
|
73
|
+
|
|
74
|
+
Create a detailed plan including:
|
|
75
|
+
1. Overall description of the solution
|
|
76
|
+
2. High-level approach and algorithm
|
|
77
|
+
|
|
78
|
+
The solution will be implemented as a single {language} file.
|
|
79
|
+
Focus on clarity and completeness. The plan will guide code generation."""
|
|
80
|
+
|
|
81
|
+
# Build params with defaults
|
|
82
|
+
params = {
|
|
83
|
+
"model": model,
|
|
84
|
+
"messages": [{"role": "user", "content": planning_prompt}],
|
|
85
|
+
"max_tokens": 1000,
|
|
86
|
+
"temperature": 0.3,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Merge litellm_params (can override anything except response_format)
|
|
90
|
+
params.update(litellm_params or {})
|
|
91
|
+
|
|
92
|
+
# Always set response_format last
|
|
93
|
+
params["response_format"] = CodePlan
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
response = await litellm.acompletion(**params)
|
|
97
|
+
except Exception as e:
|
|
98
|
+
# Check if it's an unsupported params error
|
|
99
|
+
if "UnsupportedParamsError" in type(e).__name__ or "does not support parameters" in str(e):
|
|
100
|
+
raise flyte.errors.RuntimeUserError(
|
|
101
|
+
f"Model '{model}' does not support structured outputs (response_format parameter). "
|
|
102
|
+
f"Please use a model that supports structured outputs like: gpt-4.1, "
|
|
103
|
+
f"claude-3-5-sonnet, or similar models."
|
|
104
|
+
) from e
|
|
105
|
+
raise
|
|
106
|
+
|
|
107
|
+
# Extract token usage
|
|
108
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
109
|
+
|
|
110
|
+
content = response.choices[0].message.content
|
|
111
|
+
if isinstance(content, str):
|
|
112
|
+
try:
|
|
113
|
+
plan_dict = json.loads(content)
|
|
114
|
+
except json.JSONDecodeError:
|
|
115
|
+
logger.warning("Failed to parse plan JSON, using fallback")
|
|
116
|
+
return (
|
|
117
|
+
CodePlan(description=content[:500], approach=""),
|
|
118
|
+
input_tokens,
|
|
119
|
+
output_tokens,
|
|
120
|
+
)
|
|
121
|
+
return CodePlan(**plan_dict), input_tokens, output_tokens
|
|
122
|
+
return content, input_tokens, output_tokens
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@flyte.trace
|
|
126
|
+
async def generate_code(
|
|
127
|
+
model: str,
|
|
128
|
+
conversation: list[dict],
|
|
129
|
+
plan: Optional[CodePlan],
|
|
130
|
+
litellm_params: Optional[dict],
|
|
131
|
+
is_retry: bool = False,
|
|
132
|
+
) -> tuple[CodeSolution, int, int]:
|
|
133
|
+
"""Generate code with structured output using Pydantic model.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
model: LLM model to use
|
|
137
|
+
conversation: Message history
|
|
138
|
+
plan: Optional plan to guide code generation (only used for initial generation)
|
|
139
|
+
litellm_params: LiteLLM parameters
|
|
140
|
+
is_retry: If True, skip plan context (for debugging/fixing code)
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Tuple of (CodeSolution, input_tokens, output_tokens)
|
|
144
|
+
"""
|
|
145
|
+
# Add plan context only for initial generation (not retries)
|
|
146
|
+
messages = conversation.copy()
|
|
147
|
+
if plan and not is_retry:
|
|
148
|
+
plan_context = f"""
|
|
149
|
+
Plan for implementation:
|
|
150
|
+
- Description: {plan.description}
|
|
151
|
+
- Approach: {plan.approach}
|
|
152
|
+
|
|
153
|
+
Follow this plan when generating the code."""
|
|
154
|
+
|
|
155
|
+
# Insert plan context before the last user message
|
|
156
|
+
if messages and messages[-1]["role"] == "user":
|
|
157
|
+
messages.insert(-1, {"role": "system", "content": plan_context})
|
|
158
|
+
else:
|
|
159
|
+
messages.append({"role": "system", "content": plan_context})
|
|
160
|
+
|
|
161
|
+
# Build params with defaults
|
|
162
|
+
params = {
|
|
163
|
+
"model": model,
|
|
164
|
+
"messages": messages,
|
|
165
|
+
"max_tokens": 2000,
|
|
166
|
+
"temperature": 0.7,
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
# Merge litellm_params (can override anything except response_format)
|
|
170
|
+
params.update(litellm_params or {})
|
|
171
|
+
|
|
172
|
+
# Always set response_format last
|
|
173
|
+
params["response_format"] = CodeSolution
|
|
174
|
+
|
|
175
|
+
response = await litellm.acompletion(**params)
|
|
176
|
+
|
|
177
|
+
# Extract token usage
|
|
178
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
179
|
+
|
|
180
|
+
content = response.choices[0].message.content
|
|
181
|
+
if isinstance(content, str):
|
|
182
|
+
try:
|
|
183
|
+
solution_dict = json.loads(content)
|
|
184
|
+
except json.JSONDecodeError:
|
|
185
|
+
logger.warning("Failed to parse code solution JSON, extracting code from raw response")
|
|
186
|
+
return (
|
|
187
|
+
CodeSolution(code=strip_code_fences(content)),
|
|
188
|
+
input_tokens,
|
|
189
|
+
output_tokens,
|
|
190
|
+
)
|
|
191
|
+
solution = CodeSolution(**solution_dict)
|
|
192
|
+
else:
|
|
193
|
+
solution = content
|
|
194
|
+
|
|
195
|
+
# Strip code fences the model may have included
|
|
196
|
+
solution.code = strip_code_fences(solution.code)
|
|
197
|
+
return solution, input_tokens, output_tokens
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@flyte.trace
|
|
201
|
+
async def detect_required_packages(
|
|
202
|
+
model: str,
|
|
203
|
+
code: str,
|
|
204
|
+
language: str,
|
|
205
|
+
litellm_params: Optional[dict],
|
|
206
|
+
) -> tuple[list[str], int, int]:
|
|
207
|
+
"""Use LLM to detect required packages from code.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Tuple of (packages, input_tokens, output_tokens)
|
|
211
|
+
"""
|
|
212
|
+
if not code.strip():
|
|
213
|
+
return [], 0, 0
|
|
214
|
+
|
|
215
|
+
package_type = PACKAGE_MANAGER_MAP.get(language.lower(), "package names for the package manager")
|
|
216
|
+
|
|
217
|
+
detection_prompt = f"""Given this {language} code:
|
|
218
|
+
|
|
219
|
+
```{language}
|
|
220
|
+
{code}
|
|
221
|
+
```
|
|
222
|
+
|
|
223
|
+
List the {package_type} needed to install the dependencies used in this code.
|
|
224
|
+
For standard library / built-in modules, don't include them.
|
|
225
|
+
Only include third-party packages that need to be installed."""
|
|
226
|
+
|
|
227
|
+
# Build params with defaults
|
|
228
|
+
params = {
|
|
229
|
+
"model": model,
|
|
230
|
+
"messages": [{"role": "user", "content": detection_prompt}],
|
|
231
|
+
"max_tokens": 200,
|
|
232
|
+
"temperature": 0.1,
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
# Merge litellm_params (can override anything except response_format)
|
|
236
|
+
params.update(litellm_params or {})
|
|
237
|
+
|
|
238
|
+
# Always set response_format last
|
|
239
|
+
params["response_format"] = _PackageDetectionResponse
|
|
240
|
+
|
|
241
|
+
response = await litellm.acompletion(**params)
|
|
242
|
+
|
|
243
|
+
# Extract token usage
|
|
244
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
245
|
+
|
|
246
|
+
content = response.choices[0].message.content
|
|
247
|
+
if isinstance(content, str):
|
|
248
|
+
try:
|
|
249
|
+
result_dict = json.loads(content)
|
|
250
|
+
except json.JSONDecodeError:
|
|
251
|
+
logger.warning("Failed to parse package detection JSON, returning empty list")
|
|
252
|
+
return [], input_tokens, output_tokens
|
|
253
|
+
packages = result_dict.get("packages", [])
|
|
254
|
+
else:
|
|
255
|
+
packages = content.packages
|
|
256
|
+
|
|
257
|
+
return filter_stdlib(packages), input_tokens, output_tokens
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
async def suggest_replacement_package(
|
|
261
|
+
model: str,
|
|
262
|
+
bad_package: str,
|
|
263
|
+
original_error: str,
|
|
264
|
+
solution_code: str,
|
|
265
|
+
litellm_params: Optional[dict] = None,
|
|
266
|
+
) -> tuple[Optional[str], int, int]:
|
|
267
|
+
"""Ask the LLM for the correct Debian package name to replace a bad one.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
Tuple of (replacement_package_name or None, input_tokens, output_tokens)
|
|
271
|
+
"""
|
|
272
|
+
prompt = f"""The Debian/Ubuntu apt package "{bad_package}" does not exist.
|
|
273
|
+
|
|
274
|
+
Error: {original_error}
|
|
275
|
+
|
|
276
|
+
The code that needs this package:
|
|
277
|
+
```python
|
|
278
|
+
{solution_code[:1000]}
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
What is the correct Debian/Ubuntu apt package name that should be used instead?
|
|
282
|
+
Set replacement to the correct package name or null if no system package is needed."""
|
|
283
|
+
|
|
284
|
+
params = {
|
|
285
|
+
"model": model,
|
|
286
|
+
"messages": [{"role": "user", "content": prompt}],
|
|
287
|
+
"max_tokens": 50,
|
|
288
|
+
"temperature": 0.0,
|
|
289
|
+
}
|
|
290
|
+
params.update(litellm_params or {})
|
|
291
|
+
params["response_format"] = _PackageReplacementResponse
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
response = await litellm.acompletion(**params)
|
|
295
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
296
|
+
|
|
297
|
+
content = response.choices[0].message.content
|
|
298
|
+
if isinstance(content, str):
|
|
299
|
+
try:
|
|
300
|
+
result = json.loads(content)
|
|
301
|
+
replacement = result.get("replacement")
|
|
302
|
+
except json.JSONDecodeError:
|
|
303
|
+
return None, input_tokens, output_tokens
|
|
304
|
+
else:
|
|
305
|
+
replacement = content.replacement
|
|
306
|
+
|
|
307
|
+
if replacement:
|
|
308
|
+
logger.info(f"LLM suggested replacing '{bad_package}' with '{replacement}'")
|
|
309
|
+
return replacement, input_tokens, output_tokens
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.warning(f"Failed to get package replacement suggestion: {e}")
|
|
312
|
+
return None, 0, 0
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
async def detect_and_track_packages(
|
|
316
|
+
model: str,
|
|
317
|
+
solution: CodeSolution,
|
|
318
|
+
base_pkgs: list[str],
|
|
319
|
+
detected_packages: list[str],
|
|
320
|
+
detected_system_packages: list[str],
|
|
321
|
+
litellm_params: Optional[dict] = None,
|
|
322
|
+
) -> tuple[bool, list[str], list[str], int, int]:
|
|
323
|
+
"""Detect packages from solution and track them.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
(needs_rebuild, updated_detected_packages, updated_detected_system_packages, input_tokens, output_tokens)
|
|
327
|
+
"""
|
|
328
|
+
needs_rebuild = False
|
|
329
|
+
total_input_tokens = 0
|
|
330
|
+
total_output_tokens = 0
|
|
331
|
+
|
|
332
|
+
# Detect system packages from LLM
|
|
333
|
+
if solution.system_packages and solution.system_packages != detected_system_packages:
|
|
334
|
+
detected_system_packages = solution.system_packages.copy()
|
|
335
|
+
logger.info(f"Detected system packages: {detected_system_packages}")
|
|
336
|
+
needs_rebuild = True
|
|
337
|
+
|
|
338
|
+
# Detect language packages from code
|
|
339
|
+
if solution.code.strip():
|
|
340
|
+
new_packages, in_tok, out_tok = await detect_required_packages(
|
|
341
|
+
model, solution.code, solution.language, litellm_params
|
|
342
|
+
)
|
|
343
|
+
total_input_tokens += in_tok
|
|
344
|
+
total_output_tokens += out_tok
|
|
345
|
+
|
|
346
|
+
if new_packages:
|
|
347
|
+
added_packages = []
|
|
348
|
+
for pkg in new_packages:
|
|
349
|
+
if pkg not in detected_packages and pkg not in base_pkgs:
|
|
350
|
+
detected_packages.append(pkg)
|
|
351
|
+
added_packages.append(pkg)
|
|
352
|
+
|
|
353
|
+
if added_packages:
|
|
354
|
+
logger.info(f"Detected new {solution.language} packages: {added_packages}")
|
|
355
|
+
needs_rebuild = True
|
|
356
|
+
|
|
357
|
+
return (
|
|
358
|
+
needs_rebuild,
|
|
359
|
+
detected_packages,
|
|
360
|
+
detected_system_packages,
|
|
361
|
+
total_input_tokens,
|
|
362
|
+
total_output_tokens,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@flyte.trace
|
|
367
|
+
async def generate_tests(
|
|
368
|
+
model: str,
|
|
369
|
+
prompt: str,
|
|
370
|
+
plan: CodePlan,
|
|
371
|
+
solution: CodeSolution,
|
|
372
|
+
constraints: Optional[list[str]] = None,
|
|
373
|
+
schema: Optional[str] = None,
|
|
374
|
+
data_samples: Optional[str] = None,
|
|
375
|
+
inputs: Optional[dict[str, type]] = None,
|
|
376
|
+
outputs: Optional[dict[str, type]] = None,
|
|
377
|
+
litellm_params: Optional[dict] = None,
|
|
378
|
+
) -> tuple[str, int, int]:
|
|
379
|
+
"""
|
|
380
|
+
Generate test code to validate the solution.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
(test_code, input_tokens, output_tokens)
|
|
384
|
+
"""
|
|
385
|
+
|
|
386
|
+
def _validate_python(code: str) -> None:
|
|
387
|
+
"""Raise if code is invalid or likely truncated."""
|
|
388
|
+
if not code.strip():
|
|
389
|
+
raise RuntimeError("Generated empty test code")
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
compile(code, "test_code", "exec")
|
|
393
|
+
except SyntaxError as e:
|
|
394
|
+
raise RuntimeError(f"Generated test code is invalid or truncated: {e}") from e
|
|
395
|
+
|
|
396
|
+
full_code = solution.code
|
|
397
|
+
language = solution.language.lower()
|
|
398
|
+
test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
|
|
399
|
+
test_framework = test_framework_info["name"]
|
|
400
|
+
|
|
401
|
+
import_instruction = (
|
|
402
|
+
"Import functions/classes from solution module (e.g., "
|
|
403
|
+
"'from solution import function_name'). "
|
|
404
|
+
"The code file is named solution.py and is located at "
|
|
405
|
+
"/var/inputs/solution.py."
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
test_prompt = f"""
|
|
409
|
+
You are generating TESTS.
|
|
410
|
+
|
|
411
|
+
Task:
|
|
412
|
+
{prompt}
|
|
413
|
+
|
|
414
|
+
Plan:
|
|
415
|
+
- Description: {plan.description}
|
|
416
|
+
- Approach: {plan.approach}
|
|
417
|
+
"""
|
|
418
|
+
|
|
419
|
+
if schema:
|
|
420
|
+
test_prompt += f"\nSchema:\n{schema}\n"
|
|
421
|
+
|
|
422
|
+
if constraints:
|
|
423
|
+
test_prompt += "\nConstraints:\n"
|
|
424
|
+
for i, c in enumerate(constraints, 1):
|
|
425
|
+
test_prompt += f"{i}. {c}\n"
|
|
426
|
+
|
|
427
|
+
if data_samples:
|
|
428
|
+
test_prompt += f"""
|
|
429
|
+
Data:
|
|
430
|
+
Use EXACTLY this format if referenced.
|
|
431
|
+
Do NOT invent additional samples.
|
|
432
|
+
|
|
433
|
+
{data_samples}
|
|
434
|
+
"""
|
|
435
|
+
|
|
436
|
+
if inputs:
|
|
437
|
+
test_prompt += f"\nCLI Arguments: {list(inputs.keys())}"
|
|
438
|
+
if outputs:
|
|
439
|
+
test_prompt += f"\nExpected outputs: {list(outputs.keys())}"
|
|
440
|
+
|
|
441
|
+
test_prompt += f"""
|
|
442
|
+
Solution code:
|
|
443
|
+
```{solution.language}
|
|
444
|
+
{full_code}
|
|
445
|
+
```"""
|
|
446
|
+
|
|
447
|
+
test_prompt += f"""
|
|
448
|
+
Test generation instructions:
|
|
449
|
+
|
|
450
|
+
Generate comprehensive but concise tests that maximize coverage through
|
|
451
|
+
multiple complementary cases using {test_framework}, not a single large test.
|
|
452
|
+
|
|
453
|
+
Requirements:
|
|
454
|
+
1. Use {test_framework} framework
|
|
455
|
+
2. Test the FULL execution path end-to-end (not just isolated functions)
|
|
456
|
+
3. Use provided data AS-IS
|
|
457
|
+
4. {import_instruction}
|
|
458
|
+
5. Follow {test_framework} best practices
|
|
459
|
+
6. /var/outputs is a pre-existing directory — NEVER delete or recreate it.
|
|
460
|
+
The solution writes output files there (one file per output),
|
|
461
|
+
and tests should READ from /var/outputs to verify correctness
|
|
462
|
+
|
|
463
|
+
IMPORTANT: Do NOT wrap output in code fences. Return ONLY valid Python test code."""
|
|
464
|
+
|
|
465
|
+
params = {
|
|
466
|
+
"model": model,
|
|
467
|
+
"messages": [{"role": "user", "content": test_prompt}],
|
|
468
|
+
"max_tokens": 5000,
|
|
469
|
+
"temperature": 0.4,
|
|
470
|
+
}
|
|
471
|
+
params.update(litellm_params or {})
|
|
472
|
+
|
|
473
|
+
response = await litellm.acompletion(**params)
|
|
474
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
475
|
+
|
|
476
|
+
content = response.choices[0].message.content
|
|
477
|
+
|
|
478
|
+
if not isinstance(content, str):
|
|
479
|
+
content = str(content)
|
|
480
|
+
|
|
481
|
+
test_code = strip_code_fences(content)
|
|
482
|
+
|
|
483
|
+
try:
|
|
484
|
+
_validate_python(test_code)
|
|
485
|
+
return test_code, input_tokens, output_tokens
|
|
486
|
+
except RuntimeError:
|
|
487
|
+
logger.warning("Plain text test generation invalid, retrying with structured output")
|
|
488
|
+
|
|
489
|
+
# Structured JSON fallback (optional)
|
|
490
|
+
params_structured = dict(params)
|
|
491
|
+
params_structured["response_format"] = _TestCodeResponse
|
|
492
|
+
|
|
493
|
+
response = await litellm.acompletion(**params_structured)
|
|
494
|
+
input_tokens2, output_tokens2 = extract_token_usage(response)
|
|
495
|
+
|
|
496
|
+
content = response.choices[0].message.content
|
|
497
|
+
|
|
498
|
+
if isinstance(content, str):
|
|
499
|
+
try:
|
|
500
|
+
data = json.loads(content)
|
|
501
|
+
test_code = data["test_code"]
|
|
502
|
+
except Exception as e:
|
|
503
|
+
raise RuntimeError("Structured output invalid; cannot recover test code") from e
|
|
504
|
+
else:
|
|
505
|
+
test_code = content.test_code
|
|
506
|
+
|
|
507
|
+
test_code = strip_code_fences(test_code)
|
|
508
|
+
_validate_python(test_code)
|
|
509
|
+
|
|
510
|
+
return (
|
|
511
|
+
test_code,
|
|
512
|
+
input_tokens + input_tokens2,
|
|
513
|
+
output_tokens + output_tokens2,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def _strip_parametrize_suffix(name: str) -> str:
|
|
518
|
+
"""Strip pytest parametrize suffix like [0-True-10-3] from a test name."""
|
|
519
|
+
return re.sub(r"\[.*?\]$", "", name)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def apply_test_patches(original_code: str, patches: list[TestFunctionPatch]) -> str:
|
|
523
|
+
"""Apply test function patches to the original test file using AST-based replacement.
|
|
524
|
+
|
|
525
|
+
Parses the original file to find function boundaries by name, then replaces
|
|
526
|
+
the source lines for each patched function. Preserves imports, fixtures,
|
|
527
|
+
and all non-patched code exactly.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
original_code: The full original test file source.
|
|
531
|
+
patches: List of patches, each containing a function name and its fixed source.
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
The patched test file as a string.
|
|
535
|
+
"""
|
|
536
|
+
tree = ast.parse(original_code)
|
|
537
|
+
lines = original_code.splitlines(keepends=True)
|
|
538
|
+
|
|
539
|
+
# Ensure last line has a newline for consistent joining
|
|
540
|
+
if lines and not lines[-1].endswith("\n"):
|
|
541
|
+
lines[-1] += "\n"
|
|
542
|
+
|
|
543
|
+
# Build map of function_name -> (start_line_0indexed, end_line_0indexed_exclusive)
|
|
544
|
+
# We need to handle decorated functions: start from the first decorator line
|
|
545
|
+
func_map: dict[str, tuple[int, int]] = {}
|
|
546
|
+
all_top_level = list(ast.iter_child_nodes(tree))
|
|
547
|
+
|
|
548
|
+
for idx, node in enumerate(all_top_level):
|
|
549
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
550
|
+
# Start line: use decorator_list if present, otherwise the def line
|
|
551
|
+
if node.decorator_list:
|
|
552
|
+
start = node.decorator_list[0].lineno - 1 # 0-indexed
|
|
553
|
+
else:
|
|
554
|
+
start = node.lineno - 1
|
|
555
|
+
|
|
556
|
+
# End line: next top-level node's start, or end of file
|
|
557
|
+
if idx + 1 < len(all_top_level):
|
|
558
|
+
next_node = all_top_level[idx + 1]
|
|
559
|
+
if hasattr(next_node, "decorator_list") and next_node.decorator_list:
|
|
560
|
+
end = next_node.decorator_list[0].lineno - 1
|
|
561
|
+
else:
|
|
562
|
+
end = next_node.lineno - 1
|
|
563
|
+
else:
|
|
564
|
+
end = len(lines)
|
|
565
|
+
|
|
566
|
+
# Trim trailing blank lines from this function's range
|
|
567
|
+
while end > start and lines[end - 1].strip() == "":
|
|
568
|
+
end -= 1
|
|
569
|
+
|
|
570
|
+
func_map[node.name] = (start, end)
|
|
571
|
+
|
|
572
|
+
# Apply patches in reverse order (so line numbers stay valid)
|
|
573
|
+
# Normalize patch names by stripping parametrize suffixes
|
|
574
|
+
# If LLM returns multiple patches for the same base function, keep the last one
|
|
575
|
+
patches_by_name: dict[str, TestFunctionPatch] = {}
|
|
576
|
+
for p in patches:
|
|
577
|
+
base_name = _strip_parametrize_suffix(p.test_name)
|
|
578
|
+
if base_name in patches_by_name:
|
|
579
|
+
logger.warning(
|
|
580
|
+
f"Multiple patches for '{base_name}' (from '{p.test_name}'). "
|
|
581
|
+
f"Keeping last patch — ensure it includes ALL parametrize fixes."
|
|
582
|
+
)
|
|
583
|
+
patches_by_name[base_name] = p
|
|
584
|
+
replacements = []
|
|
585
|
+
for name, (start, end) in func_map.items():
|
|
586
|
+
if name in patches_by_name:
|
|
587
|
+
replacements.append((start, end, patches_by_name[name].fixed_code))
|
|
588
|
+
|
|
589
|
+
# Sort by start line descending so we can splice without shifting
|
|
590
|
+
replacements.sort(key=lambda r: r[0], reverse=True)
|
|
591
|
+
|
|
592
|
+
for start, end, fixed_code in replacements:
|
|
593
|
+
# Ensure fixed code ends with newline
|
|
594
|
+
fixed = fixed_code.rstrip("\n") + "\n"
|
|
595
|
+
# Replace the lines
|
|
596
|
+
lines[start:end] = [fixed]
|
|
597
|
+
|
|
598
|
+
result = "".join(lines)
|
|
599
|
+
|
|
600
|
+
# Validate the patched code compiles
|
|
601
|
+
try:
|
|
602
|
+
compile(result, "patched_test", "exec")
|
|
603
|
+
except SyntaxError as e:
|
|
604
|
+
logger.warning(f"Patched test code has syntax error: {e}. Returning as-is.")
|
|
605
|
+
|
|
606
|
+
return result
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@flyte.trace
|
|
610
|
+
async def fix_failing_tests(
|
|
611
|
+
model: str,
|
|
612
|
+
test_code: str,
|
|
613
|
+
diagnosis: ErrorDiagnosis,
|
|
614
|
+
solution: CodeSolution,
|
|
615
|
+
litellm_params: Optional[dict] = None,
|
|
616
|
+
) -> tuple[str, list[TestFunctionPatch], int, int]:
|
|
617
|
+
"""Fix only the failing tests by returning patches for individual functions.
|
|
618
|
+
|
|
619
|
+
Instead of asking the LLM to reproduce the entire test file, asks for only
|
|
620
|
+
the fixed test functions and splices them into the original file.
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
test_code: Complete test file
|
|
624
|
+
diagnosis: Structured diagnosis of test failures
|
|
625
|
+
solution: The solution being tested
|
|
626
|
+
|
|
627
|
+
Returns:
|
|
628
|
+
Tuple of (fixed_test_code, patches, input_tokens, output_tokens)
|
|
629
|
+
"""
|
|
630
|
+
full_code = solution.code
|
|
631
|
+
|
|
632
|
+
# Build diagnosis information from individual failures
|
|
633
|
+
diagnosis_info = []
|
|
634
|
+
for i, failure in enumerate(diagnosis.failures, 1):
|
|
635
|
+
diagnosis_info.append(
|
|
636
|
+
f"""
|
|
637
|
+
Test {i}: {failure.test_name}
|
|
638
|
+
- Expected: {failure.expected_behavior}
|
|
639
|
+
- Actual: {failure.actual_behavior}
|
|
640
|
+
- Root cause: {failure.root_cause}
|
|
641
|
+
- Suggested fix: {failure.suggested_fix}"""
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
diagnosis_section = "\n".join(diagnosis_info)
|
|
645
|
+
|
|
646
|
+
# Extract only the failing test functions to include in the prompt
|
|
647
|
+
# Strip parametrize suffixes so we match the actual function names in the AST
|
|
648
|
+
failing_names = {_strip_parametrize_suffix(f.test_name) for f in diagnosis.failures}
|
|
649
|
+
try:
|
|
650
|
+
tree = ast.parse(test_code)
|
|
651
|
+
test_lines = test_code.splitlines(keepends=True)
|
|
652
|
+
all_nodes = list(ast.iter_child_nodes(tree))
|
|
653
|
+
failing_snippets = []
|
|
654
|
+
for idx, node in enumerate(all_nodes):
|
|
655
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
656
|
+
if node.name in failing_names:
|
|
657
|
+
if node.decorator_list:
|
|
658
|
+
start = node.decorator_list[0].lineno - 1
|
|
659
|
+
else:
|
|
660
|
+
start = node.lineno - 1
|
|
661
|
+
if idx + 1 < len(all_nodes):
|
|
662
|
+
next_node = all_nodes[idx + 1]
|
|
663
|
+
if hasattr(next_node, "decorator_list") and next_node.decorator_list:
|
|
664
|
+
end = next_node.decorator_list[0].lineno - 1
|
|
665
|
+
else:
|
|
666
|
+
end = next_node.lineno - 1
|
|
667
|
+
else:
|
|
668
|
+
end = len(test_lines)
|
|
669
|
+
while end > start and test_lines[end - 1].strip() == "":
|
|
670
|
+
end -= 1
|
|
671
|
+
failing_snippets.append("".join(test_lines[start:end]))
|
|
672
|
+
failing_code_section = "\n\n".join(failing_snippets)
|
|
673
|
+
except SyntaxError:
|
|
674
|
+
# If we can't parse, fall back to sending the whole test file
|
|
675
|
+
failing_code_section = test_code
|
|
676
|
+
|
|
677
|
+
fix_prompt = f"""You are performing MINIMAL PATCH REPAIR on pytest test functions.
|
|
678
|
+
|
|
679
|
+
YOUR PRIMARY TASK: Apply the suggested fix for each failing test EXACTLY as described below.
|
|
680
|
+
|
|
681
|
+
═══════════════════════════════════════
|
|
682
|
+
DIAGNOSED FAILURES AND REQUIRED FIXES
|
|
683
|
+
═══════════════════════════════════════
|
|
684
|
+
{diagnosis_section}
|
|
685
|
+
|
|
686
|
+
For each failure above, the "Suggested fix" tells you EXACTLY what code change to make.
|
|
687
|
+
You MUST apply that specific change — do not interpret, simplify, or find an alternative approach.
|
|
688
|
+
|
|
689
|
+
═══════════════════════════════════════
|
|
690
|
+
FAILING TEST FUNCTIONS (ORIGINAL CODE)
|
|
691
|
+
═══════════════════════════════════════
|
|
692
|
+
```{solution.language}
|
|
693
|
+
{failing_code_section}
|
|
694
|
+
```
|
|
695
|
+
|
|
696
|
+
SOLUTION CODE (for reference only — do NOT modify):
|
|
697
|
+
The solution is saved as /var/inputs/solution.py. Tests import from it via `from solution import ...`.
|
|
698
|
+
```{solution.language}
|
|
699
|
+
{full_code}
|
|
700
|
+
```
|
|
701
|
+
|
|
702
|
+
RULES:
|
|
703
|
+
1. For each suggested fix, locate the exact code it refers to and apply the substitution.
|
|
704
|
+
2. Preserve ALL existing assertions, decorators, and test structure.
|
|
705
|
+
3. Add any necessary imports at the top of the function body if needed.
|
|
706
|
+
4. DO NOT rewrite, simplify, or restructure the test.
|
|
707
|
+
5. DO NOT change anything that the diagnosis does not mention.
|
|
708
|
+
|
|
709
|
+
OUTPUT FORMAT:
|
|
710
|
+
Return ONLY patches for functions that require modification:
|
|
711
|
+
- test_name: BASE function name (no [param] suffix)
|
|
712
|
+
- fixed_code: COMPLETE function including decorators"""
|
|
713
|
+
|
|
714
|
+
params = {
|
|
715
|
+
"model": model,
|
|
716
|
+
"messages": [{"role": "user", "content": fix_prompt}],
|
|
717
|
+
"max_tokens": 2000,
|
|
718
|
+
"temperature": 0.3,
|
|
719
|
+
}
|
|
720
|
+
params.update(litellm_params or {})
|
|
721
|
+
params["response_format"] = TestFixResponse
|
|
722
|
+
|
|
723
|
+
response = await litellm.acompletion(**params)
|
|
724
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
725
|
+
|
|
726
|
+
content = response.choices[0].message.content
|
|
727
|
+
if isinstance(content, str):
|
|
728
|
+
try:
|
|
729
|
+
result_dict = json.loads(content)
|
|
730
|
+
patches = [TestFunctionPatch(**p) for p in result_dict["patches"]]
|
|
731
|
+
except (json.JSONDecodeError, KeyError) as exc:
|
|
732
|
+
logger.warning(f"Failed to parse test fix patches JSON: {exc}")
|
|
733
|
+
# Fallback: treat as a single patch with the raw content
|
|
734
|
+
patches = []
|
|
735
|
+
raw = strip_code_fences(content)
|
|
736
|
+
for failure in diagnosis.failures:
|
|
737
|
+
patches.append(TestFunctionPatch(test_name=failure.test_name, fixed_code=raw))
|
|
738
|
+
else:
|
|
739
|
+
patches = content.patches
|
|
740
|
+
|
|
741
|
+
# Strip code fences from each patch's fixed_code
|
|
742
|
+
for patch in patches:
|
|
743
|
+
patch.fixed_code = strip_code_fences(patch.fixed_code)
|
|
744
|
+
|
|
745
|
+
# Apply patches to original test code
|
|
746
|
+
fixed_test_code = apply_test_patches(test_code, patches)
|
|
747
|
+
|
|
748
|
+
return fixed_test_code, patches, input_tokens, output_tokens
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def extract_error_messages_from_pytest(output: str) -> dict[str, str]:
|
|
752
|
+
"""Extract the final error message for each failed test from pytest output.
|
|
753
|
+
|
|
754
|
+
Args:
|
|
755
|
+
output: Pytest output string
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
Dict mapping test name to error message (e.g., "RecursionError: maximum recursion depth exceeded")
|
|
759
|
+
"""
|
|
760
|
+
error_messages = {}
|
|
761
|
+
current_test = None
|
|
762
|
+
|
|
763
|
+
# Parse pytest output line by line
|
|
764
|
+
lines = output.split("\n")
|
|
765
|
+
|
|
766
|
+
for line in lines:
|
|
767
|
+
# Match test failure header: _____ test_name _____
|
|
768
|
+
test_header_match = re.match(r"^_{5,}\s+(.+?)\s+_{5,}$", line)
|
|
769
|
+
if test_header_match:
|
|
770
|
+
current_test = test_header_match.group(1).strip()
|
|
771
|
+
# Remove parametrize suffix like [530.00-33-3]
|
|
772
|
+
current_test = re.sub(r"\[.*?\]$", "", current_test)
|
|
773
|
+
continue
|
|
774
|
+
|
|
775
|
+
# Match error line: starts with "E " followed by exception
|
|
776
|
+
if current_test and line.startswith("E "):
|
|
777
|
+
error_line = line[4:].strip() # Remove "E " prefix
|
|
778
|
+
# Only capture if it looks like an exception (contains "Error" or "Exception")
|
|
779
|
+
if "Error" in error_line or "Exception" in error_line or "Failed" in error_line:
|
|
780
|
+
# Extract just the exception type and message, not the full line
|
|
781
|
+
error_messages[current_test] = error_line
|
|
782
|
+
|
|
783
|
+
return error_messages
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
@flyte.trace
|
|
787
|
+
async def diagnose_error(
|
|
788
|
+
model: str,
|
|
789
|
+
solution: CodeSolution,
|
|
790
|
+
output: str,
|
|
791
|
+
prompt: str,
|
|
792
|
+
plan: CodePlan,
|
|
793
|
+
litellm_params: Optional[dict],
|
|
794
|
+
test_code: Optional[str] = None,
|
|
795
|
+
data_samples: Optional[str] = None,
|
|
796
|
+
constraints: Optional[list[str]] = None,
|
|
797
|
+
schema: Optional[str] = None,
|
|
798
|
+
) -> tuple[ErrorDiagnosis, int, int]:
|
|
799
|
+
"""Performs structured analysis to determine if the error is due to:
|
|
800
|
+
- Environment issues (missing packages, dependencies)
|
|
801
|
+
- Code logic issues (bugs, incorrect algorithms)
|
|
802
|
+
- Test code issues (wrong expected values, incorrect test logic)
|
|
803
|
+
|
|
804
|
+
Returns:
|
|
805
|
+
Tuple of (ErrorDiagnosis, input_tokens, output_tokens)
|
|
806
|
+
"""
|
|
807
|
+
# Extract error messages from pytest output for accurate tracking
|
|
808
|
+
error_messages = extract_error_messages_from_pytest(output)
|
|
809
|
+
|
|
810
|
+
full_code = solution.code
|
|
811
|
+
|
|
812
|
+
# Build error messages section for prompt
|
|
813
|
+
error_messages_section = ""
|
|
814
|
+
if error_messages:
|
|
815
|
+
error_messages_section = (
|
|
816
|
+
"\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
|
|
817
|
+
"EXTRACTED ERROR MESSAGES\n"
|
|
818
|
+
"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
|
|
819
|
+
)
|
|
820
|
+
for test_name, error_msg in error_messages.items():
|
|
821
|
+
error_messages_section += f"\n{test_name}: {error_msg}"
|
|
822
|
+
|
|
823
|
+
diagnosis_prompt = f"""
|
|
824
|
+
You are a STRICT FAILURE ANALYSIS SYSTEM.
|
|
825
|
+
|
|
826
|
+
Your goal is to determine whether failures are caused by:
|
|
827
|
+
A) Solution code bugs
|
|
828
|
+
B) Incorrect tests (tests generated by LLM)
|
|
829
|
+
C) Environment/setup issues
|
|
830
|
+
|
|
831
|
+
- Tests are NOT authoritative.
|
|
832
|
+
- Many tests contain hallucinated expected values.
|
|
833
|
+
- Only blame solution code if you can PROVE a violation of the specification.
|
|
834
|
+
|
|
835
|
+
TASK: {prompt}
|
|
836
|
+
|
|
837
|
+
PLAN: {plan.description} | {plan.approach}
|
|
838
|
+
|
|
839
|
+
SOLUTION CODE (saved as /var/inputs/solution.py, tests import via `from solution import ...`):
|
|
840
|
+
```{solution.language}
|
|
841
|
+
{full_code}
|
|
842
|
+
```"""
|
|
843
|
+
|
|
844
|
+
if data_samples:
|
|
845
|
+
diagnosis_prompt += f"""
|
|
846
|
+
|
|
847
|
+
DATA SAMPLES:
|
|
848
|
+
```
|
|
849
|
+
{data_samples}
|
|
850
|
+
```"""
|
|
851
|
+
|
|
852
|
+
if constraints:
|
|
853
|
+
diagnosis_prompt += "\n\nCONSTRAINTS:"
|
|
854
|
+
for i, constraint in enumerate(constraints, 1):
|
|
855
|
+
diagnosis_prompt += f"\n{i}. {constraint}"
|
|
856
|
+
|
|
857
|
+
if schema:
|
|
858
|
+
diagnosis_prompt += f"""
|
|
859
|
+
|
|
860
|
+
SCHEMA:
|
|
861
|
+
```
|
|
862
|
+
{schema}
|
|
863
|
+
```"""
|
|
864
|
+
|
|
865
|
+
diagnosis_prompt += f"""
|
|
866
|
+
|
|
867
|
+
TEST CODE:
|
|
868
|
+
{test_code}
|
|
869
|
+
|
|
870
|
+
TEST OUTPUT:
|
|
871
|
+
{output}
|
|
872
|
+
{error_messages_section}
|
|
873
|
+
|
|
874
|
+
MANDATORY DECISION PROCEDURE:
|
|
875
|
+
|
|
876
|
+
For EACH failing test, follow these steps IN ORDER:
|
|
877
|
+
|
|
878
|
+
STEP 0 — Execution evidence
|
|
879
|
+
Determine whether the solution produced observable effects.
|
|
880
|
+
If none are observed, consider invocation or execution issues.
|
|
881
|
+
|
|
882
|
+
STEP 1 — Environment failure
|
|
883
|
+
If the error indicates missing packages, import errors, file system issues,
|
|
884
|
+
or system configuration problems -> classify as "environment".
|
|
885
|
+
|
|
886
|
+
STEP 2 — Test contradicts spec/data/schema
|
|
887
|
+
Classify as "test_error" if ANY of the following are true:
|
|
888
|
+
|
|
889
|
+
- Expected value is not derivable from inputs, spec, or schema
|
|
890
|
+
- Test invents constants not present in problem description
|
|
891
|
+
- Test assumes behavior not specified
|
|
892
|
+
- Test contradicts constraints or data samples
|
|
893
|
+
- Test fails before solution logic executes
|
|
894
|
+
- Assertion checks wrong field/format/order
|
|
895
|
+
- Multiple valid outputs exist but test expects one specific value
|
|
896
|
+
|
|
897
|
+
STEP 3 — Uncertain root cause
|
|
898
|
+
If you CANNOT prove the solution is wrong -> classify as "test_error".
|
|
899
|
+
|
|
900
|
+
STEP 4 — Proven logic error
|
|
901
|
+
Classify as "logic" if you can demonstrate:
|
|
902
|
+
|
|
903
|
+
- Exact specification requirement violated
|
|
904
|
+
- Specific incorrect algorithm or implementation
|
|
905
|
+
- Output contradicts constraints despite valid test
|
|
906
|
+
|
|
907
|
+
Logic errors REQUIRE proof from the specification. Absence of proof -> NOT a logic error.
|
|
908
|
+
|
|
909
|
+
OUTPUT FORMAT: Return JSON with:
|
|
910
|
+
|
|
911
|
+
1. failures: list of objects:
|
|
912
|
+
- test_name
|
|
913
|
+
- error_message (copy exactly)
|
|
914
|
+
- expected_behavior
|
|
915
|
+
- actual_behavior
|
|
916
|
+
- root_cause (must cite evidence)
|
|
917
|
+
- suggested_fix ("Replace `old` with `new`")
|
|
918
|
+
- error_type: "test_error" | "logic" | "environment"
|
|
919
|
+
2. Collect from environment errors: needs_system_packages, needs_language_packages and needs_additional_commands
|
|
920
|
+
|
|
921
|
+
CRITICAL RULE:
|
|
922
|
+
- If ANY test is classified as "test_error", RETURN ONLY test_error failures. Do NOT include logic diagnoses.
|
|
923
|
+
- /var/outputs is a pre-existing directory and exists for the entire run. NEVER delete/recreate it.
|
|
924
|
+
Only write files into it. Never make it configurable.
|
|
925
|
+
"""
|
|
926
|
+
|
|
927
|
+
# Build params with defaults
|
|
928
|
+
params = {
|
|
929
|
+
"model": model,
|
|
930
|
+
"messages": [{"role": "user", "content": diagnosis_prompt}],
|
|
931
|
+
"max_tokens": 2000, # Increased for detailed test-by-test analysis
|
|
932
|
+
"temperature": 0.3,
|
|
933
|
+
}
|
|
934
|
+
|
|
935
|
+
# Merge litellm_params (can override anything except response_format)
|
|
936
|
+
params.update(litellm_params or {})
|
|
937
|
+
|
|
938
|
+
# Always set response_format last
|
|
939
|
+
params["response_format"] = ErrorDiagnosis
|
|
940
|
+
|
|
941
|
+
response = await litellm.acompletion(**params)
|
|
942
|
+
|
|
943
|
+
# Extract token usage
|
|
944
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
945
|
+
|
|
946
|
+
content = response.choices[0].message.content
|
|
947
|
+
if isinstance(content, str):
|
|
948
|
+
try:
|
|
949
|
+
result_dict = json.loads(content)
|
|
950
|
+
except json.JSONDecodeError:
|
|
951
|
+
logger.warning("Failed to parse diagnosis JSON, returning empty diagnosis")
|
|
952
|
+
return ErrorDiagnosis(failures=[]), input_tokens, output_tokens
|
|
953
|
+
return ErrorDiagnosis(**result_dict), input_tokens, output_tokens
|
|
954
|
+
return content, input_tokens, output_tokens
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
@flyte.trace
|
|
958
|
+
async def verify_test_fixes_applied(
|
|
959
|
+
model: str,
|
|
960
|
+
diagnosis: ErrorDiagnosis,
|
|
961
|
+
patches: list[TestFunctionPatch],
|
|
962
|
+
litellm_params: Optional[dict] = None,
|
|
963
|
+
) -> tuple[FixVerification, int, int]:
|
|
964
|
+
"""Verify that the suggested test fixes from diagnosis are present in the patches.
|
|
965
|
+
|
|
966
|
+
First checks structurally that every failing function has a matching patch
|
|
967
|
+
(accounting for parametrize suffixes). Then sends diagnosis + patches to LLM
|
|
968
|
+
to verify fix content.
|
|
969
|
+
|
|
970
|
+
Returns:
|
|
971
|
+
Tuple of (FixVerification, input_tokens, output_tokens)
|
|
972
|
+
"""
|
|
973
|
+
# Structural check: every failing test function must have a matching patch
|
|
974
|
+
patch_base_names = {_strip_parametrize_suffix(p.test_name) for p in patches}
|
|
975
|
+
failing_base_names = {_strip_parametrize_suffix(f.test_name) for f in diagnosis.failures}
|
|
976
|
+
missing_functions = failing_base_names - patch_base_names
|
|
977
|
+
if missing_functions:
|
|
978
|
+
return (
|
|
979
|
+
FixVerification(
|
|
980
|
+
all_fixes_applied=False,
|
|
981
|
+
applied_fixes=[n for n in failing_base_names if n in patch_base_names],
|
|
982
|
+
missing_fixes=list(missing_functions),
|
|
983
|
+
explanation=f"No patches returned for functions: {', '.join(missing_functions)}",
|
|
984
|
+
),
|
|
985
|
+
0,
|
|
986
|
+
0,
|
|
987
|
+
)
|
|
988
|
+
|
|
989
|
+
# Build verification prompt with only diagnosis + patches (no full files)
|
|
990
|
+
fixes_to_check = []
|
|
991
|
+
for i, failure in enumerate(diagnosis.failures, 1):
|
|
992
|
+
fixes_to_check.append(
|
|
993
|
+
f"""Fix {i} (for test: {failure.test_name}):
|
|
994
|
+
- Error: {failure.error_message}
|
|
995
|
+
- Root cause: {failure.root_cause}
|
|
996
|
+
- Required fix: {failure.suggested_fix}
|
|
997
|
+
- Verification: The old code/pattern must be REMOVED and the new code/pattern must be PRESENT"""
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
fixes_section = "\n\n".join(fixes_to_check)
|
|
1001
|
+
|
|
1002
|
+
patches_section = []
|
|
1003
|
+
for patch in patches:
|
|
1004
|
+
patches_section.append(f"### {patch.test_name}\n```python\n{patch.fixed_code}\n```")
|
|
1005
|
+
patches_text = "\n\n".join(patches_section) if patches_section else "(no patches returned)"
|
|
1006
|
+
|
|
1007
|
+
verify_prompt = f"""You are a CODE DIFF REVIEWER. \
|
|
1008
|
+
Your job is to verify that specific code changes were actually made.
|
|
1009
|
+
|
|
1010
|
+
For each required fix below, check whether the EXACT code change described
|
|
1011
|
+
in "Required fix" is present in the patched function. Do NOT accept
|
|
1012
|
+
alternative approaches that "address the same issue" — the specific
|
|
1013
|
+
code transformation must be visible.
|
|
1014
|
+
|
|
1015
|
+
═══════════════════════════════════════
|
|
1016
|
+
REQUIRED FIXES
|
|
1017
|
+
═══════════════════════════════════════
|
|
1018
|
+
{fixes_section}
|
|
1019
|
+
|
|
1020
|
+
═══════════════════════════════════════
|
|
1021
|
+
PATCHED TEST FUNCTIONS
|
|
1022
|
+
═══════════════════════════════════════
|
|
1023
|
+
{patches_text}
|
|
1024
|
+
|
|
1025
|
+
VERIFICATION RULES:
|
|
1026
|
+
1. For each required fix, find the patched function for that test.
|
|
1027
|
+
2. Check if the OLD code/pattern mentioned in the fix is GONE from the patch.
|
|
1028
|
+
3. Check if the NEW code/pattern mentioned in the fix is PRESENT in the patch.
|
|
1029
|
+
4. If the fix says "Replace X with Y", then X must NOT appear and Y MUST appear.
|
|
1030
|
+
5. A fix is NOT applied if the patch uses a different approach to solve the same problem.
|
|
1031
|
+
6. Set all_fixes_applied to true ONLY if EVERY fix passes checks 2, 3, and 4."""
|
|
1032
|
+
|
|
1033
|
+
params = {
|
|
1034
|
+
"model": model,
|
|
1035
|
+
"messages": [{"role": "user", "content": verify_prompt}],
|
|
1036
|
+
"max_tokens": 1000,
|
|
1037
|
+
"temperature": 0.1,
|
|
1038
|
+
}
|
|
1039
|
+
params.update(litellm_params or {})
|
|
1040
|
+
params["response_format"] = FixVerification
|
|
1041
|
+
|
|
1042
|
+
response = await litellm.acompletion(**params)
|
|
1043
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
1044
|
+
|
|
1045
|
+
content = response.choices[0].message.content
|
|
1046
|
+
if isinstance(content, str):
|
|
1047
|
+
try:
|
|
1048
|
+
result_dict = json.loads(content)
|
|
1049
|
+
verification = FixVerification(**result_dict)
|
|
1050
|
+
except (json.JSONDecodeError, Exception):
|
|
1051
|
+
logger.warning("Failed to parse test fix verification, assuming fixes not applied")
|
|
1052
|
+
verification = FixVerification(
|
|
1053
|
+
all_fixes_applied=False,
|
|
1054
|
+
applied_fixes=[],
|
|
1055
|
+
missing_fixes=["parse_error"],
|
|
1056
|
+
explanation="Failed to parse verification response",
|
|
1057
|
+
)
|
|
1058
|
+
else:
|
|
1059
|
+
verification = content
|
|
1060
|
+
|
|
1061
|
+
return verification, input_tokens, output_tokens
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
@flyte.trace
|
|
1065
|
+
async def verify_logic_fixes_applied(
|
|
1066
|
+
model: str,
|
|
1067
|
+
diagnosis: ErrorDiagnosis,
|
|
1068
|
+
new_solution: CodeSolution,
|
|
1069
|
+
litellm_params: Optional[dict] = None,
|
|
1070
|
+
) -> tuple[FixVerification, int, int]:
|
|
1071
|
+
"""Verify that the suggested fixes from diagnosis are present in new code.
|
|
1072
|
+
|
|
1073
|
+
Only verifies "logic" failures - environment and test_error are handled differently.
|
|
1074
|
+
|
|
1075
|
+
Returns:
|
|
1076
|
+
Tuple of (FixVerification, input_tokens, output_tokens)
|
|
1077
|
+
"""
|
|
1078
|
+
# Only check logic failures (environment and test_error don't modify solution code)
|
|
1079
|
+
logic_failures = [f for f in diagnosis.failures if f.error_type == "logic"]
|
|
1080
|
+
|
|
1081
|
+
if not logic_failures:
|
|
1082
|
+
# No logic fixes to verify
|
|
1083
|
+
return (
|
|
1084
|
+
FixVerification(
|
|
1085
|
+
all_fixes_applied=True,
|
|
1086
|
+
applied_fixes=[],
|
|
1087
|
+
missing_fixes=[],
|
|
1088
|
+
explanation="No logic fixes to verify (only environment/test errors)",
|
|
1089
|
+
),
|
|
1090
|
+
0,
|
|
1091
|
+
0,
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
new_code = new_solution.code
|
|
1095
|
+
|
|
1096
|
+
# Build verification prompt - only for logic failures
|
|
1097
|
+
fixes_to_check = []
|
|
1098
|
+
for i, failure in enumerate(logic_failures, 1):
|
|
1099
|
+
fixes_to_check.append(
|
|
1100
|
+
f"""Fix {i} (for test: {failure.test_name}):
|
|
1101
|
+
- Root cause: {failure.root_cause}
|
|
1102
|
+
- Required fix: {failure.suggested_fix}"""
|
|
1103
|
+
)
|
|
1104
|
+
|
|
1105
|
+
fixes_section = "\n\n".join(fixes_to_check)
|
|
1106
|
+
|
|
1107
|
+
verify_prompt = f"""You must verify that ALL the required fixes below are present in the new code.
|
|
1108
|
+
|
|
1109
|
+
Required fixes:
|
|
1110
|
+
{fixes_section}
|
|
1111
|
+
|
|
1112
|
+
New code to verify:
|
|
1113
|
+
```{new_solution.language}
|
|
1114
|
+
{new_code}
|
|
1115
|
+
```
|
|
1116
|
+
|
|
1117
|
+
Check each fix carefully:
|
|
1118
|
+
1. For each fix, determine if the required change is present in the new code
|
|
1119
|
+
2. If a fix says "Replace X with Y", verify that X is gone and Y is present
|
|
1120
|
+
3. List which fixes are applied and which are missing
|
|
1121
|
+
4. Set all_fixes_applied to true ONLY if every single fix is present
|
|
1122
|
+
|
|
1123
|
+
Be strict - if even one fix is missing or partially applied, set all_fixes_applied to false."""
|
|
1124
|
+
|
|
1125
|
+
params = {
|
|
1126
|
+
"model": model,
|
|
1127
|
+
"messages": [{"role": "user", "content": verify_prompt}],
|
|
1128
|
+
"max_tokens": 1000,
|
|
1129
|
+
"temperature": 0.1,
|
|
1130
|
+
}
|
|
1131
|
+
params.update(litellm_params or {})
|
|
1132
|
+
params["response_format"] = FixVerification
|
|
1133
|
+
|
|
1134
|
+
response = await litellm.acompletion(**params)
|
|
1135
|
+
input_tokens, output_tokens = extract_token_usage(response)
|
|
1136
|
+
|
|
1137
|
+
content = response.choices[0].message.content
|
|
1138
|
+
if isinstance(content, str):
|
|
1139
|
+
try:
|
|
1140
|
+
result_dict = json.loads(content)
|
|
1141
|
+
verification = FixVerification(**result_dict)
|
|
1142
|
+
except (json.JSONDecodeError, Exception):
|
|
1143
|
+
logger.warning("Failed to parse logic fix verification, assuming fixes not applied")
|
|
1144
|
+
verification = FixVerification(
|
|
1145
|
+
all_fixes_applied=False,
|
|
1146
|
+
applied_fixes=[],
|
|
1147
|
+
missing_fixes=["parse_error"],
|
|
1148
|
+
explanation="Failed to parse verification response",
|
|
1149
|
+
)
|
|
1150
|
+
else:
|
|
1151
|
+
verification = content
|
|
1152
|
+
|
|
1153
|
+
return verification, input_tokens, output_tokens
|
|
1154
|
+
|
|
1155
|
+
|
|
1156
|
+
async def diagnose_and_plan_environment_fix(
|
|
1157
|
+
model: str,
|
|
1158
|
+
solution: CodeSolution,
|
|
1159
|
+
code_output: str,
|
|
1160
|
+
prompt: str,
|
|
1161
|
+
plan: CodePlan,
|
|
1162
|
+
detected_packages: list[str],
|
|
1163
|
+
detected_system_packages: list[str],
|
|
1164
|
+
additional_commands: list[str],
|
|
1165
|
+
litellm_params: Optional[dict] = None,
|
|
1166
|
+
test_code: Optional[str] = None,
|
|
1167
|
+
data_samples: Optional[str] = None,
|
|
1168
|
+
constraints: Optional[list[str]] = None,
|
|
1169
|
+
schema: Optional[str] = None,
|
|
1170
|
+
) -> tuple[bool | str, list[str], list[str], list[str], int, int, ErrorDiagnosis]:
|
|
1171
|
+
"""Diagnose error and plan environment fix (don't execute yet).
|
|
1172
|
+
|
|
1173
|
+
Returns:
|
|
1174
|
+
Tuple of (primary_error_type, updated_detected_packages, updated_detected_system_packages,
|
|
1175
|
+
updated_additional_commands, input_tokens, output_tokens, diagnosis)
|
|
1176
|
+
|
|
1177
|
+
where primary_error_type is either:
|
|
1178
|
+
- "test_error" (str): Test code has bugs, must fix tests first
|
|
1179
|
+
- False (bool): Environment and/or logic errors, handle together
|
|
1180
|
+
"""
|
|
1181
|
+
diagnosis, in_tok, out_tok = await diagnose_error(
|
|
1182
|
+
model,
|
|
1183
|
+
solution,
|
|
1184
|
+
code_output,
|
|
1185
|
+
prompt,
|
|
1186
|
+
plan,
|
|
1187
|
+
litellm_params,
|
|
1188
|
+
test_code,
|
|
1189
|
+
data_samples,
|
|
1190
|
+
constraints,
|
|
1191
|
+
schema,
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
# Important: If any test errors exist, only keep test_error failures
|
|
1195
|
+
# Discard logic/environment failures - they're unreliable when tests are broken
|
|
1196
|
+
test_error_failures = [f for f in diagnosis.failures if f.error_type == "test_error"]
|
|
1197
|
+
|
|
1198
|
+
if test_error_failures:
|
|
1199
|
+
logger.warning(
|
|
1200
|
+
f"Found {len(test_error_failures)} test_error failure(s). "
|
|
1201
|
+
f"Discarding {len(diagnosis.failures) - len(test_error_failures)} logic/environment failures "
|
|
1202
|
+
f"(unreliable when tests are broken)"
|
|
1203
|
+
)
|
|
1204
|
+
# Replace failures with only test_error failures
|
|
1205
|
+
diagnosis.failures = test_error_failures
|
|
1206
|
+
|
|
1207
|
+
# Count error types across all failures (after filtering)
|
|
1208
|
+
error_type_counts = {"environment": 0, "test_error": 0, "logic": 0}
|
|
1209
|
+
for failure in diagnosis.failures:
|
|
1210
|
+
error_type_counts[failure.error_type] += 1
|
|
1211
|
+
|
|
1212
|
+
logger.info(f"Number of failures: {len(diagnosis.failures)}")
|
|
1213
|
+
logger.info(
|
|
1214
|
+
f"Breakdown: {error_type_counts['environment']} environment, "
|
|
1215
|
+
f"{error_type_counts['test_error']} test_error, "
|
|
1216
|
+
f"{error_type_counts['logic']} logic"
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
for i, failure in enumerate(diagnosis.failures, 1):
|
|
1220
|
+
logger.info(
|
|
1221
|
+
f"Test {i} [{failure.error_type}] - {failure.test_name}",
|
|
1222
|
+
extra={"markup": False},
|
|
1223
|
+
)
|
|
1224
|
+
logger.info(f"Root cause: {failure.root_cause}", extra={"markup": False})
|
|
1225
|
+
logger.info(f"Fix: {failure.suggested_fix}", extra={"markup": False})
|
|
1226
|
+
|
|
1227
|
+
# Determine actions based on all groups
|
|
1228
|
+
has_environment_errors = error_type_counts["environment"] > 0
|
|
1229
|
+
has_test_errors = error_type_counts["test_error"] > 0
|
|
1230
|
+
has_logic_errors = error_type_counts["logic"] > 0
|
|
1231
|
+
|
|
1232
|
+
if has_test_errors:
|
|
1233
|
+
# Test code has bugs - must regenerate tests first
|
|
1234
|
+
logger.warning(
|
|
1235
|
+
f"{error_type_counts['test_error']} test(s) have test bugs. "
|
|
1236
|
+
f"Will regenerate tests first, then handle other errors."
|
|
1237
|
+
)
|
|
1238
|
+
return (
|
|
1239
|
+
"test_error",
|
|
1240
|
+
detected_packages,
|
|
1241
|
+
detected_system_packages,
|
|
1242
|
+
additional_commands,
|
|
1243
|
+
in_tok,
|
|
1244
|
+
out_tok,
|
|
1245
|
+
diagnosis,
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
# No test errors - we can fix environment + logic together in one go
|
|
1249
|
+
if has_environment_errors and has_logic_errors:
|
|
1250
|
+
logger.info(
|
|
1251
|
+
f"Will fix {error_type_counts['environment']} environment error(s) "
|
|
1252
|
+
f"and patch code for {error_type_counts['logic']} logic error(s) in one iteration"
|
|
1253
|
+
)
|
|
1254
|
+
elif has_environment_errors:
|
|
1255
|
+
logger.info(f"Will fix {error_type_counts['environment']} environment error(s)")
|
|
1256
|
+
elif has_logic_errors:
|
|
1257
|
+
logger.info(f"Will patch code for {error_type_counts['logic']} logic error(s)")
|
|
1258
|
+
|
|
1259
|
+
# Return flag indicating we should handle both environment and logic
|
|
1260
|
+
# False means "not test_error", so we'll proceed to environment + code fix
|
|
1261
|
+
return (
|
|
1262
|
+
False,
|
|
1263
|
+
detected_packages,
|
|
1264
|
+
detected_system_packages,
|
|
1265
|
+
additional_commands,
|
|
1266
|
+
in_tok,
|
|
1267
|
+
out_tok,
|
|
1268
|
+
diagnosis,
|
|
1269
|
+
)
|