inspect-ai 0.3.92__py3-none-any.whl → 0.3.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +27 -0
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +23 -27
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/local_server.py +398 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +173 -159
- inspect_ai/_view/www/dist/assets/index.js +1417 -1142
- inspect_ai/_view/www/log-schema.json +379 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +93 -14
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +22 -12
- inspect_ai/agent/_as_tool.py +20 -6
- inspect_ai/agent/_handoff.py +12 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +16 -3
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +14 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +14 -25
- inspect_ai/log/_transcript.py +84 -36
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +72 -44
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +66 -88
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +247 -0
- inspect_ai/model/_providers/vllm.py +211 -400
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +5 -22
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +8 -5
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +16 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/RECORD +120 -134
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.92.dist-info → inspect_ai-0.3.94.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,8 @@ from inspect_ai.scorer._score import score
|
|
13
13
|
from inspect_ai.solver._chain import chain
|
14
14
|
from inspect_ai.tool._tool import Tool, ToolResult, tool
|
15
15
|
from inspect_ai.tool._tool_with import tool_with
|
16
|
+
from inspect_ai.util._limit import token_limit as create_token_limit
|
16
17
|
|
17
|
-
from ._limit import SampleLimitExceededError
|
18
18
|
from ._prompt import system_message
|
19
19
|
from ._solver import Generate, Solver, solver
|
20
20
|
from ._task_state import TaskState
|
@@ -172,14 +172,11 @@ def basic_agent(
|
|
172
172
|
# (if there is no message_limit then default to 50)
|
173
173
|
state.message_limit = message_limit or state.message_limit or 50
|
174
174
|
|
175
|
-
# resolve token limit
|
176
|
-
state.token_limit = token_limit or state.token_limit
|
177
|
-
|
178
175
|
# track attempts
|
179
176
|
attempts = 0
|
180
177
|
|
181
|
-
|
182
|
-
# main loop
|
178
|
+
with create_token_limit(token_limit):
|
179
|
+
# main loop
|
183
180
|
while not state.completed:
|
184
181
|
# generate output and append assistant message
|
185
182
|
state.output = await get_model().generate(
|
@@ -247,10 +244,6 @@ def basic_agent(
|
|
247
244
|
else:
|
248
245
|
state.messages.append(ChatMessageUser(content=continue_message))
|
249
246
|
|
250
|
-
# propagate current state along with sample limit exceeded
|
251
|
-
except SampleLimitExceededError as ex:
|
252
|
-
raise ex.with_state(state)
|
253
|
-
|
254
247
|
return state
|
255
248
|
|
256
249
|
return solve
|
inspect_ai/solver/_chain.py
CHANGED
@@ -82,7 +82,7 @@ class Chain(Sequence[Solver], Solver):
|
|
82
82
|
from ._transcript import solver_transcript
|
83
83
|
|
84
84
|
for slv in self._solvers:
|
85
|
-
with solver_transcript(slv, state) as st:
|
85
|
+
async with solver_transcript(slv, state) as st:
|
86
86
|
state = await slv(state, generate)
|
87
87
|
st.complete(state)
|
88
88
|
if state.completed:
|
inspect_ai/solver/_fork.py
CHANGED
@@ -73,7 +73,7 @@ async def solver_subtask(state: TaskState, solver: Solver) -> TaskState:
|
|
73
73
|
@subtask(name=name, store=state.store, type="fork", input=input) # type: ignore
|
74
74
|
async def solve() -> TaskState:
|
75
75
|
if not isinstance(solver, Chain):
|
76
|
-
with solver_transcript(solver, state) as st:
|
76
|
+
async with solver_transcript(solver, state) as st:
|
77
77
|
new_state = await solver(state, generate)
|
78
78
|
st.complete(new_state)
|
79
79
|
return new_state
|
@@ -6,6 +6,7 @@ from typing import Match, TypedDict
|
|
6
6
|
|
7
7
|
from typing_extensions import Unpack
|
8
8
|
|
9
|
+
from inspect_ai._util.answer import answer_character, answer_index
|
9
10
|
from inspect_ai._util.logger import warn_once
|
10
11
|
from inspect_ai.util import resource
|
11
12
|
|
@@ -64,31 +65,13 @@ def answer_options(choices: Choices) -> str:
|
|
64
65
|
indexes = list(range(len(choices)))
|
65
66
|
|
66
67
|
return "\n".join(
|
67
|
-
[f"{
|
68
|
+
[f"{answer_character(i)}) {choices[j].value}" for i, j in enumerate(indexes)]
|
68
69
|
)
|
69
70
|
|
70
71
|
|
71
|
-
def answer_character(index: int) -> str:
|
72
|
-
r"""
|
73
|
-
Helper to go from array index to char, for example:
|
74
|
-
|
75
|
-
0 -> 'A', 1 -> 'B', etc
|
76
|
-
"""
|
77
|
-
return chr(ord("A") + index)
|
78
|
-
|
79
|
-
|
80
|
-
def answer_index(char: str) -> int:
|
81
|
-
r"""
|
82
|
-
Helper to go from char to array index, for example:
|
83
|
-
|
84
|
-
'A' -> 0, 'B' -> 1, etc
|
85
|
-
"""
|
86
|
-
return ord(char.upper()) - ord("A")
|
87
|
-
|
88
|
-
|
89
72
|
def prompt(question: str, choices: Choices, template: str) -> str:
|
90
73
|
choices_text = answer_options(choices)
|
91
|
-
letters = ",".join(
|
74
|
+
letters = ",".join(answer_character(i) for i in range(len(choices)))
|
92
75
|
|
93
76
|
return template.format(
|
94
77
|
choices=choices_text,
|
@@ -112,7 +95,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
112
95
|
# In this case, we're looking for a single line which contains the expected
|
113
96
|
# ANSWER: B,C string with only whitespace after it
|
114
97
|
match = re.search(
|
115
|
-
r"(?i)^ANSWER\s*:\s*([A-Za-z ,]+)\s*(?:$|\n)",
|
98
|
+
r"(?i)^ANSWER\s*:\s*([A-Za-z\d ,]+)\s*(?:$|\n)",
|
116
99
|
state.output.completion,
|
117
100
|
flags=re.MULTILINE,
|
118
101
|
)
|
@@ -121,7 +104,7 @@ def parse_answers(state: TaskState) -> Match[str] | None:
|
|
121
104
|
# version for backward compatibility
|
122
105
|
if match is None:
|
123
106
|
return re.search(
|
124
|
-
r"(?i)ANSWER\s*:\s*([A-Za-z ,]+)(?:[^\w]|\n|$)", state.output.completion
|
107
|
+
r"(?i)ANSWER\s*:\s*([A-Za-z\d ,]+)(?:[^\w]|\n|$)", state.output.completion
|
125
108
|
)
|
126
109
|
else:
|
127
110
|
return match
|
inspect_ai/solver/_plan.py
CHANGED
@@ -102,7 +102,7 @@ class Plan(Solver):
|
|
102
102
|
# execute steps
|
103
103
|
for index, solver in enumerate(self.steps):
|
104
104
|
# run solver
|
105
|
-
with solver_transcript(solver, state) as st:
|
105
|
+
async with solver_transcript(solver, state) as st:
|
106
106
|
state = await solver(state, generate)
|
107
107
|
st.complete(state)
|
108
108
|
|
@@ -113,7 +113,7 @@ class Plan(Solver):
|
|
113
113
|
|
114
114
|
# execute finish
|
115
115
|
if self.finish:
|
116
|
-
with solver_transcript(self.finish, state) as st:
|
116
|
+
async with solver_transcript(self.finish, state) as st:
|
117
117
|
state = await self.finish(state, generate)
|
118
118
|
st.complete(state)
|
119
119
|
check_sample_interrupt()
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -2,9 +2,8 @@ from collections.abc import Sequence
|
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from copy import deepcopy
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from itertools import tee
|
6
5
|
from random import Random
|
7
|
-
from typing import Any,
|
6
|
+
from typing import Any, Type, Union, cast, overload
|
8
7
|
|
9
8
|
from pydantic_core import to_jsonable_python
|
10
9
|
from shortuuid import uuid
|
@@ -18,12 +17,18 @@ from inspect_ai.model import (
|
|
18
17
|
ModelOutput,
|
19
18
|
)
|
20
19
|
from inspect_ai.model._call_tools import tools_info
|
21
|
-
from inspect_ai.model._chat_message import ChatMessageBase
|
22
20
|
from inspect_ai.model._model import sample_total_tokens
|
23
21
|
from inspect_ai.scorer._metric import Score
|
24
22
|
from inspect_ai.scorer._target import Target
|
25
23
|
from inspect_ai.tool import Tool, ToolChoice
|
26
24
|
from inspect_ai.tool._tool_def import ToolDef
|
25
|
+
from inspect_ai.util._limit import (
|
26
|
+
check_message_limit,
|
27
|
+
check_token_limit,
|
28
|
+
)
|
29
|
+
from inspect_ai.util._limit import message_limit as create_message_limit
|
30
|
+
from inspect_ai.util._limit import token_limit as create_token_limit
|
31
|
+
from inspect_ai.util._limited_conversation import ChatMessageList
|
27
32
|
from inspect_ai.util._store import Store, store_jsonable
|
28
33
|
from inspect_ai.util._store_model import SMT
|
29
34
|
|
@@ -159,11 +164,11 @@ class TaskState:
|
|
159
164
|
self._input = input
|
160
165
|
self._target = target
|
161
166
|
self._metadata = metadata
|
162
|
-
self._messages: list[ChatMessage] = ChatMessageList(messages
|
167
|
+
self._messages: list[ChatMessage] = ChatMessageList(messages)
|
163
168
|
self._tools: list[Tool] = []
|
164
169
|
self._output = output if output else ModelOutput(model=str(model))
|
165
|
-
self._message_limit = message_limit
|
166
|
-
self._token_limit = token_limit
|
170
|
+
self._message_limit = create_message_limit(message_limit)
|
171
|
+
self._token_limit = create_token_limit(token_limit)
|
167
172
|
self._completed = completed
|
168
173
|
self._store = Store()
|
169
174
|
self._uuid = uuid()
|
@@ -254,7 +259,7 @@ class TaskState:
|
|
254
259
|
|
255
260
|
@messages.setter
|
256
261
|
def messages(self, messages: list[ChatMessage]) -> None:
|
257
|
-
self._messages = ChatMessageList(messages
|
262
|
+
self._messages = ChatMessageList(messages)
|
258
263
|
|
259
264
|
@property
|
260
265
|
def output(self) -> ModelOutput:
|
@@ -302,12 +307,16 @@ class TaskState:
|
|
302
307
|
@property
|
303
308
|
def message_limit(self) -> int | None:
|
304
309
|
"""Limit on total messages allowed per conversation."""
|
305
|
-
return self._message_limit
|
310
|
+
return self._message_limit.limit
|
306
311
|
|
307
312
|
@message_limit.setter
|
308
313
|
def message_limit(self, messages: int | None) -> None:
|
309
|
-
"""Set limit on total messages allowed per conversation.
|
310
|
-
|
314
|
+
"""Set limit on total messages allowed per conversation.
|
315
|
+
|
316
|
+
Also checks whether the current message count exceeds the new limit.
|
317
|
+
"""
|
318
|
+
self._message_limit.limit = messages
|
319
|
+
check_message_limit(len(self.messages), raise_for_equal=False)
|
311
320
|
|
312
321
|
from inspect_ai.log._samples import set_active_sample_message_limit
|
313
322
|
|
@@ -316,12 +325,16 @@ class TaskState:
|
|
316
325
|
@property
|
317
326
|
def token_limit(self) -> int | None:
|
318
327
|
"""Limit on total tokens allowed per conversation."""
|
319
|
-
return self._token_limit
|
328
|
+
return self._token_limit.limit
|
320
329
|
|
321
330
|
@token_limit.setter
|
322
331
|
def token_limit(self, tokens: int | None) -> None:
|
323
|
-
"""Set limit on total tokens allowed per conversation.
|
324
|
-
|
332
|
+
"""Set limit on total tokens allowed per conversation.
|
333
|
+
|
334
|
+
Also checks whether the current token usage exceeds the new limit.
|
335
|
+
"""
|
336
|
+
self._token_limit.limit = tokens
|
337
|
+
check_token_limit()
|
325
338
|
|
326
339
|
from inspect_ai.log._samples import set_active_sample_token_limit
|
327
340
|
|
@@ -340,24 +353,11 @@ class TaskState:
|
|
340
353
|
"""
|
341
354
|
from inspect_ai.log._samples import set_active_sample_total_messages
|
342
355
|
|
343
|
-
from ._limit import SampleLimitExceededError
|
344
|
-
|
345
356
|
# update messages
|
346
357
|
set_active_sample_total_messages(len(self.messages))
|
347
358
|
|
348
359
|
if self._completed:
|
349
360
|
return True
|
350
|
-
elif self.message_limit and len(self.messages) >= self.message_limit:
|
351
|
-
raise SampleLimitExceededError(
|
352
|
-
"message",
|
353
|
-
value=len(self.messages),
|
354
|
-
limit=self.message_limit,
|
355
|
-
state=self,
|
356
|
-
)
|
357
|
-
elif self.token_limit and self.token_usage >= self.token_limit:
|
358
|
-
raise SampleLimitExceededError(
|
359
|
-
"token", value=self.token_usage, limit=self.token_limit, state=self
|
360
|
-
)
|
361
361
|
else:
|
362
362
|
check_sample_interrupt()
|
363
363
|
return self._completed
|
@@ -445,65 +445,3 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
|
|
445
445
|
def sample_jsonable(sample: Sample) -> dict[str, Any]:
|
446
446
|
jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
|
447
447
|
return cast(dict[str, Any], deepcopy(jsonable))
|
448
|
-
|
449
|
-
|
450
|
-
class ChatMessageList(list[ChatMessage]):
|
451
|
-
def __init__(self, iterable: Iterable[ChatMessage], parent_state: TaskState):
|
452
|
-
self.parent_state = parent_state
|
453
|
-
items, length = self._iterable_length(iterable)
|
454
|
-
self._check_size(length)
|
455
|
-
super().__init__(items)
|
456
|
-
|
457
|
-
def _check_size(self, additional_items: int = 1) -> None:
|
458
|
-
from inspect_ai.log._samples import active_sample_message_limit
|
459
|
-
|
460
|
-
from ._limit import SampleLimitExceededError
|
461
|
-
|
462
|
-
messages_limit = active_sample_message_limit()
|
463
|
-
if messages_limit is not None:
|
464
|
-
messages = len(self) + additional_items
|
465
|
-
if messages > messages_limit:
|
466
|
-
raise SampleLimitExceededError(
|
467
|
-
"message",
|
468
|
-
value=messages,
|
469
|
-
limit=messages_limit,
|
470
|
-
message=None,
|
471
|
-
state=self.parent_state,
|
472
|
-
)
|
473
|
-
|
474
|
-
def append(self, item: ChatMessage) -> None:
|
475
|
-
self._check_size()
|
476
|
-
super().append(item)
|
477
|
-
|
478
|
-
def extend(self, items: Iterable[ChatMessage]) -> None:
|
479
|
-
items, length = self._iterable_length(items)
|
480
|
-
self._check_size(length)
|
481
|
-
super().extend(items)
|
482
|
-
|
483
|
-
def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
|
484
|
-
self._check_size()
|
485
|
-
super().insert(index, item)
|
486
|
-
|
487
|
-
@overload
|
488
|
-
def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
|
489
|
-
|
490
|
-
@overload
|
491
|
-
def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
|
492
|
-
|
493
|
-
def __setitem__(
|
494
|
-
self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
|
495
|
-
) -> None:
|
496
|
-
if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
|
497
|
-
item, length = self._iterable_length(item)
|
498
|
-
size_change = length - len(self[index])
|
499
|
-
if size_change > 0:
|
500
|
-
self._check_size(size_change)
|
501
|
-
|
502
|
-
super().__setitem__(index, item) # type: ignore[assignment,index]
|
503
|
-
|
504
|
-
def _iterable_length(
|
505
|
-
self, items: Iterable[ChatMessage]
|
506
|
-
) -> tuple[Iterable[ChatMessage], int]:
|
507
|
-
items, counter = tee(items)
|
508
|
-
length = sum(1 for _ in counter)
|
509
|
-
return items, length
|
inspect_ai/solver/_transcript.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1
1
|
import contextlib
|
2
|
-
from typing import
|
2
|
+
from typing import AsyncIterator
|
3
3
|
|
4
4
|
from inspect_ai._util.json import json_changes
|
5
5
|
from inspect_ai._util.registry import registry_log_name
|
6
|
+
from inspect_ai.util._span import span
|
6
7
|
|
7
8
|
from ._solver import Solver
|
8
9
|
from ._task_state import TaskState, state_jsonable
|
@@ -22,12 +23,10 @@ class SolverTranscript:
|
|
22
23
|
transcript()._event(StateEvent(changes=changes))
|
23
24
|
|
24
25
|
|
25
|
-
@contextlib.
|
26
|
-
def solver_transcript(
|
26
|
+
@contextlib.asynccontextmanager
|
27
|
+
async def solver_transcript(
|
27
28
|
solver: Solver, state: TaskState, name: str | None = None
|
28
|
-
) ->
|
29
|
-
from inspect_ai.log._transcript import transcript
|
30
|
-
|
29
|
+
) -> AsyncIterator[SolverTranscript]:
|
31
30
|
name = registry_log_name(name or solver)
|
32
|
-
with
|
31
|
+
async with span(name=name, type="solver"):
|
33
32
|
yield SolverTranscript(name, state)
|
@@ -4,7 +4,7 @@ from typing import Literal, Protocol, Type, TypeAlias, TypeVar
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel, RootModel
|
6
6
|
|
7
|
-
from inspect_ai.tool._tool import ToolError
|
7
|
+
from inspect_ai.tool._tool import ToolError, ToolParsingError
|
8
8
|
|
9
9
|
|
10
10
|
class JSONRPCResponseBase(BaseModel):
|
@@ -70,6 +70,7 @@ async def exec_scalar_request(
|
|
70
70
|
params: JSONRPCParamsType,
|
71
71
|
result_type: Type[ScalarT],
|
72
72
|
transport: JSONRPCTransport,
|
73
|
+
server_error_mapper: JSONRPCServerErrorMapper,
|
73
74
|
) -> ScalarT:
|
74
75
|
"""
|
75
76
|
Execute a JSON-RPC command expecting a scalar result.
|
@@ -79,6 +80,7 @@ async def exec_scalar_request(
|
|
79
80
|
params (JSONRPCParamsType): The parameters for the JSON-RPC method.
|
80
81
|
result_type (Type[ScalarT]): The scalar type (str, int, float, bool, None) to validate the result against.
|
81
82
|
transport (JSONRPCTransport): The transport callable to use for the RPC communication.
|
83
|
+
server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
|
82
84
|
|
83
85
|
Returns:
|
84
86
|
ScalarT: The scalar result of the JSON-RPC call.
|
@@ -88,7 +90,12 @@ async def exec_scalar_request(
|
|
88
90
|
ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
|
89
91
|
ValueError: If the result is not of the expected scalar type.
|
90
92
|
"""
|
91
|
-
rpc_result = await _exec_request(
|
93
|
+
rpc_result = await _exec_request(
|
94
|
+
method=method,
|
95
|
+
params=params,
|
96
|
+
transport=transport,
|
97
|
+
server_error_mapper=server_error_mapper,
|
98
|
+
)
|
92
99
|
if (result_type is type(None) and rpc_result is not None) or not isinstance(
|
93
100
|
rpc_result, result_type
|
94
101
|
):
|
@@ -101,6 +108,7 @@ async def exec_model_request(
|
|
101
108
|
params: JSONRPCParamsType,
|
102
109
|
result_type: Type[BaseModelT],
|
103
110
|
transport: JSONRPCTransport,
|
111
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
104
112
|
) -> BaseModelT:
|
105
113
|
"""
|
106
114
|
Execute a JSON-RPC command to a sandbox environment expecting a model result.
|
@@ -110,6 +118,7 @@ async def exec_model_request(
|
|
110
118
|
params (JSONRPCParamsType): The parameters for the JSON-RPC method.
|
111
119
|
result_type (Type[BaseModelT]): The Pydantic model class to validate and parse the result.
|
112
120
|
transport (JSONRPCTransport): The transport callable to use for the RPC communication.
|
121
|
+
server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
|
113
122
|
|
114
123
|
Returns:
|
115
124
|
BaseModelT: The parsed and validated result of the JSON-RPC call.
|
@@ -119,7 +128,12 @@ async def exec_model_request(
|
|
119
128
|
ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
|
120
129
|
ValueError: If the result cannot be validated against the provided model class.
|
121
130
|
"""
|
122
|
-
rpc_result = await _exec_request(
|
131
|
+
rpc_result = await _exec_request(
|
132
|
+
method=method,
|
133
|
+
params=params,
|
134
|
+
transport=transport,
|
135
|
+
server_error_mapper=server_error_mapper,
|
136
|
+
)
|
123
137
|
return result_type.model_validate(rpc_result, strict=True)
|
124
138
|
|
125
139
|
|
@@ -161,6 +175,7 @@ async def _exec_request(
|
|
161
175
|
method: str,
|
162
176
|
params: JSONRPCParamsType,
|
163
177
|
transport: JSONRPCTransport,
|
178
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
164
179
|
) -> object:
|
165
180
|
"""Execute a request using the provided transport mechanism."""
|
166
181
|
return parse_json_rpc_response(
|
@@ -171,6 +186,7 @@ async def _exec_request(
|
|
171
186
|
),
|
172
187
|
method,
|
173
188
|
params,
|
189
|
+
server_error_mapper,
|
174
190
|
)
|
175
191
|
|
176
192
|
|
@@ -178,15 +194,16 @@ def parse_json_rpc_response(
|
|
178
194
|
response_str: str,
|
179
195
|
method: str,
|
180
196
|
params: JSONRPCParamsType,
|
197
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
181
198
|
) -> object:
|
182
199
|
"""Validates the JSON RPC response and returns the result or raises a proper Inspect error."""
|
183
200
|
match JSONRPCResponse.model_validate_json(response_str).root:
|
184
201
|
case JSONRPCSuccessResponse(result=rpc_result):
|
185
202
|
return rpc_result
|
186
|
-
case JSONRPCErrorResponse(
|
187
|
-
|
188
|
-
|
189
|
-
|
203
|
+
case JSONRPCErrorResponse(error=JSONRPCError(code=code, message=message)):
|
204
|
+
raise exception_for_rpc_response_error(
|
205
|
+
code, message, method, params, server_error_mapper
|
206
|
+
)
|
190
207
|
case _:
|
191
208
|
raise ValueError(
|
192
209
|
f"Unexpected JSON RPC response to request {_rpc_call_description(method, params)}: {response_str}"
|
@@ -220,16 +237,17 @@ def exception_for_rpc_response_error(
|
|
220
237
|
if server_error_mapper
|
221
238
|
else ToolError(message)
|
222
239
|
)
|
240
|
+
elif code == -32602: # (Invalid params)
|
241
|
+
# Even though the Inspect side does validation, it can't possibly be
|
242
|
+
# complete - especially for tools that have dynamic action dependant
|
243
|
+
# rules for optional/required params.
|
244
|
+
return ToolParsingError(message)
|
223
245
|
elif code == -32603:
|
224
246
|
return ToolError(message)
|
225
247
|
else:
|
226
248
|
# -32600 (Invalid Request)
|
227
249
|
# If we sent a bogus request, it's 100% a code bug.
|
228
250
|
# -32601 (Method not found)
|
229
|
-
# -32602 (Invalid params)
|
230
|
-
# These shouldn't be possible since Inspect did validation prior to
|
231
|
-
# making the tool call. Because of that, these errors should not make
|
232
|
-
# it back to the model, so choose RuntimeError.
|
233
251
|
# -32700 (Parse error)
|
234
252
|
# shouldn't be seen in this flow since we're processing responses, and
|
235
253
|
# this is a request oriented error.
|
@@ -276,10 +294,20 @@ def create_json_rpc_request(
|
|
276
294
|
is_notification: bool,
|
277
295
|
) -> str:
|
278
296
|
return json.dumps(
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
297
|
+
remove_none_values(
|
298
|
+
{
|
299
|
+
"jsonrpc": "2.0",
|
300
|
+
"method": method,
|
301
|
+
**({"params": params} if params else {}),
|
302
|
+
**({"id": next(id_generator)} if not is_notification else {}),
|
303
|
+
}
|
304
|
+
)
|
285
305
|
)
|
306
|
+
|
307
|
+
|
308
|
+
def remove_none_values(obj: object) -> object:
|
309
|
+
if isinstance(obj, dict):
|
310
|
+
return {k: remove_none_values(v) for k, v in obj.items() if v is not None}
|
311
|
+
elif isinstance(obj, list):
|
312
|
+
return [remove_none_values(item) for item in obj if item is not None]
|
313
|
+
return obj
|
inspect_ai/tool/_mcp/_mcp.py
CHANGED
@@ -61,16 +61,17 @@ class MCPServerImpl(MCPServer):
|
|
61
61
|
) -> list[Tool]:
|
62
62
|
return await self._task_session()._list_tools(tools)
|
63
63
|
|
64
|
-
# create a separate MCPServer session per async task
|
65
|
-
_task_sessions: dict[
|
64
|
+
# create a separate MCPServer session per async task / server name
|
65
|
+
_task_sessions: dict[str, "MCPServerSession"] = {}
|
66
66
|
|
67
67
|
def _task_session(self) -> "MCPServerSession":
|
68
68
|
task_id = anyio.get_current_task().id
|
69
|
-
|
70
|
-
|
69
|
+
session_key = f"{task_id}_{self._name}"
|
70
|
+
if session_key not in self._task_sessions:
|
71
|
+
MCPServerImpl._task_sessions[session_key] = MCPServerSession(
|
71
72
|
self._client, name=self._name, events=self._events
|
72
73
|
)
|
73
|
-
return MCPServerImpl._task_sessions[
|
74
|
+
return MCPServerImpl._task_sessions[session_key]
|
74
75
|
|
75
76
|
|
76
77
|
class MCPServerSession(MCPServer):
|
@@ -259,6 +260,7 @@ def create_server_sandbox(
|
|
259
260
|
cwd: str | Path | None = None,
|
260
261
|
env: dict[str, str] | None = None,
|
261
262
|
sandbox: str | None = None,
|
263
|
+
timeout: int | None = None,
|
262
264
|
) -> MCPServer:
|
263
265
|
# TODO: Confirm the lifetime concepts. By the time a request makes it to the
|
264
266
|
# sandbox, it's going to need both a session id and a server "name".
|
@@ -272,6 +274,7 @@ def create_server_sandbox(
|
|
272
274
|
env=env,
|
273
275
|
),
|
274
276
|
sandbox_name=sandbox,
|
277
|
+
timeout=timeout,
|
275
278
|
),
|
276
279
|
name=name,
|
277
280
|
events=False,
|
inspect_ai/tool/_mcp/_sandbox.py
CHANGED
@@ -11,7 +11,7 @@ from inspect_ai.tool._tool_support_helpers import (
|
|
11
11
|
exec_model_request,
|
12
12
|
exec_notification,
|
13
13
|
exec_scalar_request,
|
14
|
-
|
14
|
+
tool_support_sandbox,
|
15
15
|
)
|
16
16
|
|
17
17
|
from ._context import MCPServerContext
|
@@ -28,8 +28,10 @@ async def sandbox_client( # type: ignore
|
|
28
28
|
*,
|
29
29
|
sandbox_name: str | None = None,
|
30
30
|
errlog: TextIO = sys.stderr,
|
31
|
+
timeout: int | None = None, # default 180 seconds
|
31
32
|
) -> MCPServerContext: # type: ignore
|
32
|
-
|
33
|
+
timeout = timeout or 180
|
34
|
+
(sandbox_environment, _) = await tool_support_sandbox(
|
33
35
|
"mcp support", sandbox_name=sandbox_name
|
34
36
|
)
|
35
37
|
|
@@ -49,6 +51,7 @@ async def sandbox_client( # type: ignore
|
|
49
51
|
method="mcp_launch_server",
|
50
52
|
params={"server_params": server.model_dump()},
|
51
53
|
result_type=int,
|
54
|
+
timeout=timeout,
|
52
55
|
)
|
53
56
|
|
54
57
|
async def stdout_reader() -> None:
|
@@ -72,6 +75,7 @@ async def sandbox_client( # type: ignore
|
|
72
75
|
"request": root.model_dump(),
|
73
76
|
},
|
74
77
|
result_type=JSONRPCMessage,
|
78
|
+
timeout=timeout,
|
75
79
|
)
|
76
80
|
)
|
77
81
|
elif isinstance(root, JSONRPCNotification):
|
@@ -82,6 +86,7 @@ async def sandbox_client( # type: ignore
|
|
82
86
|
"session_id": session_id,
|
83
87
|
"notification": root.model_dump(),
|
84
88
|
},
|
89
|
+
timeout=timeout,
|
85
90
|
)
|
86
91
|
else:
|
87
92
|
assert False, f"Unexpected message type {message=}"
|
@@ -101,4 +106,5 @@ async def sandbox_client( # type: ignore
|
|
101
106
|
method="mcp_kill_server",
|
102
107
|
params={"session_id": session_id},
|
103
108
|
result_type=type(None),
|
109
|
+
timeout=timeout,
|
104
110
|
)
|
inspect_ai/tool/_mcp/server.py
CHANGED
@@ -73,6 +73,7 @@ def mcp_server_sandbox(
|
|
73
73
|
cwd: str | Path | None = None,
|
74
74
|
env: dict[str, str] | None = None,
|
75
75
|
sandbox: str | None = None,
|
76
|
+
timeout: int | None = None,
|
76
77
|
) -> MCPServer:
|
77
78
|
"""MCP Server (Sandbox).
|
78
79
|
|
@@ -87,6 +88,7 @@ def mcp_server_sandbox(
|
|
87
88
|
"SHELL", "TERM", and "USER" for Posix-based systems).
|
88
89
|
cwd: The working directory to use when spawning the process.
|
89
90
|
sandbox: The sandbox to use when spawning the process.
|
91
|
+
timeout: Timeout (in seconds) for command.
|
90
92
|
|
91
93
|
Returns:
|
92
94
|
McpClient: Client for MCP Server
|
@@ -94,7 +96,7 @@ def mcp_server_sandbox(
|
|
94
96
|
verfify_mcp_package()
|
95
97
|
from ._mcp import create_server_sandbox
|
96
98
|
|
97
|
-
return create_server_sandbox(command, args, cwd, env, sandbox)
|
99
|
+
return create_server_sandbox(command, args, cwd, env, sandbox, timeout)
|
98
100
|
|
99
101
|
|
100
102
|
def verfify_mcp_package() -> None:
|
inspect_ai/tool/_tool_call.py
CHANGED
@@ -68,9 +68,12 @@ class ToolCallError:
|
|
68
68
|
"permission",
|
69
69
|
"file_not_found",
|
70
70
|
"is_a_directory",
|
71
|
-
"
|
71
|
+
"limit",
|
72
72
|
"approval",
|
73
73
|
"unknown",
|
74
|
+
# Retained for backward compatibility when loading logs created with an older
|
75
|
+
# version of inspect.
|
76
|
+
"output_limit",
|
74
77
|
]
|
75
78
|
"""Error type."""
|
76
79
|
|