inspect-ai 0.3.57__py3-none-any.whl → 0.3.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. inspect_ai/__init__.py +2 -1
  2. inspect_ai/_cli/common.py +4 -2
  3. inspect_ai/_cli/eval.py +2 -0
  4. inspect_ai/_cli/trace.py +21 -2
  5. inspect_ai/_display/core/active.py +0 -2
  6. inspect_ai/_display/rich/display.py +4 -4
  7. inspect_ai/_display/textual/app.py +4 -1
  8. inspect_ai/_display/textual/widgets/samples.py +41 -5
  9. inspect_ai/_eval/eval.py +32 -20
  10. inspect_ai/_eval/evalset.py +7 -5
  11. inspect_ai/_eval/task/__init__.py +2 -2
  12. inspect_ai/_eval/task/images.py +40 -25
  13. inspect_ai/_eval/task/run.py +141 -119
  14. inspect_ai/_eval/task/task.py +140 -25
  15. inspect_ai/_util/constants.py +1 -0
  16. inspect_ai/_util/content.py +23 -1
  17. inspect_ai/_util/images.py +20 -17
  18. inspect_ai/_util/kvstore.py +73 -0
  19. inspect_ai/_util/notgiven.py +18 -0
  20. inspect_ai/_util/thread.py +5 -0
  21. inspect_ai/_view/www/dist/assets/index.js +37 -3
  22. inspect_ai/_view/www/log-schema.json +97 -13
  23. inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
  24. inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
  25. inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +5 -1
  26. inspect_ai/_view/www/src/types/log.d.ts +51 -27
  27. inspect_ai/approval/_human/util.py +2 -2
  28. inspect_ai/dataset/_sources/csv.py +2 -1
  29. inspect_ai/dataset/_sources/json.py +2 -1
  30. inspect_ai/dataset/_sources/util.py +15 -7
  31. inspect_ai/log/_condense.py +11 -1
  32. inspect_ai/log/_log.py +2 -5
  33. inspect_ai/log/_recorders/eval.py +19 -8
  34. inspect_ai/log/_samples.py +10 -5
  35. inspect_ai/log/_transcript.py +28 -1
  36. inspect_ai/model/__init__.py +10 -2
  37. inspect_ai/model/_call_tools.py +55 -12
  38. inspect_ai/model/_chat_message.py +2 -4
  39. inspect_ai/model/{_trace.py → _conversation.py} +9 -8
  40. inspect_ai/model/_model.py +2 -2
  41. inspect_ai/model/_providers/anthropic.py +9 -7
  42. inspect_ai/model/_providers/azureai.py +6 -4
  43. inspect_ai/model/_providers/bedrock.py +6 -4
  44. inspect_ai/model/_providers/google.py +79 -8
  45. inspect_ai/model/_providers/groq.py +7 -5
  46. inspect_ai/model/_providers/hf.py +11 -6
  47. inspect_ai/model/_providers/mistral.py +6 -9
  48. inspect_ai/model/_providers/openai.py +17 -5
  49. inspect_ai/model/_providers/vertex.py +17 -4
  50. inspect_ai/scorer/__init__.py +13 -2
  51. inspect_ai/scorer/_metrics/__init__.py +2 -2
  52. inspect_ai/scorer/_metrics/std.py +3 -3
  53. inspect_ai/tool/__init__.py +9 -1
  54. inspect_ai/tool/_tool.py +9 -2
  55. inspect_ai/util/__init__.py +0 -3
  56. inspect_ai/util/{_trace.py → _conversation.py} +3 -17
  57. inspect_ai/util/_display.py +14 -4
  58. inspect_ai/util/_sandbox/context.py +12 -13
  59. inspect_ai/util/_sandbox/docker/compose.py +24 -11
  60. inspect_ai/util/_sandbox/docker/docker.py +20 -13
  61. inspect_ai/util/_sandbox/environment.py +13 -1
  62. inspect_ai/util/_sandbox/local.py +1 -0
  63. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +2 -2
  64. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +68 -65
  65. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
  66. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +0 -0
  67. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
  68. {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
1
+ import asyncio
1
2
  import functools
3
+ import hashlib
2
4
  import json
3
5
  from copy import copy
6
+ from io import BytesIO
7
+ from logging import getLogger
4
8
  from typing import Any, cast
5
9
 
6
10
  import proto # type: ignore
7
11
  from google.ai.generativelanguage import (
8
12
  Blob,
9
13
  Candidate,
14
+ File,
10
15
  FunctionCall,
11
16
  FunctionCallingConfig,
12
17
  FunctionDeclaration,
@@ -28,6 +33,8 @@ from google.generativeai import ( # type: ignore
28
33
  GenerationConfig,
29
34
  GenerativeModel,
30
35
  configure,
36
+ get_file,
37
+ upload_file,
31
38
  )
32
39
  from google.generativeai.types import ( # type: ignore
33
40
  AsyncGenerateContentResponse,
@@ -45,8 +52,16 @@ from pydantic import JsonValue
45
52
  from typing_extensions import override
46
53
 
47
54
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED
48
- from inspect_ai._util.content import Content, ContentImage, ContentText
49
- from inspect_ai._util.images import image_as_data
55
+ from inspect_ai._util.content import (
56
+ Content,
57
+ ContentAudio,
58
+ ContentImage,
59
+ ContentText,
60
+ ContentVideo,
61
+ )
62
+ from inspect_ai._util.images import file_as_data
63
+ from inspect_ai._util.kvstore import inspect_kvstore
64
+ from inspect_ai._util.trace import trace_message
50
65
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
51
66
 
52
67
  from .._chat_message import (
@@ -70,6 +85,8 @@ from .._model_output import (
70
85
  )
71
86
  from .util import model_base_url
72
87
 
88
+ logger = getLogger(__name__)
89
+
73
90
  SAFETY_SETTINGS = "safety_settings"
74
91
 
75
92
  DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
@@ -364,19 +381,23 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
364
381
  return struct
365
382
 
366
383
 
367
- async def content_part(content: Content | str) -> PartDict:
384
+ async def content_part(content: Content | str) -> PartType:
368
385
  if isinstance(content, str):
369
386
  return PartDict(text=content or NO_CONTENT)
370
387
  elif isinstance(content, ContentText):
371
388
  return PartDict(text=content.text or NO_CONTENT)
372
389
  else:
373
- return PartDict(inline_data=await chat_content_image_to_blob(content))
390
+ return await chat_content_to_part(content)
374
391
 
375
392
 
376
- async def chat_content_image_to_blob(image: ContentImage) -> Blob:
377
- image_url = image.image
378
- image_bytes, mime_type = await image_as_data(image_url)
379
- return Blob(mime_type=mime_type, data=image_bytes)
393
+ async def chat_content_to_part(
394
+ content: ContentImage | ContentAudio | ContentVideo,
395
+ ) -> PartType:
396
+ if isinstance(content, ContentImage):
397
+ content_bytes, mime_type = await file_as_data(content.image)
398
+ return Blob(mime_type=mime_type, data=content_bytes)
399
+ else:
400
+ return await file_for_content(content)
380
401
 
381
402
 
382
403
  def prepend_system_messages(
@@ -630,3 +651,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
630
651
  return HarmBlockThreshold.BLOCK_NONE
631
652
  else:
632
653
  raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
654
+
655
+
656
+ async def file_for_content(content: ContentAudio | ContentVideo) -> File:
657
+ # helper to write trace messages
658
+ def trace(message: str) -> None:
659
+ trace_message(logger, "Google Files", message)
660
+
661
+ # get the file bytes and compute sha256 hash
662
+ if isinstance(content, ContentAudio):
663
+ file = content.audio
664
+ else:
665
+ file = content.video
666
+ content_bytes, mime_type = await file_as_data(file)
667
+ content_sha256 = hashlib.sha256(content_bytes).hexdigest()
668
+
669
+ # we cache uploads for re-use, open the db where we track that
670
+ # (track up to 1 million previous uploads)
671
+ with inspect_kvstore("google_files", 1000000) as files_db:
672
+ # can we serve from existing uploads?
673
+ uploaded_file = files_db.get(content_sha256)
674
+ if uploaded_file:
675
+ try:
676
+ upload = cast(File, get_file(uploaded_file))
677
+ if upload.state.name == "ACTIVE":
678
+ trace(f"Using uploaded file: {uploaded_file}")
679
+ return upload
680
+ else:
681
+ trace(
682
+ f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
683
+ )
684
+ except Exception as ex:
685
+ trace(f"Error attempting to access uploaded file: {ex}")
686
+ files_db.delete(content_sha256)
687
+
688
+ # do the upload (and record it)
689
+ upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
690
+ while upload.state.name == "PROCESSING":
691
+ await asyncio.sleep(3)
692
+ upload = get_file(upload.name)
693
+
694
+ if upload.state.name == "FAILED":
695
+ trace(f"Failed to upload file '{upload.name}: {upload.error}")
696
+ raise ValueError(f"Google file upload failed: {upload.error}")
697
+
698
+ # trace and record it
699
+ trace(f"Uploaded file: {upload.name}")
700
+ files_db.put(content_sha256, upload.name)
701
+
702
+ # return the file
703
+ return upload
@@ -23,8 +23,8 @@ from typing_extensions import override
23
23
 
24
24
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
25
25
  from inspect_ai._util.content import Content
26
- from inspect_ai._util.images import image_as_data_uri
27
- from inspect_ai._util.url import is_data_uri, is_http_url
26
+ from inspect_ai._util.images import file_as_data_uri
27
+ from inspect_ai._util.url import is_http_url
28
28
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
29
29
 
30
30
  from .._chat_message import (
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
248
248
  ) -> ChatCompletionContentPartParam:
249
249
  if content.type == "text":
250
250
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
251
- else:
251
+ elif content.type == "image":
252
252
  # API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
253
253
  image_url = content.image
254
254
  detail = content.detail
255
255
 
256
- if not is_http_url(image_url) and not is_data_uri(image_url):
257
- image_url = await image_as_data_uri(image_url)
256
+ if not is_http_url(image_url):
257
+ image_url = await file_as_data_uri(image_url)
258
258
 
259
259
  return ChatCompletionContentPartImageParam(
260
260
  type="image_url",
261
261
  image_url=dict(url=image_url, detail=detail),
262
262
  )
263
+ else:
264
+ raise RuntimeError("Groq models do not support audio or video inputs.")
263
265
 
264
266
 
265
267
  def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
239
239
  hf_messages = inspect_tools_to_string(hf_messages)
240
240
 
241
241
  # apply chat template
242
- chat = self.tokenizer.apply_chat_template(
243
- hf_messages,
244
- add_generation_prompt=True,
245
- tokenize=False,
246
- tools=tools_list if len(tools_list) > 0 else None,
247
- )
242
+ if self.tokenizer.chat_template is not None:
243
+ chat = self.tokenizer.apply_chat_template(
244
+ hf_messages,
245
+ add_generation_prompt=True,
246
+ tokenize=False,
247
+ tools=tools_list if len(tools_list) > 0 else None,
248
+ )
249
+ else:
250
+ chat = ""
251
+ for message in hf_messages:
252
+ chat += f"{message.role}: {message.content}\n"
248
253
  # return
249
254
  return cast(str, chat)
250
255
 
@@ -42,8 +42,7 @@ from inspect_ai._util.constants import (
42
42
  DEFAULT_TIMEOUT,
43
43
  )
44
44
  from inspect_ai._util.content import Content, ContentImage, ContentText
45
- from inspect_ai._util.images import image_as_data_uri
46
- from inspect_ai._util.url import is_data_uri
45
+ from inspect_ai._util.images import file_as_data_uri
47
46
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
48
47
 
49
48
  from .._chat_message import (
@@ -351,16 +350,14 @@ def mistral_system_message_content(
351
350
  async def mistral_content_chunk(content: Content) -> ContentChunk:
352
351
  if isinstance(content, ContentText):
353
352
  return TextChunk(text=content.text or NO_CONTENT)
354
- else:
353
+ elif isinstance(content, ContentImage):
355
354
  # resolve image to url
356
- image_url = content.image
357
- if not is_data_uri(image_url):
358
- image_url = await image_as_data_uri(image_url)
355
+ image_url = await file_as_data_uri(content.image)
359
356
 
360
357
  # return chunk
361
- return ImageURLChunk(
362
- image_url=ImageURL(url=content.image, detail=content.detail)
363
- )
358
+ return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
359
+ else:
360
+ raise RuntimeError("Mistral models do not support audio or video inputs.")
364
361
 
365
362
 
366
363
  def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall:
@@ -17,6 +17,7 @@ from openai.types.chat import (
17
17
  ChatCompletion,
18
18
  ChatCompletionAssistantMessageParam,
19
19
  ChatCompletionContentPartImageParam,
20
+ ChatCompletionContentPartInputAudioParam,
20
21
  ChatCompletionContentPartParam,
21
22
  ChatCompletionContentPartTextParam,
22
23
  ChatCompletionDeveloperMessageParam,
@@ -36,9 +37,9 @@ from typing_extensions import override
36
37
  from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
37
38
  from inspect_ai._util.content import Content
38
39
  from inspect_ai._util.error import PrerequisiteError
39
- from inspect_ai._util.images import image_as_data_uri
40
+ from inspect_ai._util.images import file_as_data_uri
40
41
  from inspect_ai._util.logger import warn_once
41
- from inspect_ai._util.url import is_data_uri, is_http_url
42
+ from inspect_ai._util.url import is_http_url
42
43
  from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
43
44
 
44
45
  from .._chat_message import ChatMessage, ChatMessageAssistant
@@ -463,16 +464,27 @@ async def as_chat_completion_part(
463
464
  ) -> ChatCompletionContentPartParam:
464
465
  if content.type == "text":
465
466
  return ChatCompletionContentPartTextParam(type="text", text=content.text)
466
- else:
467
+ elif content.type == "image":
467
468
  # API takes URL or base64 encoded file. If it's a remote file or
468
469
  # data URL leave it alone, otherwise encode it
469
470
  image_url = content.image
470
471
  detail = content.detail
471
472
 
472
- if not is_http_url(image_url) and not is_data_uri(image_url):
473
- image_url = await image_as_data_uri(image_url)
473
+ if not is_http_url(image_url):
474
+ image_url = await file_as_data_uri(image_url)
474
475
 
475
476
  return ChatCompletionContentPartImageParam(
476
477
  type="image_url",
477
478
  image_url=dict(url=image_url, detail=detail),
478
479
  )
480
+ elif content.type == "audio":
481
+ audio_data = await file_as_data_uri(content.audio)
482
+
483
+ return ChatCompletionContentPartInputAudioParam(
484
+ type="input_audio", input_audio=dict(data=audio_data, format=content.format)
485
+ )
486
+
487
+ else:
488
+ raise RuntimeError(
489
+ "Video content is not currently supported by Open AI chat models."
490
+ )
@@ -24,8 +24,14 @@ from vertexai.generative_models import ( # type: ignore
24
24
  from vertexai.generative_models import Content as VertexContent
25
25
 
26
26
  from inspect_ai._util.constants import BASE_64_DATA_REMOVED
27
- from inspect_ai._util.content import Content, ContentText
28
- from inspect_ai._util.images import image_as_data
27
+ from inspect_ai._util.content import (
28
+ Content,
29
+ ContentAudio,
30
+ ContentImage,
31
+ ContentText,
32
+ ContentVideo,
33
+ )
34
+ from inspect_ai._util.images import file_as_data
29
35
  from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
30
36
 
31
37
  from .._chat_message import (
@@ -308,9 +314,16 @@ async def content_part(content: Content | str) -> Part:
308
314
  return Part.from_text(content or NO_CONTENT)
309
315
  elif isinstance(content, ContentText):
310
316
  return Part.from_text(content.text or NO_CONTENT)
311
- else:
312
- image_bytes, mime_type = await image_as_data(content.image)
317
+ elif isinstance(content, ContentImage):
318
+ image_bytes, mime_type = await file_as_data(content.image)
313
319
  return Part.from_image(image=Image.from_bytes(data=image_bytes))
320
+ else:
321
+ if isinstance(content, ContentAudio):
322
+ file = content.audio
323
+ elif isinstance(content, ContentVideo):
324
+ file = content.video
325
+ file_bytes, mime_type = await file_as_data(file)
326
+ return Part.from_data(file_bytes, mime_type)
314
327
 
315
328
 
316
329
  def prepend_system_messages(
@@ -1,3 +1,5 @@
1
+ from inspect_ai._util.deprecation import relocated_module_attribute
2
+
1
3
  from ._answer import AnswerPattern, answer
2
4
  from ._choice import choice
3
5
  from ._classification import exact, f1
@@ -16,7 +18,7 @@ from ._metric import (
16
18
  )
17
19
  from ._metrics.accuracy import accuracy
18
20
  from ._metrics.mean import mean
19
- from ._metrics.std import bootstrap_std, std, stderr
21
+ from ._metrics.std import bootstrap_stderr, std, stderr
20
22
  from ._model import model_graded_fact, model_graded_qa
21
23
  from ._multi import multi_scorer
22
24
  from ._pattern import pattern
@@ -50,7 +52,7 @@ __all__ = [
50
52
  "Target",
51
53
  "scorer",
52
54
  "accuracy",
53
- "bootstrap_std",
55
+ "bootstrap_stderr",
54
56
  "std",
55
57
  "stderr",
56
58
  "mean",
@@ -76,3 +78,12 @@ __all__ = [
76
78
  "at_least",
77
79
  "pass_at",
78
80
  ]
81
+ _BOOTSTRAP_RENAME_VERSION = "0.3.58"
82
+ _REMOVED_IN = "0.4"
83
+
84
+ relocated_module_attribute(
85
+ "bootstrap_std",
86
+ "inspect_ai.scorer.bootstrap_stderr",
87
+ _BOOTSTRAP_RENAME_VERSION,
88
+ _REMOVED_IN,
89
+ )
@@ -1,12 +1,12 @@
1
1
  from .accuracy import accuracy
2
2
  from .mean import mean, var
3
- from .std import bootstrap_std, std, stderr
3
+ from .std import bootstrap_stderr, std, stderr
4
4
 
5
5
  __all__ = [
6
6
  "accuracy",
7
7
  "mean",
8
8
  "var",
9
- "bootstrap_std",
9
+ "bootstrap_stderr",
10
10
  "std",
11
11
  "stderr",
12
12
  ]
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
15
15
 
16
16
 
17
17
  @metric
18
- def bootstrap_std(
18
+ def bootstrap_stderr(
19
19
  num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
20
20
  ) -> Metric:
21
- """Standard deviation of a bootstrapped estimate of the mean.
21
+ """Standard error of the mean using bootstrap.
22
22
 
23
23
  Args:
24
24
  num_samples (int): Number of bootstrap samples to take.
@@ -31,7 +31,7 @@ def bootstrap_std(
31
31
  0 if the Value is a complex object (list or dict).
32
32
 
33
33
  Returns:
34
- bootstrap_std metric
34
+ bootstrap_stderr metric
35
35
  """
36
36
 
37
37
  def metric(scores: list[Score]) -> float:
@@ -1,4 +1,10 @@
1
- from inspect_ai._util.content import Content, ContentImage, ContentText
1
+ from inspect_ai._util.content import (
2
+ Content,
3
+ ContentAudio,
4
+ ContentImage,
5
+ ContentText,
6
+ ContentVideo,
7
+ )
2
8
  from inspect_ai._util.deprecation import relocated_module_attribute
3
9
 
4
10
  from ._tool import Tool, ToolError, ToolResult, tool
@@ -30,8 +36,10 @@ __all__ = [
30
36
  "ToolError",
31
37
  "ToolResult",
32
38
  "Content",
39
+ "ContentAudio",
33
40
  "ContentImage",
34
41
  "ContentText",
42
+ "ContentVideo",
35
43
  "ToolCall",
36
44
  "ToolCallContent",
37
45
  "ToolCallView",
inspect_ai/tool/_tool.py CHANGED
@@ -11,7 +11,12 @@ from typing import (
11
11
  runtime_checkable,
12
12
  )
13
13
 
14
- from inspect_ai._util.content import ContentImage, ContentText
14
+ from inspect_ai._util.content import (
15
+ ContentAudio,
16
+ ContentImage,
17
+ ContentText,
18
+ ContentVideo,
19
+ )
15
20
  from inspect_ai._util.registry import (
16
21
  RegistryInfo,
17
22
  registry_add,
@@ -31,7 +36,9 @@ ToolResult = (
31
36
  | bool
32
37
  | ContentText
33
38
  | ContentImage
34
- | list[ContentText | ContentImage]
39
+ | ContentAudio
40
+ | ContentVideo
41
+ | list[ContentText | ContentImage | ContentAudio | ContentVideo]
35
42
  )
36
43
 
37
44
 
@@ -26,7 +26,6 @@ from ._subprocess import (
26
26
  )
27
27
  from ._subtask import Subtask, subtask
28
28
  from ._throttle import throttle
29
- from ._trace import trace_enabled, trace_panel
30
29
 
31
30
  __all__ = [
32
31
  "ExecResult",
@@ -56,8 +55,6 @@ __all__ = [
56
55
  "Subtask",
57
56
  "subtask",
58
57
  "throttle",
59
- "trace_enabled",
60
- "trace_panel",
61
58
  "trace_action",
62
59
  "trace_message",
63
60
  ]
@@ -1,5 +1,3 @@
1
- from contextvars import ContextVar
2
-
3
1
  from rich import print
4
2
  from rich.console import RenderableType
5
3
  from rich.text import Text
@@ -7,12 +5,7 @@ from rich.text import Text
7
5
  from inspect_ai._util.transcript import transcript_panel
8
6
 
9
7
 
10
- def trace_enabled() -> bool:
11
- """Is trace mode currently enabled."""
12
- return _trace.get(None) is True
13
-
14
-
15
- def trace_panel(
8
+ def conversation_panel(
16
9
  title: str,
17
10
  *,
18
11
  subtitle: str | None = None,
@@ -20,8 +13,8 @@ def trace_panel(
20
13
  ) -> None:
21
14
  """Trace content into a standard trace panel display.
22
15
 
23
- Typically you would call `trace_enabled()` to confirm that trace mode
24
- is enabled before calling `trace_panel()`.
16
+ Typically you would call `display_type() == "conversation"` to confirm that
17
+ we are in conversation mode before calling `conversation_panel()`.
25
18
 
26
19
  Args:
27
20
  title (str): Panel title.
@@ -32,10 +25,3 @@ def trace_panel(
32
25
  transcript_panel(title, subtitle, content),
33
26
  Text(),
34
27
  )
35
-
36
-
37
- def init_trace(trace: bool | None) -> None:
38
- _trace.set(trace)
39
-
40
-
41
- _trace: ContextVar[bool | None] = ContextVar("_trace_mode")
@@ -3,10 +3,11 @@ from logging import getLogger
3
3
  from typing import Literal
4
4
 
5
5
  from inspect_ai._util.constants import DEFAULT_DISPLAY
6
+ from inspect_ai._util.thread import is_main_thread
6
7
 
7
8
  logger = getLogger(__name__)
8
9
 
9
- DisplayType = Literal["full", "rich", "plain", "none"]
10
+ DisplayType = Literal["full", "conversation", "rich", "plain", "none"]
10
11
  """Console display type."""
11
12
 
12
13
 
@@ -15,15 +16,24 @@ _display_type: DisplayType | None = None
15
16
 
16
17
  def init_display_type(display: str | None = None) -> DisplayType:
17
18
  global _display_type
18
- global _display_metrics
19
19
  display = (
20
20
  display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
21
21
  )
22
+
23
+ # if we are on a background thread then throttle down to "plain"
24
+ # ("full" requires textual which cannot run in a background thread
25
+ # b/c it calls the Python signal function; "rich" assumes exclusive
26
+ # display access which may not be the case for threads)
27
+ if display in ["full", "rich"] and not is_main_thread():
28
+ display = "plain"
29
+
22
30
  match display:
23
- case "full" | "rich" | "plain" | "none":
31
+ case "full" | "conversation" | "rich" | "plain" | "none":
24
32
  _display_type = display
25
33
  case _:
26
- logger.warning(f"Unknown display type '{display}'")
34
+ logger.warning(
35
+ f"Unknown display type '{display}' (setting display to 'full')"
36
+ )
27
37
  _display_type = "full"
28
38
  return _display_type
29
39
 
@@ -4,6 +4,8 @@ from typing import Any, NoReturn, cast
4
4
 
5
5
  from shortuuid import uuid
6
6
 
7
+ from inspect_ai._util.constants import SANDBOX_SETUP_TIMEOUT
8
+
7
9
  from .environment import (
8
10
  SampleCleanup,
9
11
  SampleInit,
@@ -193,23 +195,20 @@ async def setup_sandbox_environment(
193
195
  setup_file = f"/tmp/{uuid()}"
194
196
  await env.write_file(setup_file, setup)
195
197
 
196
- # chmod, execute, and remove
197
- async def exec(cmd: list[str]) -> None:
198
- try:
199
- result = await env.exec(cmd, timeout=30)
200
- except TimeoutError:
201
- raise RuntimeError(
202
- f"Timed out executing command {' '.join(cmd)} in sandbox"
203
- )
204
-
198
+ # execute and then remove setup script (don't retry it on timeout
199
+ # in case it is not idempotent)
200
+ try:
201
+ await env.exec(["chmod", "+x", setup_file], timeout=30)
202
+ result = await env.exec(
203
+ ["env", setup_file], timeout=SANDBOX_SETUP_TIMEOUT, timeout_retry=False
204
+ )
205
205
  if not result.success:
206
206
  raise RuntimeError(
207
207
  f"Failed to execute setup script for sample: {result.stderr}"
208
208
  )
209
-
210
- await exec(["chmod", "+x", setup_file])
211
- await exec(["env", setup_file])
212
- await exec(["rm", setup_file])
209
+ await env.exec(["rm", setup_file], timeout=30)
210
+ except TimeoutError:
211
+ raise RuntimeError("Timed out executing setup command in sandbox")
213
212
 
214
213
 
215
214
  def default_sandbox_environment(
@@ -25,16 +25,17 @@ COMPOSE_WAIT = "120"
25
25
 
26
26
 
27
27
  async def compose_up(project: ComposeProject) -> None:
28
- # Start the environment
29
- result = await compose_command(
28
+ # Start the environment. Note that we don't check the result because docker will
29
+ # return a non-zero exit code for services that exit (even successfully) when
30
+ # passing the --wait flag (see https://github.com/docker/compose/issues/10596).
31
+ # In practice, we will catch any errors when calling compose_check_running()
32
+ # immediately after we call compose_up().
33
+ await compose_command(
30
34
  ["up", "--detach", "--wait", "--wait-timeout", COMPOSE_WAIT],
31
35
  project=project,
32
36
  # wait up to 5 minutes for container to go up (compose wait + 3 minutes)
33
37
  timeout=300,
34
38
  )
35
- if not result.success:
36
- msg = f"Failed to start docker services for {project.config}: {result.stderr}"
37
- raise RuntimeError(msg)
38
39
 
39
40
 
40
41
  async def compose_down(project: ComposeProject, quiet: bool = True) -> None:
@@ -91,14 +92,21 @@ async def compose_cp(
91
92
  raise RuntimeError(msg)
92
93
 
93
94
 
94
- async def compose_check_running(services: list[str], project: ComposeProject) -> None:
95
+ async def compose_check_running(
96
+ services: list[str], project: ComposeProject
97
+ ) -> list[str]:
95
98
  # Check to ensure that the status of containers is healthy
96
99
  running_services = await compose_ps(project=project, status="running")
97
- if len(running_services) > 0:
98
- if len(running_services) != len(services):
100
+ exited_services = await compose_ps(project=project, status="exited")
101
+ successful_services = running_services + [
102
+ service for service in exited_services if service["ExitCode"] == 0
103
+ ]
104
+
105
+ if len(successful_services) > 0:
106
+ if len(successful_services) != len(services):
99
107
  unhealthy_services = services
100
- for running_service in running_services:
101
- unhealthy_services.remove(running_service["Service"])
108
+ for successful_service in successful_services:
109
+ unhealthy_services.remove(successful_service["Service"])
102
110
 
103
111
  msg = (
104
112
  "One or more docker containers failed to start from "
@@ -108,6 +116,8 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
108
116
  else:
109
117
  raise RuntimeError("No services started")
110
118
 
119
+ return [service["Service"] for service in running_services]
120
+
111
121
 
112
122
  async def compose_ps(
113
123
  project: ComposeProject,
@@ -166,6 +176,7 @@ async def compose_exec(
166
176
  *,
167
177
  project: ComposeProject,
168
178
  timeout: int | None,
179
+ timeout_retry: bool = True,
169
180
  input: str | bytes | None = None,
170
181
  output_limit: int | None = None,
171
182
  ) -> ExecResult[str]:
@@ -173,6 +184,7 @@ async def compose_exec(
173
184
  ["exec"] + command,
174
185
  project=project,
175
186
  timeout=timeout,
187
+ timeout_retry=timeout_retry,
176
188
  input=input,
177
189
  forward_env=False,
178
190
  output_limit=output_limit,
@@ -258,6 +270,7 @@ async def compose_command(
258
270
  *,
259
271
  project: ComposeProject,
260
272
  timeout: int | None,
273
+ timeout_retry: bool = True,
261
274
  input: str | bytes | None = None,
262
275
  cwd: str | Path | None = None,
263
276
  forward_env: bool = True,
@@ -325,7 +338,7 @@ async def compose_command(
325
338
  return await run_command(command_timeout)
326
339
  except TimeoutError:
327
340
  retries += 1
328
- if retries <= MAX_RETRIES:
341
+ if timeout_retry and (retries <= MAX_RETRIES):
329
342
  logger.info(
330
343
  f"Retrying docker compose command: {shlex.join(compose_command)}"
331
344
  )