inspect-ai 0.3.52__py3-none-any.whl → 0.3.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. inspect_ai/_cli/eval.py +55 -1
  2. inspect_ai/_cli/main.py +2 -0
  3. inspect_ai/_cli/trace.py +244 -0
  4. inspect_ai/_display/core/progress.py +9 -3
  5. inspect_ai/_display/core/results.py +8 -4
  6. inspect_ai/_display/textual/app.py +5 -1
  7. inspect_ai/_display/textual/widgets/task_detail.py +3 -0
  8. inspect_ai/_display/textual/widgets/tasks.py +97 -6
  9. inspect_ai/_eval/eval.py +33 -0
  10. inspect_ai/_eval/evalset.py +4 -0
  11. inspect_ai/_eval/registry.py +2 -2
  12. inspect_ai/_eval/task/images.py +4 -14
  13. inspect_ai/_eval/task/results.py +22 -4
  14. inspect_ai/_eval/task/run.py +40 -20
  15. inspect_ai/_eval/task/sandbox.py +72 -43
  16. inspect_ai/_eval/task/task.py +4 -0
  17. inspect_ai/_eval/task/util.py +2 -0
  18. inspect_ai/_util/constants.py +3 -3
  19. inspect_ai/_util/display.py +1 -0
  20. inspect_ai/_util/logger.py +34 -8
  21. inspect_ai/_util/trace.py +275 -0
  22. inspect_ai/_view/www/App.css +13 -0
  23. inspect_ai/_view/www/dist/assets/index.css +13 -0
  24. inspect_ai/_view/www/dist/assets/index.js +80 -43
  25. inspect_ai/_view/www/src/App.mjs +31 -6
  26. inspect_ai/_view/www/src/Types.mjs +6 -0
  27. inspect_ai/_view/www/src/components/JsonPanel.mjs +11 -17
  28. inspect_ai/_view/www/src/components/MessageContent.mjs +9 -2
  29. inspect_ai/_view/www/src/components/Tools.mjs +46 -18
  30. inspect_ai/_view/www/src/navbar/Navbar.mjs +12 -0
  31. inspect_ai/_view/www/src/samples/SampleList.mjs +2 -2
  32. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +2 -2
  33. inspect_ai/log/_log.py +6 -0
  34. inspect_ai/log/_message.py +2 -2
  35. inspect_ai/log/_recorders/eval.py +8 -18
  36. inspect_ai/log/_recorders/json.py +19 -17
  37. inspect_ai/model/_cache.py +22 -16
  38. inspect_ai/model/_call_tools.py +9 -1
  39. inspect_ai/model/_generate_config.py +8 -2
  40. inspect_ai/model/_model.py +11 -12
  41. inspect_ai/model/_providers/azureai.py +1 -1
  42. inspect_ai/model/_providers/bedrock.py +18 -2
  43. inspect_ai/model/_providers/hf.py +1 -1
  44. inspect_ai/model/_providers/openai.py +32 -8
  45. inspect_ai/model/_providers/providers.py +1 -1
  46. inspect_ai/model/_providers/vllm.py +1 -1
  47. inspect_ai/tool/_tools/_web_browser/_web_browser.py +1 -1
  48. inspect_ai/util/_sandbox/context.py +7 -3
  49. inspect_ai/util/_sandbox/docker/compose.py +58 -19
  50. inspect_ai/util/_sandbox/docker/config.py +8 -10
  51. inspect_ai/util/_sandbox/docker/docker.py +20 -16
  52. inspect_ai/util/_sandbox/docker/util.py +3 -9
  53. inspect_ai/util/_sandbox/environment.py +7 -2
  54. inspect_ai/util/_sandbox/limits.py +1 -1
  55. inspect_ai/util/_sandbox/local.py +8 -9
  56. inspect_ai/util/_sandbox/service.py +17 -7
  57. inspect_ai/util/_subprocess.py +6 -1
  58. inspect_ai/util/_subtask.py +8 -2
  59. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/METADATA +6 -8
  60. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/RECORD +64 -62
  61. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/LICENSE +0 -0
  62. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/WHEEL +0 -0
  63. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/entry_points.txt +0 -0
  64. {inspect_ai-0.3.52.dist-info → inspect_ai-0.3.54.dist-info}/top_level.txt +0 -0
@@ -6,10 +6,12 @@ from datetime import datetime, timezone
6
6
  from hashlib import md5
7
7
  from pathlib import Path
8
8
  from shutil import rmtree
9
+ from typing import Any
9
10
 
10
11
  from dateutil.relativedelta import relativedelta
11
12
 
12
13
  from inspect_ai._util.appdirs import inspect_cache_dir
14
+ from inspect_ai._util.trace import trace_message
13
15
  from inspect_ai.tool import ToolChoice, ToolInfo
14
16
 
15
17
  from ._chat_message import ChatMessage
@@ -19,6 +21,10 @@ from ._model_output import ModelOutput
19
21
  logger = logging.getLogger(__name__)
20
22
 
21
23
 
24
+ def trace(msg: str, *args: Any) -> None:
25
+ trace_message(logger, "Cache", msg, args)
26
+
27
+
22
28
  def _path_is_in_cache(path: Path | str) -> bool:
23
29
  """This ensures the path is in our cache directory, just in case the `model` is ../../../home/ubuntu/maliciousness"""
24
30
  if isinstance(path, str):
@@ -153,7 +159,7 @@ def _cache_key(entry: CacheEntry) -> str:
153
159
 
154
160
  base_string = "|".join([str(component) for component in components])
155
161
 
156
- logger.debug(_cache_key_debug_string([str(component) for component in components]))
162
+ trace(_cache_key_debug_string([str(component) for component in components]))
157
163
 
158
164
  return md5(base_string.encode("utf-8")).hexdigest()
159
165
 
@@ -192,11 +198,11 @@ def cache_store(
192
198
 
193
199
  with open(filename, "wb") as f:
194
200
  expiry = _cache_expiry(entry.policy)
195
- logger.debug("Storing in cache: %s (expires: %s)", filename, expiry)
201
+ trace("Storing in cache: %s (expires: %s)", filename, expiry)
196
202
  pickle.dump((expiry, output), f)
197
203
  return True
198
204
  except Exception as e:
199
- logger.debug(f"Failed to cache {filename}: {e}")
205
+ trace(f"Failed to cache {filename}: {e}")
200
206
  return False
201
207
 
202
208
 
@@ -204,12 +210,12 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
204
210
  """Fetch a value from the cache directory."""
205
211
  filename = cache_path(model=entry.model) / _cache_key(entry)
206
212
  try:
207
- logger.debug("Fetching from cache: %s", filename)
213
+ trace("Fetching from cache: %s", filename)
208
214
 
209
215
  with open(filename, "rb") as f:
210
216
  expiry, output = pickle.load(f)
211
217
  if not isinstance(output, ModelOutput):
212
- logger.debug(
218
+ trace(
213
219
  "Unexpected cached type, can only fetch ModelOutput: %s (%s)",
214
220
  type(output),
215
221
  filename,
@@ -217,7 +223,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
217
223
  return None
218
224
 
219
225
  if _is_expired(expiry):
220
- logger.debug("Cache expired for %s (%s)", filename, expiry)
226
+ trace("Cache expired for %s (%s)", filename, expiry)
221
227
  # If it's expired, no point keeping it as we'll never access it
222
228
  # successfully again.
223
229
  filename.unlink(missing_ok=True)
@@ -225,7 +231,7 @@ def cache_fetch(entry: CacheEntry) -> ModelOutput | None:
225
231
 
226
232
  return output
227
233
  except Exception as e:
228
- logger.debug(f"Failed to fetch from cache {filename}: {e}")
234
+ trace(f"Failed to fetch from cache {filename}: {e}")
229
235
  return None
230
236
 
231
237
 
@@ -235,7 +241,7 @@ def cache_clear(model: str = "") -> bool:
235
241
  path = cache_path(model)
236
242
 
237
243
  if (model == "" or _path_is_in_cache(path)) and path.exists():
238
- logger.debug("Clearing cache: %s", path)
244
+ trace("Clearing cache: %s", path)
239
245
  rmtree(path)
240
246
  return True
241
247
 
@@ -351,24 +357,24 @@ def cache_list_expired(filter_by: list[str] = []) -> list[Path]:
351
357
  # "../../foo/bar") but we don't want to search the entire cache
352
358
  return []
353
359
 
354
- logger.debug("Filtering by paths: %s", filter_by_paths)
360
+ trace("Filtering by paths: %s", filter_by_paths)
355
361
  for dirpath, _dirnames, filenames in os.walk(cache_path()):
356
362
  if filter_by_paths and Path(dirpath) not in filter_by_paths:
357
- logger.debug("Skipping path %s", dirpath)
363
+ trace("Skipping path %s", dirpath)
358
364
  continue
359
365
 
360
- logger.debug("Checking dirpath %s", dirpath)
366
+ trace("Checking dirpath %s", dirpath)
361
367
  for filename in filenames:
362
368
  path = Path(dirpath) / filename
363
- logger.debug("Checking path %s", path)
369
+ trace("Checking path %s", path)
364
370
  try:
365
371
  with open(path, "rb") as f:
366
372
  expiry, _cache_entry = pickle.load(f)
367
373
  if _is_expired(expiry):
368
- logger.debug("Expired cache entry found: %s (%s)", path, expiry)
374
+ trace("Expired cache entry found: %s (%s)", path, expiry)
369
375
  expired_cache_entries.append(path)
370
376
  except Exception as e:
371
- logger.debug("Failed to load cached item %s: %s", path, e)
377
+ trace("Failed to load cached item %s: %s", path, e)
372
378
  continue
373
379
 
374
380
  return expired_cache_entries
@@ -389,8 +395,8 @@ def cache_prune(files: list[Path] = []) -> None:
389
395
  with open(file, "rb") as f:
390
396
  expiry, _cache_entry = pickle.load(f)
391
397
  if _is_expired(expiry):
392
- logger.debug("Pruning expired cache: %s", file)
398
+ trace("Pruning expired cache: %s", file)
393
399
  file.unlink(missing_ok=True)
394
400
  except Exception as e:
395
- logger.debug("Failed to prune cache %s: %s", file, e)
401
+ trace("Failed to prune cache %s: %s", file, e)
396
402
  continue
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import inspect
3
3
  from dataclasses import is_dataclass
4
+ from logging import getLogger
4
5
  from textwrap import dedent
5
6
  from typing import (
6
7
  Any,
@@ -19,7 +20,9 @@ from jsonschema import Draft7Validator
19
20
  from pydantic import BaseModel
20
21
 
21
22
  from inspect_ai._util.content import Content, ContentImage, ContentText
23
+ from inspect_ai._util.format import format_function_call
22
24
  from inspect_ai._util.text import truncate_string_to_bytes
25
+ from inspect_ai._util.trace import trace_action
23
26
  from inspect_ai.model._trace import trace_tool_mesage
24
27
  from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
25
28
  from inspect_ai.tool._tool import (
@@ -35,6 +38,8 @@ from inspect_ai.util import OutputLimitExceededError
35
38
  from ._chat_message import ChatMessageAssistant, ChatMessageTool
36
39
  from ._generate_config import active_generate_config
37
40
 
41
+ logger = getLogger(__name__)
42
+
38
43
 
39
44
  async def call_tools(
40
45
  message: ChatMessageAssistant,
@@ -215,7 +220,10 @@ async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
215
220
  arguments = tool_params(call.arguments, tool_def.tool)
216
221
 
217
222
  # call the tool
218
- result = await tool_def.tool(**arguments)
223
+ with trace_action(
224
+ logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
225
+ ):
226
+ result = await tool_def.tool(**arguments)
219
227
 
220
228
  # return result
221
229
  return result
@@ -58,7 +58,7 @@ class GenerateConfigArgs(TypedDict, total=False):
58
58
  """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, and TogetherAI only."""
59
59
 
60
60
  logprobs: bool | None
61
- """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, and Huggingface only."""
61
+ """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
62
62
 
63
63
  top_logprobs: int | None
64
64
  """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, and Huggingface only."""
@@ -72,6 +72,9 @@ class GenerateConfigArgs(TypedDict, total=False):
72
72
  cache_prompt: Literal["auto"] | bool | None
73
73
  """Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
74
74
 
75
+ reasoning_effort: Literal["low", "medium", "high"] | None
76
+ """Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
77
+
75
78
 
76
79
  class GenerateConfig(BaseModel):
77
80
  """Base class for model generation configs."""
@@ -125,7 +128,7 @@ class GenerateConfig(BaseModel):
125
128
  """How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and vLLM only."""
126
129
 
127
130
  logprobs: bool | None = Field(default=None)
128
- """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, and vLLM only."""
131
+ """Return log probabilities of the output tokens. OpenAI, Google, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
129
132
 
130
133
  top_logprobs: int | None = Field(default=None)
131
134
  """Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Google, Grok, Huggingface, and vLLM only."""
@@ -139,6 +142,9 @@ class GenerateConfig(BaseModel):
139
142
  cache_prompt: Literal["auto"] | bool | None = Field(default=None)
140
143
  """Whether to cache the prompt prefix. Defaults to "auto", which will enable caching for requests with tools. Anthropic only."""
141
144
 
145
+ reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None)
146
+ """Constrains effort on reasoning for reasoning models. Open AI o1 models only."""
147
+
142
148
  def merge(
143
149
  self, other: Union["GenerateConfig", GenerateConfigArgs]
144
150
  ) -> "GenerateConfig":
@@ -9,7 +9,6 @@ from contextvars import ContextVar
9
9
  from copy import deepcopy
10
10
  from typing import Any, Callable, Literal, Type, cast
11
11
 
12
- from shortuuid import uuid
13
12
  from tenacity import (
14
13
  retry,
15
14
  retry_if_exception,
@@ -30,6 +29,7 @@ from inspect_ai._util.registry import (
30
29
  registry_unqualified_name,
31
30
  )
32
31
  from inspect_ai._util.retry import log_rate_limit_retry
32
+ from inspect_ai._util.trace import trace_action
33
33
  from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
34
34
  from inspect_ai.tool._tool_def import ToolDef, tool_defs
35
35
  from inspect_ai.util import concurrency
@@ -363,17 +363,16 @@ class Model:
363
363
  cache="write" if cache else None,
364
364
  )
365
365
 
366
- generate_id = uuid()
367
- logger.debug(f"model generate {generate_id} ({str(self)})")
368
- time_start = time.perf_counter()
369
- result = await self.api.generate(
370
- input=input,
371
- tools=tools,
372
- tool_choice=tool_choice,
373
- config=config,
374
- )
375
- time_elapsed = time.perf_counter() - time_start
376
- logger.debug(f"model generate {generate_id} (completed)")
366
+ with trace_action(logger, "Model", f"generate ({str(self)})"):
367
+ time_start = time.perf_counter()
368
+ result = await self.api.generate(
369
+ input=input,
370
+ tools=tools,
371
+ tool_choice=tool_choice,
372
+ config=config,
373
+ )
374
+ time_elapsed = time.perf_counter() - time_start
375
+
377
376
  if isinstance(result, tuple):
378
377
  output, call = result
379
378
  else:
@@ -93,7 +93,7 @@ class AzureAIAPI(ModelAPI):
93
93
  def collect_model_arg(name: str) -> Any | None:
94
94
  nonlocal model_args
95
95
  value = model_args.get(name, None)
96
- if value:
96
+ if value is not None:
97
97
  model_args.pop(name)
98
98
  return value
99
99
 
@@ -236,15 +236,21 @@ class BedrockAPI(ModelAPI):
236
236
  self,
237
237
  model_name: str,
238
238
  base_url: str | None,
239
+ api_key: str | None = None,
239
240
  config: GenerateConfig = GenerateConfig(),
240
241
  **model_args: Any,
241
242
  ):
242
243
  super().__init__(
243
244
  model_name=model_name,
244
245
  base_url=model_base_url(base_url, "BEDROCK_BASE_URL"),
246
+ api_key=api_key,
247
+ api_key_vars=[],
245
248
  config=config,
246
249
  )
247
250
 
251
+ # save model_args
252
+ self.model_args = model_args
253
+
248
254
  # import aioboto3 on demand
249
255
  try:
250
256
  import aioboto3
@@ -263,6 +269,9 @@ class BedrockAPI(ModelAPI):
263
269
 
264
270
  @override
265
271
  def max_tokens(self) -> int | None:
272
+ if "llama3-70" in self.model_name or "llama3-8" in self.model_name:
273
+ return 2048
274
+
266
275
  if "llama3" in self.model_name or "claude3" in self.model_name:
267
276
  return 4096
268
277
 
@@ -303,7 +312,7 @@ class BedrockAPI(ModelAPI):
303
312
  from botocore.exceptions import ClientError
304
313
 
305
314
  # The bedrock client
306
- async with self.session.client(
315
+ async with self.session.client( # type: ignore[call-overload]
307
316
  service_name="bedrock-runtime",
308
317
  endpoint_url=self.base_url,
309
318
  config=Config(
@@ -316,6 +325,7 @@ class BedrockAPI(ModelAPI):
316
325
  mode="adaptive",
317
326
  ),
318
327
  ),
328
+ **self.model_args,
319
329
  ) as client:
320
330
  # Process the tools
321
331
  resolved_tools = converse_tools(tools)
@@ -658,6 +668,8 @@ def converse_image_type(type: str) -> ConverseImageFormat:
658
668
  return "png"
659
669
  case "image/webp":
660
670
  return "webp"
671
+ case "image/jpeg":
672
+ return "jpeg"
661
673
  case _:
662
674
  raise ValueError(
663
675
  f"Image mime type {type} is not supported for Bedrock Converse models."
@@ -673,7 +685,11 @@ def converse_tools(tools: list[ToolInfo]) -> list[ConverseTool] | None:
673
685
  tool_spec = ConverseToolSpec(
674
686
  name=tool.name,
675
687
  description=tool.description,
676
- inputSchema={"json": tool.parameters.model_dump(exclude_none=True)},
688
+ inputSchema={
689
+ "json": tool.parameters.model_dump(
690
+ exclude_none=True, exclude={"additionalProperties"}
691
+ )
692
+ },
677
693
  )
678
694
  result.append(ConverseTool(toolSpec=tool_spec))
679
695
  return result
@@ -64,7 +64,7 @@ class HuggingFaceAPI(ModelAPI):
64
64
  def collect_model_arg(name: str) -> Any | None:
65
65
  nonlocal model_args
66
66
  value = model_args.get(name, None)
67
- if value:
67
+ if value is not None:
68
68
  model_args.pop(name)
69
69
  return value
70
70
 
@@ -18,6 +18,7 @@ from openai.types.chat import (
18
18
  ChatCompletionContentPartImageParam,
19
19
  ChatCompletionContentPartParam,
20
20
  ChatCompletionContentPartTextParam,
21
+ ChatCompletionDeveloperMessageParam,
21
22
  ChatCompletionMessage,
22
23
  ChatCompletionMessageParam,
23
24
  ChatCompletionMessageToolCallParam,
@@ -141,6 +142,18 @@ class OpenAIAPI(ModelAPI):
141
142
  **model_args,
142
143
  )
143
144
 
145
+ def is_o1(self) -> bool:
146
+ return self.model_name.startswith("o1")
147
+
148
+ def is_o1_full(self) -> bool:
149
+ return self.is_o1() and not self.is_o1_mini() and not self.is_o1_preview()
150
+
151
+ def is_o1_mini(self) -> bool:
152
+ return self.model_name.startswith("o1-mini")
153
+
154
+ def is_o1_preview(self) -> bool:
155
+ return self.model_name.startswith("o1-preview")
156
+
144
157
  async def generate(
145
158
  self,
146
159
  input: list[ChatMessage],
@@ -148,8 +161,8 @@ class OpenAIAPI(ModelAPI):
148
161
  tool_choice: ToolChoice,
149
162
  config: GenerateConfig,
150
163
  ) -> ModelOutput | tuple[ModelOutput, ModelCall]:
151
- # short-circuit to call o1- model
152
- if self.model_name.startswith("o1-"):
164
+ # short-circuit to call o1- models that are text only
165
+ if self.is_o1_preview() or self.is_o1_mini():
153
166
  return await generate_o1(
154
167
  client=self.client,
155
168
  input=input,
@@ -179,7 +192,7 @@ class OpenAIAPI(ModelAPI):
179
192
 
180
193
  # prepare request (we do this so we can log the ModelCall)
181
194
  request = dict(
182
- messages=await as_openai_chat_messages(input),
195
+ messages=await as_openai_chat_messages(input, self.is_o1_full()),
183
196
  tools=chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
184
197
  tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
185
198
  **self.completion_params(config, len(tools) > 0),
@@ -271,8 +284,10 @@ class OpenAIAPI(ModelAPI):
271
284
  params["logprobs"] = config.logprobs
272
285
  if config.top_logprobs is not None:
273
286
  params["top_logprobs"] = config.top_logprobs
274
- if tools and config.parallel_tool_calls is not None:
287
+ if tools and config.parallel_tool_calls is not None and not self.is_o1():
275
288
  params["parallel_tool_calls"] = config.parallel_tool_calls
289
+ if config.reasoning_effort is not None and self.is_o1_full():
290
+ params["reasoning_effort"] = config.reasoning_effort
276
291
 
277
292
  return params
278
293
 
@@ -291,14 +306,23 @@ class OpenAIAPI(ModelAPI):
291
306
 
292
307
 
293
308
  async def as_openai_chat_messages(
294
- messages: list[ChatMessage],
309
+ messages: list[ChatMessage], o1_full: bool
295
310
  ) -> list[ChatCompletionMessageParam]:
296
- return [await openai_chat_message(message) for message in messages]
311
+ return [await openai_chat_message(message, o1_full) for message in messages]
297
312
 
298
313
 
299
- async def openai_chat_message(message: ChatMessage) -> ChatCompletionMessageParam:
314
+ async def openai_chat_message(
315
+ message: ChatMessage, o1_full: bool
316
+ ) -> ChatCompletionMessageParam:
300
317
  if message.role == "system":
301
- return ChatCompletionSystemMessageParam(role=message.role, content=message.text)
318
+ if o1_full:
319
+ return ChatCompletionDeveloperMessageParam(
320
+ role="developer", content=message.text
321
+ )
322
+ else:
323
+ return ChatCompletionSystemMessageParam(
324
+ role=message.role, content=message.text
325
+ )
302
326
  elif message.role == "user":
303
327
  return ChatCompletionUserMessageParam(
304
328
  role=message.role,
@@ -242,7 +242,7 @@ def mockllm() -> type[ModelAPI]:
242
242
  def validate_openai_client(feature: str) -> None:
243
243
  FEATURE = feature
244
244
  PACKAGE = "openai"
245
- MIN_VERSION = "1.45.0"
245
+ MIN_VERSION = "1.58.1"
246
246
 
247
247
  # verify we have the package
248
248
  try:
@@ -75,7 +75,7 @@ class VLLMAPI(ModelAPI):
75
75
  def collect_model_arg(name: str) -> Any | None:
76
76
  nonlocal model_args
77
77
  value = model_args.get(name, None)
78
- if value:
78
+ if value is not None:
79
79
  model_args.pop(name)
80
80
  return value
81
81
 
@@ -362,7 +362,7 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
362
362
  else:
363
363
  arg_list = ["python3", WEB_CLIENT_REQUEST, cmd] + list(args)
364
364
 
365
- result = await sandbox_env.exec(arg_list)
365
+ result = await sandbox_env.exec(arg_list, timeout=180)
366
366
  if not result.success:
367
367
  raise RuntimeError(
368
368
  f"Error executing web browser command {cmd}({', '.join(args)}): {result.stderr}"
@@ -109,7 +109,7 @@ def raise_no_sandbox() -> NoReturn:
109
109
 
110
110
 
111
111
  async def init_sandbox_environments_sample(
112
- type: str,
112
+ sandboxenv_type: type[SandboxEnvironment],
113
113
  task_name: str,
114
114
  config: SandboxEnvironmentConfigType | None,
115
115
  files: dict[str, bytes],
@@ -117,7 +117,6 @@ async def init_sandbox_environments_sample(
117
117
  metadata: dict[str, Any],
118
118
  ) -> dict[str, SandboxEnvironment]:
119
119
  # get setup and cleanup functions
120
- sandboxenv_type = registry_find_sandboxenv(type)
121
120
  sample_init = cast(SampleInit, getattr(sandboxenv_type, "sample_init"))
122
121
  sample_cleanup = cast(SampleCleanup, getattr(sandboxenv_type, "sample_cleanup"))
123
122
 
@@ -192,7 +191,12 @@ async def setup_sandbox_environment(
192
191
 
193
192
  # chmod, execute, and remove
194
193
  async def exec(cmd: list[str]) -> None:
195
- result = await env.exec(cmd)
194
+ try:
195
+ result = await env.exec(cmd, timeout=30)
196
+ except TimeoutError:
197
+ raise RuntimeError(
198
+ f"Timed out executing command {' '.join(cmd)} in sandbox"
199
+ )
196
200
 
197
201
  if not result.success:
198
202
  raise RuntimeError(
@@ -16,7 +16,7 @@ from .prereqs import (
16
16
  DOCKER_COMPOSE_REQUIRED_VERSION_PULL_POLICY,
17
17
  validate_docker_compose,
18
18
  )
19
- from .util import ComposeProject, is_inspect_project, sandbox_log
19
+ from .util import ComposeProject, is_inspect_project
20
20
 
21
21
  logger = getLogger(__name__)
22
22
 
@@ -31,7 +31,9 @@ async def compose_up(project: ComposeProject) -> None:
31
31
  project=project,
32
32
  )
33
33
  if not result.success:
34
- msg = f"Failed to start docker services {result.stderr}"
34
+ msg = (
35
+ f"Failed to start docker services for {project.config}: " f"{result.stderr}"
36
+ )
35
37
  raise RuntimeError(msg)
36
38
 
37
39
 
@@ -94,7 +96,10 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
94
96
  for running_service in running_services:
95
97
  unhealthy_services.remove(running_service["Service"])
96
98
 
97
- msg = f"One or more docker containers failed to start {','.join(unhealthy_services)}"
99
+ msg = (
100
+ "One or more docker containers failed to start from "
101
+ f"{project.config}: {','.join(unhealthy_services)}"
102
+ )
98
103
  raise RuntimeError(msg)
99
104
  else:
100
105
  raise RuntimeError("No services started")
@@ -152,8 +157,9 @@ async def compose_pull(
152
157
 
153
158
  async def compose_exec(
154
159
  command: list[str],
160
+ *,
155
161
  project: ComposeProject,
156
- timeout: int | None = None,
162
+ timeout: int | None,
157
163
  input: str | bytes | None = None,
158
164
  output_limit: int | None = None,
159
165
  ) -> ExecResult[str]:
@@ -206,7 +212,6 @@ async def compose_cleanup_images(
206
212
  cwd: str | None = None,
207
213
  timeout: int | None = None,
208
214
  ) -> None:
209
- sandbox_log("Removing images")
210
215
  # List the images that would be created for this compose
211
216
  images_result = await compose_command(
212
217
  ["config", "--images"], project=project, cwd=cwd
@@ -241,10 +246,14 @@ async def compose_cleanup_images(
241
246
  logger.warning(msg)
242
247
 
243
248
 
249
+ DEFAULT_COMPOSE_TIMEOUT = 60
250
+
251
+
244
252
  async def compose_command(
245
253
  command: list[str],
254
+ *,
246
255
  project: ComposeProject,
247
- timeout: int | None = None,
256
+ timeout: int | None = DEFAULT_COMPOSE_TIMEOUT,
248
257
  input: str | bytes | None = None,
249
258
  cwd: str | Path | None = None,
250
259
  forward_env: bool = True,
@@ -278,16 +287,46 @@ async def compose_command(
278
287
  # build final command
279
288
  compose_command = compose_command + command
280
289
 
281
- # Execute the command
282
- sandbox_log(f"compose command: {shlex.join(compose_command)}")
283
- result = await subprocess(
284
- compose_command,
285
- input=input,
286
- cwd=cwd,
287
- env=env,
288
- timeout=timeout,
289
- capture_output=capture_output,
290
- output_limit=output_limit,
291
- )
292
- sandbox_log(f"compose command completed: {shlex.join(compose_command)}")
293
- return result
290
+ # function to run command
291
+ async def run_command(command_timeout: int | None) -> ExecResult[str]:
292
+ result = await subprocess(
293
+ compose_command,
294
+ input=input,
295
+ cwd=cwd,
296
+ env=env,
297
+ timeout=command_timeout,
298
+ capture_output=capture_output,
299
+ output_limit=output_limit,
300
+ )
301
+ return result
302
+
303
+ # we have observed underlying unreliability in docker compose in some linux
304
+ # environments on EC2 -- this exhibits in very simple commands (e.g. compose config)
305
+ # simply never returning. this tends to happen when we know there is a large
306
+ # number of commands in flight (task/sample init) so could be some sort of
307
+ # timing issue / race condition in the docker daemon. we've also observed that
308
+ # these same commands succeed if you just retry them. therefore, we add some
309
+ # extra resiliance by retrying commands with a timeout once. we were observing
310
+ # commands hanging at a rate of ~ 1/1000, so we retry up to twice (tweaking the
311
+ # retry time down) to make the odds of hanging vanishingly small
312
+
313
+ if timeout is not None:
314
+ MAX_RETRIES = 2
315
+ retries = 0
316
+ while True:
317
+ try:
318
+ command_timeout = (
319
+ timeout if retries == 0 else (min(timeout, 60) // retries)
320
+ )
321
+ return await run_command(command_timeout)
322
+ except TimeoutError:
323
+ retries += 1
324
+ if retries <= MAX_RETRIES:
325
+ logger.info(
326
+ f"Retrying docker compose command: {shlex.join(compose_command)}"
327
+ )
328
+ else:
329
+ raise
330
+
331
+ else:
332
+ return await run_command(timeout)
@@ -2,8 +2,6 @@ import os
2
2
  from logging import getLogger
3
3
  from pathlib import Path
4
4
 
5
- import aiofiles
6
-
7
5
  logger = getLogger(__name__)
8
6
 
9
7
 
@@ -17,7 +15,7 @@ CONFIG_FILES = [
17
15
  DOCKERFILE = "Dockerfile"
18
16
 
19
17
 
20
- async def resolve_compose_file(parent: str = "") -> str:
18
+ def resolve_compose_file(parent: str = "") -> str:
21
19
  # existing compose file provides all the config we need
22
20
  compose = find_compose_file(parent)
23
21
  if compose is not None:
@@ -29,11 +27,11 @@ async def resolve_compose_file(parent: str = "") -> str:
29
27
 
30
28
  # dockerfile just needs a compose.yaml synthesized
31
29
  elif has_dockerfile(parent):
32
- return await auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
30
+ return auto_compose_file(COMPOSE_DOCKERFILE_YAML, parent)
33
31
 
34
32
  # otherwise provide a generic python container
35
33
  else:
36
- return await auto_compose_file(COMPOSE_GENERIC_YAML, parent)
34
+ return auto_compose_file(COMPOSE_GENERIC_YAML, parent)
37
35
 
38
36
 
39
37
  def find_compose_file(parent: str = "") -> str | None:
@@ -59,9 +57,9 @@ def is_auto_compose_file(file: str) -> bool:
59
57
  return os.path.basename(file) == AUTO_COMPOSE_YAML
60
58
 
61
59
 
62
- async def ensure_auto_compose_file(file: str | None) -> None:
60
+ def ensure_auto_compose_file(file: str | None) -> None:
63
61
  if file is not None and is_auto_compose_file(file) and not os.path.exists(file):
64
- await resolve_compose_file(os.path.dirname(file))
62
+ resolve_compose_file(os.path.dirname(file))
65
63
 
66
64
 
67
65
  def safe_cleanup_auto_compose(file: str | None) -> None:
@@ -100,8 +98,8 @@ services:
100
98
  """
101
99
 
102
100
 
103
- async def auto_compose_file(contents: str, parent: str = "") -> str:
101
+ def auto_compose_file(contents: str, parent: str = "") -> str:
104
102
  path = os.path.join(parent, AUTO_COMPOSE_YAML)
105
- async with aiofiles.open(path, "w", encoding="utf-8") as f:
106
- await f.write(contents)
103
+ with open(path, "w", encoding="utf-8") as f:
104
+ f.write(contents)
107
105
  return Path(path).resolve().as_posix()