inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. inspect_ai/_cli/eval.py +14 -3
  2. inspect_ai/_cli/sandbox.py +3 -3
  3. inspect_ai/_cli/score.py +6 -4
  4. inspect_ai/_cli/trace.py +53 -6
  5. inspect_ai/_display/core/config.py +1 -1
  6. inspect_ai/_display/core/display.py +2 -1
  7. inspect_ai/_display/core/footer.py +6 -6
  8. inspect_ai/_display/plain/display.py +11 -6
  9. inspect_ai/_display/rich/display.py +23 -13
  10. inspect_ai/_display/textual/app.py +10 -9
  11. inspect_ai/_display/textual/display.py +2 -2
  12. inspect_ai/_display/textual/widgets/footer.py +4 -0
  13. inspect_ai/_display/textual/widgets/samples.py +14 -5
  14. inspect_ai/_eval/context.py +1 -2
  15. inspect_ai/_eval/eval.py +54 -41
  16. inspect_ai/_eval/loader.py +9 -2
  17. inspect_ai/_eval/run.py +148 -81
  18. inspect_ai/_eval/score.py +13 -8
  19. inspect_ai/_eval/task/images.py +31 -21
  20. inspect_ai/_eval/task/run.py +62 -59
  21. inspect_ai/_eval/task/rundir.py +16 -9
  22. inspect_ai/_eval/task/sandbox.py +7 -8
  23. inspect_ai/_eval/task/util.py +7 -0
  24. inspect_ai/_util/_async.py +118 -10
  25. inspect_ai/_util/constants.py +0 -2
  26. inspect_ai/_util/file.py +15 -29
  27. inspect_ai/_util/future.py +37 -0
  28. inspect_ai/_util/http.py +3 -99
  29. inspect_ai/_util/httpx.py +60 -0
  30. inspect_ai/_util/interrupt.py +2 -2
  31. inspect_ai/_util/json.py +5 -52
  32. inspect_ai/_util/logger.py +30 -86
  33. inspect_ai/_util/retry.py +10 -61
  34. inspect_ai/_util/trace.py +2 -2
  35. inspect_ai/_view/server.py +86 -3
  36. inspect_ai/_view/www/dist/assets/index.js +25837 -13269
  37. inspect_ai/_view/www/log-schema.json +253 -186
  38. inspect_ai/_view/www/package.json +2 -2
  39. inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
  40. inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
  41. inspect_ai/_view/www/src/types/log.d.ts +122 -94
  42. inspect_ai/approval/_human/manager.py +6 -10
  43. inspect_ai/approval/_human/panel.py +2 -2
  44. inspect_ai/dataset/_sources/util.py +7 -6
  45. inspect_ai/log/__init__.py +4 -0
  46. inspect_ai/log/_file.py +35 -61
  47. inspect_ai/log/_log.py +18 -1
  48. inspect_ai/log/_recorders/eval.py +14 -23
  49. inspect_ai/log/_recorders/json.py +3 -18
  50. inspect_ai/log/_samples.py +27 -2
  51. inspect_ai/log/_transcript.py +8 -8
  52. inspect_ai/model/__init__.py +2 -1
  53. inspect_ai/model/_call_tools.py +60 -40
  54. inspect_ai/model/_chat_message.py +3 -2
  55. inspect_ai/model/_generate_config.py +25 -0
  56. inspect_ai/model/_model.py +74 -36
  57. inspect_ai/model/_openai.py +9 -1
  58. inspect_ai/model/_providers/anthropic.py +24 -26
  59. inspect_ai/model/_providers/azureai.py +11 -9
  60. inspect_ai/model/_providers/bedrock.py +33 -24
  61. inspect_ai/model/_providers/cloudflare.py +8 -9
  62. inspect_ai/model/_providers/goodfire.py +7 -3
  63. inspect_ai/model/_providers/google.py +47 -13
  64. inspect_ai/model/_providers/groq.py +15 -15
  65. inspect_ai/model/_providers/hf.py +24 -17
  66. inspect_ai/model/_providers/mistral.py +36 -20
  67. inspect_ai/model/_providers/openai.py +30 -25
  68. inspect_ai/model/_providers/openai_o1.py +1 -1
  69. inspect_ai/model/_providers/providers.py +1 -1
  70. inspect_ai/model/_providers/together.py +3 -4
  71. inspect_ai/model/_providers/util/__init__.py +2 -2
  72. inspect_ai/model/_providers/util/chatapi.py +6 -19
  73. inspect_ai/model/_providers/util/hooks.py +165 -0
  74. inspect_ai/model/_providers/vertex.py +20 -3
  75. inspect_ai/model/_providers/vllm.py +16 -19
  76. inspect_ai/scorer/_multi.py +5 -2
  77. inspect_ai/solver/_bridge/patch.py +31 -1
  78. inspect_ai/solver/_fork.py +5 -3
  79. inspect_ai/solver/_human_agent/agent.py +3 -2
  80. inspect_ai/tool/__init__.py +8 -2
  81. inspect_ai/tool/_tool_info.py +4 -90
  82. inspect_ai/tool/_tool_params.py +4 -34
  83. inspect_ai/tool/_tools/_web_search.py +30 -24
  84. inspect_ai/util/__init__.py +4 -0
  85. inspect_ai/util/_concurrency.py +5 -6
  86. inspect_ai/util/_display.py +6 -0
  87. inspect_ai/util/_json.py +170 -0
  88. inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
  89. inspect_ai/util/_sandbox/docker/docker.py +5 -0
  90. inspect_ai/util/_sandbox/environment.py +56 -9
  91. inspect_ai/util/_sandbox/service.py +12 -5
  92. inspect_ai/util/_subprocess.py +94 -113
  93. inspect_ai/util/_subtask.py +2 -4
  94. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
  95. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
  96. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
  97. inspect_ai/_util/timeouts.py +0 -160
  98. inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
  99. inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
  100. inspect_ai/model/_providers/util/tracker.py +0 -92
  101. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
  102. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
  103. {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
@@ -1,44 +1,14 @@
1
1
  from typing import (
2
- Any,
3
2
  Literal,
4
- Optional,
3
+ TypeAlias,
5
4
  )
6
5
 
7
6
  from pydantic import BaseModel, Field
8
7
 
9
- JSONType = Literal["string", "integer", "number", "boolean", "array", "object", "null"]
10
- """Validate types within JSON schema."""
8
+ from inspect_ai.util._json import JSONSchema
11
9
 
12
-
13
- class ToolParam(BaseModel):
14
- """Description of tool parameter in JSON Schema format."""
15
-
16
- type: JSONType | None = Field(default=None)
17
- """JSON type of tool parameter."""
18
-
19
- description: str | None = Field(default=None)
20
- """Parameter description."""
21
-
22
- default: Any = Field(default=None)
23
- """Default value for parameter."""
24
-
25
- enum: list[Any] | None = Field(default=None)
26
- """Valid values for enum parameters."""
27
-
28
- items: Optional["ToolParam"] = Field(default=None)
29
- """Valid type for array parameters."""
30
-
31
- properties: dict[str, "ToolParam"] | None = Field(default=None)
32
- """Valid fields for object parametrs."""
33
-
34
- additionalProperties: Optional["ToolParam"] | bool | None = Field(default=None)
35
- """Are additional properties allowed?"""
36
-
37
- anyOf: list["ToolParam"] | None = Field(default=None)
38
- """Valid types for union parameters."""
39
-
40
- required: list[str] | None = Field(default=None)
41
- """Required fields for object parameters."""
10
+ ToolParam: TypeAlias = JSONSchema
11
+ """Description of tool parameter in JSON Schema format."""
42
12
 
43
13
 
44
14
  class ToolParams(BaseModel):
@@ -1,7 +1,7 @@
1
- import asyncio
2
1
  import os
3
2
  from typing import Literal, Protocol, runtime_checkable
4
3
 
4
+ import anyio
5
5
  import httpx
6
6
  from bs4 import BeautifulSoup, NavigableString
7
7
  from tenacity import (
@@ -13,7 +13,7 @@ from tenacity import (
13
13
  )
14
14
 
15
15
  from inspect_ai._util.error import PrerequisiteError
16
- from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt
16
+ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
17
17
  from inspect_ai.util._concurrency import concurrency
18
18
 
19
19
  from .._tool import Tool, ToolResult, tool
@@ -25,6 +25,17 @@ Page Content: {text}
25
25
  """
26
26
 
27
27
 
28
+ class SearchLink:
29
+ def __init__(self, url: str, snippet: str) -> None:
30
+ self.url = url
31
+ self.snippet = snippet
32
+
33
+
34
+ @runtime_checkable
35
+ class SearchProvider(Protocol):
36
+ async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
37
+
38
+
28
39
  @tool
29
40
  def web_search(
30
41
  provider: Literal["google"] = "google",
@@ -84,16 +95,22 @@ def web_search(
84
95
  async with concurrency(f"{provider}_web_search", max_connections):
85
96
  links = await search_provider(query, start_idx=search_calls * 10)
86
97
 
87
- # Extract and summarize each page individually
88
- pages = await asyncio.gather(
89
- *[page_if_relevant(link.url, query, model, client) for link in links],
90
- return_exceptions=True,
91
- )
92
- for page, link in zip(pages, links):
93
- if page and not isinstance(page, BaseException):
94
- page_contents.append(page)
95
- urls.append(link.url)
96
- snippets.append(link.snippet)
98
+ async with anyio.create_task_group() as tg:
99
+
100
+ async def process_link(link: SearchLink) -> None:
101
+ try:
102
+ page = await page_if_relevant(link.url, query, model, client)
103
+ if page:
104
+ page_contents.append(page)
105
+ urls.append(link.url)
106
+ snippets.append(link.snippet)
107
+ # exceptions fetching pages are very common!
108
+ except Exception:
109
+ pass
110
+
111
+ for lk in links:
112
+ tg.start_soon(process_link, lk)
113
+
97
114
  search_calls += 1
98
115
 
99
116
  all_page_contents = "\n\n".join(page_contents)
@@ -168,17 +185,6 @@ async def page_if_relevant(
168
185
  return None
169
186
 
170
187
 
171
- class SearchLink:
172
- def __init__(self, url: str, snippet: str) -> None:
173
- self.url = url
174
- self.snippet = snippet
175
-
176
-
177
- @runtime_checkable
178
- class SearchProvider(Protocol):
179
- async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
180
-
181
-
182
188
  def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
183
189
  google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
184
190
  google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
@@ -204,7 +210,7 @@ def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
204
210
  wait=wait_exponential_jitter(),
205
211
  stop=stop_after_attempt(5) | stop_after_delay(60),
206
212
  retry=retry_if_exception(httpx_should_retry),
207
- before_sleep=log_retry_attempt(search_url),
213
+ before_sleep=log_httpx_retry_attempt(search_url),
208
214
  )
209
215
  async def execute_search() -> httpx.Response:
210
216
  return await client.get(search_url)
@@ -3,6 +3,7 @@ from inspect_ai._util.trace import trace_action, trace_message
3
3
  from ._concurrency import concurrency
4
4
  from ._console import input_screen
5
5
  from ._display import DisplayType, display_counter, display_type
6
+ from ._json import JSONSchema, JSONType, json_schema
6
7
  from ._panel import InputPanel, input_panel
7
8
  from ._resource import resource
8
9
  from ._sandbox import (
@@ -36,6 +37,9 @@ __all__ = [
36
37
  "InputPanel",
37
38
  "input_panel",
38
39
  "input_screen",
40
+ "JSONType",
41
+ "JSONSchema",
42
+ "json_schema",
39
43
  "OutputLimitExceededError",
40
44
  "resource",
41
45
  "subprocess",
@@ -1,9 +1,10 @@
1
- import asyncio
2
1
  import contextlib
3
2
  import time
4
3
  from dataclasses import dataclass
5
4
  from typing import AsyncIterator
6
5
 
6
+ import anyio
7
+
7
8
  from inspect_ai._util.working import report_sample_waiting_time
8
9
 
9
10
 
@@ -45,9 +46,7 @@ async def concurrency(
45
46
  # do we have an existing semaphore? if not create one and store it
46
47
  semaphore = _concurrency_semaphores.get(key, None)
47
48
  if semaphore is None:
48
- semaphore = ConcurencySempahore(
49
- name, concurrency, asyncio.Semaphore(concurrency)
50
- )
49
+ semaphore = ConcurencySempahore(name, concurrency, anyio.Semaphore(concurrency))
51
50
  _concurrency_semaphores[key] = semaphore
52
51
 
53
52
  # wait and yield to protected code
@@ -60,7 +59,7 @@ async def concurrency(
60
59
  def concurrency_status() -> dict[str, tuple[int, int]]:
61
60
  status: dict[str, tuple[int, int]] = {}
62
61
  for c in _concurrency_semaphores.values():
63
- status[c.name] = (c.concurrency - c.semaphore._value, c.concurrency)
62
+ status[c.name] = (c.concurrency - c.semaphore.value, c.concurrency)
64
63
  return status
65
64
 
66
65
 
@@ -72,7 +71,7 @@ def init_concurrency() -> None:
72
71
  class ConcurencySempahore:
73
72
  name: str
74
73
  concurrency: int
75
- semaphore: asyncio.Semaphore
74
+ semaphore: anyio.Semaphore
76
75
 
77
76
 
78
77
  _concurrency_semaphores: dict[str, ConcurencySempahore] = {}
@@ -2,6 +2,7 @@ import os
2
2
  from logging import getLogger
3
3
  from typing import Literal
4
4
 
5
+ from inspect_ai._util._async import configured_async_backend
5
6
  from inspect_ai._util.constants import DEFAULT_DISPLAY
6
7
  from inspect_ai._util.thread import is_main_thread
7
8
 
@@ -20,6 +21,11 @@ def init_display_type(display: str | None = None) -> DisplayType:
20
21
  display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
21
22
  )
22
23
 
24
+ # if trio is configured as the backend then throttle down to "rich"
25
+ # (as textual uses asyncio directly so is not compatible with trio)
26
+ if configured_async_backend() == "trio" and display == "full":
27
+ display = "rich"
28
+
23
29
  # if we are on a background thread then throttle down to "plain"
24
30
  # ("full" requires textual which cannot run in a background thread
25
31
  # b/c it calls the Python signal function; "rich" assumes exclusive
@@ -0,0 +1,170 @@
1
+ import types
2
+ import typing
3
+ from dataclasses import is_dataclass
4
+ from typing import (
5
+ Any,
6
+ Dict,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ Tuple,
11
+ Type,
12
+ Union,
13
+ get_args,
14
+ get_origin,
15
+ get_type_hints,
16
+ is_typeddict,
17
+ )
18
+
19
+ from pydantic import BaseModel, Field
20
+
21
+ JSONType = Literal["string", "integer", "number", "boolean", "array", "object", "null"]
22
+ """Valid types within JSON schema."""
23
+
24
+
25
+ class JSONSchema(BaseModel):
26
+ """JSON Schema for type."""
27
+
28
+ type: JSONType | None = Field(default=None)
29
+ """JSON type of tool parameter."""
30
+
31
+ description: str | None = Field(default=None)
32
+ """Parameter description."""
33
+
34
+ default: Any = Field(default=None)
35
+ """Default value for parameter."""
36
+
37
+ enum: list[Any] | None = Field(default=None)
38
+ """Valid values for enum parameters."""
39
+
40
+ items: Optional["JSONSchema"] = Field(default=None)
41
+ """Valid type for array parameters."""
42
+
43
+ properties: dict[str, "JSONSchema"] | None = Field(default=None)
44
+ """Valid fields for object parametrs."""
45
+
46
+ additionalProperties: Optional["JSONSchema"] | bool | None = Field(default=None)
47
+ """Are additional properties allowed?"""
48
+
49
+ anyOf: list["JSONSchema"] | None = Field(default=None)
50
+ """Valid types for union parameters."""
51
+
52
+ required: list[str] | None = Field(default=None)
53
+ """Required fields for object parameters."""
54
+
55
+
56
+ def json_schema(t: Type[Any]) -> JSONSchema:
57
+ """Provide a JSON Schema for the specified type.
58
+
59
+ Schemas can be automatically inferred for a wide variety of
60
+ Python class types including Pydantic BaseModel, dataclasses,
61
+ and typed dicts.
62
+
63
+ Args:
64
+ t: Python type
65
+
66
+ Returns:
67
+ JSON Schema for type.
68
+ """
69
+ origin = get_origin(t)
70
+ args = get_args(t)
71
+
72
+ if origin is None:
73
+ if t is int:
74
+ return JSONSchema(type="integer")
75
+ elif t is float:
76
+ return JSONSchema(type="number")
77
+ elif t is str:
78
+ return JSONSchema(type="string")
79
+ elif t is bool:
80
+ return JSONSchema(type="boolean")
81
+ elif t is list:
82
+ return JSONSchema(type="array", items=JSONSchema())
83
+ elif t is dict:
84
+ return JSONSchema(type="object", additionalProperties=JSONSchema())
85
+ elif (
86
+ is_dataclass(t)
87
+ or is_typeddict(t)
88
+ or (isinstance(t, type) and issubclass(t, BaseModel))
89
+ ):
90
+ return cls_json_schema(t)
91
+ elif t is type(None):
92
+ return JSONSchema(type="null")
93
+ else:
94
+ return JSONSchema()
95
+ elif origin is list or origin is List or origin is tuple or origin is Tuple:
96
+ return JSONSchema(
97
+ type="array", items=json_schema(args[0]) if args else JSONSchema()
98
+ )
99
+ elif origin is dict or origin is Dict:
100
+ return JSONSchema(
101
+ type="object",
102
+ additionalProperties=json_schema(args[1])
103
+ if len(args) > 1
104
+ else JSONSchema(),
105
+ )
106
+ elif origin is Union or origin is types.UnionType:
107
+ return JSONSchema(anyOf=[json_schema(arg) for arg in args])
108
+ elif origin is Optional:
109
+ return JSONSchema(
110
+ anyOf=[json_schema(arg) for arg in args] + [JSONSchema(type="null")]
111
+ )
112
+ elif origin is typing.Literal:
113
+ return JSONSchema(enum=list(args))
114
+
115
+ return JSONSchema() # Default case if we can't determine the type
116
+
117
+
118
+ def cls_json_schema(cls: Type[Any]) -> JSONSchema:
119
+ properties: Dict[str, JSONSchema] = {}
120
+ required: List[str] = []
121
+
122
+ if is_dataclass(cls):
123
+ fields = cls.__dataclass_fields__ # type: ignore
124
+ for name, field in fields.items():
125
+ properties[name] = json_schema(field.type) # type: ignore
126
+ if field.default == field.default_factory:
127
+ required.append(name)
128
+ elif isinstance(cls, type) and issubclass(cls, BaseModel):
129
+ schema = cls.model_json_schema()
130
+ for name, prop in schema.get("properties", {}).items():
131
+ properties[name] = JSONSchema(**prop)
132
+ required = schema.get("required", [])
133
+ elif is_typeddict(cls):
134
+ annotations = get_type_hints(cls)
135
+ for name, type_hint in annotations.items():
136
+ properties[name] = json_schema(type_hint)
137
+ if name in cls.__required_keys__:
138
+ required.append(name)
139
+
140
+ return JSONSchema(
141
+ type="object",
142
+ properties=properties,
143
+ required=required if required else None,
144
+ additionalProperties=False,
145
+ )
146
+
147
+
148
+ def python_type_to_json_type(python_type: str | None) -> JSONType:
149
+ match python_type:
150
+ case "str":
151
+ return "string"
152
+ case "int":
153
+ return "integer"
154
+ case "float":
155
+ return "number"
156
+ case "bool":
157
+ return "boolean"
158
+ case "list":
159
+ return "array"
160
+ case "dict":
161
+ return "object"
162
+ case "None":
163
+ return "null"
164
+ # treat 'unknown' as string as anything can be converted to string
165
+ case None:
166
+ return "string"
167
+ case _:
168
+ raise ValueError(
169
+ f"Unsupported type: {python_type} for Python to JSON conversion."
170
+ )
@@ -1,12 +1,14 @@
1
- import asyncio
2
1
  from contextvars import ContextVar
3
2
  from pathlib import Path
4
3
  from typing import Awaitable, Callable, Set
5
4
 
5
+ import anyio
6
6
  from rich import box, print
7
7
  from rich.panel import Panel
8
8
  from rich.table import Table
9
9
 
10
+ from inspect_ai._util._async import coro_print_exceptions
11
+
10
12
  from .compose import compose_down, compose_ls, compose_ps
11
13
  from .config import is_auto_compose_file, safe_cleanup_auto_compose
12
14
  from .util import ComposeProject
@@ -94,13 +96,15 @@ async def cleanup_projects(
94
96
  )
95
97
 
96
98
  # cleanup all of the projects in parallel
97
- tasks = [cleanup_fn(project, False) for project in projects]
98
- results = await asyncio.gather(*tasks, return_exceptions=True)
99
-
100
- # report errors
101
- for result in results:
102
- if result is not None:
103
- print(f"Error cleaning up Docker environment: {result}")
99
+ async with anyio.create_task_group() as tg:
100
+ for project in projects:
101
+ tg.start_soon(
102
+ coro_print_exceptions,
103
+ "cleaning up Docker environment",
104
+ cleanup_fn,
105
+ project,
106
+ False,
107
+ )
104
108
 
105
109
 
106
110
  async def cli_cleanup(project_name: str | None) -> None:
@@ -141,7 +145,7 @@ def auto_compose_files() -> Set[str]:
141
145
 
142
146
 
143
147
  _running_projects: ContextVar[list[ComposeProject]] = ContextVar(
144
- "docker_running_projects"
148
+ "docker_running_projects", default=[]
145
149
  )
146
150
 
147
151
  _auto_compose_files: ContextVar[Set[str]] = ContextVar("docker_auto_compose_files")
@@ -293,6 +293,9 @@ class DockerSandboxEnvironment(SandboxEnvironment):
293
293
 
294
294
  @override
295
295
  async def write_file(self, file: str, contents: str | bytes) -> None:
296
+ # defualt timeout for write_file operations
297
+ TIMEOUT = 180
298
+
296
299
  # resolve relative file paths
297
300
  file = self.container_file(file)
298
301
 
@@ -309,6 +312,7 @@ class DockerSandboxEnvironment(SandboxEnvironment):
309
312
  result = await self.exec(
310
313
  ["sh", "-e", "-c", 'tee -- "$1"', "write_file_script", file],
311
314
  input=contents,
315
+ timeout=TIMEOUT,
312
316
  )
313
317
  else:
314
318
  base64_contents = base64.b64encode(contents).decode("US-ASCII")
@@ -322,6 +326,7 @@ class DockerSandboxEnvironment(SandboxEnvironment):
322
326
  file,
323
327
  ],
324
328
  input=base64_contents,
329
+ timeout=TIMEOUT,
325
330
  )
326
331
  if result.returncode != 0:
327
332
  if "permission denied" in result.stderr.casefold():
@@ -3,18 +3,19 @@ from __future__ import annotations
3
3
  import abc
4
4
  from dataclasses import dataclass, field
5
5
  from typing import (
6
+ Annotated,
6
7
  Any,
7
8
  Awaitable,
8
9
  Callable,
9
10
  Literal,
10
- NamedTuple,
11
11
  Type,
12
12
  TypeVar,
13
13
  Union,
14
+ cast,
14
15
  overload,
15
16
  )
16
17
 
17
- from pydantic import BaseModel, Field
18
+ from pydantic import BaseModel, Field, model_validator
18
19
 
19
20
  from .._subprocess import ExecResult
20
21
 
@@ -38,6 +39,7 @@ SampleCleanup = Callable[
38
39
  ],
39
40
  Awaitable[None],
40
41
  ]
42
+ ConfigDeserialize = Callable[[dict[str, Any]], BaseModel]
41
43
 
42
44
 
43
45
  class HostMapping(BaseModel):
@@ -211,11 +213,6 @@ class SandboxEnvironment(abc.ABC):
211
213
  f"Expected instance of {sandbox_cls.__name__}, got {type(self).__name__}"
212
214
  )
213
215
 
214
- @classmethod
215
- def config_files(cls) -> list[str]:
216
- """Standard config files for this provider (used for automatic discovery)"""
217
- return []
218
-
219
216
  @classmethod
220
217
  def default_concurrency(cls) -> int | None:
221
218
  """Default max_sandboxes for this provider (`None` means no maximum)"""
@@ -296,6 +293,30 @@ class SandboxEnvironment(abc.ABC):
296
293
  """
297
294
  pass
298
295
 
296
+ @classmethod
297
+ def config_files(cls) -> list[str]:
298
+ """Standard config files for this provider (used for automatic discovery)"""
299
+ return []
300
+
301
+ @classmethod
302
+ def config_deserialize(cls, config: dict[str, Any]) -> BaseModel:
303
+ """Deserialize a sandbox-specific configuration model from a dict.
304
+
305
+ Override this method if you support a custom configuration model.
306
+
307
+ A basic implementation would be: `return MySandboxEnvironmentConfig(**config)`
308
+
309
+ Args:
310
+ config: Configuration dictionary produced by serializing the configuration
311
+ model.
312
+
313
+ Returns:
314
+ The sandbox-specific configuration model.
315
+ """
316
+ raise NotImplementedError(
317
+ "The SandboxEnvironment provider has not implemented config_deserialize."
318
+ )
319
+
299
320
 
300
321
  @dataclass
301
322
  class SandboxEnvironments:
@@ -311,15 +332,30 @@ class SandboxEnvironments:
311
332
  """
312
333
 
313
334
 
314
- class SandboxEnvironmentSpec(NamedTuple):
335
+ class SandboxEnvironmentSpec(BaseModel, frozen=True):
315
336
  """Specification of a SandboxEnvironment."""
316
337
 
317
338
  type: str
318
339
  """Sandbox type (e.g. 'local', 'docker')"""
319
340
 
320
- config: SandboxEnvironmentConfigType | None = None
341
+ # Any is used to prevent Pydantic from trying to initialise a BaseModel.
342
+ config: Annotated[Any, "BaseModel, str or None"] = None
321
343
  """Sandbox configuration (filename or config object)."""
322
344
 
345
+ def __init__(self, type: str, config: BaseModel | str | None = None):
346
+ super().__init__(type=type, config=config)
347
+
348
+ @model_validator(mode="before")
349
+ @classmethod
350
+ def load_config_model(cls, data: dict[str, Any]) -> dict[str, Any]:
351
+ type = data["type"]
352
+ config = data.get("config")
353
+ # Pydantic won't know what concrete type to instantiate for config, so
354
+ # ask the relevant sandbox environment to deserialize it.
355
+ if isinstance(config, dict) and len(config) > 0:
356
+ data["config"] = deserialize_sandbox_specific_config(type, config)
357
+ return data
358
+
323
359
 
324
360
  SandboxEnvironmentConfigType = BaseModel | str
325
361
 
@@ -343,3 +379,14 @@ def resolve_sandbox_environment(
343
379
  return SandboxEnvironmentSpec(sandbox[0], sandbox[1])
344
380
  else:
345
381
  return None
382
+
383
+
384
+ def deserialize_sandbox_specific_config(type: str, config: dict[str, Any]) -> BaseModel:
385
+ # Avoid circular import
386
+ from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
387
+
388
+ sandboxenv_type = registry_find_sandboxenv(type)
389
+ config_deserialize = cast(
390
+ ConfigDeserialize, getattr(sandboxenv_type, "config_deserialize")
391
+ )
392
+ return config_deserialize(config)
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import json
3
2
  from logging import getLogger
4
3
  from pathlib import PurePosixPath
@@ -9,8 +8,10 @@ from typing import (
9
8
  cast,
10
9
  )
11
10
 
11
+ import anyio
12
12
  from pydantic import JsonValue
13
13
 
14
+ from inspect_ai._util._async import coro_log_exceptions
14
15
  from inspect_ai.util._subprocess import ExecResult
15
16
 
16
17
  from .environment import SandboxEnvironment
@@ -59,7 +60,7 @@ async def sandbox_service(
59
60
 
60
61
  # wait for and process methods
61
62
  while not until():
62
- await asyncio.sleep(POLLING_INTERVAL)
63
+ await anyio.sleep(POLLING_INTERVAL)
63
64
  await service.handle_requests()
64
65
 
65
66
 
@@ -141,9 +142,15 @@ class SandboxService:
141
142
  if result.success:
142
143
  request_files = result.stdout.strip().splitlines()
143
144
  if request_files:
144
- await asyncio.gather(
145
- *[self._handle_request(file) for file in request_files]
146
- )
145
+ async with anyio.create_task_group() as tg:
146
+ for file in request_files:
147
+ tg.start_soon(
148
+ coro_log_exceptions,
149
+ logger,
150
+ "handling sandbox service request",
151
+ self._handle_request,
152
+ file,
153
+ )
147
154
 
148
155
  async def _handle_request(self, request_file: str) -> None:
149
156
  # read request