inspect-ai 0.3.95__py3-none-any.whl → 0.3.97__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/eval.py +10 -2
- inspect_ai/_eval/task/util.py +32 -3
- inspect_ai/_util/local_server.py +16 -0
- inspect_ai/_util/registry.py +7 -0
- inspect_ai/_util/timer.py +13 -0
- inspect_ai/_view/www/dist/assets/index.css +275 -195
- inspect_ai/_view/www/dist/assets/index.js +8568 -7376
- inspect_ai/_view/www/src/app/App.css +1 -0
- inspect_ai/_view/www/src/app/App.tsx +27 -10
- inspect_ai/_view/www/src/app/appearance/icons.ts +5 -0
- inspect_ai/_view/www/src/app/content/RecordTree.module.css +22 -0
- inspect_ai/_view/www/src/app/content/RecordTree.tsx +370 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.module.css +5 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +32 -19
- inspect_ai/_view/www/src/app/content/record_processors/store.ts +101 -0
- inspect_ai/_view/www/src/app/content/record_processors/types.ts +3 -0
- inspect_ai/_view/www/src/app/content/types.ts +5 -0
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +1 -0
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +35 -28
- inspect_ai/_view/www/src/app/log-view/LogViewLayout.tsx +1 -8
- inspect_ai/_view/www/src/app/log-view/navbar/PrimaryBar.tsx +2 -4
- inspect_ai/_view/www/src/app/log-view/navbar/ResultsPanel.tsx +13 -3
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.module.css +15 -0
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.tsx +14 -10
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +9 -3
- inspect_ai/_view/www/src/app/log-view/tabs/JsonTab.tsx +1 -3
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +8 -2
- inspect_ai/_view/www/src/app/log-view/types.ts +1 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.module.css +7 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +5 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +13 -8
- inspect_ai/_view/www/src/app/routing/navigationHooks.ts +63 -8
- inspect_ai/_view/www/src/app/routing/url.ts +45 -0
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.tsx +15 -8
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +3 -0
- inspect_ai/_view/www/src/app/samples/SampleDialog.tsx +16 -5
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +68 -31
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.module.css +12 -7
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +17 -5
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.module.css +9 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +48 -18
- inspect_ai/_view/www/src/app/samples/chat/ChatView.tsx +0 -1
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.module.css +4 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +41 -1
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +7 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.module.css +0 -3
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolInput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolOutput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +5 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/PassFailScoreDescriptor.tsx +11 -6
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +7 -0
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +5 -18
- inspect_ai/_view/www/src/app/samples/sample-tools/SortFilter.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresGrid.tsx +18 -5
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.module.css +0 -6
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.tsx +4 -1
- inspect_ai/_view/www/src/app/samples/transcript/ApprovalEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +13 -6
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/LoggerEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +11 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +14 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +13 -8
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +25 -16
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +7 -5
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +11 -28
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +12 -20
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +12 -31
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +25 -29
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +297 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +0 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +43 -25
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +43 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +109 -43
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +19 -8
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +128 -60
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +14 -4
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +6 -4
- inspect_ai/_view/www/src/app/types.ts +12 -1
- inspect_ai/_view/www/src/components/Card.css +6 -3
- inspect_ai/_view/www/src/components/Card.tsx +15 -2
- inspect_ai/_view/www/src/components/CopyButton.tsx +4 -6
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +20 -14
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +17 -22
- inspect_ai/_view/www/src/components/LargeModal.tsx +5 -1
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +25 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.css +4 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +2 -2
- inspect_ai/_view/www/src/components/TabSet.module.css +6 -1
- inspect_ai/_view/www/src/components/TabSet.tsx +8 -2
- inspect_ai/_view/www/src/state/hooks.ts +83 -13
- inspect_ai/_view/www/src/state/logPolling.ts +2 -2
- inspect_ai/_view/www/src/state/logSlice.ts +1 -2
- inspect_ai/_view/www/src/state/logsSlice.ts +9 -9
- inspect_ai/_view/www/src/state/samplePolling.ts +1 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +134 -7
- inspect_ai/_view/www/src/state/scoring.ts +1 -1
- inspect_ai/_view/www/src/state/scrolling.ts +39 -6
- inspect_ai/_view/www/src/state/store.ts +5 -0
- inspect_ai/_view/www/src/state/store_filter.ts +47 -44
- inspect_ai/_view/www/src/utils/debugging.ts +95 -0
- inspect_ai/_view/www/src/utils/format.ts +2 -2
- inspect_ai/_view/www/src/utils/json.ts +29 -0
- inspect_ai/agent/__init__.py +2 -1
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_react.py +184 -48
- inspect_ai/agent/_types.py +15 -2
- inspect_ai/analysis/beta/__init__.py +11 -3
- inspect_ai/analysis/beta/_dataframe/columns.py +11 -16
- inspect_ai/analysis/beta/_dataframe/evals/table.py +101 -39
- inspect_ai/analysis/beta/_dataframe/events/columns.py +50 -0
- inspect_ai/analysis/beta/_dataframe/events/extract.py +26 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +77 -3
- inspect_ai/analysis/beta/_dataframe/extract.py +44 -25
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +1 -1
- inspect_ai/analysis/beta/_dataframe/messages/table.py +30 -29
- inspect_ai/analysis/beta/_dataframe/progress.py +56 -0
- inspect_ai/analysis/beta/_dataframe/record.py +13 -9
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +8 -4
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +5 -33
- inspect_ai/analysis/beta/_dataframe/samples/table.py +211 -60
- inspect_ai/analysis/beta/_dataframe/util.py +33 -28
- inspect_ai/log/_file.py +9 -2
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_providers/anthropic.py +18 -5
- inspect_ai/model/_providers/azureai.py +7 -2
- inspect_ai/model/_providers/util/llama31.py +3 -3
- inspect_ai/solver/_task_state.py +1 -1
- inspect_ai/tool/_mcp/_sandbox.py +17 -14
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/RECORD +140 -133
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/WHEEL +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.module.css +0 -48
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +0 -276
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.95.dist-info → inspect_ai-0.3.97.dist-info}/top_level.txt +0 -0
inspect_ai/agent/_react.py
CHANGED
@@ -22,6 +22,7 @@ from ._agent import Agent, AgentState, agent, agent_with
|
|
22
22
|
from ._filter import MessageFilter
|
23
23
|
from ._handoff import has_handoff
|
24
24
|
from ._types import (
|
25
|
+
DEFAULT_CONTINUE_PROMOT_NO_SUBMIT,
|
25
26
|
DEFAULT_CONTINUE_PROMPT,
|
26
27
|
AgentAttempts,
|
27
28
|
AgentContinue,
|
@@ -41,7 +42,7 @@ def react(
|
|
41
42
|
tools: Sequence[Tool | ToolDef | ToolSource] | None = None,
|
42
43
|
model: str | Model | Agent | None = None,
|
43
44
|
attempts: int | AgentAttempts = 1,
|
44
|
-
submit: AgentSubmit =
|
45
|
+
submit: AgentSubmit | bool | None = None,
|
45
46
|
on_continue: str | AgentContinue | None = None,
|
46
47
|
truncation: Literal["auto", "disabled"] | MessageFilter = "disabled",
|
47
48
|
) -> Agent:
|
@@ -73,14 +74,16 @@ def react(
|
|
73
74
|
tools: Tools available for the agent.
|
74
75
|
model: Model to use for agent (defaults to currently evaluated model).
|
75
76
|
attempts: Configure agent to make multiple attempts.
|
76
|
-
submit:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
that
|
77
|
+
submit: Use a submit tool for reporting the final answer. Defaults to `True`
|
78
|
+
which uses the default submit behavior. Pass an `AgentSubmit` to
|
79
|
+
customize the behavior or pass `False` to disable the submit tool.
|
80
|
+
on_continue: Message to play back to the model to urge it to continue
|
81
|
+
when it stops calling tools. Use the placeholder {submit} to refer to
|
82
|
+
the submit tool within the message. Alternatively, an async function
|
83
|
+
to call to determine whether the loop should continue and what message
|
84
|
+
to play back. Note that this function is called on _every_ iteration of
|
85
|
+
the loop so if you only want to send a message back when the model fails
|
86
|
+
to call tools you need to code that behavior explicitly.
|
84
87
|
truncation: Truncate the conversation history in the event of a context
|
85
88
|
window overflow. Defaults to "disabled" which does no truncation. Pass
|
86
89
|
"auto" to use `trim_messages()` to reduce the context size. Pass a
|
@@ -89,6 +92,29 @@ def react(
|
|
89
92
|
Returns:
|
90
93
|
ReAct agent.
|
91
94
|
"""
|
95
|
+
# if there is no submit tool then delegate to react_no_submit
|
96
|
+
if submit is False:
|
97
|
+
# if the user passes a `str` for on_continue this won't do anything
|
98
|
+
if isinstance(on_continue, str):
|
99
|
+
raise ValueError(
|
100
|
+
"Passing a string to on_continue with no submit tool is not permitted, "
|
101
|
+
+ "because in this case the agent will always terminate when no tool "
|
102
|
+
+ "calls are made."
|
103
|
+
)
|
104
|
+
|
105
|
+
return react_no_submit(
|
106
|
+
name=name,
|
107
|
+
description=description,
|
108
|
+
prompt=prompt,
|
109
|
+
tools=tools,
|
110
|
+
model=model,
|
111
|
+
on_continue=on_continue,
|
112
|
+
truncation=truncation,
|
113
|
+
)
|
114
|
+
|
115
|
+
# if submit is True or None then use default AgentSubmit
|
116
|
+
if submit is True or submit is None:
|
117
|
+
submit = AgentSubmit()
|
92
118
|
|
93
119
|
# default submit tool
|
94
120
|
@tool(name="submit")
|
@@ -115,19 +141,7 @@ def react(
|
|
115
141
|
tools.append(submit_tool)
|
116
142
|
|
117
143
|
# resolve prompt / system message
|
118
|
-
|
119
|
-
if prompt:
|
120
|
-
prompt_lines: list[str] = []
|
121
|
-
if prompt.instructions:
|
122
|
-
prompt_lines.append(prompt.instructions)
|
123
|
-
if prompt.handoff_prompt and has_handoff(tools):
|
124
|
-
prompt_lines.append(prompt.handoff_prompt)
|
125
|
-
if prompt.assistant_prompt:
|
126
|
-
prompt_lines.append(prompt.assistant_prompt)
|
127
|
-
prompt_content = "\n\n".join(prompt_lines).format(submit=submit_tool.name)
|
128
|
-
system_message: ChatMessage | None = ChatMessageSystem(content=prompt_content)
|
129
|
-
else:
|
130
|
-
system_message = None
|
144
|
+
system_message = _prompt_to_system_message(prompt, tools, submit_tool.name)
|
131
145
|
|
132
146
|
# resolve attempts
|
133
147
|
attempts = AgentAttempts(attempts) if isinstance(attempts, int) else attempts
|
@@ -150,12 +164,7 @@ def react(
|
|
150
164
|
state.messages.insert(0, system_message)
|
151
165
|
|
152
166
|
# resolve overflow handling
|
153
|
-
|
154
|
-
overflow = cast(MessageFilter | None, trim_messages)
|
155
|
-
elif truncation == "disabled":
|
156
|
-
overflow = None
|
157
|
-
else:
|
158
|
-
overflow = truncation
|
167
|
+
overflow = _resolve_overflow(truncation)
|
159
168
|
|
160
169
|
# track attempts
|
161
170
|
attempt_count = 0
|
@@ -168,20 +177,11 @@ def react(
|
|
168
177
|
|
169
178
|
# check for context window overflow
|
170
179
|
if state.output.stop_reason == "model_length":
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
if len(state.messages) < len(previous_messages):
|
177
|
-
transcript().info(
|
178
|
-
"Agent exceeded model context window, truncating messages and continuing."
|
179
|
-
)
|
180
|
-
continue
|
181
|
-
|
182
|
-
# no overflow policy or overflow didn't reduce conversation length
|
183
|
-
transcript().info("Agent terminated: model context window exceeded")
|
184
|
-
break
|
180
|
+
state, handled = await _handle_overflow(state, overflow)
|
181
|
+
if handled:
|
182
|
+
continue
|
183
|
+
else:
|
184
|
+
break
|
185
185
|
|
186
186
|
# resolve tool calls (if any)
|
187
187
|
if state.output.message.tool_calls:
|
@@ -233,9 +233,7 @@ def react(
|
|
233
233
|
|
234
234
|
# call the on_continue hook (if any)
|
235
235
|
if callable(on_continue):
|
236
|
-
|
237
|
-
raise ValueError("The on_continue function must be async.")
|
238
|
-
do_continue = await cast(AgentContinue, on_continue)(state)
|
236
|
+
do_continue = await _call_on_continue(on_continue, state)
|
239
237
|
if do_continue is True:
|
240
238
|
# if there were no tool calls we need to send back a user message
|
241
239
|
if not state.output.message.tool_calls:
|
@@ -274,10 +272,133 @@ def react(
|
|
274
272
|
state.messages = _remove_submit_tool(state.messages, submit_tool.name)
|
275
273
|
return state
|
276
274
|
|
277
|
-
|
278
|
-
|
275
|
+
return _resolve_agent(execute, name, description)
|
276
|
+
|
277
|
+
|
278
|
+
def react_no_submit(
|
279
|
+
*,
|
280
|
+
name: str | None,
|
281
|
+
description: str | None,
|
282
|
+
prompt: str | AgentPrompt | None,
|
283
|
+
tools: Sequence[Tool | ToolDef | ToolSource] | None,
|
284
|
+
model: str | Model | Agent | None,
|
285
|
+
on_continue: AgentContinue | None,
|
286
|
+
truncation: Literal["auto", "disabled"] | MessageFilter,
|
287
|
+
) -> Agent:
|
288
|
+
# resolve tools
|
289
|
+
tools = list(tools) if tools is not None else []
|
290
|
+
|
291
|
+
# resolve prompt / system message
|
292
|
+
system_message = _prompt_to_system_message(prompt, tools, None)
|
293
|
+
|
294
|
+
async def execute(state: AgentState) -> AgentState:
|
295
|
+
async with mcp_connection(tools):
|
296
|
+
# prepend system message if we have one
|
297
|
+
if system_message:
|
298
|
+
state.messages.insert(0, system_message)
|
299
|
+
|
300
|
+
# resolve overflow handling
|
301
|
+
overflow = _resolve_overflow(truncation)
|
302
|
+
|
303
|
+
# main loop
|
304
|
+
while True:
|
305
|
+
# generate output and append assistant message
|
306
|
+
state = await _agent_generate(model, state, tools)
|
307
|
+
|
308
|
+
# check for context window overflow
|
309
|
+
if state.output.stop_reason == "model_length":
|
310
|
+
state, handled = await _handle_overflow(state, overflow)
|
311
|
+
if handled:
|
312
|
+
continue
|
313
|
+
else:
|
314
|
+
break
|
315
|
+
|
316
|
+
# resolve tool calls (if any)
|
317
|
+
if state.output.message.tool_calls:
|
318
|
+
# call tool functions
|
319
|
+
messages, output = await execute_tools(state.messages, tools)
|
320
|
+
state.messages.extend(messages)
|
321
|
+
if output:
|
322
|
+
state.output = output
|
323
|
+
|
324
|
+
# call the on_continue hook (if any)
|
325
|
+
if on_continue:
|
326
|
+
do_continue = await _call_on_continue(on_continue, state)
|
327
|
+
if do_continue is True:
|
328
|
+
do_continue = DEFAULT_CONTINUE_PROMOT_NO_SUBMIT
|
329
|
+
if do_continue:
|
330
|
+
state.messages.append(ChatMessageUser(content=do_continue))
|
331
|
+
else:
|
332
|
+
break
|
333
|
+
elif not state.output.message.tool_calls:
|
334
|
+
break
|
335
|
+
|
336
|
+
return state
|
337
|
+
|
338
|
+
return _resolve_agent(execute, name, description)
|
339
|
+
|
340
|
+
|
341
|
+
def _prompt_to_system_message(
|
342
|
+
prompt: str | AgentPrompt | None,
|
343
|
+
tools: list[Tool | ToolDef | ToolSource],
|
344
|
+
submit_tool: str | None,
|
345
|
+
) -> ChatMessage | None:
|
346
|
+
prompt = AgentPrompt(prompt) if isinstance(prompt, str) else prompt
|
347
|
+
if prompt:
|
348
|
+
prompt_lines: list[str] = []
|
349
|
+
if prompt.instructions:
|
350
|
+
prompt_lines.append(prompt.instructions)
|
351
|
+
if prompt.handoff_prompt and has_handoff(tools):
|
352
|
+
prompt_lines.append(prompt.handoff_prompt)
|
353
|
+
if prompt.assistant_prompt:
|
354
|
+
if (
|
355
|
+
submit_tool
|
356
|
+
and ("{submit}" not in prompt.assistant_prompt)
|
357
|
+
and prompt.submit_prompt
|
358
|
+
):
|
359
|
+
assistant_prompt = f"{prompt.assistant_prompt}\n{prompt.submit_prompt}"
|
360
|
+
else:
|
361
|
+
assistant_prompt = prompt.assistant_prompt
|
362
|
+
prompt_lines.append(assistant_prompt)
|
363
|
+
prompt_content = "\n\n".join(prompt_lines).format(
|
364
|
+
submit=submit_tool or "submit"
|
365
|
+
)
|
366
|
+
system_message: ChatMessage | None = ChatMessageSystem(content=prompt_content)
|
279
367
|
else:
|
280
|
-
|
368
|
+
system_message = None
|
369
|
+
return system_message
|
370
|
+
|
371
|
+
|
372
|
+
def _resolve_overflow(
|
373
|
+
truncation: Literal["auto", "disabled"] | MessageFilter,
|
374
|
+
) -> MessageFilter | None:
|
375
|
+
# resolve overflow handling
|
376
|
+
if truncation == "auto":
|
377
|
+
overflow = cast(MessageFilter | None, trim_messages)
|
378
|
+
elif truncation == "disabled":
|
379
|
+
overflow = None
|
380
|
+
else:
|
381
|
+
overflow = truncation
|
382
|
+
return overflow
|
383
|
+
|
384
|
+
|
385
|
+
async def _handle_overflow(
|
386
|
+
state: AgentState, overflow: MessageFilter | None
|
387
|
+
) -> tuple[AgentState, bool]:
|
388
|
+
from inspect_ai.log._transcript import transcript
|
389
|
+
|
390
|
+
if overflow is not None:
|
391
|
+
previous_messages = state.messages[:-1]
|
392
|
+
state.messages = await overflow(previous_messages)
|
393
|
+
if len(state.messages) < len(previous_messages):
|
394
|
+
transcript().info(
|
395
|
+
"Agent exceeded model context window, truncating messages and continuing."
|
396
|
+
)
|
397
|
+
return state, True
|
398
|
+
|
399
|
+
# no overflow policy or overflow didn't reduce conversation length
|
400
|
+
transcript().info("Agent terminated: model context window exceeded")
|
401
|
+
return state, False
|
281
402
|
|
282
403
|
|
283
404
|
async def _agent_generate(
|
@@ -319,6 +440,21 @@ def _model_generate(model: str | Model | None) -> Agent:
|
|
319
440
|
return generate
|
320
441
|
|
321
442
|
|
443
|
+
async def _call_on_continue(
|
444
|
+
on_continue: AgentContinue, state: AgentState
|
445
|
+
) -> str | bool:
|
446
|
+
if not is_callable_coroutine(on_continue):
|
447
|
+
raise ValueError("The on_continue function must be async.")
|
448
|
+
return await on_continue(state)
|
449
|
+
|
450
|
+
|
451
|
+
def _resolve_agent(agent: Agent, name: str | None, description: str | None) -> Agent:
|
452
|
+
if name is not None or description is not None:
|
453
|
+
return agent_with(agent, name=name, description=description)
|
454
|
+
else:
|
455
|
+
return agent
|
456
|
+
|
457
|
+
|
322
458
|
def _remove_submit_tool(
|
323
459
|
messages: list[ChatMessage], submit_name: str
|
324
460
|
) -> list[ChatMessage]:
|
inspect_ai/agent/_types.py
CHANGED
@@ -22,7 +22,9 @@ see the result of tool calls right after sending the message. If you need
|
|
22
22
|
to perform multiple actions, you can always send more messages with additional
|
23
23
|
tool calls. Do some reasoning before your actions, describing what tool calls
|
24
24
|
you are going to use and how they fit into your plan.
|
25
|
+
"""
|
25
26
|
|
27
|
+
DEFAULT_SUBMIT_PROMPT = """
|
26
28
|
When you have completed the task and have an answer, call the {submit}()
|
27
29
|
tool to report it.
|
28
30
|
"""
|
@@ -38,12 +40,23 @@ class AgentPrompt(NamedTuple):
|
|
38
40
|
"""Prompt used when there are additional handoff agents active."""
|
39
41
|
|
40
42
|
assistant_prompt: str | None = DEFAULT_ASSISTANT_PROMPT
|
41
|
-
"""Prompt for assistant (covers tool use,
|
43
|
+
"""Prompt for assistant (covers tool use, CoT, etc.)."""
|
44
|
+
|
45
|
+
submit_prompt: str | None = DEFAULT_SUBMIT_PROMPT
|
46
|
+
"""Prompt to tell the model about the submit tool.
|
47
|
+
|
48
|
+
This prompt is not used if the `assistant_prompt` contains a
|
49
|
+
{submit} placeholder.
|
50
|
+
"""
|
42
51
|
|
43
52
|
|
44
53
|
DEFAULT_CONTINUE_PROMPT = """
|
45
54
|
Please proceed to the next step using your best judgement. If you believe you
|
46
|
-
have completed the task, please call the `{submit}()` tool.
|
55
|
+
have completed the task, please call the `{submit}()` tool with your final answer.
|
56
|
+
"""
|
57
|
+
|
58
|
+
DEFAULT_CONTINUE_PROMOT_NO_SUBMIT = """
|
59
|
+
Please proceed to the next step using your best judgement.
|
47
60
|
"""
|
48
61
|
|
49
62
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from ._dataframe.columns import (
|
2
2
|
Column,
|
3
3
|
ColumnError,
|
4
|
-
ColumnErrors,
|
5
4
|
ColumnType,
|
6
5
|
)
|
7
6
|
from ._dataframe.evals.columns import (
|
@@ -15,7 +14,13 @@ from ._dataframe.evals.columns import (
|
|
15
14
|
EvalTask,
|
16
15
|
)
|
17
16
|
from ._dataframe.evals.table import evals_df
|
18
|
-
from ._dataframe.events.columns import
|
17
|
+
from ._dataframe.events.columns import (
|
18
|
+
EventColumn,
|
19
|
+
EventInfo,
|
20
|
+
EventTiming,
|
21
|
+
ModelEventColumns,
|
22
|
+
ToolEventColumns,
|
23
|
+
)
|
19
24
|
from ._dataframe.events.table import events_df
|
20
25
|
from ._dataframe.messages.columns import (
|
21
26
|
MessageColumn,
|
@@ -50,8 +55,11 @@ __all__ = [
|
|
50
55
|
"MessageFilter",
|
51
56
|
"events_df",
|
52
57
|
"EventColumn",
|
58
|
+
"EventInfo",
|
59
|
+
"EventTiming",
|
60
|
+
"ModelEventColumns",
|
61
|
+
"ToolEventColumns",
|
53
62
|
"Column",
|
54
63
|
"ColumnType",
|
55
64
|
"ColumnError",
|
56
|
-
"ColumnErrors",
|
57
65
|
]
|
@@ -7,6 +7,8 @@ from jsonpath_ng import JSONPath # type: ignore
|
|
7
7
|
from jsonpath_ng.ext import parse # type: ignore
|
8
8
|
from pydantic import JsonValue
|
9
9
|
|
10
|
+
from inspect_ai.log._log import EvalLog
|
11
|
+
|
10
12
|
from .validate import jsonpath_in_schema
|
11
13
|
|
12
14
|
ColumnType: TypeAlias = int | float | bool | str | date | time | datetime | None
|
@@ -122,24 +124,17 @@ class ColumnError:
|
|
122
124
|
path: str | None
|
123
125
|
"""Path to select column value. """
|
124
126
|
|
125
|
-
|
126
|
-
"""
|
127
|
+
error: Exception
|
128
|
+
"""Underlying error."""
|
129
|
+
|
130
|
+
log: EvalLog
|
131
|
+
"""Eval log where the error occurred.
|
132
|
+
|
133
|
+
Use log.location to determine the path where the log was read from.
|
134
|
+
"""
|
127
135
|
|
128
136
|
def __str__(self) -> str:
|
129
137
|
msg = f"Error reading column '{self.column}'"
|
130
138
|
if self.path:
|
131
139
|
msg = f"{msg} from path '{self.path}'"
|
132
|
-
return f"{msg}: {self.
|
133
|
-
|
134
|
-
|
135
|
-
class ColumnErrors(dict[str, list[ColumnError]]):
|
136
|
-
"""Dictionary of column errors keyed by log file."""
|
137
|
-
|
138
|
-
def __str__(self) -> str:
|
139
|
-
lines: list[str] = [""]
|
140
|
-
for file, errors in self.items():
|
141
|
-
lines.append(file)
|
142
|
-
for error in errors:
|
143
|
-
lines.append(f" - {error}")
|
144
|
-
lines.append("")
|
145
|
-
return "\n".join(lines)
|
140
|
+
return f"{msg}: {self.error} (log: {self.log.location})"
|
@@ -1,14 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import TYPE_CHECKING, Callable, Literal, Sequence, overload
|
4
5
|
|
5
|
-
from inspect_ai.
|
6
|
-
from inspect_ai._util.path import pretty_path
|
6
|
+
from inspect_ai.analysis.beta._dataframe.progress import import_progress, no_progress
|
7
7
|
from inspect_ai.log._file import (
|
8
|
+
list_eval_logs,
|
8
9
|
read_eval_log,
|
9
10
|
)
|
11
|
+
from inspect_ai.log._log import EvalLog
|
10
12
|
|
11
|
-
from ..columns import Column,
|
13
|
+
from ..columns import Column, ColumnError, ColumnType
|
12
14
|
from ..record import import_record, resolve_duplicate_columns
|
13
15
|
from ..util import (
|
14
16
|
LogPaths,
|
@@ -20,6 +22,8 @@ from ..util import (
|
|
20
22
|
)
|
21
23
|
from .columns import EvalColumns, EvalId
|
22
24
|
|
25
|
+
logger = getLogger(__name__)
|
26
|
+
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
import pandas as pd
|
25
29
|
|
@@ -29,41 +33,38 @@ EVAL_SUFFIX = "_eval"
|
|
29
33
|
|
30
34
|
@overload
|
31
35
|
def evals_df(
|
32
|
-
logs: LogPaths,
|
33
|
-
columns:
|
34
|
-
recursive: bool = True,
|
35
|
-
reverse: bool = False,
|
36
|
+
logs: LogPaths = list_eval_logs(),
|
37
|
+
columns: Sequence[Column] = EvalColumns,
|
36
38
|
strict: Literal[True] = True,
|
39
|
+
quiet: bool = False,
|
37
40
|
) -> "pd.DataFrame": ...
|
38
41
|
|
39
42
|
|
40
43
|
@overload
|
41
44
|
def evals_df(
|
42
|
-
logs: LogPaths,
|
43
|
-
columns:
|
44
|
-
recursive: bool = True,
|
45
|
-
reverse: bool = False,
|
45
|
+
logs: LogPaths = list_eval_logs(),
|
46
|
+
columns: Sequence[Column] = EvalColumns,
|
46
47
|
strict: Literal[False] = False,
|
47
|
-
|
48
|
+
quiet: bool = False,
|
49
|
+
) -> tuple["pd.DataFrame", Sequence[ColumnError]]: ...
|
48
50
|
|
49
51
|
|
50
52
|
def evals_df(
|
51
|
-
logs: LogPaths,
|
52
|
-
columns:
|
53
|
-
recursive: bool = True,
|
54
|
-
reverse: bool = False,
|
53
|
+
logs: LogPaths = list_eval_logs(),
|
54
|
+
columns: Sequence[Column] = EvalColumns,
|
55
55
|
strict: bool = True,
|
56
|
-
|
56
|
+
quiet: bool = False,
|
57
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", Sequence[ColumnError]]:
|
57
58
|
"""Read a dataframe containing evals.
|
58
59
|
|
59
60
|
Args:
|
60
61
|
logs: One or more paths to log files or log directories.
|
62
|
+
Defaults to the contents of the currently active log directory
|
63
|
+
(e.g. ./logs or INSPECT_LOG_DIR).
|
61
64
|
columns: Specification for what columns to read from log files.
|
62
|
-
recursive: Include recursive contents of directories (defaults to `True`)
|
63
|
-
reverse: Reverse the order of the dataframe (by default, items
|
64
|
-
are ordered from oldest to newest).
|
65
65
|
strict: Raise import errors immediately. Defaults to `True`.
|
66
66
|
If `False` then a tuple of `DataFrame` and errors is returned.
|
67
|
+
quiet: If `True`, do not show any output or progress. Defaults to `False`.
|
67
68
|
|
68
69
|
Returns:
|
69
70
|
For `strict`, a Pandas `DataFrame` with information for the specified logs.
|
@@ -73,48 +74,109 @@ def evals_df(
|
|
73
74
|
verify_prerequisites()
|
74
75
|
|
75
76
|
# resolve logs
|
76
|
-
log_paths = resolve_logs(logs
|
77
|
+
log_paths = resolve_logs(logs)
|
78
|
+
|
79
|
+
# establish progress
|
80
|
+
progress_cm = (
|
81
|
+
import_progress("reading logs", total=len(log_paths))
|
82
|
+
if not quiet
|
83
|
+
else no_progress()
|
84
|
+
)
|
85
|
+
|
86
|
+
with progress_cm as p:
|
87
|
+
if strict:
|
88
|
+
evals_table, _, _ = _read_evals_df(log_paths, columns, True, p.update)
|
89
|
+
return evals_table
|
90
|
+
else:
|
91
|
+
evals_table, _, all_errors, _ = _read_evals_df(
|
92
|
+
log_paths, columns, False, p.update
|
93
|
+
)
|
94
|
+
return evals_table, all_errors
|
95
|
+
|
96
|
+
|
97
|
+
@overload
|
98
|
+
def _read_evals_df(
|
99
|
+
log_paths: Sequence[str],
|
100
|
+
columns: Sequence[Column],
|
101
|
+
strict: Literal[True],
|
102
|
+
progress: Callable[[], None],
|
103
|
+
) -> tuple["pd.DataFrame", Sequence[EvalLog], int]: ...
|
104
|
+
|
105
|
+
|
106
|
+
@overload
|
107
|
+
def _read_evals_df(
|
108
|
+
log_paths: Sequence[str],
|
109
|
+
columns: Sequence[Column],
|
110
|
+
strict: Literal[False],
|
111
|
+
progress: Callable[[], None],
|
112
|
+
) -> tuple["pd.DataFrame", Sequence[EvalLog], Sequence[ColumnError], int]: ...
|
113
|
+
|
114
|
+
|
115
|
+
def _read_evals_df(
|
116
|
+
log_paths: Sequence[str],
|
117
|
+
columns: Sequence[Column],
|
118
|
+
strict: bool,
|
119
|
+
progress: Callable[[], None],
|
120
|
+
) -> (
|
121
|
+
tuple["pd.DataFrame", Sequence[EvalLog], int]
|
122
|
+
| tuple["pd.DataFrame", Sequence[EvalLog], Sequence[ColumnError], int]
|
123
|
+
):
|
124
|
+
verify_prerequisites()
|
77
125
|
|
78
126
|
# resolve duplicate columns
|
79
127
|
columns = resolve_duplicate_columns(columns)
|
80
128
|
|
81
129
|
# accumulate errors for strict=False
|
82
|
-
all_errors =
|
130
|
+
all_errors: list[ColumnError] = []
|
83
131
|
|
84
132
|
# ensure eval_id
|
85
|
-
ensure_eval_id(columns)
|
133
|
+
columns = ensure_eval_id(columns)
|
86
134
|
|
87
135
|
# read logs
|
136
|
+
total_samples = 0
|
137
|
+
eval_ids: set[str] = set()
|
138
|
+
eval_logs: list[EvalLog] = []
|
88
139
|
records: list[dict[str, ColumnType]] = []
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
140
|
+
for log_path in log_paths:
|
141
|
+
log = read_eval_log(log_path, header_only=True)
|
142
|
+
if strict:
|
143
|
+
record = import_record(log, log, columns, strict=True)
|
144
|
+
else:
|
145
|
+
record, errors = import_record(log, log, columns, strict=False)
|
146
|
+
all_errors.extend(errors)
|
147
|
+
|
148
|
+
# don't add duplicate ids
|
149
|
+
eval_id = str(record.get(EVAL_ID, ""))
|
150
|
+
if eval_id not in eval_ids:
|
151
|
+
eval_ids.add(eval_id)
|
152
|
+
eval_logs.append(log)
|
97
153
|
records.append(record)
|
98
|
-
|
99
|
-
|
154
|
+
total_samples += (
|
155
|
+
len(log.eval.dataset.sample_ids)
|
156
|
+
if log.eval.dataset.sample_ids is not None
|
157
|
+
else (log.eval.dataset.samples or 100)
|
158
|
+
)
|
159
|
+
progress()
|
100
160
|
|
101
161
|
# return table (+errors if strict=False)
|
102
162
|
evals_table = records_to_pandas(records)
|
103
163
|
evals_table = reorder_evals_df_columns(evals_table, columns)
|
104
164
|
|
105
165
|
if strict:
|
106
|
-
return evals_table
|
166
|
+
return evals_table, eval_logs, total_samples
|
107
167
|
else:
|
108
|
-
return evals_table, all_errors
|
168
|
+
return evals_table, eval_logs, all_errors, total_samples
|
109
169
|
|
110
170
|
|
111
|
-
def ensure_eval_id(columns:
|
171
|
+
def ensure_eval_id(columns: Sequence[Column]) -> Sequence[Column]:
|
112
172
|
if not any([column.name == EVAL_ID for column in columns]):
|
113
|
-
columns
|
173
|
+
return list(columns) + EvalId
|
174
|
+
else:
|
175
|
+
return columns
|
114
176
|
|
115
177
|
|
116
178
|
def reorder_evals_df_columns(
|
117
|
-
df: "pd.DataFrame", eval_columns:
|
179
|
+
df: "pd.DataFrame", eval_columns: Sequence[Column]
|
118
180
|
) -> "pd.DataFrame":
|
119
181
|
actual_columns = list(df.columns)
|
120
182
|
ordered_columns: list[str] = []
|