inspect-ai 0.3.72__py3-none-any.whl → 0.3.73__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +14 -3
- inspect_ai/_cli/sandbox.py +3 -3
- inspect_ai/_cli/score.py +6 -4
- inspect_ai/_cli/trace.py +53 -6
- inspect_ai/_display/core/config.py +1 -1
- inspect_ai/_display/core/display.py +2 -1
- inspect_ai/_display/core/footer.py +6 -6
- inspect_ai/_display/plain/display.py +11 -6
- inspect_ai/_display/rich/display.py +23 -13
- inspect_ai/_display/textual/app.py +10 -9
- inspect_ai/_display/textual/display.py +2 -2
- inspect_ai/_display/textual/widgets/footer.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +14 -5
- inspect_ai/_eval/context.py +1 -2
- inspect_ai/_eval/eval.py +54 -41
- inspect_ai/_eval/loader.py +9 -2
- inspect_ai/_eval/run.py +148 -81
- inspect_ai/_eval/score.py +13 -8
- inspect_ai/_eval/task/images.py +31 -21
- inspect_ai/_eval/task/run.py +62 -59
- inspect_ai/_eval/task/rundir.py +16 -9
- inspect_ai/_eval/task/sandbox.py +7 -8
- inspect_ai/_eval/task/util.py +7 -0
- inspect_ai/_util/_async.py +118 -10
- inspect_ai/_util/constants.py +0 -2
- inspect_ai/_util/file.py +15 -29
- inspect_ai/_util/future.py +37 -0
- inspect_ai/_util/http.py +3 -99
- inspect_ai/_util/httpx.py +60 -0
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/json.py +5 -52
- inspect_ai/_util/logger.py +30 -86
- inspect_ai/_util/retry.py +10 -61
- inspect_ai/_util/trace.py +2 -2
- inspect_ai/_view/server.py +86 -3
- inspect_ai/_view/www/dist/assets/index.js +25837 -13269
- inspect_ai/_view/www/log-schema.json +253 -186
- inspect_ai/_view/www/package.json +2 -2
- inspect_ai/_view/www/src/plan/PlanDetailView.tsx +8 -3
- inspect_ai/_view/www/src/samples/transcript/StepEventView.tsx +2 -3
- inspect_ai/_view/www/src/types/log.d.ts +122 -94
- inspect_ai/approval/_human/manager.py +6 -10
- inspect_ai/approval/_human/panel.py +2 -2
- inspect_ai/dataset/_sources/util.py +7 -6
- inspect_ai/log/__init__.py +4 -0
- inspect_ai/log/_file.py +35 -61
- inspect_ai/log/_log.py +18 -1
- inspect_ai/log/_recorders/eval.py +14 -23
- inspect_ai/log/_recorders/json.py +3 -18
- inspect_ai/log/_samples.py +27 -2
- inspect_ai/log/_transcript.py +8 -8
- inspect_ai/model/__init__.py +2 -1
- inspect_ai/model/_call_tools.py +60 -40
- inspect_ai/model/_chat_message.py +3 -2
- inspect_ai/model/_generate_config.py +25 -0
- inspect_ai/model/_model.py +74 -36
- inspect_ai/model/_openai.py +9 -1
- inspect_ai/model/_providers/anthropic.py +24 -26
- inspect_ai/model/_providers/azureai.py +11 -9
- inspect_ai/model/_providers/bedrock.py +33 -24
- inspect_ai/model/_providers/cloudflare.py +8 -9
- inspect_ai/model/_providers/goodfire.py +7 -3
- inspect_ai/model/_providers/google.py +47 -13
- inspect_ai/model/_providers/groq.py +15 -15
- inspect_ai/model/_providers/hf.py +24 -17
- inspect_ai/model/_providers/mistral.py +36 -20
- inspect_ai/model/_providers/openai.py +30 -25
- inspect_ai/model/_providers/openai_o1.py +1 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +3 -4
- inspect_ai/model/_providers/util/__init__.py +2 -2
- inspect_ai/model/_providers/util/chatapi.py +6 -19
- inspect_ai/model/_providers/util/hooks.py +165 -0
- inspect_ai/model/_providers/vertex.py +20 -3
- inspect_ai/model/_providers/vllm.py +16 -19
- inspect_ai/scorer/_multi.py +5 -2
- inspect_ai/solver/_bridge/patch.py +31 -1
- inspect_ai/solver/_fork.py +5 -3
- inspect_ai/solver/_human_agent/agent.py +3 -2
- inspect_ai/tool/__init__.py +8 -2
- inspect_ai/tool/_tool_info.py +4 -90
- inspect_ai/tool/_tool_params.py +4 -34
- inspect_ai/tool/_tools/_web_search.py +30 -24
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_concurrency.py +5 -6
- inspect_ai/util/_display.py +6 -0
- inspect_ai/util/_json.py +170 -0
- inspect_ai/util/_sandbox/docker/cleanup.py +13 -9
- inspect_ai/util/_sandbox/docker/docker.py +5 -0
- inspect_ai/util/_sandbox/environment.py +56 -9
- inspect_ai/util/_sandbox/service.py +12 -5
- inspect_ai/util/_subprocess.py +94 -113
- inspect_ai/util/_subtask.py +2 -4
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/METADATA +6 -2
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/RECORD +99 -99
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/WHEEL +1 -1
- inspect_ai/_util/timeouts.py +0 -160
- inspect_ai/_view/www/node_modules/flatted/python/flatted.py +0 -149
- inspect_ai/_view/www/node_modules/flatted/python/test.py +0 -63
- inspect_ai/model/_providers/util/tracker.py +0 -92
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.72.dist-info → inspect_ai-0.3.73.dist-info}/top_level.txt +0 -0
inspect_ai/tool/_tool_params.py
CHANGED
@@ -1,44 +1,14 @@
|
|
1
1
|
from typing import (
|
2
|
-
Any,
|
3
2
|
Literal,
|
4
|
-
|
3
|
+
TypeAlias,
|
5
4
|
)
|
6
5
|
|
7
6
|
from pydantic import BaseModel, Field
|
8
7
|
|
9
|
-
|
10
|
-
"""Validate types within JSON schema."""
|
8
|
+
from inspect_ai.util._json import JSONSchema
|
11
9
|
|
12
|
-
|
13
|
-
|
14
|
-
"""Description of tool parameter in JSON Schema format."""
|
15
|
-
|
16
|
-
type: JSONType | None = Field(default=None)
|
17
|
-
"""JSON type of tool parameter."""
|
18
|
-
|
19
|
-
description: str | None = Field(default=None)
|
20
|
-
"""Parameter description."""
|
21
|
-
|
22
|
-
default: Any = Field(default=None)
|
23
|
-
"""Default value for parameter."""
|
24
|
-
|
25
|
-
enum: list[Any] | None = Field(default=None)
|
26
|
-
"""Valid values for enum parameters."""
|
27
|
-
|
28
|
-
items: Optional["ToolParam"] = Field(default=None)
|
29
|
-
"""Valid type for array parameters."""
|
30
|
-
|
31
|
-
properties: dict[str, "ToolParam"] | None = Field(default=None)
|
32
|
-
"""Valid fields for object parametrs."""
|
33
|
-
|
34
|
-
additionalProperties: Optional["ToolParam"] | bool | None = Field(default=None)
|
35
|
-
"""Are additional properties allowed?"""
|
36
|
-
|
37
|
-
anyOf: list["ToolParam"] | None = Field(default=None)
|
38
|
-
"""Valid types for union parameters."""
|
39
|
-
|
40
|
-
required: list[str] | None = Field(default=None)
|
41
|
-
"""Required fields for object parameters."""
|
10
|
+
ToolParam: TypeAlias = JSONSchema
|
11
|
+
"""Description of tool parameter in JSON Schema format."""
|
42
12
|
|
43
13
|
|
44
14
|
class ToolParams(BaseModel):
|
@@ -1,7 +1,7 @@
|
|
1
|
-
import asyncio
|
2
1
|
import os
|
3
2
|
from typing import Literal, Protocol, runtime_checkable
|
4
3
|
|
4
|
+
import anyio
|
5
5
|
import httpx
|
6
6
|
from bs4 import BeautifulSoup, NavigableString
|
7
7
|
from tenacity import (
|
@@ -13,7 +13,7 @@ from tenacity import (
|
|
13
13
|
)
|
14
14
|
|
15
15
|
from inspect_ai._util.error import PrerequisiteError
|
16
|
-
from inspect_ai._util.
|
16
|
+
from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
17
17
|
from inspect_ai.util._concurrency import concurrency
|
18
18
|
|
19
19
|
from .._tool import Tool, ToolResult, tool
|
@@ -25,6 +25,17 @@ Page Content: {text}
|
|
25
25
|
"""
|
26
26
|
|
27
27
|
|
28
|
+
class SearchLink:
|
29
|
+
def __init__(self, url: str, snippet: str) -> None:
|
30
|
+
self.url = url
|
31
|
+
self.snippet = snippet
|
32
|
+
|
33
|
+
|
34
|
+
@runtime_checkable
|
35
|
+
class SearchProvider(Protocol):
|
36
|
+
async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
|
37
|
+
|
38
|
+
|
28
39
|
@tool
|
29
40
|
def web_search(
|
30
41
|
provider: Literal["google"] = "google",
|
@@ -84,16 +95,22 @@ def web_search(
|
|
84
95
|
async with concurrency(f"{provider}_web_search", max_connections):
|
85
96
|
links = await search_provider(query, start_idx=search_calls * 10)
|
86
97
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
98
|
+
async with anyio.create_task_group() as tg:
|
99
|
+
|
100
|
+
async def process_link(link: SearchLink) -> None:
|
101
|
+
try:
|
102
|
+
page = await page_if_relevant(link.url, query, model, client)
|
103
|
+
if page:
|
104
|
+
page_contents.append(page)
|
105
|
+
urls.append(link.url)
|
106
|
+
snippets.append(link.snippet)
|
107
|
+
# exceptions fetching pages are very common!
|
108
|
+
except Exception:
|
109
|
+
pass
|
110
|
+
|
111
|
+
for lk in links:
|
112
|
+
tg.start_soon(process_link, lk)
|
113
|
+
|
97
114
|
search_calls += 1
|
98
115
|
|
99
116
|
all_page_contents = "\n\n".join(page_contents)
|
@@ -168,17 +185,6 @@ async def page_if_relevant(
|
|
168
185
|
return None
|
169
186
|
|
170
187
|
|
171
|
-
class SearchLink:
|
172
|
-
def __init__(self, url: str, snippet: str) -> None:
|
173
|
-
self.url = url
|
174
|
-
self.snippet = snippet
|
175
|
-
|
176
|
-
|
177
|
-
@runtime_checkable
|
178
|
-
class SearchProvider(Protocol):
|
179
|
-
async def __call__(self, query: str, start_idx: int) -> list[SearchLink]: ...
|
180
|
-
|
181
|
-
|
182
188
|
def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
|
183
189
|
google_api_key = os.environ.get("GOOGLE_CSE_API_KEY", None)
|
184
190
|
google_cse_id = os.environ.get("GOOGLE_CSE_ID", None)
|
@@ -204,7 +210,7 @@ def google_search_provider(client: httpx.AsyncClient) -> SearchProvider:
|
|
204
210
|
wait=wait_exponential_jitter(),
|
205
211
|
stop=stop_after_attempt(5) | stop_after_delay(60),
|
206
212
|
retry=retry_if_exception(httpx_should_retry),
|
207
|
-
before_sleep=
|
213
|
+
before_sleep=log_httpx_retry_attempt(search_url),
|
208
214
|
)
|
209
215
|
async def execute_search() -> httpx.Response:
|
210
216
|
return await client.get(search_url)
|
inspect_ai/util/__init__.py
CHANGED
@@ -3,6 +3,7 @@ from inspect_ai._util.trace import trace_action, trace_message
|
|
3
3
|
from ._concurrency import concurrency
|
4
4
|
from ._console import input_screen
|
5
5
|
from ._display import DisplayType, display_counter, display_type
|
6
|
+
from ._json import JSONSchema, JSONType, json_schema
|
6
7
|
from ._panel import InputPanel, input_panel
|
7
8
|
from ._resource import resource
|
8
9
|
from ._sandbox import (
|
@@ -36,6 +37,9 @@ __all__ = [
|
|
36
37
|
"InputPanel",
|
37
38
|
"input_panel",
|
38
39
|
"input_screen",
|
40
|
+
"JSONType",
|
41
|
+
"JSONSchema",
|
42
|
+
"json_schema",
|
39
43
|
"OutputLimitExceededError",
|
40
44
|
"resource",
|
41
45
|
"subprocess",
|
inspect_ai/util/_concurrency.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
|
-
import asyncio
|
2
1
|
import contextlib
|
3
2
|
import time
|
4
3
|
from dataclasses import dataclass
|
5
4
|
from typing import AsyncIterator
|
6
5
|
|
6
|
+
import anyio
|
7
|
+
|
7
8
|
from inspect_ai._util.working import report_sample_waiting_time
|
8
9
|
|
9
10
|
|
@@ -45,9 +46,7 @@ async def concurrency(
|
|
45
46
|
# do we have an existing semaphore? if not create one and store it
|
46
47
|
semaphore = _concurrency_semaphores.get(key, None)
|
47
48
|
if semaphore is None:
|
48
|
-
semaphore = ConcurencySempahore(
|
49
|
-
name, concurrency, asyncio.Semaphore(concurrency)
|
50
|
-
)
|
49
|
+
semaphore = ConcurencySempahore(name, concurrency, anyio.Semaphore(concurrency))
|
51
50
|
_concurrency_semaphores[key] = semaphore
|
52
51
|
|
53
52
|
# wait and yield to protected code
|
@@ -60,7 +59,7 @@ async def concurrency(
|
|
60
59
|
def concurrency_status() -> dict[str, tuple[int, int]]:
|
61
60
|
status: dict[str, tuple[int, int]] = {}
|
62
61
|
for c in _concurrency_semaphores.values():
|
63
|
-
status[c.name] = (c.concurrency - c.semaphore.
|
62
|
+
status[c.name] = (c.concurrency - c.semaphore.value, c.concurrency)
|
64
63
|
return status
|
65
64
|
|
66
65
|
|
@@ -72,7 +71,7 @@ def init_concurrency() -> None:
|
|
72
71
|
class ConcurencySempahore:
|
73
72
|
name: str
|
74
73
|
concurrency: int
|
75
|
-
semaphore:
|
74
|
+
semaphore: anyio.Semaphore
|
76
75
|
|
77
76
|
|
78
77
|
_concurrency_semaphores: dict[str, ConcurencySempahore] = {}
|
inspect_ai/util/_display.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
from logging import getLogger
|
3
3
|
from typing import Literal
|
4
4
|
|
5
|
+
from inspect_ai._util._async import configured_async_backend
|
5
6
|
from inspect_ai._util.constants import DEFAULT_DISPLAY
|
6
7
|
from inspect_ai._util.thread import is_main_thread
|
7
8
|
|
@@ -20,6 +21,11 @@ def init_display_type(display: str | None = None) -> DisplayType:
|
|
20
21
|
display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
|
21
22
|
)
|
22
23
|
|
24
|
+
# if trio is configured as the backend then throttle down to "rich"
|
25
|
+
# (as textual uses asyncio directly so is not compatible with trio)
|
26
|
+
if configured_async_backend() == "trio" and display == "full":
|
27
|
+
display = "rich"
|
28
|
+
|
23
29
|
# if we are on a background thread then throttle down to "plain"
|
24
30
|
# ("full" requires textual which cannot run in a background thread
|
25
31
|
# b/c it calls the Python signal function; "rich" assumes exclusive
|
inspect_ai/util/_json.py
ADDED
@@ -0,0 +1,170 @@
|
|
1
|
+
import types
|
2
|
+
import typing
|
3
|
+
from dataclasses import is_dataclass
|
4
|
+
from typing import (
|
5
|
+
Any,
|
6
|
+
Dict,
|
7
|
+
List,
|
8
|
+
Literal,
|
9
|
+
Optional,
|
10
|
+
Tuple,
|
11
|
+
Type,
|
12
|
+
Union,
|
13
|
+
get_args,
|
14
|
+
get_origin,
|
15
|
+
get_type_hints,
|
16
|
+
is_typeddict,
|
17
|
+
)
|
18
|
+
|
19
|
+
from pydantic import BaseModel, Field
|
20
|
+
|
21
|
+
JSONType = Literal["string", "integer", "number", "boolean", "array", "object", "null"]
|
22
|
+
"""Valid types within JSON schema."""
|
23
|
+
|
24
|
+
|
25
|
+
class JSONSchema(BaseModel):
|
26
|
+
"""JSON Schema for type."""
|
27
|
+
|
28
|
+
type: JSONType | None = Field(default=None)
|
29
|
+
"""JSON type of tool parameter."""
|
30
|
+
|
31
|
+
description: str | None = Field(default=None)
|
32
|
+
"""Parameter description."""
|
33
|
+
|
34
|
+
default: Any = Field(default=None)
|
35
|
+
"""Default value for parameter."""
|
36
|
+
|
37
|
+
enum: list[Any] | None = Field(default=None)
|
38
|
+
"""Valid values for enum parameters."""
|
39
|
+
|
40
|
+
items: Optional["JSONSchema"] = Field(default=None)
|
41
|
+
"""Valid type for array parameters."""
|
42
|
+
|
43
|
+
properties: dict[str, "JSONSchema"] | None = Field(default=None)
|
44
|
+
"""Valid fields for object parametrs."""
|
45
|
+
|
46
|
+
additionalProperties: Optional["JSONSchema"] | bool | None = Field(default=None)
|
47
|
+
"""Are additional properties allowed?"""
|
48
|
+
|
49
|
+
anyOf: list["JSONSchema"] | None = Field(default=None)
|
50
|
+
"""Valid types for union parameters."""
|
51
|
+
|
52
|
+
required: list[str] | None = Field(default=None)
|
53
|
+
"""Required fields for object parameters."""
|
54
|
+
|
55
|
+
|
56
|
+
def json_schema(t: Type[Any]) -> JSONSchema:
|
57
|
+
"""Provide a JSON Schema for the specified type.
|
58
|
+
|
59
|
+
Schemas can be automatically inferred for a wide variety of
|
60
|
+
Python class types including Pydantic BaseModel, dataclasses,
|
61
|
+
and typed dicts.
|
62
|
+
|
63
|
+
Args:
|
64
|
+
t: Python type
|
65
|
+
|
66
|
+
Returns:
|
67
|
+
JSON Schema for type.
|
68
|
+
"""
|
69
|
+
origin = get_origin(t)
|
70
|
+
args = get_args(t)
|
71
|
+
|
72
|
+
if origin is None:
|
73
|
+
if t is int:
|
74
|
+
return JSONSchema(type="integer")
|
75
|
+
elif t is float:
|
76
|
+
return JSONSchema(type="number")
|
77
|
+
elif t is str:
|
78
|
+
return JSONSchema(type="string")
|
79
|
+
elif t is bool:
|
80
|
+
return JSONSchema(type="boolean")
|
81
|
+
elif t is list:
|
82
|
+
return JSONSchema(type="array", items=JSONSchema())
|
83
|
+
elif t is dict:
|
84
|
+
return JSONSchema(type="object", additionalProperties=JSONSchema())
|
85
|
+
elif (
|
86
|
+
is_dataclass(t)
|
87
|
+
or is_typeddict(t)
|
88
|
+
or (isinstance(t, type) and issubclass(t, BaseModel))
|
89
|
+
):
|
90
|
+
return cls_json_schema(t)
|
91
|
+
elif t is type(None):
|
92
|
+
return JSONSchema(type="null")
|
93
|
+
else:
|
94
|
+
return JSONSchema()
|
95
|
+
elif origin is list or origin is List or origin is tuple or origin is Tuple:
|
96
|
+
return JSONSchema(
|
97
|
+
type="array", items=json_schema(args[0]) if args else JSONSchema()
|
98
|
+
)
|
99
|
+
elif origin is dict or origin is Dict:
|
100
|
+
return JSONSchema(
|
101
|
+
type="object",
|
102
|
+
additionalProperties=json_schema(args[1])
|
103
|
+
if len(args) > 1
|
104
|
+
else JSONSchema(),
|
105
|
+
)
|
106
|
+
elif origin is Union or origin is types.UnionType:
|
107
|
+
return JSONSchema(anyOf=[json_schema(arg) for arg in args])
|
108
|
+
elif origin is Optional:
|
109
|
+
return JSONSchema(
|
110
|
+
anyOf=[json_schema(arg) for arg in args] + [JSONSchema(type="null")]
|
111
|
+
)
|
112
|
+
elif origin is typing.Literal:
|
113
|
+
return JSONSchema(enum=list(args))
|
114
|
+
|
115
|
+
return JSONSchema() # Default case if we can't determine the type
|
116
|
+
|
117
|
+
|
118
|
+
def cls_json_schema(cls: Type[Any]) -> JSONSchema:
|
119
|
+
properties: Dict[str, JSONSchema] = {}
|
120
|
+
required: List[str] = []
|
121
|
+
|
122
|
+
if is_dataclass(cls):
|
123
|
+
fields = cls.__dataclass_fields__ # type: ignore
|
124
|
+
for name, field in fields.items():
|
125
|
+
properties[name] = json_schema(field.type) # type: ignore
|
126
|
+
if field.default == field.default_factory:
|
127
|
+
required.append(name)
|
128
|
+
elif isinstance(cls, type) and issubclass(cls, BaseModel):
|
129
|
+
schema = cls.model_json_schema()
|
130
|
+
for name, prop in schema.get("properties", {}).items():
|
131
|
+
properties[name] = JSONSchema(**prop)
|
132
|
+
required = schema.get("required", [])
|
133
|
+
elif is_typeddict(cls):
|
134
|
+
annotations = get_type_hints(cls)
|
135
|
+
for name, type_hint in annotations.items():
|
136
|
+
properties[name] = json_schema(type_hint)
|
137
|
+
if name in cls.__required_keys__:
|
138
|
+
required.append(name)
|
139
|
+
|
140
|
+
return JSONSchema(
|
141
|
+
type="object",
|
142
|
+
properties=properties,
|
143
|
+
required=required if required else None,
|
144
|
+
additionalProperties=False,
|
145
|
+
)
|
146
|
+
|
147
|
+
|
148
|
+
def python_type_to_json_type(python_type: str | None) -> JSONType:
|
149
|
+
match python_type:
|
150
|
+
case "str":
|
151
|
+
return "string"
|
152
|
+
case "int":
|
153
|
+
return "integer"
|
154
|
+
case "float":
|
155
|
+
return "number"
|
156
|
+
case "bool":
|
157
|
+
return "boolean"
|
158
|
+
case "list":
|
159
|
+
return "array"
|
160
|
+
case "dict":
|
161
|
+
return "object"
|
162
|
+
case "None":
|
163
|
+
return "null"
|
164
|
+
# treat 'unknown' as string as anything can be converted to string
|
165
|
+
case None:
|
166
|
+
return "string"
|
167
|
+
case _:
|
168
|
+
raise ValueError(
|
169
|
+
f"Unsupported type: {python_type} for Python to JSON conversion."
|
170
|
+
)
|
@@ -1,12 +1,14 @@
|
|
1
|
-
import asyncio
|
2
1
|
from contextvars import ContextVar
|
3
2
|
from pathlib import Path
|
4
3
|
from typing import Awaitable, Callable, Set
|
5
4
|
|
5
|
+
import anyio
|
6
6
|
from rich import box, print
|
7
7
|
from rich.panel import Panel
|
8
8
|
from rich.table import Table
|
9
9
|
|
10
|
+
from inspect_ai._util._async import coro_print_exceptions
|
11
|
+
|
10
12
|
from .compose import compose_down, compose_ls, compose_ps
|
11
13
|
from .config import is_auto_compose_file, safe_cleanup_auto_compose
|
12
14
|
from .util import ComposeProject
|
@@ -94,13 +96,15 @@ async def cleanup_projects(
|
|
94
96
|
)
|
95
97
|
|
96
98
|
# cleanup all of the projects in parallel
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
99
|
+
async with anyio.create_task_group() as tg:
|
100
|
+
for project in projects:
|
101
|
+
tg.start_soon(
|
102
|
+
coro_print_exceptions,
|
103
|
+
"cleaning up Docker environment",
|
104
|
+
cleanup_fn,
|
105
|
+
project,
|
106
|
+
False,
|
107
|
+
)
|
104
108
|
|
105
109
|
|
106
110
|
async def cli_cleanup(project_name: str | None) -> None:
|
@@ -141,7 +145,7 @@ def auto_compose_files() -> Set[str]:
|
|
141
145
|
|
142
146
|
|
143
147
|
_running_projects: ContextVar[list[ComposeProject]] = ContextVar(
|
144
|
-
"docker_running_projects"
|
148
|
+
"docker_running_projects", default=[]
|
145
149
|
)
|
146
150
|
|
147
151
|
_auto_compose_files: ContextVar[Set[str]] = ContextVar("docker_auto_compose_files")
|
@@ -293,6 +293,9 @@ class DockerSandboxEnvironment(SandboxEnvironment):
|
|
293
293
|
|
294
294
|
@override
|
295
295
|
async def write_file(self, file: str, contents: str | bytes) -> None:
|
296
|
+
# defualt timeout for write_file operations
|
297
|
+
TIMEOUT = 180
|
298
|
+
|
296
299
|
# resolve relative file paths
|
297
300
|
file = self.container_file(file)
|
298
301
|
|
@@ -309,6 +312,7 @@ class DockerSandboxEnvironment(SandboxEnvironment):
|
|
309
312
|
result = await self.exec(
|
310
313
|
["sh", "-e", "-c", 'tee -- "$1"', "write_file_script", file],
|
311
314
|
input=contents,
|
315
|
+
timeout=TIMEOUT,
|
312
316
|
)
|
313
317
|
else:
|
314
318
|
base64_contents = base64.b64encode(contents).decode("US-ASCII")
|
@@ -322,6 +326,7 @@ class DockerSandboxEnvironment(SandboxEnvironment):
|
|
322
326
|
file,
|
323
327
|
],
|
324
328
|
input=base64_contents,
|
329
|
+
timeout=TIMEOUT,
|
325
330
|
)
|
326
331
|
if result.returncode != 0:
|
327
332
|
if "permission denied" in result.stderr.casefold():
|
@@ -3,18 +3,19 @@ from __future__ import annotations
|
|
3
3
|
import abc
|
4
4
|
from dataclasses import dataclass, field
|
5
5
|
from typing import (
|
6
|
+
Annotated,
|
6
7
|
Any,
|
7
8
|
Awaitable,
|
8
9
|
Callable,
|
9
10
|
Literal,
|
10
|
-
NamedTuple,
|
11
11
|
Type,
|
12
12
|
TypeVar,
|
13
13
|
Union,
|
14
|
+
cast,
|
14
15
|
overload,
|
15
16
|
)
|
16
17
|
|
17
|
-
from pydantic import BaseModel, Field
|
18
|
+
from pydantic import BaseModel, Field, model_validator
|
18
19
|
|
19
20
|
from .._subprocess import ExecResult
|
20
21
|
|
@@ -38,6 +39,7 @@ SampleCleanup = Callable[
|
|
38
39
|
],
|
39
40
|
Awaitable[None],
|
40
41
|
]
|
42
|
+
ConfigDeserialize = Callable[[dict[str, Any]], BaseModel]
|
41
43
|
|
42
44
|
|
43
45
|
class HostMapping(BaseModel):
|
@@ -211,11 +213,6 @@ class SandboxEnvironment(abc.ABC):
|
|
211
213
|
f"Expected instance of {sandbox_cls.__name__}, got {type(self).__name__}"
|
212
214
|
)
|
213
215
|
|
214
|
-
@classmethod
|
215
|
-
def config_files(cls) -> list[str]:
|
216
|
-
"""Standard config files for this provider (used for automatic discovery)"""
|
217
|
-
return []
|
218
|
-
|
219
216
|
@classmethod
|
220
217
|
def default_concurrency(cls) -> int | None:
|
221
218
|
"""Default max_sandboxes for this provider (`None` means no maximum)"""
|
@@ -296,6 +293,30 @@ class SandboxEnvironment(abc.ABC):
|
|
296
293
|
"""
|
297
294
|
pass
|
298
295
|
|
296
|
+
@classmethod
|
297
|
+
def config_files(cls) -> list[str]:
|
298
|
+
"""Standard config files for this provider (used for automatic discovery)"""
|
299
|
+
return []
|
300
|
+
|
301
|
+
@classmethod
|
302
|
+
def config_deserialize(cls, config: dict[str, Any]) -> BaseModel:
|
303
|
+
"""Deserialize a sandbox-specific configuration model from a dict.
|
304
|
+
|
305
|
+
Override this method if you support a custom configuration model.
|
306
|
+
|
307
|
+
A basic implementation would be: `return MySandboxEnvironmentConfig(**config)`
|
308
|
+
|
309
|
+
Args:
|
310
|
+
config: Configuration dictionary produced by serializing the configuration
|
311
|
+
model.
|
312
|
+
|
313
|
+
Returns:
|
314
|
+
The sandbox-specific configuration model.
|
315
|
+
"""
|
316
|
+
raise NotImplementedError(
|
317
|
+
"The SandboxEnvironment provider has not implemented config_deserialize."
|
318
|
+
)
|
319
|
+
|
299
320
|
|
300
321
|
@dataclass
|
301
322
|
class SandboxEnvironments:
|
@@ -311,15 +332,30 @@ class SandboxEnvironments:
|
|
311
332
|
"""
|
312
333
|
|
313
334
|
|
314
|
-
class SandboxEnvironmentSpec(
|
335
|
+
class SandboxEnvironmentSpec(BaseModel, frozen=True):
|
315
336
|
"""Specification of a SandboxEnvironment."""
|
316
337
|
|
317
338
|
type: str
|
318
339
|
"""Sandbox type (e.g. 'local', 'docker')"""
|
319
340
|
|
320
|
-
|
341
|
+
# Any is used to prevent Pydantic from trying to initialise a BaseModel.
|
342
|
+
config: Annotated[Any, "BaseModel, str or None"] = None
|
321
343
|
"""Sandbox configuration (filename or config object)."""
|
322
344
|
|
345
|
+
def __init__(self, type: str, config: BaseModel | str | None = None):
|
346
|
+
super().__init__(type=type, config=config)
|
347
|
+
|
348
|
+
@model_validator(mode="before")
|
349
|
+
@classmethod
|
350
|
+
def load_config_model(cls, data: dict[str, Any]) -> dict[str, Any]:
|
351
|
+
type = data["type"]
|
352
|
+
config = data.get("config")
|
353
|
+
# Pydantic won't know what concrete type to instantiate for config, so
|
354
|
+
# ask the relevant sandbox environment to deserialize it.
|
355
|
+
if isinstance(config, dict) and len(config) > 0:
|
356
|
+
data["config"] = deserialize_sandbox_specific_config(type, config)
|
357
|
+
return data
|
358
|
+
|
323
359
|
|
324
360
|
SandboxEnvironmentConfigType = BaseModel | str
|
325
361
|
|
@@ -343,3 +379,14 @@ def resolve_sandbox_environment(
|
|
343
379
|
return SandboxEnvironmentSpec(sandbox[0], sandbox[1])
|
344
380
|
else:
|
345
381
|
return None
|
382
|
+
|
383
|
+
|
384
|
+
def deserialize_sandbox_specific_config(type: str, config: dict[str, Any]) -> BaseModel:
|
385
|
+
# Avoid circular import
|
386
|
+
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv
|
387
|
+
|
388
|
+
sandboxenv_type = registry_find_sandboxenv(type)
|
389
|
+
config_deserialize = cast(
|
390
|
+
ConfigDeserialize, getattr(sandboxenv_type, "config_deserialize")
|
391
|
+
)
|
392
|
+
return config_deserialize(config)
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import asyncio
|
2
1
|
import json
|
3
2
|
from logging import getLogger
|
4
3
|
from pathlib import PurePosixPath
|
@@ -9,8 +8,10 @@ from typing import (
|
|
9
8
|
cast,
|
10
9
|
)
|
11
10
|
|
11
|
+
import anyio
|
12
12
|
from pydantic import JsonValue
|
13
13
|
|
14
|
+
from inspect_ai._util._async import coro_log_exceptions
|
14
15
|
from inspect_ai.util._subprocess import ExecResult
|
15
16
|
|
16
17
|
from .environment import SandboxEnvironment
|
@@ -59,7 +60,7 @@ async def sandbox_service(
|
|
59
60
|
|
60
61
|
# wait for and process methods
|
61
62
|
while not until():
|
62
|
-
await
|
63
|
+
await anyio.sleep(POLLING_INTERVAL)
|
63
64
|
await service.handle_requests()
|
64
65
|
|
65
66
|
|
@@ -141,9 +142,15 @@ class SandboxService:
|
|
141
142
|
if result.success:
|
142
143
|
request_files = result.stdout.strip().splitlines()
|
143
144
|
if request_files:
|
144
|
-
|
145
|
-
|
146
|
-
|
145
|
+
async with anyio.create_task_group() as tg:
|
146
|
+
for file in request_files:
|
147
|
+
tg.start_soon(
|
148
|
+
coro_log_exceptions,
|
149
|
+
logger,
|
150
|
+
"handling sandbox service request",
|
151
|
+
self._handle_request,
|
152
|
+
file,
|
153
|
+
)
|
147
154
|
|
148
155
|
async def _handle_request(self, request_file: str) -> None:
|
149
156
|
# read request
|