latch-eval-tools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_eval_tools/__init__.py +64 -0
- latch_eval_tools/answer_extraction.py +35 -0
- latch_eval_tools/cli/__init__.py +0 -0
- latch_eval_tools/cli/eval_lint.py +185 -0
- latch_eval_tools/eval_server.py +570 -0
- latch_eval_tools/faas_utils.py +13 -0
- latch_eval_tools/graders/__init__.py +40 -0
- latch_eval_tools/graders/base.py +29 -0
- latch_eval_tools/graders/distribution.py +102 -0
- latch_eval_tools/graders/label_set.py +75 -0
- latch_eval_tools/graders/marker_gene.py +317 -0
- latch_eval_tools/graders/multiple_choice.py +38 -0
- latch_eval_tools/graders/numeric.py +137 -0
- latch_eval_tools/graders/spatial.py +93 -0
- latch_eval_tools/harness/__init__.py +27 -0
- latch_eval_tools/harness/claudecode.py +212 -0
- latch_eval_tools/harness/minisweagent.py +265 -0
- latch_eval_tools/harness/plotsagent.py +156 -0
- latch_eval_tools/harness/runner.py +191 -0
- latch_eval_tools/harness/utils.py +191 -0
- latch_eval_tools/headless_eval_server.py +727 -0
- latch_eval_tools/linter/__init__.py +25 -0
- latch_eval_tools/linter/explanations.py +331 -0
- latch_eval_tools/linter/runner.py +146 -0
- latch_eval_tools/linter/schema.py +126 -0
- latch_eval_tools/linter/validators.py +595 -0
- latch_eval_tools/types.py +30 -0
- latch_eval_tools/wrapper_entrypoint.py +316 -0
- latch_eval_tools-0.1.0.dist-info/METADATA +118 -0
- latch_eval_tools-0.1.0.dist-info/RECORD +33 -0
- latch_eval_tools-0.1.0.dist-info/WHEEL +4 -0
- latch_eval_tools-0.1.0.dist-info/entry_points.txt +2 -0
- latch_eval_tools-0.1.0.dist-info/licenses/LICENSE +1 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
from latch_eval_tools import faas_utils
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import asyncio
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
import socket
|
|
9
|
+
import sys
|
|
10
|
+
import textwrap
|
|
11
|
+
import time
|
|
12
|
+
import uuid
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
|
|
15
|
+
import websockets
|
|
16
|
+
import websockets.server
|
|
17
|
+
from latch_eval_tools.types import Eval, EvalResult
|
|
18
|
+
from latch_eval_tools.graders import GRADER_REGISTRY
|
|
19
|
+
from latch_eval_tools.answer_extraction import extract_answer_from_conversation
|
|
20
|
+
from latch_eval_tools.headless_eval_server import run_eval_batch_headless
|
|
21
|
+
|
|
22
|
+
faas_runtime_dir = Path(os.environ.get("LATCH_PLOTS_FAAS_PATH", "/root/latch-plots-faas")) / "runtime" / "mount"
|
|
23
|
+
sys.path.insert(0, str(faas_runtime_dir))
|
|
24
|
+
|
|
25
|
+
from socketio import SocketIo
|
|
26
|
+
from utils import gql_query
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_auth_token() -> str:
|
|
30
|
+
return f"Latch-SDK-Token {(Path.home() / '.latch' / 'token').read_text()}"
|
|
31
|
+
|
|
32
|
+
class EvalServer:
|
|
33
|
+
sandbox_dir: Path
|
|
34
|
+
current_eval_case: Eval | None
|
|
35
|
+
agent_proc: asyncio.subprocess.Process | None
|
|
36
|
+
agent_sock: socket.socket | None
|
|
37
|
+
agent_conn: SocketIo | None
|
|
38
|
+
websocket: websockets.server.WebSocketServerProtocol | None
|
|
39
|
+
session_id: int | None
|
|
40
|
+
eval_complete: bool
|
|
41
|
+
|
|
42
|
+
def __init__(self, sandbox_dir: Path):
|
|
43
|
+
self.sandbox_dir = sandbox_dir
|
|
44
|
+
self.current_eval_case = None
|
|
45
|
+
self.agent_proc = None
|
|
46
|
+
self.agent_sock = None
|
|
47
|
+
self.agent_conn = None
|
|
48
|
+
self.websocket = None
|
|
49
|
+
self.session_id = None
|
|
50
|
+
self.eval_complete = False
|
|
51
|
+
|
|
52
|
+
async def start_agent(self):
|
|
53
|
+
print("[eval] Starting agent")
|
|
54
|
+
|
|
55
|
+
sock_a, sock_agent = socket.socketpair(family=socket.AF_UNIX)
|
|
56
|
+
sock_a.setblocking(False)
|
|
57
|
+
sock_agent_fd = sock_agent.detach()
|
|
58
|
+
|
|
59
|
+
self.agent_sock = sock_a
|
|
60
|
+
self.agent_conn = await SocketIo.from_socket(sock_a)
|
|
61
|
+
|
|
62
|
+
agent_path = faas_runtime_dir / "agent.py"
|
|
63
|
+
|
|
64
|
+
self.agent_proc = await asyncio.create_subprocess_exec(
|
|
65
|
+
sys.executable,
|
|
66
|
+
"-u",
|
|
67
|
+
str(agent_path),
|
|
68
|
+
str(sock_agent_fd),
|
|
69
|
+
pass_fds=[sock_agent_fd],
|
|
70
|
+
stdin=asyncio.subprocess.DEVNULL,
|
|
71
|
+
stdout=asyncio.subprocess.PIPE,
|
|
72
|
+
stderr=asyncio.subprocess.PIPE,
|
|
73
|
+
env={
|
|
74
|
+
**os.environ,
|
|
75
|
+
"LATCH_SANDBOX_ROOT": str(self.sandbox_dir),
|
|
76
|
+
"PYTHONUNBUFFERED": "1",
|
|
77
|
+
"AGENT_DEBUG": "1",
|
|
78
|
+
},
|
|
79
|
+
preexec_fn=lambda: os.nice(5),
|
|
80
|
+
limit=1024 * 1024,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
async def stream_output(stream, prefix=""):
|
|
84
|
+
while True:
|
|
85
|
+
try:
|
|
86
|
+
line = await stream.readline()
|
|
87
|
+
if not line:
|
|
88
|
+
break
|
|
89
|
+
decoded = line.decode().rstrip()
|
|
90
|
+
if len(decoded) > 1000:
|
|
91
|
+
decoded = decoded[:1000] + "... [TRUNCATED]"
|
|
92
|
+
print(f"[agent stream] {prefix}{decoded}", flush=True)
|
|
93
|
+
except (ValueError, asyncio.LimitOverrunError) as e:
|
|
94
|
+
if "limit" in str(e).lower():
|
|
95
|
+
chunk = await stream.read(8192)
|
|
96
|
+
if not chunk:
|
|
97
|
+
break
|
|
98
|
+
print(f"[agent] {prefix}[Large output truncated: {len(chunk)} bytes]", flush=True)
|
|
99
|
+
else:
|
|
100
|
+
raise
|
|
101
|
+
except Exception as e:
|
|
102
|
+
print(f"[agent] {prefix}[Error reading output: {e}]", flush=True)
|
|
103
|
+
break
|
|
104
|
+
|
|
105
|
+
asyncio.create_task(stream_output(self.agent_proc.stdout, ""))
|
|
106
|
+
asyncio.create_task(stream_output(self.agent_proc.stderr, "[stderr] "))
|
|
107
|
+
|
|
108
|
+
msg = await self.agent_conn.recv()
|
|
109
|
+
if msg.get("type") == "ready":
|
|
110
|
+
print("[eval] Agent subprocess started and ready")
|
|
111
|
+
|
|
112
|
+
async def stop_agent(self):
|
|
113
|
+
if self.agent_proc:
|
|
114
|
+
print("[eval] Stopping agent")
|
|
115
|
+
try:
|
|
116
|
+
self.agent_proc.terminate()
|
|
117
|
+
await asyncio.wait_for(self.agent_proc.wait(), timeout=2)
|
|
118
|
+
except TimeoutError:
|
|
119
|
+
self.agent_proc.kill()
|
|
120
|
+
await self.agent_proc.wait()
|
|
121
|
+
|
|
122
|
+
if self.agent_sock:
|
|
123
|
+
try:
|
|
124
|
+
self.agent_sock.close()
|
|
125
|
+
except Exception:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
self.agent_proc = None
|
|
129
|
+
self.agent_sock = None
|
|
130
|
+
self.agent_conn = None
|
|
131
|
+
|
|
132
|
+
def clear_notebook_context(self):
|
|
133
|
+
context_dir = faas_runtime_dir / "agent_config" / "context" / "notebook_context"
|
|
134
|
+
if context_dir.exists():
|
|
135
|
+
for file in context_dir.iterdir():
|
|
136
|
+
if file.is_file() and file.name != ".gitkeep":
|
|
137
|
+
file.unlink()
|
|
138
|
+
print("[eval] Cleared notebook context files")
|
|
139
|
+
|
|
140
|
+
async def initialize_agent_session(self, websocket):
|
|
141
|
+
print("[eval] Waiting for console init to get session_id...")
|
|
142
|
+
init_msg = await websocket.recv()
|
|
143
|
+
console_init = json.loads(init_msg)
|
|
144
|
+
if console_init.get("type") == "init":
|
|
145
|
+
self.session_id = int(console_init.get("session_id"))
|
|
146
|
+
|
|
147
|
+
await self.agent_conn.send({"type": "init", "session_id": self.session_id, "eval_mode": True})
|
|
148
|
+
|
|
149
|
+
while True:
|
|
150
|
+
msg = await self.agent_conn.recv()
|
|
151
|
+
if msg.get("type") == "agent_status" and msg.get("status") == "ready":
|
|
152
|
+
print("[eval] Agent initialized and ready")
|
|
153
|
+
break
|
|
154
|
+
print(f"[eval] Skipping init message: {msg.get('type')}")
|
|
155
|
+
|
|
156
|
+
self.clear_notebook_context()
|
|
157
|
+
|
|
158
|
+
async def handle_agent_message(self, msg: dict):
|
|
159
|
+
msg_type = msg.get("type")
|
|
160
|
+
|
|
161
|
+
if msg_type == "agent_history_updated":
|
|
162
|
+
await self.check_for_completion()
|
|
163
|
+
|
|
164
|
+
async def check_for_completion(self):
|
|
165
|
+
history = await self.fetch_full_conversation_history()
|
|
166
|
+
for payload in history:
|
|
167
|
+
if payload.get("type") == "anthropic_message" and payload.get("role") == "assistant":
|
|
168
|
+
content = payload.get("content", [])
|
|
169
|
+
for block in content:
|
|
170
|
+
if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == "submit_response":
|
|
171
|
+
tool_input = block.get("input", {})
|
|
172
|
+
if tool_input.get("next_status") == "done":
|
|
173
|
+
self.eval_complete = True
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
async def fetch_full_conversation_history(self) -> list[dict]:
|
|
177
|
+
try:
|
|
178
|
+
resp = await gql_query(
|
|
179
|
+
auth=get_auth_token(),
|
|
180
|
+
query="""
|
|
181
|
+
query AgentHistory($sessionId: BigInt!) {
|
|
182
|
+
agentHistories(condition: {sessionId: $sessionId, removed: false}, orderBy: ID_ASC) {
|
|
183
|
+
nodes { id payload }
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
""",
|
|
187
|
+
variables={"sessionId": str(self.session_id)},
|
|
188
|
+
)
|
|
189
|
+
nodes = resp.get("data", {}).get("agentHistories", {}).get("nodes", [])
|
|
190
|
+
return [node.get("payload", {}) for node in nodes]
|
|
191
|
+
except Exception as e:
|
|
192
|
+
print(f"[eval] Error fetching conversation history: {e}")
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
async def reset_for_next_test(self):
|
|
196
|
+
print("[eval] Clearing agent history for next test...")
|
|
197
|
+
await self.agent_conn.send({"type": "agent_clear_history"})
|
|
198
|
+
self.clear_notebook_context()
|
|
199
|
+
self.eval_complete = False
|
|
200
|
+
self.current_eval_case = None
|
|
201
|
+
print("[eval] Reset complete")
|
|
202
|
+
|
|
203
|
+
async def keep_forwarding(self):
|
|
204
|
+
async def forward_agent_to_console():
|
|
205
|
+
while True:
|
|
206
|
+
msg = await self.agent_conn.recv()
|
|
207
|
+
await self.websocket.send(json.dumps(msg))
|
|
208
|
+
|
|
209
|
+
async def forward_console_to_agent():
|
|
210
|
+
async for message in self.websocket:
|
|
211
|
+
msg = json.loads(message)
|
|
212
|
+
await self.agent_conn.send(msg)
|
|
213
|
+
|
|
214
|
+
forward_task = asyncio.create_task(forward_agent_to_console())
|
|
215
|
+
receive_task = asyncio.create_task(forward_console_to_agent())
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
await asyncio.gather(forward_task, receive_task)
|
|
219
|
+
except asyncio.CancelledError:
|
|
220
|
+
forward_task.cancel()
|
|
221
|
+
receive_task.cancel()
|
|
222
|
+
|
|
223
|
+
async def run_eval(self, eval_case: Eval) -> EvalResult:
|
|
224
|
+
print(f"\n{'=' * 70}")
|
|
225
|
+
print(f"Running eval: {eval_case.id}")
|
|
226
|
+
print("=" * 70)
|
|
227
|
+
|
|
228
|
+
if not self.websocket:
|
|
229
|
+
raise RuntimeError("websocket must be set before calling run_eval()")
|
|
230
|
+
if not self.agent_proc:
|
|
231
|
+
raise RuntimeError("agent must be started before calling run_eval()")
|
|
232
|
+
|
|
233
|
+
start_time = time.time()
|
|
234
|
+
|
|
235
|
+
self.current_eval_case = eval_case
|
|
236
|
+
self.eval_complete = False
|
|
237
|
+
|
|
238
|
+
data_context = ""
|
|
239
|
+
if eval_case.data_node:
|
|
240
|
+
data_nodes = eval_case.data_node if isinstance(eval_case.data_node, list) else [eval_case.data_node]
|
|
241
|
+
contextual_data = []
|
|
242
|
+
for node in data_nodes:
|
|
243
|
+
contextual_data.append({
|
|
244
|
+
"type": "File",
|
|
245
|
+
"path": node,
|
|
246
|
+
"id": node.replace("latch:///", "").replace(".csv", "").replace(".h5ad", ""),
|
|
247
|
+
})
|
|
248
|
+
data_context = f"\n\nHere is the context of the selected nodes the user would like to use: <ContextualNodeData>{json.dumps(contextual_data)}</ContextualNodeData>"
|
|
249
|
+
|
|
250
|
+
initial_query = textwrap.dedent(f"""
|
|
251
|
+
{eval_case.task}
|
|
252
|
+
|
|
253
|
+
IMPORTANT: When you finish this task, include your answer in your submit_response summary as raw JSON (no markdown code fences) wrapped in <EVAL_ANSWER></EVAL_ANSWER> tags.
|
|
254
|
+
|
|
255
|
+
Example format for your summary:
|
|
256
|
+
<EVAL_ANSWER>
|
|
257
|
+
{{"field1": value1, "field2": value2}}
|
|
258
|
+
</EVAL_ANSWER>
|
|
259
|
+
|
|
260
|
+
Do NOT use markdown code fences (```json) inside the EVAL_ANSWER tags - use raw JSON only.
|
|
261
|
+
{data_context}
|
|
262
|
+
""").strip()
|
|
263
|
+
|
|
264
|
+
async def forward_agent_to_console():
|
|
265
|
+
try:
|
|
266
|
+
while True:
|
|
267
|
+
msg = await self.agent_conn.recv()
|
|
268
|
+
msg_type = msg.get("type", "unknown")
|
|
269
|
+
if msg_type != "agent_stream_delta":
|
|
270
|
+
print(f"[eval] agent→console: {msg_type}")
|
|
271
|
+
await self.handle_agent_message(msg)
|
|
272
|
+
await self.websocket.send(json.dumps(msg))
|
|
273
|
+
except Exception as e:
|
|
274
|
+
print(f"[eval] Agent forwarding ended: {e}")
|
|
275
|
+
|
|
276
|
+
async def forward_console_to_agent():
|
|
277
|
+
try:
|
|
278
|
+
async for message in self.websocket:
|
|
279
|
+
msg = json.loads(message)
|
|
280
|
+
msg_type = msg.get("type")
|
|
281
|
+
print(f"[eval] console→agent: {msg_type}")
|
|
282
|
+
await self.agent_conn.send(msg)
|
|
283
|
+
except Exception as e:
|
|
284
|
+
print(f"[eval] Console forwarding ended: {e}")
|
|
285
|
+
|
|
286
|
+
forward_task = asyncio.create_task(forward_agent_to_console())
|
|
287
|
+
receive_task = asyncio.create_task(forward_console_to_agent())
|
|
288
|
+
|
|
289
|
+
print("[eval] Resetting kernel state...")
|
|
290
|
+
await self.websocket.send(json.dumps({
|
|
291
|
+
"type": "agent_action",
|
|
292
|
+
"action": "reset_kernel_globals",
|
|
293
|
+
"params": {},
|
|
294
|
+
"tx_id": str(uuid.uuid4()),
|
|
295
|
+
}))
|
|
296
|
+
|
|
297
|
+
await self.agent_conn.send({
|
|
298
|
+
"type": "agent_query",
|
|
299
|
+
"query": initial_query,
|
|
300
|
+
"request_id": f"eval-init-{self.session_id}"
|
|
301
|
+
})
|
|
302
|
+
|
|
303
|
+
while not self.eval_complete:
|
|
304
|
+
if forward_task.done() or receive_task.done():
|
|
305
|
+
print("[eval] One of the forwarding tasks completed unexpectedly")
|
|
306
|
+
if forward_task.done():
|
|
307
|
+
try:
|
|
308
|
+
forward_task.result()
|
|
309
|
+
except Exception as e:
|
|
310
|
+
print(f"[eval] Forward task error: {e}")
|
|
311
|
+
if self.websocket:
|
|
312
|
+
forward_task = asyncio.create_task(forward_agent_to_console())
|
|
313
|
+
if receive_task.done():
|
|
314
|
+
try:
|
|
315
|
+
receive_task.result()
|
|
316
|
+
except Exception as e:
|
|
317
|
+
print(f"[eval] Receive task error: {e}")
|
|
318
|
+
if self.websocket:
|
|
319
|
+
receive_task = asyncio.create_task(forward_console_to_agent())
|
|
320
|
+
break
|
|
321
|
+
await asyncio.sleep(1)
|
|
322
|
+
|
|
323
|
+
print("[eval] Eval complete, stopping forwarding tasks...")
|
|
324
|
+
receive_task.cancel()
|
|
325
|
+
try:
|
|
326
|
+
await asyncio.wait_for(receive_task, timeout=0.1)
|
|
327
|
+
except (TimeoutError, asyncio.CancelledError):
|
|
328
|
+
pass
|
|
329
|
+
|
|
330
|
+
forward_task.cancel()
|
|
331
|
+
try:
|
|
332
|
+
await asyncio.wait_for(forward_task, timeout=0.1)
|
|
333
|
+
except (TimeoutError, asyncio.CancelledError):
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
337
|
+
|
|
338
|
+
print(f"[eval] Fetching full conversation history from database...")
|
|
339
|
+
conversation_history = await self.fetch_full_conversation_history()
|
|
340
|
+
print(f"[eval] Retrieved {len(conversation_history)} messages from database")
|
|
341
|
+
|
|
342
|
+
agent_answer = extract_answer_from_conversation(conversation_history)
|
|
343
|
+
if agent_answer is not None:
|
|
344
|
+
print(f"[eval] Extracted answer: {json.dumps(agent_answer)[:200]}...")
|
|
345
|
+
else:
|
|
346
|
+
print("[eval] No answer extracted from conversation")
|
|
347
|
+
|
|
348
|
+
eval_result = EvalResult(
|
|
349
|
+
eval_id=eval_case.id,
|
|
350
|
+
conversation_history=conversation_history,
|
|
351
|
+
duration_ms=duration_ms,
|
|
352
|
+
agent_answer=agent_answer,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
if eval_case.grader:
|
|
356
|
+
print("[eval] Running binary grader...")
|
|
357
|
+
grader_type = eval_case.grader.get("type")
|
|
358
|
+
grader_config = eval_case.grader.get("config", {})
|
|
359
|
+
|
|
360
|
+
if agent_answer is None:
|
|
361
|
+
eval_result.grader_result = {
|
|
362
|
+
"passed": False,
|
|
363
|
+
"metrics": {},
|
|
364
|
+
"reasoning": "Failed to extract answer from conversation history",
|
|
365
|
+
"agent_answer": None
|
|
366
|
+
}
|
|
367
|
+
print("[eval] Grader result: FAIL (no answer extracted)")
|
|
368
|
+
elif grader_type in GRADER_REGISTRY:
|
|
369
|
+
grader_cls = GRADER_REGISTRY[grader_type]
|
|
370
|
+
grader = grader_cls()
|
|
371
|
+
grader_result = grader.evaluate(agent_answer, grader_config)
|
|
372
|
+
|
|
373
|
+
eval_result.grader_result = {
|
|
374
|
+
"passed": grader_result.passed,
|
|
375
|
+
"metrics": grader_result.metrics,
|
|
376
|
+
"reasoning": grader_result.reasoning,
|
|
377
|
+
"agent_answer": grader_result.agent_answer
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
print(f"[eval] Grader result: {'PASS' if grader_result.passed else 'FAIL'}")
|
|
381
|
+
print(f"[eval] Grader reasoning:\n{grader_result.reasoning}")
|
|
382
|
+
else:
|
|
383
|
+
print(f"[eval] Warning: Unknown grader type '{grader_type}'")
|
|
384
|
+
|
|
385
|
+
print(f"\n[eval] Eval completed in {duration_ms / 1000:.2f}s")
|
|
386
|
+
print(f"[eval] Total conversation turns: {len(conversation_history)}")
|
|
387
|
+
|
|
388
|
+
return eval_result
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
async def run_with_websocket_server(port: int, connection_handler, done_event: asyncio.Event):
|
|
392
|
+
async with websockets.serve(
|
|
393
|
+
connection_handler,
|
|
394
|
+
"localhost",
|
|
395
|
+
port,
|
|
396
|
+
max_size=10 * 1024 * 1024
|
|
397
|
+
):
|
|
398
|
+
print(f"[eval] WebSocket server listening on ws://localhost:{port}/agent")
|
|
399
|
+
print("[eval] Waiting for a running plot notebook to connect to the local agent.")
|
|
400
|
+
|
|
401
|
+
await done_event.wait()
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
async def run_eval_batch(eval_cases: list[Eval], port: int, sandbox_dir: Path, interactive: bool = False) -> list[EvalResult]:
|
|
405
|
+
server = EvalServer(sandbox_dir)
|
|
406
|
+
done_event = asyncio.Event()
|
|
407
|
+
results: list[EvalResult] = []
|
|
408
|
+
|
|
409
|
+
async def connection_handler(websocket):
|
|
410
|
+
if websocket.path == "/agent":
|
|
411
|
+
if server.agent_proc is not None:
|
|
412
|
+
print(f"[eval] Console reconnected, updating websocket")
|
|
413
|
+
server.websocket = websocket
|
|
414
|
+
try:
|
|
415
|
+
await done_event.wait()
|
|
416
|
+
except Exception:
|
|
417
|
+
pass
|
|
418
|
+
return
|
|
419
|
+
|
|
420
|
+
num_evals = len(eval_cases)
|
|
421
|
+
print(f"[eval] Console connected ({'single eval' if num_evals == 1 else f'batch of {num_evals} evals'})")
|
|
422
|
+
|
|
423
|
+
server.websocket = websocket
|
|
424
|
+
await server.start_agent()
|
|
425
|
+
|
|
426
|
+
await server.initialize_agent_session(websocket)
|
|
427
|
+
|
|
428
|
+
for eval_case in eval_cases:
|
|
429
|
+
await server.reset_for_next_test()
|
|
430
|
+
result = await server.run_eval(eval_case)
|
|
431
|
+
results.append(result)
|
|
432
|
+
|
|
433
|
+
if interactive:
|
|
434
|
+
print("\n[eval] Interactive mode - agent still running. Press Ctrl+C to exit.")
|
|
435
|
+
await server.keep_forwarding()
|
|
436
|
+
|
|
437
|
+
await server.stop_agent()
|
|
438
|
+
done_event.set()
|
|
439
|
+
else:
|
|
440
|
+
print(f"[eval] Unknown path: {websocket.request.path}")
|
|
441
|
+
await websocket.close()
|
|
442
|
+
|
|
443
|
+
await run_with_websocket_server(port, connection_handler, done_event)
|
|
444
|
+
return results
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def create_sandbox(sandbox_dir: Path, eval_case: Eval):
|
|
448
|
+
if sandbox_dir.exists():
|
|
449
|
+
print(f"[eval] Removing existing sandbox at {sandbox_dir}")
|
|
450
|
+
shutil.rmtree(sandbox_dir)
|
|
451
|
+
|
|
452
|
+
sandbox_dir.mkdir(parents=True, exist_ok=True)
|
|
453
|
+
print(f"[eval] Created fresh sandbox at {sandbox_dir}")
|
|
454
|
+
|
|
455
|
+
user_token = Path.home() / ".latch" / "token"
|
|
456
|
+
if user_token.exists():
|
|
457
|
+
(sandbox_dir / "token").write_text(user_token.read_text())
|
|
458
|
+
else:
|
|
459
|
+
(sandbox_dir / "token").write_text("local-dev-token")
|
|
460
|
+
|
|
461
|
+
(sandbox_dir / "session-id").write_text(f"eval-{eval_case.id}")
|
|
462
|
+
(sandbox_dir / "nucleus-url").write_text("https://nucleus.latch.bio")
|
|
463
|
+
(sandbox_dir / "id").write_text("0")
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
def write_results(results: list[EvalResult], output_path: Path):
|
|
467
|
+
evals = []
|
|
468
|
+
for r in results:
|
|
469
|
+
entry = {
|
|
470
|
+
"eval_id": r.eval_id,
|
|
471
|
+
"duration_ms": r.duration_ms,
|
|
472
|
+
"passed": r.grader_result.get("passed") if r.grader_result else None,
|
|
473
|
+
"reasoning": r.grader_result.get("reasoning") if r.grader_result else None,
|
|
474
|
+
"agent_answer": r.agent_answer,
|
|
475
|
+
}
|
|
476
|
+
evals.append(entry)
|
|
477
|
+
|
|
478
|
+
passed = sum(1 for e in evals if e["passed"] is True)
|
|
479
|
+
total = len(evals)
|
|
480
|
+
accuracy = passed / total if total > 0 else 0
|
|
481
|
+
|
|
482
|
+
output = {
|
|
483
|
+
"accuracy": accuracy,
|
|
484
|
+
"passed": passed,
|
|
485
|
+
"total": total,
|
|
486
|
+
"evals": evals,
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
output_path.write_text(json.dumps(output, indent=2))
|
|
490
|
+
print(f"[eval] Results written to {output_path}")
|
|
491
|
+
print(f"[eval] Accuracy: {passed}/{total} ({accuracy:.1%})")
|
|
492
|
+
|
|
493
|
+
workspaces_dir = output_path.parent / "workspaces"
|
|
494
|
+
workspaces_dir.mkdir(parents=True, exist_ok=True)
|
|
495
|
+
|
|
496
|
+
for r in results:
|
|
497
|
+
eval_dir = workspaces_dir / r.eval_id
|
|
498
|
+
eval_dir.mkdir(parents=True, exist_ok=True)
|
|
499
|
+
|
|
500
|
+
(eval_dir / "trajectory.json").write_text(json.dumps(r.trajectory, indent=2))
|
|
501
|
+
|
|
502
|
+
agent_log_lines = []
|
|
503
|
+
for event in r.trajectory:
|
|
504
|
+
agent_log_lines.append(json.dumps(event))
|
|
505
|
+
(eval_dir / "agent_output.log").write_text("\n".join(agent_log_lines))
|
|
506
|
+
|
|
507
|
+
if r.agent_answer is not None:
|
|
508
|
+
(eval_dir / "eval_answer.json").write_text(json.dumps(r.agent_answer, indent=2))
|
|
509
|
+
|
|
510
|
+
result_data = {
|
|
511
|
+
"eval": r.eval_id,
|
|
512
|
+
"model": "anthropic/claude-sonnet-4",
|
|
513
|
+
"agent": "plots-agent",
|
|
514
|
+
"passed": r.grader_result.get("passed") if r.grader_result else None,
|
|
515
|
+
"duration_s": r.duration_ms / 1000,
|
|
516
|
+
"agent_answer": r.agent_answer,
|
|
517
|
+
"grader_result": r.grader_result,
|
|
518
|
+
}
|
|
519
|
+
(eval_dir / "_result.json").write_text(json.dumps(result_data, indent=2))
|
|
520
|
+
|
|
521
|
+
print(f"[eval] Wrote trajectory for {r.eval_id} to {eval_dir}")
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
async def main():
|
|
525
|
+
parser = argparse.ArgumentParser(description="Run agent eval server")
|
|
526
|
+
parser.add_argument("--eval", help="Eval file or directory to run")
|
|
527
|
+
parser.add_argument("--output", "-o", help="Output file for results (default: results.json)")
|
|
528
|
+
parser.add_argument("--interactive", "-i", action="store_true", help="Keep agent running after eval for interaction")
|
|
529
|
+
parser.add_argument("--headless", action="store_true", help="Run in headless mode with temporary notebook")
|
|
530
|
+
args = parser.parse_args()
|
|
531
|
+
|
|
532
|
+
sandbox_dir = Path.cwd() / "sandboxes" / "batch"
|
|
533
|
+
eval_path = Path(args.eval)
|
|
534
|
+
output_path = Path(args.output) if args.output else Path("results.json")
|
|
535
|
+
|
|
536
|
+
eval_cases = []
|
|
537
|
+
if eval_path.is_dir():
|
|
538
|
+
print(f"[eval] Loading test cases from directory: {eval_path}")
|
|
539
|
+
for json_file in sorted(eval_path.rglob("*.json")):
|
|
540
|
+
with open(json_file) as f:
|
|
541
|
+
test_data = json.load(f)
|
|
542
|
+
eval_cases.append(Eval(**test_data))
|
|
543
|
+
print(f"[eval] Found {len(eval_cases)} test cases")
|
|
544
|
+
else:
|
|
545
|
+
with open(eval_path) as f:
|
|
546
|
+
test_data = json.load(f)
|
|
547
|
+
eval_cases.append(Eval(**test_data))
|
|
548
|
+
|
|
549
|
+
if not eval_cases:
|
|
550
|
+
print("[eval] No test cases found")
|
|
551
|
+
return
|
|
552
|
+
|
|
553
|
+
create_sandbox(sandbox_dir, eval_cases[0])
|
|
554
|
+
|
|
555
|
+
try:
|
|
556
|
+
if args.headless:
|
|
557
|
+
results = await run_eval_batch_headless(eval_cases, sandbox_dir)
|
|
558
|
+
else:
|
|
559
|
+
results = await run_eval_batch(eval_cases, 8765, sandbox_dir, interactive=args.interactive)
|
|
560
|
+
print(f"\n[eval] Batch complete: {len(results)}/{len(eval_cases)} evals completed")
|
|
561
|
+
write_results(results, output_path)
|
|
562
|
+
except KeyboardInterrupt:
|
|
563
|
+
print("\n[eval] Interrupted by user")
|
|
564
|
+
except Exception:
|
|
565
|
+
print("[eval] Error during eval execution")
|
|
566
|
+
raise
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
if __name__ == "__main__":
|
|
570
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
os.environ.setdefault("auto_reload", "false")
|
|
4
|
+
os.environ.setdefault("logging_mode", "console")
|
|
5
|
+
os.environ.setdefault("domain", "latch.bio")
|
|
6
|
+
os.environ.setdefault("DD_VERSION", "eval")
|
|
7
|
+
os.environ.setdefault("DD_SERVICE", "latch-plots-eval")
|
|
8
|
+
os.environ.setdefault("DD_ENV", "eval")
|
|
9
|
+
os.environ.setdefault("DD_AGENT_HOST", "localhost")
|
|
10
|
+
os.environ.setdefault("DD_TRACE_ENABLED", "false")
|
|
11
|
+
os.environ.setdefault("DD_PROFILING_ENABLED", "false")
|
|
12
|
+
os.environ.setdefault("DD_RUNTIME_METRICS_ENABLED", "false")
|
|
13
|
+
os.environ.setdefault("OTEL_SDK_DISABLED", "true")
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .base import BinaryGrader, GraderResult, get_nested_value
|
|
2
|
+
from .numeric import NumericToleranceGrader
|
|
3
|
+
from .marker_gene import MarkerGenePrecisionRecallGrader, MarkerGeneSeparationGrader
|
|
4
|
+
from .label_set import LabelSetJaccardGrader
|
|
5
|
+
from .distribution import DistributionComparisonGrader
|
|
6
|
+
from .spatial import SpatialAdjacencyGrader
|
|
7
|
+
from .multiple_choice import MultipleChoiceGrader
|
|
8
|
+
|
|
9
|
+
GRADER_REGISTRY = {
|
|
10
|
+
"numeric_tolerance": NumericToleranceGrader,
|
|
11
|
+
"label_set_jaccard": LabelSetJaccardGrader,
|
|
12
|
+
"jaccard_label_set": LabelSetJaccardGrader,
|
|
13
|
+
"distribution_comparison": DistributionComparisonGrader,
|
|
14
|
+
"marker_gene_precision_recall": MarkerGenePrecisionRecallGrader,
|
|
15
|
+
"marker_gene_separation": MarkerGeneSeparationGrader,
|
|
16
|
+
"spatial_adjacency": SpatialAdjacencyGrader,
|
|
17
|
+
"multiple_choice": MultipleChoiceGrader,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_grader(grader_type: str) -> BinaryGrader:
|
|
22
|
+
if grader_type not in GRADER_REGISTRY:
|
|
23
|
+
raise ValueError(f"Unknown grader type: {grader_type}. Available: {list(GRADER_REGISTRY.keys())}")
|
|
24
|
+
return GRADER_REGISTRY[grader_type]()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"BinaryGrader",
|
|
29
|
+
"GraderResult",
|
|
30
|
+
"get_nested_value",
|
|
31
|
+
"NumericToleranceGrader",
|
|
32
|
+
"MarkerGenePrecisionRecallGrader",
|
|
33
|
+
"MarkerGeneSeparationGrader",
|
|
34
|
+
"LabelSetJaccardGrader",
|
|
35
|
+
"DistributionComparisonGrader",
|
|
36
|
+
"SpatialAdjacencyGrader",
|
|
37
|
+
"MultipleChoiceGrader",
|
|
38
|
+
"GRADER_REGISTRY",
|
|
39
|
+
"get_grader",
|
|
40
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclass
|
|
5
|
+
class GraderResult:
|
|
6
|
+
passed: bool
|
|
7
|
+
metrics: dict
|
|
8
|
+
reasoning: str
|
|
9
|
+
agent_answer: dict | None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_nested_value(obj: dict, key: str) -> tuple[any, bool]:
|
|
13
|
+
if "." not in key:
|
|
14
|
+
return obj.get(key), key in obj
|
|
15
|
+
parts = key.split(".")
|
|
16
|
+
current = obj
|
|
17
|
+
for part in parts:
|
|
18
|
+
if not isinstance(current, dict) or part not in current:
|
|
19
|
+
return None, False
|
|
20
|
+
current = current[part]
|
|
21
|
+
return current, True
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class BinaryGrader:
|
|
25
|
+
def evaluate_answer(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
26
|
+
raise NotImplementedError
|
|
27
|
+
|
|
28
|
+
def evaluate(self, agent_answer: dict, config: dict) -> GraderResult:
|
|
29
|
+
return self.evaluate_answer(agent_answer, config)
|