inspect-ai 0.3.82__py3-none-any.whl → 0.3.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_display/textual/app.py +14 -3
  3. inspect_ai/_display/textual/display.py +4 -0
  4. inspect_ai/_display/textual/widgets/samples.py +9 -3
  5. inspect_ai/_display/textual/widgets/task_detail.py +3 -4
  6. inspect_ai/_display/textual/widgets/tasks.py +17 -1
  7. inspect_ai/_display/textual/widgets/vscode.py +48 -0
  8. inspect_ai/_eval/eval.py +36 -24
  9. inspect_ai/_eval/evalset.py +17 -18
  10. inspect_ai/_eval/loader.py +34 -11
  11. inspect_ai/_eval/run.py +8 -13
  12. inspect_ai/_eval/score.py +13 -3
  13. inspect_ai/_eval/task/generate.py +8 -9
  14. inspect_ai/_eval/task/log.py +2 -0
  15. inspect_ai/_eval/task/task.py +23 -9
  16. inspect_ai/_util/file.py +13 -0
  17. inspect_ai/_util/json.py +2 -1
  18. inspect_ai/_util/registry.py +1 -0
  19. inspect_ai/_util/vscode.py +37 -0
  20. inspect_ai/_view/www/App.css +6 -0
  21. inspect_ai/_view/www/dist/assets/index.css +304 -128
  22. inspect_ai/_view/www/dist/assets/index.js +47495 -27519
  23. inspect_ai/_view/www/log-schema.json +124 -31
  24. inspect_ai/_view/www/package.json +3 -0
  25. inspect_ai/_view/www/src/App.tsx +12 -0
  26. inspect_ai/_view/www/src/appearance/icons.ts +1 -0
  27. inspect_ai/_view/www/src/components/Card.tsx +6 -4
  28. inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
  29. inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
  30. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
  31. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
  32. inspect_ai/_view/www/src/components/Modal.module.css +38 -0
  33. inspect_ai/_view/www/src/components/Modal.tsx +77 -0
  34. inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
  35. inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
  36. inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
  37. inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
  38. inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
  39. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
  40. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
  41. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
  42. inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
  43. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
  44. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
  45. inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
  46. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
  47. inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
  48. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
  49. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
  50. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
  51. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
  52. inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
  53. inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
  54. inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
  55. inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
  56. inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
  57. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
  58. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
  59. inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
  60. inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
  61. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  62. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
  63. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
  64. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
  65. inspect_ai/_view/www/src/state/hooks.ts +5 -3
  66. inspect_ai/_view/www/src/state/logPolling.ts +5 -1
  67. inspect_ai/_view/www/src/state/logSlice.ts +10 -0
  68. inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
  69. inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
  70. inspect_ai/_view/www/src/types/log.d.ts +34 -26
  71. inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
  72. inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
  73. inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
  74. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
  75. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
  76. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
  77. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
  78. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
  79. inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
  80. inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
  81. inspect_ai/_view/www/yarn.lock +94 -1
  82. inspect_ai/agent/__init__.py +36 -0
  83. inspect_ai/agent/_agent.py +268 -0
  84. inspect_ai/agent/_as_solver.py +72 -0
  85. inspect_ai/agent/_as_tool.py +122 -0
  86. inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
  87. inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
  88. inspect_ai/agent/_filter.py +46 -0
  89. inspect_ai/agent/_handoff.py +93 -0
  90. inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
  91. inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
  92. inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
  93. inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
  94. inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
  95. inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
  96. inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
  97. inspect_ai/agent/_react.py +241 -0
  98. inspect_ai/agent/_run.py +36 -0
  99. inspect_ai/agent/_types.py +81 -0
  100. inspect_ai/log/_log.py +11 -2
  101. inspect_ai/log/_transcript.py +13 -9
  102. inspect_ai/model/__init__.py +7 -1
  103. inspect_ai/model/_call_tools.py +256 -52
  104. inspect_ai/model/_chat_message.py +7 -4
  105. inspect_ai/model/_conversation.py +13 -62
  106. inspect_ai/model/_display.py +85 -0
  107. inspect_ai/model/_model.py +113 -14
  108. inspect_ai/model/_model_output.py +14 -9
  109. inspect_ai/model/_openai.py +16 -4
  110. inspect_ai/model/_openai_computer_use.py +162 -0
  111. inspect_ai/model/_openai_responses.py +319 -165
  112. inspect_ai/model/_providers/anthropic.py +20 -21
  113. inspect_ai/model/_providers/azureai.py +24 -13
  114. inspect_ai/model/_providers/bedrock.py +1 -7
  115. inspect_ai/model/_providers/cloudflare.py +3 -3
  116. inspect_ai/model/_providers/goodfire.py +2 -6
  117. inspect_ai/model/_providers/google.py +11 -10
  118. inspect_ai/model/_providers/groq.py +6 -3
  119. inspect_ai/model/_providers/hf.py +7 -3
  120. inspect_ai/model/_providers/mistral.py +7 -10
  121. inspect_ai/model/_providers/openai.py +47 -17
  122. inspect_ai/model/_providers/openai_o1.py +11 -4
  123. inspect_ai/model/_providers/openai_responses.py +12 -14
  124. inspect_ai/model/_providers/providers.py +2 -2
  125. inspect_ai/model/_providers/together.py +12 -2
  126. inspect_ai/model/_providers/util/chatapi.py +7 -2
  127. inspect_ai/model/_providers/util/hf_handler.py +4 -2
  128. inspect_ai/model/_providers/util/llama31.py +4 -2
  129. inspect_ai/model/_providers/vertex.py +11 -9
  130. inspect_ai/model/_providers/vllm.py +4 -4
  131. inspect_ai/scorer/__init__.py +2 -0
  132. inspect_ai/scorer/_metrics/__init__.py +2 -0
  133. inspect_ai/scorer/_metrics/grouped.py +84 -0
  134. inspect_ai/scorer/_score.py +26 -6
  135. inspect_ai/solver/__init__.py +2 -2
  136. inspect_ai/solver/_basic_agent.py +22 -9
  137. inspect_ai/solver/_bridge.py +31 -0
  138. inspect_ai/solver/_chain.py +20 -12
  139. inspect_ai/solver/_fork.py +5 -1
  140. inspect_ai/solver/_human_agent.py +52 -0
  141. inspect_ai/solver/_prompt.py +3 -1
  142. inspect_ai/solver/_run.py +59 -0
  143. inspect_ai/solver/_solver.py +14 -4
  144. inspect_ai/solver/_task_state.py +5 -3
  145. inspect_ai/tool/_tool_call.py +15 -8
  146. inspect_ai/tool/_tool_def.py +17 -12
  147. inspect_ai/tool/_tool_support_helpers.py +2 -2
  148. inspect_ai/tool/_tool_with.py +14 -11
  149. inspect_ai/tool/_tools/_bash_session.py +11 -2
  150. inspect_ai/tool/_tools/_computer/_common.py +18 -2
  151. inspect_ai/tool/_tools/_computer/_computer.py +18 -2
  152. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
  153. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
  154. inspect_ai/tool/_tools/_think.py +1 -1
  155. inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
  156. inspect_ai/util/__init__.py +2 -0
  157. inspect_ai/util/_anyio.py +27 -0
  158. inspect_ai/util/_sandbox/__init__.py +2 -1
  159. inspect_ai/util/_sandbox/context.py +32 -7
  160. inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
  161. inspect_ai/util/_sandbox/docker/compose.py +2 -2
  162. inspect_ai/util/_sandbox/docker/docker.py +12 -1
  163. inspect_ai/util/_store_model.py +30 -7
  164. inspect_ai/util/_subprocess.py +13 -3
  165. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/METADATA +1 -1
  166. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/RECORD +179 -153
  167. inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
  168. /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
  169. /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
  170. /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
  171. /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
  172. /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
  173. /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
  174. /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
  175. /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
  176. /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
  177. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/WHEEL +0 -0
  178. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/entry_points.txt +0 -0
  179. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/licenses/LICENSE +0 -0
  180. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/top_level.txt +0 -0
inspect_ai/__init__.py CHANGED
@@ -10,7 +10,8 @@ from inspect_ai._eval.score import score, score_async
10
10
  from inspect_ai._eval.task import Epochs, Task, TaskInfo, task_with
11
11
  from inspect_ai._eval.task.tasks import Tasks
12
12
  from inspect_ai._util.constants import PKG_NAME
13
- from inspect_ai.solver._human_agent.agent import human_agent
13
+ from inspect_ai.agent._human.agent import human_cli
14
+ from inspect_ai.solver._human_agent import human_agent
14
15
 
15
16
  __version__ = importlib_version(PKG_NAME)
16
17
 
@@ -58,10 +58,12 @@ class TaskScreenResult(Generic[TR]):
58
58
  value: TR | BaseException,
59
59
  tasks: list[TaskWithResult],
60
60
  output: list[str],
61
+ warnings: list[str],
61
62
  ) -> None:
62
63
  self.value = value
63
64
  self.tasks = tasks
64
65
  self.output = output
66
+ self.warnings = warnings
65
67
 
66
68
 
67
69
  class TaskScreenApp(App[TR]):
@@ -86,6 +88,7 @@ class TaskScreenApp(App[TR]):
86
88
  self._worker: Worker[TR] | None = None
87
89
  self._error: BaseException | None = None
88
90
  self._output: list[str] = []
91
+ self._warnings: list[str] = []
89
92
 
90
93
  # task screen
91
94
  self._total_tasks = 0
@@ -120,7 +123,12 @@ class TaskScreenApp(App[TR]):
120
123
  value = CancelledError()
121
124
 
122
125
  # return result w/ output
123
- return TaskScreenResult(value=value, tasks=self._app_tasks, output=self._output)
126
+ return TaskScreenResult(
127
+ value=value,
128
+ tasks=self._app_tasks,
129
+ output=self._output,
130
+ warnings=self._warnings,
131
+ )
124
132
 
125
133
  async def on_load(self) -> None:
126
134
  # events used to synchronise loading
@@ -349,8 +357,11 @@ class TaskScreenApp(App[TR]):
349
357
  if text.endswith("\n"):
350
358
  text = text[:-1]
351
359
 
352
- # track output (for printing at the end)
353
- self._output.append(text)
360
+ # track output and warnings (for printing at the end)
361
+ if "WARNING" in text:
362
+ self._warnings.append(text)
363
+ else:
364
+ self._output.append(text)
354
365
 
355
366
  # write to console view
356
367
  self.query_one(ConsoleView).write_ansi(text)
@@ -42,6 +42,10 @@ class TextualDisplay(Display):
42
42
  # print tasks
43
43
  rich.print(tasks_results(result.tasks))
44
44
 
45
+ # print warnings
46
+ if result.warnings:
47
+ print("\n".join(result.warnings))
48
+
45
49
  # raise error as required
46
50
  if isinstance(result.value, BaseException):
47
51
  raise result.value
@@ -17,7 +17,7 @@ from textual.widgets import (
17
17
  OptionList,
18
18
  Static,
19
19
  )
20
- from textual.widgets.option_list import Option
20
+ from textual.widgets.option_list import Option, OptionDoesNotExist
21
21
 
22
22
  from inspect_ai._display.textual.widgets.port_mappings import get_url
23
23
  from inspect_ai._util.format import format_progress_time
@@ -124,7 +124,7 @@ class SamplesList(OptionList):
124
124
  def set_samples(self, samples: list[ActiveSample]) -> None:
125
125
  # check for a highlighted sample (make sure we don't remove it)
126
126
  highlighted_id = (
127
- self.get_option_at_index(self.highlighted).id
127
+ self.get_id_at_index(self.highlighted)
128
128
  if self.highlighted is not None
129
129
  else None
130
130
  )
@@ -179,12 +179,18 @@ class SamplesList(OptionList):
179
179
  self.scroll_to_highlight()
180
180
 
181
181
  def sample_for_highlighted(self, highlighted: int) -> ActiveSample | None:
182
- highlighted_id = self.get_option_at_index(highlighted).id
182
+ highlighted_id = self.get_id_at_index(highlighted)
183
183
  if highlighted_id is not None:
184
184
  return sample_for_id(self.samples, highlighted_id)
185
185
  else:
186
186
  return None
187
187
 
188
+ def get_id_at_index(self, index: int) -> str | None:
189
+ try:
190
+ return self.get_option_at_index(index).id
191
+ except OptionDoesNotExist:
192
+ return None
193
+
188
194
 
189
195
  class SampleVNC(Horizontal):
190
196
  DEFAULT_CSS = """
@@ -221,12 +221,11 @@ class TaskMetrics(Widget):
221
221
  self.recompute_grid()
222
222
 
223
223
  def on_mount(self) -> None:
224
- self.recompute_grid()
224
+ self.recompute_grid(True)
225
225
 
226
- def recompute_grid(self) -> None:
227
- if not self.is_mounted:
226
+ def recompute_grid(self, force: bool = False) -> None:
227
+ if not self.is_mounted and not force:
228
228
  return
229
-
230
229
  grid = self.query_one(f"#{self.grid_id()}")
231
230
 
232
231
  grid.remove_children()
@@ -17,6 +17,11 @@ from inspect_ai._display.core.results import task_metric
17
17
  from inspect_ai._display.textual.widgets.clock import Clock
18
18
  from inspect_ai._display.textual.widgets.task_detail import TaskDetail
19
19
  from inspect_ai._display.textual.widgets.toggle import Toggle
20
+ from inspect_ai._display.textual.widgets.vscode import conditional_vscode_link
21
+ from inspect_ai._util.file import to_uri
22
+ from inspect_ai._util.vscode import (
23
+ VSCodeCommand,
24
+ )
20
25
 
21
26
  from ...core.display import (
22
27
  Progress,
@@ -151,7 +156,7 @@ class TaskProgressView(Widget):
151
156
  height: auto;
152
157
  width: 1fr;
153
158
  layout: grid;
154
- grid-size: 8 2;
159
+ grid-size: 9 2;
155
160
  grid-columns: auto auto auto auto 1fr auto auto auto;
156
161
  grid-rows: auto auto;
157
162
  grid-gutter: 0 1;
@@ -200,6 +205,15 @@ class TaskProgressView(Widget):
200
205
 
201
206
  self.sample_count_width: int = sample_count_width
202
207
  self.display_metrics = display_metrics
208
+ self.view_log_link = conditional_vscode_link(
209
+ "[View Log]",
210
+ VSCodeCommand(
211
+ command="inspect.openLogViewer",
212
+ args=[to_uri(task.profile.log_location)]
213
+ if task.profile.log_location
214
+ else [],
215
+ ),
216
+ )
203
217
 
204
218
  metrics: reactive[list[TaskDisplayMetric] | None] = reactive(None)
205
219
  metrics_width: reactive[int | None] = reactive(None)
@@ -222,6 +236,8 @@ class TaskProgressView(Widget):
222
236
  yield self.count_display
223
237
  yield self.metrics_display
224
238
  yield Clock()
239
+ yield self.view_log_link
240
+
225
241
  yield self.task_detail
226
242
 
227
243
  @on(Toggle.Toggled)
@@ -0,0 +1,48 @@
1
+ from textual.widget import Widget
2
+ from textual.widgets import Link, Static
3
+
4
+ from inspect_ai._util.vscode import (
5
+ VSCodeCommand,
6
+ can_execute_vscode_command,
7
+ execute_vscode_commands,
8
+ )
9
+
10
+
11
+ def conditional_vscode_link(text: str, command: VSCodeCommand) -> Widget:
12
+ if can_execute_vscode_command(command.command):
13
+ vscode_link = VSCodeLink(text)
14
+ vscode_link.commands = [command]
15
+ return vscode_link
16
+ else:
17
+ return Static()
18
+
19
+
20
+ class VSCodeLink(Link):
21
+ def __init__(
22
+ self,
23
+ text: str,
24
+ *,
25
+ url: str | None = None,
26
+ tooltip: str | None = None,
27
+ name: str | None = None,
28
+ id: str | None = None,
29
+ classes: str | None = None,
30
+ disabled: bool = False,
31
+ ) -> None:
32
+ super().__init__(
33
+ text,
34
+ url=url,
35
+ tooltip=tooltip,
36
+ name=name,
37
+ id=id,
38
+ classes=classes,
39
+ disabled=disabled,
40
+ )
41
+ self.commands: list[VSCodeCommand] = []
42
+
43
+ def on_click(self) -> None:
44
+ execute_vscode_commands(self.commands)
45
+
46
+ def action_open_link(self) -> None:
47
+ # Workaround to prevent the default action of opening the link in a browser
48
+ return None
inspect_ai/_eval/eval.py CHANGED
@@ -2,9 +2,11 @@ import logging
2
2
  import os
3
3
  import sys
4
4
  from pathlib import Path
5
- from typing import Any, Literal
5
+ from typing import Any, Literal, cast
6
6
 
7
7
  from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
8
+ from inspect_ai.agent._agent import Agent, is_agent
9
+ from inspect_ai.agent._as_solver import as_solver
8
10
 
9
11
  if sys.version_info < (3, 11):
10
12
  from exceptiongroup import ExceptionGroup
@@ -71,7 +73,7 @@ def eval(
71
73
  task_args: dict[str, Any] | str = dict(),
72
74
  sandbox: SandboxEnvironmentType | None = None,
73
75
  sandbox_cleanup: bool | None = None,
74
- solver: Solver | list[Solver] | SolverSpec | None = None,
76
+ solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
75
77
  tags: list[str] | None = None,
76
78
  metadata: dict[str, Any] | None = None,
77
79
  trace: bool | None = None,
@@ -246,7 +248,7 @@ async def eval_async(
246
248
  task_args: dict[str, Any] | str = dict(),
247
249
  sandbox: SandboxEnvironmentType | None = None,
248
250
  sandbox_cleanup: bool | None = None,
249
- solver: Solver | list[Solver] | SolverSpec | None = None,
251
+ solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
250
252
  tags: list[str] | None = None,
251
253
  metadata: dict[str, Any] | None = None,
252
254
  approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
@@ -353,13 +355,10 @@ async def eval_async(
353
355
 
354
356
  try:
355
357
  # intialise eval
356
- model, approval, resolved_tasks = eval_init(
357
- tasks=tasks,
358
+ model, approval = eval_init(
358
359
  model=model,
359
360
  model_base_url=model_base_url,
360
361
  model_args=model_args,
361
- task_args=task_args,
362
- sandbox=sandbox,
363
362
  approval=approval,
364
363
  max_subprocesses=max_subprocesses,
365
364
  log_level=log_level,
@@ -367,6 +366,11 @@ async def eval_async(
367
366
  **kwargs,
368
367
  )
369
368
 
369
+ # resolve tasks
370
+ resolved_tasks = eval_resolve_tasks(
371
+ tasks, task_args, model, GenerateConfig(**kwargs), sandbox
372
+ )
373
+
370
374
  # warn and return empty string if we resolved no tasks
371
375
  if len(resolved_tasks) == 0:
372
376
  log.warning("No inspect tasks were found at the specified paths.")
@@ -412,7 +416,12 @@ async def eval_async(
412
416
  )
413
417
 
414
418
  # resolve solver
415
- solver = chain(solver) if isinstance(solver, list) else solver
419
+ if isinstance(solver, list):
420
+ solver = chain(solver)
421
+ elif is_agent(solver):
422
+ solver = as_solver(solver)
423
+ else:
424
+ solver = cast(Solver | SolverSpec | None, solver)
416
425
 
417
426
  # ensure consistency of limit and sample_id
418
427
  if sample_id is not None and limit is not None:
@@ -724,7 +733,7 @@ async def eval_retry_async(
724
733
  # context to reconstruct ephemeral Task instances)
725
734
  task: str | None
726
735
  task_id = eval_log.eval.task_id
727
- task_name = eval_log.eval.task
736
+ task_name = eval_log.eval.task_registry_name or eval_log.eval.task
728
737
  task_file = eval_log.eval.task_file
729
738
  if task_file:
730
739
  if not Path(task_file).exists():
@@ -846,24 +855,20 @@ async def eval_retry_async(
846
855
 
847
856
 
848
857
  def eval_init(
849
- tasks: Tasks,
850
858
  model: str | Model | list[str] | list[Model] | None | NotGiven = NOT_GIVEN,
851
859
  model_base_url: str | None = None,
852
860
  model_args: dict[str, Any] | str = dict(),
853
- task_args: dict[str, Any] | str = dict(),
854
- sandbox: SandboxEnvironmentType | None = None,
855
861
  approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None,
856
862
  max_subprocesses: int | None = None,
857
863
  log_level: str | None = None,
858
864
  log_level_transcript: str | None = None,
859
865
  **kwargs: Unpack[GenerateConfigArgs],
860
- ) -> tuple[list[Model], list[ApprovalPolicy] | None, list[ResolvedTask]]:
866
+ ) -> tuple[list[Model], list[ApprovalPolicy] | None]:
861
867
  # init eval context
862
868
  init_eval_context(log_level, log_level_transcript, max_subprocesses)
863
869
 
864
870
  # resolve model and task args
865
871
  model_args = resolve_args(model_args)
866
- task_args = resolve_args(task_args)
867
872
 
868
873
  # resolve model args from environment if not specified
869
874
  if len(model_args) == 0:
@@ -876,21 +881,28 @@ def eval_init(
876
881
  generate_config = GenerateConfig(**kwargs)
877
882
  models = resolve_models(model, model_base_url, model_args, generate_config)
878
883
 
879
- # resolve tasks (set active model to resolve uses of the
880
- # 'default' model in tools, solvers, and scorers)
881
-
882
- with task_display().suspend_task_app():
883
- resolved_tasks: list[ResolvedTask] = []
884
- for m in models:
885
- init_active_model(m, generate_config)
886
- resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox))
887
-
888
884
  # resolve approval
889
885
  if isinstance(approval, str | ApprovalPolicyConfig):
890
886
  approval = approval_policies_from_config(approval)
891
887
  init_tool_approval(approval)
892
888
 
893
- return models, approval, resolved_tasks
889
+ return models, approval
890
+
891
+
892
+ def eval_resolve_tasks(
893
+ tasks: Tasks,
894
+ task_args: dict[str, Any] | str,
895
+ models: list[Model],
896
+ config: GenerateConfig,
897
+ sandbox: SandboxEnvironmentType | None,
898
+ ) -> list[ResolvedTask]:
899
+ task_args = resolve_args(task_args)
900
+ with task_display().suspend_task_app():
901
+ resolved_tasks: list[ResolvedTask] = []
902
+ for m in models:
903
+ init_active_model(m, config)
904
+ resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox))
905
+ return resolved_tasks
894
906
 
895
907
 
896
908
  def init_eval_display(
@@ -1,6 +1,5 @@
1
1
  import hashlib
2
2
  import logging
3
- from copy import deepcopy
4
3
  from typing import Any, Literal, NamedTuple, Set, cast
5
4
 
6
5
  import rich
@@ -18,6 +17,7 @@ from typing_extensions import Unpack
18
17
  from inspect_ai._util.error import PrerequisiteError
19
18
  from inspect_ai._util.file import basename, filesystem
20
19
  from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
20
+ from inspect_ai.agent._agent import Agent
21
21
  from inspect_ai.approval._policy import ApprovalPolicy
22
22
  from inspect_ai.log import EvalLog
23
23
  from inspect_ai.log._bundle import bundle_log_dir
@@ -37,7 +37,7 @@ from inspect_ai.solver._solver import Solver, SolverSpec
37
37
  from inspect_ai.util import DisplayType, SandboxEnvironmentType
38
38
  from inspect_ai.util._display import display_type_initialized, init_display_type
39
39
 
40
- from .eval import eval, eval_init
40
+ from .eval import eval, eval_init, eval_resolve_tasks
41
41
  from .loader import resolve_task_args
42
42
  from .task import Epochs
43
43
  from .task.resolved import ResolvedTask
@@ -66,7 +66,7 @@ def eval_set(
66
66
  task_args: dict[str, Any] | str = dict(),
67
67
  sandbox: SandboxEnvironmentType | None = None,
68
68
  sandbox_cleanup: bool | None = None,
69
- solver: Solver | list[Solver] | SolverSpec | None = None,
69
+ solver: Solver | SolverSpec | Agent | list[Solver] | None = None,
70
70
  tags: list[str] | None = None,
71
71
  metadata: dict[str, Any] | None = None,
72
72
  trace: bool | None = None,
@@ -247,29 +247,21 @@ def eval_set(
247
247
  if display == "conversation":
248
248
  raise RuntimeError("eval_set cannot be used with conversation display.")
249
249
 
250
- # resolve tasks
251
- models, _, resolved_tasks = eval_init(
252
- tasks=tasks,
250
+ # initialize eval
251
+ models, _ = eval_init(
253
252
  model=model,
254
253
  model_base_url=model_base_url,
255
254
  model_args=model_args,
256
- task_args=task_args,
257
- sandbox=sandbox,
258
255
  max_subprocesses=max_subprocesses,
259
256
  log_level=log_level,
260
257
  log_level_transcript=log_level_transcript,
261
258
  **kwargs,
262
259
  )
263
260
 
264
- # ensure log_dir and list all logs
261
+ # ensure log_dir
265
262
  fs = filesystem(log_dir)
266
263
  fs.mkdir(log_dir, exist_ok=True)
267
264
 
268
- # validate that:
269
- # (1) All tasks have a unique identifier
270
- # (2) All logs have identifiers that map to tasks
271
- validate_eval_set_prerequisites(resolved_tasks, list_all_eval_logs(log_dir))
272
-
273
265
  # resolve some parameters
274
266
  retry_connections = retry_connections or 0.5
275
267
  retry_cleanup = retry_cleanup is not False
@@ -310,11 +302,21 @@ def eval_set(
310
302
  # - tasks with a successful log (they'll just be returned)
311
303
  # - tasks with failed logs (they'll be retried)
312
304
  def try_eval() -> list[EvalLog]:
305
+ # resolve tasks
306
+ resolved_tasks = eval_resolve_tasks(
307
+ tasks, task_args, models, GenerateConfig(**kwargs), sandbox
308
+ )
309
+
313
310
  # list all logs currently in the log directory (update manifest if there are some)
314
311
  all_logs = list_all_eval_logs(log_dir)
315
312
  if len(all_logs) > 0:
316
313
  write_log_dir_manifest(log_dir)
317
314
 
315
+ # validate that:
316
+ # (1) All tasks have a unique identifier
317
+ # (2) All logs have identifiers that map to tasks
318
+ validate_eval_set_prerequisites(resolved_tasks, all_logs)
319
+
318
320
  # see which tasks are yet to run (to complete successfully we need
319
321
  # a successful eval for every [task_file/]task_name/model combination)
320
322
  # for those that haven't run, schedule them into models => tasks groups
@@ -419,13 +421,10 @@ def as_previous_tasks(
419
421
  # want to bring this back but we'd need to resolve the
420
422
  # directory issues.
421
423
 
422
- # deepcopy so the same instance is not run twice
423
- prev_task = deepcopy(task.task)
424
-
425
424
  previous_tasks.append(
426
425
  PreviousTask(
427
426
  id=log.header.eval.task_id,
428
- task=prev_task,
427
+ task=task.task,
429
428
  task_args=resolve_task_args(task.task),
430
429
  model=task.model,
431
430
  log=read_eval_log(log.info),
@@ -26,6 +26,8 @@ from inspect_ai._util.registry import (
26
26
  registry_lookup,
27
27
  registry_params,
28
28
  )
29
+ from inspect_ai.agent._agent import Agent
30
+ from inspect_ai.agent._as_solver import as_solver
29
31
  from inspect_ai.model import Model
30
32
  from inspect_ai.scorer._scorer import Scorer, ScorerSpec, scorer_create
31
33
  from inspect_ai.solver._bridge import bridge
@@ -421,20 +423,32 @@ def solver_from_spec(spec: SolverSpec) -> Solver:
421
423
  if solver_file is None:
422
424
  if solver_name is None:
423
425
  raise ValueError(f"Unable to resolve solver name from {spec.solver}")
424
- return cast(Solver, registry_create("solver", solver_name, **spec.args))
426
+ elif registry_lookup("solver", solver_name) is not None:
427
+ return cast(Solver, registry_create("solver", solver_name, **spec.args))
428
+ elif registry_lookup("agent", solver_name) is not None:
429
+ agent = cast(Agent, registry_create("agent", solver_name, **spec.args))
430
+ return as_solver(agent)
431
+ else:
432
+ raise ValueError(
433
+ f"Unkonwn solver {solver_name} (not registered as a @solver or @agent)"
434
+ )
425
435
 
426
436
  # we do have a solver file
427
437
  else:
428
438
  # load the module and parse decorators
429
439
  solver_module = load_module(solver_file)
430
- decorators = parse_decorators(solver_file, "solver")
440
+ solver_decorators = parse_decorators(solver_file, "solver")
441
+ agent_decorators = parse_decorators(solver_file, "agent")
431
442
 
432
443
  # if there is no solver_name see if we can discover it
433
444
  if solver_name is None:
434
- if len(decorators) == 1:
445
+ if len(solver_decorators) == 1:
435
446
  # decorator based solver
436
- solver_name = decorators[0][0]
437
- elif len(decorators) == 0:
447
+ solver_name = solver_decorators[0][0]
448
+ elif len(agent_decorators) == 1:
449
+ # decorator based agent
450
+ solver_name = agent_decorators[0][0]
451
+ elif len(solver_decorators) == 0 and len(agent_decorators) == 0:
438
452
  # see if we can find an agent based solver
439
453
  functions = [
440
454
  function
@@ -454,26 +468,35 @@ def solver_from_spec(spec: SolverSpec) -> Solver:
454
468
 
455
469
  elif len(agent_functions) == 0:
456
470
  raise PrerequisiteError(
457
- f"The source file {pretty_solver_file} does not contain any @solver functions or agent functions."
471
+ f"The source file {pretty_solver_file} does not contain any @solver, @agent or bridged agent functions."
458
472
  )
459
473
  else:
460
474
  raise PrerequisiteError(
461
- f"The source file {pretty_solver_file} has more than one agent function (qualify which agent using e.g. '{solver_file.name}@agent_fn')"
475
+ f"The source file {pretty_solver_file} has more than one bridged agent function (qualify which agent using e.g. '{solver_file.name}@agent_fn')"
462
476
  )
463
- else:
477
+ elif len(solver_decorators) > 1:
464
478
  raise PrerequisiteError(
465
479
  f"The source file {pretty_solver_file} has more than one @solver function (qualify which solver using e.g. '{solver_file.name}y@solver_fn')"
466
480
  )
481
+ else:
482
+ raise PrerequisiteError(
483
+ f"The source file {pretty_solver_file} has more than one @agent function (qualify which agent using e.g. '{solver_file.name}y@agent_fn')"
484
+ )
467
485
 
468
486
  # create decorator based solvers using the registry
469
- if any(solver[0] == solver_name for solver in decorators):
487
+ if any(solver[0] == solver_name for solver in solver_decorators):
470
488
  return cast(Solver, registry_create("solver", solver_name, **spec.args))
471
489
 
472
- # create agent based solvers by calling the function and wrapping it in bridge()
490
+ # create decorator based agents using the registry
491
+ elif any(agent[0] == solver_name for agent in agent_decorators):
492
+ agent = cast(Agent, registry_create("agent", solver_name, **spec.args))
493
+ return as_solver(agent)
494
+
495
+ # create bridge based solvers by calling the function and wrapping it in bridge()
473
496
  else:
474
497
  agent_fn = getattr(solver_module, solver_name, None)
475
498
  if inspect.isfunction(agent_fn):
476
- return bridge.bridge(agent_fn(**spec.args))
499
+ return bridge(agent_fn(**spec.args))
477
500
  elif agent_fn is not None:
478
501
  raise PrerequisiteError(
479
502
  f"The object {solver_name} in file {pretty_solver_file} is not a Python function."
inspect_ai/_eval/run.py CHANGED
@@ -1,4 +1,3 @@
1
- import functools
2
1
  import logging
3
2
  import os
4
3
  import sys
@@ -20,7 +19,6 @@ from inspect_ai._display.core.active import (
20
19
  init_task_screen,
21
20
  )
22
21
  from inspect_ai._display.core.display import TaskSpec
23
- from inspect_ai._util._async import tg_collect
24
22
  from inspect_ai._util.error import PrerequisiteError, exception_message
25
23
  from inspect_ai._util.path import chdir
26
24
  from inspect_ai._util.registry import registry_unqualified_name
@@ -195,6 +193,7 @@ async def eval_run(
195
193
  task_name=task.name,
196
194
  task_version=task.version,
197
195
  task_file=resolved_task.task_file,
196
+ task_registry_name=resolved_task.task.registry_name,
198
197
  task_id=resolved_task.id if resolved_task.id else uuid(),
199
198
  run_id=run_id,
200
199
  solver=eval_solver_spec,
@@ -359,17 +358,13 @@ async def run_multiple(tasks: list[TaskRunOptions], parallel: int) -> list[EvalL
359
358
  "Run Task",
360
359
  f"task: {task_options.task.name} ({task_options.model})",
361
360
  ):
362
- tg_results = await tg_collect(
363
- [functools.partial(task_run, task_options)]
364
- )
365
- # check for empty results list (indicates cancellation)
366
- if len(tg_results) == 0:
367
- # task was cancelled, break out of the worker loop
368
- result = None
369
-
370
- else:
371
- result = tg_results[0]
372
- results.append(result)
361
+ async with anyio.create_task_group() as tg:
362
+
363
+ async def run_task() -> None:
364
+ result = await task_run(task_options)
365
+ results.append(result)
366
+
367
+ tg.start_soon(run_task)
373
368
 
374
369
  except Exception as ex:
375
370
  # errors generally don't escape from tasks (the exception being if an error
inspect_ai/_eval/score.py CHANGED
@@ -7,8 +7,8 @@ import anyio
7
7
 
8
8
  from inspect_ai._display import display
9
9
  from inspect_ai._eval.loader import scorer_from_spec
10
- from inspect_ai._util._async import tg_collect
11
- from inspect_ai._util.platform import platform_init
10
+ from inspect_ai._util._async import configured_async_backend, run_coroutine, tg_collect
11
+ from inspect_ai._util.platform import platform_init, running_in_notebook
12
12
  from inspect_ai._util.registry import registry_create, registry_unqualified_name
13
13
  from inspect_ai.log import (
14
14
  EvalLog,
@@ -56,7 +56,17 @@ def score(
56
56
  # resolve scorers into a list
57
57
  scorers = [scorers] if isinstance(scorers, Scorer) else scorers
58
58
 
59
- return anyio.run(score_async, log, scorers, epochs_reducer, action)
59
+ if running_in_notebook():
60
+ return run_coroutine(score_async(log, scorers, epochs_reducer, action))
61
+ else:
62
+ return anyio.run(
63
+ score_async,
64
+ log,
65
+ scorers,
66
+ epochs_reducer,
67
+ action,
68
+ backend=configured_async_backend(),
69
+ )
60
70
 
61
71
 
62
72
  async def score_async(
@@ -1,12 +1,8 @@
1
1
  from typing import Literal
2
2
 
3
- from inspect_ai.model import (
4
- CachePolicy,
5
- GenerateConfig,
6
- Model,
7
- call_tools,
8
- )
3
+ from inspect_ai.model import CachePolicy, GenerateConfig, Model
9
4
  from inspect_ai.model._cache import epoch
5
+ from inspect_ai.model._call_tools import execute_tools
10
6
  from inspect_ai.solver import TaskState
11
7
  from inspect_ai.solver._limit import SampleLimitExceededError
12
8
  from inspect_ai.tool import ToolFunction
@@ -48,10 +44,13 @@ async def task_generate(
48
44
 
49
45
  # resolve tool calls if necessary
50
46
  if tool_calls != "none" and message.tool_calls:
51
- # call tools and append messages to state
52
- state.messages.extend(
53
- await call_tools(message, state.tools, config.max_tool_output)
47
+ # call tools and update messages and output
48
+ messages, output = await execute_tools(
49
+ state.messages, state.tools, config.max_tool_output
54
50
  )
51
+ state.messages.extend(messages)
52
+ if output is not None:
53
+ state.output = output
55
54
 
56
55
  # check for completed or only executing a single tool call
57
56
  if state.completed or tool_calls == "single":
@@ -57,6 +57,7 @@ class TaskLogger:
57
57
  task_name: str,
58
58
  task_version: int,
59
59
  task_file: str | None,
60
+ task_registry_name: str | None,
60
61
  task_id: str | None,
61
62
  run_id: str,
62
63
  solver: SolverSpec | None,
@@ -131,6 +132,7 @@ class TaskLogger:
131
132
  task_id=task_id if task_id else uuid(),
132
133
  task_version=task_version,
133
134
  task_file=task_file,
135
+ task_registry_name=task_registry_name,
134
136
  task_attribs=task_attribs,
135
137
  task_args=task_args,
136
138
  solver=solver.solver if solver else None,