inspect-ai 0.3.81__py3-none-any.whl → 0.3.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/eval.py +35 -2
- inspect_ai/_cli/util.py +44 -1
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +13 -4
- inspect_ai/_display/core/results.py +1 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +8 -8
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +44 -0
- inspect_ai/_eval/eval.py +74 -25
- inspect_ai/_eval/evalset.py +22 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +13 -15
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +55 -6
- inspect_ai/_eval/task/run.py +51 -10
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/file.py +30 -1
- inspect_ai/_util/json.py +37 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/server.py +113 -1
- inspect_ai/_view/www/App.css +7 -1
- inspect_ai/_view/www/dist/assets/index.css +813 -415
- inspect_ai/_view/www/dist/assets/index.js +54475 -32003
- inspect_ai/_view/www/eslint.config.mjs +1 -1
- inspect_ai/_view/www/log-schema.json +137 -31
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +149 -0
- inspect_ai/_view/www/package.json +11 -2
- inspect_ai/_view/www/src/App.tsx +161 -853
- inspect_ai/_view/www/src/api/api-browser.ts +176 -5
- inspect_ai/_view/www/src/api/api-vscode.ts +75 -1
- inspect_ai/_view/www/src/api/client-api.ts +66 -10
- inspect_ai/_view/www/src/api/jsonrpc.ts +2 -0
- inspect_ai/_view/www/src/api/types.ts +107 -2
- inspect_ai/_view/www/src/appearance/icons.ts +2 -0
- inspect_ai/_view/www/src/components/AsciinemaPlayer.tsx +3 -3
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/DownloadPanel.tsx +2 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +56 -61
- inspect_ai/_view/www/src/components/FindBand.tsx +17 -9
- inspect_ai/_view/www/src/components/HumanBaselineView.tsx +1 -1
- inspect_ai/_view/www/src/components/JsonPanel.tsx +14 -24
- inspect_ai/_view/www/src/components/LargeModal.tsx +2 -35
- inspect_ai/_view/www/src/components/LightboxCarousel.tsx +27 -11
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.module.css +11 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +177 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +116 -26
- inspect_ai/_view/www/src/components/MessageBand.tsx +14 -9
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/components/MorePopOver.tsx +3 -3
- inspect_ai/_view/www/src/components/NavPills.tsx +20 -8
- inspect_ai/_view/www/src/components/NoContentsPanel.module.css +12 -0
- inspect_ai/_view/www/src/components/NoContentsPanel.tsx +20 -0
- inspect_ai/_view/www/src/components/ProgressBar.module.css +5 -4
- inspect_ai/_view/www/src/components/ProgressBar.tsx +3 -2
- inspect_ai/_view/www/src/components/PulsingDots.module.css +81 -0
- inspect_ai/_view/www/src/components/PulsingDots.tsx +45 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -37
- inspect_ai/_view/www/src/components/ToolButton.tsx +3 -4
- inspect_ai/_view/www/src/index.tsx +26 -94
- inspect_ai/_view/www/src/logfile/remoteLogFile.ts +9 -1
- inspect_ai/_view/www/src/logfile/remoteZipFile.ts +30 -4
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +4 -6
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/ScorerDetailView.tsx +1 -1
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +74 -28
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +58 -22
- inspect_ai/_view/www/src/samples/SampleDisplay.module.css +4 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +135 -104
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +10 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +83 -36
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +35 -30
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -1
- inspect_ai/_view/www/src/samples/chat/ChatMessageRenderer.tsx +1 -1
- inspect_ai/_view/www/src/samples/chat/ChatViewVirtualList.tsx +45 -53
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +6 -1
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +5 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +36 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.module.css +3 -0
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +11 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +22 -46
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +34 -20
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -10
- inspect_ai/_view/www/src/samples/descriptor/types.ts +6 -5
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +22 -3
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +27 -2
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +122 -85
- inspect_ai/_view/www/src/samples/list/SampleRow.module.css +6 -0
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +28 -15
- inspect_ai/_view/www/src/samples/sample-tools/SelectScorer.tsx +29 -18
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +28 -28
- inspect_ai/_view/www/src/samples/sample-tools/sample-filter/SampleFilter.tsx +19 -9
- inspect_ai/_view/www/src/samples/sampleDataAdapter.ts +33 -0
- inspect_ai/_view/www/src/samples/sampleLimit.ts +2 -2
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +12 -27
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/InputEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +10 -24
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +4 -22
- inspect_ai/_view/www/src/samples/transcript/SandboxEventView.tsx +15 -24
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.tsx +0 -13
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +6 -28
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.tsx +24 -34
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.module.css +4 -0
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +33 -17
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.tsx +197 -338
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.module.css +16 -0
- inspect_ai/_view/www/src/samples/transcript/TranscriptVirtualListComponent.tsx +44 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventNav.tsx +7 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +81 -60
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.module.css +23 -0
- inspect_ai/_view/www/src/samples/transcript/event/EventProgressPanel.tsx +27 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +29 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +102 -72
- inspect_ai/_view/www/src/scoring/utils.ts +87 -0
- inspect_ai/_view/www/src/state/appSlice.ts +244 -0
- inspect_ai/_view/www/src/state/hooks.ts +399 -0
- inspect_ai/_view/www/src/state/logPolling.ts +200 -0
- inspect_ai/_view/www/src/state/logSlice.ts +224 -0
- inspect_ai/_view/www/src/state/logsPolling.ts +118 -0
- inspect_ai/_view/www/src/state/logsSlice.ts +181 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +314 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +140 -0
- inspect_ai/_view/www/src/state/sampleUtils.ts +21 -0
- inspect_ai/_view/www/src/state/scrolling.ts +206 -0
- inspect_ai/_view/www/src/state/store.ts +168 -0
- inspect_ai/_view/www/src/state/store_filter.ts +84 -0
- inspect_ai/_view/www/src/state/utils.ts +23 -0
- inspect_ai/_view/www/src/storage/index.ts +26 -0
- inspect_ai/_view/www/src/types/log.d.ts +36 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/types.ts +94 -32
- inspect_ai/_view/www/src/utils/attachments.ts +58 -23
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/utils/logger.ts +52 -0
- inspect_ai/_view/www/src/utils/polling.ts +100 -0
- inspect_ai/_view/www/src/utils/react.ts +30 -0
- inspect_ai/_view/www/src/utils/vscode.ts +1 -1
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +184 -217
- inspect_ai/_view/www/src/workspace/WorkSpaceView.tsx +11 -53
- inspect_ai/_view/www/src/workspace/navbar/Navbar.tsx +8 -18
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +40 -22
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -1
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +159 -103
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.module.css +32 -0
- inspect_ai/_view/www/src/workspace/navbar/RunningStatusPanel.tsx +32 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +12 -14
- inspect_ai/_view/www/src/workspace/navbar/StatusPanel.tsx +6 -2
- inspect_ai/_view/www/src/workspace/sidebar/LogDirectoryTitleView.tsx +4 -4
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.tsx +28 -13
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +5 -10
- inspect_ai/_view/www/src/workspace/tabs/JsonTab.tsx +4 -4
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.module.css +22 -0
- inspect_ai/_view/www/src/workspace/tabs/RunningNoSamples.tsx +19 -0
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +128 -115
- inspect_ai/_view/www/src/workspace/tabs/grouping.ts +37 -5
- inspect_ai/_view/www/src/workspace/tabs/types.ts +4 -0
- inspect_ai/_view/www/src/workspace/types.ts +4 -3
- inspect_ai/_view/www/src/workspace/utils.ts +4 -4
- inspect_ai/_view/www/vite.config.js +6 -0
- inspect_ai/_view/www/yarn.lock +464 -355
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_condense.py +26 -0
- inspect_ai/log/_log.py +17 -5
- inspect_ai/log/_recorders/buffer/__init__.py +14 -0
- inspect_ai/log/_recorders/buffer/buffer.py +30 -0
- inspect_ai/log/_recorders/buffer/database.py +685 -0
- inspect_ai/log/_recorders/buffer/filestore.py +259 -0
- inspect_ai/log/_recorders/buffer/types.py +84 -0
- inspect_ai/log/_recorders/eval.py +2 -11
- inspect_ai/log/_recorders/types.py +30 -0
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +257 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_generate_config.py +2 -2
- inspect_ai/model/_model.py +114 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +4 -4
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +103 -62
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- inspect_ai/util/_subtask.py +1 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/RECORD +295 -229
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -169
- inspect_ai/_view/www/src/samples/transcript/SampleTranscript.tsx +0 -22
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.81.dist-info → inspect_ai-0.3.83.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import os
|
|
3
3
|
import re
|
4
4
|
from copy import copy
|
5
5
|
from logging import getLogger
|
6
|
-
from typing import Any, Literal,
|
6
|
+
from typing import Any, Literal, Optional, Tuple, cast
|
7
7
|
|
8
8
|
import httpcore
|
9
9
|
import httpx
|
@@ -153,7 +153,7 @@ class AnthropicAPI(ModelAPI):
|
|
153
153
|
self._http_hooks = HttpxHooks(self.client._client)
|
154
154
|
|
155
155
|
@override
|
156
|
-
async def
|
156
|
+
async def aclose(self) -> None:
|
157
157
|
await self.client.close()
|
158
158
|
|
159
159
|
def is_bedrock(self) -> bool:
|
@@ -639,11 +639,7 @@ def message_tool_choice(
|
|
639
639
|
elif tool_choice == "any":
|
640
640
|
return {"type": "any"}
|
641
641
|
elif tool_choice == "none":
|
642
|
-
|
643
|
-
logger,
|
644
|
-
'The Anthropic API does not support tool_choice="none" (using "auto" instead)',
|
645
|
-
)
|
646
|
-
return {"type": "auto"}
|
642
|
+
return {"type": "none"}
|
647
643
|
else:
|
648
644
|
return {"type": "auto"}
|
649
645
|
|
@@ -723,11 +719,12 @@ async def message_param(message: ChatMessage) -> MessageParam:
|
|
723
719
|
|
724
720
|
# now add tools
|
725
721
|
for tool_call in message.tool_calls:
|
722
|
+
internal_name = _internal_name_from_tool_call(tool_call)
|
726
723
|
tools_content.append(
|
727
724
|
ToolUseBlockParam(
|
728
725
|
type="tool_use",
|
729
726
|
id=tool_call.id,
|
730
|
-
name=
|
727
|
+
name=internal_name or tool_call.function,
|
731
728
|
input=tool_call.arguments,
|
732
729
|
)
|
733
730
|
)
|
@@ -774,14 +771,13 @@ async def model_output_from_message(
|
|
774
771
|
content.append(ContentText(type="text", text=content_text))
|
775
772
|
elif isinstance(content_block, ToolUseBlock):
|
776
773
|
tool_calls = tool_calls or []
|
777
|
-
|
774
|
+
(tool_name, internal_name) = _names_for_tool_call(content_block.name, tools)
|
778
775
|
tool_calls.append(
|
779
776
|
ToolCall(
|
780
|
-
type=info.internal_type,
|
781
777
|
id=content_block.id,
|
782
|
-
function=
|
783
|
-
internal_name=info.internal_name,
|
778
|
+
function=tool_name,
|
784
779
|
arguments=content_block.model_dump().get("input", {}),
|
780
|
+
internal=internal_name,
|
785
781
|
)
|
786
782
|
)
|
787
783
|
elif isinstance(content_block, RedactedThinkingBlock):
|
@@ -801,7 +797,7 @@ async def model_output_from_message(
|
|
801
797
|
# resolve choice
|
802
798
|
choice = ChatCompletionChoice(
|
803
799
|
message=ChatMessageAssistant(
|
804
|
-
content=content, tool_calls=tool_calls, source="generate"
|
800
|
+
content=content, tool_calls=tool_calls, model=model, source="generate"
|
805
801
|
),
|
806
802
|
stop_reason=message_stop_reason(message),
|
807
803
|
)
|
@@ -831,15 +827,18 @@ async def model_output_from_message(
|
|
831
827
|
)
|
832
828
|
|
833
829
|
|
834
|
-
|
835
|
-
|
836
|
-
|
837
|
-
|
830
|
+
def _internal_name_from_tool_call(tool_call: ToolCall) -> str | None:
|
831
|
+
assert isinstance(tool_call.internal, str | None), (
|
832
|
+
f"ToolCall internal must be `str | None`: {tool_call.internal}"
|
833
|
+
)
|
834
|
+
return tool_call.internal
|
838
835
|
|
839
836
|
|
840
|
-
def
|
837
|
+
def _names_for_tool_call(
|
838
|
+
tool_called: str, tools: list[ToolInfo]
|
839
|
+
) -> tuple[str, str | None]:
|
841
840
|
"""
|
842
|
-
Return
|
841
|
+
Return the name of the tool to call and potentially an internal name.
|
843
842
|
|
844
843
|
Anthropic prescribes names for their native tools - `computer`, `bash`, and
|
845
844
|
`str_replace_editor`. For a variety of reasons, Inspect's tool names to not
|
@@ -854,11 +853,11 @@ def maybe_mapped_call_info(tool_called: str, tools: list[ToolInfo]) -> CallInfo:
|
|
854
853
|
|
855
854
|
return next(
|
856
855
|
(
|
857
|
-
|
856
|
+
(entry[2], entry[0])
|
858
857
|
for entry in mappings
|
859
858
|
if entry[0] == tool_called and any(tool.name == entry[2] for tool in tools)
|
860
859
|
),
|
861
|
-
|
860
|
+
(tool_called, None),
|
862
861
|
)
|
863
862
|
|
864
863
|
|
@@ -129,11 +129,6 @@ class AzureAIAPI(ModelAPI):
|
|
129
129
|
self.endpoint_url = endpoint_url
|
130
130
|
self.model_args = model_args
|
131
131
|
|
132
|
-
@override
|
133
|
-
async def close(self) -> None:
|
134
|
-
# client is created/destroyed each time in generate()
|
135
|
-
pass
|
136
|
-
|
137
132
|
async def generate(
|
138
133
|
self,
|
139
134
|
input: list[ChatMessage],
|
@@ -143,9 +138,9 @@ class AzureAIAPI(ModelAPI):
|
|
143
138
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
144
139
|
# emulate tools (auto for llama, opt-in for others)
|
145
140
|
if self.emulate_tools is None and self.is_llama():
|
146
|
-
handler: ChatAPIHandler | None = Llama31Handler()
|
141
|
+
handler: ChatAPIHandler | None = Llama31Handler(self.model_name)
|
147
142
|
elif self.emulate_tools:
|
148
|
-
handler = Llama31Handler()
|
143
|
+
handler = Llama31Handler(self.model_name)
|
149
144
|
else:
|
150
145
|
handler = None
|
151
146
|
|
@@ -190,7 +185,9 @@ class AzureAIAPI(ModelAPI):
|
|
190
185
|
response: ChatCompletions = await client.complete(**request)
|
191
186
|
return ModelOutput(
|
192
187
|
model=response.model,
|
193
|
-
choices=chat_completion_choices(
|
188
|
+
choices=chat_completion_choices(
|
189
|
+
response.model, response.choices, tools, handler
|
190
|
+
),
|
194
191
|
usage=ModelUsage(
|
195
192
|
input_tokens=response.usage.prompt_tokens,
|
196
193
|
output_tokens=response.usage.completion_tokens,
|
@@ -368,24 +365,37 @@ def chat_tool_choice(
|
|
368
365
|
|
369
366
|
|
370
367
|
def chat_completion_choices(
|
371
|
-
|
368
|
+
model: str,
|
369
|
+
choices: list[ChatChoice],
|
370
|
+
tools: list[ToolInfo],
|
371
|
+
handler: ChatAPIHandler | None,
|
372
372
|
) -> list[ChatCompletionChoice]:
|
373
373
|
choices = copy(choices)
|
374
374
|
choices.sort(key=lambda c: c.index)
|
375
|
-
return [
|
375
|
+
return [
|
376
|
+
chat_complection_choice(model, choice, tools, handler) for choice in choices
|
377
|
+
]
|
376
378
|
|
377
379
|
|
378
380
|
def chat_complection_choice(
|
379
|
-
|
381
|
+
model: str,
|
382
|
+
choice: ChatChoice,
|
383
|
+
tools: list[ToolInfo],
|
384
|
+
handler: ChatAPIHandler | None,
|
380
385
|
) -> ChatCompletionChoice:
|
381
386
|
return ChatCompletionChoice(
|
382
|
-
message=chat_completion_assistant_message(
|
387
|
+
message=chat_completion_assistant_message(
|
388
|
+
model, choice.message, tools, handler
|
389
|
+
),
|
383
390
|
stop_reason=chat_completion_stop_reason(choice.finish_reason),
|
384
391
|
)
|
385
392
|
|
386
393
|
|
387
394
|
def chat_completion_assistant_message(
|
388
|
-
|
395
|
+
model: str,
|
396
|
+
response: ChatResponseMessage,
|
397
|
+
tools: list[ToolInfo],
|
398
|
+
handler: ChatAPIHandler | None,
|
389
399
|
) -> ChatMessageAssistant:
|
390
400
|
if handler:
|
391
401
|
return handler.parse_assistant_response(response.content, tools)
|
@@ -397,6 +407,7 @@ def chat_completion_assistant_message(
|
|
397
407
|
]
|
398
408
|
if response.tool_calls is not None
|
399
409
|
else None,
|
410
|
+
model=model,
|
400
411
|
)
|
401
412
|
|
402
413
|
|
@@ -269,11 +269,6 @@ class BedrockAPI(ModelAPI):
|
|
269
269
|
except ImportError:
|
270
270
|
raise pip_dependency_error("Bedrock API", ["aioboto3"])
|
271
271
|
|
272
|
-
@override
|
273
|
-
async def close(self) -> None:
|
274
|
-
# client is created/destroyed each time in generate()
|
275
|
-
pass
|
276
|
-
|
277
272
|
@override
|
278
273
|
def connection_key(self) -> str:
|
279
274
|
return self.model_name
|
@@ -454,7 +449,6 @@ def model_output_from_response(
|
|
454
449
|
tool_calls.append(
|
455
450
|
ToolCall(
|
456
451
|
id=c.toolUse.toolUseId,
|
457
|
-
type="function",
|
458
452
|
function=c.toolUse.name,
|
459
453
|
arguments=cast(dict[str, Any], c.toolUse.input or {}),
|
460
454
|
)
|
@@ -465,7 +459,7 @@ def model_output_from_response(
|
|
465
459
|
# resolve choice
|
466
460
|
choice = ChatCompletionChoice(
|
467
461
|
message=ChatMessageAssistant(
|
468
|
-
content=content, tool_calls=tool_calls, source="generate"
|
462
|
+
content=content, tool_calls=tool_calls, model=model, source="generate"
|
469
463
|
),
|
470
464
|
stop_reason=message_stop_reason(response.stopReason),
|
471
465
|
)
|
@@ -59,7 +59,7 @@ class CloudFlareAPI(ModelAPI):
|
|
59
59
|
self.model_args = model_args
|
60
60
|
|
61
61
|
@override
|
62
|
-
async def
|
62
|
+
async def aclose(self) -> None:
|
63
63
|
await self.client.aclose()
|
64
64
|
|
65
65
|
async def generate(
|
@@ -141,6 +141,6 @@ class CloudFlareAPI(ModelAPI):
|
|
141
141
|
|
142
142
|
def chat_api_handler(self) -> ChatAPIHandler:
|
143
143
|
if "llama" in self.model_name.lower():
|
144
|
-
return Llama31Handler()
|
144
|
+
return Llama31Handler(self.model_name)
|
145
145
|
else:
|
146
|
-
return ChatAPIHandler()
|
146
|
+
return ChatAPIHandler(self.model_name)
|
@@ -115,11 +115,6 @@ class GoodfireAPI(ModelAPI):
|
|
115
115
|
# Initialize variant directly with model name
|
116
116
|
self.variant = Variant(self.model_name) # type: ignore
|
117
117
|
|
118
|
-
@override
|
119
|
-
async def close(self) -> None:
|
120
|
-
# httpx.AsyncClient is created on each generate()
|
121
|
-
pass
|
122
|
-
|
123
118
|
def _to_goodfire_message(self, message: ChatMessage) -> GoodfireChatMessage:
|
124
119
|
"""Convert an Inspect message to a Goodfire message format.
|
125
120
|
|
@@ -232,7 +227,8 @@ class GoodfireAPI(ModelAPI):
|
|
232
227
|
choices=[
|
233
228
|
ChatCompletionChoice(
|
234
229
|
message=ChatMessageAssistant(
|
235
|
-
content=response_dict["choices"][0]["message"]["content"]
|
230
|
+
content=response_dict["choices"][0]["message"]["content"],
|
231
|
+
model=self.model_name,
|
236
232
|
),
|
237
233
|
stop_reason="stop",
|
238
234
|
)
|
@@ -183,11 +183,6 @@ class GoogleGenAIAPI(ModelAPI):
|
|
183
183
|
# save model args
|
184
184
|
self.model_args = model_args
|
185
185
|
|
186
|
-
@override
|
187
|
-
async def close(self) -> None:
|
188
|
-
# GenerativeModel uses a cached/shared client so there is no 'close'
|
189
|
-
pass
|
190
|
-
|
191
186
|
def is_vertex(self) -> bool:
|
192
187
|
return self.service == "vertex"
|
193
188
|
|
@@ -257,9 +252,10 @@ class GoogleGenAIAPI(ModelAPI):
|
|
257
252
|
except ClientError as ex:
|
258
253
|
return self.handle_client_error(ex), model_call()
|
259
254
|
|
255
|
+
model_name = response.model_version or self.model_name
|
260
256
|
output = ModelOutput(
|
261
|
-
model=
|
262
|
-
choices=completion_choices_from_candidates(response),
|
257
|
+
model=model_name,
|
258
|
+
choices=completion_choices_from_candidates(model_name, response),
|
263
259
|
usage=usage_metadata_to_model_usage(response.usage_metadata),
|
264
260
|
)
|
265
261
|
|
@@ -546,7 +542,9 @@ def chat_tool_config(tool_choice: ToolChoice) -> ToolConfig:
|
|
546
542
|
)
|
547
543
|
|
548
544
|
|
549
|
-
def completion_choice_from_candidate(
|
545
|
+
def completion_choice_from_candidate(
|
546
|
+
model: str, candidate: Candidate
|
547
|
+
) -> ChatCompletionChoice:
|
550
548
|
# content can be None when the finish_reason is SAFETY
|
551
549
|
if candidate.content is None:
|
552
550
|
content = ""
|
@@ -572,7 +570,6 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
572
570
|
if part.function_call:
|
573
571
|
tool_calls.append(
|
574
572
|
ToolCall(
|
575
|
-
type="function",
|
576
573
|
id=part.function_call.name,
|
577
574
|
function=part.function_call.name,
|
578
575
|
arguments=part.function_call.args,
|
@@ -596,6 +593,7 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
596
593
|
message=ChatMessageAssistant(
|
597
594
|
content=choice_content,
|
598
595
|
tool_calls=tool_calls if len(tool_calls) > 0 else None,
|
596
|
+
model=model,
|
599
597
|
source="generate",
|
600
598
|
),
|
601
599
|
stop_reason=stop_reason,
|
@@ -624,19 +622,22 @@ def completion_choice_from_candidate(candidate: Candidate) -> ChatCompletionChoi
|
|
624
622
|
|
625
623
|
|
626
624
|
def completion_choices_from_candidates(
|
625
|
+
model: str,
|
627
626
|
response: GenerateContentResponse,
|
628
627
|
) -> list[ChatCompletionChoice]:
|
629
628
|
candidates = response.candidates
|
630
629
|
if candidates:
|
631
630
|
candidates_list = sorted(candidates, key=lambda c: c.index)
|
632
631
|
return [
|
633
|
-
completion_choice_from_candidate(candidate)
|
632
|
+
completion_choice_from_candidate(model, candidate)
|
633
|
+
for candidate in candidates_list
|
634
634
|
]
|
635
635
|
elif response.prompt_feedback:
|
636
636
|
return [
|
637
637
|
ChatCompletionChoice(
|
638
638
|
message=ChatMessageAssistant(
|
639
639
|
content=prompt_feedback_to_content(response.prompt_feedback),
|
640
|
+
model=model,
|
640
641
|
source="generate",
|
641
642
|
),
|
642
643
|
stop_reason="content_filter",
|
@@ -93,7 +93,7 @@ class GroqAPI(ModelAPI):
|
|
93
93
|
self._http_hooks = HttpxHooks(self.client._client)
|
94
94
|
|
95
95
|
@override
|
96
|
-
async def
|
96
|
+
async def aclose(self) -> None:
|
97
97
|
await self.client.close()
|
98
98
|
|
99
99
|
async def generate(
|
@@ -203,7 +203,7 @@ class GroqAPI(ModelAPI):
|
|
203
203
|
choices.sort(key=lambda c: c.index)
|
204
204
|
return [
|
205
205
|
ChatCompletionChoice(
|
206
|
-
message=chat_message_assistant(choice.message, tools),
|
206
|
+
message=chat_message_assistant(self.model_name, choice.message, tools),
|
207
207
|
stop_reason=as_stop_reason(choice.finish_reason),
|
208
208
|
)
|
209
209
|
for choice in choices
|
@@ -323,7 +323,9 @@ def chat_tool_calls(message: Any, tools: list[ToolInfo]) -> Optional[List[ToolCa
|
|
323
323
|
return None
|
324
324
|
|
325
325
|
|
326
|
-
def chat_message_assistant(
|
326
|
+
def chat_message_assistant(
|
327
|
+
model: str, message: Any, tools: list[ToolInfo]
|
328
|
+
) -> ChatMessageAssistant:
|
327
329
|
reasoning = getattr(message, "reasoning", None)
|
328
330
|
if reasoning is not None:
|
329
331
|
content: str | list[Content] = [
|
@@ -335,6 +337,7 @@ def chat_message_assistant(message: Any, tools: list[ToolInfo]) -> ChatMessageAs
|
|
335
337
|
|
336
338
|
return ChatMessageAssistant(
|
337
339
|
content=content,
|
340
|
+
model=model,
|
338
341
|
source="generate",
|
339
342
|
tool_calls=chat_tool_calls(message, tools),
|
340
343
|
)
|
@@ -123,7 +123,7 @@ class HuggingFaceAPI(ModelAPI):
|
|
123
123
|
self.tokenizer.padding_side = "left"
|
124
124
|
|
125
125
|
@override
|
126
|
-
|
126
|
+
def close(self) -> None:
|
127
127
|
self.model = None
|
128
128
|
self.tokenizer = None
|
129
129
|
gc.collect()
|
@@ -205,7 +205,9 @@ class HuggingFaceAPI(ModelAPI):
|
|
205
205
|
|
206
206
|
# construct choice
|
207
207
|
choice = ChatCompletionChoice(
|
208
|
-
message=ChatMessageAssistant(
|
208
|
+
message=ChatMessageAssistant(
|
209
|
+
content=response.output, model=self.model_name, source="generate"
|
210
|
+
),
|
209
211
|
logprobs=(
|
210
212
|
Logprobs(content=final_logprobs) if final_logprobs is not None else None
|
211
213
|
),
|
@@ -338,7 +340,9 @@ def chat_completion_assistant_message(
|
|
338
340
|
if handler:
|
339
341
|
return handler.parse_assistant_response(response.output, tools)
|
340
342
|
else:
|
341
|
-
return ChatMessageAssistant(
|
343
|
+
return ChatMessageAssistant(
|
344
|
+
content=response.output, model=model_name, source="generate"
|
345
|
+
)
|
342
346
|
|
343
347
|
|
344
348
|
def set_random_seeds(seed: int | None = None) -> None:
|
@@ -135,11 +135,6 @@ class MistralAPI(ModelAPI):
|
|
135
135
|
def is_azure(self) -> bool:
|
136
136
|
return self.service == "azure"
|
137
137
|
|
138
|
-
@override
|
139
|
-
async def close(self) -> None:
|
140
|
-
# client is created and destroyed in generate
|
141
|
-
pass
|
142
|
-
|
143
138
|
async def generate(
|
144
139
|
self,
|
145
140
|
input: list[ChatMessage],
|
@@ -448,13 +443,11 @@ def chat_tool_call(tool_call: MistralToolCall, tools: list[ToolInfo]) -> ToolCal
|
|
448
443
|
id, tool_call.function.name, tool_call.function.arguments, tools
|
449
444
|
)
|
450
445
|
else:
|
451
|
-
return ToolCall(
|
452
|
-
id, tool_call.function.name, tool_call.function.arguments, type="function"
|
453
|
-
)
|
446
|
+
return ToolCall(id, tool_call.function.name, tool_call.function.arguments)
|
454
447
|
|
455
448
|
|
456
449
|
def completion_choice(
|
457
|
-
choice: MistralChatCompletionChoice, tools: list[ToolInfo]
|
450
|
+
model: str, choice: MistralChatCompletionChoice, tools: list[ToolInfo]
|
458
451
|
) -> ChatCompletionChoice:
|
459
452
|
message = choice.message
|
460
453
|
if message:
|
@@ -465,6 +458,7 @@ def completion_choice(
|
|
465
458
|
tool_calls=chat_tool_calls(message.tool_calls, tools)
|
466
459
|
if message.tool_calls
|
467
460
|
else None,
|
461
|
+
model=model,
|
468
462
|
source="generate",
|
469
463
|
),
|
470
464
|
stop_reason=(
|
@@ -511,7 +505,10 @@ def completion_choices_from_response(
|
|
511
505
|
if response.choices is None:
|
512
506
|
return []
|
513
507
|
else:
|
514
|
-
return [
|
508
|
+
return [
|
509
|
+
completion_choice(response.model, choice, tools)
|
510
|
+
for choice in response.choices
|
511
|
+
]
|
515
512
|
|
516
513
|
|
517
514
|
def choice_stop_reason(choice: MistralChatCompletionChoice) -> StopReason:
|
@@ -33,6 +33,7 @@ from .._model_call import ModelCall
|
|
33
33
|
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
34
34
|
from .._openai import (
|
35
35
|
OpenAIResponseError,
|
36
|
+
is_computer_use_preview,
|
36
37
|
is_gpt,
|
37
38
|
is_o1_mini,
|
38
39
|
is_o1_preview,
|
@@ -45,10 +46,7 @@ from .._openai import (
|
|
45
46
|
openai_media_filter,
|
46
47
|
)
|
47
48
|
from .openai_o1 import generate_o1
|
48
|
-
from .util import
|
49
|
-
environment_prerequisite_error,
|
50
|
-
model_base_url,
|
51
|
-
)
|
49
|
+
from .util import environment_prerequisite_error, model_base_url
|
52
50
|
|
53
51
|
logger = getLogger(__name__)
|
54
52
|
|
@@ -77,9 +75,6 @@ class OpenAIAPI(ModelAPI):
|
|
77
75
|
else:
|
78
76
|
self.service = None
|
79
77
|
|
80
|
-
# note whether we are forcing the responses_api
|
81
|
-
self.responses_api = True if responses_api else False
|
82
|
-
|
83
78
|
# call super
|
84
79
|
super().__init__(
|
85
80
|
model_name=model_name,
|
@@ -89,6 +84,11 @@ class OpenAIAPI(ModelAPI):
|
|
89
84
|
config=config,
|
90
85
|
)
|
91
86
|
|
87
|
+
# note whether we are forcing the responses_api
|
88
|
+
self.responses_api = (
|
89
|
+
responses_api or self.is_o1_pro() or self.is_computer_use_preview()
|
90
|
+
)
|
91
|
+
|
92
92
|
# resolve api_key
|
93
93
|
if not self.api_key:
|
94
94
|
if self.service == "azure":
|
@@ -128,10 +128,14 @@ class OpenAIAPI(ModelAPI):
|
|
128
128
|
)
|
129
129
|
|
130
130
|
# resolve version
|
131
|
-
|
132
|
-
"
|
133
|
-
|
134
|
-
|
131
|
+
if model_args.get("api_version") is not None:
|
132
|
+
# use slightly complicated logic to allow for "api_version" to be removed
|
133
|
+
api_version = model_args.pop("api_version")
|
134
|
+
else:
|
135
|
+
api_version = os.environ.get(
|
136
|
+
"AZUREAI_OPENAI_API_VERSION",
|
137
|
+
os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
|
138
|
+
)
|
135
139
|
|
136
140
|
self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
|
137
141
|
api_key=self.api_key,
|
@@ -166,13 +170,33 @@ class OpenAIAPI(ModelAPI):
|
|
166
170
|
def is_o1_preview(self) -> bool:
|
167
171
|
return is_o1_preview(self.model_name)
|
168
172
|
|
173
|
+
def is_computer_use_preview(self) -> bool:
|
174
|
+
return is_computer_use_preview(self.model_name)
|
175
|
+
|
169
176
|
def is_gpt(self) -> bool:
|
170
177
|
return is_gpt(self.model_name)
|
171
178
|
|
172
179
|
@override
|
173
|
-
async def
|
180
|
+
async def aclose(self) -> None:
|
174
181
|
await self.client.close()
|
175
182
|
|
183
|
+
@override
|
184
|
+
def emulate_reasoning_history(self) -> bool:
|
185
|
+
return not self.responses_api
|
186
|
+
|
187
|
+
@override
|
188
|
+
def tool_result_images(self) -> bool:
|
189
|
+
# o1-pro, o1, and computer_use_preview support image inputs (but we're not strictly supporting o1)
|
190
|
+
return self.is_o1_pro() or self.is_computer_use_preview()
|
191
|
+
|
192
|
+
@override
|
193
|
+
def disable_computer_screenshot_truncation(self) -> bool:
|
194
|
+
# Because ComputerCallOutput has a required output field of type
|
195
|
+
# ResponseComputerToolCallOutputScreenshot, we must have an image in
|
196
|
+
# order to provide a valid tool call response. Therefore, we cannot
|
197
|
+
# support image truncation.
|
198
|
+
return True
|
199
|
+
|
176
200
|
async def generate(
|
177
201
|
self,
|
178
202
|
input: list[ChatMessage],
|
@@ -188,7 +212,7 @@ class OpenAIAPI(ModelAPI):
|
|
188
212
|
tools=tools,
|
189
213
|
**self.completion_params(config, False),
|
190
214
|
)
|
191
|
-
elif self.
|
215
|
+
elif self.responses_api:
|
192
216
|
return await generate_responses(
|
193
217
|
client=self.client,
|
194
218
|
http_hooks=self._http_hooks,
|
@@ -344,10 +368,7 @@ class OpenAIAPI(ModelAPI):
|
|
344
368
|
params["top_p"] = config.top_p
|
345
369
|
if config.num_choices is not None:
|
346
370
|
params["n"] = config.num_choices
|
347
|
-
|
348
|
-
params["logprobs"] = config.logprobs
|
349
|
-
if config.top_logprobs is not None:
|
350
|
-
params["top_logprobs"] = config.top_logprobs
|
371
|
+
params = self.set_logprobs_params(params, config)
|
351
372
|
if tools and config.parallel_tool_calls is not None and not self.is_o_series():
|
352
373
|
params["parallel_tool_calls"] = config.parallel_tool_calls
|
353
374
|
if (
|
@@ -372,6 +393,15 @@ class OpenAIAPI(ModelAPI):
|
|
372
393
|
|
373
394
|
return params
|
374
395
|
|
396
|
+
def set_logprobs_params(
|
397
|
+
self, params: dict[str, Any], config: GenerateConfig
|
398
|
+
) -> dict[str, Any]:
|
399
|
+
if config.logprobs is not None:
|
400
|
+
params["logprobs"] = config.logprobs
|
401
|
+
if config.top_logprobs is not None:
|
402
|
+
params["top_logprobs"] = config.top_logprobs
|
403
|
+
return params
|
404
|
+
|
375
405
|
|
376
406
|
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
377
407
|
"""Custom async client that deals better with long running Async requests.
|
@@ -40,7 +40,7 @@ async def generate_o1(
|
|
40
40
|
**params: Any,
|
41
41
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
42
42
|
# create chatapi handler
|
43
|
-
handler = O1PreviewChatAPIHandler()
|
43
|
+
handler = O1PreviewChatAPIHandler(model)
|
44
44
|
|
45
45
|
# call model
|
46
46
|
request = dict(
|
@@ -155,6 +155,9 @@ TOOL_CALL = "tool_call"
|
|
155
155
|
|
156
156
|
|
157
157
|
class O1PreviewChatAPIHandler(ChatAPIHandler):
|
158
|
+
def __init__(self, model: str) -> None:
|
159
|
+
self.model = model
|
160
|
+
|
158
161
|
@override
|
159
162
|
def input_with_tools(
|
160
163
|
self, input: list[ChatMessage], tools: list[ToolInfo]
|
@@ -234,12 +237,17 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
|
|
234
237
|
|
235
238
|
# return the message
|
236
239
|
return ChatMessageAssistant(
|
237
|
-
content=content,
|
240
|
+
content=content,
|
241
|
+
tool_calls=tool_calls,
|
242
|
+
model=self.model,
|
243
|
+
source="generate",
|
238
244
|
)
|
239
245
|
|
240
246
|
# otherwise this is just an ordinary assistant message
|
241
247
|
else:
|
242
|
-
return ChatMessageAssistant(
|
248
|
+
return ChatMessageAssistant(
|
249
|
+
content=response, model=self.model, source="generate"
|
250
|
+
)
|
243
251
|
|
244
252
|
@override
|
245
253
|
def assistant_message(self, message: ChatMessageAssistant) -> ChatAPIMessage:
|
@@ -328,6 +336,5 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
|
|
328
336
|
id="unknown",
|
329
337
|
function="unknown",
|
330
338
|
arguments={},
|
331
|
-
type="function",
|
332
339
|
parse_error=parse_error,
|
333
340
|
)
|