inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +4 -2
  3. inspect_ai/_cli/eval.py +2 -0
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +0 -2
  6. inspect_ai/_display/core/panel.py +1 -1
  7. inspect_ai/_display/rich/display.py +4 -4
  8. inspect_ai/_display/textual/app.py +4 -1
  9. inspect_ai/_display/textual/widgets/samples.py +41 -5
  10. inspect_ai/_eval/eval.py +32 -20
  11. inspect_ai/_eval/evalset.py +7 -5
  12. inspect_ai/_eval/run.py +16 -11
  13. inspect_ai/_eval/task/__init__.py +2 -2
  14. inspect_ai/_eval/task/images.py +40 -25
  15. inspect_ai/_eval/task/run.py +141 -119
  16. inspect_ai/_eval/task/task.py +140 -25
  17. inspect_ai/_util/constants.py +1 -0
  18. inspect_ai/_util/content.py +23 -1
  19. inspect_ai/_util/datetime.py +1 -1
  20. inspect_ai/_util/deprecation.py +1 -1
  21. inspect_ai/_util/images.py +20 -17
  22. inspect_ai/_util/json.py +11 -1
  23. inspect_ai/_util/kvstore.py +73 -0
  24. inspect_ai/_util/logger.py +2 -1
  25. inspect_ai/_util/notgiven.py +18 -0
  26. inspect_ai/_util/thread.py +5 -0
  27. inspect_ai/_util/trace.py +39 -3
  28. inspect_ai/_util/transcript.py +36 -7
  29. inspect_ai/_view/www/.prettierrc.js +12 -0
  30. inspect_ai/_view/www/dist/assets/index.js +322 -226
  31. inspect_ai/_view/www/log-schema.json +221 -138
  32. inspect_ai/_view/www/src/App.mjs +18 -9
  33. inspect_ai/_view/www/src/Types.mjs +0 -1
  34. inspect_ai/_view/www/src/api/Types.mjs +15 -4
  35. inspect_ai/_view/www/src/api/api-http.mjs +2 -0
  36. inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
  37. inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
  38. inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
  39. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  40. inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
  41. inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
  42. inspect_ai/_view/www/src/components/Tools.mjs +18 -3
  43. inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
  44. inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
  45. inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
  46. inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
  47. inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
  48. inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
  49. inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
  50. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
  51. inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
  52. inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
  53. inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
  54. inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
  55. inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
  56. inspect_ai/_view/www/src/types/log.d.ts +53 -35
  57. inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
  58. inspect_ai/approval/_human/util.py +2 -2
  59. inspect_ai/dataset/_sources/csv.py +2 -1
  60. inspect_ai/dataset/_sources/json.py +2 -1
  61. inspect_ai/dataset/_sources/util.py +15 -7
  62. inspect_ai/log/_condense.py +11 -1
  63. inspect_ai/log/_log.py +27 -5
  64. inspect_ai/log/_recorders/eval.py +21 -8
  65. inspect_ai/log/_samples.py +10 -5
  66. inspect_ai/log/_transcript.py +28 -1
  67. inspect_ai/model/__init__.py +10 -2
  68. inspect_ai/model/_call_tools.py +82 -17
  69. inspect_ai/model/_chat_message.py +2 -4
  70. inspect_ai/model/{_trace.py → _conversation.py} +9 -8
  71. inspect_ai/model/_model.py +2 -2
  72. inspect_ai/model/_providers/anthropic.py +9 -7
  73. inspect_ai/model/_providers/azureai.py +6 -4
  74. inspect_ai/model/_providers/bedrock.py +6 -4
  75. inspect_ai/model/_providers/google.py +103 -14
  76. inspect_ai/model/_providers/groq.py +7 -5
  77. inspect_ai/model/_providers/hf.py +11 -6
  78. inspect_ai/model/_providers/mistral.py +6 -9
  79. inspect_ai/model/_providers/openai.py +34 -8
  80. inspect_ai/model/_providers/openai_o1.py +10 -12
  81. inspect_ai/model/_providers/vertex.py +17 -4
  82. inspect_ai/scorer/__init__.py +13 -2
  83. inspect_ai/scorer/_metrics/__init__.py +2 -2
  84. inspect_ai/scorer/_metrics/std.py +3 -3
  85. inspect_ai/tool/__init__.py +9 -1
  86. inspect_ai/tool/_tool.py +9 -2
  87. inspect_ai/tool/_tool_info.py +2 -1
  88. inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
  89. inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
  90. inspect_ai/util/__init__.py +4 -3
  91. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  92. inspect_ai/util/_display.py +14 -4
  93. inspect_ai/util/_sandbox/context.py +12 -13
  94. inspect_ai/util/_sandbox/docker/compose.py +24 -13
  95. inspect_ai/util/_sandbox/docker/docker.py +20 -13
  96. inspect_ai/util/_sandbox/docker/util.py +2 -1
  97. inspect_ai/util/_sandbox/environment.py +13 -1
  98. inspect_ai/util/_sandbox/local.py +1 -0
  99. inspect_ai/util/_sandbox/self_check.py +18 -18
  100. inspect_ai/util/_store.py +2 -2
  101. inspect_ai/util/_subprocess.py +3 -3
  102. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
  103. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
  104. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
  105. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
  106. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
  107. {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from openai.types.chat import (
17
17
  ChatCompletion,
18
18
  ChatCompletionAssistantMessageParam,
19
19
  ChatCompletionContentPartImageParam,
20
+ ChatCompletionContentPartInputAudioParam,
20
21
  ChatCompletionContentPartParam,
21
22
  ChatCompletionContentPartTextParam,
22
23
  ChatCompletionDeveloperMessageParam,
@@ -36,9 +37,9 @@ from typing_extensions import override
36
37
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
37
38
  from inspect_ai._util.content import Content
38
39
  from inspect_ai._util.error import PrerequisiteError
39
- from inspect_ai._util.images import image_as_data_uri
40
+ from inspect_ai._util.images import file_as_data_uri
40
41
  from inspect_ai._util.logger import warn_once
41
- from inspect_ai._util.url import is_data_uri, is_http_url
42
+ from inspect_ai._util.url import is_http_url
42
43
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
43
44
 
44
45
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -51,6 +52,7 @@ from .._model_output import (
51
52
  Logprobs,
52
53
  ModelOutput,
53
54
  ModelUsage,
55
+ StopReason,
54
56
  )
55
57
  from .openai_o1 import generate_o1
56
58
  from .util import (
@@ -262,7 +264,10 @@ class OpenAIAPI(ModelAPI):
262
264
  model=self.model_name,
263
265
  )
264
266
  if config.max_tokens is not None:
265
- params["max_tokens"] = config.max_tokens
267
+ if self.is_o1():
268
+ params["max_completion_tokens"] = config.max_tokens
269
+ else:
270
+ params["max_tokens"] = config.max_tokens
266
271
  if config.frequency_penalty is not None:
267
272
  params["frequency_penalty"] = config.frequency_penalty
268
273
  if config.stop_seqs is not None:
@@ -303,13 +308,23 @@ class OpenAIAPI(ModelAPI):
303
308
 
304
309
  # convert some well known bad request errors into ModelOutput
305
310
  def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
306
- if e.status_code == 400 and e.code == "context_length_exceeded":
311
+ if e.status_code == 400:
312
+ # extract message
307
313
  if isinstance(e.body, dict) and "message" in e.body.keys():
308
314
  content = str(e.body.get("message"))
309
315
  else:
310
316
  content = e.message
317
+
318
+ # narrow stop_reason
319
+ if e.code == "context_length_exceeded":
320
+ stop_reason: StopReason = "model_length"
321
+ elif e.code == "invalid_prompt":
322
+ stop_reason = "content_filter"
323
+ else:
324
+ stop_reason = "unknown"
325
+
311
326
  return ModelOutput.from_content(
312
- model=self.model_name, content=content, stop_reason="model_length"
327
+ model=self.model_name, content=content, stop_reason=stop_reason
313
328
  )
314
329
  else:
315
330
  raise e
@@ -449,16 +464,27 @@ async def as_chat_completion_part(
449
464
  ) -> ChatCompletionContentPartParam:
450
465
  if content.type == "text":
451
466
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
452
- else:
467
+ elif content.type == "image":
453
468
  # API takes URL or base64 encoded file. If it's a remote file or
454
469
  # data URL leave it alone, otherwise encode it
455
470
  image_url = content.image
456
471
  detail = content.detail
457
472
 
458
- if not is_http_url(image_url) and not is_data_uri(image_url):
459
- image_url = await image_as_data_uri(image_url)
473
+ if not is_http_url(image_url):
474
+ image_url = await file_as_data_uri(image_url)
460
475
 
461
476
  return ChatCompletionContentPartImageParam(
462
477
  type="image_url",
463
478
  image_url=dict(url=image_url, detail=detail),
464
479
  )
480
+ elif content.type == "audio":
481
+ audio_data = await file_as_data_uri(content.audio)
482
+
483
+ return ChatCompletionContentPartInputAudioParam(
484
+ type="input_audio", input_audio=dict(data=audio_data, format=content.format)
485
+ )
486
+
487
+ else:
488
+ raise RuntimeError(
489
+ "Video content is not currently supported by Open AI chat models."
490
+ )
@@ -25,7 +25,7 @@ from inspect_ai.model import (
25
25
  from inspect_ai.tool import ToolCall, ToolInfo
26
26
 
27
27
  from .._model_call import ModelCall
28
- from .._model_output import ModelUsage
28
+ from .._model_output import ModelUsage, StopReason
29
29
  from .._providers.util import (
30
30
  ChatAPIHandler,
31
31
  ChatAPIMessage,
@@ -48,12 +48,6 @@ async def generate_o1(
48
48
  # create chatapi handler
49
49
  handler = O1PreviewChatAPIHandler()
50
50
 
51
- # map max_tokens => max_completion_tokens
52
- max_tokens = params.get("max_tokens", None)
53
- if max_tokens:
54
- params["max_completion_tokens"] = max_tokens
55
- del params["max_tokens"]
56
-
57
51
  # call model
58
52
  request = dict(
59
53
  model=model,
@@ -89,12 +83,16 @@ async def generate_o1(
89
83
 
90
84
 
91
85
  def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
92
- if ex.code == "invalid_prompt":
93
- return ModelOutput.from_content(
94
- model=model, content=str(ex), stop_reason="content_filter"
95
- )
86
+ if ex.code == "context_length_exceeded":
87
+ stop_reason: StopReason = "model_length"
88
+ elif ex.code == "invalid_prompt":
89
+ stop_reason = "content_filter"
96
90
  else:
97
- raise ex
91
+ stop_reason = "unknown"
92
+
93
+ return ModelOutput.from_content(
94
+ model=model, content=str(ex), stop_reason=stop_reason
95
+ )
98
96
 
99
97
 
100
98
  def chat_messages(
@@ -24,8 +24,14 @@ from vertexai.generative_models import ( # type: ignore
24
24
  from vertexai.generative_models import Content as VertexContent
25
25
 
26
26
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED
27
- from inspect_ai._util.content import Content, ContentText
28
- from inspect_ai._util.images import image_as_data
27
+ from inspect_ai._util.content import (
28
+ Content,
29
+ ContentAudio,
30
+ ContentImage,
31
+ ContentText,
32
+ ContentVideo,
33
+ )
34
+ from inspect_ai._util.images import file_as_data
29
35
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
30
36
 
31
37
  from .._chat_message import (
@@ -308,9 +314,16 @@ async def content_part(content: Content | str) -> Part:
308
314
  return Part.from_text(content or NO_CONTENT)
309
315
  elif isinstance(content, ContentText):
310
316
  return Part.from_text(content.text or NO_CONTENT)
311
- else:
312
- image_bytes, mime_type = await image_as_data(content.image)
317
+ elif isinstance(content, ContentImage):
318
+ image_bytes, mime_type = await file_as_data(content.image)
313
319
  return Part.from_image(image=Image.from_bytes(data=image_bytes))
320
+ else:
321
+ if isinstance(content, ContentAudio):
322
+ file = content.audio
323
+ elif isinstance(content, ContentVideo):
324
+ file = content.video
325
+ file_bytes, mime_type = await file_as_data(file)
326
+ return Part.from_data(file_bytes, mime_type)
314
327
 
315
328
 
316
329
  def prepend_system_messages(
@@ -1,3 +1,5 @@
1
+ from inspect_ai._util.deprecation import relocated_module_attribute
2
+
1
3
  from ._answer import AnswerPattern, answer
2
4
  from ._choice import choice
3
5
  from ._classification import exact, f1
@@ -16,7 +18,7 @@ from ._metric import (
16
18
  )
17
19
  from ._metrics.accuracy import accuracy
18
20
  from ._metrics.mean import mean
19
- from ._metrics.std import bootstrap_std, std, stderr
21
+ from ._metrics.std import bootstrap_stderr, std, stderr
20
22
  from ._model import model_graded_fact, model_graded_qa
21
23
  from ._multi import multi_scorer
22
24
  from ._pattern import pattern
@@ -50,7 +52,7 @@ __all__ = [
50
52
  "Target",
51
53
  "scorer",
52
54
  "accuracy",
53
- "bootstrap_std",
55
+ "bootstrap_stderr",
54
56
  "std",
55
57
  "stderr",
56
58
  "mean",
@@ -76,3 +78,12 @@ __all__ = [
76
78
  "at_least",
77
79
  "pass_at",
78
80
  ]
81
+ _BOOTSTRAP_RENAME_VERSION = "0.3.58"
82
+ _REMOVED_IN = "0.4"
83
+
84
+ relocated_module_attribute(
85
+ "bootstrap_std",
86
+ "inspect_ai.scorer.bootstrap_stderr",
87
+ _BOOTSTRAP_RENAME_VERSION,
88
+ _REMOVED_IN,
89
+ )
@@ -1,12 +1,12 @@
1
1
  from .accuracy import accuracy
2
2
  from .mean import mean, var
3
- from .std import bootstrap_std, std, stderr
3
+ from .std import bootstrap_stderr, std, stderr
4
4
 
5
5
  __all__ = [
6
6
  "accuracy",
7
7
  "mean",
8
8
  "var",
9
- "bootstrap_std",
9
+ "bootstrap_stderr",
10
10
  "std",
11
11
  "stderr",
12
12
  ]
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
15
15
 
16
16
 
17
17
  @metric
18
- def bootstrap_std(
18
+ def bootstrap_stderr(
19
19
  num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
20
20
  ) -> Metric:
21
- """Standard deviation of a bootstrapped estimate of the mean.
21
+ """Standard error of the mean using bootstrap.
22
22
 
23
23
  Args:
24
24
  num_samples (int): Number of bootstrap samples to take.
@@ -31,7 +31,7 @@ def bootstrap_std(
31
31
  0 if the Value is a complex object (list or dict).
32
32
 
33
33
  Returns:
34
- bootstrap_std metric
34
+ bootstrap_stderr metric
35
35
  """
36
36
 
37
37
  def metric(scores: list[Score]) -> float:
@@ -1,4 +1,10 @@
1
- from inspect_ai._util.content import Content, ContentImage, ContentText
1
+ from inspect_ai._util.content import (
2
+ Content,
3
+ ContentAudio,
4
+ ContentImage,
5
+ ContentText,
6
+ ContentVideo,
7
+ )
2
8
  from inspect_ai._util.deprecation import relocated_module_attribute
3
9
 
4
10
  from ._tool import Tool, ToolError, ToolResult, tool
@@ -30,8 +36,10 @@ __all__ = [
30
36
  "ToolError",
31
37
  "ToolResult",
32
38
  "Content",
39
+ "ContentAudio",
33
40
  "ContentImage",
34
41
  "ContentText",
42
+ "ContentVideo",
35
43
  "ToolCall",
36
44
  "ToolCallContent",
37
45
  "ToolCallView",
inspect_ai/tool/_tool.py CHANGED
@@ -11,7 +11,12 @@ from typing import (
11
11
  runtime_checkable,
12
12
  )
13
13
 
14
- from inspect_ai._util.content import ContentImage, ContentText
14
+ from inspect_ai._util.content import (
15
+ ContentAudio,
16
+ ContentImage,
17
+ ContentText,
18
+ ContentVideo,
19
+ )
15
20
  from inspect_ai._util.registry import (
16
21
  RegistryInfo,
17
22
  registry_add,
@@ -31,7 +36,9 @@ ToolResult = (
31
36
  | bool
32
37
  | ContentText
33
38
  | ContentImage
34
- | list[ContentText | ContentImage]
39
+ | ContentAudio
40
+ | ContentVideo
41
+ | list[ContentText | ContentImage | ContentAudio | ContentVideo]
35
42
  )
36
43
 
37
44
 
@@ -8,6 +8,7 @@ from typing import (
8
8
  Dict,
9
9
  List,
10
10
  Optional,
11
+ Tuple,
11
12
  Type,
12
13
  Union,
13
14
  get_args,
@@ -155,7 +156,7 @@ def parse_type(type_hint: Type[Any]) -> ToolParam:
155
156
  return ToolParam(type="null")
156
157
  else:
157
158
  return ToolParam()
158
- elif origin is list or origin is List:
159
+ elif origin is list or origin is List or origin is tuple or origin is Tuple:
159
160
  return ToolParam(
160
161
  type="array", items=parse_type(args[0]) if args else ToolParam()
161
162
  )
@@ -38,9 +38,9 @@ class EnvironmentSpec:
38
38
  for i, obs_spec in enumerate(env_obs_spec.values()):
39
39
  self.observation_spec[i + 1] = convert(obs_spec)
40
40
 
41
- assert isinstance(
42
- env.action_spec(), specs.Array
43
- ), "Only a single action type is supported."
41
+ assert isinstance(env.action_spec(), specs.Array), (
42
+ "Only a single action type is supported."
43
+ )
44
44
  self.action_spec = {1: convert(env.action_spec())}
45
45
 
46
46
  self.observation_manager = spec_manager.SpecManager(self.observation_spec)
@@ -234,12 +234,12 @@ class EnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer):
234
234
  observations.
235
235
  """
236
236
  with self._lock:
237
- assert (
238
- cur_world in self._envs
239
- ), "Current world does not have an assosiated environment"
240
- assert (
241
- cur_world in self._joined_worlds
242
- ), "Please join world before calling step."
237
+ assert cur_world in self._envs, (
238
+ "Current world does not have an assosiated environment"
239
+ )
240
+ assert cur_world in self._joined_worlds, (
241
+ "Please join world before calling step."
242
+ )
243
243
  env = self._envs[cur_world]
244
244
  spec = self._specs[cur_world]
245
245
 
@@ -372,7 +372,9 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
372
372
  )
373
373
  else:
374
374
  response = parse_web_browser_output(result.stdout)
375
- if "web_at" in response:
375
+ if "error" in response and response.get("error", "").strip() != "":
376
+ raise ToolError(str(response.get("error")) or "(unknown error)")
377
+ elif "web_at" in response:
376
378
  web_at = (
377
379
  str(response.get("web_at")) or "(no web accessiblity tree available)"
378
380
  )
@@ -384,8 +386,6 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
384
386
  web_at = "\n".join(web_at_lines)
385
387
  store_as(WebBrowserStore).web_at = web_at
386
388
  return web_at
387
- elif "error" in response:
388
- raise ToolError(str(response.get("error")) or "(unknown error)")
389
389
  else:
390
390
  raise RuntimeError(
391
391
  f"web_browser output must contain either 'error' or 'web_at' field: {result.stdout}"
@@ -1,3 +1,5 @@
1
+ from inspect_ai._util.trace import trace_action, trace_message
2
+
1
3
  from ._concurrency import concurrency
2
4
  from ._console import input_screen
3
5
  from ._display import DisplayType, display_type
@@ -24,7 +26,6 @@ from ._subprocess import (
24
26
  )
25
27
  from ._subtask import Subtask, subtask
26
28
  from ._throttle import throttle
27
- from ._trace import trace_enabled, trace_panel
28
29
 
29
30
  __all__ = [
30
31
  "ExecResult",
@@ -54,6 +55,6 @@ __all__ = [
54
55
  "Subtask",
55
56
  "subtask",
56
57
  "throttle",
57
- "trace_enabled",
58
- "trace_panel",
58
+ "trace_action",
59
+ "trace_message",
59
60
  ]
@@ -1,5 +1,3 @@
1
- from contextvars import ContextVar
2
-
3
1
  from rich import print
4
2
  from rich.console import RenderableType
5
3
  from rich.text import Text
@@ -7,12 +5,7 @@ from rich.text import Text
7
5
  from inspect_ai._util.transcript import transcript_panel
8
6
 
9
7
 
10
- def trace_enabled() -> bool:
11
- """Is trace mode currently enabled."""
12
- return _trace.get(None) is True
13
-
14
-
15
- def trace_panel(
8
+ def conversation_panel(
16
9
  title: str,
17
10
  *,
18
11
  subtitle: str | None = None,
@@ -20,8 +13,8 @@ def trace_panel(
20
13
  ) -> None:
21
14
  """Trace content into a standard trace panel display.
22
15
 
23
- Typically you would call `trace_enabled()` to confirm that trace mode
24
- is enabled before calling `trace_panel()`.
16
+ Typically you would call `display_type() == "conversation"` to confirm that
17
+ we are in conversation mode before calling `conversation_panel()`.
25
18
 
26
19
  Args:
27
20
  title (str): Panel title.
@@ -32,10 +25,3 @@ def trace_panel(
32
25
  transcript_panel(title, subtitle, content),
33
26
  Text(),
34
27
  )
35
-
36
-
37
- def init_trace(trace: bool | None) -> None:
38
- _trace.set(trace)
39
-
40
-
41
- _trace: ContextVar[bool | None] = ContextVar("_trace_mode")
@@ -3,10 +3,11 @@ from logging import getLogger
3
3
  from typing import Literal
4
4
 
5
5
  from inspect_ai._util.constants import DEFAULT_DISPLAY
6
+ from inspect_ai._util.thread import is_main_thread
6
7
 
7
8
  logger = getLogger(__name__)
8
9
 
9
- DisplayType = Literal["full", "rich", "plain", "none"]
10
+ DisplayType = Literal["full", "conversation", "rich", "plain", "none"]
10
11
  """Console display type."""
11
12
 
12
13
 
@@ -15,15 +16,24 @@ _display_type: DisplayType | None = None
15
16
 
16
17
  def init_display_type(display: str | None = None) -> DisplayType:
17
18
  global _display_type
18
- global _display_metrics
19
19
  display = (
20
20
  display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
21
21
  )
22
+
23
+ # if we are on a background thread then throttle down to "plain"
24
+ # ("full" requires textual which cannot run in a background thread
25
+ # b/c it calls the Python signal function; "rich" assumes exclusive
26
+ # display access which may not be the case for threads)
27
+ if display in ["full", "rich"] and not is_main_thread():
28
+ display = "plain"
29
+
22
30
  match display:
23
- case "full" | "rich" | "plain" | "none":
31
+ case "full" | "conversation" | "rich" | "plain" | "none":
24
32
  _display_type = display
25
33
  case _:
26
- logger.warning(f"Unknown display type '{display}'")
34
+ logger.warning(
35
+ f"Unknown display type '{display}' (setting display to 'full')"
36
+ )
27
37
  _display_type = "full"
28
38
  return _display_type
29
39
 
@@ -4,6 +4,8 @@ from typing import Any, NoReturn, cast
4
4
 
5
5
  from shortuuid import uuid
6
6
 
7
+ from inspect_ai._util.constants import SANDBOX_SETUP_TIMEOUT
8
+
7
9
  from .environment import (
8
10
  SampleCleanup,
9
11
  SampleInit,
@@ -193,23 +195,20 @@ async def setup_sandbox_environment(
193
195
  setup_file = f"/tmp/{uuid()}"
194
196
  await env.write_file(setup_file, setup)
195
197
 
196
- # chmod, execute, and remove
197
- async def exec(cmd: list[str]) -> None:
198
- try:
199
- result = await env.exec(cmd, timeout=30)
200
- except TimeoutError:
201
- raise RuntimeError(
202
- f"Timed out executing command {' '.join(cmd)} in sandbox"
203
- )
204
-
198
+ # execute and then remove setup script (don't retry it on timeout
199
+ # in case it is not idempotent)
200
+ try:
201
+ await env.exec(["chmod", "+x", setup_file], timeout=30)
202
+ result = await env.exec(
203
+ ["env", setup_file], timeout=SANDBOX_SETUP_TIMEOUT, timeout_retry=False
204
+ )
205
205
  if not result.success:
206
206
  raise RuntimeError(
207
207
  f"Failed to execute setup script for sample: {result.stderr}"
208
208
  )
209
-
210
- await exec(["chmod", "+x", setup_file])
211
- await exec(["env", setup_file])
212
- await exec(["rm", setup_file])
209
+ await env.exec(["rm", setup_file], timeout=30)
210
+ except TimeoutError:
211
+ raise RuntimeError("Timed out executing setup command in sandbox")
213
212
 
214
213
 
215
214
  def default_sandbox_environment(
@@ -25,18 +25,17 @@ COMPOSE_WAIT = "120"
25
25
 
26
26
 
27
27
  async def compose_up(project: ComposeProject) -> None:
28
- # Start the environment
29
- result = await compose_command(
28
+ # Start the environment. Note that we don't check the result because docker will
29
+ # return a non-zero exit code for services that exit (even successfully) when
30
+ # passing the --wait flag (see https://github.com/docker/compose/issues/10596).
31
+ # In practice, we will catch any errors when calling compose_check_running()
32
+ # immediately after we call compose_up().
33
+ await compose_command(
30
34
  ["up", "--detach", "--wait", "--wait-timeout", COMPOSE_WAIT],
31
35
  project=project,
32
36
  # wait up to 5 minutes for container to go up (compose wait + 3 minutes)
33
37
  timeout=300,
34
38
  )
35
- if not result.success:
36
- msg = (
37
- f"Failed to start docker services for {project.config}: " f"{result.stderr}"
38
- )
39
- raise RuntimeError(msg)
40
39
 
41
40
 
42
41
  async def compose_down(project: ComposeProject, quiet: bool = True) -> None:
@@ -93,14 +92,21 @@ async def compose_cp(
93
92
  raise RuntimeError(msg)
94
93
 
95
94
 
96
- async def compose_check_running(services: list[str], project: ComposeProject) -> None:
95
+ async def compose_check_running(
96
+ services: list[str], project: ComposeProject
97
+ ) -> list[str]:
97
98
  # Check to ensure that the status of containers is healthy
98
99
  running_services = await compose_ps(project=project, status="running")
99
- if len(running_services) > 0:
100
- if len(running_services) != len(services):
100
+ exited_services = await compose_ps(project=project, status="exited")
101
+ successful_services = running_services + [
102
+ service for service in exited_services if service["ExitCode"] == 0
103
+ ]
104
+
105
+ if len(successful_services) > 0:
106
+ if len(successful_services) != len(services):
101
107
  unhealthy_services = services
102
- for running_service in running_services:
103
- unhealthy_services.remove(running_service["Service"])
108
+ for successful_service in successful_services:
109
+ unhealthy_services.remove(successful_service["Service"])
104
110
 
105
111
  msg = (
106
112
  "One or more docker containers failed to start from "
@@ -110,6 +116,8 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
110
116
  else:
111
117
  raise RuntimeError("No services started")
112
118
 
119
+ return [service["Service"] for service in running_services]
120
+
113
121
 
114
122
  async def compose_ps(
115
123
  project: ComposeProject,
@@ -168,6 +176,7 @@ async def compose_exec(
168
176
  *,
169
177
  project: ComposeProject,
170
178
  timeout: int | None,
179
+ timeout_retry: bool = True,
171
180
  input: str | bytes | None = None,
172
181
  output_limit: int | None = None,
173
182
  ) -> ExecResult[str]:
@@ -175,6 +184,7 @@ async def compose_exec(
175
184
  ["exec"] + command,
176
185
  project=project,
177
186
  timeout=timeout,
187
+ timeout_retry=timeout_retry,
178
188
  input=input,
179
189
  forward_env=False,
180
190
  output_limit=output_limit,
@@ -260,6 +270,7 @@ async def compose_command(
260
270
  *,
261
271
  project: ComposeProject,
262
272
  timeout: int | None,
273
+ timeout_retry: bool = True,
263
274
  input: str | bytes | None = None,
264
275
  cwd: str | Path | None = None,
265
276
  forward_env: bool = True,
@@ -327,7 +338,7 @@ async def compose_command(
327
338
  return await run_command(command_timeout)
328
339
  except TimeoutError:
329
340
  retries += 1
330
- if retries <= MAX_RETRIES:
341
+ if timeout_retry and (retries <= MAX_RETRIES):
331
342
  logger.info(
332
343
  f"Retrying docker compose command: {shlex.join(compose_command)}"
333
344
  )