inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -8
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/generate.py +41 -35
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +21 -15
- inspect_ai/_util/hooks.py +17 -7
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_chat_message.py +2 -2
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model.py +90 -25
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -14
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +4 -0
- inspect_ai/solver/_basic_agent.py +65 -55
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/{util → solver}/_limit.py +13 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +37 -7
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
- inspect_ai/util/__init__.py +0 -2
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/self_check.py +51 -28
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -33,7 +33,6 @@ from inspect_ai._util.trace import trace_action
|
|
33
33
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
34
34
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
35
35
|
from inspect_ai.util import concurrency
|
36
|
-
from inspect_ai.util._limit import SampleLimitExceededError
|
37
36
|
|
38
37
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
39
38
|
from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
|
@@ -764,40 +763,104 @@ def resolve_tool_model_input(
|
|
764
763
|
def tool_result_images_as_user_message(
|
765
764
|
messages: list[ChatMessage],
|
766
765
|
) -> list[ChatMessage]:
|
767
|
-
|
766
|
+
"""
|
767
|
+
To conform to models lacking support for images in tool responses, create an alternate message history that moves images into a fabricated user message.
|
768
|
+
|
769
|
+
Tool responses will have images replaced with "Image content is included below.", and the new user message will contain the images.
|
770
|
+
"""
|
771
|
+
init_accum: ImagesAccumulator = ([], [], [])
|
772
|
+
chat_messages, user_message_content, tool_call_ids = functools.reduce(
|
773
|
+
tool_result_images_reducer, messages, init_accum
|
774
|
+
)
|
775
|
+
# if the last message was a tool result, we may need to flush the pending stuff here
|
776
|
+
return maybe_adding_user_message(chat_messages, user_message_content, tool_call_ids)
|
777
|
+
|
778
|
+
|
779
|
+
ImagesAccumulator = tuple[list[ChatMessage], list[Content], list[str]]
|
780
|
+
"""
|
781
|
+
ImagesAccumulator is a tuple containing three lists:
|
782
|
+
- The first list contains ChatMessages that are the result of processing.
|
783
|
+
- The second list contains ContentImages that need to be inserted into a fabricated user message.
|
784
|
+
- The third list contains the tool_call_id's associated with the tool responses.
|
785
|
+
"""
|
768
786
|
|
769
787
|
|
770
788
|
def tool_result_images_reducer(
|
771
|
-
|
789
|
+
accum: ImagesAccumulator,
|
772
790
|
message: ChatMessage,
|
773
|
-
) ->
|
791
|
+
) -> ImagesAccumulator:
|
792
|
+
messages, pending_content, tool_call_ids = accum
|
774
793
|
# if there are tool result images, pull them out into a ChatUserMessage
|
775
|
-
if
|
776
|
-
|
777
|
-
|
778
|
-
|
779
|
-
|
794
|
+
if (
|
795
|
+
isinstance(message, ChatMessageTool)
|
796
|
+
and isinstance(message.content, list)
|
797
|
+
and any([isinstance(c, ContentImage) for c in message.content])
|
798
|
+
):
|
799
|
+
init_accum: ImageContentAccumulator = ([], [])
|
800
|
+
new_user_message_content, edited_tool_message_content = functools.reduce(
|
801
|
+
tool_result_image_content_reducer, message.content, init_accum
|
780
802
|
)
|
781
|
-
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
|
789
|
-
text="Image content is in the message below."
|
803
|
+
|
804
|
+
return (
|
805
|
+
messages
|
806
|
+
+ [
|
807
|
+
ChatMessageTool(
|
808
|
+
content=edited_tool_message_content,
|
809
|
+
tool_call_id=message.tool_call_id,
|
810
|
+
function=message.function,
|
790
811
|
)
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
812
|
+
],
|
813
|
+
pending_content + new_user_message_content,
|
814
|
+
tool_call_ids + ([message.tool_call_id] if message.tool_call_id else []),
|
815
|
+
)
|
795
816
|
|
796
817
|
else:
|
797
|
-
|
818
|
+
return (
|
819
|
+
maybe_adding_user_message(messages, pending_content, tool_call_ids)
|
820
|
+
+ [message],
|
821
|
+
[],
|
822
|
+
[],
|
823
|
+
)
|
798
824
|
|
799
|
-
|
800
|
-
|
825
|
+
|
826
|
+
ImageContentAccumulator = tuple[list[Content], list[Content]]
|
827
|
+
"""
|
828
|
+
ImageContentAccumulator is a tuple containing two lists of Content objects:
|
829
|
+
- The first list contains ContentImages that will be included in a fabricated user message.
|
830
|
+
- The second list contains modified content for the tool message with images replaced with text.
|
831
|
+
"""
|
832
|
+
|
833
|
+
|
834
|
+
def tool_result_image_content_reducer(
|
835
|
+
acc: ImageContentAccumulator, content: Content
|
836
|
+
) -> ImageContentAccumulator:
|
837
|
+
"""
|
838
|
+
Reduces the messages Content into two separate lists: one for a fabricated user message that will contain the images and one for modified tool message with the images replaced with text.
|
839
|
+
|
840
|
+
Returns:
|
841
|
+
ImageContentReducer: A tuple containing two lists of Content objects.
|
842
|
+
- The first list contains the images that will be included in a fabricated user message.
|
843
|
+
- The second list contains modified content for the tool message with images replaced with text.
|
844
|
+
"""
|
845
|
+
new_user_message_content, edited_tool_message_content = acc
|
846
|
+
if isinstance(content, ContentImage):
|
847
|
+
return new_user_message_content + [content], edited_tool_message_content + [
|
848
|
+
ContentText(text="Image content is included below.")
|
849
|
+
]
|
850
|
+
|
851
|
+
else:
|
852
|
+
return new_user_message_content, edited_tool_message_content + [content]
|
853
|
+
|
854
|
+
|
855
|
+
def maybe_adding_user_message(
|
856
|
+
messages: list[ChatMessage], content: list[Content], tool_call_ids: list[str]
|
857
|
+
) -> list[ChatMessage]:
|
858
|
+
"""If content is empty, return messages, otherwise, create a new ChatMessageUser with it and return a new messages list with that message added."""
|
859
|
+
return (
|
860
|
+
messages + [ChatMessageUser(content=content, tool_call_id=tool_call_ids)]
|
861
|
+
if content
|
862
|
+
else messages
|
863
|
+
)
|
801
864
|
|
802
865
|
|
803
866
|
# Functions to reduce consecutive user messages to a single user message -> required for some models
|
@@ -884,6 +947,7 @@ def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
|
|
884
947
|
active_sample_message_limit,
|
885
948
|
set_active_sample_total_messages,
|
886
949
|
)
|
950
|
+
from inspect_ai.solver._limit import SampleLimitExceededError
|
887
951
|
|
888
952
|
total_messages = 1 if isinstance(input, str) else len(input)
|
889
953
|
message_limit = active_sample_message_limit()
|
@@ -910,6 +974,7 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
|
|
910
974
|
active_sample_token_limit,
|
911
975
|
set_active_sample_total_tokens,
|
912
976
|
)
|
977
|
+
from inspect_ai.solver._limit import SampleLimitExceededError
|
913
978
|
|
914
979
|
# record usage
|
915
980
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
@@ -214,3 +214,18 @@ class ModelOutput(BaseModel):
|
|
214
214
|
)
|
215
215
|
],
|
216
216
|
)
|
217
|
+
|
218
|
+
|
219
|
+
def as_stop_reason(reason: str | None) -> StopReason:
|
220
|
+
"""Encode common reason strings into standard StopReason."""
|
221
|
+
match reason:
|
222
|
+
case "stop" | "eos":
|
223
|
+
return "stop"
|
224
|
+
case "length":
|
225
|
+
return "max_tokens"
|
226
|
+
case "tool_calls" | "function_call":
|
227
|
+
return "tool_calls"
|
228
|
+
case "content_filter" | "model_length" | "max_tokens":
|
229
|
+
return reason
|
230
|
+
case _:
|
231
|
+
return "unknown"
|
@@ -0,0 +1,383 @@
|
|
1
|
+
import json
|
2
|
+
from typing import Literal
|
3
|
+
|
4
|
+
from openai.types.chat import (
|
5
|
+
ChatCompletion,
|
6
|
+
ChatCompletionAssistantMessageParam,
|
7
|
+
ChatCompletionContentPartImageParam,
|
8
|
+
ChatCompletionContentPartInputAudioParam,
|
9
|
+
ChatCompletionContentPartParam,
|
10
|
+
ChatCompletionContentPartRefusalParam,
|
11
|
+
ChatCompletionContentPartTextParam,
|
12
|
+
ChatCompletionDeveloperMessageParam,
|
13
|
+
ChatCompletionMessage,
|
14
|
+
ChatCompletionMessageParam,
|
15
|
+
ChatCompletionMessageToolCall,
|
16
|
+
ChatCompletionMessageToolCallParam,
|
17
|
+
ChatCompletionNamedToolChoiceParam,
|
18
|
+
ChatCompletionSystemMessageParam,
|
19
|
+
ChatCompletionToolChoiceOptionParam,
|
20
|
+
ChatCompletionToolMessageParam,
|
21
|
+
ChatCompletionToolParam,
|
22
|
+
ChatCompletionUserMessageParam,
|
23
|
+
)
|
24
|
+
from openai.types.chat.chat_completion import Choice, ChoiceLogprobs
|
25
|
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
26
|
+
from openai.types.completion_usage import CompletionUsage
|
27
|
+
from openai.types.shared_params.function_definition import FunctionDefinition
|
28
|
+
|
29
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentText
|
30
|
+
from inspect_ai._util.images import file_as_data_uri
|
31
|
+
from inspect_ai._util.url import is_http_url
|
32
|
+
from inspect_ai.model._call_tools import parse_tool_call
|
33
|
+
from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
|
34
|
+
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
35
|
+
|
36
|
+
from ._chat_message import (
|
37
|
+
ChatMessage,
|
38
|
+
ChatMessageAssistant,
|
39
|
+
ChatMessageSystem,
|
40
|
+
ChatMessageTool,
|
41
|
+
ChatMessageUser,
|
42
|
+
)
|
43
|
+
from ._model_output import ModelUsage, StopReason, as_stop_reason
|
44
|
+
|
45
|
+
|
46
|
+
def is_o1(name: str) -> bool:
|
47
|
+
return name.startswith("o1")
|
48
|
+
|
49
|
+
|
50
|
+
def is_o1_full(name: str) -> bool:
|
51
|
+
return is_o1(name) and not is_o1_mini(name) and not is_o1_preview(name)
|
52
|
+
|
53
|
+
|
54
|
+
def is_o1_mini(name: str) -> bool:
|
55
|
+
return name.startswith("o1-mini")
|
56
|
+
|
57
|
+
|
58
|
+
def is_o1_preview(name: str) -> bool:
|
59
|
+
return name.startswith("o1-preview")
|
60
|
+
|
61
|
+
|
62
|
+
def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
|
63
|
+
return ChatCompletionMessageToolCall(
|
64
|
+
type="function",
|
65
|
+
id=tool_call.id,
|
66
|
+
function=Function(
|
67
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
68
|
+
),
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def openai_chat_tool_call_param(
|
73
|
+
tool_call: ToolCall,
|
74
|
+
) -> ChatCompletionMessageToolCallParam:
|
75
|
+
return ChatCompletionMessageToolCallParam(
|
76
|
+
id=tool_call.id,
|
77
|
+
function=dict(
|
78
|
+
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
79
|
+
),
|
80
|
+
type=tool_call.type,
|
81
|
+
)
|
82
|
+
|
83
|
+
|
84
|
+
async def openai_chat_completion_part(
|
85
|
+
content: Content,
|
86
|
+
) -> ChatCompletionContentPartParam:
|
87
|
+
if content.type == "text":
|
88
|
+
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
89
|
+
elif content.type == "image":
|
90
|
+
# API takes URL or base64 encoded file. If it's a remote file or
|
91
|
+
# data URL leave it alone, otherwise encode it
|
92
|
+
image_url = content.image
|
93
|
+
detail = content.detail
|
94
|
+
|
95
|
+
if not is_http_url(image_url):
|
96
|
+
image_url = await file_as_data_uri(image_url)
|
97
|
+
|
98
|
+
return ChatCompletionContentPartImageParam(
|
99
|
+
type="image_url",
|
100
|
+
image_url=dict(url=image_url, detail=detail),
|
101
|
+
)
|
102
|
+
elif content.type == "audio":
|
103
|
+
audio_data = await file_as_data_uri(content.audio)
|
104
|
+
|
105
|
+
return ChatCompletionContentPartInputAudioParam(
|
106
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
107
|
+
)
|
108
|
+
|
109
|
+
else:
|
110
|
+
raise RuntimeError(
|
111
|
+
"Video content is not currently supported by Open AI chat models."
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
async def openai_chat_message(
|
116
|
+
message: ChatMessage, model: str
|
117
|
+
) -> ChatCompletionMessageParam:
|
118
|
+
if message.role == "system":
|
119
|
+
if is_o1(model):
|
120
|
+
return ChatCompletionDeveloperMessageParam(
|
121
|
+
role="developer", content=message.text
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
return ChatCompletionSystemMessageParam(
|
125
|
+
role=message.role, content=message.text
|
126
|
+
)
|
127
|
+
elif message.role == "user":
|
128
|
+
return ChatCompletionUserMessageParam(
|
129
|
+
role=message.role,
|
130
|
+
content=(
|
131
|
+
message.content
|
132
|
+
if isinstance(message.content, str)
|
133
|
+
else [
|
134
|
+
await openai_chat_completion_part(content)
|
135
|
+
for content in message.content
|
136
|
+
]
|
137
|
+
),
|
138
|
+
)
|
139
|
+
elif message.role == "assistant":
|
140
|
+
if message.tool_calls:
|
141
|
+
return ChatCompletionAssistantMessageParam(
|
142
|
+
role=message.role,
|
143
|
+
content=message.text,
|
144
|
+
tool_calls=[
|
145
|
+
openai_chat_tool_call_param(call) for call in message.tool_calls
|
146
|
+
],
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
return ChatCompletionAssistantMessageParam(
|
150
|
+
role=message.role, content=message.text
|
151
|
+
)
|
152
|
+
elif message.role == "tool":
|
153
|
+
return ChatCompletionToolMessageParam(
|
154
|
+
role=message.role,
|
155
|
+
content=(
|
156
|
+
f"Error: {message.error.message}" if message.error else message.text
|
157
|
+
),
|
158
|
+
tool_call_id=str(message.tool_call_id),
|
159
|
+
)
|
160
|
+
else:
|
161
|
+
raise ValueError(f"Unexpected message role {message.role}")
|
162
|
+
|
163
|
+
|
164
|
+
async def openai_chat_messages(
|
165
|
+
messages: list[ChatMessage], model: str
|
166
|
+
) -> list[ChatCompletionMessageParam]:
|
167
|
+
return [await openai_chat_message(message, model) for message in messages]
|
168
|
+
|
169
|
+
|
170
|
+
def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]:
|
171
|
+
oai_choices: list[Choice] = []
|
172
|
+
|
173
|
+
for index, choice in enumerate(choices):
|
174
|
+
if isinstance(choice.message.content, str):
|
175
|
+
content = choice.message.content
|
176
|
+
else:
|
177
|
+
content = "\n".join(
|
178
|
+
[c.text for c in choice.message.content if c.type == "text"]
|
179
|
+
)
|
180
|
+
if choice.message.tool_calls:
|
181
|
+
tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls]
|
182
|
+
else:
|
183
|
+
tool_calls = None
|
184
|
+
message = ChatCompletionMessage(
|
185
|
+
role="assistant", content=content, tool_calls=tool_calls
|
186
|
+
)
|
187
|
+
oai_choices.append(
|
188
|
+
Choice(
|
189
|
+
finish_reason=openai_finish_reason(choice.stop_reason),
|
190
|
+
index=index,
|
191
|
+
message=message,
|
192
|
+
logprobs=ChoiceLogprobs(**choice.logprobs.model_dump())
|
193
|
+
if choice.logprobs is not None
|
194
|
+
else None,
|
195
|
+
)
|
196
|
+
)
|
197
|
+
|
198
|
+
return oai_choices
|
199
|
+
|
200
|
+
|
201
|
+
def openai_completion_usage(usage: ModelUsage) -> CompletionUsage:
|
202
|
+
return CompletionUsage(
|
203
|
+
completion_tokens=usage.output_tokens,
|
204
|
+
prompt_tokens=usage.input_tokens,
|
205
|
+
total_tokens=usage.total_tokens,
|
206
|
+
)
|
207
|
+
|
208
|
+
|
209
|
+
def openai_finish_reason(
|
210
|
+
stop_reason: StopReason,
|
211
|
+
) -> Literal["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
212
|
+
match stop_reason:
|
213
|
+
case "stop" | "tool_calls" | "content_filter":
|
214
|
+
return stop_reason
|
215
|
+
case "model_length":
|
216
|
+
return "length"
|
217
|
+
case _:
|
218
|
+
return "stop"
|
219
|
+
|
220
|
+
|
221
|
+
def openai_chat_tool_param(tool: ToolInfo) -> ChatCompletionToolParam:
|
222
|
+
function = FunctionDefinition(
|
223
|
+
name=tool.name,
|
224
|
+
description=tool.description,
|
225
|
+
parameters=tool.parameters.model_dump(exclude_none=True),
|
226
|
+
)
|
227
|
+
return ChatCompletionToolParam(type="function", function=function)
|
228
|
+
|
229
|
+
|
230
|
+
def openai_chat_tools(tools: list[ToolInfo]) -> list[ChatCompletionToolParam]:
|
231
|
+
return [openai_chat_tool_param(tool) for tool in tools]
|
232
|
+
|
233
|
+
|
234
|
+
def openai_chat_tool_choice(
|
235
|
+
tool_choice: ToolChoice,
|
236
|
+
) -> ChatCompletionToolChoiceOptionParam:
|
237
|
+
if isinstance(tool_choice, ToolFunction):
|
238
|
+
return ChatCompletionNamedToolChoiceParam(
|
239
|
+
type="function", function=dict(name=tool_choice.name)
|
240
|
+
)
|
241
|
+
# openai supports 'any' via the 'required' keyword
|
242
|
+
elif tool_choice == "any":
|
243
|
+
return "required"
|
244
|
+
else:
|
245
|
+
return tool_choice
|
246
|
+
|
247
|
+
|
248
|
+
def chat_tool_calls_from_openai(
|
249
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
250
|
+
) -> list[ToolCall] | None:
|
251
|
+
if message.tool_calls:
|
252
|
+
return [
|
253
|
+
parse_tool_call(call.id, call.function.name, call.function.arguments, tools)
|
254
|
+
for call in message.tool_calls
|
255
|
+
]
|
256
|
+
else:
|
257
|
+
return None
|
258
|
+
|
259
|
+
|
260
|
+
def chat_messages_from_openai(
|
261
|
+
messages: list[ChatCompletionMessageParam],
|
262
|
+
) -> list[ChatMessage]:
|
263
|
+
# track tool names by id
|
264
|
+
tool_names: dict[str, str] = {}
|
265
|
+
|
266
|
+
chat_messages: list[ChatMessage] = []
|
267
|
+
|
268
|
+
for message in messages:
|
269
|
+
if message["role"] == "system" or message["role"] == "developer":
|
270
|
+
sys_content = message["content"]
|
271
|
+
if isinstance(sys_content, str):
|
272
|
+
chat_messages.append(ChatMessageSystem(content=sys_content))
|
273
|
+
else:
|
274
|
+
chat_messages.append(
|
275
|
+
ChatMessageSystem(
|
276
|
+
content=[content_from_openai(c) for c in sys_content]
|
277
|
+
)
|
278
|
+
)
|
279
|
+
elif message["role"] == "user":
|
280
|
+
user_content = message["content"]
|
281
|
+
if isinstance(user_content, str):
|
282
|
+
chat_messages.append(ChatMessageUser(content=user_content))
|
283
|
+
else:
|
284
|
+
chat_messages.append(
|
285
|
+
ChatMessageUser(
|
286
|
+
content=[content_from_openai(c) for c in user_content]
|
287
|
+
)
|
288
|
+
)
|
289
|
+
elif message["role"] == "assistant":
|
290
|
+
# resolve content
|
291
|
+
asst_content = message["content"]
|
292
|
+
if isinstance(asst_content, str):
|
293
|
+
content: str | list[Content] = asst_content
|
294
|
+
elif asst_content is None:
|
295
|
+
content = message.get("refusal", None) or ""
|
296
|
+
else:
|
297
|
+
content = [content_from_openai(c) for c in asst_content]
|
298
|
+
|
299
|
+
# return message
|
300
|
+
if "tool_calls" in message:
|
301
|
+
tool_calls: list[ToolCall] = []
|
302
|
+
for tc in message["tool_calls"]:
|
303
|
+
tool_calls.append(tool_call_from_openai(tc))
|
304
|
+
tool_names[tc["id"]] = tc["function"]["name"]
|
305
|
+
|
306
|
+
else:
|
307
|
+
tool_calls = []
|
308
|
+
chat_messages.append(
|
309
|
+
ChatMessageAssistant(content=content, tool_calls=tool_calls or None)
|
310
|
+
)
|
311
|
+
elif message["role"] == "tool":
|
312
|
+
tool_content = message.get("content", None) or ""
|
313
|
+
if isinstance(tool_content, str):
|
314
|
+
content = tool_content
|
315
|
+
else:
|
316
|
+
content = [content_from_openai(c) for c in tool_content]
|
317
|
+
chat_messages.append(
|
318
|
+
ChatMessageTool(
|
319
|
+
content=content,
|
320
|
+
tool_call_id=message["tool_call_id"],
|
321
|
+
function=tool_names.get(message["tool_call_id"], ""),
|
322
|
+
)
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
raise ValueError(f"Unexpected message param type: {type(message)}")
|
326
|
+
|
327
|
+
return chat_messages
|
328
|
+
|
329
|
+
|
330
|
+
def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> ToolCall:
|
331
|
+
return parse_tool_call(
|
332
|
+
tool_call["id"],
|
333
|
+
tool_call["function"]["name"],
|
334
|
+
tool_call["function"]["arguments"],
|
335
|
+
)
|
336
|
+
|
337
|
+
|
338
|
+
def content_from_openai(
|
339
|
+
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
|
340
|
+
) -> Content:
|
341
|
+
if content["type"] == "text":
|
342
|
+
return ContentText(text=content["text"])
|
343
|
+
elif content["type"] == "image_url":
|
344
|
+
return ContentImage(
|
345
|
+
image=content["image_url"]["url"], detail=content["image_url"]["detail"]
|
346
|
+
)
|
347
|
+
elif content["type"] == "input_audio":
|
348
|
+
return ContentAudio(
|
349
|
+
audio=content["input_audio"]["data"],
|
350
|
+
format=content["input_audio"]["format"],
|
351
|
+
)
|
352
|
+
elif content["type"] == "refusal":
|
353
|
+
return ContentText(text=content["refusal"])
|
354
|
+
|
355
|
+
|
356
|
+
def chat_message_assistant_from_openai(
|
357
|
+
message: ChatCompletionMessage, tools: list[ToolInfo]
|
358
|
+
) -> ChatMessageAssistant:
|
359
|
+
refusal = getattr(message, "refusal", None)
|
360
|
+
return ChatMessageAssistant(
|
361
|
+
content=refusal or message.content or "",
|
362
|
+
source="generate",
|
363
|
+
tool_calls=chat_tool_calls_from_openai(message, tools),
|
364
|
+
)
|
365
|
+
|
366
|
+
|
367
|
+
def chat_choices_from_openai(
|
368
|
+
response: ChatCompletion, tools: list[ToolInfo]
|
369
|
+
) -> list[ChatCompletionChoice]:
|
370
|
+
choices = list(response.choices)
|
371
|
+
choices.sort(key=lambda c: c.index)
|
372
|
+
return [
|
373
|
+
ChatCompletionChoice(
|
374
|
+
message=chat_message_assistant_from_openai(choice.message, tools),
|
375
|
+
stop_reason=as_stop_reason(choice.finish_reason),
|
376
|
+
logprobs=(
|
377
|
+
Logprobs(**choice.logprobs.model_dump())
|
378
|
+
if choice.logprobs is not None
|
379
|
+
else None
|
380
|
+
),
|
381
|
+
)
|
382
|
+
for choice in choices
|
383
|
+
]
|