inspect-ai 0.3.57__py3-none-any.whl → 0.3.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +4 -2
- inspect_ai/_cli/eval.py +2 -0
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +0 -2
- inspect_ai/_display/rich/display.py +4 -4
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/samples.py +41 -5
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/run.py +141 -119
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/dist/assets/index.js +37 -3
- inspect_ai/_view/www/log-schema.json +97 -13
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +5 -1
- inspect_ai/_view/www/src/types/log.d.ts +51 -27
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +2 -5
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +10 -5
- inspect_ai/log/_transcript.py +28 -1
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +55 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/{_trace.py → _conversation.py} +9 -8
- inspect_ai/model/_model.py +2 -2
- inspect_ai/model/_providers/anthropic.py +9 -7
- inspect_ai/model/_providers/azureai.py +6 -4
- inspect_ai/model/_providers/bedrock.py +6 -4
- inspect_ai/model/_providers/google.py +79 -8
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +6 -9
- inspect_ai/model/_providers/openai.py +17 -5
- inspect_ai/model/_providers/vertex.py +17 -4
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/tool/__init__.py +9 -1
- inspect_ai/tool/_tool.py +9 -2
- inspect_ai/util/__init__.py +0 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +20 -13
- inspect_ai/util/_sandbox/environment.py +13 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +68 -65
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,17 @@
|
|
1
|
+
import asyncio
|
1
2
|
import functools
|
3
|
+
import hashlib
|
2
4
|
import json
|
3
5
|
from copy import copy
|
6
|
+
from io import BytesIO
|
7
|
+
from logging import getLogger
|
4
8
|
from typing import Any, cast
|
5
9
|
|
6
10
|
import proto # type: ignore
|
7
11
|
from google.ai.generativelanguage import (
|
8
12
|
Blob,
|
9
13
|
Candidate,
|
14
|
+
File,
|
10
15
|
FunctionCall,
|
11
16
|
FunctionCallingConfig,
|
12
17
|
FunctionDeclaration,
|
@@ -28,6 +33,8 @@ from google.generativeai import ( # type: ignore
|
|
28
33
|
GenerationConfig,
|
29
34
|
GenerativeModel,
|
30
35
|
configure,
|
36
|
+
get_file,
|
37
|
+
upload_file,
|
31
38
|
)
|
32
39
|
from google.generativeai.types import ( # type: ignore
|
33
40
|
AsyncGenerateContentResponse,
|
@@ -45,8 +52,16 @@ from pydantic import JsonValue
|
|
45
52
|
from typing_extensions import override
|
46
53
|
|
47
54
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
48
|
-
from inspect_ai._util.content import
|
49
|
-
|
55
|
+
from inspect_ai._util.content import (
|
56
|
+
Content,
|
57
|
+
ContentAudio,
|
58
|
+
ContentImage,
|
59
|
+
ContentText,
|
60
|
+
ContentVideo,
|
61
|
+
)
|
62
|
+
from inspect_ai._util.images import file_as_data
|
63
|
+
from inspect_ai._util.kvstore import inspect_kvstore
|
64
|
+
from inspect_ai._util.trace import trace_message
|
50
65
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo, ToolParam, ToolParams
|
51
66
|
|
52
67
|
from .._chat_message import (
|
@@ -70,6 +85,8 @@ from .._model_output import (
|
|
70
85
|
)
|
71
86
|
from .util import model_base_url
|
72
87
|
|
88
|
+
logger = getLogger(__name__)
|
89
|
+
|
73
90
|
SAFETY_SETTINGS = "safety_settings"
|
74
91
|
|
75
92
|
DEFAULT_SAFETY_SETTINGS: SafetySettingDict = {
|
@@ -364,19 +381,23 @@ def dict_to_struct(x: dict[str, Any]) -> Struct:
|
|
364
381
|
return struct
|
365
382
|
|
366
383
|
|
367
|
-
async def content_part(content: Content | str) ->
|
384
|
+
async def content_part(content: Content | str) -> PartType:
|
368
385
|
if isinstance(content, str):
|
369
386
|
return PartDict(text=content or NO_CONTENT)
|
370
387
|
elif isinstance(content, ContentText):
|
371
388
|
return PartDict(text=content.text or NO_CONTENT)
|
372
389
|
else:
|
373
|
-
return
|
390
|
+
return await chat_content_to_part(content)
|
374
391
|
|
375
392
|
|
376
|
-
async def
|
377
|
-
|
378
|
-
|
379
|
-
|
393
|
+
async def chat_content_to_part(
|
394
|
+
content: ContentImage | ContentAudio | ContentVideo,
|
395
|
+
) -> PartType:
|
396
|
+
if isinstance(content, ContentImage):
|
397
|
+
content_bytes, mime_type = await file_as_data(content.image)
|
398
|
+
return Blob(mime_type=mime_type, data=content_bytes)
|
399
|
+
else:
|
400
|
+
return await file_for_content(content)
|
380
401
|
|
381
402
|
|
382
403
|
def prepend_system_messages(
|
@@ -630,3 +651,53 @@ def str_to_harm_block_threshold(threshold: str) -> HarmBlockThreshold:
|
|
630
651
|
return HarmBlockThreshold.BLOCK_NONE
|
631
652
|
else:
|
632
653
|
raise ValueError(f"Unknown HarmBlockThreshold: {threshold}")
|
654
|
+
|
655
|
+
|
656
|
+
async def file_for_content(content: ContentAudio | ContentVideo) -> File:
|
657
|
+
# helper to write trace messages
|
658
|
+
def trace(message: str) -> None:
|
659
|
+
trace_message(logger, "Google Files", message)
|
660
|
+
|
661
|
+
# get the file bytes and compute sha256 hash
|
662
|
+
if isinstance(content, ContentAudio):
|
663
|
+
file = content.audio
|
664
|
+
else:
|
665
|
+
file = content.video
|
666
|
+
content_bytes, mime_type = await file_as_data(file)
|
667
|
+
content_sha256 = hashlib.sha256(content_bytes).hexdigest()
|
668
|
+
|
669
|
+
# we cache uploads for re-use, open the db where we track that
|
670
|
+
# (track up to 1 million previous uploads)
|
671
|
+
with inspect_kvstore("google_files", 1000000) as files_db:
|
672
|
+
# can we serve from existing uploads?
|
673
|
+
uploaded_file = files_db.get(content_sha256)
|
674
|
+
if uploaded_file:
|
675
|
+
try:
|
676
|
+
upload = cast(File, get_file(uploaded_file))
|
677
|
+
if upload.state.name == "ACTIVE":
|
678
|
+
trace(f"Using uploaded file: {uploaded_file}")
|
679
|
+
return upload
|
680
|
+
else:
|
681
|
+
trace(
|
682
|
+
f"Not using uploaded file '{uploaded_file} (state was {upload.state})"
|
683
|
+
)
|
684
|
+
except Exception as ex:
|
685
|
+
trace(f"Error attempting to access uploaded file: {ex}")
|
686
|
+
files_db.delete(content_sha256)
|
687
|
+
|
688
|
+
# do the upload (and record it)
|
689
|
+
upload = upload_file(BytesIO(content_bytes), mime_type=mime_type)
|
690
|
+
while upload.state.name == "PROCESSING":
|
691
|
+
await asyncio.sleep(3)
|
692
|
+
upload = get_file(upload.name)
|
693
|
+
|
694
|
+
if upload.state.name == "FAILED":
|
695
|
+
trace(f"Failed to upload file '{upload.name}: {upload.error}")
|
696
|
+
raise ValueError(f"Google file upload failed: {upload.error}")
|
697
|
+
|
698
|
+
# trace and record it
|
699
|
+
trace(f"Uploaded file: {upload.name}")
|
700
|
+
files_db.put(content_sha256, upload.name)
|
701
|
+
|
702
|
+
# return the file
|
703
|
+
return upload
|
@@ -23,8 +23,8 @@ from typing_extensions import override
|
|
23
23
|
|
24
24
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS
|
25
25
|
from inspect_ai._util.content import Content
|
26
|
-
from inspect_ai._util.images import
|
27
|
-
from inspect_ai._util.url import
|
26
|
+
from inspect_ai._util.images import file_as_data_uri
|
27
|
+
from inspect_ai._util.url import is_http_url
|
28
28
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
29
29
|
|
30
30
|
from .._chat_message import (
|
@@ -248,18 +248,20 @@ async def as_chat_completion_part(
|
|
248
248
|
) -> ChatCompletionContentPartParam:
|
249
249
|
if content.type == "text":
|
250
250
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
251
|
-
|
251
|
+
elif content.type == "image":
|
252
252
|
# API takes URL or base64 encoded file. If it's a remote file or data URL leave it alone, otherwise encode it
|
253
253
|
image_url = content.image
|
254
254
|
detail = content.detail
|
255
255
|
|
256
|
-
if not is_http_url(image_url)
|
257
|
-
image_url = await
|
256
|
+
if not is_http_url(image_url):
|
257
|
+
image_url = await file_as_data_uri(image_url)
|
258
258
|
|
259
259
|
return ChatCompletionContentPartImageParam(
|
260
260
|
type="image_url",
|
261
261
|
image_url=dict(url=image_url, detail=detail),
|
262
262
|
)
|
263
|
+
else:
|
264
|
+
raise RuntimeError("Groq models do not support audio or video inputs.")
|
263
265
|
|
264
266
|
|
265
267
|
def chat_tools(tools: List[ToolInfo]) -> List[Dict[str, Any]]:
|
@@ -239,12 +239,17 @@ class HuggingFaceAPI(ModelAPI):
|
|
239
239
|
hf_messages = inspect_tools_to_string(hf_messages)
|
240
240
|
|
241
241
|
# apply chat template
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
242
|
+
if self.tokenizer.chat_template is not None:
|
243
|
+
chat = self.tokenizer.apply_chat_template(
|
244
|
+
hf_messages,
|
245
|
+
add_generation_prompt=True,
|
246
|
+
tokenize=False,
|
247
|
+
tools=tools_list if len(tools_list) > 0 else None,
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
chat = ""
|
251
|
+
for message in hf_messages:
|
252
|
+
chat += f"{message.role}: {message.content}\n"
|
248
253
|
# return
|
249
254
|
return cast(str, chat)
|
250
255
|
|
@@ -42,8 +42,7 @@ from inspect_ai._util.constants import (
|
|
42
42
|
DEFAULT_TIMEOUT,
|
43
43
|
)
|
44
44
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
45
|
-
from inspect_ai._util.images import
|
46
|
-
from inspect_ai._util.url import is_data_uri
|
45
|
+
from inspect_ai._util.images import file_as_data_uri
|
47
46
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
48
47
|
|
49
48
|
from .._chat_message import (
|
@@ -351,16 +350,14 @@ def mistral_system_message_content(
|
|
351
350
|
async def mistral_content_chunk(content: Content) -> ContentChunk:
|
352
351
|
if isinstance(content, ContentText):
|
353
352
|
return TextChunk(text=content.text or NO_CONTENT)
|
354
|
-
|
353
|
+
elif isinstance(content, ContentImage):
|
355
354
|
# resolve image to url
|
356
|
-
image_url = content.image
|
357
|
-
if not is_data_uri(image_url):
|
358
|
-
image_url = await image_as_data_uri(image_url)
|
355
|
+
image_url = await file_as_data_uri(content.image)
|
359
356
|
|
360
357
|
# return chunk
|
361
|
-
return ImageURLChunk(
|
362
|
-
|
363
|
-
)
|
358
|
+
return ImageURLChunk(image_url=ImageURL(url=image_url, detail=content.detail))
|
359
|
+
else:
|
360
|
+
raise RuntimeError("Mistral models do not support audio or video inputs.")
|
364
361
|
|
365
362
|
|
366
363
|
def mistral_tool_call(tool_call: ToolCall) -> MistralToolCall:
|
@@ -17,6 +17,7 @@ from openai.types.chat import (
|
|
17
17
|
ChatCompletion,
|
18
18
|
ChatCompletionAssistantMessageParam,
|
19
19
|
ChatCompletionContentPartImageParam,
|
20
|
+
ChatCompletionContentPartInputAudioParam,
|
20
21
|
ChatCompletionContentPartParam,
|
21
22
|
ChatCompletionContentPartTextParam,
|
22
23
|
ChatCompletionDeveloperMessageParam,
|
@@ -36,9 +37,9 @@ from typing_extensions import override
|
|
36
37
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
37
38
|
from inspect_ai._util.content import Content
|
38
39
|
from inspect_ai._util.error import PrerequisiteError
|
39
|
-
from inspect_ai._util.images import
|
40
|
+
from inspect_ai._util.images import file_as_data_uri
|
40
41
|
from inspect_ai._util.logger import warn_once
|
41
|
-
from inspect_ai._util.url import
|
42
|
+
from inspect_ai._util.url import is_http_url
|
42
43
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
43
44
|
|
44
45
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -463,16 +464,27 @@ async def as_chat_completion_part(
|
|
463
464
|
) -> ChatCompletionContentPartParam:
|
464
465
|
if content.type == "text":
|
465
466
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
466
|
-
|
467
|
+
elif content.type == "image":
|
467
468
|
# API takes URL or base64 encoded file. If it's a remote file or
|
468
469
|
# data URL leave it alone, otherwise encode it
|
469
470
|
image_url = content.image
|
470
471
|
detail = content.detail
|
471
472
|
|
472
|
-
if not is_http_url(image_url)
|
473
|
-
image_url = await
|
473
|
+
if not is_http_url(image_url):
|
474
|
+
image_url = await file_as_data_uri(image_url)
|
474
475
|
|
475
476
|
return ChatCompletionContentPartImageParam(
|
476
477
|
type="image_url",
|
477
478
|
image_url=dict(url=image_url, detail=detail),
|
478
479
|
)
|
480
|
+
elif content.type == "audio":
|
481
|
+
audio_data = await file_as_data_uri(content.audio)
|
482
|
+
|
483
|
+
return ChatCompletionContentPartInputAudioParam(
|
484
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
485
|
+
)
|
486
|
+
|
487
|
+
else:
|
488
|
+
raise RuntimeError(
|
489
|
+
"Video content is not currently supported by Open AI chat models."
|
490
|
+
)
|
@@ -24,8 +24,14 @@ from vertexai.generative_models import ( # type: ignore
|
|
24
24
|
from vertexai.generative_models import Content as VertexContent
|
25
25
|
|
26
26
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
27
|
-
from inspect_ai._util.content import
|
28
|
-
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
34
|
+
from inspect_ai._util.images import file_as_data
|
29
35
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
|
30
36
|
|
31
37
|
from .._chat_message import (
|
@@ -308,9 +314,16 @@ async def content_part(content: Content | str) -> Part:
|
|
308
314
|
return Part.from_text(content or NO_CONTENT)
|
309
315
|
elif isinstance(content, ContentText):
|
310
316
|
return Part.from_text(content.text or NO_CONTENT)
|
311
|
-
|
312
|
-
image_bytes, mime_type = await
|
317
|
+
elif isinstance(content, ContentImage):
|
318
|
+
image_bytes, mime_type = await file_as_data(content.image)
|
313
319
|
return Part.from_image(image=Image.from_bytes(data=image_bytes))
|
320
|
+
else:
|
321
|
+
if isinstance(content, ContentAudio):
|
322
|
+
file = content.audio
|
323
|
+
elif isinstance(content, ContentVideo):
|
324
|
+
file = content.video
|
325
|
+
file_bytes, mime_type = await file_as_data(file)
|
326
|
+
return Part.from_data(file_bytes, mime_type)
|
314
327
|
|
315
328
|
|
316
329
|
def prepend_system_messages(
|
inspect_ai/scorer/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
|
+
|
1
3
|
from ._answer import AnswerPattern, answer
|
2
4
|
from ._choice import choice
|
3
5
|
from ._classification import exact, f1
|
@@ -16,7 +18,7 @@ from ._metric import (
|
|
16
18
|
)
|
17
19
|
from ._metrics.accuracy import accuracy
|
18
20
|
from ._metrics.mean import mean
|
19
|
-
from ._metrics.std import
|
21
|
+
from ._metrics.std import bootstrap_stderr, std, stderr
|
20
22
|
from ._model import model_graded_fact, model_graded_qa
|
21
23
|
from ._multi import multi_scorer
|
22
24
|
from ._pattern import pattern
|
@@ -50,7 +52,7 @@ __all__ = [
|
|
50
52
|
"Target",
|
51
53
|
"scorer",
|
52
54
|
"accuracy",
|
53
|
-
"
|
55
|
+
"bootstrap_stderr",
|
54
56
|
"std",
|
55
57
|
"stderr",
|
56
58
|
"mean",
|
@@ -76,3 +78,12 @@ __all__ = [
|
|
76
78
|
"at_least",
|
77
79
|
"pass_at",
|
78
80
|
]
|
81
|
+
_BOOTSTRAP_RENAME_VERSION = "0.3.58"
|
82
|
+
_REMOVED_IN = "0.4"
|
83
|
+
|
84
|
+
relocated_module_attribute(
|
85
|
+
"bootstrap_std",
|
86
|
+
"inspect_ai.scorer.bootstrap_stderr",
|
87
|
+
_BOOTSTRAP_RENAME_VERSION,
|
88
|
+
_REMOVED_IN,
|
89
|
+
)
|
@@ -1,12 +1,12 @@
|
|
1
1
|
from .accuracy import accuracy
|
2
2
|
from .mean import mean, var
|
3
|
-
from .std import
|
3
|
+
from .std import bootstrap_stderr, std, stderr
|
4
4
|
|
5
5
|
__all__ = [
|
6
6
|
"accuracy",
|
7
7
|
"mean",
|
8
8
|
"var",
|
9
|
-
"
|
9
|
+
"bootstrap_stderr",
|
10
10
|
"std",
|
11
11
|
"stderr",
|
12
12
|
]
|
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
|
|
15
15
|
|
16
16
|
|
17
17
|
@metric
|
18
|
-
def
|
18
|
+
def bootstrap_stderr(
|
19
19
|
num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
|
20
20
|
) -> Metric:
|
21
|
-
"""Standard
|
21
|
+
"""Standard error of the mean using bootstrap.
|
22
22
|
|
23
23
|
Args:
|
24
24
|
num_samples (int): Number of bootstrap samples to take.
|
@@ -31,7 +31,7 @@ def bootstrap_std(
|
|
31
31
|
0 if the Value is a complex object (list or dict).
|
32
32
|
|
33
33
|
Returns:
|
34
|
-
|
34
|
+
bootstrap_stderr metric
|
35
35
|
"""
|
36
36
|
|
37
37
|
def metric(scores: list[Score]) -> float:
|
inspect_ai/tool/__init__.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1
|
-
from inspect_ai._util.content import
|
1
|
+
from inspect_ai._util.content import (
|
2
|
+
Content,
|
3
|
+
ContentAudio,
|
4
|
+
ContentImage,
|
5
|
+
ContentText,
|
6
|
+
ContentVideo,
|
7
|
+
)
|
2
8
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
3
9
|
|
4
10
|
from ._tool import Tool, ToolError, ToolResult, tool
|
@@ -30,8 +36,10 @@ __all__ = [
|
|
30
36
|
"ToolError",
|
31
37
|
"ToolResult",
|
32
38
|
"Content",
|
39
|
+
"ContentAudio",
|
33
40
|
"ContentImage",
|
34
41
|
"ContentText",
|
42
|
+
"ContentVideo",
|
35
43
|
"ToolCall",
|
36
44
|
"ToolCallContent",
|
37
45
|
"ToolCallView",
|
inspect_ai/tool/_tool.py
CHANGED
@@ -11,7 +11,12 @@ from typing import (
|
|
11
11
|
runtime_checkable,
|
12
12
|
)
|
13
13
|
|
14
|
-
from inspect_ai._util.content import
|
14
|
+
from inspect_ai._util.content import (
|
15
|
+
ContentAudio,
|
16
|
+
ContentImage,
|
17
|
+
ContentText,
|
18
|
+
ContentVideo,
|
19
|
+
)
|
15
20
|
from inspect_ai._util.registry import (
|
16
21
|
RegistryInfo,
|
17
22
|
registry_add,
|
@@ -31,7 +36,9 @@ ToolResult = (
|
|
31
36
|
| bool
|
32
37
|
| ContentText
|
33
38
|
| ContentImage
|
34
|
-
|
|
39
|
+
| ContentAudio
|
40
|
+
| ContentVideo
|
41
|
+
| list[ContentText | ContentImage | ContentAudio | ContentVideo]
|
35
42
|
)
|
36
43
|
|
37
44
|
|
inspect_ai/util/__init__.py
CHANGED
@@ -26,7 +26,6 @@ from ._subprocess import (
|
|
26
26
|
)
|
27
27
|
from ._subtask import Subtask, subtask
|
28
28
|
from ._throttle import throttle
|
29
|
-
from ._trace import trace_enabled, trace_panel
|
30
29
|
|
31
30
|
__all__ = [
|
32
31
|
"ExecResult",
|
@@ -56,8 +55,6 @@ __all__ = [
|
|
56
55
|
"Subtask",
|
57
56
|
"subtask",
|
58
57
|
"throttle",
|
59
|
-
"trace_enabled",
|
60
|
-
"trace_panel",
|
61
58
|
"trace_action",
|
62
59
|
"trace_message",
|
63
60
|
]
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from contextvars import ContextVar
|
2
|
-
|
3
1
|
from rich import print
|
4
2
|
from rich.console import RenderableType
|
5
3
|
from rich.text import Text
|
@@ -7,12 +5,7 @@ from rich.text import Text
|
|
7
5
|
from inspect_ai._util.transcript import transcript_panel
|
8
6
|
|
9
7
|
|
10
|
-
def
|
11
|
-
"""Is trace mode currently enabled."""
|
12
|
-
return _trace.get(None) is True
|
13
|
-
|
14
|
-
|
15
|
-
def trace_panel(
|
8
|
+
def conversation_panel(
|
16
9
|
title: str,
|
17
10
|
*,
|
18
11
|
subtitle: str | None = None,
|
@@ -20,8 +13,8 @@ def trace_panel(
|
|
20
13
|
) -> None:
|
21
14
|
"""Trace content into a standard trace panel display.
|
22
15
|
|
23
|
-
Typically you would call `
|
24
|
-
|
16
|
+
Typically you would call `display_type() == "conversation"` to confirm that
|
17
|
+
we are in conversation mode before calling `conversation_panel()`.
|
25
18
|
|
26
19
|
Args:
|
27
20
|
title (str): Panel title.
|
@@ -32,10 +25,3 @@ def trace_panel(
|
|
32
25
|
transcript_panel(title, subtitle, content),
|
33
26
|
Text(),
|
34
27
|
)
|
35
|
-
|
36
|
-
|
37
|
-
def init_trace(trace: bool | None) -> None:
|
38
|
-
_trace.set(trace)
|
39
|
-
|
40
|
-
|
41
|
-
_trace: ContextVar[bool | None] = ContextVar("_trace_mode")
|
inspect_ai/util/_display.py
CHANGED
@@ -3,10 +3,11 @@ from logging import getLogger
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
5
|
from inspect_ai._util.constants import DEFAULT_DISPLAY
|
6
|
+
from inspect_ai._util.thread import is_main_thread
|
6
7
|
|
7
8
|
logger = getLogger(__name__)
|
8
9
|
|
9
|
-
DisplayType = Literal["full", "rich", "plain", "none"]
|
10
|
+
DisplayType = Literal["full", "conversation", "rich", "plain", "none"]
|
10
11
|
"""Console display type."""
|
11
12
|
|
12
13
|
|
@@ -15,15 +16,24 @@ _display_type: DisplayType | None = None
|
|
15
16
|
|
16
17
|
def init_display_type(display: str | None = None) -> DisplayType:
|
17
18
|
global _display_type
|
18
|
-
global _display_metrics
|
19
19
|
display = (
|
20
20
|
display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
|
21
21
|
)
|
22
|
+
|
23
|
+
# if we are on a background thread then throttle down to "plain"
|
24
|
+
# ("full" requires textual which cannot run in a background thread
|
25
|
+
# b/c it calls the Python signal function; "rich" assumes exclusive
|
26
|
+
# display access which may not be the case for threads)
|
27
|
+
if display in ["full", "rich"] and not is_main_thread():
|
28
|
+
display = "plain"
|
29
|
+
|
22
30
|
match display:
|
23
|
-
case "full" | "rich" | "plain" | "none":
|
31
|
+
case "full" | "conversation" | "rich" | "plain" | "none":
|
24
32
|
_display_type = display
|
25
33
|
case _:
|
26
|
-
logger.warning(
|
34
|
+
logger.warning(
|
35
|
+
f"Unknown display type '{display}' (setting display to 'full')"
|
36
|
+
)
|
27
37
|
_display_type = "full"
|
28
38
|
return _display_type
|
29
39
|
|
@@ -4,6 +4,8 @@ from typing import Any, NoReturn, cast
|
|
4
4
|
|
5
5
|
from shortuuid import uuid
|
6
6
|
|
7
|
+
from inspect_ai._util.constants import SANDBOX_SETUP_TIMEOUT
|
8
|
+
|
7
9
|
from .environment import (
|
8
10
|
SampleCleanup,
|
9
11
|
SampleInit,
|
@@ -193,23 +195,20 @@ async def setup_sandbox_environment(
|
|
193
195
|
setup_file = f"/tmp/{uuid()}"
|
194
196
|
await env.write_file(setup_file, setup)
|
195
197
|
|
196
|
-
#
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
)
|
204
|
-
|
198
|
+
# execute and then remove setup script (don't retry it on timeout
|
199
|
+
# in case it is not idempotent)
|
200
|
+
try:
|
201
|
+
await env.exec(["chmod", "+x", setup_file], timeout=30)
|
202
|
+
result = await env.exec(
|
203
|
+
["env", setup_file], timeout=SANDBOX_SETUP_TIMEOUT, timeout_retry=False
|
204
|
+
)
|
205
205
|
if not result.success:
|
206
206
|
raise RuntimeError(
|
207
207
|
f"Failed to execute setup script for sample: {result.stderr}"
|
208
208
|
)
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
await exec(["rm", setup_file])
|
209
|
+
await env.exec(["rm", setup_file], timeout=30)
|
210
|
+
except TimeoutError:
|
211
|
+
raise RuntimeError("Timed out executing setup command in sandbox")
|
213
212
|
|
214
213
|
|
215
214
|
def default_sandbox_environment(
|
@@ -25,16 +25,17 @@ COMPOSE_WAIT = "120"
|
|
25
25
|
|
26
26
|
|
27
27
|
async def compose_up(project: ComposeProject) -> None:
|
28
|
-
# Start the environment
|
29
|
-
|
28
|
+
# Start the environment. Note that we don't check the result because docker will
|
29
|
+
# return a non-zero exit code for services that exit (even successfully) when
|
30
|
+
# passing the --wait flag (see https://github.com/docker/compose/issues/10596).
|
31
|
+
# In practice, we will catch any errors when calling compose_check_running()
|
32
|
+
# immediately after we call compose_up().
|
33
|
+
await compose_command(
|
30
34
|
["up", "--detach", "--wait", "--wait-timeout", COMPOSE_WAIT],
|
31
35
|
project=project,
|
32
36
|
# wait up to 5 minutes for container to go up (compose wait + 3 minutes)
|
33
37
|
timeout=300,
|
34
38
|
)
|
35
|
-
if not result.success:
|
36
|
-
msg = f"Failed to start docker services for {project.config}: {result.stderr}"
|
37
|
-
raise RuntimeError(msg)
|
38
39
|
|
39
40
|
|
40
41
|
async def compose_down(project: ComposeProject, quiet: bool = True) -> None:
|
@@ -91,14 +92,21 @@ async def compose_cp(
|
|
91
92
|
raise RuntimeError(msg)
|
92
93
|
|
93
94
|
|
94
|
-
async def compose_check_running(
|
95
|
+
async def compose_check_running(
|
96
|
+
services: list[str], project: ComposeProject
|
97
|
+
) -> list[str]:
|
95
98
|
# Check to ensure that the status of containers is healthy
|
96
99
|
running_services = await compose_ps(project=project, status="running")
|
97
|
-
|
98
|
-
|
100
|
+
exited_services = await compose_ps(project=project, status="exited")
|
101
|
+
successful_services = running_services + [
|
102
|
+
service for service in exited_services if service["ExitCode"] == 0
|
103
|
+
]
|
104
|
+
|
105
|
+
if len(successful_services) > 0:
|
106
|
+
if len(successful_services) != len(services):
|
99
107
|
unhealthy_services = services
|
100
|
-
for
|
101
|
-
unhealthy_services.remove(
|
108
|
+
for successful_service in successful_services:
|
109
|
+
unhealthy_services.remove(successful_service["Service"])
|
102
110
|
|
103
111
|
msg = (
|
104
112
|
"One or more docker containers failed to start from "
|
@@ -108,6 +116,8 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
|
|
108
116
|
else:
|
109
117
|
raise RuntimeError("No services started")
|
110
118
|
|
119
|
+
return [service["Service"] for service in running_services]
|
120
|
+
|
111
121
|
|
112
122
|
async def compose_ps(
|
113
123
|
project: ComposeProject,
|
@@ -166,6 +176,7 @@ async def compose_exec(
|
|
166
176
|
*,
|
167
177
|
project: ComposeProject,
|
168
178
|
timeout: int | None,
|
179
|
+
timeout_retry: bool = True,
|
169
180
|
input: str | bytes | None = None,
|
170
181
|
output_limit: int | None = None,
|
171
182
|
) -> ExecResult[str]:
|
@@ -173,6 +184,7 @@ async def compose_exec(
|
|
173
184
|
["exec"] + command,
|
174
185
|
project=project,
|
175
186
|
timeout=timeout,
|
187
|
+
timeout_retry=timeout_retry,
|
176
188
|
input=input,
|
177
189
|
forward_env=False,
|
178
190
|
output_limit=output_limit,
|
@@ -258,6 +270,7 @@ async def compose_command(
|
|
258
270
|
*,
|
259
271
|
project: ComposeProject,
|
260
272
|
timeout: int | None,
|
273
|
+
timeout_retry: bool = True,
|
261
274
|
input: str | bytes | None = None,
|
262
275
|
cwd: str | Path | None = None,
|
263
276
|
forward_env: bool = True,
|
@@ -325,7 +338,7 @@ async def compose_command(
|
|
325
338
|
return await run_command(command_timeout)
|
326
339
|
except TimeoutError:
|
327
340
|
retries += 1
|
328
|
-
if retries <= MAX_RETRIES:
|
341
|
+
if timeout_retry and (retries <= MAX_RETRIES):
|
329
342
|
logger.info(
|
330
343
|
f"Retrying docker compose command: {shlex.join(compose_command)}"
|
331
344
|
)
|