inspect-ai 0.3.76__py3-none-any.whl → 0.3.78__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_display/core/results.py +6 -1
- inspect_ai/_eval/eval.py +8 -1
- inspect_ai/_eval/evalset.py +3 -0
- inspect_ai/_eval/run.py +3 -2
- inspect_ai/_util/content.py +3 -0
- inspect_ai/_view/www/dist/assets/index.js +18 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +22 -4
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +1 -1
- inspect_ai/model/_openai.py +67 -4
- inspect_ai/model/_openai_responses.py +283 -0
- inspect_ai/model/_providers/anthropic.py +1 -0
- inspect_ai/model/_providers/azureai.py +2 -2
- inspect_ai/model/_providers/mistral.py +29 -13
- inspect_ai/model/_providers/openai.py +53 -49
- inspect_ai/model/_providers/openai_responses.py +177 -0
- inspect_ai/model/_providers/openrouter.py +52 -2
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/tool/__init__.py +2 -0
- inspect_ai/tool/_tool.py +23 -3
- inspect_ai/tool/_tools/_think.py +48 -0
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/RECORD +27 -25
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/WHEEL +1 -1
- inspect_ai/model/_image.py +0 -15
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.76.dist-info → inspect_ai-0.3.78.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,283 @@
|
|
1
|
+
import json
|
2
|
+
|
3
|
+
from openai.types.responses import (
|
4
|
+
FunctionToolParam,
|
5
|
+
Response,
|
6
|
+
ResponseFunctionToolCall,
|
7
|
+
ResponseFunctionToolCallParam,
|
8
|
+
ResponseInputContentParam,
|
9
|
+
ResponseInputImageParam,
|
10
|
+
ResponseInputItemParam,
|
11
|
+
ResponseInputMessageContentListParam,
|
12
|
+
ResponseInputTextParam,
|
13
|
+
ResponseOutputMessage,
|
14
|
+
ResponseOutputMessageParam,
|
15
|
+
ResponseOutputRefusalParam,
|
16
|
+
ResponseOutputText,
|
17
|
+
ResponseOutputTextParam,
|
18
|
+
ResponseReasoningItem,
|
19
|
+
ResponseReasoningItemParam,
|
20
|
+
ToolChoiceFunctionParam,
|
21
|
+
ToolParam,
|
22
|
+
)
|
23
|
+
from openai.types.responses.response_create_params import (
|
24
|
+
ToolChoice as ResponsesToolChoice,
|
25
|
+
)
|
26
|
+
from openai.types.responses.response_input_item_param import FunctionCallOutput, Message
|
27
|
+
from openai.types.responses.response_reasoning_item_param import Summary
|
28
|
+
|
29
|
+
from inspect_ai._util.content import (
|
30
|
+
Content,
|
31
|
+
ContentImage,
|
32
|
+
ContentReasoning,
|
33
|
+
ContentText,
|
34
|
+
)
|
35
|
+
from inspect_ai._util.images import file_as_data_uri
|
36
|
+
from inspect_ai._util.url import is_http_url
|
37
|
+
from inspect_ai.model._call_tools import parse_tool_call
|
38
|
+
from inspect_ai.model._model_output import ChatCompletionChoice, StopReason
|
39
|
+
from inspect_ai.model._openai import is_o_series
|
40
|
+
from inspect_ai.tool._tool_call import ToolCall
|
41
|
+
from inspect_ai.tool._tool_choice import ToolChoice
|
42
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
43
|
+
|
44
|
+
from ._chat_message import ChatMessage, ChatMessageAssistant
|
45
|
+
|
46
|
+
|
47
|
+
async def openai_responses_inputs(
|
48
|
+
messages: list[ChatMessage], model: str
|
49
|
+
) -> list[ResponseInputItemParam]:
|
50
|
+
responses_inputs: list[ResponseInputItemParam] = []
|
51
|
+
for message in messages:
|
52
|
+
responses_inputs.extend(await openai_responses_input(message, model))
|
53
|
+
return responses_inputs
|
54
|
+
|
55
|
+
|
56
|
+
async def openai_responses_input(
|
57
|
+
message: ChatMessage, model: str
|
58
|
+
) -> list[ResponseInputItemParam]:
|
59
|
+
if message.role == "system":
|
60
|
+
content = await openai_responses_content_list_param(message.content)
|
61
|
+
if is_o_series(model):
|
62
|
+
return [Message(type="message", role="developer", content=content)]
|
63
|
+
else:
|
64
|
+
return [Message(type="message", role="system", content=content)]
|
65
|
+
elif message.role == "user":
|
66
|
+
return [
|
67
|
+
Message(
|
68
|
+
type="message",
|
69
|
+
role="user",
|
70
|
+
content=await openai_responses_content_list_param(message.content),
|
71
|
+
)
|
72
|
+
]
|
73
|
+
elif message.role == "assistant":
|
74
|
+
reasoning_content = openai_responses_reasponing_content_params(message.content)
|
75
|
+
if message.content:
|
76
|
+
formatted_id = str(message.id).replace("resp_", "msg_", 1)
|
77
|
+
if not formatted_id.startswith("msg_"):
|
78
|
+
# These messages MUST start with `msg_`.
|
79
|
+
# As `store=False` for this provider, OpenAI doesn't validate the IDs.
|
80
|
+
# This will keep them consistent across calls though.
|
81
|
+
formatted_id = f"msg_{formatted_id}"
|
82
|
+
text_content = [
|
83
|
+
ResponseOutputMessageParam(
|
84
|
+
type="message",
|
85
|
+
role="assistant",
|
86
|
+
id=formatted_id,
|
87
|
+
content=openai_responses_text_content_params(message.content),
|
88
|
+
status="completed",
|
89
|
+
)
|
90
|
+
]
|
91
|
+
else:
|
92
|
+
text_content = []
|
93
|
+
tools_content = openai_responses_tools_content_params(message.tool_calls)
|
94
|
+
return reasoning_content + text_content + tools_content
|
95
|
+
elif message.role == "tool":
|
96
|
+
# TODO: Return ouptut types for internal tools e.g. computer, web_search
|
97
|
+
if message.error is not None:
|
98
|
+
output = message.error.message
|
99
|
+
else:
|
100
|
+
output = message.text
|
101
|
+
return [
|
102
|
+
FunctionCallOutput(
|
103
|
+
type="function_call_output",
|
104
|
+
call_id=message.tool_call_id or str(message.function),
|
105
|
+
output=output,
|
106
|
+
)
|
107
|
+
]
|
108
|
+
else:
|
109
|
+
raise ValueError(f"Unexpected message role '{message.role}'")
|
110
|
+
|
111
|
+
|
112
|
+
async def openai_responses_content_list_param(
|
113
|
+
content: str | list[Content],
|
114
|
+
) -> ResponseInputMessageContentListParam:
|
115
|
+
if isinstance(content, str):
|
116
|
+
content = [ContentText(text=content)]
|
117
|
+
return [await openai_responses_content_param(c) for c in content]
|
118
|
+
|
119
|
+
|
120
|
+
async def openai_responses_content_param(content: Content) -> ResponseInputContentParam: # type: ignore[return]
|
121
|
+
if isinstance(content, ContentText):
|
122
|
+
return ResponseInputTextParam(type="input_text", text=content.text)
|
123
|
+
elif isinstance(content, ContentImage):
|
124
|
+
image_url = content.image
|
125
|
+
if not is_http_url(image_url):
|
126
|
+
image_url = await file_as_data_uri(image_url)
|
127
|
+
|
128
|
+
return ResponseInputImageParam(
|
129
|
+
type="input_image", detail=content.detail, image_url=image_url
|
130
|
+
)
|
131
|
+
else:
|
132
|
+
# TODO: support for files (PDFs) and audio and video whenever
|
133
|
+
# that is supported by the responses API (was not on initial release)
|
134
|
+
|
135
|
+
# TODO: note that when doing this we should ensure that the
|
136
|
+
# openai_media_filter is properly screening out base64 encoded
|
137
|
+
# audio and video (if it exists, looks like it may all be done
|
138
|
+
# w/ file uploads in the responses API)
|
139
|
+
|
140
|
+
raise ValueError("Unsupported content type.")
|
141
|
+
|
142
|
+
|
143
|
+
def openai_responses_reasponing_content_params(
|
144
|
+
content: str | list[Content],
|
145
|
+
) -> list[ResponseInputItemParam]:
|
146
|
+
if isinstance(content, list):
|
147
|
+
return [
|
148
|
+
ResponseReasoningItemParam(
|
149
|
+
type="reasoning",
|
150
|
+
id=str(c.signature),
|
151
|
+
summary=[Summary(type="summary_text", text=c.reasoning)],
|
152
|
+
)
|
153
|
+
for c in content
|
154
|
+
if isinstance(c, ContentReasoning)
|
155
|
+
]
|
156
|
+
else:
|
157
|
+
return []
|
158
|
+
|
159
|
+
|
160
|
+
def openai_responses_text_content_params(
|
161
|
+
content: str | list[Content],
|
162
|
+
) -> list[ResponseOutputTextParam | ResponseOutputRefusalParam]:
|
163
|
+
if isinstance(content, str):
|
164
|
+
content = [ContentText(text=content)]
|
165
|
+
|
166
|
+
params: list[ResponseOutputTextParam | ResponseOutputRefusalParam] = []
|
167
|
+
|
168
|
+
for c in content:
|
169
|
+
if isinstance(c, ContentText):
|
170
|
+
if c.refusal:
|
171
|
+
params.append(
|
172
|
+
ResponseOutputRefusalParam(type="refusal", refusal=c.text)
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
params.append(
|
176
|
+
ResponseOutputTextParam(
|
177
|
+
type="output_text", text=c.text, annotations=[]
|
178
|
+
)
|
179
|
+
)
|
180
|
+
|
181
|
+
return params
|
182
|
+
|
183
|
+
|
184
|
+
def openai_responses_tools_content_params(
|
185
|
+
tool_calls: list[ToolCall] | None,
|
186
|
+
) -> list[ResponseInputItemParam]:
|
187
|
+
if tool_calls is not None:
|
188
|
+
return [
|
189
|
+
ResponseFunctionToolCallParam(
|
190
|
+
type="function_call",
|
191
|
+
call_id=call.id,
|
192
|
+
name=call.function,
|
193
|
+
arguments=json.dumps(call.arguments),
|
194
|
+
status="completed",
|
195
|
+
)
|
196
|
+
for call in tool_calls
|
197
|
+
]
|
198
|
+
else:
|
199
|
+
return []
|
200
|
+
|
201
|
+
|
202
|
+
def openai_responses_tool_choice(tool_choice: ToolChoice) -> ResponsesToolChoice:
|
203
|
+
match tool_choice:
|
204
|
+
case "none" | "auto":
|
205
|
+
return tool_choice
|
206
|
+
case "any":
|
207
|
+
return "required"
|
208
|
+
# TODO: internal tools need to be converted to ToolChoiceTypesParam
|
209
|
+
case _:
|
210
|
+
return ToolChoiceFunctionParam(type="function", name=tool_choice.name)
|
211
|
+
|
212
|
+
|
213
|
+
def openai_responses_tools(tools: list[ToolInfo]) -> list[ToolParam]:
|
214
|
+
# TODO: return special types for internal tools
|
215
|
+
return [
|
216
|
+
FunctionToolParam(
|
217
|
+
type="function",
|
218
|
+
name=tool.name,
|
219
|
+
description=tool.description,
|
220
|
+
parameters=tool.parameters.model_dump(exclude_none=True),
|
221
|
+
strict=False, # default parameters don't work in strict mode
|
222
|
+
)
|
223
|
+
for tool in tools
|
224
|
+
]
|
225
|
+
|
226
|
+
|
227
|
+
def openai_responses_chat_choices(
|
228
|
+
response: Response, tools: list[ToolInfo]
|
229
|
+
) -> list[ChatCompletionChoice]:
|
230
|
+
# determine the StopReason
|
231
|
+
stop_reason: StopReason = "stop"
|
232
|
+
if response.incomplete_details is not None:
|
233
|
+
if response.incomplete_details.reason == "max_output_tokens":
|
234
|
+
stop_reason = "max_tokens"
|
235
|
+
elif response.incomplete_details.reason == "content_filter":
|
236
|
+
stop_reason = "content_filter"
|
237
|
+
|
238
|
+
# collect output and tool calls
|
239
|
+
message_content: list[Content] = []
|
240
|
+
tool_calls: list[ToolCall] = []
|
241
|
+
for output in response.output:
|
242
|
+
if isinstance(output, ResponseOutputMessage):
|
243
|
+
for content in output.content:
|
244
|
+
if isinstance(content, ResponseOutputText):
|
245
|
+
message_content.append(ContentText(text=content.text))
|
246
|
+
else:
|
247
|
+
message_content.append(
|
248
|
+
ContentText(text=content.refusal, refusal=True)
|
249
|
+
)
|
250
|
+
elif isinstance(output, ResponseReasoningItem):
|
251
|
+
reasoning = "\n".join([summary.text for summary in output.summary])
|
252
|
+
if reasoning:
|
253
|
+
message_content.append(
|
254
|
+
ContentReasoning(signature=output.id, reasoning=reasoning)
|
255
|
+
)
|
256
|
+
else:
|
257
|
+
stop_reason = "tool_calls"
|
258
|
+
if isinstance(output, ResponseFunctionToolCall):
|
259
|
+
tool_calls.append(
|
260
|
+
parse_tool_call(
|
261
|
+
output.call_id,
|
262
|
+
output.name,
|
263
|
+
output.arguments,
|
264
|
+
tools,
|
265
|
+
)
|
266
|
+
)
|
267
|
+
pass
|
268
|
+
else:
|
269
|
+
## TODO: implement support for internal tools
|
270
|
+
raise ValueError(f"Unexpected output type: {output.__class__}")
|
271
|
+
|
272
|
+
# return choice
|
273
|
+
return [
|
274
|
+
ChatCompletionChoice(
|
275
|
+
message=ChatMessageAssistant(
|
276
|
+
id=response.id,
|
277
|
+
content=message_content,
|
278
|
+
tool_calls=tool_calls if len(tool_calls) > 0 else None,
|
279
|
+
source="generate",
|
280
|
+
),
|
281
|
+
stop_reason=stop_reason,
|
282
|
+
)
|
283
|
+
]
|
@@ -51,7 +51,6 @@ from .._chat_message import (
|
|
51
51
|
ChatMessageUser,
|
52
52
|
)
|
53
53
|
from .._generate_config import GenerateConfig
|
54
|
-
from .._image import image_url_filter
|
55
54
|
from .._model import ModelAPI
|
56
55
|
from .._model_call import ModelCall
|
57
56
|
from .._model_output import (
|
@@ -60,6 +59,7 @@ from .._model_output import (
|
|
60
59
|
ModelUsage,
|
61
60
|
StopReason,
|
62
61
|
)
|
62
|
+
from .._openai import openai_media_filter
|
63
63
|
from .util import (
|
64
64
|
environment_prerequisite_error,
|
65
65
|
model_base_url,
|
@@ -182,7 +182,7 @@ class AzureAIAPI(ModelAPI):
|
|
182
182
|
else None,
|
183
183
|
),
|
184
184
|
response=response.as_dict() if response else {},
|
185
|
-
filter=
|
185
|
+
filter=openai_media_filter,
|
186
186
|
)
|
187
187
|
|
188
188
|
# make call
|
@@ -82,6 +82,14 @@ class MistralAPI(ModelAPI):
|
|
82
82
|
config: GenerateConfig = GenerateConfig(),
|
83
83
|
**model_args: Any,
|
84
84
|
):
|
85
|
+
# extract any service prefix from model name
|
86
|
+
parts = model_name.split("/")
|
87
|
+
if len(parts) > 1:
|
88
|
+
self.service: str | None = parts[0]
|
89
|
+
model_name = "/".join(parts[1:])
|
90
|
+
else:
|
91
|
+
self.service = None
|
92
|
+
|
85
93
|
super().__init__(
|
86
94
|
model_name=model_name,
|
87
95
|
base_url=base_url,
|
@@ -94,31 +102,39 @@ class MistralAPI(ModelAPI):
|
|
94
102
|
config=config,
|
95
103
|
)
|
96
104
|
|
97
|
-
# resolve api_key
|
105
|
+
# resolve api_key
|
98
106
|
if not self.api_key:
|
99
|
-
self.
|
100
|
-
if self.api_key:
|
101
|
-
base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
|
102
|
-
else:
|
107
|
+
if self.is_azure():
|
103
108
|
self.api_key = os.environ.get(
|
104
109
|
AZUREAI_MISTRAL_API_KEY, os.environ.get(AZURE_MISTRAL_API_KEY, None)
|
105
110
|
)
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
111
|
+
else:
|
112
|
+
self.api_key = os.environ.get(MISTRAL_API_KEY, None)
|
113
|
+
|
114
|
+
if not self.api_key:
|
115
|
+
raise environment_prerequisite_error(
|
116
|
+
"Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
|
117
|
+
)
|
118
|
+
|
119
|
+
if not self.base_url:
|
120
|
+
if self.is_azure():
|
121
|
+
self.base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
|
122
|
+
if not self.base_url:
|
112
123
|
raise ValueError(
|
113
124
|
"You must provide a base URL when using Mistral on Azure. Use the AZUREAI_MISTRAL_BASE_URL "
|
114
125
|
+ " environment variable or the --model-base-url CLI flag to set the base URL."
|
115
126
|
)
|
127
|
+
else:
|
128
|
+
self.base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
|
116
129
|
|
117
|
-
if base_url:
|
118
|
-
model_args["server_url"] = base_url
|
130
|
+
if self.base_url:
|
131
|
+
model_args["server_url"] = self.base_url
|
119
132
|
|
120
133
|
self.model_args = model_args
|
121
134
|
|
135
|
+
def is_azure(self) -> bool:
|
136
|
+
return self.service == "azure"
|
137
|
+
|
122
138
|
@override
|
123
139
|
async def close(self) -> None:
|
124
140
|
# client is created and destroyed in generate
|
@@ -22,28 +22,27 @@ from inspect_ai._util.error import PrerequisiteError
|
|
22
22
|
from inspect_ai._util.http import is_retryable_http_status
|
23
23
|
from inspect_ai._util.logger import warn_once
|
24
24
|
from inspect_ai.model._openai import chat_choices_from_openai
|
25
|
+
from inspect_ai.model._providers.openai_responses import generate_responses
|
25
26
|
from inspect_ai.model._providers.util.hooks import HttpxHooks
|
26
27
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
27
28
|
|
28
29
|
from .._chat_message import ChatMessage
|
29
30
|
from .._generate_config import GenerateConfig
|
30
|
-
from .._image import image_url_filter
|
31
31
|
from .._model import ModelAPI
|
32
32
|
from .._model_call import ModelCall
|
33
|
-
from .._model_output import
|
34
|
-
ChatCompletionChoice,
|
35
|
-
ModelOutput,
|
36
|
-
ModelUsage,
|
37
|
-
StopReason,
|
38
|
-
)
|
33
|
+
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
|
39
34
|
from .._openai import (
|
35
|
+
OpenAIResponseError,
|
40
36
|
is_gpt,
|
41
37
|
is_o1_mini,
|
42
38
|
is_o1_preview,
|
39
|
+
is_o1_pro,
|
43
40
|
is_o_series,
|
44
41
|
openai_chat_messages,
|
45
42
|
openai_chat_tool_choice,
|
46
43
|
openai_chat_tools,
|
44
|
+
openai_handle_bad_request,
|
45
|
+
openai_media_filter,
|
47
46
|
)
|
48
47
|
from .openai_o1 import generate_o1
|
49
48
|
from .util import (
|
@@ -65,6 +64,7 @@ class OpenAIAPI(ModelAPI):
|
|
65
64
|
base_url: str | None = None,
|
66
65
|
api_key: str | None = None,
|
67
66
|
config: GenerateConfig = GenerateConfig(),
|
67
|
+
responses_api: bool | None = None,
|
68
68
|
**model_args: Any,
|
69
69
|
) -> None:
|
70
70
|
# extract azure service prefix from model name (other providers
|
@@ -77,6 +77,9 @@ class OpenAIAPI(ModelAPI):
|
|
77
77
|
else:
|
78
78
|
self.service = None
|
79
79
|
|
80
|
+
# note whether we are forcing the responses_api
|
81
|
+
self.responses_api = True if responses_api else False
|
82
|
+
|
80
83
|
# call super
|
81
84
|
super().__init__(
|
82
85
|
model_name=model_name,
|
@@ -88,22 +91,21 @@ class OpenAIAPI(ModelAPI):
|
|
88
91
|
|
89
92
|
# resolve api_key
|
90
93
|
if not self.api_key:
|
91
|
-
self.
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
|
96
|
-
self.service = "azure"
|
94
|
+
if self.service == "azure":
|
95
|
+
self.api_key = os.environ.get(
|
96
|
+
AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
|
97
|
+
)
|
97
98
|
else:
|
98
99
|
self.api_key = os.environ.get(OPENAI_API_KEY, None)
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
100
|
+
|
101
|
+
if not self.api_key:
|
102
|
+
raise environment_prerequisite_error(
|
103
|
+
"OpenAI",
|
104
|
+
[
|
105
|
+
OPENAI_API_KEY,
|
106
|
+
AZUREAI_OPENAI_API_KEY,
|
107
|
+
],
|
108
|
+
)
|
107
109
|
|
108
110
|
# create async http client
|
109
111
|
http_client = OpenAIAsyncHttpxClient()
|
@@ -125,10 +127,16 @@ class OpenAIAPI(ModelAPI):
|
|
125
127
|
+ "environment variable or the --model-base-url CLI flag to set the base URL."
|
126
128
|
)
|
127
129
|
|
130
|
+
# resolve version
|
131
|
+
api_version = os.environ.get(
|
132
|
+
"AZUREAI_OPENAI_API_VERSION",
|
133
|
+
os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
|
134
|
+
)
|
135
|
+
|
128
136
|
self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
|
129
137
|
api_key=self.api_key,
|
138
|
+
api_version=api_version,
|
130
139
|
azure_endpoint=base_url,
|
131
|
-
azure_deployment=model_name,
|
132
140
|
http_client=http_client,
|
133
141
|
**model_args,
|
134
142
|
)
|
@@ -149,6 +157,9 @@ class OpenAIAPI(ModelAPI):
|
|
149
157
|
def is_o_series(self) -> bool:
|
150
158
|
return is_o_series(self.model_name)
|
151
159
|
|
160
|
+
def is_o1_pro(self) -> bool:
|
161
|
+
return is_o1_pro(self.model_name)
|
162
|
+
|
152
163
|
def is_o1_mini(self) -> bool:
|
153
164
|
return is_o1_mini(self.model_name)
|
154
165
|
|
@@ -177,6 +188,16 @@ class OpenAIAPI(ModelAPI):
|
|
177
188
|
tools=tools,
|
178
189
|
**self.completion_params(config, False),
|
179
190
|
)
|
191
|
+
elif self.is_o1_pro() or self.responses_api:
|
192
|
+
return await generate_responses(
|
193
|
+
client=self.client,
|
194
|
+
http_hooks=self._http_hooks,
|
195
|
+
model_name=self.model_name,
|
196
|
+
input=input,
|
197
|
+
tools=tools,
|
198
|
+
tool_choice=tool_choice,
|
199
|
+
config=config,
|
200
|
+
)
|
180
201
|
|
181
202
|
# allocate request_id (so we can see it from ModelCall)
|
182
203
|
request_id = self._http_hooks.start_request()
|
@@ -189,7 +210,7 @@ class OpenAIAPI(ModelAPI):
|
|
189
210
|
return ModelCall.create(
|
190
211
|
request=request,
|
191
212
|
response=response,
|
192
|
-
filter=
|
213
|
+
filter=openai_media_filter,
|
193
214
|
time=self._http_hooks.end_request(request_id),
|
194
215
|
)
|
195
216
|
|
@@ -221,6 +242,7 @@ class OpenAIAPI(ModelAPI):
|
|
221
242
|
|
222
243
|
# save response for model_call
|
223
244
|
response = completion.model_dump()
|
245
|
+
self.on_response(response)
|
224
246
|
|
225
247
|
# parse out choices
|
226
248
|
choices = self._chat_choices_from_response(completion, tools)
|
@@ -252,6 +274,12 @@ class OpenAIAPI(ModelAPI):
|
|
252
274
|
except BadRequestError as e:
|
253
275
|
return self.handle_bad_request(e), model_call()
|
254
276
|
|
277
|
+
def on_response(self, response: dict[str, Any]) -> None:
|
278
|
+
pass
|
279
|
+
|
280
|
+
def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
|
281
|
+
return openai_handle_bad_request(self.model_name, ex)
|
282
|
+
|
255
283
|
def _chat_choices_from_response(
|
256
284
|
self, response: ChatCompletion, tools: list[ToolInfo]
|
257
285
|
) -> list[ChatCompletionChoice]:
|
@@ -270,6 +298,8 @@ class OpenAIAPI(ModelAPI):
|
|
270
298
|
return True
|
271
299
|
elif isinstance(ex, APIStatusError):
|
272
300
|
return is_retryable_http_status(ex.status_code)
|
301
|
+
elif isinstance(ex, OpenAIResponseError):
|
302
|
+
return ex.code in ["rate_limit_exceeded", "server_error"]
|
273
303
|
elif isinstance(ex, APITimeoutError):
|
274
304
|
return True
|
275
305
|
else:
|
@@ -342,32 +372,6 @@ class OpenAIAPI(ModelAPI):
|
|
342
372
|
|
343
373
|
return params
|
344
374
|
|
345
|
-
# convert some well known bad request errors into ModelOutput
|
346
|
-
def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
|
347
|
-
# extract message
|
348
|
-
if isinstance(e.body, dict) and "message" in e.body.keys():
|
349
|
-
content = str(e.body.get("message"))
|
350
|
-
else:
|
351
|
-
content = e.message
|
352
|
-
|
353
|
-
# narrow stop_reason
|
354
|
-
stop_reason: StopReason | None = None
|
355
|
-
if e.code == "context_length_exceeded":
|
356
|
-
stop_reason = "model_length"
|
357
|
-
elif (
|
358
|
-
e.code == "invalid_prompt" # seems to happen for o1/o3
|
359
|
-
or e.code == "content_policy_violation" # seems to happen for vision
|
360
|
-
or e.code == "content_filter" # seems to happen on azure
|
361
|
-
):
|
362
|
-
stop_reason = "content_filter"
|
363
|
-
|
364
|
-
if stop_reason:
|
365
|
-
return ModelOutput.from_content(
|
366
|
-
model=self.model_name, content=content, stop_reason=stop_reason
|
367
|
-
)
|
368
|
-
else:
|
369
|
-
return e
|
370
|
-
|
371
375
|
|
372
376
|
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
373
377
|
"""Custom async client that deals better with long running Async requests.
|