inspect-ai 0.3.59__py3-none-any.whl → 0.3.60__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. inspect_ai/_cli/eval.py +0 -7
  2. inspect_ai/_display/textual/widgets/samples.py +1 -1
  3. inspect_ai/_eval/eval.py +10 -1
  4. inspect_ai/_eval/loader.py +79 -19
  5. inspect_ai/_eval/registry.py +6 -0
  6. inspect_ai/_eval/score.py +2 -1
  7. inspect_ai/_eval/task/results.py +6 -5
  8. inspect_ai/_eval/task/run.py +11 -11
  9. inspect_ai/_view/www/dist/assets/index.js +262 -303
  10. inspect_ai/_view/www/src/App.mjs +6 -6
  11. inspect_ai/_view/www/src/Types.mjs +1 -1
  12. inspect_ai/_view/www/src/api/Types.ts +133 -0
  13. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  14. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  15. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  16. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  17. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  18. inspect_ai/_view/www/src/api/index.ts +51 -0
  19. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  20. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  21. inspect_ai/_view/www/src/index.js +2 -2
  22. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  23. inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
  24. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
  25. inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
  26. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  27. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
  28. inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
  29. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  30. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
  31. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  32. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  33. inspect_ai/approval/_human/manager.py +1 -1
  34. inspect_ai/model/_call_tools.py +55 -0
  35. inspect_ai/model/_conversation.py +1 -4
  36. inspect_ai/model/_generate_config.py +2 -8
  37. inspect_ai/model/_model_output.py +15 -0
  38. inspect_ai/model/_openai.py +383 -0
  39. inspect_ai/model/_providers/anthropic.py +52 -11
  40. inspect_ai/model/_providers/azureai.py +1 -1
  41. inspect_ai/model/_providers/goodfire.py +248 -0
  42. inspect_ai/model/_providers/groq.py +7 -3
  43. inspect_ai/model/_providers/hf.py +6 -0
  44. inspect_ai/model/_providers/mistral.py +2 -1
  45. inspect_ai/model/_providers/openai.py +36 -202
  46. inspect_ai/model/_providers/openai_o1.py +2 -4
  47. inspect_ai/model/_providers/providers.py +22 -0
  48. inspect_ai/model/_providers/together.py +4 -4
  49. inspect_ai/model/_providers/util/__init__.py +2 -3
  50. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  51. inspect_ai/model/_providers/util/llama31.py +1 -1
  52. inspect_ai/model/_providers/util/util.py +0 -76
  53. inspect_ai/scorer/_metric.py +3 -0
  54. inspect_ai/scorer/_scorer.py +2 -1
  55. inspect_ai/solver/__init__.py +2 -0
  56. inspect_ai/solver/_basic_agent.py +1 -1
  57. inspect_ai/solver/_bridge/__init__.py +3 -0
  58. inspect_ai/solver/_bridge/bridge.py +100 -0
  59. inspect_ai/solver/_bridge/patch.py +170 -0
  60. inspect_ai/solver/_solver.py +6 -0
  61. inspect_ai/util/_display.py +5 -0
  62. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  63. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
  64. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +68 -63
  65. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  66. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  67. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  68. inspect_ai/_view/www/src/api/index.mjs +0 -49
  69. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  70. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  71. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
  72. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
  73. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
  74. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
1
+ import os
2
+ from typing import Any, List, Literal, get_args
3
+
4
+ from goodfire import AsyncClient
5
+ from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
6
+ from goodfire.api.exceptions import InvalidRequestException, RateLimitException
7
+ from goodfire.variants.variants import SUPPORTED_MODELS, Variant
8
+ from typing_extensions import override
9
+
10
+ from inspect_ai.tool._tool_choice import ToolChoice
11
+ from inspect_ai.tool._tool_info import ToolInfo
12
+
13
+ from .._chat_message import (
14
+ ChatMessage,
15
+ ChatMessageAssistant,
16
+ ChatMessageSystem,
17
+ ChatMessageTool,
18
+ ChatMessageUser,
19
+ )
20
+ from .._generate_config import GenerateConfig
21
+ from .._model import ModelAPI
22
+ from .._model_call import ModelCall
23
+ from .._model_output import (
24
+ ChatCompletionChoice,
25
+ ModelOutput,
26
+ ModelUsage,
27
+ )
28
+ from .util import environment_prerequisite_error, model_base_url
29
+
30
+ # Constants
31
+ GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
32
+ DEFAULT_BASE_URL = "https://api.goodfire.ai"
33
+ DEFAULT_MAX_TOKENS = 4096
34
+ DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
35
+ DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
36
+
37
+
38
+ class GoodfireAPI(ModelAPI):
39
+ """Goodfire API provider.
40
+
41
+ This provider implements the Goodfire API for LLM inference. It supports:
42
+ - Chat completions with standard message formats
43
+ - Basic parameter controls (temperature, top_p, etc.)
44
+ - Usage statistics tracking
45
+ - Stop reason handling
46
+
47
+ Does not currently support:
48
+ - Tool calls
49
+ - Feature analysis
50
+ - Streaming responses
51
+
52
+ Known limitations:
53
+ - Limited role support (system/user/assistant only)
54
+ - Tool messages converted to user messages
55
+ """
56
+
57
+ client: AsyncClient
58
+ variant: Variant
59
+ model_args: dict[str, Any]
60
+
61
+ def __init__(
62
+ self,
63
+ model_name: str,
64
+ base_url: str | None = None,
65
+ api_key: str | None = None,
66
+ config: GenerateConfig = GenerateConfig(),
67
+ **model_args: Any,
68
+ ) -> None:
69
+ """Initialize the Goodfire API provider.
70
+
71
+ Args:
72
+ model_name: Name of the model to use
73
+ base_url: Optional custom API base URL
74
+ api_key: Optional API key (will check env vars if not provided)
75
+ config: Generation config options
76
+ **model_args: Additional arguments passed to the API
77
+ """
78
+ super().__init__(
79
+ model_name=model_name,
80
+ base_url=base_url,
81
+ api_key=api_key,
82
+ api_key_vars=[GOODFIRE_API_KEY],
83
+ config=config,
84
+ )
85
+
86
+ # resolve api_key
87
+ if not self.api_key:
88
+ self.api_key = os.environ.get(GOODFIRE_API_KEY)
89
+ if not self.api_key:
90
+ raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
91
+
92
+ # Validate model name against supported models
93
+ supported_models = list(get_args(SUPPORTED_MODELS))
94
+ if self.model_name not in supported_models:
95
+ raise ValueError(
96
+ f"Model {self.model_name} not supported. Supported models: {supported_models}"
97
+ )
98
+
99
+ # Initialize client with minimal configuration
100
+ base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
101
+ assert isinstance(base_url_val, str) or base_url_val is None
102
+
103
+ # Store model args for use in generate
104
+ self.model_args = model_args
105
+
106
+ self.client = AsyncClient(
107
+ api_key=self.api_key,
108
+ base_url=base_url_val or DEFAULT_BASE_URL,
109
+ )
110
+
111
+ # Initialize variant directly with model name
112
+ self.variant = Variant(self.model_name) # type: ignore
113
+
114
+ def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
115
+ """Convert an Inspect message to a Goodfire message format.
116
+
117
+ Args:
118
+ message: The message to convert
119
+
120
+ Returns:
121
+ The converted message in Goodfire format
122
+
123
+ Raises:
124
+ ValueError: If the message type is unknown
125
+ """
126
+ role: Literal["system", "user", "assistant"] = "user"
127
+ if isinstance(message, ChatMessageSystem):
128
+ role = "system"
129
+ elif isinstance(message, ChatMessageUser):
130
+ role = "user"
131
+ elif isinstance(message, ChatMessageAssistant):
132
+ role = "assistant"
133
+ elif isinstance(message, ChatMessageTool):
134
+ role = "user" # Convert tool messages to user messages
135
+ else:
136
+ raise ValueError(f"Unknown message type: {type(message)}")
137
+
138
+ content = str(message.content)
139
+ if isinstance(message, ChatMessageTool):
140
+ content = f"Tool {message.function}: {content}"
141
+
142
+ return GoodfireChatMessage(role=role, content=content)
143
+
144
+ def handle_error(self, ex: Exception) -> ModelOutput | Exception:
145
+ """Handle only errors that need special treatment for retry logic or model limits."""
146
+ # Handle token/context length errors
147
+ if isinstance(ex, InvalidRequestException):
148
+ error_msg = str(ex).lower()
149
+ if "context length" in error_msg or "max tokens" in error_msg:
150
+ return ModelOutput.from_content(
151
+ model=self.model_name,
152
+ content=str(ex),
153
+ stop_reason="model_length",
154
+ error=error_msg,
155
+ )
156
+
157
+ # Let all other errors propagate
158
+ return ex
159
+
160
+ @override
161
+ def is_rate_limit(self, ex: BaseException) -> bool:
162
+ """Check if exception is due to rate limiting."""
163
+ return isinstance(ex, RateLimitException)
164
+
165
+ @override
166
+ def connection_key(self) -> str:
167
+ """Return key for connection pooling."""
168
+ return f"goodfire:{self.api_key}"
169
+
170
+ @override
171
+ def max_tokens(self) -> int | None:
172
+ """Return maximum tokens supported by model."""
173
+ return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
174
+
175
+ async def generate(
176
+ self,
177
+ input: List[ChatMessage],
178
+ tools: List[ToolInfo],
179
+ tool_choice: ToolChoice,
180
+ config: GenerateConfig,
181
+ *,
182
+ cache: bool = True,
183
+ ) -> tuple[ModelOutput | Exception, ModelCall]:
184
+ """Generate output from the model."""
185
+ # Convert messages and prepare request params
186
+ messages = [self._to_goodfire_message(msg) for msg in input]
187
+ # Build request parameters with type hints
188
+ params: dict[str, Any] = {
189
+ "model": self.variant.base_model, # Use base_model instead of stringifying the Variant
190
+ "messages": messages,
191
+ "max_completion_tokens": int(config.max_tokens)
192
+ if config.max_tokens
193
+ else DEFAULT_MAX_TOKENS,
194
+ "stream": False,
195
+ }
196
+
197
+ # Add generation parameters from config if not in model_args
198
+ if "temperature" not in self.model_args and config.temperature is not None:
199
+ params["temperature"] = float(config.temperature)
200
+ elif "temperature" not in self.model_args:
201
+ params["temperature"] = DEFAULT_TEMPERATURE
202
+
203
+ if "top_p" not in self.model_args and config.top_p is not None:
204
+ params["top_p"] = float(config.top_p)
205
+ elif "top_p" not in self.model_args:
206
+ params["top_p"] = DEFAULT_TOP_P
207
+
208
+ # Add any additional model args (highest priority)
209
+ api_params = {
210
+ k: v
211
+ for k, v in self.model_args.items()
212
+ if k not in ["api_key", "base_url", "model_args"]
213
+ }
214
+ params.update(api_params)
215
+
216
+ try:
217
+ # Use native async client
218
+ response = await self.client.chat.completions.create(**params)
219
+ response_dict = response.model_dump()
220
+
221
+ output = ModelOutput(
222
+ model=self.model_name,
223
+ choices=[
224
+ ChatCompletionChoice(
225
+ message=ChatMessageAssistant(
226
+ content=response_dict["choices"][0]["message"]["content"]
227
+ ),
228
+ stop_reason="stop",
229
+ )
230
+ ],
231
+ usage=ModelUsage(**response_dict["usage"])
232
+ if "usage" in response_dict
233
+ else None,
234
+ )
235
+ model_call = ModelCall.create(request=params, response=response_dict)
236
+ return (output, model_call)
237
+ except Exception as ex:
238
+ result = self.handle_error(ex)
239
+ model_call = ModelCall.create(
240
+ request=params,
241
+ response={}, # Empty response for error case
242
+ )
243
+ return (result, model_call)
244
+
245
+ @property
246
+ def name(self) -> str:
247
+ """Get provider name."""
248
+ return "goodfire"
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
27
27
  from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
+ from .._call_tools import parse_tool_call
30
31
  from .._chat_message import (
31
32
  ChatMessage,
32
33
  ChatMessageAssistant,
@@ -37,12 +38,15 @@ from .._chat_message import (
37
38
  from .._generate_config import GenerateConfig
38
39
  from .._model import ModelAPI
39
40
  from .._model_call import ModelCall
40
- from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
41
- from .util import (
41
+ from .._model_output import (
42
+ ChatCompletionChoice,
43
+ ModelOutput,
44
+ ModelUsage,
42
45
  as_stop_reason,
46
+ )
47
+ from .util import (
43
48
  environment_prerequisite_error,
44
49
  model_base_url,
45
- parse_tool_call,
46
50
  )
47
51
 
48
52
  GROQ_API_KEY = "GROQ_API_KEY"
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
150
150
  kwargs["output_logits"] = config.logprobs
151
151
  if "return_dict_in_generate" in kwargs:
152
152
  assert kwargs["return_dict_in_generate"]
153
+ if config.stop_seqs is not None:
154
+ from transformers.generation import StopStringCriteria # type: ignore
155
+
156
+ stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
157
+ kwargs["stopping_criteria"] = stopping_criteria
158
+
153
159
  kwargs["return_dict_in_generate"] = True
154
160
  generator = functools.partial(self.model.generate, **kwargs)
155
161
 
@@ -46,6 +46,7 @@ from inspect_ai._util.content import Content, ContentImage, ContentText
46
46
  from inspect_ai._util.images import file_as_data_uri
47
47
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
48
 
49
+ from .._call_tools import parse_tool_call
49
50
  from .._chat_message import (
50
51
  ChatMessage,
51
52
  ChatMessageAssistant,
@@ -59,7 +60,7 @@ from .._model_output import (
59
60
  ModelUsage,
60
61
  StopReason,
61
62
  )
62
- from .util import environment_prerequisite_error, model_base_url, parse_tool_call
63
+ from .util import environment_prerequisite_error, model_base_url
63
64
 
64
65
  AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
65
66
  AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
@@ -1,4 +1,3 @@
1
- import json
2
1
  import os
3
2
  from logging import getLogger
4
3
  from typing import Any
@@ -15,51 +14,39 @@ from openai import (
15
14
  from openai._types import NOT_GIVEN
16
15
  from openai.types.chat import (
17
16
  ChatCompletion,
18
- ChatCompletionAssistantMessageParam,
19
- ChatCompletionContentPartImageParam,
20
- ChatCompletionContentPartInputAudioParam,
21
- ChatCompletionContentPartParam,
22
- ChatCompletionContentPartTextParam,
23
- ChatCompletionDeveloperMessageParam,
24
- ChatCompletionMessage,
25
- ChatCompletionMessageParam,
26
- ChatCompletionMessageToolCallParam,
27
- ChatCompletionNamedToolChoiceParam,
28
- ChatCompletionSystemMessageParam,
29
- ChatCompletionToolChoiceOptionParam,
30
- ChatCompletionToolMessageParam,
31
- ChatCompletionToolParam,
32
- ChatCompletionUserMessageParam,
33
17
  )
34
- from openai.types.shared_params.function_definition import FunctionDefinition
35
18
  from typing_extensions import override
36
19
 
37
20
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
38
- from inspect_ai._util.content import Content
39
21
  from inspect_ai._util.error import PrerequisiteError
40
- from inspect_ai._util.images import file_as_data_uri
41
22
  from inspect_ai._util.logger import warn_once
42
- from inspect_ai._util.url import is_http_url
43
- from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
23
+ from inspect_ai.model._openai import chat_choices_from_openai
24
+ from inspect_ai.tool import ToolChoice, ToolInfo
44
25
 
45
- from .._chat_message import ChatMessage, ChatMessageAssistant
26
+ from .._chat_message import ChatMessage
46
27
  from .._generate_config import GenerateConfig
47
28
  from .._image import image_url_filter
48
29
  from .._model import ModelAPI
49
30
  from .._model_call import ModelCall
50
31
  from .._model_output import (
51
32
  ChatCompletionChoice,
52
- Logprobs,
53
33
  ModelOutput,
54
34
  ModelUsage,
55
35
  StopReason,
56
36
  )
37
+ from .._openai import (
38
+ is_o1,
39
+ is_o1_full,
40
+ is_o1_mini,
41
+ is_o1_preview,
42
+ openai_chat_messages,
43
+ openai_chat_tool_choice,
44
+ openai_chat_tools,
45
+ )
57
46
  from .openai_o1 import generate_o1
58
47
  from .util import (
59
- as_stop_reason,
60
48
  environment_prerequisite_error,
61
49
  model_base_url,
62
- parse_tool_call,
63
50
  )
64
51
 
65
52
  logger = getLogger(__name__)
@@ -87,20 +74,22 @@ class OpenAIAPI(ModelAPI):
87
74
  config=config,
88
75
  )
89
76
 
90
- # pull out azure model_arg
91
- AZURE_MODEL_ARG = "azure"
92
- is_azure = False
93
- if AZURE_MODEL_ARG in model_args:
94
- is_azure = model_args.get(AZURE_MODEL_ARG, False)
95
- del model_args[AZURE_MODEL_ARG]
77
+ # extract any service prefix from model name
78
+ parts = model_name.split("/")
79
+ if len(parts) > 1:
80
+ self.service: str | None = parts[0]
81
+ model_name = "/".join(parts[1:])
82
+ else:
83
+ self.service = None
96
84
 
97
85
  # resolve api_key
98
86
  if not self.api_key:
99
87
  self.api_key = os.environ.get(
100
88
  AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
101
89
  )
102
- if self.api_key:
103
- is_azure = True
90
+ # backward compatibility for when env vars determined service
91
+ if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
92
+ self.service = "azure"
104
93
  else:
105
94
  self.api_key = os.environ.get(OPENAI_API_KEY, None)
106
95
  if not self.api_key:
@@ -113,7 +102,7 @@ class OpenAIAPI(ModelAPI):
113
102
  )
114
103
 
115
104
  # azure client
116
- if is_azure:
105
+ if self.is_azure():
117
106
  # resolve base_url
118
107
  base_url = model_base_url(
119
108
  base_url,
@@ -148,17 +137,20 @@ class OpenAIAPI(ModelAPI):
148
137
  **model_args,
149
138
  )
150
139
 
140
+ def is_azure(self) -> bool:
141
+ return self.service == "azure"
142
+
151
143
  def is_o1(self) -> bool:
152
- return self.model_name.startswith("o1")
144
+ return is_o1(self.model_name)
153
145
 
154
146
  def is_o1_full(self) -> bool:
155
- return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
147
+ return is_o1_full(self.model_name)
156
148
 
157
149
  def is_o1_mini(self) -> bool:
158
- return self.model_name.startswith("o1-mini")
150
+ return is_o1_mini(self.model_name)
159
151
 
160
152
  def is_o1_preview(self) -> bool:
161
- return self.model_name.startswith("o1-preview")
153
+ return is_o1_preview(self.model_name)
162
154
 
163
155
  async def generate(
164
156
  self,
@@ -198,9 +190,11 @@ class OpenAIAPI(ModelAPI):
198
190
 
199
191
  # prepare request (we do this so we can log the ModelCall)
200
192
  request = dict(
201
- messages=await as_openai_chat_messages(input, self.is_o1_full()),
202
- tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
203
- tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
193
+ messages=await openai_chat_messages(input, self.model_name),
194
+ tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
195
+ tool_choice=openai_chat_tool_choice(tool_choice)
196
+ if len(tools) > 0
197
+ else NOT_GIVEN,
204
198
  **self.completion_params(config, len(tools) > 0),
205
199
  )
206
200
 
@@ -237,7 +231,7 @@ class OpenAIAPI(ModelAPI):
237
231
  self, response: ChatCompletion, tools: list[ToolInfo]
238
232
  ) -> list[ChatCompletionChoice]:
239
233
  # adding this as a method so we can override from other classes (e.g together)
240
- return chat_choices_from_response(response, tools)
234
+ return chat_choices_from_openai(response, tools)
241
235
 
242
236
  @override
243
237
  def is_rate_limit(self, ex: BaseException) -> bool:
@@ -327,163 +321,3 @@ class OpenAIAPI(ModelAPI):
327
321
  )
328
322
  else:
329
323
  return e
330
-
331
-
332
- async def as_openai_chat_messages(
333
- messages: list[ChatMessage], o1_full: bool
334
- ) -> list[ChatCompletionMessageParam]:
335
- return [await openai_chat_message(message, o1_full) for message in messages]
336
-
337
-
338
- async def openai_chat_message(
339
- message: ChatMessage, o1_full: bool
340
- ) -> ChatCompletionMessageParam:
341
- if message.role == "system":
342
- if o1_full:
343
- return ChatCompletionDeveloperMessageParam(
344
- role="developer", content=message.text
345
- )
346
- else:
347
- return ChatCompletionSystemMessageParam(
348
- role=message.role, content=message.text
349
- )
350
- elif message.role == "user":
351
- return ChatCompletionUserMessageParam(
352
- role=message.role,
353
- content=(
354
- message.content
355
- if isinstance(message.content, str)
356
- else [
357
- await as_chat_completion_part(content)
358
- for content in message.content
359
- ]
360
- ),
361
- )
362
- elif message.role == "assistant":
363
- if message.tool_calls:
364
- return ChatCompletionAssistantMessageParam(
365
- role=message.role,
366
- content=message.text,
367
- tool_calls=[chat_tool_call(call) for call in message.tool_calls],
368
- )
369
- else:
370
- return ChatCompletionAssistantMessageParam(
371
- role=message.role, content=message.text
372
- )
373
- elif message.role == "tool":
374
- return ChatCompletionToolMessageParam(
375
- role=message.role,
376
- content=(
377
- f"Error: {message.error.message}" if message.error else message.text
378
- ),
379
- tool_call_id=str(message.tool_call_id),
380
- )
381
- else:
382
- raise ValueError(f"Unexpected message role {message.role}")
383
-
384
-
385
- def chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCallParam:
386
- return ChatCompletionMessageToolCallParam(
387
- id=tool_call.id,
388
- function=dict(
389
- name=tool_call.function, arguments=json.dumps(tool_call.arguments)
390
- ),
391
- type=tool_call.type,
392
- )
393
-
394
-
395
- def chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
396
- return [chat_tool_param(tool) for tool in tools]
397
-
398
-
399
- def chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
400
- function = FunctionDefinition(
401
- name=tool.name,
402
- description=tool.description,
403
- parameters=tool.parameters.model_dump(exclude_none=True),
404
- )
405
- return ChatCompletionToolParam(type="function", function=function)
406
-
407
-
408
- def chat_tool_choice(tool_choice: ToolChoice) -> ChatCompletionToolChoiceOptionParam:
409
- if isinstance(tool_choice, ToolFunction):
410
- return ChatCompletionNamedToolChoiceParam(
411
- type="function", function=dict(name=tool_choice.name)
412
- )
413
- # openai supports 'any' via the 'required' keyword
414
- elif tool_choice == "any":
415
- return "required"
416
- else:
417
- return tool_choice
418
-
419
-
420
- def chat_tool_calls(
421
- message: ChatCompletionMessage, tools: list[ToolInfo]
422
- ) -> list[ToolCall] | None:
423
- if message.tool_calls:
424
- return [
425
- parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
426
- for call in message.tool_calls
427
- ]
428
- else:
429
- return None
430
-
431
-
432
- def chat_choices_from_response(
433
- response: ChatCompletion, tools: list[ToolInfo]
434
- ) -> list[ChatCompletionChoice]:
435
- choices = list(response.choices)
436
- choices.sort(key=lambda c: c.index)
437
- return [
438
- ChatCompletionChoice(
439
- message=chat_message_assistant(choice.message, tools),
440
- stop_reason=as_stop_reason(choice.finish_reason),
441
- logprobs=(
442
- Logprobs(**choice.logprobs.model_dump())
443
- if choice.logprobs is not None
444
- else None
445
- ),
446
- )
447
- for choice in choices
448
- ]
449
-
450
-
451
- def chat_message_assistant(
452
- message: ChatCompletionMessage, tools: list[ToolInfo]
453
- ) -> ChatMessageAssistant:
454
- return ChatMessageAssistant(
455
- content=message.content or "",
456
- source="generate",
457
- tool_calls=chat_tool_calls(message, tools),
458
- )
459
-
460
-
461
- async def as_chat_completion_part(
462
- content: Content,
463
- ) -> ChatCompletionContentPartParam:
464
- if content.type == "text":
465
- return ChatCompletionContentPartTextParam(type="text", text=content.text)
466
- elif content.type == "image":
467
- # API takes URL or base64 encoded file. If it's a remote file or
468
- # data URL leave it alone, otherwise encode it
469
- image_url = content.image
470
- detail = content.detail
471
-
472
- if not is_http_url(image_url):
473
- image_url = await file_as_data_uri(image_url)
474
-
475
- return ChatCompletionContentPartImageParam(
476
- type="image_url",
477
- image_url=dict(url=image_url, detail=detail),
478
- )
479
- elif content.type == "audio":
480
- audio_data = await file_as_data_uri(content.audio)
481
-
482
- return ChatCompletionContentPartInputAudioParam(
483
- type="input_audio", input_audio=dict(data=audio_data, format=content.format)
484
- )
485
-
486
- else:
487
- raise RuntimeError(
488
- "Video content is not currently supported by Open AI chat models."
489
- )
@@ -24,15 +24,13 @@ from inspect_ai.model import (
24
24
  )
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
+ from .._call_tools import parse_tool_call, tool_parse_error_message
27
28
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage, StopReason
29
+ from .._model_output import ModelUsage, StopReason, as_stop_reason
29
30
  from .._providers.util import (
30
31
  ChatAPIHandler,
31
32
  ChatAPIMessage,
32
- as_stop_reason,
33
33
  chat_api_input,
34
- parse_tool_call,
35
- tool_parse_error_message,
36
34
  )
37
35
 
38
36
  logger = getLogger(__name__)
@@ -239,6 +239,28 @@ def mockllm() -> type[ModelAPI]:
239
239
  return MockLLM
240
240
 
241
241
 
242
+ @modelapi("goodfire")
243
+ def goodfire() -> type[ModelAPI]:
244
+ """Get the Goodfire API provider."""
245
+ FEATURE = "Goodfire API"
246
+ PACKAGE = "goodfire"
247
+ MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
248
+
249
+ # verify we have the package
250
+ try:
251
+ import goodfire # noqa: F401
252
+ except ImportError:
253
+ raise pip_dependency_error(FEATURE, [PACKAGE])
254
+
255
+ # verify version
256
+ verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
257
+
258
+ # in the clear
259
+ from .goodfire import GoodfireAPI
260
+
261
+ return GoodfireAPI
262
+
263
+
242
264
  def validate_openai_client(feature: str) -> None:
243
265
  FEATURE = feature
244
266
  PACKAGE = "openai"