inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -9
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +79 -12
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/eval.py +10 -1
  13. inspect_ai/_eval/loader.py +79 -19
  14. inspect_ai/_eval/registry.py +6 -0
  15. inspect_ai/_eval/score.py +3 -1
  16. inspect_ai/_eval/task/results.py +51 -22
  17. inspect_ai/_eval/task/run.py +47 -13
  18. inspect_ai/_eval/task/sandbox.py +10 -5
  19. inspect_ai/_util/constants.py +1 -0
  20. inspect_ai/_util/port_names.py +61 -0
  21. inspect_ai/_util/text.py +23 -0
  22. inspect_ai/_view/www/App.css +31 -1
  23. inspect_ai/_view/www/dist/assets/index.css +31 -1
  24. inspect_ai/_view/www/dist/assets/index.js +25498 -2044
  25. inspect_ai/_view/www/log-schema.json +32 -2
  26. inspect_ai/_view/www/package.json +2 -0
  27. inspect_ai/_view/www/src/App.mjs +14 -16
  28. inspect_ai/_view/www/src/Types.mjs +1 -2
  29. inspect_ai/_view/www/src/api/Types.ts +133 -0
  30. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  31. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  32. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  33. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  34. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  35. inspect_ai/_view/www/src/api/index.ts +51 -0
  36. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  37. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  38. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  39. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  40. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  41. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  42. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  43. inspect_ai/_view/www/src/index.js +77 -4
  44. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
  46. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
  47. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  48. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  49. inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
  50. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  51. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  52. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
  53. inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
  54. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  55. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  56. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
  76. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  77. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
  78. inspect_ai/_view/www/vite.config.js +7 -0
  79. inspect_ai/_view/www/yarn.lock +116 -0
  80. inspect_ai/approval/_human/__init__.py +0 -0
  81. inspect_ai/approval/_human/manager.py +1 -1
  82. inspect_ai/approval/_policy.py +12 -6
  83. inspect_ai/log/_log.py +1 -1
  84. inspect_ai/log/_samples.py +16 -0
  85. inspect_ai/log/_transcript.py +4 -1
  86. inspect_ai/model/_call_tools.py +59 -0
  87. inspect_ai/model/_conversation.py +16 -7
  88. inspect_ai/model/_generate_config.py +12 -12
  89. inspect_ai/model/_model.py +117 -18
  90. inspect_ai/model/_model_output.py +22 -2
  91. inspect_ai/model/_openai.py +383 -0
  92. inspect_ai/model/_providers/anthropic.py +152 -55
  93. inspect_ai/model/_providers/azureai.py +21 -21
  94. inspect_ai/model/_providers/bedrock.py +37 -40
  95. inspect_ai/model/_providers/goodfire.py +248 -0
  96. inspect_ai/model/_providers/google.py +46 -54
  97. inspect_ai/model/_providers/groq.py +7 -3
  98. inspect_ai/model/_providers/hf.py +6 -0
  99. inspect_ai/model/_providers/mistral.py +13 -12
  100. inspect_ai/model/_providers/openai.py +51 -218
  101. inspect_ai/model/_providers/openai_o1.py +11 -12
  102. inspect_ai/model/_providers/providers.py +23 -1
  103. inspect_ai/model/_providers/together.py +12 -12
  104. inspect_ai/model/_providers/util/__init__.py +2 -3
  105. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  106. inspect_ai/model/_providers/util/llama31.py +1 -1
  107. inspect_ai/model/_providers/util/util.py +0 -76
  108. inspect_ai/model/_providers/vertex.py +1 -4
  109. inspect_ai/scorer/_metric.py +3 -0
  110. inspect_ai/scorer/_reducer/reducer.py +1 -1
  111. inspect_ai/scorer/_scorer.py +4 -3
  112. inspect_ai/solver/__init__.py +4 -5
  113. inspect_ai/solver/_basic_agent.py +1 -1
  114. inspect_ai/solver/_bridge/__init__.py +3 -0
  115. inspect_ai/solver/_bridge/bridge.py +100 -0
  116. inspect_ai/solver/_bridge/patch.py +170 -0
  117. inspect_ai/solver/_prompt.py +35 -5
  118. inspect_ai/solver/_solver.py +6 -0
  119. inspect_ai/solver/_task_state.py +80 -38
  120. inspect_ai/tool/__init__.py +2 -0
  121. inspect_ai/tool/_tool.py +12 -1
  122. inspect_ai/tool/_tool_call.py +10 -0
  123. inspect_ai/tool/_tool_def.py +16 -5
  124. inspect_ai/tool/_tool_with.py +21 -4
  125. inspect_ai/tool/beta/__init__.py +5 -0
  126. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  127. inspect_ai/tool/beta/_computer/_common.py +133 -0
  128. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  129. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  130. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  131. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  134. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  135. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  136. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  137. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  138. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  139. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  144. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  145. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  146. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  147. inspect_ai/util/__init__.py +2 -0
  148. inspect_ai/util/_display.py +5 -0
  149. inspect_ai/util/_limit.py +26 -0
  150. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  153. inspect_ai/util/_sandbox/environment.py +14 -0
  154. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  155. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
  156. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  157. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  158. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  159. inspect_ai/_view/www/src/api/index.mjs +0 -49
  160. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  161. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  162. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  163. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  164. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  165. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  166. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ from inspect_ai._util.trace import trace_action
33
33
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
34
34
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
35
35
  from inspect_ai.util import concurrency
36
+ from inspect_ai.util._limit import SampleLimitExceededError
36
37
 
37
38
  from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
38
39
  from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
@@ -43,7 +44,7 @@ from ._chat_message import (
43
44
  ChatMessageTool,
44
45
  ChatMessageUser,
45
46
  )
46
- from ._conversation import conversation_assistant_message
47
+ from ._conversation import conversation_assistant_error, conversation_assistant_message
47
48
  from ._generate_config import (
48
49
  GenerateConfig,
49
50
  active_generate_config,
@@ -116,7 +117,7 @@ class ModelAPI(abc.ABC):
116
117
  tools: list[ToolInfo],
117
118
  tool_choice: ToolChoice,
118
119
  config: GenerateConfig,
119
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
120
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
120
121
  """Generate output from the model.
121
122
 
122
123
  Args:
@@ -165,7 +166,7 @@ class ModelAPI(abc.ABC):
165
166
  return False
166
167
 
167
168
  def tool_result_images(self) -> bool:
168
- """Tool results can containe images"""
169
+ """Tool results can contain images"""
169
170
  return False
170
171
 
171
172
 
@@ -222,11 +223,17 @@ class Model:
222
223
  Returns:
223
224
  ModelOutput
224
225
  """
226
+ # if we are the default model then enforce message limit if it
227
+ # exists (raise an exception if it is exceeded)
228
+ is_active_model = self == active_model()
229
+ if is_active_model:
230
+ handle_sample_message_limit(input)
231
+
225
232
  # base config for this model
226
233
  base_config = self.config
227
234
 
228
235
  # if we are the active_model then merge active generate config
229
- if self == active_model():
236
+ if is_active_model:
230
237
  base_config = base_config.merge(active_generate_config())
231
238
 
232
239
  # merge passed config
@@ -296,6 +303,9 @@ class Model:
296
303
  tools = []
297
304
  tool_choice = "none"
298
305
 
306
+ # apply any tool model_input handlers
307
+ input = resolve_tool_model_input(tdefs, input)
308
+
299
309
  # break tool image content out into user messages if the model doesn't
300
310
  # support tools returning images
301
311
  if not self.api.tool_result_images():
@@ -389,6 +399,17 @@ class Model:
389
399
  output = result
390
400
  call = None
391
401
 
402
+ # raise error
403
+ if isinstance(output, Exception):
404
+ complete(output, call)
405
+
406
+ # Wrap the error in a runtime error which will show the
407
+ # request which caused the error
408
+ error = repr(output)
409
+ request = json.dumps(call.request, indent=2) if call is not None else ""
410
+ error_message = f"{error}\n\nRequest:\n{request}"
411
+ raise RuntimeError(error_message)
412
+
392
413
  # update output with time elapsed
393
414
  output.time = time_elapsed
394
415
 
@@ -464,7 +485,7 @@ class Model:
464
485
  cache: Literal["read", "write"] | None,
465
486
  output: ModelOutput | None = None,
466
487
  call: ModelCall | None = None,
467
- ) -> Callable[[ModelOutput, ModelCall | None], None]:
488
+ ) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
468
489
  from inspect_ai.log._transcript import ModelEvent, transcript
469
490
 
470
491
  # create event and add it to the transcript
@@ -484,13 +505,16 @@ class Model:
484
505
 
485
506
  # callable that can be used to update the interaction w/ output
486
507
  def complete(
487
- updated_output: ModelOutput, updated_call: ModelCall | None
508
+ result: ModelOutput | Exception, updated_call: ModelCall | None
488
509
  ) -> None:
489
510
  # trace
490
- conversation_assistant_message(input, updated_output.choices[0].message)
511
+ if isinstance(result, ModelOutput):
512
+ conversation_assistant_message(input, result.choices[0].message)
513
+ event.output = result
514
+ else:
515
+ conversation_assistant_error(result)
516
+ event.error = repr(result)
491
517
 
492
- # update event
493
- event.output = updated_output
494
518
  event.call = updated_call
495
519
  event.pending = None
496
520
 
@@ -703,6 +727,40 @@ def simple_input_messages(
703
727
  return messages
704
728
 
705
729
 
730
+ def resolve_tool_model_input(
731
+ tdefs: list[ToolDef], messages: list[ChatMessage]
732
+ ) -> list[ChatMessage]:
733
+ # filter on tooldefs that have a model input handler
734
+ tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
735
+
736
+ # bail if there are no handlers
737
+ if len(tdefs) == 0:
738
+ return messages
739
+
740
+ # don't mutate the original messages
741
+ messages = deepcopy(messages)
742
+
743
+ # extract tool messages
744
+ tool_messages = [
745
+ message for message in messages if isinstance(message, ChatMessageTool)
746
+ ]
747
+ # run model_input handlers over all tool_messages with the same function name
748
+ for tdef in tdefs:
749
+ assert tdef.model_input
750
+ # filter messages down to just this tool
751
+ tdef_tool_messages = [
752
+ message for message in tool_messages if message.function == tdef.name
753
+ ]
754
+ # call the function for each tool, passing the index, total, and content
755
+ for index, message in enumerate(tdef_tool_messages):
756
+ message.content = tdef.model_input(
757
+ index, len(tool_messages), message.content
758
+ )
759
+
760
+ # return modified messages
761
+ return messages
762
+
763
+
706
764
  def tool_result_images_as_user_message(
707
765
  messages: list[ChatMessage],
708
766
  ) -> list[ChatMessage]:
@@ -713,16 +771,21 @@ def tool_result_images_reducer(
713
771
  messages: list[ChatMessage],
714
772
  message: ChatMessage,
715
773
  ) -> list[ChatMessage]:
716
- # append the message
717
- messages.append(message)
718
-
719
774
  # if there are tool result images, pull them out into a ChatUserMessage
720
775
  if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
776
+ tool_message = ChatMessageTool(
777
+ content=message.content.copy(),
778
+ tool_call_id=message.tool_call_id,
779
+ function=message.function,
780
+ )
781
+ assert isinstance(tool_message.content, list)
782
+ messages.append(tool_message)
783
+
721
784
  user_content: list[Content] = []
722
- for i in range(0, len(message.content)):
723
- if isinstance(message.content[i], ContentImage):
785
+ for i in range(0, len(tool_message.content)):
786
+ if isinstance(tool_message.content[i], ContentImage):
724
787
  user_content.append(message.content[i])
725
- message.content[i] = ContentText(
788
+ tool_message.content[i] = ContentText(
726
789
  text="Image content is in the message below."
727
790
  )
728
791
  if len(user_content) > 0:
@@ -730,6 +793,9 @@ def tool_result_images_reducer(
730
793
  ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
731
794
  )
732
795
 
796
+ else:
797
+ messages.append(message)
798
+
733
799
  # return messages
734
800
  return messages
735
801
 
@@ -813,6 +879,24 @@ def active_model() -> Model | None:
813
879
  active_model_context_var: ContextVar[Model] = ContextVar("active_model")
814
880
 
815
881
 
882
+ def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
883
+ from inspect_ai.log._samples import (
884
+ active_sample_message_limit,
885
+ set_active_sample_total_messages,
886
+ )
887
+
888
+ total_messages = 1 if isinstance(input, str) else len(input)
889
+ message_limit = active_sample_message_limit()
890
+ if message_limit is not None:
891
+ if total_messages >= message_limit:
892
+ raise SampleLimitExceededError(
893
+ "message", value=total_messages, limit=message_limit
894
+ )
895
+
896
+ # set total messages
897
+ set_active_sample_total_messages(total_messages)
898
+
899
+
816
900
  def init_model_usage() -> None:
817
901
  model_usage_context_var.set({})
818
902
 
@@ -822,13 +906,28 @@ def init_sample_model_usage() -> None:
822
906
 
823
907
 
824
908
  def record_model_usage(model: str, usage: ModelUsage) -> None:
909
+ from inspect_ai.log._samples import (
910
+ active_sample_token_limit,
911
+ set_active_sample_total_tokens,
912
+ )
913
+
914
+ # record usage
825
915
  set_model_usage(model, usage, sample_model_usage_context_var.get(None))
826
916
  set_model_usage(model, usage, model_usage_context_var.get(None))
827
917
 
828
- # update active sample
829
- from inspect_ai.log._samples import set_active_sample_total_tokens
918
+ # compute total tokens
919
+ total_tokens = sample_total_tokens()
830
920
 
831
- set_active_sample_total_tokens(sample_total_tokens())
921
+ # update active sample
922
+ set_active_sample_total_tokens(total_tokens)
923
+
924
+ # check for token limit overflow and raise
925
+ token_limit = active_sample_token_limit()
926
+ if token_limit is not None:
927
+ if total_tokens > token_limit:
928
+ raise SampleLimitExceededError(
929
+ "token", value=total_tokens, limit=token_limit
930
+ )
832
931
 
833
932
 
834
933
  def set_model_usage(
@@ -26,9 +26,14 @@ class ModelUsage(BaseModel):
26
26
 
27
27
 
28
28
  StopReason = Literal[
29
- "stop", "max_tokens", "model_length", "tool_calls", "content_filter", "unknown"
29
+ "stop",
30
+ "max_tokens",
31
+ "model_length",
32
+ "tool_calls",
33
+ "content_filter",
34
+ "unknown",
30
35
  ]
31
- """Reason that the model stopped generating."""
36
+ """Reason that the model stopped or failed to generate."""
32
37
 
33
38
 
34
39
  class TopLogprob(BaseModel):
@@ -209,3 +214,18 @@ class ModelOutput(BaseModel):
209
214
  )
210
215
  ],
211
216
  )
217
+
218
+
219
+ def as_stop_reason(reason: str | None) -> StopReason:
220
+ """Encode common reason strings into standard StopReason."""
221
+ match reason:
222
+ case "stop" | "eos":
223
+ return "stop"
224
+ case "length":
225
+ return "max_tokens"
226
+ case "tool_calls" | "function_call":
227
+ return "tool_calls"
228
+ case "content_filter" | "model_length" | "max_tokens":
229
+ return reason
230
+ case _:
231
+ return "unknown"
@@ -0,0 +1,383 @@
1
+ import json
2
+ from typing import Literal
3
+
4
+ from openai.types.chat import (
5
+ ChatCompletion,
6
+ ChatCompletionAssistantMessageParam,
7
+ ChatCompletionContentPartImageParam,
8
+ ChatCompletionContentPartInputAudioParam,
9
+ ChatCompletionContentPartParam,
10
+ ChatCompletionContentPartRefusalParam,
11
+ ChatCompletionContentPartTextParam,
12
+ ChatCompletionDeveloperMessageParam,
13
+ ChatCompletionMessage,
14
+ ChatCompletionMessageParam,
15
+ ChatCompletionMessageToolCall,
16
+ ChatCompletionMessageToolCallParam,
17
+ ChatCompletionNamedToolChoiceParam,
18
+ ChatCompletionSystemMessageParam,
19
+ ChatCompletionToolChoiceOptionParam,
20
+ ChatCompletionToolMessageParam,
21
+ ChatCompletionToolParam,
22
+ ChatCompletionUserMessageParam,
23
+ )
24
+ from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
25
+ from openai.types.chat.chat_completion_message_tool_call import Function
26
+ from openai.types.completion_usage import CompletionUsage
27
+ from openai.types.shared_params.function_definition import FunctionDefinition
28
+
29
+ from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
30
+ from inspect_ai._util.images import file_as_data_uri
31
+ from inspect_ai._util.url import is_http_url
32
+ from inspect_ai.model._call_tools import parse_tool_call
33
+ from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
34
+ from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
35
+
36
+ from ._chat_message import (
37
+ ChatMessage,
38
+ ChatMessageAssistant,
39
+ ChatMessageSystem,
40
+ ChatMessageTool,
41
+ ChatMessageUser,
42
+ )
43
+ from ._model_output import ModelUsage, StopReason, as_stop_reason
44
+
45
+
46
+ def is_o1(name: str) -> bool:
47
+ return name.startswith("o1")
48
+
49
+
50
+ def is_o1_full(name: str) -> bool:
51
+ return is_o1(name) and not is_o1_mini(name) and not is_o1_preview(name)
52
+
53
+
54
+ def is_o1_mini(name: str) -> bool:
55
+ return name.startswith("o1-mini")
56
+
57
+
58
+ def is_o1_preview(name: str) -> bool:
59
+ return name.startswith("o1-preview")
60
+
61
+
62
+ def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
63
+ return ChatCompletionMessageToolCall(
64
+ type="function",
65
+ id=tool_call.id,
66
+ function=Function(
67
+ name=tool_call.function, arguments=json.dumps(tool_call.arguments)
68
+ ),
69
+ )
70
+
71
+
72
+ def openai_chat_tool_call_param(
73
+ tool_call: ToolCall,
74
+ ) -> ChatCompletionMessageToolCallParam:
75
+ return ChatCompletionMessageToolCallParam(
76
+ id=tool_call.id,
77
+ function=dict(
78
+ name=tool_call.function, arguments=json.dumps(tool_call.arguments)
79
+ ),
80
+ type=tool_call.type,
81
+ )
82
+
83
+
84
+ async def openai_chat_completion_part(
85
+ content: Content,
86
+ ) -> ChatCompletionContentPartParam:
87
+ if content.type == "text":
88
+ return ChatCompletionContentPartTextParam(type="text", text=content.text)
89
+ elif content.type == "image":
90
+ # API takes URL or base64 encoded file. If it's a remote file or
91
+ # data URL leave it alone, otherwise encode it
92
+ image_url = content.image
93
+ detail = content.detail
94
+
95
+ if not is_http_url(image_url):
96
+ image_url = await file_as_data_uri(image_url)
97
+
98
+ return ChatCompletionContentPartImageParam(
99
+ type="image_url",
100
+ image_url=dict(url=image_url, detail=detail),
101
+ )
102
+ elif content.type == "audio":
103
+ audio_data = await file_as_data_uri(content.audio)
104
+
105
+ return ChatCompletionContentPartInputAudioParam(
106
+ type="input_audio", input_audio=dict(data=audio_data, format=content.format)
107
+ )
108
+
109
+ else:
110
+ raise RuntimeError(
111
+ "Video content is not currently supported by Open AI chat models."
112
+ )
113
+
114
+
115
+ async def openai_chat_message(
116
+ message: ChatMessage, model: str
117
+ ) -> ChatCompletionMessageParam:
118
+ if message.role == "system":
119
+ if is_o1(model):
120
+ return ChatCompletionDeveloperMessageParam(
121
+ role="developer", content=message.text
122
+ )
123
+ else:
124
+ return ChatCompletionSystemMessageParam(
125
+ role=message.role, content=message.text
126
+ )
127
+ elif message.role == "user":
128
+ return ChatCompletionUserMessageParam(
129
+ role=message.role,
130
+ content=(
131
+ message.content
132
+ if isinstance(message.content, str)
133
+ else [
134
+ await openai_chat_completion_part(content)
135
+ for content in message.content
136
+ ]
137
+ ),
138
+ )
139
+ elif message.role == "assistant":
140
+ if message.tool_calls:
141
+ return ChatCompletionAssistantMessageParam(
142
+ role=message.role,
143
+ content=message.text,
144
+ tool_calls=[
145
+ openai_chat_tool_call_param(call) for call in message.tool_calls
146
+ ],
147
+ )
148
+ else:
149
+ return ChatCompletionAssistantMessageParam(
150
+ role=message.role, content=message.text
151
+ )
152
+ elif message.role == "tool":
153
+ return ChatCompletionToolMessageParam(
154
+ role=message.role,
155
+ content=(
156
+ f"Error: {message.error.message}" if message.error else message.text
157
+ ),
158
+ tool_call_id=str(message.tool_call_id),
159
+ )
160
+ else:
161
+ raise ValueError(f"Unexpected message role {message.role}")
162
+
163
+
164
+ async def openai_chat_messages(
165
+ messages: list[ChatMessage], model: str
166
+ ) -> list[ChatCompletionMessageParam]:
167
+ return [await openai_chat_message(message, model) for message in messages]
168
+
169
+
170
+ def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]:
171
+ oai_choices: list[Choice] = []
172
+
173
+ for index, choice in enumerate(choices):
174
+ if isinstance(choice.message.content, str):
175
+ content = choice.message.content
176
+ else:
177
+ content = "\n".join(
178
+ [c.text for c in choice.message.content if c.type == "text"]
179
+ )
180
+ if choice.message.tool_calls:
181
+ tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls]
182
+ else:
183
+ tool_calls = None
184
+ message = ChatCompletionMessage(
185
+ role="assistant", content=content, tool_calls=tool_calls
186
+ )
187
+ oai_choices.append(
188
+ Choice(
189
+ finish_reason=openai_finish_reason(choice.stop_reason),
190
+ index=index,
191
+ message=message,
192
+ logprobs=ChoiceLogprobs(**choice.logprobs.model_dump())
193
+ if choice.logprobs is not None
194
+ else None,
195
+ )
196
+ )
197
+
198
+ return oai_choices
199
+
200
+
201
+ def openai_completion_usage(usage: ModelUsage) -> CompletionUsage:
202
+ return CompletionUsage(
203
+ completion_tokens=usage.output_tokens,
204
+ prompt_tokens=usage.input_tokens,
205
+ total_tokens=usage.total_tokens,
206
+ )
207
+
208
+
209
+ def openai_finish_reason(
210
+ stop_reason: StopReason,
211
+ ) -> Literal["stop", "length", "tool_calls", "content_filter", "function_call"]:
212
+ match stop_reason:
213
+ case "stop" | "tool_calls" | "content_filter":
214
+ return stop_reason
215
+ case "model_length":
216
+ return "length"
217
+ case _:
218
+ return "stop"
219
+
220
+
221
+ def openai_chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
222
+ function = FunctionDefinition(
223
+ name=tool.name,
224
+ description=tool.description,
225
+ parameters=tool.parameters.model_dump(exclude_none=True),
226
+ )
227
+ return ChatCompletionToolParam(type="function", function=function)
228
+
229
+
230
+ def openai_chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
231
+ return [openai_chat_tool_param(tool) for tool in tools]
232
+
233
+
234
+ def openai_chat_tool_choice(
235
+ tool_choice: ToolChoice,
236
+ ) -> ChatCompletionToolChoiceOptionParam:
237
+ if isinstance(tool_choice, ToolFunction):
238
+ return ChatCompletionNamedToolChoiceParam(
239
+ type="function", function=dict(name=tool_choice.name)
240
+ )
241
+ # openai supports 'any' via the 'required' keyword
242
+ elif tool_choice == "any":
243
+ return "required"
244
+ else:
245
+ return tool_choice
246
+
247
+
248
+ def chat_tool_calls_from_openai(
249
+ message: ChatCompletionMessage, tools: list[ToolInfo]
250
+ ) -> list[ToolCall] | None:
251
+ if message.tool_calls:
252
+ return [
253
+ parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
254
+ for call in message.tool_calls
255
+ ]
256
+ else:
257
+ return None
258
+
259
+
260
+ def chat_messages_from_openai(
261
+ messages: list[ChatCompletionMessageParam],
262
+ ) -> list[ChatMessage]:
263
+ # track tool names by id
264
+ tool_names: dict[str, str] = {}
265
+
266
+ chat_messages: list[ChatMessage] = []
267
+
268
+ for message in messages:
269
+ if message["role"] == "system" or message["role"] == "developer":
270
+ sys_content = message["content"]
271
+ if isinstance(sys_content, str):
272
+ chat_messages.append(ChatMessageSystem(content=sys_content))
273
+ else:
274
+ chat_messages.append(
275
+ ChatMessageSystem(
276
+ content=[content_from_openai(c) for c in sys_content]
277
+ )
278
+ )
279
+ elif message["role"] == "user":
280
+ user_content = message["content"]
281
+ if isinstance(user_content, str):
282
+ chat_messages.append(ChatMessageUser(content=user_content))
283
+ else:
284
+ chat_messages.append(
285
+ ChatMessageUser(
286
+ content=[content_from_openai(c) for c in user_content]
287
+ )
288
+ )
289
+ elif message["role"] == "assistant":
290
+ # resolve content
291
+ asst_content = message["content"]
292
+ if isinstance(asst_content, str):
293
+ content: str | list[Content] = asst_content
294
+ elif asst_content is None:
295
+ content = message.get("refusal", None) or ""
296
+ else:
297
+ content = [content_from_openai(c) for c in asst_content]
298
+
299
+ # return message
300
+ if "tool_calls" in message:
301
+ tool_calls: list[ToolCall] = []
302
+ for tc in message["tool_calls"]:
303
+ tool_calls.append(tool_call_from_openai(tc))
304
+ tool_names[tc["id"]] = tc["function"]["name"]
305
+
306
+ else:
307
+ tool_calls = []
308
+ chat_messages.append(
309
+ ChatMessageAssistant(content=content, tool_calls=tool_calls or None)
310
+ )
311
+ elif message["role"] == "tool":
312
+ tool_content = message.get("content", None) or ""
313
+ if isinstance(tool_content, str):
314
+ content = tool_content
315
+ else:
316
+ content = [content_from_openai(c) for c in tool_content]
317
+ chat_messages.append(
318
+ ChatMessageTool(
319
+ content=content,
320
+ tool_call_id=message["tool_call_id"],
321
+ function=tool_names.get(message["tool_call_id"], ""),
322
+ )
323
+ )
324
+ else:
325
+ raise ValueError(f"Unexpected message param type: {type(message)}")
326
+
327
+ return chat_messages
328
+
329
+
330
+ def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> ToolCall:
331
+ return parse_tool_call(
332
+ tool_call["id"],
333
+ tool_call["function"]["name"],
334
+ tool_call["function"]["arguments"],
335
+ )
336
+
337
+
338
+ def content_from_openai(
339
+ content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
340
+ ) -> Content:
341
+ if content["type"] == "text":
342
+ return ContentText(text=content["text"])
343
+ elif content["type"] == "image_url":
344
+ return ContentImage(
345
+ image=content["image_url"]["url"], detail=content["image_url"]["detail"]
346
+ )
347
+ elif content["type"] == "input_audio":
348
+ return ContentAudio(
349
+ audio=content["input_audio"]["data"],
350
+ format=content["input_audio"]["format"],
351
+ )
352
+ elif content["type"] == "refusal":
353
+ return ContentText(text=content["refusal"])
354
+
355
+
356
+ def chat_message_assistant_from_openai(
357
+ message: ChatCompletionMessage, tools: list[ToolInfo]
358
+ ) -> ChatMessageAssistant:
359
+ refusal = getattr(message, "refusal", None)
360
+ return ChatMessageAssistant(
361
+ content=refusal or message.content or "",
362
+ source="generate",
363
+ tool_calls=chat_tool_calls_from_openai(message, tools),
364
+ )
365
+
366
+
367
+ def chat_choices_from_openai(
368
+ response: ChatCompletion, tools: list[ToolInfo]
369
+ ) -> list[ChatCompletionChoice]:
370
+ choices = list(response.choices)
371
+ choices.sort(key=lambda c: c.index)
372
+ return [
373
+ ChatCompletionChoice(
374
+ message=chat_message_assistant_from_openai(choice.message, tools),
375
+ stop_reason=as_stop_reason(choice.finish_reason),
376
+ logprobs=(
377
+ Logprobs(**choice.logprobs.model_dump())
378
+ if choice.logprobs is not None
379
+ else None
380
+ ),
381
+ )
382
+ for choice in choices
383
+ ]