inspect-ai 0.3.59__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -7
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +11 -11
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -11
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +2 -0
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +68 -63
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any, List, Literal, get_args
|
3
|
+
|
4
|
+
from goodfire import AsyncClient
|
5
|
+
from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage
|
6
|
+
from goodfire.api.exceptions import InvalidRequestException, RateLimitException
|
7
|
+
from goodfire.variants.variants import SUPPORTED_MODELS, Variant
|
8
|
+
from typing_extensions import override
|
9
|
+
|
10
|
+
from inspect_ai.tool._tool_choice import ToolChoice
|
11
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
12
|
+
|
13
|
+
from .._chat_message import (
|
14
|
+
ChatMessage,
|
15
|
+
ChatMessageAssistant,
|
16
|
+
ChatMessageSystem,
|
17
|
+
ChatMessageTool,
|
18
|
+
ChatMessageUser,
|
19
|
+
)
|
20
|
+
from .._generate_config import GenerateConfig
|
21
|
+
from .._model import ModelAPI
|
22
|
+
from .._model_call import ModelCall
|
23
|
+
from .._model_output import (
|
24
|
+
ChatCompletionChoice,
|
25
|
+
ModelOutput,
|
26
|
+
ModelUsage,
|
27
|
+
)
|
28
|
+
from .util import environment_prerequisite_error, model_base_url
|
29
|
+
|
30
|
+
# Constants
|
31
|
+
GOODFIRE_API_KEY = "GOODFIRE_API_KEY"
|
32
|
+
DEFAULT_BASE_URL = "https://api.goodfire.ai"
|
33
|
+
DEFAULT_MAX_TOKENS = 4096
|
34
|
+
DEFAULT_TEMPERATURE = 1.0 # Standard sampling temperature (baseline)
|
35
|
+
DEFAULT_TOP_P = 1.0 # No nucleus sampling truncation (baseline)
|
36
|
+
|
37
|
+
|
38
|
+
class GoodfireAPI(ModelAPI):
|
39
|
+
"""Goodfire API provider.
|
40
|
+
|
41
|
+
This provider implements the Goodfire API for LLM inference. It supports:
|
42
|
+
- Chat completions with standard message formats
|
43
|
+
- Basic parameter controls (temperature, top_p, etc.)
|
44
|
+
- Usage statistics tracking
|
45
|
+
- Stop reason handling
|
46
|
+
|
47
|
+
Does not currently support:
|
48
|
+
- Tool calls
|
49
|
+
- Feature analysis
|
50
|
+
- Streaming responses
|
51
|
+
|
52
|
+
Known limitations:
|
53
|
+
- Limited role support (system/user/assistant only)
|
54
|
+
- Tool messages converted to user messages
|
55
|
+
"""
|
56
|
+
|
57
|
+
client: AsyncClient
|
58
|
+
variant: Variant
|
59
|
+
model_args: dict[str, Any]
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
model_name: str,
|
64
|
+
base_url: str | None = None,
|
65
|
+
api_key: str | None = None,
|
66
|
+
config: GenerateConfig = GenerateConfig(),
|
67
|
+
**model_args: Any,
|
68
|
+
) -> None:
|
69
|
+
"""Initialize the Goodfire API provider.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
model_name: Name of the model to use
|
73
|
+
base_url: Optional custom API base URL
|
74
|
+
api_key: Optional API key (will check env vars if not provided)
|
75
|
+
config: Generation config options
|
76
|
+
**model_args: Additional arguments passed to the API
|
77
|
+
"""
|
78
|
+
super().__init__(
|
79
|
+
model_name=model_name,
|
80
|
+
base_url=base_url,
|
81
|
+
api_key=api_key,
|
82
|
+
api_key_vars=[GOODFIRE_API_KEY],
|
83
|
+
config=config,
|
84
|
+
)
|
85
|
+
|
86
|
+
# resolve api_key
|
87
|
+
if not self.api_key:
|
88
|
+
self.api_key = os.environ.get(GOODFIRE_API_KEY)
|
89
|
+
if not self.api_key:
|
90
|
+
raise environment_prerequisite_error("Goodfire", GOODFIRE_API_KEY)
|
91
|
+
|
92
|
+
# Validate model name against supported models
|
93
|
+
supported_models = list(get_args(SUPPORTED_MODELS))
|
94
|
+
if self.model_name not in supported_models:
|
95
|
+
raise ValueError(
|
96
|
+
f"Model {self.model_name} not supported. Supported models: {supported_models}"
|
97
|
+
)
|
98
|
+
|
99
|
+
# Initialize client with minimal configuration
|
100
|
+
base_url_val = model_base_url(base_url, "GOODFIRE_BASE_URL")
|
101
|
+
assert isinstance(base_url_val, str) or base_url_val is None
|
102
|
+
|
103
|
+
# Store model args for use in generate
|
104
|
+
self.model_args = model_args
|
105
|
+
|
106
|
+
self.client = AsyncClient(
|
107
|
+
api_key=self.api_key,
|
108
|
+
base_url=base_url_val or DEFAULT_BASE_URL,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Initialize variant directly with model name
|
112
|
+
self.variant = Variant(self.model_name) # type: ignore
|
113
|
+
|
114
|
+
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
115
|
+
"""Convert an Inspect message to a Goodfire message format.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
message: The message to convert
|
119
|
+
|
120
|
+
Returns:
|
121
|
+
The converted message in Goodfire format
|
122
|
+
|
123
|
+
Raises:
|
124
|
+
ValueError: If the message type is unknown
|
125
|
+
"""
|
126
|
+
role: Literal["system", "user", "assistant"] = "user"
|
127
|
+
if isinstance(message, ChatMessageSystem):
|
128
|
+
role = "system"
|
129
|
+
elif isinstance(message, ChatMessageUser):
|
130
|
+
role = "user"
|
131
|
+
elif isinstance(message, ChatMessageAssistant):
|
132
|
+
role = "assistant"
|
133
|
+
elif isinstance(message, ChatMessageTool):
|
134
|
+
role = "user" # Convert tool messages to user messages
|
135
|
+
else:
|
136
|
+
raise ValueError(f"Unknown message type: {type(message)}")
|
137
|
+
|
138
|
+
content = str(message.content)
|
139
|
+
if isinstance(message, ChatMessageTool):
|
140
|
+
content = f"Tool {message.function}: {content}"
|
141
|
+
|
142
|
+
return GoodfireChatMessage(role=role, content=content)
|
143
|
+
|
144
|
+
def handle_error(self, ex: Exception) -> ModelOutput | Exception:
|
145
|
+
"""Handle only errors that need special treatment for retry logic or model limits."""
|
146
|
+
# Handle token/context length errors
|
147
|
+
if isinstance(ex, InvalidRequestException):
|
148
|
+
error_msg = str(ex).lower()
|
149
|
+
if "context length" in error_msg or "max tokens" in error_msg:
|
150
|
+
return ModelOutput.from_content(
|
151
|
+
model=self.model_name,
|
152
|
+
content=str(ex),
|
153
|
+
stop_reason="model_length",
|
154
|
+
error=error_msg,
|
155
|
+
)
|
156
|
+
|
157
|
+
# Let all other errors propagate
|
158
|
+
return ex
|
159
|
+
|
160
|
+
@override
|
161
|
+
def is_rate_limit(self, ex: BaseException) -> bool:
|
162
|
+
"""Check if exception is due to rate limiting."""
|
163
|
+
return isinstance(ex, RateLimitException)
|
164
|
+
|
165
|
+
@override
|
166
|
+
def connection_key(self) -> str:
|
167
|
+
"""Return key for connection pooling."""
|
168
|
+
return f"goodfire:{self.api_key}"
|
169
|
+
|
170
|
+
@override
|
171
|
+
def max_tokens(self) -> int | None:
|
172
|
+
"""Return maximum tokens supported by model."""
|
173
|
+
return DEFAULT_MAX_TOKENS # Let Goodfire's Variant handle model-specific limits
|
174
|
+
|
175
|
+
async def generate(
|
176
|
+
self,
|
177
|
+
input: List[ChatMessage],
|
178
|
+
tools: List[ToolInfo],
|
179
|
+
tool_choice: ToolChoice,
|
180
|
+
config: GenerateConfig,
|
181
|
+
*,
|
182
|
+
cache: bool = True,
|
183
|
+
) -> tuple[ModelOutput | Exception, ModelCall]:
|
184
|
+
"""Generate output from the model."""
|
185
|
+
# Convert messages and prepare request params
|
186
|
+
messages = [self._to_goodfire_message(msg) for msg in input]
|
187
|
+
# Build request parameters with type hints
|
188
|
+
params: dict[str, Any] = {
|
189
|
+
"model": self.variant.base_model, # Use base_model instead of stringifying the Variant
|
190
|
+
"messages": messages,
|
191
|
+
"max_completion_tokens": int(config.max_tokens)
|
192
|
+
if config.max_tokens
|
193
|
+
else DEFAULT_MAX_TOKENS,
|
194
|
+
"stream": False,
|
195
|
+
}
|
196
|
+
|
197
|
+
# Add generation parameters from config if not in model_args
|
198
|
+
if "temperature" not in self.model_args and config.temperature is not None:
|
199
|
+
params["temperature"] = float(config.temperature)
|
200
|
+
elif "temperature" not in self.model_args:
|
201
|
+
params["temperature"] = DEFAULT_TEMPERATURE
|
202
|
+
|
203
|
+
if "top_p" not in self.model_args and config.top_p is not None:
|
204
|
+
params["top_p"] = float(config.top_p)
|
205
|
+
elif "top_p" not in self.model_args:
|
206
|
+
params["top_p"] = DEFAULT_TOP_P
|
207
|
+
|
208
|
+
# Add any additional model args (highest priority)
|
209
|
+
api_params = {
|
210
|
+
k: v
|
211
|
+
for k, v in self.model_args.items()
|
212
|
+
if k not in ["api_key", "base_url", "model_args"]
|
213
|
+
}
|
214
|
+
params.update(api_params)
|
215
|
+
|
216
|
+
try:
|
217
|
+
# Use native async client
|
218
|
+
response = await self.client.chat.completions.create(**params)
|
219
|
+
response_dict = response.model_dump()
|
220
|
+
|
221
|
+
output = ModelOutput(
|
222
|
+
model=self.model_name,
|
223
|
+
choices=[
|
224
|
+
ChatCompletionChoice(
|
225
|
+
message=ChatMessageAssistant(
|
226
|
+
content=response_dict["choices"][0]["message"]["content"]
|
227
|
+
),
|
228
|
+
stop_reason="stop",
|
229
|
+
)
|
230
|
+
],
|
231
|
+
usage=ModelUsage(**response_dict["usage"])
|
232
|
+
if "usage" in response_dict
|
233
|
+
else None,
|
234
|
+
)
|
235
|
+
model_call = ModelCall.create(request=params, response=response_dict)
|
236
|
+
return (output, model_call)
|
237
|
+
except Exception as ex:
|
238
|
+
result = self.handle_error(ex)
|
239
|
+
model_call = ModelCall.create(
|
240
|
+
request=params,
|
241
|
+
response={}, # Empty response for error case
|
242
|
+
)
|
243
|
+
return (result, model_call)
|
244
|
+
|
245
|
+
@property
|
246
|
+
def name(self) -> str:
|
247
|
+
"""Get provider name."""
|
248
|
+
return "goodfire"
|
@@ -27,6 +27,7 @@ from inspect_ai._util.images import file_as_data_uri
|
|
27
27
|
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
|
+
from .._call_tools import parse_tool_call
|
30
31
|
from .._chat_message import (
|
31
32
|
ChatMessage,
|
32
33
|
ChatMessageAssistant,
|
@@ -37,12 +38,15 @@ from .._chat_message import (
|
|
37
38
|
from .._generate_config import GenerateConfig
|
38
39
|
from .._model import ModelAPI
|
39
40
|
from .._model_call import ModelCall
|
40
|
-
from .._model_output import
|
41
|
-
|
41
|
+
from .._model_output import (
|
42
|
+
ChatCompletionChoice,
|
43
|
+
ModelOutput,
|
44
|
+
ModelUsage,
|
42
45
|
as_stop_reason,
|
46
|
+
)
|
47
|
+
from .util import (
|
43
48
|
environment_prerequisite_error,
|
44
49
|
model_base_url,
|
45
|
-
parse_tool_call,
|
46
50
|
)
|
47
51
|
|
48
52
|
GROQ_API_KEY = "GROQ_API_KEY"
|
@@ -150,6 +150,12 @@ class HuggingFaceAPI(ModelAPI):
|
|
150
150
|
kwargs["output_logits"] = config.logprobs
|
151
151
|
if "return_dict_in_generate" in kwargs:
|
152
152
|
assert kwargs["return_dict_in_generate"]
|
153
|
+
if config.stop_seqs is not None:
|
154
|
+
from transformers.generation import StopStringCriteria # type: ignore
|
155
|
+
|
156
|
+
stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
|
157
|
+
kwargs["stopping_criteria"] = stopping_criteria
|
158
|
+
|
153
159
|
kwargs["return_dict_in_generate"] = True
|
154
160
|
generator = functools.partial(self.model.generate, **kwargs)
|
155
161
|
|
@@ -46,6 +46,7 @@ from inspect_ai._util.content import Content, ContentImage, ContentText
|
|
46
46
|
from inspect_ai._util.images import file_as_data_uri
|
47
47
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
48
|
|
49
|
+
from .._call_tools import parse_tool_call
|
49
50
|
from .._chat_message import (
|
50
51
|
ChatMessage,
|
51
52
|
ChatMessageAssistant,
|
@@ -59,7 +60,7 @@ from .._model_output import (
|
|
59
60
|
ModelUsage,
|
60
61
|
StopReason,
|
61
62
|
)
|
62
|
-
from .util import environment_prerequisite_error, model_base_url
|
63
|
+
from .util import environment_prerequisite_error, model_base_url
|
63
64
|
|
64
65
|
AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY"
|
65
66
|
AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY"
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import json
|
2
1
|
import os
|
3
2
|
from logging import getLogger
|
4
3
|
from typing import Any
|
@@ -15,51 +14,39 @@ from openai import (
|
|
15
14
|
from openai._types import NOT_GIVEN
|
16
15
|
from openai.types.chat import (
|
17
16
|
ChatCompletion,
|
18
|
-
ChatCompletionAssistantMessageParam,
|
19
|
-
ChatCompletionContentPartImageParam,
|
20
|
-
ChatCompletionContentPartInputAudioParam,
|
21
|
-
ChatCompletionContentPartParam,
|
22
|
-
ChatCompletionContentPartTextParam,
|
23
|
-
ChatCompletionDeveloperMessageParam,
|
24
|
-
ChatCompletionMessage,
|
25
|
-
ChatCompletionMessageParam,
|
26
|
-
ChatCompletionMessageToolCallParam,
|
27
|
-
ChatCompletionNamedToolChoiceParam,
|
28
|
-
ChatCompletionSystemMessageParam,
|
29
|
-
ChatCompletionToolChoiceOptionParam,
|
30
|
-
ChatCompletionToolMessageParam,
|
31
|
-
ChatCompletionToolParam,
|
32
|
-
ChatCompletionUserMessageParam,
|
33
17
|
)
|
34
|
-
from openai.types.shared_params.function_definition import FunctionDefinition
|
35
18
|
from typing_extensions import override
|
36
19
|
|
37
20
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
38
|
-
from inspect_ai._util.content import Content
|
39
21
|
from inspect_ai._util.error import PrerequisiteError
|
40
|
-
from inspect_ai._util.images import file_as_data_uri
|
41
22
|
from inspect_ai._util.logger import warn_once
|
42
|
-
from inspect_ai.
|
43
|
-
from inspect_ai.tool import
|
23
|
+
from inspect_ai.model._openai import chat_choices_from_openai
|
24
|
+
from inspect_ai.tool import ToolChoice, ToolInfo
|
44
25
|
|
45
|
-
from .._chat_message import ChatMessage
|
26
|
+
from .._chat_message import ChatMessage
|
46
27
|
from .._generate_config import GenerateConfig
|
47
28
|
from .._image import image_url_filter
|
48
29
|
from .._model import ModelAPI
|
49
30
|
from .._model_call import ModelCall
|
50
31
|
from .._model_output import (
|
51
32
|
ChatCompletionChoice,
|
52
|
-
Logprobs,
|
53
33
|
ModelOutput,
|
54
34
|
ModelUsage,
|
55
35
|
StopReason,
|
56
36
|
)
|
37
|
+
from .._openai import (
|
38
|
+
is_o1,
|
39
|
+
is_o1_full,
|
40
|
+
is_o1_mini,
|
41
|
+
is_o1_preview,
|
42
|
+
openai_chat_messages,
|
43
|
+
openai_chat_tool_choice,
|
44
|
+
openai_chat_tools,
|
45
|
+
)
|
57
46
|
from .openai_o1 import generate_o1
|
58
47
|
from .util import (
|
59
|
-
as_stop_reason,
|
60
48
|
environment_prerequisite_error,
|
61
49
|
model_base_url,
|
62
|
-
parse_tool_call,
|
63
50
|
)
|
64
51
|
|
65
52
|
logger = getLogger(__name__)
|
@@ -87,20 +74,22 @@ class OpenAIAPI(ModelAPI):
|
|
87
74
|
config=config,
|
88
75
|
)
|
89
76
|
|
90
|
-
#
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
77
|
+
# extract any service prefix from model name
|
78
|
+
parts = model_name.split("/")
|
79
|
+
if len(parts) > 1:
|
80
|
+
self.service: str | None = parts[0]
|
81
|
+
model_name = "/".join(parts[1:])
|
82
|
+
else:
|
83
|
+
self.service = None
|
96
84
|
|
97
85
|
# resolve api_key
|
98
86
|
if not self.api_key:
|
99
87
|
self.api_key = os.environ.get(
|
100
88
|
AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
|
101
89
|
)
|
102
|
-
|
103
|
-
|
90
|
+
# backward compatibility for when env vars determined service
|
91
|
+
if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
|
92
|
+
self.service = "azure"
|
104
93
|
else:
|
105
94
|
self.api_key = os.environ.get(OPENAI_API_KEY, None)
|
106
95
|
if not self.api_key:
|
@@ -113,7 +102,7 @@ class OpenAIAPI(ModelAPI):
|
|
113
102
|
)
|
114
103
|
|
115
104
|
# azure client
|
116
|
-
if is_azure:
|
105
|
+
if self.is_azure():
|
117
106
|
# resolve base_url
|
118
107
|
base_url = model_base_url(
|
119
108
|
base_url,
|
@@ -148,17 +137,20 @@ class OpenAIAPI(ModelAPI):
|
|
148
137
|
**model_args,
|
149
138
|
)
|
150
139
|
|
140
|
+
def is_azure(self) -> bool:
|
141
|
+
return self.service == "azure"
|
142
|
+
|
151
143
|
def is_o1(self) -> bool:
|
152
|
-
return self.model_name
|
144
|
+
return is_o1(self.model_name)
|
153
145
|
|
154
146
|
def is_o1_full(self) -> bool:
|
155
|
-
return
|
147
|
+
return is_o1_full(self.model_name)
|
156
148
|
|
157
149
|
def is_o1_mini(self) -> bool:
|
158
|
-
return self.model_name
|
150
|
+
return is_o1_mini(self.model_name)
|
159
151
|
|
160
152
|
def is_o1_preview(self) -> bool:
|
161
|
-
return self.model_name
|
153
|
+
return is_o1_preview(self.model_name)
|
162
154
|
|
163
155
|
async def generate(
|
164
156
|
self,
|
@@ -198,9 +190,11 @@ class OpenAIAPI(ModelAPI):
|
|
198
190
|
|
199
191
|
# prepare request (we do this so we can log the ModelCall)
|
200
192
|
request = dict(
|
201
|
-
messages=await
|
202
|
-
tools=
|
203
|
-
tool_choice=
|
193
|
+
messages=await openai_chat_messages(input, self.model_name),
|
194
|
+
tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
195
|
+
tool_choice=openai_chat_tool_choice(tool_choice)
|
196
|
+
if len(tools) > 0
|
197
|
+
else NOT_GIVEN,
|
204
198
|
**self.completion_params(config, len(tools) > 0),
|
205
199
|
)
|
206
200
|
|
@@ -237,7 +231,7 @@ class OpenAIAPI(ModelAPI):
|
|
237
231
|
self, response: ChatCompletion, tools: list[ToolInfo]
|
238
232
|
) -> list[ChatCompletionChoice]:
|
239
233
|
# adding this as a method so we can override from other classes (e.g together)
|
240
|
-
return
|
234
|
+
return chat_choices_from_openai(response, tools)
|
241
235
|
|
242
236
|
@override
|
243
237
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -327,163 +321,3 @@ class OpenAIAPI(ModelAPI):
|
|
327
321
|
)
|
328
322
|
else:
|
329
323
|
return e
|
330
|
-
|
331
|
-
|
332
|
-
async def as_openai_chat_messages(
|
333
|
-
messages: list[ChatMessage], o1_full: bool
|
334
|
-
) -> list[ChatCompletionMessageParam]:
|
335
|
-
return [await openai_chat_message(message, o1_full) for message in messages]
|
336
|
-
|
337
|
-
|
338
|
-
async def openai_chat_message(
|
339
|
-
message: ChatMessage, o1_full: bool
|
340
|
-
) -> ChatCompletionMessageParam:
|
341
|
-
if message.role == "system":
|
342
|
-
if o1_full:
|
343
|
-
return ChatCompletionDeveloperMessageParam(
|
344
|
-
role="developer", content=message.text
|
345
|
-
)
|
346
|
-
else:
|
347
|
-
return ChatCompletionSystemMessageParam(
|
348
|
-
role=message.role, content=message.text
|
349
|
-
)
|
350
|
-
elif message.role == "user":
|
351
|
-
return ChatCompletionUserMessageParam(
|
352
|
-
role=message.role,
|
353
|
-
content=(
|
354
|
-
message.content
|
355
|
-
if isinstance(message.content, str)
|
356
|
-
else [
|
357
|
-
await as_chat_completion_part(content)
|
358
|
-
for content in message.content
|
359
|
-
]
|
360
|
-
),
|
361
|
-
)
|
362
|
-
elif message.role == "assistant":
|
363
|
-
if message.tool_calls:
|
364
|
-
return ChatCompletionAssistantMessageParam(
|
365
|
-
role=message.role,
|
366
|
-
content=message.text,
|
367
|
-
tool_calls=[chat_tool_call(call) for call in message.tool_calls],
|
368
|
-
)
|
369
|
-
else:
|
370
|
-
return ChatCompletionAssistantMessageParam(
|
371
|
-
role=message.role, content=message.text
|
372
|
-
)
|
373
|
-
elif message.role == "tool":
|
374
|
-
return ChatCompletionToolMessageParam(
|
375
|
-
role=message.role,
|
376
|
-
content=(
|
377
|
-
f"Error: {message.error.message}" if message.error else message.text
|
378
|
-
),
|
379
|
-
tool_call_id=str(message.tool_call_id),
|
380
|
-
)
|
381
|
-
else:
|
382
|
-
raise ValueError(f"Unexpected message role {message.role}")
|
383
|
-
|
384
|
-
|
385
|
-
def chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCallParam:
|
386
|
-
return ChatCompletionMessageToolCallParam(
|
387
|
-
id=tool_call.id,
|
388
|
-
function=dict(
|
389
|
-
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
390
|
-
),
|
391
|
-
type=tool_call.type,
|
392
|
-
)
|
393
|
-
|
394
|
-
|
395
|
-
def chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
396
|
-
return [chat_tool_param(tool) for tool in tools]
|
397
|
-
|
398
|
-
|
399
|
-
def chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
|
400
|
-
function = FunctionDefinition(
|
401
|
-
name=tool.name,
|
402
|
-
description=tool.description,
|
403
|
-
parameters=tool.parameters.model_dump(exclude_none=True),
|
404
|
-
)
|
405
|
-
return ChatCompletionToolParam(type="function", function=function)
|
406
|
-
|
407
|
-
|
408
|
-
def chat_tool_choice(tool_choice: ToolChoice) -> ChatCompletionToolChoiceOptionParam:
|
409
|
-
if isinstance(tool_choice, ToolFunction):
|
410
|
-
return ChatCompletionNamedToolChoiceParam(
|
411
|
-
type="function", function=dict(name=tool_choice.name)
|
412
|
-
)
|
413
|
-
# openai supports 'any' via the 'required' keyword
|
414
|
-
elif tool_choice == "any":
|
415
|
-
return "required"
|
416
|
-
else:
|
417
|
-
return tool_choice
|
418
|
-
|
419
|
-
|
420
|
-
def chat_tool_calls(
|
421
|
-
message: ChatCompletionMessage, tools: list[ToolInfo]
|
422
|
-
) -> list[ToolCall] | None:
|
423
|
-
if message.tool_calls:
|
424
|
-
return [
|
425
|
-
parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
|
426
|
-
for call in message.tool_calls
|
427
|
-
]
|
428
|
-
else:
|
429
|
-
return None
|
430
|
-
|
431
|
-
|
432
|
-
def chat_choices_from_response(
|
433
|
-
response: ChatCompletion, tools: list[ToolInfo]
|
434
|
-
) -> list[ChatCompletionChoice]:
|
435
|
-
choices = list(response.choices)
|
436
|
-
choices.sort(key=lambda c: c.index)
|
437
|
-
return [
|
438
|
-
ChatCompletionChoice(
|
439
|
-
message=chat_message_assistant(choice.message, tools),
|
440
|
-
stop_reason=as_stop_reason(choice.finish_reason),
|
441
|
-
logprobs=(
|
442
|
-
Logprobs(**choice.logprobs.model_dump())
|
443
|
-
if choice.logprobs is not None
|
444
|
-
else None
|
445
|
-
),
|
446
|
-
)
|
447
|
-
for choice in choices
|
448
|
-
]
|
449
|
-
|
450
|
-
|
451
|
-
def chat_message_assistant(
|
452
|
-
message: ChatCompletionMessage, tools: list[ToolInfo]
|
453
|
-
) -> ChatMessageAssistant:
|
454
|
-
return ChatMessageAssistant(
|
455
|
-
content=message.content or "",
|
456
|
-
source="generate",
|
457
|
-
tool_calls=chat_tool_calls(message, tools),
|
458
|
-
)
|
459
|
-
|
460
|
-
|
461
|
-
async def as_chat_completion_part(
|
462
|
-
content: Content,
|
463
|
-
) -> ChatCompletionContentPartParam:
|
464
|
-
if content.type == "text":
|
465
|
-
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
466
|
-
elif content.type == "image":
|
467
|
-
# API takes URL or base64 encoded file. If it's a remote file or
|
468
|
-
# data URL leave it alone, otherwise encode it
|
469
|
-
image_url = content.image
|
470
|
-
detail = content.detail
|
471
|
-
|
472
|
-
if not is_http_url(image_url):
|
473
|
-
image_url = await file_as_data_uri(image_url)
|
474
|
-
|
475
|
-
return ChatCompletionContentPartImageParam(
|
476
|
-
type="image_url",
|
477
|
-
image_url=dict(url=image_url, detail=detail),
|
478
|
-
)
|
479
|
-
elif content.type == "audio":
|
480
|
-
audio_data = await file_as_data_uri(content.audio)
|
481
|
-
|
482
|
-
return ChatCompletionContentPartInputAudioParam(
|
483
|
-
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
484
|
-
)
|
485
|
-
|
486
|
-
else:
|
487
|
-
raise RuntimeError(
|
488
|
-
"Video content is not currently supported by Open AI chat models."
|
489
|
-
)
|
@@ -24,15 +24,13 @@ from inspect_ai.model import (
|
|
24
24
|
)
|
25
25
|
from inspect_ai.tool import ToolCall, ToolInfo
|
26
26
|
|
27
|
+
from .._call_tools import parse_tool_call, tool_parse_error_message
|
27
28
|
from .._model_call import ModelCall
|
28
|
-
from .._model_output import ModelUsage, StopReason
|
29
|
+
from .._model_output import ModelUsage, StopReason, as_stop_reason
|
29
30
|
from .._providers.util import (
|
30
31
|
ChatAPIHandler,
|
31
32
|
ChatAPIMessage,
|
32
|
-
as_stop_reason,
|
33
33
|
chat_api_input,
|
34
|
-
parse_tool_call,
|
35
|
-
tool_parse_error_message,
|
36
34
|
)
|
37
35
|
|
38
36
|
logger = getLogger(__name__)
|
@@ -239,6 +239,28 @@ def mockllm() -> type[ModelAPI]:
|
|
239
239
|
return MockLLM
|
240
240
|
|
241
241
|
|
242
|
+
@modelapi("goodfire")
|
243
|
+
def goodfire() -> type[ModelAPI]:
|
244
|
+
"""Get the Goodfire API provider."""
|
245
|
+
FEATURE = "Goodfire API"
|
246
|
+
PACKAGE = "goodfire"
|
247
|
+
MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
|
248
|
+
|
249
|
+
# verify we have the package
|
250
|
+
try:
|
251
|
+
import goodfire # noqa: F401
|
252
|
+
except ImportError:
|
253
|
+
raise pip_dependency_error(FEATURE, [PACKAGE])
|
254
|
+
|
255
|
+
# verify version
|
256
|
+
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
|
257
|
+
|
258
|
+
# in the clear
|
259
|
+
from .goodfire import GoodfireAPI
|
260
|
+
|
261
|
+
return GoodfireAPI
|
262
|
+
|
263
|
+
|
242
264
|
def validate_openai_client(feature: str) -> None:
|
243
265
|
FEATURE = feature
|
244
266
|
PACKAGE = "openai"
|