inspect-ai 0.3.59__py3-none-any.whl → 0.3.61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. inspect_ai/_cli/eval.py +0 -8
  2. inspect_ai/_display/textual/widgets/samples.py +1 -1
  3. inspect_ai/_eval/eval.py +10 -1
  4. inspect_ai/_eval/loader.py +79 -19
  5. inspect_ai/_eval/registry.py +6 -0
  6. inspect_ai/_eval/score.py +2 -1
  7. inspect_ai/_eval/task/generate.py +41 -35
  8. inspect_ai/_eval/task/results.py +6 -5
  9. inspect_ai/_eval/task/run.py +21 -15
  10. inspect_ai/_util/hooks.py +17 -7
  11. inspect_ai/_view/www/dist/assets/index.js +262 -303
  12. inspect_ai/_view/www/package.json +1 -1
  13. inspect_ai/_view/www/src/App.mjs +6 -6
  14. inspect_ai/_view/www/src/Types.mjs +1 -1
  15. inspect_ai/_view/www/src/api/Types.ts +133 -0
  16. inspect_ai/_view/www/src/api/{api-browser.mjs → api-browser.ts} +25 -13
  17. inspect_ai/_view/www/src/api/api-http.ts +219 -0
  18. inspect_ai/_view/www/src/api/api-shared.ts +47 -0
  19. inspect_ai/_view/www/src/api/{api-vscode.mjs → api-vscode.ts} +22 -19
  20. inspect_ai/_view/www/src/api/{client-api.mjs → client-api.ts} +93 -53
  21. inspect_ai/_view/www/src/api/index.ts +51 -0
  22. inspect_ai/_view/www/src/api/jsonrpc.ts +225 -0
  23. inspect_ai/_view/www/src/components/DownloadButton.mjs +1 -1
  24. inspect_ai/_view/www/src/index.js +2 -2
  25. inspect_ai/_view/www/src/log/{remoteLogFile.mjs → remoteLogFile.ts} +62 -46
  26. inspect_ai/_view/www/src/navbar/Navbar.mjs +1 -1
  27. inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +1 -1
  28. inspect_ai/_view/www/src/samples/SampleList.mjs +1 -1
  29. inspect_ai/_view/www/src/samples/SampleScores.mjs +1 -1
  30. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +14 -14
  31. inspect_ai/_view/www/src/samples/SamplesTab.mjs +10 -10
  32. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +2 -2
  33. inspect_ai/_view/www/src/utils/{Json.mjs → json-worker.ts} +1 -3
  34. inspect_ai/_view/www/src/utils/vscode.ts +36 -0
  35. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  36. inspect_ai/approval/_human/manager.py +1 -1
  37. inspect_ai/model/_call_tools.py +55 -0
  38. inspect_ai/model/_chat_message.py +2 -2
  39. inspect_ai/model/_conversation.py +1 -4
  40. inspect_ai/model/_generate_config.py +2 -8
  41. inspect_ai/model/_model.py +90 -25
  42. inspect_ai/model/_model_output.py +15 -0
  43. inspect_ai/model/_openai.py +383 -0
  44. inspect_ai/model/_providers/anthropic.py +52 -14
  45. inspect_ai/model/_providers/azureai.py +1 -1
  46. inspect_ai/model/_providers/goodfire.py +248 -0
  47. inspect_ai/model/_providers/groq.py +7 -3
  48. inspect_ai/model/_providers/hf.py +6 -0
  49. inspect_ai/model/_providers/mistral.py +2 -1
  50. inspect_ai/model/_providers/openai.py +36 -202
  51. inspect_ai/model/_providers/openai_o1.py +2 -4
  52. inspect_ai/model/_providers/providers.py +22 -0
  53. inspect_ai/model/_providers/together.py +4 -4
  54. inspect_ai/model/_providers/util/__init__.py +2 -3
  55. inspect_ai/model/_providers/util/hf_handler.py +1 -1
  56. inspect_ai/model/_providers/util/llama31.py +1 -1
  57. inspect_ai/model/_providers/util/util.py +0 -76
  58. inspect_ai/scorer/_metric.py +3 -0
  59. inspect_ai/scorer/_scorer.py +2 -1
  60. inspect_ai/solver/__init__.py +4 -0
  61. inspect_ai/solver/_basic_agent.py +65 -55
  62. inspect_ai/solver/_bridge/__init__.py +3 -0
  63. inspect_ai/solver/_bridge/bridge.py +100 -0
  64. inspect_ai/solver/_bridge/patch.py +170 -0
  65. inspect_ai/{util → solver}/_limit.py +13 -0
  66. inspect_ai/solver/_solver.py +6 -0
  67. inspect_ai/solver/_task_state.py +37 -7
  68. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -1
  69. inspect_ai/tool/beta/_computer/_resources/Dockerfile +1 -3
  70. inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +1 -1
  71. inspect_ai/tool/beta/_computer/_resources/image_home_dir/.config/xfce4/xfconf/xfce-perchannel-xml/xfce4-screensaver.xml +10 -0
  72. inspect_ai/util/__init__.py +0 -2
  73. inspect_ai/util/_display.py +5 -0
  74. inspect_ai/util/_sandbox/docker/prereqs.py +1 -1
  75. inspect_ai/util/_sandbox/self_check.py +51 -28
  76. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/METADATA +3 -2
  77. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/RECORD +81 -76
  78. inspect_ai/_view/www/src/api/Types.mjs +0 -117
  79. inspect_ai/_view/www/src/api/api-http.mjs +0 -300
  80. inspect_ai/_view/www/src/api/api-shared.mjs +0 -10
  81. inspect_ai/_view/www/src/api/index.mjs +0 -49
  82. inspect_ai/_view/www/src/api/jsonrpc.mjs +0 -208
  83. inspect_ai/_view/www/src/utils/vscode.mjs +0 -16
  84. inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +0 -10
  85. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/LICENSE +0 -0
  86. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/WHEEL +0 -0
  87. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/entry_points.txt +0 -0
  88. {inspect_ai-0.3.59.dist-info → inspect_ai-0.3.61.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ from inspect_ai.solver._chain import chain
13
13
  from inspect_ai.tool._tool import Tool, ToolResult, tool
14
14
  from inspect_ai.tool._tool_with import tool_with
15
15
 
16
+ from ._limit import SampleLimitExceededError
16
17
  from ._prompt import system_message
17
18
  from ._solver import Generate, Solver, solver
18
19
  from ._task_state import TaskState
@@ -119,7 +120,7 @@ def basic_agent(
119
120
  # resolve tools
120
121
  if tools is None:
121
122
  tools = []
122
- tools = tools if isinstance(tools, Solver) else use_tools(tools)
123
+ tools = tools if isinstance(tools, Solver) else use_tools(tools, append=True)
123
124
 
124
125
  # resolve score_value function
125
126
  score_value_fn = score_value or value_to_float()
@@ -167,61 +168,70 @@ def basic_agent(
167
168
  # track attempts
168
169
  attempts = 0
169
170
 
170
- # main loop (state.completed checks message_limit and token_limit)
171
- while not state.completed:
172
- # generate output and append assistant message
173
- state.output = await get_model().generate(
174
- input=state.messages, tools=state.tools, cache=cache
175
- )
176
- state.messages.append(state.output.message)
177
-
178
- # check for context window overflow
179
- if state.output.stop_reason == "model_length":
180
- from inspect_ai.log._transcript import transcript
181
-
182
- transcript().info("Agent terminated: model context window exceeded")
183
- break
184
-
185
- # resolve tools calls (if any)
186
- if state.output.message.tool_calls:
187
- # call tool functions
188
- tool_results = await call_tools(
189
- state.output.message, state.tools, max_output=max_tool_output
171
+ try:
172
+ # main loop (state.completed checks message_limit and token_limit)
173
+ while not state.completed:
174
+ # generate output and append assistant message
175
+ state.output = await get_model().generate(
176
+ input=state.messages, tools=state.tools, cache=cache
190
177
  )
191
- state.messages.extend(tool_results)
192
-
193
- # was an answer submitted?
194
- answer = submission(tool_results)
195
- if answer:
196
- # set the output to the answer for scoring
197
- state.output.completion = answer
198
-
199
- # exit if we are at max_attempts
200
- attempts += 1
201
- if attempts >= max_attempts:
202
- state.completed = True
203
- break
204
-
205
- # exit if the submission is successful
206
- answer_scores = await score(state)
207
- if score_value_fn(answer_scores[0].value) == 1.0:
208
- state.completed = True
209
- break
210
-
211
- # otherwise notify the model that it was incorrect and continue
212
- else:
213
- response_message = (
214
- incorrect_message(state, answer_scores)
215
- if callable(incorrect_message)
216
- else incorrect_message
217
- )
218
- state.messages.append(
219
- ChatMessageUser(content=response_message)
220
- )
221
-
222
- # no tool calls, urge the model to continue
223
- else:
224
- state.messages.append(ChatMessageUser(content=continue_message))
178
+ state.messages.append(state.output.message)
179
+
180
+ # check for context window overflow
181
+ if state.output.stop_reason == "model_length":
182
+ from inspect_ai.log._transcript import transcript
183
+
184
+ transcript().info(
185
+ "Agent terminated: model context window exceeded"
186
+ )
187
+ break
188
+
189
+ # resolve tools calls (if any)
190
+ if state.output.message.tool_calls:
191
+ # call tool functions
192
+ tool_results = await call_tools(
193
+ state.output.message,
194
+ state.tools,
195
+ max_output=max_tool_output,
196
+ )
197
+ state.messages.extend(tool_results)
198
+
199
+ # was an answer submitted?
200
+ answer = submission(tool_results)
201
+ if answer:
202
+ # set the output to the answer for scoring
203
+ state.output.completion = answer
204
+
205
+ # exit if we are at max_attempts
206
+ attempts += 1
207
+ if attempts >= max_attempts:
208
+ state.completed = True
209
+ break
210
+
211
+ # exit if the submission is successful
212
+ answer_scores = await score(state)
213
+ if score_value_fn(answer_scores[0].value) == 1.0:
214
+ state.completed = True
215
+ break
216
+
217
+ # otherwise notify the model that it was incorrect and continue
218
+ else:
219
+ response_message = (
220
+ incorrect_message(state, answer_scores)
221
+ if callable(incorrect_message)
222
+ else incorrect_message
223
+ )
224
+ state.messages.append(
225
+ ChatMessageUser(content=response_message)
226
+ )
227
+
228
+ # no tool calls, urge the model to continue
229
+ else:
230
+ state.messages.append(ChatMessageUser(content=continue_message))
231
+
232
+ # propagate current state along with sample limit exceeded
233
+ except SampleLimitExceededError as ex:
234
+ raise ex.with_state(state)
225
235
 
226
236
  return state
227
237
 
@@ -0,0 +1,3 @@
1
+ from .bridge import bridge
2
+
3
+ __all__ = ["bridge"]
@@ -0,0 +1,100 @@
1
+ from typing import Any, Awaitable, Callable
2
+
3
+ from jsonschema import Draft7Validator
4
+ from pydantic import BaseModel, Field, ValidationError
5
+ from pydantic_core import to_json
6
+
7
+ from inspect_ai._util._async import is_callable_coroutine
8
+ from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
9
+ from inspect_ai.model._providers.providers import validate_openai_client
10
+ from inspect_ai.scorer._metric import Score
11
+
12
+ from .._solver import Generate, Solver, solver
13
+ from .._task_state import TaskState
14
+
15
+
16
+ @solver
17
+ def bridge(agent: Callable[[dict[str, Any]], Awaitable[dict[str, Any]]]) -> Solver:
18
+ """Bridge an external agent into an Inspect Solver.
19
+
20
+ See documentation at https://inspect.ai-safety-institute.org.uk/agent-bridge.html
21
+
22
+ Args:
23
+ agent: Callable which takes a sample `dict` and returns a result `dict`.
24
+
25
+ Returns:
26
+ Standard Inspect solver.
27
+ """
28
+ validate_openai_client("Solver bridge()")
29
+
30
+ from openai.types.chat import ChatCompletionMessageParam
31
+
32
+ from inspect_ai.model._openai import (
33
+ chat_messages_from_openai,
34
+ openai_chat_messages,
35
+ )
36
+
37
+ from .patch import openai_request_to_inspect_model
38
+
39
+ class BridgeSample(BaseModel):
40
+ sample_id: str
41
+ epoch: int
42
+ input: list[ChatCompletionMessageParam]
43
+ metadata: dict[str, Any]
44
+ target: list[str]
45
+
46
+ class BridgeResult(BaseModel):
47
+ output: str
48
+ messages: list[ChatCompletionMessageParam] | None = Field(default=None)
49
+ scores: dict[str, Score] | None = Field(default=None)
50
+
51
+ result_schema = BridgeResult.model_json_schema()
52
+ result_validator = Draft7Validator(result_schema)
53
+
54
+ # validate that the agent is an async function
55
+ if not is_callable_coroutine(agent):
56
+ raise TypeError(f"'{agent.__name__}' is not declared as an async callable.")
57
+
58
+ async def solve(state: TaskState, generate: Generate) -> TaskState:
59
+ # resolve input to array
60
+ input: list[ChatMessage] = (
61
+ [ChatMessageUser(content=state.input)]
62
+ if isinstance(state.input, str)
63
+ else state.input
64
+ )
65
+
66
+ # create sample
67
+ sample = BridgeSample(
68
+ sample_id=str(state.sample_id),
69
+ epoch=state.epoch,
70
+ input=await openai_chat_messages(input, state.model.name),
71
+ metadata=state.metadata,
72
+ target=list(state.target),
73
+ )
74
+
75
+ # run target function
76
+ async with openai_request_to_inspect_model():
77
+ # call the function
78
+ result_dict = await agent(sample.model_dump())
79
+ try:
80
+ result = BridgeResult.model_validate(result_dict)
81
+ except ValidationError:
82
+ # if we fail to validate provide a better human readable error
83
+ errors = list(result_validator.iter_errors(result_dict))
84
+ message = "\n".join(
85
+ ["Result returned from bridged solver is not valid:"]
86
+ + [f" - {error.message}" for error in errors]
87
+ + ["", to_json(result_dict, indent=2).decode()]
88
+ )
89
+ raise ValueError(message)
90
+
91
+ # update and return state
92
+ state.output.completion = result.output
93
+ if result.messages is not None:
94
+ state.messages = chat_messages_from_openai(result.messages)
95
+ if result.scores is not None:
96
+ state.scores = result.scores
97
+
98
+ return state
99
+
100
+ return solve
@@ -0,0 +1,170 @@
1
+ import contextlib
2
+ import re
3
+ from contextvars import ContextVar
4
+ from functools import wraps
5
+ from time import time
6
+ from typing import Any, AsyncGenerator, Optional, Type, cast
7
+
8
+ from openai._base_client import AsyncAPIClient, _AsyncStreamT
9
+ from openai._models import FinalRequestOptions
10
+ from openai._types import ResponseT
11
+ from openai.types.chat import (
12
+ ChatCompletion,
13
+ ChatCompletionMessageParam,
14
+ ChatCompletionToolParam,
15
+ )
16
+ from shortuuid import uuid
17
+
18
+ from inspect_ai.model._generate_config import GenerateConfig
19
+ from inspect_ai.model._model import get_model
20
+ from inspect_ai.model._openai import (
21
+ chat_messages_from_openai,
22
+ openai_chat_choices,
23
+ openai_completion_usage,
24
+ )
25
+ from inspect_ai.solver._task_state import sample_state
26
+ from inspect_ai.tool._tool_info import ToolInfo
27
+ from inspect_ai.tool._tool_params import ToolParams
28
+
29
+
30
+ @contextlib.asynccontextmanager
31
+ async def openai_request_to_inspect_model() -> AsyncGenerator[None, None]:
32
+ # ensure one time init
33
+ init_openai_request_patch()
34
+
35
+ # set the patch enabled for this context and child coroutines
36
+ token = _patch_enabled.set(True)
37
+ try:
38
+ yield
39
+ finally:
40
+ _patch_enabled.reset(token)
41
+
42
+
43
+ _patch_initialised: bool = False
44
+
45
+ _patch_enabled: ContextVar[bool] = ContextVar(
46
+ "openai_request_patch_enabled", default=False
47
+ )
48
+
49
+
50
+ def init_openai_request_patch() -> None:
51
+ global _patch_initialised
52
+ if not _patch_initialised:
53
+ # get reference to original method
54
+ original_request = getattr(AsyncAPIClient, "request")
55
+ if original_request is None:
56
+ raise RuntimeError("Couldn't find 'request' method on AsyncAPIClient")
57
+
58
+ @wraps(original_request)
59
+ async def patched_request(
60
+ self: AsyncAPIClient,
61
+ cast_to: Type[ResponseT],
62
+ options: FinalRequestOptions,
63
+ *,
64
+ stream: bool = False,
65
+ stream_cls: type[_AsyncStreamT] | None = None,
66
+ remaining_retries: Optional[int] = None,
67
+ ) -> Any:
68
+ # we have patched the underlying request method so now need to figure out when to
69
+ # patch and when to stand down
70
+ if (
71
+ # enabled for this coroutine
72
+ _patch_enabled.get()
73
+ # completions request
74
+ and options.url == "/chat/completions"
75
+ # call to openai not another service (e.g. TogetherAI)
76
+ and self.base_url == "https://api.openai.com/v1/"
77
+ ):
78
+ # must also be an explicit request for an inspect model
79
+ json_data = cast(dict[str, Any], options.json_data)
80
+ model_name = str(json_data["model"])
81
+ if re.match(r"^inspect/?", model_name):
82
+ return await inspect_model_request(model_name, options)
83
+
84
+ # otherwise just delegate
85
+ return await original_request(
86
+ self,
87
+ cast_to,
88
+ options,
89
+ stream=stream,
90
+ stream_cls=stream_cls,
91
+ remaining_retries=remaining_retries,
92
+ )
93
+
94
+ setattr(AsyncAPIClient, "request", patched_request)
95
+
96
+
97
+ async def inspect_model_request(
98
+ model_name: str, options: FinalRequestOptions
99
+ ) -> ChatCompletion:
100
+ # convert openai messages to inspect messages
101
+ json_data = cast(dict[str, Any], options.json_data)
102
+ messages: list[ChatCompletionMessageParam] = json_data["messages"]
103
+ input = chat_messages_from_openai(messages)
104
+
105
+ # convert openai tools to inspect tools
106
+ tools: list[ChatCompletionToolParam] = json_data.get("tools", [])
107
+ inspect_tools: list[ToolInfo] = []
108
+ for tool in tools:
109
+ function = tool["function"].copy()
110
+ inspect_tools.append(
111
+ ToolInfo(
112
+ name=function["name"],
113
+ description=function["description"],
114
+ parameters=ToolParams.model_validate(function["parameters"]),
115
+ )
116
+ )
117
+
118
+ # resolve model
119
+ if model_name == "inspect":
120
+ model = get_model()
121
+ else:
122
+ model = get_model(model_name.removeprefix("inspect/"))
123
+
124
+ output = await model.generate(
125
+ input=input,
126
+ tools=inspect_tools,
127
+ config=generate_config_from_openai(options),
128
+ )
129
+
130
+ # if we are using the "default" inspect model for the task, update state.messages
131
+ if model_name == "inspect":
132
+ state = sample_state()
133
+ if state:
134
+ state.messages = input + [output.choices[0].message]
135
+
136
+ # inspect completion to openai completion
137
+ return ChatCompletion(
138
+ id=uuid(),
139
+ created=int(time()),
140
+ object="chat.completion",
141
+ choices=openai_chat_choices(output.choices),
142
+ model=model_name,
143
+ usage=openai_completion_usage(output.usage) if output.usage else None,
144
+ )
145
+
146
+
147
+ def generate_config_from_openai(options: FinalRequestOptions) -> GenerateConfig:
148
+ # get options dict
149
+ json_data = cast(dict[str, Any], options.json_data)
150
+
151
+ config = GenerateConfig()
152
+ config.max_tokens = json_data.get(
153
+ "max_completion_tokens", json_data.get("max_tokens", None)
154
+ )
155
+ config.top_p = json_data.get("top_p", None)
156
+ config.temperature = json_data.get("temperature", None)
157
+ stop = json_data.get("stop", None)
158
+ if stop:
159
+ config.stop_seqs = [stop] if isinstance(stop, str) else stop
160
+ config.frequency_penalty = json_data.get("frequency_penalty", None)
161
+ config.presence_penalty = json_data.get("presence_penalty", None)
162
+ config.seed = json_data.get("seed", None)
163
+ config.num_choices = json_data.get("n", None)
164
+ config.logprobs = json_data.get("logprobs", None)
165
+ config.top_logprobs = json_data.get("top_logprobs", None)
166
+ config.logit_bias = json_data.get("logit_bias", None)
167
+ config.parallel_tool_calls = json_data.get("parallel_tool_calls", None)
168
+ config.reasoning_effort = json_data.get("reasoning_effort", None)
169
+
170
+ return config
@@ -1,5 +1,7 @@
1
1
  from typing import Literal
2
2
 
3
+ from ._task_state import TaskState
4
+
3
5
 
4
6
  class SampleLimitExceededError(Exception):
5
7
  """Exception raised when a sample limit is exceeded.
@@ -18,9 +20,20 @@ class SampleLimitExceededError(Exception):
18
20
  value: int,
19
21
  limit: int,
20
22
  message: str | None = None,
23
+ state: TaskState | None = None,
21
24
  ) -> None:
22
25
  self.type = type
23
26
  self.value = value
24
27
  self.limit = limit
25
28
  self.message = f"Exceeded {type} limit: {limit:,}"
29
+ self.state = state
26
30
  super().__init__(message)
31
+
32
+ def with_state(self, state: TaskState) -> "SampleLimitExceededError":
33
+ return SampleLimitExceededError(
34
+ self.type,
35
+ value=self.value,
36
+ limit=self.limit,
37
+ message=self.message,
38
+ state=state,
39
+ )
@@ -180,6 +180,7 @@ def solver(
180
180
  solver_type, name if name else getattr(solver_type, "__name__")
181
181
  )
182
182
 
183
+ @wraps(solver_type)
183
184
  def solver_wrapper(*args: P.args, **kwargs: P.kwargs) -> Solver:
184
185
  solver = solver_type(*args, **kwargs)
185
186
 
@@ -193,6 +194,7 @@ def solver(
193
194
  if inspect.isclass(type(solver)):
194
195
  original_call = solver.__call__
195
196
 
197
+ @wraps(original_call)
196
198
  async def call_with_state(
197
199
  state: TaskState, generate: Generate
198
200
  ) -> TaskState:
@@ -225,6 +227,10 @@ def solver(
225
227
 
226
228
  return registered_solver
227
229
 
230
+ # functools.wraps overrides the return type annotation of the inner function, so
231
+ # we explicitly set it again
232
+ solver_wrapper.__annotations__["return"] = Solver
233
+
228
234
  return solver_register(cast(Callable[P, Solver], solver_wrapper), solver_name)
229
235
 
230
236
  # for decorators with an explicit name, one more wrapper for the name
@@ -22,7 +22,6 @@ from inspect_ai.scorer._metric import Score
22
22
  from inspect_ai.scorer._target import Target
23
23
  from inspect_ai.tool import Tool, ToolChoice
24
24
  from inspect_ai.tool._tool_def import ToolDef
25
- from inspect_ai.util._limit import SampleLimitExceededError
26
25
  from inspect_ai.util._store import Store, store_jsonable
27
26
  from inspect_ai.util._store_model import SMT
28
27
 
@@ -173,7 +172,7 @@ class TaskState:
173
172
  self.metadata = metadata
174
173
  """Metadata from the `Sample` for this `TaskState`"""
175
174
 
176
- self._messages: list[ChatMessage] = ChatMessageList(messages)
175
+ self._messages: list[ChatMessage] = ChatMessageList(messages, self)
177
176
  """
178
177
  Chat conversation history for sample.
179
178
 
@@ -272,7 +271,7 @@ class TaskState:
272
271
  @messages.setter
273
272
  def messages(self, messages: list[ChatMessage]) -> None:
274
273
  """Set messages in chat history."""
275
- self._messages = ChatMessageList(messages)
274
+ self._messages = ChatMessageList(messages, self)
276
275
 
277
276
  @property
278
277
  def max_messages(self) -> int | None:
@@ -319,8 +318,32 @@ class TaskState:
319
318
 
320
319
  @property
321
320
  def completed(self) -> bool:
322
- """Is the task completed."""
323
- return self._completed
321
+ """Is the task completed.
322
+
323
+ Additionally, checks message and token limits and raises if they are exceeded.
324
+ """
325
+ from inspect_ai.log._samples import set_active_sample_total_messages
326
+
327
+ from ._limit import SampleLimitExceededError
328
+
329
+ # update messages
330
+ set_active_sample_total_messages(len(self.messages))
331
+
332
+ if self._completed:
333
+ return True
334
+ elif self.message_limit and len(self.messages) >= self.message_limit:
335
+ raise SampleLimitExceededError(
336
+ "message",
337
+ value=len(self.messages),
338
+ limit=self.message_limit,
339
+ state=self,
340
+ )
341
+ elif self.token_limit and self.token_usage >= self.token_limit:
342
+ raise SampleLimitExceededError(
343
+ "token", value=self.token_usage, limit=self.token_limit, state=self
344
+ )
345
+ else:
346
+ return self._completed
324
347
 
325
348
  @completed.setter
326
349
  def completed(self, completed: bool) -> None:
@@ -403,7 +426,8 @@ def sample_jsonable(sample: Sample) -> dict[str, Any]:
403
426
 
404
427
 
405
428
  class ChatMessageList(list[ChatMessage]):
406
- def __init__(self, iterable: Iterable[ChatMessage]):
429
+ def __init__(self, iterable: Iterable[ChatMessage], parent_state: TaskState):
430
+ self.parent_state = parent_state
407
431
  items, length = self._iterable_length(iterable)
408
432
  self._check_size(length)
409
433
  super().__init__(items)
@@ -411,12 +435,18 @@ class ChatMessageList(list[ChatMessage]):
411
435
  def _check_size(self, additional_items: int = 1) -> None:
412
436
  from inspect_ai.log._samples import active_sample_message_limit
413
437
 
438
+ from ._limit import SampleLimitExceededError
439
+
414
440
  messages_limit = active_sample_message_limit()
415
441
  if messages_limit is not None:
416
442
  messages = len(self) + additional_items
417
443
  if messages > messages_limit:
418
444
  raise SampleLimitExceededError(
419
- "message", value=messages, limit=messages_limit
445
+ "message",
446
+ value=messages,
447
+ limit=messages_limit,
448
+ message=None,
449
+ state=self.parent_state,
420
450
  )
421
451
 
422
452
  def append(self, item: ChatMessage) -> None:
@@ -345,7 +345,9 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
345
345
  if sandbox_env:
346
346
  store = store_as(WebBrowserStore)
347
347
  if not store.session_id:
348
- result = await sandbox_env.exec(["python3", WEB_CLIENT_NEW_SESSION])
348
+ result = await sandbox_env.exec(
349
+ ["python3", WEB_CLIENT_NEW_SESSION], timeout=180
350
+ )
349
351
 
350
352
  if not result.success:
351
353
  raise RuntimeError(
@@ -33,8 +33,6 @@ RUN apt-get update && \
33
33
 
34
34
  # Userland apt-get'able apps
35
35
  RUN apt-get install -y --no-install-recommends \
36
- # A simple image viewer.
37
- xpaint \
38
36
  # A calculator application.
39
37
  galculator && \
40
38
  apt-get clean
@@ -78,7 +76,7 @@ RUN useradd -m -s /bin/bash -d $HOME $USERNAME
78
76
  RUN echo "${USERNAME} ALL=(ALL) NOPASSWD: ALL" >> /etc/sudoers
79
77
  USER ${USERNAME}
80
78
  WORKDIR $HOME
81
- COPY --chown=$USERNAME:$USERNAME image_home_dir/ $HOME
79
+ ADD --chown=$USERNAME:$USERNAME image_home_dir/ $HOME
82
80
 
83
81
  # configure Firefox to skip all 'first run' UI
84
82
  RUN mkdir -p $HOME/.mozilla/firefox-esr/profile.default && \
@@ -5,7 +5,7 @@ echo "starting vnc"
5
5
  -forever \
6
6
  -shared \
7
7
  -wait 50 \
8
- -cursor most \
8
+ -multiptr \
9
9
  -cursor arrow \
10
10
  -rfbport 5900 \
11
11
  -nopw \
@@ -0,0 +1,10 @@
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+
3
+ <channel name="xfce4-screensaver" version="1.0">
4
+ <property name="saver" type="empty">
5
+ <property name="mode" type="int" value="0" />
6
+ </property>
7
+ <property name="lock" type="empty">
8
+ <property name="enabled" type="bool" value="false" />
9
+ </property>
10
+ </channel>
@@ -3,7 +3,6 @@ from inspect_ai._util.trace import trace_action, trace_message
3
3
  from ._concurrency import concurrency
4
4
  from ._console import input_screen
5
5
  from ._display import DisplayType, display_type
6
- from ._limit import SampleLimitExceededError
7
6
  from ._panel import InputPanel, input_panel
8
7
  from ._resource import resource
9
8
  from ._sandbox import (
@@ -37,7 +36,6 @@ __all__ = [
37
36
  "input_panel",
38
37
  "input_screen",
39
38
  "OutputLimitExceededError",
40
- "SampleLimitExceededError",
41
39
  "resource",
42
40
  "subprocess",
43
41
  "SandboxEnvironment",
@@ -49,3 +49,8 @@ def display_type() -> DisplayType:
49
49
  return _display_type
50
50
  else:
51
51
  return init_display_type()
52
+
53
+
54
+ def display_type_initialized() -> bool:
55
+ global _display_type
56
+ return _display_type is not None
@@ -57,7 +57,7 @@ async def validate_docker_compose(
57
57
  version: str = DOCKER_COMPOSE_REQUIRED_VERSION,
58
58
  ) -> None:
59
59
  def parse_version(stdout: str) -> semver.Version:
60
- version = json.loads(stdout)["version"].removeprefix("v")
60
+ version = json.loads(stdout)["version"].removeprefix("v").split("+")[0]
61
61
  return semver.Version.parse(version)
62
62
 
63
63
  await validate_version(