latch-eval-tools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_eval_tools/__init__.py +64 -0
- latch_eval_tools/answer_extraction.py +35 -0
- latch_eval_tools/cli/__init__.py +0 -0
- latch_eval_tools/cli/eval_lint.py +185 -0
- latch_eval_tools/eval_server.py +570 -0
- latch_eval_tools/faas_utils.py +13 -0
- latch_eval_tools/graders/__init__.py +40 -0
- latch_eval_tools/graders/base.py +29 -0
- latch_eval_tools/graders/distribution.py +102 -0
- latch_eval_tools/graders/label_set.py +75 -0
- latch_eval_tools/graders/marker_gene.py +317 -0
- latch_eval_tools/graders/multiple_choice.py +38 -0
- latch_eval_tools/graders/numeric.py +137 -0
- latch_eval_tools/graders/spatial.py +93 -0
- latch_eval_tools/harness/__init__.py +27 -0
- latch_eval_tools/harness/claudecode.py +212 -0
- latch_eval_tools/harness/minisweagent.py +265 -0
- latch_eval_tools/harness/plotsagent.py +156 -0
- latch_eval_tools/harness/runner.py +191 -0
- latch_eval_tools/harness/utils.py +191 -0
- latch_eval_tools/headless_eval_server.py +727 -0
- latch_eval_tools/linter/__init__.py +25 -0
- latch_eval_tools/linter/explanations.py +331 -0
- latch_eval_tools/linter/runner.py +146 -0
- latch_eval_tools/linter/schema.py +126 -0
- latch_eval_tools/linter/validators.py +595 -0
- latch_eval_tools/types.py +30 -0
- latch_eval_tools/wrapper_entrypoint.py +316 -0
- latch_eval_tools-0.1.0.dist-info/METADATA +118 -0
- latch_eval_tools-0.1.0.dist-info/RECORD +33 -0
- latch_eval_tools-0.1.0.dist-info/WHEEL +4 -0
- latch_eval_tools-0.1.0.dist-info/entry_points.txt +2 -0
- latch_eval_tools-0.1.0.dist-info/licenses/LICENSE +1 -0
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SpatialAdjacencyGrader(BinaryGrader):
|
|
5
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
6
|
+
scoring = config.get("scoring", {})
|
|
7
|
+
thresholds = scoring.get("pass_thresholds", {})
|
|
8
|
+
|
|
9
|
+
max_median_ic_to_pc = thresholds.get("max_median_ic_to_pc_um", 25.0)
|
|
10
|
+
max_p90_ic_to_pc = thresholds.get("max_p90_ic_to_pc_um", 80.0)
|
|
11
|
+
min_pct_within_15um = thresholds.get("min_pct_ic_within_15um", 60.0)
|
|
12
|
+
min_pct_mixed_within_55um = thresholds.get("min_pct_ic_mixed_within_55um", 60.0)
|
|
13
|
+
|
|
14
|
+
required_fields = [
|
|
15
|
+
"median_ic_to_pc_um",
|
|
16
|
+
"p90_ic_to_pc_um",
|
|
17
|
+
"pct_ic_within_15um",
|
|
18
|
+
"pct_ic_mixed_within_55um",
|
|
19
|
+
"adjacency_pass"
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
missing = [f for f in required_fields if f not in agent_answer]
|
|
23
|
+
if missing:
|
|
24
|
+
return GraderResult(
|
|
25
|
+
passed=False,
|
|
26
|
+
metrics={},
|
|
27
|
+
reasoning=f"Agent answer missing required fields: {missing}",
|
|
28
|
+
agent_answer=agent_answer
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
median_ic_to_pc = agent_answer["median_ic_to_pc_um"]
|
|
32
|
+
p90_ic_to_pc = agent_answer["p90_ic_to_pc_um"]
|
|
33
|
+
pct_within_15um = agent_answer["pct_ic_within_15um"]
|
|
34
|
+
pct_mixed_within_55um = agent_answer["pct_ic_mixed_within_55um"]
|
|
35
|
+
adjacency_pass = agent_answer["adjacency_pass"]
|
|
36
|
+
|
|
37
|
+
median_pass = median_ic_to_pc <= max_median_ic_to_pc
|
|
38
|
+
p90_pass = p90_ic_to_pc <= max_p90_ic_to_pc
|
|
39
|
+
within_15um_pass = pct_within_15um >= min_pct_within_15um
|
|
40
|
+
mixed_55um_pass = pct_mixed_within_55um >= min_pct_mixed_within_55um
|
|
41
|
+
|
|
42
|
+
passed = median_pass and p90_pass and within_15um_pass and mixed_55um_pass and adjacency_pass
|
|
43
|
+
|
|
44
|
+
metrics = {
|
|
45
|
+
"median_ic_to_pc_um": median_ic_to_pc,
|
|
46
|
+
"p90_ic_to_pc_um": p90_ic_to_pc,
|
|
47
|
+
"pct_ic_within_15um": pct_within_15um,
|
|
48
|
+
"pct_ic_mixed_within_55um": pct_mixed_within_55um,
|
|
49
|
+
"adjacency_pass": adjacency_pass,
|
|
50
|
+
"max_median_threshold": max_median_ic_to_pc,
|
|
51
|
+
"max_p90_threshold": max_p90_ic_to_pc,
|
|
52
|
+
"min_pct_15um_threshold": min_pct_within_15um,
|
|
53
|
+
"min_pct_55um_threshold": min_pct_mixed_within_55um,
|
|
54
|
+
"median_pass": median_pass,
|
|
55
|
+
"p90_pass": p90_pass,
|
|
56
|
+
"within_15um_pass": within_15um_pass,
|
|
57
|
+
"mixed_55um_pass": mixed_55um_pass,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
lines = [
|
|
61
|
+
f"Spatial Adjacency Analysis: {'PASS' if passed else 'FAIL'}",
|
|
62
|
+
"",
|
|
63
|
+
"IC->PC Distance Metrics:",
|
|
64
|
+
f" {'+'if median_pass else 'x'} Median distance: {median_ic_to_pc:.2f} um (threshold: <={max_median_ic_to_pc:.2f} um)",
|
|
65
|
+
f" {'+'if p90_pass else 'x'} 90th percentile: {p90_ic_to_pc:.2f} um (threshold: <={max_p90_ic_to_pc:.2f} um)",
|
|
66
|
+
"",
|
|
67
|
+
"IC Proximity to PC:",
|
|
68
|
+
f" {'+'if within_15um_pass else 'x'} IC within 15 um: {pct_within_15um:.1f}% (threshold: >={min_pct_within_15um:.1f}%)",
|
|
69
|
+
f" {'+'if mixed_55um_pass else 'x'} IC with PC within 55 um: {pct_mixed_within_55um:.1f}% (threshold: >={min_pct_mixed_within_55um:.1f}%)",
|
|
70
|
+
"",
|
|
71
|
+
f"Agent adjacency assessment: {'+'if adjacency_pass else 'x'} {adjacency_pass}",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
if not passed:
|
|
75
|
+
failures = []
|
|
76
|
+
if not median_pass:
|
|
77
|
+
failures.append(f"Median {median_ic_to_pc:.2f} > {max_median_ic_to_pc:.2f} um")
|
|
78
|
+
if not p90_pass:
|
|
79
|
+
failures.append(f"P90 {p90_ic_to_pc:.2f} > {max_p90_ic_to_pc:.2f} um")
|
|
80
|
+
if not within_15um_pass:
|
|
81
|
+
failures.append(f"Within 15 um {pct_within_15um:.1f}% < {min_pct_within_15um:.1f}%")
|
|
82
|
+
if not mixed_55um_pass:
|
|
83
|
+
failures.append(f"Within 55 um {pct_mixed_within_55um:.1f}% < {min_pct_mixed_within_55um:.1f}%")
|
|
84
|
+
if not adjacency_pass:
|
|
85
|
+
failures.append("Agent marked adjacency_pass as false")
|
|
86
|
+
lines.append(f"\nFailure: {'; '.join(failures)}")
|
|
87
|
+
|
|
88
|
+
return GraderResult(
|
|
89
|
+
passed=passed,
|
|
90
|
+
metrics=metrics,
|
|
91
|
+
reasoning="\n".join(lines),
|
|
92
|
+
agent_answer=agent_answer
|
|
93
|
+
)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from latch_eval_tools.harness.runner import EvalRunner
|
|
2
|
+
from latch_eval_tools.harness.utils import (
|
|
3
|
+
get_project_root,
|
|
4
|
+
get_cache_dir,
|
|
5
|
+
download_single_dataset,
|
|
6
|
+
download_data,
|
|
7
|
+
batch_download_datasets,
|
|
8
|
+
setup_workspace,
|
|
9
|
+
cleanup_workspace,
|
|
10
|
+
)
|
|
11
|
+
from latch_eval_tools.harness.minisweagent import run_minisweagent_task
|
|
12
|
+
from latch_eval_tools.harness.claudecode import run_claudecode_task
|
|
13
|
+
from latch_eval_tools.harness.plotsagent import run_plotsagent_task
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"EvalRunner",
|
|
17
|
+
"get_project_root",
|
|
18
|
+
"get_cache_dir",
|
|
19
|
+
"download_single_dataset",
|
|
20
|
+
"download_data",
|
|
21
|
+
"batch_download_datasets",
|
|
22
|
+
"setup_workspace",
|
|
23
|
+
"cleanup_workspace",
|
|
24
|
+
"run_minisweagent_task",
|
|
25
|
+
"run_claudecode_task",
|
|
26
|
+
"run_plotsagent_task",
|
|
27
|
+
]
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import subprocess
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
EVAL_TIMEOUT = 600
|
|
9
|
+
|
|
10
|
+
MODEL_MAP = {
|
|
11
|
+
"anthropic/claude-opus-4-5": "opus",
|
|
12
|
+
"anthropic/claude-sonnet-4-5": "sonnet",
|
|
13
|
+
"anthropic/claude-sonnet-4": "claude-sonnet-4-20250514",
|
|
14
|
+
"anthropic/claude-opus-4": "claude-opus-4-20250514",
|
|
15
|
+
"anthropic/claude-haiku-3-5": "haiku",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def run_claudecode_task(
|
|
20
|
+
task_prompt: str,
|
|
21
|
+
work_dir: Path,
|
|
22
|
+
model_name: str | None = None,
|
|
23
|
+
eval_timeout: int = EVAL_TIMEOUT,
|
|
24
|
+
) -> dict:
|
|
25
|
+
"""Run Claude Code agent on a task.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
task_prompt: Task description for the agent
|
|
29
|
+
work_dir: Working directory for the agent
|
|
30
|
+
model_name: Optional model name (e.g., "anthropic/claude-sonnet-4")
|
|
31
|
+
eval_timeout: Timeout for entire evaluation (seconds)
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
dict with keys "answer" (parsed JSON or None) and "metadata"
|
|
35
|
+
"""
|
|
36
|
+
agent_log_file = work_dir / "agent_output.log"
|
|
37
|
+
if agent_log_file.exists():
|
|
38
|
+
agent_log_file.unlink()
|
|
39
|
+
|
|
40
|
+
enhanced_prompt = _enhance_prompt_with_local_files(task_prompt, work_dir)
|
|
41
|
+
enhanced_prompt += """
|
|
42
|
+
|
|
43
|
+
CRITICAL: You must write eval_answer.json BEFORE signaling completion.
|
|
44
|
+
Correct order: 1) Perform analysis 2) Write eval_answer.json with your answer 3) Exit"""
|
|
45
|
+
|
|
46
|
+
cmd = ["claude", "--print", "--dangerously-skip-permissions", "--verbose", "--output-format", "stream-json"]
|
|
47
|
+
|
|
48
|
+
if model_name:
|
|
49
|
+
claude_model = MODEL_MAP.get(model_name, model_name)
|
|
50
|
+
cmd.extend(["--model", claude_model])
|
|
51
|
+
|
|
52
|
+
run_as_claude_user = os.geteuid() == 0
|
|
53
|
+
if run_as_claude_user:
|
|
54
|
+
import pwd
|
|
55
|
+
import shutil
|
|
56
|
+
import stat
|
|
57
|
+
try:
|
|
58
|
+
pwd.getpwnam("claude")
|
|
59
|
+
home_dir = Path.home()
|
|
60
|
+
current_mode = home_dir.stat().st_mode
|
|
61
|
+
home_dir.chmod(current_mode | stat.S_IXOTH)
|
|
62
|
+
eval_cache_dir = home_dir / ".eval_cache"
|
|
63
|
+
if eval_cache_dir.exists():
|
|
64
|
+
shutil.chown(eval_cache_dir, user="claude", group="claude")
|
|
65
|
+
for item in eval_cache_dir.rglob("*"):
|
|
66
|
+
try:
|
|
67
|
+
shutil.chown(item, user="claude", group="claude")
|
|
68
|
+
except PermissionError:
|
|
69
|
+
pass
|
|
70
|
+
except KeyError:
|
|
71
|
+
run_as_claude_user = False
|
|
72
|
+
|
|
73
|
+
env = os.environ.copy()
|
|
74
|
+
|
|
75
|
+
start_time = time.time()
|
|
76
|
+
timed_out = False
|
|
77
|
+
claude_result = None
|
|
78
|
+
trajectory = []
|
|
79
|
+
|
|
80
|
+
try:
|
|
81
|
+
if run_as_claude_user:
|
|
82
|
+
env_vars = [f"{k}={v}" for k, v in env.items() if k.endswith("_API_KEY")]
|
|
83
|
+
cmd = ["runuser", "-u", "claude", "--", "env"] + env_vars + cmd
|
|
84
|
+
|
|
85
|
+
process = subprocess.Popen(
|
|
86
|
+
cmd,
|
|
87
|
+
stdin=subprocess.PIPE,
|
|
88
|
+
stdout=subprocess.PIPE,
|
|
89
|
+
stderr=subprocess.PIPE,
|
|
90
|
+
cwd=str(work_dir),
|
|
91
|
+
env=env,
|
|
92
|
+
text=True,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
stdout, stderr = process.communicate(
|
|
97
|
+
input=enhanced_prompt,
|
|
98
|
+
timeout=eval_timeout
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
with open(agent_log_file, 'w') as log_file:
|
|
102
|
+
log_file.write(stdout)
|
|
103
|
+
if stderr:
|
|
104
|
+
log_file.write(f"\n\nSTDERR:\n{stderr}")
|
|
105
|
+
|
|
106
|
+
for line in stdout.strip().split('\n'):
|
|
107
|
+
if line:
|
|
108
|
+
try:
|
|
109
|
+
event = json.loads(line)
|
|
110
|
+
trajectory.append(event)
|
|
111
|
+
if event.get("type") == "result":
|
|
112
|
+
claude_result = event
|
|
113
|
+
except json.JSONDecodeError:
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
except subprocess.TimeoutExpired:
|
|
117
|
+
timed_out = True
|
|
118
|
+
process.kill()
|
|
119
|
+
stdout, stderr = process.communicate()
|
|
120
|
+
with open(agent_log_file, 'w') as log_file:
|
|
121
|
+
log_file.write(stdout)
|
|
122
|
+
log_file.write(f"\n\nAgent timed out after {eval_timeout} seconds")
|
|
123
|
+
|
|
124
|
+
except Exception as e:
|
|
125
|
+
with open(agent_log_file, 'a') as f:
|
|
126
|
+
f.write(f"\nError running claude: {e}")
|
|
127
|
+
|
|
128
|
+
duration = time.time() - start_time
|
|
129
|
+
print(f"Agent output saved to: {agent_log_file}")
|
|
130
|
+
|
|
131
|
+
if trajectory:
|
|
132
|
+
trajectory_file = work_dir / "trajectory.json"
|
|
133
|
+
trajectory_file.write_text(json.dumps(trajectory, indent=2))
|
|
134
|
+
print(f"Trajectory saved to: {trajectory_file}")
|
|
135
|
+
|
|
136
|
+
eval_answer_file = work_dir / "eval_answer.json"
|
|
137
|
+
agent_answer = None
|
|
138
|
+
error_details = None
|
|
139
|
+
|
|
140
|
+
if not eval_answer_file.exists():
|
|
141
|
+
log_tail = ""
|
|
142
|
+
if agent_log_file.exists():
|
|
143
|
+
log_content = agent_log_file.read_text()
|
|
144
|
+
log_tail = log_content[-1000:]
|
|
145
|
+
|
|
146
|
+
error_msg = "Agent timed out" if timed_out else "Agent did not create eval_answer.json"
|
|
147
|
+
error_details = {
|
|
148
|
+
"error": error_msg,
|
|
149
|
+
"timed_out": timed_out,
|
|
150
|
+
"log_tail": log_tail
|
|
151
|
+
}
|
|
152
|
+
print(f"\nWarning: {error_msg}")
|
|
153
|
+
else:
|
|
154
|
+
try:
|
|
155
|
+
agent_answer = json.loads(eval_answer_file.read_text())
|
|
156
|
+
except json.JSONDecodeError as e:
|
|
157
|
+
error_details = {
|
|
158
|
+
"error": f"Failed to parse eval_answer.json: {e}",
|
|
159
|
+
"file_contents": eval_answer_file.read_text()[:500]
|
|
160
|
+
}
|
|
161
|
+
print(f"\nWarning: Failed to parse eval_answer.json: {e}")
|
|
162
|
+
|
|
163
|
+
metadata = {
|
|
164
|
+
"duration_s": round(duration, 2),
|
|
165
|
+
"model": model_name,
|
|
166
|
+
}
|
|
167
|
+
if claude_result:
|
|
168
|
+
metadata["total_cost"] = claude_result.get("total_cost_usd")
|
|
169
|
+
metadata["n_turns"] = claude_result.get("num_turns")
|
|
170
|
+
metadata["session_id"] = claude_result.get("session_id")
|
|
171
|
+
metadata["usage"] = claude_result.get("usage")
|
|
172
|
+
if timed_out:
|
|
173
|
+
metadata["timed_out"] = True
|
|
174
|
+
metadata["eval_timeout_seconds"] = eval_timeout
|
|
175
|
+
if error_details:
|
|
176
|
+
metadata["error_details"] = error_details
|
|
177
|
+
|
|
178
|
+
return {"answer": agent_answer, "metadata": metadata}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _enhance_prompt_with_local_files(task_prompt: str, work_dir: Path) -> str:
|
|
182
|
+
"""Extract <ContextualNodeData> and add local file list to prompt."""
|
|
183
|
+
contextual_data_match = re.search(
|
|
184
|
+
r'<ContextualNodeData>(.*?)</ContextualNodeData>',
|
|
185
|
+
task_prompt,
|
|
186
|
+
re.DOTALL
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
if not contextual_data_match:
|
|
190
|
+
return task_prompt
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
contextual_data = json.loads(contextual_data_match.group(1))
|
|
194
|
+
except json.JSONDecodeError:
|
|
195
|
+
return task_prompt
|
|
196
|
+
|
|
197
|
+
local_files = []
|
|
198
|
+
for item in contextual_data:
|
|
199
|
+
if 'local_path' in item:
|
|
200
|
+
local_files.append(item['local_path'])
|
|
201
|
+
|
|
202
|
+
if not local_files:
|
|
203
|
+
return task_prompt
|
|
204
|
+
|
|
205
|
+
file_list = "\n".join([f"- {f}" for f in local_files])
|
|
206
|
+
enhancement = f"\n\nThe following data files are available in your current working directory:\n{file_list}\n\nUse these local filenames to access the data.\n"
|
|
207
|
+
|
|
208
|
+
parts = task_prompt.split('<ContextualNodeData>')
|
|
209
|
+
if len(parts) == 2:
|
|
210
|
+
return parts[0] + enhancement + '<ContextualNodeData>' + parts[1]
|
|
211
|
+
|
|
212
|
+
return task_prompt
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import io
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import signal
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
OPERATION_TIMEOUT = 300
|
|
10
|
+
EVAL_TIMEOUT = 600
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AgentTimeoutError(Exception):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _timeout_handler(signum, frame):
|
|
18
|
+
raise AgentTimeoutError("Agent exceeded time limit")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StreamingLogFile:
|
|
22
|
+
def __init__(self, file_path):
|
|
23
|
+
self.file_path = file_path
|
|
24
|
+
self.buffer = io.StringIO()
|
|
25
|
+
|
|
26
|
+
def write(self, data):
|
|
27
|
+
self.buffer.write(data)
|
|
28
|
+
with open(self.file_path, 'a') as f:
|
|
29
|
+
f.write(data)
|
|
30
|
+
f.flush()
|
|
31
|
+
|
|
32
|
+
def flush(self):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
def getvalue(self):
|
|
36
|
+
return self.buffer.getvalue()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _patch_agent_for_progress(log_file, agent_class):
|
|
40
|
+
original_add_message = agent_class.add_message
|
|
41
|
+
|
|
42
|
+
def patched_add_message(self, role, content, **kwargs):
|
|
43
|
+
original_add_message(self, role, content, **kwargs)
|
|
44
|
+
|
|
45
|
+
with open(log_file, 'a') as f:
|
|
46
|
+
if role == "assistant":
|
|
47
|
+
step_num = len([m for m in self.messages if m.get("role") == "assistant"])
|
|
48
|
+
f.write(f"\n[Step {step_num}]\n")
|
|
49
|
+
f.write(f"Assistant: {content}\n")
|
|
50
|
+
elif role == "user" and len(self.messages) > 2:
|
|
51
|
+
f.write(f"Observation: {content}\n")
|
|
52
|
+
f.flush()
|
|
53
|
+
|
|
54
|
+
agent_class.add_message = patched_add_message
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def run_minisweagent_task(
|
|
58
|
+
task_prompt: str,
|
|
59
|
+
work_dir: Path,
|
|
60
|
+
model_name: str | None = None,
|
|
61
|
+
agent_config: dict | None = None,
|
|
62
|
+
operation_timeout: int = OPERATION_TIMEOUT,
|
|
63
|
+
eval_timeout: int = EVAL_TIMEOUT,
|
|
64
|
+
) -> dict:
|
|
65
|
+
"""Run MiniSWE agent on a task.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
task_prompt: Task description for the agent
|
|
69
|
+
work_dir: Working directory for the agent
|
|
70
|
+
model_name: Optional model name (e.g., "anthropic/claude-sonnet-4")
|
|
71
|
+
agent_config: Optional agent configuration dict
|
|
72
|
+
operation_timeout: Timeout for individual operations (seconds)
|
|
73
|
+
eval_timeout: Timeout for entire evaluation (seconds)
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
dict with keys "answer" (parsed JSON or None) and "metadata"
|
|
77
|
+
"""
|
|
78
|
+
from minisweagent.agents.default import DefaultAgent, AgentConfig, FormatError
|
|
79
|
+
from minisweagent.environments.local import LocalEnvironment
|
|
80
|
+
from minisweagent.models import get_model
|
|
81
|
+
import re
|
|
82
|
+
|
|
83
|
+
class FlexibleAgent(DefaultAgent):
|
|
84
|
+
def parse_action(self, response: dict) -> dict:
|
|
85
|
+
content = response["content"]
|
|
86
|
+
actions = re.findall(r"```(?:bash|sh|shell)?\s*\n(.*?)\n?```", content, re.DOTALL)
|
|
87
|
+
if len(actions) == 1:
|
|
88
|
+
return {"action": actions[0].strip(), **response}
|
|
89
|
+
raise FormatError(self.render_template(self.config.format_error_template, actions=actions))
|
|
90
|
+
|
|
91
|
+
def has_finished(self, output: dict):
|
|
92
|
+
from minisweagent.agents.default import Submitted
|
|
93
|
+
full_output = output.get("output", "")
|
|
94
|
+
for marker in ["COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT", "MINI_SWE_AGENT_FINAL_OUTPUT"]:
|
|
95
|
+
if marker in full_output:
|
|
96
|
+
idx = full_output.find(marker)
|
|
97
|
+
rest = full_output[idx + len(marker):].strip()
|
|
98
|
+
raise Submitted(rest)
|
|
99
|
+
|
|
100
|
+
original_dir = os.getcwd()
|
|
101
|
+
|
|
102
|
+
agent_log_file = work_dir / "agent_output.log"
|
|
103
|
+
_patch_agent_for_progress(agent_log_file, FlexibleAgent)
|
|
104
|
+
if agent_log_file.exists():
|
|
105
|
+
agent_log_file.unlink()
|
|
106
|
+
|
|
107
|
+
captured_output = StreamingLogFile(agent_log_file)
|
|
108
|
+
original_stdout = sys.stdout
|
|
109
|
+
original_stderr = sys.stderr
|
|
110
|
+
|
|
111
|
+
class TeeOutput:
|
|
112
|
+
def __init__(self, *streams):
|
|
113
|
+
self.streams = streams
|
|
114
|
+
|
|
115
|
+
def write(self, data):
|
|
116
|
+
for stream in self.streams:
|
|
117
|
+
stream.write(data)
|
|
118
|
+
if hasattr(stream, 'flush'):
|
|
119
|
+
stream.flush()
|
|
120
|
+
|
|
121
|
+
def flush(self):
|
|
122
|
+
for stream in self.streams:
|
|
123
|
+
if hasattr(stream, 'flush'):
|
|
124
|
+
stream.flush()
|
|
125
|
+
|
|
126
|
+
agent = None
|
|
127
|
+
timed_out = False
|
|
128
|
+
try:
|
|
129
|
+
os.chdir(str(work_dir))
|
|
130
|
+
|
|
131
|
+
sys.stdout = TeeOutput(original_stdout, captured_output)
|
|
132
|
+
sys.stderr = TeeOutput(original_stderr, captured_output)
|
|
133
|
+
|
|
134
|
+
enhanced_prompt = _enhance_prompt_with_local_files(task_prompt, work_dir)
|
|
135
|
+
|
|
136
|
+
enhanced_prompt += """
|
|
137
|
+
|
|
138
|
+
CRITICAL INSTRUCTIONS:
|
|
139
|
+
1. Do NOT wrap your code in try/except blocks. Let errors propagate so you can see them and fix them in subsequent steps.
|
|
140
|
+
2. You must write eval_answer.json BEFORE printing the completion signal.
|
|
141
|
+
3. Correct order: Perform analysis -> Write eval_answer.json -> Print 'COMPLETE_TASK_AND_SUBMIT_FINAL_OUTPUT' as your FINAL line of output."""
|
|
142
|
+
|
|
143
|
+
if model_name:
|
|
144
|
+
os.environ['MSWEA_MODEL_NAME'] = model_name
|
|
145
|
+
|
|
146
|
+
model = get_model()
|
|
147
|
+
env = LocalEnvironment(timeout=operation_timeout)
|
|
148
|
+
|
|
149
|
+
if agent_config:
|
|
150
|
+
agent = FlexibleAgent(model, env, step_limit=100, **agent_config)
|
|
151
|
+
else:
|
|
152
|
+
agent = FlexibleAgent(model, env, step_limit=100)
|
|
153
|
+
|
|
154
|
+
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
|
|
155
|
+
signal.alarm(eval_timeout)
|
|
156
|
+
|
|
157
|
+
try:
|
|
158
|
+
agent.run(enhanced_prompt)
|
|
159
|
+
except AgentTimeoutError:
|
|
160
|
+
timed_out = True
|
|
161
|
+
print(f"\nAgent timed out after {eval_timeout} seconds")
|
|
162
|
+
except Exception as e:
|
|
163
|
+
if "Submitted" in str(type(e).__name__):
|
|
164
|
+
pass
|
|
165
|
+
else:
|
|
166
|
+
raise
|
|
167
|
+
finally:
|
|
168
|
+
signal.alarm(0)
|
|
169
|
+
signal.signal(signal.SIGALRM, old_handler)
|
|
170
|
+
|
|
171
|
+
sys.stdout = original_stdout
|
|
172
|
+
sys.stderr = original_stderr
|
|
173
|
+
|
|
174
|
+
print(f"Agent output saved to: {agent_log_file}")
|
|
175
|
+
|
|
176
|
+
if hasattr(agent, "messages"):
|
|
177
|
+
trajectory_file = work_dir / "trajectory.json"
|
|
178
|
+
trajectory_data = {
|
|
179
|
+
"messages": agent.messages,
|
|
180
|
+
"actions": getattr(agent, "actions", [])
|
|
181
|
+
}
|
|
182
|
+
trajectory_file.write_text(json.dumps(trajectory_data, indent=2))
|
|
183
|
+
print(f"Agent trajectory saved to: {trajectory_file}")
|
|
184
|
+
print(f" Total message exchanges: {len(agent.messages)}")
|
|
185
|
+
|
|
186
|
+
eval_answer_file = work_dir / "eval_answer.json"
|
|
187
|
+
agent_answer = None
|
|
188
|
+
error_details = None
|
|
189
|
+
|
|
190
|
+
if not eval_answer_file.exists():
|
|
191
|
+
agent_log_file = work_dir / "agent_output.log"
|
|
192
|
+
log_tail = ""
|
|
193
|
+
if agent_log_file.exists():
|
|
194
|
+
log_content = agent_log_file.read_text()
|
|
195
|
+
log_tail = log_content[-1000:]
|
|
196
|
+
|
|
197
|
+
trajectory_info = ""
|
|
198
|
+
if hasattr(agent, "messages"):
|
|
199
|
+
trajectory_info = f"Agent had {len(agent.messages)} message exchanges."
|
|
200
|
+
|
|
201
|
+
error_msg = "Agent timed out" if timed_out else "Agent did not create eval_answer.json"
|
|
202
|
+
error_details = {
|
|
203
|
+
"error": error_msg,
|
|
204
|
+
"timed_out": timed_out,
|
|
205
|
+
"trajectory_info": trajectory_info,
|
|
206
|
+
"log_tail": log_tail
|
|
207
|
+
}
|
|
208
|
+
print(f"\nWarning: {error_msg}. {trajectory_info}")
|
|
209
|
+
else:
|
|
210
|
+
try:
|
|
211
|
+
agent_answer = json.loads(eval_answer_file.read_text())
|
|
212
|
+
except json.JSONDecodeError as e:
|
|
213
|
+
error_details = {
|
|
214
|
+
"error": f"Failed to parse eval_answer.json: {e}",
|
|
215
|
+
"file_contents": eval_answer_file.read_text()[:500]
|
|
216
|
+
}
|
|
217
|
+
print(f"\nWarning: Failed to parse eval_answer.json: {e}")
|
|
218
|
+
|
|
219
|
+
metadata = {}
|
|
220
|
+
if hasattr(agent, "model"):
|
|
221
|
+
metadata["total_cost"] = getattr(agent.model, "cost", None)
|
|
222
|
+
metadata["n_steps"] = getattr(agent.model, "n_calls", None)
|
|
223
|
+
if hasattr(agent, "messages"):
|
|
224
|
+
metadata["n_messages"] = len(agent.messages)
|
|
225
|
+
if timed_out:
|
|
226
|
+
metadata["timed_out"] = True
|
|
227
|
+
metadata["eval_timeout_seconds"] = eval_timeout
|
|
228
|
+
if error_details:
|
|
229
|
+
metadata["error_details"] = error_details
|
|
230
|
+
|
|
231
|
+
return {"answer": agent_answer, "metadata": metadata}
|
|
232
|
+
|
|
233
|
+
finally:
|
|
234
|
+
os.chdir(original_dir)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def _enhance_prompt_with_local_files(task_prompt: str, work_dir: Path) -> str:
|
|
238
|
+
"""Extract <ContextualNodeData> and add local file list to prompt."""
|
|
239
|
+
contextual_data_match = re.search(r'<ContextualNodeData>(.*?)</ContextualNodeData>', task_prompt, re.DOTALL)
|
|
240
|
+
|
|
241
|
+
if not contextual_data_match:
|
|
242
|
+
return task_prompt
|
|
243
|
+
|
|
244
|
+
try:
|
|
245
|
+
contextual_data = json.loads(contextual_data_match.group(1))
|
|
246
|
+
except json.JSONDecodeError:
|
|
247
|
+
return task_prompt
|
|
248
|
+
|
|
249
|
+
local_files = []
|
|
250
|
+
for item in contextual_data:
|
|
251
|
+
if 'local_path' in item:
|
|
252
|
+
local_files.append(item['local_path'])
|
|
253
|
+
|
|
254
|
+
if not local_files:
|
|
255
|
+
return task_prompt
|
|
256
|
+
|
|
257
|
+
file_list = "\n".join([f"- {f}" for f in local_files])
|
|
258
|
+
|
|
259
|
+
enhancement = f"\n\nThe following data files are available in your current working directory:\n{file_list}\n\nUse these local filenames to access the data.\n"
|
|
260
|
+
|
|
261
|
+
parts = task_prompt.split('<ContextualNodeData>')
|
|
262
|
+
if len(parts) == 2:
|
|
263
|
+
return parts[0] + enhancement + '<ContextualNodeData>' + parts[1]
|
|
264
|
+
|
|
265
|
+
return task_prompt
|