inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +7 -3
  3. inspect_ai/_cli/eval.py +17 -2
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +4 -3
  6. inspect_ai/_display/core/config.py +3 -3
  7. inspect_ai/_display/core/panel.py +7 -3
  8. inspect_ai/_display/plain/__init__.py +0 -0
  9. inspect_ai/_display/plain/display.py +203 -0
  10. inspect_ai/_display/rich/display.py +4 -9
  11. inspect_ai/_display/textual/app.py +4 -1
  12. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  13. inspect_ai/_display/textual/widgets/samples.py +119 -16
  14. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  15. inspect_ai/_eval/eval.py +32 -20
  16. inspect_ai/_eval/evalset.py +7 -5
  17. inspect_ai/_eval/score.py +1 -0
  18. inspect_ai/_eval/task/__init__.py +2 -2
  19. inspect_ai/_eval/task/images.py +40 -25
  20. inspect_ai/_eval/task/results.py +50 -22
  21. inspect_ai/_eval/task/run.py +180 -124
  22. inspect_ai/_eval/task/sandbox.py +10 -5
  23. inspect_ai/_eval/task/task.py +140 -25
  24. inspect_ai/_util/constants.py +2 -0
  25. inspect_ai/_util/content.py +23 -1
  26. inspect_ai/_util/images.py +20 -17
  27. inspect_ai/_util/kvstore.py +73 -0
  28. inspect_ai/_util/notgiven.py +18 -0
  29. inspect_ai/_util/port_names.py +61 -0
  30. inspect_ai/_util/text.py +23 -0
  31. inspect_ai/_util/thread.py +5 -0
  32. inspect_ai/_view/www/App.css +31 -1
  33. inspect_ai/_view/www/dist/assets/index.css +31 -1
  34. inspect_ai/_view/www/dist/assets/index.js +25375 -1846
  35. inspect_ai/_view/www/log-schema.json +129 -15
  36. inspect_ai/_view/www/package.json +2 -0
  37. inspect_ai/_view/www/src/App.mjs +8 -10
  38. inspect_ai/_view/www/src/Types.mjs +0 -1
  39. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  40. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  41. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  42. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  43. inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
  44. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  45. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  46. inspect_ai/_view/www/src/index.js +75 -2
  47. inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
  48. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
  49. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  50. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  51. inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
  52. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  53. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
  54. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
  55. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  56. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +62 -27
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/Json.mjs +12 -6
  76. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
  77. inspect_ai/_view/www/vite.config.js +7 -0
  78. inspect_ai/_view/www/yarn.lock +116 -0
  79. inspect_ai/approval/_human/__init__.py +0 -0
  80. inspect_ai/approval/_human/util.py +2 -2
  81. inspect_ai/approval/_policy.py +12 -6
  82. inspect_ai/dataset/_sources/csv.py +2 -1
  83. inspect_ai/dataset/_sources/json.py +2 -1
  84. inspect_ai/dataset/_sources/util.py +15 -7
  85. inspect_ai/log/_condense.py +11 -1
  86. inspect_ai/log/_log.py +3 -6
  87. inspect_ai/log/_recorders/eval.py +19 -8
  88. inspect_ai/log/_samples.py +26 -5
  89. inspect_ai/log/_transcript.py +32 -2
  90. inspect_ai/model/__init__.py +10 -2
  91. inspect_ai/model/_call_tools.py +59 -12
  92. inspect_ai/model/_chat_message.py +2 -4
  93. inspect_ai/model/_conversation.py +61 -0
  94. inspect_ai/model/_generate_config.py +10 -4
  95. inspect_ai/model/_model.py +117 -18
  96. inspect_ai/model/_model_output.py +7 -2
  97. inspect_ai/model/_providers/anthropic.py +109 -51
  98. inspect_ai/model/_providers/azureai.py +26 -24
  99. inspect_ai/model/_providers/bedrock.py +43 -44
  100. inspect_ai/model/_providers/google.py +121 -58
  101. inspect_ai/model/_providers/groq.py +7 -5
  102. inspect_ai/model/_providers/hf.py +11 -6
  103. inspect_ai/model/_providers/mistral.py +17 -20
  104. inspect_ai/model/_providers/openai.py +32 -21
  105. inspect_ai/model/_providers/openai_o1.py +9 -8
  106. inspect_ai/model/_providers/providers.py +1 -1
  107. inspect_ai/model/_providers/together.py +8 -8
  108. inspect_ai/model/_providers/vertex.py +18 -8
  109. inspect_ai/scorer/__init__.py +13 -2
  110. inspect_ai/scorer/_metrics/__init__.py +2 -2
  111. inspect_ai/scorer/_metrics/std.py +3 -3
  112. inspect_ai/scorer/_reducer/reducer.py +1 -1
  113. inspect_ai/scorer/_scorer.py +2 -2
  114. inspect_ai/solver/__init__.py +2 -5
  115. inspect_ai/solver/_prompt.py +35 -5
  116. inspect_ai/solver/_task_state.py +80 -38
  117. inspect_ai/tool/__init__.py +11 -1
  118. inspect_ai/tool/_tool.py +21 -3
  119. inspect_ai/tool/_tool_call.py +10 -0
  120. inspect_ai/tool/_tool_def.py +16 -5
  121. inspect_ai/tool/_tool_with.py +21 -4
  122. inspect_ai/tool/beta/__init__.py +5 -0
  123. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  124. inspect_ai/tool/beta/_computer/_common.py +133 -0
  125. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  126. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  127. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  128. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  129. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  130. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  131. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  134. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  135. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  136. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  137. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  138. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  139. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  144. inspect_ai/util/__init__.py +2 -3
  145. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  146. inspect_ai/util/_display.py +14 -4
  147. inspect_ai/util/_limit.py +26 -0
  148. inspect_ai/util/_sandbox/context.py +12 -13
  149. inspect_ai/util/_sandbox/docker/compose.py +24 -11
  150. inspect_ai/util/_sandbox/docker/docker.py +84 -14
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/environment.py +27 -1
  153. inspect_ai/util/_sandbox/local.py +1 -0
  154. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
  155. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
  156. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  157. inspect_ai/model/_trace.py +0 -48
  158. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
  159. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
  160. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
  161. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -23,9 +23,15 @@ from vertexai.generative_models import ( # type: ignore
23
23
  )
24
24
  from vertexai.generative_models import Content as VertexContent
25
25
 
26
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
27
- from inspect_ai._util.content import Content, ContentText
28
- from inspect_ai._util.images import image_as_data
26
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
27
+ from inspect_ai._util.content import (
28
+ Content,
29
+ ContentAudio,
30
+ ContentImage,
31
+ ContentText,
32
+ ContentVideo,
33
+ )
34
+ from inspect_ai._util.images import file_as_data
29
35
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
30
36
 
31
37
  from .._chat_message import (
@@ -244,9 +250,6 @@ def consective_tool_message_reducer(
244
250
  return messages
245
251
 
246
252
 
247
- NO_CONTENT = "(no content)"
248
-
249
-
250
253
  async def content_dict(
251
254
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
252
255
  ) -> VertexContent:
@@ -308,9 +311,16 @@ async def content_part(content: Content | str) -> Part:
308
311
  return Part.from_text(content or NO_CONTENT)
309
312
  elif isinstance(content, ContentText):
310
313
  return Part.from_text(content.text or NO_CONTENT)
311
- else:
312
- image_bytes, mime_type = await image_as_data(content.image)
314
+ elif isinstance(content, ContentImage):
315
+ image_bytes, mime_type = await file_as_data(content.image)
313
316
  return Part.from_image(image=Image.from_bytes(data=image_bytes))
317
+ else:
318
+ if isinstance(content, ContentAudio):
319
+ file = content.audio
320
+ elif isinstance(content, ContentVideo):
321
+ file = content.video
322
+ file_bytes, mime_type = await file_as_data(file)
323
+ return Part.from_data(file_bytes, mime_type)
314
324
 
315
325
 
316
326
  def prepend_system_messages(
@@ -1,3 +1,5 @@
1
+ from inspect_ai._util.deprecation import relocated_module_attribute
2
+
1
3
  from ._answer import AnswerPattern, answer
2
4
  from ._choice import choice
3
5
  from ._classification import exact, f1
@@ -16,7 +18,7 @@ from ._metric import (
16
18
  )
17
19
  from ._metrics.accuracy import accuracy
18
20
  from ._metrics.mean import mean
19
- from ._metrics.std import bootstrap_std, std, stderr
21
+ from ._metrics.std import bootstrap_stderr, std, stderr
20
22
  from ._model import model_graded_fact, model_graded_qa
21
23
  from ._multi import multi_scorer
22
24
  from ._pattern import pattern
@@ -50,7 +52,7 @@ __all__ = [
50
52
  "Target",
51
53
  "scorer",
52
54
  "accuracy",
53
- "bootstrap_std",
55
+ "bootstrap_stderr",
54
56
  "std",
55
57
  "stderr",
56
58
  "mean",
@@ -76,3 +78,12 @@ __all__ = [
76
78
  "at_least",
77
79
  "pass_at",
78
80
  ]
81
+ _BOOTSTRAP_RENAME_VERSION = "0.3.58"
82
+ _REMOVED_IN = "0.4"
83
+
84
+ relocated_module_attribute(
85
+ "bootstrap_std",
86
+ "inspect_ai.scorer.bootstrap_stderr",
87
+ _BOOTSTRAP_RENAME_VERSION,
88
+ _REMOVED_IN,
89
+ )
@@ -1,12 +1,12 @@
1
1
  from .accuracy import accuracy
2
2
  from .mean import mean, var
3
- from .std import bootstrap_std, std, stderr
3
+ from .std import bootstrap_stderr, std, stderr
4
4
 
5
5
  __all__ = [
6
6
  "accuracy",
7
7
  "mean",
8
8
  "var",
9
- "bootstrap_std",
9
+ "bootstrap_stderr",
10
10
  "std",
11
11
  "stderr",
12
12
  ]
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
15
15
 
16
16
 
17
17
  @metric
18
- def bootstrap_std(
18
+ def bootstrap_stderr(
19
19
  num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
20
20
  ) -> Metric:
21
- """Standard deviation of a bootstrapped estimate of the mean.
21
+ """Standard error of the mean using bootstrap.
22
22
 
23
23
  Args:
24
24
  num_samples (int): Number of bootstrap samples to take.
@@ -31,7 +31,7 @@ def bootstrap_std(
31
31
  0 if the Value is a complex object (list or dict).
32
32
 
33
33
  Returns:
34
- bootstrap_std metric
34
+ bootstrap_stderr metric
35
35
  """
36
36
 
37
37
  def metric(scores: list[Score]) -> float:
@@ -111,7 +111,7 @@ def pass_at(
111
111
  if total - correct < k:
112
112
  return 1.0
113
113
  else:
114
- return 1.0 - cast(
114
+ return 1.0 - cast( # type: ignore[redundant-cast]
115
115
  float,
116
116
  np.prod(1.0 - k / np.arange(total - correct + 1, total + 1)).item(),
117
117
  )
@@ -151,8 +151,8 @@ def scorer_metrics(
151
151
  return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
152
152
 
153
153
 
154
- def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
155
- base_name = registry_unqualified_name(scorer)
154
+ def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
155
+ base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
156
156
  scorer_name = base_name
157
157
  count = 1
158
158
  while scorer_name in already_used_names:
@@ -7,11 +7,7 @@ from ._fork import fork
7
7
  from ._human_agent.agent import human_agent
8
8
  from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
9
9
  from ._plan import Plan, plan
10
- from ._prompt import (
11
- chain_of_thought,
12
- prompt_template,
13
- system_message,
14
- )
10
+ from ._prompt import chain_of_thought, prompt_template, system_message, user_message
15
11
  from ._solver import Generate, Solver, SolverSpec, generate, solver
16
12
  from ._task_state import Choice, Choices, TaskState
17
13
  from ._use_tools import use_tools
@@ -26,6 +22,7 @@ __all__ = [
26
22
  "chain_of_thought",
27
23
  "multiple_choice",
28
24
  "system_message",
25
+ "user_message",
29
26
  "self_critique",
30
27
  "use_tools",
31
28
  "plan",
@@ -2,6 +2,7 @@ from typing import Any
2
2
 
3
3
  from inspect_ai._util.dict import omit
4
4
  from inspect_ai.model import ChatMessageSystem
5
+ from inspect_ai.model._chat_message import ChatMessageUser
5
6
  from inspect_ai.util import resource
6
7
 
7
8
  from ._solver import Generate, Solver, solver
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
15
16
 
16
17
  Prompt template containing a `{prompt}` placeholder and any
17
18
  number of additional `params`. All values contained in sample
18
- `metadata` are also automatically included in the `params`.
19
+ `metadata` and `store` are also automatically included in the
20
+ `params`.
19
21
 
20
22
  Args:
21
23
  template: (str): Template for prompt.
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
29
31
 
30
32
  async def solve(state: TaskState, generate: Generate) -> TaskState:
31
33
  prompt = state.user_prompt
32
- kwargs = omit(state.metadata, ["prompt"]) | params
34
+ kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
33
35
  prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
34
36
  return state
35
37
 
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
41
43
  """Solver which inserts a system message into the conversation.
42
44
 
43
45
  System message template containing any number of optional `params`.
44
- for substitution. All values contained in sample `metadata` are also
45
- automatically included in the `params`.
46
+ for substitution using the `str.format()` method. All values
47
+ contained in sample `metadata` and `store` are also automatically
48
+ included in the `params`.
46
49
 
47
50
  The new message will go after other system messages (if there
48
51
  are none it will be inserted at the beginning of the conversation).
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
58
61
  content = resource(template)
59
62
 
60
63
  async def solve(state: TaskState, generate: Generate) -> TaskState:
61
- kwargs = state.metadata | params
64
+ kwargs = state.metadata | state.store._data | params
62
65
  append_system_message(
63
66
  state.messages, ChatMessageSystem(content=content.format(**kwargs))
64
67
  )
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
67
70
  return solve
68
71
 
69
72
 
73
+ @solver
74
+ def user_message(template: str, **params: Any) -> Solver:
75
+ """Solver which inserts a user message into the conversation.
76
+
77
+ User message template containing any number of optional `params`.
78
+ for substitution using the `str.format()` method. All values
79
+ contained in sample `metadata` and `store` are also automatically
80
+ included in the `params`.
81
+
82
+ Args:
83
+ template (str): Template for user message.
84
+ **params (dict[str,Any]): Parameters to fill into the template.
85
+
86
+ Returns:
87
+ A solver that inserts the parameterised user message.
88
+ """
89
+ # read template
90
+ content = resource(template)
91
+
92
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
93
+ kwargs = state.metadata | state.store._data | params
94
+ state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
95
+ return state
96
+
97
+ return solve
98
+
99
+
70
100
  DEFAULT_COT_TEMPLATE = r"""
71
101
  {prompt}
72
102
 
@@ -2,8 +2,9 @@ from collections.abc import Sequence
2
2
  from contextvars import ContextVar
3
3
  from copy import deepcopy
4
4
  from dataclasses import dataclass
5
+ from itertools import tee
5
6
  from random import Random
6
- from typing import Any, Type, Union, cast, overload
7
+ from typing import Any, Iterable, SupportsIndex, Type, Union, cast, overload
7
8
 
8
9
  from pydantic_core import to_jsonable_python
9
10
 
@@ -15,9 +16,13 @@ from inspect_ai.model import (
15
16
  ModelOutput,
16
17
  )
17
18
  from inspect_ai.model._call_tools import tools_info
19
+ from inspect_ai.model._chat_message import ChatMessageBase
18
20
  from inspect_ai.model._model import sample_total_tokens
21
+ from inspect_ai.scorer._metric import Score
22
+ from inspect_ai.scorer._target import Target
19
23
  from inspect_ai.tool import Tool, ToolChoice
20
24
  from inspect_ai.tool._tool_def import ToolDef
25
+ from inspect_ai.util._limit import SampleLimitExceededError
21
26
  from inspect_ai.util._store import Store, store_jsonable
22
27
  from inspect_ai.util._store_model import SMT
23
28
 
@@ -136,6 +141,7 @@ class TaskState:
136
141
  epoch: int,
137
142
  input: str | list[ChatMessage],
138
143
  messages: list[ChatMessage],
144
+ target: Target = Target(""),
139
145
  choices: list[str] | None = [],
140
146
  output: ModelOutput | None = None,
141
147
  message_limit: int | None = None,
@@ -161,10 +167,13 @@ class TaskState:
161
167
  or `input_text` only
162
168
  """
163
169
 
170
+ self.target = target
171
+ """The scoring target for this `Sample`."""
172
+
164
173
  self.metadata = metadata
165
174
  """Metadata from the `Sample` for this `TaskState`"""
166
175
 
167
- self.messages = messages
176
+ self._messages: list[ChatMessage] = ChatMessageList(messages)
168
177
  """
169
178
  Chat conversation history for sample.
170
179
 
@@ -189,9 +198,7 @@ class TaskState:
189
198
  """
190
199
 
191
200
  self._message_limit = message_limit
192
- self._message_limit_exceeded = False
193
201
  self._token_limit = token_limit
194
- self._token_limit_exceeded = False
195
202
  self._completed = completed
196
203
 
197
204
  """Store for shared data"""
@@ -202,6 +209,9 @@ class TaskState:
202
209
  else:
203
210
  self.choices = Choices([])
204
211
 
212
+ self.scores: dict[str, Score] | None = None
213
+ """Scores yielded by running task."""
214
+
205
215
  @property
206
216
  def model(self) -> ModelName:
207
217
  """Name of model being evaluated."""
@@ -254,6 +264,16 @@ class TaskState:
254
264
  else:
255
265
  raise ValueError("user_prompt requested from TaskState but none available")
256
266
 
267
+ @property
268
+ def messages(self) -> list[ChatMessage]:
269
+ """Messages in chat history"""
270
+ return self._messages
271
+
272
+ @messages.setter
273
+ def messages(self, messages: list[ChatMessage]) -> None:
274
+ """Set messages in chat history."""
275
+ self._messages = ChatMessageList(messages)
276
+
257
277
  @property
258
278
  def max_messages(self) -> int | None:
259
279
  """Deprecated (use message_limit)."""
@@ -300,40 +320,7 @@ class TaskState:
300
320
  @property
301
321
  def completed(self) -> bool:
302
322
  """Is the task completed."""
303
- # update messages
304
- from inspect_ai.log._samples import set_active_sample_total_messages
305
- from inspect_ai.log._transcript import SampleLimitEvent, transcript
306
-
307
- set_active_sample_total_messages(len(self.messages))
308
-
309
- if self._completed:
310
- return True
311
- elif self.message_limit and len(self.messages) >= self.message_limit:
312
- # log if this is the first time we hit this
313
- if not self._message_limit_exceeded:
314
- self._message_limit_exceeded = True
315
- transcript()._event(
316
- SampleLimitEvent(
317
- type="message",
318
- message=f"Sample completed: exceeded message limit ({self.message_limit})",
319
- limit=self.message_limit,
320
- )
321
- )
322
- return True
323
- elif self.token_limit and self.token_usage >= self.token_limit:
324
- # log if this is the first time we hit this
325
- if not self._token_limit_exceeded:
326
- self._token_limit_exceeded = True
327
- transcript()._event(
328
- SampleLimitEvent(
329
- type="token",
330
- message=f"Sample completed: exceeded token limit ({self.token_limit:,})",
331
- limit=self.token_limit,
332
- )
333
- )
334
- return True
335
- else:
336
- return False
323
+ return self._completed
337
324
 
338
325
  @completed.setter
339
326
  def completed(self, completed: bool) -> None:
@@ -413,3 +400,58 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
413
400
  def sample_jsonable(sample: Sample) -> dict[str, Any]:
414
401
  jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
415
402
  return cast(dict[str, Any], deepcopy(jsonable))
403
+
404
+
405
+ class ChatMessageList(list[ChatMessage]):
406
+ def __init__(self, iterable: Iterable[ChatMessage]):
407
+ items, length = self._iterable_length(iterable)
408
+ self._check_size(length)
409
+ super().__init__(items)
410
+
411
+ def _check_size(self, additional_items: int = 1) -> None:
412
+ from inspect_ai.log._samples import active_sample_message_limit
413
+
414
+ messages_limit = active_sample_message_limit()
415
+ if messages_limit is not None:
416
+ messages = len(self) + additional_items
417
+ if messages > messages_limit:
418
+ raise SampleLimitExceededError(
419
+ "message", value=messages, limit=messages_limit
420
+ )
421
+
422
+ def append(self, item: ChatMessage) -> None:
423
+ self._check_size()
424
+ super().append(item)
425
+
426
+ def extend(self, items: Iterable[ChatMessage]) -> None:
427
+ items, length = self._iterable_length(items)
428
+ self._check_size(length)
429
+ super().extend(items)
430
+
431
+ def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
432
+ self._check_size()
433
+ super().insert(index, item)
434
+
435
+ @overload
436
+ def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
437
+
438
+ @overload
439
+ def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
440
+
441
+ def __setitem__(
442
+ self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
443
+ ) -> None:
444
+ if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
445
+ item, length = self._iterable_length(item)
446
+ size_change = length - len(self[index])
447
+ if size_change > 0:
448
+ self._check_size(size_change)
449
+
450
+ super().__setitem__(index, item) # type: ignore[assignment,index]
451
+
452
+ def _iterable_length(
453
+ self, items: Iterable[ChatMessage]
454
+ ) -> tuple[Iterable[ChatMessage], int]:
455
+ items, counter = tee(items)
456
+ length = sum(1 for _ in counter)
457
+ return items, length
@@ -1,4 +1,10 @@
1
- from inspect_ai._util.content import Content, ContentImage, ContentText
1
+ from inspect_ai._util.content import (
2
+ Content,
3
+ ContentAudio,
4
+ ContentImage,
5
+ ContentText,
6
+ ContentVideo,
7
+ )
2
8
  from inspect_ai._util.deprecation import relocated_module_attribute
3
9
 
4
10
  from ._tool import Tool, ToolError, ToolResult, tool
@@ -6,6 +12,7 @@ from ._tool_call import (
6
12
  ToolCall,
7
13
  ToolCallContent,
8
14
  ToolCallError,
15
+ ToolCallModelInput,
9
16
  ToolCallView,
10
17
  ToolCallViewer,
11
18
  )
@@ -30,10 +37,13 @@ __all__ = [
30
37
  "ToolError",
31
38
  "ToolResult",
32
39
  "Content",
40
+ "ContentAudio",
33
41
  "ContentImage",
34
42
  "ContentText",
43
+ "ContentVideo",
35
44
  "ToolCall",
36
45
  "ToolCallContent",
46
+ "ToolCallModelInput",
37
47
  "ToolCallView",
38
48
  "ToolCallViewer",
39
49
  "ToolChoice",
inspect_ai/tool/_tool.py CHANGED
@@ -11,7 +11,12 @@ from typing import (
11
11
  runtime_checkable,
12
12
  )
13
13
 
14
- from inspect_ai._util.content import ContentImage, ContentText
14
+ from inspect_ai._util.content import (
15
+ ContentAudio,
16
+ ContentImage,
17
+ ContentText,
18
+ ContentVideo,
19
+ )
15
20
  from inspect_ai._util.registry import (
16
21
  RegistryInfo,
17
22
  registry_add,
@@ -19,7 +24,7 @@ from inspect_ai._util.registry import (
19
24
  registry_tag,
20
25
  )
21
26
 
22
- from ._tool_call import ToolCallViewer
27
+ from ._tool_call import ToolCallModelInput, ToolCallViewer
23
28
 
24
29
  logger = getLogger(__name__)
25
30
 
@@ -31,7 +36,9 @@ ToolResult = (
31
36
  | bool
32
37
  | ContentText
33
38
  | ContentImage
34
- | list[ContentText | ContentImage]
39
+ | ContentAudio
40
+ | ContentVideo
41
+ | list[ContentText | ContentImage | ContentAudio | ContentVideo]
35
42
  )
36
43
 
37
44
 
@@ -105,6 +112,7 @@ def tool(
105
112
  *,
106
113
  name: str | None = None,
107
114
  viewer: ToolCallViewer | None = None,
115
+ model_input: ToolCallModelInput | None = None,
108
116
  parallel: bool = True,
109
117
  prompt: str | None = None,
110
118
  ) -> Callable[[Callable[P, Tool]], Callable[P, Tool]]: ...
@@ -115,6 +123,7 @@ def tool(
115
123
  *,
116
124
  name: str | None = None,
117
125
  viewer: ToolCallViewer | None = None,
126
+ model_input: ToolCallModelInput | None = None,
118
127
  parallel: bool = True,
119
128
  prompt: str | None = None,
120
129
  ) -> Callable[P, Tool] | Callable[[Callable[P, Tool]], Callable[P, Tool]]:
@@ -128,6 +137,8 @@ def tool(
128
137
  will be used as the name of the tool.
129
138
  viewer (ToolCallViewer | None): Provide a custom view
130
139
  of tool call and context.
140
+ model_input (ToolCallModelInput | None): Provide a custom
141
+ function for playing back tool results as model input.
131
142
  parallel (bool):
132
143
  Does this tool support parallel execution?
133
144
  (defaults to True).
@@ -169,6 +180,9 @@ def tool(
169
180
  TOOL_PROMPT: prompt,
170
181
  TOOL_PARALLEL: parallel,
171
182
  TOOL_VIEWER: viewer,
183
+ TOOL_MODEL_INPUT: (
184
+ model_input or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
185
+ ),
172
186
  },
173
187
  ),
174
188
  *args,
@@ -188,3 +202,7 @@ def tool(
188
202
  TOOL_PROMPT = "prompt"
189
203
  TOOL_PARALLEL = "parallel"
190
204
  TOOL_VIEWER = "viewer"
205
+ TOOL_MODEL_INPUT = "model_input"
206
+
207
+
208
+ TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
@@ -3,6 +3,8 @@ from typing import Any, Callable, Literal
3
3
 
4
4
  from pydantic import BaseModel, Field
5
5
 
6
+ from inspect_ai._util.content import Content
7
+
6
8
 
7
9
  class ToolCallContent(BaseModel):
8
10
  """Content to include in tool call view."""
@@ -71,3 +73,11 @@ class ToolCallError:
71
73
 
72
74
  ToolCallViewer = Callable[[ToolCall], ToolCallView]
73
75
  """Custom view renderer for tool calls."""
76
+
77
+
78
+ ToolCallModelInput = Callable[[int, int, str | list[Content]], str | list[Content]]
79
+ """Determine how tool call results are played back as model input.
80
+
81
+ The first argument is an index into the total number of tool results
82
+ for this tool in the message history, the second is the total number.
83
+ """
@@ -13,8 +13,8 @@ from inspect_ai._util.registry import (
13
13
  set_registry_params,
14
14
  )
15
15
 
16
- from ._tool import TOOL_PARALLEL, TOOL_PROMPT, TOOL_VIEWER, Tool
17
- from ._tool_call import ToolCallViewer
16
+ from ._tool import TOOL_MODEL_INPUT, TOOL_PARALLEL, TOOL_PROMPT, TOOL_VIEWER, Tool
17
+ from ._tool_call import ToolCallModelInput, ToolCallViewer
18
18
  from ._tool_description import (
19
19
  ToolDescription,
20
20
  set_tool_description,
@@ -33,6 +33,7 @@ class ToolDef:
33
33
  parameters: dict[str, str] | ToolParams | None = None,
34
34
  parallel: bool | None = None,
35
35
  viewer: ToolCallViewer | None = None,
36
+ model_input: ToolCallModelInput | None = None,
36
37
  ) -> None:
37
38
  """Tool definition.
38
39
 
@@ -46,6 +47,8 @@ class ToolDef:
46
47
  parallel (bool | None): Does the tool support parallel execution
47
48
  (defaults to True if not specified)
48
49
  viewer (ToolCallViewer | None): Optional tool call viewer implementation.
50
+ model_input (ToolCallModelInput | None): Optional function that determines how
51
+ tool call results are played back as model input.
49
52
 
50
53
  Returns:
51
54
  Tool definition.
@@ -68,6 +71,7 @@ class ToolDef:
68
71
  parameters = parameters if parameters is not None else tdef.parameters
69
72
  self.parallel = parallel if parallel is not None else tdef.parallel
70
73
  self.viewer = viewer or tdef.viewer
74
+ self.model_input = model_input or tdef.model_input
71
75
 
72
76
  # if its not a tool then extract tool_info if all fields have not
73
77
  # been provided explicitly
@@ -97,6 +101,7 @@ class ToolDef:
97
101
  # behavioral attributes
98
102
  self.parallel = parallel is not False
99
103
  self.viewer = viewer
104
+ self.model_input = model_input
100
105
 
101
106
  tool: Callable[..., Any]
102
107
  """Callable to execute tool."""
@@ -116,6 +121,9 @@ class ToolDef:
116
121
  viewer: ToolCallViewer | None
117
122
  """Custom viewer for tool call"""
118
123
 
124
+ model_input: ToolCallModelInput | None
125
+ """Custom model input presenter for tool calls."""
126
+
119
127
  def as_tool(self) -> Tool:
120
128
  """Convert a ToolDef to a Tool."""
121
129
  tool = self.tool
@@ -159,11 +167,12 @@ class ToolDefFields(NamedTuple):
159
167
  parameters: ToolParams
160
168
  parallel: bool
161
169
  viewer: ToolCallViewer | None
170
+ model_input: ToolCallModelInput | None
162
171
 
163
172
 
164
173
  def tool_def_fields(tool: Tool) -> ToolDefFields:
165
174
  # get tool_info
166
- name, prompt, parallel, viewer = tool_registry_info(tool)
175
+ name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
167
176
  tool_info = parse_tool_info(tool)
168
177
 
169
178
  # if there is a description then append any prompt to the
@@ -213,15 +222,17 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
213
222
  parameters=tool_info.parameters,
214
223
  parallel=parallel,
215
224
  viewer=viewer,
225
+ model_input=model_input,
216
226
  )
217
227
 
218
228
 
219
229
  def tool_registry_info(
220
230
  tool: Tool,
221
- ) -> tuple[str, str | None, bool, ToolCallViewer | None]:
231
+ ) -> tuple[str, str | None, bool, ToolCallViewer | None, ToolCallModelInput | None]:
222
232
  info = registry_info(tool)
223
233
  name = info.name.split("/")[-1]
224
234
  prompt = info.metadata.get(TOOL_PROMPT, None)
225
235
  parallel = info.metadata.get(TOOL_PARALLEL, True)
226
236
  viewer = info.metadata.get(TOOL_VIEWER, None)
227
- return name, prompt, parallel, viewer
237
+ model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
238
+ return name, prompt, parallel, viewer, model_input