inspect-ai 0.3.76__py3-none-any.whl → 0.3.78__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,283 @@
1
+ import json
2
+
3
+ from openai.types.responses import (
4
+ FunctionToolParam,
5
+ Response,
6
+ ResponseFunctionToolCall,
7
+ ResponseFunctionToolCallParam,
8
+ ResponseInputContentParam,
9
+ ResponseInputImageParam,
10
+ ResponseInputItemParam,
11
+ ResponseInputMessageContentListParam,
12
+ ResponseInputTextParam,
13
+ ResponseOutputMessage,
14
+ ResponseOutputMessageParam,
15
+ ResponseOutputRefusalParam,
16
+ ResponseOutputText,
17
+ ResponseOutputTextParam,
18
+ ResponseReasoningItem,
19
+ ResponseReasoningItemParam,
20
+ ToolChoiceFunctionParam,
21
+ ToolParam,
22
+ )
23
+ from openai.types.responses.response_create_params import (
24
+ ToolChoice as ResponsesToolChoice,
25
+ )
26
+ from openai.types.responses.response_input_item_param import FunctionCallOutput, Message
27
+ from openai.types.responses.response_reasoning_item_param import Summary
28
+
29
+ from inspect_ai._util.content import (
30
+ Content,
31
+ ContentImage,
32
+ ContentReasoning,
33
+ ContentText,
34
+ )
35
+ from inspect_ai._util.images import file_as_data_uri
36
+ from inspect_ai._util.url import is_http_url
37
+ from inspect_ai.model._call_tools import parse_tool_call
38
+ from inspect_ai.model._model_output import ChatCompletionChoice, StopReason
39
+ from inspect_ai.model._openai import is_o_series
40
+ from inspect_ai.tool._tool_call import ToolCall
41
+ from inspect_ai.tool._tool_choice import ToolChoice
42
+ from inspect_ai.tool._tool_info import ToolInfo
43
+
44
+ from ._chat_message import ChatMessage, ChatMessageAssistant
45
+
46
+
47
+ async def openai_responses_inputs(
48
+ messages: list[ChatMessage], model: str
49
+ ) -> list[ResponseInputItemParam]:
50
+ responses_inputs: list[ResponseInputItemParam] = []
51
+ for message in messages:
52
+ responses_inputs.extend(await openai_responses_input(message, model))
53
+ return responses_inputs
54
+
55
+
56
+ async def openai_responses_input(
57
+ message: ChatMessage, model: str
58
+ ) -> list[ResponseInputItemParam]:
59
+ if message.role == "system":
60
+ content = await openai_responses_content_list_param(message.content)
61
+ if is_o_series(model):
62
+ return [Message(type="message", role="developer", content=content)]
63
+ else:
64
+ return [Message(type="message", role="system", content=content)]
65
+ elif message.role == "user":
66
+ return [
67
+ Message(
68
+ type="message",
69
+ role="user",
70
+ content=await openai_responses_content_list_param(message.content),
71
+ )
72
+ ]
73
+ elif message.role == "assistant":
74
+ reasoning_content = openai_responses_reasponing_content_params(message.content)
75
+ if message.content:
76
+ formatted_id = str(message.id).replace("resp_", "msg_", 1)
77
+ if not formatted_id.startswith("msg_"):
78
+ # These messages MUST start with `msg_`.
79
+ # As `store=False` for this provider, OpenAI doesn't validate the IDs.
80
+ # This will keep them consistent across calls though.
81
+ formatted_id = f"msg_{formatted_id}"
82
+ text_content = [
83
+ ResponseOutputMessageParam(
84
+ type="message",
85
+ role="assistant",
86
+ id=formatted_id,
87
+ content=openai_responses_text_content_params(message.content),
88
+ status="completed",
89
+ )
90
+ ]
91
+ else:
92
+ text_content = []
93
+ tools_content = openai_responses_tools_content_params(message.tool_calls)
94
+ return reasoning_content + text_content + tools_content
95
+ elif message.role == "tool":
96
+ # TODO: Return ouptut types for internal tools e.g. computer, web_search
97
+ if message.error is not None:
98
+ output = message.error.message
99
+ else:
100
+ output = message.text
101
+ return [
102
+ FunctionCallOutput(
103
+ type="function_call_output",
104
+ call_id=message.tool_call_id or str(message.function),
105
+ output=output,
106
+ )
107
+ ]
108
+ else:
109
+ raise ValueError(f"Unexpected message role '{message.role}'")
110
+
111
+
112
+ async def openai_responses_content_list_param(
113
+ content: str | list[Content],
114
+ ) -> ResponseInputMessageContentListParam:
115
+ if isinstance(content, str):
116
+ content = [ContentText(text=content)]
117
+ return [await openai_responses_content_param(c) for c in content]
118
+
119
+
120
+ async def openai_responses_content_param(content: Content) -> ResponseInputContentParam: # type: ignore[return]
121
+ if isinstance(content, ContentText):
122
+ return ResponseInputTextParam(type="input_text", text=content.text)
123
+ elif isinstance(content, ContentImage):
124
+ image_url = content.image
125
+ if not is_http_url(image_url):
126
+ image_url = await file_as_data_uri(image_url)
127
+
128
+ return ResponseInputImageParam(
129
+ type="input_image", detail=content.detail, image_url=image_url
130
+ )
131
+ else:
132
+ # TODO: support for files (PDFs) and audio and video whenever
133
+ # that is supported by the responses API (was not on initial release)
134
+
135
+ # TODO: note that when doing this we should ensure that the
136
+ # openai_media_filter is properly screening out base64 encoded
137
+ # audio and video (if it exists, looks like it may all be done
138
+ # w/ file uploads in the responses API)
139
+
140
+ raise ValueError("Unsupported content type.")
141
+
142
+
143
+ def openai_responses_reasponing_content_params(
144
+ content: str | list[Content],
145
+ ) -> list[ResponseInputItemParam]:
146
+ if isinstance(content, list):
147
+ return [
148
+ ResponseReasoningItemParam(
149
+ type="reasoning",
150
+ id=str(c.signature),
151
+ summary=[Summary(type="summary_text", text=c.reasoning)],
152
+ )
153
+ for c in content
154
+ if isinstance(c, ContentReasoning)
155
+ ]
156
+ else:
157
+ return []
158
+
159
+
160
+ def openai_responses_text_content_params(
161
+ content: str | list[Content],
162
+ ) -> list[ResponseOutputTextParam | ResponseOutputRefusalParam]:
163
+ if isinstance(content, str):
164
+ content = [ContentText(text=content)]
165
+
166
+ params: list[ResponseOutputTextParam | ResponseOutputRefusalParam] = []
167
+
168
+ for c in content:
169
+ if isinstance(c, ContentText):
170
+ if c.refusal:
171
+ params.append(
172
+ ResponseOutputRefusalParam(type="refusal", refusal=c.text)
173
+ )
174
+ else:
175
+ params.append(
176
+ ResponseOutputTextParam(
177
+ type="output_text", text=c.text, annotations=[]
178
+ )
179
+ )
180
+
181
+ return params
182
+
183
+
184
+ def openai_responses_tools_content_params(
185
+ tool_calls: list[ToolCall] | None,
186
+ ) -> list[ResponseInputItemParam]:
187
+ if tool_calls is not None:
188
+ return [
189
+ ResponseFunctionToolCallParam(
190
+ type="function_call",
191
+ call_id=call.id,
192
+ name=call.function,
193
+ arguments=json.dumps(call.arguments),
194
+ status="completed",
195
+ )
196
+ for call in tool_calls
197
+ ]
198
+ else:
199
+ return []
200
+
201
+
202
+ def openai_responses_tool_choice(tool_choice: ToolChoice) -> ResponsesToolChoice:
203
+ match tool_choice:
204
+ case "none" | "auto":
205
+ return tool_choice
206
+ case "any":
207
+ return "required"
208
+ # TODO: internal tools need to be converted to ToolChoiceTypesParam
209
+ case _:
210
+ return ToolChoiceFunctionParam(type="function", name=tool_choice.name)
211
+
212
+
213
+ def openai_responses_tools(tools: list[ToolInfo]) -> list[ToolParam]:
214
+ # TODO: return special types for internal tools
215
+ return [
216
+ FunctionToolParam(
217
+ type="function",
218
+ name=tool.name,
219
+ description=tool.description,
220
+ parameters=tool.parameters.model_dump(exclude_none=True),
221
+ strict=False, # default parameters don't work in strict mode
222
+ )
223
+ for tool in tools
224
+ ]
225
+
226
+
227
+ def openai_responses_chat_choices(
228
+ response: Response, tools: list[ToolInfo]
229
+ ) -> list[ChatCompletionChoice]:
230
+ # determine the StopReason
231
+ stop_reason: StopReason = "stop"
232
+ if response.incomplete_details is not None:
233
+ if response.incomplete_details.reason == "max_output_tokens":
234
+ stop_reason = "max_tokens"
235
+ elif response.incomplete_details.reason == "content_filter":
236
+ stop_reason = "content_filter"
237
+
238
+ # collect output and tool calls
239
+ message_content: list[Content] = []
240
+ tool_calls: list[ToolCall] = []
241
+ for output in response.output:
242
+ if isinstance(output, ResponseOutputMessage):
243
+ for content in output.content:
244
+ if isinstance(content, ResponseOutputText):
245
+ message_content.append(ContentText(text=content.text))
246
+ else:
247
+ message_content.append(
248
+ ContentText(text=content.refusal, refusal=True)
249
+ )
250
+ elif isinstance(output, ResponseReasoningItem):
251
+ reasoning = "\n".join([summary.text for summary in output.summary])
252
+ if reasoning:
253
+ message_content.append(
254
+ ContentReasoning(signature=output.id, reasoning=reasoning)
255
+ )
256
+ else:
257
+ stop_reason = "tool_calls"
258
+ if isinstance(output, ResponseFunctionToolCall):
259
+ tool_calls.append(
260
+ parse_tool_call(
261
+ output.call_id,
262
+ output.name,
263
+ output.arguments,
264
+ tools,
265
+ )
266
+ )
267
+ pass
268
+ else:
269
+ ## TODO: implement support for internal tools
270
+ raise ValueError(f"Unexpected output type: {output.__class__}")
271
+
272
+ # return choice
273
+ return [
274
+ ChatCompletionChoice(
275
+ message=ChatMessageAssistant(
276
+ id=response.id,
277
+ content=message_content,
278
+ tool_calls=tool_calls if len(tool_calls) > 0 else None,
279
+ source="generate",
280
+ ),
281
+ stop_reason=stop_reason,
282
+ )
283
+ ]
@@ -815,6 +815,7 @@ async def model_output_from_message(
815
815
  + (input_tokens_cache_write or 0)
816
816
  + (input_tokens_cache_read or 0)
817
817
  + message.usage.output_tokens
818
+ + reasoning_tokens
818
819
  )
819
820
  return ModelOutput(
820
821
  model=message.model,
@@ -51,7 +51,6 @@ from .._chat_message import (
51
51
  ChatMessageUser,
52
52
  )
53
53
  from .._generate_config import GenerateConfig
54
- from .._image import image_url_filter
55
54
  from .._model import ModelAPI
56
55
  from .._model_call import ModelCall
57
56
  from .._model_output import (
@@ -60,6 +59,7 @@ from .._model_output import (
60
59
  ModelUsage,
61
60
  StopReason,
62
61
  )
62
+ from .._openai import openai_media_filter
63
63
  from .util import (
64
64
  environment_prerequisite_error,
65
65
  model_base_url,
@@ -182,7 +182,7 @@ class AzureAIAPI(ModelAPI):
182
182
  else None,
183
183
  ),
184
184
  response=response.as_dict() if response else {},
185
- filter=image_url_filter,
185
+ filter=openai_media_filter,
186
186
  )
187
187
 
188
188
  # make call
@@ -82,6 +82,14 @@ class MistralAPI(ModelAPI):
82
82
  config: GenerateConfig = GenerateConfig(),
83
83
  **model_args: Any,
84
84
  ):
85
+ # extract any service prefix from model name
86
+ parts = model_name.split("/")
87
+ if len(parts) > 1:
88
+ self.service: str | None = parts[0]
89
+ model_name = "/".join(parts[1:])
90
+ else:
91
+ self.service = None
92
+
85
93
  super().__init__(
86
94
  model_name=model_name,
87
95
  base_url=base_url,
@@ -94,31 +102,39 @@ class MistralAPI(ModelAPI):
94
102
  config=config,
95
103
  )
96
104
 
97
- # resolve api_key -- look for mistral then azure
105
+ # resolve api_key
98
106
  if not self.api_key:
99
- self.api_key = os.environ.get(MISTRAL_API_KEY, None)
100
- if self.api_key:
101
- base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
102
- else:
107
+ if self.is_azure():
103
108
  self.api_key = os.environ.get(
104
109
  AZUREAI_MISTRAL_API_KEY, os.environ.get(AZURE_MISTRAL_API_KEY, None)
105
110
  )
106
- if not self.api_key:
107
- raise environment_prerequisite_error(
108
- "Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
109
- )
110
- base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
111
- if not base_url:
111
+ else:
112
+ self.api_key = os.environ.get(MISTRAL_API_KEY, None)
113
+
114
+ if not self.api_key:
115
+ raise environment_prerequisite_error(
116
+ "Mistral", [MISTRAL_API_KEY, AZUREAI_MISTRAL_API_KEY]
117
+ )
118
+
119
+ if not self.base_url:
120
+ if self.is_azure():
121
+ self.base_url = model_base_url(base_url, "AZUREAI_MISTRAL_BASE_URL")
122
+ if not self.base_url:
112
123
  raise ValueError(
113
124
  "You must provide a base URL when using Mistral on Azure. Use the AZUREAI_MISTRAL_BASE_URL "
114
125
  + " environment variable or the --model-base-url CLI flag to set the base URL."
115
126
  )
127
+ else:
128
+ self.base_url = model_base_url(base_url, "MISTRAL_BASE_URL")
116
129
 
117
- if base_url:
118
- model_args["server_url"] = base_url
130
+ if self.base_url:
131
+ model_args["server_url"] = self.base_url
119
132
 
120
133
  self.model_args = model_args
121
134
 
135
+ def is_azure(self) -> bool:
136
+ return self.service == "azure"
137
+
122
138
  @override
123
139
  async def close(self) -> None:
124
140
  # client is created and destroyed in generate
@@ -22,28 +22,27 @@ from inspect_ai._util.error import PrerequisiteError
22
22
  from inspect_ai._util.http import is_retryable_http_status
23
23
  from inspect_ai._util.logger import warn_once
24
24
  from inspect_ai.model._openai import chat_choices_from_openai
25
+ from inspect_ai.model._providers.openai_responses import generate_responses
25
26
  from inspect_ai.model._providers.util.hooks import HttpxHooks
26
27
  from inspect_ai.tool import ToolChoice, ToolInfo
27
28
 
28
29
  from .._chat_message import ChatMessage
29
30
  from .._generate_config import GenerateConfig
30
- from .._image import image_url_filter
31
31
  from .._model import ModelAPI
32
32
  from .._model_call import ModelCall
33
- from .._model_output import (
34
- ChatCompletionChoice,
35
- ModelOutput,
36
- ModelUsage,
37
- StopReason,
38
- )
33
+ from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
39
34
  from .._openai import (
35
+ OpenAIResponseError,
40
36
  is_gpt,
41
37
  is_o1_mini,
42
38
  is_o1_preview,
39
+ is_o1_pro,
43
40
  is_o_series,
44
41
  openai_chat_messages,
45
42
  openai_chat_tool_choice,
46
43
  openai_chat_tools,
44
+ openai_handle_bad_request,
45
+ openai_media_filter,
47
46
  )
48
47
  from .openai_o1 import generate_o1
49
48
  from .util import (
@@ -65,6 +64,7 @@ class OpenAIAPI(ModelAPI):
65
64
  base_url: str | None = None,
66
65
  api_key: str | None = None,
67
66
  config: GenerateConfig = GenerateConfig(),
67
+ responses_api: bool | None = None,
68
68
  **model_args: Any,
69
69
  ) -> None:
70
70
  # extract azure service prefix from model name (other providers
@@ -77,6 +77,9 @@ class OpenAIAPI(ModelAPI):
77
77
  else:
78
78
  self.service = None
79
79
 
80
+ # note whether we are forcing the responses_api
81
+ self.responses_api = True if responses_api else False
82
+
80
83
  # call super
81
84
  super().__init__(
82
85
  model_name=model_name,
@@ -88,22 +91,21 @@ class OpenAIAPI(ModelAPI):
88
91
 
89
92
  # resolve api_key
90
93
  if not self.api_key:
91
- self.api_key = os.environ.get(
92
- AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
93
- )
94
- # backward compatibility for when env vars determined service
95
- if self.api_key and (os.environ.get(OPENAI_API_KEY, None) is None):
96
- self.service = "azure"
94
+ if self.service == "azure":
95
+ self.api_key = os.environ.get(
96
+ AZUREAI_OPENAI_API_KEY, os.environ.get(AZURE_OPENAI_API_KEY, None)
97
+ )
97
98
  else:
98
99
  self.api_key = os.environ.get(OPENAI_API_KEY, None)
99
- if not self.api_key:
100
- raise environment_prerequisite_error(
101
- "OpenAI",
102
- [
103
- OPENAI_API_KEY,
104
- AZUREAI_OPENAI_API_KEY,
105
- ],
106
- )
100
+
101
+ if not self.api_key:
102
+ raise environment_prerequisite_error(
103
+ "OpenAI",
104
+ [
105
+ OPENAI_API_KEY,
106
+ AZUREAI_OPENAI_API_KEY,
107
+ ],
108
+ )
107
109
 
108
110
  # create async http client
109
111
  http_client = OpenAIAsyncHttpxClient()
@@ -125,10 +127,16 @@ class OpenAIAPI(ModelAPI):
125
127
  + "environment variable or the --model-base-url CLI flag to set the base URL."
126
128
  )
127
129
 
130
+ # resolve version
131
+ api_version = os.environ.get(
132
+ "AZUREAI_OPENAI_API_VERSION",
133
+ os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
134
+ )
135
+
128
136
  self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
129
137
  api_key=self.api_key,
138
+ api_version=api_version,
130
139
  azure_endpoint=base_url,
131
- azure_deployment=model_name,
132
140
  http_client=http_client,
133
141
  **model_args,
134
142
  )
@@ -149,6 +157,9 @@ class OpenAIAPI(ModelAPI):
149
157
  def is_o_series(self) -> bool:
150
158
  return is_o_series(self.model_name)
151
159
 
160
+ def is_o1_pro(self) -> bool:
161
+ return is_o1_pro(self.model_name)
162
+
152
163
  def is_o1_mini(self) -> bool:
153
164
  return is_o1_mini(self.model_name)
154
165
 
@@ -177,6 +188,16 @@ class OpenAIAPI(ModelAPI):
177
188
  tools=tools,
178
189
  **self.completion_params(config, False),
179
190
  )
191
+ elif self.is_o1_pro() or self.responses_api:
192
+ return await generate_responses(
193
+ client=self.client,
194
+ http_hooks=self._http_hooks,
195
+ model_name=self.model_name,
196
+ input=input,
197
+ tools=tools,
198
+ tool_choice=tool_choice,
199
+ config=config,
200
+ )
180
201
 
181
202
  # allocate request_id (so we can see it from ModelCall)
182
203
  request_id = self._http_hooks.start_request()
@@ -189,7 +210,7 @@ class OpenAIAPI(ModelAPI):
189
210
  return ModelCall.create(
190
211
  request=request,
191
212
  response=response,
192
- filter=image_url_filter,
213
+ filter=openai_media_filter,
193
214
  time=self._http_hooks.end_request(request_id),
194
215
  )
195
216
 
@@ -221,6 +242,7 @@ class OpenAIAPI(ModelAPI):
221
242
 
222
243
  # save response for model_call
223
244
  response = completion.model_dump()
245
+ self.on_response(response)
224
246
 
225
247
  # parse out choices
226
248
  choices = self._chat_choices_from_response(completion, tools)
@@ -252,6 +274,12 @@ class OpenAIAPI(ModelAPI):
252
274
  except BadRequestError as e:
253
275
  return self.handle_bad_request(e), model_call()
254
276
 
277
+ def on_response(self, response: dict[str, Any]) -> None:
278
+ pass
279
+
280
+ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | Exception:
281
+ return openai_handle_bad_request(self.model_name, ex)
282
+
255
283
  def _chat_choices_from_response(
256
284
  self, response: ChatCompletion, tools: list[ToolInfo]
257
285
  ) -> list[ChatCompletionChoice]:
@@ -270,6 +298,8 @@ class OpenAIAPI(ModelAPI):
270
298
  return True
271
299
  elif isinstance(ex, APIStatusError):
272
300
  return is_retryable_http_status(ex.status_code)
301
+ elif isinstance(ex, OpenAIResponseError):
302
+ return ex.code in ["rate_limit_exceeded", "server_error"]
273
303
  elif isinstance(ex, APITimeoutError):
274
304
  return True
275
305
  else:
@@ -342,32 +372,6 @@ class OpenAIAPI(ModelAPI):
342
372
 
343
373
  return params
344
374
 
345
- # convert some well known bad request errors into ModelOutput
346
- def handle_bad_request(self, e: BadRequestError) -> ModelOutput | Exception:
347
- # extract message
348
- if isinstance(e.body, dict) and "message" in e.body.keys():
349
- content = str(e.body.get("message"))
350
- else:
351
- content = e.message
352
-
353
- # narrow stop_reason
354
- stop_reason: StopReason | None = None
355
- if e.code == "context_length_exceeded":
356
- stop_reason = "model_length"
357
- elif (
358
- e.code == "invalid_prompt" # seems to happen for o1/o3
359
- or e.code == "content_policy_violation" # seems to happen for vision
360
- or e.code == "content_filter" # seems to happen on azure
361
- ):
362
- stop_reason = "content_filter"
363
-
364
- if stop_reason:
365
- return ModelOutput.from_content(
366
- model=self.model_name, content=content, stop_reason=stop_reason
367
- )
368
- else:
369
- return e
370
-
371
375
 
372
376
  class OpenAIAsyncHttpxClient(httpx.AsyncClient):
373
377
  """Custom async client that deals better with long running Async requests.