latch-eval-tools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- latch_eval_tools/__init__.py +64 -0
- latch_eval_tools/answer_extraction.py +35 -0
- latch_eval_tools/cli/__init__.py +0 -0
- latch_eval_tools/cli/eval_lint.py +185 -0
- latch_eval_tools/eval_server.py +570 -0
- latch_eval_tools/faas_utils.py +13 -0
- latch_eval_tools/graders/__init__.py +40 -0
- latch_eval_tools/graders/base.py +29 -0
- latch_eval_tools/graders/distribution.py +102 -0
- latch_eval_tools/graders/label_set.py +75 -0
- latch_eval_tools/graders/marker_gene.py +317 -0
- latch_eval_tools/graders/multiple_choice.py +38 -0
- latch_eval_tools/graders/numeric.py +137 -0
- latch_eval_tools/graders/spatial.py +93 -0
- latch_eval_tools/harness/__init__.py +27 -0
- latch_eval_tools/harness/claudecode.py +212 -0
- latch_eval_tools/harness/minisweagent.py +265 -0
- latch_eval_tools/harness/plotsagent.py +156 -0
- latch_eval_tools/harness/runner.py +191 -0
- latch_eval_tools/harness/utils.py +191 -0
- latch_eval_tools/headless_eval_server.py +727 -0
- latch_eval_tools/linter/__init__.py +25 -0
- latch_eval_tools/linter/explanations.py +331 -0
- latch_eval_tools/linter/runner.py +146 -0
- latch_eval_tools/linter/schema.py +126 -0
- latch_eval_tools/linter/validators.py +595 -0
- latch_eval_tools/types.py +30 -0
- latch_eval_tools/wrapper_entrypoint.py +316 -0
- latch_eval_tools-0.1.0.dist-info/METADATA +118 -0
- latch_eval_tools-0.1.0.dist-info/RECORD +33 -0
- latch_eval_tools-0.1.0.dist-info/WHEEL +4 -0
- latch_eval_tools-0.1.0.dist-info/entry_points.txt +2 -0
- latch_eval_tools-0.1.0.dist-info/licenses/LICENSE +1 -0
|
@@ -0,0 +1,727 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import textwrap
|
|
6
|
+
import time
|
|
7
|
+
import uuid
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import aiohttp
|
|
11
|
+
import websockets
|
|
12
|
+
|
|
13
|
+
from latch_eval_tools.graders import GRADER_REGISTRY
|
|
14
|
+
from latch_eval_tools.answer_extraction import extract_answer_from_conversation
|
|
15
|
+
from latch_eval_tools.types import Eval, EvalResult
|
|
16
|
+
|
|
17
|
+
faas_runtime_dir = Path(os.environ.get("LATCH_PLOTS_FAAS_PATH", "/root/latch-plots-faas")) / "runtime" / "mount"
|
|
18
|
+
sys.path.insert(0, str(faas_runtime_dir))
|
|
19
|
+
|
|
20
|
+
from utils import gql_query
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_auth_token() -> str:
|
|
24
|
+
return f"Latch-SDK-Token {(Path.home() / '.latch' / 'token').read_text().strip()}"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
async def get_workspace_id_from_token() -> str:
|
|
28
|
+
auth = get_auth_token()
|
|
29
|
+
|
|
30
|
+
resp = await gql_query(
|
|
31
|
+
auth=auth,
|
|
32
|
+
query="""
|
|
33
|
+
query GetAccount {
|
|
34
|
+
accountInfoCurrentOrRegister {
|
|
35
|
+
id
|
|
36
|
+
}
|
|
37
|
+
}
|
|
38
|
+
""",
|
|
39
|
+
variables={},
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return resp["data"]["accountInfoCurrentOrRegister"]["id"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
async def create_eval_notebook(workspace_id: str, eval_id: str) -> str:
|
|
46
|
+
auth = get_auth_token()
|
|
47
|
+
|
|
48
|
+
display_name = f"eval-{eval_id}-{uuid.uuid4().hex[:8]}"
|
|
49
|
+
resp = await gql_query(
|
|
50
|
+
auth=auth,
|
|
51
|
+
query="""
|
|
52
|
+
mutation PlotsCreateNotebook($wsId: BigInt!, $displayName: String!) {
|
|
53
|
+
createPlotNotebookInfo(
|
|
54
|
+
input: {
|
|
55
|
+
plotNotebookInfo: {
|
|
56
|
+
ownerId: $wsId
|
|
57
|
+
displayName: $displayName
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
) {
|
|
61
|
+
plotNotebookInfo {
|
|
62
|
+
id
|
|
63
|
+
}
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
""",
|
|
67
|
+
variables={
|
|
68
|
+
"wsId": workspace_id,
|
|
69
|
+
"displayName": display_name,
|
|
70
|
+
},
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
notebook_id = resp["data"]["createPlotNotebookInfo"]["plotNotebookInfo"]["id"]
|
|
74
|
+
print(f"[headless] Created eval notebook: {notebook_id}")
|
|
75
|
+
|
|
76
|
+
await gql_query(
|
|
77
|
+
auth=auth,
|
|
78
|
+
query="""
|
|
79
|
+
mutation DeletePlotNotebook($id: BigInt!) {
|
|
80
|
+
deletePlotNotebook(input: { argNotebookId: $id }) {
|
|
81
|
+
clientMutationId
|
|
82
|
+
}
|
|
83
|
+
}
|
|
84
|
+
""",
|
|
85
|
+
variables={"id": notebook_id},
|
|
86
|
+
)
|
|
87
|
+
print(f"[headless] Deleted eval notebook (hidden from frontend list): {notebook_id}")
|
|
88
|
+
|
|
89
|
+
return notebook_id
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def get_or_create_session(notebook_id: str) -> int:
|
|
93
|
+
auth = get_auth_token()
|
|
94
|
+
|
|
95
|
+
resp = await gql_query(
|
|
96
|
+
auth=auth,
|
|
97
|
+
query="""
|
|
98
|
+
query AgentSessionsByNotebook($notebookId: BigInt!) {
|
|
99
|
+
agentSessions(
|
|
100
|
+
filter: {plotNotebookId: {equalTo: $notebookId}}
|
|
101
|
+
orderBy: [CREATED_AT_DESC]
|
|
102
|
+
) {
|
|
103
|
+
nodes { id removedAt }
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
""",
|
|
107
|
+
variables={"notebookId": notebook_id},
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
nodes = resp.get("data", {}).get("agentSessions", {}).get("nodes", [])
|
|
111
|
+
active_sessions = [n for n in nodes if n.get("removedAt") is None]
|
|
112
|
+
|
|
113
|
+
if active_sessions:
|
|
114
|
+
session_id = int(active_sessions[0]["id"])
|
|
115
|
+
print(f"[headless] Using existing session: {session_id}")
|
|
116
|
+
return session_id
|
|
117
|
+
|
|
118
|
+
print(f"[headless] Creating new session for notebook {notebook_id}...")
|
|
119
|
+
await gql_query(
|
|
120
|
+
auth=auth,
|
|
121
|
+
query="""
|
|
122
|
+
mutation CreateAgentSession($notebookId: BigInt!, $metadata: JSON) {
|
|
123
|
+
createAgentSession(
|
|
124
|
+
input: {agentSession: {plotNotebookId: $notebookId, metadata: $metadata}}
|
|
125
|
+
) {
|
|
126
|
+
clientMutationId
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
""",
|
|
130
|
+
variables={"notebookId": notebook_id, "metadata": None},
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
resp = await gql_query(
|
|
134
|
+
auth=auth,
|
|
135
|
+
query="""
|
|
136
|
+
query AgentSessionsByNotebook($notebookId: BigInt!) {
|
|
137
|
+
agentSessions(
|
|
138
|
+
filter: {plotNotebookId: {equalTo: $notebookId}}
|
|
139
|
+
orderBy: [CREATED_AT_DESC]
|
|
140
|
+
) {
|
|
141
|
+
nodes { id }
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
""",
|
|
145
|
+
variables={"notebookId": notebook_id},
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
nodes = resp.get("data", {}).get("agentSessions", {}).get("nodes", [])
|
|
149
|
+
if not nodes:
|
|
150
|
+
raise RuntimeError("Failed to create session")
|
|
151
|
+
|
|
152
|
+
session_id = int(nodes[0]["id"])
|
|
153
|
+
print(f"[headless] Created session: {session_id}")
|
|
154
|
+
return session_id
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class HeadlessEvalServer:
|
|
158
|
+
def __init__(self, sandbox_dir: Path, port: int = 5000):
|
|
159
|
+
self.sandbox_dir = sandbox_dir
|
|
160
|
+
self.port = port
|
|
161
|
+
self.workspace_id: str | None = None
|
|
162
|
+
self.notebook_id: str | None = None
|
|
163
|
+
self.server_proc = None
|
|
164
|
+
self.websocket = None
|
|
165
|
+
self.session_id = None
|
|
166
|
+
self.eval_complete = False
|
|
167
|
+
self.conversation_history: list[dict] = []
|
|
168
|
+
self.current_streaming_message: dict | None = None
|
|
169
|
+
self.current_streaming_blocks: list[dict] = []
|
|
170
|
+
self.trajectory: list[dict] = []
|
|
171
|
+
self.trajectory_session_id: str = ""
|
|
172
|
+
self.current_usage: dict | None = None
|
|
173
|
+
self.turn_number: int = 0
|
|
174
|
+
self.eval_start_time: float = 0
|
|
175
|
+
|
|
176
|
+
async def start_server(self):
|
|
177
|
+
print("[headless] Starting runtime server via wrapper...")
|
|
178
|
+
|
|
179
|
+
if self.notebook_id is None:
|
|
180
|
+
raise RuntimeError("notebook_id must be set before starting server")
|
|
181
|
+
|
|
182
|
+
faas_venv_python_override = os.environ.get("PLOTS_FAAS_PYTHON")
|
|
183
|
+
if faas_venv_python_override:
|
|
184
|
+
faas_venv_python = Path(faas_venv_python_override)
|
|
185
|
+
else:
|
|
186
|
+
faas_dir = Path(os.environ.get("LATCH_PLOTS_FAAS_PATH", "/root/latch-plots-faas"))
|
|
187
|
+
faas_venv_python = faas_dir / ".venv" / "bin" / "python"
|
|
188
|
+
|
|
189
|
+
# Use wrapper_entrypoint from this package
|
|
190
|
+
wrapper_script = Path(__file__).parent / "wrapper_entrypoint.py"
|
|
191
|
+
|
|
192
|
+
if not faas_venv_python.exists():
|
|
193
|
+
raise RuntimeError(f"latch-plots-faas venv not found at {faas_venv_python}")
|
|
194
|
+
|
|
195
|
+
if not wrapper_script.exists():
|
|
196
|
+
raise RuntimeError(f"Wrapper script not found at {wrapper_script}")
|
|
197
|
+
|
|
198
|
+
cmd = [
|
|
199
|
+
str(faas_venv_python),
|
|
200
|
+
"-u",
|
|
201
|
+
str(wrapper_script),
|
|
202
|
+
"--sandbox-dir", str(self.sandbox_dir),
|
|
203
|
+
"--port", str(self.port),
|
|
204
|
+
"--notebook-id", str(self.notebook_id),
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
self.server_proc = await asyncio.create_subprocess_exec(
|
|
208
|
+
*cmd,
|
|
209
|
+
stdin=asyncio.subprocess.DEVNULL,
|
|
210
|
+
stdout=asyncio.subprocess.PIPE,
|
|
211
|
+
stderr=asyncio.subprocess.PIPE,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
async def stream_output(stream, prefix=""):
|
|
215
|
+
while True:
|
|
216
|
+
try:
|
|
217
|
+
line = await stream.readline()
|
|
218
|
+
if not line:
|
|
219
|
+
break
|
|
220
|
+
decoded = line.decode().rstrip()
|
|
221
|
+
if len(decoded) > 1000:
|
|
222
|
+
decoded = decoded[:1000] + "... [TRUNCATED]"
|
|
223
|
+
print(f"[server] {prefix}{decoded}", flush=True)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
print(f"[server] {prefix}[Error reading output: {e}]", flush=True)
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
asyncio.create_task(stream_output(self.server_proc.stdout, ""))
|
|
229
|
+
asyncio.create_task(stream_output(self.server_proc.stderr, "[stderr] "))
|
|
230
|
+
|
|
231
|
+
await self.wait_for_ready()
|
|
232
|
+
|
|
233
|
+
async def wait_for_ready(self, timeout: float = 60.0, poll_interval: float = 1.0):
|
|
234
|
+
print("[headless] Waiting for server to be ready...")
|
|
235
|
+
start = time.time()
|
|
236
|
+
server_responded = False
|
|
237
|
+
|
|
238
|
+
while time.time() - start < timeout:
|
|
239
|
+
try:
|
|
240
|
+
async with aiohttp.ClientSession() as session:
|
|
241
|
+
async with session.get(f"http://localhost:{self.port}/readyz", timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
|
242
|
+
if resp.status == 200:
|
|
243
|
+
print("[headless] Server is ready!")
|
|
244
|
+
return
|
|
245
|
+
if resp.status == 500:
|
|
246
|
+
if not server_responded:
|
|
247
|
+
print("[headless] Server responding, waiting for agent...")
|
|
248
|
+
server_responded = True
|
|
249
|
+
except aiohttp.ClientConnectorError:
|
|
250
|
+
pass
|
|
251
|
+
except Exception:
|
|
252
|
+
pass
|
|
253
|
+
|
|
254
|
+
if self.server_proc.returncode is not None:
|
|
255
|
+
raise RuntimeError(f"Server process exited unexpectedly with code {self.server_proc.returncode}")
|
|
256
|
+
|
|
257
|
+
await asyncio.sleep(poll_interval)
|
|
258
|
+
|
|
259
|
+
if server_responded:
|
|
260
|
+
print("[headless] Server responding but agent not ready, proceeding anyway")
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
raise TimeoutError("Server did not become ready in time")
|
|
264
|
+
|
|
265
|
+
async def connect(self):
|
|
266
|
+
print("[headless] Waiting for agent to be ready...")
|
|
267
|
+
await asyncio.sleep(3)
|
|
268
|
+
|
|
269
|
+
print("[headless] Connecting to /agent WebSocket...")
|
|
270
|
+
|
|
271
|
+
for attempt in range(5):
|
|
272
|
+
try:
|
|
273
|
+
self.websocket = await websockets.connect(
|
|
274
|
+
f"ws://localhost:{self.port}/agent",
|
|
275
|
+
max_size=10 * 1024 * 1024,
|
|
276
|
+
)
|
|
277
|
+
break
|
|
278
|
+
except websockets.exceptions.InvalidStatus as e:
|
|
279
|
+
print(f"[headless] WebSocket connection attempt {attempt + 1} failed: {e}")
|
|
280
|
+
if attempt < 4:
|
|
281
|
+
await asyncio.sleep(2)
|
|
282
|
+
else:
|
|
283
|
+
raise
|
|
284
|
+
|
|
285
|
+
self.session_id = await get_or_create_session(self.notebook_id)
|
|
286
|
+
|
|
287
|
+
sdk_token = (Path.home() / ".latch" / "token").read_text().strip()
|
|
288
|
+
local_storage = {
|
|
289
|
+
"plots.is_agent_controlled": "yes",
|
|
290
|
+
"plots.is_eval_harness": "yes",
|
|
291
|
+
"viewAccountId": self.workspace_id,
|
|
292
|
+
"latch.authData": json.dumps({
|
|
293
|
+
"status": "done",
|
|
294
|
+
"auth0Data": {
|
|
295
|
+
"idToken": sdk_token,
|
|
296
|
+
"idTokenPayload": {
|
|
297
|
+
"sub": "agent-session",
|
|
298
|
+
"latch.bio/tos_ok": "true",
|
|
299
|
+
},
|
|
300
|
+
},
|
|
301
|
+
}),
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
init_msg = {
|
|
305
|
+
"type": "init",
|
|
306
|
+
"notebook_id": self.notebook_id,
|
|
307
|
+
"session_id": self.session_id,
|
|
308
|
+
"local_storage": local_storage,
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
await self.websocket.send(json.dumps(init_msg))
|
|
312
|
+
print(f"[headless] Sent init message for notebook {self.notebook_id} with session_id {self.session_id}")
|
|
313
|
+
|
|
314
|
+
while True:
|
|
315
|
+
msg_str = await self.websocket.recv()
|
|
316
|
+
msg = json.loads(msg_str)
|
|
317
|
+
msg_type = msg.get("type")
|
|
318
|
+
|
|
319
|
+
if msg_type == "agent_status" and msg.get("status") == "ready":
|
|
320
|
+
print(f"[headless] Agent is ready! session_id={self.session_id}")
|
|
321
|
+
break
|
|
322
|
+
elif msg_type == "agent_error":
|
|
323
|
+
raise RuntimeError(f"Agent error: {msg.get('error')}")
|
|
324
|
+
|
|
325
|
+
print(f"[headless] Waiting for agent ready, got: {msg_type}")
|
|
326
|
+
|
|
327
|
+
def get_conversation_history(self) -> list[dict]:
|
|
328
|
+
return list(self.conversation_history)
|
|
329
|
+
|
|
330
|
+
def get_trajectory(self) -> list[dict]:
|
|
331
|
+
return list(self.trajectory)
|
|
332
|
+
|
|
333
|
+
def init_trajectory(self, eval_id: str):
|
|
334
|
+
self.trajectory_session_id = str(uuid.uuid4())
|
|
335
|
+
self.trajectory = []
|
|
336
|
+
self.turn_number = 0
|
|
337
|
+
self.eval_start_time = time.time()
|
|
338
|
+
self.trajectory.append({
|
|
339
|
+
"type": "system",
|
|
340
|
+
"subtype": "init",
|
|
341
|
+
"timestamp": self.eval_start_time,
|
|
342
|
+
"session_id": self.trajectory_session_id,
|
|
343
|
+
"eval_id": eval_id,
|
|
344
|
+
"notebook_id": self.notebook_id,
|
|
345
|
+
"tools": [
|
|
346
|
+
"create_cell",
|
|
347
|
+
"delete_cell",
|
|
348
|
+
"update_cell",
|
|
349
|
+
"run_cell",
|
|
350
|
+
"execute_code",
|
|
351
|
+
"get_context",
|
|
352
|
+
"request_reactivity_summary",
|
|
353
|
+
"submit_response",
|
|
354
|
+
],
|
|
355
|
+
"model": "claude-sonnet-4-20250514",
|
|
356
|
+
"agent": "plots-agent",
|
|
357
|
+
"uuid": str(uuid.uuid4()),
|
|
358
|
+
})
|
|
359
|
+
|
|
360
|
+
def add_assistant_to_trajectory(self, message: dict):
|
|
361
|
+
self.turn_number += 1
|
|
362
|
+
self.trajectory.append({
|
|
363
|
+
"type": "assistant",
|
|
364
|
+
"message": {
|
|
365
|
+
"role": "assistant",
|
|
366
|
+
"content": message.get("content", []),
|
|
367
|
+
},
|
|
368
|
+
"turn": self.turn_number,
|
|
369
|
+
"timestamp": time.time(),
|
|
370
|
+
"elapsed_s": time.time() - self.eval_start_time,
|
|
371
|
+
"usage": self.current_usage,
|
|
372
|
+
"session_id": self.trajectory_session_id,
|
|
373
|
+
"uuid": str(uuid.uuid4()),
|
|
374
|
+
})
|
|
375
|
+
self.current_usage = None
|
|
376
|
+
|
|
377
|
+
def add_tool_result_to_trajectory(self, tool_use_id: str, result: str, is_error: bool = False, cell_id: str | None = None):
|
|
378
|
+
content = [{
|
|
379
|
+
"type": "tool_result",
|
|
380
|
+
"tool_use_id": tool_use_id,
|
|
381
|
+
"content": result,
|
|
382
|
+
"is_error": is_error,
|
|
383
|
+
}]
|
|
384
|
+
entry = {
|
|
385
|
+
"type": "user",
|
|
386
|
+
"message": {
|
|
387
|
+
"role": "user",
|
|
388
|
+
"content": content,
|
|
389
|
+
},
|
|
390
|
+
"timestamp": time.time(),
|
|
391
|
+
"elapsed_s": time.time() - self.eval_start_time,
|
|
392
|
+
"session_id": self.trajectory_session_id,
|
|
393
|
+
"uuid": str(uuid.uuid4()),
|
|
394
|
+
"tool_use_result": f"{'Error: ' if is_error else ''}{result[:500]}{'...' if len(result) > 500 else ''}",
|
|
395
|
+
}
|
|
396
|
+
if cell_id:
|
|
397
|
+
entry["cell_id"] = cell_id
|
|
398
|
+
self.trajectory.append(entry)
|
|
399
|
+
|
|
400
|
+
def add_to_history(self, msg: dict):
|
|
401
|
+
msg_type = msg.get("type")
|
|
402
|
+
if msg_type in ("anthropic_message", "user_message"):
|
|
403
|
+
self.conversation_history.append(msg)
|
|
404
|
+
elif msg_type == "agent_stream_start":
|
|
405
|
+
self.current_streaming_message = {
|
|
406
|
+
"type": "anthropic_message",
|
|
407
|
+
"role": "assistant",
|
|
408
|
+
"content": [],
|
|
409
|
+
}
|
|
410
|
+
self.current_streaming_blocks = []
|
|
411
|
+
elif msg_type == "agent_stream_block_start":
|
|
412
|
+
block_type = msg.get("block_type")
|
|
413
|
+
if block_type == "text":
|
|
414
|
+
self.current_streaming_blocks.append({"type": "text", "text": ""})
|
|
415
|
+
elif block_type == "thinking":
|
|
416
|
+
self.current_streaming_blocks.append({"type": "thinking", "thinking": ""})
|
|
417
|
+
elif block_type == "tool_use":
|
|
418
|
+
self.current_streaming_blocks.append({
|
|
419
|
+
"type": "tool_use",
|
|
420
|
+
"id": msg.get("block_id"),
|
|
421
|
+
"name": msg.get("block_name"),
|
|
422
|
+
"input": {},
|
|
423
|
+
})
|
|
424
|
+
elif msg_type == "agent_stream_delta":
|
|
425
|
+
block_index = msg.get("block_index", 0)
|
|
426
|
+
delta = msg.get("delta", "")
|
|
427
|
+
if block_index < len(self.current_streaming_blocks):
|
|
428
|
+
block = self.current_streaming_blocks[block_index]
|
|
429
|
+
if block.get("type") == "text":
|
|
430
|
+
block["text"] += delta
|
|
431
|
+
elif block.get("type") == "thinking":
|
|
432
|
+
block["thinking"] += delta
|
|
433
|
+
elif block.get("type") == "tool_use":
|
|
434
|
+
block["input_raw"] = block.get("input_raw", "") + delta
|
|
435
|
+
elif msg_type == "agent_usage_update":
|
|
436
|
+
self.current_usage = msg.get("usage")
|
|
437
|
+
elif msg_type == "agent_stream_complete":
|
|
438
|
+
if self.current_streaming_message is not None:
|
|
439
|
+
for block in self.current_streaming_blocks:
|
|
440
|
+
if block.get("type") == "tool_use" and "input_raw" in block:
|
|
441
|
+
try:
|
|
442
|
+
block["input"] = json.loads(block.pop("input_raw"))
|
|
443
|
+
except json.JSONDecodeError:
|
|
444
|
+
block["input"] = {}
|
|
445
|
+
self.current_streaming_message["content"] = self.current_streaming_blocks
|
|
446
|
+
self.conversation_history.append(self.current_streaming_message)
|
|
447
|
+
self.add_assistant_to_trajectory(self.current_streaming_message)
|
|
448
|
+
print(f"[headless] Built message with {len(self.current_streaming_blocks)} blocks")
|
|
449
|
+
self.current_streaming_message = None
|
|
450
|
+
self.current_streaming_blocks = []
|
|
451
|
+
elif msg_type == "kernel_message":
|
|
452
|
+
inner_msg = msg.get("message", {})
|
|
453
|
+
if inner_msg.get("type") == "cell_result":
|
|
454
|
+
cell_id = inner_msg.get("cell_id", "")
|
|
455
|
+
has_exception = inner_msg.get("has_exception", False)
|
|
456
|
+
logs = inner_msg.get("logs", "")
|
|
457
|
+
exception = inner_msg.get("exception")
|
|
458
|
+
result_str = logs if logs else "Cell executed successfully"
|
|
459
|
+
if has_exception and exception:
|
|
460
|
+
result_str = f"Exception: {exception}\n{logs}"
|
|
461
|
+
tool_use_id = self.find_last_tool_use_id_for_cell(cell_id)
|
|
462
|
+
if tool_use_id:
|
|
463
|
+
self.add_tool_result_to_trajectory(tool_use_id, result_str, is_error=has_exception, cell_id=cell_id)
|
|
464
|
+
|
|
465
|
+
def find_last_tool_use_id_for_cell(self, cell_id: str) -> str | None:
|
|
466
|
+
for entry in reversed(self.trajectory):
|
|
467
|
+
if entry.get("type") == "assistant":
|
|
468
|
+
content = entry.get("message", {}).get("content", [])
|
|
469
|
+
for block in content:
|
|
470
|
+
if block.get("type") == "tool_use" and block.get("name") == "create_cell":
|
|
471
|
+
return block.get("id")
|
|
472
|
+
return None
|
|
473
|
+
|
|
474
|
+
async def handle_agent_action(self, msg: dict):
|
|
475
|
+
action = msg.get("action")
|
|
476
|
+
tx_id = msg.get("tx_id")
|
|
477
|
+
params = msg.get("params", {})
|
|
478
|
+
|
|
479
|
+
print(f"[headless] Handling action: {action} (tx_id={tx_id})")
|
|
480
|
+
|
|
481
|
+
response = {
|
|
482
|
+
"type": "agent_action_response",
|
|
483
|
+
"tx_id": tx_id,
|
|
484
|
+
"status": "success",
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
if action == "get_context":
|
|
488
|
+
response["context"] = {
|
|
489
|
+
"cells": [],
|
|
490
|
+
"selected_cells": [],
|
|
491
|
+
"data_tree": {},
|
|
492
|
+
}
|
|
493
|
+
elif action == "create_cell":
|
|
494
|
+
cell_id = f"cell_{uuid.uuid4().hex[:8]}"
|
|
495
|
+
tf_id = f"tf_{uuid.uuid4().hex[:8]}"
|
|
496
|
+
response["cell_id"] = cell_id
|
|
497
|
+
response["tf_id"] = tf_id
|
|
498
|
+
response["title"] = params.get("title", "")
|
|
499
|
+
if params.get("auto_run"):
|
|
500
|
+
asyncio.create_task(self.send_mock_cell_result(cell_id))
|
|
501
|
+
elif action == "delete_cell":
|
|
502
|
+
response["deleted"] = True
|
|
503
|
+
elif action == "update_cell":
|
|
504
|
+
response["updated"] = True
|
|
505
|
+
elif action == "run_cell":
|
|
506
|
+
response["started"] = True
|
|
507
|
+
cell_id = params.get("cell_id")
|
|
508
|
+
if cell_id:
|
|
509
|
+
asyncio.create_task(self.send_mock_cell_result(cell_id))
|
|
510
|
+
else:
|
|
511
|
+
response["status"] = "error"
|
|
512
|
+
response["error"] = f"Unknown action: {action}"
|
|
513
|
+
|
|
514
|
+
await self.websocket.send(json.dumps(response))
|
|
515
|
+
|
|
516
|
+
async def send_mock_cell_result(self, cell_id: str):
|
|
517
|
+
await asyncio.sleep(0.5)
|
|
518
|
+
result_msg = {
|
|
519
|
+
"type": "kernel_message",
|
|
520
|
+
"message": {
|
|
521
|
+
"type": "cell_result",
|
|
522
|
+
"cell_id": cell_id,
|
|
523
|
+
"has_exception": False,
|
|
524
|
+
"outputs": [],
|
|
525
|
+
}
|
|
526
|
+
}
|
|
527
|
+
if self.websocket:
|
|
528
|
+
await self.websocket.send(json.dumps(result_msg))
|
|
529
|
+
print(f"[headless] Sent mock cell result for {cell_id}")
|
|
530
|
+
|
|
531
|
+
def clear_history(self):
|
|
532
|
+
self.conversation_history.clear()
|
|
533
|
+
|
|
534
|
+
def check_for_completion(self) -> bool:
|
|
535
|
+
for payload in self.conversation_history:
|
|
536
|
+
if payload.get("type") == "anthropic_message" and payload.get("role") == "assistant":
|
|
537
|
+
content = payload.get("content", [])
|
|
538
|
+
for block in content:
|
|
539
|
+
if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == "submit_response":
|
|
540
|
+
tool_input = block.get("input", {})
|
|
541
|
+
if tool_input.get("next_status") == "done":
|
|
542
|
+
return True
|
|
543
|
+
return False
|
|
544
|
+
|
|
545
|
+
async def clear_agent_history(self):
|
|
546
|
+
print("[headless] Clearing agent history...")
|
|
547
|
+
self.clear_history()
|
|
548
|
+
await self.websocket.send(json.dumps({"type": "agent_clear_history"}))
|
|
549
|
+
await asyncio.sleep(1)
|
|
550
|
+
|
|
551
|
+
async def run_eval(self, eval_case: Eval) -> EvalResult:
|
|
552
|
+
print(f"\n{'=' * 70}")
|
|
553
|
+
print(f"Running eval: {eval_case.id}")
|
|
554
|
+
print("=" * 70)
|
|
555
|
+
|
|
556
|
+
start_time = time.time()
|
|
557
|
+
self.eval_complete = False
|
|
558
|
+
self.init_trajectory(eval_case.id)
|
|
559
|
+
|
|
560
|
+
data_context = ""
|
|
561
|
+
if eval_case.data_node:
|
|
562
|
+
data_nodes = eval_case.data_node if isinstance(eval_case.data_node, list) else [eval_case.data_node]
|
|
563
|
+
contextual_data = []
|
|
564
|
+
for node in data_nodes:
|
|
565
|
+
contextual_data.append({
|
|
566
|
+
"type": "File",
|
|
567
|
+
"path": node,
|
|
568
|
+
"id": node.replace("latch:///", "").replace(".csv", "").replace(".h5ad", ""),
|
|
569
|
+
})
|
|
570
|
+
data_context = f"\n\nHere is the context of the selected nodes the user would like to use: <ContextualNodeData>{json.dumps(contextual_data)}</ContextualNodeData>"
|
|
571
|
+
|
|
572
|
+
initial_query = textwrap.dedent(f"""
|
|
573
|
+
{eval_case.task}
|
|
574
|
+
|
|
575
|
+
IMPORTANT: When you finish this task, include your answer in your submit_response summary as raw JSON (no markdown code fences) wrapped in <EVAL_ANSWER></EVAL_ANSWER> tags.
|
|
576
|
+
|
|
577
|
+
Example format for your summary:
|
|
578
|
+
<EVAL_ANSWER>
|
|
579
|
+
{{"field1": value1, "field2": value2}}
|
|
580
|
+
</EVAL_ANSWER>
|
|
581
|
+
|
|
582
|
+
Do NOT use markdown code fences (```json) inside the EVAL_ANSWER tags - use raw JSON only.
|
|
583
|
+
{data_context}
|
|
584
|
+
""").strip()
|
|
585
|
+
|
|
586
|
+
await self.websocket.send(json.dumps({
|
|
587
|
+
"type": "agent_query",
|
|
588
|
+
"query": initial_query,
|
|
589
|
+
"request_id": f"eval-{eval_case.id}-{uuid.uuid4()}",
|
|
590
|
+
}))
|
|
591
|
+
|
|
592
|
+
print("[headless] Query sent, waiting for completion...")
|
|
593
|
+
|
|
594
|
+
while not self.eval_complete:
|
|
595
|
+
try:
|
|
596
|
+
msg_str = await asyncio.wait_for(self.websocket.recv(), timeout=5.0)
|
|
597
|
+
msg = json.loads(msg_str)
|
|
598
|
+
msg_type = msg.get("type", "unknown")
|
|
599
|
+
|
|
600
|
+
if msg_type != "agent_stream_delta":
|
|
601
|
+
print(f"[headless] Received: {msg_type}")
|
|
602
|
+
|
|
603
|
+
self.add_to_history(msg)
|
|
604
|
+
|
|
605
|
+
if msg_type == "agent_error":
|
|
606
|
+
error_msg = msg.get("error", "Unknown error")
|
|
607
|
+
print(f"[headless] Agent error received: {error_msg}")
|
|
608
|
+
raise RuntimeError(f"Agent error: {error_msg}")
|
|
609
|
+
|
|
610
|
+
if msg_type in ("agent_history_updated", "agent_stream_complete"):
|
|
611
|
+
self.eval_complete = self.check_for_completion()
|
|
612
|
+
if self.eval_complete:
|
|
613
|
+
print("[headless] Detected completion via submit_response")
|
|
614
|
+
|
|
615
|
+
except asyncio.TimeoutError:
|
|
616
|
+
self.eval_complete = self.check_for_completion()
|
|
617
|
+
except websockets.exceptions.ConnectionClosed:
|
|
618
|
+
print("[headless] WebSocket connection closed")
|
|
619
|
+
break
|
|
620
|
+
|
|
621
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
622
|
+
|
|
623
|
+
conversation_history = self.get_conversation_history()
|
|
624
|
+
print(f"[headless] Retrieved {len(conversation_history)} messages from local history")
|
|
625
|
+
|
|
626
|
+
trajectory = self.get_trajectory()
|
|
627
|
+
print(f"[headless] Captured {len(trajectory)} trajectory events")
|
|
628
|
+
|
|
629
|
+
agent_answer = extract_answer_from_conversation(conversation_history)
|
|
630
|
+
if agent_answer is not None:
|
|
631
|
+
print(f"[headless] Extracted answer: {json.dumps(agent_answer)[:200]}...")
|
|
632
|
+
else:
|
|
633
|
+
print("[headless] No answer extracted from conversation")
|
|
634
|
+
|
|
635
|
+
eval_result = EvalResult(
|
|
636
|
+
eval_id=eval_case.id,
|
|
637
|
+
conversation_history=conversation_history,
|
|
638
|
+
trajectory=trajectory,
|
|
639
|
+
duration_ms=duration_ms,
|
|
640
|
+
agent_answer=agent_answer,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
if eval_case.grader:
|
|
644
|
+
print("[headless] Running grader...")
|
|
645
|
+
grader_type = eval_case.grader.get("type")
|
|
646
|
+
grader_config = eval_case.grader.get("config", {})
|
|
647
|
+
|
|
648
|
+
if agent_answer is None:
|
|
649
|
+
eval_result.grader_result = {
|
|
650
|
+
"passed": False,
|
|
651
|
+
"metrics": {},
|
|
652
|
+
"reasoning": "Failed to extract answer from conversation history",
|
|
653
|
+
"agent_answer": None,
|
|
654
|
+
}
|
|
655
|
+
print("[headless] Grader result: FAIL (no answer extracted)")
|
|
656
|
+
elif grader_type in GRADER_REGISTRY:
|
|
657
|
+
grader_cls = GRADER_REGISTRY[grader_type]
|
|
658
|
+
grader = grader_cls()
|
|
659
|
+
grader_result = grader.evaluate(agent_answer, grader_config)
|
|
660
|
+
|
|
661
|
+
eval_result.grader_result = {
|
|
662
|
+
"passed": grader_result.passed,
|
|
663
|
+
"metrics": grader_result.metrics,
|
|
664
|
+
"reasoning": grader_result.reasoning,
|
|
665
|
+
"agent_answer": grader_result.agent_answer,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
print(f"[headless] Grader result: {'PASS' if grader_result.passed else 'FAIL'}")
|
|
669
|
+
print(f"[headless] Grader reasoning:\n{grader_result.reasoning}")
|
|
670
|
+
else:
|
|
671
|
+
print(f"[headless] Warning: Unknown grader type '{grader_type}'")
|
|
672
|
+
|
|
673
|
+
print(f"\n[headless] Eval completed in {duration_ms / 1000:.2f}s")
|
|
674
|
+
print(f"[headless] Total conversation turns: {len(conversation_history)}")
|
|
675
|
+
|
|
676
|
+
return eval_result
|
|
677
|
+
|
|
678
|
+
async def stop_server(self):
|
|
679
|
+
print("[headless] Stopping server...")
|
|
680
|
+
|
|
681
|
+
if self.websocket:
|
|
682
|
+
try:
|
|
683
|
+
await self.websocket.close()
|
|
684
|
+
except Exception:
|
|
685
|
+
pass
|
|
686
|
+
self.websocket = None
|
|
687
|
+
|
|
688
|
+
if self.server_proc:
|
|
689
|
+
try:
|
|
690
|
+
if self.server_proc.returncode is None:
|
|
691
|
+
self.server_proc.terminate()
|
|
692
|
+
await asyncio.wait_for(self.server_proc.wait(), timeout=5)
|
|
693
|
+
except TimeoutError:
|
|
694
|
+
self.server_proc.kill()
|
|
695
|
+
await self.server_proc.wait()
|
|
696
|
+
except ProcessLookupError:
|
|
697
|
+
pass
|
|
698
|
+
self.server_proc = None
|
|
699
|
+
|
|
700
|
+
print("[headless] Server stopped")
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
async def run_eval_batch_headless(eval_cases: list[Eval], sandbox_dir: Path) -> list[EvalResult]:
|
|
704
|
+
results: list[EvalResult] = []
|
|
705
|
+
|
|
706
|
+
server = HeadlessEvalServer(sandbox_dir)
|
|
707
|
+
|
|
708
|
+
server.workspace_id = await get_workspace_id_from_token()
|
|
709
|
+
print(f"[headless] Using workspace: {server.workspace_id}")
|
|
710
|
+
|
|
711
|
+
first_eval_id = eval_cases[0].id if eval_cases else "batch"
|
|
712
|
+
server.notebook_id = await create_eval_notebook(server.workspace_id, first_eval_id)
|
|
713
|
+
|
|
714
|
+
await server.start_server()
|
|
715
|
+
await server.connect()
|
|
716
|
+
|
|
717
|
+
for i, eval_case in enumerate(eval_cases):
|
|
718
|
+
print(f"\n[headless] Running eval {i + 1}/{len(eval_cases)}")
|
|
719
|
+
|
|
720
|
+
await server.clear_agent_history()
|
|
721
|
+
|
|
722
|
+
result = await server.run_eval(eval_case)
|
|
723
|
+
results.append(result)
|
|
724
|
+
|
|
725
|
+
await server.stop_server()
|
|
726
|
+
|
|
727
|
+
return results
|