inspect-ai 0.3.75__py3-none-any.whl → 0.3.77__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +16 -0
- inspect_ai/_display/core/results.py +6 -1
- inspect_ai/_eval/eval.py +8 -1
- inspect_ai/_eval/evalset.py +6 -2
- inspect_ai/_eval/registry.py +3 -5
- inspect_ai/_eval/run.py +7 -2
- inspect_ai/_eval/task/run.py +4 -0
- inspect_ai/_util/content.py +3 -0
- inspect_ai/_util/logger.py +3 -0
- inspect_ai/_view/www/dist/assets/index.css +28 -16
- inspect_ai/_view/www/dist/assets/index.js +4811 -4609
- inspect_ai/_view/www/log-schema.json +79 -9
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +22 -4
- inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/CategoricalScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -2
- inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +1 -1
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +2 -2
- inspect_ai/_view/www/src/types/log.d.ts +11 -5
- inspect_ai/log/_recorders/json.py +8 -0
- inspect_ai/log/_transcript.py +13 -4
- inspect_ai/model/_call_tools.py +13 -4
- inspect_ai/model/_chat_message.py +3 -0
- inspect_ai/model/_model.py +5 -1
- inspect_ai/model/_model_output.py +6 -1
- inspect_ai/model/_openai.py +78 -10
- inspect_ai/model/_openai_responses.py +277 -0
- inspect_ai/model/_providers/anthropic.py +134 -75
- inspect_ai/model/_providers/azureai.py +2 -2
- inspect_ai/model/_providers/mistral.py +29 -13
- inspect_ai/model/_providers/openai.py +64 -57
- inspect_ai/model/_providers/openai_responses.py +177 -0
- inspect_ai/model/_providers/openrouter.py +52 -2
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/vertex.py +5 -2
- inspect_ai/tool/__init__.py +6 -0
- inspect_ai/tool/_tool.py +23 -3
- inspect_ai/tool/_tool_call.py +5 -2
- inspect_ai/tool/_tool_support_helpers.py +200 -0
- inspect_ai/tool/_tools/_bash_session.py +119 -0
- inspect_ai/tool/_tools/_computer/_computer.py +1 -1
- inspect_ai/tool/_tools/_text_editor.py +121 -0
- inspect_ai/tool/_tools/_think.py +48 -0
- inspect_ai/tool/_tools/_web_browser/_back_compat.py +150 -0
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +75 -130
- inspect_ai/tool/_tools/_web_search.py +1 -1
- inspect_ai/util/_json.py +28 -0
- inspect_ai/util/_sandbox/context.py +16 -7
- inspect_ai/util/_sandbox/docker/config.py +1 -1
- inspect_ai/util/_sandbox/docker/internal.py +3 -3
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/METADATA +5 -2
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/RECORD +56 -80
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/WHEEL +1 -1
- inspect_ai/model/_image.py +0 -15
- inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +0 -8
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +0 -24
- inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +0 -25
- inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +0 -22
- inspect_ai/tool/_tools/_web_browser/_resources/README.md +0 -63
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +0 -71
- inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +0 -323
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +0 -5
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +0 -279
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +0 -9
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +0 -293
- inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +0 -94
- inspect_ai/tool/_tools/_web_browser/_resources/constants.py +0 -2
- inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +0 -2
- inspect_ai/tool/_tools/_web_browser/_resources/mock_environment.py +0 -45
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +0 -50
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +0 -48
- inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +0 -280
- inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +0 -65
- inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +0 -64
- inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +0 -146
- inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +0 -64
- inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +0 -180
- inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +0 -99
- inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +0 -15
- inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +0 -44
- inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +0 -39
- inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +0 -214
- inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +0 -35
- inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +0 -192
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info/licenses}/LICENSE +0 -0
- {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
|
|
1
|
+
from logging import getLogger
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from openai import (
|
5
|
+
AsyncAzureOpenAI,
|
6
|
+
AsyncOpenAI,
|
7
|
+
BadRequestError,
|
8
|
+
)
|
9
|
+
from openai._types import NOT_GIVEN
|
10
|
+
from openai.types.responses import Response, ResponseFormatTextJSONSchemaConfigParam
|
11
|
+
|
12
|
+
from inspect_ai._util.logger import warn_once
|
13
|
+
from inspect_ai.tool import ToolChoice, ToolInfo
|
14
|
+
|
15
|
+
from .._chat_message import ChatMessage
|
16
|
+
from .._generate_config import GenerateConfig
|
17
|
+
from .._model_call import ModelCall
|
18
|
+
from .._model_output import (
|
19
|
+
ModelOutput,
|
20
|
+
ModelUsage,
|
21
|
+
)
|
22
|
+
from .._openai import (
|
23
|
+
OpenAIResponseError,
|
24
|
+
is_gpt,
|
25
|
+
is_o1_mini,
|
26
|
+
is_o1_preview,
|
27
|
+
is_o_series,
|
28
|
+
openai_handle_bad_request,
|
29
|
+
openai_media_filter,
|
30
|
+
)
|
31
|
+
from .._openai_responses import (
|
32
|
+
openai_responses_chat_choices,
|
33
|
+
openai_responses_inputs,
|
34
|
+
openai_responses_tool_choice,
|
35
|
+
openai_responses_tools,
|
36
|
+
)
|
37
|
+
from .util.hooks import HttpxHooks
|
38
|
+
|
39
|
+
logger = getLogger(__name__)
|
40
|
+
|
41
|
+
|
42
|
+
async def generate_responses(
|
43
|
+
client: AsyncAzureOpenAI | AsyncOpenAI,
|
44
|
+
http_hooks: HttpxHooks,
|
45
|
+
model_name: str,
|
46
|
+
input: list[ChatMessage],
|
47
|
+
tools: list[ToolInfo],
|
48
|
+
tool_choice: ToolChoice,
|
49
|
+
config: GenerateConfig,
|
50
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
51
|
+
# allocate request_id (so we can see it from ModelCall)
|
52
|
+
request_id = http_hooks.start_request()
|
53
|
+
|
54
|
+
# setup request and response for ModelCall
|
55
|
+
request: dict[str, Any] = {}
|
56
|
+
response: dict[str, Any] = {}
|
57
|
+
|
58
|
+
def model_call() -> ModelCall:
|
59
|
+
return ModelCall.create(
|
60
|
+
request=request,
|
61
|
+
response=response,
|
62
|
+
# TODO: is this the right filter?
|
63
|
+
filter=openai_media_filter,
|
64
|
+
time=http_hooks.end_request(request_id),
|
65
|
+
)
|
66
|
+
|
67
|
+
# prepare request (we do this so we can log the ModelCall)
|
68
|
+
request = dict(
|
69
|
+
input=await openai_responses_inputs(input, model_name),
|
70
|
+
tools=openai_responses_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
71
|
+
tool_choice=openai_responses_tool_choice(tool_choice)
|
72
|
+
if len(tools) > 0
|
73
|
+
else NOT_GIVEN,
|
74
|
+
extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
|
75
|
+
**completion_params_responses(model_name, config, len(tools) > 0),
|
76
|
+
)
|
77
|
+
|
78
|
+
try:
|
79
|
+
# generate response
|
80
|
+
model_response: Response = await client.responses.create(**request)
|
81
|
+
|
82
|
+
# check for error
|
83
|
+
if model_response.error is not None:
|
84
|
+
raise OpenAIResponseError(
|
85
|
+
code=model_response.error.code, message=model_response.error.message
|
86
|
+
)
|
87
|
+
|
88
|
+
# save response for model_call
|
89
|
+
response = model_response.model_dump()
|
90
|
+
|
91
|
+
# parse out choices
|
92
|
+
choices = openai_responses_chat_choices(model_response, tools)
|
93
|
+
|
94
|
+
# return output and call
|
95
|
+
return ModelOutput(
|
96
|
+
model=model_response.model,
|
97
|
+
choices=choices,
|
98
|
+
usage=(
|
99
|
+
ModelUsage(
|
100
|
+
input_tokens=model_response.usage.input_tokens,
|
101
|
+
output_tokens=model_response.usage.output_tokens,
|
102
|
+
input_tokens_cache_read=(
|
103
|
+
model_response.usage.input_tokens_details.cached_tokens
|
104
|
+
),
|
105
|
+
reasoning_tokens=model_response.usage.output_tokens_details.reasoning_tokens,
|
106
|
+
total_tokens=model_response.usage.total_tokens,
|
107
|
+
)
|
108
|
+
if model_response.usage
|
109
|
+
else None
|
110
|
+
),
|
111
|
+
), model_call()
|
112
|
+
except BadRequestError as e:
|
113
|
+
return openai_handle_bad_request(model_name, e), model_call()
|
114
|
+
|
115
|
+
|
116
|
+
def completion_params_responses(
|
117
|
+
model_name: str, config: GenerateConfig, tools: bool
|
118
|
+
) -> dict[str, Any]:
|
119
|
+
# TODO: we'll need a computer_use_preview bool for the 'include'
|
120
|
+
# and 'reasoning' parameters
|
121
|
+
def unsupported_warning(param: str) -> None:
|
122
|
+
warn_once(
|
123
|
+
logger,
|
124
|
+
f"OpenAI Responses API does not support the '{param}' parameter.",
|
125
|
+
)
|
126
|
+
|
127
|
+
params: dict[str, Any] = dict(model=model_name, store=False)
|
128
|
+
if config.max_tokens is not None:
|
129
|
+
params["max_output_tokens"] = config.max_tokens
|
130
|
+
if config.frequency_penalty is not None:
|
131
|
+
unsupported_warning("frequency_penalty")
|
132
|
+
if config.stop_seqs is not None:
|
133
|
+
unsupported_warning("stop_seqs")
|
134
|
+
if config.presence_penalty is not None:
|
135
|
+
unsupported_warning("presence_penalty")
|
136
|
+
if config.logit_bias is not None:
|
137
|
+
unsupported_warning("logit_bias")
|
138
|
+
if config.seed is not None:
|
139
|
+
unsupported_warning("seed")
|
140
|
+
if config.temperature is not None:
|
141
|
+
if is_o_series(model_name):
|
142
|
+
warn_once(
|
143
|
+
logger,
|
144
|
+
"o series models do not support the 'temperature' parameter (temperature is always 1).",
|
145
|
+
)
|
146
|
+
else:
|
147
|
+
params["temperature"] = config.temperature
|
148
|
+
if config.top_p is not None:
|
149
|
+
params["top_p"] = config.top_p
|
150
|
+
if config.num_choices is not None:
|
151
|
+
unsupported_warning("num_choices")
|
152
|
+
if config.logprobs is not None:
|
153
|
+
unsupported_warning("logprobs")
|
154
|
+
if config.top_logprobs is not None:
|
155
|
+
unsupported_warning("top_logprobs")
|
156
|
+
if tools and config.parallel_tool_calls is not None and not is_o_series(model_name):
|
157
|
+
params["parallel_tool_calls"] = config.parallel_tool_calls
|
158
|
+
if (
|
159
|
+
config.reasoning_effort is not None
|
160
|
+
and not is_gpt(model_name)
|
161
|
+
and not is_o1_mini(model_name)
|
162
|
+
and not is_o1_preview(model_name)
|
163
|
+
):
|
164
|
+
params["reasoning"] = dict(effort=config.reasoning_effort)
|
165
|
+
if config.response_schema is not None:
|
166
|
+
params["text"] = dict(
|
167
|
+
format=ResponseFormatTextJSONSchemaConfigParam(
|
168
|
+
type="json_schema",
|
169
|
+
name=config.response_schema.name,
|
170
|
+
schema=config.response_schema.json_schema.model_dump(exclude_none=True),
|
171
|
+
description=config.response_schema.description
|
172
|
+
or config.response_schema.name,
|
173
|
+
strict=config.response_schema.strict,
|
174
|
+
)
|
175
|
+
)
|
176
|
+
|
177
|
+
return params
|
@@ -1,9 +1,11 @@
|
|
1
|
+
import json
|
1
2
|
import os
|
2
|
-
from typing import Any
|
3
|
+
from typing import Any, TypedDict
|
3
4
|
|
4
|
-
from typing_extensions import override
|
5
|
+
from typing_extensions import NotRequired, override
|
5
6
|
|
6
7
|
from inspect_ai._util.error import PrerequisiteError
|
8
|
+
from inspect_ai.model._openai import OpenAIResponseError
|
7
9
|
from inspect_ai.model._providers.util import model_base_url
|
8
10
|
from inspect_ai.model._providers.util.util import environment_prerequisite_error
|
9
11
|
|
@@ -13,6 +15,28 @@ from .openai import OpenAIAPI
|
|
13
15
|
OPENROUTER_API_KEY = "OPENROUTER_API_KEY"
|
14
16
|
|
15
17
|
|
18
|
+
class ErrorResponse(TypedDict):
|
19
|
+
code: int
|
20
|
+
message: str
|
21
|
+
metadata: NotRequired[dict[str, Any]]
|
22
|
+
|
23
|
+
|
24
|
+
class OpenRouterError(Exception):
|
25
|
+
def __init__(self, response: ErrorResponse) -> None:
|
26
|
+
self.response = response
|
27
|
+
|
28
|
+
@property
|
29
|
+
def message(self) -> str:
|
30
|
+
return f"Error {self.response['code']} - {self.response['message']}"
|
31
|
+
|
32
|
+
def __str__(self) -> str:
|
33
|
+
return (
|
34
|
+
self.message + ("\n" + json.dumps(self.response["metadata"], indent=2))
|
35
|
+
if "metadata" in self.response
|
36
|
+
else ""
|
37
|
+
)
|
38
|
+
|
39
|
+
|
16
40
|
class OpenRouterAPI(OpenAIAPI):
|
17
41
|
def __init__(
|
18
42
|
self,
|
@@ -67,6 +91,32 @@ class OpenRouterAPI(OpenAIAPI):
|
|
67
91
|
**model_args,
|
68
92
|
)
|
69
93
|
|
94
|
+
@override
|
95
|
+
def on_response(self, response: dict[str, Any]) -> None:
|
96
|
+
"""Handle documented OpenRouter error conditions.
|
97
|
+
|
98
|
+
https://openrouter.ai/docs/api-reference/errors
|
99
|
+
"""
|
100
|
+
# check if open-router yielded an error (raise explicit
|
101
|
+
# OpenAIResponseError for cases where we should retry)
|
102
|
+
error: ErrorResponse | None = response.get("error", None)
|
103
|
+
if error is not None:
|
104
|
+
if error["code"] == 429:
|
105
|
+
raise OpenAIResponseError("rate_limit_exceeded", error["message"])
|
106
|
+
elif error["code"] in [408, 502]:
|
107
|
+
raise OpenAIResponseError("server_error", error["message"])
|
108
|
+
else:
|
109
|
+
raise OpenRouterError(error)
|
110
|
+
|
111
|
+
# check for an empty response (which they document can occur on
|
112
|
+
# startup). for this we'll return a "server_error" which will
|
113
|
+
# trigger a retry w/ exponential backoff
|
114
|
+
elif response.get("choices", None) is None:
|
115
|
+
raise OpenAIResponseError(
|
116
|
+
"server_error",
|
117
|
+
"Model is warming up, please retry again after waiting for warmup.",
|
118
|
+
)
|
119
|
+
|
70
120
|
@override
|
71
121
|
def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
|
72
122
|
# default params
|
@@ -34,8 +34,8 @@ from inspect_ai._util.content import (
|
|
34
34
|
Content,
|
35
35
|
ContentAudio,
|
36
36
|
ContentImage,
|
37
|
+
ContentReasoning,
|
37
38
|
ContentText,
|
38
|
-
ContentVideo,
|
39
39
|
)
|
40
40
|
from inspect_ai._util.http import is_retryable_http_status
|
41
41
|
from inspect_ai._util.images import file_as_data
|
@@ -336,10 +336,13 @@ async def content_part(content: Content | str) -> Part:
|
|
336
336
|
elif isinstance(content, ContentImage):
|
337
337
|
image_bytes, mime_type = await file_as_data(content.image)
|
338
338
|
return Part.from_image(image=Image.from_bytes(data=image_bytes))
|
339
|
+
elif isinstance(content, ContentReasoning):
|
340
|
+
return Part.from_text(content.reasoning or NO_CONTENT)
|
339
341
|
else:
|
340
342
|
if isinstance(content, ContentAudio):
|
341
343
|
file = content.audio
|
342
|
-
|
344
|
+
else:
|
345
|
+
# it's ContentVideo
|
343
346
|
file = content.video
|
344
347
|
file_bytes, mime_type = await file_as_data(file)
|
345
348
|
return Part.from_data(file_bytes, mime_type)
|
inspect_ai/tool/__init__.py
CHANGED
@@ -22,17 +22,23 @@ from ._tool_def import ToolDef
|
|
22
22
|
from ._tool_info import ToolInfo
|
23
23
|
from ._tool_params import ToolParam, ToolParams
|
24
24
|
from ._tool_with import tool_with
|
25
|
+
from ._tools._bash_session import bash_session
|
25
26
|
from ._tools._computer import computer
|
26
27
|
from ._tools._execute import bash, python
|
28
|
+
from ._tools._text_editor import text_editor
|
29
|
+
from ._tools._think import think
|
27
30
|
from ._tools._web_browser import web_browser
|
28
31
|
from ._tools._web_search import web_search
|
29
32
|
|
30
33
|
__all__ = [
|
31
34
|
"bash",
|
35
|
+
"bash_session",
|
32
36
|
"computer",
|
33
37
|
"python",
|
34
38
|
"web_browser",
|
35
39
|
"web_search",
|
40
|
+
"think",
|
41
|
+
"text_editor",
|
36
42
|
"tool",
|
37
43
|
"tool_with",
|
38
44
|
"Tool",
|
inspect_ai/tool/_tool.py
CHANGED
@@ -20,6 +20,7 @@ from inspect_ai._util.content import (
|
|
20
20
|
)
|
21
21
|
from inspect_ai._util.registry import (
|
22
22
|
RegistryInfo,
|
23
|
+
is_registry_object,
|
23
24
|
registry_add,
|
24
25
|
registry_name,
|
25
26
|
registry_tag,
|
@@ -200,7 +201,25 @@ def tool(
|
|
200
201
|
# wrap instantiations of scorer so they carry registry info and metrics
|
201
202
|
@wraps(tool_type)
|
202
203
|
def tool_wrapper(*args: P.args, **kwargs: P.kwargs) -> Tool:
|
204
|
+
# create the tool
|
203
205
|
tool = tool_type(*args, **kwargs)
|
206
|
+
|
207
|
+
# this might already have registry info, in that case
|
208
|
+
# capture it and use it as defaults
|
209
|
+
from inspect_ai.tool._tool_def import tool_registry_info
|
210
|
+
|
211
|
+
tool_parallel = parallel
|
212
|
+
tool_viewer = viewer
|
213
|
+
tool_model_input = model_input
|
214
|
+
if is_registry_object(tool):
|
215
|
+
_, _, reg_parallel, reg_viewer, reg_model_input = tool_registry_info(
|
216
|
+
tool
|
217
|
+
)
|
218
|
+
tool_parallel = parallel and reg_parallel
|
219
|
+
tool_viewer = viewer or reg_viewer
|
220
|
+
tool_model_input = model_input or reg_model_input
|
221
|
+
|
222
|
+
# tag the object
|
204
223
|
registry_tag(
|
205
224
|
tool_type,
|
206
225
|
tool,
|
@@ -209,10 +228,11 @@ def tool(
|
|
209
228
|
name=tool_name,
|
210
229
|
metadata={
|
211
230
|
TOOL_PROMPT: prompt,
|
212
|
-
TOOL_PARALLEL:
|
213
|
-
TOOL_VIEWER:
|
231
|
+
TOOL_PARALLEL: tool_parallel,
|
232
|
+
TOOL_VIEWER: tool_viewer,
|
214
233
|
TOOL_MODEL_INPUT: (
|
215
|
-
|
234
|
+
tool_model_input
|
235
|
+
or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
|
216
236
|
),
|
217
237
|
},
|
218
238
|
),
|
inspect_ai/tool/_tool_call.py
CHANGED
@@ -44,8 +44,11 @@ class ToolCall:
|
|
44
44
|
arguments: dict[str, Any]
|
45
45
|
"""Arguments to function."""
|
46
46
|
|
47
|
-
type:
|
48
|
-
"""Type of tool call (
|
47
|
+
type: str
|
48
|
+
"""Type of tool call ('function' or a model specific internal tool type)"""
|
49
|
+
|
50
|
+
internal_name: str | None = field(default=None)
|
51
|
+
"""Model's internal name for the tool - if any."""
|
49
52
|
|
50
53
|
parse_error: str | None = field(default=None)
|
51
54
|
"""Error which occurred parsing tool call."""
|
@@ -0,0 +1,200 @@
|
|
1
|
+
"""
|
2
|
+
This module provides helper code for handling JSON-RPC communication between the inspect process and the `inspect-tool-support` package code running in the sandbox environment.
|
3
|
+
|
4
|
+
It includes definitions for JSON-RPC request and response models, as well as functions to create and parse JSON-RPC requests and responses.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import json
|
8
|
+
from itertools import count
|
9
|
+
from textwrap import dedent
|
10
|
+
from typing import Literal, Type, TypeVar, cast
|
11
|
+
|
12
|
+
from pydantic import BaseModel, RootModel
|
13
|
+
|
14
|
+
from inspect_ai._util.error import PrerequisiteError
|
15
|
+
from inspect_ai.tool._tool import ToolError, ToolParsingError
|
16
|
+
from inspect_ai.util import sandbox_with
|
17
|
+
from inspect_ai.util._sandbox.environment import SandboxEnvironment
|
18
|
+
|
19
|
+
|
20
|
+
class JSONRPCResponseBase(BaseModel):
|
21
|
+
jsonrpc: Literal["2.0"]
|
22
|
+
id: int | float | str
|
23
|
+
|
24
|
+
|
25
|
+
class JSONRPCSuccessResponse(JSONRPCResponseBase):
|
26
|
+
result: object
|
27
|
+
|
28
|
+
|
29
|
+
class JSONRPCError(BaseModel):
|
30
|
+
"""See: https://www.jsonrpc.org/specification#error_object"""
|
31
|
+
|
32
|
+
code: int
|
33
|
+
message: str
|
34
|
+
data: object | None = None
|
35
|
+
|
36
|
+
|
37
|
+
class JSONRPCErrorResponse(JSONRPCResponseBase):
|
38
|
+
error: JSONRPCError
|
39
|
+
|
40
|
+
|
41
|
+
class JSONRPCResponse(RootModel[JSONRPCSuccessResponse | JSONRPCErrorResponse]):
|
42
|
+
pass
|
43
|
+
|
44
|
+
|
45
|
+
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
46
|
+
StrOrModelT = TypeVar("StrOrModelT", bound=str | BaseModel)
|
47
|
+
|
48
|
+
id_generator = count(666)
|
49
|
+
|
50
|
+
|
51
|
+
async def exec_sandbox_rpc(
|
52
|
+
sandbox: SandboxEnvironment,
|
53
|
+
method: str,
|
54
|
+
params: dict[str, object] | tuple[object, ...],
|
55
|
+
result_cls: Type[StrOrModelT],
|
56
|
+
timeout: int | None = None,
|
57
|
+
user: str | None = None,
|
58
|
+
) -> StrOrModelT:
|
59
|
+
"""
|
60
|
+
Execute a JSON-RPC command to a sandbox environment.
|
61
|
+
|
62
|
+
Note that the JSON RPC request is sent to the exec'ed program via stdin.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
sandbox (SandboxEnvironment): The sandbox environment to execute the command in.
|
66
|
+
method (str): The JSON-RPC method to call.
|
67
|
+
params (dict[str, object] | tuple[object, ...]): The parameters for the JSON-RPC method.
|
68
|
+
result_cls (Type[BaseModelT]): The class to use for parsing the result.
|
69
|
+
timeout (int | None, optional): The timeout for the execution. Defaults to None.
|
70
|
+
user: Optional username or UID to run the command as.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
BaseModelT: The parsed result of the JSON-RPC call.
|
74
|
+
|
75
|
+
Raises:
|
76
|
+
RuntimeError: If the sandbox execution fails or if there is an error in the JSON-RPC response.
|
77
|
+
ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
|
78
|
+
"""
|
79
|
+
exec_result = await sandbox.exec(
|
80
|
+
[SANDBOX_CLI, "exec"],
|
81
|
+
input=_create_json_rpc_request(method, params),
|
82
|
+
timeout=timeout,
|
83
|
+
user=user,
|
84
|
+
)
|
85
|
+
|
86
|
+
if not exec_result.success:
|
87
|
+
raise RuntimeError(
|
88
|
+
f"Sandbox.exec failure executing {_rpc_call_description(method, params)}: {exec_result.stderr}"
|
89
|
+
)
|
90
|
+
|
91
|
+
match _parse_json_rpc_response(exec_result.stdout, result_cls):
|
92
|
+
case JSONRPCError(code=-32601 | -32602, message=message):
|
93
|
+
raise ToolParsingError(message)
|
94
|
+
case JSONRPCError(code=-32000, message=message):
|
95
|
+
raise ToolError(message)
|
96
|
+
case JSONRPCError(code=code, message=message):
|
97
|
+
raise RuntimeError(
|
98
|
+
f"Error executing tool command {_rpc_call_description(method, params)}: {code=} {message}"
|
99
|
+
)
|
100
|
+
# case result_cls() as model: yields a mypy error since it has narrowed model down
|
101
|
+
# to BaseModel and not BaseModelT. ???
|
102
|
+
case model if isinstance(model, result_cls):
|
103
|
+
return model
|
104
|
+
case not_possible:
|
105
|
+
raise RuntimeError(
|
106
|
+
f"Error executing tool command {_rpc_call_description(method, params)}: {not_possible}"
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
SANDBOX_CLI = "inspect-tool-support"
|
111
|
+
INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB = "aisiuk/inspect-tool-support"
|
112
|
+
|
113
|
+
|
114
|
+
async def tool_container_sandbox(tool_name: str) -> SandboxEnvironment:
|
115
|
+
sb = await sandbox_with(SANDBOX_CLI, True)
|
116
|
+
if sb:
|
117
|
+
return sb
|
118
|
+
else:
|
119
|
+
msg = dedent(f"""
|
120
|
+
The {tool_name} service was not found in any of the sandboxes for this sample. Please add the {tool_name} to your configuration.
|
121
|
+
|
122
|
+
For example, the following Docker compose file uses the {INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB} reference image as its default sandbox:
|
123
|
+
|
124
|
+
services:
|
125
|
+
default:
|
126
|
+
image: "{INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB}"
|
127
|
+
init: true
|
128
|
+
|
129
|
+
Alternatively, you can include the service into your own Dockerfile:
|
130
|
+
|
131
|
+
RUN python -m venv /opt/inspect_tool_support
|
132
|
+
ENV PATH="/opt/inspect_tool_support/bin:$PATH"
|
133
|
+
RUN pip install inspect-tool-support
|
134
|
+
RUN inspect-tool-support post-install
|
135
|
+
""").strip()
|
136
|
+
raise PrerequisiteError(msg)
|
137
|
+
|
138
|
+
|
139
|
+
def _create_json_rpc_request(
|
140
|
+
method: str, params: dict[str, object] | tuple[object, ...]
|
141
|
+
) -> str:
|
142
|
+
return json.dumps(
|
143
|
+
{
|
144
|
+
"jsonrpc": "2.0",
|
145
|
+
"method": method,
|
146
|
+
"id": next(id_generator),
|
147
|
+
"params": list(params) if isinstance(params, tuple) else params,
|
148
|
+
}
|
149
|
+
)
|
150
|
+
|
151
|
+
|
152
|
+
def _rpc_call_description(
|
153
|
+
method: str, params: dict[str, object] | tuple[object, ...]
|
154
|
+
) -> str:
|
155
|
+
"""
|
156
|
+
Generate a string description of an RPC call.
|
157
|
+
|
158
|
+
Args:
|
159
|
+
method (str): The name of the RPC method.
|
160
|
+
params (dict[str, object] | tuple[object, ...]): The parameters for the RPC method.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
str: A string description of the RPC call.
|
164
|
+
|
165
|
+
Examples:
|
166
|
+
>>> _rpc_call_description("subtract", {"minuend": 42, "subtrahend": 23})
|
167
|
+
'subtract(minuend: 42, subtrahend: 23)'
|
168
|
+
|
169
|
+
>>> _rpc_call_description("subtract", (42, 23))
|
170
|
+
'subtract(42, 23)'
|
171
|
+
"""
|
172
|
+
normalized_params = (
|
173
|
+
list(map(str, params))
|
174
|
+
if isinstance(params, tuple)
|
175
|
+
else [f"{k}: {v}" for k, v in params.items()]
|
176
|
+
)
|
177
|
+
return f"{method}({', '.join(normalized_params)})"
|
178
|
+
|
179
|
+
|
180
|
+
def _parse_json_rpc_response(
|
181
|
+
response_str: str,
|
182
|
+
result_cls: Type[StrOrModelT],
|
183
|
+
) -> StrOrModelT | JSONRPCError:
|
184
|
+
match JSONRPCResponse.model_validate_json(response_str).root:
|
185
|
+
case JSONRPCErrorResponse(error=error):
|
186
|
+
return error
|
187
|
+
case JSONRPCSuccessResponse(result=rpc_result):
|
188
|
+
# TODO: Wow. Is there really no way to convince Python to narrow these types
|
189
|
+
# and avoid the cast's
|
190
|
+
if result_cls is str:
|
191
|
+
if not isinstance(rpc_result, str):
|
192
|
+
raise ValueError(f"Expected string result, got {type(rpc_result)}")
|
193
|
+
return cast(StrOrModelT, rpc_result)
|
194
|
+
else:
|
195
|
+
return cast(
|
196
|
+
StrOrModelT,
|
197
|
+
cast(BaseModel, result_cls).model_validate(rpc_result, strict=True),
|
198
|
+
)
|
199
|
+
case _:
|
200
|
+
raise ValueError(f"Unexpected JSON RPC response: {response_str}")
|
@@ -0,0 +1,119 @@
|
|
1
|
+
from pydantic import BaseModel, Field, RootModel
|
2
|
+
|
3
|
+
from inspect_ai.tool import ToolResult
|
4
|
+
from inspect_ai.tool._tool_support_helpers import (
|
5
|
+
exec_sandbox_rpc,
|
6
|
+
tool_container_sandbox,
|
7
|
+
)
|
8
|
+
from inspect_ai.util import StoreModel, store_as
|
9
|
+
|
10
|
+
from .._tool import Tool, ToolParsingError, tool
|
11
|
+
from .._tool_call import ToolCall, ToolCallContent, ToolCallView, ToolCallViewer
|
12
|
+
|
13
|
+
|
14
|
+
# These models are cloned from the container code. If/when we decide to create
|
15
|
+
# a package that is shared between the inspect and tool-container codebases, we'll
|
16
|
+
# just have to live with it.
|
17
|
+
class NewSessionResult(BaseModel):
|
18
|
+
session_name: str
|
19
|
+
|
20
|
+
|
21
|
+
class BashRestartResult(BaseModel):
|
22
|
+
pass
|
23
|
+
|
24
|
+
|
25
|
+
class BashCommandResult(BaseModel):
|
26
|
+
status: int
|
27
|
+
stdout: str
|
28
|
+
stderr: str
|
29
|
+
|
30
|
+
|
31
|
+
class BashResult(RootModel[BashRestartResult | BashCommandResult]):
|
32
|
+
pass
|
33
|
+
|
34
|
+
|
35
|
+
class BashSessionStore(StoreModel):
|
36
|
+
session_id: str = Field(default_factory=str)
|
37
|
+
|
38
|
+
|
39
|
+
# custom viewer for bash
|
40
|
+
def code_viewer(language: str, code_param: str) -> ToolCallViewer:
|
41
|
+
def viewer(tool_call: ToolCall) -> ToolCallView:
|
42
|
+
code = tool_call.arguments.get(code_param, None)
|
43
|
+
code = (code or tool_call.function).strip()
|
44
|
+
call = ToolCallContent(
|
45
|
+
title=language,
|
46
|
+
format="markdown",
|
47
|
+
content=f"```{language}\n" + code + "\n```\n",
|
48
|
+
)
|
49
|
+
return ToolCallView(call=call)
|
50
|
+
|
51
|
+
return viewer
|
52
|
+
|
53
|
+
|
54
|
+
@tool(viewer=code_viewer("bash", "command"))
|
55
|
+
def bash_session(timeout: int | None = None) -> Tool:
|
56
|
+
"""Bash shell session command execution tool.
|
57
|
+
|
58
|
+
Execute bash shell commands in a long running session using a sandbox environment (e.g. "docker").
|
59
|
+
|
60
|
+
Args:
|
61
|
+
timeout: Timeout (in seconds) for command.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
String with command output (stdout) or command error (stderr).
|
65
|
+
"""
|
66
|
+
|
67
|
+
async def execute(
|
68
|
+
command: str | None = None,
|
69
|
+
restart: bool | None = None,
|
70
|
+
) -> ToolResult:
|
71
|
+
"""
|
72
|
+
Use this function to execute bash commands.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
command: The bash command to run. Required unless the tool is being restarted.
|
76
|
+
restart: Specifying true will restart this tool. Otherwise, leave this unspecified.
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
The output of the command.
|
80
|
+
"""
|
81
|
+
if not ((command is None) ^ (restart is None)):
|
82
|
+
raise ToolParsingError(
|
83
|
+
"Either 'command' or 'restart' must be specified, but not both."
|
84
|
+
)
|
85
|
+
params: dict[str, object] = {"command": command, "restart": restart}
|
86
|
+
|
87
|
+
sandbox = await tool_container_sandbox("bash session")
|
88
|
+
store = store_as(BashSessionStore)
|
89
|
+
|
90
|
+
if not store.session_id:
|
91
|
+
store.session_id = (
|
92
|
+
await exec_sandbox_rpc(
|
93
|
+
sandbox,
|
94
|
+
"bash_session_new_session",
|
95
|
+
{},
|
96
|
+
NewSessionResult,
|
97
|
+
timeout=timeout,
|
98
|
+
)
|
99
|
+
).session_name
|
100
|
+
|
101
|
+
params["session_name"] = store.session_id
|
102
|
+
|
103
|
+
result = (
|
104
|
+
await exec_sandbox_rpc(
|
105
|
+
sandbox,
|
106
|
+
"bash_session",
|
107
|
+
params,
|
108
|
+
BashResult,
|
109
|
+
timeout=timeout,
|
110
|
+
)
|
111
|
+
).root
|
112
|
+
|
113
|
+
if isinstance(result, BashRestartResult):
|
114
|
+
return "Bash session restarted."
|
115
|
+
|
116
|
+
# return output (including stderr if any)
|
117
|
+
return f"{result.stderr}\n{result.stdout}" if result.stderr else result.stdout
|
118
|
+
|
119
|
+
return execute
|