inspect-ai 0.3.91__py3-none-any.whl → 0.3.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +31 -0
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +13 -20
- inspect_ai/_util/local_server.py +368 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +159 -146
- inspect_ai/_view/www/dist/assets/index.js +1020 -1061
- inspect_ai/_view/www/log-schema.json +4 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +3 -2
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +20 -12
- inspect_ai/agent/_as_tool.py +15 -3
- inspect_ai/agent/_handoff.py +8 -1
- inspect_ai/agent/_run.py +11 -3
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +0 -8
- inspect_ai/log/_transcript.py +7 -1
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +32 -12
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +21 -48
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_openai_responses.py +13 -1
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +241 -0
- inspect_ai/model/_providers/vllm.py +207 -400
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +2 -0
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +12 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/RECORD +90 -109
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/WHEEL +1 -1
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/top_level.txt +0 -0
inspect_ai/log/_util.py
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
import textwrap
|
2
|
+
from datetime import date, datetime, time
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from inspect_ai._util.content import (
|
6
|
+
ContentAudio,
|
7
|
+
ContentImage,
|
8
|
+
ContentReasoning,
|
9
|
+
ContentText,
|
10
|
+
ContentVideo,
|
11
|
+
)
|
12
|
+
from inspect_ai.model._chat_message import ChatMessage
|
13
|
+
|
14
|
+
|
15
|
+
def text_input_only(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
|
16
|
+
# Clean the input of any images
|
17
|
+
if isinstance(inputs, list):
|
18
|
+
input: list[ChatMessage] = []
|
19
|
+
for message in inputs:
|
20
|
+
if not isinstance(message.content, str):
|
21
|
+
filtered_content: list[
|
22
|
+
ContentText
|
23
|
+
| ContentReasoning
|
24
|
+
| ContentImage
|
25
|
+
| ContentAudio
|
26
|
+
| ContentVideo
|
27
|
+
] = []
|
28
|
+
for content in message.content:
|
29
|
+
if content.type == "text":
|
30
|
+
filtered_content.append(content)
|
31
|
+
else:
|
32
|
+
filtered_content.append(
|
33
|
+
ContentText(text=f"({content.type.capitalize()})")
|
34
|
+
)
|
35
|
+
message.content = filtered_content
|
36
|
+
input.append(message)
|
37
|
+
else:
|
38
|
+
input.append(message)
|
39
|
+
|
40
|
+
return input
|
41
|
+
else:
|
42
|
+
return inputs
|
43
|
+
|
44
|
+
|
45
|
+
def thin_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
|
46
|
+
thinned: dict[str, Any] = {}
|
47
|
+
for key, value in metadata.items():
|
48
|
+
if isinstance(value, int | float | bool | date | time | datetime):
|
49
|
+
thinned[key] = value
|
50
|
+
elif isinstance(value, str):
|
51
|
+
thinned[key] = textwrap.shorten(value, width=1024, placeholder="...")
|
52
|
+
return thinned
|
inspect_ai/model/__init__.py
CHANGED
@@ -28,7 +28,11 @@ from ._chat_message import (
|
|
28
28
|
ChatMessageUser,
|
29
29
|
)
|
30
30
|
from ._conversation import ModelConversation
|
31
|
-
from ._generate_config import
|
31
|
+
from ._generate_config import (
|
32
|
+
GenerateConfig,
|
33
|
+
GenerateConfigArgs,
|
34
|
+
ResponseSchema,
|
35
|
+
)
|
32
36
|
from ._model import (
|
33
37
|
Model,
|
34
38
|
ModelAPI,
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -60,6 +60,7 @@ from inspect_ai.tool._tool_info import parse_docstring
|
|
60
60
|
from inspect_ai.tool._tool_params import ToolParams
|
61
61
|
from inspect_ai.util import OutputLimitExceededError
|
62
62
|
from inspect_ai.util._anyio import inner_exception
|
63
|
+
from inspect_ai.util._limit import LimitExceededError, apply_limits
|
63
64
|
|
64
65
|
from ._chat_message import (
|
65
66
|
ChatMessage,
|
@@ -171,10 +172,15 @@ async def execute_tools(
|
|
171
172
|
tool_error = ToolCallError("is_a_directory", err)
|
172
173
|
except OutputLimitExceededError as ex:
|
173
174
|
tool_error = ToolCallError(
|
174
|
-
"
|
175
|
-
f"The tool output limit of {ex.limit_str}
|
175
|
+
"limit",
|
176
|
+
f"The tool exceeded its output limit of {ex.limit_str}.",
|
176
177
|
)
|
177
178
|
result = ex.truncated_output or ""
|
179
|
+
except LimitExceededError as ex:
|
180
|
+
tool_error = ToolCallError(
|
181
|
+
"limit",
|
182
|
+
f"The tool exceeded its {ex.type} limit of {ex.limit}.",
|
183
|
+
)
|
178
184
|
except ToolParsingError as ex:
|
179
185
|
tool_error = ToolCallError("parsing", ex.message)
|
180
186
|
except ToolApprovalError as ex:
|
@@ -344,6 +350,7 @@ async def call_tool(
|
|
344
350
|
tools: list[ToolDef], message: str, call: ToolCall, conversation: list[ChatMessage]
|
345
351
|
) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
|
346
352
|
from inspect_ai.agent._handoff import AgentTool
|
353
|
+
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
347
354
|
|
348
355
|
# if there was an error parsing the ToolCall, raise that
|
349
356
|
if call.parse_error:
|
@@ -362,14 +369,11 @@ async def call_tool(
|
|
362
369
|
)
|
363
370
|
if not approved:
|
364
371
|
if approval and approval.decision == "terminate":
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
"operator",
|
369
|
-
value=1,
|
370
|
-
limit=1,
|
371
|
-
message="Tool call approver requested termination.",
|
372
|
+
message = "Tool call approver requested termination."
|
373
|
+
transcript()._event(
|
374
|
+
SampleLimitEvent(type="operator", limit=1, message=message)
|
372
375
|
)
|
376
|
+
raise LimitExceededError("operator", value=1, limit=1, message=message)
|
373
377
|
else:
|
374
378
|
raise ToolApprovalError(approval.explanation if approval else None)
|
375
379
|
if approval and approval.modified:
|
@@ -454,9 +458,14 @@ async def agent_handoff(
|
|
454
458
|
arguments = tool_params(arguments, agent_tool.agent)
|
455
459
|
del arguments["state"]
|
456
460
|
|
457
|
-
#
|
461
|
+
# run the agent with limits
|
462
|
+
limit_error: LimitExceededError | None = None
|
458
463
|
agent_state = AgentState(messages=copy(agent_conversation))
|
459
|
-
|
464
|
+
try:
|
465
|
+
with apply_limits(agent_tool.limits):
|
466
|
+
agent_state = await agent_tool.agent(agent_state, **arguments)
|
467
|
+
except LimitExceededError as ex:
|
468
|
+
limit_error = ex
|
460
469
|
|
461
470
|
# determine which messages are new and return only those (but exclude new
|
462
471
|
# system messages as they an internal matter for the handed off to agent.
|
@@ -474,9 +483,20 @@ async def agent_handoff(
|
|
474
483
|
if agent_tool.output_filter is not None:
|
475
484
|
agent_messages = await agent_tool.output_filter(agent_messages)
|
476
485
|
|
486
|
+
if limit_error is not None:
|
487
|
+
agent_messages.append(
|
488
|
+
ChatMessageUser(
|
489
|
+
content=(
|
490
|
+
f"The {agent_name} exceeded its {limit_error.type} limit of "
|
491
|
+
f"{limit_error.limit}."
|
492
|
+
)
|
493
|
+
)
|
494
|
+
)
|
477
495
|
# if we end with an assistant message then add a user message
|
478
496
|
# so that the calling agent carries on
|
479
|
-
|
497
|
+
elif len(agent_messages) == 0 or isinstance(
|
498
|
+
agent_messages[-1], ChatMessageAssistant
|
499
|
+
):
|
480
500
|
agent_messages.append(
|
481
501
|
ChatMessageUser(content=f"The {agent_name} agent has completed its work.")
|
482
502
|
)
|
@@ -106,6 +106,9 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
106
106
|
response_schema: ResponseSchema | None
|
107
107
|
"""Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and Mistral only."""
|
108
108
|
|
109
|
+
extra_body: dict[str, Any] | None
|
110
|
+
"""Extra body to be sent with requests to OpenAI compatible servers. OpenAI, vLLM, and SGLang only."""
|
111
|
+
|
109
112
|
|
110
113
|
class GenerateConfig(BaseModel):
|
111
114
|
"""Model generation options."""
|
@@ -138,28 +141,28 @@ class GenerateConfig(BaseModel):
|
|
138
141
|
"""Generates best_of completions server-side and returns the 'best' (the one with the highest log probability per token). vLLM only."""
|
139
142
|
|
140
143
|
frequency_penalty: float | None = Field(default=None)
|
141
|
-
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Google, Grok, Groq, and
|
144
|
+
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Google, Grok, Groq, vLLM, and SGLang only."""
|
142
145
|
|
143
146
|
presence_penalty: float | None = Field(default=None)
|
144
|
-
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Google, Grok, Groq, and
|
147
|
+
"""Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Google, Grok, Groq, vLLM, and SGLang only."""
|
145
148
|
|
146
149
|
logit_bias: dict[int, float] | None = Field(default=None)
|
147
|
-
"""Map token Ids to an associated bias value from -100 to 100 (e.g. "42=10,43=-10"). OpenAI, Grok, and
|
150
|
+
"""Map token Ids to an associated bias value from -100 to 100 (e.g. "42=10,43=-10"). OpenAI, Grok, Grok, and vLLM only."""
|
148
151
|
|
149
152
|
seed: int | None = Field(default=None)
|
150
153
|
"""Random seed. OpenAI, Google, Mistral, Groq, HuggingFace, and vLLM only."""
|
151
154
|
|
152
155
|
top_k: int | None = Field(default=None)
|
153
|
-
"""Randomly sample the next word from the top_k most likely next words. Anthropic, Google, HuggingFace, and
|
156
|
+
"""Randomly sample the next word from the top_k most likely next words. Anthropic, Google, HuggingFace, vLLM, and SGLang only."""
|
154
157
|
|
155
158
|
num_choices: int | None = Field(default=None)
|
156
|
-
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and
|
159
|
+
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, vLLM, and SGLang only."""
|
157
160
|
|
158
161
|
logprobs: bool | None = Field(default=None)
|
159
|
-
"""Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and
|
162
|
+
"""Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, vLLM, and SGLang only."""
|
160
163
|
|
161
164
|
top_logprobs: int | None = Field(default=None)
|
162
|
-
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, and
|
165
|
+
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, vLLM, and SGLang only."""
|
163
166
|
|
164
167
|
parallel_tool_calls: bool | None = Field(default=None)
|
165
168
|
"""Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
|
@@ -190,7 +193,10 @@ class GenerateConfig(BaseModel):
|
|
190
193
|
"""Include reasoning in chat message history sent to generate."""
|
191
194
|
|
192
195
|
response_schema: ResponseSchema | None = Field(default=None)
|
193
|
-
"""Request a response format as JSONSchema (output should still be validated). OpenAI, Google, and
|
196
|
+
"""Request a response format as JSONSchema (output should still be validated). OpenAI, Google, Mistral, vLLM, and SGLang only."""
|
197
|
+
|
198
|
+
extra_body: dict[str, Any] | None = Field(default=None)
|
199
|
+
"""Extra body to be sent with requests to OpenAI compatible servers. OpenAI, vLLM, and SGLang only."""
|
194
200
|
|
195
201
|
# migrate reasoning_history as a bool
|
196
202
|
@model_validator(mode="before")
|
inspect_ai/model/_model.py
CHANGED
@@ -57,6 +57,11 @@ from inspect_ai.tool._tool import ToolSource
|
|
57
57
|
from inspect_ai.tool._tool_call import ToolCallModelInputHints
|
58
58
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
59
59
|
from inspect_ai.util import concurrency
|
60
|
+
from inspect_ai.util._limit import (
|
61
|
+
check_message_limit,
|
62
|
+
check_token_limit,
|
63
|
+
record_model_usage,
|
64
|
+
)
|
60
65
|
|
61
66
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
62
67
|
from ._call_tools import (
|
@@ -355,11 +360,15 @@ class Model:
|
|
355
360
|
Returns:
|
356
361
|
ModelOutput
|
357
362
|
"""
|
358
|
-
# if we are the default model then
|
359
|
-
# exists (raise an exception if it is exceeded)
|
363
|
+
# if we are the default model then update the displayed message count
|
360
364
|
is_active_model = self == active_model()
|
361
365
|
if is_active_model:
|
362
|
-
|
366
|
+
set_total_messages(input)
|
367
|
+
|
368
|
+
# check message limit, raise exception if we're already at the limit to prevent
|
369
|
+
# a wasteful generate()
|
370
|
+
conversation_length = len(input) if isinstance(input, list) else 1
|
371
|
+
check_message_limit(conversation_length, raise_for_equal=True)
|
363
372
|
|
364
373
|
# base config for this model
|
365
374
|
base_config = self.config
|
@@ -666,7 +675,7 @@ class Model:
|
|
666
675
|
# record usage
|
667
676
|
if output.usage:
|
668
677
|
# record usage
|
669
|
-
|
678
|
+
record_and_check_model_usage(f"{self}", output.usage)
|
670
679
|
|
671
680
|
# send telemetry if its hooked up
|
672
681
|
await send_telemetry(
|
@@ -1423,20 +1432,10 @@ _model_roles: ContextVar[dict[str, Model]] = ContextVar("model_roles", default={
|
|
1423
1432
|
|
1424
1433
|
|
1425
1434
|
# shared contexts for asyncio tasks
|
1426
|
-
def
|
1427
|
-
from inspect_ai.log._samples import
|
1428
|
-
active_sample_message_limit,
|
1429
|
-
set_active_sample_total_messages,
|
1430
|
-
)
|
1431
|
-
from inspect_ai.solver._limit import SampleLimitExceededError
|
1435
|
+
def set_total_messages(input: str | list[ChatMessage]) -> None:
|
1436
|
+
from inspect_ai.log._samples import set_active_sample_total_messages
|
1432
1437
|
|
1433
1438
|
total_messages = 1 if isinstance(input, str) else len(input)
|
1434
|
-
message_limit = active_sample_message_limit()
|
1435
|
-
if message_limit is not None:
|
1436
|
-
if total_messages >= message_limit:
|
1437
|
-
raise SampleLimitExceededError(
|
1438
|
-
"message", value=total_messages, limit=message_limit
|
1439
|
-
)
|
1440
1439
|
|
1441
1440
|
# set total messages
|
1442
1441
|
set_active_sample_total_messages(total_messages)
|
@@ -1450,16 +1449,13 @@ def init_sample_model_usage() -> None:
|
|
1450
1449
|
sample_model_usage_context_var.set({})
|
1451
1450
|
|
1452
1451
|
|
1453
|
-
def
|
1454
|
-
from inspect_ai.log._samples import
|
1455
|
-
active_sample_token_limit,
|
1456
|
-
set_active_sample_total_tokens,
|
1457
|
-
)
|
1458
|
-
from inspect_ai.solver._limit import SampleLimitExceededError
|
1452
|
+
def record_and_check_model_usage(model: str, usage: ModelUsage) -> None:
|
1453
|
+
from inspect_ai.log._samples import set_active_sample_total_tokens
|
1459
1454
|
|
1460
1455
|
# record usage
|
1461
1456
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
1462
1457
|
set_model_usage(model, usage, model_usage_context_var.get(None))
|
1458
|
+
record_model_usage(usage)
|
1463
1459
|
|
1464
1460
|
# compute total tokens
|
1465
1461
|
total_tokens = sample_total_tokens()
|
@@ -1467,38 +1463,15 @@ def record_model_usage(model: str, usage: ModelUsage) -> None:
|
|
1467
1463
|
# update active sample
|
1468
1464
|
set_active_sample_total_tokens(total_tokens)
|
1469
1465
|
|
1470
|
-
|
1471
|
-
token_limit = active_sample_token_limit()
|
1472
|
-
if token_limit is not None:
|
1473
|
-
if total_tokens > token_limit:
|
1474
|
-
raise SampleLimitExceededError(
|
1475
|
-
"token", value=total_tokens, limit=token_limit
|
1476
|
-
)
|
1466
|
+
check_token_limit()
|
1477
1467
|
|
1478
1468
|
|
1479
1469
|
def set_model_usage(
|
1480
1470
|
model: str, usage: ModelUsage, model_usage: dict[str, ModelUsage] | None
|
1481
1471
|
) -> None:
|
1482
1472
|
if model_usage is not None:
|
1483
|
-
total_usage
|
1484
|
-
|
1485
|
-
total_usage = ModelUsage()
|
1486
|
-
total_usage.input_tokens += usage.input_tokens
|
1487
|
-
total_usage.output_tokens += usage.output_tokens
|
1488
|
-
total_usage.total_tokens += usage.total_tokens
|
1489
|
-
if usage.input_tokens_cache_write is not None:
|
1490
|
-
if total_usage.input_tokens_cache_write is None:
|
1491
|
-
total_usage.input_tokens_cache_write = 0
|
1492
|
-
total_usage.input_tokens_cache_write += usage.input_tokens_cache_write
|
1493
|
-
if usage.input_tokens_cache_read is not None:
|
1494
|
-
if total_usage.input_tokens_cache_read is None:
|
1495
|
-
total_usage.input_tokens_cache_read = 0
|
1496
|
-
total_usage.input_tokens_cache_read += usage.input_tokens_cache_read
|
1497
|
-
if usage.reasoning_tokens is not None:
|
1498
|
-
if total_usage.reasoning_tokens is None:
|
1499
|
-
total_usage.reasoning_tokens = 0
|
1500
|
-
total_usage.reasoning_tokens += usage.reasoning_tokens
|
1501
|
-
|
1473
|
+
total_usage = model_usage.get(model, ModelUsage())
|
1474
|
+
total_usage += usage
|
1502
1475
|
model_usage[model] = total_usage
|
1503
1476
|
|
1504
1477
|
|
@@ -30,6 +30,31 @@ class ModelUsage(BaseModel):
|
|
30
30
|
reasoning_tokens: int | None = Field(default=None)
|
31
31
|
"""Number of tokens used for reasoning."""
|
32
32
|
|
33
|
+
def __add__(self, other: "ModelUsage") -> "ModelUsage":
|
34
|
+
def optional_sum(a: int | None, b: int | None) -> int | None:
|
35
|
+
if a is not None and b is not None:
|
36
|
+
return a + b
|
37
|
+
if a is not None:
|
38
|
+
return a
|
39
|
+
if b is not None:
|
40
|
+
return b
|
41
|
+
return None
|
42
|
+
|
43
|
+
return ModelUsage(
|
44
|
+
input_tokens=self.input_tokens + other.input_tokens,
|
45
|
+
output_tokens=self.output_tokens + other.output_tokens,
|
46
|
+
total_tokens=self.total_tokens + other.total_tokens,
|
47
|
+
input_tokens_cache_write=optional_sum(
|
48
|
+
self.input_tokens_cache_write, other.input_tokens_cache_write
|
49
|
+
),
|
50
|
+
input_tokens_cache_read=optional_sum(
|
51
|
+
self.input_tokens_cache_read, other.input_tokens_cache_read
|
52
|
+
),
|
53
|
+
reasoning_tokens=optional_sum(
|
54
|
+
self.reasoning_tokens, other.reasoning_tokens
|
55
|
+
),
|
56
|
+
)
|
57
|
+
|
33
58
|
|
34
59
|
StopReason = Literal[
|
35
60
|
"stop",
|
inspect_ai/model/_openai.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
from itertools import chain
|
2
3
|
from typing import TypedDict, cast
|
3
4
|
|
@@ -306,6 +307,14 @@ def _openai_input_items_from_chat_message_assistant(
|
|
306
307
|
"""
|
307
308
|
(output_message_id, tool_message_ids) = _ids_from_assistant_internal(message)
|
308
309
|
|
310
|
+
# we want to prevent yielding output messages in the case where we have an
|
311
|
+
# 'internal' field (so the message came from the model API as opposed to
|
312
|
+
# being user synthesized) AND there is no output_message_id (indicating that
|
313
|
+
# when reading the message from the server we didn't find output). this could
|
314
|
+
# happen e.g. when a react() agent sets the output.completion in response
|
315
|
+
# to a submit() tool call
|
316
|
+
suppress_output_message = message.internal is not None and output_message_id is None
|
317
|
+
|
309
318
|
# if we are not storing messages on the server then blank these out
|
310
319
|
if not store:
|
311
320
|
output_message_id = None
|
@@ -341,6 +350,9 @@ def _openai_input_items_from_chat_message_assistant(
|
|
341
350
|
)
|
342
351
|
)
|
343
352
|
case ContentText(text=text, refusal=refusal):
|
353
|
+
if suppress_output_message:
|
354
|
+
continue
|
355
|
+
|
344
356
|
new_content = (
|
345
357
|
ResponseOutputRefusalParam(type="refusal", refusal=text)
|
346
358
|
if refusal
|
@@ -415,7 +427,7 @@ def _tool_call_items_from_assistant_message(
|
|
415
427
|
type="function_call",
|
416
428
|
call_id=call.id,
|
417
429
|
name=_responses_tool_alias(call.function),
|
418
|
-
arguments=call.
|
430
|
+
arguments=json.dumps(call.arguments),
|
419
431
|
)
|
420
432
|
|
421
433
|
# add id if available
|
@@ -26,7 +26,6 @@ from anthropic.types import (
|
|
26
26
|
TextBlockParam,
|
27
27
|
ThinkingBlock,
|
28
28
|
ThinkingBlockParam,
|
29
|
-
ToolBash20250124Param,
|
30
29
|
ToolParam,
|
31
30
|
ToolResultBlockParam,
|
32
31
|
ToolTextEditor20250124Param,
|
@@ -76,6 +75,7 @@ class AnthropicAPI(ModelAPI):
|
|
76
75
|
base_url: str | None = None,
|
77
76
|
api_key: str | None = None,
|
78
77
|
config: GenerateConfig = GenerateConfig(),
|
78
|
+
streaming: bool | Literal["auto"] = "auto",
|
79
79
|
**model_args: Any,
|
80
80
|
):
|
81
81
|
# extract any service prefix from model name
|
@@ -85,6 +85,9 @@ class AnthropicAPI(ModelAPI):
|
|
85
85
|
else:
|
86
86
|
self.service = None
|
87
87
|
|
88
|
+
# record steraming pref
|
89
|
+
self.streaming = streaming
|
90
|
+
|
88
91
|
# collect generate model_args (then delete them so we can pass the rest on)
|
89
92
|
def collect_model_arg(name: str) -> Any | None:
|
90
93
|
nonlocal model_args
|
@@ -224,8 +227,13 @@ class AnthropicAPI(ModelAPI):
|
|
224
227
|
if self.extra_body is not None:
|
225
228
|
request["extra_body"] = self.extra_body
|
226
229
|
|
227
|
-
# make request (stream if we are using reasoning)
|
228
|
-
|
230
|
+
# make request (unless overrideen, stream if we are using reasoning)
|
231
|
+
streaming = (
|
232
|
+
self.is_using_thinking(config)
|
233
|
+
if self.streaming == "auto"
|
234
|
+
else self.streaming
|
235
|
+
)
|
236
|
+
if streaming:
|
229
237
|
async with self.client.messages.stream(**request) as stream:
|
230
238
|
message = await stream.get_final_message()
|
231
239
|
else:
|
@@ -489,11 +497,7 @@ class AnthropicAPI(ModelAPI):
|
|
489
497
|
self, tool: ToolInfo, config: GenerateConfig
|
490
498
|
) -> Optional["ToolParamDef"]:
|
491
499
|
return (
|
492
|
-
(
|
493
|
-
self.computer_use_tool_param(tool)
|
494
|
-
or self.text_editor_tool_param(tool)
|
495
|
-
or self.bash_tool_param(tool)
|
496
|
-
)
|
500
|
+
(self.computer_use_tool_param(tool) or self.text_editor_tool_param(tool))
|
497
501
|
if config.internal_tools is not False
|
498
502
|
else None
|
499
503
|
)
|
@@ -564,23 +568,10 @@ class AnthropicAPI(ModelAPI):
|
|
564
568
|
else:
|
565
569
|
return None
|
566
570
|
|
567
|
-
def bash_tool_param(self, tool: ToolInfo) -> Optional[ToolBash20250124Param]:
|
568
|
-
# check for compatible 'bash' tool
|
569
|
-
if tool.name == "bash_session" and (
|
570
|
-
sorted(tool.parameters.properties.keys()) == sorted(["command", "restart"])
|
571
|
-
):
|
572
|
-
return ToolBash20250124Param(type="bash_20250124", name="bash")
|
573
|
-
# not a bash tool
|
574
|
-
else:
|
575
|
-
return None
|
576
|
-
|
577
571
|
|
578
572
|
# tools can be either a stock tool param or a special Anthropic native use tool param
|
579
573
|
ToolParamDef = (
|
580
|
-
ToolParam
|
581
|
-
| BetaToolComputerUse20250124Param
|
582
|
-
| ToolTextEditor20250124Param
|
583
|
-
| ToolBash20250124Param
|
574
|
+
ToolParam | BetaToolComputerUse20250124Param | ToolTextEditor20250124Param
|
584
575
|
)
|
585
576
|
|
586
577
|
|
@@ -589,7 +580,6 @@ def add_cache_control(
|
|
589
580
|
| ToolParam
|
590
581
|
| BetaToolComputerUse20250124Param
|
591
582
|
| ToolTextEditor20250124Param
|
592
|
-
| ToolBash20250124Param
|
593
583
|
| dict[str, Any],
|
594
584
|
) -> None:
|
595
585
|
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}
|
@@ -211,8 +211,15 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
|
|
211
211
|
This method has an interdependency with `input_with_tools()` (as that is the
|
212
212
|
prompt that asks the model to use the <tool_call>...</tool_call> syntax)
|
213
213
|
"""
|
214
|
-
#
|
214
|
+
# define regex patterns
|
215
|
+
# NOTE: If you change either of these regex patterns, please update the other
|
216
|
+
# tool_call_regex extracts the JSON content (in curly braces) between tool call tags
|
215
217
|
tool_call_regex = rf"<{TOOL_CALL}>\s*(\{{[\s\S]*?\}})\s*</{TOOL_CALL}>"
|
218
|
+
# tool_call_content_regex matches the entire tool call block including tags for extracting
|
219
|
+
# the content outside of the tool call tags
|
220
|
+
tool_call_content_regex = rf"<{TOOL_CALL}>\s*\{{[\s\S]*?\}}\s*</{TOOL_CALL}>"
|
221
|
+
|
222
|
+
# extract tool calls
|
216
223
|
tool_calls_content: list[str] = re.findall(tool_call_regex, response)
|
217
224
|
|
218
225
|
# if there are tool calls proceed with parsing
|
@@ -226,7 +233,6 @@ class O1PreviewChatAPIHandler(ChatAPIHandler):
|
|
226
233
|
]
|
227
234
|
|
228
235
|
# find other content that exists outside tool calls
|
229
|
-
tool_call_content_regex = rf"<{TOOL_CALL}>(?:.|\n)*?</{TOOL_CALL}>"
|
230
236
|
other_content = re.split(tool_call_content_regex, response, flags=re.DOTALL)
|
231
237
|
other_content = [
|
232
238
|
str(content).strip()
|
@@ -136,10 +136,12 @@ def hf() -> type[ModelAPI]:
|
|
136
136
|
|
137
137
|
@modelapi(name="vllm")
|
138
138
|
def vllm() -> type[ModelAPI]:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
139
|
+
# Only validate OpenAI compatibility (needed for the API interface)
|
140
|
+
validate_openai_client("vLLM API")
|
141
|
+
|
142
|
+
# Import VLLMAPI without checking for vllm package yet
|
143
|
+
# The actual vllm dependency will only be checked if needed to start a server
|
144
|
+
from .vllm import VLLMAPI
|
143
145
|
|
144
146
|
return VLLMAPI
|
145
147
|
|
@@ -257,6 +259,18 @@ def mockllm() -> type[ModelAPI]:
|
|
257
259
|
return MockLLM
|
258
260
|
|
259
261
|
|
262
|
+
@modelapi(name="sglang")
|
263
|
+
def sglang() -> type[ModelAPI]:
|
264
|
+
# Only validate OpenAI compatibility (needed for the API interface)
|
265
|
+
validate_openai_client("SGLang API")
|
266
|
+
|
267
|
+
# Import SGLangAPI without checking for sglang package yet
|
268
|
+
# The actual sglang dependency will only be checked if needed to start a server
|
269
|
+
from .sglang import SGLangAPI
|
270
|
+
|
271
|
+
return SGLangAPI
|
272
|
+
|
273
|
+
|
260
274
|
@modelapi(name="none")
|
261
275
|
def none() -> type[ModelAPI]:
|
262
276
|
from .none import NoModel
|