inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +7 -3
- inspect_ai/_cli/eval.py +17 -2
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +4 -3
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +4 -9
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +119 -16
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/score.py +1 -0
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/results.py +50 -22
- inspect_ai/_eval/task/run.py +180 -124
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25375 -1846
- inspect_ai/_view/www/log-schema.json +129 -15
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +8 -10
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +75 -2
- inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +62 -27
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/Json.mjs +12 -6
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +3 -6
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +26 -5
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +59 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/_conversation.py +61 -0
- inspect_ai/model/_generate_config.py +10 -4
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +7 -2
- inspect_ai/model/_providers/anthropic.py +109 -51
- inspect_ai/model/_providers/azureai.py +26 -24
- inspect_ai/model/_providers/bedrock.py +43 -44
- inspect_ai/model/_providers/google.py +121 -58
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +17 -20
- inspect_ai/model/_providers/openai.py +32 -21
- inspect_ai/model/_providers/openai_o1.py +9 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +8 -8
- inspect_ai/model/_providers/vertex.py +18 -8
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +2 -2
- inspect_ai/solver/__init__.py +2 -5
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +11 -1
- inspect_ai/tool/_tool.py +21 -3
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +84 -14
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/environment.py +27 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/model/_trace.py +0 -48
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
@@ -23,9 +23,15 @@ from vertexai.generative_models import ( # type: ignore
|
|
23
23
|
)
|
24
24
|
from vertexai.generative_models import Content as VertexContent
|
25
25
|
|
26
|
-
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
27
|
-
from inspect_ai._util.content import
|
28
|
-
|
26
|
+
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, NO_CONTENT
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
34
|
+
from inspect_ai._util.images import file_as_data
|
29
35
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo
|
30
36
|
|
31
37
|
from .._chat_message import (
|
@@ -244,9 +250,6 @@ def consective_tool_message_reducer(
|
|
244
250
|
return messages
|
245
251
|
|
246
252
|
|
247
|
-
NO_CONTENT = "(no content)"
|
248
|
-
|
249
|
-
|
250
253
|
async def content_dict(
|
251
254
|
message: ChatMessageUser | ChatMessageAssistant | ChatMessageTool,
|
252
255
|
) -> VertexContent:
|
@@ -308,9 +311,16 @@ async def content_part(content: Content | str) -> Part:
|
|
308
311
|
return Part.from_text(content or NO_CONTENT)
|
309
312
|
elif isinstance(content, ContentText):
|
310
313
|
return Part.from_text(content.text or NO_CONTENT)
|
311
|
-
|
312
|
-
image_bytes, mime_type = await
|
314
|
+
elif isinstance(content, ContentImage):
|
315
|
+
image_bytes, mime_type = await file_as_data(content.image)
|
313
316
|
return Part.from_image(image=Image.from_bytes(data=image_bytes))
|
317
|
+
else:
|
318
|
+
if isinstance(content, ContentAudio):
|
319
|
+
file = content.audio
|
320
|
+
elif isinstance(content, ContentVideo):
|
321
|
+
file = content.video
|
322
|
+
file_bytes, mime_type = await file_as_data(file)
|
323
|
+
return Part.from_data(file_bytes, mime_type)
|
314
324
|
|
315
325
|
|
316
326
|
def prepend_system_messages(
|
inspect_ai/scorer/__init__.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from inspect_ai._util.deprecation import relocated_module_attribute
|
2
|
+
|
1
3
|
from ._answer import AnswerPattern, answer
|
2
4
|
from ._choice import choice
|
3
5
|
from ._classification import exact, f1
|
@@ -16,7 +18,7 @@ from ._metric import (
|
|
16
18
|
)
|
17
19
|
from ._metrics.accuracy import accuracy
|
18
20
|
from ._metrics.mean import mean
|
19
|
-
from ._metrics.std import
|
21
|
+
from ._metrics.std import bootstrap_stderr, std, stderr
|
20
22
|
from ._model import model_graded_fact, model_graded_qa
|
21
23
|
from ._multi import multi_scorer
|
22
24
|
from ._pattern import pattern
|
@@ -50,7 +52,7 @@ __all__ = [
|
|
50
52
|
"Target",
|
51
53
|
"scorer",
|
52
54
|
"accuracy",
|
53
|
-
"
|
55
|
+
"bootstrap_stderr",
|
54
56
|
"std",
|
55
57
|
"stderr",
|
56
58
|
"mean",
|
@@ -76,3 +78,12 @@ __all__ = [
|
|
76
78
|
"at_least",
|
77
79
|
"pass_at",
|
78
80
|
]
|
81
|
+
_BOOTSTRAP_RENAME_VERSION = "0.3.58"
|
82
|
+
_REMOVED_IN = "0.4"
|
83
|
+
|
84
|
+
relocated_module_attribute(
|
85
|
+
"bootstrap_std",
|
86
|
+
"inspect_ai.scorer.bootstrap_stderr",
|
87
|
+
_BOOTSTRAP_RENAME_VERSION,
|
88
|
+
_REMOVED_IN,
|
89
|
+
)
|
@@ -1,12 +1,12 @@
|
|
1
1
|
from .accuracy import accuracy
|
2
2
|
from .mean import mean, var
|
3
|
-
from .std import
|
3
|
+
from .std import bootstrap_stderr, std, stderr
|
4
4
|
|
5
5
|
__all__ = [
|
6
6
|
"accuracy",
|
7
7
|
"mean",
|
8
8
|
"var",
|
9
|
-
"
|
9
|
+
"bootstrap_stderr",
|
10
10
|
"std",
|
11
11
|
"stderr",
|
12
12
|
]
|
@@ -15,10 +15,10 @@ logger = getLogger(__name__)
|
|
15
15
|
|
16
16
|
|
17
17
|
@metric
|
18
|
-
def
|
18
|
+
def bootstrap_stderr(
|
19
19
|
num_samples: int = 1000, to_float: ValueToFloat = value_to_float()
|
20
20
|
) -> Metric:
|
21
|
-
"""Standard
|
21
|
+
"""Standard error of the mean using bootstrap.
|
22
22
|
|
23
23
|
Args:
|
24
24
|
num_samples (int): Number of bootstrap samples to take.
|
@@ -31,7 +31,7 @@ def bootstrap_std(
|
|
31
31
|
0 if the Value is a complex object (list or dict).
|
32
32
|
|
33
33
|
Returns:
|
34
|
-
|
34
|
+
bootstrap_stderr metric
|
35
35
|
"""
|
36
36
|
|
37
37
|
def metric(scores: list[Score]) -> float:
|
inspect_ai/scorer/_scorer.py
CHANGED
@@ -151,8 +151,8 @@ def scorer_metrics(
|
|
151
151
|
return cast(list[Metric | dict[str, list[Metric]]], metrics_raw)
|
152
152
|
|
153
153
|
|
154
|
-
def unique_scorer_name(scorer: Scorer, already_used_names: list[str]) -> str:
|
155
|
-
base_name = registry_unqualified_name(scorer)
|
154
|
+
def unique_scorer_name(scorer: Scorer | str, already_used_names: list[str]) -> str:
|
155
|
+
base_name = scorer if isinstance(scorer, str) else registry_unqualified_name(scorer)
|
156
156
|
scorer_name = base_name
|
157
157
|
count = 1
|
158
158
|
while scorer_name in already_used_names:
|
inspect_ai/solver/__init__.py
CHANGED
@@ -7,11 +7,7 @@ from ._fork import fork
|
|
7
7
|
from ._human_agent.agent import human_agent
|
8
8
|
from ._multiple_choice import MultipleChoiceTemplate, multiple_choice
|
9
9
|
from ._plan import Plan, plan
|
10
|
-
from ._prompt import
|
11
|
-
chain_of_thought,
|
12
|
-
prompt_template,
|
13
|
-
system_message,
|
14
|
-
)
|
10
|
+
from ._prompt import chain_of_thought, prompt_template, system_message, user_message
|
15
11
|
from ._solver import Generate, Solver, SolverSpec, generate, solver
|
16
12
|
from ._task_state import Choice, Choices, TaskState
|
17
13
|
from ._use_tools import use_tools
|
@@ -26,6 +22,7 @@ __all__ = [
|
|
26
22
|
"chain_of_thought",
|
27
23
|
"multiple_choice",
|
28
24
|
"system_message",
|
25
|
+
"user_message",
|
29
26
|
"self_critique",
|
30
27
|
"use_tools",
|
31
28
|
"plan",
|
inspect_ai/solver/_prompt.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Any
|
|
2
2
|
|
3
3
|
from inspect_ai._util.dict import omit
|
4
4
|
from inspect_ai.model import ChatMessageSystem
|
5
|
+
from inspect_ai.model._chat_message import ChatMessageUser
|
5
6
|
from inspect_ai.util import resource
|
6
7
|
|
7
8
|
from ._solver import Generate, Solver, solver
|
@@ -15,7 +16,8 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
15
16
|
|
16
17
|
Prompt template containing a `{prompt}` placeholder and any
|
17
18
|
number of additional `params`. All values contained in sample
|
18
|
-
`metadata` are also automatically included in the
|
19
|
+
`metadata` and `store` are also automatically included in the
|
20
|
+
`params`.
|
19
21
|
|
20
22
|
Args:
|
21
23
|
template: (str): Template for prompt.
|
@@ -29,7 +31,7 @@ def prompt_template(template: str, **params: Any) -> Solver:
|
|
29
31
|
|
30
32
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
31
33
|
prompt = state.user_prompt
|
32
|
-
kwargs = omit(state.metadata, ["prompt"]) | params
|
34
|
+
kwargs = omit(state.metadata | state.store._data, ["prompt"]) | params
|
33
35
|
prompt.text = prompt_template.format(prompt=prompt.text, **kwargs)
|
34
36
|
return state
|
35
37
|
|
@@ -41,8 +43,9 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
41
43
|
"""Solver which inserts a system message into the conversation.
|
42
44
|
|
43
45
|
System message template containing any number of optional `params`.
|
44
|
-
for substitution
|
45
|
-
|
46
|
+
for substitution using the `str.format()` method. All values
|
47
|
+
contained in sample `metadata` and `store` are also automatically
|
48
|
+
included in the `params`.
|
46
49
|
|
47
50
|
The new message will go after other system messages (if there
|
48
51
|
are none it will be inserted at the beginning of the conversation).
|
@@ -58,7 +61,7 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
58
61
|
content = resource(template)
|
59
62
|
|
60
63
|
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
61
|
-
kwargs = state.metadata | params
|
64
|
+
kwargs = state.metadata | state.store._data | params
|
62
65
|
append_system_message(
|
63
66
|
state.messages, ChatMessageSystem(content=content.format(**kwargs))
|
64
67
|
)
|
@@ -67,6 +70,33 @@ def system_message(template: str, **params: Any) -> Solver:
|
|
67
70
|
return solve
|
68
71
|
|
69
72
|
|
73
|
+
@solver
|
74
|
+
def user_message(template: str, **params: Any) -> Solver:
|
75
|
+
"""Solver which inserts a user message into the conversation.
|
76
|
+
|
77
|
+
User message template containing any number of optional `params`.
|
78
|
+
for substitution using the `str.format()` method. All values
|
79
|
+
contained in sample `metadata` and `store` are also automatically
|
80
|
+
included in the `params`.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
template (str): Template for user message.
|
84
|
+
**params (dict[str,Any]): Parameters to fill into the template.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
A solver that inserts the parameterised user message.
|
88
|
+
"""
|
89
|
+
# read template
|
90
|
+
content = resource(template)
|
91
|
+
|
92
|
+
async def solve(state: TaskState, generate: Generate) -> TaskState:
|
93
|
+
kwargs = state.metadata | state.store._data | params
|
94
|
+
state.messages.append(ChatMessageUser(content=content.format(**kwargs)))
|
95
|
+
return state
|
96
|
+
|
97
|
+
return solve
|
98
|
+
|
99
|
+
|
70
100
|
DEFAULT_COT_TEMPLATE = r"""
|
71
101
|
{prompt}
|
72
102
|
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -2,8 +2,9 @@ from collections.abc import Sequence
|
|
2
2
|
from contextvars import ContextVar
|
3
3
|
from copy import deepcopy
|
4
4
|
from dataclasses import dataclass
|
5
|
+
from itertools import tee
|
5
6
|
from random import Random
|
6
|
-
from typing import Any, Type, Union, cast, overload
|
7
|
+
from typing import Any, Iterable, SupportsIndex, Type, Union, cast, overload
|
7
8
|
|
8
9
|
from pydantic_core import to_jsonable_python
|
9
10
|
|
@@ -15,9 +16,13 @@ from inspect_ai.model import (
|
|
15
16
|
ModelOutput,
|
16
17
|
)
|
17
18
|
from inspect_ai.model._call_tools import tools_info
|
19
|
+
from inspect_ai.model._chat_message import ChatMessageBase
|
18
20
|
from inspect_ai.model._model import sample_total_tokens
|
21
|
+
from inspect_ai.scorer._metric import Score
|
22
|
+
from inspect_ai.scorer._target import Target
|
19
23
|
from inspect_ai.tool import Tool, ToolChoice
|
20
24
|
from inspect_ai.tool._tool_def import ToolDef
|
25
|
+
from inspect_ai.util._limit import SampleLimitExceededError
|
21
26
|
from inspect_ai.util._store import Store, store_jsonable
|
22
27
|
from inspect_ai.util._store_model import SMT
|
23
28
|
|
@@ -136,6 +141,7 @@ class TaskState:
|
|
136
141
|
epoch: int,
|
137
142
|
input: str | list[ChatMessage],
|
138
143
|
messages: list[ChatMessage],
|
144
|
+
target: Target = Target(""),
|
139
145
|
choices: list[str] | None = [],
|
140
146
|
output: ModelOutput | None = None,
|
141
147
|
message_limit: int | None = None,
|
@@ -161,10 +167,13 @@ class TaskState:
|
|
161
167
|
or `input_text` only
|
162
168
|
"""
|
163
169
|
|
170
|
+
self.target = target
|
171
|
+
"""The scoring target for this `Sample`."""
|
172
|
+
|
164
173
|
self.metadata = metadata
|
165
174
|
"""Metadata from the `Sample` for this `TaskState`"""
|
166
175
|
|
167
|
-
self.
|
176
|
+
self._messages: list[ChatMessage] = ChatMessageList(messages)
|
168
177
|
"""
|
169
178
|
Chat conversation history for sample.
|
170
179
|
|
@@ -189,9 +198,7 @@ class TaskState:
|
|
189
198
|
"""
|
190
199
|
|
191
200
|
self._message_limit = message_limit
|
192
|
-
self._message_limit_exceeded = False
|
193
201
|
self._token_limit = token_limit
|
194
|
-
self._token_limit_exceeded = False
|
195
202
|
self._completed = completed
|
196
203
|
|
197
204
|
"""Store for shared data"""
|
@@ -202,6 +209,9 @@ class TaskState:
|
|
202
209
|
else:
|
203
210
|
self.choices = Choices([])
|
204
211
|
|
212
|
+
self.scores: dict[str, Score] | None = None
|
213
|
+
"""Scores yielded by running task."""
|
214
|
+
|
205
215
|
@property
|
206
216
|
def model(self) -> ModelName:
|
207
217
|
"""Name of model being evaluated."""
|
@@ -254,6 +264,16 @@ class TaskState:
|
|
254
264
|
else:
|
255
265
|
raise ValueError("user_prompt requested from TaskState but none available")
|
256
266
|
|
267
|
+
@property
|
268
|
+
def messages(self) -> list[ChatMessage]:
|
269
|
+
"""Messages in chat history"""
|
270
|
+
return self._messages
|
271
|
+
|
272
|
+
@messages.setter
|
273
|
+
def messages(self, messages: list[ChatMessage]) -> None:
|
274
|
+
"""Set messages in chat history."""
|
275
|
+
self._messages = ChatMessageList(messages)
|
276
|
+
|
257
277
|
@property
|
258
278
|
def max_messages(self) -> int | None:
|
259
279
|
"""Deprecated (use message_limit)."""
|
@@ -300,40 +320,7 @@ class TaskState:
|
|
300
320
|
@property
|
301
321
|
def completed(self) -> bool:
|
302
322
|
"""Is the task completed."""
|
303
|
-
|
304
|
-
from inspect_ai.log._samples import set_active_sample_total_messages
|
305
|
-
from inspect_ai.log._transcript import SampleLimitEvent, transcript
|
306
|
-
|
307
|
-
set_active_sample_total_messages(len(self.messages))
|
308
|
-
|
309
|
-
if self._completed:
|
310
|
-
return True
|
311
|
-
elif self.message_limit and len(self.messages) >= self.message_limit:
|
312
|
-
# log if this is the first time we hit this
|
313
|
-
if not self._message_limit_exceeded:
|
314
|
-
self._message_limit_exceeded = True
|
315
|
-
transcript()._event(
|
316
|
-
SampleLimitEvent(
|
317
|
-
type="message",
|
318
|
-
message=f"Sample completed: exceeded message limit ({self.message_limit})",
|
319
|
-
limit=self.message_limit,
|
320
|
-
)
|
321
|
-
)
|
322
|
-
return True
|
323
|
-
elif self.token_limit and self.token_usage >= self.token_limit:
|
324
|
-
# log if this is the first time we hit this
|
325
|
-
if not self._token_limit_exceeded:
|
326
|
-
self._token_limit_exceeded = True
|
327
|
-
transcript()._event(
|
328
|
-
SampleLimitEvent(
|
329
|
-
type="token",
|
330
|
-
message=f"Sample completed: exceeded token limit ({self.token_limit:,})",
|
331
|
-
limit=self.token_limit,
|
332
|
-
)
|
333
|
-
)
|
334
|
-
return True
|
335
|
-
else:
|
336
|
-
return False
|
323
|
+
return self._completed
|
337
324
|
|
338
325
|
@completed.setter
|
339
326
|
def completed(self, completed: bool) -> None:
|
@@ -413,3 +400,58 @@ def state_jsonable(state: TaskState | None = None) -> dict[str, Any]:
|
|
413
400
|
def sample_jsonable(sample: Sample) -> dict[str, Any]:
|
414
401
|
jsonable = to_jsonable_python(sample, exclude_none=True, fallback=lambda _x: None)
|
415
402
|
return cast(dict[str, Any], deepcopy(jsonable))
|
403
|
+
|
404
|
+
|
405
|
+
class ChatMessageList(list[ChatMessage]):
|
406
|
+
def __init__(self, iterable: Iterable[ChatMessage]):
|
407
|
+
items, length = self._iterable_length(iterable)
|
408
|
+
self._check_size(length)
|
409
|
+
super().__init__(items)
|
410
|
+
|
411
|
+
def _check_size(self, additional_items: int = 1) -> None:
|
412
|
+
from inspect_ai.log._samples import active_sample_message_limit
|
413
|
+
|
414
|
+
messages_limit = active_sample_message_limit()
|
415
|
+
if messages_limit is not None:
|
416
|
+
messages = len(self) + additional_items
|
417
|
+
if messages > messages_limit:
|
418
|
+
raise SampleLimitExceededError(
|
419
|
+
"message", value=messages, limit=messages_limit
|
420
|
+
)
|
421
|
+
|
422
|
+
def append(self, item: ChatMessage) -> None:
|
423
|
+
self._check_size()
|
424
|
+
super().append(item)
|
425
|
+
|
426
|
+
def extend(self, items: Iterable[ChatMessage]) -> None:
|
427
|
+
items, length = self._iterable_length(items)
|
428
|
+
self._check_size(length)
|
429
|
+
super().extend(items)
|
430
|
+
|
431
|
+
def insert(self, index: SupportsIndex, item: ChatMessage) -> None:
|
432
|
+
self._check_size()
|
433
|
+
super().insert(index, item)
|
434
|
+
|
435
|
+
@overload
|
436
|
+
def __setitem__(self, index: SupportsIndex, item: ChatMessage) -> None: ...
|
437
|
+
|
438
|
+
@overload
|
439
|
+
def __setitem__(self, index: slice, item: Iterable[ChatMessage]) -> None: ...
|
440
|
+
|
441
|
+
def __setitem__(
|
442
|
+
self, index: SupportsIndex | slice, item: ChatMessage | Iterable[ChatMessage]
|
443
|
+
) -> None:
|
444
|
+
if isinstance(index, slice) and not isinstance(item, ChatMessageBase):
|
445
|
+
item, length = self._iterable_length(item)
|
446
|
+
size_change = length - len(self[index])
|
447
|
+
if size_change > 0:
|
448
|
+
self._check_size(size_change)
|
449
|
+
|
450
|
+
super().__setitem__(index, item) # type: ignore[assignment,index]
|
451
|
+
|
452
|
+
def _iterable_length(
|
453
|
+
self, items: Iterable[ChatMessage]
|
454
|
+
) -> tuple[Iterable[ChatMessage], int]:
|
455
|
+
items, counter = tee(items)
|
456
|
+
length = sum(1 for _ in counter)
|
457
|
+
return items, length
|
inspect_ai/tool/__init__.py
CHANGED
@@ -1,4 +1,10 @@
|
|
1
|
-
from inspect_ai._util.content import
|
1
|
+
from inspect_ai._util.content import (
|
2
|
+
Content,
|
3
|
+
ContentAudio,
|
4
|
+
ContentImage,
|
5
|
+
ContentText,
|
6
|
+
ContentVideo,
|
7
|
+
)
|
2
8
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
3
9
|
|
4
10
|
from ._tool import Tool, ToolError, ToolResult, tool
|
@@ -6,6 +12,7 @@ from ._tool_call import (
|
|
6
12
|
ToolCall,
|
7
13
|
ToolCallContent,
|
8
14
|
ToolCallError,
|
15
|
+
ToolCallModelInput,
|
9
16
|
ToolCallView,
|
10
17
|
ToolCallViewer,
|
11
18
|
)
|
@@ -30,10 +37,13 @@ __all__ = [
|
|
30
37
|
"ToolError",
|
31
38
|
"ToolResult",
|
32
39
|
"Content",
|
40
|
+
"ContentAudio",
|
33
41
|
"ContentImage",
|
34
42
|
"ContentText",
|
43
|
+
"ContentVideo",
|
35
44
|
"ToolCall",
|
36
45
|
"ToolCallContent",
|
46
|
+
"ToolCallModelInput",
|
37
47
|
"ToolCallView",
|
38
48
|
"ToolCallViewer",
|
39
49
|
"ToolChoice",
|
inspect_ai/tool/_tool.py
CHANGED
@@ -11,7 +11,12 @@ from typing import (
|
|
11
11
|
runtime_checkable,
|
12
12
|
)
|
13
13
|
|
14
|
-
from inspect_ai._util.content import
|
14
|
+
from inspect_ai._util.content import (
|
15
|
+
ContentAudio,
|
16
|
+
ContentImage,
|
17
|
+
ContentText,
|
18
|
+
ContentVideo,
|
19
|
+
)
|
15
20
|
from inspect_ai._util.registry import (
|
16
21
|
RegistryInfo,
|
17
22
|
registry_add,
|
@@ -19,7 +24,7 @@ from inspect_ai._util.registry import (
|
|
19
24
|
registry_tag,
|
20
25
|
)
|
21
26
|
|
22
|
-
from ._tool_call import ToolCallViewer
|
27
|
+
from ._tool_call import ToolCallModelInput, ToolCallViewer
|
23
28
|
|
24
29
|
logger = getLogger(__name__)
|
25
30
|
|
@@ -31,7 +36,9 @@ ToolResult = (
|
|
31
36
|
| bool
|
32
37
|
| ContentText
|
33
38
|
| ContentImage
|
34
|
-
|
|
39
|
+
| ContentAudio
|
40
|
+
| ContentVideo
|
41
|
+
| list[ContentText | ContentImage | ContentAudio | ContentVideo]
|
35
42
|
)
|
36
43
|
|
37
44
|
|
@@ -105,6 +112,7 @@ def tool(
|
|
105
112
|
*,
|
106
113
|
name: str | None = None,
|
107
114
|
viewer: ToolCallViewer | None = None,
|
115
|
+
model_input: ToolCallModelInput | None = None,
|
108
116
|
parallel: bool = True,
|
109
117
|
prompt: str | None = None,
|
110
118
|
) -> Callable[[Callable[P, Tool]], Callable[P, Tool]]: ...
|
@@ -115,6 +123,7 @@ def tool(
|
|
115
123
|
*,
|
116
124
|
name: str | None = None,
|
117
125
|
viewer: ToolCallViewer | None = None,
|
126
|
+
model_input: ToolCallModelInput | None = None,
|
118
127
|
parallel: bool = True,
|
119
128
|
prompt: str | None = None,
|
120
129
|
) -> Callable[P, Tool] | Callable[[Callable[P, Tool]], Callable[P, Tool]]:
|
@@ -128,6 +137,8 @@ def tool(
|
|
128
137
|
will be used as the name of the tool.
|
129
138
|
viewer (ToolCallViewer | None): Provide a custom view
|
130
139
|
of tool call and context.
|
140
|
+
model_input (ToolCallModelInput | None): Provide a custom
|
141
|
+
function for playing back tool results as model input.
|
131
142
|
parallel (bool):
|
132
143
|
Does this tool support parallel execution?
|
133
144
|
(defaults to True).
|
@@ -169,6 +180,9 @@ def tool(
|
|
169
180
|
TOOL_PROMPT: prompt,
|
170
181
|
TOOL_PARALLEL: parallel,
|
171
182
|
TOOL_VIEWER: viewer,
|
183
|
+
TOOL_MODEL_INPUT: (
|
184
|
+
model_input or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
|
185
|
+
),
|
172
186
|
},
|
173
187
|
),
|
174
188
|
*args,
|
@@ -188,3 +202,7 @@ def tool(
|
|
188
202
|
TOOL_PROMPT = "prompt"
|
189
203
|
TOOL_PARALLEL = "parallel"
|
190
204
|
TOOL_VIEWER = "viewer"
|
205
|
+
TOOL_MODEL_INPUT = "model_input"
|
206
|
+
|
207
|
+
|
208
|
+
TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
|
inspect_ai/tool/_tool_call.py
CHANGED
@@ -3,6 +3,8 @@ from typing import Any, Callable, Literal
|
|
3
3
|
|
4
4
|
from pydantic import BaseModel, Field
|
5
5
|
|
6
|
+
from inspect_ai._util.content import Content
|
7
|
+
|
6
8
|
|
7
9
|
class ToolCallContent(BaseModel):
|
8
10
|
"""Content to include in tool call view."""
|
@@ -71,3 +73,11 @@ class ToolCallError:
|
|
71
73
|
|
72
74
|
ToolCallViewer = Callable[[ToolCall], ToolCallView]
|
73
75
|
"""Custom view renderer for tool calls."""
|
76
|
+
|
77
|
+
|
78
|
+
ToolCallModelInput = Callable[[int, int, str | list[Content]], str | list[Content]]
|
79
|
+
"""Determine how tool call results are played back as model input.
|
80
|
+
|
81
|
+
The first argument is an index into the total number of tool results
|
82
|
+
for this tool in the message history, the second is the total number.
|
83
|
+
"""
|
inspect_ai/tool/_tool_def.py
CHANGED
@@ -13,8 +13,8 @@ from inspect_ai._util.registry import (
|
|
13
13
|
set_registry_params,
|
14
14
|
)
|
15
15
|
|
16
|
-
from ._tool import TOOL_PARALLEL, TOOL_PROMPT, TOOL_VIEWER, Tool
|
17
|
-
from ._tool_call import ToolCallViewer
|
16
|
+
from ._tool import TOOL_MODEL_INPUT, TOOL_PARALLEL, TOOL_PROMPT, TOOL_VIEWER, Tool
|
17
|
+
from ._tool_call import ToolCallModelInput, ToolCallViewer
|
18
18
|
from ._tool_description import (
|
19
19
|
ToolDescription,
|
20
20
|
set_tool_description,
|
@@ -33,6 +33,7 @@ class ToolDef:
|
|
33
33
|
parameters: dict[str, str] | ToolParams | None = None,
|
34
34
|
parallel: bool | None = None,
|
35
35
|
viewer: ToolCallViewer | None = None,
|
36
|
+
model_input: ToolCallModelInput | None = None,
|
36
37
|
) -> None:
|
37
38
|
"""Tool definition.
|
38
39
|
|
@@ -46,6 +47,8 @@ class ToolDef:
|
|
46
47
|
parallel (bool | None): Does the tool support parallel execution
|
47
48
|
(defaults to True if not specified)
|
48
49
|
viewer (ToolCallViewer | None): Optional tool call viewer implementation.
|
50
|
+
model_input (ToolCallModelInput | None): Optional function that determines how
|
51
|
+
tool call results are played back as model input.
|
49
52
|
|
50
53
|
Returns:
|
51
54
|
Tool definition.
|
@@ -68,6 +71,7 @@ class ToolDef:
|
|
68
71
|
parameters = parameters if parameters is not None else tdef.parameters
|
69
72
|
self.parallel = parallel if parallel is not None else tdef.parallel
|
70
73
|
self.viewer = viewer or tdef.viewer
|
74
|
+
self.model_input = model_input or tdef.model_input
|
71
75
|
|
72
76
|
# if its not a tool then extract tool_info if all fields have not
|
73
77
|
# been provided explicitly
|
@@ -97,6 +101,7 @@ class ToolDef:
|
|
97
101
|
# behavioral attributes
|
98
102
|
self.parallel = parallel is not False
|
99
103
|
self.viewer = viewer
|
104
|
+
self.model_input = model_input
|
100
105
|
|
101
106
|
tool: Callable[..., Any]
|
102
107
|
"""Callable to execute tool."""
|
@@ -116,6 +121,9 @@ class ToolDef:
|
|
116
121
|
viewer: ToolCallViewer | None
|
117
122
|
"""Custom viewer for tool call"""
|
118
123
|
|
124
|
+
model_input: ToolCallModelInput | None
|
125
|
+
"""Custom model input presenter for tool calls."""
|
126
|
+
|
119
127
|
def as_tool(self) -> Tool:
|
120
128
|
"""Convert a ToolDef to a Tool."""
|
121
129
|
tool = self.tool
|
@@ -159,11 +167,12 @@ class ToolDefFields(NamedTuple):
|
|
159
167
|
parameters: ToolParams
|
160
168
|
parallel: bool
|
161
169
|
viewer: ToolCallViewer | None
|
170
|
+
model_input: ToolCallModelInput | None
|
162
171
|
|
163
172
|
|
164
173
|
def tool_def_fields(tool: Tool) -> ToolDefFields:
|
165
174
|
# get tool_info
|
166
|
-
name, prompt, parallel, viewer = tool_registry_info(tool)
|
175
|
+
name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
|
167
176
|
tool_info = parse_tool_info(tool)
|
168
177
|
|
169
178
|
# if there is a description then append any prompt to the
|
@@ -213,15 +222,17 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
|
|
213
222
|
parameters=tool_info.parameters,
|
214
223
|
parallel=parallel,
|
215
224
|
viewer=viewer,
|
225
|
+
model_input=model_input,
|
216
226
|
)
|
217
227
|
|
218
228
|
|
219
229
|
def tool_registry_info(
|
220
230
|
tool: Tool,
|
221
|
-
) -> tuple[str, str | None, bool, ToolCallViewer | None]:
|
231
|
+
) -> tuple[str, str | None, bool, ToolCallViewer | None, ToolCallModelInput | None]:
|
222
232
|
info = registry_info(tool)
|
223
233
|
name = info.name.split("/")[-1]
|
224
234
|
prompt = info.metadata.get(TOOL_PROMPT, None)
|
225
235
|
parallel = info.metadata.get(TOOL_PARALLEL, True)
|
226
236
|
viewer = info.metadata.get(TOOL_VIEWER, None)
|
227
|
-
|
237
|
+
model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
|
238
|
+
return name, prompt, parallel, viewer, model_input
|