inspect-ai 0.3.75__py3-none-any.whl → 0.3.77__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. inspect_ai/_cli/eval.py +16 -0
  2. inspect_ai/_display/core/results.py +6 -1
  3. inspect_ai/_eval/eval.py +8 -1
  4. inspect_ai/_eval/evalset.py +6 -2
  5. inspect_ai/_eval/registry.py +3 -5
  6. inspect_ai/_eval/run.py +7 -2
  7. inspect_ai/_eval/task/run.py +4 -0
  8. inspect_ai/_util/content.py +3 -0
  9. inspect_ai/_util/logger.py +3 -0
  10. inspect_ai/_view/www/dist/assets/index.css +28 -16
  11. inspect_ai/_view/www/dist/assets/index.js +4811 -4609
  12. inspect_ai/_view/www/log-schema.json +79 -9
  13. inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +22 -4
  14. inspect_ai/_view/www/src/samples/chat/tools/ToolInput.tsx +1 -1
  15. inspect_ai/_view/www/src/samples/descriptor/score/CategoricalScoreDescriptor.tsx +1 -1
  16. inspect_ai/_view/www/src/samples/descriptor/score/NumericScoreDescriptor.tsx +2 -2
  17. inspect_ai/_view/www/src/samples/sample-tools/SortFilter.tsx +1 -1
  18. inspect_ai/_view/www/src/samples/transcript/ModelEventView.module.css +2 -2
  19. inspect_ai/_view/www/src/types/log.d.ts +11 -5
  20. inspect_ai/log/_recorders/json.py +8 -0
  21. inspect_ai/log/_transcript.py +13 -4
  22. inspect_ai/model/_call_tools.py +13 -4
  23. inspect_ai/model/_chat_message.py +3 -0
  24. inspect_ai/model/_model.py +5 -1
  25. inspect_ai/model/_model_output.py +6 -1
  26. inspect_ai/model/_openai.py +78 -10
  27. inspect_ai/model/_openai_responses.py +277 -0
  28. inspect_ai/model/_providers/anthropic.py +134 -75
  29. inspect_ai/model/_providers/azureai.py +2 -2
  30. inspect_ai/model/_providers/mistral.py +29 -13
  31. inspect_ai/model/_providers/openai.py +64 -57
  32. inspect_ai/model/_providers/openai_responses.py +177 -0
  33. inspect_ai/model/_providers/openrouter.py +52 -2
  34. inspect_ai/model/_providers/providers.py +1 -1
  35. inspect_ai/model/_providers/vertex.py +5 -2
  36. inspect_ai/tool/__init__.py +6 -0
  37. inspect_ai/tool/_tool.py +23 -3
  38. inspect_ai/tool/_tool_call.py +5 -2
  39. inspect_ai/tool/_tool_support_helpers.py +200 -0
  40. inspect_ai/tool/_tools/_bash_session.py +119 -0
  41. inspect_ai/tool/_tools/_computer/_computer.py +1 -1
  42. inspect_ai/tool/_tools/_text_editor.py +121 -0
  43. inspect_ai/tool/_tools/_think.py +48 -0
  44. inspect_ai/tool/_tools/_web_browser/_back_compat.py +150 -0
  45. inspect_ai/tool/_tools/_web_browser/_web_browser.py +75 -130
  46. inspect_ai/tool/_tools/_web_search.py +1 -1
  47. inspect_ai/util/_json.py +28 -0
  48. inspect_ai/util/_sandbox/context.py +16 -7
  49. inspect_ai/util/_sandbox/docker/config.py +1 -1
  50. inspect_ai/util/_sandbox/docker/internal.py +3 -3
  51. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/METADATA +5 -2
  52. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/RECORD +56 -80
  53. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/WHEEL +1 -1
  54. inspect_ai/model/_image.py +0 -15
  55. inspect_ai/tool/_tools/_web_browser/_resources/.pylintrc +0 -8
  56. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/launch.json +0 -24
  57. inspect_ai/tool/_tools/_web_browser/_resources/.vscode/settings.json +0 -25
  58. inspect_ai/tool/_tools/_web_browser/_resources/Dockerfile +0 -22
  59. inspect_ai/tool/_tools/_web_browser/_resources/README.md +0 -63
  60. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree.py +0 -71
  61. inspect_ai/tool/_tools/_web_browser/_resources/accessibility_tree_node.py +0 -323
  62. inspect_ai/tool/_tools/_web_browser/_resources/cdp/__init__.py +0 -5
  63. inspect_ai/tool/_tools/_web_browser/_resources/cdp/a11y.py +0 -279
  64. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom.py +0 -9
  65. inspect_ai/tool/_tools/_web_browser/_resources/cdp/dom_snapshot.py +0 -293
  66. inspect_ai/tool/_tools/_web_browser/_resources/cdp/page.py +0 -94
  67. inspect_ai/tool/_tools/_web_browser/_resources/constants.py +0 -2
  68. inspect_ai/tool/_tools/_web_browser/_resources/images/usage_diagram.svg +0 -2
  69. inspect_ai/tool/_tools/_web_browser/_resources/mock_environment.py +0 -45
  70. inspect_ai/tool/_tools/_web_browser/_resources/playwright_browser.py +0 -50
  71. inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +0 -48
  72. inspect_ai/tool/_tools/_web_browser/_resources/playwright_page_crawler.py +0 -280
  73. inspect_ai/tool/_tools/_web_browser/_resources/pyproject.toml +0 -65
  74. inspect_ai/tool/_tools/_web_browser/_resources/rectangle.py +0 -64
  75. inspect_ai/tool/_tools/_web_browser/_resources/rpc_client_helpers.py +0 -146
  76. inspect_ai/tool/_tools/_web_browser/_resources/scale_factor.py +0 -64
  77. inspect_ai/tool/_tools/_web_browser/_resources/test_accessibility_tree_node.py +0 -180
  78. inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +0 -99
  79. inspect_ai/tool/_tools/_web_browser/_resources/test_rectangle.py +0 -15
  80. inspect_ai/tool/_tools/_web_browser/_resources/test_web_client.py +0 -44
  81. inspect_ai/tool/_tools/_web_browser/_resources/web_browser_rpc_types.py +0 -39
  82. inspect_ai/tool/_tools/_web_browser/_resources/web_client.py +0 -214
  83. inspect_ai/tool/_tools/_web_browser/_resources/web_client_new_session.py +0 -35
  84. inspect_ai/tool/_tools/_web_browser/_resources/web_server.py +0 -192
  85. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/entry_points.txt +0 -0
  86. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info/licenses}/LICENSE +0 -0
  87. {inspect_ai-0.3.75.dist-info → inspect_ai-0.3.77.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,177 @@
1
+ from logging import getLogger
2
+ from typing import Any
3
+
4
+ from openai import (
5
+ AsyncAzureOpenAI,
6
+ AsyncOpenAI,
7
+ BadRequestError,
8
+ )
9
+ from openai._types import NOT_GIVEN
10
+ from openai.types.responses import Response, ResponseFormatTextJSONSchemaConfigParam
11
+
12
+ from inspect_ai._util.logger import warn_once
13
+ from inspect_ai.tool import ToolChoice, ToolInfo
14
+
15
+ from .._chat_message import ChatMessage
16
+ from .._generate_config import GenerateConfig
17
+ from .._model_call import ModelCall
18
+ from .._model_output import (
19
+ ModelOutput,
20
+ ModelUsage,
21
+ )
22
+ from .._openai import (
23
+ OpenAIResponseError,
24
+ is_gpt,
25
+ is_o1_mini,
26
+ is_o1_preview,
27
+ is_o_series,
28
+ openai_handle_bad_request,
29
+ openai_media_filter,
30
+ )
31
+ from .._openai_responses import (
32
+ openai_responses_chat_choices,
33
+ openai_responses_inputs,
34
+ openai_responses_tool_choice,
35
+ openai_responses_tools,
36
+ )
37
+ from .util.hooks import HttpxHooks
38
+
39
+ logger = getLogger(__name__)
40
+
41
+
42
+ async def generate_responses(
43
+ client: AsyncAzureOpenAI | AsyncOpenAI,
44
+ http_hooks: HttpxHooks,
45
+ model_name: str,
46
+ input: list[ChatMessage],
47
+ tools: list[ToolInfo],
48
+ tool_choice: ToolChoice,
49
+ config: GenerateConfig,
50
+ ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
51
+ # allocate request_id (so we can see it from ModelCall)
52
+ request_id = http_hooks.start_request()
53
+
54
+ # setup request and response for ModelCall
55
+ request: dict[str, Any] = {}
56
+ response: dict[str, Any] = {}
57
+
58
+ def model_call() -> ModelCall:
59
+ return ModelCall.create(
60
+ request=request,
61
+ response=response,
62
+ # TODO: is this the right filter?
63
+ filter=openai_media_filter,
64
+ time=http_hooks.end_request(request_id),
65
+ )
66
+
67
+ # prepare request (we do this so we can log the ModelCall)
68
+ request = dict(
69
+ input=await openai_responses_inputs(input, model_name),
70
+ tools=openai_responses_tools(tools) if len(tools) > 0 else NOT_GIVEN,
71
+ tool_choice=openai_responses_tool_choice(tool_choice)
72
+ if len(tools) > 0
73
+ else NOT_GIVEN,
74
+ extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
75
+ **completion_params_responses(model_name, config, len(tools) > 0),
76
+ )
77
+
78
+ try:
79
+ # generate response
80
+ model_response: Response = await client.responses.create(**request)
81
+
82
+ # check for error
83
+ if model_response.error is not None:
84
+ raise OpenAIResponseError(
85
+ code=model_response.error.code, message=model_response.error.message
86
+ )
87
+
88
+ # save response for model_call
89
+ response = model_response.model_dump()
90
+
91
+ # parse out choices
92
+ choices = openai_responses_chat_choices(model_response, tools)
93
+
94
+ # return output and call
95
+ return ModelOutput(
96
+ model=model_response.model,
97
+ choices=choices,
98
+ usage=(
99
+ ModelUsage(
100
+ input_tokens=model_response.usage.input_tokens,
101
+ output_tokens=model_response.usage.output_tokens,
102
+ input_tokens_cache_read=(
103
+ model_response.usage.input_tokens_details.cached_tokens
104
+ ),
105
+ reasoning_tokens=model_response.usage.output_tokens_details.reasoning_tokens,
106
+ total_tokens=model_response.usage.total_tokens,
107
+ )
108
+ if model_response.usage
109
+ else None
110
+ ),
111
+ ), model_call()
112
+ except BadRequestError as e:
113
+ return openai_handle_bad_request(model_name, e), model_call()
114
+
115
+
116
+ def completion_params_responses(
117
+ model_name: str, config: GenerateConfig, tools: bool
118
+ ) -> dict[str, Any]:
119
+ # TODO: we'll need a computer_use_preview bool for the 'include'
120
+ # and 'reasoning' parameters
121
+ def unsupported_warning(param: str) -> None:
122
+ warn_once(
123
+ logger,
124
+ f"OpenAI Responses API does not support the '{param}' parameter.",
125
+ )
126
+
127
+ params: dict[str, Any] = dict(model=model_name, store=False)
128
+ if config.max_tokens is not None:
129
+ params["max_output_tokens"] = config.max_tokens
130
+ if config.frequency_penalty is not None:
131
+ unsupported_warning("frequency_penalty")
132
+ if config.stop_seqs is not None:
133
+ unsupported_warning("stop_seqs")
134
+ if config.presence_penalty is not None:
135
+ unsupported_warning("presence_penalty")
136
+ if config.logit_bias is not None:
137
+ unsupported_warning("logit_bias")
138
+ if config.seed is not None:
139
+ unsupported_warning("seed")
140
+ if config.temperature is not None:
141
+ if is_o_series(model_name):
142
+ warn_once(
143
+ logger,
144
+ "o series models do not support the 'temperature' parameter (temperature is always 1).",
145
+ )
146
+ else:
147
+ params["temperature"] = config.temperature
148
+ if config.top_p is not None:
149
+ params["top_p"] = config.top_p
150
+ if config.num_choices is not None:
151
+ unsupported_warning("num_choices")
152
+ if config.logprobs is not None:
153
+ unsupported_warning("logprobs")
154
+ if config.top_logprobs is not None:
155
+ unsupported_warning("top_logprobs")
156
+ if tools and config.parallel_tool_calls is not None and not is_o_series(model_name):
157
+ params["parallel_tool_calls"] = config.parallel_tool_calls
158
+ if (
159
+ config.reasoning_effort is not None
160
+ and not is_gpt(model_name)
161
+ and not is_o1_mini(model_name)
162
+ and not is_o1_preview(model_name)
163
+ ):
164
+ params["reasoning"] = dict(effort=config.reasoning_effort)
165
+ if config.response_schema is not None:
166
+ params["text"] = dict(
167
+ format=ResponseFormatTextJSONSchemaConfigParam(
168
+ type="json_schema",
169
+ name=config.response_schema.name,
170
+ schema=config.response_schema.json_schema.model_dump(exclude_none=True),
171
+ description=config.response_schema.description
172
+ or config.response_schema.name,
173
+ strict=config.response_schema.strict,
174
+ )
175
+ )
176
+
177
+ return params
@@ -1,9 +1,11 @@
1
+ import json
1
2
  import os
2
- from typing import Any
3
+ from typing import Any, TypedDict
3
4
 
4
- from typing_extensions import override
5
+ from typing_extensions import NotRequired, override
5
6
 
6
7
  from inspect_ai._util.error import PrerequisiteError
8
+ from inspect_ai.model._openai import OpenAIResponseError
7
9
  from inspect_ai.model._providers.util import model_base_url
8
10
  from inspect_ai.model._providers.util.util import environment_prerequisite_error
9
11
 
@@ -13,6 +15,28 @@ from .openai import OpenAIAPI
13
15
  OPENROUTER_API_KEY = "OPENROUTER_API_KEY"
14
16
 
15
17
 
18
+ class ErrorResponse(TypedDict):
19
+ code: int
20
+ message: str
21
+ metadata: NotRequired[dict[str, Any]]
22
+
23
+
24
+ class OpenRouterError(Exception):
25
+ def __init__(self, response: ErrorResponse) -> None:
26
+ self.response = response
27
+
28
+ @property
29
+ def message(self) -> str:
30
+ return f"Error {self.response['code']} - {self.response['message']}"
31
+
32
+ def __str__(self) -> str:
33
+ return (
34
+ self.message + ("\n" + json.dumps(self.response["metadata"], indent=2))
35
+ if "metadata" in self.response
36
+ else ""
37
+ )
38
+
39
+
16
40
  class OpenRouterAPI(OpenAIAPI):
17
41
  def __init__(
18
42
  self,
@@ -67,6 +91,32 @@ class OpenRouterAPI(OpenAIAPI):
67
91
  **model_args,
68
92
  )
69
93
 
94
+ @override
95
+ def on_response(self, response: dict[str, Any]) -> None:
96
+ """Handle documented OpenRouter error conditions.
97
+
98
+ https://openrouter.ai/docs/api-reference/errors
99
+ """
100
+ # check if open-router yielded an error (raise explicit
101
+ # OpenAIResponseError for cases where we should retry)
102
+ error: ErrorResponse | None = response.get("error", None)
103
+ if error is not None:
104
+ if error["code"] == 429:
105
+ raise OpenAIResponseError("rate_limit_exceeded", error["message"])
106
+ elif error["code"] in [408, 502]:
107
+ raise OpenAIResponseError("server_error", error["message"])
108
+ else:
109
+ raise OpenRouterError(error)
110
+
111
+ # check for an empty response (which they document can occur on
112
+ # startup). for this we'll return a "server_error" which will
113
+ # trigger a retry w/ exponential backoff
114
+ elif response.get("choices", None) is None:
115
+ raise OpenAIResponseError(
116
+ "server_error",
117
+ "Model is warming up, please retry again after waiting for warmup.",
118
+ )
119
+
70
120
  @override
71
121
  def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, Any]:
72
122
  # default params
@@ -282,7 +282,7 @@ def goodfire() -> type[ModelAPI]:
282
282
  def validate_openai_client(feature: str) -> None:
283
283
  FEATURE = feature
284
284
  PACKAGE = "openai"
285
- MIN_VERSION = "1.58.1"
285
+ MIN_VERSION = "1.68.0"
286
286
 
287
287
  # verify we have the package
288
288
  try:
@@ -34,8 +34,8 @@ from inspect_ai._util.content import (
34
34
  Content,
35
35
  ContentAudio,
36
36
  ContentImage,
37
+ ContentReasoning,
37
38
  ContentText,
38
- ContentVideo,
39
39
  )
40
40
  from inspect_ai._util.http import is_retryable_http_status
41
41
  from inspect_ai._util.images import file_as_data
@@ -336,10 +336,13 @@ async def content_part(content: Content | str) -> Part:
336
336
  elif isinstance(content, ContentImage):
337
337
  image_bytes, mime_type = await file_as_data(content.image)
338
338
  return Part.from_image(image=Image.from_bytes(data=image_bytes))
339
+ elif isinstance(content, ContentReasoning):
340
+ return Part.from_text(content.reasoning or NO_CONTENT)
339
341
  else:
340
342
  if isinstance(content, ContentAudio):
341
343
  file = content.audio
342
- elif isinstance(content, ContentVideo):
344
+ else:
345
+ # it's ContentVideo
343
346
  file = content.video
344
347
  file_bytes, mime_type = await file_as_data(file)
345
348
  return Part.from_data(file_bytes, mime_type)
@@ -22,17 +22,23 @@ from ._tool_def import ToolDef
22
22
  from ._tool_info import ToolInfo
23
23
  from ._tool_params import ToolParam, ToolParams
24
24
  from ._tool_with import tool_with
25
+ from ._tools._bash_session import bash_session
25
26
  from ._tools._computer import computer
26
27
  from ._tools._execute import bash, python
28
+ from ._tools._text_editor import text_editor
29
+ from ._tools._think import think
27
30
  from ._tools._web_browser import web_browser
28
31
  from ._tools._web_search import web_search
29
32
 
30
33
  __all__ = [
31
34
  "bash",
35
+ "bash_session",
32
36
  "computer",
33
37
  "python",
34
38
  "web_browser",
35
39
  "web_search",
40
+ "think",
41
+ "text_editor",
36
42
  "tool",
37
43
  "tool_with",
38
44
  "Tool",
inspect_ai/tool/_tool.py CHANGED
@@ -20,6 +20,7 @@ from inspect_ai._util.content import (
20
20
  )
21
21
  from inspect_ai._util.registry import (
22
22
  RegistryInfo,
23
+ is_registry_object,
23
24
  registry_add,
24
25
  registry_name,
25
26
  registry_tag,
@@ -200,7 +201,25 @@ def tool(
200
201
  # wrap instantiations of scorer so they carry registry info and metrics
201
202
  @wraps(tool_type)
202
203
  def tool_wrapper(*args: P.args, **kwargs: P.kwargs) -> Tool:
204
+ # create the tool
203
205
  tool = tool_type(*args, **kwargs)
206
+
207
+ # this might already have registry info, in that case
208
+ # capture it and use it as defaults
209
+ from inspect_ai.tool._tool_def import tool_registry_info
210
+
211
+ tool_parallel = parallel
212
+ tool_viewer = viewer
213
+ tool_model_input = model_input
214
+ if is_registry_object(tool):
215
+ _, _, reg_parallel, reg_viewer, reg_model_input = tool_registry_info(
216
+ tool
217
+ )
218
+ tool_parallel = parallel and reg_parallel
219
+ tool_viewer = viewer or reg_viewer
220
+ tool_model_input = model_input or reg_model_input
221
+
222
+ # tag the object
204
223
  registry_tag(
205
224
  tool_type,
206
225
  tool,
@@ -209,10 +228,11 @@ def tool(
209
228
  name=tool_name,
210
229
  metadata={
211
230
  TOOL_PROMPT: prompt,
212
- TOOL_PARALLEL: parallel,
213
- TOOL_VIEWER: viewer,
231
+ TOOL_PARALLEL: tool_parallel,
232
+ TOOL_VIEWER: tool_viewer,
214
233
  TOOL_MODEL_INPUT: (
215
- model_input or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
234
+ tool_model_input
235
+ or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
216
236
  ),
217
237
  },
218
238
  ),
@@ -44,8 +44,11 @@ class ToolCall:
44
44
  arguments: dict[str, Any]
45
45
  """Arguments to function."""
46
46
 
47
- type: Literal["function"]
48
- """Type of tool call (currently only 'function')"""
47
+ type: str
48
+ """Type of tool call ('function' or a model specific internal tool type)"""
49
+
50
+ internal_name: str | None = field(default=None)
51
+ """Model's internal name for the tool - if any."""
49
52
 
50
53
  parse_error: str | None = field(default=None)
51
54
  """Error which occurred parsing tool call."""
@@ -0,0 +1,200 @@
1
+ """
2
+ This module provides helper code for handling JSON-RPC communication between the inspect process and the `inspect-tool-support` package code running in the sandbox environment.
3
+
4
+ It includes definitions for JSON-RPC request and response models, as well as functions to create and parse JSON-RPC requests and responses.
5
+ """
6
+
7
+ import json
8
+ from itertools import count
9
+ from textwrap import dedent
10
+ from typing import Literal, Type, TypeVar, cast
11
+
12
+ from pydantic import BaseModel, RootModel
13
+
14
+ from inspect_ai._util.error import PrerequisiteError
15
+ from inspect_ai.tool._tool import ToolError, ToolParsingError
16
+ from inspect_ai.util import sandbox_with
17
+ from inspect_ai.util._sandbox.environment import SandboxEnvironment
18
+
19
+
20
+ class JSONRPCResponseBase(BaseModel):
21
+ jsonrpc: Literal["2.0"]
22
+ id: int | float | str
23
+
24
+
25
+ class JSONRPCSuccessResponse(JSONRPCResponseBase):
26
+ result: object
27
+
28
+
29
+ class JSONRPCError(BaseModel):
30
+ """See: https://www.jsonrpc.org/specification#error_object"""
31
+
32
+ code: int
33
+ message: str
34
+ data: object | None = None
35
+
36
+
37
+ class JSONRPCErrorResponse(JSONRPCResponseBase):
38
+ error: JSONRPCError
39
+
40
+
41
+ class JSONRPCResponse(RootModel[JSONRPCSuccessResponse | JSONRPCErrorResponse]):
42
+ pass
43
+
44
+
45
+ BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
46
+ StrOrModelT = TypeVar("StrOrModelT", bound=str | BaseModel)
47
+
48
+ id_generator = count(666)
49
+
50
+
51
+ async def exec_sandbox_rpc(
52
+ sandbox: SandboxEnvironment,
53
+ method: str,
54
+ params: dict[str, object] | tuple[object, ...],
55
+ result_cls: Type[StrOrModelT],
56
+ timeout: int | None = None,
57
+ user: str | None = None,
58
+ ) -> StrOrModelT:
59
+ """
60
+ Execute a JSON-RPC command to a sandbox environment.
61
+
62
+ Note that the JSON RPC request is sent to the exec'ed program via stdin.
63
+
64
+ Args:
65
+ sandbox (SandboxEnvironment): The sandbox environment to execute the command in.
66
+ method (str): The JSON-RPC method to call.
67
+ params (dict[str, object] | tuple[object, ...]): The parameters for the JSON-RPC method.
68
+ result_cls (Type[BaseModelT]): The class to use for parsing the result.
69
+ timeout (int | None, optional): The timeout for the execution. Defaults to None.
70
+ user: Optional username or UID to run the command as.
71
+
72
+ Returns:
73
+ BaseModelT: The parsed result of the JSON-RPC call.
74
+
75
+ Raises:
76
+ RuntimeError: If the sandbox execution fails or if there is an error in the JSON-RPC response.
77
+ ToolParsingError: If the JSON-RPC response contains a specific error code indicating a parsing error.
78
+ """
79
+ exec_result = await sandbox.exec(
80
+ [SANDBOX_CLI, "exec"],
81
+ input=_create_json_rpc_request(method, params),
82
+ timeout=timeout,
83
+ user=user,
84
+ )
85
+
86
+ if not exec_result.success:
87
+ raise RuntimeError(
88
+ f"Sandbox.exec failure executing {_rpc_call_description(method, params)}: {exec_result.stderr}"
89
+ )
90
+
91
+ match _parse_json_rpc_response(exec_result.stdout, result_cls):
92
+ case JSONRPCError(code=-32601 | -32602, message=message):
93
+ raise ToolParsingError(message)
94
+ case JSONRPCError(code=-32000, message=message):
95
+ raise ToolError(message)
96
+ case JSONRPCError(code=code, message=message):
97
+ raise RuntimeError(
98
+ f"Error executing tool command {_rpc_call_description(method, params)}: {code=} {message}"
99
+ )
100
+ # case result_cls() as model: yields a mypy error since it has narrowed model down
101
+ # to BaseModel and not BaseModelT. ???
102
+ case model if isinstance(model, result_cls):
103
+ return model
104
+ case not_possible:
105
+ raise RuntimeError(
106
+ f"Error executing tool command {_rpc_call_description(method, params)}: {not_possible}"
107
+ )
108
+
109
+
110
+ SANDBOX_CLI = "inspect-tool-support"
111
+ INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB = "aisiuk/inspect-tool-support"
112
+
113
+
114
+ async def tool_container_sandbox(tool_name: str) -> SandboxEnvironment:
115
+ sb = await sandbox_with(SANDBOX_CLI, True)
116
+ if sb:
117
+ return sb
118
+ else:
119
+ msg = dedent(f"""
120
+ The {tool_name} service was not found in any of the sandboxes for this sample. Please add the {tool_name} to your configuration.
121
+
122
+ For example, the following Docker compose file uses the {INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB} reference image as its default sandbox:
123
+
124
+ services:
125
+ default:
126
+ image: "{INSPECT_TOOL_SUPPORT_IMAGE_DOCKERHUB}"
127
+ init: true
128
+
129
+ Alternatively, you can include the service into your own Dockerfile:
130
+
131
+ RUN python -m venv /opt/inspect_tool_support
132
+ ENV PATH="/opt/inspect_tool_support/bin:$PATH"
133
+ RUN pip install inspect-tool-support
134
+ RUN inspect-tool-support post-install
135
+ """).strip()
136
+ raise PrerequisiteError(msg)
137
+
138
+
139
+ def _create_json_rpc_request(
140
+ method: str, params: dict[str, object] | tuple[object, ...]
141
+ ) -> str:
142
+ return json.dumps(
143
+ {
144
+ "jsonrpc": "2.0",
145
+ "method": method,
146
+ "id": next(id_generator),
147
+ "params": list(params) if isinstance(params, tuple) else params,
148
+ }
149
+ )
150
+
151
+
152
+ def _rpc_call_description(
153
+ method: str, params: dict[str, object] | tuple[object, ...]
154
+ ) -> str:
155
+ """
156
+ Generate a string description of an RPC call.
157
+
158
+ Args:
159
+ method (str): The name of the RPC method.
160
+ params (dict[str, object] | tuple[object, ...]): The parameters for the RPC method.
161
+
162
+ Returns:
163
+ str: A string description of the RPC call.
164
+
165
+ Examples:
166
+ >>> _rpc_call_description("subtract", {"minuend": 42, "subtrahend": 23})
167
+ 'subtract(minuend: 42, subtrahend: 23)'
168
+
169
+ >>> _rpc_call_description("subtract", (42, 23))
170
+ 'subtract(42, 23)'
171
+ """
172
+ normalized_params = (
173
+ list(map(str, params))
174
+ if isinstance(params, tuple)
175
+ else [f"{k}: {v}" for k, v in params.items()]
176
+ )
177
+ return f"{method}({', '.join(normalized_params)})"
178
+
179
+
180
+ def _parse_json_rpc_response(
181
+ response_str: str,
182
+ result_cls: Type[StrOrModelT],
183
+ ) -> StrOrModelT | JSONRPCError:
184
+ match JSONRPCResponse.model_validate_json(response_str).root:
185
+ case JSONRPCErrorResponse(error=error):
186
+ return error
187
+ case JSONRPCSuccessResponse(result=rpc_result):
188
+ # TODO: Wow. Is there really no way to convince Python to narrow these types
189
+ # and avoid the cast's
190
+ if result_cls is str:
191
+ if not isinstance(rpc_result, str):
192
+ raise ValueError(f"Expected string result, got {type(rpc_result)}")
193
+ return cast(StrOrModelT, rpc_result)
194
+ else:
195
+ return cast(
196
+ StrOrModelT,
197
+ cast(BaseModel, result_cls).model_validate(rpc_result, strict=True),
198
+ )
199
+ case _:
200
+ raise ValueError(f"Unexpected JSON RPC response: {response_str}")
@@ -0,0 +1,119 @@
1
+ from pydantic import BaseModel, Field, RootModel
2
+
3
+ from inspect_ai.tool import ToolResult
4
+ from inspect_ai.tool._tool_support_helpers import (
5
+ exec_sandbox_rpc,
6
+ tool_container_sandbox,
7
+ )
8
+ from inspect_ai.util import StoreModel, store_as
9
+
10
+ from .._tool import Tool, ToolParsingError, tool
11
+ from .._tool_call import ToolCall, ToolCallContent, ToolCallView, ToolCallViewer
12
+
13
+
14
+ # These models are cloned from the container code. If/when we decide to create
15
+ # a package that is shared between the inspect and tool-container codebases, we'll
16
+ # just have to live with it.
17
+ class NewSessionResult(BaseModel):
18
+ session_name: str
19
+
20
+
21
+ class BashRestartResult(BaseModel):
22
+ pass
23
+
24
+
25
+ class BashCommandResult(BaseModel):
26
+ status: int
27
+ stdout: str
28
+ stderr: str
29
+
30
+
31
+ class BashResult(RootModel[BashRestartResult | BashCommandResult]):
32
+ pass
33
+
34
+
35
+ class BashSessionStore(StoreModel):
36
+ session_id: str = Field(default_factory=str)
37
+
38
+
39
+ # custom viewer for bash
40
+ def code_viewer(language: str, code_param: str) -> ToolCallViewer:
41
+ def viewer(tool_call: ToolCall) -> ToolCallView:
42
+ code = tool_call.arguments.get(code_param, None)
43
+ code = (code or tool_call.function).strip()
44
+ call = ToolCallContent(
45
+ title=language,
46
+ format="markdown",
47
+ content=f"```{language}\n" + code + "\n```\n",
48
+ )
49
+ return ToolCallView(call=call)
50
+
51
+ return viewer
52
+
53
+
54
+ @tool(viewer=code_viewer("bash", "command"))
55
+ def bash_session(timeout: int | None = None) -> Tool:
56
+ """Bash shell session command execution tool.
57
+
58
+ Execute bash shell commands in a long running session using a sandbox environment (e.g. "docker").
59
+
60
+ Args:
61
+ timeout: Timeout (in seconds) for command.
62
+
63
+ Returns:
64
+ String with command output (stdout) or command error (stderr).
65
+ """
66
+
67
+ async def execute(
68
+ command: str | None = None,
69
+ restart: bool | None = None,
70
+ ) -> ToolResult:
71
+ """
72
+ Use this function to execute bash commands.
73
+
74
+ Args:
75
+ command: The bash command to run. Required unless the tool is being restarted.
76
+ restart: Specifying true will restart this tool. Otherwise, leave this unspecified.
77
+
78
+ Returns:
79
+ The output of the command.
80
+ """
81
+ if not ((command is None) ^ (restart is None)):
82
+ raise ToolParsingError(
83
+ "Either 'command' or 'restart' must be specified, but not both."
84
+ )
85
+ params: dict[str, object] = {"command": command, "restart": restart}
86
+
87
+ sandbox = await tool_container_sandbox("bash session")
88
+ store = store_as(BashSessionStore)
89
+
90
+ if not store.session_id:
91
+ store.session_id = (
92
+ await exec_sandbox_rpc(
93
+ sandbox,
94
+ "bash_session_new_session",
95
+ {},
96
+ NewSessionResult,
97
+ timeout=timeout,
98
+ )
99
+ ).session_name
100
+
101
+ params["session_name"] = store.session_id
102
+
103
+ result = (
104
+ await exec_sandbox_rpc(
105
+ sandbox,
106
+ "bash_session",
107
+ params,
108
+ BashResult,
109
+ timeout=timeout,
110
+ )
111
+ ).root
112
+
113
+ if isinstance(result, BashRestartResult):
114
+ return "Bash session restarted."
115
+
116
+ # return output (including stderr if any)
117
+ return f"{result.stderr}\n{result.stdout}" if result.stderr else result.stdout
118
+
119
+ return execute