inspect-ai 0.3.96__py3-none-any.whl → 0.3.97__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/eval.py +10 -2
- inspect_ai/_eval/task/util.py +32 -3
- inspect_ai/_util/registry.py +7 -0
- inspect_ai/_util/timer.py +13 -0
- inspect_ai/_view/www/dist/assets/index.css +275 -195
- inspect_ai/_view/www/dist/assets/index.js +8568 -7376
- inspect_ai/_view/www/src/app/App.css +1 -0
- inspect_ai/_view/www/src/app/App.tsx +27 -10
- inspect_ai/_view/www/src/app/appearance/icons.ts +5 -0
- inspect_ai/_view/www/src/app/content/RecordTree.module.css +22 -0
- inspect_ai/_view/www/src/app/content/RecordTree.tsx +370 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.module.css +5 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +32 -19
- inspect_ai/_view/www/src/app/content/record_processors/store.ts +101 -0
- inspect_ai/_view/www/src/app/content/record_processors/types.ts +3 -0
- inspect_ai/_view/www/src/app/content/types.ts +5 -0
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +1 -0
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +35 -28
- inspect_ai/_view/www/src/app/log-view/LogViewLayout.tsx +1 -8
- inspect_ai/_view/www/src/app/log-view/navbar/PrimaryBar.tsx +2 -4
- inspect_ai/_view/www/src/app/log-view/navbar/ResultsPanel.tsx +13 -3
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.module.css +15 -0
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.tsx +14 -10
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +9 -3
- inspect_ai/_view/www/src/app/log-view/tabs/JsonTab.tsx +1 -3
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +8 -2
- inspect_ai/_view/www/src/app/log-view/types.ts +1 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.module.css +7 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +5 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +13 -8
- inspect_ai/_view/www/src/app/routing/navigationHooks.ts +63 -8
- inspect_ai/_view/www/src/app/routing/url.ts +45 -0
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.tsx +15 -8
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +3 -0
- inspect_ai/_view/www/src/app/samples/SampleDialog.tsx +16 -5
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +68 -31
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.module.css +12 -7
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +17 -5
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.module.css +9 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +48 -18
- inspect_ai/_view/www/src/app/samples/chat/ChatView.tsx +0 -1
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.module.css +4 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +41 -1
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +7 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.module.css +0 -3
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolInput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolOutput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +5 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/PassFailScoreDescriptor.tsx +11 -6
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +7 -0
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +5 -18
- inspect_ai/_view/www/src/app/samples/sample-tools/SortFilter.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresGrid.tsx +18 -5
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.module.css +0 -6
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.tsx +4 -1
- inspect_ai/_view/www/src/app/samples/transcript/ApprovalEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +13 -6
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/LoggerEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +11 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +14 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +13 -8
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +25 -16
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +7 -5
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +11 -28
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +12 -20
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +12 -31
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +25 -29
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +297 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +0 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +43 -25
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +43 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +109 -43
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +19 -8
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +128 -60
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +14 -4
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +6 -4
- inspect_ai/_view/www/src/app/types.ts +12 -1
- inspect_ai/_view/www/src/components/Card.css +6 -3
- inspect_ai/_view/www/src/components/Card.tsx +15 -2
- inspect_ai/_view/www/src/components/CopyButton.tsx +4 -6
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +20 -14
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +17 -22
- inspect_ai/_view/www/src/components/LargeModal.tsx +5 -1
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +25 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.css +4 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +2 -2
- inspect_ai/_view/www/src/components/TabSet.module.css +6 -1
- inspect_ai/_view/www/src/components/TabSet.tsx +8 -2
- inspect_ai/_view/www/src/state/hooks.ts +83 -13
- inspect_ai/_view/www/src/state/logPolling.ts +2 -2
- inspect_ai/_view/www/src/state/logSlice.ts +1 -2
- inspect_ai/_view/www/src/state/logsSlice.ts +9 -9
- inspect_ai/_view/www/src/state/samplePolling.ts +1 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +134 -7
- inspect_ai/_view/www/src/state/scoring.ts +1 -1
- inspect_ai/_view/www/src/state/scrolling.ts +39 -6
- inspect_ai/_view/www/src/state/store.ts +5 -0
- inspect_ai/_view/www/src/state/store_filter.ts +47 -44
- inspect_ai/_view/www/src/utils/debugging.ts +95 -0
- inspect_ai/_view/www/src/utils/format.ts +2 -2
- inspect_ai/_view/www/src/utils/json.ts +29 -0
- inspect_ai/agent/__init__.py +2 -1
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_react.py +184 -48
- inspect_ai/agent/_types.py +14 -1
- inspect_ai/analysis/beta/__init__.py +0 -2
- inspect_ai/analysis/beta/_dataframe/columns.py +11 -16
- inspect_ai/analysis/beta/_dataframe/evals/table.py +65 -40
- inspect_ai/analysis/beta/_dataframe/events/table.py +24 -36
- inspect_ai/analysis/beta/_dataframe/messages/table.py +24 -15
- inspect_ai/analysis/beta/_dataframe/progress.py +35 -5
- inspect_ai/analysis/beta/_dataframe/record.py +13 -9
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +1 -1
- inspect_ai/analysis/beta/_dataframe/samples/table.py +156 -46
- inspect_ai/analysis/beta/_dataframe/util.py +14 -12
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_providers/anthropic.py +18 -5
- inspect_ai/model/_providers/azureai.py +7 -2
- inspect_ai/model/_providers/util/llama31.py +3 -3
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/RECORD +131 -126
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/WHEEL +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.module.css +0 -48
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +0 -276
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/top_level.txt +0 -0
inspect_ai/agent/_react.py
CHANGED
@@ -22,6 +22,7 @@ from ._agent import Agent, AgentState, agent, agent_with
|
|
22
22
|
from ._filter import MessageFilter
|
23
23
|
from ._handoff import has_handoff
|
24
24
|
from ._types import (
|
25
|
+
DEFAULT_CONTINUE_PROMOT_NO_SUBMIT,
|
25
26
|
DEFAULT_CONTINUE_PROMPT,
|
26
27
|
AgentAttempts,
|
27
28
|
AgentContinue,
|
@@ -41,7 +42,7 @@ def react(
|
|
41
42
|
tools: Sequence[Tool | ToolDef | ToolSource] | None = None,
|
42
43
|
model: str | Model | Agent | None = None,
|
43
44
|
attempts: int | AgentAttempts = 1,
|
44
|
-
submit: AgentSubmit =
|
45
|
+
submit: AgentSubmit | bool | None = None,
|
45
46
|
on_continue: str | AgentContinue | None = None,
|
46
47
|
truncation: Literal["auto", "disabled"] | MessageFilter = "disabled",
|
47
48
|
) -> Agent:
|
@@ -73,14 +74,16 @@ def react(
|
|
73
74
|
tools: Tools available for the agent.
|
74
75
|
model: Model to use for agent (defaults to currently evaluated model).
|
75
76
|
attempts: Configure agent to make multiple attempts.
|
76
|
-
submit:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
that
|
77
|
+
submit: Use a submit tool for reporting the final answer. Defaults to `True`
|
78
|
+
which uses the default submit behavior. Pass an `AgentSubmit` to
|
79
|
+
customize the behavior or pass `False` to disable the submit tool.
|
80
|
+
on_continue: Message to play back to the model to urge it to continue
|
81
|
+
when it stops calling tools. Use the placeholder {submit} to refer to
|
82
|
+
the submit tool within the message. Alternatively, an async function
|
83
|
+
to call to determine whether the loop should continue and what message
|
84
|
+
to play back. Note that this function is called on _every_ iteration of
|
85
|
+
the loop so if you only want to send a message back when the model fails
|
86
|
+
to call tools you need to code that behavior explicitly.
|
84
87
|
truncation: Truncate the conversation history in the event of a context
|
85
88
|
window overflow. Defaults to "disabled" which does no truncation. Pass
|
86
89
|
"auto" to use `trim_messages()` to reduce the context size. Pass a
|
@@ -89,6 +92,29 @@ def react(
|
|
89
92
|
Returns:
|
90
93
|
ReAct agent.
|
91
94
|
"""
|
95
|
+
# if there is no submit tool then delegate to react_no_submit
|
96
|
+
if submit is False:
|
97
|
+
# if the user passes a `str` for on_continue this won't do anything
|
98
|
+
if isinstance(on_continue, str):
|
99
|
+
raise ValueError(
|
100
|
+
"Passing a string to on_continue with no submit tool is not permitted, "
|
101
|
+
+ "because in this case the agent will always terminate when no tool "
|
102
|
+
+ "calls are made."
|
103
|
+
)
|
104
|
+
|
105
|
+
return react_no_submit(
|
106
|
+
name=name,
|
107
|
+
description=description,
|
108
|
+
prompt=prompt,
|
109
|
+
tools=tools,
|
110
|
+
model=model,
|
111
|
+
on_continue=on_continue,
|
112
|
+
truncation=truncation,
|
113
|
+
)
|
114
|
+
|
115
|
+
# if submit is True or None then use default AgentSubmit
|
116
|
+
if submit is True or submit is None:
|
117
|
+
submit = AgentSubmit()
|
92
118
|
|
93
119
|
# default submit tool
|
94
120
|
@tool(name="submit")
|
@@ -115,19 +141,7 @@ def react(
|
|
115
141
|
tools.append(submit_tool)
|
116
142
|
|
117
143
|
# resolve prompt / system message
|
118
|
-
|
119
|
-
if prompt:
|
120
|
-
prompt_lines: list[str] = []
|
121
|
-
if prompt.instructions:
|
122
|
-
prompt_lines.append(prompt.instructions)
|
123
|
-
if prompt.handoff_prompt and has_handoff(tools):
|
124
|
-
prompt_lines.append(prompt.handoff_prompt)
|
125
|
-
if prompt.assistant_prompt:
|
126
|
-
prompt_lines.append(prompt.assistant_prompt)
|
127
|
-
prompt_content = "\n\n".join(prompt_lines).format(submit=submit_tool.name)
|
128
|
-
system_message: ChatMessage | None = ChatMessageSystem(content=prompt_content)
|
129
|
-
else:
|
130
|
-
system_message = None
|
144
|
+
system_message = _prompt_to_system_message(prompt, tools, submit_tool.name)
|
131
145
|
|
132
146
|
# resolve attempts
|
133
147
|
attempts = AgentAttempts(attempts) if isinstance(attempts, int) else attempts
|
@@ -150,12 +164,7 @@ def react(
|
|
150
164
|
state.messages.insert(0, system_message)
|
151
165
|
|
152
166
|
# resolve overflow handling
|
153
|
-
|
154
|
-
overflow = cast(MessageFilter | None, trim_messages)
|
155
|
-
elif truncation == "disabled":
|
156
|
-
overflow = None
|
157
|
-
else:
|
158
|
-
overflow = truncation
|
167
|
+
overflow = _resolve_overflow(truncation)
|
159
168
|
|
160
169
|
# track attempts
|
161
170
|
attempt_count = 0
|
@@ -168,20 +177,11 @@ def react(
|
|
168
177
|
|
169
178
|
# check for context window overflow
|
170
179
|
if state.output.stop_reason == "model_length":
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
if len(state.messages) < len(previous_messages):
|
177
|
-
transcript().info(
|
178
|
-
"Agent exceeded model context window, truncating messages and continuing."
|
179
|
-
)
|
180
|
-
continue
|
181
|
-
|
182
|
-
# no overflow policy or overflow didn't reduce conversation length
|
183
|
-
transcript().info("Agent terminated: model context window exceeded")
|
184
|
-
break
|
180
|
+
state, handled = await _handle_overflow(state, overflow)
|
181
|
+
if handled:
|
182
|
+
continue
|
183
|
+
else:
|
184
|
+
break
|
185
185
|
|
186
186
|
# resolve tool calls (if any)
|
187
187
|
if state.output.message.tool_calls:
|
@@ -233,9 +233,7 @@ def react(
|
|
233
233
|
|
234
234
|
# call the on_continue hook (if any)
|
235
235
|
if callable(on_continue):
|
236
|
-
|
237
|
-
raise ValueError("The on_continue function must be async.")
|
238
|
-
do_continue = await cast(AgentContinue, on_continue)(state)
|
236
|
+
do_continue = await _call_on_continue(on_continue, state)
|
239
237
|
if do_continue is True:
|
240
238
|
# if there were no tool calls we need to send back a user message
|
241
239
|
if not state.output.message.tool_calls:
|
@@ -274,10 +272,133 @@ def react(
|
|
274
272
|
state.messages = _remove_submit_tool(state.messages, submit_tool.name)
|
275
273
|
return state
|
276
274
|
|
277
|
-
|
278
|
-
|
275
|
+
return _resolve_agent(execute, name, description)
|
276
|
+
|
277
|
+
|
278
|
+
def react_no_submit(
|
279
|
+
*,
|
280
|
+
name: str | None,
|
281
|
+
description: str | None,
|
282
|
+
prompt: str | AgentPrompt | None,
|
283
|
+
tools: Sequence[Tool | ToolDef | ToolSource] | None,
|
284
|
+
model: str | Model | Agent | None,
|
285
|
+
on_continue: AgentContinue | None,
|
286
|
+
truncation: Literal["auto", "disabled"] | MessageFilter,
|
287
|
+
) -> Agent:
|
288
|
+
# resolve tools
|
289
|
+
tools = list(tools) if tools is not None else []
|
290
|
+
|
291
|
+
# resolve prompt / system message
|
292
|
+
system_message = _prompt_to_system_message(prompt, tools, None)
|
293
|
+
|
294
|
+
async def execute(state: AgentState) -> AgentState:
|
295
|
+
async with mcp_connection(tools):
|
296
|
+
# prepend system message if we have one
|
297
|
+
if system_message:
|
298
|
+
state.messages.insert(0, system_message)
|
299
|
+
|
300
|
+
# resolve overflow handling
|
301
|
+
overflow = _resolve_overflow(truncation)
|
302
|
+
|
303
|
+
# main loop
|
304
|
+
while True:
|
305
|
+
# generate output and append assistant message
|
306
|
+
state = await _agent_generate(model, state, tools)
|
307
|
+
|
308
|
+
# check for context window overflow
|
309
|
+
if state.output.stop_reason == "model_length":
|
310
|
+
state, handled = await _handle_overflow(state, overflow)
|
311
|
+
if handled:
|
312
|
+
continue
|
313
|
+
else:
|
314
|
+
break
|
315
|
+
|
316
|
+
# resolve tool calls (if any)
|
317
|
+
if state.output.message.tool_calls:
|
318
|
+
# call tool functions
|
319
|
+
messages, output = await execute_tools(state.messages, tools)
|
320
|
+
state.messages.extend(messages)
|
321
|
+
if output:
|
322
|
+
state.output = output
|
323
|
+
|
324
|
+
# call the on_continue hook (if any)
|
325
|
+
if on_continue:
|
326
|
+
do_continue = await _call_on_continue(on_continue, state)
|
327
|
+
if do_continue is True:
|
328
|
+
do_continue = DEFAULT_CONTINUE_PROMOT_NO_SUBMIT
|
329
|
+
if do_continue:
|
330
|
+
state.messages.append(ChatMessageUser(content=do_continue))
|
331
|
+
else:
|
332
|
+
break
|
333
|
+
elif not state.output.message.tool_calls:
|
334
|
+
break
|
335
|
+
|
336
|
+
return state
|
337
|
+
|
338
|
+
return _resolve_agent(execute, name, description)
|
339
|
+
|
340
|
+
|
341
|
+
def _prompt_to_system_message(
|
342
|
+
prompt: str | AgentPrompt | None,
|
343
|
+
tools: list[Tool | ToolDef | ToolSource],
|
344
|
+
submit_tool: str | None,
|
345
|
+
) -> ChatMessage | None:
|
346
|
+
prompt = AgentPrompt(prompt) if isinstance(prompt, str) else prompt
|
347
|
+
if prompt:
|
348
|
+
prompt_lines: list[str] = []
|
349
|
+
if prompt.instructions:
|
350
|
+
prompt_lines.append(prompt.instructions)
|
351
|
+
if prompt.handoff_prompt and has_handoff(tools):
|
352
|
+
prompt_lines.append(prompt.handoff_prompt)
|
353
|
+
if prompt.assistant_prompt:
|
354
|
+
if (
|
355
|
+
submit_tool
|
356
|
+
and ("{submit}" not in prompt.assistant_prompt)
|
357
|
+
and prompt.submit_prompt
|
358
|
+
):
|
359
|
+
assistant_prompt = f"{prompt.assistant_prompt}\n{prompt.submit_prompt}"
|
360
|
+
else:
|
361
|
+
assistant_prompt = prompt.assistant_prompt
|
362
|
+
prompt_lines.append(assistant_prompt)
|
363
|
+
prompt_content = "\n\n".join(prompt_lines).format(
|
364
|
+
submit=submit_tool or "submit"
|
365
|
+
)
|
366
|
+
system_message: ChatMessage | None = ChatMessageSystem(content=prompt_content)
|
279
367
|
else:
|
280
|
-
|
368
|
+
system_message = None
|
369
|
+
return system_message
|
370
|
+
|
371
|
+
|
372
|
+
def _resolve_overflow(
|
373
|
+
truncation: Literal["auto", "disabled"] | MessageFilter,
|
374
|
+
) -> MessageFilter | None:
|
375
|
+
# resolve overflow handling
|
376
|
+
if truncation == "auto":
|
377
|
+
overflow = cast(MessageFilter | None, trim_messages)
|
378
|
+
elif truncation == "disabled":
|
379
|
+
overflow = None
|
380
|
+
else:
|
381
|
+
overflow = truncation
|
382
|
+
return overflow
|
383
|
+
|
384
|
+
|
385
|
+
async def _handle_overflow(
|
386
|
+
state: AgentState, overflow: MessageFilter | None
|
387
|
+
) -> tuple[AgentState, bool]:
|
388
|
+
from inspect_ai.log._transcript import transcript
|
389
|
+
|
390
|
+
if overflow is not None:
|
391
|
+
previous_messages = state.messages[:-1]
|
392
|
+
state.messages = await overflow(previous_messages)
|
393
|
+
if len(state.messages) < len(previous_messages):
|
394
|
+
transcript().info(
|
395
|
+
"Agent exceeded model context window, truncating messages and continuing."
|
396
|
+
)
|
397
|
+
return state, True
|
398
|
+
|
399
|
+
# no overflow policy or overflow didn't reduce conversation length
|
400
|
+
transcript().info("Agent terminated: model context window exceeded")
|
401
|
+
return state, False
|
281
402
|
|
282
403
|
|
283
404
|
async def _agent_generate(
|
@@ -319,6 +440,21 @@ def _model_generate(model: str | Model | None) -> Agent:
|
|
319
440
|
return generate
|
320
441
|
|
321
442
|
|
443
|
+
async def _call_on_continue(
|
444
|
+
on_continue: AgentContinue, state: AgentState
|
445
|
+
) -> str | bool:
|
446
|
+
if not is_callable_coroutine(on_continue):
|
447
|
+
raise ValueError("The on_continue function must be async.")
|
448
|
+
return await on_continue(state)
|
449
|
+
|
450
|
+
|
451
|
+
def _resolve_agent(agent: Agent, name: str | None, description: str | None) -> Agent:
|
452
|
+
if name is not None or description is not None:
|
453
|
+
return agent_with(agent, name=name, description=description)
|
454
|
+
else:
|
455
|
+
return agent
|
456
|
+
|
457
|
+
|
322
458
|
def _remove_submit_tool(
|
323
459
|
messages: list[ChatMessage], submit_name: str
|
324
460
|
) -> list[ChatMessage]:
|
inspect_ai/agent/_types.py
CHANGED
@@ -22,7 +22,9 @@ see the result of tool calls right after sending the message. If you need
|
|
22
22
|
to perform multiple actions, you can always send more messages with additional
|
23
23
|
tool calls. Do some reasoning before your actions, describing what tool calls
|
24
24
|
you are going to use and how they fit into your plan.
|
25
|
+
"""
|
25
26
|
|
27
|
+
DEFAULT_SUBMIT_PROMPT = """
|
26
28
|
When you have completed the task and have an answer, call the {submit}()
|
27
29
|
tool to report it.
|
28
30
|
"""
|
@@ -38,7 +40,14 @@ class AgentPrompt(NamedTuple):
|
|
38
40
|
"""Prompt used when there are additional handoff agents active."""
|
39
41
|
|
40
42
|
assistant_prompt: str | None = DEFAULT_ASSISTANT_PROMPT
|
41
|
-
"""Prompt for assistant (covers tool use,
|
43
|
+
"""Prompt for assistant (covers tool use, CoT, etc.)."""
|
44
|
+
|
45
|
+
submit_prompt: str | None = DEFAULT_SUBMIT_PROMPT
|
46
|
+
"""Prompt to tell the model about the submit tool.
|
47
|
+
|
48
|
+
This prompt is not used if the `assistant_prompt` contains a
|
49
|
+
{submit} placeholder.
|
50
|
+
"""
|
42
51
|
|
43
52
|
|
44
53
|
DEFAULT_CONTINUE_PROMPT = """
|
@@ -46,6 +55,10 @@ Please proceed to the next step using your best judgement. If you believe you
|
|
46
55
|
have completed the task, please call the `{submit}()` tool with your final answer.
|
47
56
|
"""
|
48
57
|
|
58
|
+
DEFAULT_CONTINUE_PROMOT_NO_SUBMIT = """
|
59
|
+
Please proceed to the next step using your best judgement.
|
60
|
+
"""
|
61
|
+
|
49
62
|
|
50
63
|
AgentContinue: TypeAlias = Callable[[AgentState], Awaitable[bool | str]]
|
51
64
|
"""Function called to determine whether the agent should continue.
|
@@ -7,6 +7,8 @@ from jsonpath_ng import JSONPath # type: ignore
|
|
7
7
|
from jsonpath_ng.ext import parse # type: ignore
|
8
8
|
from pydantic import JsonValue
|
9
9
|
|
10
|
+
from inspect_ai.log._log import EvalLog
|
11
|
+
|
10
12
|
from .validate import jsonpath_in_schema
|
11
13
|
|
12
14
|
ColumnType: TypeAlias = int | float | bool | str | date | time | datetime | None
|
@@ -122,24 +124,17 @@ class ColumnError:
|
|
122
124
|
path: str | None
|
123
125
|
"""Path to select column value. """
|
124
126
|
|
125
|
-
|
126
|
-
"""
|
127
|
+
error: Exception
|
128
|
+
"""Underlying error."""
|
129
|
+
|
130
|
+
log: EvalLog
|
131
|
+
"""Eval log where the error occurred.
|
132
|
+
|
133
|
+
Use log.location to determine the path where the log was read from.
|
134
|
+
"""
|
127
135
|
|
128
136
|
def __str__(self) -> str:
|
129
137
|
msg = f"Error reading column '{self.column}'"
|
130
138
|
if self.path:
|
131
139
|
msg = f"{msg} from path '{self.path}'"
|
132
|
-
return f"{msg}: {self.
|
133
|
-
|
134
|
-
|
135
|
-
class ColumnErrors(dict[str, list[ColumnError]]):
|
136
|
-
"""Dictionary of column errors keyed by log file."""
|
137
|
-
|
138
|
-
def __str__(self) -> str:
|
139
|
-
lines: list[str] = [""]
|
140
|
-
for file, errors in self.items():
|
141
|
-
lines.append(file)
|
142
|
-
for error in errors:
|
143
|
-
lines.append(f" - {error}")
|
144
|
-
lines.append("")
|
145
|
-
return "\n".join(lines)
|
140
|
+
return f"{msg}: {self.error} (log: {self.log.location})"
|
@@ -1,15 +1,16 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from logging import getLogger
|
4
|
+
from typing import TYPE_CHECKING, Callable, Literal, Sequence, overload
|
4
5
|
|
5
|
-
from inspect_ai.
|
6
|
-
from inspect_ai.analysis.beta._dataframe.progress import import_progress
|
6
|
+
from inspect_ai.analysis.beta._dataframe.progress import import_progress, no_progress
|
7
7
|
from inspect_ai.log._file import (
|
8
8
|
list_eval_logs,
|
9
9
|
read_eval_log,
|
10
10
|
)
|
11
|
+
from inspect_ai.log._log import EvalLog
|
11
12
|
|
12
|
-
from ..columns import Column,
|
13
|
+
from ..columns import Column, ColumnError, ColumnType
|
13
14
|
from ..record import import_record, resolve_duplicate_columns
|
14
15
|
from ..util import (
|
15
16
|
LogPaths,
|
@@ -21,6 +22,8 @@ from ..util import (
|
|
21
22
|
)
|
22
23
|
from .columns import EvalColumns, EvalId
|
23
24
|
|
25
|
+
logger = getLogger(__name__)
|
26
|
+
|
24
27
|
if TYPE_CHECKING:
|
25
28
|
import pandas as pd
|
26
29
|
|
@@ -31,24 +34,27 @@ EVAL_SUFFIX = "_eval"
|
|
31
34
|
@overload
|
32
35
|
def evals_df(
|
33
36
|
logs: LogPaths = list_eval_logs(),
|
34
|
-
columns:
|
37
|
+
columns: Sequence[Column] = EvalColumns,
|
35
38
|
strict: Literal[True] = True,
|
39
|
+
quiet: bool = False,
|
36
40
|
) -> "pd.DataFrame": ...
|
37
41
|
|
38
42
|
|
39
43
|
@overload
|
40
44
|
def evals_df(
|
41
45
|
logs: LogPaths = list_eval_logs(),
|
42
|
-
columns:
|
46
|
+
columns: Sequence[Column] = EvalColumns,
|
43
47
|
strict: Literal[False] = False,
|
44
|
-
|
48
|
+
quiet: bool = False,
|
49
|
+
) -> tuple["pd.DataFrame", Sequence[ColumnError]]: ...
|
45
50
|
|
46
51
|
|
47
52
|
def evals_df(
|
48
53
|
logs: LogPaths = list_eval_logs(),
|
49
|
-
columns:
|
54
|
+
columns: Sequence[Column] = EvalColumns,
|
50
55
|
strict: bool = True,
|
51
|
-
|
56
|
+
quiet: bool = False,
|
57
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", Sequence[ColumnError]]:
|
52
58
|
"""Read a dataframe containing evals.
|
53
59
|
|
54
60
|
Args:
|
@@ -58,6 +64,7 @@ def evals_df(
|
|
58
64
|
columns: Specification for what columns to read from log files.
|
59
65
|
strict: Raise import errors immediately. Defaults to `True`.
|
60
66
|
If `False` then a tuple of `DataFrame` and errors is returned.
|
67
|
+
quiet: If `True`, do not show any output or progress. Defaults to `False`.
|
61
68
|
|
62
69
|
Returns:
|
63
70
|
For `strict`, a Pandas `DataFrame` with information for the specified logs.
|
@@ -69,70 +76,86 @@ def evals_df(
|
|
69
76
|
# resolve logs
|
70
77
|
log_paths = resolve_logs(logs)
|
71
78
|
|
72
|
-
|
79
|
+
# establish progress
|
80
|
+
progress_cm = (
|
81
|
+
import_progress("reading logs", total=len(log_paths))
|
82
|
+
if not quiet
|
83
|
+
else no_progress()
|
84
|
+
)
|
85
|
+
|
86
|
+
with progress_cm as p:
|
73
87
|
if strict:
|
74
|
-
evals_table, _ = _read_evals_df(
|
75
|
-
log_paths, columns, True, lambda: p.update(task_id, advance=1)
|
76
|
-
)
|
88
|
+
evals_table, _, _ = _read_evals_df(log_paths, columns, True, p.update)
|
77
89
|
return evals_table
|
78
90
|
else:
|
79
|
-
evals_table, all_errors, _ = _read_evals_df(
|
80
|
-
log_paths, columns, False,
|
91
|
+
evals_table, _, all_errors, _ = _read_evals_df(
|
92
|
+
log_paths, columns, False, p.update
|
81
93
|
)
|
82
94
|
return evals_table, all_errors
|
83
95
|
|
84
96
|
|
85
97
|
@overload
|
86
98
|
def _read_evals_df(
|
87
|
-
log_paths:
|
88
|
-
columns:
|
99
|
+
log_paths: Sequence[str],
|
100
|
+
columns: Sequence[Column],
|
89
101
|
strict: Literal[True],
|
90
102
|
progress: Callable[[], None],
|
91
|
-
) -> tuple["pd.DataFrame", int]: ...
|
103
|
+
) -> tuple["pd.DataFrame", Sequence[EvalLog], int]: ...
|
92
104
|
|
93
105
|
|
94
106
|
@overload
|
95
107
|
def _read_evals_df(
|
96
|
-
log_paths:
|
97
|
-
columns:
|
108
|
+
log_paths: Sequence[str],
|
109
|
+
columns: Sequence[Column],
|
98
110
|
strict: Literal[False],
|
99
111
|
progress: Callable[[], None],
|
100
|
-
) -> tuple["pd.DataFrame",
|
112
|
+
) -> tuple["pd.DataFrame", Sequence[EvalLog], Sequence[ColumnError], int]: ...
|
101
113
|
|
102
114
|
|
103
115
|
def _read_evals_df(
|
104
|
-
log_paths:
|
105
|
-
columns:
|
116
|
+
log_paths: Sequence[str],
|
117
|
+
columns: Sequence[Column],
|
106
118
|
strict: bool,
|
107
119
|
progress: Callable[[], None],
|
108
|
-
) ->
|
120
|
+
) -> (
|
121
|
+
tuple["pd.DataFrame", Sequence[EvalLog], int]
|
122
|
+
| tuple["pd.DataFrame", Sequence[EvalLog], Sequence[ColumnError], int]
|
123
|
+
):
|
109
124
|
verify_prerequisites()
|
110
125
|
|
111
126
|
# resolve duplicate columns
|
112
127
|
columns = resolve_duplicate_columns(columns)
|
113
128
|
|
114
129
|
# accumulate errors for strict=False
|
115
|
-
all_errors =
|
130
|
+
all_errors: list[ColumnError] = []
|
116
131
|
|
117
132
|
# ensure eval_id
|
118
|
-
ensure_eval_id(columns)
|
133
|
+
columns = ensure_eval_id(columns)
|
119
134
|
|
120
135
|
# read logs
|
121
136
|
total_samples = 0
|
137
|
+
eval_ids: set[str] = set()
|
138
|
+
eval_logs: list[EvalLog] = []
|
122
139
|
records: list[dict[str, ColumnType]] = []
|
123
140
|
for log_path in log_paths:
|
124
141
|
log = read_eval_log(log_path, header_only=True)
|
125
142
|
if strict:
|
126
|
-
record = import_record(log, columns, strict=True)
|
143
|
+
record = import_record(log, log, columns, strict=True)
|
127
144
|
else:
|
128
|
-
record, errors = import_record(log, columns, strict=False)
|
129
|
-
all_errors
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
145
|
+
record, errors = import_record(log, log, columns, strict=False)
|
146
|
+
all_errors.extend(errors)
|
147
|
+
|
148
|
+
# don't add duplicate ids
|
149
|
+
eval_id = str(record.get(EVAL_ID, ""))
|
150
|
+
if eval_id not in eval_ids:
|
151
|
+
eval_ids.add(eval_id)
|
152
|
+
eval_logs.append(log)
|
153
|
+
records.append(record)
|
154
|
+
total_samples += (
|
155
|
+
len(log.eval.dataset.sample_ids)
|
156
|
+
if log.eval.dataset.sample_ids is not None
|
157
|
+
else (log.eval.dataset.samples or 100)
|
158
|
+
)
|
136
159
|
progress()
|
137
160
|
|
138
161
|
# return table (+errors if strict=False)
|
@@ -140,18 +163,20 @@ def _read_evals_df(
|
|
140
163
|
evals_table = reorder_evals_df_columns(evals_table, columns)
|
141
164
|
|
142
165
|
if strict:
|
143
|
-
return evals_table, total_samples
|
166
|
+
return evals_table, eval_logs, total_samples
|
144
167
|
else:
|
145
|
-
return evals_table, all_errors, total_samples
|
168
|
+
return evals_table, eval_logs, all_errors, total_samples
|
146
169
|
|
147
170
|
|
148
|
-
def ensure_eval_id(columns:
|
171
|
+
def ensure_eval_id(columns: Sequence[Column]) -> Sequence[Column]:
|
149
172
|
if not any([column.name == EVAL_ID for column in columns]):
|
150
|
-
columns
|
173
|
+
return list(columns) + EvalId
|
174
|
+
else:
|
175
|
+
return columns
|
151
176
|
|
152
177
|
|
153
178
|
def reorder_evals_df_columns(
|
154
|
-
df: "pd.DataFrame", eval_columns:
|
179
|
+
df: "pd.DataFrame", eval_columns: Sequence[Column]
|
155
180
|
) -> "pd.DataFrame":
|
156
181
|
actual_columns = list(df.columns)
|
157
182
|
ordered_columns: list[str] = []
|