inspect-ai 0.3.82__py3-none-any.whl → 0.3.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +3 -4
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +48 -0
- inspect_ai/_eval/eval.py +36 -24
- inspect_ai/_eval/evalset.py +17 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +8 -13
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +2 -0
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/file.py +13 -0
- inspect_ai/_util/json.py +2 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/www/App.css +6 -0
- inspect_ai/_view/www/dist/assets/index.css +304 -128
- inspect_ai/_view/www/dist/assets/index.js +47495 -27519
- inspect_ai/_view/www/log-schema.json +124 -31
- inspect_ai/_view/www/package.json +3 -0
- inspect_ai/_view/www/src/App.tsx +12 -0
- inspect_ai/_view/www/src/appearance/icons.ts +1 -0
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
- inspect_ai/_view/www/src/state/hooks.ts +5 -3
- inspect_ai/_view/www/src/state/logPolling.ts +5 -1
- inspect_ai/_view/www/src/state/logSlice.ts +10 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
- inspect_ai/_view/www/src/types/log.d.ts +34 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
- inspect_ai/_view/www/yarn.lock +94 -1
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_log.py +11 -2
- inspect_ai/log/_transcript.py +13 -9
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +256 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_model.py +113 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +2 -2
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/RECORD +179 -153
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/top_level.txt +0 -0
inspect_ai/__init__.py
CHANGED
@@ -10,7 +10,8 @@ from inspect_ai._eval.score import score, score_async
|
|
10
10
|
from inspect_ai._eval.task import Epochs, Task, TaskInfo, task_with
|
11
11
|
from inspect_ai._eval.task.tasks import Tasks
|
12
12
|
from inspect_ai._util.constants import PKG_NAME
|
13
|
-
from inspect_ai.
|
13
|
+
from inspect_ai.agent._human.agent import human_cli
|
14
|
+
from inspect_ai.solver._human_agent import human_agent
|
14
15
|
|
15
16
|
__version__ = importlib_version(PKG_NAME)
|
16
17
|
|
@@ -58,10 +58,12 @@ class TaskScreenResult(Generic[TR]):
|
|
58
58
|
value: TR | BaseException,
|
59
59
|
tasks: list[TaskWithResult],
|
60
60
|
output: list[str],
|
61
|
+
warnings: list[str],
|
61
62
|
) -> None:
|
62
63
|
self.value = value
|
63
64
|
self.tasks = tasks
|
64
65
|
self.output = output
|
66
|
+
self.warnings = warnings
|
65
67
|
|
66
68
|
|
67
69
|
class TaskScreenApp(App[TR]):
|
@@ -86,6 +88,7 @@ class TaskScreenApp(App[TR]):
|
|
86
88
|
self._worker: Worker[TR] | None = None
|
87
89
|
self._error: BaseException | None = None
|
88
90
|
self._output: list[str] = []
|
91
|
+
self._warnings: list[str] = []
|
89
92
|
|
90
93
|
# task screen
|
91
94
|
self._total_tasks = 0
|
@@ -120,7 +123,12 @@ class TaskScreenApp(App[TR]):
|
|
120
123
|
value = CancelledError()
|
121
124
|
|
122
125
|
# return result w/ output
|
123
|
-
return TaskScreenResult(
|
126
|
+
return TaskScreenResult(
|
127
|
+
value=value,
|
128
|
+
tasks=self._app_tasks,
|
129
|
+
output=self._output,
|
130
|
+
warnings=self._warnings,
|
131
|
+
)
|
124
132
|
|
125
133
|
async def on_load(self) -> None:
|
126
134
|
# events used to synchronise loading
|
@@ -349,8 +357,11 @@ class TaskScreenApp(App[TR]):
|
|
349
357
|
if text.endswith("\n"):
|
350
358
|
text = text[:-1]
|
351
359
|
|
352
|
-
# track output (for printing at the end)
|
353
|
-
|
360
|
+
# track output and warnings (for printing at the end)
|
361
|
+
if "WARNING" in text:
|
362
|
+
self._warnings.append(text)
|
363
|
+
else:
|
364
|
+
self._output.append(text)
|
354
365
|
|
355
366
|
# write to console view
|
356
367
|
self.query_one(ConsoleView).write_ansi(text)
|
@@ -42,6 +42,10 @@ class TextualDisplay(Display):
|
|
42
42
|
# print tasks
|
43
43
|
rich.print(tasks_results(result.tasks))
|
44
44
|
|
45
|
+
# print warnings
|
46
|
+
if result.warnings:
|
47
|
+
print("\n".join(result.warnings))
|
48
|
+
|
45
49
|
# raise error as required
|
46
50
|
if isinstance(result.value, BaseException):
|
47
51
|
raise result.value
|
@@ -17,7 +17,7 @@ from textual.widgets import (
|
|
17
17
|
OptionList,
|
18
18
|
Static,
|
19
19
|
)
|
20
|
-
from textual.widgets.option_list import Option
|
20
|
+
from textual.widgets.option_list import Option, OptionDoesNotExist
|
21
21
|
|
22
22
|
from inspect_ai._display.textual.widgets.port_mappings import get_url
|
23
23
|
from inspect_ai._util.format import format_progress_time
|
@@ -124,7 +124,7 @@ class SamplesList(OptionList):
|
|
124
124
|
def set_samples(self, samples: list[ActiveSample]) -> None:
|
125
125
|
# check for a highlighted sample (make sure we don't remove it)
|
126
126
|
highlighted_id = (
|
127
|
-
self.
|
127
|
+
self.get_id_at_index(self.highlighted)
|
128
128
|
if self.highlighted is not None
|
129
129
|
else None
|
130
130
|
)
|
@@ -179,12 +179,18 @@ class SamplesList(OptionList):
|
|
179
179
|
self.scroll_to_highlight()
|
180
180
|
|
181
181
|
def sample_for_highlighted(self, highlighted: int) -> ActiveSample | None:
|
182
|
-
highlighted_id = self.
|
182
|
+
highlighted_id = self.get_id_at_index(highlighted)
|
183
183
|
if highlighted_id is not None:
|
184
184
|
return sample_for_id(self.samples, highlighted_id)
|
185
185
|
else:
|
186
186
|
return None
|
187
187
|
|
188
|
+
def get_id_at_index(self, index: int) -> str | None:
|
189
|
+
try:
|
190
|
+
return self.get_option_at_index(index).id
|
191
|
+
except OptionDoesNotExist:
|
192
|
+
return None
|
193
|
+
|
188
194
|
|
189
195
|
class SampleVNC(Horizontal):
|
190
196
|
DEFAULT_CSS = """
|
@@ -221,12 +221,11 @@ class TaskMetrics(Widget):
|
|
221
221
|
self.recompute_grid()
|
222
222
|
|
223
223
|
def on_mount(self) -> None:
|
224
|
-
self.recompute_grid()
|
224
|
+
self.recompute_grid(True)
|
225
225
|
|
226
|
-
def recompute_grid(self) -> None:
|
227
|
-
if not self.is_mounted:
|
226
|
+
def recompute_grid(self, force: bool = False) -> None:
|
227
|
+
if not self.is_mounted and not force:
|
228
228
|
return
|
229
|
-
|
230
229
|
grid = self.query_one(f"#{self.grid_id()}")
|
231
230
|
|
232
231
|
grid.remove_children()
|
@@ -17,6 +17,11 @@ from inspect_ai._display.core.results import task_metric
|
|
17
17
|
from inspect_ai._display.textual.widgets.clock import Clock
|
18
18
|
from inspect_ai._display.textual.widgets.task_detail import TaskDetail
|
19
19
|
from inspect_ai._display.textual.widgets.toggle import Toggle
|
20
|
+
from inspect_ai._display.textual.widgets.vscode import conditional_vscode_link
|
21
|
+
from inspect_ai._util.file import to_uri
|
22
|
+
from inspect_ai._util.vscode import (
|
23
|
+
VSCodeCommand,
|
24
|
+
)
|
20
25
|
|
21
26
|
from ...core.display import (
|
22
27
|
Progress,
|
@@ -151,7 +156,7 @@ class TaskProgressView(Widget):
|
|
151
156
|
height: auto;
|
152
157
|
width: 1fr;
|
153
158
|
layout: grid;
|
154
|
-
grid-size:
|
159
|
+
grid-size: 9 2;
|
155
160
|
grid-columns: auto auto auto auto 1fr auto auto auto;
|
156
161
|
grid-rows: auto auto;
|
157
162
|
grid-gutter: 0 1;
|
@@ -200,6 +205,15 @@ class TaskProgressView(Widget):
|
|
200
205
|
|
201
206
|
self.sample_count_width: int = sample_count_width
|
202
207
|
self.display_metrics = display_metrics
|
208
|
+
self.view_log_link = conditional_vscode_link(
|
209
|
+
"[View Log]",
|
210
|
+
VSCodeCommand(
|
211
|
+
command="inspect.openLogViewer",
|
212
|
+
args=[to_uri(task.profile.log_location)]
|
213
|
+
if task.profile.log_location
|
214
|
+
else [],
|
215
|
+
),
|
216
|
+
)
|
203
217
|
|
204
218
|
metrics: reactive[list[TaskDisplayMetric] | None] = reactive(None)
|
205
219
|
metrics_width: reactive[int | None] = reactive(None)
|
@@ -222,6 +236,8 @@ class TaskProgressView(Widget):
|
|
222
236
|
yield self.count_display
|
223
237
|
yield self.metrics_display
|
224
238
|
yield Clock()
|
239
|
+
yield self.view_log_link
|
240
|
+
|
225
241
|
yield self.task_detail
|
226
242
|
|
227
243
|
@on(Toggle.Toggled)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from textual.widget import Widget
|
2
|
+
from textual.widgets import Link, Static
|
3
|
+
|
4
|
+
from inspect_ai._util.vscode import (
|
5
|
+
VSCodeCommand,
|
6
|
+
can_execute_vscode_command,
|
7
|
+
execute_vscode_commands,
|
8
|
+
)
|
9
|
+
|
10
|
+
|
11
|
+
def conditional_vscode_link(text: str, command: VSCodeCommand) -> Widget:
|
12
|
+
if can_execute_vscode_command(command.command):
|
13
|
+
vscode_link = VSCodeLink(text)
|
14
|
+
vscode_link.commands = [command]
|
15
|
+
return vscode_link
|
16
|
+
else:
|
17
|
+
return Static()
|
18
|
+
|
19
|
+
|
20
|
+
class VSCodeLink(Link):
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
text: str,
|
24
|
+
*,
|
25
|
+
url: str | None = None,
|
26
|
+
tooltip: str | None = None,
|
27
|
+
name: str | None = None,
|
28
|
+
id: str | None = None,
|
29
|
+
classes: str | None = None,
|
30
|
+
disabled: bool = False,
|
31
|
+
) -> None:
|
32
|
+
super().__init__(
|
33
|
+
text,
|
34
|
+
url=url,
|
35
|
+
tooltip=tooltip,
|
36
|
+
name=name,
|
37
|
+
id=id,
|
38
|
+
classes=classes,
|
39
|
+
disabled=disabled,
|
40
|
+
)
|
41
|
+
self.commands: list[VSCodeCommand] = []
|
42
|
+
|
43
|
+
def on_click(self) -> None:
|
44
|
+
execute_vscode_commands(self.commands)
|
45
|
+
|
46
|
+
def action_open_link(self) -> None:
|
47
|
+
# Workaround to prevent the default action of opening the link in a browser
|
48
|
+
return None
|
inspect_ai/_eval/eval.py
CHANGED
@@ -2,9 +2,11 @@ import logging
|
|
2
2
|
import os
|
3
3
|
import sys
|
4
4
|
from pathlib import Path
|
5
|
-
from typing import Any, Literal
|
5
|
+
from typing import Any, Literal, cast
|
6
6
|
|
7
7
|
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
|
8
|
+
from inspect_ai.agent._agent import Agent, is_agent
|
9
|
+
from inspect_ai.agent._as_solver import as_solver
|
8
10
|
|
9
11
|
if sys.version_info < (3, 11):
|
10
12
|
from exceptiongroup import ExceptionGroup
|
@@ -71,7 +73,7 @@ def eval(
|
|
71
73
|
task_args: dict[str, Any] | str = dict(),
|
72
74
|
sandbox: SandboxEnvironmentType | None = None,
|
73
75
|
sandbox_cleanup: bool | None = None,
|
74
|
-
solver: Solver | list[Solver] |
|
76
|
+
solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
|
75
77
|
tags: list[str] | None = None,
|
76
78
|
metadata: dict[str, Any] | None = None,
|
77
79
|
trace: bool | None = None,
|
@@ -246,7 +248,7 @@ async def eval_async(
|
|
246
248
|
task_args: dict[str, Any] | str = dict(),
|
247
249
|
sandbox: SandboxEnvironmentType | None = None,
|
248
250
|
sandbox_cleanup: bool | None = None,
|
249
|
-
solver: Solver | list[Solver] |
|
251
|
+
solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
|
250
252
|
tags: list[str] | None = None,
|
251
253
|
metadata: dict[str, Any] | None = None,
|
252
254
|
approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
|
@@ -353,13 +355,10 @@ async def eval_async(
|
|
353
355
|
|
354
356
|
try:
|
355
357
|
# intialise eval
|
356
|
-
model, approval
|
357
|
-
tasks=tasks,
|
358
|
+
model, approval = eval_init(
|
358
359
|
model=model,
|
359
360
|
model_base_url=model_base_url,
|
360
361
|
model_args=model_args,
|
361
|
-
task_args=task_args,
|
362
|
-
sandbox=sandbox,
|
363
362
|
approval=approval,
|
364
363
|
max_subprocesses=max_subprocesses,
|
365
364
|
log_level=log_level,
|
@@ -367,6 +366,11 @@ async def eval_async(
|
|
367
366
|
**kwargs,
|
368
367
|
)
|
369
368
|
|
369
|
+
# resolve tasks
|
370
|
+
resolved_tasks = eval_resolve_tasks(
|
371
|
+
tasks, task_args, model, GenerateConfig(**kwargs), sandbox
|
372
|
+
)
|
373
|
+
|
370
374
|
# warn and return empty string if we resolved no tasks
|
371
375
|
if len(resolved_tasks) == 0:
|
372
376
|
log.warning("No inspect tasks were found at the specified paths.")
|
@@ -412,7 +416,12 @@ async def eval_async(
|
|
412
416
|
)
|
413
417
|
|
414
418
|
# resolve solver
|
415
|
-
|
419
|
+
if isinstance(solver, list):
|
420
|
+
solver = chain(solver)
|
421
|
+
elif is_agent(solver):
|
422
|
+
solver = as_solver(solver)
|
423
|
+
else:
|
424
|
+
solver = cast(Solver | SolverSpec | None, solver)
|
416
425
|
|
417
426
|
# ensure consistency of limit and sample_id
|
418
427
|
if sample_id is not None and limit is not None:
|
@@ -724,7 +733,7 @@ async def eval_retry_async(
|
|
724
733
|
# context to reconstruct ephemeral Task instances)
|
725
734
|
task: str | None
|
726
735
|
task_id = eval_log.eval.task_id
|
727
|
-
task_name = eval_log.eval.task
|
736
|
+
task_name = eval_log.eval.task_registry_name or eval_log.eval.task
|
728
737
|
task_file = eval_log.eval.task_file
|
729
738
|
if task_file:
|
730
739
|
if not Path(task_file).exists():
|
@@ -846,24 +855,20 @@ async def eval_retry_async(
|
|
846
855
|
|
847
856
|
|
848
857
|
def eval_init(
|
849
|
-
tasks: Tasks,
|
850
858
|
model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
|
851
859
|
model_base_url: str | None = None,
|
852
860
|
model_args: dict[str, Any] | str = dict(),
|
853
|
-
task_args: dict[str, Any] | str = dict(),
|
854
|
-
sandbox: SandboxEnvironmentType | None = None,
|
855
861
|
approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
|
856
862
|
max_subprocesses: int | None = None,
|
857
863
|
log_level: str | None = None,
|
858
864
|
log_level_transcript: str | None = None,
|
859
865
|
**kwargs: Unpack[GenerateConfigArgs],
|
860
|
-
) -> tuple[list[Model], list[ApprovalPolicy] | None
|
866
|
+
) -> tuple[list[Model], list[ApprovalPolicy] | None]:
|
861
867
|
# init eval context
|
862
868
|
init_eval_context(log_level, log_level_transcript, max_subprocesses)
|
863
869
|
|
864
870
|
# resolve model and task args
|
865
871
|
model_args = resolve_args(model_args)
|
866
|
-
task_args = resolve_args(task_args)
|
867
872
|
|
868
873
|
# resolve model args from environment if not specified
|
869
874
|
if len(model_args) == 0:
|
@@ -876,21 +881,28 @@ def eval_init(
|
|
876
881
|
generate_config = GenerateConfig(**kwargs)
|
877
882
|
models = resolve_models(model, model_base_url, model_args, generate_config)
|
878
883
|
|
879
|
-
# resolve tasks (set active model to resolve uses of the
|
880
|
-
# 'default' model in tools, solvers, and scorers)
|
881
|
-
|
882
|
-
with task_display().suspend_task_app():
|
883
|
-
resolved_tasks: list[ResolvedTask] = []
|
884
|
-
for m in models:
|
885
|
-
init_active_model(m, generate_config)
|
886
|
-
resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox))
|
887
|
-
|
888
884
|
# resolve approval
|
889
885
|
if isinstance(approval, str | ApprovalPolicyConfig):
|
890
886
|
approval = approval_policies_from_config(approval)
|
891
887
|
init_tool_approval(approval)
|
892
888
|
|
893
|
-
return models, approval
|
889
|
+
return models, approval
|
890
|
+
|
891
|
+
|
892
|
+
def eval_resolve_tasks(
|
893
|
+
tasks: Tasks,
|
894
|
+
task_args: dict[str, Any] | str,
|
895
|
+
models: list[Model],
|
896
|
+
config: GenerateConfig,
|
897
|
+
sandbox: SandboxEnvironmentType | None,
|
898
|
+
) -> list[ResolvedTask]:
|
899
|
+
task_args = resolve_args(task_args)
|
900
|
+
with task_display().suspend_task_app():
|
901
|
+
resolved_tasks: list[ResolvedTask] = []
|
902
|
+
for m in models:
|
903
|
+
init_active_model(m, config)
|
904
|
+
resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox))
|
905
|
+
return resolved_tasks
|
894
906
|
|
895
907
|
|
896
908
|
def init_eval_display(
|
inspect_ai/_eval/evalset.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
import hashlib
|
2
2
|
import logging
|
3
|
-
from copy import deepcopy
|
4
3
|
from typing import Any, Literal, NamedTuple, Set, cast
|
5
4
|
|
6
5
|
import rich
|
@@ -18,6 +17,7 @@ from typing_extensions import Unpack
|
|
18
17
|
from inspect_ai._util.error import PrerequisiteError
|
19
18
|
from inspect_ai._util.file import basename, filesystem
|
20
19
|
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
|
20
|
+
from inspect_ai.agent._agent import Agent
|
21
21
|
from inspect_ai.approval._policy import ApprovalPolicy
|
22
22
|
from inspect_ai.log import EvalLog
|
23
23
|
from inspect_ai.log._bundle import bundle_log_dir
|
@@ -37,7 +37,7 @@ from inspect_ai.solver._solver import Solver, SolverSpec
|
|
37
37
|
from inspect_ai.util import DisplayType, SandboxEnvironmentType
|
38
38
|
from inspect_ai.util._display import display_type_initialized, init_display_type
|
39
39
|
|
40
|
-
from .eval import eval, eval_init
|
40
|
+
from .eval import eval, eval_init, eval_resolve_tasks
|
41
41
|
from .loader import resolve_task_args
|
42
42
|
from .task import Epochs
|
43
43
|
from .task.resolved import ResolvedTask
|
@@ -66,7 +66,7 @@ def eval_set(
|
|
66
66
|
task_args: dict[str, Any] | str = dict(),
|
67
67
|
sandbox: SandboxEnvironmentType | None = None,
|
68
68
|
sandbox_cleanup: bool | None = None,
|
69
|
-
solver: Solver | list[Solver] |
|
69
|
+
solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
|
70
70
|
tags: list[str] | None = None,
|
71
71
|
metadata: dict[str, Any] | None = None,
|
72
72
|
trace: bool | None = None,
|
@@ -247,29 +247,21 @@ def eval_set(
|
|
247
247
|
if display == "conversation":
|
248
248
|
raise RuntimeError("eval_set cannot be used with conversation display.")
|
249
249
|
|
250
|
-
#
|
251
|
-
models, _
|
252
|
-
tasks=tasks,
|
250
|
+
# initialize eval
|
251
|
+
models, _ = eval_init(
|
253
252
|
model=model,
|
254
253
|
model_base_url=model_base_url,
|
255
254
|
model_args=model_args,
|
256
|
-
task_args=task_args,
|
257
|
-
sandbox=sandbox,
|
258
255
|
max_subprocesses=max_subprocesses,
|
259
256
|
log_level=log_level,
|
260
257
|
log_level_transcript=log_level_transcript,
|
261
258
|
**kwargs,
|
262
259
|
)
|
263
260
|
|
264
|
-
# ensure log_dir
|
261
|
+
# ensure log_dir
|
265
262
|
fs = filesystem(log_dir)
|
266
263
|
fs.mkdir(log_dir, exist_ok=True)
|
267
264
|
|
268
|
-
# validate that:
|
269
|
-
# (1) All tasks have a unique identifier
|
270
|
-
# (2) All logs have identifiers that map to tasks
|
271
|
-
validate_eval_set_prerequisites(resolved_tasks, list_all_eval_logs(log_dir))
|
272
|
-
|
273
265
|
# resolve some parameters
|
274
266
|
retry_connections = retry_connections or 0.5
|
275
267
|
retry_cleanup = retry_cleanup is not False
|
@@ -310,11 +302,21 @@ def eval_set(
|
|
310
302
|
# - tasks with a successful log (they'll just be returned)
|
311
303
|
# - tasks with failed logs (they'll be retried)
|
312
304
|
def try_eval() -> list[EvalLog]:
|
305
|
+
# resolve tasks
|
306
|
+
resolved_tasks = eval_resolve_tasks(
|
307
|
+
tasks, task_args, models, GenerateConfig(**kwargs), sandbox
|
308
|
+
)
|
309
|
+
|
313
310
|
# list all logs currently in the log directory (update manifest if there are some)
|
314
311
|
all_logs = list_all_eval_logs(log_dir)
|
315
312
|
if len(all_logs) > 0:
|
316
313
|
write_log_dir_manifest(log_dir)
|
317
314
|
|
315
|
+
# validate that:
|
316
|
+
# (1) All tasks have a unique identifier
|
317
|
+
# (2) All logs have identifiers that map to tasks
|
318
|
+
validate_eval_set_prerequisites(resolved_tasks, all_logs)
|
319
|
+
|
318
320
|
# see which tasks are yet to run (to complete successfully we need
|
319
321
|
# a successful eval for every [task_file/]task_name/model combination)
|
320
322
|
# for those that haven't run, schedule them into models => tasks groups
|
@@ -419,13 +421,10 @@ def as_previous_tasks(
|
|
419
421
|
# want to bring this back but we'd need to resolve the
|
420
422
|
# directory issues.
|
421
423
|
|
422
|
-
# deepcopy so the same instance is not run twice
|
423
|
-
prev_task = deepcopy(task.task)
|
424
|
-
|
425
424
|
previous_tasks.append(
|
426
425
|
PreviousTask(
|
427
426
|
id=log.header.eval.task_id,
|
428
|
-
task=
|
427
|
+
task=task.task,
|
429
428
|
task_args=resolve_task_args(task.task),
|
430
429
|
model=task.model,
|
431
430
|
log=read_eval_log(log.info),
|
inspect_ai/_eval/loader.py
CHANGED
@@ -26,6 +26,8 @@ from inspect_ai._util.registry import (
|
|
26
26
|
registry_lookup,
|
27
27
|
registry_params,
|
28
28
|
)
|
29
|
+
from inspect_ai.agent._agent import Agent
|
30
|
+
from inspect_ai.agent._as_solver import as_solver
|
29
31
|
from inspect_ai.model import Model
|
30
32
|
from inspect_ai.scorer._scorer import Scorer, ScorerSpec, scorer_create
|
31
33
|
from inspect_ai.solver._bridge import bridge
|
@@ -421,20 +423,32 @@ def solver_from_spec(spec: SolverSpec) -> Solver:
|
|
421
423
|
if solver_file is None:
|
422
424
|
if solver_name is None:
|
423
425
|
raise ValueError(f"Unable to resolve solver name from {spec.solver}")
|
424
|
-
|
426
|
+
elif registry_lookup("solver", solver_name) is not None:
|
427
|
+
return cast(Solver, registry_create("solver", solver_name, **spec.args))
|
428
|
+
elif registry_lookup("agent", solver_name) is not None:
|
429
|
+
agent = cast(Agent, registry_create("agent", solver_name, **spec.args))
|
430
|
+
return as_solver(agent)
|
431
|
+
else:
|
432
|
+
raise ValueError(
|
433
|
+
f"Unkonwn solver {solver_name} (not registered as a @solver or @agent)"
|
434
|
+
)
|
425
435
|
|
426
436
|
# we do have a solver file
|
427
437
|
else:
|
428
438
|
# load the module and parse decorators
|
429
439
|
solver_module = load_module(solver_file)
|
430
|
-
|
440
|
+
solver_decorators = parse_decorators(solver_file, "solver")
|
441
|
+
agent_decorators = parse_decorators(solver_file, "agent")
|
431
442
|
|
432
443
|
# if there is no solver_name see if we can discover it
|
433
444
|
if solver_name is None:
|
434
|
-
if len(
|
445
|
+
if len(solver_decorators) == 1:
|
435
446
|
# decorator based solver
|
436
|
-
solver_name =
|
437
|
-
elif len(
|
447
|
+
solver_name = solver_decorators[0][0]
|
448
|
+
elif len(agent_decorators) == 1:
|
449
|
+
# decorator based agent
|
450
|
+
solver_name = agent_decorators[0][0]
|
451
|
+
elif len(solver_decorators) == 0 and len(agent_decorators) == 0:
|
438
452
|
# see if we can find an agent based solver
|
439
453
|
functions = [
|
440
454
|
function
|
@@ -454,26 +468,35 @@ def solver_from_spec(spec: SolverSpec) -> Solver:
|
|
454
468
|
|
455
469
|
elif len(agent_functions) == 0:
|
456
470
|
raise PrerequisiteError(
|
457
|
-
f"The source file {pretty_solver_file} does not contain any @solver
|
471
|
+
f"The source file {pretty_solver_file} does not contain any @solver, @agent or bridged agent functions."
|
458
472
|
)
|
459
473
|
else:
|
460
474
|
raise PrerequisiteError(
|
461
|
-
f"The source file {pretty_solver_file} has more than one agent function (qualify which agent using e.g. '{solver_file.name}@agent_fn')"
|
475
|
+
f"The source file {pretty_solver_file} has more than one bridged agent function (qualify which agent using e.g. '{solver_file.name}@agent_fn')"
|
462
476
|
)
|
463
|
-
|
477
|
+
elif len(solver_decorators) > 1:
|
464
478
|
raise PrerequisiteError(
|
465
479
|
f"The source file {pretty_solver_file} has more than one @solver function (qualify which solver using e.g. '{solver_file.name}y@solver_fn')"
|
466
480
|
)
|
481
|
+
else:
|
482
|
+
raise PrerequisiteError(
|
483
|
+
f"The source file {pretty_solver_file} has more than one @agent function (qualify which agent using e.g. '{solver_file.name}y@agent_fn')"
|
484
|
+
)
|
467
485
|
|
468
486
|
# create decorator based solvers using the registry
|
469
|
-
if any(solver[0] == solver_name for solver in
|
487
|
+
if any(solver[0] == solver_name for solver in solver_decorators):
|
470
488
|
return cast(Solver, registry_create("solver", solver_name, **spec.args))
|
471
489
|
|
472
|
-
# create
|
490
|
+
# create decorator based agents using the registry
|
491
|
+
elif any(agent[0] == solver_name for agent in agent_decorators):
|
492
|
+
agent = cast(Agent, registry_create("agent", solver_name, **spec.args))
|
493
|
+
return as_solver(agent)
|
494
|
+
|
495
|
+
# create bridge based solvers by calling the function and wrapping it in bridge()
|
473
496
|
else:
|
474
497
|
agent_fn = getattr(solver_module, solver_name, None)
|
475
498
|
if inspect.isfunction(agent_fn):
|
476
|
-
return bridge
|
499
|
+
return bridge(agent_fn(**spec.args))
|
477
500
|
elif agent_fn is not None:
|
478
501
|
raise PrerequisiteError(
|
479
502
|
f"The object {solver_name} in file {pretty_solver_file} is not a Python function."
|
inspect_ai/_eval/run.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
import functools
|
2
1
|
import logging
|
3
2
|
import os
|
4
3
|
import sys
|
@@ -20,7 +19,6 @@ from inspect_ai._display.core.active import (
|
|
20
19
|
init_task_screen,
|
21
20
|
)
|
22
21
|
from inspect_ai._display.core.display import TaskSpec
|
23
|
-
from inspect_ai._util._async import tg_collect
|
24
22
|
from inspect_ai._util.error import PrerequisiteError, exception_message
|
25
23
|
from inspect_ai._util.path import chdir
|
26
24
|
from inspect_ai._util.registry import registry_unqualified_name
|
@@ -195,6 +193,7 @@ async def eval_run(
|
|
195
193
|
task_name=task.name,
|
196
194
|
task_version=task.version,
|
197
195
|
task_file=resolved_task.task_file,
|
196
|
+
task_registry_name=resolved_task.task.registry_name,
|
198
197
|
task_id=resolved_task.id if resolved_task.id else uuid(),
|
199
198
|
run_id=run_id,
|
200
199
|
solver=eval_solver_spec,
|
@@ -359,17 +358,13 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
|
|
359
358
|
"Run Task",
|
360
359
|
f"task: {task_options.task.name} ({task_options.model})",
|
361
360
|
):
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
else:
|
371
|
-
result = tg_results[0]
|
372
|
-
results.append(result)
|
361
|
+
async with anyio.create_task_group() as tg:
|
362
|
+
|
363
|
+
async def run_task() -> None:
|
364
|
+
result = await task_run(task_options)
|
365
|
+
results.append(result)
|
366
|
+
|
367
|
+
tg.start_soon(run_task)
|
373
368
|
|
374
369
|
except Exception as ex:
|
375
370
|
# errors generally don't escape from tasks (the exception being if an error
|
inspect_ai/_eval/score.py
CHANGED
@@ -7,8 +7,8 @@ import anyio
|
|
7
7
|
|
8
8
|
from inspect_ai._display import display
|
9
9
|
from inspect_ai._eval.loader import scorer_from_spec
|
10
|
-
from inspect_ai._util._async import tg_collect
|
11
|
-
from inspect_ai._util.platform import platform_init
|
10
|
+
from inspect_ai._util._async import configured_async_backend, run_coroutine, tg_collect
|
11
|
+
from inspect_ai._util.platform import platform_init, running_in_notebook
|
12
12
|
from inspect_ai._util.registry import registry_create, registry_unqualified_name
|
13
13
|
from inspect_ai.log import (
|
14
14
|
EvalLog,
|
@@ -56,7 +56,17 @@ def score(
|
|
56
56
|
# resolve scorers into a list
|
57
57
|
scorers = [scorers] if isinstance(scorers, Scorer) else scorers
|
58
58
|
|
59
|
-
|
59
|
+
if running_in_notebook():
|
60
|
+
return run_coroutine(score_async(log, scorers, epochs_reducer, action))
|
61
|
+
else:
|
62
|
+
return anyio.run(
|
63
|
+
score_async,
|
64
|
+
log,
|
65
|
+
scorers,
|
66
|
+
epochs_reducer,
|
67
|
+
action,
|
68
|
+
backend=configured_async_backend(),
|
69
|
+
)
|
60
70
|
|
61
71
|
|
62
72
|
async def score_async(
|
@@ -1,12 +1,8 @@
|
|
1
1
|
from typing import Literal
|
2
2
|
|
3
|
-
from inspect_ai.model import
|
4
|
-
CachePolicy,
|
5
|
-
GenerateConfig,
|
6
|
-
Model,
|
7
|
-
call_tools,
|
8
|
-
)
|
3
|
+
from inspect_ai.model import CachePolicy, GenerateConfig, Model
|
9
4
|
from inspect_ai.model._cache import epoch
|
5
|
+
from inspect_ai.model._call_tools import execute_tools
|
10
6
|
from inspect_ai.solver import TaskState
|
11
7
|
from inspect_ai.solver._limit import SampleLimitExceededError
|
12
8
|
from inspect_ai.tool import ToolFunction
|
@@ -48,10 +44,13 @@ async def task_generate(
|
|
48
44
|
|
49
45
|
# resolve tool calls if necessary
|
50
46
|
if tool_calls != "none" and message.tool_calls:
|
51
|
-
# call tools and
|
52
|
-
|
53
|
-
|
47
|
+
# call tools and update messages and output
|
48
|
+
messages, output = await execute_tools(
|
49
|
+
state.messages, state.tools, config.max_tool_output
|
54
50
|
)
|
51
|
+
state.messages.extend(messages)
|
52
|
+
if output is not None:
|
53
|
+
state.output = output
|
55
54
|
|
56
55
|
# check for completed or only executing a single tool call
|
57
56
|
if state.completed or tool_calls == "single":
|
inspect_ai/_eval/task/log.py
CHANGED
@@ -57,6 +57,7 @@ class TaskLogger:
|
|
57
57
|
task_name: str,
|
58
58
|
task_version: int,
|
59
59
|
task_file: str | None,
|
60
|
+
task_registry_name: str | None,
|
60
61
|
task_id: str | None,
|
61
62
|
run_id: str,
|
62
63
|
solver: SolverSpec | None,
|
@@ -131,6 +132,7 @@ class TaskLogger:
|
|
131
132
|
task_id=task_id if task_id else uuid(),
|
132
133
|
task_version=task_version,
|
133
134
|
task_file=task_file,
|
135
|
+
task_registry_name=task_registry_name,
|
134
136
|
task_attribs=task_attribs,
|
135
137
|
task_args=task_args,
|
136
138
|
solver=solver.solver if solver else None,
|