inspect-ai 0.3.88__py3-none-any.whl → 0.3.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -256
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +71 -36
  22. inspect_ai/_view/www/dist/assets/index.js +573 -475
  23. inspect_ai/_view/www/log-schema.json +66 -0
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -2
  30. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
  31. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  32. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -6
  33. inspect_ai/_view/www/src/samples/transcript/TranscriptView.module.css +0 -2
  34. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  35. inspect_ai/_view/www/src/types/log.d.ts +24 -6
  36. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  37. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  38. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  39. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  40. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  41. inspect_ai/agent/_agent.py +12 -0
  42. inspect_ai/agent/_as_tool.py +1 -1
  43. inspect_ai/agent/_bridge/bridge.py +9 -2
  44. inspect_ai/agent/_react.py +142 -74
  45. inspect_ai/agent/_run.py +13 -2
  46. inspect_ai/agent/_types.py +6 -0
  47. inspect_ai/approval/_apply.py +6 -7
  48. inspect_ai/approval/_approver.py +3 -3
  49. inspect_ai/approval/_auto.py +2 -2
  50. inspect_ai/approval/_call.py +20 -4
  51. inspect_ai/approval/_human/approver.py +3 -3
  52. inspect_ai/approval/_human/manager.py +2 -2
  53. inspect_ai/approval/_human/panel.py +3 -3
  54. inspect_ai/approval/_policy.py +3 -3
  55. inspect_ai/log/__init__.py +2 -0
  56. inspect_ai/log/_log.py +23 -2
  57. inspect_ai/log/_model.py +58 -0
  58. inspect_ai/log/_recorders/file.py +14 -3
  59. inspect_ai/log/_transcript.py +3 -0
  60. inspect_ai/model/__init__.py +2 -0
  61. inspect_ai/model/_call_tools.py +4 -1
  62. inspect_ai/model/_model.py +49 -3
  63. inspect_ai/model/_openai.py +151 -21
  64. inspect_ai/model/_providers/anthropic.py +20 -12
  65. inspect_ai/model/_providers/bedrock.py +3 -3
  66. inspect_ai/model/_providers/cloudflare.py +29 -108
  67. inspect_ai/model/_providers/google.py +21 -10
  68. inspect_ai/model/_providers/grok.py +23 -17
  69. inspect_ai/model/_providers/groq.py +61 -37
  70. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  71. inspect_ai/model/_providers/mistral.py +8 -3
  72. inspect_ai/model/_providers/ollama.py +8 -9
  73. inspect_ai/model/_providers/openai.py +53 -157
  74. inspect_ai/model/_providers/openai_compatible.py +195 -0
  75. inspect_ai/model/_providers/openrouter.py +4 -15
  76. inspect_ai/model/_providers/providers.py +11 -0
  77. inspect_ai/model/_providers/together.py +25 -23
  78. inspect_ai/model/_trim.py +83 -0
  79. inspect_ai/solver/_plan.py +5 -3
  80. inspect_ai/tool/_tool_def.py +8 -2
  81. inspect_ai/util/__init__.py +3 -0
  82. inspect_ai/util/_concurrency.py +15 -2
  83. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/METADATA +1 -1
  84. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/RECORD +88 -83
  85. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/WHEEL +1 -1
  86. inspect_ai/_eval/task/rundir.py +0 -78
  87. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  88. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/entry_points.txt +0 -0
  89. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/licenses/LICENSE +0 -0
  90. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,8 @@
1
1
  import os
2
- import socket
3
2
  from logging import getLogger
4
- from typing import Any
3
+ from typing import Any, Literal
5
4
 
6
- import httpx
7
5
  from openai import (
8
- DEFAULT_CONNECTION_LIMITS,
9
- DEFAULT_TIMEOUT,
10
- APIStatusError,
11
- APITimeoutError,
12
6
  AsyncAzureOpenAI,
13
7
  AsyncOpenAI,
14
8
  BadRequestError,
@@ -20,7 +14,6 @@ from openai.types.chat import ChatCompletion
20
14
  from typing_extensions import override
21
15
 
22
16
  from inspect_ai._util.error import PrerequisiteError
23
- from inspect_ai._util.http import is_retryable_http_status
24
17
  from inspect_ai._util.logger import warn_once
25
18
  from inspect_ai.model._openai import chat_choices_from_openai
26
19
  from inspect_ai.model._providers.openai_responses import generate_responses
@@ -31,20 +24,23 @@ from .._chat_message import ChatMessage
31
24
  from .._generate_config import GenerateConfig
32
25
  from .._model import ModelAPI
33
26
  from .._model_call import ModelCall
34
- from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
27
+ from .._model_output import ModelOutput
35
28
  from .._openai import (
36
- OpenAIResponseError,
29
+ OpenAIAsyncHttpxClient,
37
30
  is_computer_use_preview,
38
31
  is_gpt,
39
32
  is_o1_mini,
40
33
  is_o1_preview,
41
34
  is_o1_pro,
42
35
  is_o_series,
36
+ model_output_from_openai,
43
37
  openai_chat_messages,
44
38
  openai_chat_tool_choice,
45
39
  openai_chat_tools,
40
+ openai_completion_params,
46
41
  openai_handle_bad_request,
47
42
  openai_media_filter,
43
+ openai_should_retry,
48
44
  )
49
45
  from .openai_o1 import generate_o1
50
46
  from .util import environment_prerequisite_error, model_base_url
@@ -55,6 +51,8 @@ OPENAI_API_KEY = "OPENAI_API_KEY"
55
51
  AZURE_OPENAI_API_KEY = "AZURE_OPENAI_API_KEY"
56
52
  AZUREAI_OPENAI_API_KEY = "AZUREAI_OPENAI_API_KEY"
57
53
 
54
+ # NOTE: If you are creating a new provider that is OpenAI compatible you should inherit from OpenAICompatibleAPI rather than OpenAPAPI.
55
+
58
56
 
59
57
  class OpenAIAPI(ModelAPI):
60
58
  def __init__(
@@ -72,7 +70,6 @@ class OpenAIAPI(ModelAPI):
72
70
  parts = model_name.split("/")
73
71
  if parts[0] == "azure" and len(parts) > 1:
74
72
  self.service: str | None = parts[0]
75
- model_name = "/".join(parts[1:])
76
73
  else:
77
74
  self.service = None
78
75
 
@@ -135,7 +132,7 @@ class OpenAIAPI(ModelAPI):
135
132
  else:
136
133
  api_version = os.environ.get(
137
134
  "AZUREAI_OPENAI_API_VERSION",
138
- os.environ.get("OPENAI_API_VERSION", "2025-02-01-preview"),
135
+ os.environ.get("OPENAI_API_VERSION", "2025-03-01-preview"),
139
136
  )
140
137
 
141
138
  self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI(
@@ -160,22 +157,22 @@ class OpenAIAPI(ModelAPI):
160
157
  return self.service == "azure"
161
158
 
162
159
  def is_o_series(self) -> bool:
163
- return is_o_series(self.model_name)
160
+ return is_o_series(self.service_model_name())
164
161
 
165
162
  def is_o1_pro(self) -> bool:
166
- return is_o1_pro(self.model_name)
163
+ return is_o1_pro(self.service_model_name())
167
164
 
168
165
  def is_o1_mini(self) -> bool:
169
- return is_o1_mini(self.model_name)
166
+ return is_o1_mini(self.service_model_name())
170
167
 
171
168
  def is_o1_preview(self) -> bool:
172
- return is_o1_preview(self.model_name)
169
+ return is_o1_preview(self.service_model_name())
173
170
 
174
171
  def is_computer_use_preview(self) -> bool:
175
- return is_computer_use_preview(self.model_name)
172
+ return is_computer_use_preview(self.service_model_name())
176
173
 
177
174
  def is_gpt(self) -> bool:
178
- return is_gpt(self.model_name)
175
+ return is_gpt(self.service_model_name())
179
176
 
180
177
  @override
181
178
  async def aclose(self) -> None:
@@ -217,7 +214,7 @@ class OpenAIAPI(ModelAPI):
217
214
  return await generate_responses(
218
215
  client=self.client,
219
216
  http_hooks=self._http_hooks,
220
- model_name=self.model_name,
217
+ model_name=self.service_model_name(),
221
218
  input=input,
222
219
  tools=tools,
223
220
  tool_choice=tool_choice,
@@ -242,15 +239,27 @@ class OpenAIAPI(ModelAPI):
242
239
  # unlike text models, vision models require a max_tokens (and set it to a very low
243
240
  # default, see https://community.openai.com/t/gpt-4-vision-preview-finish-details/475911/10)
244
241
  OPENAI_IMAGE_DEFAULT_TOKENS = 4096
245
- if "vision" in self.model_name:
242
+ if "vision" in self.service_model_name():
246
243
  if isinstance(config.max_tokens, int):
247
244
  config.max_tokens = max(config.max_tokens, OPENAI_IMAGE_DEFAULT_TOKENS)
248
245
  else:
249
246
  config.max_tokens = OPENAI_IMAGE_DEFAULT_TOKENS
250
247
 
248
+ # determine system role
249
+ # o1-mini does not support developer or system messages
250
+ # (see Dec 17, 2024 changelog: https://platform.openai.com/docs/changelog)
251
+ if self.is_o1_mini():
252
+ system_role: Literal["user", "system", "developer"] = "user"
253
+ # other o-series models use 'developer' rather than 'system' messages
254
+ # https://platform.openai.com/docs/guides/reasoning#advice-on-prompting
255
+ elif self.is_o_series():
256
+ system_role = "developer"
257
+ else:
258
+ system_role = "system"
259
+
251
260
  # prepare request (we do this so we can log the ModelCall)
252
261
  request = dict(
253
- messages=await openai_chat_messages(input, self.model_name),
262
+ messages=await openai_chat_messages(input, system_role),
254
263
  tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
255
264
  tool_choice=openai_chat_tool_choice(tool_choice)
256
265
  if len(tools) > 0
@@ -267,49 +276,16 @@ class OpenAIAPI(ModelAPI):
267
276
 
268
277
  # save response for model_call
269
278
  response = completion.model_dump()
270
- self.on_response(response)
271
-
272
- # parse out choices
273
- choices = self._chat_choices_from_response(completion, tools)
274
279
 
275
280
  # return output and call
276
- return ModelOutput(
277
- model=completion.model,
278
- choices=choices,
279
- usage=(
280
- ModelUsage(
281
- input_tokens=completion.usage.prompt_tokens,
282
- output_tokens=completion.usage.completion_tokens,
283
- input_tokens_cache_read=(
284
- completion.usage.prompt_tokens_details.cached_tokens
285
- if completion.usage.prompt_tokens_details is not None
286
- else None # openai only have cache read stats/pricing.
287
- ),
288
- reasoning_tokens=(
289
- completion.usage.completion_tokens_details.reasoning_tokens
290
- if completion.usage.completion_tokens_details is not None
291
- else None
292
- ),
293
- total_tokens=completion.usage.total_tokens,
294
- )
295
- if completion.usage
296
- else None
297
- ),
298
- ), model_call()
281
+ choices = chat_choices_from_openai(completion, tools)
282
+ return model_output_from_openai(completion, choices), model_call()
299
283
  except (BadRequestError, UnprocessableEntityError) as e:
300
- return self.handle_bad_request(e), model_call()
284
+ return openai_handle_bad_request(self.service_model_name(), e), model_call()
301
285
 
302
- def on_response(self, response: dict[str, Any]) -> None:
303
- pass
304
-
305
- def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
306
- return openai_handle_bad_request(self.model_name, ex)
307
-
308
- def _chat_choices_from_response(
309
- self, response: ChatCompletion, tools: list[ToolInfo]
310
- ) -> list[ChatCompletionChoice]:
311
- # adding this as a method so we can override from other classes (e.g together)
312
- return chat_choices_from_openai(response, tools)
286
+ def service_model_name(self) -> str:
287
+ """Model name without any service prefix."""
288
+ return self.model_name.replace(f"{self.service}/", "", 1)
313
289
 
314
290
  @override
315
291
  def should_retry(self, ex: Exception) -> bool:
@@ -321,14 +297,8 @@ class OpenAIAPI(ModelAPI):
321
297
  return False
322
298
  else:
323
299
  return True
324
- elif isinstance(ex, APIStatusError):
325
- return is_retryable_http_status(ex.status_code)
326
- elif isinstance(ex, OpenAIResponseError):
327
- return ex.code in ["rate_limit_exceeded", "server_error"]
328
- elif isinstance(ex, APITimeoutError):
329
- return True
330
300
  else:
331
- return False
301
+ return openai_should_retry(ex)
332
302
 
333
303
  @override
334
304
  def connection_key(self) -> str:
@@ -336,105 +306,31 @@ class OpenAIAPI(ModelAPI):
336
306
  return str(self.api_key)
337
307
 
338
308
  def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
339
- params: dict[str, Any] = dict(
340
- model=self.model_name,
341
- )
309
+ # first call the default processing
310
+ params = openai_completion_params(self.service_model_name(), config, tools)
311
+
312
+ # now tailor to current model
342
313
  if config.max_tokens is not None:
343
314
  if self.is_o_series():
344
315
  params["max_completion_tokens"] = config.max_tokens
345
- else:
346
- params["max_tokens"] = config.max_tokens
347
- if config.frequency_penalty is not None:
348
- params["frequency_penalty"] = config.frequency_penalty
349
- if config.stop_seqs is not None:
350
- params["stop"] = config.stop_seqs
351
- if config.presence_penalty is not None:
352
- params["presence_penalty"] = config.presence_penalty
353
- if config.logit_bias is not None:
354
- params["logit_bias"] = config.logit_bias
355
- if config.seed is not None:
356
- params["seed"] = config.seed
316
+ del params["max_tokens"]
317
+
357
318
  if config.temperature is not None:
358
319
  if self.is_o_series():
359
320
  warn_once(
360
321
  logger,
361
322
  "o series models do not support the 'temperature' parameter (temperature is always 1).",
362
323
  )
363
- else:
364
- params["temperature"] = config.temperature
365
- # TogetherAPI requires temperature w/ num_choices
366
- elif config.num_choices is not None:
367
- params["temperature"] = 1
368
- if config.top_p is not None:
369
- params["top_p"] = config.top_p
370
- if config.num_choices is not None:
371
- params["n"] = config.num_choices
372
- params = self.set_logprobs_params(params, config)
373
- if tools and config.parallel_tool_calls is not None and not self.is_o_series():
374
- params["parallel_tool_calls"] = config.parallel_tool_calls
375
- if (
376
- config.reasoning_effort is not None
377
- and not self.is_gpt()
378
- and not self.is_o1_mini()
379
- and not self.is_o1_preview()
380
- ):
381
- params["reasoning_effort"] = config.reasoning_effort
382
- if config.response_schema is not None:
383
- params["response_format"] = dict(
384
- type="json_schema",
385
- json_schema=dict(
386
- name=config.response_schema.name,
387
- schema=config.response_schema.json_schema.model_dump(
388
- exclude_none=True
389
- ),
390
- description=config.response_schema.description,
391
- strict=config.response_schema.strict,
392
- ),
393
- )
324
+ del params["temperature"]
394
325
 
395
- return params
396
-
397
- def set_logprobs_params(
398
- self, params: dict[str, Any], config: GenerateConfig
399
- ) -> dict[str, Any]:
400
- if config.logprobs is not None:
401
- params["logprobs"] = config.logprobs
402
- if config.top_logprobs is not None:
403
- params["top_logprobs"] = config.top_logprobs
404
- return params
405
-
406
-
407
- class OpenAIAsyncHttpxClient(httpx.AsyncClient):
408
- """Custom async client that deals better with long running Async requests.
409
-
410
- Based on Anthropic DefaultAsyncHttpClient implementation that they
411
- released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
412
-
413
- """
326
+ # remove parallel_tool_calls if not supported
327
+ if "parallel_tool_calls" in params.keys() and self.is_o_series():
328
+ del params["parallel_tool_calls"]
414
329
 
415
- def __init__(self, **kwargs: Any) -> None:
416
- # This is based on the openai DefaultAsyncHttpxClient:
417
- # https://github.com/openai/openai-python/commit/347363ed67a6a1611346427bb9ebe4becce53f7e
418
- kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
419
- kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
420
- kwargs.setdefault("follow_redirects", True)
421
-
422
- # This is based on the anthrpopic changes for claude 3.7:
423
- # https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
424
-
425
- # set socket options to deal with long running reasoning requests
426
- socket_options = [
427
- (socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
428
- (socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
429
- (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
430
- ]
431
- TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
432
- if TCP_KEEPIDLE is not None:
433
- socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
434
-
435
- kwargs["transport"] = httpx.AsyncHTTPTransport(
436
- limits=DEFAULT_CONNECTION_LIMITS,
437
- socket_options=socket_options,
438
- )
330
+ # remove reasoning_effort if not supported
331
+ if "reasoning_effort" in params.keys() and (
332
+ self.is_gpt() or self.is_o1_mini() or self.is_o1_preview()
333
+ ):
334
+ del params["reasoning_effort"]
439
335
 
440
- super().__init__(**kwargs)
336
+ return params
@@ -0,0 +1,195 @@
1
+ import os
2
+ from logging import getLogger
3
+ from typing import Any
4
+
5
+ from openai import (
6
+ APIStatusError,
7
+ AsyncOpenAI,
8
+ BadRequestError,
9
+ PermissionDeniedError,
10
+ UnprocessableEntityError,
11
+ )
12
+ from openai._types import NOT_GIVEN
13
+ from openai.types.chat import ChatCompletion
14
+ from typing_extensions import override
15
+
16
+ from inspect_ai.model._openai import chat_choices_from_openai
17
+ from inspect_ai.model._providers.util.hooks import HttpxHooks
18
+ from inspect_ai.tool import ToolChoice, ToolInfo
19
+
20
+ from .._chat_message import ChatMessage
21
+ from .._generate_config import GenerateConfig
22
+ from .._model import ModelAPI
23
+ from .._model_call import ModelCall
24
+ from .._model_output import ChatCompletionChoice, ModelOutput
25
+ from .._openai import (
26
+ OpenAIAsyncHttpxClient,
27
+ model_output_from_openai,
28
+ openai_chat_messages,
29
+ openai_chat_tool_choice,
30
+ openai_chat_tools,
31
+ openai_completion_params,
32
+ openai_handle_bad_request,
33
+ openai_media_filter,
34
+ openai_should_retry,
35
+ )
36
+ from .util import environment_prerequisite_error, model_base_url
37
+
38
+ logger = getLogger(__name__)
39
+
40
+
41
+ class OpenAICompatibleAPI(ModelAPI):
42
+ def __init__(
43
+ self,
44
+ model_name: str,
45
+ base_url: str | None = None,
46
+ api_key: str | None = None,
47
+ config: GenerateConfig = GenerateConfig(),
48
+ service: str | None = None,
49
+ service_base_url: str | None = None,
50
+ **model_args: Any,
51
+ ) -> None:
52
+ # extract service prefix from model name if not specified
53
+ if service is None:
54
+ parts = model_name.split("/")
55
+ if len(parts) == 1:
56
+ raise ValueError(
57
+ "openai-api model names must include a service prefix (e.g. 'openai-api/service/model')"
58
+ )
59
+ self.service = parts[0]
60
+ else:
61
+ self.service = service
62
+
63
+ # compute api key
64
+ api_key_var = f"{self.service.upper()}_API_KEY"
65
+
66
+ super().__init__(
67
+ model_name=model_name,
68
+ base_url=base_url,
69
+ api_key=api_key,
70
+ api_key_vars=[api_key_var],
71
+ config=config,
72
+ )
73
+
74
+ # use service prefix to lookup api_key
75
+ if not self.api_key:
76
+ self.api_key = os.environ.get(api_key_var, None)
77
+ if not self.api_key:
78
+ raise environment_prerequisite_error(
79
+ self.service,
80
+ [api_key_var],
81
+ )
82
+
83
+ # use service prefix to lookup base_url
84
+ if not self.base_url:
85
+ base_url_var = f"{self.service.upper()}_BASE_URL"
86
+ self.base_url = model_base_url(base_url, [base_url_var]) or service_base_url
87
+ if not self.base_url:
88
+ raise environment_prerequisite_error(
89
+ self.service,
90
+ [base_url_var],
91
+ )
92
+
93
+ # create async http client
94
+ http_client = OpenAIAsyncHttpxClient()
95
+ self.client = AsyncOpenAI(
96
+ api_key=self.api_key,
97
+ base_url=self.base_url,
98
+ http_client=http_client,
99
+ **model_args,
100
+ )
101
+
102
+ # create time tracker
103
+ self._http_hooks = HttpxHooks(self.client._client)
104
+
105
+ @override
106
+ async def aclose(self) -> None:
107
+ await self.client.close()
108
+
109
+ async def generate(
110
+ self,
111
+ input: list[ChatMessage],
112
+ tools: list[ToolInfo],
113
+ tool_choice: ToolChoice,
114
+ config: GenerateConfig,
115
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
116
+ # allocate request_id (so we can see it from ModelCall)
117
+ request_id = self._http_hooks.start_request()
118
+
119
+ # setup request and response for ModelCall
120
+ request: dict[str, Any] = {}
121
+ response: dict[str, Any] = {}
122
+
123
+ def model_call() -> ModelCall:
124
+ return ModelCall.create(
125
+ request=request,
126
+ response=response,
127
+ filter=openai_media_filter,
128
+ time=self._http_hooks.end_request(request_id),
129
+ )
130
+
131
+ # get completion params (slice off service from model name)
132
+ completion_params = self.completion_params(
133
+ config=config,
134
+ tools=len(tools) > 0,
135
+ )
136
+
137
+ # prepare request (we do this so we can log the ModelCall)
138
+ request = dict(
139
+ messages=await openai_chat_messages(input),
140
+ tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
141
+ tool_choice=openai_chat_tool_choice(tool_choice)
142
+ if len(tools) > 0
143
+ else NOT_GIVEN,
144
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
145
+ **completion_params,
146
+ )
147
+
148
+ try:
149
+ # generate completion and save response for model call
150
+ completion: ChatCompletion = await self.client.chat.completions.create(
151
+ **request
152
+ )
153
+ response = completion.model_dump()
154
+ self.on_response(response)
155
+
156
+ # return output and call
157
+ choices = self.chat_choices_from_completion(completion, tools)
158
+ return model_output_from_openai(completion, choices), model_call()
159
+
160
+ except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
161
+ return self.handle_bad_request(ex), model_call()
162
+
163
+ def service_model_name(self) -> str:
164
+ """Model name without any service prefix."""
165
+ return self.model_name.replace(f"{self.service}/", "", 1)
166
+
167
+ @override
168
+ def should_retry(self, ex: Exception) -> bool:
169
+ return openai_should_retry(ex)
170
+
171
+ @override
172
+ def connection_key(self) -> str:
173
+ """Scope for enforcing max_connections (could also use endpoint)."""
174
+ return str(self.api_key)
175
+
176
+ def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
177
+ return openai_completion_params(
178
+ model=self.service_model_name(),
179
+ config=config,
180
+ tools=tools,
181
+ )
182
+
183
+ def on_response(self, response: dict[str, Any]) -> None:
184
+ """Hook for subclasses to do custom response handling."""
185
+ pass
186
+
187
+ def chat_choices_from_completion(
188
+ self, completion: ChatCompletion, tools: list[ToolInfo]
189
+ ) -> list[ChatCompletionChoice]:
190
+ """Hook for subclasses to do custom chat choice processing."""
191
+ return chat_choices_from_openai(completion, tools)
192
+
193
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
194
+ """Hook for subclasses to do bad request handling"""
195
+ return openai_handle_bad_request(self.service_model_name(), ex)
@@ -1,16 +1,13 @@
1
1
  import json
2
- import os
3
2
  from typing import Any, TypedDict
4
3
 
5
4
  from typing_extensions import NotRequired, override
6
5
 
7
6
  from inspect_ai._util.error import PrerequisiteError
8
7
  from inspect_ai.model._openai import OpenAIResponseError
9
- from inspect_ai.model._providers.util import model_base_url
10
- from inspect_ai.model._providers.util.util import environment_prerequisite_error
11
8
 
12
9
  from .._generate_config import GenerateConfig
13
- from .openai import OpenAIAPI
10
+ from .openai_compatible import OpenAICompatibleAPI
14
11
 
15
12
  OPENROUTER_API_KEY = "OPENROUTER_API_KEY"
16
13
 
@@ -37,7 +34,7 @@ class OpenRouterError(Exception):
37
34
  )
38
35
 
39
36
 
40
- class OpenRouterAPI(OpenAIAPI):
37
+ class OpenRouterAPI(OpenAICompatibleAPI):
41
38
  def __init__(
42
39
  self,
43
40
  model_name: str,
@@ -46,16 +43,6 @@ class OpenRouterAPI(OpenAIAPI):
46
43
  config: GenerateConfig = GenerateConfig(),
47
44
  **model_args: Any,
48
45
  ) -> None:
49
- # api_key
50
- if not api_key:
51
- api_key = os.environ.get(OPENROUTER_API_KEY, None)
52
- if not api_key:
53
- raise environment_prerequisite_error("OpenRouter", OPENROUTER_API_KEY)
54
-
55
- # base_url
56
- base_url = model_base_url(base_url, "OPENROUTER_BASE_URL")
57
- base_url = base_url if base_url else "https://openrouter.ai/api/v1"
58
-
59
46
  # collect known model args that we forward to generate
60
47
  def collect_model_arg(name: str) -> Any | None:
61
48
  nonlocal model_args
@@ -88,6 +75,8 @@ class OpenRouterAPI(OpenAIAPI):
88
75
  base_url=base_url,
89
76
  api_key=api_key,
90
77
  config=config,
78
+ service="OpenRouter",
79
+ service_base_url="https://openrouter.ai/api/v1",
91
80
  **model_args,
92
81
  )
93
82
 
@@ -44,6 +44,17 @@ def openai() -> type[ModelAPI]:
44
44
  return OpenAIAPI
45
45
 
46
46
 
47
+ @modelapi(name="openai-api")
48
+ def openai_api() -> type[ModelAPI]:
49
+ # validate
50
+ validate_openai_client("OpenAI Compatible API")
51
+
52
+ # in the clear
53
+ from .openai_compatible import OpenAICompatibleAPI
54
+
55
+ return OpenAICompatibleAPI
56
+
57
+
47
58
  @modelapi(name="anthropic")
48
59
  def anthropic() -> type[ModelAPI]:
49
60
  FEATURE = "Anthropic API"