inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. inspect_ai/_cli/common.py +3 -1
  2. inspect_ai/_cli/eval.py +15 -9
  3. inspect_ai/_display/core/active.py +4 -1
  4. inspect_ai/_display/core/config.py +3 -3
  5. inspect_ai/_display/core/panel.py +7 -3
  6. inspect_ai/_display/plain/__init__.py +0 -0
  7. inspect_ai/_display/plain/display.py +203 -0
  8. inspect_ai/_display/rich/display.py +0 -5
  9. inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
  10. inspect_ai/_display/textual/widgets/samples.py +79 -12
  11. inspect_ai/_display/textual/widgets/sandbox.py +37 -0
  12. inspect_ai/_eval/eval.py +10 -1
  13. inspect_ai/_eval/loader.py +79 -19
  14. inspect_ai/_eval/registry.py +6 -0
  15. inspect_ai/_eval/score.py +3 -1
  16. inspect_ai/_eval/task/results.py +51 -22
  17. inspect_ai/_eval/task/run.py +47 -13
  18. inspect_ai/_eval/task/sandbox.py +10 -5
  19. inspect_ai/_util/constants.py +1 -0
  20. inspect_ai/_util/port_names.py +61 -0
  21. inspect_ai/_util/text.py +23 -0
  22. inspect_ai/_view/www/App.css +31 -1
  23. inspect_ai/_view/www/dist/assets/index.css +31 -1
  24. inspect_ai/_view/www/dist/assets/index.js +25498 -2044
  25. inspect_ai/_view/www/log-schema.json +32 -2
  26. inspect_ai/_view/www/package.json +2 -0
  27. inspect_ai/_view/www/src/App.mjs +14 -16
  28. inspect_ai/_view/www/src/Types.mjs +1 -2
  29. inspect_ai/_view/www/src/api/Types.ts +133 -0
  30. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  31. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  32. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  33. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  34. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  35. inspect_ai/_view/www/src/api/index.ts +51 -0
  36. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  37. inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
  38. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  39. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
  40. inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
  41. inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
  42. inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
  43. inspect_ai/_view/www/src/index.js +77 -4
  44. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
  46. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
  47. inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
  48. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
  49. inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
  50. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  51. inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
  52. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
  53. inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
  54. inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
  55. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
  56. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  57. inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
  58. inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
  59. inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
  60. inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
  61. inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
  62. inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
  63. inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
  64. inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
  65. inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
  66. inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
  67. inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
  68. inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
  69. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
  70. inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
  71. inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
  72. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
  73. inspect_ai/_view/www/src/types/log.d.ts +13 -2
  74. inspect_ai/_view/www/src/utils/Format.mjs +10 -3
  75. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
  76. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  77. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
  78. inspect_ai/_view/www/vite.config.js +7 -0
  79. inspect_ai/_view/www/yarn.lock +116 -0
  80. inspect_ai/approval/_human/__init__.py +0 -0
  81. inspect_ai/approval/_human/manager.py +1 -1
  82. inspect_ai/approval/_policy.py +12 -6
  83. inspect_ai/log/_log.py +1 -1
  84. inspect_ai/log/_samples.py +16 -0
  85. inspect_ai/log/_transcript.py +4 -1
  86. inspect_ai/model/_call_tools.py +59 -0
  87. inspect_ai/model/_conversation.py +16 -7
  88. inspect_ai/model/_generate_config.py +12 -12
  89. inspect_ai/model/_model.py +117 -18
  90. inspect_ai/model/_model_output.py +22 -2
  91. inspect_ai/model/_openai.py +383 -0
  92. inspect_ai/model/_providers/anthropic.py +152 -55
  93. inspect_ai/model/_providers/azureai.py +21 -21
  94. inspect_ai/model/_providers/bedrock.py +37 -40
  95. inspect_ai/model/_providers/goodfire.py +248 -0
  96. inspect_ai/model/_providers/google.py +46 -54
  97. inspect_ai/model/_providers/groq.py +7 -3
  98. inspect_ai/model/_providers/hf.py +6 -0
  99. inspect_ai/model/_providers/mistral.py +13 -12
  100. inspect_ai/model/_providers/openai.py +51 -218
  101. inspect_ai/model/_providers/openai_o1.py +11 -12
  102. inspect_ai/model/_providers/providers.py +23 -1
  103. inspect_ai/model/_providers/together.py +12 -12
  104. inspect_ai/model/_providers/util/__init__.py +2 -3
  105. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  106. inspect_ai/model/_providers/util/llama31.py +1 -1
  107. inspect_ai/model/_providers/util/util.py +0 -76
  108. inspect_ai/model/_providers/vertex.py +1 -4
  109. inspect_ai/scorer/_metric.py +3 -0
  110. inspect_ai/scorer/_reducer/reducer.py +1 -1
  111. inspect_ai/scorer/_scorer.py +4 -3
  112. inspect_ai/solver/__init__.py +4 -5
  113. inspect_ai/solver/_basic_agent.py +1 -1
  114. inspect_ai/solver/_bridge/__init__.py +3 -0
  115. inspect_ai/solver/_bridge/bridge.py +100 -0
  116. inspect_ai/solver/_bridge/patch.py +170 -0
  117. inspect_ai/solver/_prompt.py +35 -5
  118. inspect_ai/solver/_solver.py +6 -0
  119. inspect_ai/solver/_task_state.py +80 -38
  120. inspect_ai/tool/__init__.py +2 -0
  121. inspect_ai/tool/_tool.py +12 -1
  122. inspect_ai/tool/_tool_call.py +10 -0
  123. inspect_ai/tool/_tool_def.py +16 -5
  124. inspect_ai/tool/_tool_with.py +21 -4
  125. inspect_ai/tool/beta/__init__.py +5 -0
  126. inspect_ai/tool/beta/_computer/__init__.py +3 -0
  127. inspect_ai/tool/beta/_computer/_common.py +133 -0
  128. inspect_ai/tool/beta/_computer/_computer.py +155 -0
  129. inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
  130. inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
  131. inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
  132. inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
  133. inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
  134. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
  135. inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
  136. inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
  137. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
  138. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
  139. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
  140. inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
  141. inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
  142. inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
  143. inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
  144. inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
  145. inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
  146. inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
  147. inspect_ai/util/__init__.py +2 -0
  148. inspect_ai/util/_display.py +5 -0
  149. inspect_ai/util/_limit.py +26 -0
  150. inspect_ai/util/_sandbox/docker/docker.py +64 -1
  151. inspect_ai/util/_sandbox/docker/internal.py +3 -1
  152. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  153. inspect_ai/util/_sandbox/environment.py +14 -0
  154. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  155. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
  156. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  157. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  158. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  159. inspect_ai/_view/www/src/api/index.mjs +0 -49
  160. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  161. inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
  162. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  163. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  164. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  165. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  166. {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
3
  from typing import Any
@@ -15,51 +14,39 @@ from openai import (
15
14
  from openai._types import NOT_GIVEN
16
15
  from openai.types.chat import (
17
16
  ChatCompletion,
18
- ChatCompletionAssistantMessageParam,
19
- ChatCompletionContentPartImageParam,
20
- ChatCompletionContentPartInputAudioParam,
21
- ChatCompletionContentPartParam,
22
- ChatCompletionContentPartTextParam,
23
- ChatCompletionDeveloperMessageParam,
24
- ChatCompletionMessage,
25
- ChatCompletionMessageParam,
26
- ChatCompletionMessageToolCallParam,
27
- ChatCompletionNamedToolChoiceParam,
28
- ChatCompletionSystemMessageParam,
29
- ChatCompletionToolChoiceOptionParam,
30
- ChatCompletionToolMessageParam,
31
- ChatCompletionToolParam,
32
- ChatCompletionUserMessageParam,
33
17
  )
34
- from openai.types.shared_params.function_definition import FunctionDefinition
35
18
  from typing_extensions import override
36
19
 
37
20
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
38
- from inspect_ai._util.content import Content
39
21
  from inspect_ai._util.error import PrerequisiteError
40
- from inspect_ai._util.images import file_as_data_uri
41
22
  from inspect_ai._util.logger import warn_once
42
- from inspect_ai._util.url import is_http_url
43
- from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
23
+ from inspect_ai.model._openai import chat_choices_from_openai
24
+ from inspect_ai.tool import ToolChoice, ToolInfo
44
25
 
45
- from .._chat_message import ChatMessage, ChatMessageAssistant
26
+ from .._chat_message import ChatMessage
46
27
  from .._generate_config import GenerateConfig
47
28
  from .._image import image_url_filter
48
29
  from .._model import ModelAPI
49
30
  from .._model_call import ModelCall
50
31
  from .._model_output import (
51
32
  ChatCompletionChoice,
52
- Logprobs,
53
33
  ModelOutput,
54
34
  ModelUsage,
55
35
  StopReason,
56
36
  )
37
+ from .._openai import (
38
+ is_o1,
39
+ is_o1_full,
40
+ is_o1_mini,
41
+ is_o1_preview,
42
+ openai_chat_messages,
43
+ openai_chat_tool_choice,
44
+ openai_chat_tools,
45
+ )
57
46
  from .openai_o1 import generate_o1
58
47
  from .util import (
59
- as_stop_reason,
60
48
  environment_prerequisite_error,
61
49
  model_base_url,
62
- parse_tool_call,
63
50
  )
64
51
 
65
52
  logger = getLogger(__name__)
@@ -87,20 +74,22 @@ class OpenAIAPI(ModelAPI):
87
74
  config=config,
88
75
  )
89
76
 
90
- # pull out azure model_arg
91
- AZURE_MODEL_ARG = "azure"
92
- is_azure = False
93
- if AZURE_MODEL_ARG in model_args:
94
- is_azure = model_args.get(AZURE_MODEL_ARG, False)
95
- del model_args[AZURE_MODEL_ARG]
77
+ # extract any service prefix from model name
78
+ parts = model_name.split("/")
79
+ if len(parts) > 1:
80
+ self.service: str | None = parts[0]
81
+ model_name = "/".join(parts[1:])
82
+ else:
83
+ self.service = None
96
84
 
97
85
  # resolve api_key
98
86
  if not self.api_key:
99
87
  self.api_key = os.environ.get(
100
88
  AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
101
89
  )
102
- if self.api_key:
103
- is_azure = True
90
+ # backward compatibility for when env vars determined service
91
+ if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
92
+ self.service = "azure"
104
93
  else:
105
94
  self.api_key = os.environ.get(OPENAI_API_KEY, None)
106
95
  if not self.api_key:
@@ -113,7 +102,7 @@ class OpenAIAPI(ModelAPI):
113
102
  )
114
103
 
115
104
  # azure client
116
- if is_azure:
105
+ if self.is_azure():
117
106
  # resolve base_url
118
107
  base_url = model_base_url(
119
108
  base_url,
@@ -148,17 +137,20 @@ class OpenAIAPI(ModelAPI):
148
137
  **model_args,
149
138
  )
150
139
 
140
+ def is_azure(self) -> bool:
141
+ return self.service == "azure"
142
+
151
143
  def is_o1(self) -> bool:
152
- return self.model_name.startswith("o1")
144
+ return is_o1(self.model_name)
153
145
 
154
146
  def is_o1_full(self) -> bool:
155
- return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
147
+ return is_o1_full(self.model_name)
156
148
 
157
149
  def is_o1_mini(self) -> bool:
158
- return self.model_name.startswith("o1-mini")
150
+ return is_o1_mini(self.model_name)
159
151
 
160
152
  def is_o1_preview(self) -> bool:
161
- return self.model_name.startswith("o1-preview")
153
+ return is_o1_preview(self.model_name)
162
154
 
163
155
  async def generate(
164
156
  self,
@@ -166,7 +158,7 @@ class OpenAIAPI(ModelAPI):
166
158
  tools: list[ToolInfo],
167
159
  tool_choice: ToolChoice,
168
160
  config: GenerateConfig,
169
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
161
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
170
162
  # short-circuit to call o1- models that are text only
171
163
  if self.is_o1_preview() or self.is_o1_mini():
172
164
  return await generate_o1(
@@ -198,9 +190,11 @@ class OpenAIAPI(ModelAPI):
198
190
 
199
191
  # prepare request (we do this so we can log the ModelCall)
200
192
  request = dict(
201
- messages=await as_openai_chat_messages(input, self.is_o1_full()),
202
- tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
203
- tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
193
+ messages=await openai_chat_messages(input, self.model_name),
194
+ tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
195
+ tool_choice=openai_chat_tool_choice(tool_choice)
196
+ if len(tools) > 0
197
+ else NOT_GIVEN,
204
198
  **self.completion_params(config, len(tools) > 0),
205
199
  )
206
200
 
@@ -237,7 +231,7 @@ class OpenAIAPI(ModelAPI):
237
231
  self, response: ChatCompletion, tools: list[ToolInfo]
238
232
  ) -> list[ChatCompletionChoice]:
239
233
  # adding this as a method so we can override from other classes (e.g together)
240
- return chat_choices_from_response(response, tools)
234
+ return chat_choices_from_openai(response, tools)
241
235
 
242
236
  @override
243
237
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -307,184 +301,23 @@ class OpenAIAPI(ModelAPI):
307
301
  return params
308
302
 
309
303
  # convert some well known bad request errors into ModelOutput
310
- def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
311
- if e.status_code == 400:
312
- # extract message
313
- if isinstance(e.body, dict) and "message" in e.body.keys():
314
- content = str(e.body.get("message"))
315
- else:
316
- content = e.message
304
+ def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
305
+ # extract message
306
+ if isinstance(e.body, dict) and "message" in e.body.keys():
307
+ content = str(e.body.get("message"))
308
+ else:
309
+ content = e.message
317
310
 
318
- # narrow stop_reason
319
- if e.code == "context_length_exceeded":
320
- stop_reason: StopReason = "model_length"
321
- elif e.code == "invalid_prompt":
322
- stop_reason = "content_filter"
323
- else:
324
- stop_reason = "unknown"
311
+ # narrow stop_reason
312
+ stop_reason: StopReason | None = None
313
+ if e.code == "context_length_exceeded":
314
+ stop_reason = "model_length"
315
+ elif e.code == "invalid_prompt":
316
+ stop_reason = "content_filter"
325
317
 
318
+ if stop_reason:
326
319
  return ModelOutput.from_content(
327
320
  model=self.model_name, content=content, stop_reason=stop_reason
328
321
  )
329
322
  else:
330
- raise e
331
-
332
-
333
- async def as_openai_chat_messages(
334
- messages: list[ChatMessage], o1_full: bool
335
- ) -> list[ChatCompletionMessageParam]:
336
- return [await openai_chat_message(message, o1_full) for message in messages]
337
-
338
-
339
- async def openai_chat_message(
340
- message: ChatMessage, o1_full: bool
341
- ) -> ChatCompletionMessageParam:
342
- if message.role == "system":
343
- if o1_full:
344
- return ChatCompletionDeveloperMessageParam(
345
- role="developer", content=message.text
346
- )
347
- else:
348
- return ChatCompletionSystemMessageParam(
349
- role=message.role, content=message.text
350
- )
351
- elif message.role == "user":
352
- return ChatCompletionUserMessageParam(
353
- role=message.role,
354
- content=(
355
- message.content
356
- if isinstance(message.content, str)
357
- else [
358
- await as_chat_completion_part(content)
359
- for content in message.content
360
- ]
361
- ),
362
- )
363
- elif message.role == "assistant":
364
- if message.tool_calls:
365
- return ChatCompletionAssistantMessageParam(
366
- role=message.role,
367
- content=message.text,
368
- tool_calls=[chat_tool_call(call) for call in message.tool_calls],
369
- )
370
- else:
371
- return ChatCompletionAssistantMessageParam(
372
- role=message.role, content=message.text
373
- )
374
- elif message.role == "tool":
375
- return ChatCompletionToolMessageParam(
376
- role=message.role,
377
- content=(
378
- f"Error: {message.error.message}" if message.error else message.text
379
- ),
380
- tool_call_id=str(message.tool_call_id),
381
- )
382
- else:
383
- raise ValueError(f"Unexpected message role {message.role}")
384
-
385
-
386
- def chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCallParam:
387
- return ChatCompletionMessageToolCallParam(
388
- id=tool_call.id,
389
- function=dict(
390
- name=tool_call.function, arguments=json.dumps(tool_call.arguments)
391
- ),
392
- type=tool_call.type,
393
- )
394
-
395
-
396
- def chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
397
- return [chat_tool_param(tool) for tool in tools]
398
-
399
-
400
- def chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
401
- function = FunctionDefinition(
402
- name=tool.name,
403
- description=tool.description,
404
- parameters=tool.parameters.model_dump(exclude_none=True),
405
- )
406
- return ChatCompletionToolParam(type="function", function=function)
407
-
408
-
409
- def chat_tool_choice(tool_choice: ToolChoice) -> ChatCompletionToolChoiceOptionParam:
410
- if isinstance(tool_choice, ToolFunction):
411
- return ChatCompletionNamedToolChoiceParam(
412
- type="function", function=dict(name=tool_choice.name)
413
- )
414
- # openai supports 'any' via the 'required' keyword
415
- elif tool_choice == "any":
416
- return "required"
417
- else:
418
- return tool_choice
419
-
420
-
421
- def chat_tool_calls(
422
- message: ChatCompletionMessage, tools: list[ToolInfo]
423
- ) -> list[ToolCall] | None:
424
- if message.tool_calls:
425
- return [
426
- parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
427
- for call in message.tool_calls
428
- ]
429
- else:
430
- return None
431
-
432
-
433
- def chat_choices_from_response(
434
- response: ChatCompletion, tools: list[ToolInfo]
435
- ) -> list[ChatCompletionChoice]:
436
- choices = list(response.choices)
437
- choices.sort(key=lambda c: c.index)
438
- return [
439
- ChatCompletionChoice(
440
- message=chat_message_assistant(choice.message, tools),
441
- stop_reason=as_stop_reason(choice.finish_reason),
442
- logprobs=(
443
- Logprobs(**choice.logprobs.model_dump())
444
- if choice.logprobs is not None
445
- else None
446
- ),
447
- )
448
- for choice in choices
449
- ]
450
-
451
-
452
- def chat_message_assistant(
453
- message: ChatCompletionMessage, tools: list[ToolInfo]
454
- ) -> ChatMessageAssistant:
455
- return ChatMessageAssistant(
456
- content=message.content or "",
457
- source="generate",
458
- tool_calls=chat_tool_calls(message, tools),
459
- )
460
-
461
-
462
- async def as_chat_completion_part(
463
- content: Content,
464
- ) -> ChatCompletionContentPartParam:
465
- if content.type == "text":
466
- return ChatCompletionContentPartTextParam(type="text", text=content.text)
467
- elif content.type == "image":
468
- # API takes URL or base64 encoded file. If it's a remote file or
469
- # data URL leave it alone, otherwise encode it
470
- image_url = content.image
471
- detail = content.detail
472
-
473
- if not is_http_url(image_url):
474
- image_url = await file_as_data_uri(image_url)
475
-
476
- return ChatCompletionContentPartImageParam(
477
- type="image_url",
478
- image_url=dict(url=image_url, detail=detail),
479
- )
480
- elif content.type == "audio":
481
- audio_data = await file_as_data_uri(content.audio)
482
-
483
- return ChatCompletionContentPartInputAudioParam(
484
- type="input_audio", input_audio=dict(data=audio_data, format=content.format)
485
- )
486
-
487
- else:
488
- raise RuntimeError(
489
- "Video content is not currently supported by Open AI chat models."
490
- )
323
+ return e
@@ -24,15 +24,13 @@ from inspect_ai.model import (
24
24
  )
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
+ from .._call_tools import parse_tool_call, tool_parse_error_message
27
28
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage, StopReason
29
+ from .._model_output import ModelUsage, StopReason, as_stop_reason
29
30
  from .._providers.util import (
30
31
  ChatAPIHandler,
31
32
  ChatAPIMessage,
32
- as_stop_reason,
33
33
  chat_api_input,
34
- parse_tool_call,
35
- tool_parse_error_message,
36
34
  )
37
35
 
38
36
  logger = getLogger(__name__)
@@ -44,7 +42,7 @@ async def generate_o1(
44
42
  input: list[ChatMessage],
45
43
  tools: list[ToolInfo],
46
44
  **params: Any,
47
- ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
45
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
48
46
  # create chatapi handler
49
47
  handler = O1PreviewChatAPIHandler()
50
48
 
@@ -82,17 +80,18 @@ async def generate_o1(
82
80
  ), model_call()
83
81
 
84
82
 
85
- def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
83
+ def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
86
84
  if ex.code == "context_length_exceeded":
87
- stop_reason: StopReason = "model_length"
85
+ stop_reason: StopReason | None = "model_length"
88
86
  elif ex.code == "invalid_prompt":
89
87
  stop_reason = "content_filter"
90
- else:
91
- stop_reason = "unknown"
92
88
 
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason=stop_reason
95
- )
89
+ if stop_reason:
90
+ return ModelOutput.from_content(
91
+ model=model, content=str(ex), stop_reason=stop_reason
92
+ )
93
+ else:
94
+ return ex
96
95
 
97
96
 
98
97
  def chat_messages(
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
94
94
  def google() -> type[ModelAPI]:
95
95
  FEATURE = "Google API"
96
96
  PACKAGE = "google-generativeai"
97
- MIN_VERSION = "0.8.3"
97
+ MIN_VERSION = "0.8.4"
98
98
 
99
99
  # workaround log spam
100
100
  # https://github.com/ray-project/ray/issues/24917
@@ -239,6 +239,28 @@ def mockllm() -> type[ModelAPI]:
239
239
  return MockLLM
240
240
 
241
241
 
242
+ @modelapi("goodfire")
243
+ def goodfire() -> type[ModelAPI]:
244
+ """Get the Goodfire API provider."""
245
+ FEATURE = "Goodfire API"
246
+ PACKAGE = "goodfire"
247
+ MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
248
+
249
+ # verify we have the package
250
+ try:
251
+ import goodfire # noqa: F401
252
+ except ImportError:
253
+ raise pip_dependency_error(FEATURE, [PACKAGE])
254
+
255
+ # verify version
256
+ verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
257
+
258
+ # in the clear
259
+ from .goodfire import GoodfireAPI
260
+
261
+ return GoodfireAPI
262
+
263
+
242
264
  def validate_openai_client(feature: str) -> None:
243
265
  FEATURE = feature
244
266
  PACKAGE = "openai"
@@ -24,13 +24,13 @@ from .._model_output import (
24
24
  ModelOutput,
25
25
  ModelUsage,
26
26
  StopReason,
27
+ as_stop_reason,
27
28
  )
29
+ from .._openai import chat_message_assistant_from_openai
28
30
  from .openai import (
29
31
  OpenAIAPI,
30
- chat_message_assistant,
31
32
  )
32
33
  from .util import (
33
- as_stop_reason,
34
34
  chat_api_input,
35
35
  chat_api_request,
36
36
  environment_prerequisite_error,
@@ -68,7 +68,7 @@ def chat_choices_from_response_together(
68
68
  logprobs_models.append(Logprobs(content=logprobs_sequence))
69
69
  return [
70
70
  ChatCompletionChoice(
71
- message=chat_message_assistant(choice.message, tools),
71
+ message=chat_message_assistant_from_openai(choice.message, tools),
72
72
  stop_reason=as_stop_reason(choice.finish_reason),
73
73
  logprobs=logprobs,
74
74
  )
@@ -99,22 +99,22 @@ class TogetherAIAPI(OpenAIAPI):
99
99
 
100
100
  # Together uses a default of 512 so we bump it up
101
101
  @override
102
- def max_tokens(self) -> int:
102
+ def max_tokens(self) -> int | None:
103
103
  return DEFAULT_MAX_TOKENS
104
104
 
105
105
  @override
106
- def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
107
- if ex.status_code == 400 and "max_new_tokens" in ex.message:
108
- response = ex.response.json()
109
- if "error" in response and "message" in response.get("error"):
110
- content = response.get("error").get("message")
111
- else:
112
- content = str(response)
106
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
107
+ response = ex.response.json()
108
+ if "error" in response and "message" in response.get("error"):
109
+ content = response.get("error").get("message")
110
+ else:
111
+ content = str(response)
112
+ if "max_new_tokens" in ex.message:
113
113
  return ModelOutput.from_content(
114
114
  model=self.model_name, content=content, stop_reason="model_length"
115
115
  )
116
116
  else:
117
- raise ex
117
+ return ex
118
118
 
119
119
  # Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
120
120
  def _chat_choices_from_response(
@@ -1,3 +1,5 @@
1
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
2
+ from ..._model_output import as_stop_reason
1
3
  from .chatapi import (
2
4
  ChatAPIHandler,
3
5
  ChatAPIMessage,
@@ -8,11 +10,8 @@ from .chatapi import (
8
10
  from .hf_handler import HFHandler
9
11
  from .llama31 import Llama31Handler
10
12
  from .util import (
11
- as_stop_reason,
12
13
  environment_prerequisite_error,
13
14
  model_base_url,
14
- parse_tool_call,
15
- tool_parse_error_message,
16
15
  )
17
16
 
18
17
  __all__ = [
@@ -8,9 +8,9 @@ from typing_extensions import override
8
8
  from inspect_ai.tool._tool_call import ToolCall
9
9
  from inspect_ai.tool._tool_info import ToolInfo
10
10
 
11
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
11
12
  from ..._chat_message import ChatMessageAssistant
12
13
  from .chatapi import ChatAPIHandler
13
- from .util import parse_tool_call, tool_parse_error_message
14
14
 
15
15
  logger = getLogger(__name__)
16
16
 
@@ -9,6 +9,7 @@ from typing_extensions import override
9
9
  from inspect_ai.tool._tool_call import ToolCall
10
10
  from inspect_ai.tool._tool_info import ToolInfo
11
11
 
12
+ from ..._call_tools import parse_tool_call, tool_parse_error_message
12
13
  from ..._chat_message import (
13
14
  ChatMessage,
14
15
  ChatMessageAssistant,
@@ -16,7 +17,6 @@ from ..._chat_message import (
16
17
  ChatMessageTool,
17
18
  )
18
19
  from .chatapi import ChatAPIHandler, ChatAPIMessage
19
- from .util import parse_tool_call, tool_parse_error_message
20
20
 
21
21
  logger = getLogger(__name__)
22
22
 
@@ -1,34 +1,11 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
- from typing import Any
5
-
6
- import yaml
7
3
 
8
4
  from inspect_ai._util.error import PrerequisiteError
9
- from inspect_ai.tool._tool_call import ToolCall
10
- from inspect_ai.tool._tool_info import ToolInfo
11
-
12
- from ..._model_output import StopReason
13
5
 
14
6
  logger = getLogger(__name__)
15
7
 
16
8
 
17
- def as_stop_reason(reason: str | None) -> StopReason:
18
- """Encode common reason strings into standard StopReason."""
19
- match reason:
20
- case "stop" | "eos":
21
- return "stop"
22
- case "length":
23
- return "max_tokens"
24
- case "tool_calls" | "function_call":
25
- return "tool_calls"
26
- case "content_filter" | "model_length" | "max_tokens":
27
- return reason
28
- case _:
29
- return "unknown"
30
-
31
-
32
9
  def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | None:
33
10
  if base_url:
34
11
  return base_url
@@ -44,59 +21,6 @@ def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | Non
44
21
  return os.getenv("INSPECT_EVAL_MODEL_BASE_URL", None)
45
22
 
46
23
 
47
- def tool_parse_error_message(arguments: str, ex: Exception) -> str:
48
- return f"Error parsing the following tool call arguments:\n\n{arguments}\n\nError details: {ex}"
49
-
50
-
51
- def parse_tool_call(
52
- id: str, function: str, arguments: str, tools: list[ToolInfo]
53
- ) -> ToolCall:
54
- error: str | None = None
55
- arguments_dict: dict[str, Any] = {}
56
-
57
- def report_parse_error(ex: Exception) -> None:
58
- nonlocal error
59
- error = tool_parse_error_message(arguments, ex)
60
- logger.info(error)
61
-
62
- # if the arguments is a dict, then handle it with a plain json.loads
63
- arguments = arguments.strip()
64
- if arguments.startswith("{"):
65
- try:
66
- arguments_dict = json.loads(arguments)
67
- except json.JSONDecodeError as ex:
68
- report_parse_error(ex)
69
-
70
- # otherwise parse it as yaml (which will pickup unquoted strings, numbers, and true/false)
71
- # and then create a dict that maps it to the first function argument
72
- else:
73
- tool_info = next(
74
- (
75
- tool
76
- for tool in tools
77
- if tool.name == function and len(tool.parameters.properties) > 0
78
- ),
79
- None,
80
- )
81
- if tool_info:
82
- param_names = list(tool_info.parameters.properties.keys())
83
- try:
84
- value = yaml.safe_load(arguments)
85
- arguments_dict[param_names[0]] = value
86
- except yaml.error.YAMLError:
87
- # If the yaml parser fails, we treat it as a string argument.
88
- arguments_dict[param_names[0]] = arguments
89
-
90
- # return ToolCall with error payload
91
- return ToolCall(
92
- id=id,
93
- function=function,
94
- arguments=arguments_dict,
95
- type="function",
96
- parse_error=error,
97
- )
98
-
99
-
100
24
  def environment_prerequisite_error(
101
25
  client: str, env_vars: str | list[str]
102
26
  ) -> PrerequisiteError:
@@ -23,7 +23,7 @@ from vertexai.generative_models import ( # type: ignore
23
23
  )
24
24
  from vertexai.generative_models import Content as VertexContent
25
25
 
26
- from inspect_ai._util.constants import BASE_64_DATA_REMOVED
26
+ from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
27
27
  from inspect_ai._util.content import (
28
28
  Content,
29
29
  ContentAudio,
@@ -250,9 +250,6 @@ def consective_tool_message_reducer(
250
250
  return messages
251
251
 
252
252
 
253
- NO_CONTENT = "(no content)"
254
-
255
-
256
253
  async def content_dict(
257
254
  message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
258
255
  ) -> VertexContent:
@@ -125,6 +125,9 @@ class SampleScore(BaseModel):
125
125
  sample_id: str | int | None = Field(default=None)
126
126
  """A sample id"""
127
127
 
128
+ scorer: str | None = Field(default=None)
129
+ """Registry name of scorer that created this score."""
130
+
128
131
 
129
132
  ValueToFloat = Callable[[Value], float]
130
133
  """Function used by metrics to translate from a Score value to a float value."""