flyteplugins-codegen 2.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/codegen/__init__.py +18 -0
- flyteplugins/codegen/auto_coder_agent.py +1088 -0
- flyteplugins/codegen/core/__init__.py +19 -0
- flyteplugins/codegen/core/types.py +337 -0
- flyteplugins/codegen/data/__init__.py +27 -0
- flyteplugins/codegen/data/extraction.py +281 -0
- flyteplugins/codegen/data/schema.py +270 -0
- flyteplugins/codegen/execution/__init__.py +7 -0
- flyteplugins/codegen/execution/agent.py +671 -0
- flyteplugins/codegen/execution/docker.py +206 -0
- flyteplugins/codegen/generation/__init__.py +41 -0
- flyteplugins/codegen/generation/llm.py +1269 -0
- flyteplugins/codegen/generation/prompts.py +136 -0
- flyteplugins_codegen-2.0.6.dist-info/METADATA +441 -0
- flyteplugins_codegen-2.0.6.dist-info/RECORD +17 -0
- flyteplugins_codegen-2.0.6.dist-info/WHEEL +5 -0
- flyteplugins_codegen-2.0.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import shlex
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import flyte
|
|
8
|
+
from flyte.errors import InvalidPackageError
|
|
9
|
+
from flyte.io import File
|
|
10
|
+
from flyte.sandbox import ImageConfig
|
|
11
|
+
|
|
12
|
+
from flyteplugins.codegen.core.types import CodeGenEvalResult, CodePlan, CodeSolution
|
|
13
|
+
from flyteplugins.codegen.execution.docker import build_image, run_tests
|
|
14
|
+
from flyteplugins.codegen.generation.prompts import build_enhanced_prompt
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
SAFE_PREFIXES = {
|
|
20
|
+
"ls",
|
|
21
|
+
"pwd",
|
|
22
|
+
"cat",
|
|
23
|
+
"head",
|
|
24
|
+
"tail",
|
|
25
|
+
"grep",
|
|
26
|
+
"wc",
|
|
27
|
+
"mkdir",
|
|
28
|
+
"touch",
|
|
29
|
+
"rm",
|
|
30
|
+
"mv",
|
|
31
|
+
"cp",
|
|
32
|
+
"echo",
|
|
33
|
+
"sed",
|
|
34
|
+
"awk",
|
|
35
|
+
"find",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _classify_bash_command(cmd: str) -> str:
|
|
40
|
+
try:
|
|
41
|
+
tokens = shlex.split(cmd)
|
|
42
|
+
except Exception:
|
|
43
|
+
return "deny"
|
|
44
|
+
|
|
45
|
+
if not tokens:
|
|
46
|
+
return "deny"
|
|
47
|
+
|
|
48
|
+
prog = tokens[0]
|
|
49
|
+
|
|
50
|
+
if prog == "pytest" or "pytest" in tokens:
|
|
51
|
+
return "pytest"
|
|
52
|
+
|
|
53
|
+
# Allow safe workspace ops
|
|
54
|
+
if prog in SAFE_PREFIXES:
|
|
55
|
+
return "allow"
|
|
56
|
+
|
|
57
|
+
# Deny others
|
|
58
|
+
return "deny"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
async def code_gen_eval_agent(
|
|
62
|
+
name: str,
|
|
63
|
+
model: str,
|
|
64
|
+
prompt: str,
|
|
65
|
+
schema: Optional[str] = None,
|
|
66
|
+
constraints: Optional[list[str]] = None,
|
|
67
|
+
inputs: Optional[dict[str, type]] = None,
|
|
68
|
+
outputs: Optional[dict[str, type]] = None,
|
|
69
|
+
original_samples: Optional[dict[str, File]] = None,
|
|
70
|
+
data_context: Optional[str] = None,
|
|
71
|
+
generated_schemas: Optional[dict[str, str]] = None,
|
|
72
|
+
base_packages: Optional[list[str]] = None,
|
|
73
|
+
resources: Optional[flyte.Resources] = None,
|
|
74
|
+
image_config: Optional[ImageConfig] = None,
|
|
75
|
+
retries: int = 0,
|
|
76
|
+
timeout: Optional[int] = None,
|
|
77
|
+
env_vars: Optional[dict[str, str]] = None,
|
|
78
|
+
secrets: Optional[list] = None,
|
|
79
|
+
cache: str = "auto",
|
|
80
|
+
max_turns: int = 50,
|
|
81
|
+
language: str = "python",
|
|
82
|
+
) -> CodeGenEvalResult:
|
|
83
|
+
"""Generate single-file Python code using Claude Agent SDK.
|
|
84
|
+
|
|
85
|
+
Runs an autonomous Claude agent that generates a single Python script,
|
|
86
|
+
writes tests, builds sandbox images, and iterates until tests pass.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name: Unique name for this task. Used for workspace isolation, sandbox image names, etc.
|
|
90
|
+
model: Claude model to use (e.g. "sonnet", "opus", "haiku").
|
|
91
|
+
prompt: Task description
|
|
92
|
+
schema: Optional external schema definition (e.g., target database schema)
|
|
93
|
+
constraints: Optional constraints
|
|
94
|
+
inputs: Optional input types
|
|
95
|
+
outputs: Optional output types
|
|
96
|
+
original_samples: Optional sample data files (defaults for result.run()/as_task())
|
|
97
|
+
data_context: Optional extracted data context string
|
|
98
|
+
generated_schemas: Optional Pandera schemas as Python code strings
|
|
99
|
+
base_packages: Optional base packages to always include
|
|
100
|
+
resources: Optional resources for sandbox execution
|
|
101
|
+
image_config: Optional image configuration (registry, python_version, etc.)
|
|
102
|
+
retries: Number of retries for sandbox execution (agent iterations)
|
|
103
|
+
timeout: Timeout for sandbox execution in seconds
|
|
104
|
+
env_vars: Optional environment variables to set in the sandbox
|
|
105
|
+
secrets: Optional secrets to make available in the sandbox
|
|
106
|
+
cache: Caching behavior for sandbox execution ("auto", "override", "disable")
|
|
107
|
+
max_turns: Maximum number of agent turns before stopping (default: 50)
|
|
108
|
+
language: Programming language for code generation (default: "python")
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
CodeGenEvalResult with generated solution
|
|
112
|
+
"""
|
|
113
|
+
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient, HookMatcher
|
|
114
|
+
|
|
115
|
+
logger.info(f"Starting Agent SDK code generation for task: {name}")
|
|
116
|
+
|
|
117
|
+
_task_hash = hashlib.sha256(name.encode()).hexdigest()[:8]
|
|
118
|
+
workspace = Path(f"/tmp/codegen-{_task_hash}")
|
|
119
|
+
workspace.mkdir(parents=True, exist_ok=True) # noqa: ASYNC240
|
|
120
|
+
|
|
121
|
+
tool_calls: list[str] = []
|
|
122
|
+
|
|
123
|
+
# Mutable dicts for closure mutation
|
|
124
|
+
_exec_state = {"test_count": 0} # incremented only when sandbox actually runs
|
|
125
|
+
_sandbox_state = {
|
|
126
|
+
"packages": [],
|
|
127
|
+
"system_packages": [],
|
|
128
|
+
"image": None,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
base_pkgs = (base_packages or []) + ["pytest"]
|
|
132
|
+
|
|
133
|
+
# Convert inputs/outputs to type-name dicts for prompt building
|
|
134
|
+
inputs_for_prompt = (
|
|
135
|
+
{k: t.__name__ if hasattr(t, "__name__") else str(t) for k, t in inputs.items()} if inputs else None
|
|
136
|
+
)
|
|
137
|
+
outputs_for_prompt = (
|
|
138
|
+
{k: t.__name__ if hasattr(t, "__name__") else str(t) for k, t in outputs.items()} if outputs else None
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def _build_tool_detail(tool_name: str, raw_input: dict) -> str:
|
|
142
|
+
if tool_name == "Bash":
|
|
143
|
+
return raw_input.get("command", "")
|
|
144
|
+
elif tool_name in ("Write", "Read", "Edit"):
|
|
145
|
+
return raw_input.get("file_path", "")
|
|
146
|
+
return ""
|
|
147
|
+
|
|
148
|
+
def _read_package_file(filename: str) -> list[str]:
|
|
149
|
+
"""Read a package file (one package per line) from the output directory."""
|
|
150
|
+
path = workspace / filename
|
|
151
|
+
if not path.exists():
|
|
152
|
+
return []
|
|
153
|
+
return [line.strip() for line in path.read_text().split("\n") if line.strip()]
|
|
154
|
+
|
|
155
|
+
async def _run_tests_in_sandbox() -> tuple[str, int]:
|
|
156
|
+
"""Build image from packages.txt/system_packages.txt and run tests in sandbox."""
|
|
157
|
+
packages = _read_package_file("packages.txt")
|
|
158
|
+
system_packages = _read_package_file("system_packages.txt")
|
|
159
|
+
|
|
160
|
+
# Read solution and tests from agent output
|
|
161
|
+
solution_content = (workspace / "solution.py").read_text()
|
|
162
|
+
tests_content = (workspace / "tests.py").read_text()
|
|
163
|
+
|
|
164
|
+
detected = [p for p in packages if p not in base_pkgs]
|
|
165
|
+
all_packages = base_pkgs + detected
|
|
166
|
+
|
|
167
|
+
built_image = await build_image(
|
|
168
|
+
language=language,
|
|
169
|
+
base_pkgs=base_pkgs,
|
|
170
|
+
detected_packages=detected,
|
|
171
|
+
detected_system_packages=system_packages,
|
|
172
|
+
previously_installed_packages=[],
|
|
173
|
+
previously_installed_system_packages=[],
|
|
174
|
+
additional_commands=[],
|
|
175
|
+
image_name=f"sandbox-{name}",
|
|
176
|
+
current_image=None,
|
|
177
|
+
image_config=image_config,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
_exec_state["test_count"] += 1
|
|
181
|
+
run_tests_output = await run_tests.aio(
|
|
182
|
+
code=solution_content,
|
|
183
|
+
tests=tests_content,
|
|
184
|
+
image=built_image,
|
|
185
|
+
name=f"sandbox-{name}",
|
|
186
|
+
resources=resources,
|
|
187
|
+
retries=retries,
|
|
188
|
+
timeout=timeout,
|
|
189
|
+
env_vars=env_vars,
|
|
190
|
+
secrets=secrets,
|
|
191
|
+
cache=cache,
|
|
192
|
+
_attempt=_exec_state["test_count"],
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
test_exit_code, test_output = (
|
|
196
|
+
run_tests_output.exit_code,
|
|
197
|
+
run_tests_output.output,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
exit_code = int(test_exit_code.strip()) if test_exit_code.strip() else -1
|
|
201
|
+
|
|
202
|
+
# Update sandbox state
|
|
203
|
+
_sandbox_state["image"] = built_image
|
|
204
|
+
_sandbox_state["packages"] = all_packages
|
|
205
|
+
_sandbox_state["system_packages"] = system_packages
|
|
206
|
+
|
|
207
|
+
# Checkpoint image reference so it can be restored on retry/cache-hit paths
|
|
208
|
+
# where _run_tests_in_sandbox is never called (e.g. continue: False).
|
|
209
|
+
try:
|
|
210
|
+
img_ref_file = workspace / ".image_ref"
|
|
211
|
+
img_ref_file.write_text(str(built_image))
|
|
212
|
+
await File.from_local(
|
|
213
|
+
str(img_ref_file),
|
|
214
|
+
remote_destination=File.named_remote(f"{_task_hash}-image_ref").path,
|
|
215
|
+
)
|
|
216
|
+
except Exception as e:
|
|
217
|
+
logger.warning(f"Failed to checkpoint image ref: {e}")
|
|
218
|
+
|
|
219
|
+
return test_output, exit_code
|
|
220
|
+
|
|
221
|
+
_CHECKPOINT_FILES = (
|
|
222
|
+
"solution.py",
|
|
223
|
+
"tests.py",
|
|
224
|
+
"packages.txt",
|
|
225
|
+
"system_packages.txt",
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
async def on_user_prompt_submit(
|
|
229
|
+
input_data: Optional[dict],
|
|
230
|
+
tool_use_id: Optional[str],
|
|
231
|
+
context: Optional[dict],
|
|
232
|
+
) -> dict:
|
|
233
|
+
"""Restore workspace from named-remote checkpoints before the agent starts.
|
|
234
|
+
|
|
235
|
+
File.named_remote() produces the same deterministic remote path for a given
|
|
236
|
+
name within a task execution, so uploads from attempt N are visible on attempt
|
|
237
|
+
N+1. On first run nothing has been uploaded yet, so download fails silently and
|
|
238
|
+
agent starts fresh. On retry, workspace is restored, agent resumes from
|
|
239
|
+
the last known state instead of regenerating from scratch.
|
|
240
|
+
"""
|
|
241
|
+
restored: list[str] = []
|
|
242
|
+
for filename in _CHECKPOINT_FILES:
|
|
243
|
+
remote = File.named_remote(f"{_task_hash}-{filename}")
|
|
244
|
+
logger.info(f"Path: {remote.path}, exists: {await remote.exists()}")
|
|
245
|
+
|
|
246
|
+
if await remote.exists():
|
|
247
|
+
logger.info(f"Restoring {filename} from checkpoint remote storage")
|
|
248
|
+
await remote.download(str(workspace / filename))
|
|
249
|
+
restored.append(filename)
|
|
250
|
+
|
|
251
|
+
if "solution.py" not in restored:
|
|
252
|
+
return {}
|
|
253
|
+
|
|
254
|
+
exit_code_val = -1
|
|
255
|
+
remote_exit = File.named_remote(f"{_task_hash}-exit_code")
|
|
256
|
+
if await remote_exit.exists():
|
|
257
|
+
await remote_exit.download(str(workspace / "exit_code"))
|
|
258
|
+
exit_code_val = int((workspace / "exit_code").read_text().strip())
|
|
259
|
+
|
|
260
|
+
remote_img = File.named_remote(f"{_task_hash}-image_ref")
|
|
261
|
+
if await remote_img.exists():
|
|
262
|
+
img_ref_file = workspace / ".image_ref"
|
|
263
|
+
await remote_img.download(str(img_ref_file))
|
|
264
|
+
_sandbox_state["image"] = img_ref_file.read_text().strip()
|
|
265
|
+
|
|
266
|
+
remote_result = File.named_remote(f"{_task_hash}-result")
|
|
267
|
+
if await remote_result.exists():
|
|
268
|
+
await remote_result.download(str(workspace / "result"))
|
|
269
|
+
|
|
270
|
+
if exit_code_val == 0:
|
|
271
|
+
logger.info("Tests already passed from prior run. Skipping agent execution and returning cached results.")
|
|
272
|
+
return {
|
|
273
|
+
"continue": False,
|
|
274
|
+
"stopReason": "Tests already passed from prior run.",
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
existing = [f for f in ("solution.py", "tests.py", "packages.txt") if f in restored]
|
|
278
|
+
ctx = f"Existing workspace files: {', '.join(existing)}."
|
|
279
|
+
|
|
280
|
+
logger.info(f"Restored workspace from checkpoints: {ctx}")
|
|
281
|
+
|
|
282
|
+
if exit_code_val != -1:
|
|
283
|
+
ctx += f" Last test exit_code: {exit_code_val}."
|
|
284
|
+
return {
|
|
285
|
+
"hookSpecificOutput": {
|
|
286
|
+
"hookEventName": "UserPromptSubmit",
|
|
287
|
+
"additionalContext": ctx,
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
async def on_pre_tool_use(
|
|
292
|
+
input_data: Optional[dict[str, str | object]],
|
|
293
|
+
tool_use_id: Optional[str],
|
|
294
|
+
context: Optional[dict],
|
|
295
|
+
) -> dict:
|
|
296
|
+
"""PreToolUse: checkpoint writes to named remote; run sandbox directly."""
|
|
297
|
+
if not input_data:
|
|
298
|
+
return {}
|
|
299
|
+
|
|
300
|
+
tool_name = input_data.get("tool_name", "unknown")
|
|
301
|
+
raw_input = input_data.get("tool_input") or {}
|
|
302
|
+
detail = _build_tool_detail(tool_name, raw_input)
|
|
303
|
+
cmd = raw_input.get("command", "")
|
|
304
|
+
action = _classify_bash_command(cmd) if cmd else "allow"
|
|
305
|
+
|
|
306
|
+
tool_calls.append(f"{tool_name}: {detail}" if detail else tool_name)
|
|
307
|
+
|
|
308
|
+
if action == "deny":
|
|
309
|
+
return {
|
|
310
|
+
"hookSpecificOutput": {
|
|
311
|
+
"hookEventName": "PreToolUse",
|
|
312
|
+
"permissionDecision": "deny",
|
|
313
|
+
"permissionDecisionReason": "This command is not allowed in the sandbox.",
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
# If a checkpoint already exists in named remote for this file (written by a
|
|
318
|
+
# previous attempt's PostToolUse), restore it to workspace and deny the write.
|
|
319
|
+
# PostToolUse handles checkpointing after each successful write.
|
|
320
|
+
if tool_name == "Write":
|
|
321
|
+
file_path = raw_input.get("file_path", "")
|
|
322
|
+
filename = Path(file_path).name if file_path else ""
|
|
323
|
+
if filename in _CHECKPOINT_FILES:
|
|
324
|
+
remote = File.named_remote(f"{_task_hash}-{filename}")
|
|
325
|
+
if await remote.exists():
|
|
326
|
+
await remote.download(file_path)
|
|
327
|
+
return {
|
|
328
|
+
"hookSpecificOutput": {
|
|
329
|
+
"hookEventName": "PreToolUse",
|
|
330
|
+
"permissionDecision": "deny",
|
|
331
|
+
"permissionDecisionReason": (
|
|
332
|
+
"File already exists from a previous attempt. Use the existing file content."
|
|
333
|
+
),
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
# Sandbox runs as a Flyte container task, cached by Flyte's task cache.
|
|
338
|
+
if tool_name == "Bash" and action == "pytest":
|
|
339
|
+
has_required_files = (workspace / "solution.py").exists() and (workspace / "tests.py").exists()
|
|
340
|
+
if has_required_files:
|
|
341
|
+
try:
|
|
342
|
+
test_output, exit_code = await _run_tests_in_sandbox()
|
|
343
|
+
|
|
344
|
+
logger.info(f"Sandbox test execution: exit_code={exit_code}")
|
|
345
|
+
(workspace / "result").write_text(test_output)
|
|
346
|
+
(workspace / "exit_code").write_text(str(exit_code))
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
await File.from_local(
|
|
350
|
+
str(workspace / "exit_code"),
|
|
351
|
+
remote_destination=File.named_remote(f"{_task_hash}-exit_code").path,
|
|
352
|
+
)
|
|
353
|
+
except Exception as e:
|
|
354
|
+
logger.warning(f"Failed to checkpoint exit_code: {e}")
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
await File.from_local(
|
|
358
|
+
str(workspace / "result"),
|
|
359
|
+
remote_destination=File.named_remote(f"{_task_hash}-result").path,
|
|
360
|
+
)
|
|
361
|
+
except Exception as e:
|
|
362
|
+
logger.warning(f"Failed to checkpoint result: {e}")
|
|
363
|
+
|
|
364
|
+
return {
|
|
365
|
+
"hookSpecificOutput": {
|
|
366
|
+
"hookEventName": "PreToolUse",
|
|
367
|
+
"permissionDecision": "deny",
|
|
368
|
+
"permissionDecisionReason": (
|
|
369
|
+
f"Tests executed in isolated sandbox (exit_code={exit_code}). "
|
|
370
|
+
f"Results written to {workspace}/result. "
|
|
371
|
+
f"Read it for the test output and ignore any warnings."
|
|
372
|
+
),
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
except InvalidPackageError as e:
|
|
376
|
+
logger.warning(f"Invalid system package: {e.package_name}")
|
|
377
|
+
return {
|
|
378
|
+
"hookSpecificOutput": {
|
|
379
|
+
"hookEventName": "PreToolUse",
|
|
380
|
+
"permissionDecision": "deny",
|
|
381
|
+
"permissionDecisionReason": (
|
|
382
|
+
f"Image build failed: system package '{e.package_name}' does not exist "
|
|
383
|
+
f"in apt repositories. Remove it from {workspace}/system_packages.txt and try again."
|
|
384
|
+
),
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
except Exception as e:
|
|
388
|
+
logger.warning(f"Sandbox execution failed: {e}")
|
|
389
|
+
return {
|
|
390
|
+
"hookSpecificOutput": {
|
|
391
|
+
"hookEventName": "PreToolUse",
|
|
392
|
+
"permissionDecision": "deny",
|
|
393
|
+
"permissionDecisionReason": "Test execution failed in the sandbox.",
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
return {}
|
|
398
|
+
|
|
399
|
+
@flyte.trace
|
|
400
|
+
async def trace_post_tool_use(tool_name: str, detail: str, file_path: str) -> dict:
|
|
401
|
+
"""Checkpoint Write/Edit results to named remote.
|
|
402
|
+
|
|
403
|
+
Upload is inside the trace so sequence-based replay on retries dedups re-uploads:
|
|
404
|
+
same file at the same sequence position results in cache hit, so body will be skipped.
|
|
405
|
+
New writes/edits are at new positions, cache miss, body runs, and uploads.
|
|
406
|
+
"""
|
|
407
|
+
filename = Path(file_path).name if file_path else ""
|
|
408
|
+
local = workspace / filename
|
|
409
|
+
if tool_name in ("Write", "Edit") and filename in _CHECKPOINT_FILES and local.exists():
|
|
410
|
+
try:
|
|
411
|
+
remote_path = File.named_remote(f"{_task_hash}-{filename}").path
|
|
412
|
+
await File.from_local(
|
|
413
|
+
str(local),
|
|
414
|
+
remote_destination=remote_path,
|
|
415
|
+
)
|
|
416
|
+
logger.info(f"Checkpointed {filename} to named remote {remote_path} after {tool_name} tool use.")
|
|
417
|
+
except Exception as e:
|
|
418
|
+
logger.warning(f"Failed to checkpoint {filename}: {e}")
|
|
419
|
+
return {}
|
|
420
|
+
|
|
421
|
+
async def on_post_tool_use(
|
|
422
|
+
input_data: Optional[dict[str, str | object]],
|
|
423
|
+
tool_use_id: Optional[str],
|
|
424
|
+
context: Optional[dict],
|
|
425
|
+
) -> dict:
|
|
426
|
+
"""PostToolUse: checkpoint Write/Edit results via trace."""
|
|
427
|
+
if not input_data:
|
|
428
|
+
return {}
|
|
429
|
+
|
|
430
|
+
tool_name = input_data.get("tool_name", "unknown")
|
|
431
|
+
raw_input = input_data.get("tool_input") or {}
|
|
432
|
+
detail = _build_tool_detail(tool_name, raw_input)
|
|
433
|
+
file_path = raw_input.get("file_path", "") if tool_name in ("Write", "Edit") else ""
|
|
434
|
+
|
|
435
|
+
await trace_post_tool_use(tool_name, detail, file_path)
|
|
436
|
+
return {}
|
|
437
|
+
|
|
438
|
+
@flyte.trace
|
|
439
|
+
async def trace_post_tool_use_failure(
|
|
440
|
+
tool_name: str,
|
|
441
|
+
error: str,
|
|
442
|
+
is_interrupt: bool,
|
|
443
|
+
) -> dict[str, str | bool]:
|
|
444
|
+
return {}
|
|
445
|
+
|
|
446
|
+
async def on_post_tool_use_failure(
|
|
447
|
+
input_data: Optional[dict[str, str | bool | object]],
|
|
448
|
+
tool_use_id: Optional[str],
|
|
449
|
+
context: Optional[dict],
|
|
450
|
+
) -> dict:
|
|
451
|
+
"""PostToolUseFailure: record tool errors."""
|
|
452
|
+
await trace_post_tool_use_failure(
|
|
453
|
+
tool_name=str(input_data.get("tool_name", "")),
|
|
454
|
+
error=str(input_data.get("error", "")),
|
|
455
|
+
is_interrupt=bool(input_data.get("is_interrupt", False)),
|
|
456
|
+
)
|
|
457
|
+
return {}
|
|
458
|
+
|
|
459
|
+
@flyte.trace
|
|
460
|
+
async def trace_stop(
|
|
461
|
+
tool_calls_count: int,
|
|
462
|
+
tool_calls_summary: str,
|
|
463
|
+
test_execution_count: int,
|
|
464
|
+
) -> dict[str, int | str]:
|
|
465
|
+
return {}
|
|
466
|
+
|
|
467
|
+
async def on_stop(
|
|
468
|
+
input_data: Optional[dict[str, str | bool]],
|
|
469
|
+
tool_use_id: Optional[str],
|
|
470
|
+
context: Optional[dict],
|
|
471
|
+
) -> dict:
|
|
472
|
+
"""Stop: checkpoint workspace files and record a summary of the agent run."""
|
|
473
|
+
await trace_stop(
|
|
474
|
+
tool_calls_count=len(tool_calls),
|
|
475
|
+
tool_calls_summary=", ".join(tool_calls[-20:]), # last 20 to stay bounded
|
|
476
|
+
test_execution_count=_exec_state["test_count"],
|
|
477
|
+
)
|
|
478
|
+
return {}
|
|
479
|
+
|
|
480
|
+
# Build the task description from user prompt + schema + constraints + data
|
|
481
|
+
base_prompt = build_enhanced_prompt(
|
|
482
|
+
prompt,
|
|
483
|
+
language,
|
|
484
|
+
schema or None,
|
|
485
|
+
constraints,
|
|
486
|
+
data_context or None,
|
|
487
|
+
inputs_for_prompt,
|
|
488
|
+
outputs_for_prompt,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# System prompt: role + workspace rules + test workflow
|
|
492
|
+
system_prompt = f"""
|
|
493
|
+
You are an expert {language.capitalize()} code generation agent.
|
|
494
|
+
Your job is to write a working {language.capitalize()} solution, comprehensive tests, and iterate until all tests pass.
|
|
495
|
+
|
|
496
|
+
There are two separate environments you must understand:
|
|
497
|
+
|
|
498
|
+
## 1. AGENT WORKSPACE (where you write files)
|
|
499
|
+
|
|
500
|
+
Path: {workspace}
|
|
501
|
+
This is your working directory. Write all your files here:
|
|
502
|
+
- {workspace}/solution.py — your complete solution code
|
|
503
|
+
- {workspace}/tests.py — pytest-based tests
|
|
504
|
+
- {workspace}/packages.txt — pip dependencies (one per line, no stdlib)
|
|
505
|
+
- {workspace}/system_packages.txt — apt dependencies (only if needed for native libs)
|
|
506
|
+
|
|
507
|
+
## 2. SANDBOX RUNTIME (where tests and solution.py EXECUTE)
|
|
508
|
+
|
|
509
|
+
Tests and the solution run inside an isolated sandbox. The sandbox has these paths:
|
|
510
|
+
- /var/inputs/solution.py — your solution code is placed here automatically.
|
|
511
|
+
- /var/inputs/ — READ-ONLY. Input data files are also mounted here (e.g. /var/inputs/csv_data).
|
|
512
|
+
- /var/outputs/ — Write output files here. PRE-CREATED. NEVER delete or recreate it.
|
|
513
|
+
Write outputs like: open('/var/outputs/<name>', 'w').write(str(value))
|
|
514
|
+
|
|
515
|
+
CRITICAL: {workspace} does not exist inside the sandbox. Never reference {workspace} in solution.py or tests.py.
|
|
516
|
+
The solution code lives at /var/inputs/solution.py inside the sandbox.
|
|
517
|
+
|
|
518
|
+
## SOLUTION RULES
|
|
519
|
+
- Include all imports at the top.
|
|
520
|
+
- Must be a runnable script with an `if __name__ == '__main__':` block.
|
|
521
|
+
- Use argparse with optional (--prefixed) arguments for all inputs.
|
|
522
|
+
Example: parser.add_argument('--csv_data', required=True). Do not use positional arguments.
|
|
523
|
+
- Read input files from paths passed via argparse (at runtime these will be /var/inputs/<name>).
|
|
524
|
+
- Write all outputs to /var/outputs/<name>. Always use the literal path '/var/outputs'.
|
|
525
|
+
- Do not run solution.py directly. Only validate via pytest.
|
|
526
|
+
|
|
527
|
+
## TEST RULES
|
|
528
|
+
- Use pytest. Import from solution module: `from solution import ...`
|
|
529
|
+
- Test the full execution path end-to-end.
|
|
530
|
+
- If you need to run solution.py as a subprocess, use: `{language} /var/inputs/solution.py --arg value`
|
|
531
|
+
- Tests run in the sandbox: create test input files under /var/inputs/, run the solution, then verify /var/outputs/.
|
|
532
|
+
|
|
533
|
+
## WORKFLOW
|
|
534
|
+
1. Write {workspace}/solution.py (code references /var/inputs and /var/outputs)
|
|
535
|
+
2. Write {workspace}/tests.py (tests verify /var/outputs after running solution)
|
|
536
|
+
3. Write {workspace}/packages.txt (and system_packages.txt if needed)
|
|
537
|
+
4. Run: pytest {workspace}/tests.py -v --tb=short
|
|
538
|
+
Tests run in an isolated sandbox with your packages installed.
|
|
539
|
+
If the command is denied, read {workspace}/result and {workspace}/exit_code for output.
|
|
540
|
+
5. If tests fail: read {workspace}/result, fix code/packages, re-run. Repeat until exit code is 0."""
|
|
541
|
+
|
|
542
|
+
# User query: the actual task + data context
|
|
543
|
+
user_query = f"""Generate a {language.capitalize()} solution for the following task:
|
|
544
|
+
|
|
545
|
+
{base_prompt}"""
|
|
546
|
+
|
|
547
|
+
user_query += "\n\nStart by creating the solution code, then tests, then packages.txt, then run the tests."
|
|
548
|
+
|
|
549
|
+
logger.info("Running Agent SDK...")
|
|
550
|
+
|
|
551
|
+
async def restrict_to_workspace(
|
|
552
|
+
tool_name: str,
|
|
553
|
+
input_data: dict,
|
|
554
|
+
context: object,
|
|
555
|
+
) -> object:
|
|
556
|
+
"""Deny file operations outside the workspace directory."""
|
|
557
|
+
from claude_agent_sdk.types import PermissionResultAllow, PermissionResultDeny
|
|
558
|
+
|
|
559
|
+
if tool_name in ("Write", "Edit", "Read"):
|
|
560
|
+
file_path = input_data.get("file_path", "")
|
|
561
|
+
if file_path:
|
|
562
|
+
try:
|
|
563
|
+
Path(file_path).resolve().relative_to(workspace.resolve()) # noqa: ASYNC240
|
|
564
|
+
except ValueError:
|
|
565
|
+
return PermissionResultDeny(
|
|
566
|
+
message=(
|
|
567
|
+
f"Access outside workspace is not allowed: {file_path}. "
|
|
568
|
+
f"Only paths under {workspace} are permitted."
|
|
569
|
+
)
|
|
570
|
+
)
|
|
571
|
+
return PermissionResultAllow(updated_input=input_data)
|
|
572
|
+
|
|
573
|
+
try:
|
|
574
|
+
options = ClaudeAgentOptions(
|
|
575
|
+
model=model,
|
|
576
|
+
system_prompt=system_prompt,
|
|
577
|
+
allowed_tools=["Bash", "Read", "Write", "Edit"],
|
|
578
|
+
cwd=str(workspace),
|
|
579
|
+
permission_mode="acceptEdits",
|
|
580
|
+
max_turns=max_turns,
|
|
581
|
+
can_use_tool=restrict_to_workspace,
|
|
582
|
+
hooks={
|
|
583
|
+
"UserPromptSubmit": [HookMatcher(hooks=[on_user_prompt_submit])],
|
|
584
|
+
"PreToolUse": [HookMatcher(hooks=[on_pre_tool_use])],
|
|
585
|
+
"PostToolUse": [HookMatcher(hooks=[on_post_tool_use])],
|
|
586
|
+
"PostToolUseFailure": [HookMatcher(hooks=[on_post_tool_use_failure])],
|
|
587
|
+
"Stop": [HookMatcher(hooks=[on_stop])],
|
|
588
|
+
},
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
async with ClaudeSDKClient(options=options) as client:
|
|
592
|
+
await client.query(user_query)
|
|
593
|
+
|
|
594
|
+
async for message in client.receive_response():
|
|
595
|
+
# Log agent messages for debugging
|
|
596
|
+
if hasattr(message, "type"):
|
|
597
|
+
logger.debug(f"Agent message: type={message.type}")
|
|
598
|
+
if hasattr(message, "content"):
|
|
599
|
+
content = str(message.content)
|
|
600
|
+
if len(content) > 200:
|
|
601
|
+
content = content[:200] + "..."
|
|
602
|
+
logger.debug(f"Agent content: {content}")
|
|
603
|
+
|
|
604
|
+
# Log what files the agent created
|
|
605
|
+
existing_files = list(workspace.iterdir()) if workspace.exists() else [] # noqa: ASYNC240
|
|
606
|
+
logger.info(f"Agent output directory contents: {[f.name for f in existing_files]}")
|
|
607
|
+
|
|
608
|
+
# Read outputs
|
|
609
|
+
solution_file = workspace / "solution.py"
|
|
610
|
+
tests_file = workspace / "tests.py"
|
|
611
|
+
result_file = workspace / "result"
|
|
612
|
+
exit_code_file = workspace / "exit_code"
|
|
613
|
+
|
|
614
|
+
# Agent must create solution.py at minimum
|
|
615
|
+
if not solution_file.exists():
|
|
616
|
+
raise RuntimeError(
|
|
617
|
+
f"Agent did not create solution.py. "
|
|
618
|
+
f"Files in output dir: {[f.name for f in existing_files]}. "
|
|
619
|
+
f"Tool calls made: {len(tool_calls)} ({', '.join(tool_calls[:10])})"
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
# Read agent-created files (with defaults for optional ones)
|
|
623
|
+
solution_content = solution_file.read_text()
|
|
624
|
+
tests = tests_file.read_text() if tests_file.exists() else ""
|
|
625
|
+
detected_packages = _read_package_file("packages.txt")
|
|
626
|
+
detected_system_packages = _read_package_file("system_packages.txt")
|
|
627
|
+
|
|
628
|
+
# Test-execution files are created by _run_tests_in_sandbox, not the agent
|
|
629
|
+
test_output = result_file.read_text() if result_file.exists() else ""
|
|
630
|
+
exit_code_text = exit_code_file.read_text().strip() if exit_code_file.exists() else ""
|
|
631
|
+
exit_code = int(exit_code_text) if exit_code_text else -1
|
|
632
|
+
|
|
633
|
+
success = exit_code == 0
|
|
634
|
+
|
|
635
|
+
logger.info(f"Agent SDK completed: success={success}, exit_code={exit_code}")
|
|
636
|
+
logger.info(f"Tool calls: {len(tool_calls)} total")
|
|
637
|
+
|
|
638
|
+
plan = CodePlan(
|
|
639
|
+
description="Agent SDK autonomous generation",
|
|
640
|
+
approach="Agent explored, generated, tested, and fixed autonomously",
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
return CodeGenEvalResult(
|
|
644
|
+
plan=plan,
|
|
645
|
+
solution=CodeSolution(
|
|
646
|
+
language=language,
|
|
647
|
+
code=solution_content,
|
|
648
|
+
system_packages=detected_system_packages,
|
|
649
|
+
),
|
|
650
|
+
tests=tests,
|
|
651
|
+
success=success,
|
|
652
|
+
output=test_output,
|
|
653
|
+
exit_code=exit_code,
|
|
654
|
+
error=test_output if not success else None,
|
|
655
|
+
attempts=1,
|
|
656
|
+
conversation_history=[],
|
|
657
|
+
detected_packages=detected_packages,
|
|
658
|
+
detected_system_packages=detected_system_packages,
|
|
659
|
+
image=_sandbox_state["image"] or None,
|
|
660
|
+
total_input_tokens=0,
|
|
661
|
+
total_output_tokens=0,
|
|
662
|
+
declared_inputs=inputs,
|
|
663
|
+
declared_outputs=outputs,
|
|
664
|
+
data_context=data_context,
|
|
665
|
+
original_samples=original_samples,
|
|
666
|
+
generated_schemas=generated_schemas,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
except Exception as e:
|
|
670
|
+
logger.error(f"Agent SDK generation failed: {e}")
|
|
671
|
+
raise
|