inspect-ai 0.3.91__py3-none-any.whl → 0.3.93__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +31 -0
- inspect_ai/_eval/eval.py +19 -2
- inspect_ai/_eval/evalset.py +4 -1
- inspect_ai/_eval/run.py +41 -0
- inspect_ai/_eval/task/generate.py +38 -44
- inspect_ai/_eval/task/log.py +26 -28
- inspect_ai/_eval/task/run.py +13 -20
- inspect_ai/_util/local_server.py +368 -0
- inspect_ai/_util/working.py +10 -4
- inspect_ai/_view/www/dist/assets/index.css +159 -146
- inspect_ai/_view/www/dist/assets/index.js +1020 -1061
- inspect_ai/_view/www/log-schema.json +4 -3
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +3 -2
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +2 -2
- inspect_ai/_view/www/src/app/content/MetaDataView.module.css +1 -1
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +1 -1
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +1 -1
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +11 -0
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +2 -9
- inspect_ai/_view/www/src/app/log-view/tabs/ModelsTab.tsx +51 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.module.css +6 -0
- inspect_ai/_view/www/src/app/log-view/tabs/TaskTab.tsx +143 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +1 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +29 -7
- inspect_ai/_view/www/src/app/plan/PlanDetailView.module.css +1 -1
- inspect_ai/_view/www/src/app/plan/PlanDetailView.tsx +1 -198
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -1
- inspect_ai/_view/www/src/app/usage/ModelUsagePanel.tsx +3 -2
- inspect_ai/_view/www/src/app/usage/TokenTable.module.css +4 -1
- inspect_ai/_view/www/src/app/usage/TokenTable.tsx +2 -2
- inspect_ai/_view/www/src/app/usage/UsageCard.module.css +8 -3
- inspect_ai/_view/www/src/app/usage/UsageCard.tsx +1 -35
- inspect_ai/_view/www/src/components/Card.css +0 -1
- inspect_ai/_view/www/src/constants.ts +2 -0
- inspect_ai/_view/www/src/utils/numeric.ts +17 -0
- inspect_ai/agent/_agent.py +3 -3
- inspect_ai/agent/_as_solver.py +20 -12
- inspect_ai/agent/_as_tool.py +15 -3
- inspect_ai/agent/_handoff.py +8 -1
- inspect_ai/agent/_run.py +11 -3
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +56 -0
- inspect_ai/log/_log.py +99 -0
- inspect_ai/log/_recorders/__init__.py +2 -0
- inspect_ai/log/_recorders/buffer/database.py +12 -11
- inspect_ai/log/_recorders/buffer/filestore.py +2 -2
- inspect_ai/log/_recorders/buffer/types.py +2 -2
- inspect_ai/log/_recorders/eval.py +20 -65
- inspect_ai/log/_recorders/file.py +28 -6
- inspect_ai/log/_recorders/recorder.py +7 -0
- inspect_ai/log/_recorders/types.py +1 -23
- inspect_ai/log/_samples.py +0 -8
- inspect_ai/log/_transcript.py +7 -1
- inspect_ai/log/_util.py +52 -0
- inspect_ai/model/__init__.py +5 -1
- inspect_ai/model/_call_tools.py +32 -12
- inspect_ai/model/_generate_config.py +14 -8
- inspect_ai/model/_model.py +21 -48
- inspect_ai/model/_model_output.py +25 -0
- inspect_ai/model/_openai.py +2 -0
- inspect_ai/model/_openai_responses.py +13 -1
- inspect_ai/model/_providers/anthropic.py +13 -23
- inspect_ai/model/_providers/openai_o1.py +8 -2
- inspect_ai/model/_providers/providers.py +18 -4
- inspect_ai/model/_providers/sglang.py +241 -0
- inspect_ai/model/_providers/vllm.py +207 -400
- inspect_ai/solver/__init__.py +7 -2
- inspect_ai/solver/_basic_agent.py +3 -10
- inspect_ai/solver/_task_state.py +26 -88
- inspect_ai/tool/_json_rpc_helpers.py +45 -17
- inspect_ai/tool/_mcp/_mcp.py +2 -0
- inspect_ai/tool/_mcp/_sandbox.py +8 -2
- inspect_ai/tool/_mcp/server.py +3 -1
- inspect_ai/tool/_tool_call.py +4 -1
- inspect_ai/tool/_tool_support_helpers.py +51 -12
- inspect_ai/tool/_tools/_bash_session.py +190 -68
- inspect_ai/tool/_tools/_computer/_computer.py +25 -1
- inspect_ai/tool/_tools/_text_editor.py +4 -3
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +10 -3
- inspect_ai/util/__init__.py +12 -0
- inspect_ai/util/_limit.py +393 -0
- inspect_ai/util/_limited_conversation.py +57 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/RECORD +90 -109
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/WHEEL +1 -1
- inspect_ai/solver/_limit.py +0 -39
- inspect_ai/tool/_tools/_computer/_resources/Dockerfile +0 -102
- inspect_ai/tool/_tools/_computer/_resources/README.md +0 -30
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/entrypoint.sh +0 -18
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/novnc_startup.sh +0 -20
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/x11vnc_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xfce_startup.sh +0 -13
- inspect_ai/tool/_tools/_computer/_resources/entrypoint/xvfb_startup.sh +0 -48
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/globalStorage/state.vscdb +0 -0
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/Code/User/settings.json +0 -9
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-panel.xml +0 -61
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfwm4.xml +0 -91
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Terminal.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +0 -10
- inspect_ai/tool/_tools/_computer/_resources/tool/.pylintrc +0 -8
- inspect_ai/tool/_tools/_computer/_resources/tool/.vscode/settings.json +0 -12
- inspect_ai/tool/_tools/_computer/_resources/tool/_args.py +0 -78
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_logger.py +0 -22
- inspect_ai/tool/_tools/_computer/_resources/tool/_run.py +0 -42
- inspect_ai/tool/_tools/_computer/_resources/tool/_tool_result.py +0 -33
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +0 -341
- inspect_ai/tool/_tools/_computer/_resources/tool/computer_tool.py +0 -141
- inspect_ai/tool/_tools/_computer/_resources/tool/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/tool/_tools/_computer/test_args.py +0 -151
- /inspect_ai/{tool/_tools/_computer/_resources/tool/__init__.py → _view/www/src/app/log-view/tabs/ModelsTab.module.css} +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.91.dist-info → inspect_ai-0.3.93.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,8 @@ from inspect_ai.scorer._score import score
|
|
13
13
|
from inspect_ai.solver._chain import chain
|
14
14
|
from inspect_ai.tool._tool import Tool, ToolResult, tool
|
15
15
|
from inspect_ai.tool._tool_with import tool_with
|
16
|
+
from inspect_ai.util._limit import token_limit as create_token_limit
|
16
17
|
|
17
|
-
from ._limit import SampleLimitExceededError
|
18
18
|
from ._prompt import system_message
|
19
19
|
from ._solver import Generate, Solver, solver
|
20
20
|
from ._task_state import TaskState
|
@@ -172,14 +172,11 @@ def basic_agent(
|
|
172
172
|
# (if there is no message_limit then default to 50)
|
173
173
|
state.message_limit = message_limit or state.message_limit or 50
|
174
174
|
|
175
|
-
# resolve token limit
|
176
|
-
state.token_limit = token_limit or state.token_limit
|
177
|
-
|
178
175
|
# track attempts
|
179
176
|
attempts = 0
|
180
177
|
|
181
|
-
|
182
|
-
# main loop
|
178
|
+
with create_token_limit(token_limit):
|
179
|
+
# main loop
|
183
180
|
while not state.completed:
|
184
181
|
# generate output and append assistant message
|
185
182
|
state.output = await get_model().generate(
|
@@ -247,10 +244,6 @@ def basic_agent(
|
|
247
244
|
else:
|
248
245
|
state.messages.append(ChatMessageUser(content=continue_message))
|
249
246
|
|
250
|
-
# propagate current state along with sample limit exceeded
|
251
|
-
except SampleLimitExceededError as ex:
|
252
|
-
raise ex.with_state(state)
|
253
|
-
|
254
247
|
return state
|
255
248
|
|
256
249
|
return solve
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -2,9 +2,8 @@ from collections.abc import Sequence
|
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from copy import deepcopy
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from itertools import tee
|
6
5
|
from random import Random
|
7
|
-
from typing import Any,
|
6
|
+
from typing import Any, Type, Union, cast, overload
|
8
7
|
|
9
8
|
from pydantic_core import to_jsonable_python
|
10
9
|
from shortuuid import uuid
|
@@ -18,12 +17,18 @@ from inspect_ai.model import (
|
|
18
17
|
ModelOutput,
|
19
18
|
)
|
20
19
|
from inspect_ai.model._call_tools import tools_info
|
21
|
-
from inspect_ai.model._chat_message import ChatMessageBase
|
22
20
|
from inspect_ai.model._model import sample_total_tokens
|
23
21
|
from inspect_ai.scorer._metric import Score
|
24
22
|
from inspect_ai.scorer._target import Target
|
25
23
|
from inspect_ai.tool import Tool, ToolChoice
|
26
24
|
from inspect_ai.tool._tool_def import ToolDef
|
25
|
+
from inspect_ai.util._limit import (
|
26
|
+
check_message_limit,
|
27
|
+
check_token_limit,
|
28
|
+
)
|
29
|
+
from inspect_ai.util._limit import message_limit as create_message_limit
|
30
|
+
from inspect_ai.util._limit import token_limit as create_token_limit
|
31
|
+
from inspect_ai.util._limited_conversation import ChatMessageList
|
27
32
|
from inspect_ai.util._store import Store, store_jsonable
|
28
33
|
from inspect_ai.util._store_model import SMT
|
29
34
|
|
@@ -159,11 +164,11 @@ class TaskState:
|
|
159
164
|
self._input = input
|
160
165
|
self._target = target
|
161
166
|
self._metadata = metadata
|
162
|
-
self._messages: list[ChatMessage] = ChatMessageList(messages
|
167
|
+
self._messages: list[ChatMessage] = ChatMessageList(messages)
|
163
168
|
self._tools: list[Tool] = []
|
164
169
|
self._output = output if output else ModelOutput(model=str(model))
|
165
|
-
self._message_limit = message_limit
|
166
|
-
self._token_limit = token_limit
|
170
|
+
self._message_limit = create_message_limit(message_limit)
|
171
|
+
self._token_limit = create_token_limit(token_limit)
|
167
172
|
self._completed = completed
|
168
173
|
self._store = Store()
|
169
174
|
self._uuid = uuid()
|
@@ -254,7 +259,7 @@ class TaskState:
|
|
254
259
|
|
255
260
|
@messages.setter
|
256
261
|
def messages(self, messages: list[ChatMessage]) -> None:
|
257
|
-
self._messages = ChatMessageList(messages
|
262
|
+
self._messages = ChatMessageList(messages)
|
258
263
|
|
259
264
|
@property
|
260
265
|
def output(self) -> ModelOutput:
|
@@ -302,12 +307,16 @@ class TaskState:
|
|
302
307
|
@property
|
303
308
|
def message_limit(self) -> int | None:
|
304
309
|
"""Limit on total messages allowed per conversation."""
|
305
|
-
return self._message_limit
|
310
|
+
return self._message_limit.limit
|
306
311
|
|
307
312
|
@message_limit.setter
|
308
313
|
def message_limit(self, messages: int | None) -> None:
|
309
|
-
"""Set limit on total messages allowed per conversation.
|
310
|
-
|
314
|
+
"""Set limit on total messages allowed per conversation.
|
315
|
+
|
316
|
+
Also checks whether the current message count exceeds the new limit.
|
317
|
+
"""
|
318
|
+
self._message_limit.limit = messages
|
319
|
+
check_message_limit(len(self.messages), raise_for_equal=False)
|
311
320
|
|
312
321
|
from inspect_ai.log._samples import set_active_sample_message_limit
|
313
322
|
|
@@ -316,12 +325,16 @@ class TaskState:
|
|
316
325
|
@property
|
317
326
|
def token_limit(self) -> int | None:
|
318
327
|
"""Limit on total tokens allowed per conversation."""
|
319
|
-
return self._token_limit
|
328
|
+
return self._token_limit.limit
|
320
329
|
|
321
330
|
@token_limit.setter
|
322
331
|
def token_limit(self, tokens: int | None) -> None:
|
323
|
-
"""Set limit on total tokens allowed per conversation.
|
324
|
-
|
332
|
+
"""Set limit on total tokens allowed per conversation.
|
333
|
+
|
334
|
+
Also checks whether the current token usage exceeds the new limit.
|
335
|
+
"""
|
336
|
+
self._token_limit.limit = tokens
|
337
|
+
check_token_limit()
|
325
338
|
|
326
339
|
from inspect_ai.log._samples import set_active_sample_token_limit
|
327
340
|
|
@@ -340,24 +353,11 @@ class TaskState:
|
|
340
353
|
"""
|
341
354
|
from inspect_ai.log._samples import set_active_sample_total_messages
|
342
355
|
|
343
|
-
from ._limit import SampleLimitExceededError
|
344
|
-
|
345
356
|
# update messages
|
346
357
|
set_active_sample_total_messages(len(self.messages))
|
347
358
|
|
348
359
|
if self._completed:
|
349
360
|
return True
|
350
|
-
elif self.message_limit and len(self.messages) >= self.message_limit:
|
351
|
-
raise SampleLimitExceededError(
|
352
|
-
"message",
|
353
|
-
value=len(self.messages),
|
354
|
-
limit=self.message_limit,
|
355
|
-
state=self,
|
356
|
-
)
|
357
|
-
elif self.token_limit and self.token_usage >= self.token_limit:
|
358
|
-
raise SampleLimitExceededError(
|
359
|
-
"token", value=self.token_usage, limit=self.token_limit, state=self
|
360
|
-
)
|
361
361
|
else:
|
362
362
|
check_sample_interrupt()
|
363
363
|
return self._completed
|
@@ -445,65 +445,3 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
|
|
445
445
|
def sample_jsonable(sample: Sample) -> dict[str, Any]:
|
446
446
|
jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
|
447
447
|
return cast(dict[str, Any], deepcopy(jsonable))
|
448
|
-
|
449
|
-
|
450
|
-
class ChatMessageList(list[ChatMessage]):
|
451
|
-
def __init__(self, iterable: Iterable[ChatMessage], parent_state: TaskState):
|
452
|
-
self.parent_state = parent_state
|
453
|
-
items, length = self._iterable_length(iterable)
|
454
|
-
self._check_size(length)
|
455
|
-
super().__init__(items)
|
456
|
-
|
457
|
-
def _check_size(self, additional_items: int = 1) -> None:
|
458
|
-
from inspect_ai.log._samples import active_sample_message_limit
|
459
|
-
|
460
|
-
from ._limit import SampleLimitExceededError
|
461
|
-
|
462
|
-
messages_limit = active_sample_message_limit()
|
463
|
-
if messages_limit is not None:
|
464
|
-
messages = len(self) + additional_items
|
465
|
-
if messages > messages_limit:
|
466
|
-
raise SampleLimitExceededError(
|
467
|
-
"message",
|
468
|
-
value=messages,
|
469
|
-
limit=messages_limit,
|
470
|
-
message=None,
|
471
|
-
state=self.parent_state,
|
472
|
-
)
|
473
|
-
|
474
|
-
def append(self, item: ChatMessage) -> None:
|
475
|
-
self._check_size()
|
476
|
-
super().append(item)
|
477
|
-
|
478
|
-
def extend(self, items: Iterable[ChatMessage]) -> None:
|
479
|
-
items, length = self._iterable_length(items)
|
480
|
-
self._check_size(length)
|
481
|
-
super().extend(items)
|
482
|
-
|
483
|
-
def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
|
484
|
-
self._check_size()
|
485
|
-
super().insert(index, item)
|
486
|
-
|
487
|
-
@overload
|
488
|
-
def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
|
489
|
-
|
490
|
-
@overload
|
491
|
-
def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
|
492
|
-
|
493
|
-
def __setitem__(
|
494
|
-
self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
|
495
|
-
) -> None:
|
496
|
-
if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
|
497
|
-
item, length = self._iterable_length(item)
|
498
|
-
size_change = length - len(self[index])
|
499
|
-
if size_change > 0:
|
500
|
-
self._check_size(size_change)
|
501
|
-
|
502
|
-
super().__setitem__(index, item) # type: ignore[assignment,index]
|
503
|
-
|
504
|
-
def _iterable_length(
|
505
|
-
self, items: Iterable[ChatMessage]
|
506
|
-
) -> tuple[Iterable[ChatMessage], int]:
|
507
|
-
items, counter = tee(items)
|
508
|
-
length = sum(1 for _ in counter)
|
509
|
-
return items, length
|
@@ -4,7 +4,7 @@ from typing import Literal, Protocol, Type, TypeAlias, TypeVar
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel, RootModel
|
6
6
|
|
7
|
-
from inspect_ai.tool._tool import ToolError
|
7
|
+
from inspect_ai.tool._tool import ToolError, ToolParsingError
|
8
8
|
|
9
9
|
|
10
10
|
class JSONRPCResponseBase(BaseModel):
|
@@ -70,6 +70,7 @@ async def exec_scalar_request(
|
|
70
70
|
params: JSONRPCParamsType,
|
71
71
|
result_type: Type[ScalarT],
|
72
72
|
transport: JSONRPCTransport,
|
73
|
+
server_error_mapper: JSONRPCServerErrorMapper,
|
73
74
|
) -> ScalarT:
|
74
75
|
"""
|
75
76
|
Execute a JSON-RPC command expecting a scalar result.
|
@@ -79,6 +80,7 @@ async def exec_scalar_request(
|
|
79
80
|
params (JSONRPCParamsType): The parameters for the JSON-RPC method.
|
80
81
|
result_type (Type[ScalarT]): The scalar type (str, int, float, bool, None) to validate the result against.
|
81
82
|
transport (JSONRPCTransport): The transport callable to use for the RPC communication.
|
83
|
+
server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
|
82
84
|
|
83
85
|
Returns:
|
84
86
|
ScalarT: The scalar result of the JSON-RPC call.
|
@@ -88,7 +90,12 @@ async def exec_scalar_request(
|
|
88
90
|
ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
|
89
91
|
ValueError: If the result is not of the expected scalar type.
|
90
92
|
"""
|
91
|
-
rpc_result = await _exec_request(
|
93
|
+
rpc_result = await _exec_request(
|
94
|
+
method=method,
|
95
|
+
params=params,
|
96
|
+
transport=transport,
|
97
|
+
server_error_mapper=server_error_mapper,
|
98
|
+
)
|
92
99
|
if (result_type is type(None) and rpc_result is not None) or not isinstance(
|
93
100
|
rpc_result, result_type
|
94
101
|
):
|
@@ -101,6 +108,7 @@ async def exec_model_request(
|
|
101
108
|
params: JSONRPCParamsType,
|
102
109
|
result_type: Type[BaseModelT],
|
103
110
|
transport: JSONRPCTransport,
|
111
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
104
112
|
) -> BaseModelT:
|
105
113
|
"""
|
106
114
|
Execute a JSON-RPC command to a sandbox environment expecting a model result.
|
@@ -110,6 +118,7 @@ async def exec_model_request(
|
|
110
118
|
params (JSONRPCParamsType): The parameters for the JSON-RPC method.
|
111
119
|
result_type (Type[BaseModelT]): The Pydantic model class to validate and parse the result.
|
112
120
|
transport (JSONRPCTransport): The transport callable to use for the RPC communication.
|
121
|
+
server_error_mapper (JSONRPCServerErrorMapper): A callable to map server specific JSON-RPC errors to exceptions.
|
113
122
|
|
114
123
|
Returns:
|
115
124
|
BaseModelT: The parsed and validated result of the JSON-RPC call.
|
@@ -119,7 +128,12 @@ async def exec_model_request(
|
|
119
128
|
ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
|
120
129
|
ValueError: If the result cannot be validated against the provided model class.
|
121
130
|
"""
|
122
|
-
rpc_result = await _exec_request(
|
131
|
+
rpc_result = await _exec_request(
|
132
|
+
method=method,
|
133
|
+
params=params,
|
134
|
+
transport=transport,
|
135
|
+
server_error_mapper=server_error_mapper,
|
136
|
+
)
|
123
137
|
return result_type.model_validate(rpc_result, strict=True)
|
124
138
|
|
125
139
|
|
@@ -161,6 +175,7 @@ async def _exec_request(
|
|
161
175
|
method: str,
|
162
176
|
params: JSONRPCParamsType,
|
163
177
|
transport: JSONRPCTransport,
|
178
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
164
179
|
) -> object:
|
165
180
|
"""Execute a request using the provided transport mechanism."""
|
166
181
|
return parse_json_rpc_response(
|
@@ -171,6 +186,7 @@ async def _exec_request(
|
|
171
186
|
),
|
172
187
|
method,
|
173
188
|
params,
|
189
|
+
server_error_mapper,
|
174
190
|
)
|
175
191
|
|
176
192
|
|
@@ -178,15 +194,16 @@ def parse_json_rpc_response(
|
|
178
194
|
response_str: str,
|
179
195
|
method: str,
|
180
196
|
params: JSONRPCParamsType,
|
197
|
+
server_error_mapper: JSONRPCServerErrorMapper | None = None,
|
181
198
|
) -> object:
|
182
199
|
"""Validates the JSON RPC response and returns the result or raises a proper Inspect error."""
|
183
200
|
match JSONRPCResponse.model_validate_json(response_str).root:
|
184
201
|
case JSONRPCSuccessResponse(result=rpc_result):
|
185
202
|
return rpc_result
|
186
|
-
case JSONRPCErrorResponse(
|
187
|
-
|
188
|
-
|
189
|
-
|
203
|
+
case JSONRPCErrorResponse(error=JSONRPCError(code=code, message=message)):
|
204
|
+
raise exception_for_rpc_response_error(
|
205
|
+
code, message, method, params, server_error_mapper
|
206
|
+
)
|
190
207
|
case _:
|
191
208
|
raise ValueError(
|
192
209
|
f"Unexpected JSON RPC response to request {_rpc_call_description(method, params)}: {response_str}"
|
@@ -220,16 +237,17 @@ def exception_for_rpc_response_error(
|
|
220
237
|
if server_error_mapper
|
221
238
|
else ToolError(message)
|
222
239
|
)
|
240
|
+
elif code == -32602: # (Invalid params)
|
241
|
+
# Even though the Inspect side does validation, it can't possibly be
|
242
|
+
# complete - especially for tools that have dynamic action dependant
|
243
|
+
# rules for optional/required params.
|
244
|
+
return ToolParsingError(message)
|
223
245
|
elif code == -32603:
|
224
246
|
return ToolError(message)
|
225
247
|
else:
|
226
248
|
# -32600 (Invalid Request)
|
227
249
|
# If we sent a bogus request, it's 100% a code bug.
|
228
250
|
# -32601 (Method not found)
|
229
|
-
# -32602 (Invalid params)
|
230
|
-
# These shouldn't be possible since Inspect did validation prior to
|
231
|
-
# making the tool call. Because of that, these errors should not make
|
232
|
-
# it back to the model, so choose RuntimeError.
|
233
251
|
# -32700 (Parse error)
|
234
252
|
# shouldn't be seen in this flow since we're processing responses, and
|
235
253
|
# this is a request oriented error.
|
@@ -276,10 +294,20 @@ def create_json_rpc_request(
|
|
276
294
|
is_notification: bool,
|
277
295
|
) -> str:
|
278
296
|
return json.dumps(
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
297
|
+
remove_none_values(
|
298
|
+
{
|
299
|
+
"jsonrpc": "2.0",
|
300
|
+
"method": method,
|
301
|
+
**({"params": params} if params else {}),
|
302
|
+
**({"id": next(id_generator)} if not is_notification else {}),
|
303
|
+
}
|
304
|
+
)
|
285
305
|
)
|
306
|
+
|
307
|
+
|
308
|
+
def remove_none_values(obj: object) -> object:
|
309
|
+
if isinstance(obj, dict):
|
310
|
+
return {k: remove_none_values(v) for k, v in obj.items() if v is not None}
|
311
|
+
elif isinstance(obj, list):
|
312
|
+
return [remove_none_values(item) for item in obj if item is not None]
|
313
|
+
return obj
|
inspect_ai/tool/_mcp/_mcp.py
CHANGED
@@ -259,6 +259,7 @@ def create_server_sandbox(
|
|
259
259
|
cwd: str | Path | None = None,
|
260
260
|
env: dict[str, str] | None = None,
|
261
261
|
sandbox: str | None = None,
|
262
|
+
timeout: int | None = None,
|
262
263
|
) -> MCPServer:
|
263
264
|
# TODO: Confirm the lifetime concepts. By the time a request makes it to the
|
264
265
|
# sandbox, it's going to need both a session id and a server "name".
|
@@ -272,6 +273,7 @@ def create_server_sandbox(
|
|
272
273
|
env=env,
|
273
274
|
),
|
274
275
|
sandbox_name=sandbox,
|
276
|
+
timeout=timeout,
|
275
277
|
),
|
276
278
|
name=name,
|
277
279
|
events=False,
|
inspect_ai/tool/_mcp/_sandbox.py
CHANGED
@@ -11,7 +11,7 @@ from inspect_ai.tool._tool_support_helpers import (
|
|
11
11
|
exec_model_request,
|
12
12
|
exec_notification,
|
13
13
|
exec_scalar_request,
|
14
|
-
|
14
|
+
tool_support_sandbox,
|
15
15
|
)
|
16
16
|
|
17
17
|
from ._context import MCPServerContext
|
@@ -28,8 +28,10 @@ async def sandbox_client( # type: ignore
|
|
28
28
|
*,
|
29
29
|
sandbox_name: str | None = None,
|
30
30
|
errlog: TextIO = sys.stderr,
|
31
|
+
timeout: int | None = None, # default 180 seconds
|
31
32
|
) -> MCPServerContext: # type: ignore
|
32
|
-
|
33
|
+
timeout = timeout or 180
|
34
|
+
(sandbox_environment, _) = await tool_support_sandbox(
|
33
35
|
"mcp support", sandbox_name=sandbox_name
|
34
36
|
)
|
35
37
|
|
@@ -49,6 +51,7 @@ async def sandbox_client( # type: ignore
|
|
49
51
|
method="mcp_launch_server",
|
50
52
|
params={"server_params": server.model_dump()},
|
51
53
|
result_type=int,
|
54
|
+
timeout=timeout,
|
52
55
|
)
|
53
56
|
|
54
57
|
async def stdout_reader() -> None:
|
@@ -72,6 +75,7 @@ async def sandbox_client( # type: ignore
|
|
72
75
|
"request": root.model_dump(),
|
73
76
|
},
|
74
77
|
result_type=JSONRPCMessage,
|
78
|
+
timeout=timeout,
|
75
79
|
)
|
76
80
|
)
|
77
81
|
elif isinstance(root, JSONRPCNotification):
|
@@ -82,6 +86,7 @@ async def sandbox_client( # type: ignore
|
|
82
86
|
"session_id": session_id,
|
83
87
|
"notification": root.model_dump(),
|
84
88
|
},
|
89
|
+
timeout=timeout,
|
85
90
|
)
|
86
91
|
else:
|
87
92
|
assert False, f"Unexpected message type {message=}"
|
@@ -101,4 +106,5 @@ async def sandbox_client( # type: ignore
|
|
101
106
|
method="mcp_kill_server",
|
102
107
|
params={"session_id": session_id},
|
103
108
|
result_type=type(None),
|
109
|
+
timeout=timeout,
|
104
110
|
)
|
inspect_ai/tool/_mcp/server.py
CHANGED
@@ -73,6 +73,7 @@ def mcp_server_sandbox(
|
|
73
73
|
cwd: str | Path | None = None,
|
74
74
|
env: dict[str, str] | None = None,
|
75
75
|
sandbox: str | None = None,
|
76
|
+
timeout: int | None = None,
|
76
77
|
) -> MCPServer:
|
77
78
|
"""MCP Server (Sandbox).
|
78
79
|
|
@@ -87,6 +88,7 @@ def mcp_server_sandbox(
|
|
87
88
|
"SHELL", "TERM", and "USER" for Posix-based systems).
|
88
89
|
cwd: The working directory to use when spawning the process.
|
89
90
|
sandbox: The sandbox to use when spawning the process.
|
91
|
+
timeout: Timeout (in seconds) for command.
|
90
92
|
|
91
93
|
Returns:
|
92
94
|
McpClient: Client for MCP Server
|
@@ -94,7 +96,7 @@ def mcp_server_sandbox(
|
|
94
96
|
verfify_mcp_package()
|
95
97
|
from ._mcp import create_server_sandbox
|
96
98
|
|
97
|
-
return create_server_sandbox(command, args, cwd, env, sandbox)
|
99
|
+
return create_server_sandbox(command, args, cwd, env, sandbox, timeout)
|
98
100
|
|
99
101
|
|
100
102
|
def verfify_mcp_package() -> None:
|
inspect_ai/tool/_tool_call.py
CHANGED
@@ -68,9 +68,12 @@ class ToolCallError:
|
|
68
68
|
"permission",
|
69
69
|
"file_not_found",
|
70
70
|
"is_a_directory",
|
71
|
-
"
|
71
|
+
"limit",
|
72
72
|
"approval",
|
73
73
|
"unknown",
|
74
|
+
# Retained for backward compatibility when loading logs created with an older
|
75
|
+
# version of inspect.
|
76
|
+
"output_limit",
|
74
77
|
]
|
75
78
|
"""Error type."""
|
76
79
|
|
@@ -7,13 +7,17 @@ It includes definitions for JSON-RPC request and response models, as well as fun
|
|
7
7
|
from textwrap import dedent
|
8
8
|
from typing import Type
|
9
9
|
|
10
|
+
import semver
|
11
|
+
|
10
12
|
from inspect_ai._util.error import PrerequisiteError
|
13
|
+
from inspect_ai.tool._tool import ToolError
|
11
14
|
from inspect_ai.util import sandbox_with
|
12
15
|
from inspect_ai.util._sandbox.environment import SandboxEnvironment
|
13
16
|
|
14
17
|
from ._json_rpc_helpers import (
|
15
18
|
BaseModelT,
|
16
19
|
JSONRPCParamsType,
|
20
|
+
JSONRPCServerErrorMapper,
|
17
21
|
JSONRPCTransport,
|
18
22
|
ScalarT,
|
19
23
|
_rpc_call_description,
|
@@ -29,7 +33,7 @@ async def exec_scalar_request(
|
|
29
33
|
method: str,
|
30
34
|
params: JSONRPCParamsType,
|
31
35
|
result_type: Type[ScalarT],
|
32
|
-
timeout: int
|
36
|
+
timeout: int,
|
33
37
|
user: str | None = None,
|
34
38
|
) -> ScalarT:
|
35
39
|
return await scalar_request(
|
@@ -37,6 +41,7 @@ async def exec_scalar_request(
|
|
37
41
|
params,
|
38
42
|
result_type,
|
39
43
|
transport=ToolSupportSandboxTransport(sandbox, timeout, user),
|
44
|
+
server_error_mapper=ToolSupportServerErrorMapper(),
|
40
45
|
)
|
41
46
|
|
42
47
|
|
@@ -45,7 +50,7 @@ async def exec_model_request(
|
|
45
50
|
method: str,
|
46
51
|
params: JSONRPCParamsType,
|
47
52
|
result_type: Type[BaseModelT],
|
48
|
-
timeout: int
|
53
|
+
timeout: int,
|
49
54
|
user: str | None = None,
|
50
55
|
) -> BaseModelT:
|
51
56
|
return await model_request(
|
@@ -53,6 +58,7 @@ async def exec_model_request(
|
|
53
58
|
params,
|
54
59
|
result_type,
|
55
60
|
transport=ToolSupportSandboxTransport(sandbox, timeout, user),
|
61
|
+
server_error_mapper=ToolSupportServerErrorMapper(),
|
56
62
|
)
|
57
63
|
|
58
64
|
|
@@ -60,7 +66,7 @@ async def exec_notification(
|
|
60
66
|
sandbox: SandboxEnvironment,
|
61
67
|
method: str,
|
62
68
|
params: JSONRPCParamsType,
|
63
|
-
timeout: int
|
69
|
+
timeout: int,
|
64
70
|
user: str | None = None,
|
65
71
|
) -> None:
|
66
72
|
return await notification_helper(
|
@@ -68,19 +74,33 @@ async def exec_notification(
|
|
68
74
|
)
|
69
75
|
|
70
76
|
|
77
|
+
class ToolSupportServerErrorMapper(JSONRPCServerErrorMapper):
|
78
|
+
def __call__(
|
79
|
+
self, code: int, message: str, method: str, params: JSONRPCParamsType
|
80
|
+
) -> Exception:
|
81
|
+
"""Map `inspect-tool-support` defined custom codes to an exception."""
|
82
|
+
match code:
|
83
|
+
case -32099: # This is a ToolException from the container
|
84
|
+
return ToolError(message)
|
85
|
+
case -32098: # This is an unexpected exception inside the container
|
86
|
+
return RuntimeError(message)
|
87
|
+
case _:
|
88
|
+
return RuntimeError(message)
|
89
|
+
|
90
|
+
|
71
91
|
class ToolSupportSandboxTransport(JSONRPCTransport):
|
72
92
|
"""
|
73
|
-
A transport
|
93
|
+
A transport that uses a sandbox for RPC communication.
|
74
94
|
|
75
|
-
This class implements the TransportCallable protocol and encapsulates
|
76
|
-
|
77
|
-
|
95
|
+
This class implements the TransportCallable protocol and encapsulates the
|
96
|
+
sandbox, timeout, and user parameters needed for sandbox-based RPC
|
97
|
+
communication.
|
78
98
|
"""
|
79
99
|
|
80
100
|
def __init__(
|
81
101
|
self,
|
82
102
|
sandbox: SandboxEnvironment,
|
83
|
-
timeout: int
|
103
|
+
timeout: int,
|
84
104
|
user: str | None = None,
|
85
105
|
):
|
86
106
|
"""
|
@@ -128,13 +148,32 @@ class ToolSupportSandboxTransport(JSONRPCTransport):
|
|
128
148
|
|
129
149
|
SANDBOX_CLI = "inspect-tool-support"
|
130
150
|
INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB = "aisiuk/inspect-tool-support"
|
151
|
+
FIRST_PUBLISHED_VERSION = semver.Version.parse("0.1.6")
|
152
|
+
MIN_SUPPORTED_VERSION = FIRST_PUBLISHED_VERSION
|
153
|
+
MIN_NON_DEPRECATED_VERSION = semver.Version.parse("1.0.0")
|
154
|
+
|
155
|
+
|
156
|
+
async def _get_sandbox_tool_support_version(
|
157
|
+
sandbox: SandboxEnvironment,
|
158
|
+
) -> semver.Version:
|
159
|
+
try:
|
160
|
+
return semver.Version.parse(
|
161
|
+
await exec_scalar_request(sandbox, "version", {}, str, 5)
|
162
|
+
)
|
163
|
+
except RuntimeError as rte:
|
164
|
+
if "-32601" in str(rte):
|
165
|
+
# The container doesn't even have a version method. The first version
|
166
|
+
# published was 0.1.6, so we'll have to assume it was that old.
|
167
|
+
return FIRST_PUBLISHED_VERSION
|
168
|
+
raise rte
|
131
169
|
|
132
170
|
|
133
|
-
async def
|
171
|
+
async def tool_support_sandbox(
|
134
172
|
tool_name: str, *, sandbox_name: str | None = None
|
135
|
-
) -> SandboxEnvironment:
|
173
|
+
) -> tuple[SandboxEnvironment, semver.Version]:
|
136
174
|
if sb := await sandbox_with(SANDBOX_CLI, True, name=sandbox_name):
|
137
|
-
|
175
|
+
current_version = await _get_sandbox_tool_support_version(sb)
|
176
|
+
return (sb, current_version)
|
138
177
|
|
139
178
|
# This sort of programmatic sentence building will not cut it if we ever
|
140
179
|
# support other languages.
|
@@ -160,7 +199,7 @@ async def tool_container_sandbox(
|
|
160
199
|
|
161
200
|
|
162
201
|
def create_sandbox_transport(
|
163
|
-
sandbox: SandboxEnvironment, timeout: int
|
202
|
+
sandbox: SandboxEnvironment, timeout: int, user: str | None = None
|
164
203
|
) -> JSONRPCTransport:
|
165
204
|
"""
|
166
205
|
Create a transport callable that uses a sandbox for RPC communication.
|