inspect-ai 0.3.82__py3-none-any.whl → 0.3.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_display/textual/app.py +14 -3
  3. inspect_ai/_display/textual/display.py +4 -0
  4. inspect_ai/_display/textual/widgets/samples.py +9 -3
  5. inspect_ai/_display/textual/widgets/task_detail.py +3 -4
  6. inspect_ai/_display/textual/widgets/tasks.py +17 -1
  7. inspect_ai/_display/textual/widgets/vscode.py +44 -0
  8. inspect_ai/_eval/eval.py +36 -24
  9. inspect_ai/_eval/evalset.py +17 -18
  10. inspect_ai/_eval/loader.py +34 -11
  11. inspect_ai/_eval/run.py +8 -13
  12. inspect_ai/_eval/score.py +13 -3
  13. inspect_ai/_eval/task/generate.py +8 -9
  14. inspect_ai/_eval/task/log.py +2 -0
  15. inspect_ai/_eval/task/task.py +23 -9
  16. inspect_ai/_util/file.py +13 -0
  17. inspect_ai/_util/json.py +2 -1
  18. inspect_ai/_util/registry.py +1 -0
  19. inspect_ai/_util/vscode.py +37 -0
  20. inspect_ai/_view/www/App.css +6 -0
  21. inspect_ai/_view/www/dist/assets/index.css +304 -128
  22. inspect_ai/_view/www/dist/assets/index.js +47495 -27519
  23. inspect_ai/_view/www/log-schema.json +124 -31
  24. inspect_ai/_view/www/package.json +3 -0
  25. inspect_ai/_view/www/src/App.tsx +12 -0
  26. inspect_ai/_view/www/src/appearance/icons.ts +1 -0
  27. inspect_ai/_view/www/src/components/Card.tsx +6 -4
  28. inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
  29. inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
  30. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
  31. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
  32. inspect_ai/_view/www/src/components/Modal.module.css +38 -0
  33. inspect_ai/_view/www/src/components/Modal.tsx +77 -0
  34. inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
  35. inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
  36. inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
  37. inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
  38. inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
  39. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
  40. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
  41. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
  42. inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
  43. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
  44. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
  45. inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
  46. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
  47. inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
  48. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
  49. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
  50. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
  51. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
  52. inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
  53. inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
  54. inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
  55. inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
  56. inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
  57. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
  58. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
  59. inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
  60. inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
  61. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  62. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
  63. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
  64. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
  65. inspect_ai/_view/www/src/state/hooks.ts +5 -3
  66. inspect_ai/_view/www/src/state/logPolling.ts +5 -1
  67. inspect_ai/_view/www/src/state/logSlice.ts +10 -0
  68. inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
  69. inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
  70. inspect_ai/_view/www/src/types/log.d.ts +34 -26
  71. inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
  72. inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
  73. inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
  74. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
  75. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
  76. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
  77. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
  78. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
  79. inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
  80. inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
  81. inspect_ai/_view/www/yarn.lock +94 -1
  82. inspect_ai/agent/__init__.py +36 -0
  83. inspect_ai/agent/_agent.py +268 -0
  84. inspect_ai/agent/_as_solver.py +72 -0
  85. inspect_ai/agent/_as_tool.py +122 -0
  86. inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
  87. inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
  88. inspect_ai/agent/_filter.py +46 -0
  89. inspect_ai/agent/_handoff.py +93 -0
  90. inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
  91. inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
  92. inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
  93. inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
  94. inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
  95. inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
  96. inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
  97. inspect_ai/agent/_react.py +241 -0
  98. inspect_ai/agent/_run.py +36 -0
  99. inspect_ai/agent/_types.py +81 -0
  100. inspect_ai/log/_log.py +11 -2
  101. inspect_ai/log/_transcript.py +13 -9
  102. inspect_ai/model/__init__.py +7 -1
  103. inspect_ai/model/_call_tools.py +256 -52
  104. inspect_ai/model/_chat_message.py +7 -4
  105. inspect_ai/model/_conversation.py +13 -62
  106. inspect_ai/model/_display.py +85 -0
  107. inspect_ai/model/_model.py +113 -14
  108. inspect_ai/model/_model_output.py +14 -9
  109. inspect_ai/model/_openai.py +16 -4
  110. inspect_ai/model/_openai_computer_use.py +162 -0
  111. inspect_ai/model/_openai_responses.py +319 -165
  112. inspect_ai/model/_providers/anthropic.py +20 -21
  113. inspect_ai/model/_providers/azureai.py +24 -13
  114. inspect_ai/model/_providers/bedrock.py +1 -7
  115. inspect_ai/model/_providers/cloudflare.py +3 -3
  116. inspect_ai/model/_providers/goodfire.py +2 -6
  117. inspect_ai/model/_providers/google.py +11 -10
  118. inspect_ai/model/_providers/groq.py +6 -3
  119. inspect_ai/model/_providers/hf.py +7 -3
  120. inspect_ai/model/_providers/mistral.py +7 -10
  121. inspect_ai/model/_providers/openai.py +47 -17
  122. inspect_ai/model/_providers/openai_o1.py +11 -4
  123. inspect_ai/model/_providers/openai_responses.py +12 -14
  124. inspect_ai/model/_providers/providers.py +2 -2
  125. inspect_ai/model/_providers/together.py +12 -2
  126. inspect_ai/model/_providers/util/chatapi.py +7 -2
  127. inspect_ai/model/_providers/util/hf_handler.py +4 -2
  128. inspect_ai/model/_providers/util/llama31.py +4 -2
  129. inspect_ai/model/_providers/vertex.py +11 -9
  130. inspect_ai/model/_providers/vllm.py +4 -4
  131. inspect_ai/scorer/__init__.py +2 -0
  132. inspect_ai/scorer/_metrics/__init__.py +2 -0
  133. inspect_ai/scorer/_metrics/grouped.py +84 -0
  134. inspect_ai/scorer/_score.py +26 -6
  135. inspect_ai/solver/__init__.py +2 -2
  136. inspect_ai/solver/_basic_agent.py +22 -9
  137. inspect_ai/solver/_bridge.py +31 -0
  138. inspect_ai/solver/_chain.py +20 -12
  139. inspect_ai/solver/_fork.py +5 -1
  140. inspect_ai/solver/_human_agent.py +52 -0
  141. inspect_ai/solver/_prompt.py +3 -1
  142. inspect_ai/solver/_run.py +59 -0
  143. inspect_ai/solver/_solver.py +14 -4
  144. inspect_ai/solver/_task_state.py +5 -3
  145. inspect_ai/tool/_tool_call.py +15 -8
  146. inspect_ai/tool/_tool_def.py +17 -12
  147. inspect_ai/tool/_tool_support_helpers.py +2 -2
  148. inspect_ai/tool/_tool_with.py +14 -11
  149. inspect_ai/tool/_tools/_bash_session.py +11 -2
  150. inspect_ai/tool/_tools/_computer/_common.py +18 -2
  151. inspect_ai/tool/_tools/_computer/_computer.py +18 -2
  152. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
  153. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
  154. inspect_ai/tool/_tools/_think.py +1 -1
  155. inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
  156. inspect_ai/util/__init__.py +2 -0
  157. inspect_ai/util/_anyio.py +27 -0
  158. inspect_ai/util/_sandbox/__init__.py +2 -1
  159. inspect_ai/util/_sandbox/context.py +32 -7
  160. inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
  161. inspect_ai/util/_sandbox/docker/compose.py +2 -2
  162. inspect_ai/util/_sandbox/docker/docker.py +12 -1
  163. inspect_ai/util/_store_model.py +30 -7
  164. inspect_ai/util/_subprocess.py +13 -3
  165. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/METADATA +1 -1
  166. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/RECORD +179 -153
  167. inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
  168. /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
  169. /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
  170. /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
  171. /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
  172. /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
  173. /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
  174. /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
  175. /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
  176. /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
  177. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/WHEEL +0 -0
  178. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/entry_points.txt +0 -0
  179. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/licenses/LICENSE +0 -0
  180. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/top_level.txt +0 -0
@@ -6,7 +6,7 @@ import logging
6
6
  import os
7
7
  import time
8
8
  from contextvars import ContextVar
9
- from copy import deepcopy
9
+ from copy import copy, deepcopy
10
10
  from datetime import datetime
11
11
  from types import TracebackType
12
12
  from typing import Any, AsyncIterator, Callable, Literal, Type, cast
@@ -45,11 +45,17 @@ from inspect_ai._util.retry import report_http_retry
45
45
  from inspect_ai._util.trace import trace_action
46
46
  from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
47
47
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
48
+ from inspect_ai.tool._tool_call import ToolCallModelInputHints
48
49
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
49
50
  from inspect_ai.util import concurrency
50
51
 
51
52
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
52
- from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
53
+ from ._call_tools import (
54
+ disable_parallel_tools,
55
+ execute_tools,
56
+ tool_call_view,
57
+ tools_info,
58
+ )
53
59
  from ._chat_message import (
54
60
  ChatMessage,
55
61
  ChatMessageAssistant,
@@ -57,7 +63,10 @@ from ._chat_message import (
57
63
  ChatMessageTool,
58
64
  ChatMessageUser,
59
65
  )
60
- from ._conversation import conversation_assistant_error, conversation_assistant_message
66
+ from ._display import (
67
+ display_conversation_assistant,
68
+ display_conversation_assistant_error,
69
+ )
61
70
  from ._generate_config import (
62
71
  GenerateConfig,
63
72
  active_generate_config,
@@ -123,9 +132,20 @@ class ModelAPI(abc.ABC):
123
132
  # set any explicitly specified api key
124
133
  self.api_key = api_key
125
134
 
126
- async def close(self) -> None:
127
- """Close method for closing any client allocated for the model."""
128
- pass
135
+ async def aclose(self) -> None:
136
+ """Async close method for closing any client allocated for the model."""
137
+ self.close()
138
+
139
+ def close(self) -> None:
140
+ """Sync close method for closing any client allocated for the model."""
141
+ # if this is is called and aclose is implemented by a subclass then
142
+ # raise a runtime error (as this model reuqires async close)
143
+ aclose_method = getattr(self.__class__, "aclose")
144
+ base_aclose_method = getattr(ModelAPI, "aclose")
145
+ if aclose_method != base_aclose_method:
146
+ raise RuntimeError(
147
+ f"{self.__class__.__name__} models require an async close / context manager."
148
+ )
129
149
 
130
150
  @abc.abstractmethod
131
151
  async def generate(
@@ -201,6 +221,10 @@ class ModelAPI(abc.ABC):
201
221
  """Tool results can contain images"""
202
222
  return False
203
223
 
224
+ def disable_computer_screenshot_truncation(self) -> bool:
225
+ """Some models do not support truncation of computer screenshots."""
226
+ return False
227
+
204
228
  def emulate_reasoning_history(self) -> bool:
205
229
  """Chat message assistant messages with reasoning should playback reasoning with emulation (.e.g. <think> tags)"""
206
230
  return True
@@ -255,10 +279,23 @@ class Model:
255
279
  # get hit before score() or eval() so we activate nest_asyncio
256
280
  platform_init()
257
281
 
258
- async def __aenter__(self: "Model") -> "Model":
282
+ def __enter__(self: "Model") -> "Model":
259
283
  self._context_bound = True
260
284
  return self
261
285
 
286
+ async def __aenter__(self: "Model") -> "Model":
287
+ return self.__enter__()
288
+
289
+ def __exit__(
290
+ self,
291
+ exc_type: type[BaseException] | None,
292
+ exc: BaseException | None,
293
+ exc_tb: TracebackType | None,
294
+ ) -> None:
295
+ if not self._closed:
296
+ self.api.close()
297
+ self._closed = True
298
+
262
299
  async def __aexit__(
263
300
  self,
264
301
  exc_type: type[BaseException] | None,
@@ -266,7 +303,7 @@ class Model:
266
303
  exc_tb: TracebackType | None,
267
304
  ) -> None:
268
305
  if not self._closed:
269
- await self.api.close()
306
+ await self.api.aclose()
270
307
  self._closed = True
271
308
 
272
309
  @property
@@ -373,6 +410,55 @@ class Model:
373
410
  # return output
374
411
  return output
375
412
 
413
+ async def generate_loop(
414
+ self,
415
+ input: str | list[ChatMessage],
416
+ tools: list[Tool] | list[ToolDef] | list[Tool | ToolDef] = [],
417
+ config: GenerateConfig = GenerateConfig(),
418
+ cache: bool | CachePolicy = False,
419
+ ) -> tuple[list[ChatMessage], ModelOutput]:
420
+ """Generate output from the model, looping as long as the model calls tools.
421
+
422
+ Similar to `generate()`, but runs in a loop resolving model tool calls.
423
+ The loop terminates when the model stops calling tools. The final `ModelOutput`
424
+ as well the message list for the conversation are returned as a tuple.
425
+
426
+ Args:
427
+ input: Chat message input (if a `str` is passed it is converted
428
+ to a `ChatMessageUser`).
429
+ tools: Tools available for the model to call.
430
+ config: Model configuration.
431
+ cache: Caching behavior for generate responses (defaults to no caching).
432
+
433
+ Returns:
434
+ Tuple of list[ChatMessage], ModelOutput
435
+ """
436
+ # initialise messages
437
+ input = [ChatMessageUser(content=input)] if isinstance(input, str) else input
438
+ messages = copy(input)
439
+ while True:
440
+ # call model
441
+ output = await self.generate(
442
+ input=messages,
443
+ tools=tools, # type:ignore[arg-type]
444
+ config=config,
445
+ cache=cache,
446
+ )
447
+
448
+ # append to new messages
449
+ messages.append(output.message)
450
+
451
+ # make tool calls or terminate if there are none
452
+ if output.message.tool_calls:
453
+ tools_messages, tools_output = await execute_tools(
454
+ messages, tools, config.max_tool_output
455
+ )
456
+ messages.extend(tools_messages)
457
+ if tools_output is not None:
458
+ output = tools_output
459
+ else:
460
+ return messages[len(input) :], output
461
+
376
462
  async def _generate(
377
463
  self,
378
464
  input: list[ChatMessage],
@@ -414,7 +500,13 @@ class Model:
414
500
  input = resolve_reasoning_history(input, config, self.api)
415
501
 
416
502
  # apply any tool model_input handlers
417
- input = resolve_tool_model_input(tdefs, input)
503
+ input = resolve_tool_model_input(
504
+ tdefs,
505
+ input,
506
+ ToolCallModelInputHints(
507
+ disable_computer_screenshot_truncation=self.api.disable_computer_screenshot_truncation()
508
+ ),
509
+ )
418
510
 
419
511
  # break tool image content out into user messages if the model doesn't
420
512
  # support tools returning images
@@ -664,10 +756,10 @@ class Model:
664
756
  # trace
665
757
  if isinstance(result, ModelOutput):
666
758
  if result.choices:
667
- conversation_assistant_message(input, result.choices[0].message)
759
+ display_conversation_assistant(input, result.choices[0].message)
668
760
  event.output = result
669
761
  else:
670
- conversation_assistant_error(result)
762
+ display_conversation_assistant_error(result)
671
763
  event.error = repr(result)
672
764
 
673
765
  event.call = updated_call
@@ -1034,7 +1126,7 @@ def resolve_reasoning_history(
1034
1126
 
1035
1127
 
1036
1128
  def resolve_tool_model_input(
1037
- tdefs: list[ToolDef], messages: list[ChatMessage]
1129
+ tdefs: list[ToolDef], messages: list[ChatMessage], hints: ToolCallModelInputHints
1038
1130
  ) -> list[ChatMessage]:
1039
1131
  # filter on tooldefs that have a model input handler
1040
1132
  tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
@@ -1060,7 +1152,7 @@ def resolve_tool_model_input(
1060
1152
  # call the function for each tool, passing the index, total, and content
1061
1153
  for index, message in enumerate(tdef_tool_messages):
1062
1154
  message.content = tdef.model_input(
1063
- index, len(tool_messages), message.content
1155
+ index, len(tool_messages), message.content, hints
1064
1156
  )
1065
1157
 
1066
1158
  # return modified messages
@@ -1116,7 +1208,7 @@ def tool_result_images_reducer(
1116
1208
  content=edited_tool_message_content,
1117
1209
  tool_call_id=message.tool_call_id,
1118
1210
  function=message.function,
1119
- internal_name=message.internal_name,
1211
+ internal=message.internal,
1120
1212
  )
1121
1213
  ],
1122
1214
  pending_content + new_user_message_content,
@@ -1219,6 +1311,13 @@ def consecutive_message_reducer(
1219
1311
  def combine_messages(
1220
1312
  a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage]
1221
1313
  ) -> ChatMessage:
1314
+ # TODO: Although unlikely to happen based on the current call sites, these
1315
+ # fabricated messages drop interesting fields from the source messages -
1316
+ # such as `internal_name`, `tool_calls`, etc.
1317
+ # To be more specific, since all `ChatMessageXxx` fields other than `id` and
1318
+ # `content` have default values, it's more the case that they're reset to
1319
+ # default values rather than dropped.
1320
+
1222
1321
  if isinstance(a.content, str) and isinstance(b.content, str):
1223
1322
  return message_type(id=a.id, content=f"{a.content}\n{b.content}")
1224
1323
  elif isinstance(a.content, list) and isinstance(b.content, list):
@@ -1,7 +1,7 @@
1
1
  import uuid
2
2
  from typing import Any, Literal, Type
3
3
 
4
- from pydantic import BaseModel, Field, model_validator
4
+ from pydantic import BaseModel, Field, JsonValue, model_validator
5
5
 
6
6
  from inspect_ai.tool._tool_call import ToolCall
7
7
 
@@ -123,6 +123,10 @@ class ModelOutput(BaseModel):
123
123
  error: str | None = Field(default=None)
124
124
  """Error message in the case of content moderation refusals."""
125
125
 
126
+ @property
127
+ def empty(self) -> bool:
128
+ return len(self.choices) == 0
129
+
126
130
  @property
127
131
  def stop_reason(self) -> StopReason:
128
132
  """First message stop reason."""
@@ -153,7 +157,8 @@ class ModelOutput(BaseModel):
153
157
  else:
154
158
  self.choices.append(
155
159
  ChatCompletionChoice(
156
- message=ChatMessageAssistant(content=completion), stop_reason="stop"
160
+ message=ChatMessageAssistant(content=completion, model=self.model),
161
+ stop_reason="stop",
157
162
  )
158
163
  )
159
164
 
@@ -176,7 +181,9 @@ class ModelOutput(BaseModel):
176
181
  model=model,
177
182
  choices=[
178
183
  ChatCompletionChoice(
179
- message=ChatMessageAssistant(content=content, source="generate"),
184
+ message=ChatMessageAssistant(
185
+ content=content, model=model, source="generate"
186
+ ),
180
187
  stop_reason=stop_reason,
181
188
  )
182
189
  ],
@@ -188,10 +195,9 @@ class ModelOutput(BaseModel):
188
195
  model: str,
189
196
  tool_name: str,
190
197
  tool_arguments: dict[str, Any],
191
- internal_tool_name: str | None = None,
198
+ internal: JsonValue | None = None,
192
199
  tool_call_id: str | None = None,
193
200
  content: str | None = None,
194
- type: str = "function",
195
201
  ) -> "ModelOutput":
196
202
  """
197
203
  Returns a ModelOutput for requesting a tool call.
@@ -199,8 +205,7 @@ class ModelOutput(BaseModel):
199
205
  Args:
200
206
  model: model name
201
207
  tool_name: The name of the tool.
202
- internal_tool_name: The model's internal name for the tool (if any).
203
- type: The model's type for the tool. e.g. "function", "computer_use_preview"
208
+ internal: The model's internal info for the tool (if any).
204
209
  tool_arguments: The arguments passed to the tool.
205
210
  tool_call_id: Optional ID for the tool call. Defaults to a random UUID.
206
211
  content: Optional content to include in the message. Defaults to "tool call for tool {tool_name}".
@@ -220,14 +225,14 @@ class ModelOutput(BaseModel):
220
225
  ChatCompletionChoice(
221
226
  message=ChatMessageAssistant(
222
227
  content=content,
228
+ model=model,
223
229
  source="generate",
224
230
  tool_calls=[
225
231
  ToolCall(
226
232
  id=tool_call_id,
227
233
  function=tool_name,
228
- internal_name=internal_tool_name,
234
+ internal=internal,
229
235
  arguments=tool_arguments,
230
- type=type,
231
236
  )
232
237
  ],
233
238
  ),
@@ -83,6 +83,10 @@ def is_o1_preview(name: str) -> bool:
83
83
  return "o1-preview" in name
84
84
 
85
85
 
86
+ def is_computer_use_preview(name: str) -> bool:
87
+ return "computer-use-preview" in name
88
+
89
+
86
90
  def is_gpt(name: str) -> bool:
87
91
  return "gpt" in name
88
92
 
@@ -100,13 +104,12 @@ def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
100
104
  def openai_chat_tool_call_param(
101
105
  tool_call: ToolCall,
102
106
  ) -> ChatCompletionMessageToolCallParam:
103
- assert tool_call.type == "function", f"Unexpected tool call type {tool_call.type}"
104
107
  return ChatCompletionMessageToolCallParam(
105
108
  id=tool_call.id,
106
109
  function=dict(
107
110
  name=tool_call.function, arguments=json.dumps(tool_call.arguments)
108
111
  ),
109
- type="function", # Type narrowing couldn't figure it out
112
+ type="function",
110
113
  )
111
114
 
112
115
 
@@ -308,6 +311,7 @@ def chat_tool_calls_from_openai(
308
311
 
309
312
 
310
313
  def chat_messages_from_openai(
314
+ model: str,
311
315
  messages: list[ChatCompletionMessageParam],
312
316
  ) -> list[ChatMessage]:
313
317
  # track tool names by id
@@ -386,6 +390,8 @@ def chat_messages_from_openai(
386
390
  ChatMessageAssistant(
387
391
  content=content,
388
392
  tool_calls=tool_calls or None,
393
+ model=model,
394
+ source="generate",
389
395
  )
390
396
  )
391
397
  elif message["role"] == "tool":
@@ -464,7 +470,7 @@ def content_from_openai(
464
470
 
465
471
 
466
472
  def chat_message_assistant_from_openai(
467
- message: ChatCompletionMessage, tools: list[ToolInfo]
473
+ model: str, message: ChatCompletionMessage, tools: list[ToolInfo]
468
474
  ) -> ChatMessageAssistant:
469
475
  refusal = getattr(message, "refusal", None)
470
476
  reasoning = getattr(message, "reasoning_content", None) or getattr(
@@ -484,6 +490,7 @@ def chat_message_assistant_from_openai(
484
490
 
485
491
  return ChatMessageAssistant(
486
492
  content=content,
493
+ model=model,
487
494
  source="generate",
488
495
  tool_calls=chat_tool_calls_from_openai(message, tools),
489
496
  )
@@ -496,7 +503,9 @@ def chat_choices_from_openai(
496
503
  choices.sort(key=lambda c: c.index)
497
504
  return [
498
505
  ChatCompletionChoice(
499
- message=chat_message_assistant_from_openai(choice.message, tools),
506
+ message=chat_message_assistant_from_openai(
507
+ response.model, choice.message, tools
508
+ ),
500
509
  stop_reason=as_stop_reason(choice.finish_reason),
501
510
  logprobs=(
502
511
  Logprobs(**choice.logprobs.model_dump())
@@ -538,6 +547,9 @@ def openai_handle_bad_request(
538
547
 
539
548
  def openai_media_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
540
549
  # remove images from raw api call
550
+ if key == "output" and isinstance(value, dict) and "image_url" in value:
551
+ value = copy(value)
552
+ value.update(image_url=BASE_64_DATA_REMOVED)
541
553
  if key == "image_url" and isinstance(value, dict) and "url" in value:
542
554
  url = str(value.get("url"))
543
555
  if url.startswith("data:"):
@@ -0,0 +1,162 @@
1
+ from openai.types.responses import (
2
+ ComputerToolParam,
3
+ ResponseComputerToolCall,
4
+ ResponseComputerToolCallOutputScreenshotParam,
5
+ )
6
+ from openai.types.responses.response_input_item_param import ComputerCallOutput
7
+
8
+ from inspect_ai._util.content import Content, ContentImage
9
+ from inspect_ai.model._chat_message import ChatMessageTool
10
+ from inspect_ai.tool._tool_call import ToolCall
11
+ from inspect_ai.tool._tool_info import ToolInfo
12
+
13
+
14
+ def tool_call_from_openai_computer_tool_call(
15
+ output: ResponseComputerToolCall,
16
+ ) -> ToolCall:
17
+ return ToolCall(
18
+ id=output.call_id,
19
+ function="computer",
20
+ arguments=_parse_computer_tool_call_arguments(output),
21
+ internal=output.model_dump(),
22
+ )
23
+
24
+
25
+ def maybe_computer_use_preview_tool(tool: ToolInfo) -> ComputerToolParam | None:
26
+ # check for compatible 'computer' tool
27
+ return (
28
+ ComputerToolParam(
29
+ type="computer_use_preview",
30
+ # The OpenAI model is ahead of the sdk — "ubuntu" -> "linux"
31
+ environment="linux", # type: ignore
32
+ # Note: The dimensions passed here for display_width and display_height should
33
+ # match the dimensions of screenshots returned by the tool.
34
+ # Those dimensions will always be one of the values in MAX_SCALING_TARGETS
35
+ # in _x11_client.py.
36
+ # TODO: enhance this code to calculate the dimensions based on the scaled screen
37
+ # size used by the container.
38
+ display_width=1366,
39
+ display_height=768,
40
+ )
41
+ if tool.name == "computer"
42
+ and (
43
+ sorted(tool.parameters.properties.keys())
44
+ == sorted(
45
+ [
46
+ "action",
47
+ "coordinate",
48
+ "duration",
49
+ "scroll_amount",
50
+ "scroll_direction",
51
+ "start_coordinate",
52
+ "text",
53
+ ]
54
+ )
55
+ )
56
+ else None
57
+ )
58
+
59
+
60
+ def computer_call_output(
61
+ message: ChatMessageTool,
62
+ # internal is passed in despite being within message to avoid an extra
63
+ # validation step
64
+ internal: ResponseComputerToolCall,
65
+ ) -> ComputerCallOutput:
66
+ return ComputerCallOutput(
67
+ call_id=internal.call_id,
68
+ type="computer_call_output",
69
+ output=ResponseComputerToolCallOutputScreenshotParam(
70
+ type="computer_screenshot",
71
+ image_url=_content_image(message.content),
72
+ ),
73
+ )
74
+
75
+
76
+ def _parse_computer_tool_call_arguments(
77
+ output: ResponseComputerToolCall,
78
+ ) -> dict[str, object]:
79
+ action = output.action
80
+
81
+ if action.type == "click":
82
+ coordinate = [action.x, action.y]
83
+ match action.button:
84
+ case "left":
85
+ return {"action": "left_click", "coordinate": coordinate}
86
+ case "right":
87
+ return {"action": "right_click", "coordinate": coordinate}
88
+ case "wheel":
89
+ return {"action": "middle_click", "coordinate": coordinate}
90
+ case "back":
91
+ return {"action": "back_click", "coordinate": coordinate}
92
+ case "forward":
93
+ return {"action": "forward_click", "coordinate": coordinate}
94
+ elif action.type == "double_click":
95
+ return {"action": "double_click", "coordinate": [action.x, action.y]}
96
+ elif action.type == "drag":
97
+ # TODO: For now, we go directly from the first to the last coordinate in
98
+ # the path. Ultimately, we'll need to extend the tool to support all of
99
+ # the intermediate coordinates in the path.
100
+ path = action.path
101
+ assert len(path) >= 2
102
+ start = path[0]
103
+ end = path[-1]
104
+ return {
105
+ "action": "left_click_drag",
106
+ "start_coordinate": [start.x, start.y],
107
+ "coordinate": [end.x, end.y],
108
+ }
109
+ elif action.type == "keypress":
110
+ # TODO: This mapping logic is copied from their example, but seems incomplete
111
+ mapping = {
112
+ "ENTER": "Return",
113
+ "LEFT": "Left",
114
+ "RIGHT": "Right",
115
+ "UP": "Up",
116
+ "DOWN": "Down",
117
+ "ESC": "Escape",
118
+ "SPACE": "space",
119
+ "BACKSPACE": "BackSpace",
120
+ "TAB": "Tab",
121
+ }
122
+ return {
123
+ "action": "key",
124
+ "text": "+".join([mapping.get(key, key) for key in action.keys]),
125
+ }
126
+ elif action.type == "move":
127
+ return {"action": "mouse_move", "coordinate": [action.x, action.y]}
128
+ elif action.type == "screenshot":
129
+ return {"action": "screenshot"}
130
+ elif action.type == "scroll":
131
+ # TODO: OpenAI spec's with x/y distances. Their example code treats the
132
+ # unit of measurement as a "click" of the scroll wheel. Since it's not
133
+ # really a thing to scroll both horizontally and vertically at the same
134
+ # time, we'll just pick one of the potentially two directions and
135
+ # scroll along that dimension.
136
+ (scroll_direction, scroll_amount) = (
137
+ ("right" if action.scroll_x > 0 else "left", abs(action.scroll_x))
138
+ if action.scroll_x
139
+ else ("down" if action.scroll_y > 0 else "up", abs(action.scroll_y))
140
+ )
141
+ return {
142
+ "action": "scroll",
143
+ "coordinate": [action.x, action.y],
144
+ "scroll_direction": scroll_direction,
145
+ "scroll_amount": scroll_amount,
146
+ }
147
+ elif action.type == "type":
148
+ return {"action": "type", "text": action.text}
149
+ elif action.type == "wait":
150
+ return {"action": "wait", "duration": 1}
151
+
152
+ assert False, f"Unexpected action type: {action.type}"
153
+
154
+
155
+ def _content_image(input: str | list[Content]) -> str:
156
+ result = (
157
+ next((item.image for item in input if isinstance(item, ContentImage)), None)
158
+ if isinstance(input, list)
159
+ else None
160
+ )
161
+ assert result, "Must find image in content"
162
+ return result