flyteplugins-codegen 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1269 @@
1
+ import ast
2
+ import json
3
+ import logging
4
+ import re
5
+ import sys
6
+ from typing import Optional
7
+
8
+ import flyte
9
+ import flyte.errors
10
+ import litellm
11
+
12
+ from flyteplugins.codegen.core.types import (
13
+ CodePlan,
14
+ CodeSolution,
15
+ ErrorDiagnosis,
16
+ FixVerification,
17
+ TestFixResponse,
18
+ TestFunctionPatch,
19
+ _PackageDetectionResponse,
20
+ _PackageReplacementResponse,
21
+ _TestCodeResponse,
22
+ )
23
+ from flyteplugins.codegen.data.schema import extract_token_usage
24
+ from flyteplugins.codegen.generation.prompts import (
25
+ PACKAGE_MANAGER_MAP,
26
+ TEST_FRAMEWORKS,
27
+ build_enhanced_prompt,
28
+ )
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ _PYTHON_STDLIB = sys.stdlib_module_names
33
+
34
+
35
+ def strip_code_fences(code: str) -> str:
36
+ """Strip markdown code fences from LLM output."""
37
+ stripped = code.strip()
38
+ # Match ```python ... ``` or ``` ... ```
39
+ match = re.match(r"^```(?:\w+)?\s*\n(.*?)```\s*$", stripped, re.DOTALL)
40
+ if match:
41
+ return match.group(1).strip()
42
+ return stripped
43
+
44
+
45
+ def filter_stdlib(packages: list[str]) -> list[str]:
46
+ """Remove Python standard library modules from a package list."""
47
+ return [p for p in packages if p.split(".")[0].lower() not in _PYTHON_STDLIB]
48
+
49
+
50
+ @flyte.trace
51
+ async def generate_plan(
52
+ model: str,
53
+ prompt: str,
54
+ language: str,
55
+ schema: Optional[str],
56
+ constraints: Optional[list[str]],
57
+ data_samples: Optional[str],
58
+ inputs: Optional[dict[str, type]],
59
+ outputs: Optional[dict[str, type]],
60
+ litellm_params: Optional[dict],
61
+ ) -> tuple[CodePlan, int, int]:
62
+ """Generate a structured plan for the code solution before writing code.
63
+
64
+ Returns:
65
+ Tuple of (plan, input_tokens, output_tokens)
66
+ """
67
+ base_prompt = build_enhanced_prompt(prompt, language, schema, constraints, data_samples, inputs, outputs)
68
+
69
+ # Add planning-specific instructions
70
+ planning_prompt = f"""You are planning a {language} solution for the following task:
71
+
72
+ {base_prompt}
73
+
74
+ Create a detailed plan including:
75
+ 1. Overall description of the solution
76
+ 2. High-level approach and algorithm
77
+
78
+ The solution will be implemented as a single {language} file.
79
+ Focus on clarity and completeness. The plan will guide code generation."""
80
+
81
+ # Build params with defaults
82
+ params = {
83
+ "model": model,
84
+ "messages": [{"role": "user", "content": planning_prompt}],
85
+ "max_tokens": 1000,
86
+ "temperature": 0.3,
87
+ }
88
+
89
+ # Merge litellm_params (can override anything except response_format)
90
+ params.update(litellm_params or {})
91
+
92
+ # Always set response_format last
93
+ params["response_format"] = CodePlan
94
+
95
+ try:
96
+ response = await litellm.acompletion(**params)
97
+ except Exception as e:
98
+ # Check if it's an unsupported params error
99
+ if "UnsupportedParamsError" in type(e).__name__ or "does not support parameters" in str(e):
100
+ raise flyte.errors.RuntimeUserError(
101
+ f"Model '{model}' does not support structured outputs (response_format parameter). "
102
+ f"Please use a model that supports structured outputs like: gpt-4.1, "
103
+ f"claude-3-5-sonnet, or similar models."
104
+ ) from e
105
+ raise
106
+
107
+ # Extract token usage
108
+ input_tokens, output_tokens = extract_token_usage(response)
109
+
110
+ content = response.choices[0].message.content
111
+ if isinstance(content, str):
112
+ try:
113
+ plan_dict = json.loads(content)
114
+ except json.JSONDecodeError:
115
+ logger.warning("Failed to parse plan JSON, using fallback")
116
+ return (
117
+ CodePlan(description=content[:500], approach=""),
118
+ input_tokens,
119
+ output_tokens,
120
+ )
121
+ return CodePlan(**plan_dict), input_tokens, output_tokens
122
+ return content, input_tokens, output_tokens
123
+
124
+
125
+ @flyte.trace
126
+ async def generate_code(
127
+ model: str,
128
+ conversation: list[dict],
129
+ plan: Optional[CodePlan],
130
+ litellm_params: Optional[dict],
131
+ is_retry: bool = False,
132
+ ) -> tuple[CodeSolution, int, int]:
133
+ """Generate code with structured output using Pydantic model.
134
+
135
+ Args:
136
+ model: LLM model to use
137
+ conversation: Message history
138
+ plan: Optional plan to guide code generation (only used for initial generation)
139
+ litellm_params: LiteLLM parameters
140
+ is_retry: If True, skip plan context (for debugging/fixing code)
141
+
142
+ Returns:
143
+ Tuple of (CodeSolution, input_tokens, output_tokens)
144
+ """
145
+ # Add plan context only for initial generation (not retries)
146
+ messages = conversation.copy()
147
+ if plan and not is_retry:
148
+ plan_context = f"""
149
+ Plan for implementation:
150
+ - Description: {plan.description}
151
+ - Approach: {plan.approach}
152
+
153
+ Follow this plan when generating the code."""
154
+
155
+ # Insert plan context before the last user message
156
+ if messages and messages[-1]["role"] == "user":
157
+ messages.insert(-1, {"role": "system", "content": plan_context})
158
+ else:
159
+ messages.append({"role": "system", "content": plan_context})
160
+
161
+ # Build params with defaults
162
+ params = {
163
+ "model": model,
164
+ "messages": messages,
165
+ "max_tokens": 2000,
166
+ "temperature": 0.7,
167
+ }
168
+
169
+ # Merge litellm_params (can override anything except response_format)
170
+ params.update(litellm_params or {})
171
+
172
+ # Always set response_format last
173
+ params["response_format"] = CodeSolution
174
+
175
+ response = await litellm.acompletion(**params)
176
+
177
+ # Extract token usage
178
+ input_tokens, output_tokens = extract_token_usage(response)
179
+
180
+ content = response.choices[0].message.content
181
+ if isinstance(content, str):
182
+ try:
183
+ solution_dict = json.loads(content)
184
+ except json.JSONDecodeError:
185
+ logger.warning("Failed to parse code solution JSON, extracting code from raw response")
186
+ return (
187
+ CodeSolution(code=strip_code_fences(content)),
188
+ input_tokens,
189
+ output_tokens,
190
+ )
191
+ solution = CodeSolution(**solution_dict)
192
+ else:
193
+ solution = content
194
+
195
+ # Strip code fences the model may have included
196
+ solution.code = strip_code_fences(solution.code)
197
+ return solution, input_tokens, output_tokens
198
+
199
+
200
+ @flyte.trace
201
+ async def detect_required_packages(
202
+ model: str,
203
+ code: str,
204
+ language: str,
205
+ litellm_params: Optional[dict],
206
+ ) -> tuple[list[str], int, int]:
207
+ """Use LLM to detect required packages from code.
208
+
209
+ Returns:
210
+ Tuple of (packages, input_tokens, output_tokens)
211
+ """
212
+ if not code.strip():
213
+ return [], 0, 0
214
+
215
+ package_type = PACKAGE_MANAGER_MAP.get(language.lower(), "package names for the package manager")
216
+
217
+ detection_prompt = f"""Given this {language} code:
218
+
219
+ ```{language}
220
+ {code}
221
+ ```
222
+
223
+ List the {package_type} needed to install the dependencies used in this code.
224
+ For standard library / built-in modules, don't include them.
225
+ Only include third-party packages that need to be installed."""
226
+
227
+ # Build params with defaults
228
+ params = {
229
+ "model": model,
230
+ "messages": [{"role": "user", "content": detection_prompt}],
231
+ "max_tokens": 200,
232
+ "temperature": 0.1,
233
+ }
234
+
235
+ # Merge litellm_params (can override anything except response_format)
236
+ params.update(litellm_params or {})
237
+
238
+ # Always set response_format last
239
+ params["response_format"] = _PackageDetectionResponse
240
+
241
+ response = await litellm.acompletion(**params)
242
+
243
+ # Extract token usage
244
+ input_tokens, output_tokens = extract_token_usage(response)
245
+
246
+ content = response.choices[0].message.content
247
+ if isinstance(content, str):
248
+ try:
249
+ result_dict = json.loads(content)
250
+ except json.JSONDecodeError:
251
+ logger.warning("Failed to parse package detection JSON, returning empty list")
252
+ return [], input_tokens, output_tokens
253
+ packages = result_dict.get("packages", [])
254
+ else:
255
+ packages = content.packages
256
+
257
+ return filter_stdlib(packages), input_tokens, output_tokens
258
+
259
+
260
+ async def suggest_replacement_package(
261
+ model: str,
262
+ bad_package: str,
263
+ original_error: str,
264
+ solution_code: str,
265
+ litellm_params: Optional[dict] = None,
266
+ ) -> tuple[Optional[str], int, int]:
267
+ """Ask the LLM for the correct Debian package name to replace a bad one.
268
+
269
+ Returns:
270
+ Tuple of (replacement_package_name or None, input_tokens, output_tokens)
271
+ """
272
+ prompt = f"""The Debian/Ubuntu apt package "{bad_package}" does not exist.
273
+
274
+ Error: {original_error}
275
+
276
+ The code that needs this package:
277
+ ```python
278
+ {solution_code[:1000]}
279
+ ```
280
+
281
+ What is the correct Debian/Ubuntu apt package name that should be used instead?
282
+ Set replacement to the correct package name or null if no system package is needed."""
283
+
284
+ params = {
285
+ "model": model,
286
+ "messages": [{"role": "user", "content": prompt}],
287
+ "max_tokens": 50,
288
+ "temperature": 0.0,
289
+ }
290
+ params.update(litellm_params or {})
291
+ params["response_format"] = _PackageReplacementResponse
292
+
293
+ try:
294
+ response = await litellm.acompletion(**params)
295
+ input_tokens, output_tokens = extract_token_usage(response)
296
+
297
+ content = response.choices[0].message.content
298
+ if isinstance(content, str):
299
+ try:
300
+ result = json.loads(content)
301
+ replacement = result.get("replacement")
302
+ except json.JSONDecodeError:
303
+ return None, input_tokens, output_tokens
304
+ else:
305
+ replacement = content.replacement
306
+
307
+ if replacement:
308
+ logger.info(f"LLM suggested replacing '{bad_package}' with '{replacement}'")
309
+ return replacement, input_tokens, output_tokens
310
+ except Exception as e:
311
+ logger.warning(f"Failed to get package replacement suggestion: {e}")
312
+ return None, 0, 0
313
+
314
+
315
+ async def detect_and_track_packages(
316
+ model: str,
317
+ solution: CodeSolution,
318
+ base_pkgs: list[str],
319
+ detected_packages: list[str],
320
+ detected_system_packages: list[str],
321
+ litellm_params: Optional[dict] = None,
322
+ ) -> tuple[bool, list[str], list[str], int, int]:
323
+ """Detect packages from solution and track them.
324
+
325
+ Returns:
326
+ (needs_rebuild, updated_detected_packages, updated_detected_system_packages, input_tokens, output_tokens)
327
+ """
328
+ needs_rebuild = False
329
+ total_input_tokens = 0
330
+ total_output_tokens = 0
331
+
332
+ # Detect system packages from LLM
333
+ if solution.system_packages and solution.system_packages != detected_system_packages:
334
+ detected_system_packages = solution.system_packages.copy()
335
+ logger.info(f"Detected system packages: {detected_system_packages}")
336
+ needs_rebuild = True
337
+
338
+ # Detect language packages from code
339
+ if solution.code.strip():
340
+ new_packages, in_tok, out_tok = await detect_required_packages(
341
+ model, solution.code, solution.language, litellm_params
342
+ )
343
+ total_input_tokens += in_tok
344
+ total_output_tokens += out_tok
345
+
346
+ if new_packages:
347
+ added_packages = []
348
+ for pkg in new_packages:
349
+ if pkg not in detected_packages and pkg not in base_pkgs:
350
+ detected_packages.append(pkg)
351
+ added_packages.append(pkg)
352
+
353
+ if added_packages:
354
+ logger.info(f"Detected new {solution.language} packages: {added_packages}")
355
+ needs_rebuild = True
356
+
357
+ return (
358
+ needs_rebuild,
359
+ detected_packages,
360
+ detected_system_packages,
361
+ total_input_tokens,
362
+ total_output_tokens,
363
+ )
364
+
365
+
366
+ @flyte.trace
367
+ async def generate_tests(
368
+ model: str,
369
+ prompt: str,
370
+ plan: CodePlan,
371
+ solution: CodeSolution,
372
+ constraints: Optional[list[str]] = None,
373
+ schema: Optional[str] = None,
374
+ data_samples: Optional[str] = None,
375
+ inputs: Optional[dict[str, type]] = None,
376
+ outputs: Optional[dict[str, type]] = None,
377
+ litellm_params: Optional[dict] = None,
378
+ ) -> tuple[str, int, int]:
379
+ """
380
+ Generate test code to validate the solution.
381
+
382
+ Returns:
383
+ (test_code, input_tokens, output_tokens)
384
+ """
385
+
386
+ def _validate_python(code: str) -> None:
387
+ """Raise if code is invalid or likely truncated."""
388
+ if not code.strip():
389
+ raise RuntimeError("Generated empty test code")
390
+
391
+ try:
392
+ compile(code, "test_code", "exec")
393
+ except SyntaxError as e:
394
+ raise RuntimeError(f"Generated test code is invalid or truncated: {e}") from e
395
+
396
+ full_code = solution.code
397
+ language = solution.language.lower()
398
+ test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
399
+ test_framework = test_framework_info["name"]
400
+
401
+ import_instruction = (
402
+ "Import functions/classes from solution module (e.g., "
403
+ "'from solution import function_name'). "
404
+ "The code file is named solution.py and is located at "
405
+ "/var/inputs/solution.py."
406
+ )
407
+
408
+ test_prompt = f"""
409
+ You are generating TESTS.
410
+
411
+ Task:
412
+ {prompt}
413
+
414
+ Plan:
415
+ - Description: {plan.description}
416
+ - Approach: {plan.approach}
417
+ """
418
+
419
+ if schema:
420
+ test_prompt += f"\nSchema:\n{schema}\n"
421
+
422
+ if constraints:
423
+ test_prompt += "\nConstraints:\n"
424
+ for i, c in enumerate(constraints, 1):
425
+ test_prompt += f"{i}. {c}\n"
426
+
427
+ if data_samples:
428
+ test_prompt += f"""
429
+ Data:
430
+ Use EXACTLY this format if referenced.
431
+ Do NOT invent additional samples.
432
+
433
+ {data_samples}
434
+ """
435
+
436
+ if inputs:
437
+ test_prompt += f"\nCLI Arguments: {list(inputs.keys())}"
438
+ if outputs:
439
+ test_prompt += f"\nExpected outputs: {list(outputs.keys())}"
440
+
441
+ test_prompt += f"""
442
+ Solution code:
443
+ ```{solution.language}
444
+ {full_code}
445
+ ```"""
446
+
447
+ test_prompt += f"""
448
+ Test generation instructions:
449
+
450
+ Generate comprehensive but concise tests that maximize coverage through
451
+ multiple complementary cases using {test_framework}, not a single large test.
452
+
453
+ Requirements:
454
+ 1. Use {test_framework} framework
455
+ 2. Test the FULL execution path end-to-end (not just isolated functions)
456
+ 3. Use provided data AS-IS
457
+ 4. {import_instruction}
458
+ 5. Follow {test_framework} best practices
459
+ 6. /var/outputs is a pre-existing directory — NEVER delete or recreate it.
460
+ The solution writes output files there (one file per output),
461
+ and tests should READ from /var/outputs to verify correctness
462
+
463
+ IMPORTANT: Do NOT wrap output in code fences. Return ONLY valid Python test code."""
464
+
465
+ params = {
466
+ "model": model,
467
+ "messages": [{"role": "user", "content": test_prompt}],
468
+ "max_tokens": 5000,
469
+ "temperature": 0.4,
470
+ }
471
+ params.update(litellm_params or {})
472
+
473
+ response = await litellm.acompletion(**params)
474
+ input_tokens, output_tokens = extract_token_usage(response)
475
+
476
+ content = response.choices[0].message.content
477
+
478
+ if not isinstance(content, str):
479
+ content = str(content)
480
+
481
+ test_code = strip_code_fences(content)
482
+
483
+ try:
484
+ _validate_python(test_code)
485
+ return test_code, input_tokens, output_tokens
486
+ except RuntimeError:
487
+ logger.warning("Plain text test generation invalid, retrying with structured output")
488
+
489
+ # Structured JSON fallback (optional)
490
+ params_structured = dict(params)
491
+ params_structured["response_format"] = _TestCodeResponse
492
+
493
+ response = await litellm.acompletion(**params_structured)
494
+ input_tokens2, output_tokens2 = extract_token_usage(response)
495
+
496
+ content = response.choices[0].message.content
497
+
498
+ if isinstance(content, str):
499
+ try:
500
+ data = json.loads(content)
501
+ test_code = data["test_code"]
502
+ except Exception as e:
503
+ raise RuntimeError("Structured output invalid; cannot recover test code") from e
504
+ else:
505
+ test_code = content.test_code
506
+
507
+ test_code = strip_code_fences(test_code)
508
+ _validate_python(test_code)
509
+
510
+ return (
511
+ test_code,
512
+ input_tokens + input_tokens2,
513
+ output_tokens + output_tokens2,
514
+ )
515
+
516
+
517
+ def _strip_parametrize_suffix(name: str) -> str:
518
+ """Strip pytest parametrize suffix like [0-True-10-3] from a test name."""
519
+ return re.sub(r"\[.*?\]$", "", name)
520
+
521
+
522
+ def apply_test_patches(original_code: str, patches: list[TestFunctionPatch]) -> str:
523
+ """Apply test function patches to the original test file using AST-based replacement.
524
+
525
+ Parses the original file to find function boundaries by name, then replaces
526
+ the source lines for each patched function. Preserves imports, fixtures,
527
+ and all non-patched code exactly.
528
+
529
+ Args:
530
+ original_code: The full original test file source.
531
+ patches: List of patches, each containing a function name and its fixed source.
532
+
533
+ Returns:
534
+ The patched test file as a string.
535
+ """
536
+ tree = ast.parse(original_code)
537
+ lines = original_code.splitlines(keepends=True)
538
+
539
+ # Ensure last line has a newline for consistent joining
540
+ if lines and not lines[-1].endswith("\n"):
541
+ lines[-1] += "\n"
542
+
543
+ # Build map of function_name -> (start_line_0indexed, end_line_0indexed_exclusive)
544
+ # We need to handle decorated functions: start from the first decorator line
545
+ func_map: dict[str, tuple[int, int]] = {}
546
+ all_top_level = list(ast.iter_child_nodes(tree))
547
+
548
+ for idx, node in enumerate(all_top_level):
549
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
550
+ # Start line: use decorator_list if present, otherwise the def line
551
+ if node.decorator_list:
552
+ start = node.decorator_list[0].lineno - 1 # 0-indexed
553
+ else:
554
+ start = node.lineno - 1
555
+
556
+ # End line: next top-level node's start, or end of file
557
+ if idx + 1 < len(all_top_level):
558
+ next_node = all_top_level[idx + 1]
559
+ if hasattr(next_node, "decorator_list") and next_node.decorator_list:
560
+ end = next_node.decorator_list[0].lineno - 1
561
+ else:
562
+ end = next_node.lineno - 1
563
+ else:
564
+ end = len(lines)
565
+
566
+ # Trim trailing blank lines from this function's range
567
+ while end > start and lines[end - 1].strip() == "":
568
+ end -= 1
569
+
570
+ func_map[node.name] = (start, end)
571
+
572
+ # Apply patches in reverse order (so line numbers stay valid)
573
+ # Normalize patch names by stripping parametrize suffixes
574
+ # If LLM returns multiple patches for the same base function, keep the last one
575
+ patches_by_name: dict[str, TestFunctionPatch] = {}
576
+ for p in patches:
577
+ base_name = _strip_parametrize_suffix(p.test_name)
578
+ if base_name in patches_by_name:
579
+ logger.warning(
580
+ f"Multiple patches for '{base_name}' (from '{p.test_name}'). "
581
+ f"Keeping last patch — ensure it includes ALL parametrize fixes."
582
+ )
583
+ patches_by_name[base_name] = p
584
+ replacements = []
585
+ for name, (start, end) in func_map.items():
586
+ if name in patches_by_name:
587
+ replacements.append((start, end, patches_by_name[name].fixed_code))
588
+
589
+ # Sort by start line descending so we can splice without shifting
590
+ replacements.sort(key=lambda r: r[0], reverse=True)
591
+
592
+ for start, end, fixed_code in replacements:
593
+ # Ensure fixed code ends with newline
594
+ fixed = fixed_code.rstrip("\n") + "\n"
595
+ # Replace the lines
596
+ lines[start:end] = [fixed]
597
+
598
+ result = "".join(lines)
599
+
600
+ # Validate the patched code compiles
601
+ try:
602
+ compile(result, "patched_test", "exec")
603
+ except SyntaxError as e:
604
+ logger.warning(f"Patched test code has syntax error: {e}. Returning as-is.")
605
+
606
+ return result
607
+
608
+
609
+ @flyte.trace
610
+ async def fix_failing_tests(
611
+ model: str,
612
+ test_code: str,
613
+ diagnosis: ErrorDiagnosis,
614
+ solution: CodeSolution,
615
+ litellm_params: Optional[dict] = None,
616
+ ) -> tuple[str, list[TestFunctionPatch], int, int]:
617
+ """Fix only the failing tests by returning patches for individual functions.
618
+
619
+ Instead of asking the LLM to reproduce the entire test file, asks for only
620
+ the fixed test functions and splices them into the original file.
621
+
622
+ Args:
623
+ test_code: Complete test file
624
+ diagnosis: Structured diagnosis of test failures
625
+ solution: The solution being tested
626
+
627
+ Returns:
628
+ Tuple of (fixed_test_code, patches, input_tokens, output_tokens)
629
+ """
630
+ full_code = solution.code
631
+
632
+ # Build diagnosis information from individual failures
633
+ diagnosis_info = []
634
+ for i, failure in enumerate(diagnosis.failures, 1):
635
+ diagnosis_info.append(
636
+ f"""
637
+ Test {i}: {failure.test_name}
638
+ - Expected: {failure.expected_behavior}
639
+ - Actual: {failure.actual_behavior}
640
+ - Root cause: {failure.root_cause}
641
+ - Suggested fix: {failure.suggested_fix}"""
642
+ )
643
+
644
+ diagnosis_section = "\n".join(diagnosis_info)
645
+
646
+ # Extract only the failing test functions to include in the prompt
647
+ # Strip parametrize suffixes so we match the actual function names in the AST
648
+ failing_names = {_strip_parametrize_suffix(f.test_name) for f in diagnosis.failures}
649
+ try:
650
+ tree = ast.parse(test_code)
651
+ test_lines = test_code.splitlines(keepends=True)
652
+ all_nodes = list(ast.iter_child_nodes(tree))
653
+ failing_snippets = []
654
+ for idx, node in enumerate(all_nodes):
655
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
656
+ if node.name in failing_names:
657
+ if node.decorator_list:
658
+ start = node.decorator_list[0].lineno - 1
659
+ else:
660
+ start = node.lineno - 1
661
+ if idx + 1 < len(all_nodes):
662
+ next_node = all_nodes[idx + 1]
663
+ if hasattr(next_node, "decorator_list") and next_node.decorator_list:
664
+ end = next_node.decorator_list[0].lineno - 1
665
+ else:
666
+ end = next_node.lineno - 1
667
+ else:
668
+ end = len(test_lines)
669
+ while end > start and test_lines[end - 1].strip() == "":
670
+ end -= 1
671
+ failing_snippets.append("".join(test_lines[start:end]))
672
+ failing_code_section = "\n\n".join(failing_snippets)
673
+ except SyntaxError:
674
+ # If we can't parse, fall back to sending the whole test file
675
+ failing_code_section = test_code
676
+
677
+ fix_prompt = f"""You are performing MINIMAL PATCH REPAIR on pytest test functions.
678
+
679
+ YOUR PRIMARY TASK: Apply the suggested fix for each failing test EXACTLY as described below.
680
+
681
+ ═══════════════════════════════════════
682
+ DIAGNOSED FAILURES AND REQUIRED FIXES
683
+ ═══════════════════════════════════════
684
+ {diagnosis_section}
685
+
686
+ For each failure above, the "Suggested fix" tells you EXACTLY what code change to make.
687
+ You MUST apply that specific change — do not interpret, simplify, or find an alternative approach.
688
+
689
+ ═══════════════════════════════════════
690
+ FAILING TEST FUNCTIONS (ORIGINAL CODE)
691
+ ═══════════════════════════════════════
692
+ ```{solution.language}
693
+ {failing_code_section}
694
+ ```
695
+
696
+ SOLUTION CODE (for reference only — do NOT modify):
697
+ The solution is saved as /var/inputs/solution.py. Tests import from it via `from solution import ...`.
698
+ ```{solution.language}
699
+ {full_code}
700
+ ```
701
+
702
+ RULES:
703
+ 1. For each suggested fix, locate the exact code it refers to and apply the substitution.
704
+ 2. Preserve ALL existing assertions, decorators, and test structure.
705
+ 3. Add any necessary imports at the top of the function body if needed.
706
+ 4. DO NOT rewrite, simplify, or restructure the test.
707
+ 5. DO NOT change anything that the diagnosis does not mention.
708
+
709
+ OUTPUT FORMAT:
710
+ Return ONLY patches for functions that require modification:
711
+ - test_name: BASE function name (no [param] suffix)
712
+ - fixed_code: COMPLETE function including decorators"""
713
+
714
+ params = {
715
+ "model": model,
716
+ "messages": [{"role": "user", "content": fix_prompt}],
717
+ "max_tokens": 2000,
718
+ "temperature": 0.3,
719
+ }
720
+ params.update(litellm_params or {})
721
+ params["response_format"] = TestFixResponse
722
+
723
+ response = await litellm.acompletion(**params)
724
+ input_tokens, output_tokens = extract_token_usage(response)
725
+
726
+ content = response.choices[0].message.content
727
+ if isinstance(content, str):
728
+ try:
729
+ result_dict = json.loads(content)
730
+ patches = [TestFunctionPatch(**p) for p in result_dict["patches"]]
731
+ except (json.JSONDecodeError, KeyError) as exc:
732
+ logger.warning(f"Failed to parse test fix patches JSON: {exc}")
733
+ # Fallback: treat as a single patch with the raw content
734
+ patches = []
735
+ raw = strip_code_fences(content)
736
+ for failure in diagnosis.failures:
737
+ patches.append(TestFunctionPatch(test_name=failure.test_name, fixed_code=raw))
738
+ else:
739
+ patches = content.patches
740
+
741
+ # Strip code fences from each patch's fixed_code
742
+ for patch in patches:
743
+ patch.fixed_code = strip_code_fences(patch.fixed_code)
744
+
745
+ # Apply patches to original test code
746
+ fixed_test_code = apply_test_patches(test_code, patches)
747
+
748
+ return fixed_test_code, patches, input_tokens, output_tokens
749
+
750
+
751
+ def extract_error_messages_from_pytest(output: str) -> dict[str, str]:
752
+ """Extract the final error message for each failed test from pytest output.
753
+
754
+ Args:
755
+ output: Pytest output string
756
+
757
+ Returns:
758
+ Dict mapping test name to error message (e.g., "RecursionError: maximum recursion depth exceeded")
759
+ """
760
+ error_messages = {}
761
+ current_test = None
762
+
763
+ # Parse pytest output line by line
764
+ lines = output.split("\n")
765
+
766
+ for line in lines:
767
+ # Match test failure header: _____ test_name _____
768
+ test_header_match = re.match(r"^_{5,}\s+(.+?)\s+_{5,}$", line)
769
+ if test_header_match:
770
+ current_test = test_header_match.group(1).strip()
771
+ # Remove parametrize suffix like [530.00-33-3]
772
+ current_test = re.sub(r"\[.*?\]$", "", current_test)
773
+ continue
774
+
775
+ # Match error line: starts with "E " followed by exception
776
+ if current_test and line.startswith("E "):
777
+ error_line = line[4:].strip() # Remove "E " prefix
778
+ # Only capture if it looks like an exception (contains "Error" or "Exception")
779
+ if "Error" in error_line or "Exception" in error_line or "Failed" in error_line:
780
+ # Extract just the exception type and message, not the full line
781
+ error_messages[current_test] = error_line
782
+
783
+ return error_messages
784
+
785
+
786
+ @flyte.trace
787
+ async def diagnose_error(
788
+ model: str,
789
+ solution: CodeSolution,
790
+ output: str,
791
+ prompt: str,
792
+ plan: CodePlan,
793
+ litellm_params: Optional[dict],
794
+ test_code: Optional[str] = None,
795
+ data_samples: Optional[str] = None,
796
+ constraints: Optional[list[str]] = None,
797
+ schema: Optional[str] = None,
798
+ ) -> tuple[ErrorDiagnosis, int, int]:
799
+ """Performs structured analysis to determine if the error is due to:
800
+ - Environment issues (missing packages, dependencies)
801
+ - Code logic issues (bugs, incorrect algorithms)
802
+ - Test code issues (wrong expected values, incorrect test logic)
803
+
804
+ Returns:
805
+ Tuple of (ErrorDiagnosis, input_tokens, output_tokens)
806
+ """
807
+ # Extract error messages from pytest output for accurate tracking
808
+ error_messages = extract_error_messages_from_pytest(output)
809
+
810
+ full_code = solution.code
811
+
812
+ # Build error messages section for prompt
813
+ error_messages_section = ""
814
+ if error_messages:
815
+ error_messages_section = (
816
+ "\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
817
+ "EXTRACTED ERROR MESSAGES\n"
818
+ "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
819
+ )
820
+ for test_name, error_msg in error_messages.items():
821
+ error_messages_section += f"\n{test_name}: {error_msg}"
822
+
823
+ diagnosis_prompt = f"""
824
+ You are a STRICT FAILURE ANALYSIS SYSTEM.
825
+
826
+ Your goal is to determine whether failures are caused by:
827
+ A) Solution code bugs
828
+ B) Incorrect tests (tests generated by LLM)
829
+ C) Environment/setup issues
830
+
831
+ - Tests are NOT authoritative.
832
+ - Many tests contain hallucinated expected values.
833
+ - Only blame solution code if you can PROVE a violation of the specification.
834
+
835
+ TASK: {prompt}
836
+
837
+ PLAN: {plan.description} | {plan.approach}
838
+
839
+ SOLUTION CODE (saved as /var/inputs/solution.py, tests import via `from solution import ...`):
840
+ ```{solution.language}
841
+ {full_code}
842
+ ```"""
843
+
844
+ if data_samples:
845
+ diagnosis_prompt += f"""
846
+
847
+ DATA SAMPLES:
848
+ ```
849
+ {data_samples}
850
+ ```"""
851
+
852
+ if constraints:
853
+ diagnosis_prompt += "\n\nCONSTRAINTS:"
854
+ for i, constraint in enumerate(constraints, 1):
855
+ diagnosis_prompt += f"\n{i}. {constraint}"
856
+
857
+ if schema:
858
+ diagnosis_prompt += f"""
859
+
860
+ SCHEMA:
861
+ ```
862
+ {schema}
863
+ ```"""
864
+
865
+ diagnosis_prompt += f"""
866
+
867
+ TEST CODE:
868
+ {test_code}
869
+
870
+ TEST OUTPUT:
871
+ {output}
872
+ {error_messages_section}
873
+
874
+ MANDATORY DECISION PROCEDURE:
875
+
876
+ For EACH failing test, follow these steps IN ORDER:
877
+
878
+ STEP 0 — Execution evidence
879
+ Determine whether the solution produced observable effects.
880
+ If none are observed, consider invocation or execution issues.
881
+
882
+ STEP 1 — Environment failure
883
+ If the error indicates missing packages, import errors, file system issues,
884
+ or system configuration problems -> classify as "environment".
885
+
886
+ STEP 2 — Test contradicts spec/data/schema
887
+ Classify as "test_error" if ANY of the following are true:
888
+
889
+ - Expected value is not derivable from inputs, spec, or schema
890
+ - Test invents constants not present in problem description
891
+ - Test assumes behavior not specified
892
+ - Test contradicts constraints or data samples
893
+ - Test fails before solution logic executes
894
+ - Assertion checks wrong field/format/order
895
+ - Multiple valid outputs exist but test expects one specific value
896
+
897
+ STEP 3 — Uncertain root cause
898
+ If you CANNOT prove the solution is wrong -> classify as "test_error".
899
+
900
+ STEP 4 — Proven logic error
901
+ Classify as "logic" if you can demonstrate:
902
+
903
+ - Exact specification requirement violated
904
+ - Specific incorrect algorithm or implementation
905
+ - Output contradicts constraints despite valid test
906
+
907
+ Logic errors REQUIRE proof from the specification. Absence of proof -> NOT a logic error.
908
+
909
+ OUTPUT FORMAT: Return JSON with:
910
+
911
+ 1. failures: list of objects:
912
+ - test_name
913
+ - error_message (copy exactly)
914
+ - expected_behavior
915
+ - actual_behavior
916
+ - root_cause (must cite evidence)
917
+ - suggested_fix ("Replace `old` with `new`")
918
+ - error_type: "test_error" | "logic" | "environment"
919
+ 2. Collect from environment errors: needs_system_packages, needs_language_packages and needs_additional_commands
920
+
921
+ CRITICAL RULE:
922
+ - If ANY test is classified as "test_error", RETURN ONLY test_error failures. Do NOT include logic diagnoses.
923
+ - /var/outputs is a pre-existing directory and exists for the entire run. NEVER delete/recreate it.
924
+ Only write files into it. Never make it configurable.
925
+ """
926
+
927
+ # Build params with defaults
928
+ params = {
929
+ "model": model,
930
+ "messages": [{"role": "user", "content": diagnosis_prompt}],
931
+ "max_tokens": 2000, # Increased for detailed test-by-test analysis
932
+ "temperature": 0.3,
933
+ }
934
+
935
+ # Merge litellm_params (can override anything except response_format)
936
+ params.update(litellm_params or {})
937
+
938
+ # Always set response_format last
939
+ params["response_format"] = ErrorDiagnosis
940
+
941
+ response = await litellm.acompletion(**params)
942
+
943
+ # Extract token usage
944
+ input_tokens, output_tokens = extract_token_usage(response)
945
+
946
+ content = response.choices[0].message.content
947
+ if isinstance(content, str):
948
+ try:
949
+ result_dict = json.loads(content)
950
+ except json.JSONDecodeError:
951
+ logger.warning("Failed to parse diagnosis JSON, returning empty diagnosis")
952
+ return ErrorDiagnosis(failures=[]), input_tokens, output_tokens
953
+ return ErrorDiagnosis(**result_dict), input_tokens, output_tokens
954
+ return content, input_tokens, output_tokens
955
+
956
+
957
+ @flyte.trace
958
+ async def verify_test_fixes_applied(
959
+ model: str,
960
+ diagnosis: ErrorDiagnosis,
961
+ patches: list[TestFunctionPatch],
962
+ litellm_params: Optional[dict] = None,
963
+ ) -> tuple[FixVerification, int, int]:
964
+ """Verify that the suggested test fixes from diagnosis are present in the patches.
965
+
966
+ First checks structurally that every failing function has a matching patch
967
+ (accounting for parametrize suffixes). Then sends diagnosis + patches to LLM
968
+ to verify fix content.
969
+
970
+ Returns:
971
+ Tuple of (FixVerification, input_tokens, output_tokens)
972
+ """
973
+ # Structural check: every failing test function must have a matching patch
974
+ patch_base_names = {_strip_parametrize_suffix(p.test_name) for p in patches}
975
+ failing_base_names = {_strip_parametrize_suffix(f.test_name) for f in diagnosis.failures}
976
+ missing_functions = failing_base_names - patch_base_names
977
+ if missing_functions:
978
+ return (
979
+ FixVerification(
980
+ all_fixes_applied=False,
981
+ applied_fixes=[n for n in failing_base_names if n in patch_base_names],
982
+ missing_fixes=list(missing_functions),
983
+ explanation=f"No patches returned for functions: {', '.join(missing_functions)}",
984
+ ),
985
+ 0,
986
+ 0,
987
+ )
988
+
989
+ # Build verification prompt with only diagnosis + patches (no full files)
990
+ fixes_to_check = []
991
+ for i, failure in enumerate(diagnosis.failures, 1):
992
+ fixes_to_check.append(
993
+ f"""Fix {i} (for test: {failure.test_name}):
994
+ - Error: {failure.error_message}
995
+ - Root cause: {failure.root_cause}
996
+ - Required fix: {failure.suggested_fix}
997
+ - Verification: The old code/pattern must be REMOVED and the new code/pattern must be PRESENT"""
998
+ )
999
+
1000
+ fixes_section = "\n\n".join(fixes_to_check)
1001
+
1002
+ patches_section = []
1003
+ for patch in patches:
1004
+ patches_section.append(f"### {patch.test_name}\n```python\n{patch.fixed_code}\n```")
1005
+ patches_text = "\n\n".join(patches_section) if patches_section else "(no patches returned)"
1006
+
1007
+ verify_prompt = f"""You are a CODE DIFF REVIEWER. \
1008
+ Your job is to verify that specific code changes were actually made.
1009
+
1010
+ For each required fix below, check whether the EXACT code change described
1011
+ in "Required fix" is present in the patched function. Do NOT accept
1012
+ alternative approaches that "address the same issue" — the specific
1013
+ code transformation must be visible.
1014
+
1015
+ ═══════════════════════════════════════
1016
+ REQUIRED FIXES
1017
+ ═══════════════════════════════════════
1018
+ {fixes_section}
1019
+
1020
+ ═══════════════════════════════════════
1021
+ PATCHED TEST FUNCTIONS
1022
+ ═══════════════════════════════════════
1023
+ {patches_text}
1024
+
1025
+ VERIFICATION RULES:
1026
+ 1. For each required fix, find the patched function for that test.
1027
+ 2. Check if the OLD code/pattern mentioned in the fix is GONE from the patch.
1028
+ 3. Check if the NEW code/pattern mentioned in the fix is PRESENT in the patch.
1029
+ 4. If the fix says "Replace X with Y", then X must NOT appear and Y MUST appear.
1030
+ 5. A fix is NOT applied if the patch uses a different approach to solve the same problem.
1031
+ 6. Set all_fixes_applied to true ONLY if EVERY fix passes checks 2, 3, and 4."""
1032
+
1033
+ params = {
1034
+ "model": model,
1035
+ "messages": [{"role": "user", "content": verify_prompt}],
1036
+ "max_tokens": 1000,
1037
+ "temperature": 0.1,
1038
+ }
1039
+ params.update(litellm_params or {})
1040
+ params["response_format"] = FixVerification
1041
+
1042
+ response = await litellm.acompletion(**params)
1043
+ input_tokens, output_tokens = extract_token_usage(response)
1044
+
1045
+ content = response.choices[0].message.content
1046
+ if isinstance(content, str):
1047
+ try:
1048
+ result_dict = json.loads(content)
1049
+ verification = FixVerification(**result_dict)
1050
+ except (json.JSONDecodeError, Exception):
1051
+ logger.warning("Failed to parse test fix verification, assuming fixes not applied")
1052
+ verification = FixVerification(
1053
+ all_fixes_applied=False,
1054
+ applied_fixes=[],
1055
+ missing_fixes=["parse_error"],
1056
+ explanation="Failed to parse verification response",
1057
+ )
1058
+ else:
1059
+ verification = content
1060
+
1061
+ return verification, input_tokens, output_tokens
1062
+
1063
+
1064
+ @flyte.trace
1065
+ async def verify_logic_fixes_applied(
1066
+ model: str,
1067
+ diagnosis: ErrorDiagnosis,
1068
+ new_solution: CodeSolution,
1069
+ litellm_params: Optional[dict] = None,
1070
+ ) -> tuple[FixVerification, int, int]:
1071
+ """Verify that the suggested fixes from diagnosis are present in new code.
1072
+
1073
+ Only verifies "logic" failures - environment and test_error are handled differently.
1074
+
1075
+ Returns:
1076
+ Tuple of (FixVerification, input_tokens, output_tokens)
1077
+ """
1078
+ # Only check logic failures (environment and test_error don't modify solution code)
1079
+ logic_failures = [f for f in diagnosis.failures if f.error_type == "logic"]
1080
+
1081
+ if not logic_failures:
1082
+ # No logic fixes to verify
1083
+ return (
1084
+ FixVerification(
1085
+ all_fixes_applied=True,
1086
+ applied_fixes=[],
1087
+ missing_fixes=[],
1088
+ explanation="No logic fixes to verify (only environment/test errors)",
1089
+ ),
1090
+ 0,
1091
+ 0,
1092
+ )
1093
+
1094
+ new_code = new_solution.code
1095
+
1096
+ # Build verification prompt - only for logic failures
1097
+ fixes_to_check = []
1098
+ for i, failure in enumerate(logic_failures, 1):
1099
+ fixes_to_check.append(
1100
+ f"""Fix {i} (for test: {failure.test_name}):
1101
+ - Root cause: {failure.root_cause}
1102
+ - Required fix: {failure.suggested_fix}"""
1103
+ )
1104
+
1105
+ fixes_section = "\n\n".join(fixes_to_check)
1106
+
1107
+ verify_prompt = f"""You must verify that ALL the required fixes below are present in the new code.
1108
+
1109
+ Required fixes:
1110
+ {fixes_section}
1111
+
1112
+ New code to verify:
1113
+ ```{new_solution.language}
1114
+ {new_code}
1115
+ ```
1116
+
1117
+ Check each fix carefully:
1118
+ 1. For each fix, determine if the required change is present in the new code
1119
+ 2. If a fix says "Replace X with Y", verify that X is gone and Y is present
1120
+ 3. List which fixes are applied and which are missing
1121
+ 4. Set all_fixes_applied to true ONLY if every single fix is present
1122
+
1123
+ Be strict - if even one fix is missing or partially applied, set all_fixes_applied to false."""
1124
+
1125
+ params = {
1126
+ "model": model,
1127
+ "messages": [{"role": "user", "content": verify_prompt}],
1128
+ "max_tokens": 1000,
1129
+ "temperature": 0.1,
1130
+ }
1131
+ params.update(litellm_params or {})
1132
+ params["response_format"] = FixVerification
1133
+
1134
+ response = await litellm.acompletion(**params)
1135
+ input_tokens, output_tokens = extract_token_usage(response)
1136
+
1137
+ content = response.choices[0].message.content
1138
+ if isinstance(content, str):
1139
+ try:
1140
+ result_dict = json.loads(content)
1141
+ verification = FixVerification(**result_dict)
1142
+ except (json.JSONDecodeError, Exception):
1143
+ logger.warning("Failed to parse logic fix verification, assuming fixes not applied")
1144
+ verification = FixVerification(
1145
+ all_fixes_applied=False,
1146
+ applied_fixes=[],
1147
+ missing_fixes=["parse_error"],
1148
+ explanation="Failed to parse verification response",
1149
+ )
1150
+ else:
1151
+ verification = content
1152
+
1153
+ return verification, input_tokens, output_tokens
1154
+
1155
+
1156
+ async def diagnose_and_plan_environment_fix(
1157
+ model: str,
1158
+ solution: CodeSolution,
1159
+ code_output: str,
1160
+ prompt: str,
1161
+ plan: CodePlan,
1162
+ detected_packages: list[str],
1163
+ detected_system_packages: list[str],
1164
+ additional_commands: list[str],
1165
+ litellm_params: Optional[dict] = None,
1166
+ test_code: Optional[str] = None,
1167
+ data_samples: Optional[str] = None,
1168
+ constraints: Optional[list[str]] = None,
1169
+ schema: Optional[str] = None,
1170
+ ) -> tuple[bool | str, list[str], list[str], list[str], int, int, ErrorDiagnosis]:
1171
+ """Diagnose error and plan environment fix (don't execute yet).
1172
+
1173
+ Returns:
1174
+ Tuple of (primary_error_type, updated_detected_packages, updated_detected_system_packages,
1175
+ updated_additional_commands, input_tokens, output_tokens, diagnosis)
1176
+
1177
+ where primary_error_type is either:
1178
+ - "test_error" (str): Test code has bugs, must fix tests first
1179
+ - False (bool): Environment and/or logic errors, handle together
1180
+ """
1181
+ diagnosis, in_tok, out_tok = await diagnose_error(
1182
+ model,
1183
+ solution,
1184
+ code_output,
1185
+ prompt,
1186
+ plan,
1187
+ litellm_params,
1188
+ test_code,
1189
+ data_samples,
1190
+ constraints,
1191
+ schema,
1192
+ )
1193
+
1194
+ # Important: If any test errors exist, only keep test_error failures
1195
+ # Discard logic/environment failures - they're unreliable when tests are broken
1196
+ test_error_failures = [f for f in diagnosis.failures if f.error_type == "test_error"]
1197
+
1198
+ if test_error_failures:
1199
+ logger.warning(
1200
+ f"Found {len(test_error_failures)} test_error failure(s). "
1201
+ f"Discarding {len(diagnosis.failures) - len(test_error_failures)} logic/environment failures "
1202
+ f"(unreliable when tests are broken)"
1203
+ )
1204
+ # Replace failures with only test_error failures
1205
+ diagnosis.failures = test_error_failures
1206
+
1207
+ # Count error types across all failures (after filtering)
1208
+ error_type_counts = {"environment": 0, "test_error": 0, "logic": 0}
1209
+ for failure in diagnosis.failures:
1210
+ error_type_counts[failure.error_type] += 1
1211
+
1212
+ logger.info(f"Number of failures: {len(diagnosis.failures)}")
1213
+ logger.info(
1214
+ f"Breakdown: {error_type_counts['environment']} environment, "
1215
+ f"{error_type_counts['test_error']} test_error, "
1216
+ f"{error_type_counts['logic']} logic"
1217
+ )
1218
+
1219
+ for i, failure in enumerate(diagnosis.failures, 1):
1220
+ logger.info(
1221
+ f"Test {i} [{failure.error_type}] - {failure.test_name}",
1222
+ extra={"markup": False},
1223
+ )
1224
+ logger.info(f"Root cause: {failure.root_cause}", extra={"markup": False})
1225
+ logger.info(f"Fix: {failure.suggested_fix}", extra={"markup": False})
1226
+
1227
+ # Determine actions based on all groups
1228
+ has_environment_errors = error_type_counts["environment"] > 0
1229
+ has_test_errors = error_type_counts["test_error"] > 0
1230
+ has_logic_errors = error_type_counts["logic"] > 0
1231
+
1232
+ if has_test_errors:
1233
+ # Test code has bugs - must regenerate tests first
1234
+ logger.warning(
1235
+ f"{error_type_counts['test_error']} test(s) have test bugs. "
1236
+ f"Will regenerate tests first, then handle other errors."
1237
+ )
1238
+ return (
1239
+ "test_error",
1240
+ detected_packages,
1241
+ detected_system_packages,
1242
+ additional_commands,
1243
+ in_tok,
1244
+ out_tok,
1245
+ diagnosis,
1246
+ )
1247
+
1248
+ # No test errors - we can fix environment + logic together in one go
1249
+ if has_environment_errors and has_logic_errors:
1250
+ logger.info(
1251
+ f"Will fix {error_type_counts['environment']} environment error(s) "
1252
+ f"and patch code for {error_type_counts['logic']} logic error(s) in one iteration"
1253
+ )
1254
+ elif has_environment_errors:
1255
+ logger.info(f"Will fix {error_type_counts['environment']} environment error(s)")
1256
+ elif has_logic_errors:
1257
+ logger.info(f"Will patch code for {error_type_counts['logic']} logic error(s)")
1258
+
1259
+ # Return flag indicating we should handle both environment and logic
1260
+ # False means "not test_error", so we'll proceed to environment + code fix
1261
+ return (
1262
+ False,
1263
+ detected_packages,
1264
+ detected_system_packages,
1265
+ additional_commands,
1266
+ in_tok,
1267
+ out_tok,
1268
+ diagnosis,
1269
+ )