inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +0 -8
- inspect_ai/_display/textual/widgets/samples.py +1 -1
- inspect_ai/_eval/eval.py +10 -1
- inspect_ai/_eval/loader.py +79 -19
- inspect_ai/_eval/registry.py +6 -0
- inspect_ai/_eval/score.py +2 -1
- inspect_ai/_eval/task/generate.py +41 -35
- inspect_ai/_eval/task/results.py +6 -5
- inspect_ai/_eval/task/run.py +21 -15
- inspect_ai/_util/hooks.py +17 -7
- inspect_ai/_view/www/dist/assets/index.js +262 -303
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/App.mjs +6 -6
- inspect_ai/_view/www/src/Types.mjs +1 -1
- inspect_ai/_view/www/src/api/Types.ts +133 -0
- inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
- inspect_ai/_view/www/src/api/api-http.ts +219 -0
- inspect_ai/_view/www/src/api/api-shared.ts +47 -0
- inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
- inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
- inspect_ai/_view/www/src/api/index.ts +51 -0
- inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
- inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
- inspect_ai/_view/www/src/index.js +2 -2
- inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
- inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
- inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
- inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
- inspect_ai/_view/www/src/utils/vscode.ts +36 -0
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/manager.py +1 -1
- inspect_ai/model/_call_tools.py +55 -0
- inspect_ai/model/_chat_message.py +2 -2
- inspect_ai/model/_conversation.py +1 -4
- inspect_ai/model/_generate_config.py +2 -8
- inspect_ai/model/_model.py +90 -25
- inspect_ai/model/_model_output.py +15 -0
- inspect_ai/model/_openai.py +383 -0
- inspect_ai/model/_providers/anthropic.py +52 -14
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/goodfire.py +248 -0
- inspect_ai/model/_providers/groq.py +7 -3
- inspect_ai/model/_providers/hf.py +6 -0
- inspect_ai/model/_providers/mistral.py +2 -1
- inspect_ai/model/_providers/openai.py +36 -202
- inspect_ai/model/_providers/openai_o1.py +2 -4
- inspect_ai/model/_providers/providers.py +22 -0
- inspect_ai/model/_providers/together.py +4 -4
- inspect_ai/model/_providers/util/__init__.py +2 -3
- inspect_ai/model/_providers/util/hf_handler.py +1 -1
- inspect_ai/model/_providers/util/llama31.py +1 -1
- inspect_ai/model/_providers/util/util.py +0 -76
- inspect_ai/scorer/_metric.py +3 -0
- inspect_ai/scorer/_scorer.py +2 -1
- inspect_ai/solver/__init__.py +4 -0
- inspect_ai/solver/_basic_agent.py +65 -55
- inspect_ai/solver/_bridge/__init__.py +3 -0
- inspect_ai/solver/_bridge/bridge.py +100 -0
- inspect_ai/solver/_bridge/patch.py +170 -0
- inspect_ai/{util → solver}/_limit.py +13 -0
- inspect_ai/solver/_solver.py +6 -0
- inspect_ai/solver/_task_state.py +37 -7
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
- inspect_ai/util/__init__.py +0 -2
- inspect_ai/util/_display.py +5 -0
- inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
- inspect_ai/util/_sandbox/self_check.py +51 -28
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
- inspect_ai/_view/www/src/api/Types.mjs +0 -117
- inspect_ai/_view/www/src/api/api-http.mjs +0 -300
- inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
- inspect_ai/_view/www/src/api/index.mjs +0 -49
- inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
- inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from inspect_ai.solver._chain import chain
|
|
13
13
|
from inspect_ai.tool._tool import Tool, ToolResult, tool
|
14
14
|
from inspect_ai.tool._tool_with import tool_with
|
15
15
|
|
16
|
+
from ._limit import SampleLimitExceededError
|
16
17
|
from ._prompt import system_message
|
17
18
|
from ._solver import Generate, Solver, solver
|
18
19
|
from ._task_state import TaskState
|
@@ -119,7 +120,7 @@ def basic_agent(
|
|
119
120
|
# resolve tools
|
120
121
|
if tools is None:
|
121
122
|
tools = []
|
122
|
-
tools = tools if isinstance(tools, Solver) else use_tools(tools)
|
123
|
+
tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
|
123
124
|
|
124
125
|
# resolve score_value function
|
125
126
|
score_value_fn = score_value or value_to_float()
|
@@ -167,61 +168,70 @@ def basic_agent(
|
|
167
168
|
# track attempts
|
168
169
|
attempts = 0
|
169
170
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
state.messages.append(state.output.message)
|
177
|
-
|
178
|
-
# check for context window overflow
|
179
|
-
if state.output.stop_reason == "model_length":
|
180
|
-
from inspect_ai.log._transcript import transcript
|
181
|
-
|
182
|
-
transcript().info("Agent terminated: model context window exceeded")
|
183
|
-
break
|
184
|
-
|
185
|
-
# resolve tools calls (if any)
|
186
|
-
if state.output.message.tool_calls:
|
187
|
-
# call tool functions
|
188
|
-
tool_results = await call_tools(
|
189
|
-
state.output.message, state.tools, max_output=max_tool_output
|
171
|
+
try:
|
172
|
+
# main loop (state.completed checks message_limit and token_limit)
|
173
|
+
while not state.completed:
|
174
|
+
# generate output and append assistant message
|
175
|
+
state.output = await get_model().generate(
|
176
|
+
input=state.messages, tools=state.tools, cache=cache
|
190
177
|
)
|
191
|
-
state.messages.
|
192
|
-
|
193
|
-
#
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
178
|
+
state.messages.append(state.output.message)
|
179
|
+
|
180
|
+
# check for context window overflow
|
181
|
+
if state.output.stop_reason == "model_length":
|
182
|
+
from inspect_ai.log._transcript import transcript
|
183
|
+
|
184
|
+
transcript().info(
|
185
|
+
"Agent terminated: model context window exceeded"
|
186
|
+
)
|
187
|
+
break
|
188
|
+
|
189
|
+
# resolve tools calls (if any)
|
190
|
+
if state.output.message.tool_calls:
|
191
|
+
# call tool functions
|
192
|
+
tool_results = await call_tools(
|
193
|
+
state.output.message,
|
194
|
+
state.tools,
|
195
|
+
max_output=max_tool_output,
|
196
|
+
)
|
197
|
+
state.messages.extend(tool_results)
|
198
|
+
|
199
|
+
# was an answer submitted?
|
200
|
+
answer = submission(tool_results)
|
201
|
+
if answer:
|
202
|
+
# set the output to the answer for scoring
|
203
|
+
state.output.completion = answer
|
204
|
+
|
205
|
+
# exit if we are at max_attempts
|
206
|
+
attempts += 1
|
207
|
+
if attempts >= max_attempts:
|
208
|
+
state.completed = True
|
209
|
+
break
|
210
|
+
|
211
|
+
# exit if the submission is successful
|
212
|
+
answer_scores = await score(state)
|
213
|
+
if score_value_fn(answer_scores[0].value) == 1.0:
|
214
|
+
state.completed = True
|
215
|
+
break
|
216
|
+
|
217
|
+
# otherwise notify the model that it was incorrect and continue
|
218
|
+
else:
|
219
|
+
response_message = (
|
220
|
+
incorrect_message(state, answer_scores)
|
221
|
+
if callable(incorrect_message)
|
222
|
+
else incorrect_message
|
223
|
+
)
|
224
|
+
state.messages.append(
|
225
|
+
ChatMessageUser(content=response_message)
|
226
|
+
)
|
227
|
+
|
228
|
+
# no tool calls, urge the model to continue
|
229
|
+
else:
|
230
|
+
state.messages.append(ChatMessageUser(content=continue_message))
|
231
|
+
|
232
|
+
# propagate current state along with sample limit exceeded
|
233
|
+
except SampleLimitExceededError as ex:
|
234
|
+
raise ex.with_state(state)
|
225
235
|
|
226
236
|
return state
|
227
237
|
|
@@ -0,0 +1,100 @@
|
|
1
|
+
from typing import Any, Awaitable, Callable
|
2
|
+
|
3
|
+
from jsonschema import Draft7Validator
|
4
|
+
from pydantic import BaseModel, Field, ValidationError
|
5
|
+
from pydantic_core import to_json
|
6
|
+
|
7
|
+
from inspect_ai._util._async import is_callable_coroutine
|
8
|
+
from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
|
9
|
+
from inspect_ai.model._providers.providers import validate_openai_client
|
10
|
+
from inspect_ai.scorer._metric import Score
|
11
|
+
|
12
|
+
from .._solver import Generate, Solver, solver
|
13
|
+
from .._task_state import TaskState
|
14
|
+
|
15
|
+
|
16
|
+
@solver
|
17
|
+
def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
|
18
|
+
"""Bridge an external agent into an Inspect Solver.
|
19
|
+
|
20
|
+
See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
|
21
|
+
|
22
|
+
Args:
|
23
|
+
agent: Callable which takes a sample `dict` and returns a result `dict`.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
Standard Inspect solver.
|
27
|
+
"""
|
28
|
+
validate_openai_client("Solver bridge()")
|
29
|
+
|
30
|
+
from openai.types.chat import ChatCompletionMessageParam
|
31
|
+
|
32
|
+
from inspect_ai.model._openai import (
|
33
|
+
chat_messages_from_openai,
|
34
|
+
openai_chat_messages,
|
35
|
+
)
|
36
|
+
|
37
|
+
from .patch import openai_request_to_inspect_model
|
38
|
+
|
39
|
+
class BridgeSample(BaseModel):
|
40
|
+
sample_id: str
|
41
|
+
epoch: int
|
42
|
+
input: list[ChatCompletionMessageParam]
|
43
|
+
metadata: dict[str, Any]
|
44
|
+
target: list[str]
|
45
|
+
|
46
|
+
class BridgeResult(BaseModel):
|
47
|
+
output: str
|
48
|
+
messages: list[ChatCompletionMessageParam] | None = Field(default=None)
|
49
|
+
scores: dict[str, Score] | None = Field(default=None)
|
50
|
+
|
51
|
+
result_schema = BridgeResult.model_json_schema()
|
52
|
+
result_validator = Draft7Validator(result_schema)
|
53
|
+
|
54
|
+
# validate that the agent is an async function
|
55
|
+
if not is_callable_coroutine(agent):
|
56
|
+
raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
|
57
|
+
|
58
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
59
|
+
# resolve input to array
|
60
|
+
input: list[ChatMessage] = (
|
61
|
+
[ChatMessageUser(content=state.input)]
|
62
|
+
if isinstance(state.input, str)
|
63
|
+
else state.input
|
64
|
+
)
|
65
|
+
|
66
|
+
# create sample
|
67
|
+
sample = BridgeSample(
|
68
|
+
sample_id=str(state.sample_id),
|
69
|
+
epoch=state.epoch,
|
70
|
+
input=await openai_chat_messages(input, state.model.name),
|
71
|
+
metadata=state.metadata,
|
72
|
+
target=list(state.target),
|
73
|
+
)
|
74
|
+
|
75
|
+
# run target function
|
76
|
+
async with openai_request_to_inspect_model():
|
77
|
+
# call the function
|
78
|
+
result_dict = await agent(sample.model_dump())
|
79
|
+
try:
|
80
|
+
result = BridgeResult.model_validate(result_dict)
|
81
|
+
except ValidationError:
|
82
|
+
# if we fail to validate provide a better human readable error
|
83
|
+
errors = list(result_validator.iter_errors(result_dict))
|
84
|
+
message = "\n".join(
|
85
|
+
["Result returned from bridged solver is not valid:"]
|
86
|
+
+ [f" - {error.message}" for error in errors]
|
87
|
+
+ ["", to_json(result_dict, indent=2).decode()]
|
88
|
+
)
|
89
|
+
raise ValueError(message)
|
90
|
+
|
91
|
+
# update and return state
|
92
|
+
state.output.completion = result.output
|
93
|
+
if result.messages is not None:
|
94
|
+
state.messages = chat_messages_from_openai(result.messages)
|
95
|
+
if result.scores is not None:
|
96
|
+
state.scores = result.scores
|
97
|
+
|
98
|
+
return state
|
99
|
+
|
100
|
+
return solve
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import contextlib
|
2
|
+
import re
|
3
|
+
from contextvars import ContextVar
|
4
|
+
from functools import wraps
|
5
|
+
from time import time
|
6
|
+
from typing import Any, AsyncGenerator, Optional, Type, cast
|
7
|
+
|
8
|
+
from openai._base_client import AsyncAPIClient, _AsyncStreamT
|
9
|
+
from openai._models import FinalRequestOptions
|
10
|
+
from openai._types import ResponseT
|
11
|
+
from openai.types.chat import (
|
12
|
+
ChatCompletion,
|
13
|
+
ChatCompletionMessageParam,
|
14
|
+
ChatCompletionToolParam,
|
15
|
+
)
|
16
|
+
from shortuuid import uuid
|
17
|
+
|
18
|
+
from inspect_ai.model._generate_config import GenerateConfig
|
19
|
+
from inspect_ai.model._model import get_model
|
20
|
+
from inspect_ai.model._openai import (
|
21
|
+
chat_messages_from_openai,
|
22
|
+
openai_chat_choices,
|
23
|
+
openai_completion_usage,
|
24
|
+
)
|
25
|
+
from inspect_ai.solver._task_state import sample_state
|
26
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
27
|
+
from inspect_ai.tool._tool_params import ToolParams
|
28
|
+
|
29
|
+
|
30
|
+
@contextlib.asynccontextmanager
|
31
|
+
async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
|
32
|
+
# ensure one time init
|
33
|
+
init_openai_request_patch()
|
34
|
+
|
35
|
+
# set the patch enabled for this context and child coroutines
|
36
|
+
token = _patch_enabled.set(True)
|
37
|
+
try:
|
38
|
+
yield
|
39
|
+
finally:
|
40
|
+
_patch_enabled.reset(token)
|
41
|
+
|
42
|
+
|
43
|
+
_patch_initialised: bool = False
|
44
|
+
|
45
|
+
_patch_enabled: ContextVar[bool] = ContextVar(
|
46
|
+
"openai_request_patch_enabled", default=False
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
def init_openai_request_patch() -> None:
|
51
|
+
global _patch_initialised
|
52
|
+
if not _patch_initialised:
|
53
|
+
# get reference to original method
|
54
|
+
original_request = getattr(AsyncAPIClient, "request")
|
55
|
+
if original_request is None:
|
56
|
+
raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
|
57
|
+
|
58
|
+
@wraps(original_request)
|
59
|
+
async def patched_request(
|
60
|
+
self: AsyncAPIClient,
|
61
|
+
cast_to: Type[ResponseT],
|
62
|
+
options: FinalRequestOptions,
|
63
|
+
*,
|
64
|
+
stream: bool = False,
|
65
|
+
stream_cls: type[_AsyncStreamT] | None = None,
|
66
|
+
remaining_retries: Optional[int] = None,
|
67
|
+
) -> Any:
|
68
|
+
# we have patched the underlying request method so now need to figure out when to
|
69
|
+
# patch and when to stand down
|
70
|
+
if (
|
71
|
+
# enabled for this coroutine
|
72
|
+
_patch_enabled.get()
|
73
|
+
# completions request
|
74
|
+
and options.url == "/chat/completions"
|
75
|
+
# call to openai not another service (e.g. TogetherAI)
|
76
|
+
and self.base_url == "https://api.openai.com/v1/"
|
77
|
+
):
|
78
|
+
# must also be an explicit request for an inspect model
|
79
|
+
json_data = cast(dict[str, Any], options.json_data)
|
80
|
+
model_name = str(json_data["model"])
|
81
|
+
if re.match(r"^inspect/?", model_name):
|
82
|
+
return await inspect_model_request(model_name, options)
|
83
|
+
|
84
|
+
# otherwise just delegate
|
85
|
+
return await original_request(
|
86
|
+
self,
|
87
|
+
cast_to,
|
88
|
+
options,
|
89
|
+
stream=stream,
|
90
|
+
stream_cls=stream_cls,
|
91
|
+
remaining_retries=remaining_retries,
|
92
|
+
)
|
93
|
+
|
94
|
+
setattr(AsyncAPIClient, "request", patched_request)
|
95
|
+
|
96
|
+
|
97
|
+
async def inspect_model_request(
|
98
|
+
model_name: str, options: FinalRequestOptions
|
99
|
+
) -> ChatCompletion:
|
100
|
+
# convert openai messages to inspect messages
|
101
|
+
json_data = cast(dict[str, Any], options.json_data)
|
102
|
+
messages: list[ChatCompletionMessageParam] = json_data["messages"]
|
103
|
+
input = chat_messages_from_openai(messages)
|
104
|
+
|
105
|
+
# convert openai tools to inspect tools
|
106
|
+
tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
|
107
|
+
inspect_tools: list[ToolInfo] = []
|
108
|
+
for tool in tools:
|
109
|
+
function = tool["function"].copy()
|
110
|
+
inspect_tools.append(
|
111
|
+
ToolInfo(
|
112
|
+
name=function["name"],
|
113
|
+
description=function["description"],
|
114
|
+
parameters=ToolParams.model_validate(function["parameters"]),
|
115
|
+
)
|
116
|
+
)
|
117
|
+
|
118
|
+
# resolve model
|
119
|
+
if model_name == "inspect":
|
120
|
+
model = get_model()
|
121
|
+
else:
|
122
|
+
model = get_model(model_name.removeprefix("inspect/"))
|
123
|
+
|
124
|
+
output = await model.generate(
|
125
|
+
input=input,
|
126
|
+
tools=inspect_tools,
|
127
|
+
config=generate_config_from_openai(options),
|
128
|
+
)
|
129
|
+
|
130
|
+
# if we are using the "default" inspect model for the task, update state.messages
|
131
|
+
if model_name == "inspect":
|
132
|
+
state = sample_state()
|
133
|
+
if state:
|
134
|
+
state.messages = input + [output.choices[0].message]
|
135
|
+
|
136
|
+
# inspect completion to openai completion
|
137
|
+
return ChatCompletion(
|
138
|
+
id=uuid(),
|
139
|
+
created=int(time()),
|
140
|
+
object="chat.completion",
|
141
|
+
choices=openai_chat_choices(output.choices),
|
142
|
+
model=model_name,
|
143
|
+
usage=openai_completion_usage(output.usage) if output.usage else None,
|
144
|
+
)
|
145
|
+
|
146
|
+
|
147
|
+
def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
|
148
|
+
# get options dict
|
149
|
+
json_data = cast(dict[str, Any], options.json_data)
|
150
|
+
|
151
|
+
config = GenerateConfig()
|
152
|
+
config.max_tokens = json_data.get(
|
153
|
+
"max_completion_tokens", json_data.get("max_tokens", None)
|
154
|
+
)
|
155
|
+
config.top_p = json_data.get("top_p", None)
|
156
|
+
config.temperature = json_data.get("temperature", None)
|
157
|
+
stop = json_data.get("stop", None)
|
158
|
+
if stop:
|
159
|
+
config.stop_seqs = [stop] if isinstance(stop, str) else stop
|
160
|
+
config.frequency_penalty = json_data.get("frequency_penalty", None)
|
161
|
+
config.presence_penalty = json_data.get("presence_penalty", None)
|
162
|
+
config.seed = json_data.get("seed", None)
|
163
|
+
config.num_choices = json_data.get("n", None)
|
164
|
+
config.logprobs = json_data.get("logprobs", None)
|
165
|
+
config.top_logprobs = json_data.get("top_logprobs", None)
|
166
|
+
config.logit_bias = json_data.get("logit_bias", None)
|
167
|
+
config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
|
168
|
+
config.reasoning_effort = json_data.get("reasoning_effort", None)
|
169
|
+
|
170
|
+
return config
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Literal
|
2
2
|
|
3
|
+
from ._task_state import TaskState
|
4
|
+
|
3
5
|
|
4
6
|
class SampleLimitExceededError(Exception):
|
5
7
|
"""Exception raised when a sample limit is exceeded.
|
@@ -18,9 +20,20 @@ class SampleLimitExceededError(Exception):
|
|
18
20
|
value: int,
|
19
21
|
limit: int,
|
20
22
|
message: str | None = None,
|
23
|
+
state: TaskState | None = None,
|
21
24
|
) -> None:
|
22
25
|
self.type = type
|
23
26
|
self.value = value
|
24
27
|
self.limit = limit
|
25
28
|
self.message = f"Exceeded {type} limit: {limit:,}"
|
29
|
+
self.state = state
|
26
30
|
super().__init__(message)
|
31
|
+
|
32
|
+
def with_state(self, state: TaskState) -> "SampleLimitExceededError":
|
33
|
+
return SampleLimitExceededError(
|
34
|
+
self.type,
|
35
|
+
value=self.value,
|
36
|
+
limit=self.limit,
|
37
|
+
message=self.message,
|
38
|
+
state=state,
|
39
|
+
)
|
inspect_ai/solver/_solver.py
CHANGED
@@ -180,6 +180,7 @@ def solver(
|
|
180
180
|
solver_type, name if name else getattr(solver_type, "__name__")
|
181
181
|
)
|
182
182
|
|
183
|
+
@wraps(solver_type)
|
183
184
|
def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
|
184
185
|
solver = solver_type(*args, **kwargs)
|
185
186
|
|
@@ -193,6 +194,7 @@ def solver(
|
|
193
194
|
if inspect.isclass(type(solver)):
|
194
195
|
original_call = solver.__call__
|
195
196
|
|
197
|
+
@wraps(original_call)
|
196
198
|
async def call_with_state(
|
197
199
|
state: TaskState, generate: Generate
|
198
200
|
) -> TaskState:
|
@@ -225,6 +227,10 @@ def solver(
|
|
225
227
|
|
226
228
|
return registered_solver
|
227
229
|
|
230
|
+
# functools.wraps overrides the return type annotation of the inner function, so
|
231
|
+
# we explicitly set it again
|
232
|
+
solver_wrapper.__annotations__["return"] = Solver
|
233
|
+
|
228
234
|
return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
|
229
235
|
|
230
236
|
# for decorators with an explicit name, one more wrapper for the name
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -22,7 +22,6 @@ from inspect_ai.scorer._metric import Score
|
|
22
22
|
from inspect_ai.scorer._target import Target
|
23
23
|
from inspect_ai.tool import Tool, ToolChoice
|
24
24
|
from inspect_ai.tool._tool_def import ToolDef
|
25
|
-
from inspect_ai.util._limit import SampleLimitExceededError
|
26
25
|
from inspect_ai.util._store import Store, store_jsonable
|
27
26
|
from inspect_ai.util._store_model import SMT
|
28
27
|
|
@@ -173,7 +172,7 @@ class TaskState:
|
|
173
172
|
self.metadata = metadata
|
174
173
|
"""Metadata from the `Sample` for this `TaskState`"""
|
175
174
|
|
176
|
-
self._messages: list[ChatMessage] = ChatMessageList(messages)
|
175
|
+
self._messages: list[ChatMessage] = ChatMessageList(messages, self)
|
177
176
|
"""
|
178
177
|
Chat conversation history for sample.
|
179
178
|
|
@@ -272,7 +271,7 @@ class TaskState:
|
|
272
271
|
@messages.setter
|
273
272
|
def messages(self, messages: list[ChatMessage]) -> None:
|
274
273
|
"""Set messages in chat history."""
|
275
|
-
self._messages = ChatMessageList(messages)
|
274
|
+
self._messages = ChatMessageList(messages, self)
|
276
275
|
|
277
276
|
@property
|
278
277
|
def max_messages(self) -> int | None:
|
@@ -319,8 +318,32 @@ class TaskState:
|
|
319
318
|
|
320
319
|
@property
|
321
320
|
def completed(self) -> bool:
|
322
|
-
"""Is the task completed.
|
323
|
-
|
321
|
+
"""Is the task completed.
|
322
|
+
|
323
|
+
Additionally, checks message and token limits and raises if they are exceeded.
|
324
|
+
"""
|
325
|
+
from inspect_ai.log._samples import set_active_sample_total_messages
|
326
|
+
|
327
|
+
from ._limit import SampleLimitExceededError
|
328
|
+
|
329
|
+
# update messages
|
330
|
+
set_active_sample_total_messages(len(self.messages))
|
331
|
+
|
332
|
+
if self._completed:
|
333
|
+
return True
|
334
|
+
elif self.message_limit and len(self.messages) >= self.message_limit:
|
335
|
+
raise SampleLimitExceededError(
|
336
|
+
"message",
|
337
|
+
value=len(self.messages),
|
338
|
+
limit=self.message_limit,
|
339
|
+
state=self,
|
340
|
+
)
|
341
|
+
elif self.token_limit and self.token_usage >= self.token_limit:
|
342
|
+
raise SampleLimitExceededError(
|
343
|
+
"token", value=self.token_usage, limit=self.token_limit, state=self
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
return self._completed
|
324
347
|
|
325
348
|
@completed.setter
|
326
349
|
def completed(self, completed: bool) -> None:
|
@@ -403,7 +426,8 @@ def sample_jsonable(sample: Sample) -> dict[str, Any]:
|
|
403
426
|
|
404
427
|
|
405
428
|
class ChatMessageList(list[ChatMessage]):
|
406
|
-
def __init__(self, iterable: Iterable[ChatMessage]):
|
429
|
+
def __init__(self, iterable: Iterable[ChatMessage], parent_state: TaskState):
|
430
|
+
self.parent_state = parent_state
|
407
431
|
items, length = self._iterable_length(iterable)
|
408
432
|
self._check_size(length)
|
409
433
|
super().__init__(items)
|
@@ -411,12 +435,18 @@ class ChatMessageList(list[ChatMessage]):
|
|
411
435
|
def _check_size(self, additional_items: int = 1) -> None:
|
412
436
|
from inspect_ai.log._samples import active_sample_message_limit
|
413
437
|
|
438
|
+
from ._limit import SampleLimitExceededError
|
439
|
+
|
414
440
|
messages_limit = active_sample_message_limit()
|
415
441
|
if messages_limit is not None:
|
416
442
|
messages = len(self) + additional_items
|
417
443
|
if messages > messages_limit:
|
418
444
|
raise SampleLimitExceededError(
|
419
|
-
"message",
|
445
|
+
"message",
|
446
|
+
value=messages,
|
447
|
+
limit=messages_limit,
|
448
|
+
message=None,
|
449
|
+
state=self.parent_state,
|
420
450
|
)
|
421
451
|
|
422
452
|
def append(self, item: ChatMessage) -> None:
|
@@ -345,7 +345,9 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
|
|
345
345
|
if sandbox_env:
|
346
346
|
store = store_as(WebBrowserStore)
|
347
347
|
if not store.session_id:
|
348
|
-
result = await sandbox_env.exec(
|
348
|
+
result = await sandbox_env.exec(
|
349
|
+
["python3", WEB_CLIENT_NEW_SESSION], timeout=180
|
350
|
+
)
|
349
351
|
|
350
352
|
if not result.success:
|
351
353
|
raise RuntimeError(
|
@@ -33,8 +33,6 @@ RUN apt-get update && \
|
|
33
33
|
|
34
34
|
# Userland apt-get'able apps
|
35
35
|
RUN apt-get install -y --no-install-recommends \
|
36
|
-
# A simple image viewer.
|
37
|
-
xpaint \
|
38
36
|
# A calculator application.
|
39
37
|
galculator && \
|
40
38
|
apt-get clean
|
@@ -78,7 +76,7 @@ RUN useradd -m -s /bin/bash -d $HOME $USERNAME
|
|
78
76
|
RUN echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
|
79
77
|
USER ${USERNAME}
|
80
78
|
WORKDIR $HOME
|
81
|
-
|
79
|
+
ADD --chown=$USERNAME:$USERNAME image_home_dir/ $HOME
|
82
80
|
|
83
81
|
# configure Firefox to skip all 'first run' UI
|
84
82
|
RUN mkdir -p $HOME/.mozilla/firefox-esr/profile.default && \
|
@@ -0,0 +1,10 @@
|
|
1
|
+
<?xml version="1.0" encoding="UTF-8"?>
|
2
|
+
|
3
|
+
<channel name="xfce4-screensaver" version="1.0">
|
4
|
+
<property name="saver" type="empty">
|
5
|
+
<property name="mode" type="int" value="0" />
|
6
|
+
</property>
|
7
|
+
<property name="lock" type="empty">
|
8
|
+
<property name="enabled" type="bool" value="false" />
|
9
|
+
</property>
|
10
|
+
</channel>
|
inspect_ai/util/__init__.py
CHANGED
@@ -3,7 +3,6 @@ from inspect_ai._util.trace import trace_action, trace_message
|
|
3
3
|
from ._concurrency import concurrency
|
4
4
|
from ._console import input_screen
|
5
5
|
from ._display import DisplayType, display_type
|
6
|
-
from ._limit import SampleLimitExceededError
|
7
6
|
from ._panel import InputPanel, input_panel
|
8
7
|
from ._resource import resource
|
9
8
|
from ._sandbox import (
|
@@ -37,7 +36,6 @@ __all__ = [
|
|
37
36
|
"input_panel",
|
38
37
|
"input_screen",
|
39
38
|
"OutputLimitExceededError",
|
40
|
-
"SampleLimitExceededError",
|
41
39
|
"resource",
|
42
40
|
"subprocess",
|
43
41
|
"SandboxEnvironment",
|
inspect_ai/util/_display.py
CHANGED
@@ -57,7 +57,7 @@ async def validate_docker_compose(
|
|
57
57
|
version: str = DOCKER_COMPOSE_REQUIRED_VERSION,
|
58
58
|
) -> None:
|
59
59
|
def parse_version(stdout: str) -> semver.Version:
|
60
|
-
version = json.loads(stdout)["version"].removeprefix("v")
|
60
|
+
version = json.loads(stdout)["version"].removeprefix("v").split("+")[0]
|
61
61
|
return semver.Version.parse(version)
|
62
62
|
|
63
63
|
await validate_version(
|