inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (161) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +7 -3
  3. inspect_ai/_cli/eval.py +17 -2
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +4 -3
  6. inspect_ai/_display/core/config.py +3 -3
  7. inspect_ai/_display/core/panel.py +7 -3
  8. inspect_ai/_display/plain/__init__.py +0 -0
  9. inspect_ai/_display/plain/display.py +203 -0
  10. inspect_ai/_display/rich/display.py +4 -9
  11. inspect_ai/_display/textual/app.py +4 -1
  12. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  13. inspect_ai/_display/textual/widgets/samples.py +119 -16
  14. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  15. inspect_ai/_eval/eval.py +32 -20
  16. inspect_ai/_eval/evalset.py +7 -5
  17. inspect_ai/_eval/score.py +1 -0
  18. inspect_ai/_eval/task/__init__.py +2 -2
  19. inspect_ai/_eval/task/images.py +40 -25
  20. inspect_ai/_eval/task/results.py +50 -22
  21. inspect_ai/_eval/task/run.py +180 -124
  22. inspect_ai/_eval/task/sandbox.py +10 -5
  23. inspect_ai/_eval/task/task.py +140 -25
  24. inspect_ai/_util/constants.py +2 -0
  25. inspect_ai/_util/content.py +23 -1
  26. inspect_ai/_util/images.py +20 -17
  27. inspect_ai/_util/kvstore.py +73 -0
  28. inspect_ai/_util/notgiven.py +18 -0
  29. inspect_ai/_util/port_names.py +61 -0
  30. inspect_ai/_util/text.py +23 -0
  31. inspect_ai/_util/thread.py +5 -0
  32. inspect_ai/_view/www/App.css +31 -1
  33. inspect_ai/_view/www/dist/assets/index.css +31 -1
  34. inspect_ai/_view/www/dist/assets/index.js +25375 -1846
  35. inspect_ai/_view/www/log-schema.json +129 -15
  36. inspect_ai/_view/www/package.json +2 -0
  37. inspect_ai/_view/www/src/App.mjs +8 -10
  38. inspect_ai/_view/www/src/Types.mjs +0 -1
  39. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  40. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  41. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  42. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  43. inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
  44. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  45. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  46. inspect_ai/_view/www/src/index.js +75 -2
  47. inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
  48. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
  49. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  50. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  51. inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
  52. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  53. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
  54. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
  55. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  56. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +62 -27
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/Json.mjs +12 -6
  76. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
  77. inspect_ai/_view/www/vite.config.js +7 -0
  78. inspect_ai/_view/www/yarn.lock +116 -0
  79. inspect_ai/approval/_human/__init__.py +0 -0
  80. inspect_ai/approval/_human/util.py +2 -2
  81. inspect_ai/approval/_policy.py +12 -6
  82. inspect_ai/dataset/_sources/csv.py +2 -1
  83. inspect_ai/dataset/_sources/json.py +2 -1
  84. inspect_ai/dataset/_sources/util.py +15 -7
  85. inspect_ai/log/_condense.py +11 -1
  86. inspect_ai/log/_log.py +3 -6
  87. inspect_ai/log/_recorders/eval.py +19 -8
  88. inspect_ai/log/_samples.py +26 -5
  89. inspect_ai/log/_transcript.py +32 -2
  90. inspect_ai/model/__init__.py +10 -2
  91. inspect_ai/model/_call_tools.py +59 -12
  92. inspect_ai/model/_chat_message.py +2 -4
  93. inspect_ai/model/_conversation.py +61 -0
  94. inspect_ai/model/_generate_config.py +10 -4
  95. inspect_ai/model/_model.py +117 -18
  96. inspect_ai/model/_model_output.py +7 -2
  97. inspect_ai/model/_providers/anthropic.py +109 -51
  98. inspect_ai/model/_providers/azureai.py +26 -24
  99. inspect_ai/model/_providers/bedrock.py +43 -44
  100. inspect_ai/model/_providers/google.py +121 -58
  101. inspect_ai/model/_providers/groq.py +7 -5
  102. inspect_ai/model/_providers/hf.py +11 -6
  103. inspect_ai/model/_providers/mistral.py +17 -20
  104. inspect_ai/model/_providers/openai.py +32 -21
  105. inspect_ai/model/_providers/openai_o1.py +9 -8
  106. inspect_ai/model/_providers/providers.py +1 -1
  107. inspect_ai/model/_providers/together.py +8 -8
  108. inspect_ai/model/_providers/vertex.py +18 -8
  109. inspect_ai/scorer/__init__.py +13 -2
  110. inspect_ai/scorer/_metrics/__init__.py +2 -2
  111. inspect_ai/scorer/_metrics/std.py +3 -3
  112. inspect_ai/scorer/_reducer/reducer.py +1 -1
  113. inspect_ai/scorer/_scorer.py +2 -2
  114. inspect_ai/solver/__init__.py +2 -5
  115. inspect_ai/solver/_prompt.py +35 -5
  116. inspect_ai/solver/_task_state.py +80 -38
  117. inspect_ai/tool/__init__.py +11 -1
  118. inspect_ai/tool/_tool.py +21 -3
  119. inspect_ai/tool/_tool_call.py +10 -0
  120. inspect_ai/tool/_tool_def.py +16 -5
  121. inspect_ai/tool/_tool_with.py +21 -4
  122. inspect_ai/tool/beta/__init__.py +5 -0
  123. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  124. inspect_ai/tool/beta/_computer/_common.py +133 -0
  125. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  126. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  127. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  128. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  129. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  130. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  131. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  134. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  135. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  136. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  137. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  138. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  139. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  144. inspect_ai/util/__init__.py +2 -3
  145. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  146. inspect_ai/util/_display.py +14 -4
  147. inspect_ai/util/_limit.py +26 -0
  148. inspect_ai/util/_sandbox/context.py +12 -13
  149. inspect_ai/util/_sandbox/docker/compose.py +24 -11
  150. inspect_ai/util/_sandbox/docker/docker.py +84 -14
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/environment.py +27 -1
  153. inspect_ai/util/_sandbox/local.py +1 -0
  154. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
  155. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
  156. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  157. inspect_ai/model/_trace.py +0 -48
  158. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
  159. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
  160. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
  161. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import contextlib
2
3
  from contextvars import ContextVar
3
4
  from datetime import datetime
@@ -11,7 +12,7 @@ from typing import (
11
12
  Union,
12
13
  )
13
14
 
14
- from pydantic import BaseModel, Field, JsonValue, field_serializer
15
+ from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
15
16
 
16
17
  from inspect_ai._util.constants import SAMPLE_SUBTASK
17
18
  from inspect_ai._util.error import EvalError
@@ -69,7 +70,7 @@ class SampleLimitEvent(BaseEvent):
69
70
  event: Literal["sample_limit"] = Field(default="sample_limit")
70
71
  """Event type."""
71
72
 
72
- type: Literal["message", "time", "token", "operator"]
73
+ type: Literal["message", "time", "token", "operator", "custom"]
73
74
  """Type of limit that halted processing"""
74
75
 
75
76
  message: str
@@ -123,6 +124,9 @@ class ModelEvent(BaseEvent):
123
124
  output: ModelOutput
124
125
  """Output from model."""
125
126
 
127
+ error: str | None = Field(default=None)
128
+ """Error which occurred during model call."""
129
+
126
130
  cache: Literal["read", "write"] | None = Field(default=None)
127
131
  """Was this a cache read or write."""
128
132
 
@@ -176,6 +180,32 @@ class ToolEvent(BaseEvent):
176
180
  self.events = events
177
181
  self.pending = None
178
182
 
183
+ # mechanism for operator to cancel the tool call
184
+
185
+ def set_task(self, task: asyncio.Task[Any]) -> None:
186
+ """Set the tool task (for possible cancellation)"""
187
+ self._task = task
188
+
189
+ def cancel(self) -> None:
190
+ """Cancel the tool task."""
191
+ if self._task:
192
+ self._cancelled = True
193
+ self._task.cancel()
194
+
195
+ @property
196
+ def cancelled(self) -> bool:
197
+ """Was the task cancelled?"""
198
+ return self._cancelled is True
199
+
200
+ _cancelled: bool | None = None
201
+ """Was this tool call cancelled?"""
202
+
203
+ _task: asyncio.Task[Any] | None = None
204
+ """Handle to task (used for cancellation)"""
205
+
206
+ model_config = ConfigDict(arbitrary_types_allowed=True)
207
+ """Required so that we can include '_task' as a member."""
208
+
179
209
 
180
210
  class ApprovalEvent(BaseEvent):
181
211
  """Tool approval."""
@@ -1,6 +1,12 @@
1
1
  # ruff: noqa: F401 F403 F405
2
2
 
3
- from inspect_ai._util.content import Content, ContentImage, ContentText
3
+ from inspect_ai._util.content import (
4
+ Content,
5
+ ContentAudio,
6
+ ContentImage,
7
+ ContentText,
8
+ ContentVideo,
9
+ )
4
10
  from inspect_ai._util.deprecation import relocated_module_attribute
5
11
 
6
12
  from ._cache import (
@@ -42,8 +48,10 @@ __all__ = [
42
48
  "GenerateConfig",
43
49
  "GenerateConfigArgs",
44
50
  "CachePolicy",
45
- "ContentText",
51
+ "ContentAudio",
46
52
  "ContentImage",
53
+ "ContentText",
54
+ "ContentVideo",
47
55
  "Content",
48
56
  "ChatMessage",
49
57
  "ChatMessageSystem",
@@ -24,11 +24,17 @@ from typing import (
24
24
  from jsonschema import Draft7Validator
25
25
  from pydantic import BaseModel
26
26
 
27
- from inspect_ai._util.content import Content, ContentImage, ContentText
27
+ from inspect_ai._util.content import (
28
+ Content,
29
+ ContentAudio,
30
+ ContentImage,
31
+ ContentText,
32
+ ContentVideo,
33
+ )
28
34
  from inspect_ai._util.format import format_function_call
29
35
  from inspect_ai._util.text import truncate_string_to_bytes
30
36
  from inspect_ai._util.trace import trace_action
31
- from inspect_ai.model._trace import trace_tool_mesage
37
+ from inspect_ai.model._conversation import conversation_tool_mesage
32
38
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
33
39
  from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
34
40
  from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
@@ -120,10 +126,14 @@ async def call_tools(
120
126
  # massage result, leave list[Content] alone, convert all other
121
127
  # types to string as that is what the model APIs accept
122
128
  truncated: tuple[int, int] | None = None
123
- if isinstance(result, ContentText | ContentImage):
129
+ if isinstance(
130
+ result, ContentText | ContentImage | ContentAudio | ContentVideo
131
+ ):
124
132
  content: str | list[Content] = [result]
125
133
  elif isinstance(result, list) and (
126
- isinstance(result[0], ContentText | ContentImage)
134
+ isinstance(
135
+ result[0], ContentText | ContentImage | ContentAudio | ContentVideo
136
+ )
127
137
  ):
128
138
  content = result
129
139
  else:
@@ -163,6 +173,9 @@ async def call_tools(
163
173
  # call tools
164
174
  tool_messages: list[ChatMessageTool] = []
165
175
  for call in message.tool_calls:
176
+ # create the task
177
+ task = asyncio.create_task(call_tool_task(call))
178
+
166
179
  # create pending tool event and add it to the transcript
167
180
  event = ToolEvent(
168
181
  id=call.id,
@@ -171,15 +184,44 @@ async def call_tools(
171
184
  view=call.view,
172
185
  pending=True,
173
186
  )
187
+ event.set_task(task)
174
188
  transcript()._event(event)
175
189
 
176
- # execute the tool call
177
- task = asyncio.create_task(call_tool_task(call))
178
- tool_message, result_event = await task
190
+ # execute the tool call. if the operator cancelled the
191
+ # tool call then synthesize the appropriate message/event
192
+ try:
193
+ tool_message, result_event = await task
194
+ except asyncio.CancelledError:
195
+ if event.cancelled:
196
+ tool_message = ChatMessageTool(
197
+ content="",
198
+ function=call.function,
199
+ tool_call_id=call.id,
200
+ error=ToolCallError(
201
+ "timeout", "Command timed out before completing."
202
+ ),
203
+ )
204
+ result_event = ToolEvent(
205
+ id=call.id,
206
+ function=call.function,
207
+ arguments=call.arguments,
208
+ result=tool_message.content,
209
+ truncated=None,
210
+ view=call.view,
211
+ error=tool_message.error,
212
+ events=[],
213
+ )
214
+ transcript().info(
215
+ f"Tool call '{call.function}' was cancelled by operator."
216
+ )
217
+ else:
218
+ raise
219
+
220
+ # update return messages
179
221
  tool_messages.append(tool_message)
180
222
 
181
- # trace if we are tracing
182
- trace_tool_mesage(tool_message)
223
+ # print conversation if display is conversation
224
+ conversation_tool_mesage(tool_message)
183
225
 
184
226
  # update the event with the results
185
227
  event.set_result(
@@ -286,6 +328,10 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
286
328
  type_hints = get_type_hints(func)
287
329
  docstring = inspect.getdoc(func)
288
330
 
331
+ # if the function takes **kwargs: Any then just pass the tool arguments through
332
+ if "kwargs" in type_hints and type_hints["kwargs"] == Any:
333
+ return input
334
+
289
335
  # build params
290
336
  params: dict[str, Any] = {}
291
337
  for param_name, param in signature.parameters.items():
@@ -411,12 +457,13 @@ def truncate_tool_output(
411
457
  # truncate if required
412
458
  truncated = truncate_string_to_bytes(output, active_max_output)
413
459
  if truncated:
414
- truncated_output = dedent(f"""
460
+ truncated_output = dedent("""
415
461
  The output of your call to {tool_name} was too long to be displayed.
416
462
  Here is a truncated version:
417
463
  <START_TOOL_OUTPUT>
418
- {truncated.output}
419
- <END_TOOL_OUTPUT>""")
464
+ {truncated_output}
465
+ <END_TOOL_OUTPUT>
466
+ """).format(tool_name=tool_name, truncated_output=truncated.output)
420
467
  return TruncatedToolOutput(
421
468
  truncated_output, truncated.original_bytes, active_max_output
422
469
  )
@@ -59,10 +59,8 @@ class ChatMessageBase(BaseModel):
59
59
  if isinstance(self.content, str):
60
60
  self.content = text
61
61
  else:
62
- all_images = [
63
- content for content in self.content if content.type == "image"
64
- ]
65
- self.content = [ContentText(text=text)] + all_images
62
+ all_other = [content for content in self.content if content.type != "text"]
63
+ self.content = [ContentText(text=text)] + all_other
66
64
 
67
65
 
68
66
  class ChatMessageSystem(ChatMessageBase):
@@ -0,0 +1,61 @@
1
+ from rich.console import RenderableType
2
+ from rich.text import Text
3
+
4
+ from inspect_ai._util.constants import NO_CONTENT
5
+ from inspect_ai._util.rich import lines_display
6
+ from inspect_ai._util.transcript import transcript_markdown
7
+ from inspect_ai.util._conversation import conversation_panel
8
+ from inspect_ai.util._display import display_type
9
+
10
+ from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
11
+ from ._render import messages_preceding_assistant, render_tool_calls
12
+
13
+ MESSAGE_TITLE = "Message"
14
+
15
+
16
+ def conversation_tool_mesage(message: ChatMessageTool) -> None:
17
+ if display_type() == "conversation":
18
+ # truncate output to 100 lines
19
+ output = (
20
+ message.error.message.strip() if message.error else message.text.strip()
21
+ )
22
+ if output:
23
+ content = lines_display(output, 100)
24
+
25
+ conversation_panel(
26
+ title=f"Tool Output: {message.function}",
27
+ content=content,
28
+ )
29
+
30
+
31
+ def conversation_assistant_message(
32
+ input: list[ChatMessage], message: ChatMessageAssistant
33
+ ) -> None:
34
+ if display_type() == "conversation":
35
+ # print precding messages that aren't tool or assistant
36
+ for m in messages_preceding_assistant(input):
37
+ conversation_panel(
38
+ title=m.role.capitalize(),
39
+ content=transcript_markdown(m.text, escape=True),
40
+ )
41
+
42
+ # start with assistant content
43
+ content: list[RenderableType] = (
44
+ [transcript_markdown(message.text, escape=True)]
45
+ if message.text and message.text != NO_CONTENT
46
+ else []
47
+ )
48
+
49
+ # print tool calls
50
+ if message.tool_calls:
51
+ if content:
52
+ content.append(Text())
53
+ content.extend(render_tool_calls(message.tool_calls))
54
+
55
+ # print the assistant message
56
+ conversation_panel(title="Assistant", content=content)
57
+
58
+
59
+ def conversation_assistant_error(error: Exception) -> None:
60
+ if display_type() == "conversation":
61
+ conversation_panel(title="Assistant", content=repr(error))
@@ -58,14 +58,17 @@ class GenerateConfigArgs(TypedDict, total=False):
58
58
  """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, and TogetherAI only."""
59
59
 
60
60
  logprobs: bool | None
61
- """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
61
+ """Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
62
62
 
63
63
  top_logprobs: int | None
64
- """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, and Huggingface only."""
64
+ """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, and Huggingface only."""
65
65
 
66
66
  parallel_tool_calls: bool | None
67
67
  """Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
68
68
 
69
+ internal_tools: bool | None
70
+ """Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""
71
+
69
72
  max_tool_output: int | None
70
73
  """Maximum tool output (in bytes). Defaults to 16 * 1024."""
71
74
 
@@ -128,14 +131,17 @@ class GenerateConfig(BaseModel):
128
131
  """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and vLLM only."""
129
132
 
130
133
  logprobs: bool | None = Field(default=None)
131
- """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
134
+ """Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
132
135
 
133
136
  top_logprobs: int | None = Field(default=None)
134
- """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, Huggingface, and vLLM only."""
137
+ """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, and vLLM only."""
135
138
 
136
139
  parallel_tool_calls: bool | None = Field(default=None)
137
140
  """Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
138
141
 
142
+ internal_tools: bool | None = Field(default=None)
143
+ """Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""
144
+
139
145
  max_tool_output: int | None = Field(default=None)
140
146
  """Maximum tool output (in bytes). Defaults to 16 * 1024."""
141
147
 
@@ -33,6 +33,7 @@ from inspect_ai._util.trace import trace_action
33
33
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
34
34
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
35
35
  from inspect_ai.util import concurrency
36
+ from inspect_ai.util._limit import SampleLimitExceededError
36
37
 
37
38
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
38
39
  from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
@@ -43,6 +44,7 @@ from ._chat_message import (
43
44
  ChatMessageTool,
44
45
  ChatMessageUser,
45
46
  )
47
+ from ._conversation import conversation_assistant_error, conversation_assistant_message
46
48
  from ._generate_config import (
47
49
  GenerateConfig,
48
50
  active_generate_config,
@@ -50,7 +52,6 @@ from ._generate_config import (
50
52
  )
51
53
  from ._model_call import ModelCall
52
54
  from ._model_output import ModelOutput, ModelUsage
53
- from ._trace import trace_assistant_message
54
55
 
55
56
  logger = logging.getLogger(__name__)
56
57
 
@@ -116,7 +117,7 @@ class ModelAPI(abc.ABC):
116
117
  tools: list[ToolInfo],
117
118
  tool_choice: ToolChoice,
118
119
  config: GenerateConfig,
119
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
120
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
120
121
  """Generate output from the model.
121
122
 
122
123
  Args:
@@ -165,7 +166,7 @@ class ModelAPI(abc.ABC):
165
166
  return False
166
167
 
167
168
  def tool_result_images(self) -> bool:
168
- """Tool results can containe images"""
169
+ """Tool results can contain images"""
169
170
  return False
170
171
 
171
172
 
@@ -222,11 +223,17 @@ class Model:
222
223
  Returns:
223
224
  ModelOutput
224
225
  """
226
+ # if we are the default model then enforce message limit if it
227
+ # exists (raise an exception if it is exceeded)
228
+ is_active_model = self == active_model()
229
+ if is_active_model:
230
+ handle_sample_message_limit(input)
231
+
225
232
  # base config for this model
226
233
  base_config = self.config
227
234
 
228
235
  # if we are the active_model then merge active generate config
229
- if self == active_model():
236
+ if is_active_model:
230
237
  base_config = base_config.merge(active_generate_config())
231
238
 
232
239
  # merge passed config
@@ -296,6 +303,9 @@ class Model:
296
303
  tools = []
297
304
  tool_choice = "none"
298
305
 
306
+ # apply any tool model_input handlers
307
+ input = resolve_tool_model_input(tdefs, input)
308
+
299
309
  # break tool image content out into user messages if the model doesn't
300
310
  # support tools returning images
301
311
  if not self.api.tool_result_images():
@@ -389,6 +399,17 @@ class Model:
389
399
  output = result
390
400
  call = None
391
401
 
402
+ # raise error
403
+ if isinstance(output, Exception):
404
+ complete(output, call)
405
+
406
+ # Wrap the error in a runtime error which will show the
407
+ # request which caused the error
408
+ error = repr(output)
409
+ request = json.dumps(call.request, indent=2) if call is not None else ""
410
+ error_message = f"{error}\n\nRequest:\n{request}"
411
+ raise RuntimeError(error_message)
412
+
392
413
  # update output with time elapsed
393
414
  output.time = time_elapsed
394
415
 
@@ -464,7 +485,7 @@ class Model:
464
485
  cache: Literal["read", "write"] | None,
465
486
  output: ModelOutput | None = None,
466
487
  call: ModelCall | None = None,
467
- ) -> Callable[[ModelOutput, ModelCall | None], None]:
488
+ ) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
468
489
  from inspect_ai.log._transcript import ModelEvent, transcript
469
490
 
470
491
  # create event and add it to the transcript
@@ -484,13 +505,16 @@ class Model:
484
505
 
485
506
  # callable that can be used to update the interaction w/ output
486
507
  def complete(
487
- updated_output: ModelOutput, updated_call: ModelCall | None
508
+ result: ModelOutput | Exception, updated_call: ModelCall | None
488
509
  ) -> None:
489
510
  # trace
490
- trace_assistant_message(input, updated_output.choices[0].message)
511
+ if isinstance(result, ModelOutput):
512
+ conversation_assistant_message(input, result.choices[0].message)
513
+ event.output = result
514
+ else:
515
+ conversation_assistant_error(result)
516
+ event.error = repr(result)
491
517
 
492
- # update event
493
- event.output = updated_output
494
518
  event.call = updated_call
495
519
  event.pending = None
496
520
 
@@ -703,6 +727,40 @@ def simple_input_messages(
703
727
  return messages
704
728
 
705
729
 
730
+ def resolve_tool_model_input(
731
+ tdefs: list[ToolDef], messages: list[ChatMessage]
732
+ ) -> list[ChatMessage]:
733
+ # filter on tooldefs that have a model input handler
734
+ tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
735
+
736
+ # bail if there are no handlers
737
+ if len(tdefs) == 0:
738
+ return messages
739
+
740
+ # don't mutate the original messages
741
+ messages = deepcopy(messages)
742
+
743
+ # extract tool messages
744
+ tool_messages = [
745
+ message for message in messages if isinstance(message, ChatMessageTool)
746
+ ]
747
+ # run model_input handlers over all tool_messages with the same function name
748
+ for tdef in tdefs:
749
+ assert tdef.model_input
750
+ # filter messages down to just this tool
751
+ tdef_tool_messages = [
752
+ message for message in tool_messages if message.function == tdef.name
753
+ ]
754
+ # call the function for each tool, passing the index, total, and content
755
+ for index, message in enumerate(tdef_tool_messages):
756
+ message.content = tdef.model_input(
757
+ index, len(tool_messages), message.content
758
+ )
759
+
760
+ # return modified messages
761
+ return messages
762
+
763
+
706
764
  def tool_result_images_as_user_message(
707
765
  messages: list[ChatMessage],
708
766
  ) -> list[ChatMessage]:
@@ -713,16 +771,21 @@ def tool_result_images_reducer(
713
771
  messages: list[ChatMessage],
714
772
  message: ChatMessage,
715
773
  ) -> list[ChatMessage]:
716
- # append the message
717
- messages.append(message)
718
-
719
774
  # if there are tool result images, pull them out into a ChatUserMessage
720
775
  if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
776
+ tool_message = ChatMessageTool(
777
+ content=message.content.copy(),
778
+ tool_call_id=message.tool_call_id,
779
+ function=message.function,
780
+ )
781
+ assert isinstance(tool_message.content, list)
782
+ messages.append(tool_message)
783
+
721
784
  user_content: list[Content] = []
722
- for i in range(0, len(message.content)):
723
- if isinstance(message.content[i], ContentImage):
785
+ for i in range(0, len(tool_message.content)):
786
+ if isinstance(tool_message.content[i], ContentImage):
724
787
  user_content.append(message.content[i])
725
- message.content[i] = ContentText(
788
+ tool_message.content[i] = ContentText(
726
789
  text="Image content is in the message below."
727
790
  )
728
791
  if len(user_content) > 0:
@@ -730,6 +793,9 @@ def tool_result_images_reducer(
730
793
  ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
731
794
  )
732
795
 
796
+ else:
797
+ messages.append(message)
798
+
733
799
  # return messages
734
800
  return messages
735
801
 
@@ -813,6 +879,24 @@ def active_model() -> Model | None:
813
879
  active_model_context_var: ContextVar[Model] = ContextVar("active_model")
814
880
 
815
881
 
882
+ def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
883
+ from inspect_ai.log._samples import (
884
+ active_sample_message_limit,
885
+ set_active_sample_total_messages,
886
+ )
887
+
888
+ total_messages = 1 if isinstance(input, str) else len(input)
889
+ message_limit = active_sample_message_limit()
890
+ if message_limit is not None:
891
+ if total_messages >= message_limit:
892
+ raise SampleLimitExceededError(
893
+ "message", value=total_messages, limit=message_limit
894
+ )
895
+
896
+ # set total messages
897
+ set_active_sample_total_messages(total_messages)
898
+
899
+
816
900
  def init_model_usage() -> None:
817
901
  model_usage_context_var.set({})
818
902
 
@@ -822,13 +906,28 @@ def init_sample_model_usage() -> None:
822
906
 
823
907
 
824
908
  def record_model_usage(model: str, usage: ModelUsage) -> None:
909
+ from inspect_ai.log._samples import (
910
+ active_sample_token_limit,
911
+ set_active_sample_total_tokens,
912
+ )
913
+
914
+ # record usage
825
915
  set_model_usage(model, usage, sample_model_usage_context_var.get(None))
826
916
  set_model_usage(model, usage, model_usage_context_var.get(None))
827
917
 
828
- # update active sample
829
- from inspect_ai.log._samples import set_active_sample_total_tokens
918
+ # compute total tokens
919
+ total_tokens = sample_total_tokens()
830
920
 
831
- set_active_sample_total_tokens(sample_total_tokens())
921
+ # update active sample
922
+ set_active_sample_total_tokens(total_tokens)
923
+
924
+ # check for token limit overflow and raise
925
+ token_limit = active_sample_token_limit()
926
+ if token_limit is not None:
927
+ if total_tokens > token_limit:
928
+ raise SampleLimitExceededError(
929
+ "token", value=total_tokens, limit=token_limit
930
+ )
832
931
 
833
932
 
834
933
  def set_model_usage(
@@ -26,9 +26,14 @@ class ModelUsage(BaseModel):
26
26
 
27
27
 
28
28
  StopReason = Literal[
29
- "stop", "max_tokens", "model_length", "tool_calls", "content_filter", "unknown"
29
+ "stop",
30
+ "max_tokens",
31
+ "model_length",
32
+ "tool_calls",
33
+ "content_filter",
34
+ "unknown",
30
35
  ]
31
- """Reason that the model stopped generating."""
36
+ """Reason that the model stopped or failed to generate."""
32
37
 
33
38
 
34
39
  class TopLogprob(BaseModel):