inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +4 -2
- inspect_ai/_cli/eval.py +2 -0
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +0 -2
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/rich/display.py +4 -4
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/samples.py +41 -5
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/run.py +16 -11
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/run.py +141 -119
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/datetime.py +1 -1
- inspect_ai/_util/deprecation.py +1 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/json.py +11 -1
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/logger.py +2 -1
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_util/trace.py +39 -3
- inspect_ai/_util/transcript.py +36 -7
- inspect_ai/_view/www/.prettierrc.js +12 -0
- inspect_ai/_view/www/dist/assets/index.js +322 -226
- inspect_ai/_view/www/log-schema.json +221 -138
- inspect_ai/_view/www/src/App.mjs +18 -9
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/api/Types.mjs +15 -4
- inspect_ai/_view/www/src/api/api-http.mjs +2 -0
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
- inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
- inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +18 -3
- inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
- inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
- inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
- inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
- inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
- inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
- inspect_ai/_view/www/src/types/log.d.ts +53 -35
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +27 -5
- inspect_ai/log/_recorders/eval.py +21 -8
- inspect_ai/log/_samples.py +10 -5
- inspect_ai/log/_transcript.py +28 -1
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +82 -17
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/{_trace.py → _conversation.py} +9 -8
- inspect_ai/model/_model.py +2 -2
- inspect_ai/model/_providers/anthropic.py +9 -7
- inspect_ai/model/_providers/azureai.py +6 -4
- inspect_ai/model/_providers/bedrock.py +6 -4
- inspect_ai/model/_providers/google.py +103 -14
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +6 -9
- inspect_ai/model/_providers/openai.py +34 -8
- inspect_ai/model/_providers/openai_o1.py +10 -12
- inspect_ai/model/_providers/vertex.py +17 -4
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/tool/__init__.py +9 -1
- inspect_ai/tool/_tool.py +9 -2
- inspect_ai/tool/_tool_info.py +2 -1
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
- inspect_ai/util/__init__.py +4 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -13
- inspect_ai/util/_sandbox/docker/docker.py +20 -13
- inspect_ai/util/_sandbox/docker/util.py +2 -1
- inspect_ai/util/_sandbox/environment.py +13 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- inspect_ai/util/_sandbox/self_check.py +18 -18
- inspect_ai/util/_store.py +2 -2
- inspect_ai/util/_subprocess.py +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from openai.types.chat import (
|
|
17
17
|
ChatCompletion,
|
18
18
|
ChatCompletionAssistantMessageParam,
|
19
19
|
ChatCompletionContentPartImageParam,
|
20
|
+
ChatCompletionContentPartInputAudioParam,
|
20
21
|
ChatCompletionContentPartParam,
|
21
22
|
ChatCompletionContentPartTextParam,
|
22
23
|
ChatCompletionDeveloperMessageParam,
|
@@ -36,9 +37,9 @@ from typing_extensions import override
|
|
36
37
|
from inspect_ai._util.constants import DEFAULT_MAX_RETRIES
|
37
38
|
from inspect_ai._util.content import Content
|
38
39
|
from inspect_ai._util.error import PrerequisiteError
|
39
|
-
from inspect_ai._util.images import
|
40
|
+
from inspect_ai._util.images import file_as_data_uri
|
40
41
|
from inspect_ai._util.logger import warn_once
|
41
|
-
from inspect_ai._util.url import
|
42
|
+
from inspect_ai._util.url import is_http_url
|
42
43
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
43
44
|
|
44
45
|
from .._chat_message import ChatMessage, ChatMessageAssistant
|
@@ -51,6 +52,7 @@ from .._model_output import (
|
|
51
52
|
Logprobs,
|
52
53
|
ModelOutput,
|
53
54
|
ModelUsage,
|
55
|
+
StopReason,
|
54
56
|
)
|
55
57
|
from .openai_o1 import generate_o1
|
56
58
|
from .util import (
|
@@ -262,7 +264,10 @@ class OpenAIAPI(ModelAPI):
|
|
262
264
|
model=self.model_name,
|
263
265
|
)
|
264
266
|
if config.max_tokens is not None:
|
265
|
-
|
267
|
+
if self.is_o1():
|
268
|
+
params["max_completion_tokens"] = config.max_tokens
|
269
|
+
else:
|
270
|
+
params["max_tokens"] = config.max_tokens
|
266
271
|
if config.frequency_penalty is not None:
|
267
272
|
params["frequency_penalty"] = config.frequency_penalty
|
268
273
|
if config.stop_seqs is not None:
|
@@ -303,13 +308,23 @@ class OpenAIAPI(ModelAPI):
|
|
303
308
|
|
304
309
|
# convert some well known bad request errors into ModelOutput
|
305
310
|
def handle_bad_request(self, e: BadRequestError) -> ModelOutput:
|
306
|
-
if e.status_code == 400
|
311
|
+
if e.status_code == 400:
|
312
|
+
# extract message
|
307
313
|
if isinstance(e.body, dict) and "message" in e.body.keys():
|
308
314
|
content = str(e.body.get("message"))
|
309
315
|
else:
|
310
316
|
content = e.message
|
317
|
+
|
318
|
+
# narrow stop_reason
|
319
|
+
if e.code == "context_length_exceeded":
|
320
|
+
stop_reason: StopReason = "model_length"
|
321
|
+
elif e.code == "invalid_prompt":
|
322
|
+
stop_reason = "content_filter"
|
323
|
+
else:
|
324
|
+
stop_reason = "unknown"
|
325
|
+
|
311
326
|
return ModelOutput.from_content(
|
312
|
-
model=self.model_name, content=content, stop_reason=
|
327
|
+
model=self.model_name, content=content, stop_reason=stop_reason
|
313
328
|
)
|
314
329
|
else:
|
315
330
|
raise e
|
@@ -449,16 +464,27 @@ async def as_chat_completion_part(
|
|
449
464
|
) -> ChatCompletionContentPartParam:
|
450
465
|
if content.type == "text":
|
451
466
|
return ChatCompletionContentPartTextParam(type="text", text=content.text)
|
452
|
-
|
467
|
+
elif content.type == "image":
|
453
468
|
# API takes URL or base64 encoded file. If it's a remote file or
|
454
469
|
# data URL leave it alone, otherwise encode it
|
455
470
|
image_url = content.image
|
456
471
|
detail = content.detail
|
457
472
|
|
458
|
-
if not is_http_url(image_url)
|
459
|
-
image_url = await
|
473
|
+
if not is_http_url(image_url):
|
474
|
+
image_url = await file_as_data_uri(image_url)
|
460
475
|
|
461
476
|
return ChatCompletionContentPartImageParam(
|
462
477
|
type="image_url",
|
463
478
|
image_url=dict(url=image_url, detail=detail),
|
464
479
|
)
|
480
|
+
elif content.type == "audio":
|
481
|
+
audio_data = await file_as_data_uri(content.audio)
|
482
|
+
|
483
|
+
return ChatCompletionContentPartInputAudioParam(
|
484
|
+
type="input_audio", input_audio=dict(data=audio_data, format=content.format)
|
485
|
+
)
|
486
|
+
|
487
|
+
else:
|
488
|
+
raise RuntimeError(
|
489
|
+
"Video content is not currently supported by Open AI chat models."
|
490
|
+
)
|
@@ -25,7 +25,7 @@ from inspect_ai.model import (
|
|
25
25
|
from inspect_ai.tool import ToolCall, ToolInfo
|
26
26
|
|
27
27
|
from .._model_call import ModelCall
|
28
|
-
from .._model_output import ModelUsage
|
28
|
+
from .._model_output import ModelUsage, StopReason
|
29
29
|
from .._providers.util import (
|
30
30
|
ChatAPIHandler,
|
31
31
|
ChatAPIMessage,
|
@@ -48,12 +48,6 @@ async def generate_o1(
|
|
48
48
|
# create chatapi handler
|
49
49
|
handler = O1PreviewChatAPIHandler()
|
50
50
|
|
51
|
-
# map max_tokens => max_completion_tokens
|
52
|
-
max_tokens = params.get("max_tokens", None)
|
53
|
-
if max_tokens:
|
54
|
-
params["max_completion_tokens"] = max_tokens
|
55
|
-
del params["max_tokens"]
|
56
|
-
|
57
51
|
# call model
|
58
52
|
request = dict(
|
59
53
|
model=model,
|
@@ -89,12 +83,16 @@ async def generate_o1(
|
|
89
83
|
|
90
84
|
|
91
85
|
def handle_bad_request(model: str, ex: BadRequestError) -> ModelOutput:
|
92
|
-
if ex.code == "
|
93
|
-
|
94
|
-
|
95
|
-
|
86
|
+
if ex.code == "context_length_exceeded":
|
87
|
+
stop_reason: StopReason = "model_length"
|
88
|
+
elif ex.code == "invalid_prompt":
|
89
|
+
stop_reason = "content_filter"
|
96
90
|
else:
|
97
|
-
|
91
|
+
stop_reason = "unknown"
|
92
|
+
|
93
|
+
return ModelOutput.from_content(
|
94
|
+
model=model, content=str(ex), stop_reason=stop_reason
|
95
|
+
)
|
98
96
|
|
99
97
|
|
100
98
|
def chat_messages(
|
@@ -24,8 +24,14 @@ from vertexai.generative_models import ( # type: ignore
|
|
24
24
|
from vertexai.generative_models import Content as VertexContent
|
25
25
|
|
26
26
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
27
|
-
from inspect_ai._util.content import
|
28
|
-
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
34
|
+
from inspect_ai._util.images import file_as_data
|
29
35
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
|
30
36
|
|
31
37
|
from .._chat_message import (
|
@@ -308,9 +314,16 @@ async def content_part(content: Content | str) -> Part:
|
|
308
314
|
return Part.from_text(content or NO_CONTENT)
|
309
315
|
elif isinstance(content, ContentText):
|
310
316
|
return Part.from_text(content.text or NO_CONTENT)
|
311
|
-
|
312
|
-
image_bytes, mime_type = await
|
317
|
+
elif isinstance(content, ContentImage):
|
318
|
+
image_bytes, mime_type = await file_as_data(content.image)
|
313
319
|
return Part.from_image(image=Image.from_bytes(data=image_bytes))
|
320
|
+
else:
|
321
|
+
if isinstance(content, ContentAudio):
|
322
|
+
file = content.audio
|
323
|
+
elif isinstance(content, ContentVideo):
|
324
|
+
file = content.video
|
325
|
+
file_bytes, mime_type = await file_as_data(file)
|
326
|
+
return Part.from_data(file_bytes, mime_type)
|
314
327
|
|
315
328
|
|
316
329
|
def prepend_system_messages(
|
inspect_ai/scorer/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
|
+
|
1
3
|
from ._answer import AnswerPattern, answer
|
2
4
|
from ._choice import choice
|
3
5
|
from ._classification import exact, f1
|
@@ -16,7 +18,7 @@ from ._metric import (
|
|
16
18
|
)
|
17
19
|
from ._metrics.accuracy import accuracy
|
18
20
|
from ._metrics.mean import mean
|
19
|
-
from ._metrics.std import
|
21
|
+
from ._metrics.std import bootstrap_stderr, std, stderr
|
20
22
|
from ._model import model_graded_fact, model_graded_qa
|
21
23
|
from ._multi import multi_scorer
|
22
24
|
from ._pattern import pattern
|
@@ -50,7 +52,7 @@ __all__ = [
|
|
50
52
|
"Target",
|
51
53
|
"scorer",
|
52
54
|
"accuracy",
|
53
|
-
"
|
55
|
+
"bootstrap_stderr",
|
54
56
|
"std",
|
55
57
|
"stderr",
|
56
58
|
"mean",
|
@@ -76,3 +78,12 @@ __all__ = [
|
|
76
78
|
"at_least",
|
77
79
|
"pass_at",
|
78
80
|
]
|
81
|
+
_BOOTSTRAP_RENAME_VERSION = "0.3.58"
|
82
|
+
_REMOVED_IN = "0.4"
|
83
|
+
|
84
|
+
relocated_module_attribute(
|
85
|
+
"bootstrap_std",
|
86
|
+
"inspect_ai.scorer.bootstrap_stderr",
|
87
|
+
_BOOTSTRAP_RENAME_VERSION,
|
88
|
+
_REMOVED_IN,
|
89
|
+
)
|
@@ -1,12 +1,12 @@
|
|
1
1
|
from .accuracy import accuracy
|
2
2
|
from .mean import mean, var
|
3
|
-
from .std import
|
3
|
+
from .std import bootstrap_stderr, std, stderr
|
4
4
|
|
5
5
|
__all__ = [
|
6
6
|
"accuracy",
|
7
7
|
"mean",
|
8
8
|
"var",
|
9
|
-
"
|
9
|
+
"bootstrap_stderr",
|
10
10
|
"std",
|
11
11
|
"stderr",
|
12
12
|
]
|
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
|
|
15
15
|
|
16
16
|
|
17
17
|
@metric
|
18
|
-
def
|
18
|
+
def bootstrap_stderr(
|
19
19
|
num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
|
20
20
|
) -> Metric:
|
21
|
-
"""Standard
|
21
|
+
"""Standard error of the mean using bootstrap.
|
22
22
|
|
23
23
|
Args:
|
24
24
|
num_samples (int): Number of bootstrap samples to take.
|
@@ -31,7 +31,7 @@ def bootstrap_std(
|
|
31
31
|
0 if the Value is a complex object (list or dict).
|
32
32
|
|
33
33
|
Returns:
|
34
|
-
|
34
|
+
bootstrap_stderr metric
|
35
35
|
"""
|
36
36
|
|
37
37
|
def metric(scores: list[Score]) -> float:
|
inspect_ai/tool/__init__.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1
|
-
from inspect_ai._util.content import
|
1
|
+
from inspect_ai._util.content import (
|
2
|
+
Content,
|
3
|
+
ContentAudio,
|
4
|
+
ContentImage,
|
5
|
+
ContentText,
|
6
|
+
ContentVideo,
|
7
|
+
)
|
2
8
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
3
9
|
|
4
10
|
from ._tool import Tool, ToolError, ToolResult, tool
|
@@ -30,8 +36,10 @@ __all__ = [
|
|
30
36
|
"ToolError",
|
31
37
|
"ToolResult",
|
32
38
|
"Content",
|
39
|
+
"ContentAudio",
|
33
40
|
"ContentImage",
|
34
41
|
"ContentText",
|
42
|
+
"ContentVideo",
|
35
43
|
"ToolCall",
|
36
44
|
"ToolCallContent",
|
37
45
|
"ToolCallView",
|
inspect_ai/tool/_tool.py
CHANGED
@@ -11,7 +11,12 @@ from typing import (
|
|
11
11
|
runtime_checkable,
|
12
12
|
)
|
13
13
|
|
14
|
-
from inspect_ai._util.content import
|
14
|
+
from inspect_ai._util.content import (
|
15
|
+
ContentAudio,
|
16
|
+
ContentImage,
|
17
|
+
ContentText,
|
18
|
+
ContentVideo,
|
19
|
+
)
|
15
20
|
from inspect_ai._util.registry import (
|
16
21
|
RegistryInfo,
|
17
22
|
registry_add,
|
@@ -31,7 +36,9 @@ ToolResult = (
|
|
31
36
|
| bool
|
32
37
|
| ContentText
|
33
38
|
| ContentImage
|
34
|
-
|
|
39
|
+
| ContentAudio
|
40
|
+
| ContentVideo
|
41
|
+
| list[ContentText | ContentImage | ContentAudio | ContentVideo]
|
35
42
|
)
|
36
43
|
|
37
44
|
|
inspect_ai/tool/_tool_info.py
CHANGED
@@ -8,6 +8,7 @@ from typing import (
|
|
8
8
|
Dict,
|
9
9
|
List,
|
10
10
|
Optional,
|
11
|
+
Tuple,
|
11
12
|
Type,
|
12
13
|
Union,
|
13
14
|
get_args,
|
@@ -155,7 +156,7 @@ def parse_type(type_hint: Type[Any]) -> ToolParam:
|
|
155
156
|
return ToolParam(type="null")
|
156
157
|
else:
|
157
158
|
return ToolParam()
|
158
|
-
elif origin is list or origin is List:
|
159
|
+
elif origin is list or origin is List or origin is tuple or origin is Tuple:
|
159
160
|
return ToolParam(
|
160
161
|
type="array", items=parse_type(args[0]) if args else ToolParam()
|
161
162
|
)
|
@@ -38,9 +38,9 @@ class EnvironmentSpec:
|
|
38
38
|
for i, obs_spec in enumerate(env_obs_spec.values()):
|
39
39
|
self.observation_spec[i + 1] = convert(obs_spec)
|
40
40
|
|
41
|
-
assert isinstance(
|
42
|
-
|
43
|
-
)
|
41
|
+
assert isinstance(env.action_spec(), specs.Array), (
|
42
|
+
"Only a single action type is supported."
|
43
|
+
)
|
44
44
|
self.action_spec = {1: convert(env.action_spec())}
|
45
45
|
|
46
46
|
self.observation_manager = spec_manager.SpecManager(self.observation_spec)
|
@@ -234,12 +234,12 @@ class EnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer):
|
|
234
234
|
observations.
|
235
235
|
"""
|
236
236
|
with self._lock:
|
237
|
-
assert (
|
238
|
-
|
239
|
-
)
|
240
|
-
assert (
|
241
|
-
|
242
|
-
)
|
237
|
+
assert cur_world in self._envs, (
|
238
|
+
"Current world does not have an assosiated environment"
|
239
|
+
)
|
240
|
+
assert cur_world in self._joined_worlds, (
|
241
|
+
"Please join world before calling step."
|
242
|
+
)
|
243
243
|
env = self._envs[cur_world]
|
244
244
|
spec = self._specs[cur_world]
|
245
245
|
|
@@ -372,7 +372,9 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
|
|
372
372
|
)
|
373
373
|
else:
|
374
374
|
response = parse_web_browser_output(result.stdout)
|
375
|
-
if "
|
375
|
+
if "error" in response and response.get("error", "").strip() != "":
|
376
|
+
raise ToolError(str(response.get("error")) or "(unknown error)")
|
377
|
+
elif "web_at" in response:
|
376
378
|
web_at = (
|
377
379
|
str(response.get("web_at")) or "(no web accessiblity tree available)"
|
378
380
|
)
|
@@ -384,8 +386,6 @@ async def web_browser_cmd(cmd: str, *args: str) -> str:
|
|
384
386
|
web_at = "\n".join(web_at_lines)
|
385
387
|
store_as(WebBrowserStore).web_at = web_at
|
386
388
|
return web_at
|
387
|
-
elif "error" in response:
|
388
|
-
raise ToolError(str(response.get("error")) or "(unknown error)")
|
389
389
|
else:
|
390
390
|
raise RuntimeError(
|
391
391
|
f"web_browser output must contain either 'error' or 'web_at' field: {result.stdout}"
|
inspect_ai/util/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from inspect_ai._util.trace import trace_action, trace_message
|
2
|
+
|
1
3
|
from ._concurrency import concurrency
|
2
4
|
from ._console import input_screen
|
3
5
|
from ._display import DisplayType, display_type
|
@@ -24,7 +26,6 @@ from ._subprocess import (
|
|
24
26
|
)
|
25
27
|
from ._subtask import Subtask, subtask
|
26
28
|
from ._throttle import throttle
|
27
|
-
from ._trace import trace_enabled, trace_panel
|
28
29
|
|
29
30
|
__all__ = [
|
30
31
|
"ExecResult",
|
@@ -54,6 +55,6 @@ __all__ = [
|
|
54
55
|
"Subtask",
|
55
56
|
"subtask",
|
56
57
|
"throttle",
|
57
|
-
"
|
58
|
-
"
|
58
|
+
"trace_action",
|
59
|
+
"trace_message",
|
59
60
|
]
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from contextvars import ContextVar
|
2
|
-
|
3
1
|
from rich import print
|
4
2
|
from rich.console import RenderableType
|
5
3
|
from rich.text import Text
|
@@ -7,12 +5,7 @@ from rich.text import Text
|
|
7
5
|
from inspect_ai._util.transcript import transcript_panel
|
8
6
|
|
9
7
|
|
10
|
-
def
|
11
|
-
"""Is trace mode currently enabled."""
|
12
|
-
return _trace.get(None) is True
|
13
|
-
|
14
|
-
|
15
|
-
def trace_panel(
|
8
|
+
def conversation_panel(
|
16
9
|
title: str,
|
17
10
|
*,
|
18
11
|
subtitle: str | None = None,
|
@@ -20,8 +13,8 @@ def trace_panel(
|
|
20
13
|
) -> None:
|
21
14
|
"""Trace content into a standard trace panel display.
|
22
15
|
|
23
|
-
Typically you would call `
|
24
|
-
|
16
|
+
Typically you would call `display_type() == "conversation"` to confirm that
|
17
|
+
we are in conversation mode before calling `conversation_panel()`.
|
25
18
|
|
26
19
|
Args:
|
27
20
|
title (str): Panel title.
|
@@ -32,10 +25,3 @@ def trace_panel(
|
|
32
25
|
transcript_panel(title, subtitle, content),
|
33
26
|
Text(),
|
34
27
|
)
|
35
|
-
|
36
|
-
|
37
|
-
def init_trace(trace: bool | None) -> None:
|
38
|
-
_trace.set(trace)
|
39
|
-
|
40
|
-
|
41
|
-
_trace: ContextVar[bool | None] = ContextVar("_trace_mode")
|
inspect_ai/util/_display.py
CHANGED
@@ -3,10 +3,11 @@ from logging import getLogger
|
|
3
3
|
from typing import Literal
|
4
4
|
|
5
5
|
from inspect_ai._util.constants import DEFAULT_DISPLAY
|
6
|
+
from inspect_ai._util.thread import is_main_thread
|
6
7
|
|
7
8
|
logger = getLogger(__name__)
|
8
9
|
|
9
|
-
DisplayType = Literal["full", "rich", "plain", "none"]
|
10
|
+
DisplayType = Literal["full", "conversation", "rich", "plain", "none"]
|
10
11
|
"""Console display type."""
|
11
12
|
|
12
13
|
|
@@ -15,15 +16,24 @@ _display_type: DisplayType | None = None
|
|
15
16
|
|
16
17
|
def init_display_type(display: str | None = None) -> DisplayType:
|
17
18
|
global _display_type
|
18
|
-
global _display_metrics
|
19
19
|
display = (
|
20
20
|
display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip()
|
21
21
|
)
|
22
|
+
|
23
|
+
# if we are on a background thread then throttle down to "plain"
|
24
|
+
# ("full" requires textual which cannot run in a background thread
|
25
|
+
# b/c it calls the Python signal function; "rich" assumes exclusive
|
26
|
+
# display access which may not be the case for threads)
|
27
|
+
if display in ["full", "rich"] and not is_main_thread():
|
28
|
+
display = "plain"
|
29
|
+
|
22
30
|
match display:
|
23
|
-
case "full" | "rich" | "plain" | "none":
|
31
|
+
case "full" | "conversation" | "rich" | "plain" | "none":
|
24
32
|
_display_type = display
|
25
33
|
case _:
|
26
|
-
logger.warning(
|
34
|
+
logger.warning(
|
35
|
+
f"Unknown display type '{display}' (setting display to 'full')"
|
36
|
+
)
|
27
37
|
_display_type = "full"
|
28
38
|
return _display_type
|
29
39
|
|
@@ -4,6 +4,8 @@ from typing import Any, NoReturn, cast
|
|
4
4
|
|
5
5
|
from shortuuid import uuid
|
6
6
|
|
7
|
+
from inspect_ai._util.constants import SANDBOX_SETUP_TIMEOUT
|
8
|
+
|
7
9
|
from .environment import (
|
8
10
|
SampleCleanup,
|
9
11
|
SampleInit,
|
@@ -193,23 +195,20 @@ async def setup_sandbox_environment(
|
|
193
195
|
setup_file = f"/tmp/{uuid()}"
|
194
196
|
await env.write_file(setup_file, setup)
|
195
197
|
|
196
|
-
#
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
)
|
204
|
-
|
198
|
+
# execute and then remove setup script (don't retry it on timeout
|
199
|
+
# in case it is not idempotent)
|
200
|
+
try:
|
201
|
+
await env.exec(["chmod", "+x", setup_file], timeout=30)
|
202
|
+
result = await env.exec(
|
203
|
+
["env", setup_file], timeout=SANDBOX_SETUP_TIMEOUT, timeout_retry=False
|
204
|
+
)
|
205
205
|
if not result.success:
|
206
206
|
raise RuntimeError(
|
207
207
|
f"Failed to execute setup script for sample: {result.stderr}"
|
208
208
|
)
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
await exec(["rm", setup_file])
|
209
|
+
await env.exec(["rm", setup_file], timeout=30)
|
210
|
+
except TimeoutError:
|
211
|
+
raise RuntimeError("Timed out executing setup command in sandbox")
|
213
212
|
|
214
213
|
|
215
214
|
def default_sandbox_environment(
|
@@ -25,18 +25,17 @@ COMPOSE_WAIT = "120"
|
|
25
25
|
|
26
26
|
|
27
27
|
async def compose_up(project: ComposeProject) -> None:
|
28
|
-
# Start the environment
|
29
|
-
|
28
|
+
# Start the environment. Note that we don't check the result because docker will
|
29
|
+
# return a non-zero exit code for services that exit (even successfully) when
|
30
|
+
# passing the --wait flag (see https://github.com/docker/compose/issues/10596).
|
31
|
+
# In practice, we will catch any errors when calling compose_check_running()
|
32
|
+
# immediately after we call compose_up().
|
33
|
+
await compose_command(
|
30
34
|
["up", "--detach", "--wait", "--wait-timeout", COMPOSE_WAIT],
|
31
35
|
project=project,
|
32
36
|
# wait up to 5 minutes for container to go up (compose wait + 3 minutes)
|
33
37
|
timeout=300,
|
34
38
|
)
|
35
|
-
if not result.success:
|
36
|
-
msg = (
|
37
|
-
f"Failed to start docker services for {project.config}: " f"{result.stderr}"
|
38
|
-
)
|
39
|
-
raise RuntimeError(msg)
|
40
39
|
|
41
40
|
|
42
41
|
async def compose_down(project: ComposeProject, quiet: bool = True) -> None:
|
@@ -93,14 +92,21 @@ async def compose_cp(
|
|
93
92
|
raise RuntimeError(msg)
|
94
93
|
|
95
94
|
|
96
|
-
async def compose_check_running(
|
95
|
+
async def compose_check_running(
|
96
|
+
services: list[str], project: ComposeProject
|
97
|
+
) -> list[str]:
|
97
98
|
# Check to ensure that the status of containers is healthy
|
98
99
|
running_services = await compose_ps(project=project, status="running")
|
99
|
-
|
100
|
-
|
100
|
+
exited_services = await compose_ps(project=project, status="exited")
|
101
|
+
successful_services = running_services + [
|
102
|
+
service for service in exited_services if service["ExitCode"] == 0
|
103
|
+
]
|
104
|
+
|
105
|
+
if len(successful_services) > 0:
|
106
|
+
if len(successful_services) != len(services):
|
101
107
|
unhealthy_services = services
|
102
|
-
for
|
103
|
-
unhealthy_services.remove(
|
108
|
+
for successful_service in successful_services:
|
109
|
+
unhealthy_services.remove(successful_service["Service"])
|
104
110
|
|
105
111
|
msg = (
|
106
112
|
"One or more docker containers failed to start from "
|
@@ -110,6 +116,8 @@ async def compose_check_running(services: list[str], project: ComposeProject) ->
|
|
110
116
|
else:
|
111
117
|
raise RuntimeError("No services started")
|
112
118
|
|
119
|
+
return [service["Service"] for service in running_services]
|
120
|
+
|
113
121
|
|
114
122
|
async def compose_ps(
|
115
123
|
project: ComposeProject,
|
@@ -168,6 +176,7 @@ async def compose_exec(
|
|
168
176
|
*,
|
169
177
|
project: ComposeProject,
|
170
178
|
timeout: int | None,
|
179
|
+
timeout_retry: bool = True,
|
171
180
|
input: str | bytes | None = None,
|
172
181
|
output_limit: int | None = None,
|
173
182
|
) -> ExecResult[str]:
|
@@ -175,6 +184,7 @@ async def compose_exec(
|
|
175
184
|
["exec"] + command,
|
176
185
|
project=project,
|
177
186
|
timeout=timeout,
|
187
|
+
timeout_retry=timeout_retry,
|
178
188
|
input=input,
|
179
189
|
forward_env=False,
|
180
190
|
output_limit=output_limit,
|
@@ -260,6 +270,7 @@ async def compose_command(
|
|
260
270
|
*,
|
261
271
|
project: ComposeProject,
|
262
272
|
timeout: int | None,
|
273
|
+
timeout_retry: bool = True,
|
263
274
|
input: str | bytes | None = None,
|
264
275
|
cwd: str | Path | None = None,
|
265
276
|
forward_env: bool = True,
|
@@ -327,7 +338,7 @@ async def compose_command(
|
|
327
338
|
return await run_command(command_timeout)
|
328
339
|
except TimeoutError:
|
329
340
|
retries += 1
|
330
|
-
if retries <= MAX_RETRIES:
|
341
|
+
if timeout_retry and (retries <= MAX_RETRIES):
|
331
342
|
logger.info(
|
332
343
|
f"Retrying docker compose command: {shlex.join(compose_command)}"
|
333
344
|
)
|