inspect-ai 0.3.58__py3-none-any.whl → 0.3.60__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/common.py +3 -1
- inspect_ai/_cli/eval.py +15 -9
- inspect_ai/_display/core/active.py +4 -1
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +0 -5
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +79 -12
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +3 -1
- inspect_ai/_eval/task/results.py +51 -22
- inspect_ai/_eval/task/run.py +47 -13
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25498 -2044
- inspect_ai/_view/www/log-schema.json +32 -2
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +14 -16
- inspect_ai/_view/www/src/Types.mjs +1 -2
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +77 -4
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +4 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +19 -10
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +19 -49
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +38 -26
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +14 -11
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +13 -2
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +13 -9
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +11 -5
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_samples.py +16 -0
- inspect_ai/log/_transcript.py +4 -1
- inspect_ai/model/_call_tools.py +59 -0
- inspect_ai/model/_conversation.py +16 -7
- inspect_ai/model/_generate_config.py +12 -12
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +22 -2
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +152 -55
- inspect_ai/model/_providers/azureai.py +21 -21
- inspect_ai/model/_providers/bedrock.py +37 -40
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/google.py +46 -54
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +13 -12
- inspect_ai/model/_providers/openai.py +51 -218
- inspect_ai/model/_providers/openai_o1.py +11 -12
- inspect_ai/model/_providers/providers.py +23 -1
- inspect_ai/model/_providers/together.py +12 -12
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/model/_providers/vertex.py +1 -4
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +4 -3
- inspect_ai/solver/__init__.py +4 -5
- inspect_ai/solver/_basic_agent.py +1 -1
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +12 -1
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/docker/docker.py +64 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/environment.py +14 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/RECORD +159 -126
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.58.dist-info → inspect_ai-0.3.60.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,3 @@
|
|
1
|
-
import json
|
2
1
|
import os
|
3
2
|
from logging import getLogger
|
4
3
|
from typing import Any
|
@@ -15,51 +14,39 @@ from openai import (
|
|
15
14
|
from openai._types import NOT_GIVEN
|
16
15
|
from openai.types.chat import (
|
17
16
|
ChatCompletion,
|
18
|
-
ChatCompletionAssistantMessageParam,
|
19
|
-
ChatCompletionContentPartImageParam,
|
20
|
-
ChatCompletionContentPartInputAudioParam,
|
21
|
-
ChatCompletionContentPartParam,
|
22
|
-
ChatCompletionContentPartTextParam,
|
23
|
-
ChatCompletionDeveloperMessageParam,
|
24
|
-
ChatCompletionMessage,
|
25
|
-
ChatCompletionMessageParam,
|
26
|
-
ChatCompletionMessageToolCallParam,
|
27
|
-
ChatCompletionNamedToolChoiceParam,
|
28
|
-
ChatCompletionSystemMessageParam,
|
29
|
-
ChatCompletionToolChoiceOptionParam,
|
30
|
-
ChatCompletionToolMessageParam,
|
31
|
-
ChatCompletionToolParam,
|
32
|
-
ChatCompletionUserMessageParam,
|
33
17
|
)
|
34
|
-
from openai.types.shared_params.function_definition import FunctionDefinition
|
35
18
|
from typing_extensions import override
|
36
19
|
|
37
20
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
38
|
-
from inspect_ai._util.content import Content
|
39
21
|
from inspect_ai._util.error import PrerequisiteError
|
40
|
-
from inspect_ai._util.images import file_as_data_uri
|
41
22
|
from inspect_ai._util.logger import warn_once
|
42
|
-
from inspect_ai.
|
43
|
-
from inspect_ai.tool import
|
23
|
+
from inspect_ai.model._openai import chat_choices_from_openai
|
24
|
+
from inspect_ai.tool import ToolChoice, ToolInfo
|
44
25
|
|
45
|
-
from .._chat_message import ChatMessage
|
26
|
+
from .._chat_message import ChatMessage
|
46
27
|
from .._generate_config import GenerateConfig
|
47
28
|
from .._image import image_url_filter
|
48
29
|
from .._model import ModelAPI
|
49
30
|
from .._model_call import ModelCall
|
50
31
|
from .._model_output import (
|
51
32
|
ChatCompletionChoice,
|
52
|
-
Logprobs,
|
53
33
|
ModelOutput,
|
54
34
|
ModelUsage,
|
55
35
|
StopReason,
|
56
36
|
)
|
37
|
+
from .._openai import (
|
38
|
+
is_o1,
|
39
|
+
is_o1_full,
|
40
|
+
is_o1_mini,
|
41
|
+
is_o1_preview,
|
42
|
+
openai_chat_messages,
|
43
|
+
openai_chat_tool_choice,
|
44
|
+
openai_chat_tools,
|
45
|
+
)
|
57
46
|
from .openai_o1 import generate_o1
|
58
47
|
from .util import (
|
59
|
-
as_stop_reason,
|
60
48
|
environment_prerequisite_error,
|
61
49
|
model_base_url,
|
62
|
-
parse_tool_call,
|
63
50
|
)
|
64
51
|
|
65
52
|
logger = getLogger(__name__)
|
@@ -87,20 +74,22 @@ class OpenAIAPI(ModelAPI):
|
|
87
74
|
config=config,
|
88
75
|
)
|
89
76
|
|
90
|
-
#
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
77
|
+
# extract any service prefix from model name
|
78
|
+
parts = model_name.split("/")
|
79
|
+
if len(parts) > 1:
|
80
|
+
self.service: str | None = parts[0]
|
81
|
+
model_name = "/".join(parts[1:])
|
82
|
+
else:
|
83
|
+
self.service = None
|
96
84
|
|
97
85
|
# resolve api_key
|
98
86
|
if not self.api_key:
|
99
87
|
self.api_key = os.environ.get(
|
100
88
|
AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
|
101
89
|
)
|
102
|
-
|
103
|
-
|
90
|
+
# backward compatibility for when env vars determined service
|
91
|
+
if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
|
92
|
+
self.service = "azure"
|
104
93
|
else:
|
105
94
|
self.api_key = os.environ.get(OPENAI_API_KEY, None)
|
106
95
|
if not self.api_key:
|
@@ -113,7 +102,7 @@ class OpenAIAPI(ModelAPI):
|
|
113
102
|
)
|
114
103
|
|
115
104
|
# azure client
|
116
|
-
if is_azure:
|
105
|
+
if self.is_azure():
|
117
106
|
# resolve base_url
|
118
107
|
base_url = model_base_url(
|
119
108
|
base_url,
|
@@ -148,17 +137,20 @@ class OpenAIAPI(ModelAPI):
|
|
148
137
|
**model_args,
|
149
138
|
)
|
150
139
|
|
140
|
+
def is_azure(self) -> bool:
|
141
|
+
return self.service == "azure"
|
142
|
+
|
151
143
|
def is_o1(self) -> bool:
|
152
|
-
return self.model_name
|
144
|
+
return is_o1(self.model_name)
|
153
145
|
|
154
146
|
def is_o1_full(self) -> bool:
|
155
|
-
return
|
147
|
+
return is_o1_full(self.model_name)
|
156
148
|
|
157
149
|
def is_o1_mini(self) -> bool:
|
158
|
-
return self.model_name
|
150
|
+
return is_o1_mini(self.model_name)
|
159
151
|
|
160
152
|
def is_o1_preview(self) -> bool:
|
161
|
-
return self.model_name
|
153
|
+
return is_o1_preview(self.model_name)
|
162
154
|
|
163
155
|
async def generate(
|
164
156
|
self,
|
@@ -166,7 +158,7 @@ class OpenAIAPI(ModelAPI):
|
|
166
158
|
tools: list[ToolInfo],
|
167
159
|
tool_choice: ToolChoice,
|
168
160
|
config: GenerateConfig,
|
169
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
161
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
170
162
|
# short-circuit to call o1- models that are text only
|
171
163
|
if self.is_o1_preview() or self.is_o1_mini():
|
172
164
|
return await generate_o1(
|
@@ -198,9 +190,11 @@ class OpenAIAPI(ModelAPI):
|
|
198
190
|
|
199
191
|
# prepare request (we do this so we can log the ModelCall)
|
200
192
|
request = dict(
|
201
|
-
messages=await
|
202
|
-
tools=
|
203
|
-
tool_choice=
|
193
|
+
messages=await openai_chat_messages(input, self.model_name),
|
194
|
+
tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
195
|
+
tool_choice=openai_chat_tool_choice(tool_choice)
|
196
|
+
if len(tools) > 0
|
197
|
+
else NOT_GIVEN,
|
204
198
|
**self.completion_params(config, len(tools) > 0),
|
205
199
|
)
|
206
200
|
|
@@ -237,7 +231,7 @@ class OpenAIAPI(ModelAPI):
|
|
237
231
|
self, response: ChatCompletion, tools: list[ToolInfo]
|
238
232
|
) -> list[ChatCompletionChoice]:
|
239
233
|
# adding this as a method so we can override from other classes (e.g together)
|
240
|
-
return
|
234
|
+
return chat_choices_from_openai(response, tools)
|
241
235
|
|
242
236
|
@override
|
243
237
|
def is_rate_limit(self, ex: BaseException) -> bool:
|
@@ -307,184 +301,23 @@ class OpenAIAPI(ModelAPI):
|
|
307
301
|
return params
|
308
302
|
|
309
303
|
# convert some well known bad request errors into ModelOutput
|
310
|
-
def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
content = e.message
|
304
|
+
def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
|
305
|
+
# extract message
|
306
|
+
if isinstance(e.body, dict) and "message" in e.body.keys():
|
307
|
+
content = str(e.body.get("message"))
|
308
|
+
else:
|
309
|
+
content = e.message
|
317
310
|
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
stop_reason = "unknown"
|
311
|
+
# narrow stop_reason
|
312
|
+
stop_reason: StopReason | None = None
|
313
|
+
if e.code == "context_length_exceeded":
|
314
|
+
stop_reason = "model_length"
|
315
|
+
elif e.code == "invalid_prompt":
|
316
|
+
stop_reason = "content_filter"
|
325
317
|
|
318
|
+
if stop_reason:
|
326
319
|
return ModelOutput.from_content(
|
327
320
|
model=self.model_name, content=content, stop_reason=stop_reason
|
328
321
|
)
|
329
322
|
else:
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
async def as_openai_chat_messages(
|
334
|
-
messages: list[ChatMessage], o1_full: bool
|
335
|
-
) -> list[ChatCompletionMessageParam]:
|
336
|
-
return [await openai_chat_message(message, o1_full) for message in messages]
|
337
|
-
|
338
|
-
|
339
|
-
async def openai_chat_message(
|
340
|
-
message: ChatMessage, o1_full: bool
|
341
|
-
) -> ChatCompletionMessageParam:
|
342
|
-
if message.role == "system":
|
343
|
-
if o1_full:
|
344
|
-
return ChatCompletionDeveloperMessageParam(
|
345
|
-
role="developer", content=message.text
|
346
|
-
)
|
347
|
-
else:
|
348
|
-
return ChatCompletionSystemMessageParam(
|
349
|
-
role=message.role, content=message.text
|
350
|
-
)
|
351
|
-
elif message.role == "user":
|
352
|
-
return ChatCompletionUserMessageParam(
|
353
|
-
role=message.role,
|
354
|
-
content=(
|
355
|
-
message.content
|
356
|
-
if isinstance(message.content, str)
|
357
|
-
else [
|
358
|
-
await as_chat_completion_part(content)
|
359
|
-
for content in message.content
|
360
|
-
]
|
361
|
-
),
|
362
|
-
)
|
363
|
-
elif message.role == "assistant":
|
364
|
-
if message.tool_calls:
|
365
|
-
return ChatCompletionAssistantMessageParam(
|
366
|
-
role=message.role,
|
367
|
-
content=message.text,
|
368
|
-
tool_calls=[chat_tool_call(call) for call in message.tool_calls],
|
369
|
-
)
|
370
|
-
else:
|
371
|
-
return ChatCompletionAssistantMessageParam(
|
372
|
-
role=message.role, content=message.text
|
373
|
-
)
|
374
|
-
elif message.role == "tool":
|
375
|
-
return ChatCompletionToolMessageParam(
|
376
|
-
role=message.role,
|
377
|
-
content=(
|
378
|
-
f"Error: {message.error.message}" if message.error else message.text
|
379
|
-
),
|
380
|
-
tool_call_id=str(message.tool_call_id),
|
381
|
-
)
|
382
|
-
else:
|
383
|
-
raise ValueError(f"Unexpected message role {message.role}")
|
384
|
-
|
385
|
-
|
386
|
-
def chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCallParam:
|
387
|
-
return ChatCompletionMessageToolCallParam(
|
388
|
-
id=tool_call.id,
|
389
|
-
function=dict(
|
390
|
-
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
391
|
-
),
|
392
|
-
type=tool_call.type,
|
393
|
-
)
|
394
|
-
|
395
|
-
|
396
|
-
def chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
397
|
-
return [chat_tool_param(tool) for tool in tools]
|
398
|
-
|
399
|
-
|
400
|
-
def chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
|
401
|
-
function = FunctionDefinition(
|
402
|
-
name=tool.name,
|
403
|
-
description=tool.description,
|
404
|
-
parameters=tool.parameters.model_dump(exclude_none=True),
|
405
|
-
)
|
406
|
-
return ChatCompletionToolParam(type="function", function=function)
|
407
|
-
|
408
|
-
|
409
|
-
def chat_tool_choice(tool_choice: ToolChoice) -> ChatCompletionToolChoiceOptionParam:
|
410
|
-
if isinstance(tool_choice, ToolFunction):
|
411
|
-
return ChatCompletionNamedToolChoiceParam(
|
412
|
-
type="function", function=dict(name=tool_choice.name)
|
413
|
-
)
|
414
|
-
# openai supports 'any' via the 'required' keyword
|
415
|
-
elif tool_choice == "any":
|
416
|
-
return "required"
|
417
|
-
else:
|
418
|
-
return tool_choice
|
419
|
-
|
420
|
-
|
421
|
-
def chat_tool_calls(
|
422
|
-
message: ChatCompletionMessage, tools: list[ToolInfo]
|
423
|
-
) -> list[ToolCall] | None:
|
424
|
-
if message.tool_calls:
|
425
|
-
return [
|
426
|
-
parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
|
427
|
-
for call in message.tool_calls
|
428
|
-
]
|
429
|
-
else:
|
430
|
-
return None
|
431
|
-
|
432
|
-
|
433
|
-
def chat_choices_from_response(
|
434
|
-
response: ChatCompletion, tools: list[ToolInfo]
|
435
|
-
) -> list[ChatCompletionChoice]:
|
436
|
-
choices = list(response.choices)
|
437
|
-
choices.sort(key=lambda c: c.index)
|
438
|
-
return [
|
439
|
-
ChatCompletionChoice(
|
440
|
-
message=chat_message_assistant(choice.message, tools),
|
441
|
-
stop_reason=as_stop_reason(choice.finish_reason),
|
442
|
-
logprobs=(
|
443
|
-
Logprobs(**choice.logprobs.model_dump())
|
444
|
-
if choice.logprobs is not None
|
445
|
-
else None
|
446
|
-
),
|
447
|
-
)
|
448
|
-
for choice in choices
|
449
|
-
]
|
450
|
-
|
451
|
-
|
452
|
-
def chat_message_assistant(
|
453
|
-
message: ChatCompletionMessage, tools: list[ToolInfo]
|
454
|
-
) -> ChatMessageAssistant:
|
455
|
-
return ChatMessageAssistant(
|
456
|
-
content=message.content or "",
|
457
|
-
source="generate",
|
458
|
-
tool_calls=chat_tool_calls(message, tools),
|
459
|
-
)
|
460
|
-
|
461
|
-
|
462
|
-
async def as_chat_completion_part(
|
463
|
-
content: Content,
|
464
|
-
) -> ChatCompletionContentPartParam:
|
465
|
-
if content.type == "text":
|
466
|
-
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
467
|
-
elif content.type == "image":
|
468
|
-
# API takes URL or base64 encoded file. If it's a remote file or
|
469
|
-
# data URL leave it alone, otherwise encode it
|
470
|
-
image_url = content.image
|
471
|
-
detail = content.detail
|
472
|
-
|
473
|
-
if not is_http_url(image_url):
|
474
|
-
image_url = await file_as_data_uri(image_url)
|
475
|
-
|
476
|
-
return ChatCompletionContentPartImageParam(
|
477
|
-
type="image_url",
|
478
|
-
image_url=dict(url=image_url, detail=detail),
|
479
|
-
)
|
480
|
-
elif content.type == "audio":
|
481
|
-
audio_data = await file_as_data_uri(content.audio)
|
482
|
-
|
483
|
-
return ChatCompletionContentPartInputAudioParam(
|
484
|
-
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
485
|
-
)
|
486
|
-
|
487
|
-
else:
|
488
|
-
raise RuntimeError(
|
489
|
-
"Video content is not currently supported by Open AI chat models."
|
490
|
-
)
|
323
|
+
return e
|
@@ -24,15 +24,13 @@ from inspect_ai.model import (
|
|
24
24
|
)
|
25
25
|
from inspect_ai.tool import ToolCall, ToolInfo
|
26
26
|
|
27
|
+
from .._call_tools import parse_tool_call, tool_parse_error_message
|
27
28
|
from .._model_call import ModelCall
|
28
|
-
from .._model_output import ModelUsage, StopReason
|
29
|
+
from .._model_output import ModelUsage, StopReason, as_stop_reason
|
29
30
|
from .._providers.util import (
|
30
31
|
ChatAPIHandler,
|
31
32
|
ChatAPIMessage,
|
32
|
-
as_stop_reason,
|
33
33
|
chat_api_input,
|
34
|
-
parse_tool_call,
|
35
|
-
tool_parse_error_message,
|
36
34
|
)
|
37
35
|
|
38
36
|
logger = getLogger(__name__)
|
@@ -44,7 +42,7 @@ async def generate_o1(
|
|
44
42
|
input: list[ChatMessage],
|
45
43
|
tools: list[ToolInfo],
|
46
44
|
**params: Any,
|
47
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
45
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
48
46
|
# create chatapi handler
|
49
47
|
handler = O1PreviewChatAPIHandler()
|
50
48
|
|
@@ -82,17 +80,18 @@ async def generate_o1(
|
|
82
80
|
), model_call()
|
83
81
|
|
84
82
|
|
85
|
-
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
|
83
|
+
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput | Exception:
|
86
84
|
if ex.code == "context_length_exceeded":
|
87
|
-
stop_reason: StopReason = "model_length"
|
85
|
+
stop_reason: StopReason | None = "model_length"
|
88
86
|
elif ex.code == "invalid_prompt":
|
89
87
|
stop_reason = "content_filter"
|
90
|
-
else:
|
91
|
-
stop_reason = "unknown"
|
92
88
|
|
93
|
-
|
94
|
-
|
95
|
-
|
89
|
+
if stop_reason:
|
90
|
+
return ModelOutput.from_content(
|
91
|
+
model=model, content=str(ex), stop_reason=stop_reason
|
92
|
+
)
|
93
|
+
else:
|
94
|
+
return ex
|
96
95
|
|
97
96
|
|
98
97
|
def chat_messages(
|
@@ -94,7 +94,7 @@ def vertex() -> type[ModelAPI]:
|
|
94
94
|
def google() -> type[ModelAPI]:
|
95
95
|
FEATURE = "Google API"
|
96
96
|
PACKAGE = "google-generativeai"
|
97
|
-
MIN_VERSION = "0.8.
|
97
|
+
MIN_VERSION = "0.8.4"
|
98
98
|
|
99
99
|
# workaround log spam
|
100
100
|
# https://github.com/ray-project/ray/issues/24917
|
@@ -239,6 +239,28 @@ def mockllm() -> type[ModelAPI]:
|
|
239
239
|
return MockLLM
|
240
240
|
|
241
241
|
|
242
|
+
@modelapi("goodfire")
|
243
|
+
def goodfire() -> type[ModelAPI]:
|
244
|
+
"""Get the Goodfire API provider."""
|
245
|
+
FEATURE = "Goodfire API"
|
246
|
+
PACKAGE = "goodfire"
|
247
|
+
MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
|
248
|
+
|
249
|
+
# verify we have the package
|
250
|
+
try:
|
251
|
+
import goodfire # noqa: F401
|
252
|
+
except ImportError:
|
253
|
+
raise pip_dependency_error(FEATURE, [PACKAGE])
|
254
|
+
|
255
|
+
# verify version
|
256
|
+
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
|
257
|
+
|
258
|
+
# in the clear
|
259
|
+
from .goodfire import GoodfireAPI
|
260
|
+
|
261
|
+
return GoodfireAPI
|
262
|
+
|
263
|
+
|
242
264
|
def validate_openai_client(feature: str) -> None:
|
243
265
|
FEATURE = feature
|
244
266
|
PACKAGE = "openai"
|
@@ -24,13 +24,13 @@ from .._model_output import (
|
|
24
24
|
ModelOutput,
|
25
25
|
ModelUsage,
|
26
26
|
StopReason,
|
27
|
+
as_stop_reason,
|
27
28
|
)
|
29
|
+
from .._openai import chat_message_assistant_from_openai
|
28
30
|
from .openai import (
|
29
31
|
OpenAIAPI,
|
30
|
-
chat_message_assistant,
|
31
32
|
)
|
32
33
|
from .util import (
|
33
|
-
as_stop_reason,
|
34
34
|
chat_api_input,
|
35
35
|
chat_api_request,
|
36
36
|
environment_prerequisite_error,
|
@@ -68,7 +68,7 @@ def chat_choices_from_response_together(
|
|
68
68
|
logprobs_models.append(Logprobs(content=logprobs_sequence))
|
69
69
|
return [
|
70
70
|
ChatCompletionChoice(
|
71
|
-
message=
|
71
|
+
message=chat_message_assistant_from_openai(choice.message, tools),
|
72
72
|
stop_reason=as_stop_reason(choice.finish_reason),
|
73
73
|
logprobs=logprobs,
|
74
74
|
)
|
@@ -99,22 +99,22 @@ class TogetherAIAPI(OpenAIAPI):
|
|
99
99
|
|
100
100
|
# Together uses a default of 512 so we bump it up
|
101
101
|
@override
|
102
|
-
def max_tokens(self) -> int:
|
102
|
+
def max_tokens(self) -> int | None:
|
103
103
|
return DEFAULT_MAX_TOKENS
|
104
104
|
|
105
105
|
@override
|
106
|
-
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput:
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
106
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
107
|
+
response = ex.response.json()
|
108
|
+
if "error" in response and "message" in response.get("error"):
|
109
|
+
content = response.get("error").get("message")
|
110
|
+
else:
|
111
|
+
content = str(response)
|
112
|
+
if "max_new_tokens" in ex.message:
|
113
113
|
return ModelOutput.from_content(
|
114
114
|
model=self.model_name, content=content, stop_reason="model_length"
|
115
115
|
)
|
116
116
|
else:
|
117
|
-
|
117
|
+
return ex
|
118
118
|
|
119
119
|
# Together has a slightly different logprobs structure to OpenAI, so we need to remap it.
|
120
120
|
def _chat_choices_from_response(
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
2
|
+
from ..._model_output import as_stop_reason
|
1
3
|
from .chatapi import (
|
2
4
|
ChatAPIHandler,
|
3
5
|
ChatAPIMessage,
|
@@ -8,11 +10,8 @@ from .chatapi import (
|
|
8
10
|
from .hf_handler import HFHandler
|
9
11
|
from .llama31 import Llama31Handler
|
10
12
|
from .util import (
|
11
|
-
as_stop_reason,
|
12
13
|
environment_prerequisite_error,
|
13
14
|
model_base_url,
|
14
|
-
parse_tool_call,
|
15
|
-
tool_parse_error_message,
|
16
15
|
)
|
17
16
|
|
18
17
|
__all__ = [
|
@@ -8,9 +8,9 @@ from typing_extensions import override
|
|
8
8
|
from inspect_ai.tool._tool_call import ToolCall
|
9
9
|
from inspect_ai.tool._tool_info import ToolInfo
|
10
10
|
|
11
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
11
12
|
from ..._chat_message import ChatMessageAssistant
|
12
13
|
from .chatapi import ChatAPIHandler
|
13
|
-
from .util import parse_tool_call, tool_parse_error_message
|
14
14
|
|
15
15
|
logger = getLogger(__name__)
|
16
16
|
|
@@ -9,6 +9,7 @@ from typing_extensions import override
|
|
9
9
|
from inspect_ai.tool._tool_call import ToolCall
|
10
10
|
from inspect_ai.tool._tool_info import ToolInfo
|
11
11
|
|
12
|
+
from ..._call_tools import parse_tool_call, tool_parse_error_message
|
12
13
|
from ..._chat_message import (
|
13
14
|
ChatMessage,
|
14
15
|
ChatMessageAssistant,
|
@@ -16,7 +17,6 @@ from ..._chat_message import (
|
|
16
17
|
ChatMessageTool,
|
17
18
|
)
|
18
19
|
from .chatapi import ChatAPIHandler, ChatAPIMessage
|
19
|
-
from .util import parse_tool_call, tool_parse_error_message
|
20
20
|
|
21
21
|
logger = getLogger(__name__)
|
22
22
|
|
@@ -1,34 +1,11 @@
|
|
1
|
-
import json
|
2
1
|
import os
|
3
2
|
from logging import getLogger
|
4
|
-
from typing import Any
|
5
|
-
|
6
|
-
import yaml
|
7
3
|
|
8
4
|
from inspect_ai._util.error import PrerequisiteError
|
9
|
-
from inspect_ai.tool._tool_call import ToolCall
|
10
|
-
from inspect_ai.tool._tool_info import ToolInfo
|
11
|
-
|
12
|
-
from ..._model_output import StopReason
|
13
5
|
|
14
6
|
logger = getLogger(__name__)
|
15
7
|
|
16
8
|
|
17
|
-
def as_stop_reason(reason: str | None) -> StopReason:
|
18
|
-
"""Encode common reason strings into standard StopReason."""
|
19
|
-
match reason:
|
20
|
-
case "stop" | "eos":
|
21
|
-
return "stop"
|
22
|
-
case "length":
|
23
|
-
return "max_tokens"
|
24
|
-
case "tool_calls" | "function_call":
|
25
|
-
return "tool_calls"
|
26
|
-
case "content_filter" | "model_length" | "max_tokens":
|
27
|
-
return reason
|
28
|
-
case _:
|
29
|
-
return "unknown"
|
30
|
-
|
31
|
-
|
32
9
|
def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | None:
|
33
10
|
if base_url:
|
34
11
|
return base_url
|
@@ -44,59 +21,6 @@ def model_base_url(base_url: str | None, env_vars: str | list[str]) -> str | Non
|
|
44
21
|
return os.getenv("INSPECT_EVAL_MODEL_BASE_URL", None)
|
45
22
|
|
46
23
|
|
47
|
-
def tool_parse_error_message(arguments: str, ex: Exception) -> str:
|
48
|
-
return f"Error parsing the following tool call arguments:\n\n{arguments}\n\nError details: {ex}"
|
49
|
-
|
50
|
-
|
51
|
-
def parse_tool_call(
|
52
|
-
id: str, function: str, arguments: str, tools: list[ToolInfo]
|
53
|
-
) -> ToolCall:
|
54
|
-
error: str | None = None
|
55
|
-
arguments_dict: dict[str, Any] = {}
|
56
|
-
|
57
|
-
def report_parse_error(ex: Exception) -> None:
|
58
|
-
nonlocal error
|
59
|
-
error = tool_parse_error_message(arguments, ex)
|
60
|
-
logger.info(error)
|
61
|
-
|
62
|
-
# if the arguments is a dict, then handle it with a plain json.loads
|
63
|
-
arguments = arguments.strip()
|
64
|
-
if arguments.startswith("{"):
|
65
|
-
try:
|
66
|
-
arguments_dict = json.loads(arguments)
|
67
|
-
except json.JSONDecodeError as ex:
|
68
|
-
report_parse_error(ex)
|
69
|
-
|
70
|
-
# otherwise parse it as yaml (which will pickup unquoted strings, numbers, and true/false)
|
71
|
-
# and then create a dict that maps it to the first function argument
|
72
|
-
else:
|
73
|
-
tool_info = next(
|
74
|
-
(
|
75
|
-
tool
|
76
|
-
for tool in tools
|
77
|
-
if tool.name == function and len(tool.parameters.properties) > 0
|
78
|
-
),
|
79
|
-
None,
|
80
|
-
)
|
81
|
-
if tool_info:
|
82
|
-
param_names = list(tool_info.parameters.properties.keys())
|
83
|
-
try:
|
84
|
-
value = yaml.safe_load(arguments)
|
85
|
-
arguments_dict[param_names[0]] = value
|
86
|
-
except yaml.error.YAMLError:
|
87
|
-
# If the yaml parser fails, we treat it as a string argument.
|
88
|
-
arguments_dict[param_names[0]] = arguments
|
89
|
-
|
90
|
-
# return ToolCall with error payload
|
91
|
-
return ToolCall(
|
92
|
-
id=id,
|
93
|
-
function=function,
|
94
|
-
arguments=arguments_dict,
|
95
|
-
type="function",
|
96
|
-
parse_error=error,
|
97
|
-
)
|
98
|
-
|
99
|
-
|
100
24
|
def environment_prerequisite_error(
|
101
25
|
client: str, env_vars: str | list[str]
|
102
26
|
) -> PrerequisiteError:
|
@@ -23,7 +23,7 @@ from vertexai.generative_models import ( # type: ignore
|
|
23
23
|
)
|
24
24
|
from vertexai.generative_models import Content as VertexContent
|
25
25
|
|
26
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
26
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
27
27
|
from inspect_ai._util.content import (
|
28
28
|
Content,
|
29
29
|
ContentAudio,
|
@@ -250,9 +250,6 @@ def consective_tool_message_reducer(
|
|
250
250
|
return messages
|
251
251
|
|
252
252
|
|
253
|
-
NO_CONTENT = "(no content)"
|
254
|
-
|
255
|
-
|
256
253
|
async def content_dict(
|
257
254
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
258
255
|
) -> VertexContent:
|
inspect_ai/scorer/_metric.py
CHANGED
@@ -125,6 +125,9 @@ class SampleScore(BaseModel):
|
|
125
125
|
sample_id: str | int | None = Field(default=None)
|
126
126
|
"""A sample id"""
|
127
127
|
|
128
|
+
scorer: str | None = Field(default=None)
|
129
|
+
"""Registry name of scorer that created this score."""
|
130
|
+
|
128
131
|
|
129
132
|
ValueToFloat = Callable[[Value], float]
|
130
133
|
"""Function used by metrics to translate from a Score value to a float value."""
|