inspect-ai 0.3.82__py3-none-any.whl → 0.3.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_display/textual/app.py +14 -3
  3. inspect_ai/_display/textual/display.py +4 -0
  4. inspect_ai/_display/textual/widgets/samples.py +9 -3
  5. inspect_ai/_display/textual/widgets/task_detail.py +3 -4
  6. inspect_ai/_display/textual/widgets/tasks.py +17 -1
  7. inspect_ai/_display/textual/widgets/vscode.py +44 -0
  8. inspect_ai/_eval/eval.py +36 -24
  9. inspect_ai/_eval/evalset.py +17 -18
  10. inspect_ai/_eval/loader.py +34 -11
  11. inspect_ai/_eval/run.py +8 -13
  12. inspect_ai/_eval/score.py +13 -3
  13. inspect_ai/_eval/task/generate.py +8 -9
  14. inspect_ai/_eval/task/log.py +2 -0
  15. inspect_ai/_eval/task/task.py +23 -9
  16. inspect_ai/_util/file.py +13 -0
  17. inspect_ai/_util/json.py +2 -1
  18. inspect_ai/_util/registry.py +1 -0
  19. inspect_ai/_util/vscode.py +37 -0
  20. inspect_ai/_view/www/App.css +6 -0
  21. inspect_ai/_view/www/dist/assets/index.css +304 -128
  22. inspect_ai/_view/www/dist/assets/index.js +47495 -27519
  23. inspect_ai/_view/www/log-schema.json +124 -31
  24. inspect_ai/_view/www/package.json +3 -0
  25. inspect_ai/_view/www/src/App.tsx +12 -0
  26. inspect_ai/_view/www/src/appearance/icons.ts +1 -0
  27. inspect_ai/_view/www/src/components/Card.tsx +6 -4
  28. inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
  29. inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
  30. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
  31. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
  32. inspect_ai/_view/www/src/components/Modal.module.css +38 -0
  33. inspect_ai/_view/www/src/components/Modal.tsx +77 -0
  34. inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
  35. inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
  36. inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
  37. inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
  38. inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
  39. inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
  40. inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
  41. inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
  42. inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
  43. inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
  44. inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
  45. inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
  46. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
  47. inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
  48. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
  49. inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
  50. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
  51. inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
  52. inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
  53. inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
  54. inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
  55. inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
  56. inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
  57. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
  58. inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
  59. inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
  60. inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
  61. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
  62. inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
  63. inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
  64. inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
  65. inspect_ai/_view/www/src/state/hooks.ts +5 -3
  66. inspect_ai/_view/www/src/state/logPolling.ts +5 -1
  67. inspect_ai/_view/www/src/state/logSlice.ts +10 -0
  68. inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
  69. inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
  70. inspect_ai/_view/www/src/types/log.d.ts +34 -26
  71. inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
  72. inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
  73. inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
  74. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
  75. inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
  76. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
  77. inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
  78. inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
  79. inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
  80. inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
  81. inspect_ai/_view/www/yarn.lock +94 -1
  82. inspect_ai/agent/__init__.py +36 -0
  83. inspect_ai/agent/_agent.py +268 -0
  84. inspect_ai/agent/_as_solver.py +72 -0
  85. inspect_ai/agent/_as_tool.py +122 -0
  86. inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
  87. inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
  88. inspect_ai/agent/_filter.py +46 -0
  89. inspect_ai/agent/_handoff.py +93 -0
  90. inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
  91. inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
  92. inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
  93. inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
  94. inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
  95. inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
  96. inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
  97. inspect_ai/agent/_react.py +241 -0
  98. inspect_ai/agent/_run.py +36 -0
  99. inspect_ai/agent/_types.py +81 -0
  100. inspect_ai/log/_log.py +11 -2
  101. inspect_ai/log/_transcript.py +13 -9
  102. inspect_ai/model/__init__.py +7 -1
  103. inspect_ai/model/_call_tools.py +256 -52
  104. inspect_ai/model/_chat_message.py +7 -4
  105. inspect_ai/model/_conversation.py +13 -62
  106. inspect_ai/model/_display.py +85 -0
  107. inspect_ai/model/_model.py +113 -14
  108. inspect_ai/model/_model_output.py +14 -9
  109. inspect_ai/model/_openai.py +16 -4
  110. inspect_ai/model/_openai_computer_use.py +162 -0
  111. inspect_ai/model/_openai_responses.py +319 -165
  112. inspect_ai/model/_providers/anthropic.py +20 -21
  113. inspect_ai/model/_providers/azureai.py +24 -13
  114. inspect_ai/model/_providers/bedrock.py +1 -7
  115. inspect_ai/model/_providers/cloudflare.py +3 -3
  116. inspect_ai/model/_providers/goodfire.py +2 -6
  117. inspect_ai/model/_providers/google.py +11 -10
  118. inspect_ai/model/_providers/groq.py +6 -3
  119. inspect_ai/model/_providers/hf.py +7 -3
  120. inspect_ai/model/_providers/mistral.py +7 -10
  121. inspect_ai/model/_providers/openai.py +47 -17
  122. inspect_ai/model/_providers/openai_o1.py +11 -4
  123. inspect_ai/model/_providers/openai_responses.py +12 -14
  124. inspect_ai/model/_providers/providers.py +2 -2
  125. inspect_ai/model/_providers/together.py +12 -2
  126. inspect_ai/model/_providers/util/chatapi.py +7 -2
  127. inspect_ai/model/_providers/util/hf_handler.py +4 -2
  128. inspect_ai/model/_providers/util/llama31.py +4 -2
  129. inspect_ai/model/_providers/vertex.py +11 -9
  130. inspect_ai/model/_providers/vllm.py +4 -4
  131. inspect_ai/scorer/__init__.py +2 -0
  132. inspect_ai/scorer/_metrics/__init__.py +2 -0
  133. inspect_ai/scorer/_metrics/grouped.py +84 -0
  134. inspect_ai/scorer/_score.py +26 -6
  135. inspect_ai/solver/__init__.py +2 -2
  136. inspect_ai/solver/_basic_agent.py +22 -9
  137. inspect_ai/solver/_bridge.py +31 -0
  138. inspect_ai/solver/_chain.py +20 -12
  139. inspect_ai/solver/_fork.py +5 -1
  140. inspect_ai/solver/_human_agent.py +52 -0
  141. inspect_ai/solver/_prompt.py +3 -1
  142. inspect_ai/solver/_run.py +59 -0
  143. inspect_ai/solver/_solver.py +14 -4
  144. inspect_ai/solver/_task_state.py +5 -3
  145. inspect_ai/tool/_tool_call.py +15 -8
  146. inspect_ai/tool/_tool_def.py +17 -12
  147. inspect_ai/tool/_tool_support_helpers.py +2 -2
  148. inspect_ai/tool/_tool_with.py +14 -11
  149. inspect_ai/tool/_tools/_bash_session.py +11 -2
  150. inspect_ai/tool/_tools/_computer/_common.py +18 -2
  151. inspect_ai/tool/_tools/_computer/_computer.py +18 -2
  152. inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
  153. inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
  154. inspect_ai/tool/_tools/_think.py +1 -1
  155. inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
  156. inspect_ai/util/__init__.py +2 -0
  157. inspect_ai/util/_anyio.py +27 -0
  158. inspect_ai/util/_sandbox/__init__.py +2 -1
  159. inspect_ai/util/_sandbox/context.py +32 -7
  160. inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
  161. inspect_ai/util/_sandbox/docker/compose.py +2 -2
  162. inspect_ai/util/_sandbox/docker/docker.py +12 -1
  163. inspect_ai/util/_store_model.py +30 -7
  164. inspect_ai/util/_subprocess.py +13 -3
  165. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/METADATA +1 -1
  166. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/RECORD +179 -153
  167. inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
  168. /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
  169. /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
  170. /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
  171. /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
  172. /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
  173. /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
  174. /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
  175. /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
  176. /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
  177. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/WHEEL +0 -0
  178. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/entry_points.txt +0 -0
  179. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/licenses/LICENSE +0 -0
  180. {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.83.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import os
3
3
  import re
4
4
  from copy import copy
5
5
  from logging import getLogger
6
- from typing import Any, Literal, NamedTuple, Optional, Tuple, cast
6
+ from typing import Any, Literal, Optional, Tuple, cast
7
7
 
8
8
  import httpcore
9
9
  import httpx
@@ -153,7 +153,7 @@ class AnthropicAPI(ModelAPI):
153
153
  self._http_hooks = HttpxHooks(self.client._client)
154
154
 
155
155
  @override
156
- async def close(self) -> None:
156
+ async def aclose(self) -> None:
157
157
  await self.client.close()
158
158
 
159
159
  def is_bedrock(self) -> bool:
@@ -639,11 +639,7 @@ def message_tool_choice(
639
639
  elif tool_choice == "any":
640
640
  return {"type": "any"}
641
641
  elif tool_choice == "none":
642
- warn_once(
643
- logger,
644
- 'The Anthropic API does not support tool_choice="none" (using "auto" instead)',
645
- )
646
- return {"type": "auto"}
642
+ return {"type": "none"}
647
643
  else:
648
644
  return {"type": "auto"}
649
645
 
@@ -723,11 +719,12 @@ async def message_param(message: ChatMessage) -> MessageParam:
723
719
 
724
720
  # now add tools
725
721
  for tool_call in message.tool_calls:
722
+ internal_name = _internal_name_from_tool_call(tool_call)
726
723
  tools_content.append(
727
724
  ToolUseBlockParam(
728
725
  type="tool_use",
729
726
  id=tool_call.id,
730
- name=tool_call.internal_name or tool_call.function,
727
+ name=internal_name or tool_call.function,
731
728
  input=tool_call.arguments,
732
729
  )
733
730
  )
@@ -774,14 +771,13 @@ async def model_output_from_message(
774
771
  content.append(ContentText(type="text", text=content_text))
775
772
  elif isinstance(content_block, ToolUseBlock):
776
773
  tool_calls = tool_calls or []
777
- info = maybe_mapped_call_info(content_block.name, tools)
774
+ (tool_name, internal_name) = _names_for_tool_call(content_block.name, tools)
778
775
  tool_calls.append(
779
776
  ToolCall(
780
- type=info.internal_type,
781
777
  id=content_block.id,
782
- function=info.inspect_name,
783
- internal_name=info.internal_name,
778
+ function=tool_name,
784
779
  arguments=content_block.model_dump().get("input", {}),
780
+ internal=internal_name,
785
781
  )
786
782
  )
787
783
  elif isinstance(content_block, RedactedThinkingBlock):
@@ -801,7 +797,7 @@ async def model_output_from_message(
801
797
  # resolve choice
802
798
  choice = ChatCompletionChoice(
803
799
  message=ChatMessageAssistant(
804
- content=content, tool_calls=tool_calls, source="generate"
800
+ content=content, tool_calls=tool_calls, model=model, source="generate"
805
801
  ),
806
802
  stop_reason=message_stop_reason(message),
807
803
  )
@@ -831,15 +827,18 @@ async def model_output_from_message(
831
827
  )
832
828
 
833
829
 
834
- class CallInfo(NamedTuple):
835
- internal_name: str | None
836
- internal_type: str
837
- inspect_name: str
830
+ def _internal_name_from_tool_call(tool_call: ToolCall) -> str | None:
831
+ assert isinstance(tool_call.internal, str | None), (
832
+ f"ToolCall internal must be `str | None`: {tool_call.internal}"
833
+ )
834
+ return tool_call.internal
838
835
 
839
836
 
840
- def maybe_mapped_call_info(tool_called: str, tools: list[ToolInfo]) -> CallInfo:
837
+ def _names_for_tool_call(
838
+ tool_called: str, tools: list[ToolInfo]
839
+ ) -> tuple[str, str | None]:
841
840
  """
842
- Return call info - potentially transformed by native tool mappings.
841
+ Return the name of the tool to call and potentially an internal name.
843
842
 
844
843
  Anthropic prescribes names for their native tools - `computer`, `bash`, and
845
844
  `str_replace_editor`. For a variety of reasons, Inspect's tool names to not
@@ -854,11 +853,11 @@ def maybe_mapped_call_info(tool_called: str, tools: list[ToolInfo]) -> CallInfo:
854
853
 
855
854
  return next(
856
855
  (
857
- CallInfo(entry[0], entry[1], entry[2])
856
+ (entry[2], entry[0])
858
857
  for entry in mappings
859
858
  if entry[0] == tool_called and any(tool.name == entry[2] for tool in tools)
860
859
  ),
861
- CallInfo(None, "function", tool_called),
860
+ (tool_called, None),
862
861
  )
863
862
 
864
863
 
@@ -129,11 +129,6 @@ class AzureAIAPI(ModelAPI):
129
129
  self.endpoint_url = endpoint_url
130
130
  self.model_args = model_args
131
131
 
132
- @override
133
- async def close(self) -> None:
134
- # client is created/destroyed each time in generate()
135
- pass
136
-
137
132
  async def generate(
138
133
  self,
139
134
  input: list[ChatMessage],
@@ -143,9 +138,9 @@ class AzureAIAPI(ModelAPI):
143
138
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
144
139
  # emulate tools (auto for llama, opt-in for others)
145
140
  if self.emulate_tools is None and self.is_llama():
146
- handler: ChatAPIHandler | None = Llama31Handler()
141
+ handler: ChatAPIHandler | None = Llama31Handler(self.model_name)
147
142
  elif self.emulate_tools:
148
- handler = Llama31Handler()
143
+ handler = Llama31Handler(self.model_name)
149
144
  else:
150
145
  handler = None
151
146
 
@@ -190,7 +185,9 @@ class AzureAIAPI(ModelAPI):
190
185
  response: ChatCompletions = await client.complete(**request)
191
186
  return ModelOutput(
192
187
  model=response.model,
193
- choices=chat_completion_choices(response.choices, tools, handler),
188
+ choices=chat_completion_choices(
189
+ response.model, response.choices, tools, handler
190
+ ),
194
191
  usage=ModelUsage(
195
192
  input_tokens=response.usage.prompt_tokens,
196
193
  output_tokens=response.usage.completion_tokens,
@@ -368,24 +365,37 @@ def chat_tool_choice(
368
365
 
369
366
 
370
367
  def chat_completion_choices(
371
- choices: list[ChatChoice], tools: list[ToolInfo], handler: ChatAPIHandler | None
368
+ model: str,
369
+ choices: list[ChatChoice],
370
+ tools: list[ToolInfo],
371
+ handler: ChatAPIHandler | None,
372
372
  ) -> list[ChatCompletionChoice]:
373
373
  choices = copy(choices)
374
374
  choices.sort(key=lambda c: c.index)
375
- return [chat_complection_choice(choice, tools, handler) for choice in choices]
375
+ return [
376
+ chat_complection_choice(model, choice, tools, handler) for choice in choices
377
+ ]
376
378
 
377
379
 
378
380
  def chat_complection_choice(
379
- choice: ChatChoice, tools: list[ToolInfo], handler: ChatAPIHandler | None
381
+ model: str,
382
+ choice: ChatChoice,
383
+ tools: list[ToolInfo],
384
+ handler: ChatAPIHandler | None,
380
385
  ) -> ChatCompletionChoice:
381
386
  return ChatCompletionChoice(
382
- message=chat_completion_assistant_message(choice.message, tools, handler),
387
+ message=chat_completion_assistant_message(
388
+ model, choice.message, tools, handler
389
+ ),
383
390
  stop_reason=chat_completion_stop_reason(choice.finish_reason),
384
391
  )
385
392
 
386
393
 
387
394
  def chat_completion_assistant_message(
388
- response: ChatResponseMessage, tools: list[ToolInfo], handler: ChatAPIHandler | None
395
+ model: str,
396
+ response: ChatResponseMessage,
397
+ tools: list[ToolInfo],
398
+ handler: ChatAPIHandler | None,
389
399
  ) -> ChatMessageAssistant:
390
400
  if handler:
391
401
  return handler.parse_assistant_response(response.content, tools)
@@ -397,6 +407,7 @@ def chat_completion_assistant_message(
397
407
  ]
398
408
  if response.tool_calls is not None
399
409
  else None,
410
+ model=model,
400
411
  )
401
412
 
402
413
 
@@ -269,11 +269,6 @@ class BedrockAPI(ModelAPI):
269
269
  except ImportError:
270
270
  raise pip_dependency_error("Bedrock API", ["aioboto3"])
271
271
 
272
- @override
273
- async def close(self) -> None:
274
- # client is created/destroyed each time in generate()
275
- pass
276
-
277
272
  @override
278
273
  def connection_key(self) -> str:
279
274
  return self.model_name
@@ -454,7 +449,6 @@ def model_output_from_response(
454
449
  tool_calls.append(
455
450
  ToolCall(
456
451
  id=c.toolUse.toolUseId,
457
- type="function",
458
452
  function=c.toolUse.name,
459
453
  arguments=cast(dict[str, Any], c.toolUse.input or {}),
460
454
  )
@@ -465,7 +459,7 @@ def model_output_from_response(
465
459
  # resolve choice
466
460
  choice = ChatCompletionChoice(
467
461
  message=ChatMessageAssistant(
468
- content=content, tool_calls=tool_calls, source="generate"
462
+ content=content, tool_calls=tool_calls, model=model, source="generate"
469
463
  ),
470
464
  stop_reason=message_stop_reason(response.stopReason),
471
465
  )
@@ -59,7 +59,7 @@ class CloudFlareAPI(ModelAPI):
59
59
  self.model_args = model_args
60
60
 
61
61
  @override
62
- async def close(self) -> None:
62
+ async def aclose(self) -> None:
63
63
  await self.client.aclose()
64
64
 
65
65
  async def generate(
@@ -141,6 +141,6 @@ class CloudFlareAPI(ModelAPI):
141
141
 
142
142
  def chat_api_handler(self) -> ChatAPIHandler:
143
143
  if "llama" in self.model_name.lower():
144
- return Llama31Handler()
144
+ return Llama31Handler(self.model_name)
145
145
  else:
146
- return ChatAPIHandler()
146
+ return ChatAPIHandler(self.model_name)
@@ -115,11 +115,6 @@ class GoodfireAPI(ModelAPI):
115
115
  # Initialize variant directly with model name
116
116
  self.variant = Variant(self.model_name) # type: ignore
117
117
 
118
- @override
119
- async def close(self) -> None:
120
- # httpx.AsyncClient is created on each generate()
121
- pass
122
-
123
118
  def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
124
119
  """Convert an Inspect message to a Goodfire message format.
125
120
 
@@ -232,7 +227,8 @@ class GoodfireAPI(ModelAPI):
232
227
  choices=[
233
228
  ChatCompletionChoice(
234
229
  message=ChatMessageAssistant(
235
- content=response_dict["choices"][0]["message"]["content"]
230
+ content=response_dict["choices"][0]["message"]["content"],
231
+ model=self.model_name,
236
232
  ),
237
233
  stop_reason="stop",
238
234
  )
@@ -183,11 +183,6 @@ class GoogleGenAIAPI(ModelAPI):
183
183
  # save model args
184
184
  self.model_args = model_args
185
185
 
186
- @override
187
- async def close(self) -> None:
188
- # GenerativeModel uses a cached/shared client so there is no 'close'
189
- pass
190
-
191
186
  def is_vertex(self) -> bool:
192
187
  return self.service == "vertex"
193
188
 
@@ -257,9 +252,10 @@ class GoogleGenAIAPI(ModelAPI):
257
252
  except ClientError as ex:
258
253
  return self.handle_client_error(ex), model_call()
259
254
 
255
+ model_name = response.model_version or self.model_name
260
256
  output = ModelOutput(
261
- model=self.model_name,
262
- choices=completion_choices_from_candidates(response),
257
+ model=model_name,
258
+ choices=completion_choices_from_candidates(model_name, response),
263
259
  usage=usage_metadata_to_model_usage(response.usage_metadata),
264
260
  )
265
261
 
@@ -546,7 +542,9 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
546
542
  )
547
543
 
548
544
 
549
- def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoice:
545
+ def completion_choice_from_candidate(
546
+ model: str, candidate: Candidate
547
+ ) -> ChatCompletionChoice:
550
548
  # content can be None when the finish_reason is SAFETY
551
549
  if candidate.content is None:
552
550
  content = ""
@@ -572,7 +570,6 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
572
570
  if part.function_call:
573
571
  tool_calls.append(
574
572
  ToolCall(
575
- type="function",
576
573
  id=part.function_call.name,
577
574
  function=part.function_call.name,
578
575
  arguments=part.function_call.args,
@@ -596,6 +593,7 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
596
593
  message=ChatMessageAssistant(
597
594
  content=choice_content,
598
595
  tool_calls=tool_calls if len(tool_calls) > 0 else None,
596
+ model=model,
599
597
  source="generate",
600
598
  ),
601
599
  stop_reason=stop_reason,
@@ -624,19 +622,22 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
624
622
 
625
623
 
626
624
  def completion_choices_from_candidates(
625
+ model: str,
627
626
  response: GenerateContentResponse,
628
627
  ) -> list[ChatCompletionChoice]:
629
628
  candidates = response.candidates
630
629
  if candidates:
631
630
  candidates_list = sorted(candidates, key=lambda c: c.index)
632
631
  return [
633
- completion_choice_from_candidate(candidate) for candidate in candidates_list
632
+ completion_choice_from_candidate(model, candidate)
633
+ for candidate in candidates_list
634
634
  ]
635
635
  elif response.prompt_feedback:
636
636
  return [
637
637
  ChatCompletionChoice(
638
638
  message=ChatMessageAssistant(
639
639
  content=prompt_feedback_to_content(response.prompt_feedback),
640
+ model=model,
640
641
  source="generate",
641
642
  ),
642
643
  stop_reason="content_filter",
@@ -93,7 +93,7 @@ class GroqAPI(ModelAPI):
93
93
  self._http_hooks = HttpxHooks(self.client._client)
94
94
 
95
95
  @override
96
- async def close(self) -> None:
96
+ async def aclose(self) -> None:
97
97
  await self.client.close()
98
98
 
99
99
  async def generate(
@@ -203,7 +203,7 @@ class GroqAPI(ModelAPI):
203
203
  choices.sort(key=lambda c: c.index)
204
204
  return [
205
205
  ChatCompletionChoice(
206
- message=chat_message_assistant(choice.message, tools),
206
+ message=chat_message_assistant(self.model_name, choice.message, tools),
207
207
  stop_reason=as_stop_reason(choice.finish_reason),
208
208
  )
209
209
  for choice in choices
@@ -323,7 +323,9 @@ def chat_tool_calls(message: Any, tools: list[ToolInfo]) -> Optional[List[ToolCa
323
323
  return None
324
324
 
325
325
 
326
- def chat_message_assistant(message: Any, tools: list[ToolInfo]) -> ChatMessageAssistant:
326
+ def chat_message_assistant(
327
+ model: str, message: Any, tools: list[ToolInfo]
328
+ ) -> ChatMessageAssistant:
327
329
  reasoning = getattr(message, "reasoning", None)
328
330
  if reasoning is not None:
329
331
  content: str | list[Content] = [
@@ -335,6 +337,7 @@ def chat_message_assistant(message: Any, tools: list[ToolInfo]) -> ChatMessageAs
335
337
 
336
338
  return ChatMessageAssistant(
337
339
  content=content,
340
+ model=model,
338
341
  source="generate",
339
342
  tool_calls=chat_tool_calls(message, tools),
340
343
  )
@@ -123,7 +123,7 @@ class HuggingFaceAPI(ModelAPI):
123
123
  self.tokenizer.padding_side = "left"
124
124
 
125
125
  @override
126
- async def close(self) -> None:
126
+ def close(self) -> None:
127
127
  self.model = None
128
128
  self.tokenizer = None
129
129
  gc.collect()
@@ -205,7 +205,9 @@ class HuggingFaceAPI(ModelAPI):
205
205
 
206
206
  # construct choice
207
207
  choice = ChatCompletionChoice(
208
- message=ChatMessageAssistant(content=response.output, source="generate"),
208
+ message=ChatMessageAssistant(
209
+ content=response.output, model=self.model_name, source="generate"
210
+ ),
209
211
  logprobs=(
210
212
  Logprobs(content=final_logprobs) if final_logprobs is not None else None
211
213
  ),
@@ -338,7 +340,9 @@ def chat_completion_assistant_message(
338
340
  if handler:
339
341
  return handler.parse_assistant_response(response.output, tools)
340
342
  else:
341
- return ChatMessageAssistant(content=response.output, source="generate")
343
+ return ChatMessageAssistant(
344
+ content=response.output, model=model_name, source="generate"
345
+ )
342
346
 
343
347
 
344
348
  def set_random_seeds(seed: int | None = None) -> None:
@@ -135,11 +135,6 @@ class MistralAPI(ModelAPI):
135
135
  def is_azure(self) -> bool:
136
136
  return self.service == "azure"
137
137
 
138
- @override
139
- async def close(self) -> None:
140
- # client is created and destroyed in generate
141
- pass
142
-
143
138
  async def generate(
144
139
  self,
145
140
  input: list[ChatMessage],
@@ -448,13 +443,11 @@ def chat_tool_call(tool_call: MistralToolCall, tools: list[ToolInfo]) -> ToolCal
448
443
  id, tool_call.function.name, tool_call.function.arguments, tools
449
444
  )
450
445
  else:
451
- return ToolCall(
452
- id, tool_call.function.name, tool_call.function.arguments, type="function"
453
- )
446
+ return ToolCall(id, tool_call.function.name, tool_call.function.arguments)
454
447
 
455
448
 
456
449
  def completion_choice(
457
- choice: MistralChatCompletionChoice, tools: list[ToolInfo]
450
+ model: str, choice: MistralChatCompletionChoice, tools: list[ToolInfo]
458
451
  ) -> ChatCompletionChoice:
459
452
  message = choice.message
460
453
  if message:
@@ -465,6 +458,7 @@ def completion_choice(
465
458
  tool_calls=chat_tool_calls(message.tool_calls, tools)
466
459
  if message.tool_calls
467
460
  else None,
461
+ model=model,
468
462
  source="generate",
469
463
  ),
470
464
  stop_reason=(
@@ -511,7 +505,10 @@ def completion_choices_from_response(
511
505
  if response.choices is None:
512
506
  return []
513
507
  else:
514
- return [completion_choice(choice, tools) for choice in response.choices]
508
+ return [
509
+ completion_choice(response.model, choice, tools)
510
+ for choice in response.choices
511
+ ]
515
512
 
516
513
 
517
514
  def choice_stop_reason(choice: MistralChatCompletionChoice) -> StopReason:
@@ -33,6 +33,7 @@ from .._model_call import ModelCall
33
33
  from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
34
34
  from .._openai import (
35
35
  OpenAIResponseError,
36
+ is_computer_use_preview,
36
37
  is_gpt,
37
38
  is_o1_mini,
38
39
  is_o1_preview,
@@ -45,10 +46,7 @@ from .._openai import (
45
46
  openai_media_filter,
46
47
  )
47
48
  from .openai_o1 import generate_o1
48
- from .util import (
49
- environment_prerequisite_error,
50
- model_base_url,
51
- )
49
+ from .util import environment_prerequisite_error, model_base_url
52
50
 
53
51
  logger = getLogger(__name__)
54
52
 
@@ -77,9 +75,6 @@ class OpenAIAPI(ModelAPI):
77
75
  else:
78
76
  self.service = None
79
77
 
80
- # note whether we are forcing the responses_api
81
- self.responses_api = True if responses_api else False
82
-
83
78
  # call super
84
79
  super().__init__(
85
80
  model_name=model_name,
@@ -89,6 +84,11 @@ class OpenAIAPI(ModelAPI):
89
84
  config=config,
90
85
  )
91
86
 
87
+ # note whether we are forcing the responses_api
88
+ self.responses_api = (
89
+ responses_api or self.is_o1_pro() or self.is_computer_use_preview()
90
+ )
91
+
92
92
  # resolve api_key
93
93
  if not self.api_key:
94
94
  if self.service == "azure":
@@ -128,10 +128,14 @@ class OpenAIAPI(ModelAPI):
128
128
  )
129
129
 
130
130
  # resolve version
131
- api_version = os.environ.get(
132
- "AZUREAI_OPENAI_API_VERSION",
133
- os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
134
- )
131
+ if model_args.get("api_version") is not None:
132
+ # use slightly complicated logic to allow for "api_version" to be removed
133
+ api_version = model_args.pop("api_version")
134
+ else:
135
+ api_version = os.environ.get(
136
+ "AZUREAI_OPENAI_API_VERSION",
137
+ os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
138
+ )
135
139
 
136
140
  self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
137
141
  api_key=self.api_key,
@@ -166,13 +170,33 @@ class OpenAIAPI(ModelAPI):
166
170
  def is_o1_preview(self) -> bool:
167
171
  return is_o1_preview(self.model_name)
168
172
 
173
+ def is_computer_use_preview(self) -> bool:
174
+ return is_computer_use_preview(self.model_name)
175
+
169
176
  def is_gpt(self) -> bool:
170
177
  return is_gpt(self.model_name)
171
178
 
172
179
  @override
173
- async def close(self) -> None:
180
+ async def aclose(self) -> None:
174
181
  await self.client.close()
175
182
 
183
+ @override
184
+ def emulate_reasoning_history(self) -> bool:
185
+ return not self.responses_api
186
+
187
+ @override
188
+ def tool_result_images(self) -> bool:
189
+ # o1-pro, o1, and computer_use_preview support image inputs (but we're not strictly supporting o1)
190
+ return self.is_o1_pro() or self.is_computer_use_preview()
191
+
192
+ @override
193
+ def disable_computer_screenshot_truncation(self) -> bool:
194
+ # Because ComputerCallOutput has a required output field of type
195
+ # ResponseComputerToolCallOutputScreenshot, we must have an image in
196
+ # order to provide a valid tool call response. Therefore, we cannot
197
+ # support image truncation.
198
+ return True
199
+
176
200
  async def generate(
177
201
  self,
178
202
  input: list[ChatMessage],
@@ -188,7 +212,7 @@ class OpenAIAPI(ModelAPI):
188
212
  tools=tools,
189
213
  **self.completion_params(config, False),
190
214
  )
191
- elif self.is_o1_pro() or self.responses_api:
215
+ elif self.responses_api:
192
216
  return await generate_responses(
193
217
  client=self.client,
194
218
  http_hooks=self._http_hooks,
@@ -344,10 +368,7 @@ class OpenAIAPI(ModelAPI):
344
368
  params["top_p"] = config.top_p
345
369
  if config.num_choices is not None:
346
370
  params["n"] = config.num_choices
347
- if config.logprobs is not None:
348
- params["logprobs"] = config.logprobs
349
- if config.top_logprobs is not None:
350
- params["top_logprobs"] = config.top_logprobs
371
+ params = self.set_logprobs_params(params, config)
351
372
  if tools and config.parallel_tool_calls is not None and not self.is_o_series():
352
373
  params["parallel_tool_calls"] = config.parallel_tool_calls
353
374
  if (
@@ -372,6 +393,15 @@ class OpenAIAPI(ModelAPI):
372
393
 
373
394
  return params
374
395
 
396
+ def set_logprobs_params(
397
+ self, params: dict[str, Any], config: GenerateConfig
398
+ ) -> dict[str, Any]:
399
+ if config.logprobs is not None:
400
+ params["logprobs"] = config.logprobs
401
+ if config.top_logprobs is not None:
402
+ params["top_logprobs"] = config.top_logprobs
403
+ return params
404
+
375
405
 
376
406
  class OpenAIAsyncHttpxClient(httpx.AsyncClient):
377
407
  """Custom async client that deals better with long running Async requests.
@@ -40,7 +40,7 @@ async def generate_o1(
40
40
  **params: Any,
41
41
  ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
42
42
  # create chatapi handler
43
- handler = O1PreviewChatAPIHandler()
43
+ handler = O1PreviewChatAPIHandler(model)
44
44
 
45
45
  # call model
46
46
  request = dict(
@@ -155,6 +155,9 @@ TOOL_CALL = "tool_call"
155
155
 
156
156
 
157
157
  class O1PreviewChatAPIHandler(ChatAPIHandler):
158
+ def __init__(self, model: str) -> None:
159
+ self.model = model
160
+
158
161
  @override
159
162
  def input_with_tools(
160
163
  self, input: list[ChatMessage], tools: list[ToolInfo]
@@ -234,12 +237,17 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
234
237
 
235
238
  # return the message
236
239
  return ChatMessageAssistant(
237
- content=content, tool_calls=tool_calls, source="generate"
240
+ content=content,
241
+ tool_calls=tool_calls,
242
+ model=self.model,
243
+ source="generate",
238
244
  )
239
245
 
240
246
  # otherwise this is just an ordinary assistant message
241
247
  else:
242
- return ChatMessageAssistant(content=response, source="generate")
248
+ return ChatMessageAssistant(
249
+ content=response, model=self.model, source="generate"
250
+ )
243
251
 
244
252
  @override
245
253
  def assistant_message(self, message: ChatMessageAssistant) -> ChatAPIMessage:
@@ -328,6 +336,5 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
328
336
  id="unknown",
329
337
  function="unknown",
330
338
  arguments={},
331
- type="function",
332
339
  parse_error=parse_error,
333
340
  )