inspect-ai 0.3.52__py3-none-any.whl → 0.3.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +55 -1
- inspect_ai/_cli/main.py +2 -0
- inspect_ai/_cli/trace.py +244 -0
- inspect_ai/_display/core/progress.py +9 -3
- inspect_ai/_display/core/results.py +8 -4
- inspect_ai/_display/textual/app.py +5 -1
- inspect_ai/_display/textual/widgets/task_detail.py +3 -0
- inspect_ai/_display/textual/widgets/tasks.py +97 -6
- inspect_ai/_eval/eval.py +33 -0
- inspect_ai/_eval/evalset.py +4 -0
- inspect_ai/_eval/registry.py +2 -2
- inspect_ai/_eval/task/images.py +4 -14
- inspect_ai/_eval/task/results.py +22 -4
- inspect_ai/_eval/task/run.py +40 -20
- inspect_ai/_eval/task/sandbox.py +72 -43
- inspect_ai/_eval/task/task.py +4 -0
- inspect_ai/_eval/task/util.py +2 -0
- inspect_ai/_util/constants.py +3 -3
- inspect_ai/_util/display.py +1 -0
- inspect_ai/_util/logger.py +34 -8
- inspect_ai/_util/trace.py +275 -0
- inspect_ai/_view/www/App.css +13 -0
- inspect_ai/_view/www/dist/assets/index.css +13 -0
- inspect_ai/_view/www/dist/assets/index.js +80 -43
- inspect_ai/_view/www/src/App.mjs +31 -6
- inspect_ai/_view/www/src/Types.mjs +6 -0
- inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
- inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
- inspect_ai/_view/www/src/components/Tools.mjs +46 -18
- inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
- inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
- inspect_ai/log/_log.py +6 -0
- inspect_ai/log/_message.py +2 -2
- inspect_ai/log/_recorders/eval.py +8 -18
- inspect_ai/log/_recorders/json.py +19 -17
- inspect_ai/model/_cache.py +22 -16
- inspect_ai/model/_call_tools.py +9 -1
- inspect_ai/model/_generate_config.py +8 -2
- inspect_ai/model/_model.py +11 -12
- inspect_ai/model/_providers/azureai.py +1 -1
- inspect_ai/model/_providers/bedrock.py +18 -2
- inspect_ai/model/_providers/hf.py +1 -1
- inspect_ai/model/_providers/openai.py +32 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/vllm.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
- inspect_ai/util/_sandbox/context.py +7 -3
- inspect_ai/util/_sandbox/docker/compose.py +58 -19
- inspect_ai/util/_sandbox/docker/config.py +8 -10
- inspect_ai/util/_sandbox/docker/docker.py +20 -16
- inspect_ai/util/_sandbox/docker/util.py +3 -9
- inspect_ai/util/_sandbox/environment.py +7 -2
- inspect_ai/util/_sandbox/limits.py +1 -1
- inspect_ai/util/_sandbox/local.py +8 -9
- inspect_ai/util/_sandbox/service.py +17 -7
- inspect_ai/util/_subprocess.py +6 -1
- inspect_ai/util/_subtask.py +8 -2
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/METADATA +6 -8
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/RECORD +64 -62
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/top_level.txt +0 -0
inspect_ai/model/_cache.py
CHANGED
@@ -6,10 +6,12 @@ from datetime import datetime, timezone
|
|
6
6
|
from hashlib import md5
|
7
7
|
from pathlib import Path
|
8
8
|
from shutil import rmtree
|
9
|
+
from typing import Any
|
9
10
|
|
10
11
|
from dateutil.relativedelta import relativedelta
|
11
12
|
|
12
13
|
from inspect_ai._util.appdirs import inspect_cache_dir
|
14
|
+
from inspect_ai._util.trace import trace_message
|
13
15
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
14
16
|
|
15
17
|
from ._chat_message import ChatMessage
|
@@ -19,6 +21,10 @@ from ._model_output import ModelOutput
|
|
19
21
|
logger = logging.getLogger(__name__)
|
20
22
|
|
21
23
|
|
24
|
+
def trace(msg: str, *args: Any) -> None:
|
25
|
+
trace_message(logger, "Cache", msg, args)
|
26
|
+
|
27
|
+
|
22
28
|
def _path_is_in_cache(path: Path | str) -> bool:
|
23
29
|
"""This ensures the path is in our cache directory, just in case the `model` is ../../../home/ubuntu/maliciousness"""
|
24
30
|
if isinstance(path, str):
|
@@ -153,7 +159,7 @@ def _cache_key(entry: CacheEntry) -> str:
|
|
153
159
|
|
154
160
|
base_string = "|".join([str(component) for component in components])
|
155
161
|
|
156
|
-
|
162
|
+
trace(_cache_key_debug_string([str(component) for component in components]))
|
157
163
|
|
158
164
|
return md5(base_string.encode("utf-8")).hexdigest()
|
159
165
|
|
@@ -192,11 +198,11 @@ def cache_store(
|
|
192
198
|
|
193
199
|
with open(filename, "wb") as f:
|
194
200
|
expiry = _cache_expiry(entry.policy)
|
195
|
-
|
201
|
+
trace("Storing in cache: %s (expires: %s)", filename, expiry)
|
196
202
|
pickle.dump((expiry, output), f)
|
197
203
|
return True
|
198
204
|
except Exception as e:
|
199
|
-
|
205
|
+
trace(f"Failed to cache {filename}: {e}")
|
200
206
|
return False
|
201
207
|
|
202
208
|
|
@@ -204,12 +210,12 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
204
210
|
"""Fetch a value from the cache directory."""
|
205
211
|
filename = cache_path(model=entry.model) / _cache_key(entry)
|
206
212
|
try:
|
207
|
-
|
213
|
+
trace("Fetching from cache: %s", filename)
|
208
214
|
|
209
215
|
with open(filename, "rb") as f:
|
210
216
|
expiry, output = pickle.load(f)
|
211
217
|
if not isinstance(output, ModelOutput):
|
212
|
-
|
218
|
+
trace(
|
213
219
|
"Unexpected cached type, can only fetch ModelOutput: %s (%s)",
|
214
220
|
type(output),
|
215
221
|
filename,
|
@@ -217,7 +223,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
217
223
|
return None
|
218
224
|
|
219
225
|
if _is_expired(expiry):
|
220
|
-
|
226
|
+
trace("Cache expired for %s (%s)", filename, expiry)
|
221
227
|
# If it's expired, no point keeping it as we'll never access it
|
222
228
|
# successfully again.
|
223
229
|
filename.unlink(missing_ok=True)
|
@@ -225,7 +231,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
|
|
225
231
|
|
226
232
|
return output
|
227
233
|
except Exception as e:
|
228
|
-
|
234
|
+
trace(f"Failed to fetch from cache {filename}: {e}")
|
229
235
|
return None
|
230
236
|
|
231
237
|
|
@@ -235,7 +241,7 @@ def cache_clear(model: str = "") -> bool:
|
|
235
241
|
path = cache_path(model)
|
236
242
|
|
237
243
|
if (model == "" or _path_is_in_cache(path)) and path.exists():
|
238
|
-
|
244
|
+
trace("Clearing cache: %s", path)
|
239
245
|
rmtree(path)
|
240
246
|
return True
|
241
247
|
|
@@ -351,24 +357,24 @@ def cache_list_expired(filter_by: list[str] = []) -> list[Path]:
|
|
351
357
|
# "../../foo/bar") but we don't want to search the entire cache
|
352
358
|
return []
|
353
359
|
|
354
|
-
|
360
|
+
trace("Filtering by paths: %s", filter_by_paths)
|
355
361
|
for dirpath, _dirnames, filenames in os.walk(cache_path()):
|
356
362
|
if filter_by_paths and Path(dirpath) not in filter_by_paths:
|
357
|
-
|
363
|
+
trace("Skipping path %s", dirpath)
|
358
364
|
continue
|
359
365
|
|
360
|
-
|
366
|
+
trace("Checking dirpath %s", dirpath)
|
361
367
|
for filename in filenames:
|
362
368
|
path = Path(dirpath) / filename
|
363
|
-
|
369
|
+
trace("Checking path %s", path)
|
364
370
|
try:
|
365
371
|
with open(path, "rb") as f:
|
366
372
|
expiry, _cache_entry = pickle.load(f)
|
367
373
|
if _is_expired(expiry):
|
368
|
-
|
374
|
+
trace("Expired cache entry found: %s (%s)", path, expiry)
|
369
375
|
expired_cache_entries.append(path)
|
370
376
|
except Exception as e:
|
371
|
-
|
377
|
+
trace("Failed to load cached item %s: %s", path, e)
|
372
378
|
continue
|
373
379
|
|
374
380
|
return expired_cache_entries
|
@@ -389,8 +395,8 @@ def cache_prune(files: list[Path] = []) -> None:
|
|
389
395
|
with open(file, "rb") as f:
|
390
396
|
expiry, _cache_entry = pickle.load(f)
|
391
397
|
if _is_expired(expiry):
|
392
|
-
|
398
|
+
trace("Pruning expired cache: %s", file)
|
393
399
|
file.unlink(missing_ok=True)
|
394
400
|
except Exception as e:
|
395
|
-
|
401
|
+
trace("Failed to prune cache %s: %s", file, e)
|
396
402
|
continue
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
import inspect
|
3
3
|
from dataclasses import is_dataclass
|
4
|
+
from logging import getLogger
|
4
5
|
from textwrap import dedent
|
5
6
|
from typing import (
|
6
7
|
Any,
|
@@ -19,7 +20,9 @@ from jsonschema import Draft7Validator
|
|
19
20
|
from pydantic import BaseModel
|
20
21
|
|
21
22
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
23
|
+
from inspect_ai._util.format import format_function_call
|
22
24
|
from inspect_ai._util.text import truncate_string_to_bytes
|
25
|
+
from inspect_ai._util.trace import trace_action
|
23
26
|
from inspect_ai.model._trace import trace_tool_mesage
|
24
27
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
25
28
|
from inspect_ai.tool._tool import (
|
@@ -35,6 +38,8 @@ from inspect_ai.util import OutputLimitExceededError
|
|
35
38
|
from ._chat_message import ChatMessageAssistant, ChatMessageTool
|
36
39
|
from ._generate_config import active_generate_config
|
37
40
|
|
41
|
+
logger = getLogger(__name__)
|
42
|
+
|
38
43
|
|
39
44
|
async def call_tools(
|
40
45
|
message: ChatMessageAssistant,
|
@@ -215,7 +220,10 @@ async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
|
|
215
220
|
arguments = tool_params(call.arguments, tool_def.tool)
|
216
221
|
|
217
222
|
# call the tool
|
218
|
-
|
223
|
+
with trace_action(
|
224
|
+
logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
|
225
|
+
):
|
226
|
+
result = await tool_def.tool(**arguments)
|
219
227
|
|
220
228
|
# return result
|
221
229
|
return result
|
@@ -58,7 +58,7 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
58
58
|
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, and TogetherAI only."""
|
59
59
|
|
60
60
|
logprobs: bool | None
|
61
|
-
"""Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, and
|
61
|
+
"""Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
|
62
62
|
|
63
63
|
top_logprobs: int | None
|
64
64
|
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, and Huggingface only."""
|
@@ -72,6 +72,9 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
72
72
|
cache_prompt: Literal["auto"] | bool | None
|
73
73
|
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
|
74
74
|
|
75
|
+
reasoning_effort: Literal["low", "medium", "high"] | None
|
76
|
+
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
|
77
|
+
|
75
78
|
|
76
79
|
class GenerateConfig(BaseModel):
|
77
80
|
"""Base class for model generation configs."""
|
@@ -125,7 +128,7 @@ class GenerateConfig(BaseModel):
|
|
125
128
|
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and vLLM only."""
|
126
129
|
|
127
130
|
logprobs: bool | None = Field(default=None)
|
128
|
-
"""Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, and vLLM only."""
|
131
|
+
"""Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
|
129
132
|
|
130
133
|
top_logprobs: int | None = Field(default=None)
|
131
134
|
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, Huggingface, and vLLM only."""
|
@@ -139,6 +142,9 @@ class GenerateConfig(BaseModel):
|
|
139
142
|
cache_prompt: Literal["auto"] | bool | None = Field(default=None)
|
140
143
|
"""Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
|
141
144
|
|
145
|
+
reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
|
146
|
+
"""Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
|
147
|
+
|
142
148
|
def merge(
|
143
149
|
self, other: Union["GenerateConfig", GenerateConfigArgs]
|
144
150
|
) -> "GenerateConfig":
|
inspect_ai/model/_model.py
CHANGED
@@ -9,7 +9,6 @@ from contextvars import ContextVar
|
|
9
9
|
from copy import deepcopy
|
10
10
|
from typing import Any, Callable, Literal, Type, cast
|
11
11
|
|
12
|
-
from shortuuid import uuid
|
13
12
|
from tenacity import (
|
14
13
|
retry,
|
15
14
|
retry_if_exception,
|
@@ -30,6 +29,7 @@ from inspect_ai._util.registry import (
|
|
30
29
|
registry_unqualified_name,
|
31
30
|
)
|
32
31
|
from inspect_ai._util.retry import log_rate_limit_retry
|
32
|
+
from inspect_ai._util.trace import trace_action
|
33
33
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
34
34
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
35
35
|
from inspect_ai.util import concurrency
|
@@ -363,17 +363,16 @@ class Model:
|
|
363
363
|
cache="write" if cache else None,
|
364
364
|
)
|
365
365
|
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
logger.debug(f"model generate {generate_id} (completed)")
|
366
|
+
with trace_action(logger, "Model", f"generate ({str(self)})"):
|
367
|
+
time_start = time.perf_counter()
|
368
|
+
result = await self.api.generate(
|
369
|
+
input=input,
|
370
|
+
tools=tools,
|
371
|
+
tool_choice=tool_choice,
|
372
|
+
config=config,
|
373
|
+
)
|
374
|
+
time_elapsed = time.perf_counter() - time_start
|
375
|
+
|
377
376
|
if isinstance(result, tuple):
|
378
377
|
output, call = result
|
379
378
|
else:
|
@@ -236,15 +236,21 @@ class BedrockAPI(ModelAPI):
|
|
236
236
|
self,
|
237
237
|
model_name: str,
|
238
238
|
base_url: str | None,
|
239
|
+
api_key: str | None = None,
|
239
240
|
config: GenerateConfig = GenerateConfig(),
|
240
241
|
**model_args: Any,
|
241
242
|
):
|
242
243
|
super().__init__(
|
243
244
|
model_name=model_name,
|
244
245
|
base_url=model_base_url(base_url, "BEDROCK_BASE_URL"),
|
246
|
+
api_key=api_key,
|
247
|
+
api_key_vars=[],
|
245
248
|
config=config,
|
246
249
|
)
|
247
250
|
|
251
|
+
# save model_args
|
252
|
+
self.model_args = model_args
|
253
|
+
|
248
254
|
# import aioboto3 on demand
|
249
255
|
try:
|
250
256
|
import aioboto3
|
@@ -263,6 +269,9 @@ class BedrockAPI(ModelAPI):
|
|
263
269
|
|
264
270
|
@override
|
265
271
|
def max_tokens(self) -> int | None:
|
272
|
+
if "llama3-70" in self.model_name or "llama3-8" in self.model_name:
|
273
|
+
return 2048
|
274
|
+
|
266
275
|
if "llama3" in self.model_name or "claude3" in self.model_name:
|
267
276
|
return 4096
|
268
277
|
|
@@ -303,7 +312,7 @@ class BedrockAPI(ModelAPI):
|
|
303
312
|
from botocore.exceptions import ClientError
|
304
313
|
|
305
314
|
# The bedrock client
|
306
|
-
async with self.session.client(
|
315
|
+
async with self.session.client( # type: ignore[call-overload]
|
307
316
|
service_name="bedrock-runtime",
|
308
317
|
endpoint_url=self.base_url,
|
309
318
|
config=Config(
|
@@ -316,6 +325,7 @@ class BedrockAPI(ModelAPI):
|
|
316
325
|
mode="adaptive",
|
317
326
|
),
|
318
327
|
),
|
328
|
+
**self.model_args,
|
319
329
|
) as client:
|
320
330
|
# Process the tools
|
321
331
|
resolved_tools = converse_tools(tools)
|
@@ -658,6 +668,8 @@ def converse_image_type(type: str) -> ConverseImageFormat:
|
|
658
668
|
return "png"
|
659
669
|
case "image/webp":
|
660
670
|
return "webp"
|
671
|
+
case "image/jpeg":
|
672
|
+
return "jpeg"
|
661
673
|
case _:
|
662
674
|
raise ValueError(
|
663
675
|
f"Image mime type {type} is not supported for Bedrock Converse models."
|
@@ -673,7 +685,11 @@ def converse_tools(tools: list[ToolInfo]) -> list[ConverseTool] | None:
|
|
673
685
|
tool_spec = ConverseToolSpec(
|
674
686
|
name=tool.name,
|
675
687
|
description=tool.description,
|
676
|
-
inputSchema={
|
688
|
+
inputSchema={
|
689
|
+
"json": tool.parameters.model_dump(
|
690
|
+
exclude_none=True, exclude={"additionalProperties"}
|
691
|
+
)
|
692
|
+
},
|
677
693
|
)
|
678
694
|
result.append(ConverseTool(toolSpec=tool_spec))
|
679
695
|
return result
|
@@ -18,6 +18,7 @@ from openai.types.chat import (
|
|
18
18
|
ChatCompletionContentPartImageParam,
|
19
19
|
ChatCompletionContentPartParam,
|
20
20
|
ChatCompletionContentPartTextParam,
|
21
|
+
ChatCompletionDeveloperMessageParam,
|
21
22
|
ChatCompletionMessage,
|
22
23
|
ChatCompletionMessageParam,
|
23
24
|
ChatCompletionMessageToolCallParam,
|
@@ -141,6 +142,18 @@ class OpenAIAPI(ModelAPI):
|
|
141
142
|
**model_args,
|
142
143
|
)
|
143
144
|
|
145
|
+
def is_o1(self) -> bool:
|
146
|
+
return self.model_name.startswith("o1")
|
147
|
+
|
148
|
+
def is_o1_full(self) -> bool:
|
149
|
+
return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
|
150
|
+
|
151
|
+
def is_o1_mini(self) -> bool:
|
152
|
+
return self.model_name.startswith("o1-mini")
|
153
|
+
|
154
|
+
def is_o1_preview(self) -> bool:
|
155
|
+
return self.model_name.startswith("o1-preview")
|
156
|
+
|
144
157
|
async def generate(
|
145
158
|
self,
|
146
159
|
input: list[ChatMessage],
|
@@ -148,8 +161,8 @@ class OpenAIAPI(ModelAPI):
|
|
148
161
|
tool_choice: ToolChoice,
|
149
162
|
config: GenerateConfig,
|
150
163
|
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
151
|
-
# short-circuit to call o1-
|
152
|
-
if self.
|
164
|
+
# short-circuit to call o1- models that are text only
|
165
|
+
if self.is_o1_preview() or self.is_o1_mini():
|
153
166
|
return await generate_o1(
|
154
167
|
client=self.client,
|
155
168
|
input=input,
|
@@ -179,7 +192,7 @@ class OpenAIAPI(ModelAPI):
|
|
179
192
|
|
180
193
|
# prepare request (we do this so we can log the ModelCall)
|
181
194
|
request = dict(
|
182
|
-
messages=await as_openai_chat_messages(input),
|
195
|
+
messages=await as_openai_chat_messages(input, self.is_o1_full()),
|
183
196
|
tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
|
184
197
|
tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
|
185
198
|
**self.completion_params(config, len(tools) > 0),
|
@@ -271,8 +284,10 @@ class OpenAIAPI(ModelAPI):
|
|
271
284
|
params["logprobs"] = config.logprobs
|
272
285
|
if config.top_logprobs is not None:
|
273
286
|
params["top_logprobs"] = config.top_logprobs
|
274
|
-
if tools and config.parallel_tool_calls is not None:
|
287
|
+
if tools and config.parallel_tool_calls is not None and not self.is_o1():
|
275
288
|
params["parallel_tool_calls"] = config.parallel_tool_calls
|
289
|
+
if config.reasoning_effort is not None and self.is_o1_full():
|
290
|
+
params["reasoning_effort"] = config.reasoning_effort
|
276
291
|
|
277
292
|
return params
|
278
293
|
|
@@ -291,14 +306,23 @@ class OpenAIAPI(ModelAPI):
|
|
291
306
|
|
292
307
|
|
293
308
|
async def as_openai_chat_messages(
|
294
|
-
messages: list[ChatMessage],
|
309
|
+
messages: list[ChatMessage], o1_full: bool
|
295
310
|
) -> list[ChatCompletionMessageParam]:
|
296
|
-
return [await openai_chat_message(message) for message in messages]
|
311
|
+
return [await openai_chat_message(message, o1_full) for message in messages]
|
297
312
|
|
298
313
|
|
299
|
-
async def openai_chat_message(
|
314
|
+
async def openai_chat_message(
|
315
|
+
message: ChatMessage, o1_full: bool
|
316
|
+
) -> ChatCompletionMessageParam:
|
300
317
|
if message.role == "system":
|
301
|
-
|
318
|
+
if o1_full:
|
319
|
+
return ChatCompletionDeveloperMessageParam(
|
320
|
+
role="developer", content=message.text
|
321
|
+
)
|
322
|
+
else:
|
323
|
+
return ChatCompletionSystemMessageParam(
|
324
|
+
role=message.role, content=message.text
|
325
|
+
)
|
302
326
|
elif message.role == "user":
|
303
327
|
return ChatCompletionUserMessageParam(
|
304
328
|
role=message.role,
|
@@ -362,7 +362,7 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
|
|
362
362
|
else:
|
363
363
|
arg_list = ["python3", WEB_CLIENT_REQUEST, cmd] + list(args)
|
364
364
|
|
365
|
-
result = await sandbox_env.exec(arg_list)
|
365
|
+
result = await sandbox_env.exec(arg_list, timeout=180)
|
366
366
|
if not result.success:
|
367
367
|
raise RuntimeError(
|
368
368
|
f"Error executing web browser command {cmd}({', '.join(args)}): {result.stderr}"
|
@@ -109,7 +109,7 @@ def raise_no_sandbox() -> NoReturn:
|
|
109
109
|
|
110
110
|
|
111
111
|
async def init_sandbox_environments_sample(
|
112
|
-
|
112
|
+
sandboxenv_type: type[SandboxEnvironment],
|
113
113
|
task_name: str,
|
114
114
|
config: SandboxEnvironmentConfigType | None,
|
115
115
|
files: dict[str, bytes],
|
@@ -117,7 +117,6 @@ async def init_sandbox_environments_sample(
|
|
117
117
|
metadata: dict[str, Any],
|
118
118
|
) -> dict[str, SandboxEnvironment]:
|
119
119
|
# get setup and cleanup functions
|
120
|
-
sandboxenv_type = registry_find_sandboxenv(type)
|
121
120
|
sample_init = cast(SampleInit, getattr(sandboxenv_type, "sample_init"))
|
122
121
|
sample_cleanup = cast(SampleCleanup, getattr(sandboxenv_type, "sample_cleanup"))
|
123
122
|
|
@@ -192,7 +191,12 @@ async def setup_sandbox_environment(
|
|
192
191
|
|
193
192
|
# chmod, execute, and remove
|
194
193
|
async def exec(cmd: list[str]) -> None:
|
195
|
-
|
194
|
+
try:
|
195
|
+
result = await env.exec(cmd, timeout=30)
|
196
|
+
except TimeoutError:
|
197
|
+
raise RuntimeError(
|
198
|
+
f"Timed out executing command {' '.join(cmd)} in sandbox"
|
199
|
+
)
|
196
200
|
|
197
201
|
if not result.success:
|
198
202
|
raise RuntimeError(
|
@@ -16,7 +16,7 @@ from .prereqs import (
|
|
16
16
|
DOCKER_COMPOSE_REQUIRED_VERSION_PULL_POLICY,
|
17
17
|
validate_docker_compose,
|
18
18
|
)
|
19
|
-
from .util import ComposeProject, is_inspect_project
|
19
|
+
from .util import ComposeProject, is_inspect_project
|
20
20
|
|
21
21
|
logger = getLogger(__name__)
|
22
22
|
|
@@ -31,7 +31,9 @@ async def compose_up(project: ComposeProject) -> None:
|
|
31
31
|
project=project,
|
32
32
|
)
|
33
33
|
if not result.success:
|
34
|
-
msg =
|
34
|
+
msg = (
|
35
|
+
f"Failed to start docker services for {project.config}: " f"{result.stderr}"
|
36
|
+
)
|
35
37
|
raise RuntimeError(msg)
|
36
38
|
|
37
39
|
|
@@ -94,7 +96,10 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
|
|
94
96
|
for running_service in running_services:
|
95
97
|
unhealthy_services.remove(running_service["Service"])
|
96
98
|
|
97
|
-
msg =
|
99
|
+
msg = (
|
100
|
+
"One or more docker containers failed to start from "
|
101
|
+
f"{project.config}: {','.join(unhealthy_services)}"
|
102
|
+
)
|
98
103
|
raise RuntimeError(msg)
|
99
104
|
else:
|
100
105
|
raise RuntimeError("No services started")
|
@@ -152,8 +157,9 @@ async def compose_pull(
|
|
152
157
|
|
153
158
|
async def compose_exec(
|
154
159
|
command: list[str],
|
160
|
+
*,
|
155
161
|
project: ComposeProject,
|
156
|
-
timeout: int | None
|
162
|
+
timeout: int | None,
|
157
163
|
input: str | bytes | None = None,
|
158
164
|
output_limit: int | None = None,
|
159
165
|
) -> ExecResult[str]:
|
@@ -206,7 +212,6 @@ async def compose_cleanup_images(
|
|
206
212
|
cwd: str | None = None,
|
207
213
|
timeout: int | None = None,
|
208
214
|
) -> None:
|
209
|
-
sandbox_log("Removing images")
|
210
215
|
# List the images that would be created for this compose
|
211
216
|
images_result = await compose_command(
|
212
217
|
["config", "--images"], project=project, cwd=cwd
|
@@ -241,10 +246,14 @@ async def compose_cleanup_images(
|
|
241
246
|
logger.warning(msg)
|
242
247
|
|
243
248
|
|
249
|
+
DEFAULT_COMPOSE_TIMEOUT = 60
|
250
|
+
|
251
|
+
|
244
252
|
async def compose_command(
|
245
253
|
command: list[str],
|
254
|
+
*,
|
246
255
|
project: ComposeProject,
|
247
|
-
timeout: int | None =
|
256
|
+
timeout: int | None = DEFAULT_COMPOSE_TIMEOUT,
|
248
257
|
input: str | bytes | None = None,
|
249
258
|
cwd: str | Path | None = None,
|
250
259
|
forward_env: bool = True,
|
@@ -278,16 +287,46 @@ async def compose_command(
|
|
278
287
|
# build final command
|
279
288
|
compose_command = compose_command + command
|
280
289
|
|
281
|
-
#
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
290
|
+
# function to run command
|
291
|
+
async def run_command(command_timeout: int | None) -> ExecResult[str]:
|
292
|
+
result = await subprocess(
|
293
|
+
compose_command,
|
294
|
+
input=input,
|
295
|
+
cwd=cwd,
|
296
|
+
env=env,
|
297
|
+
timeout=command_timeout,
|
298
|
+
capture_output=capture_output,
|
299
|
+
output_limit=output_limit,
|
300
|
+
)
|
301
|
+
return result
|
302
|
+
|
303
|
+
# we have observed underlying unreliability in docker compose in some linux
|
304
|
+
# environments on EC2 -- this exhibits in very simple commands (e.g. compose config)
|
305
|
+
# simply never returning. this tends to happen when we know there is a large
|
306
|
+
# number of commands in flight (task/sample init) so could be some sort of
|
307
|
+
# timing issue / race condition in the docker daemon. we've also observed that
|
308
|
+
# these same commands succeed if you just retry them. therefore, we add some
|
309
|
+
# extra resiliance by retrying commands with a timeout once. we were observing
|
310
|
+
# commands hanging at a rate of ~ 1/1000, so we retry up to twice (tweaking the
|
311
|
+
# retry time down) to make the odds of hanging vanishingly small
|
312
|
+
|
313
|
+
if timeout is not None:
|
314
|
+
MAX_RETRIES = 2
|
315
|
+
retries = 0
|
316
|
+
while True:
|
317
|
+
try:
|
318
|
+
command_timeout = (
|
319
|
+
timeout if retries == 0 else (min(timeout, 60) // retries)
|
320
|
+
)
|
321
|
+
return await run_command(command_timeout)
|
322
|
+
except TimeoutError:
|
323
|
+
retries += 1
|
324
|
+
if retries <= MAX_RETRIES:
|
325
|
+
logger.info(
|
326
|
+
f"Retrying docker compose command: {shlex.join(compose_command)}"
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
raise
|
330
|
+
|
331
|
+
else:
|
332
|
+
return await run_command(timeout)
|
@@ -2,8 +2,6 @@ import os
|
|
2
2
|
from logging import getLogger
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
|
-
import aiofiles
|
6
|
-
|
7
5
|
logger = getLogger(__name__)
|
8
6
|
|
9
7
|
|
@@ -17,7 +15,7 @@ CONFIG_FILES = [
|
|
17
15
|
DOCKERFILE = "Dockerfile"
|
18
16
|
|
19
17
|
|
20
|
-
|
18
|
+
def resolve_compose_file(parent: str = "") -> str:
|
21
19
|
# existing compose file provides all the config we need
|
22
20
|
compose = find_compose_file(parent)
|
23
21
|
if compose is not None:
|
@@ -29,11 +27,11 @@ async def resolve_compose_file(parent: str = "") -> str:
|
|
29
27
|
|
30
28
|
# dockerfile just needs a compose.yaml synthesized
|
31
29
|
elif has_dockerfile(parent):
|
32
|
-
return
|
30
|
+
return auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
|
33
31
|
|
34
32
|
# otherwise provide a generic python container
|
35
33
|
else:
|
36
|
-
return
|
34
|
+
return auto_compose_file(COMPOSE_GENERIC_YAML, parent)
|
37
35
|
|
38
36
|
|
39
37
|
def find_compose_file(parent: str = "") -> str | None:
|
@@ -59,9 +57,9 @@ def is_auto_compose_file(file: str) -> bool:
|
|
59
57
|
return os.path.basename(file) == AUTO_COMPOSE_YAML
|
60
58
|
|
61
59
|
|
62
|
-
|
60
|
+
def ensure_auto_compose_file(file: str | None) -> None:
|
63
61
|
if file is not None and is_auto_compose_file(file) and not os.path.exists(file):
|
64
|
-
|
62
|
+
resolve_compose_file(os.path.dirname(file))
|
65
63
|
|
66
64
|
|
67
65
|
def safe_cleanup_auto_compose(file: str | None) -> None:
|
@@ -100,8 +98,8 @@ services:
|
|
100
98
|
"""
|
101
99
|
|
102
100
|
|
103
|
-
|
101
|
+
def auto_compose_file(contents: str, parent: str = "") -> str:
|
104
102
|
path = os.path.join(parent, AUTO_COMPOSE_YAML)
|
105
|
-
|
106
|
-
|
103
|
+
with open(path, "w", encoding="utf-8") as f:
|
104
|
+
f.write(contents)
|
107
105
|
return Path(path).resolve().as_posix()
|