inspect-ai 0.3.87__py3-none-any.whl → 0.3.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -244
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +55 -18
- inspect_ai/_view/www/dist/assets/index.js +550 -458
- inspect_ai/_view/www/log-schema.json +84 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +150 -129
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -9
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +15 -2
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +25 -14
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_call.py +3 -0
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +86 -81
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.87.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,8 @@
|
|
1
1
|
import os
|
2
|
-
import socket
|
3
2
|
from logging import getLogger
|
4
|
-
from typing import Any
|
3
|
+
from typing import Any, Literal
|
5
4
|
|
6
|
-
import httpx
|
7
5
|
from openai import (
|
8
|
-
DEFAULT_CONNECTION_LIMITS,
|
9
|
-
DEFAULT_TIMEOUT,
|
10
|
-
APIStatusError,
|
11
|
-
APITimeoutError,
|
12
6
|
AsyncAzureOpenAI,
|
13
7
|
AsyncOpenAI,
|
14
8
|
BadRequestError,
|
@@ -20,7 +14,6 @@ from openai.types.chat import ChatCompletion
|
|
20
14
|
from typing_extensions import override
|
21
15
|
|
22
16
|
from inspect_ai._util.error import PrerequisiteError
|
23
|
-
from inspect_ai._util.http import is_retryable_http_status
|
24
17
|
from inspect_ai._util.logger import warn_once
|
25
18
|
from inspect_ai.model._openai import chat_choices_from_openai
|
26
19
|
from inspect_ai.model._providers.openai_responses import generate_responses
|
@@ -31,20 +24,23 @@ from .._chat_message import ChatMessage
|
|
31
24
|
from .._generate_config import GenerateConfig
|
32
25
|
from .._model import ModelAPI
|
33
26
|
from .._model_call import ModelCall
|
34
|
-
from .._model_output import
|
27
|
+
from .._model_output import ModelOutput
|
35
28
|
from .._openai import (
|
36
|
-
|
29
|
+
OpenAIAsyncHttpxClient,
|
37
30
|
is_computer_use_preview,
|
38
31
|
is_gpt,
|
39
32
|
is_o1_mini,
|
40
33
|
is_o1_preview,
|
41
34
|
is_o1_pro,
|
42
35
|
is_o_series,
|
36
|
+
model_output_from_openai,
|
43
37
|
openai_chat_messages,
|
44
38
|
openai_chat_tool_choice,
|
45
39
|
openai_chat_tools,
|
40
|
+
openai_completion_params,
|
46
41
|
openai_handle_bad_request,
|
47
42
|
openai_media_filter,
|
43
|
+
openai_should_retry,
|
48
44
|
)
|
49
45
|
from .openai_o1 import generate_o1
|
50
46
|
from .util import environment_prerequisite_error, model_base_url
|
@@ -55,6 +51,8 @@ OPENAI_API_KEY = "OPENAI_API_KEY"
|
|
55
51
|
AZURE_OPENAI_API_KEY = "AZURE_OPENAI_API_KEY"
|
56
52
|
AZUREAI_OPENAI_API_KEY = "AZUREAI_OPENAI_API_KEY"
|
57
53
|
|
54
|
+
# NOTE: If you are creating a new provider that is OpenAI compatible you should inherit from OpenAICompatibleAPI rather than OpenAPAPI.
|
55
|
+
|
58
56
|
|
59
57
|
class OpenAIAPI(ModelAPI):
|
60
58
|
def __init__(
|
@@ -72,7 +70,6 @@ class OpenAIAPI(ModelAPI):
|
|
72
70
|
parts = model_name.split("/")
|
73
71
|
if parts[0] == "azure" and len(parts) > 1:
|
74
72
|
self.service: str | None = parts[0]
|
75
|
-
model_name = "/".join(parts[1:])
|
76
73
|
else:
|
77
74
|
self.service = None
|
78
75
|
|
@@ -135,7 +132,7 @@ class OpenAIAPI(ModelAPI):
|
|
135
132
|
else:
|
136
133
|
api_version = os.environ.get(
|
137
134
|
"AZUREAI_OPENAI_API_VERSION",
|
138
|
-
os.environ.get("OPENAI_API_VERSION", "2025-
|
135
|
+
os.environ.get("OPENAI_API_VERSION", "2025-03-01-preview"),
|
139
136
|
)
|
140
137
|
|
141
138
|
self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
|
@@ -160,22 +157,22 @@ class OpenAIAPI(ModelAPI):
|
|
160
157
|
return self.service == "azure"
|
161
158
|
|
162
159
|
def is_o_series(self) -> bool:
|
163
|
-
return is_o_series(self.
|
160
|
+
return is_o_series(self.service_model_name())
|
164
161
|
|
165
162
|
def is_o1_pro(self) -> bool:
|
166
|
-
return is_o1_pro(self.
|
163
|
+
return is_o1_pro(self.service_model_name())
|
167
164
|
|
168
165
|
def is_o1_mini(self) -> bool:
|
169
|
-
return is_o1_mini(self.
|
166
|
+
return is_o1_mini(self.service_model_name())
|
170
167
|
|
171
168
|
def is_o1_preview(self) -> bool:
|
172
|
-
return is_o1_preview(self.
|
169
|
+
return is_o1_preview(self.service_model_name())
|
173
170
|
|
174
171
|
def is_computer_use_preview(self) -> bool:
|
175
|
-
return is_computer_use_preview(self.
|
172
|
+
return is_computer_use_preview(self.service_model_name())
|
176
173
|
|
177
174
|
def is_gpt(self) -> bool:
|
178
|
-
return is_gpt(self.
|
175
|
+
return is_gpt(self.service_model_name())
|
179
176
|
|
180
177
|
@override
|
181
178
|
async def aclose(self) -> None:
|
@@ -217,7 +214,7 @@ class OpenAIAPI(ModelAPI):
|
|
217
214
|
return await generate_responses(
|
218
215
|
client=self.client,
|
219
216
|
http_hooks=self._http_hooks,
|
220
|
-
model_name=self.
|
217
|
+
model_name=self.service_model_name(),
|
221
218
|
input=input,
|
222
219
|
tools=tools,
|
223
220
|
tool_choice=tool_choice,
|
@@ -242,15 +239,27 @@ class OpenAIAPI(ModelAPI):
|
|
242
239
|
# unlike text models, vision models require a max_tokens (and set it to a very low
|
243
240
|
# default, see https://community.openai.com/t/gpt-4-vision-preview-finish-details/475911/10)
|
244
241
|
OPENAI_IMAGE_DEFAULT_TOKENS = 4096
|
245
|
-
if "vision" in self.
|
242
|
+
if "vision" in self.service_model_name():
|
246
243
|
if isinstance(config.max_tokens, int):
|
247
244
|
config.max_tokens = max(config.max_tokens, OPENAI_IMAGE_DEFAULT_TOKENS)
|
248
245
|
else:
|
249
246
|
config.max_tokens = OPENAI_IMAGE_DEFAULT_TOKENS
|
250
247
|
|
248
|
+
# determine system role
|
249
|
+
# o1-mini does not support developer or system messages
|
250
|
+
# (see Dec 17, 2024 changelog: https://platform.openai.com/docs/changelog)
|
251
|
+
if self.is_o1_mini():
|
252
|
+
system_role: Literal["user", "system", "developer"] = "user"
|
253
|
+
# other o-series models use 'developer' rather than 'system' messages
|
254
|
+
# https://platform.openai.com/docs/guides/reasoning#advice-on-prompting
|
255
|
+
elif self.is_o_series():
|
256
|
+
system_role = "developer"
|
257
|
+
else:
|
258
|
+
system_role = "system"
|
259
|
+
|
251
260
|
# prepare request (we do this so we can log the ModelCall)
|
252
261
|
request = dict(
|
253
|
-
messages=await openai_chat_messages(input,
|
262
|
+
messages=await openai_chat_messages(input, system_role),
|
254
263
|
tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
255
264
|
tool_choice=openai_chat_tool_choice(tool_choice)
|
256
265
|
if len(tools) > 0
|
@@ -267,49 +276,16 @@ class OpenAIAPI(ModelAPI):
|
|
267
276
|
|
268
277
|
# save response for model_call
|
269
278
|
response = completion.model_dump()
|
270
|
-
self.on_response(response)
|
271
|
-
|
272
|
-
# parse out choices
|
273
|
-
choices = self._chat_choices_from_response(completion, tools)
|
274
279
|
|
275
280
|
# return output and call
|
276
|
-
|
277
|
-
|
278
|
-
choices=choices,
|
279
|
-
usage=(
|
280
|
-
ModelUsage(
|
281
|
-
input_tokens=completion.usage.prompt_tokens,
|
282
|
-
output_tokens=completion.usage.completion_tokens,
|
283
|
-
input_tokens_cache_read=(
|
284
|
-
completion.usage.prompt_tokens_details.cached_tokens
|
285
|
-
if completion.usage.prompt_tokens_details is not None
|
286
|
-
else None # openai only have cache read stats/pricing.
|
287
|
-
),
|
288
|
-
reasoning_tokens=(
|
289
|
-
completion.usage.completion_tokens_details.reasoning_tokens
|
290
|
-
if completion.usage.completion_tokens_details is not None
|
291
|
-
else None
|
292
|
-
),
|
293
|
-
total_tokens=completion.usage.total_tokens,
|
294
|
-
)
|
295
|
-
if completion.usage
|
296
|
-
else None
|
297
|
-
),
|
298
|
-
), model_call()
|
281
|
+
choices = chat_choices_from_openai(completion, tools)
|
282
|
+
return model_output_from_openai(completion, choices), model_call()
|
299
283
|
except (BadRequestError, UnprocessableEntityError) as e:
|
300
|
-
return self.
|
284
|
+
return openai_handle_bad_request(self.service_model_name(), e), model_call()
|
301
285
|
|
302
|
-
def
|
303
|
-
|
304
|
-
|
305
|
-
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
306
|
-
return openai_handle_bad_request(self.model_name, ex)
|
307
|
-
|
308
|
-
def _chat_choices_from_response(
|
309
|
-
self, response: ChatCompletion, tools: list[ToolInfo]
|
310
|
-
) -> list[ChatCompletionChoice]:
|
311
|
-
# adding this as a method so we can override from other classes (e.g together)
|
312
|
-
return chat_choices_from_openai(response, tools)
|
286
|
+
def service_model_name(self) -> str:
|
287
|
+
"""Model name without any service prefix."""
|
288
|
+
return self.model_name.replace(f"{self.service}/", "", 1)
|
313
289
|
|
314
290
|
@override
|
315
291
|
def should_retry(self, ex: Exception) -> bool:
|
@@ -321,14 +297,8 @@ class OpenAIAPI(ModelAPI):
|
|
321
297
|
return False
|
322
298
|
else:
|
323
299
|
return True
|
324
|
-
elif isinstance(ex, APIStatusError):
|
325
|
-
return is_retryable_http_status(ex.status_code)
|
326
|
-
elif isinstance(ex, OpenAIResponseError):
|
327
|
-
return ex.code in ["rate_limit_exceeded", "server_error"]
|
328
|
-
elif isinstance(ex, APITimeoutError):
|
329
|
-
return True
|
330
300
|
else:
|
331
|
-
return
|
301
|
+
return openai_should_retry(ex)
|
332
302
|
|
333
303
|
@override
|
334
304
|
def connection_key(self) -> str:
|
@@ -336,105 +306,31 @@ class OpenAIAPI(ModelAPI):
|
|
336
306
|
return str(self.api_key)
|
337
307
|
|
338
308
|
def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
|
339
|
-
|
340
|
-
|
341
|
-
|
309
|
+
# first call the default processing
|
310
|
+
params = openai_completion_params(self.service_model_name(), config, tools)
|
311
|
+
|
312
|
+
# now tailor to current model
|
342
313
|
if config.max_tokens is not None:
|
343
314
|
if self.is_o_series():
|
344
315
|
params["max_completion_tokens"] = config.max_tokens
|
345
|
-
|
346
|
-
|
347
|
-
if config.frequency_penalty is not None:
|
348
|
-
params["frequency_penalty"] = config.frequency_penalty
|
349
|
-
if config.stop_seqs is not None:
|
350
|
-
params["stop"] = config.stop_seqs
|
351
|
-
if config.presence_penalty is not None:
|
352
|
-
params["presence_penalty"] = config.presence_penalty
|
353
|
-
if config.logit_bias is not None:
|
354
|
-
params["logit_bias"] = config.logit_bias
|
355
|
-
if config.seed is not None:
|
356
|
-
params["seed"] = config.seed
|
316
|
+
del params["max_tokens"]
|
317
|
+
|
357
318
|
if config.temperature is not None:
|
358
319
|
if self.is_o_series():
|
359
320
|
warn_once(
|
360
321
|
logger,
|
361
322
|
"o series models do not support the 'temperature' parameter (temperature is always 1).",
|
362
323
|
)
|
363
|
-
|
364
|
-
params["temperature"] = config.temperature
|
365
|
-
# TogetherAPI requires temperature w/ num_choices
|
366
|
-
elif config.num_choices is not None:
|
367
|
-
params["temperature"] = 1
|
368
|
-
if config.top_p is not None:
|
369
|
-
params["top_p"] = config.top_p
|
370
|
-
if config.num_choices is not None:
|
371
|
-
params["n"] = config.num_choices
|
372
|
-
params = self.set_logprobs_params(params, config)
|
373
|
-
if tools and config.parallel_tool_calls is not None and not self.is_o_series():
|
374
|
-
params["parallel_tool_calls"] = config.parallel_tool_calls
|
375
|
-
if (
|
376
|
-
config.reasoning_effort is not None
|
377
|
-
and not self.is_gpt()
|
378
|
-
and not self.is_o1_mini()
|
379
|
-
and not self.is_o1_preview()
|
380
|
-
):
|
381
|
-
params["reasoning_effort"] = config.reasoning_effort
|
382
|
-
if config.response_schema is not None:
|
383
|
-
params["response_format"] = dict(
|
384
|
-
type="json_schema",
|
385
|
-
json_schema=dict(
|
386
|
-
name=config.response_schema.name,
|
387
|
-
schema=config.response_schema.json_schema.model_dump(
|
388
|
-
exclude_none=True
|
389
|
-
),
|
390
|
-
description=config.response_schema.description,
|
391
|
-
strict=config.response_schema.strict,
|
392
|
-
),
|
393
|
-
)
|
324
|
+
del params["temperature"]
|
394
325
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
self, params: dict[str, Any], config: GenerateConfig
|
399
|
-
) -> dict[str, Any]:
|
400
|
-
if config.logprobs is not None:
|
401
|
-
params["logprobs"] = config.logprobs
|
402
|
-
if config.top_logprobs is not None:
|
403
|
-
params["top_logprobs"] = config.top_logprobs
|
404
|
-
return params
|
405
|
-
|
406
|
-
|
407
|
-
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
408
|
-
"""Custom async client that deals better with long running Async requests.
|
409
|
-
|
410
|
-
Based on Anthropic DefaultAsyncHttpClient implementation that they
|
411
|
-
released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
|
412
|
-
|
413
|
-
"""
|
326
|
+
# remove parallel_tool_calls if not supported
|
327
|
+
if "parallel_tool_calls" in params.keys() and self.is_o_series():
|
328
|
+
del params["parallel_tool_calls"]
|
414
329
|
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
kwargs.setdefault("follow_redirects", True)
|
421
|
-
|
422
|
-
# This is based on the anthrpopic changes for claude 3.7:
|
423
|
-
# https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
|
424
|
-
|
425
|
-
# set socket options to deal with long running reasoning requests
|
426
|
-
socket_options = [
|
427
|
-
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
|
428
|
-
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
|
429
|
-
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
|
430
|
-
]
|
431
|
-
TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
|
432
|
-
if TCP_KEEPIDLE is not None:
|
433
|
-
socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
|
434
|
-
|
435
|
-
kwargs["transport"] = httpx.AsyncHTTPTransport(
|
436
|
-
limits=DEFAULT_CONNECTION_LIMITS,
|
437
|
-
socket_options=socket_options,
|
438
|
-
)
|
330
|
+
# remove reasoning_effort if not supported
|
331
|
+
if "reasoning_effort" in params.keys() and (
|
332
|
+
self.is_gpt() or self.is_o1_mini() or self.is_o1_preview()
|
333
|
+
):
|
334
|
+
del params["reasoning_effort"]
|
439
335
|
|
440
|
-
|
336
|
+
return params
|
@@ -0,0 +1,195 @@
|
|
1
|
+
import os
|
2
|
+
from logging import getLogger
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from openai import (
|
6
|
+
APIStatusError,
|
7
|
+
AsyncOpenAI,
|
8
|
+
BadRequestError,
|
9
|
+
PermissionDeniedError,
|
10
|
+
UnprocessableEntityError,
|
11
|
+
)
|
12
|
+
from openai._types import NOT_GIVEN
|
13
|
+
from openai.types.chat import ChatCompletion
|
14
|
+
from typing_extensions import override
|
15
|
+
|
16
|
+
from inspect_ai.model._openai import chat_choices_from_openai
|
17
|
+
from inspect_ai.model._providers.util.hooks import HttpxHooks
|
18
|
+
from inspect_ai.tool import ToolChoice, ToolInfo
|
19
|
+
|
20
|
+
from .._chat_message import ChatMessage
|
21
|
+
from .._generate_config import GenerateConfig
|
22
|
+
from .._model import ModelAPI
|
23
|
+
from .._model_call import ModelCall
|
24
|
+
from .._model_output import ChatCompletionChoice, ModelOutput
|
25
|
+
from .._openai import (
|
26
|
+
OpenAIAsyncHttpxClient,
|
27
|
+
model_output_from_openai,
|
28
|
+
openai_chat_messages,
|
29
|
+
openai_chat_tool_choice,
|
30
|
+
openai_chat_tools,
|
31
|
+
openai_completion_params,
|
32
|
+
openai_handle_bad_request,
|
33
|
+
openai_media_filter,
|
34
|
+
openai_should_retry,
|
35
|
+
)
|
36
|
+
from .util import environment_prerequisite_error, model_base_url
|
37
|
+
|
38
|
+
logger = getLogger(__name__)
|
39
|
+
|
40
|
+
|
41
|
+
class OpenAICompatibleAPI(ModelAPI):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
model_name: str,
|
45
|
+
base_url: str | None = None,
|
46
|
+
api_key: str | None = None,
|
47
|
+
config: GenerateConfig = GenerateConfig(),
|
48
|
+
service: str | None = None,
|
49
|
+
service_base_url: str | None = None,
|
50
|
+
**model_args: Any,
|
51
|
+
) -> None:
|
52
|
+
# extract service prefix from model name if not specified
|
53
|
+
if service is None:
|
54
|
+
parts = model_name.split("/")
|
55
|
+
if len(parts) == 1:
|
56
|
+
raise ValueError(
|
57
|
+
"openai-api model names must include a service prefix (e.g. 'openai-api/service/model')"
|
58
|
+
)
|
59
|
+
self.service = parts[0]
|
60
|
+
else:
|
61
|
+
self.service = service
|
62
|
+
|
63
|
+
# compute api key
|
64
|
+
api_key_var = f"{self.service.upper()}_API_KEY"
|
65
|
+
|
66
|
+
super().__init__(
|
67
|
+
model_name=model_name,
|
68
|
+
base_url=base_url,
|
69
|
+
api_key=api_key,
|
70
|
+
api_key_vars=[api_key_var],
|
71
|
+
config=config,
|
72
|
+
)
|
73
|
+
|
74
|
+
# use service prefix to lookup api_key
|
75
|
+
if not self.api_key:
|
76
|
+
self.api_key = os.environ.get(api_key_var, None)
|
77
|
+
if not self.api_key:
|
78
|
+
raise environment_prerequisite_error(
|
79
|
+
self.service,
|
80
|
+
[api_key_var],
|
81
|
+
)
|
82
|
+
|
83
|
+
# use service prefix to lookup base_url
|
84
|
+
if not self.base_url:
|
85
|
+
base_url_var = f"{self.service.upper()}_BASE_URL"
|
86
|
+
self.base_url = model_base_url(base_url, [base_url_var]) or service_base_url
|
87
|
+
if not self.base_url:
|
88
|
+
raise environment_prerequisite_error(
|
89
|
+
self.service,
|
90
|
+
[base_url_var],
|
91
|
+
)
|
92
|
+
|
93
|
+
# create async http client
|
94
|
+
http_client = OpenAIAsyncHttpxClient()
|
95
|
+
self.client = AsyncOpenAI(
|
96
|
+
api_key=self.api_key,
|
97
|
+
base_url=self.base_url,
|
98
|
+
http_client=http_client,
|
99
|
+
**model_args,
|
100
|
+
)
|
101
|
+
|
102
|
+
# create time tracker
|
103
|
+
self._http_hooks = HttpxHooks(self.client._client)
|
104
|
+
|
105
|
+
@override
|
106
|
+
async def aclose(self) -> None:
|
107
|
+
await self.client.close()
|
108
|
+
|
109
|
+
async def generate(
|
110
|
+
self,
|
111
|
+
input: list[ChatMessage],
|
112
|
+
tools: list[ToolInfo],
|
113
|
+
tool_choice: ToolChoice,
|
114
|
+
config: GenerateConfig,
|
115
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
116
|
+
# allocate request_id (so we can see it from ModelCall)
|
117
|
+
request_id = self._http_hooks.start_request()
|
118
|
+
|
119
|
+
# setup request and response for ModelCall
|
120
|
+
request: dict[str, Any] = {}
|
121
|
+
response: dict[str, Any] = {}
|
122
|
+
|
123
|
+
def model_call() -> ModelCall:
|
124
|
+
return ModelCall.create(
|
125
|
+
request=request,
|
126
|
+
response=response,
|
127
|
+
filter=openai_media_filter,
|
128
|
+
time=self._http_hooks.end_request(request_id),
|
129
|
+
)
|
130
|
+
|
131
|
+
# get completion params (slice off service from model name)
|
132
|
+
completion_params = self.completion_params(
|
133
|
+
config=config,
|
134
|
+
tools=len(tools) > 0,
|
135
|
+
)
|
136
|
+
|
137
|
+
# prepare request (we do this so we can log the ModelCall)
|
138
|
+
request = dict(
|
139
|
+
messages=await openai_chat_messages(input),
|
140
|
+
tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
141
|
+
tool_choice=openai_chat_tool_choice(tool_choice)
|
142
|
+
if len(tools) > 0
|
143
|
+
else NOT_GIVEN,
|
144
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
145
|
+
**completion_params,
|
146
|
+
)
|
147
|
+
|
148
|
+
try:
|
149
|
+
# generate completion and save response for model call
|
150
|
+
completion: ChatCompletion = await self.client.chat.completions.create(
|
151
|
+
**request
|
152
|
+
)
|
153
|
+
response = completion.model_dump()
|
154
|
+
self.on_response(response)
|
155
|
+
|
156
|
+
# return output and call
|
157
|
+
choices = self.chat_choices_from_completion(completion, tools)
|
158
|
+
return model_output_from_openai(completion, choices), model_call()
|
159
|
+
|
160
|
+
except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
|
161
|
+
return self.handle_bad_request(ex), model_call()
|
162
|
+
|
163
|
+
def service_model_name(self) -> str:
|
164
|
+
"""Model name without any service prefix."""
|
165
|
+
return self.model_name.replace(f"{self.service}/", "", 1)
|
166
|
+
|
167
|
+
@override
|
168
|
+
def should_retry(self, ex: Exception) -> bool:
|
169
|
+
return openai_should_retry(ex)
|
170
|
+
|
171
|
+
@override
|
172
|
+
def connection_key(self) -> str:
|
173
|
+
"""Scope for enforcing max_connections (could also use endpoint)."""
|
174
|
+
return str(self.api_key)
|
175
|
+
|
176
|
+
def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
|
177
|
+
return openai_completion_params(
|
178
|
+
model=self.service_model_name(),
|
179
|
+
config=config,
|
180
|
+
tools=tools,
|
181
|
+
)
|
182
|
+
|
183
|
+
def on_response(self, response: dict[str, Any]) -> None:
|
184
|
+
"""Hook for subclasses to do custom response handling."""
|
185
|
+
pass
|
186
|
+
|
187
|
+
def chat_choices_from_completion(
|
188
|
+
self, completion: ChatCompletion, tools: list[ToolInfo]
|
189
|
+
) -> list[ChatCompletionChoice]:
|
190
|
+
"""Hook for subclasses to do custom chat choice processing."""
|
191
|
+
return chat_choices_from_openai(completion, tools)
|
192
|
+
|
193
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
194
|
+
"""Hook for subclasses to do bad request handling"""
|
195
|
+
return openai_handle_bad_request(self.service_model_name(), ex)
|
@@ -1,16 +1,13 @@
|
|
1
1
|
import json
|
2
|
-
import os
|
3
2
|
from typing import Any, TypedDict
|
4
3
|
|
5
4
|
from typing_extensions import NotRequired, override
|
6
5
|
|
7
6
|
from inspect_ai._util.error import PrerequisiteError
|
8
7
|
from inspect_ai.model._openai import OpenAIResponseError
|
9
|
-
from inspect_ai.model._providers.util import model_base_url
|
10
|
-
from inspect_ai.model._providers.util.util import environment_prerequisite_error
|
11
8
|
|
12
9
|
from .._generate_config import GenerateConfig
|
13
|
-
from .
|
10
|
+
from .openai_compatible import OpenAICompatibleAPI
|
14
11
|
|
15
12
|
OPENROUTER_API_KEY = "OPENROUTER_API_KEY"
|
16
13
|
|
@@ -37,7 +34,7 @@ class OpenRouterError(Exception):
|
|
37
34
|
)
|
38
35
|
|
39
36
|
|
40
|
-
class OpenRouterAPI(
|
37
|
+
class OpenRouterAPI(OpenAICompatibleAPI):
|
41
38
|
def __init__(
|
42
39
|
self,
|
43
40
|
model_name: str,
|
@@ -46,16 +43,6 @@ class OpenRouterAPI(OpenAIAPI):
|
|
46
43
|
config: GenerateConfig = GenerateConfig(),
|
47
44
|
**model_args: Any,
|
48
45
|
) -> None:
|
49
|
-
# api_key
|
50
|
-
if not api_key:
|
51
|
-
api_key = os.environ.get(OPENROUTER_API_KEY, None)
|
52
|
-
if not api_key:
|
53
|
-
raise environment_prerequisite_error("OpenRouter", OPENROUTER_API_KEY)
|
54
|
-
|
55
|
-
# base_url
|
56
|
-
base_url = model_base_url(base_url, "OPENROUTER_BASE_URL")
|
57
|
-
base_url = base_url if base_url else "https://openrouter.ai/api/v1"
|
58
|
-
|
59
46
|
# collect known model args that we forward to generate
|
60
47
|
def collect_model_arg(name: str) -> Any | None:
|
61
48
|
nonlocal model_args
|
@@ -88,6 +75,8 @@ class OpenRouterAPI(OpenAIAPI):
|
|
88
75
|
base_url=base_url,
|
89
76
|
api_key=api_key,
|
90
77
|
config=config,
|
78
|
+
service="OpenRouter",
|
79
|
+
service_base_url="https://openrouter.ai/api/v1",
|
91
80
|
**model_args,
|
92
81
|
)
|
93
82
|
|
@@ -44,6 +44,17 @@ def openai() -> type[ModelAPI]:
|
|
44
44
|
return OpenAIAPI
|
45
45
|
|
46
46
|
|
47
|
+
@modelapi(name="openai-api")
|
48
|
+
def openai_api() -> type[ModelAPI]:
|
49
|
+
# validate
|
50
|
+
validate_openai_client("OpenAI Compatible API")
|
51
|
+
|
52
|
+
# in the clear
|
53
|
+
from .openai_compatible import OpenAICompatibleAPI
|
54
|
+
|
55
|
+
return OpenAICompatibleAPI
|
56
|
+
|
57
|
+
|
47
58
|
@modelapi(name="anthropic")
|
48
59
|
def anthropic() -> type[ModelAPI]:
|
49
60
|
FEATURE = "Anthropic API"
|