flyteplugins-codegen 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/codegen/__init__.py +18 -0
- flyteplugins/codegen/auto_coder_agent.py +1088 -0
- flyteplugins/codegen/core/__init__.py +19 -0
- flyteplugins/codegen/core/types.py +337 -0
- flyteplugins/codegen/data/__init__.py +27 -0
- flyteplugins/codegen/data/extraction.py +281 -0
- flyteplugins/codegen/data/schema.py +270 -0
- flyteplugins/codegen/execution/__init__.py +7 -0
- flyteplugins/codegen/execution/agent.py +671 -0
- flyteplugins/codegen/execution/docker.py +206 -0
- flyteplugins/codegen/generation/__init__.py +41 -0
- flyteplugins/codegen/generation/llm.py +1269 -0
- flyteplugins/codegen/generation/prompts.py +136 -0
- flyteplugins_codegen-2.0.6.dist-info/METADATA +441 -0
- flyteplugins_codegen-2.0.6.dist-info/RECORD +17 -0
- flyteplugins_codegen-2.0.6.dist-info/WHEEL +5 -0
- flyteplugins_codegen-2.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1088 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import os
|
|
6
|
+
import tempfile
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Literal, Optional
|
|
10
|
+
|
|
11
|
+
import flyte
|
|
12
|
+
import litellm
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from flyte.errors import InvalidPackageError
|
|
15
|
+
from flyte.io import File
|
|
16
|
+
from flyte.sandbox import ImageConfig
|
|
17
|
+
from flyte.syncify import syncify
|
|
18
|
+
|
|
19
|
+
from flyteplugins.codegen.core.types import CodeGenEvalResult, CodePlan, CodeSolution
|
|
20
|
+
from flyteplugins.codegen.data.extraction import extract_data_context, is_dataframe
|
|
21
|
+
from flyteplugins.codegen.execution.agent import code_gen_eval_agent
|
|
22
|
+
from flyteplugins.codegen.execution.docker import build_image, run_tests
|
|
23
|
+
from flyteplugins.codegen.generation.llm import (
|
|
24
|
+
detect_and_track_packages,
|
|
25
|
+
diagnose_and_plan_environment_fix,
|
|
26
|
+
fix_failing_tests,
|
|
27
|
+
generate_code,
|
|
28
|
+
generate_plan,
|
|
29
|
+
generate_tests,
|
|
30
|
+
suggest_replacement_package,
|
|
31
|
+
verify_logic_fixes_applied,
|
|
32
|
+
verify_test_fixes_applied,
|
|
33
|
+
)
|
|
34
|
+
from flyteplugins.codegen.generation.prompts import (
|
|
35
|
+
DEFAULT_SYSTEM_PROMPT,
|
|
36
|
+
STRUCTURED_OUTPUT_REQUIREMENTS,
|
|
37
|
+
TEST_FRAMEWORKS,
|
|
38
|
+
build_enhanced_prompt,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
logger = logging.getLogger(__name__)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class AutoCoderAgent:
|
|
46
|
+
"""Agent for single-file Python code generation with automatic testing and iteration.
|
|
47
|
+
|
|
48
|
+
Generates a single Python script, builds a sandbox image with the required
|
|
49
|
+
dependencies, runs pytest-based tests, and iterates until tests pass.
|
|
50
|
+
|
|
51
|
+
Uses Sandbox internally for isolated code execution.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
name: Name for the agent (used in image naming and logging).
|
|
55
|
+
model: LLM model to use (required). Must support structured outputs.
|
|
56
|
+
For backend="litellm" (default): e.g. "gpt-4.1", "claude-sonnet-4-20250514".
|
|
57
|
+
For backend="claude": a Claude model ("sonnet", "opus", "haiku").
|
|
58
|
+
system_prompt: Optional system prompt to use for LLM. If not provided,
|
|
59
|
+
a default prompt with structured output requirements is used.
|
|
60
|
+
api_key: Optional environment variable name for LLM API key.
|
|
61
|
+
api_base: Optional base URL for LLM API.
|
|
62
|
+
litellm_params: Optional dict of additional parameters to pass to LiteLLM calls.
|
|
63
|
+
base_packages: Optional list of base packages to install in the sandbox.
|
|
64
|
+
resources: Optional resources for sandbox execution (default: cpu=1, 1Gi).
|
|
65
|
+
image_config: Optional image configuration for sandbox execution.
|
|
66
|
+
max_iterations: Maximum number of generate-test-fix iterations. Defaults to 10.
|
|
67
|
+
max_sample_rows: Optional maximum number of rows to use for sample data. Defaults to 100.
|
|
68
|
+
skip_tests: Optional flag to skip testing. Defaults to False.
|
|
69
|
+
sandbox_retries: Number of Flyte task-level retries for each sandbox execution. Defaults to 0.
|
|
70
|
+
timeout: Timeout in seconds for sandboxes. Defaults to None.
|
|
71
|
+
env_vars: Environment variables to pass to sandboxes.
|
|
72
|
+
secrets: flyte.Secret objects to make available to sandboxes.
|
|
73
|
+
cache: CacheRequest for sandboxes: "auto", "override", or "disable". Defaults to "auto".
|
|
74
|
+
backend: Execution backend: "litellm" (default) or "claude".
|
|
75
|
+
agent_max_turns: Maximum agent turns when backend="claude". Defaults to 50.
|
|
76
|
+
|
|
77
|
+
Example::
|
|
78
|
+
|
|
79
|
+
from flyte.sandbox import sandbox_environment
|
|
80
|
+
from flyteplugins.codegen import AutoCoderAgent
|
|
81
|
+
|
|
82
|
+
agent = AutoCoderAgent(
|
|
83
|
+
model="gpt-4.1",
|
|
84
|
+
base_packages=["pandas"],
|
|
85
|
+
resources=flyte.Resources(cpu=1, memory="1Gi"),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
env = flyte.TaskEnvironment(
|
|
89
|
+
name="my-env",
|
|
90
|
+
depends_on=[sandbox_environment],
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
@env.task
|
|
94
|
+
async def my_task(data_file: File) -> float:
|
|
95
|
+
result = await agent.generate.aio(
|
|
96
|
+
prompt="Process CSV data",
|
|
97
|
+
samples={"csv": data_file},
|
|
98
|
+
outputs={"total": float},
|
|
99
|
+
)
|
|
100
|
+
return await result.run.aio()
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
model: str
|
|
104
|
+
name: str = "auto-coder"
|
|
105
|
+
system_prompt: Optional[str] = None
|
|
106
|
+
api_key: Optional[str] = None
|
|
107
|
+
api_base: Optional[str] = None
|
|
108
|
+
litellm_params: Optional[dict] = None
|
|
109
|
+
base_packages: Optional[list[str]] = None
|
|
110
|
+
resources: Optional[flyte.Resources] = None
|
|
111
|
+
image_config: Optional[ImageConfig] = None
|
|
112
|
+
max_iterations: int = 10
|
|
113
|
+
max_sample_rows: int = 100
|
|
114
|
+
skip_tests: bool = False
|
|
115
|
+
sandbox_retries: int = 0
|
|
116
|
+
timeout: Optional[int] = None
|
|
117
|
+
env_vars: Optional[dict[str, str]] = None
|
|
118
|
+
secrets: Optional[list] = None
|
|
119
|
+
cache: str = "auto"
|
|
120
|
+
backend: Literal["litellm", "claude"] = "litellm"
|
|
121
|
+
agent_max_turns: int = 50
|
|
122
|
+
|
|
123
|
+
@syncify
|
|
124
|
+
async def generate(
|
|
125
|
+
self,
|
|
126
|
+
prompt: str,
|
|
127
|
+
schema: Optional[str] = None,
|
|
128
|
+
constraints: Optional[list[str]] = None,
|
|
129
|
+
samples: Optional[dict[str, pd.DataFrame | File]] = None,
|
|
130
|
+
inputs: Optional[dict[str, type]] = None,
|
|
131
|
+
outputs: Optional[dict[str, type]] = None,
|
|
132
|
+
) -> CodeGenEvalResult:
|
|
133
|
+
"""Generate and evaluate code in an isolated sandbox.
|
|
134
|
+
|
|
135
|
+
Each call is independent with its own sandbox, packages and execution environment.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
prompt: The prompt to generate code from.
|
|
139
|
+
schema: Optional free-form context about data formats, structures or schemas.
|
|
140
|
+
Included verbatim in the LLM prompt. Use for input formats, output schemas,
|
|
141
|
+
database schemas or any structural context the LLM needs to generate code.
|
|
142
|
+
constraints: Optional list of constraints or requirements.
|
|
143
|
+
samples: Optional dict of sample data. Each value is sampled and included in
|
|
144
|
+
the LLM prompt for context, and converted to a File input for the
|
|
145
|
+
sandbox. Values are used as defaults at runtime — override them
|
|
146
|
+
when calling ``result.run()`` or ``result.as_task()``.
|
|
147
|
+
Supported types: File, pd.DataFrame.
|
|
148
|
+
inputs: Optional dict declaring non-sample CLI argument types
|
|
149
|
+
(e.g., ``{"threshold": float, "mode": str}``).
|
|
150
|
+
Sample entries are automatically added as File inputs — don't redeclare them here.
|
|
151
|
+
Supported types: str, int, float, bool, File.
|
|
152
|
+
outputs: Optional dict defining output types (e.g., ``{"result": str, "report": File}``).
|
|
153
|
+
Supported types: str, int, float, bool, datetime, timedelta, File.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
CodeGenEvalResult with solution and execution details.
|
|
157
|
+
"""
|
|
158
|
+
language = "python"
|
|
159
|
+
|
|
160
|
+
# Input validation
|
|
161
|
+
if inputs:
|
|
162
|
+
supported_input_types = (str, int, float, bool, File)
|
|
163
|
+
for input_key, input_type in inputs.items():
|
|
164
|
+
if input_type not in supported_input_types:
|
|
165
|
+
supported_names = [t.__name__ for t in supported_input_types]
|
|
166
|
+
raise ValueError(
|
|
167
|
+
f"Unsupported input type for '{input_key}': {input_type}. "
|
|
168
|
+
f"Sandbox only supports: {', '.join(supported_names)}"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Data processing
|
|
172
|
+
sample_files = None
|
|
173
|
+
extracted_data_context = None
|
|
174
|
+
data_schemas = {}
|
|
175
|
+
schema_input_tokens = 0
|
|
176
|
+
schema_output_tokens = 0
|
|
177
|
+
|
|
178
|
+
if samples:
|
|
179
|
+
logger.info(f"Processing {len(samples)} sample inputs...")
|
|
180
|
+
inferred_types = {}
|
|
181
|
+
sample_files = {}
|
|
182
|
+
|
|
183
|
+
for data_key, value in samples.items():
|
|
184
|
+
if isinstance(value, File):
|
|
185
|
+
inferred_types[data_key] = File
|
|
186
|
+
sample_files[data_key] = value
|
|
187
|
+
elif is_dataframe(value):
|
|
188
|
+
temp_file = Path(tempfile.gettempdir()) / f"{data_key}.csv"
|
|
189
|
+
value.to_csv(temp_file, index=False)
|
|
190
|
+
file_obj = await File.from_local(str(temp_file))
|
|
191
|
+
inferred_types[data_key] = File
|
|
192
|
+
sample_files[data_key] = file_obj
|
|
193
|
+
else:
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"Unsupported sample type for '{data_key}': {type(value)}. Supported: File, pd.DataFrame."
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
logger.info("Extracting data context (schema, stats, patterns) and inferring Pandera schemas...")
|
|
199
|
+
(
|
|
200
|
+
extracted_data_context,
|
|
201
|
+
data_schemas,
|
|
202
|
+
schema_input_tokens,
|
|
203
|
+
schema_output_tokens,
|
|
204
|
+
) = await extract_data_context(
|
|
205
|
+
samples,
|
|
206
|
+
self.max_sample_rows,
|
|
207
|
+
constraints=constraints,
|
|
208
|
+
model=self.model,
|
|
209
|
+
litellm_params=self.litellm_params,
|
|
210
|
+
)
|
|
211
|
+
if data_schemas:
|
|
212
|
+
logger.info(f"Inferred Pandera schemas for: {list(data_schemas.keys())}")
|
|
213
|
+
|
|
214
|
+
if not inputs:
|
|
215
|
+
inputs = inferred_types
|
|
216
|
+
logger.info(f"Inferred input types: {inputs}")
|
|
217
|
+
else:
|
|
218
|
+
# Merge data-inferred types into user-provided inputs
|
|
219
|
+
for key, typ in inferred_types.items():
|
|
220
|
+
if key not in inputs:
|
|
221
|
+
inputs[key] = typ
|
|
222
|
+
logger.info(f"Merged input types: {inputs}")
|
|
223
|
+
|
|
224
|
+
schemas_as_code = data_schemas or {}
|
|
225
|
+
|
|
226
|
+
# Output validation
|
|
227
|
+
if outputs:
|
|
228
|
+
supported_types = (
|
|
229
|
+
str,
|
|
230
|
+
int,
|
|
231
|
+
float,
|
|
232
|
+
bool,
|
|
233
|
+
datetime.datetime,
|
|
234
|
+
datetime.timedelta,
|
|
235
|
+
File,
|
|
236
|
+
)
|
|
237
|
+
for output_key, output_type in outputs.items():
|
|
238
|
+
if output_type not in supported_types:
|
|
239
|
+
supported_names = [t.__name__ for t in supported_types]
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Unsupported output type for '{output_key}': {output_type}. "
|
|
242
|
+
f"Sandbox only supports: {', '.join(supported_names)}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
# Agent SDK routing
|
|
246
|
+
if self.backend == "claude":
|
|
247
|
+
if self.skip_tests:
|
|
248
|
+
logger.warning(
|
|
249
|
+
"skip_tests is not supported with Agent SDK mode. The agent autonomously decides when to test."
|
|
250
|
+
)
|
|
251
|
+
logger.info("Using Claude Agent SDK approach")
|
|
252
|
+
return await code_gen_eval_agent(
|
|
253
|
+
name=self.name,
|
|
254
|
+
model=self.model,
|
|
255
|
+
prompt=prompt,
|
|
256
|
+
schema=schema,
|
|
257
|
+
constraints=constraints,
|
|
258
|
+
inputs=inputs,
|
|
259
|
+
outputs=outputs,
|
|
260
|
+
original_samples=sample_files,
|
|
261
|
+
data_context=extracted_data_context,
|
|
262
|
+
generated_schemas=schemas_as_code or None,
|
|
263
|
+
base_packages=self.base_packages,
|
|
264
|
+
resources=self.resources,
|
|
265
|
+
image_config=self.image_config,
|
|
266
|
+
retries=self.sandbox_retries,
|
|
267
|
+
timeout=self.timeout,
|
|
268
|
+
env_vars=self.env_vars,
|
|
269
|
+
secrets=self.secrets,
|
|
270
|
+
cache=self.cache,
|
|
271
|
+
max_turns=self.agent_max_turns,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
logger.info(
|
|
275
|
+
f"Starting code generation: language={language}, model={self.model}, max_iterations={self.max_iterations}"
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# LiteLLM setup
|
|
279
|
+
if self.api_key:
|
|
280
|
+
litellm.api_key = os.getenv(self.api_key)
|
|
281
|
+
if self.api_base:
|
|
282
|
+
litellm.api_base = self.api_base
|
|
283
|
+
|
|
284
|
+
# Build prompts
|
|
285
|
+
base_prompt_text = self.system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
286
|
+
final_system_prompt = f"{base_prompt_text}\n{STRUCTURED_OUTPUT_REQUIREMENTS}"
|
|
287
|
+
enhanced_prompt = build_enhanced_prompt(
|
|
288
|
+
prompt,
|
|
289
|
+
language,
|
|
290
|
+
schema,
|
|
291
|
+
constraints,
|
|
292
|
+
extracted_data_context,
|
|
293
|
+
inputs,
|
|
294
|
+
outputs,
|
|
295
|
+
)
|
|
296
|
+
base_messages = [
|
|
297
|
+
{"role": "system", "content": final_system_prompt},
|
|
298
|
+
{"role": "user", "content": enhanced_prompt},
|
|
299
|
+
]
|
|
300
|
+
|
|
301
|
+
# Generate plan
|
|
302
|
+
logger.info("Generating plan...")
|
|
303
|
+
plan, in_tok, out_tok = await generate_plan(
|
|
304
|
+
self.model,
|
|
305
|
+
prompt,
|
|
306
|
+
language,
|
|
307
|
+
schema,
|
|
308
|
+
constraints,
|
|
309
|
+
extracted_data_context,
|
|
310
|
+
inputs,
|
|
311
|
+
outputs,
|
|
312
|
+
self.litellm_params,
|
|
313
|
+
)
|
|
314
|
+
logger.info(f"Plan created: {plan.description}")
|
|
315
|
+
logger.info(f"Approach: {plan.approach}")
|
|
316
|
+
|
|
317
|
+
# Prepare base packages
|
|
318
|
+
base_pkgs = list(self.base_packages or [])
|
|
319
|
+
if not self.skip_tests:
|
|
320
|
+
test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
|
|
321
|
+
for pkg in test_framework_info["packages"]:
|
|
322
|
+
if pkg not in base_pkgs:
|
|
323
|
+
base_pkgs.append(pkg)
|
|
324
|
+
|
|
325
|
+
# Run iteration loop
|
|
326
|
+
session = _CodeGenSession(
|
|
327
|
+
agent=self,
|
|
328
|
+
language=language,
|
|
329
|
+
prompt=prompt,
|
|
330
|
+
schema=schema,
|
|
331
|
+
constraints=constraints,
|
|
332
|
+
inputs=inputs,
|
|
333
|
+
outputs=outputs,
|
|
334
|
+
base_packages=base_pkgs,
|
|
335
|
+
extracted_data_context=extracted_data_context,
|
|
336
|
+
sample_files=sample_files,
|
|
337
|
+
schemas_as_code=schemas_as_code,
|
|
338
|
+
base_messages=base_messages,
|
|
339
|
+
plan=plan,
|
|
340
|
+
skip_tests=self.skip_tests,
|
|
341
|
+
initial_input_tokens=(schema_input_tokens + in_tok) if samples else in_tok,
|
|
342
|
+
initial_output_tokens=((schema_output_tokens + out_tok) if samples else out_tok),
|
|
343
|
+
)
|
|
344
|
+
return await session.run()
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
class _CodeGenSession:
|
|
348
|
+
"""Internal: manages mutable state for a single LiteLLM code generation run.
|
|
349
|
+
|
|
350
|
+
Encapsulates the retry loop, code/test generation, image building,
|
|
351
|
+
error diagnosis and reclassification logic.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def __init__(
|
|
355
|
+
self,
|
|
356
|
+
*,
|
|
357
|
+
agent: AutoCoderAgent,
|
|
358
|
+
language: str,
|
|
359
|
+
prompt: str,
|
|
360
|
+
schema: Optional[str],
|
|
361
|
+
constraints: Optional[list[str]],
|
|
362
|
+
inputs: Optional[dict[str, type]],
|
|
363
|
+
outputs: Optional[dict[str, type]],
|
|
364
|
+
base_packages: list[str],
|
|
365
|
+
extracted_data_context: Optional[str],
|
|
366
|
+
sample_files: Optional[dict[str, File]],
|
|
367
|
+
schemas_as_code: dict[str, str],
|
|
368
|
+
base_messages: list[dict[str, str]],
|
|
369
|
+
plan: CodePlan,
|
|
370
|
+
skip_tests: bool,
|
|
371
|
+
initial_input_tokens: int,
|
|
372
|
+
initial_output_tokens: int,
|
|
373
|
+
):
|
|
374
|
+
# Agent reference (immutable config)
|
|
375
|
+
self.agent = agent
|
|
376
|
+
self.name = agent.name
|
|
377
|
+
self.model = agent.model
|
|
378
|
+
self.max_iterations = 1 if skip_tests else agent.max_iterations
|
|
379
|
+
self.skip_tests = skip_tests
|
|
380
|
+
self.resources = agent.resources
|
|
381
|
+
self.sandbox_retries = agent.sandbox_retries
|
|
382
|
+
self.timeout = agent.timeout
|
|
383
|
+
self.env_vars = agent.env_vars
|
|
384
|
+
self.secrets = agent.secrets
|
|
385
|
+
self.cache = agent.cache
|
|
386
|
+
self.image_config = agent.image_config
|
|
387
|
+
self.litellm_params = agent.litellm_params
|
|
388
|
+
|
|
389
|
+
# Per-call config (immutable)
|
|
390
|
+
self.language = language
|
|
391
|
+
self.prompt = prompt
|
|
392
|
+
self.schema = schema
|
|
393
|
+
self.constraints = constraints
|
|
394
|
+
self.inputs = inputs
|
|
395
|
+
self.outputs = outputs
|
|
396
|
+
self.extracted_data_context = extracted_data_context
|
|
397
|
+
self.sample_files = sample_files
|
|
398
|
+
self.schemas_as_code = schemas_as_code
|
|
399
|
+
self.base_messages = base_messages
|
|
400
|
+
self.plan = plan
|
|
401
|
+
|
|
402
|
+
# Package state
|
|
403
|
+
self.base_pkgs = base_packages
|
|
404
|
+
self.detected_packages: list[str] = []
|
|
405
|
+
self.detected_system_packages: list[str] = []
|
|
406
|
+
self.additional_commands: list[str] = []
|
|
407
|
+
self.previously_installed_packages: list[str] = []
|
|
408
|
+
self.previously_installed_system_packages: list[str] = []
|
|
409
|
+
|
|
410
|
+
# Image state
|
|
411
|
+
self.current_image: Optional[str] = None
|
|
412
|
+
self.image_name = self._compute_image_name(self.base_pkgs, [])
|
|
413
|
+
|
|
414
|
+
# Generation state
|
|
415
|
+
self.solution: Optional[CodeSolution] = None
|
|
416
|
+
self.tests: Optional[str] = None
|
|
417
|
+
self.needs_new_code = True
|
|
418
|
+
self.needs_new_tests = True
|
|
419
|
+
self.needs_rebuild = True
|
|
420
|
+
self.last_packages_snapshot: tuple[set, set] = (set(), set())
|
|
421
|
+
|
|
422
|
+
# Error tracking
|
|
423
|
+
self.last_error: Optional[str] = None
|
|
424
|
+
self.last_error_message: Optional[str] = None
|
|
425
|
+
self.last_diagnosis = None
|
|
426
|
+
self.last_result: Optional[CodeGenEvalResult] = None
|
|
427
|
+
|
|
428
|
+
# Reclassification tracking
|
|
429
|
+
self.logic_fix_attempts: dict[tuple, int] = {}
|
|
430
|
+
self.test_fix_attempts: dict[tuple, int] = {}
|
|
431
|
+
self.max_logic_attempts = 1
|
|
432
|
+
self.max_test_attempts = 1
|
|
433
|
+
|
|
434
|
+
# Token tracking
|
|
435
|
+
self.total_input_tokens = initial_input_tokens
|
|
436
|
+
self.total_output_tokens = initial_output_tokens
|
|
437
|
+
|
|
438
|
+
# Add test framework system packages
|
|
439
|
+
if not self.skip_tests:
|
|
440
|
+
test_framework_info = TEST_FRAMEWORKS.get(language, TEST_FRAMEWORKS["python"])
|
|
441
|
+
for pkg in test_framework_info.get("system_packages", []):
|
|
442
|
+
if pkg not in self.detected_system_packages:
|
|
443
|
+
self.detected_system_packages.append(pkg)
|
|
444
|
+
|
|
445
|
+
def _compute_image_name(self, packages: list[str], system_packages: list[str]) -> str:
|
|
446
|
+
spec = {
|
|
447
|
+
"language": self.language,
|
|
448
|
+
"packages": sorted(packages),
|
|
449
|
+
"system_packages": sorted(system_packages),
|
|
450
|
+
}
|
|
451
|
+
config_hash = hashlib.sha256(json.dumps(spec, sort_keys=True).encode()).hexdigest()[:12]
|
|
452
|
+
return f"auto-coder-agent-{self.language}-{config_hash}"
|
|
453
|
+
|
|
454
|
+
def _track_tokens(self, in_tok: int, out_tok: int):
|
|
455
|
+
self.total_input_tokens += in_tok
|
|
456
|
+
self.total_output_tokens += out_tok
|
|
457
|
+
|
|
458
|
+
def _make_result(
|
|
459
|
+
self,
|
|
460
|
+
*,
|
|
461
|
+
success: bool,
|
|
462
|
+
test_output: str,
|
|
463
|
+
exit_code: int,
|
|
464
|
+
attempt: int,
|
|
465
|
+
error: Optional[str] = None,
|
|
466
|
+
) -> CodeGenEvalResult:
|
|
467
|
+
return CodeGenEvalResult(
|
|
468
|
+
plan=self.plan,
|
|
469
|
+
solution=self.solution or CodeSolution(),
|
|
470
|
+
tests=self.tests,
|
|
471
|
+
success=success,
|
|
472
|
+
output=test_output,
|
|
473
|
+
exit_code=exit_code,
|
|
474
|
+
error=error,
|
|
475
|
+
attempts=attempt,
|
|
476
|
+
conversation_history=self.base_messages,
|
|
477
|
+
detected_packages=self.detected_packages,
|
|
478
|
+
detected_system_packages=self.detected_system_packages,
|
|
479
|
+
image=self.current_image,
|
|
480
|
+
total_input_tokens=self.total_input_tokens,
|
|
481
|
+
total_output_tokens=self.total_output_tokens,
|
|
482
|
+
declared_inputs=self.inputs,
|
|
483
|
+
declared_outputs=self.outputs,
|
|
484
|
+
data_context=self.extracted_data_context,
|
|
485
|
+
original_samples=self.sample_files,
|
|
486
|
+
generated_schemas=self.schemas_as_code or None,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
async def run(self) -> CodeGenEvalResult:
|
|
490
|
+
"""Execute the full retry loop."""
|
|
491
|
+
for attempt in range(1, self.max_iterations + 1):
|
|
492
|
+
logger.info(
|
|
493
|
+
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n"
|
|
494
|
+
f"[ITERATION] Starting attempt {attempt}/{self.max_iterations}\n"
|
|
495
|
+
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
|
|
496
|
+
)
|
|
497
|
+
try:
|
|
498
|
+
result = await self._attempt(attempt)
|
|
499
|
+
if result is not None:
|
|
500
|
+
return result
|
|
501
|
+
except Exception as e:
|
|
502
|
+
self.last_error = str(e)
|
|
503
|
+
logger.error(f"Error during attempt {attempt}: {self.last_error}")
|
|
504
|
+
self.last_error_message = f"An error occurred: {self.last_error}"
|
|
505
|
+
self.last_result = self._make_result(
|
|
506
|
+
success=False,
|
|
507
|
+
test_output="",
|
|
508
|
+
exit_code=-1,
|
|
509
|
+
attempt=attempt,
|
|
510
|
+
error=f"Attempt {attempt} failed: {self.last_error}",
|
|
511
|
+
)
|
|
512
|
+
if attempt == self.max_iterations:
|
|
513
|
+
return self.last_result
|
|
514
|
+
self.needs_new_code = True
|
|
515
|
+
if self.tests is None:
|
|
516
|
+
self.needs_new_tests = True
|
|
517
|
+
|
|
518
|
+
return self.last_result or self._make_result(
|
|
519
|
+
success=False,
|
|
520
|
+
test_output="",
|
|
521
|
+
exit_code=-1,
|
|
522
|
+
attempt=self.max_iterations,
|
|
523
|
+
error=f"All {self.max_iterations} attempts failed: {self.last_error}",
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
async def _attempt(self, attempt: int) -> Optional[CodeGenEvalResult]:
|
|
527
|
+
"""Run a single iteration. Returns result on success, None to continue."""
|
|
528
|
+
# 1. Generate code (only when needed)
|
|
529
|
+
if self.needs_new_code:
|
|
530
|
+
logger.info("Generating code...")
|
|
531
|
+
if not await self._generate_code(attempt):
|
|
532
|
+
return None # Skip to next iteration
|
|
533
|
+
|
|
534
|
+
# Detect and track packages
|
|
535
|
+
(
|
|
536
|
+
self.needs_rebuild,
|
|
537
|
+
self.detected_packages,
|
|
538
|
+
self.detected_system_packages,
|
|
539
|
+
in_tok,
|
|
540
|
+
out_tok,
|
|
541
|
+
) = await detect_and_track_packages(
|
|
542
|
+
self.model,
|
|
543
|
+
self.solution,
|
|
544
|
+
self.base_pkgs,
|
|
545
|
+
self.detected_packages,
|
|
546
|
+
self.detected_system_packages,
|
|
547
|
+
self.litellm_params,
|
|
548
|
+
)
|
|
549
|
+
self._track_tokens(in_tok, out_tok)
|
|
550
|
+
self.needs_new_code = False
|
|
551
|
+
if self.tests is None:
|
|
552
|
+
self.needs_new_tests = True
|
|
553
|
+
|
|
554
|
+
# Short-circuit: skip tests if requested
|
|
555
|
+
if self.skip_tests:
|
|
556
|
+
logger.info("skip_tests=True: skipping test generation and execution.")
|
|
557
|
+
# Still build the image (needed for as_task() and run())
|
|
558
|
+
self._update_image_name_if_needed()
|
|
559
|
+
if self.needs_rebuild or self.current_image is None:
|
|
560
|
+
await self._build_image()
|
|
561
|
+
return self._make_result(
|
|
562
|
+
success=True,
|
|
563
|
+
test_output="",
|
|
564
|
+
exit_code=0,
|
|
565
|
+
attempt=attempt,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# 2. Generate tests (only when needed)
|
|
569
|
+
if self.needs_new_tests:
|
|
570
|
+
logger.info("Generating tests...")
|
|
571
|
+
self.tests, in_tok, out_tok = await generate_tests(
|
|
572
|
+
self.model,
|
|
573
|
+
self.prompt,
|
|
574
|
+
self.plan,
|
|
575
|
+
self.solution,
|
|
576
|
+
self.constraints,
|
|
577
|
+
self.schema,
|
|
578
|
+
self.extracted_data_context,
|
|
579
|
+
self.inputs,
|
|
580
|
+
self.outputs,
|
|
581
|
+
self.litellm_params,
|
|
582
|
+
)
|
|
583
|
+
self._track_tokens(in_tok, out_tok)
|
|
584
|
+
self.needs_new_tests = False
|
|
585
|
+
|
|
586
|
+
# 3. Update image name if packages changed
|
|
587
|
+
self._update_image_name_if_needed()
|
|
588
|
+
|
|
589
|
+
# 4. Build/rebuild image if needed
|
|
590
|
+
if self.needs_rebuild or self.current_image is None:
|
|
591
|
+
await self._build_image()
|
|
592
|
+
|
|
593
|
+
# 5. Execute tests
|
|
594
|
+
logger.info("Running tests...")
|
|
595
|
+
run_tests_output = await run_tests.aio(
|
|
596
|
+
code=self.solution.code,
|
|
597
|
+
tests=self.tests,
|
|
598
|
+
image=self.current_image,
|
|
599
|
+
name=self.name,
|
|
600
|
+
resources=self.resources,
|
|
601
|
+
retries=self.sandbox_retries,
|
|
602
|
+
timeout=self.timeout,
|
|
603
|
+
env_vars=self.env_vars,
|
|
604
|
+
secrets=self.secrets,
|
|
605
|
+
cache=self.cache,
|
|
606
|
+
_attempt=attempt,
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
tests_passed, test_output, test_exit_code = (
|
|
610
|
+
run_tests_output.tests_passed,
|
|
611
|
+
run_tests_output.output,
|
|
612
|
+
run_tests_output.exit_code,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
# 6. Handle success
|
|
616
|
+
if tests_passed:
|
|
617
|
+
logger.info("Tests passed! Solution successful.")
|
|
618
|
+
logger.info(f"Total tokens: input={self.total_input_tokens}, output={self.total_output_tokens}")
|
|
619
|
+
self.last_diagnosis = None
|
|
620
|
+
return self._make_result(
|
|
621
|
+
success=True,
|
|
622
|
+
test_output=test_output,
|
|
623
|
+
exit_code=int(test_exit_code.strip()),
|
|
624
|
+
attempt=attempt,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# 7. Handle failure
|
|
628
|
+
return await self._handle_failure(test_output, test_exit_code, attempt)
|
|
629
|
+
|
|
630
|
+
async def _generate_code(self, attempt: int) -> bool:
|
|
631
|
+
"""Generate code with up to 3 verification attempts. Returns True on success."""
|
|
632
|
+
max_attempts = 3
|
|
633
|
+
|
|
634
|
+
for code_attempt in range(1, max_attempts + 1):
|
|
635
|
+
logger.info(f"Generating code (attempt {attempt}, code gen {code_attempt}/{max_attempts})...")
|
|
636
|
+
|
|
637
|
+
# Build messages with progressively forceful error context
|
|
638
|
+
messages = self.base_messages.copy()
|
|
639
|
+
if self.last_error_message:
|
|
640
|
+
if code_attempt == 1:
|
|
641
|
+
messages.append({"role": "user", "content": self.last_error_message})
|
|
642
|
+
elif code_attempt == 2:
|
|
643
|
+
messages.append(
|
|
644
|
+
{
|
|
645
|
+
"role": "user",
|
|
646
|
+
"content": (
|
|
647
|
+
f"{self.last_error_message}\n\n"
|
|
648
|
+
"CRITICAL: The previous code generation attempt did NOT "
|
|
649
|
+
"apply all the required fixes.\n"
|
|
650
|
+
"You MUST apply EVERY SINGLE fix listed above. "
|
|
651
|
+
"Do not skip any fix.\n"
|
|
652
|
+
"Apply each fix EXACTLY as specified "
|
|
653
|
+
"- find the old code and replace it with the new code."
|
|
654
|
+
),
|
|
655
|
+
}
|
|
656
|
+
)
|
|
657
|
+
else:
|
|
658
|
+
messages.append(
|
|
659
|
+
{
|
|
660
|
+
"role": "user",
|
|
661
|
+
"content": (
|
|
662
|
+
f"{self.last_error_message}\n\n"
|
|
663
|
+
"FINAL ATTEMPT: You have failed to apply the required fixes twice.\n"
|
|
664
|
+
"This is your last chance. Apply EVERY fix listed above WITHOUT EXCEPTION.\n"
|
|
665
|
+
"For each fix, you MUST:\n"
|
|
666
|
+
"1. Find the EXACT old code mentioned\n"
|
|
667
|
+
"2. Replace it with the EXACT new code mentioned\n"
|
|
668
|
+
"3. Do NOT change anything else\n\n"
|
|
669
|
+
"If you fail to apply all fixes this time, the entire task will fail."
|
|
670
|
+
),
|
|
671
|
+
}
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
self.solution, in_tok, out_tok = await generate_code(
|
|
675
|
+
self.model,
|
|
676
|
+
messages,
|
|
677
|
+
self.plan,
|
|
678
|
+
self.litellm_params,
|
|
679
|
+
is_retry=(attempt > 1),
|
|
680
|
+
)
|
|
681
|
+
self._track_tokens(in_tok, out_tok)
|
|
682
|
+
|
|
683
|
+
if self.solution.language.lower() != self.language.lower():
|
|
684
|
+
logger.warning(f"Requested {self.language} but LLM generated {self.solution.language}")
|
|
685
|
+
|
|
686
|
+
# Verify fixes are applied (if we have a diagnosis with logic failures)
|
|
687
|
+
has_logic_failures = self.last_diagnosis and any(
|
|
688
|
+
f.error_type == "logic" for f in self.last_diagnosis.failures
|
|
689
|
+
)
|
|
690
|
+
if not has_logic_failures:
|
|
691
|
+
return True
|
|
692
|
+
|
|
693
|
+
logger.info("Verifying that logic fixes were applied...")
|
|
694
|
+
verification, in_tok, out_tok = await verify_logic_fixes_applied(
|
|
695
|
+
self.model,
|
|
696
|
+
self.last_diagnosis,
|
|
697
|
+
self.solution,
|
|
698
|
+
self.litellm_params,
|
|
699
|
+
)
|
|
700
|
+
self._track_tokens(in_tok, out_tok)
|
|
701
|
+
|
|
702
|
+
if verification.all_fixes_applied:
|
|
703
|
+
logger.info(f"Verification passed: All fixes applied. {verification.explanation}")
|
|
704
|
+
return True
|
|
705
|
+
|
|
706
|
+
logger.warning(f"Verification failed: {verification.explanation}")
|
|
707
|
+
logger.warning(f"Applied: {verification.applied_fixes}")
|
|
708
|
+
logger.warning(f"Missing: {verification.missing_fixes}")
|
|
709
|
+
|
|
710
|
+
if code_attempt < max_attempts:
|
|
711
|
+
# Append verification feedback for next attempt
|
|
712
|
+
missing_msg = "\n\nVERIFICATION FAILED - The following fixes are STILL MISSING:\n"
|
|
713
|
+
for i, fix in enumerate(verification.missing_fixes, 1):
|
|
714
|
+
missing_msg += f"\n{i}. {fix}"
|
|
715
|
+
missing_msg += "\n\nYou successfully applied these fixes:\n"
|
|
716
|
+
for fix in verification.applied_fixes:
|
|
717
|
+
missing_msg += f"- {fix}\n"
|
|
718
|
+
missing_msg += (
|
|
719
|
+
"\nYou MUST now apply the MISSING fixes listed above. "
|
|
720
|
+
"Do NOT regenerate the entire solution - just apply the missing fixes to your previous code."
|
|
721
|
+
)
|
|
722
|
+
self.last_error_message = (self.last_error_message or "") + missing_msg
|
|
723
|
+
else:
|
|
724
|
+
logger.error(f"Failed to apply all fixes after {max_attempts} attempts. Proceeding anyway...")
|
|
725
|
+
return True # Proceed anyway
|
|
726
|
+
|
|
727
|
+
logger.error("Failed to generate code with all fixes applied. Skipping this iteration.")
|
|
728
|
+
return False
|
|
729
|
+
|
|
730
|
+
def _update_image_name_if_needed(self):
|
|
731
|
+
"""Check if packages changed and update image name."""
|
|
732
|
+
current_snapshot = (
|
|
733
|
+
set(self.detected_packages),
|
|
734
|
+
set(self.detected_system_packages),
|
|
735
|
+
)
|
|
736
|
+
if current_snapshot == self.last_packages_snapshot:
|
|
737
|
+
return
|
|
738
|
+
|
|
739
|
+
self.needs_rebuild = True
|
|
740
|
+
self.last_packages_snapshot = current_snapshot
|
|
741
|
+
|
|
742
|
+
all_packages = self.base_pkgs + self.detected_packages
|
|
743
|
+
new_name = self._compute_image_name(all_packages, self.detected_system_packages)
|
|
744
|
+
if new_name != self.image_name:
|
|
745
|
+
logger.info(f"Image name updated: {self.image_name} -> {new_name}")
|
|
746
|
+
self.image_name = new_name
|
|
747
|
+
self.current_image = None
|
|
748
|
+
|
|
749
|
+
async def _build_image(self):
|
|
750
|
+
"""Build/rebuild image with package retry loop."""
|
|
751
|
+
max_retries = 3
|
|
752
|
+
for _ in range(max_retries):
|
|
753
|
+
try:
|
|
754
|
+
self.current_image = await build_image(
|
|
755
|
+
self.solution.language,
|
|
756
|
+
self.base_pkgs,
|
|
757
|
+
self.detected_packages,
|
|
758
|
+
self.detected_system_packages,
|
|
759
|
+
self.previously_installed_packages,
|
|
760
|
+
self.previously_installed_system_packages,
|
|
761
|
+
self.additional_commands,
|
|
762
|
+
self.image_name,
|
|
763
|
+
self.current_image,
|
|
764
|
+
self.image_config,
|
|
765
|
+
)
|
|
766
|
+
self.previously_installed_packages = self.detected_packages.copy()
|
|
767
|
+
self.previously_installed_system_packages = self.detected_system_packages.copy()
|
|
768
|
+
self.needs_rebuild = False
|
|
769
|
+
return
|
|
770
|
+
except InvalidPackageError as e:
|
|
771
|
+
bad_package = e.package_name
|
|
772
|
+
logger.warning(f"Invalid system package '{bad_package}', asking LLM for replacement...")
|
|
773
|
+
if bad_package in self.detected_system_packages:
|
|
774
|
+
self.detected_system_packages.remove(bad_package)
|
|
775
|
+
if bad_package in self.previously_installed_system_packages:
|
|
776
|
+
self.previously_installed_system_packages.remove(bad_package)
|
|
777
|
+
|
|
778
|
+
# Ask LLM for the correct package name
|
|
779
|
+
solution_code = self.solution.code
|
|
780
|
+
replacement, in_tok, out_tok = await suggest_replacement_package(
|
|
781
|
+
self.model,
|
|
782
|
+
bad_package,
|
|
783
|
+
e.original_error,
|
|
784
|
+
solution_code,
|
|
785
|
+
self.litellm_params,
|
|
786
|
+
)
|
|
787
|
+
self.total_input_tokens += in_tok
|
|
788
|
+
self.total_output_tokens += out_tok
|
|
789
|
+
|
|
790
|
+
if replacement and replacement not in self.detected_system_packages:
|
|
791
|
+
self.detected_system_packages.append(replacement)
|
|
792
|
+
logger.info(f"Replacing '{bad_package}' with '{replacement}'")
|
|
793
|
+
|
|
794
|
+
logger.info(f"Retrying with system packages: {self.detected_system_packages}")
|
|
795
|
+
|
|
796
|
+
async def _handle_failure(
|
|
797
|
+
self,
|
|
798
|
+
test_output: str,
|
|
799
|
+
test_exit_code: str,
|
|
800
|
+
attempt: int,
|
|
801
|
+
) -> Optional[CodeGenEvalResult]:
|
|
802
|
+
"""Handle test failure: diagnose, reclassify, fix tests or code. Returns None to continue."""
|
|
803
|
+
# Check if tests actually executed
|
|
804
|
+
tests_executed = (
|
|
805
|
+
" passed" in test_output or " failed" in test_output or "collected 0 items" not in test_output
|
|
806
|
+
) and "ERROR collecting" not in test_output
|
|
807
|
+
|
|
808
|
+
if not tests_executed:
|
|
809
|
+
logger.warning("No tests executed - test file likely has errors. Regenerating tests...")
|
|
810
|
+
self.needs_new_tests = True
|
|
811
|
+
self.needs_new_code = False
|
|
812
|
+
return None
|
|
813
|
+
|
|
814
|
+
# Diagnose failures
|
|
815
|
+
(
|
|
816
|
+
primary_error_type,
|
|
817
|
+
self.detected_packages,
|
|
818
|
+
self.detected_system_packages,
|
|
819
|
+
self.additional_commands,
|
|
820
|
+
in_tok,
|
|
821
|
+
out_tok,
|
|
822
|
+
diagnosis,
|
|
823
|
+
) = await diagnose_and_plan_environment_fix(
|
|
824
|
+
self.model,
|
|
825
|
+
self.solution,
|
|
826
|
+
test_output,
|
|
827
|
+
self.prompt,
|
|
828
|
+
self.plan,
|
|
829
|
+
self.detected_packages,
|
|
830
|
+
self.detected_system_packages,
|
|
831
|
+
self.additional_commands,
|
|
832
|
+
self.litellm_params,
|
|
833
|
+
self.tests,
|
|
834
|
+
self.extracted_data_context,
|
|
835
|
+
self.constraints,
|
|
836
|
+
self.schema,
|
|
837
|
+
)
|
|
838
|
+
self._track_tokens(in_tok, out_tok)
|
|
839
|
+
self.last_diagnosis = diagnosis
|
|
840
|
+
|
|
841
|
+
# Apply environment fixes from diagnosis
|
|
842
|
+
self._apply_environment_fixes(diagnosis)
|
|
843
|
+
|
|
844
|
+
# Reclassify repeated errors
|
|
845
|
+
primary_error_type = self._reclassify_errors(diagnosis, primary_error_type)
|
|
846
|
+
|
|
847
|
+
# Handle test errors: fix tests
|
|
848
|
+
if primary_error_type == "test_error":
|
|
849
|
+
return await self._handle_test_errors(diagnosis, attempt)
|
|
850
|
+
|
|
851
|
+
# Handle logic/environment errors: build patch message
|
|
852
|
+
return self._handle_logic_env_errors(diagnosis, test_output, test_exit_code, attempt)
|
|
853
|
+
|
|
854
|
+
def _apply_environment_fixes(self, diagnosis):
|
|
855
|
+
"""Extract and apply environment fixes from diagnosis."""
|
|
856
|
+
if diagnosis.needs_language_packages:
|
|
857
|
+
added = [
|
|
858
|
+
p
|
|
859
|
+
for p in diagnosis.needs_language_packages
|
|
860
|
+
if p not in self.detected_packages and p not in self.base_pkgs
|
|
861
|
+
]
|
|
862
|
+
if added:
|
|
863
|
+
self.detected_packages.extend(added)
|
|
864
|
+
logger.info(f"Adding language packages from diagnosis: {added}")
|
|
865
|
+
self.needs_rebuild = True
|
|
866
|
+
|
|
867
|
+
if diagnosis.needs_system_packages:
|
|
868
|
+
added = [p for p in diagnosis.needs_system_packages if p not in self.detected_system_packages]
|
|
869
|
+
if added:
|
|
870
|
+
self.detected_system_packages.extend(added)
|
|
871
|
+
logger.info(f"Adding system packages from diagnosis: {added}")
|
|
872
|
+
self.needs_rebuild = True
|
|
873
|
+
|
|
874
|
+
if diagnosis.needs_additional_commands:
|
|
875
|
+
logger.info(f"Adding additional commands from diagnosis: {diagnosis.needs_additional_commands}")
|
|
876
|
+
self.additional_commands.extend(diagnosis.needs_additional_commands)
|
|
877
|
+
self.needs_rebuild = True
|
|
878
|
+
|
|
879
|
+
def _reclassify_errors(self, diagnosis, primary_error_type: str) -> str:
|
|
880
|
+
"""Reclassify repeated errors (test_error <-> logic). Returns updated primary_error_type."""
|
|
881
|
+
# test_error -> logic (test might be correct, code is wrong)
|
|
882
|
+
reclassified = 0
|
|
883
|
+
for failure in diagnosis.failures:
|
|
884
|
+
if failure.error_type != "test_error":
|
|
885
|
+
continue
|
|
886
|
+
sig = failure.error_message or failure.actual_behavior
|
|
887
|
+
key = (failure.test_name, sig)
|
|
888
|
+
self.test_fix_attempts[key] = self.test_fix_attempts.get(key, 0) + 1
|
|
889
|
+
|
|
890
|
+
if self.test_fix_attempts[key] > self.max_test_attempts:
|
|
891
|
+
original = failure.root_cause
|
|
892
|
+
failure.error_type = "logic"
|
|
893
|
+
failure.root_cause = (
|
|
894
|
+
f"Test failed {self.max_test_attempts + 1} times with same error after test fixes. "
|
|
895
|
+
f"The test expectations are likely correct. The code logic could be wrong. "
|
|
896
|
+
f"Original test diagnosis was: {original}"
|
|
897
|
+
)
|
|
898
|
+
failure.suggested_fix = (
|
|
899
|
+
f"Fix the code logic to match test expectations. "
|
|
900
|
+
f"Test expects: {failure.expected_behavior}. "
|
|
901
|
+
f"Code produces: {failure.actual_behavior}. "
|
|
902
|
+
f"Update the code to produce the expected behavior."
|
|
903
|
+
)
|
|
904
|
+
self.test_fix_attempts.pop(key, None)
|
|
905
|
+
reclassified += 1
|
|
906
|
+
logger.warning(f"Reclassified test_error -> logic for '{failure.test_name}'")
|
|
907
|
+
|
|
908
|
+
if reclassified:
|
|
909
|
+
logger.info(f"Reclassified {reclassified} test_error(s) to logic.")
|
|
910
|
+
|
|
911
|
+
# logic -> test_error (LLM might misdiagnose test bugs as logic bugs)
|
|
912
|
+
reclassified = 0
|
|
913
|
+
for failure in diagnosis.failures:
|
|
914
|
+
if failure.error_type != "logic":
|
|
915
|
+
continue
|
|
916
|
+
sig = failure.error_message or failure.actual_behavior
|
|
917
|
+
key = (failure.test_name, sig)
|
|
918
|
+
self.logic_fix_attempts[key] = self.logic_fix_attempts.get(key, 0) + 1
|
|
919
|
+
|
|
920
|
+
if self.logic_fix_attempts[key] > self.max_logic_attempts:
|
|
921
|
+
original = failure.root_cause
|
|
922
|
+
failure.error_type = "test_error"
|
|
923
|
+
failure.root_cause = (
|
|
924
|
+
f"Test failed {self.max_logic_attempts + 1} times with same error after logic fixes. "
|
|
925
|
+
f"Likely the test itself has wrong expected values, not the code. "
|
|
926
|
+
f"Original diagnosis was: {original}"
|
|
927
|
+
)
|
|
928
|
+
failure.suggested_fix = (
|
|
929
|
+
f"Fix the test expectations to match actual correct behavior. "
|
|
930
|
+
f"Code produces: {failure.actual_behavior}. "
|
|
931
|
+
f"If this is correct, update the test to expect this value instead."
|
|
932
|
+
)
|
|
933
|
+
self.logic_fix_attempts.pop(key, None)
|
|
934
|
+
reclassified += 1
|
|
935
|
+
logger.warning(f"Reclassified logic -> test_error for '{failure.test_name}'")
|
|
936
|
+
|
|
937
|
+
if reclassified:
|
|
938
|
+
logger.info(f"Reclassified {reclassified} logic error(s) to test_error.")
|
|
939
|
+
has_test_errors = any(f.error_type == "test_error" for f in diagnosis.failures)
|
|
940
|
+
if has_test_errors:
|
|
941
|
+
diagnosis.failures = [f for f in diagnosis.failures if f.error_type == "test_error"]
|
|
942
|
+
primary_error_type = "test_error"
|
|
943
|
+
logger.info(f"After reclassification: {len(diagnosis.failures)} test_error failure(s).")
|
|
944
|
+
|
|
945
|
+
return primary_error_type
|
|
946
|
+
|
|
947
|
+
async def _handle_test_errors(self, diagnosis, attempt: int) -> Optional[CodeGenEvalResult]:
|
|
948
|
+
"""Fix failing tests with up to 3 verification attempts. Returns None to continue."""
|
|
949
|
+
logger.info("Diagnosis identified bug in test code. Fixing only failed tests...")
|
|
950
|
+
logger.info(f"Failed tests to fix: {[f.test_name for f in diagnosis.failures]}")
|
|
951
|
+
|
|
952
|
+
max_attempts = 3
|
|
953
|
+
|
|
954
|
+
for fix_attempt in range(1, max_attempts + 1):
|
|
955
|
+
logger.info(f"Fixing failing tests (attempt {fix_attempt}/{max_attempts})...")
|
|
956
|
+
|
|
957
|
+
self.tests, patches, in_tok, out_tok = await fix_failing_tests(
|
|
958
|
+
self.model,
|
|
959
|
+
self.tests,
|
|
960
|
+
diagnosis,
|
|
961
|
+
self.solution,
|
|
962
|
+
self.litellm_params,
|
|
963
|
+
)
|
|
964
|
+
self._track_tokens(in_tok, out_tok)
|
|
965
|
+
|
|
966
|
+
logger.info("Verifying that test fixes were applied...")
|
|
967
|
+
verification, in_tok, out_tok = await verify_test_fixes_applied(
|
|
968
|
+
self.model,
|
|
969
|
+
diagnosis,
|
|
970
|
+
patches,
|
|
971
|
+
self.litellm_params,
|
|
972
|
+
)
|
|
973
|
+
self._track_tokens(in_tok, out_tok)
|
|
974
|
+
|
|
975
|
+
if verification.all_fixes_applied:
|
|
976
|
+
logger.info(f"Verification passed: All test fixes applied. {verification.explanation}")
|
|
977
|
+
self.needs_new_code = False
|
|
978
|
+
self.last_diagnosis = None
|
|
979
|
+
return None # Continue to next iteration with fixed tests
|
|
980
|
+
|
|
981
|
+
logger.warning(f"Verification failed: {verification.explanation}")
|
|
982
|
+
|
|
983
|
+
if fix_attempt < max_attempts:
|
|
984
|
+
missing_msg = "\n\nVERIFICATION FAILED - The following test fixes are STILL MISSING:\n"
|
|
985
|
+
for i, fix in enumerate(verification.missing_fixes, 1):
|
|
986
|
+
missing_msg += f"\n{i}. {fix}"
|
|
987
|
+
missing_msg += "\n\nYou successfully applied these fixes:\n"
|
|
988
|
+
for fix in verification.applied_fixes:
|
|
989
|
+
missing_msg += f"- {fix}\n"
|
|
990
|
+
missing_msg += "\nYou MUST now apply the MISSING test fixes listed above."
|
|
991
|
+
|
|
992
|
+
forceful = ""
|
|
993
|
+
if fix_attempt == 2:
|
|
994
|
+
forceful = (
|
|
995
|
+
"\n\nCRITICAL: The previous test fix attempt did NOT apply all the required fixes. "
|
|
996
|
+
"You MUST apply EVERY SINGLE fix listed above. Do not skip any fix."
|
|
997
|
+
)
|
|
998
|
+
elif fix_attempt >= 3:
|
|
999
|
+
forceful = (
|
|
1000
|
+
"\n\nFINAL ATTEMPT: You have failed to apply the required test fixes twice. "
|
|
1001
|
+
"This is your last chance. Apply EVERY fix listed above WITHOUT EXCEPTION."
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
for failure in diagnosis.failures:
|
|
1005
|
+
failure.suggested_fix = f"{failure.suggested_fix}\n\n{missing_msg}{forceful}"
|
|
1006
|
+
else:
|
|
1007
|
+
logger.error(f"Failed to apply all test fixes after {max_attempts} attempts. Proceeding anyway...")
|
|
1008
|
+
self.needs_new_code = False
|
|
1009
|
+
self.last_diagnosis = None
|
|
1010
|
+
return None
|
|
1011
|
+
|
|
1012
|
+
return None
|
|
1013
|
+
|
|
1014
|
+
def _handle_logic_env_errors(
|
|
1015
|
+
self,
|
|
1016
|
+
diagnosis,
|
|
1017
|
+
test_output: str,
|
|
1018
|
+
test_exit_code: str,
|
|
1019
|
+
attempt: int,
|
|
1020
|
+
) -> Optional[CodeGenEvalResult]:
|
|
1021
|
+
"""Handle logic and/or environment errors. Returns None to continue, result if max retries."""
|
|
1022
|
+
failures_info = []
|
|
1023
|
+
logic_count = 0
|
|
1024
|
+
env_count = 0
|
|
1025
|
+
|
|
1026
|
+
for i, failure in enumerate(diagnosis.failures, 1):
|
|
1027
|
+
failures_info.append(
|
|
1028
|
+
f"\nTest {i} [{failure.error_type}] - {failure.test_name}\n"
|
|
1029
|
+
f"- Expected: {failure.expected_behavior}\n"
|
|
1030
|
+
f"- Actual: {failure.actual_behavior}\n"
|
|
1031
|
+
f"- Root cause: {failure.root_cause}\n"
|
|
1032
|
+
f"- FIX: {failure.suggested_fix}"
|
|
1033
|
+
)
|
|
1034
|
+
if failure.error_type == "logic":
|
|
1035
|
+
logic_count += 1
|
|
1036
|
+
elif failure.error_type == "environment":
|
|
1037
|
+
env_count += 1
|
|
1038
|
+
|
|
1039
|
+
if logic_count > 0 and env_count > 0:
|
|
1040
|
+
logger.info(f"Will fix {env_count} environment error(s) and patch code for {logic_count} logic error(s)")
|
|
1041
|
+
elif logic_count > 0:
|
|
1042
|
+
logger.info(f"Will patch code for {logic_count} logic error(s)")
|
|
1043
|
+
elif env_count > 0:
|
|
1044
|
+
logger.info(f"Will fix {env_count} environment error(s)")
|
|
1045
|
+
|
|
1046
|
+
full_code = self.solution.code
|
|
1047
|
+
error_msg = (
|
|
1048
|
+
"Tests failed. Apply only the specific fixes below to your code.\n\n"
|
|
1049
|
+
"Do not regenerate from scratch. PATCH the code by applying ONLY the fixes below.\n"
|
|
1050
|
+
"Do NOT make any other changes - keep everything else exactly as is.\n\n"
|
|
1051
|
+
"CRITICAL CONSTRAINTS:\n"
|
|
1052
|
+
"1. /var/outputs is a PRE-EXISTING directory. NEVER delete, recreate, or modify it. "
|
|
1053
|
+
"NEVER use shutil.rmtree or os.makedirs on /var/outputs. Only write files into it using: "
|
|
1054
|
+
"open('/var/outputs/<name>', 'w').write(str(value)). Always use the literal path '/var/outputs' "
|
|
1055
|
+
"-- never make it configurable or store it in a variable.\n"
|
|
1056
|
+
"2. If a part of the code is working correctly, DO NOT change it. Only fix what's broken.\n"
|
|
1057
|
+
"3. Apply each fix by finding the exact code quoted and replacing it - nothing more.\n"
|
|
1058
|
+
"4. Do NOT regenerate the entire code. Just apply the specific patches mentioned below.\n\n"
|
|
1059
|
+
f"Your previous code:\n```{self.solution.language}\n{full_code}\n```\n\n" + "\n".join(failures_info)
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
if logic_count > 0:
|
|
1063
|
+
logger.info("Tests failed. Will patch code with fixes (keeping same tests)...")
|
|
1064
|
+
self.last_error_message = error_msg
|
|
1065
|
+
else:
|
|
1066
|
+
self.last_error_message = None
|
|
1067
|
+
|
|
1068
|
+
self.last_result = self._make_result(
|
|
1069
|
+
success=False,
|
|
1070
|
+
test_output=test_output,
|
|
1071
|
+
exit_code=int(test_exit_code.strip()),
|
|
1072
|
+
attempt=attempt,
|
|
1073
|
+
error=error_msg,
|
|
1074
|
+
)
|
|
1075
|
+
|
|
1076
|
+
if attempt == self.max_iterations:
|
|
1077
|
+
return self.last_result
|
|
1078
|
+
|
|
1079
|
+
# Set flags for next iteration
|
|
1080
|
+
if logic_count > 0:
|
|
1081
|
+
self.needs_new_code = True
|
|
1082
|
+
self.needs_new_tests = False
|
|
1083
|
+
elif env_count > 0:
|
|
1084
|
+
logger.info("Only environment errors - skipping code regeneration, will rebuild image with new packages")
|
|
1085
|
+
self.needs_new_code = False
|
|
1086
|
+
self.needs_new_tests = False
|
|
1087
|
+
|
|
1088
|
+
return None
|