inspect-ai 0.3.88__py3-none-any.whl → 0.3.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -244
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +55 -18
- inspect_ai/_view/www/dist/assets/index.js +550 -458
- inspect_ai/_view/www/log-schema.json +66 -0
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +24 -6
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -7
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +4 -1
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +20 -12
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/RECORD +84 -79
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.89.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,24 @@
|
|
1
1
|
import os
|
2
2
|
from typing import Any
|
3
3
|
|
4
|
-
import
|
4
|
+
from openai import APIStatusError
|
5
5
|
from typing_extensions import override
|
6
6
|
|
7
7
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
8
|
-
from inspect_ai.
|
8
|
+
from inspect_ai.model._model_output import ModelOutput
|
9
|
+
from inspect_ai.model._providers.openai_compatible import OpenAICompatibleAPI
|
9
10
|
|
10
|
-
from ...model import
|
11
|
-
from
|
12
|
-
from .._model_output import ChatCompletionChoice
|
13
|
-
from .util import (
|
14
|
-
ChatAPIHandler,
|
15
|
-
Llama31Handler,
|
16
|
-
chat_api_input,
|
17
|
-
chat_api_request,
|
18
|
-
environment_prerequisite_error,
|
19
|
-
model_base_url,
|
20
|
-
should_retry_chat_api_error,
|
21
|
-
)
|
22
|
-
from .util.hooks import HttpxHooks
|
11
|
+
from ...model import GenerateConfig
|
12
|
+
from .util import environment_prerequisite_error
|
23
13
|
|
24
14
|
# https://developers.cloudflare.com/workers-ai/models/#text-generation
|
15
|
+
# https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility/
|
25
16
|
|
26
|
-
|
17
|
+
CLOUDFLARE_API_KEY = "CLOUDFLARE_API_KEY"
|
27
18
|
CLOUDFLARE_API_TOKEN = "CLOUDFLARE_API_TOKEN"
|
28
19
|
|
29
20
|
|
30
|
-
class CloudFlareAPI(
|
21
|
+
class CloudFlareAPI(OpenAICompatibleAPI):
|
31
22
|
def __init__(
|
32
23
|
self,
|
33
24
|
model_name: str,
|
@@ -36,98 +27,34 @@ class CloudFlareAPI(ModelAPI):
|
|
36
27
|
config: GenerateConfig = GenerateConfig(),
|
37
28
|
**model_args: Any,
|
38
29
|
):
|
30
|
+
# migrate formerly used CLOUDFLARE_API_TOKEN if no other key is specified
|
31
|
+
if api_key is None and CLOUDFLARE_API_KEY not in os.environ:
|
32
|
+
api_key = os.environ.get(CLOUDFLARE_API_TOKEN, None)
|
33
|
+
|
34
|
+
# account id used for limits and forming base url
|
35
|
+
self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID", None)
|
36
|
+
if not self.account_id:
|
37
|
+
raise environment_prerequisite_error("CloudFlare", "CLOUDFLARE_ACCOUNT_ID")
|
38
|
+
|
39
39
|
super().__init__(
|
40
|
-
model_name=model_name,
|
40
|
+
model_name=f"@cf/{model_name}",
|
41
41
|
base_url=base_url,
|
42
42
|
api_key=api_key,
|
43
|
-
api_key_vars=[CLOUDFLARE_API_TOKEN],
|
44
43
|
config=config,
|
44
|
+
service="CloudFlare",
|
45
|
+
service_base_url=f"https://api.cloudflare.com/client/v4/accounts/{self.account_id}/ai/v1",
|
46
|
+
**model_args,
|
45
47
|
)
|
46
|
-
self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
|
47
|
-
if not self.account_id:
|
48
|
-
raise environment_prerequisite_error("CloudFlare", "CLOUDFLARE_ACCOUNT_ID")
|
49
|
-
if not self.api_key:
|
50
|
-
self.api_key = os.getenv(CLOUDFLARE_API_TOKEN)
|
51
|
-
if not self.api_key:
|
52
|
-
raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN)
|
53
|
-
self.client = httpx.AsyncClient()
|
54
|
-
self._http_hooks = HttpxHooks(self.client)
|
55
|
-
base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
|
56
|
-
self.base_url = (
|
57
|
-
base_url if base_url else "https://api.cloudflare.com/client/v4/accounts"
|
58
|
-
)
|
59
|
-
self.model_args = model_args
|
60
48
|
|
61
49
|
@override
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
config: GenerateConfig,
|
71
|
-
) -> tuple[ModelOutput, ModelCall]:
|
72
|
-
# chat url
|
73
|
-
chat_url = f"{self.base_url}/{self.account_id}/ai/run/@cf"
|
74
|
-
|
75
|
-
# chat api input
|
76
|
-
json: dict[str, Any] = dict(**self.model_args)
|
77
|
-
if config.max_tokens is not None:
|
78
|
-
json["max_tokens"] = config.max_tokens
|
79
|
-
json["messages"] = chat_api_input(input, tools, self.chat_api_handler())
|
80
|
-
|
81
|
-
# request_id
|
82
|
-
request_id = self._http_hooks.start_request()
|
83
|
-
|
84
|
-
# setup response
|
85
|
-
response: dict[str, Any] = {}
|
86
|
-
|
87
|
-
def model_call() -> ModelCall:
|
88
|
-
return ModelCall.create(
|
89
|
-
request=json,
|
90
|
-
response=response,
|
91
|
-
time=self._http_hooks.end_request(request_id),
|
92
|
-
)
|
93
|
-
|
94
|
-
# make the call
|
95
|
-
response = await chat_api_request(
|
96
|
-
self.client,
|
97
|
-
model_name=self.model_name,
|
98
|
-
url=f"{chat_url}/{self.model_name}",
|
99
|
-
headers={
|
100
|
-
"Authorization": f"Bearer {self.api_key}",
|
101
|
-
HttpxHooks.REQUEST_ID_HEADER: request_id,
|
102
|
-
},
|
103
|
-
json=json,
|
104
|
-
)
|
105
|
-
|
106
|
-
# handle response
|
107
|
-
if response["success"]:
|
108
|
-
# extract output
|
109
|
-
content = response["result"]["response"]
|
110
|
-
output = ModelOutput(
|
111
|
-
model=self.model_name,
|
112
|
-
choices=[
|
113
|
-
ChatCompletionChoice(
|
114
|
-
message=self.chat_api_handler().parse_assistant_response(
|
115
|
-
content, tools
|
116
|
-
),
|
117
|
-
stop_reason="stop",
|
118
|
-
)
|
119
|
-
],
|
120
|
-
)
|
121
|
-
|
122
|
-
# return
|
123
|
-
return output, model_call()
|
124
|
-
else:
|
125
|
-
error = str(response.get("errors", "Unknown"))
|
126
|
-
raise RuntimeError(f"Error calling {self.model_name}: {error}")
|
127
|
-
|
128
|
-
@override
|
129
|
-
def should_retry(self, ex: Exception) -> bool:
|
130
|
-
return should_retry_chat_api_error(ex)
|
50
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
51
|
+
if ex.status_code == 403:
|
52
|
+
content = str(ex)
|
53
|
+
if "context window limit" in content:
|
54
|
+
return ModelOutput.from_content(
|
55
|
+
self.model_name, content=content, stop_reason="model_length"
|
56
|
+
)
|
57
|
+
return ex
|
131
58
|
|
132
59
|
# cloudflare enforces rate limits by model for each account
|
133
60
|
@override
|
@@ -138,9 +65,3 @@ class CloudFlareAPI(ModelAPI):
|
|
138
65
|
@override
|
139
66
|
def max_tokens(self) -> int:
|
140
67
|
return DEFAULT_MAX_TOKENS
|
141
|
-
|
142
|
-
def chat_api_handler(self) -> ChatAPIHandler:
|
143
|
-
if "llama" in self.model_name.lower():
|
144
|
-
return Llama31Handler(self.model_name)
|
145
|
-
else:
|
146
|
-
return ChatAPIHandler(self.model_name)
|
@@ -127,7 +127,6 @@ class GoogleGenAIAPI(ModelAPI):
|
|
127
127
|
parts = model_name.split("/")
|
128
128
|
if len(parts) > 1:
|
129
129
|
self.service: str | None = parts[0]
|
130
|
-
model_name = "/".join(parts[1:])
|
131
130
|
else:
|
132
131
|
self.service = None
|
133
132
|
|
@@ -245,14 +244,14 @@ class GoogleGenAIAPI(ModelAPI):
|
|
245
244
|
|
246
245
|
try:
|
247
246
|
response = await client.aio.models.generate_content(
|
248
|
-
model=self.
|
247
|
+
model=self.service_model_name(),
|
249
248
|
contents=gemini_contents,
|
250
249
|
config=parameters,
|
251
250
|
)
|
252
251
|
except ClientError as ex:
|
253
252
|
return self.handle_client_error(ex), model_call()
|
254
253
|
|
255
|
-
model_name = response.model_version or self.
|
254
|
+
model_name = response.model_version or self.service_model_name()
|
256
255
|
output = ModelOutput(
|
257
256
|
model=model_name,
|
258
257
|
choices=completion_choices_from_candidates(model_name, response),
|
@@ -261,6 +260,10 @@ class GoogleGenAIAPI(ModelAPI):
|
|
261
260
|
|
262
261
|
return output, model_call()
|
263
262
|
|
263
|
+
def service_model_name(self) -> str:
|
264
|
+
"""Model name without any service prefix."""
|
265
|
+
return self.model_name.replace(f"{self.service}/", "", 1)
|
266
|
+
|
264
267
|
@override
|
265
268
|
def should_retry(self, ex: Exception) -> bool:
|
266
269
|
if isinstance(ex, APIError) and ex.code is not None:
|
@@ -270,8 +273,8 @@ class GoogleGenAIAPI(ModelAPI):
|
|
270
273
|
|
271
274
|
@override
|
272
275
|
def connection_key(self) -> str:
|
273
|
-
"""Scope for enforcing max_connections
|
274
|
-
return self.
|
276
|
+
"""Scope for enforcing max_connections."""
|
277
|
+
return str(self.api_key)
|
275
278
|
|
276
279
|
def handle_client_error(self, ex: ClientError) -> ModelOutput | Exception:
|
277
280
|
if (
|
@@ -283,7 +286,9 @@ class GoogleGenAIAPI(ModelAPI):
|
|
283
286
|
)
|
284
287
|
):
|
285
288
|
return ModelOutput.from_content(
|
286
|
-
self.
|
289
|
+
self.service_model_name(),
|
290
|
+
content=ex.message,
|
291
|
+
stop_reason="model_length",
|
287
292
|
)
|
288
293
|
else:
|
289
294
|
raise ex
|
@@ -644,10 +649,16 @@ def completion_choices_from_candidates(
|
|
644
649
|
)
|
645
650
|
]
|
646
651
|
else:
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
652
|
+
return [
|
653
|
+
ChatCompletionChoice(
|
654
|
+
message=ChatMessageAssistant(
|
655
|
+
content=NO_CONTENT,
|
656
|
+
model=model,
|
657
|
+
source="generate",
|
658
|
+
),
|
659
|
+
stop_reason="stop",
|
660
|
+
)
|
661
|
+
]
|
651
662
|
|
652
663
|
|
653
664
|
def split_reasoning(content: str) -> tuple[str | None, str]:
|
@@ -1,15 +1,12 @@
|
|
1
|
-
import
|
1
|
+
from openai import APIStatusError
|
2
2
|
|
3
|
-
from inspect_ai.model.
|
4
|
-
from inspect_ai.model._providers.util.util import environment_prerequisite_error
|
3
|
+
from inspect_ai.model._model_output import ModelOutput
|
5
4
|
|
6
5
|
from .._generate_config import GenerateConfig
|
7
|
-
from .
|
6
|
+
from .openai_compatible import OpenAICompatibleAPI
|
8
7
|
|
9
|
-
GROK_API_KEY = "GROK_API_KEY"
|
10
8
|
|
11
|
-
|
12
|
-
class GrokAPI(OpenAIAPI):
|
9
|
+
class GrokAPI(OpenAICompatibleAPI):
|
13
10
|
def __init__(
|
14
11
|
self,
|
15
12
|
model_name: str,
|
@@ -17,19 +14,28 @@ class GrokAPI(OpenAIAPI):
|
|
17
14
|
api_key: str | None = None,
|
18
15
|
config: GenerateConfig = GenerateConfig(),
|
19
16
|
) -> None:
|
20
|
-
# resolve base url
|
21
|
-
base_url = model_base_url(base_url, "GROK_BASE_URL")
|
22
|
-
base_url = base_url or "https://api.x.ai/v1"
|
23
|
-
|
24
|
-
# resolve api key
|
25
|
-
api_key = api_key or os.environ.get(GROK_API_KEY, None)
|
26
|
-
if api_key is None:
|
27
|
-
raise environment_prerequisite_error("Grok", GROK_API_KEY)
|
28
|
-
|
29
|
-
# call super
|
30
17
|
super().__init__(
|
31
18
|
model_name=model_name,
|
32
19
|
base_url=base_url,
|
33
20
|
api_key=api_key,
|
34
21
|
config=config,
|
22
|
+
service="Grok",
|
23
|
+
service_base_url="https://api.x.ai/v1",
|
35
24
|
)
|
25
|
+
|
26
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
27
|
+
if ex.status_code == 400:
|
28
|
+
# extract message
|
29
|
+
if isinstance(ex.body, dict) and "message" in ex.body.keys():
|
30
|
+
content = str(ex.body.get("message"))
|
31
|
+
else:
|
32
|
+
content = ex.message
|
33
|
+
|
34
|
+
if "prompt length" in content:
|
35
|
+
return ModelOutput.from_content(
|
36
|
+
model=self.model_name, content=content, stop_reason="model_length"
|
37
|
+
)
|
38
|
+
else:
|
39
|
+
return ex
|
40
|
+
else:
|
41
|
+
return ex
|
@@ -102,7 +102,7 @@ class GroqAPI(ModelAPI):
|
|
102
102
|
tools: list[ToolInfo],
|
103
103
|
tool_choice: ToolChoice,
|
104
104
|
config: GenerateConfig,
|
105
|
-
) -> tuple[ModelOutput, ModelCall]:
|
105
|
+
) -> tuple[ModelOutput | Exception, ModelCall]:
|
106
106
|
# allocate request_id (so we can see it from ModelCall)
|
107
107
|
request_id = self._http_hooks.start_request()
|
108
108
|
|
@@ -136,45 +136,48 @@ class GroqAPI(ModelAPI):
|
|
136
136
|
**params,
|
137
137
|
)
|
138
138
|
|
139
|
-
|
140
|
-
|
141
|
-
|
139
|
+
try:
|
140
|
+
completion: ChatCompletion = await self.client.chat.completions.create(
|
141
|
+
**request,
|
142
|
+
)
|
142
143
|
|
143
|
-
|
144
|
-
|
145
|
-
# extract metadata
|
146
|
-
metadata: dict[str, Any] = {
|
147
|
-
"id": completion.id,
|
148
|
-
"system_fingerprint": completion.system_fingerprint,
|
149
|
-
"created": completion.created,
|
150
|
-
}
|
151
|
-
if completion.usage:
|
152
|
-
metadata = metadata | {
|
153
|
-
"queue_time": completion.usage.queue_time,
|
154
|
-
"prompt_time": completion.usage.prompt_time,
|
155
|
-
"completion_time": completion.usage.completion_time,
|
156
|
-
"total_time": completion.usage.total_time,
|
157
|
-
}
|
144
|
+
response = completion.model_dump()
|
158
145
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
146
|
+
# extract metadata
|
147
|
+
metadata: dict[str, Any] = {
|
148
|
+
"id": completion.id,
|
149
|
+
"system_fingerprint": completion.system_fingerprint,
|
150
|
+
"created": completion.created,
|
151
|
+
}
|
152
|
+
if completion.usage:
|
153
|
+
metadata = metadata | {
|
154
|
+
"queue_time": completion.usage.queue_time,
|
155
|
+
"prompt_time": completion.usage.prompt_time,
|
156
|
+
"completion_time": completion.usage.completion_time,
|
157
|
+
"total_time": completion.usage.total_time,
|
158
|
+
}
|
159
|
+
|
160
|
+
# extract output
|
161
|
+
choices = self._chat_choices_from_response(completion, tools)
|
162
|
+
output = ModelOutput(
|
163
|
+
model=completion.model,
|
164
|
+
choices=choices,
|
165
|
+
usage=(
|
166
|
+
ModelUsage(
|
167
|
+
input_tokens=completion.usage.prompt_tokens,
|
168
|
+
output_tokens=completion.usage.completion_tokens,
|
169
|
+
total_tokens=completion.usage.total_tokens,
|
170
|
+
)
|
171
|
+
if completion.usage
|
172
|
+
else None
|
173
|
+
),
|
174
|
+
metadata=metadata,
|
175
|
+
)
|
175
176
|
|
176
|
-
|
177
|
-
|
177
|
+
# return
|
178
|
+
return output, model_call()
|
179
|
+
except APIStatusError as ex:
|
180
|
+
return self.handle_bad_request(ex), model_call()
|
178
181
|
|
179
182
|
def completion_params(self, config: GenerateConfig) -> Dict[str, Any]:
|
180
183
|
params: dict[str, Any] = {}
|
@@ -234,6 +237,27 @@ class GroqAPI(ModelAPI):
|
|
234
237
|
def max_tokens(self) -> Optional[int]:
|
235
238
|
return DEFAULT_MAX_TOKENS
|
236
239
|
|
240
|
+
def handle_bad_request(self, ex: APIStatusError) -> ModelOutput | Exception:
|
241
|
+
if ex.status_code == 400:
|
242
|
+
# extract code and message
|
243
|
+
content = ex.message
|
244
|
+
code = ""
|
245
|
+
if isinstance(ex.body, dict) and isinstance(
|
246
|
+
ex.body.get("error", None), dict
|
247
|
+
):
|
248
|
+
error = ex.body.get("error", {})
|
249
|
+
content = str(error.get("message", content))
|
250
|
+
code = error.get("code", code)
|
251
|
+
|
252
|
+
if code == "context_length_exceeded":
|
253
|
+
return ModelOutput.from_content(
|
254
|
+
model=self.model_name,
|
255
|
+
content=content,
|
256
|
+
stop_reason="model_length",
|
257
|
+
)
|
258
|
+
|
259
|
+
return ex
|
260
|
+
|
237
261
|
|
238
262
|
async def as_groq_chat_messages(
|
239
263
|
messages: list[ChatMessage],
|
@@ -1,10 +1,8 @@
|
|
1
|
-
from inspect_ai.model._providers.util import model_base_url
|
2
|
-
|
3
1
|
from .._generate_config import GenerateConfig
|
4
|
-
from .
|
2
|
+
from .openai_compatible import OpenAICompatibleAPI
|
5
3
|
|
6
4
|
|
7
|
-
class LlamaCppPythonAPI(
|
5
|
+
class LlamaCppPythonAPI(OpenAICompatibleAPI):
|
8
6
|
def __init__(
|
9
7
|
self,
|
10
8
|
model_name: str,
|
@@ -12,10 +10,11 @@ class LlamaCppPythonAPI(OpenAIAPI):
|
|
12
10
|
api_key: str | None = None,
|
13
11
|
config: GenerateConfig = GenerateConfig(),
|
14
12
|
) -> None:
|
15
|
-
base_url = model_base_url(base_url, "LLAMA_CPP_PYTHON_BASE_URL")
|
16
|
-
base_url = base_url if base_url else "http://localhost:8000/v1"
|
17
|
-
if not api_key:
|
18
|
-
api_key = "llama-cpp-python"
|
19
13
|
super().__init__(
|
20
|
-
model_name=model_name,
|
14
|
+
model_name=model_name,
|
15
|
+
base_url=base_url,
|
16
|
+
api_key=api_key or "llama-cpp-python",
|
17
|
+
config=config,
|
18
|
+
service="llama_cpp_python",
|
19
|
+
service_base_url="http://localhost:8000/v1",
|
21
20
|
)
|
@@ -86,7 +86,6 @@ class MistralAPI(ModelAPI):
|
|
86
86
|
parts = model_name.split("/")
|
87
87
|
if len(parts) > 1:
|
88
88
|
self.service: str | None = parts[0]
|
89
|
-
model_name = "/".join(parts[1:])
|
90
89
|
else:
|
91
90
|
self.service = None
|
92
91
|
|
@@ -150,7 +149,7 @@ class MistralAPI(ModelAPI):
|
|
150
149
|
# build request
|
151
150
|
request_id = http_hooks.start_request()
|
152
151
|
request: dict[str, Any] = dict(
|
153
|
-
model=self.
|
152
|
+
model=self.service_model_name(),
|
154
153
|
messages=await mistral_chat_messages(input),
|
155
154
|
tools=mistral_chat_tools(tools) if len(tools) > 0 else None,
|
156
155
|
tool_choice=(
|
@@ -228,6 +227,10 @@ class MistralAPI(ModelAPI):
|
|
228
227
|
),
|
229
228
|
), model_call()
|
230
229
|
|
230
|
+
def service_model_name(self) -> str:
|
231
|
+
"""Model name without any service prefix."""
|
232
|
+
return self.model_name.replace(f"{self.service}/", "", 1)
|
233
|
+
|
231
234
|
@override
|
232
235
|
def should_retry(self, ex: Exception) -> bool:
|
233
236
|
if isinstance(ex, SDKError):
|
@@ -246,7 +249,9 @@ class MistralAPI(ModelAPI):
|
|
246
249
|
content = body.get("message", ex.body)
|
247
250
|
if "maximum context length" in ex.body:
|
248
251
|
return ModelOutput.from_content(
|
249
|
-
model=self.
|
252
|
+
model=self.service_model_name(),
|
253
|
+
content=content,
|
254
|
+
stop_reason="model_length",
|
250
255
|
)
|
251
256
|
else:
|
252
257
|
return ex
|
@@ -1,10 +1,8 @@
|
|
1
|
-
from inspect_ai.model._providers.util import model_base_url
|
2
|
-
|
3
1
|
from .._generate_config import GenerateConfig
|
4
|
-
from .
|
2
|
+
from .openai_compatible import OpenAICompatibleAPI
|
5
3
|
|
6
4
|
|
7
|
-
class OllamaAPI(
|
5
|
+
class OllamaAPI(OpenAICompatibleAPI):
|
8
6
|
def __init__(
|
9
7
|
self,
|
10
8
|
model_name: str,
|
@@ -12,10 +10,11 @@ class OllamaAPI(OpenAIAPI):
|
|
12
10
|
api_key: str | None = None,
|
13
11
|
config: GenerateConfig = GenerateConfig(),
|
14
12
|
) -> None:
|
15
|
-
base_url = model_base_url(base_url, "OLLAMA_BASE_URL")
|
16
|
-
base_url = base_url if base_url else "http://localhost:11434/v1"
|
17
|
-
if not api_key:
|
18
|
-
api_key = "ollama"
|
19
13
|
super().__init__(
|
20
|
-
model_name=model_name,
|
14
|
+
model_name=model_name,
|
15
|
+
base_url=base_url,
|
16
|
+
api_key=api_key or "ollama",
|
17
|
+
config=config,
|
18
|
+
service="Ollama",
|
19
|
+
service_base_url="http://localhost:11434/v1",
|
21
20
|
)
|