flyteplugins-codegen 2.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1088 @@
1
+ import datetime
2
+ import hashlib
3
+ import json
4
+ import logging
5
+ import os
6
+ import tempfile
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Literal, Optional
10
+
11
+ import flyte
12
+ import litellm
13
+ import pandas as pd
14
+ from flyte.errors import InvalidPackageError
15
+ from flyte.io import File
16
+ from flyte.sandbox import ImageConfig
17
+ from flyte.syncify import syncify
18
+
19
+ from flyteplugins.codegen.core.types import CodeGenEvalResult, CodePlan, CodeSolution
20
+ from flyteplugins.codegen.data.extraction import extract_data_context, is_dataframe
21
+ from flyteplugins.codegen.execution.agent import code_gen_eval_agent
22
+ from flyteplugins.codegen.execution.docker import build_image, run_tests
23
+ from flyteplugins.codegen.generation.llm import (
24
+ detect_and_track_packages,
25
+ diagnose_and_plan_environment_fix,
26
+ fix_failing_tests,
27
+ generate_code,
28
+ generate_plan,
29
+ generate_tests,
30
+ suggest_replacement_package,
31
+ verify_logic_fixes_applied,
32
+ verify_test_fixes_applied,
33
+ )
34
+ from flyteplugins.codegen.generation.prompts import (
35
+ DEFAULT_SYSTEM_PROMPT,
36
+ STRUCTURED_OUTPUT_REQUIREMENTS,
37
+ TEST_FRAMEWORKS,
38
+ build_enhanced_prompt,
39
+ )
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ @dataclass
45
+ class AutoCoderAgent:
46
+ """Agent for single-file Python code generation with automatic testing and iteration.
47
+
48
+ Generates a single Python script, builds a sandbox image with the required
49
+ dependencies, runs pytest-based tests, and iterates until tests pass.
50
+
51
+ Uses Sandbox internally for isolated code execution.
52
+
53
+ Args:
54
+ name: Name for the agent (used in image naming and logging).
55
+ model: LLM model to use (required). Must support structured outputs.
56
+ For backend="litellm" (default): e.g. "gpt-4.1", "claude-sonnet-4-20250514".
57
+ For backend="claude": a Claude model ("sonnet", "opus", "haiku").
58
+ system_prompt: Optional system prompt to use for LLM. If not provided,
59
+ a default prompt with structured output requirements is used.
60
+ api_key: Optional environment variable name for LLM API key.
61
+ api_base: Optional base URL for LLM API.
62
+ litellm_params: Optional dict of additional parameters to pass to LiteLLM calls.
63
+ base_packages: Optional list of base packages to install in the sandbox.
64
+ resources: Optional resources for sandbox execution (default: cpu=1, 1Gi).
65
+ image_config: Optional image configuration for sandbox execution.
66
+ max_iterations: Maximum number of generate-test-fix iterations. Defaults to 10.
67
+ max_sample_rows: Optional maximum number of rows to use for sample data. Defaults to 100.
68
+ skip_tests: Optional flag to skip testing. Defaults to False.
69
+ sandbox_retries: Number of Flyte task-level retries for each sandbox execution. Defaults to 0.
70
+ timeout: Timeout in seconds for sandboxes. Defaults to None.
71
+ env_vars: Environment variables to pass to sandboxes.
72
+ secrets: flyte.Secret objects to make available to sandboxes.
73
+ cache: CacheRequest for sandboxes: "auto", "override", or "disable". Defaults to "auto".
74
+ backend: Execution backend: "litellm" (default) or "claude".
75
+ agent_max_turns: Maximum agent turns when backend="claude". Defaults to 50.
76
+
77
+ Example::
78
+
79
+ from flyte.sandbox import sandbox_environment
80
+ from flyteplugins.codegen import AutoCoderAgent
81
+
82
+ agent = AutoCoderAgent(
83
+ model="gpt-4.1",
84
+ base_packages=["pandas"],
85
+ resources=flyte.Resources(cpu=1, memory="1Gi"),
86
+ )
87
+
88
+ env = flyte.TaskEnvironment(
89
+ name="my-env",
90
+ depends_on=[sandbox_environment],
91
+ )
92
+
93
+ @env.task
94
+ async def my_task(data_file: File) -> float:
95
+ result = await agent.generate.aio(
96
+ prompt="Process CSV data",
97
+ samples={"csv": data_file},
98
+ outputs={"total": float},
99
+ )
100
+ return await result.run.aio()
101
+ """
102
+
103
+ model: str
104
+ name: str = "auto-coder"
105
+ system_prompt: Optional[str] = None
106
+ api_key: Optional[str] = None
107
+ api_base: Optional[str] = None
108
+ litellm_params: Optional[dict] = None
109
+ base_packages: Optional[list[str]] = None
110
+ resources: Optional[flyte.Resources] = None
111
+ image_config: Optional[ImageConfig] = None
112
+ max_iterations: int = 10
113
+ max_sample_rows: int = 100
114
+ skip_tests: bool = False
115
+ sandbox_retries: int = 0
116
+ timeout: Optional[int] = None
117
+ env_vars: Optional[dict[str, str]] = None
118
+ secrets: Optional[list] = None
119
+ cache: str = "auto"
120
+ backend: Literal["litellm", "claude"] = "litellm"
121
+ agent_max_turns: int = 50
122
+
123
+ @syncify
124
+ async def generate(
125
+ self,
126
+ prompt: str,
127
+ schema: Optional[str] = None,
128
+ constraints: Optional[list[str]] = None,
129
+ samples: Optional[dict[str, pd.DataFrame | File]] = None,
130
+ inputs: Optional[dict[str, type]] = None,
131
+ outputs: Optional[dict[str, type]] = None,
132
+ ) -> CodeGenEvalResult:
133
+ """Generate and evaluate code in an isolated sandbox.
134
+
135
+ Each call is independent with its own sandbox, packages and execution environment.
136
+
137
+ Args:
138
+ prompt: The prompt to generate code from.
139
+ schema: Optional free-form context about data formats, structures or schemas.
140
+ Included verbatim in the LLM prompt. Use for input formats, output schemas,
141
+ database schemas or any structural context the LLM needs to generate code.
142
+ constraints: Optional list of constraints or requirements.
143
+ samples: Optional dict of sample data. Each value is sampled and included in
144
+ the LLM prompt for context, and converted to a File input for the
145
+ sandbox. Values are used as defaults at runtime — override them
146
+ when calling ``result.run()`` or ``result.as_task()``.
147
+ Supported types: File, pd.DataFrame.
148
+ inputs: Optional dict declaring non-sample CLI argument types
149
+ (e.g., ``{"threshold": float, "mode": str}``).
150
+ Sample entries are automatically added as File inputs — don't redeclare them here.
151
+ Supported types: str, int, float, bool, File.
152
+ outputs: Optional dict defining output types (e.g., ``{"result": str, "report": File}``).
153
+ Supported types: str, int, float, bool, datetime, timedelta, File.
154
+
155
+ Returns:
156
+ CodeGenEvalResult with solution and execution details.
157
+ """
158
+ language = "python"
159
+
160
+ # Input validation
161
+ if inputs:
162
+ supported_input_types = (str, int, float, bool, File)
163
+ for input_key, input_type in inputs.items():
164
+ if input_type not in supported_input_types:
165
+ supported_names = [t.__name__ for t in supported_input_types]
166
+ raise ValueError(
167
+ f"Unsupported input type for '{input_key}': {input_type}. "
168
+ f"Sandbox only supports: {', '.join(supported_names)}"
169
+ )
170
+
171
+ # Data processing
172
+ sample_files = None
173
+ extracted_data_context = None
174
+ data_schemas = {}
175
+ schema_input_tokens = 0
176
+ schema_output_tokens = 0
177
+
178
+ if samples:
179
+ logger.info(f"Processing {len(samples)} sample inputs...")
180
+ inferred_types = {}
181
+ sample_files = {}
182
+
183
+ for data_key, value in samples.items():
184
+ if isinstance(value, File):
185
+ inferred_types[data_key] = File
186
+ sample_files[data_key] = value
187
+ elif is_dataframe(value):
188
+ temp_file = Path(tempfile.gettempdir()) / f"{data_key}.csv"
189
+ value.to_csv(temp_file, index=False)
190
+ file_obj = await File.from_local(str(temp_file))
191
+ inferred_types[data_key] = File
192
+ sample_files[data_key] = file_obj
193
+ else:
194
+ raise ValueError(
195
+ f"Unsupported sample type for '{data_key}': {type(value)}. Supported: File, pd.DataFrame."
196
+ )
197
+
198
+ logger.info("Extracting data context (schema, stats, patterns) and inferring Pandera schemas...")
199
+ (
200
+ extracted_data_context,
201
+ data_schemas,
202
+ schema_input_tokens,
203
+ schema_output_tokens,
204
+ ) = await extract_data_context(
205
+ samples,
206
+ self.max_sample_rows,
207
+ constraints=constraints,
208
+ model=self.model,
209
+ litellm_params=self.litellm_params,
210
+ )
211
+ if data_schemas:
212
+ logger.info(f"Inferred Pandera schemas for: {list(data_schemas.keys())}")
213
+
214
+ if not inputs:
215
+ inputs = inferred_types
216
+ logger.info(f"Inferred input types: {inputs}")
217
+ else:
218
+ # Merge data-inferred types into user-provided inputs
219
+ for key, typ in inferred_types.items():
220
+ if key not in inputs:
221
+ inputs[key] = typ
222
+ logger.info(f"Merged input types: {inputs}")
223
+
224
+ schemas_as_code = data_schemas or {}
225
+
226
+ # Output validation
227
+ if outputs:
228
+ supported_types = (
229
+ str,
230
+ int,
231
+ float,
232
+ bool,
233
+ datetime.datetime,
234
+ datetime.timedelta,
235
+ File,
236
+ )
237
+ for output_key, output_type in outputs.items():
238
+ if output_type not in supported_types:
239
+ supported_names = [t.__name__ for t in supported_types]
240
+ raise ValueError(
241
+ f"Unsupported output type for '{output_key}': {output_type}. "
242
+ f"Sandbox only supports: {', '.join(supported_names)}"
243
+ )
244
+
245
+ # Agent SDK routing
246
+ if self.backend == "claude":
247
+ if self.skip_tests:
248
+ logger.warning(
249
+ "skip_tests is not supported with Agent SDK mode. The agent autonomously decides when to test."
250
+ )
251
+ logger.info("Using Claude Agent SDK approach")
252
+ return await code_gen_eval_agent(
253
+ name=self.name,
254
+ model=self.model,
255
+ prompt=prompt,
256
+ schema=schema,
257
+ constraints=constraints,
258
+ inputs=inputs,
259
+ outputs=outputs,
260
+ original_samples=sample_files,
261
+ data_context=extracted_data_context,
262
+ generated_schemas=schemas_as_code or None,
263
+ base_packages=self.base_packages,
264
+ resources=self.resources,
265
+ image_config=self.image_config,
266
+ retries=self.sandbox_retries,
267
+ timeout=self.timeout,
268
+ env_vars=self.env_vars,
269
+ secrets=self.secrets,
270
+ cache=self.cache,
271
+ max_turns=self.agent_max_turns,
272
+ )
273
+
274
+ logger.info(
275
+ f"Starting code generation: language={language}, model={self.model}, max_iterations={self.max_iterations}"
276
+ )
277
+
278
+ # LiteLLM setup
279
+ if self.api_key:
280
+ litellm.api_key = os.getenv(self.api_key)
281
+ if self.api_base:
282
+ litellm.api_base = self.api_base
283
+
284
+ # Build prompts
285
+ base_prompt_text = self.system_prompt or DEFAULT_SYSTEM_PROMPT
286
+ final_system_prompt = f"{base_prompt_text}\n{STRUCTURED_OUTPUT_REQUIREMENTS}"
287
+ enhanced_prompt = build_enhanced_prompt(
288
+ prompt,
289
+ language,
290
+ schema,
291
+ constraints,
292
+ extracted_data_context,
293
+ inputs,
294
+ outputs,
295
+ )
296
+ base_messages = [
297
+ {"role": "system", "content": final_system_prompt},
298
+ {"role": "user", "content": enhanced_prompt},
299
+ ]
300
+
301
+ # Generate plan
302
+ logger.info("Generating plan...")
303
+ plan, in_tok, out_tok = await generate_plan(
304
+ self.model,
305
+ prompt,
306
+ language,
307
+ schema,
308
+ constraints,
309
+ extracted_data_context,
310
+ inputs,
311
+ outputs,
312
+ self.litellm_params,
313
+ )
314
+ logger.info(f"Plan created: {plan.description}")
315
+ logger.info(f"Approach: {plan.approach}")
316
+
317
+ # Prepare base packages
318
+ base_pkgs = list(self.base_packages or [])
319
+ if not self.skip_tests:
320
+ test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
321
+ for pkg in test_framework_info["packages"]:
322
+ if pkg not in base_pkgs:
323
+ base_pkgs.append(pkg)
324
+
325
+ # Run iteration loop
326
+ session = _CodeGenSession(
327
+ agent=self,
328
+ language=language,
329
+ prompt=prompt,
330
+ schema=schema,
331
+ constraints=constraints,
332
+ inputs=inputs,
333
+ outputs=outputs,
334
+ base_packages=base_pkgs,
335
+ extracted_data_context=extracted_data_context,
336
+ sample_files=sample_files,
337
+ schemas_as_code=schemas_as_code,
338
+ base_messages=base_messages,
339
+ plan=plan,
340
+ skip_tests=self.skip_tests,
341
+ initial_input_tokens=(schema_input_tokens + in_tok) if samples else in_tok,
342
+ initial_output_tokens=((schema_output_tokens + out_tok) if samples else out_tok),
343
+ )
344
+ return await session.run()
345
+
346
+
347
+ class _CodeGenSession:
348
+ """Internal: manages mutable state for a single LiteLLM code generation run.
349
+
350
+ Encapsulates the retry loop, code/test generation, image building,
351
+ error diagnosis and reclassification logic.
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ *,
357
+ agent: AutoCoderAgent,
358
+ language: str,
359
+ prompt: str,
360
+ schema: Optional[str],
361
+ constraints: Optional[list[str]],
362
+ inputs: Optional[dict[str, type]],
363
+ outputs: Optional[dict[str, type]],
364
+ base_packages: list[str],
365
+ extracted_data_context: Optional[str],
366
+ sample_files: Optional[dict[str, File]],
367
+ schemas_as_code: dict[str, str],
368
+ base_messages: list[dict[str, str]],
369
+ plan: CodePlan,
370
+ skip_tests: bool,
371
+ initial_input_tokens: int,
372
+ initial_output_tokens: int,
373
+ ):
374
+ # Agent reference (immutable config)
375
+ self.agent = agent
376
+ self.name = agent.name
377
+ self.model = agent.model
378
+ self.max_iterations = 1 if skip_tests else agent.max_iterations
379
+ self.skip_tests = skip_tests
380
+ self.resources = agent.resources
381
+ self.sandbox_retries = agent.sandbox_retries
382
+ self.timeout = agent.timeout
383
+ self.env_vars = agent.env_vars
384
+ self.secrets = agent.secrets
385
+ self.cache = agent.cache
386
+ self.image_config = agent.image_config
387
+ self.litellm_params = agent.litellm_params
388
+
389
+ # Per-call config (immutable)
390
+ self.language = language
391
+ self.prompt = prompt
392
+ self.schema = schema
393
+ self.constraints = constraints
394
+ self.inputs = inputs
395
+ self.outputs = outputs
396
+ self.extracted_data_context = extracted_data_context
397
+ self.sample_files = sample_files
398
+ self.schemas_as_code = schemas_as_code
399
+ self.base_messages = base_messages
400
+ self.plan = plan
401
+
402
+ # Package state
403
+ self.base_pkgs = base_packages
404
+ self.detected_packages: list[str] = []
405
+ self.detected_system_packages: list[str] = []
406
+ self.additional_commands: list[str] = []
407
+ self.previously_installed_packages: list[str] = []
408
+ self.previously_installed_system_packages: list[str] = []
409
+
410
+ # Image state
411
+ self.current_image: Optional[str] = None
412
+ self.image_name = self._compute_image_name(self.base_pkgs, [])
413
+
414
+ # Generation state
415
+ self.solution: Optional[CodeSolution] = None
416
+ self.tests: Optional[str] = None
417
+ self.needs_new_code = True
418
+ self.needs_new_tests = True
419
+ self.needs_rebuild = True
420
+ self.last_packages_snapshot: tuple[set, set] = (set(), set())
421
+
422
+ # Error tracking
423
+ self.last_error: Optional[str] = None
424
+ self.last_error_message: Optional[str] = None
425
+ self.last_diagnosis = None
426
+ self.last_result: Optional[CodeGenEvalResult] = None
427
+
428
+ # Reclassification tracking
429
+ self.logic_fix_attempts: dict[tuple, int] = {}
430
+ self.test_fix_attempts: dict[tuple, int] = {}
431
+ self.max_logic_attempts = 1
432
+ self.max_test_attempts = 1
433
+
434
+ # Token tracking
435
+ self.total_input_tokens = initial_input_tokens
436
+ self.total_output_tokens = initial_output_tokens
437
+
438
+ # Add test framework system packages
439
+ if not self.skip_tests:
440
+ test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
441
+ for pkg in test_framework_info.get("system_packages", []):
442
+ if pkg not in self.detected_system_packages:
443
+ self.detected_system_packages.append(pkg)
444
+
445
+ def _compute_image_name(self, packages: list[str], system_packages: list[str]) -> str:
446
+ spec = {
447
+ "language": self.language,
448
+ "packages": sorted(packages),
449
+ "system_packages": sorted(system_packages),
450
+ }
451
+ config_hash = hashlib.sha256(json.dumps(spec, sort_keys=True).encode()).hexdigest()[:12]
452
+ return f"auto-coder-agent-{self.language}-{config_hash}"
453
+
454
+ def _track_tokens(self, in_tok: int, out_tok: int):
455
+ self.total_input_tokens += in_tok
456
+ self.total_output_tokens += out_tok
457
+
458
+ def _make_result(
459
+ self,
460
+ *,
461
+ success: bool,
462
+ test_output: str,
463
+ exit_code: int,
464
+ attempt: int,
465
+ error: Optional[str] = None,
466
+ ) -> CodeGenEvalResult:
467
+ return CodeGenEvalResult(
468
+ plan=self.plan,
469
+ solution=self.solution or CodeSolution(),
470
+ tests=self.tests,
471
+ success=success,
472
+ output=test_output,
473
+ exit_code=exit_code,
474
+ error=error,
475
+ attempts=attempt,
476
+ conversation_history=self.base_messages,
477
+ detected_packages=self.detected_packages,
478
+ detected_system_packages=self.detected_system_packages,
479
+ image=self.current_image,
480
+ total_input_tokens=self.total_input_tokens,
481
+ total_output_tokens=self.total_output_tokens,
482
+ declared_inputs=self.inputs,
483
+ declared_outputs=self.outputs,
484
+ data_context=self.extracted_data_context,
485
+ original_samples=self.sample_files,
486
+ generated_schemas=self.schemas_as_code or None,
487
+ )
488
+
489
+ async def run(self) -> CodeGenEvalResult:
490
+ """Execute the full retry loop."""
491
+ for attempt in range(1, self.max_iterations + 1):
492
+ logger.info(
493
+ f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
494
+ f"[ITERATION] Starting attempt {attempt}/{self.max_iterations}\n"
495
+ f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
496
+ )
497
+ try:
498
+ result = await self._attempt(attempt)
499
+ if result is not None:
500
+ return result
501
+ except Exception as e:
502
+ self.last_error = str(e)
503
+ logger.error(f"Error during attempt {attempt}: {self.last_error}")
504
+ self.last_error_message = f"An error occurred: {self.last_error}"
505
+ self.last_result = self._make_result(
506
+ success=False,
507
+ test_output="",
508
+ exit_code=-1,
509
+ attempt=attempt,
510
+ error=f"Attempt {attempt} failed: {self.last_error}",
511
+ )
512
+ if attempt == self.max_iterations:
513
+ return self.last_result
514
+ self.needs_new_code = True
515
+ if self.tests is None:
516
+ self.needs_new_tests = True
517
+
518
+ return self.last_result or self._make_result(
519
+ success=False,
520
+ test_output="",
521
+ exit_code=-1,
522
+ attempt=self.max_iterations,
523
+ error=f"All {self.max_iterations} attempts failed: {self.last_error}",
524
+ )
525
+
526
+ async def _attempt(self, attempt: int) -> Optional[CodeGenEvalResult]:
527
+ """Run a single iteration. Returns result on success, None to continue."""
528
+ # 1. Generate code (only when needed)
529
+ if self.needs_new_code:
530
+ logger.info("Generating code...")
531
+ if not await self._generate_code(attempt):
532
+ return None # Skip to next iteration
533
+
534
+ # Detect and track packages
535
+ (
536
+ self.needs_rebuild,
537
+ self.detected_packages,
538
+ self.detected_system_packages,
539
+ in_tok,
540
+ out_tok,
541
+ ) = await detect_and_track_packages(
542
+ self.model,
543
+ self.solution,
544
+ self.base_pkgs,
545
+ self.detected_packages,
546
+ self.detected_system_packages,
547
+ self.litellm_params,
548
+ )
549
+ self._track_tokens(in_tok, out_tok)
550
+ self.needs_new_code = False
551
+ if self.tests is None:
552
+ self.needs_new_tests = True
553
+
554
+ # Short-circuit: skip tests if requested
555
+ if self.skip_tests:
556
+ logger.info("skip_tests=True: skipping test generation and execution.")
557
+ # Still build the image (needed for as_task() and run())
558
+ self._update_image_name_if_needed()
559
+ if self.needs_rebuild or self.current_image is None:
560
+ await self._build_image()
561
+ return self._make_result(
562
+ success=True,
563
+ test_output="",
564
+ exit_code=0,
565
+ attempt=attempt,
566
+ )
567
+
568
+ # 2. Generate tests (only when needed)
569
+ if self.needs_new_tests:
570
+ logger.info("Generating tests...")
571
+ self.tests, in_tok, out_tok = await generate_tests(
572
+ self.model,
573
+ self.prompt,
574
+ self.plan,
575
+ self.solution,
576
+ self.constraints,
577
+ self.schema,
578
+ self.extracted_data_context,
579
+ self.inputs,
580
+ self.outputs,
581
+ self.litellm_params,
582
+ )
583
+ self._track_tokens(in_tok, out_tok)
584
+ self.needs_new_tests = False
585
+
586
+ # 3. Update image name if packages changed
587
+ self._update_image_name_if_needed()
588
+
589
+ # 4. Build/rebuild image if needed
590
+ if self.needs_rebuild or self.current_image is None:
591
+ await self._build_image()
592
+
593
+ # 5. Execute tests
594
+ logger.info("Running tests...")
595
+ run_tests_output = await run_tests.aio(
596
+ code=self.solution.code,
597
+ tests=self.tests,
598
+ image=self.current_image,
599
+ name=self.name,
600
+ resources=self.resources,
601
+ retries=self.sandbox_retries,
602
+ timeout=self.timeout,
603
+ env_vars=self.env_vars,
604
+ secrets=self.secrets,
605
+ cache=self.cache,
606
+ _attempt=attempt,
607
+ )
608
+
609
+ tests_passed, test_output, test_exit_code = (
610
+ run_tests_output.tests_passed,
611
+ run_tests_output.output,
612
+ run_tests_output.exit_code,
613
+ )
614
+
615
+ # 6. Handle success
616
+ if tests_passed:
617
+ logger.info("Tests passed! Solution successful.")
618
+ logger.info(f"Total tokens: input={self.total_input_tokens}, output={self.total_output_tokens}")
619
+ self.last_diagnosis = None
620
+ return self._make_result(
621
+ success=True,
622
+ test_output=test_output,
623
+ exit_code=int(test_exit_code.strip()),
624
+ attempt=attempt,
625
+ )
626
+
627
+ # 7. Handle failure
628
+ return await self._handle_failure(test_output, test_exit_code, attempt)
629
+
630
+ async def _generate_code(self, attempt: int) -> bool:
631
+ """Generate code with up to 3 verification attempts. Returns True on success."""
632
+ max_attempts = 3
633
+
634
+ for code_attempt in range(1, max_attempts + 1):
635
+ logger.info(f"Generating code (attempt {attempt}, code gen {code_attempt}/{max_attempts})...")
636
+
637
+ # Build messages with progressively forceful error context
638
+ messages = self.base_messages.copy()
639
+ if self.last_error_message:
640
+ if code_attempt == 1:
641
+ messages.append({"role": "user", "content": self.last_error_message})
642
+ elif code_attempt == 2:
643
+ messages.append(
644
+ {
645
+ "role": "user",
646
+ "content": (
647
+ f"{self.last_error_message}\n\n"
648
+ "CRITICAL: The previous code generation attempt did NOT "
649
+ "apply all the required fixes.\n"
650
+ "You MUST apply EVERY SINGLE fix listed above. "
651
+ "Do not skip any fix.\n"
652
+ "Apply each fix EXACTLY as specified "
653
+ "- find the old code and replace it with the new code."
654
+ ),
655
+ }
656
+ )
657
+ else:
658
+ messages.append(
659
+ {
660
+ "role": "user",
661
+ "content": (
662
+ f"{self.last_error_message}\n\n"
663
+ "FINAL ATTEMPT: You have failed to apply the required fixes twice.\n"
664
+ "This is your last chance. Apply EVERY fix listed above WITHOUT EXCEPTION.\n"
665
+ "For each fix, you MUST:\n"
666
+ "1. Find the EXACT old code mentioned\n"
667
+ "2. Replace it with the EXACT new code mentioned\n"
668
+ "3. Do NOT change anything else\n\n"
669
+ "If you fail to apply all fixes this time, the entire task will fail."
670
+ ),
671
+ }
672
+ )
673
+
674
+ self.solution, in_tok, out_tok = await generate_code(
675
+ self.model,
676
+ messages,
677
+ self.plan,
678
+ self.litellm_params,
679
+ is_retry=(attempt > 1),
680
+ )
681
+ self._track_tokens(in_tok, out_tok)
682
+
683
+ if self.solution.language.lower() != self.language.lower():
684
+ logger.warning(f"Requested {self.language} but LLM generated {self.solution.language}")
685
+
686
+ # Verify fixes are applied (if we have a diagnosis with logic failures)
687
+ has_logic_failures = self.last_diagnosis and any(
688
+ f.error_type == "logic" for f in self.last_diagnosis.failures
689
+ )
690
+ if not has_logic_failures:
691
+ return True
692
+
693
+ logger.info("Verifying that logic fixes were applied...")
694
+ verification, in_tok, out_tok = await verify_logic_fixes_applied(
695
+ self.model,
696
+ self.last_diagnosis,
697
+ self.solution,
698
+ self.litellm_params,
699
+ )
700
+ self._track_tokens(in_tok, out_tok)
701
+
702
+ if verification.all_fixes_applied:
703
+ logger.info(f"Verification passed: All fixes applied. {verification.explanation}")
704
+ return True
705
+
706
+ logger.warning(f"Verification failed: {verification.explanation}")
707
+ logger.warning(f"Applied: {verification.applied_fixes}")
708
+ logger.warning(f"Missing: {verification.missing_fixes}")
709
+
710
+ if code_attempt < max_attempts:
711
+ # Append verification feedback for next attempt
712
+ missing_msg = "\n\nVERIFICATION FAILED - The following fixes are STILL MISSING:\n"
713
+ for i, fix in enumerate(verification.missing_fixes, 1):
714
+ missing_msg += f"\n{i}. {fix}"
715
+ missing_msg += "\n\nYou successfully applied these fixes:\n"
716
+ for fix in verification.applied_fixes:
717
+ missing_msg += f"- {fix}\n"
718
+ missing_msg += (
719
+ "\nYou MUST now apply the MISSING fixes listed above. "
720
+ "Do NOT regenerate the entire solution - just apply the missing fixes to your previous code."
721
+ )
722
+ self.last_error_message = (self.last_error_message or "") + missing_msg
723
+ else:
724
+ logger.error(f"Failed to apply all fixes after {max_attempts} attempts. Proceeding anyway...")
725
+ return True # Proceed anyway
726
+
727
+ logger.error("Failed to generate code with all fixes applied. Skipping this iteration.")
728
+ return False
729
+
730
+ def _update_image_name_if_needed(self):
731
+ """Check if packages changed and update image name."""
732
+ current_snapshot = (
733
+ set(self.detected_packages),
734
+ set(self.detected_system_packages),
735
+ )
736
+ if current_snapshot == self.last_packages_snapshot:
737
+ return
738
+
739
+ self.needs_rebuild = True
740
+ self.last_packages_snapshot = current_snapshot
741
+
742
+ all_packages = self.base_pkgs + self.detected_packages
743
+ new_name = self._compute_image_name(all_packages, self.detected_system_packages)
744
+ if new_name != self.image_name:
745
+ logger.info(f"Image name updated: {self.image_name} -> {new_name}")
746
+ self.image_name = new_name
747
+ self.current_image = None
748
+
749
+ async def _build_image(self):
750
+ """Build/rebuild image with package retry loop."""
751
+ max_retries = 3
752
+ for _ in range(max_retries):
753
+ try:
754
+ self.current_image = await build_image(
755
+ self.solution.language,
756
+ self.base_pkgs,
757
+ self.detected_packages,
758
+ self.detected_system_packages,
759
+ self.previously_installed_packages,
760
+ self.previously_installed_system_packages,
761
+ self.additional_commands,
762
+ self.image_name,
763
+ self.current_image,
764
+ self.image_config,
765
+ )
766
+ self.previously_installed_packages = self.detected_packages.copy()
767
+ self.previously_installed_system_packages = self.detected_system_packages.copy()
768
+ self.needs_rebuild = False
769
+ return
770
+ except InvalidPackageError as e:
771
+ bad_package = e.package_name
772
+ logger.warning(f"Invalid system package '{bad_package}', asking LLM for replacement...")
773
+ if bad_package in self.detected_system_packages:
774
+ self.detected_system_packages.remove(bad_package)
775
+ if bad_package in self.previously_installed_system_packages:
776
+ self.previously_installed_system_packages.remove(bad_package)
777
+
778
+ # Ask LLM for the correct package name
779
+ solution_code = self.solution.code
780
+ replacement, in_tok, out_tok = await suggest_replacement_package(
781
+ self.model,
782
+ bad_package,
783
+ e.original_error,
784
+ solution_code,
785
+ self.litellm_params,
786
+ )
787
+ self.total_input_tokens += in_tok
788
+ self.total_output_tokens += out_tok
789
+
790
+ if replacement and replacement not in self.detected_system_packages:
791
+ self.detected_system_packages.append(replacement)
792
+ logger.info(f"Replacing '{bad_package}' with '{replacement}'")
793
+
794
+ logger.info(f"Retrying with system packages: {self.detected_system_packages}")
795
+
796
+ async def _handle_failure(
797
+ self,
798
+ test_output: str,
799
+ test_exit_code: str,
800
+ attempt: int,
801
+ ) -> Optional[CodeGenEvalResult]:
802
+ """Handle test failure: diagnose, reclassify, fix tests or code. Returns None to continue."""
803
+ # Check if tests actually executed
804
+ tests_executed = (
805
+ " passed" in test_output or " failed" in test_output or "collected 0 items" not in test_output
806
+ ) and "ERROR collecting" not in test_output
807
+
808
+ if not tests_executed:
809
+ logger.warning("No tests executed - test file likely has errors. Regenerating tests...")
810
+ self.needs_new_tests = True
811
+ self.needs_new_code = False
812
+ return None
813
+
814
+ # Diagnose failures
815
+ (
816
+ primary_error_type,
817
+ self.detected_packages,
818
+ self.detected_system_packages,
819
+ self.additional_commands,
820
+ in_tok,
821
+ out_tok,
822
+ diagnosis,
823
+ ) = await diagnose_and_plan_environment_fix(
824
+ self.model,
825
+ self.solution,
826
+ test_output,
827
+ self.prompt,
828
+ self.plan,
829
+ self.detected_packages,
830
+ self.detected_system_packages,
831
+ self.additional_commands,
832
+ self.litellm_params,
833
+ self.tests,
834
+ self.extracted_data_context,
835
+ self.constraints,
836
+ self.schema,
837
+ )
838
+ self._track_tokens(in_tok, out_tok)
839
+ self.last_diagnosis = diagnosis
840
+
841
+ # Apply environment fixes from diagnosis
842
+ self._apply_environment_fixes(diagnosis)
843
+
844
+ # Reclassify repeated errors
845
+ primary_error_type = self._reclassify_errors(diagnosis, primary_error_type)
846
+
847
+ # Handle test errors: fix tests
848
+ if primary_error_type == "test_error":
849
+ return await self._handle_test_errors(diagnosis, attempt)
850
+
851
+ # Handle logic/environment errors: build patch message
852
+ return self._handle_logic_env_errors(diagnosis, test_output, test_exit_code, attempt)
853
+
854
+ def _apply_environment_fixes(self, diagnosis):
855
+ """Extract and apply environment fixes from diagnosis."""
856
+ if diagnosis.needs_language_packages:
857
+ added = [
858
+ p
859
+ for p in diagnosis.needs_language_packages
860
+ if p not in self.detected_packages and p not in self.base_pkgs
861
+ ]
862
+ if added:
863
+ self.detected_packages.extend(added)
864
+ logger.info(f"Adding language packages from diagnosis: {added}")
865
+ self.needs_rebuild = True
866
+
867
+ if diagnosis.needs_system_packages:
868
+ added = [p for p in diagnosis.needs_system_packages if p not in self.detected_system_packages]
869
+ if added:
870
+ self.detected_system_packages.extend(added)
871
+ logger.info(f"Adding system packages from diagnosis: {added}")
872
+ self.needs_rebuild = True
873
+
874
+ if diagnosis.needs_additional_commands:
875
+ logger.info(f"Adding additional commands from diagnosis: {diagnosis.needs_additional_commands}")
876
+ self.additional_commands.extend(diagnosis.needs_additional_commands)
877
+ self.needs_rebuild = True
878
+
879
+ def _reclassify_errors(self, diagnosis, primary_error_type: str) -> str:
880
+ """Reclassify repeated errors (test_error <-> logic). Returns updated primary_error_type."""
881
+ # test_error -> logic (test might be correct, code is wrong)
882
+ reclassified = 0
883
+ for failure in diagnosis.failures:
884
+ if failure.error_type != "test_error":
885
+ continue
886
+ sig = failure.error_message or failure.actual_behavior
887
+ key = (failure.test_name, sig)
888
+ self.test_fix_attempts[key] = self.test_fix_attempts.get(key, 0) + 1
889
+
890
+ if self.test_fix_attempts[key] > self.max_test_attempts:
891
+ original = failure.root_cause
892
+ failure.error_type = "logic"
893
+ failure.root_cause = (
894
+ f"Test failed {self.max_test_attempts + 1} times with same error after test fixes. "
895
+ f"The test expectations are likely correct. The code logic could be wrong. "
896
+ f"Original test diagnosis was: {original}"
897
+ )
898
+ failure.suggested_fix = (
899
+ f"Fix the code logic to match test expectations. "
900
+ f"Test expects: {failure.expected_behavior}. "
901
+ f"Code produces: {failure.actual_behavior}. "
902
+ f"Update the code to produce the expected behavior."
903
+ )
904
+ self.test_fix_attempts.pop(key, None)
905
+ reclassified += 1
906
+ logger.warning(f"Reclassified test_error -> logic for '{failure.test_name}'")
907
+
908
+ if reclassified:
909
+ logger.info(f"Reclassified {reclassified} test_error(s) to logic.")
910
+
911
+ # logic -> test_error (LLM might misdiagnose test bugs as logic bugs)
912
+ reclassified = 0
913
+ for failure in diagnosis.failures:
914
+ if failure.error_type != "logic":
915
+ continue
916
+ sig = failure.error_message or failure.actual_behavior
917
+ key = (failure.test_name, sig)
918
+ self.logic_fix_attempts[key] = self.logic_fix_attempts.get(key, 0) + 1
919
+
920
+ if self.logic_fix_attempts[key] > self.max_logic_attempts:
921
+ original = failure.root_cause
922
+ failure.error_type = "test_error"
923
+ failure.root_cause = (
924
+ f"Test failed {self.max_logic_attempts + 1} times with same error after logic fixes. "
925
+ f"Likely the test itself has wrong expected values, not the code. "
926
+ f"Original diagnosis was: {original}"
927
+ )
928
+ failure.suggested_fix = (
929
+ f"Fix the test expectations to match actual correct behavior. "
930
+ f"Code produces: {failure.actual_behavior}. "
931
+ f"If this is correct, update the test to expect this value instead."
932
+ )
933
+ self.logic_fix_attempts.pop(key, None)
934
+ reclassified += 1
935
+ logger.warning(f"Reclassified logic -> test_error for '{failure.test_name}'")
936
+
937
+ if reclassified:
938
+ logger.info(f"Reclassified {reclassified} logic error(s) to test_error.")
939
+ has_test_errors = any(f.error_type == "test_error" for f in diagnosis.failures)
940
+ if has_test_errors:
941
+ diagnosis.failures = [f for f in diagnosis.failures if f.error_type == "test_error"]
942
+ primary_error_type = "test_error"
943
+ logger.info(f"After reclassification: {len(diagnosis.failures)} test_error failure(s).")
944
+
945
+ return primary_error_type
946
+
947
+ async def _handle_test_errors(self, diagnosis, attempt: int) -> Optional[CodeGenEvalResult]:
948
+ """Fix failing tests with up to 3 verification attempts. Returns None to continue."""
949
+ logger.info("Diagnosis identified bug in test code. Fixing only failed tests...")
950
+ logger.info(f"Failed tests to fix: {[f.test_name for f in diagnosis.failures]}")
951
+
952
+ max_attempts = 3
953
+
954
+ for fix_attempt in range(1, max_attempts + 1):
955
+ logger.info(f"Fixing failing tests (attempt {fix_attempt}/{max_attempts})...")
956
+
957
+ self.tests, patches, in_tok, out_tok = await fix_failing_tests(
958
+ self.model,
959
+ self.tests,
960
+ diagnosis,
961
+ self.solution,
962
+ self.litellm_params,
963
+ )
964
+ self._track_tokens(in_tok, out_tok)
965
+
966
+ logger.info("Verifying that test fixes were applied...")
967
+ verification, in_tok, out_tok = await verify_test_fixes_applied(
968
+ self.model,
969
+ diagnosis,
970
+ patches,
971
+ self.litellm_params,
972
+ )
973
+ self._track_tokens(in_tok, out_tok)
974
+
975
+ if verification.all_fixes_applied:
976
+ logger.info(f"Verification passed: All test fixes applied. {verification.explanation}")
977
+ self.needs_new_code = False
978
+ self.last_diagnosis = None
979
+ return None # Continue to next iteration with fixed tests
980
+
981
+ logger.warning(f"Verification failed: {verification.explanation}")
982
+
983
+ if fix_attempt < max_attempts:
984
+ missing_msg = "\n\nVERIFICATION FAILED - The following test fixes are STILL MISSING:\n"
985
+ for i, fix in enumerate(verification.missing_fixes, 1):
986
+ missing_msg += f"\n{i}. {fix}"
987
+ missing_msg += "\n\nYou successfully applied these fixes:\n"
988
+ for fix in verification.applied_fixes:
989
+ missing_msg += f"- {fix}\n"
990
+ missing_msg += "\nYou MUST now apply the MISSING test fixes listed above."
991
+
992
+ forceful = ""
993
+ if fix_attempt == 2:
994
+ forceful = (
995
+ "\n\nCRITICAL: The previous test fix attempt did NOT apply all the required fixes. "
996
+ "You MUST apply EVERY SINGLE fix listed above. Do not skip any fix."
997
+ )
998
+ elif fix_attempt >= 3:
999
+ forceful = (
1000
+ "\n\nFINAL ATTEMPT: You have failed to apply the required test fixes twice. "
1001
+ "This is your last chance. Apply EVERY fix listed above WITHOUT EXCEPTION."
1002
+ )
1003
+
1004
+ for failure in diagnosis.failures:
1005
+ failure.suggested_fix = f"{failure.suggested_fix}\n\n{missing_msg}{forceful}"
1006
+ else:
1007
+ logger.error(f"Failed to apply all test fixes after {max_attempts} attempts. Proceeding anyway...")
1008
+ self.needs_new_code = False
1009
+ self.last_diagnosis = None
1010
+ return None
1011
+
1012
+ return None
1013
+
1014
+ def _handle_logic_env_errors(
1015
+ self,
1016
+ diagnosis,
1017
+ test_output: str,
1018
+ test_exit_code: str,
1019
+ attempt: int,
1020
+ ) -> Optional[CodeGenEvalResult]:
1021
+ """Handle logic and/or environment errors. Returns None to continue, result if max retries."""
1022
+ failures_info = []
1023
+ logic_count = 0
1024
+ env_count = 0
1025
+
1026
+ for i, failure in enumerate(diagnosis.failures, 1):
1027
+ failures_info.append(
1028
+ f"\nTest {i} [{failure.error_type}] - {failure.test_name}\n"
1029
+ f"- Expected: {failure.expected_behavior}\n"
1030
+ f"- Actual: {failure.actual_behavior}\n"
1031
+ f"- Root cause: {failure.root_cause}\n"
1032
+ f"- FIX: {failure.suggested_fix}"
1033
+ )
1034
+ if failure.error_type == "logic":
1035
+ logic_count += 1
1036
+ elif failure.error_type == "environment":
1037
+ env_count += 1
1038
+
1039
+ if logic_count > 0 and env_count > 0:
1040
+ logger.info(f"Will fix {env_count} environment error(s) and patch code for {logic_count} logic error(s)")
1041
+ elif logic_count > 0:
1042
+ logger.info(f"Will patch code for {logic_count} logic error(s)")
1043
+ elif env_count > 0:
1044
+ logger.info(f"Will fix {env_count} environment error(s)")
1045
+
1046
+ full_code = self.solution.code
1047
+ error_msg = (
1048
+ "Tests failed. Apply only the specific fixes below to your code.\n\n"
1049
+ "Do not regenerate from scratch. PATCH the code by applying ONLY the fixes below.\n"
1050
+ "Do NOT make any other changes - keep everything else exactly as is.\n\n"
1051
+ "CRITICAL CONSTRAINTS:\n"
1052
+ "1. /var/outputs is a PRE-EXISTING directory. NEVER delete, recreate, or modify it. "
1053
+ "NEVER use shutil.rmtree or os.makedirs on /var/outputs. Only write files into it using: "
1054
+ "open('/var/outputs/<name>', 'w').write(str(value)). Always use the literal path '/var/outputs' "
1055
+ "-- never make it configurable or store it in a variable.\n"
1056
+ "2. If a part of the code is working correctly, DO NOT change it. Only fix what's broken.\n"
1057
+ "3. Apply each fix by finding the exact code quoted and replacing it - nothing more.\n"
1058
+ "4. Do NOT regenerate the entire code. Just apply the specific patches mentioned below.\n\n"
1059
+ f"Your previous code:\n```{self.solution.language}\n{full_code}\n```\n\n" + "\n".join(failures_info)
1060
+ )
1061
+
1062
+ if logic_count > 0:
1063
+ logger.info("Tests failed. Will patch code with fixes (keeping same tests)...")
1064
+ self.last_error_message = error_msg
1065
+ else:
1066
+ self.last_error_message = None
1067
+
1068
+ self.last_result = self._make_result(
1069
+ success=False,
1070
+ test_output=test_output,
1071
+ exit_code=int(test_exit_code.strip()),
1072
+ attempt=attempt,
1073
+ error=error_msg,
1074
+ )
1075
+
1076
+ if attempt == self.max_iterations:
1077
+ return self.last_result
1078
+
1079
+ # Set flags for next iteration
1080
+ if logic_count > 0:
1081
+ self.needs_new_code = True
1082
+ self.needs_new_tests = False
1083
+ elif env_count > 0:
1084
+ logger.info("Only environment errors - skipping code regeneration, will rebuild image with new packages")
1085
+ self.needs_new_code = False
1086
+ self.needs_new_tests = False
1087
+
1088
+ return None