inspect-ai 0.3.88__py3-none-any.whl → 0.3.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_cli/score.py +1 -12
  3. inspect_ai/_cli/util.py +4 -2
  4. inspect_ai/_display/core/footer.py +2 -2
  5. inspect_ai/_display/plain/display.py +2 -2
  6. inspect_ai/_eval/context.py +7 -1
  7. inspect_ai/_eval/eval.py +51 -27
  8. inspect_ai/_eval/evalset.py +27 -10
  9. inspect_ai/_eval/loader.py +7 -8
  10. inspect_ai/_eval/run.py +23 -31
  11. inspect_ai/_eval/score.py +18 -1
  12. inspect_ai/_eval/task/log.py +5 -13
  13. inspect_ai/_eval/task/resolved.py +1 -0
  14. inspect_ai/_eval/task/run.py +231 -244
  15. inspect_ai/_eval/task/task.py +25 -2
  16. inspect_ai/_eval/task/util.py +1 -8
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/json.py +8 -3
  19. inspect_ai/_util/registry.py +30 -13
  20. inspect_ai/_view/www/App.css +5 -0
  21. inspect_ai/_view/www/dist/assets/index.css +55 -18
  22. inspect_ai/_view/www/dist/assets/index.js +550 -458
  23. inspect_ai/_view/www/log-schema.json +66 -0
  24. inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
  25. inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
  26. inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
  27. inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
  28. inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
  29. inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
  30. inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
  31. inspect_ai/_view/www/src/types/log.d.ts +24 -6
  32. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
  33. inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
  34. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
  35. inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
  36. inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
  37. inspect_ai/agent/_agent.py +12 -0
  38. inspect_ai/agent/_as_tool.py +1 -1
  39. inspect_ai/agent/_bridge/bridge.py +9 -2
  40. inspect_ai/agent/_react.py +142 -74
  41. inspect_ai/agent/_run.py +13 -2
  42. inspect_ai/agent/_types.py +6 -0
  43. inspect_ai/approval/_apply.py +6 -7
  44. inspect_ai/approval/_approver.py +3 -3
  45. inspect_ai/approval/_auto.py +2 -2
  46. inspect_ai/approval/_call.py +20 -4
  47. inspect_ai/approval/_human/approver.py +3 -3
  48. inspect_ai/approval/_human/manager.py +2 -2
  49. inspect_ai/approval/_human/panel.py +3 -3
  50. inspect_ai/approval/_policy.py +3 -3
  51. inspect_ai/log/__init__.py +2 -0
  52. inspect_ai/log/_log.py +23 -2
  53. inspect_ai/log/_model.py +58 -0
  54. inspect_ai/log/_recorders/file.py +14 -3
  55. inspect_ai/log/_transcript.py +3 -0
  56. inspect_ai/model/__init__.py +2 -0
  57. inspect_ai/model/_call_tools.py +4 -1
  58. inspect_ai/model/_model.py +49 -3
  59. inspect_ai/model/_openai.py +151 -21
  60. inspect_ai/model/_providers/anthropic.py +20 -12
  61. inspect_ai/model/_providers/bedrock.py +3 -3
  62. inspect_ai/model/_providers/cloudflare.py +29 -108
  63. inspect_ai/model/_providers/google.py +21 -10
  64. inspect_ai/model/_providers/grok.py +23 -17
  65. inspect_ai/model/_providers/groq.py +61 -37
  66. inspect_ai/model/_providers/llama_cpp_python.py +8 -9
  67. inspect_ai/model/_providers/mistral.py +8 -3
  68. inspect_ai/model/_providers/ollama.py +8 -9
  69. inspect_ai/model/_providers/openai.py +53 -157
  70. inspect_ai/model/_providers/openai_compatible.py +195 -0
  71. inspect_ai/model/_providers/openrouter.py +4 -15
  72. inspect_ai/model/_providers/providers.py +11 -0
  73. inspect_ai/model/_providers/together.py +25 -23
  74. inspect_ai/model/_trim.py +83 -0
  75. inspect_ai/solver/_plan.py +5 -3
  76. inspect_ai/tool/_tool_def.py +8 -2
  77. inspect_ai/util/__init__.py +3 -0
  78. inspect_ai/util/_concurrency.py +15 -2
  79. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
  80. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +84 -79
  81. inspect_ai/_eval/task/rundir.py +0 -78
  82. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  83. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
  84. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
  85. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
  86. {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,24 @@
1
1
  import os
2
2
  from typing import Any
3
3
 
4
- import httpx
4
+ from openai import APIStatusError
5
5
  from typing_extensions import override
6
6
 
7
7
  from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
8
- from inspect_ai.tool import ToolChoice, ToolInfo
8
+ from inspect_ai.model._model_output import ModelOutput
9
+ from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
9
10
 
10
- from ...model import ChatMessage, GenerateConfig, ModelAPI, ModelOutput
11
- from .._model_call import ModelCall
12
- from .._model_output import ChatCompletionChoice
13
- from .util import (
14
- ChatAPIHandler,
15
- Llama31Handler,
16
- chat_api_input,
17
- chat_api_request,
18
- environment_prerequisite_error,
19
- model_base_url,
20
- should_retry_chat_api_error,
21
- )
22
- from .util.hooks import HttpxHooks
11
+ from ...model import GenerateConfig
12
+ from .util import environment_prerequisite_error
23
13
 
24
14
  # https://developers.cloudflare.com/workers-ai/models/#text-generation
15
+ # https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/
25
16
 
26
-
17
+ CLOUDFLARE_API_KEY = "CLOUDFLARE_API_KEY"
27
18
  CLOUDFLARE_API_TOKEN = "CLOUDFLARE_API_TOKEN"
28
19
 
29
20
 
30
- class CloudFlareAPI(ModelAPI):
21
+ class CloudFlareAPI(OpenAICompatibleAPI):
31
22
  def __init__(
32
23
  self,
33
24
  model_name: str,
@@ -36,98 +27,34 @@ class CloudFlareAPI(ModelAPI):
36
27
  config: GenerateConfig = GenerateConfig(),
37
28
  **model_args: Any,
38
29
  ):
30
+ # migrate formerly used CLOUDFLARE_API_TOKEN if no other key is specified
31
+ if api_key is None and CLOUDFLARE_API_KEY not in os.environ:
32
+ api_key = os.environ.get(CLOUDFLARE_API_TOKEN, None)
33
+
34
+ # account id used for limits and forming base url
35
+ self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID", None)
36
+ if not self.account_id:
37
+ raise environment_prerequisite_error("CloudFlare", "CLOUDFLARE_ACCOUNT_ID")
38
+
39
39
  super().__init__(
40
- model_name=model_name,
40
+ model_name=f"@cf/{model_name}",
41
41
  base_url=base_url,
42
42
  api_key=api_key,
43
- api_key_vars=[CLOUDFLARE_API_TOKEN],
44
43
  config=config,
44
+ service="CloudFlare",
45
+ service_base_url=f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/v1",
46
+ **model_args,
45
47
  )
46
- self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
47
- if not self.account_id:
48
- raise environment_prerequisite_error("CloudFlare", "CLOUDFLARE_ACCOUNT_ID")
49
- if not self.api_key:
50
- self.api_key = os.getenv(CLOUDFLARE_API_TOKEN)
51
- if not self.api_key:
52
- raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
53
- self.client = httpx.AsyncClient()
54
- self._http_hooks = HttpxHooks(self.client)
55
- base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
56
- self.base_url = (
57
- base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
58
- )
59
- self.model_args = model_args
60
48
 
61
49
  @override
62
- async def aclose(self) -> None:
63
- await self.client.aclose()
64
-
65
- async def generate(
66
- self,
67
- input: list[ChatMessage],
68
- tools: list[ToolInfo],
69
- tool_choice: ToolChoice,
70
- config: GenerateConfig,
71
- ) -> tuple[ModelOutput, ModelCall]:
72
- # chat url
73
- chat_url = f"{self.base_url}/{self.account_id}/ai/run/@cf"
74
-
75
- # chat api input
76
- json: dict[str, Any] = dict(**self.model_args)
77
- if config.max_tokens is not None:
78
- json["max_tokens"] = config.max_tokens
79
- json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
80
-
81
- # request_id
82
- request_id = self._http_hooks.start_request()
83
-
84
- # setup response
85
- response: dict[str, Any] = {}
86
-
87
- def model_call() -> ModelCall:
88
- return ModelCall.create(
89
- request=json,
90
- response=response,
91
- time=self._http_hooks.end_request(request_id),
92
- )
93
-
94
- # make the call
95
- response = await chat_api_request(
96
- self.client,
97
- model_name=self.model_name,
98
- url=f"{chat_url}/{self.model_name}",
99
- headers={
100
- "Authorization": f"Bearer {self.api_key}",
101
- HttpxHooks.REQUEST_ID_HEADER: request_id,
102
- },
103
- json=json,
104
- )
105
-
106
- # handle response
107
- if response["success"]:
108
- # extract output
109
- content = response["result"]["response"]
110
- output = ModelOutput(
111
- model=self.model_name,
112
- choices=[
113
- ChatCompletionChoice(
114
- message=self.chat_api_handler().parse_assistant_response(
115
- content, tools
116
- ),
117
- stop_reason="stop",
118
- )
119
- ],
120
- )
121
-
122
- # return
123
- return output, model_call()
124
- else:
125
- error = str(response.get("errors", "Unknown"))
126
- raise RuntimeError(f"Error calling {self.model_name}: {error}")
127
-
128
- @override
129
- def should_retry(self, ex: Exception) -> bool:
130
- return should_retry_chat_api_error(ex)
50
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
51
+ if ex.status_code == 403:
52
+ content = str(ex)
53
+ if "context window limit" in content:
54
+ return ModelOutput.from_content(
55
+ self.model_name, content=content, stop_reason="model_length"
56
+ )
57
+ return ex
131
58
 
132
59
  # cloudflare enforces rate limits by model for each account
133
60
  @override
@@ -138,9 +65,3 @@ class CloudFlareAPI(ModelAPI):
138
65
  @override
139
66
  def max_tokens(self) -> int:
140
67
  return DEFAULT_MAX_TOKENS
141
-
142
- def chat_api_handler(self) -> ChatAPIHandler:
143
- if "llama" in self.model_name.lower():
144
- return Llama31Handler(self.model_name)
145
- else:
146
- return ChatAPIHandler(self.model_name)
@@ -127,7 +127,6 @@ class GoogleGenAIAPI(ModelAPI):
127
127
  parts = model_name.split("/")
128
128
  if len(parts) > 1:
129
129
  self.service: str | None = parts[0]
130
- model_name = "/".join(parts[1:])
131
130
  else:
132
131
  self.service = None
133
132
 
@@ -245,14 +244,14 @@ class GoogleGenAIAPI(ModelAPI):
245
244
 
246
245
  try:
247
246
  response = await client.aio.models.generate_content(
248
- model=self.model_name,
247
+ model=self.service_model_name(),
249
248
  contents=gemini_contents,
250
249
  config=parameters,
251
250
  )
252
251
  except ClientError as ex:
253
252
  return self.handle_client_error(ex), model_call()
254
253
 
255
- model_name = response.model_version or self.model_name
254
+ model_name = response.model_version or self.service_model_name()
256
255
  output = ModelOutput(
257
256
  model=model_name,
258
257
  choices=completion_choices_from_candidates(model_name, response),
@@ -261,6 +260,10 @@ class GoogleGenAIAPI(ModelAPI):
261
260
 
262
261
  return output, model_call()
263
262
 
263
+ def service_model_name(self) -> str:
264
+ """Model name without any service prefix."""
265
+ return self.model_name.replace(f"{self.service}/", "", 1)
266
+
264
267
  @override
265
268
  def should_retry(self, ex: Exception) -> bool:
266
269
  if isinstance(ex, APIError) and ex.code is not None:
@@ -270,8 +273,8 @@ class GoogleGenAIAPI(ModelAPI):
270
273
 
271
274
  @override
272
275
  def connection_key(self) -> str:
273
- """Scope for enforcing max_connections (could also use endpoint)."""
274
- return self.model_name
276
+ """Scope for enforcing max_connections."""
277
+ return str(self.api_key)
275
278
 
276
279
  def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
277
280
  if (
@@ -283,7 +286,9 @@ class GoogleGenAIAPI(ModelAPI):
283
286
  )
284
287
  ):
285
288
  return ModelOutput.from_content(
286
- self.model_name, content=ex.message, stop_reason="model_length"
289
+ self.service_model_name(),
290
+ content=ex.message,
291
+ stop_reason="model_length",
287
292
  )
288
293
  else:
289
294
  raise ex
@@ -644,10 +649,16 @@ def completion_choices_from_candidates(
644
649
  )
645
650
  ]
646
651
  else:
647
- raise RuntimeError(
648
- "Google response includes no completion candidates and no block reason: "
649
- + f"{response.model_dump_json(indent=2)}"
650
- )
652
+ return [
653
+ ChatCompletionChoice(
654
+ message=ChatMessageAssistant(
655
+ content=NO_CONTENT,
656
+ model=model,
657
+ source="generate",
658
+ ),
659
+ stop_reason="stop",
660
+ )
661
+ ]
651
662
 
652
663
 
653
664
  def split_reasoning(content: str) -> tuple[str | None, str]:
@@ -1,15 +1,12 @@
1
- import os
1
+ from openai import APIStatusError
2
2
 
3
- from inspect_ai.model._providers.util import model_base_url
4
- from inspect_ai.model._providers.util.util import environment_prerequisite_error
3
+ from inspect_ai.model._model_output import ModelOutput
5
4
 
6
5
  from .._generate_config import GenerateConfig
7
- from .openai import OpenAIAPI
6
+ from .openai_compatible import OpenAICompatibleAPI
8
7
 
9
- GROK_API_KEY = "GROK_API_KEY"
10
8
 
11
-
12
- class GrokAPI(OpenAIAPI):
9
+ class GrokAPI(OpenAICompatibleAPI):
13
10
  def __init__(
14
11
  self,
15
12
  model_name: str,
@@ -17,19 +14,28 @@ class GrokAPI(OpenAIAPI):
17
14
  api_key: str | None = None,
18
15
  config: GenerateConfig = GenerateConfig(),
19
16
  ) -> None:
20
- # resolve base url
21
- base_url = model_base_url(base_url, "GROK_BASE_URL")
22
- base_url = base_url or "https://api.x.ai/v1"
23
-
24
- # resolve api key
25
- api_key = api_key or os.environ.get(GROK_API_KEY, None)
26
- if api_key is None:
27
- raise environment_prerequisite_error("Grok", GROK_API_KEY)
28
-
29
- # call super
30
17
  super().__init__(
31
18
  model_name=model_name,
32
19
  base_url=base_url,
33
20
  api_key=api_key,
34
21
  config=config,
22
+ service="Grok",
23
+ service_base_url="https://api.x.ai/v1",
35
24
  )
25
+
26
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
27
+ if ex.status_code == 400:
28
+ # extract message
29
+ if isinstance(ex.body, dict) and "message" in ex.body.keys():
30
+ content = str(ex.body.get("message"))
31
+ else:
32
+ content = ex.message
33
+
34
+ if "prompt length" in content:
35
+ return ModelOutput.from_content(
36
+ model=self.model_name, content=content, stop_reason="model_length"
37
+ )
38
+ else:
39
+ return ex
40
+ else:
41
+ return ex
@@ -102,7 +102,7 @@ class GroqAPI(ModelAPI):
102
102
  tools: list[ToolInfo],
103
103
  tool_choice: ToolChoice,
104
104
  config: GenerateConfig,
105
- ) -> tuple[ModelOutput, ModelCall]:
105
+ ) -> tuple[ModelOutput | Exception, ModelCall]:
106
106
  # allocate request_id (so we can see it from ModelCall)
107
107
  request_id = self._http_hooks.start_request()
108
108
 
@@ -136,45 +136,48 @@ class GroqAPI(ModelAPI):
136
136
  **params,
137
137
  )
138
138
 
139
- completion: ChatCompletion = await self.client.chat.completions.create(
140
- **request,
141
- )
139
+ try:
140
+ completion: ChatCompletion = await self.client.chat.completions.create(
141
+ **request,
142
+ )
142
143
 
143
- response = completion.model_dump()
144
-
145
- # extract metadata
146
- metadata: dict[str, Any] = {
147
- "id": completion.id,
148
- "system_fingerprint": completion.system_fingerprint,
149
- "created": completion.created,
150
- }
151
- if completion.usage:
152
- metadata = metadata | {
153
- "queue_time": completion.usage.queue_time,
154
- "prompt_time": completion.usage.prompt_time,
155
- "completion_time": completion.usage.completion_time,
156
- "total_time": completion.usage.total_time,
157
- }
144
+ response = completion.model_dump()
158
145
 
159
- # extract output
160
- choices = self._chat_choices_from_response(completion, tools)
161
- output = ModelOutput(
162
- model=completion.model,
163
- choices=choices,
164
- usage=(
165
- ModelUsage(
166
- input_tokens=completion.usage.prompt_tokens,
167
- output_tokens=completion.usage.completion_tokens,
168
- total_tokens=completion.usage.total_tokens,
169
- )
170
- if completion.usage
171
- else None
172
- ),
173
- metadata=metadata,
174
- )
146
+ # extract metadata
147
+ metadata: dict[str, Any] = {
148
+ "id": completion.id,
149
+ "system_fingerprint": completion.system_fingerprint,
150
+ "created": completion.created,
151
+ }
152
+ if completion.usage:
153
+ metadata = metadata | {
154
+ "queue_time": completion.usage.queue_time,
155
+ "prompt_time": completion.usage.prompt_time,
156
+ "completion_time": completion.usage.completion_time,
157
+ "total_time": completion.usage.total_time,
158
+ }
159
+
160
+ # extract output
161
+ choices = self._chat_choices_from_response(completion, tools)
162
+ output = ModelOutput(
163
+ model=completion.model,
164
+ choices=choices,
165
+ usage=(
166
+ ModelUsage(
167
+ input_tokens=completion.usage.prompt_tokens,
168
+ output_tokens=completion.usage.completion_tokens,
169
+ total_tokens=completion.usage.total_tokens,
170
+ )
171
+ if completion.usage
172
+ else None
173
+ ),
174
+ metadata=metadata,
175
+ )
175
176
 
176
- # return
177
- return output, model_call()
177
+ # return
178
+ return output, model_call()
179
+ except APIStatusError as ex:
180
+ return self.handle_bad_request(ex), model_call()
178
181
 
179
182
  def completion_params(self, config: GenerateConfig) -> Dict[str, Any]:
180
183
  params: dict[str, Any] = {}
@@ -234,6 +237,27 @@ class GroqAPI(ModelAPI):
234
237
  def max_tokens(self) -> Optional[int]:
235
238
  return DEFAULT_MAX_TOKENS
236
239
 
240
+ def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
241
+ if ex.status_code == 400:
242
+ # extract code and message
243
+ content = ex.message
244
+ code = ""
245
+ if isinstance(ex.body, dict) and isinstance(
246
+ ex.body.get("error", None), dict
247
+ ):
248
+ error = ex.body.get("error", {})
249
+ content = str(error.get("message", content))
250
+ code = error.get("code", code)
251
+
252
+ if code == "context_length_exceeded":
253
+ return ModelOutput.from_content(
254
+ model=self.model_name,
255
+ content=content,
256
+ stop_reason="model_length",
257
+ )
258
+
259
+ return ex
260
+
237
261
 
238
262
  async def as_groq_chat_messages(
239
263
  messages: list[ChatMessage],
@@ -1,10 +1,8 @@
1
- from inspect_ai.model._providers.util import model_base_url
2
-
3
1
  from .._generate_config import GenerateConfig
4
- from .openai import OpenAIAPI
2
+ from .openai_compatible import OpenAICompatibleAPI
5
3
 
6
4
 
7
- class LlamaCppPythonAPI(OpenAIAPI):
5
+ class LlamaCppPythonAPI(OpenAICompatibleAPI):
8
6
  def __init__(
9
7
  self,
10
8
  model_name: str,
@@ -12,10 +10,11 @@ class LlamaCppPythonAPI(OpenAIAPI):
12
10
  api_key: str | None = None,
13
11
  config: GenerateConfig = GenerateConfig(),
14
12
  ) -> None:
15
- base_url = model_base_url(base_url, "LLAMA_CPP_PYTHON_BASE_URL")
16
- base_url = base_url if base_url else "http://localhost:8000/v1"
17
- if not api_key:
18
- api_key = "llama-cpp-python"
19
13
  super().__init__(
20
- model_name=model_name, base_url=base_url, api_key=api_key, config=config
14
+ model_name=model_name,
15
+ base_url=base_url,
16
+ api_key=api_key or "llama-cpp-python",
17
+ config=config,
18
+ service="llama_cpp_python",
19
+ service_base_url="http://localhost:8000/v1",
21
20
  )
@@ -86,7 +86,6 @@ class MistralAPI(ModelAPI):
86
86
  parts = model_name.split("/")
87
87
  if len(parts) > 1:
88
88
  self.service: str | None = parts[0]
89
- model_name = "/".join(parts[1:])
90
89
  else:
91
90
  self.service = None
92
91
 
@@ -150,7 +149,7 @@ class MistralAPI(ModelAPI):
150
149
  # build request
151
150
  request_id = http_hooks.start_request()
152
151
  request: dict[str, Any] = dict(
153
- model=self.model_name,
152
+ model=self.service_model_name(),
154
153
  messages=await mistral_chat_messages(input),
155
154
  tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
156
155
  tool_choice=(
@@ -228,6 +227,10 @@ class MistralAPI(ModelAPI):
228
227
  ),
229
228
  ), model_call()
230
229
 
230
+ def service_model_name(self) -> str:
231
+ """Model name without any service prefix."""
232
+ return self.model_name.replace(f"{self.service}/", "", 1)
233
+
231
234
  @override
232
235
  def should_retry(self, ex: Exception) -> bool:
233
236
  if isinstance(ex, SDKError):
@@ -246,7 +249,9 @@ class MistralAPI(ModelAPI):
246
249
  content = body.get("message", ex.body)
247
250
  if "maximum context length" in ex.body:
248
251
  return ModelOutput.from_content(
249
- model=self.model_name, content=content, stop_reason="model_length"
252
+ model=self.service_model_name(),
253
+ content=content,
254
+ stop_reason="model_length",
250
255
  )
251
256
  else:
252
257
  return ex
@@ -1,10 +1,8 @@
1
- from inspect_ai.model._providers.util import model_base_url
2
-
3
1
  from .._generate_config import GenerateConfig
4
- from .openai import OpenAIAPI
2
+ from .openai_compatible import OpenAICompatibleAPI
5
3
 
6
4
 
7
- class OllamaAPI(OpenAIAPI):
5
+ class OllamaAPI(OpenAICompatibleAPI):
8
6
  def __init__(
9
7
  self,
10
8
  model_name: str,
@@ -12,10 +10,11 @@ class OllamaAPI(OpenAIAPI):
12
10
  api_key: str | None = None,
13
11
  config: GenerateConfig = GenerateConfig(),
14
12
  ) -> None:
15
- base_url = model_base_url(base_url, "OLLAMA_BASE_URL")
16
- base_url = base_url if base_url else "http://localhost:11434/v1"
17
- if not api_key:
18
- api_key = "ollama"
19
13
  super().__init__(
20
- model_name=model_name, base_url=base_url, api_key=api_key, config=config
14
+ model_name=model_name,
15
+ base_url=base_url,
16
+ api_key=api_key or "ollama",
17
+ config=config,
18
+ service="Ollama",
19
+ service_base_url="http://localhost:11434/v1",
21
20
  )