inspect-ai 0.3.56__py3-none-any.whl → 0.3.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +4 -2
- inspect_ai/_cli/eval.py +2 -0
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +0 -2
- inspect_ai/_display/core/panel.py +1 -1
- inspect_ai/_display/rich/display.py +4 -4
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/samples.py +41 -5
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/run.py +16 -11
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/run.py +141 -119
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/datetime.py +1 -1
- inspect_ai/_util/deprecation.py +1 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/json.py +11 -1
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/logger.py +2 -1
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_util/trace.py +39 -3
- inspect_ai/_util/transcript.py +36 -7
- inspect_ai/_view/www/.prettierrc.js +12 -0
- inspect_ai/_view/www/dist/assets/index.js +322 -226
- inspect_ai/_view/www/log-schema.json +221 -138
- inspect_ai/_view/www/src/App.mjs +18 -9
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/api/Types.mjs +15 -4
- inspect_ai/_view/www/src/api/api-http.mjs +2 -0
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +2 -2
- inspect_ai/_view/www/src/components/FindBand.mjs +5 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +1 -1
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +44 -2
- inspect_ai/_view/www/src/components/TabSet.mjs +1 -1
- inspect_ai/_view/www/src/components/Tools.mjs +18 -3
- inspect_ai/_view/www/src/components/VirtualList.mjs +15 -17
- inspect_ai/_view/www/src/log/remoteLogFile.mjs +2 -1
- inspect_ai/_view/www/src/navbar/Navbar.mjs +44 -32
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +1 -2
- inspect_ai/_view/www/src/samples/SampleList.mjs +35 -4
- inspect_ai/_view/www/src/samples/SampleScoreView.mjs +13 -2
- inspect_ai/_view/www/src/samples/SampleScores.mjs +11 -2
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +242 -178
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -2
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +5 -5
- inspect_ai/_view/www/src/samples/tools/SelectScorer.mjs +7 -0
- inspect_ai/_view/www/src/samples/tools/SortFilter.mjs +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +1 -1
- inspect_ai/_view/www/src/types/log.d.ts +53 -35
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +1 -1
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +27 -5
- inspect_ai/log/_recorders/eval.py +21 -8
- inspect_ai/log/_samples.py +10 -5
- inspect_ai/log/_transcript.py +28 -1
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +82 -17
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/{_trace.py → _conversation.py} +9 -8
- inspect_ai/model/_model.py +2 -2
- inspect_ai/model/_providers/anthropic.py +9 -7
- inspect_ai/model/_providers/azureai.py +6 -4
- inspect_ai/model/_providers/bedrock.py +6 -4
- inspect_ai/model/_providers/google.py +103 -14
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +6 -9
- inspect_ai/model/_providers/openai.py +34 -8
- inspect_ai/model/_providers/openai_o1.py +10 -12
- inspect_ai/model/_providers/vertex.py +17 -4
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/tool/__init__.py +9 -1
- inspect_ai/tool/_tool.py +9 -2
- inspect_ai/tool/_tool_info.py +2 -1
- inspect_ai/tool/_tools/_web_browser/_resources/dm_env_servicer.py +9 -9
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +3 -3
- inspect_ai/util/__init__.py +4 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -13
- inspect_ai/util/_sandbox/docker/docker.py +20 -13
- inspect_ai/util/_sandbox/docker/util.py +2 -1
- inspect_ai/util/_sandbox/environment.py +13 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- inspect_ai/util/_sandbox/self_check.py +18 -18
- inspect_ai/util/_store.py +2 -2
- inspect_ai/util/_subprocess.py +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +107 -103
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.56.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
inspect_ai/_eval/task/images.py
CHANGED
@@ -1,66 +1,69 @@
|
|
1
1
|
import asyncio
|
2
2
|
|
3
3
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
4
|
-
from inspect_ai._util.
|
4
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentVideo
|
5
|
+
from inspect_ai._util.images import file_as_data_uri
|
5
6
|
from inspect_ai._util.url import is_data_uri
|
6
7
|
from inspect_ai.dataset import Sample
|
7
|
-
from inspect_ai.model import ChatMessage, ChatMessageUser
|
8
|
+
from inspect_ai.model import ChatMessage, ChatMessageUser
|
8
9
|
from inspect_ai.solver import TaskState
|
9
10
|
|
10
11
|
|
11
|
-
async def
|
12
|
-
return await asyncio.gather(*[
|
12
|
+
async def states_with_base64_content(states: list[TaskState]) -> list[TaskState]:
|
13
|
+
return await asyncio.gather(*[state_with_base64_content(state) for state in states])
|
13
14
|
|
14
15
|
|
15
|
-
async def
|
16
|
-
state.messages = await
|
16
|
+
async def state_with_base64_content(state: TaskState) -> TaskState:
|
17
|
+
state.messages = await messages_with_base64_content(state.messages)
|
17
18
|
return state
|
18
19
|
|
19
20
|
|
20
|
-
def
|
21
|
-
state.messages =
|
21
|
+
def state_without_base64_content(state: TaskState) -> TaskState:
|
22
|
+
state.messages = messages_without_base64_content(state.messages)
|
22
23
|
return state
|
23
24
|
|
24
25
|
|
25
|
-
async def
|
26
|
+
async def samples_with_base64_content(samples: list[Sample]) -> list[Sample]:
|
26
27
|
return await asyncio.gather(
|
27
|
-
*[
|
28
|
+
*[sample_with_base64_content(sample) for sample in samples]
|
28
29
|
)
|
29
30
|
|
30
31
|
|
31
|
-
async def
|
32
|
+
async def sample_with_base64_content(sample: Sample) -> Sample:
|
32
33
|
if isinstance(sample.input, list):
|
33
34
|
return sample.model_copy(
|
34
|
-
update={"input": await
|
35
|
+
update={"input": await messages_with_base64_content(sample.input)}
|
35
36
|
)
|
36
37
|
else:
|
37
38
|
return sample
|
38
39
|
|
39
40
|
|
40
|
-
def
|
41
|
+
def sample_without_base64_content(sample: Sample) -> Sample:
|
41
42
|
if isinstance(sample.input, list):
|
42
43
|
return sample.model_copy(
|
43
|
-
update={"input":
|
44
|
+
update={"input": messages_without_base64_content(sample.input)}
|
44
45
|
)
|
45
46
|
else:
|
46
47
|
return sample
|
47
48
|
|
48
49
|
|
49
|
-
async def
|
50
|
+
async def messages_with_base64_content(
|
51
|
+
messages: list[ChatMessage],
|
52
|
+
) -> list[ChatMessage]:
|
50
53
|
return await asyncio.gather(
|
51
|
-
*[
|
54
|
+
*[message_with_base64_content(message) for message in messages]
|
52
55
|
)
|
53
56
|
|
54
57
|
|
55
|
-
def
|
56
|
-
return [
|
58
|
+
def messages_without_base64_content(messages: list[ChatMessage]) -> list[ChatMessage]:
|
59
|
+
return [message_without_base64_content(message) for message in messages]
|
57
60
|
|
58
61
|
|
59
|
-
async def
|
62
|
+
async def message_with_base64_content(message: ChatMessage) -> ChatMessage:
|
60
63
|
if isinstance(message, ChatMessageUser) and not isinstance(message.content, str):
|
61
64
|
return ChatMessageUser(
|
62
65
|
content=[
|
63
|
-
await
|
66
|
+
await chat_content_with_base64_content(content)
|
64
67
|
for content in message.content
|
65
68
|
],
|
66
69
|
source=message.source,
|
@@ -69,11 +72,11 @@ async def message_with_base64_image(message: ChatMessage) -> ChatMessage:
|
|
69
72
|
return message
|
70
73
|
|
71
74
|
|
72
|
-
def
|
75
|
+
def message_without_base64_content(message: ChatMessage) -> ChatMessage:
|
73
76
|
if isinstance(message, ChatMessageUser) and not isinstance(message.content, str):
|
74
77
|
return ChatMessageUser(
|
75
78
|
content=[
|
76
|
-
|
79
|
+
chat_content_without_base64_content(content)
|
77
80
|
for content in message.content
|
78
81
|
],
|
79
82
|
source=message.source,
|
@@ -82,18 +85,30 @@ def message_without_base64_image(message: ChatMessage) -> ChatMessage:
|
|
82
85
|
return message
|
83
86
|
|
84
87
|
|
85
|
-
async def
|
88
|
+
async def chat_content_with_base64_content(content: Content) -> Content:
|
86
89
|
if isinstance(content, ContentImage):
|
87
90
|
return ContentImage(
|
88
|
-
image=await
|
91
|
+
image=await file_as_data_uri(content.image),
|
89
92
|
detail=content.detail,
|
90
93
|
)
|
94
|
+
elif isinstance(content, ContentAudio):
|
95
|
+
return ContentAudio(
|
96
|
+
audio=await file_as_data_uri(content.audio), format=content.format
|
97
|
+
)
|
98
|
+
elif isinstance(content, ContentVideo):
|
99
|
+
return ContentVideo(
|
100
|
+
video=await file_as_data_uri(content.video), format=content.format
|
101
|
+
)
|
91
102
|
else:
|
92
103
|
return content
|
93
104
|
|
94
105
|
|
95
|
-
def
|
106
|
+
def chat_content_without_base64_content(content: Content) -> Content:
|
96
107
|
if isinstance(content, ContentImage) and is_data_uri(content.image):
|
97
108
|
return ContentImage(image=BASE_64_DATA_REMOVED, detail=content.detail)
|
109
|
+
elif isinstance(content, ContentAudio) and is_data_uri(content.audio):
|
110
|
+
return ContentAudio(audio=BASE_64_DATA_REMOVED, format="mp3")
|
111
|
+
elif isinstance(content, ContentVideo) and is_data_uri(content.video):
|
112
|
+
return ContentVideo(video=BASE_64_DATA_REMOVED, format="mp4")
|
98
113
|
else:
|
99
114
|
return content
|
inspect_ai/_eval/task/run.py
CHANGED
@@ -4,6 +4,7 @@ import sys
|
|
4
4
|
import time
|
5
5
|
from copy import deepcopy
|
6
6
|
from dataclasses import dataclass, field
|
7
|
+
from datetime import datetime
|
7
8
|
from logging import getLogger
|
8
9
|
from pathlib import PurePath
|
9
10
|
from typing import Callable, Literal
|
@@ -71,6 +72,7 @@ from inspect_ai.solver._chain import Chain, unroll
|
|
71
72
|
from inspect_ai.solver._fork import set_task_generate
|
72
73
|
from inspect_ai.solver._solver import Solver
|
73
74
|
from inspect_ai.solver._task_state import sample_state, set_sample_state, state_jsonable
|
75
|
+
from inspect_ai.util._sandbox.context import sandbox_connections
|
74
76
|
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
|
75
77
|
from inspect_ai.util._subtask import init_subtask
|
76
78
|
|
@@ -79,10 +81,10 @@ from ..task import Task
|
|
79
81
|
from .error import SampleErrorHandler
|
80
82
|
from .generate import task_generate
|
81
83
|
from .images import (
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
84
|
+
sample_without_base64_content,
|
85
|
+
samples_with_base64_content,
|
86
|
+
state_without_base64_content,
|
87
|
+
states_with_base64_content,
|
86
88
|
)
|
87
89
|
from .log import TaskLogger, collect_eval_data, log_start
|
88
90
|
from .results import eval_results
|
@@ -533,11 +535,6 @@ async def task_run_sample(
|
|
533
535
|
else contextlib.nullcontext()
|
534
536
|
)
|
535
537
|
|
536
|
-
# use timeout if provided
|
537
|
-
timeout_cm = (
|
538
|
-
timeout(time_limit) if time_limit is not None else contextlib.nullcontext()
|
539
|
-
)
|
540
|
-
|
541
538
|
# helper to handle exceptions (will throw if we've exceeded the limit)
|
542
539
|
def handle_error(ex: BaseException) -> EvalError:
|
543
540
|
err = sample_error(ex)
|
@@ -547,7 +544,6 @@ async def task_run_sample(
|
|
547
544
|
# solver loop
|
548
545
|
async with (
|
549
546
|
semaphore_cm,
|
550
|
-
sandboxenv_cm,
|
551
547
|
active_sample(
|
552
548
|
task=task_name,
|
553
549
|
model=str(state.model),
|
@@ -561,125 +557,151 @@ async def task_run_sample(
|
|
561
557
|
) as active,
|
562
558
|
):
|
563
559
|
error: EvalError | None = None
|
560
|
+
results: dict[str, SampleScore] = {}
|
564
561
|
try:
|
565
|
-
async with
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
562
|
+
async with sandboxenv_cm:
|
563
|
+
try:
|
564
|
+
# update active sample wth sandboxes now that we are initialised
|
565
|
+
active.sandboxes = await sandbox_connections()
|
566
|
+
|
567
|
+
# initialise timeout context manager
|
568
|
+
timeout_cm = (
|
569
|
+
timeout(time_limit)
|
570
|
+
if time_limit is not None
|
571
|
+
else contextlib.nullcontext()
|
572
|
+
)
|
575
573
|
|
576
|
-
|
577
|
-
|
574
|
+
# run sample w/ optional timeout
|
575
|
+
async with timeout_cm:
|
576
|
+
# mark started
|
577
|
+
active.started = datetime.now().timestamp()
|
578
578
|
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
"Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
|
591
|
-
)
|
579
|
+
# sample init event (remove file bodies as they have content or absolute paths)
|
580
|
+
event_sample = sample.model_copy(
|
581
|
+
update=dict(files={k: "" for k in sample.files.keys()})
|
582
|
+
if sample.files
|
583
|
+
else None
|
584
|
+
)
|
585
|
+
transcript()._event(
|
586
|
+
SampleInitEvent(
|
587
|
+
sample=event_sample, state=state_jsonable(state)
|
588
|
+
)
|
589
|
+
)
|
592
590
|
|
593
|
-
|
594
|
-
|
591
|
+
# set progress for plan then run it
|
592
|
+
state = await plan(state, generate)
|
595
593
|
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
594
|
+
except TimeoutError:
|
595
|
+
if time_limit is not None:
|
596
|
+
transcript()._event(
|
597
|
+
SampleLimitEvent(
|
598
|
+
type="time",
|
599
|
+
message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)",
|
600
|
+
limit=time_limit,
|
601
|
+
)
|
602
|
+
)
|
603
|
+
else:
|
604
|
+
py_logger.warning(
|
605
|
+
"Unexpected timeout error reached top of sample stack. Are you handling TimeoutError when applying timeouts?"
|
606
|
+
)
|
605
607
|
|
606
|
-
|
607
|
-
|
608
|
-
case "score":
|
609
|
-
# continue to scoring (capture the most recent state)
|
610
|
-
state = sample_state() or state
|
611
|
-
case "error":
|
612
|
-
# default error handling
|
613
|
-
error = handle_error(ex)
|
608
|
+
# capture most recent state for scoring
|
609
|
+
state = sample_state() or state
|
614
610
|
|
615
|
-
|
616
|
-
|
611
|
+
except asyncio.CancelledError as ex:
|
612
|
+
if active.interrupt_action:
|
613
|
+
# record eve t
|
614
|
+
transcript()._event(
|
615
|
+
SampleLimitEvent(
|
616
|
+
type="operator",
|
617
|
+
message="Sample completed: interrupted by operator",
|
618
|
+
)
|
619
|
+
)
|
617
620
|
|
618
|
-
|
619
|
-
|
621
|
+
# handle the action
|
622
|
+
match active.interrupt_action:
|
623
|
+
case "score":
|
624
|
+
# continue to scoring (capture the most recent state)
|
625
|
+
state = sample_state() or state
|
626
|
+
case "error":
|
627
|
+
# default error handling
|
628
|
+
error = handle_error(ex)
|
629
|
+
|
630
|
+
else:
|
631
|
+
raise
|
632
|
+
|
633
|
+
except BaseException as ex:
|
634
|
+
error = handle_error(ex)
|
635
|
+
|
636
|
+
# set timeout for scoring. if the original timeout was never hit
|
637
|
+
# then just create a new timeout_cm targeting the original
|
638
|
+
# timeout time. if the original timeout was hit we still want
|
639
|
+
# to provide an opportunity for scoring, but we don't necessarily
|
640
|
+
# want to wait the full timeout again (especially in the case where
|
641
|
+
# the cause of the timeout is a hung container and scoring requires
|
642
|
+
# interacting with the container). as a middle ground we use half
|
643
|
+
# of the original timeout value for scoring.
|
644
|
+
if isinstance(timeout_cm, Timeout):
|
645
|
+
if not timeout_cm.expired():
|
646
|
+
timeout_cm = timeout_at(timeout_cm.when())
|
647
|
+
else:
|
648
|
+
assert time_limit
|
649
|
+
timeout_cm = timeout(time_limit / 2)
|
650
|
+
|
651
|
+
# scoring
|
652
|
+
try:
|
653
|
+
# timeout during scoring will result in an ordinary sample error
|
654
|
+
async with timeout_cm:
|
655
|
+
if scorers and error is None:
|
656
|
+
for scorer in scorers:
|
657
|
+
scorer_name = unique_scorer_name(
|
658
|
+
scorer, list(results.keys())
|
659
|
+
)
|
660
|
+
with transcript().step(name=scorer_name, type="scorer"):
|
661
|
+
score_result = (
|
662
|
+
await scorer(state, Target(sample.target))
|
663
|
+
if scorer
|
664
|
+
else None
|
665
|
+
)
|
666
|
+
if score_result is not None:
|
667
|
+
sample_score = SampleScore(
|
668
|
+
score=score_result,
|
669
|
+
sample_id=sample.id,
|
670
|
+
)
|
671
|
+
transcript()._event(
|
672
|
+
ScoreEvent(
|
673
|
+
score=score_result, target=sample.target
|
674
|
+
)
|
675
|
+
)
|
676
|
+
results[scorer_name] = sample_score
|
677
|
+
|
678
|
+
except asyncio.CancelledError:
|
679
|
+
if active.interrupt_action:
|
680
|
+
transcript()._event(
|
681
|
+
SampleLimitEvent(
|
682
|
+
type="operator",
|
683
|
+
message="Unable to score sample due to operator interruption",
|
684
|
+
)
|
685
|
+
)
|
620
686
|
|
621
|
-
|
622
|
-
# then just create a new timeout_cm targeting the original
|
623
|
-
# timeout time. if the original timeout was hit we still want
|
624
|
-
# to provide an opportunity for scoring, but we don't necessarily
|
625
|
-
# want to wait the full timeout again (especially in the case where
|
626
|
-
# the cause of the timeout is a hung container and scoring requires
|
627
|
-
# interacting with the container). as a middle ground we use half
|
628
|
-
# of the original timeout value for scoring.
|
629
|
-
if isinstance(timeout_cm, Timeout):
|
630
|
-
if not timeout_cm.expired():
|
631
|
-
timeout_cm = timeout_at(timeout_cm.when())
|
632
|
-
else:
|
633
|
-
assert time_limit
|
634
|
-
timeout_cm = timeout(time_limit / 2)
|
687
|
+
raise
|
635
688
|
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
with transcript().step(name=scorer_name, type="scorer"):
|
645
|
-
score_result = (
|
646
|
-
await scorer(state, Target(sample.target))
|
647
|
-
if scorer
|
648
|
-
else None
|
689
|
+
except BaseException as ex:
|
690
|
+
# note timeout
|
691
|
+
if isinstance(ex, TimeoutError):
|
692
|
+
transcript()._event(
|
693
|
+
SampleLimitEvent(
|
694
|
+
type="time",
|
695
|
+
message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)",
|
696
|
+
limit=time_limit,
|
649
697
|
)
|
650
|
-
|
651
|
-
sample_score = SampleScore(
|
652
|
-
score=score_result,
|
653
|
-
sample_id=sample.id,
|
654
|
-
)
|
655
|
-
transcript()._event(
|
656
|
-
ScoreEvent(score=score_result, target=sample.target)
|
657
|
-
)
|
658
|
-
results[scorer_name] = sample_score
|
659
|
-
|
660
|
-
except asyncio.CancelledError:
|
661
|
-
if active.interrupt_action:
|
662
|
-
transcript()._event(
|
663
|
-
SampleLimitEvent(
|
664
|
-
type="operator",
|
665
|
-
message="Unable to score sample due to operator interruption",
|
666
|
-
)
|
667
|
-
)
|
698
|
+
)
|
668
699
|
|
669
|
-
|
700
|
+
# handle error (this will throw if we've exceeded the limit)
|
701
|
+
error = handle_error(ex)
|
670
702
|
|
703
|
+
# handle sandboxenv init errors
|
671
704
|
except BaseException as ex:
|
672
|
-
# note timeout
|
673
|
-
if isinstance(ex, TimeoutError):
|
674
|
-
transcript()._event(
|
675
|
-
SampleLimitEvent(
|
676
|
-
type="time",
|
677
|
-
message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)",
|
678
|
-
limit=time_limit,
|
679
|
-
)
|
680
|
-
)
|
681
|
-
|
682
|
-
# handle error (this will throw if we've exceeded the limit)
|
683
705
|
error = handle_error(ex)
|
684
706
|
|
685
707
|
# complete the sample
|
@@ -689,12 +711,12 @@ async def task_run_sample(
|
|
689
711
|
if logger is not None:
|
690
712
|
# if we are logging images then be sure to base64 images injected by solvers
|
691
713
|
if log_images:
|
692
|
-
state = (await
|
714
|
+
state = (await states_with_base64_content([state]))[0]
|
693
715
|
|
694
716
|
# otherwise ensure there are no base64 images in sample or messages
|
695
717
|
else:
|
696
|
-
sample =
|
697
|
-
state =
|
718
|
+
sample = sample_without_base64_content(sample)
|
719
|
+
state = state_without_base64_content(state)
|
698
720
|
|
699
721
|
# log the sample
|
700
722
|
await log_sample(
|
@@ -784,7 +806,7 @@ async def resolve_dataset(
|
|
784
806
|
|
785
807
|
# if we are logging images then resolve sample images here
|
786
808
|
if log_images:
|
787
|
-
samples = await
|
809
|
+
samples = await samples_with_base64_content(samples)
|
788
810
|
|
789
811
|
# prime the eval tasks (deep copy so they share no state w/ sample)
|
790
812
|
sample_epochs: list[int] = []
|
inspect_ai/_eval/task/task.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from copy import deepcopy
|
1
2
|
from dataclasses import dataclass
|
2
3
|
from logging import getLogger
|
3
4
|
from typing import Any, Callable, Sequence, cast
|
@@ -6,6 +7,7 @@ from pydantic import BaseModel
|
|
6
7
|
from typing_extensions import TypedDict, Unpack
|
7
8
|
|
8
9
|
from inspect_ai._util.logger import warn_once
|
10
|
+
from inspect_ai._util.notgiven import NOT_GIVEN, NotGiven
|
9
11
|
from inspect_ai._util.registry import is_registry_object, registry_info
|
10
12
|
from inspect_ai.approval._policy import ApprovalPolicy, approval_policies_from_config
|
11
13
|
from inspect_ai.dataset import Dataset, MemoryDataset, Sample
|
@@ -115,35 +117,15 @@ class Task:
|
|
115
117
|
f"DEPRECATED: the '{arg}' parameter is deprecated (please use the '{newarg}' parameter instead)",
|
116
118
|
)
|
117
119
|
|
118
|
-
|
119
|
-
if isinstance(epochs, int):
|
120
|
-
epochs = Epochs(epochs)
|
121
|
-
if epochs is not None and epochs.epochs < 1:
|
122
|
-
raise ValueError("epochs must be a positive integer.")
|
123
|
-
|
124
|
-
# resolve dataset (provide empty sample to bootstrap tasks w/o samples,
|
125
|
-
# which could occur for testing or for an interactive mode eval)
|
126
|
-
dataset = dataset or [Sample(input="prompt")]
|
127
|
-
self.dataset: Dataset = (
|
128
|
-
dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
|
129
|
-
)
|
120
|
+
self.dataset = resolve_dataset(dataset)
|
130
121
|
self.setup = setup
|
131
|
-
self.solver =
|
132
|
-
self.scorer = (
|
133
|
-
scorer
|
134
|
-
if isinstance(scorer, list)
|
135
|
-
else [scorer]
|
136
|
-
if scorer is not None
|
137
|
-
else None
|
138
|
-
)
|
122
|
+
self.solver = resolve_solver(solver)
|
123
|
+
self.scorer = resolve_scorer(scorer)
|
139
124
|
self.metrics = metrics
|
140
125
|
self.config = config
|
141
126
|
self.sandbox = resolve_sandbox_environment(sandbox)
|
142
|
-
self.approval = (
|
143
|
-
|
144
|
-
if isinstance(approval, str)
|
145
|
-
else approval
|
146
|
-
)
|
127
|
+
self.approval = resolve_approval(approval)
|
128
|
+
epochs = resolve_epochs(epochs)
|
147
129
|
self.epochs = epochs.epochs if epochs else None
|
148
130
|
self.epochs_reducer = epochs.reducer if epochs else None
|
149
131
|
self.fail_on_error = fail_on_error
|
@@ -171,6 +153,106 @@ class Task:
|
|
171
153
|
return dict()
|
172
154
|
|
173
155
|
|
156
|
+
def task_with(
|
157
|
+
task: Task,
|
158
|
+
*,
|
159
|
+
dataset: Dataset | Sequence[Sample] | None | NotGiven = NOT_GIVEN,
|
160
|
+
setup: Solver | list[Solver] | None | NotGiven = NOT_GIVEN,
|
161
|
+
solver: Solver | list[Solver] | NotGiven = NOT_GIVEN,
|
162
|
+
scorer: Scorer | list[Scorer] | None | NotGiven = NOT_GIVEN,
|
163
|
+
metrics: list[Metric] | dict[str, list[Metric]] | None | NotGiven = NOT_GIVEN,
|
164
|
+
config: GenerateConfig | NotGiven = NOT_GIVEN,
|
165
|
+
sandbox: SandboxEnvironmentType | None | NotGiven = NOT_GIVEN,
|
166
|
+
approval: str | list[ApprovalPolicy] | None | NotGiven = NOT_GIVEN,
|
167
|
+
epochs: int | Epochs | None | NotGiven = NOT_GIVEN,
|
168
|
+
fail_on_error: bool | float | None | NotGiven = NOT_GIVEN,
|
169
|
+
message_limit: int | None | NotGiven = NOT_GIVEN,
|
170
|
+
token_limit: int | None | NotGiven = NOT_GIVEN,
|
171
|
+
time_limit: int | None | NotGiven = NOT_GIVEN,
|
172
|
+
name: str | None | NotGiven = NOT_GIVEN,
|
173
|
+
version: int | NotGiven = NOT_GIVEN,
|
174
|
+
metadata: dict[str, Any] | None | NotGiven = NOT_GIVEN,
|
175
|
+
) -> Task:
|
176
|
+
"""Task adapted with alternate values for one or more options.
|
177
|
+
|
178
|
+
Args:
|
179
|
+
task (Task): Task to adapt (it is deep copied prior to mutating options)
|
180
|
+
dataset (Dataset | Sequence[Sample]): Dataset to evaluate
|
181
|
+
setup: (Solver | list[Solver] | None): Setup step (always run
|
182
|
+
even when the main `solver` is replaced).
|
183
|
+
solver: (Solver | list[Solver]): Solver or list of solvers.
|
184
|
+
Defaults to generate(), a normal call to the model.
|
185
|
+
scorer: (Scorer | list[Scorer] | None): Scorer used to evaluate model output.
|
186
|
+
metrics (list[Metric] | dict[str, list[Metric]] | None):
|
187
|
+
Alternative metrics (overrides the metrics provided by the specified scorer).
|
188
|
+
config (GenerateConfig): Model generation config.
|
189
|
+
sandbox (SandboxEnvironmentType | None): Sandbox environment type
|
190
|
+
(or optionally a str or tuple with a shorthand spec)
|
191
|
+
approval: (str | list[ApprovalPolicy] | None): Tool use approval policies.
|
192
|
+
Either a path to an approval policy config file or a list of approval policies.
|
193
|
+
Defaults to no approval policy.
|
194
|
+
epochs (int | Epochs | None): Epochs to repeat samples for and optional score
|
195
|
+
reducer function(s) used to combine sample scores (defaults to "mean")
|
196
|
+
fail_on_error (bool | float | None): `True` to fail on first sample error
|
197
|
+
(default); `False` to never fail on sample errors; Value between 0 and 1
|
198
|
+
to fail if a proportion of total samples fails. Value greater than 1 to fail
|
199
|
+
eval if a count of samples fails.
|
200
|
+
message_limit (int | None): Limit on total messages used for each sample.
|
201
|
+
token_limit (int | None): Limit on total tokens used for each sample.
|
202
|
+
time_limit (int | None): Limit on time (in seconds) for execution of each sample.
|
203
|
+
name: (str | None): Task name. If not specified is automatically
|
204
|
+
determined based on the name of the task directory (or "task")
|
205
|
+
if its anonymous task (e.g. created in a notebook and passed to
|
206
|
+
eval() directly)
|
207
|
+
version: (int): Version of task (to distinguish evolutions
|
208
|
+
of the task spec or breaking changes to it)
|
209
|
+
metadata: (dict[str, Any] | None): Additional metadata to associate with the task.
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Task: Task adapted with alternate options.
|
213
|
+
"""
|
214
|
+
# deep copy the task
|
215
|
+
task = deepcopy(task)
|
216
|
+
|
217
|
+
if not isinstance(dataset, NotGiven):
|
218
|
+
task.dataset = resolve_dataset(dataset)
|
219
|
+
if not isinstance(setup, NotGiven):
|
220
|
+
task.setup = setup
|
221
|
+
if not isinstance(solver, NotGiven):
|
222
|
+
task.solver = resolve_solver(solver)
|
223
|
+
if not isinstance(scorer, NotGiven):
|
224
|
+
task.scorer = resolve_scorer(scorer)
|
225
|
+
if not isinstance(metrics, NotGiven):
|
226
|
+
task.metrics = metrics
|
227
|
+
if not isinstance(config, NotGiven):
|
228
|
+
task.config = config
|
229
|
+
if not isinstance(sandbox, NotGiven):
|
230
|
+
task.sandbox = resolve_sandbox_environment(sandbox)
|
231
|
+
if not isinstance(approval, NotGiven):
|
232
|
+
task.approval = resolve_approval(approval)
|
233
|
+
if not isinstance(epochs, NotGiven):
|
234
|
+
epochs = resolve_epochs(epochs)
|
235
|
+
task.epochs = epochs.epochs if epochs else None
|
236
|
+
task.epochs_reducer = epochs.reducer if epochs else None
|
237
|
+
if not isinstance(fail_on_error, NotGiven):
|
238
|
+
task.fail_on_error = fail_on_error
|
239
|
+
if not isinstance(message_limit, NotGiven):
|
240
|
+
task.message_limit = message_limit
|
241
|
+
if not isinstance(token_limit, NotGiven):
|
242
|
+
task.token_limit = token_limit
|
243
|
+
if not isinstance(time_limit, NotGiven):
|
244
|
+
task.time_limit = time_limit
|
245
|
+
if not isinstance(version, NotGiven):
|
246
|
+
task.version = version
|
247
|
+
if not isinstance(name, NotGiven):
|
248
|
+
task._name = name
|
249
|
+
if not isinstance(metadata, NotGiven):
|
250
|
+
task.metadata = metadata
|
251
|
+
|
252
|
+
# return modified task
|
253
|
+
return task
|
254
|
+
|
255
|
+
|
174
256
|
class TaskInfo(BaseModel):
|
175
257
|
"""Task information (file, name, and attributes)."""
|
176
258
|
|
@@ -225,3 +307,36 @@ classes, and task instances (a single task or list of tasks
|
|
225
307
|
can be specified). None is a request to read a task out
|
226
308
|
of the current working directory.
|
227
309
|
"""
|
310
|
+
|
311
|
+
|
312
|
+
def resolve_approval(
|
313
|
+
approval: str | list[ApprovalPolicy] | None,
|
314
|
+
) -> list[ApprovalPolicy] | None:
|
315
|
+
return (
|
316
|
+
approval_policies_from_config(approval)
|
317
|
+
if isinstance(approval, str)
|
318
|
+
else approval
|
319
|
+
)
|
320
|
+
|
321
|
+
|
322
|
+
def resolve_epochs(epochs: int | Epochs | None) -> Epochs | None:
|
323
|
+
if isinstance(epochs, int):
|
324
|
+
epochs = Epochs(epochs)
|
325
|
+
if epochs is not None and epochs.epochs < 1:
|
326
|
+
raise ValueError("epochs must be a positive integer.")
|
327
|
+
return epochs
|
328
|
+
|
329
|
+
|
330
|
+
def resolve_dataset(dataset: Dataset | Sequence[Sample] | None) -> Dataset:
|
331
|
+
dataset = dataset or [Sample(input="prompt")]
|
332
|
+
return dataset if isinstance(dataset, Dataset) else MemoryDataset(list(dataset))
|
333
|
+
|
334
|
+
|
335
|
+
def resolve_solver(solver: Solver | list[Solver]) -> Solver:
|
336
|
+
return chain(solver) if isinstance(solver, list) else solver
|
337
|
+
|
338
|
+
|
339
|
+
def resolve_scorer(scorer: Scorer | list[Scorer] | None) -> list[Scorer] | None:
|
340
|
+
return (
|
341
|
+
scorer if isinstance(scorer, list) else [scorer] if scorer is not None else None
|
342
|
+
)
|
inspect_ai/_util/constants.py
CHANGED