inspect-ai 0.3.88__py3-none-any.whl → 0.3.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_cli/score.py +1 -12
- inspect_ai/_cli/util.py +4 -2
- inspect_ai/_display/core/footer.py +2 -2
- inspect_ai/_display/plain/display.py +2 -2
- inspect_ai/_eval/context.py +7 -1
- inspect_ai/_eval/eval.py +51 -27
- inspect_ai/_eval/evalset.py +27 -10
- inspect_ai/_eval/loader.py +7 -8
- inspect_ai/_eval/run.py +23 -31
- inspect_ai/_eval/score.py +18 -1
- inspect_ai/_eval/task/log.py +5 -13
- inspect_ai/_eval/task/resolved.py +1 -0
- inspect_ai/_eval/task/run.py +231 -256
- inspect_ai/_eval/task/task.py +25 -2
- inspect_ai/_eval/task/util.py +1 -8
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/json.py +8 -3
- inspect_ai/_util/registry.py +30 -13
- inspect_ai/_view/www/App.css +5 -0
- inspect_ai/_view/www/dist/assets/index.css +71 -36
- inspect_ai/_view/www/dist/assets/index.js +573 -475
- inspect_ai/_view/www/log-schema.json +66 -0
- inspect_ai/_view/www/src/metadata/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/metadata/MetaDataView.tsx +13 -8
- inspect_ai/_view/www/src/metadata/RenderedContent.tsx +3 -0
- inspect_ai/_view/www/src/plan/ModelCard.module.css +16 -0
- inspect_ai/_view/www/src/plan/ModelCard.tsx +93 -0
- inspect_ai/_view/www/src/samples/chat/ChatMessage.tsx +2 -2
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.module.css +2 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.tsx +5 -1
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +12 -6
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.module.css +0 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.tsx +6 -29
- inspect_ai/_view/www/src/types/log.d.ts +24 -6
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ModelRolesView.tsx +43 -0
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.module.css +1 -1
- inspect_ai/_view/www/src/workspace/navbar/PrimaryBar.tsx +5 -0
- inspect_ai/_view/www/src/workspace/tabs/InfoTab.tsx +2 -0
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_as_tool.py +1 -1
- inspect_ai/agent/_bridge/bridge.py +9 -2
- inspect_ai/agent/_react.py +142 -74
- inspect_ai/agent/_run.py +13 -2
- inspect_ai/agent/_types.py +6 -0
- inspect_ai/approval/_apply.py +6 -7
- inspect_ai/approval/_approver.py +3 -3
- inspect_ai/approval/_auto.py +2 -2
- inspect_ai/approval/_call.py +20 -4
- inspect_ai/approval/_human/approver.py +3 -3
- inspect_ai/approval/_human/manager.py +2 -2
- inspect_ai/approval/_human/panel.py +3 -3
- inspect_ai/approval/_policy.py +3 -3
- inspect_ai/log/__init__.py +2 -0
- inspect_ai/log/_log.py +23 -2
- inspect_ai/log/_model.py +58 -0
- inspect_ai/log/_recorders/file.py +14 -3
- inspect_ai/log/_transcript.py +3 -0
- inspect_ai/model/__init__.py +2 -0
- inspect_ai/model/_call_tools.py +4 -1
- inspect_ai/model/_model.py +49 -3
- inspect_ai/model/_openai.py +151 -21
- inspect_ai/model/_providers/anthropic.py +20 -12
- inspect_ai/model/_providers/bedrock.py +3 -3
- inspect_ai/model/_providers/cloudflare.py +29 -108
- inspect_ai/model/_providers/google.py +21 -10
- inspect_ai/model/_providers/grok.py +23 -17
- inspect_ai/model/_providers/groq.py +61 -37
- inspect_ai/model/_providers/llama_cpp_python.py +8 -9
- inspect_ai/model/_providers/mistral.py +8 -3
- inspect_ai/model/_providers/ollama.py +8 -9
- inspect_ai/model/_providers/openai.py +53 -157
- inspect_ai/model/_providers/openai_compatible.py +195 -0
- inspect_ai/model/_providers/openrouter.py +4 -15
- inspect_ai/model/_providers/providers.py +11 -0
- inspect_ai/model/_providers/together.py +25 -23
- inspect_ai/model/_trim.py +83 -0
- inspect_ai/solver/_plan.py +5 -3
- inspect_ai/tool/_tool_def.py +8 -2
- inspect_ai/util/__init__.py +3 -0
- inspect_ai/util/_concurrency.py +15 -2
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/RECORD +88 -83
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/WHEEL +1 -1
- inspect_ai/_eval/task/rundir.py +0 -78
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.88.dist-info → inspect_ai-0.3.90.dist-info}/top_level.txt +0 -0
inspect_ai/log/_model.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
from inspect import isgenerator
|
2
|
+
from typing import Any, Iterator
|
3
|
+
|
4
|
+
from inspect_ai.log._log import EvalModelConfig
|
5
|
+
from inspect_ai.model._model import Model, get_model
|
6
|
+
|
7
|
+
|
8
|
+
def model_roles_to_model_roles_config(
|
9
|
+
model_roles: dict[str, Model] | None,
|
10
|
+
) -> dict[str, EvalModelConfig] | None:
|
11
|
+
if model_roles is not None:
|
12
|
+
return {k: model_to_model_config(v) for k, v in model_roles.items()}
|
13
|
+
else:
|
14
|
+
return None
|
15
|
+
|
16
|
+
|
17
|
+
def model_roles_config_to_model_roles(
|
18
|
+
model_config: dict[str, EvalModelConfig] | None,
|
19
|
+
) -> dict[str, Model] | None:
|
20
|
+
if model_config is not None:
|
21
|
+
return {k: model_config_to_model(v) for k, v in model_config.items()}
|
22
|
+
else:
|
23
|
+
return None
|
24
|
+
|
25
|
+
|
26
|
+
def model_to_model_config(model: Model) -> EvalModelConfig:
|
27
|
+
return EvalModelConfig(
|
28
|
+
model=str(model),
|
29
|
+
config=model.config,
|
30
|
+
base_url=model.api.base_url,
|
31
|
+
args=model_args_for_log(model.model_args),
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def model_config_to_model(model_config: EvalModelConfig) -> Model:
|
36
|
+
return get_model(
|
37
|
+
model=model_config.model,
|
38
|
+
config=model_config.config,
|
39
|
+
base_url=model_config.base_url,
|
40
|
+
memoize=False,
|
41
|
+
**model_config.args,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def model_args_for_log(model_args: dict[str, Any]) -> dict[str, Any]:
|
46
|
+
# redact authentication oriented model_args
|
47
|
+
model_args = model_args.copy()
|
48
|
+
if "api_key" in model_args:
|
49
|
+
del model_args["api_key"]
|
50
|
+
model_args = {k: v for k, v in model_args.items() if not k.startswith("aws_")}
|
51
|
+
|
52
|
+
# don't try to serialise generators
|
53
|
+
model_args = {
|
54
|
+
k: v
|
55
|
+
for k, v in model_args.items()
|
56
|
+
if not isgenerator(v) and not isinstance(v, Iterator)
|
57
|
+
}
|
58
|
+
return model_args
|
@@ -1,8 +1,10 @@
|
|
1
|
+
import os
|
1
2
|
from logging import getLogger
|
2
3
|
from typing import Any
|
3
4
|
|
4
5
|
from typing_extensions import override
|
5
6
|
|
7
|
+
from inspect_ai._util.constants import MODEL_NONE
|
6
8
|
from inspect_ai._util.file import filesystem
|
7
9
|
from inspect_ai._util.registry import registry_unqualified_name
|
8
10
|
|
@@ -71,9 +73,18 @@ class FileRecorder(Recorder):
|
|
71
73
|
return s.replace("_", "-").replace("/", "-").replace(":", "-")
|
72
74
|
|
73
75
|
# remove package from task name
|
74
|
-
task = registry_unqualified_name(eval.task)
|
75
|
-
|
76
|
-
|
76
|
+
task = registry_unqualified_name(eval.task) # noqa: F841
|
77
|
+
|
78
|
+
# derive log file pattern
|
79
|
+
log_file_pattern = os.getenv("INSPECT_EVAL_LOG_FILE_PATTERN", "{task}_{id}")
|
80
|
+
|
81
|
+
# compute and return log file name
|
82
|
+
log_file_name = f"{clean(eval.created)}_" + log_file_pattern
|
83
|
+
log_file_name = log_file_name.replace("{task}", clean(task))
|
84
|
+
log_file_name = log_file_name.replace("{id}", clean(eval.task_id))
|
85
|
+
model = clean(eval.model) if eval.model != MODEL_NONE else ""
|
86
|
+
log_file_name = log_file_name.replace("{model}", model)
|
87
|
+
return log_file_name
|
77
88
|
|
78
89
|
def _log_file_path(self, eval: EvalSpec) -> str:
|
79
90
|
return f"{self.log_dir}{self.fs.sep}{self._log_file_key(eval)}{self.suffix}"
|
inspect_ai/log/_transcript.py
CHANGED
inspect_ai/model/__init__.py
CHANGED
@@ -47,6 +47,7 @@ from ._model_output import (
|
|
47
47
|
)
|
48
48
|
from ._providers.providers import *
|
49
49
|
from ._registry import modelapi
|
50
|
+
from ._trim import trim_messages
|
50
51
|
|
51
52
|
__all__ = [
|
52
53
|
"GenerateConfig",
|
@@ -80,6 +81,7 @@ __all__ = [
|
|
80
81
|
"call_tools",
|
81
82
|
"execute_tools",
|
82
83
|
"ExecuteToolsResult",
|
84
|
+
"trim_messages",
|
83
85
|
"cache_clear",
|
84
86
|
"cache_list_expired",
|
85
87
|
"cache_path",
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -264,6 +264,7 @@ async def execute_tools(
|
|
264
264
|
tuple[ExecuteToolsResult, ToolEvent, Exception | None]
|
265
265
|
]()
|
266
266
|
|
267
|
+
result_exception = None
|
267
268
|
async with anyio.create_task_group() as tg:
|
268
269
|
tg.start_soon(call_tool_task, call, messages, send_stream)
|
269
270
|
event._set_cancel_fn(tg.cancel_scope.cancel)
|
@@ -348,7 +349,9 @@ async def call_tool(
|
|
348
349
|
# if we have a tool approver, apply it now
|
349
350
|
from inspect_ai.approval._apply import apply_tool_approval
|
350
351
|
|
351
|
-
approved, approval = await apply_tool_approval(
|
352
|
+
approved, approval = await apply_tool_approval(
|
353
|
+
message, call, tool_def.viewer, conversation
|
354
|
+
)
|
352
355
|
if not approved:
|
353
356
|
if approval and approval.decision == "terminate":
|
354
357
|
from inspect_ai.solver._limit import SampleLimitExceededError
|
inspect_ai/model/_model.py
CHANGED
@@ -270,6 +270,7 @@ class Model:
|
|
270
270
|
self.api = api
|
271
271
|
self.config = config
|
272
272
|
self.model_args = model_args
|
273
|
+
self._role: str | None = None
|
273
274
|
|
274
275
|
# state indicating whether our lifetime is bound by a context manager
|
275
276
|
self._context_bound = False
|
@@ -311,6 +312,14 @@ class Model:
|
|
311
312
|
"""Model name."""
|
312
313
|
return self.api.model_name
|
313
314
|
|
315
|
+
@property
|
316
|
+
def role(self) -> str | None:
|
317
|
+
"""Model role."""
|
318
|
+
return self._role
|
319
|
+
|
320
|
+
def _set_role(self, role: str) -> None:
|
321
|
+
self._role = role
|
322
|
+
|
314
323
|
def __str__(self) -> str:
|
315
324
|
return f"{ModelName(self)}"
|
316
325
|
|
@@ -716,7 +725,7 @@ class Model:
|
|
716
725
|
)
|
717
726
|
model_name = ModelName(self)
|
718
727
|
async with concurrency(
|
719
|
-
name=
|
728
|
+
name=str(model_name),
|
720
729
|
concurrency=max_connections,
|
721
730
|
key=f"Model{self.api.connection_key()}",
|
722
731
|
):
|
@@ -738,6 +747,7 @@ class Model:
|
|
738
747
|
model = str(self)
|
739
748
|
event = ModelEvent(
|
740
749
|
model=model,
|
750
|
+
role=self.role,
|
741
751
|
input=input,
|
742
752
|
tools=tools,
|
743
753
|
tool_choice=tool_choice,
|
@@ -828,6 +838,9 @@ class ModelName:
|
|
828
838
|
|
829
839
|
def get_model(
|
830
840
|
model: str | Model | None = None,
|
841
|
+
*,
|
842
|
+
role: str | None = None,
|
843
|
+
default: str | Model | None = None,
|
831
844
|
config: GenerateConfig = GenerateConfig(),
|
832
845
|
base_url: str | None = None,
|
833
846
|
api_key: str | None = None,
|
@@ -858,6 +871,11 @@ def get_model(
|
|
858
871
|
if `None` is passed then the model currently being
|
859
872
|
evaluated is returned (or if there is no evaluation
|
860
873
|
then the model referred to by `INSPECT_EVAL_MODEL`).
|
874
|
+
role: Optional named role for model (e.g. for roles specified
|
875
|
+
at the task or eval level). Provide a `default` as a fallback
|
876
|
+
in the case where the `role` hasn't been externally specified.
|
877
|
+
default: Optional. Fallback model in case the specified
|
878
|
+
`model` or `role` is not found.
|
861
879
|
config: Configuration for model.
|
862
880
|
base_url: Optional. Alternate base URL for model.
|
863
881
|
api_key: Optional. API key for model.
|
@@ -878,6 +896,22 @@ def get_model(
|
|
878
896
|
if model == "none":
|
879
897
|
model = "none/none"
|
880
898
|
|
899
|
+
# resolve model role
|
900
|
+
if role is not None:
|
901
|
+
model_for_role = model_roles().get(role, None)
|
902
|
+
if model_for_role is not None:
|
903
|
+
return model_for_role
|
904
|
+
|
905
|
+
# if a default was specified then use it as the model if
|
906
|
+
# no model was passed
|
907
|
+
if model is None:
|
908
|
+
if isinstance(default, Model):
|
909
|
+
if role is not None:
|
910
|
+
default._set_role(role)
|
911
|
+
return default
|
912
|
+
else:
|
913
|
+
model = default
|
914
|
+
|
881
915
|
# now try finding an 'ambient' model (active or env var)
|
882
916
|
if model is None:
|
883
917
|
# return active_model if there is one
|
@@ -901,6 +935,7 @@ def get_model(
|
|
901
935
|
if memoize:
|
902
936
|
model_cache_key = (
|
903
937
|
model
|
938
|
+
+ str(role)
|
904
939
|
+ config.model_dump_json(exclude_none=True)
|
905
940
|
+ str(base_url)
|
906
941
|
+ str(api_key)
|
@@ -941,10 +976,11 @@ def get_model(
|
|
941
976
|
**model_args,
|
942
977
|
)
|
943
978
|
m = Model(modelapi_instance, config, model_args)
|
979
|
+
if role is not None:
|
980
|
+
m._set_role(role)
|
944
981
|
if memoize:
|
945
982
|
_models[model_cache_key] = m
|
946
983
|
return m
|
947
|
-
|
948
984
|
else:
|
949
985
|
from_api = f" from {api_name}" if api_name else ""
|
950
986
|
raise ValueError(f"Model name {model}{from_api} not recognized.")
|
@@ -1353,10 +1389,20 @@ def active_model() -> Model | None:
|
|
1353
1389
|
return active_model_context_var.get(None)
|
1354
1390
|
|
1355
1391
|
|
1356
|
-
|
1392
|
+
def init_model_roles(roles: dict[str, Model]) -> None:
|
1393
|
+
_model_roles.set(roles)
|
1394
|
+
|
1395
|
+
|
1396
|
+
def model_roles() -> dict[str, Model]:
|
1397
|
+
return _model_roles.get()
|
1398
|
+
|
1399
|
+
|
1357
1400
|
active_model_context_var: ContextVar[Model | None] = ContextVar("active_model")
|
1358
1401
|
|
1402
|
+
_model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={})
|
1403
|
+
|
1359
1404
|
|
1405
|
+
# shared contexts for asyncio tasks
|
1360
1406
|
def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
|
1361
1407
|
from inspect_ai.log._samples import (
|
1362
1408
|
active_sample_message_limit,
|
inspect_ai/model/_openai.py
CHANGED
@@ -1,9 +1,18 @@
|
|
1
1
|
import json
|
2
2
|
import re
|
3
|
+
import socket
|
3
4
|
from copy import copy
|
4
|
-
from typing import Literal
|
5
|
-
|
6
|
-
|
5
|
+
from typing import Any, Literal
|
6
|
+
|
7
|
+
import httpx
|
8
|
+
from openai import (
|
9
|
+
DEFAULT_CONNECTION_LIMITS,
|
10
|
+
DEFAULT_TIMEOUT,
|
11
|
+
APIStatusError,
|
12
|
+
APITimeoutError,
|
13
|
+
OpenAIError,
|
14
|
+
RateLimitError,
|
15
|
+
)
|
7
16
|
from openai.types.chat import (
|
8
17
|
ChatCompletion,
|
9
18
|
ChatCompletionAssistantMessageParam,
|
@@ -38,9 +47,11 @@ from inspect_ai._util.content import (
|
|
38
47
|
ContentReasoning,
|
39
48
|
ContentText,
|
40
49
|
)
|
50
|
+
from inspect_ai._util.http import is_retryable_http_status
|
41
51
|
from inspect_ai._util.images import file_as_data_uri
|
42
52
|
from inspect_ai._util.url import is_http_url
|
43
53
|
from inspect_ai.model._call_tools import parse_tool_call
|
54
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
44
55
|
from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
|
45
56
|
from inspect_ai.model._reasoning import parse_content_with_reasoning
|
46
57
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
@@ -146,24 +157,20 @@ async def openai_chat_completion_part(
|
|
146
157
|
|
147
158
|
|
148
159
|
async def openai_chat_message(
|
149
|
-
message: ChatMessage,
|
160
|
+
message: ChatMessage, system_role: Literal["user", "system", "developer"] = "system"
|
150
161
|
) -> ChatCompletionMessageParam:
|
151
162
|
if message.role == "system":
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
else:
|
164
|
-
return ChatCompletionSystemMessageParam(
|
165
|
-
role=message.role, content=message.text
|
166
|
-
)
|
163
|
+
match system_role:
|
164
|
+
case "user":
|
165
|
+
return ChatCompletionUserMessageParam(role="user", content=message.text)
|
166
|
+
case "system":
|
167
|
+
return ChatCompletionSystemMessageParam(
|
168
|
+
role=message.role, content=message.text
|
169
|
+
)
|
170
|
+
case "developer":
|
171
|
+
return ChatCompletionDeveloperMessageParam(
|
172
|
+
role="developer", content=message.text
|
173
|
+
)
|
167
174
|
elif message.role == "user":
|
168
175
|
return ChatCompletionUserMessageParam(
|
169
176
|
role=message.role,
|
@@ -202,9 +209,54 @@ async def openai_chat_message(
|
|
202
209
|
|
203
210
|
|
204
211
|
async def openai_chat_messages(
|
205
|
-
messages: list[ChatMessage],
|
212
|
+
messages: list[ChatMessage],
|
213
|
+
system_role: Literal["user", "system", "developer"] = "system",
|
206
214
|
) -> list[ChatCompletionMessageParam]:
|
207
|
-
return [await openai_chat_message(message,
|
215
|
+
return [await openai_chat_message(message, system_role) for message in messages]
|
216
|
+
|
217
|
+
|
218
|
+
def openai_completion_params(
|
219
|
+
model: str, config: GenerateConfig, tools: bool
|
220
|
+
) -> dict[str, Any]:
|
221
|
+
params: dict[str, Any] = dict(model=model)
|
222
|
+
if config.max_tokens is not None:
|
223
|
+
params["max_tokens"] = config.max_tokens
|
224
|
+
if config.frequency_penalty is not None:
|
225
|
+
params["frequency_penalty"] = config.frequency_penalty
|
226
|
+
if config.stop_seqs is not None:
|
227
|
+
params["stop"] = config.stop_seqs
|
228
|
+
if config.presence_penalty is not None:
|
229
|
+
params["presence_penalty"] = config.presence_penalty
|
230
|
+
if config.logit_bias is not None:
|
231
|
+
params["logit_bias"] = config.logit_bias
|
232
|
+
if config.seed is not None:
|
233
|
+
params["seed"] = config.seed
|
234
|
+
if config.temperature is not None:
|
235
|
+
params["temperature"] = config.temperature
|
236
|
+
if config.top_p is not None:
|
237
|
+
params["top_p"] = config.top_p
|
238
|
+
if config.num_choices is not None:
|
239
|
+
params["n"] = config.num_choices
|
240
|
+
if config.logprobs is not None:
|
241
|
+
params["logprobs"] = config.logprobs
|
242
|
+
if config.top_logprobs is not None:
|
243
|
+
params["top_logprobs"] = config.top_logprobs
|
244
|
+
if tools and config.parallel_tool_calls is not None:
|
245
|
+
params["parallel_tool_calls"] = config.parallel_tool_calls
|
246
|
+
if config.reasoning_effort is not None:
|
247
|
+
params["reasoning_effort"] = config.reasoning_effort
|
248
|
+
if config.response_schema is not None:
|
249
|
+
params["response_format"] = dict(
|
250
|
+
type="json_schema",
|
251
|
+
json_schema=dict(
|
252
|
+
name=config.response_schema.name,
|
253
|
+
schema=config.response_schema.json_schema.model_dump(exclude_none=True),
|
254
|
+
description=config.response_schema.description,
|
255
|
+
strict=config.response_schema.strict,
|
256
|
+
),
|
257
|
+
)
|
258
|
+
|
259
|
+
return params
|
208
260
|
|
209
261
|
|
210
262
|
def openai_assistant_content(message: ChatMessageAssistant) -> str:
|
@@ -496,6 +548,35 @@ def chat_message_assistant_from_openai(
|
|
496
548
|
)
|
497
549
|
|
498
550
|
|
551
|
+
def model_output_from_openai(
|
552
|
+
completion: ChatCompletion,
|
553
|
+
choices: list[ChatCompletionChoice],
|
554
|
+
) -> ModelOutput:
|
555
|
+
return ModelOutput(
|
556
|
+
model=completion.model,
|
557
|
+
choices=choices,
|
558
|
+
usage=(
|
559
|
+
ModelUsage(
|
560
|
+
input_tokens=completion.usage.prompt_tokens,
|
561
|
+
output_tokens=completion.usage.completion_tokens,
|
562
|
+
input_tokens_cache_read=(
|
563
|
+
completion.usage.prompt_tokens_details.cached_tokens
|
564
|
+
if completion.usage.prompt_tokens_details is not None
|
565
|
+
else None # openai only have cache read stats/pricing.
|
566
|
+
),
|
567
|
+
reasoning_tokens=(
|
568
|
+
completion.usage.completion_tokens_details.reasoning_tokens
|
569
|
+
if completion.usage.completion_tokens_details is not None
|
570
|
+
else None
|
571
|
+
),
|
572
|
+
total_tokens=completion.usage.total_tokens,
|
573
|
+
)
|
574
|
+
if completion.usage
|
575
|
+
else None
|
576
|
+
),
|
577
|
+
)
|
578
|
+
|
579
|
+
|
499
580
|
def chat_choices_from_openai(
|
500
581
|
response: ChatCompletion, tools: list[ToolInfo]
|
501
582
|
) -> list[ChatCompletionChoice]:
|
@@ -517,6 +598,19 @@ def chat_choices_from_openai(
|
|
517
598
|
]
|
518
599
|
|
519
600
|
|
601
|
+
def openai_should_retry(ex: Exception) -> bool:
|
602
|
+
if isinstance(ex, RateLimitError):
|
603
|
+
return True
|
604
|
+
elif isinstance(ex, APIStatusError):
|
605
|
+
return is_retryable_http_status(ex.status_code)
|
606
|
+
elif isinstance(ex, OpenAIResponseError):
|
607
|
+
return ex.code in ["rate_limit_exceeded", "server_error"]
|
608
|
+
elif isinstance(ex, APITimeoutError):
|
609
|
+
return True
|
610
|
+
else:
|
611
|
+
return False
|
612
|
+
|
613
|
+
|
520
614
|
def openai_handle_bad_request(
|
521
615
|
model_name: str, e: APIStatusError
|
522
616
|
) -> ModelOutput | Exception:
|
@@ -559,3 +653,39 @@ def openai_media_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
|
559
653
|
value = copy(value)
|
560
654
|
value.update(data=BASE_64_DATA_REMOVED)
|
561
655
|
return value
|
656
|
+
|
657
|
+
|
658
|
+
class OpenAIAsyncHttpxClient(httpx.AsyncClient):
|
659
|
+
"""Custom async client that deals better with long running Async requests.
|
660
|
+
|
661
|
+
Based on Anthropic DefaultAsyncHttpClient implementation that they
|
662
|
+
released along with Claude 3.7 as well as the OpenAI DefaultAsyncHttpxClient
|
663
|
+
|
664
|
+
"""
|
665
|
+
|
666
|
+
def __init__(self, **kwargs: Any) -> None:
|
667
|
+
# This is based on the openai DefaultAsyncHttpxClient:
|
668
|
+
# https://github.com/openai/openai-python/commit/347363ed67a6a1611346427bb9ebe4becce53f7e
|
669
|
+
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
|
670
|
+
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
|
671
|
+
kwargs.setdefault("follow_redirects", True)
|
672
|
+
|
673
|
+
# This is based on the anthrpopic changes for claude 3.7:
|
674
|
+
# https://github.com/anthropics/anthropic-sdk-python/commit/c5387e69e799f14e44006ea4e54fdf32f2f74393#diff-3acba71f89118b06b03f2ba9f782c49ceed5bb9f68d62727d929f1841b61d12bR1387-R1403
|
675
|
+
|
676
|
+
# set socket options to deal with long running reasoning requests
|
677
|
+
socket_options = [
|
678
|
+
(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
|
679
|
+
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60),
|
680
|
+
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5),
|
681
|
+
]
|
682
|
+
TCP_KEEPIDLE = getattr(socket, "TCP_KEEPIDLE", None)
|
683
|
+
if TCP_KEEPIDLE is not None:
|
684
|
+
socket_options.append((socket.IPPROTO_TCP, TCP_KEEPIDLE, 60))
|
685
|
+
|
686
|
+
kwargs["transport"] = httpx.AsyncHTTPTransport(
|
687
|
+
limits=DEFAULT_CONNECTION_LIMITS,
|
688
|
+
socket_options=socket_options,
|
689
|
+
)
|
690
|
+
|
691
|
+
super().__init__(**kwargs)
|
@@ -82,7 +82,6 @@ class AnthropicAPI(ModelAPI):
|
|
82
82
|
parts = model_name.split("/")
|
83
83
|
if len(parts) > 1:
|
84
84
|
self.service: str | None = parts[0]
|
85
|
-
model_name = "/".join(parts[1:])
|
86
85
|
else:
|
87
86
|
self.service = None
|
88
87
|
|
@@ -237,7 +236,7 @@ class AnthropicAPI(ModelAPI):
|
|
237
236
|
|
238
237
|
# extract output
|
239
238
|
output = await model_output_from_message(
|
240
|
-
self.client, self.
|
239
|
+
self.client, self.service_model_name(), message, tools
|
241
240
|
)
|
242
241
|
|
243
242
|
# return output and call
|
@@ -249,7 +248,7 @@ class AnthropicAPI(ModelAPI):
|
|
249
248
|
except APIStatusError as ex:
|
250
249
|
if ex.status_code == 413:
|
251
250
|
return ModelOutput.from_content(
|
252
|
-
model=self.
|
251
|
+
model=self.service_model_name(),
|
253
252
|
content=ex.message,
|
254
253
|
stop_reason="model_length",
|
255
254
|
error=ex.message,
|
@@ -261,7 +260,7 @@ class AnthropicAPI(ModelAPI):
|
|
261
260
|
self, config: GenerateConfig
|
262
261
|
) -> tuple[dict[str, Any], dict[str, str], list[str]]:
|
263
262
|
max_tokens = cast(int, config.max_tokens)
|
264
|
-
params = dict(model=self.
|
263
|
+
params = dict(model=self.service_model_name(), max_tokens=max_tokens)
|
265
264
|
headers: dict[str, str] = {}
|
266
265
|
betas: list[str] = []
|
267
266
|
# some params not compatible with thinking models
|
@@ -311,18 +310,22 @@ class AnthropicAPI(ModelAPI):
|
|
311
310
|
return not self.is_claude_3() and not self.is_claude_3_5()
|
312
311
|
|
313
312
|
def is_claude_3(self) -> bool:
|
314
|
-
return re.search(r"claude-3-[a-zA-Z]", self.
|
313
|
+
return re.search(r"claude-3-[a-zA-Z]", self.service_model_name()) is not None
|
315
314
|
|
316
315
|
def is_claude_3_5(self) -> bool:
|
317
|
-
return "claude-3-5-" in self.
|
316
|
+
return "claude-3-5-" in self.service_model_name()
|
318
317
|
|
319
318
|
def is_claude_3_7(self) -> bool:
|
320
|
-
return "claude-3-7-" in self.
|
319
|
+
return "claude-3-7-" in self.service_model_name()
|
321
320
|
|
322
321
|
@override
|
323
322
|
def connection_key(self) -> str:
|
324
323
|
return str(self.api_key)
|
325
324
|
|
325
|
+
def service_model_name(self) -> str:
|
326
|
+
"""Model name without any service prefix."""
|
327
|
+
return self.model_name.replace(f"{self.service}/", "", 1)
|
328
|
+
|
326
329
|
@override
|
327
330
|
def should_retry(self, ex: Exception) -> bool:
|
328
331
|
if isinstance(ex, APIStatusError):
|
@@ -371,7 +374,11 @@ class AnthropicAPI(ModelAPI):
|
|
371
374
|
# NOTE: Using case insensitive matching because the Anthropic Bedrock API seems to capitalize the work 'input' in its error message, other times it doesn't.
|
372
375
|
if any(
|
373
376
|
message in error.lower()
|
374
|
-
for message in [
|
377
|
+
for message in [
|
378
|
+
"prompt is too long",
|
379
|
+
"input is too long",
|
380
|
+
"input length and `max_tokens` exceed context limit",
|
381
|
+
]
|
375
382
|
):
|
376
383
|
if (
|
377
384
|
isinstance(ex.body, dict)
|
@@ -392,7 +399,7 @@ class AnthropicAPI(ModelAPI):
|
|
392
399
|
|
393
400
|
if content and stop_reason:
|
394
401
|
return ModelOutput.from_content(
|
395
|
-
model=self.
|
402
|
+
model=self.service_model_name(),
|
396
403
|
content=content,
|
397
404
|
stop_reason=stop_reason,
|
398
405
|
error=error,
|
@@ -440,10 +447,11 @@ class AnthropicAPI(ModelAPI):
|
|
440
447
|
|
441
448
|
# only certain claude models qualify
|
442
449
|
if cache_prompt:
|
450
|
+
model_name = self.service_model_name()
|
443
451
|
if (
|
444
|
-
"claude-3-sonnet" in
|
445
|
-
or "claude-2" in
|
446
|
-
or "claude-instant" in
|
452
|
+
"claude-3-sonnet" in model_name
|
453
|
+
or "claude-2" in model_name
|
454
|
+
or "claude-instant" in model_name
|
447
455
|
):
|
448
456
|
cache_prompt = False
|
449
457
|
|
@@ -368,7 +368,7 @@ class BedrockAPI(ModelAPI):
|
|
368
368
|
toolConfig=tool_config,
|
369
369
|
)
|
370
370
|
|
371
|
-
def model_call(response: dict[str, Any]
|
371
|
+
def model_call(response: dict[str, Any] = {}) -> ModelCall:
|
372
372
|
return ModelCall.create(
|
373
373
|
request=replace_bytes_with_placeholder(
|
374
374
|
request.model_dump(exclude_none=True)
|
@@ -388,14 +388,14 @@ class BedrockAPI(ModelAPI):
|
|
388
388
|
# Look for an explicit validation exception
|
389
389
|
if ex.response["Error"]["Code"] == "ValidationException":
|
390
390
|
response = ex.response["Error"]["Message"]
|
391
|
-
if "
|
391
|
+
if "too many input tokens" in response.lower():
|
392
392
|
return ModelOutput.from_content(
|
393
393
|
model=self.model_name,
|
394
394
|
content=response,
|
395
395
|
stop_reason="model_length",
|
396
396
|
)
|
397
397
|
else:
|
398
|
-
return ex, model_call(
|
398
|
+
return ex, model_call()
|
399
399
|
else:
|
400
400
|
raise ex
|
401
401
|
|