inspect-ai 0.3.57__py3-none-any.whl → 0.3.59__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +7 -3
- inspect_ai/_cli/eval.py +17 -2
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +4 -3
- inspect_ai/_display/core/config.py +3 -3
- inspect_ai/_display/core/panel.py +7 -3
- inspect_ai/_display/plain/__init__.py +0 -0
- inspect_ai/_display/plain/display.py +203 -0
- inspect_ai/_display/rich/display.py +4 -9
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/port_mappings.py +110 -0
- inspect_ai/_display/textual/widgets/samples.py +119 -16
- inspect_ai/_display/textual/widgets/sandbox.py +37 -0
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/score.py +1 -0
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/results.py +50 -22
- inspect_ai/_eval/task/run.py +180 -124
- inspect_ai/_eval/task/sandbox.py +10 -5
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +2 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/port_names.py +61 -0
- inspect_ai/_util/text.py +23 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/App.css +31 -1
- inspect_ai/_view/www/dist/assets/index.css +31 -1
- inspect_ai/_view/www/dist/assets/index.js +25375 -1846
- inspect_ai/_view/www/log-schema.json +129 -15
- inspect_ai/_view/www/package.json +2 -0
- inspect_ai/_view/www/src/App.mjs +8 -10
- inspect_ai/_view/www/src/Types.mjs +0 -1
- inspect_ai/_view/www/src/components/ChatView.mjs +133 -43
- inspect_ai/_view/www/src/components/ExpandablePanel.mjs +0 -4
- inspect_ai/_view/www/src/components/LargeModal.mjs +19 -20
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/components/TabSet.mjs +3 -1
- inspect_ai/_view/www/src/components/VirtualList.mjs +266 -84
- inspect_ai/_view/www/src/index.js +75 -2
- inspect_ai/_view/www/src/navbar/Navbar.mjs +3 -0
- inspect_ai/_view/www/src/navbar/SecondaryBar.mjs +18 -9
- inspect_ai/_view/www/src/samples/SampleDialog.mjs +5 -1
- inspect_ai/_view/www/src/samples/SampleDisplay.mjs +23 -15
- inspect_ai/_view/www/src/samples/SampleList.mjs +18 -48
- inspect_ai/_view/www/src/samples/SampleTranscript.mjs +8 -3
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +29 -13
- inspect_ai/_view/www/src/samples/SamplesTab.mjs +4 -1
- inspect_ai/_view/www/src/samples/SamplesTools.mjs +8 -8
- inspect_ai/_view/www/src/samples/tools/SampleFilter.mjs +712 -89
- inspect_ai/_view/www/src/samples/tools/filters.mjs +260 -87
- inspect_ai/_view/www/src/samples/transcript/ErrorEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/EventPanel.mjs +29 -24
- inspect_ai/_view/www/src/samples/transcript/EventRow.mjs +1 -1
- inspect_ai/_view/www/src/samples/transcript/InfoEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/InputEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/ModelEventView.mjs +31 -10
- inspect_ai/_view/www/src/samples/transcript/SampleInitEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.mjs +23 -2
- inspect_ai/_view/www/src/samples/transcript/ScoreEventView.mjs +24 -2
- inspect_ai/_view/www/src/samples/transcript/StepEventView.mjs +33 -3
- inspect_ai/_view/www/src/samples/transcript/SubtaskEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.mjs +25 -2
- inspect_ai/_view/www/src/samples/transcript/TranscriptView.mjs +193 -11
- inspect_ai/_view/www/src/samples/transcript/Types.mjs +10 -0
- inspect_ai/_view/www/src/samples/transcript/state/StateEventView.mjs +26 -2
- inspect_ai/_view/www/src/types/log.d.ts +62 -27
- inspect_ai/_view/www/src/utils/Format.mjs +10 -3
- inspect_ai/_view/www/src/utils/Json.mjs +12 -6
- inspect_ai/_view/www/src/workspace/WorkSpace.mjs +10 -4
- inspect_ai/_view/www/vite.config.js +7 -0
- inspect_ai/_view/www/yarn.lock +116 -0
- inspect_ai/approval/_human/__init__.py +0 -0
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/approval/_policy.py +12 -6
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +3 -6
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +26 -5
- inspect_ai/log/_transcript.py +32 -2
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +59 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/_conversation.py +61 -0
- inspect_ai/model/_generate_config.py +10 -4
- inspect_ai/model/_model.py +117 -18
- inspect_ai/model/_model_output.py +7 -2
- inspect_ai/model/_providers/anthropic.py +109 -51
- inspect_ai/model/_providers/azureai.py +26 -24
- inspect_ai/model/_providers/bedrock.py +43 -44
- inspect_ai/model/_providers/google.py +121 -58
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +17 -20
- inspect_ai/model/_providers/openai.py +32 -21
- inspect_ai/model/_providers/openai_o1.py +9 -8
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/together.py +8 -8
- inspect_ai/model/_providers/vertex.py +18 -8
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/scorer/_scorer.py +2 -2
- inspect_ai/solver/__init__.py +2 -5
- inspect_ai/solver/_prompt.py +35 -5
- inspect_ai/solver/_task_state.py +80 -38
- inspect_ai/tool/__init__.py +11 -1
- inspect_ai/tool/_tool.py +21 -3
- inspect_ai/tool/_tool_call.py +10 -0
- inspect_ai/tool/_tool_def.py +16 -5
- inspect_ai/tool/_tool_with.py +21 -4
- inspect_ai/tool/beta/__init__.py +5 -0
- inspect_ai/tool/beta/_computer/__init__.py +3 -0
- inspect_ai/tool/beta/_computer/_common.py +133 -0
- inspect_ai/tool/beta/_computer/_computer.py +155 -0
- inspect_ai/tool/beta/_computer/_computer_split.py +198 -0
- inspect_ai/tool/beta/_computer/_resources/Dockerfile +100 -0
- inspect_ai/tool/beta/_computer/_resources/README.md +30 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/entrypoint.sh +18 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/novnc_startup.sh +20 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/x11vnc_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xfce_startup.sh +13 -0
- inspect_ai/tool/beta/_computer/_resources/entrypoint/xvfb_startup.sh +48 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Firefox Web Browser.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/Visual Studio Code.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/image_home_dir/Desktop/XPaint.desktop +10 -0
- inspect_ai/tool/beta/_computer/_resources/tool/__init__.py +0 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_logger.py +22 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_run.py +42 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_tool_result.py +33 -0
- inspect_ai/tool/beta/_computer/_resources/tool/_x11_client.py +262 -0
- inspect_ai/tool/beta/_computer/_resources/tool/computer_tool.py +85 -0
- inspect_ai/tool/beta/_computer/_resources/tool/requirements.txt +0 -0
- inspect_ai/util/__init__.py +2 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_limit.py +26 -0
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +84 -14
- inspect_ai/util/_sandbox/docker/internal.py +3 -1
- inspect_ai/util/_sandbox/environment.py +27 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/RECORD +159 -128
- inspect_ai/_view/www/src/samples/transcript/TranscriptState.mjs +0 -70
- inspect_ai/model/_trace.py +0 -48
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.59.dist-info}/top_level.txt +0 -0
inspect_ai/log/_transcript.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import contextlib
|
2
3
|
from contextvars import ContextVar
|
3
4
|
from datetime import datetime
|
@@ -11,7 +12,7 @@ from typing import (
|
|
11
12
|
Union,
|
12
13
|
)
|
13
14
|
|
14
|
-
from pydantic import BaseModel, Field, JsonValue, field_serializer
|
15
|
+
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
|
15
16
|
|
16
17
|
from inspect_ai._util.constants import SAMPLE_SUBTASK
|
17
18
|
from inspect_ai._util.error import EvalError
|
@@ -69,7 +70,7 @@ class SampleLimitEvent(BaseEvent):
|
|
69
70
|
event: Literal["sample_limit"] = Field(default="sample_limit")
|
70
71
|
"""Event type."""
|
71
72
|
|
72
|
-
type: Literal["message", "time", "token", "operator"]
|
73
|
+
type: Literal["message", "time", "token", "operator", "custom"]
|
73
74
|
"""Type of limit that halted processing"""
|
74
75
|
|
75
76
|
message: str
|
@@ -123,6 +124,9 @@ class ModelEvent(BaseEvent):
|
|
123
124
|
output: ModelOutput
|
124
125
|
"""Output from model."""
|
125
126
|
|
127
|
+
error: str | None = Field(default=None)
|
128
|
+
"""Error which occurred during model call."""
|
129
|
+
|
126
130
|
cache: Literal["read", "write"] | None = Field(default=None)
|
127
131
|
"""Was this a cache read or write."""
|
128
132
|
|
@@ -176,6 +180,32 @@ class ToolEvent(BaseEvent):
|
|
176
180
|
self.events = events
|
177
181
|
self.pending = None
|
178
182
|
|
183
|
+
# mechanism for operator to cancel the tool call
|
184
|
+
|
185
|
+
def set_task(self, task: asyncio.Task[Any]) -> None:
|
186
|
+
"""Set the tool task (for possible cancellation)"""
|
187
|
+
self._task = task
|
188
|
+
|
189
|
+
def cancel(self) -> None:
|
190
|
+
"""Cancel the tool task."""
|
191
|
+
if self._task:
|
192
|
+
self._cancelled = True
|
193
|
+
self._task.cancel()
|
194
|
+
|
195
|
+
@property
|
196
|
+
def cancelled(self) -> bool:
|
197
|
+
"""Was the task cancelled?"""
|
198
|
+
return self._cancelled is True
|
199
|
+
|
200
|
+
_cancelled: bool | None = None
|
201
|
+
"""Was this tool call cancelled?"""
|
202
|
+
|
203
|
+
_task: asyncio.Task[Any] | None = None
|
204
|
+
"""Handle to task (used for cancellation)"""
|
205
|
+
|
206
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
207
|
+
"""Required so that we can include '_task' as a member."""
|
208
|
+
|
179
209
|
|
180
210
|
class ApprovalEvent(BaseEvent):
|
181
211
|
"""Tool approval."""
|
inspect_ai/model/__init__.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
# ruff: noqa: F401 F403 F405
|
2
2
|
|
3
|
-
from inspect_ai._util.content import
|
3
|
+
from inspect_ai._util.content import (
|
4
|
+
Content,
|
5
|
+
ContentAudio,
|
6
|
+
ContentImage,
|
7
|
+
ContentText,
|
8
|
+
ContentVideo,
|
9
|
+
)
|
4
10
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
5
11
|
|
6
12
|
from ._cache import (
|
@@ -42,8 +48,10 @@ __all__ = [
|
|
42
48
|
"GenerateConfig",
|
43
49
|
"GenerateConfigArgs",
|
44
50
|
"CachePolicy",
|
45
|
-
"
|
51
|
+
"ContentAudio",
|
46
52
|
"ContentImage",
|
53
|
+
"ContentText",
|
54
|
+
"ContentVideo",
|
47
55
|
"Content",
|
48
56
|
"ChatMessage",
|
49
57
|
"ChatMessageSystem",
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -24,11 +24,17 @@ from typing import (
|
|
24
24
|
from jsonschema import Draft7Validator
|
25
25
|
from pydantic import BaseModel
|
26
26
|
|
27
|
-
from inspect_ai._util.content import
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
28
34
|
from inspect_ai._util.format import format_function_call
|
29
35
|
from inspect_ai._util.text import truncate_string_to_bytes
|
30
36
|
from inspect_ai._util.trace import trace_action
|
31
|
-
from inspect_ai.model.
|
37
|
+
from inspect_ai.model._conversation import conversation_tool_mesage
|
32
38
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
33
39
|
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
|
34
40
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
|
@@ -120,10 +126,14 @@ async def call_tools(
|
|
120
126
|
# massage result, leave list[Content] alone, convert all other
|
121
127
|
# types to string as that is what the model APIs accept
|
122
128
|
truncated: tuple[int, int] | None = None
|
123
|
-
if isinstance(
|
129
|
+
if isinstance(
|
130
|
+
result, ContentText | ContentImage | ContentAudio | ContentVideo
|
131
|
+
):
|
124
132
|
content: str | list[Content] = [result]
|
125
133
|
elif isinstance(result, list) and (
|
126
|
-
isinstance(
|
134
|
+
isinstance(
|
135
|
+
result[0], ContentText | ContentImage | ContentAudio | ContentVideo
|
136
|
+
)
|
127
137
|
):
|
128
138
|
content = result
|
129
139
|
else:
|
@@ -163,6 +173,9 @@ async def call_tools(
|
|
163
173
|
# call tools
|
164
174
|
tool_messages: list[ChatMessageTool] = []
|
165
175
|
for call in message.tool_calls:
|
176
|
+
# create the task
|
177
|
+
task = asyncio.create_task(call_tool_task(call))
|
178
|
+
|
166
179
|
# create pending tool event and add it to the transcript
|
167
180
|
event = ToolEvent(
|
168
181
|
id=call.id,
|
@@ -171,15 +184,44 @@ async def call_tools(
|
|
171
184
|
view=call.view,
|
172
185
|
pending=True,
|
173
186
|
)
|
187
|
+
event.set_task(task)
|
174
188
|
transcript()._event(event)
|
175
189
|
|
176
|
-
# execute the tool call
|
177
|
-
|
178
|
-
|
190
|
+
# execute the tool call. if the operator cancelled the
|
191
|
+
# tool call then synthesize the appropriate message/event
|
192
|
+
try:
|
193
|
+
tool_message, result_event = await task
|
194
|
+
except asyncio.CancelledError:
|
195
|
+
if event.cancelled:
|
196
|
+
tool_message = ChatMessageTool(
|
197
|
+
content="",
|
198
|
+
function=call.function,
|
199
|
+
tool_call_id=call.id,
|
200
|
+
error=ToolCallError(
|
201
|
+
"timeout", "Command timed out before completing."
|
202
|
+
),
|
203
|
+
)
|
204
|
+
result_event = ToolEvent(
|
205
|
+
id=call.id,
|
206
|
+
function=call.function,
|
207
|
+
arguments=call.arguments,
|
208
|
+
result=tool_message.content,
|
209
|
+
truncated=None,
|
210
|
+
view=call.view,
|
211
|
+
error=tool_message.error,
|
212
|
+
events=[],
|
213
|
+
)
|
214
|
+
transcript().info(
|
215
|
+
f"Tool call '{call.function}' was cancelled by operator."
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
raise
|
219
|
+
|
220
|
+
# update return messages
|
179
221
|
tool_messages.append(tool_message)
|
180
222
|
|
181
|
-
#
|
182
|
-
|
223
|
+
# print conversation if display is conversation
|
224
|
+
conversation_tool_mesage(tool_message)
|
183
225
|
|
184
226
|
# update the event with the results
|
185
227
|
event.set_result(
|
@@ -286,6 +328,10 @@ def tool_params(input: dict[str, Any], func: Callable[..., Any]) -> dict[str, An
|
|
286
328
|
type_hints = get_type_hints(func)
|
287
329
|
docstring = inspect.getdoc(func)
|
288
330
|
|
331
|
+
# if the function takes **kwargs: Any then just pass the tool arguments through
|
332
|
+
if "kwargs" in type_hints and type_hints["kwargs"] == Any:
|
333
|
+
return input
|
334
|
+
|
289
335
|
# build params
|
290
336
|
params: dict[str, Any] = {}
|
291
337
|
for param_name, param in signature.parameters.items():
|
@@ -411,12 +457,13 @@ def truncate_tool_output(
|
|
411
457
|
# truncate if required
|
412
458
|
truncated = truncate_string_to_bytes(output, active_max_output)
|
413
459
|
if truncated:
|
414
|
-
truncated_output = dedent(
|
460
|
+
truncated_output = dedent("""
|
415
461
|
The output of your call to {tool_name} was too long to be displayed.
|
416
462
|
Here is a truncated version:
|
417
463
|
<START_TOOL_OUTPUT>
|
418
|
-
{
|
419
|
-
<END_TOOL_OUTPUT>
|
464
|
+
{truncated_output}
|
465
|
+
<END_TOOL_OUTPUT>
|
466
|
+
""").format(tool_name=tool_name, truncated_output=truncated.output)
|
420
467
|
return TruncatedToolOutput(
|
421
468
|
truncated_output, truncated.original_bytes, active_max_output
|
422
469
|
)
|
@@ -59,10 +59,8 @@ class ChatMessageBase(BaseModel):
|
|
59
59
|
if isinstance(self.content, str):
|
60
60
|
self.content = text
|
61
61
|
else:
|
62
|
-
|
63
|
-
|
64
|
-
]
|
65
|
-
self.content = [ContentText(text=text)] + all_images
|
62
|
+
all_other = [content for content in self.content if content.type != "text"]
|
63
|
+
self.content = [ContentText(text=text)] + all_other
|
66
64
|
|
67
65
|
|
68
66
|
class ChatMessageSystem(ChatMessageBase):
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from rich.console import RenderableType
|
2
|
+
from rich.text import Text
|
3
|
+
|
4
|
+
from inspect_ai._util.constants import NO_CONTENT
|
5
|
+
from inspect_ai._util.rich import lines_display
|
6
|
+
from inspect_ai._util.transcript import transcript_markdown
|
7
|
+
from inspect_ai.util._conversation import conversation_panel
|
8
|
+
from inspect_ai.util._display import display_type
|
9
|
+
|
10
|
+
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
11
|
+
from ._render import messages_preceding_assistant, render_tool_calls
|
12
|
+
|
13
|
+
MESSAGE_TITLE = "Message"
|
14
|
+
|
15
|
+
|
16
|
+
def conversation_tool_mesage(message: ChatMessageTool) -> None:
|
17
|
+
if display_type() == "conversation":
|
18
|
+
# truncate output to 100 lines
|
19
|
+
output = (
|
20
|
+
message.error.message.strip() if message.error else message.text.strip()
|
21
|
+
)
|
22
|
+
if output:
|
23
|
+
content = lines_display(output, 100)
|
24
|
+
|
25
|
+
conversation_panel(
|
26
|
+
title=f"Tool Output: {message.function}",
|
27
|
+
content=content,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def conversation_assistant_message(
|
32
|
+
input: list[ChatMessage], message: ChatMessageAssistant
|
33
|
+
) -> None:
|
34
|
+
if display_type() == "conversation":
|
35
|
+
# print precding messages that aren't tool or assistant
|
36
|
+
for m in messages_preceding_assistant(input):
|
37
|
+
conversation_panel(
|
38
|
+
title=m.role.capitalize(),
|
39
|
+
content=transcript_markdown(m.text, escape=True),
|
40
|
+
)
|
41
|
+
|
42
|
+
# start with assistant content
|
43
|
+
content: list[RenderableType] = (
|
44
|
+
[transcript_markdown(message.text, escape=True)]
|
45
|
+
if message.text and message.text != NO_CONTENT
|
46
|
+
else []
|
47
|
+
)
|
48
|
+
|
49
|
+
# print tool calls
|
50
|
+
if message.tool_calls:
|
51
|
+
if content:
|
52
|
+
content.append(Text())
|
53
|
+
content.extend(render_tool_calls(message.tool_calls))
|
54
|
+
|
55
|
+
# print the assistant message
|
56
|
+
conversation_panel(title="Assistant", content=content)
|
57
|
+
|
58
|
+
|
59
|
+
def conversation_assistant_error(error: Exception) -> None:
|
60
|
+
if display_type() == "conversation":
|
61
|
+
conversation_panel(title="Assistant", content=repr(error))
|
@@ -58,14 +58,17 @@ class GenerateConfigArgs(TypedDict, total=False):
|
|
58
58
|
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, and TogetherAI only."""
|
59
59
|
|
60
60
|
logprobs: bool | None
|
61
|
-
"""Return log probabilities of the output tokens. OpenAI,
|
61
|
+
"""Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
|
62
62
|
|
63
63
|
top_logprobs: int | None
|
64
|
-
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI,
|
64
|
+
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, and Huggingface only."""
|
65
65
|
|
66
66
|
parallel_tool_calls: bool | None
|
67
67
|
"""Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
|
68
68
|
|
69
|
+
internal_tools: bool | None
|
70
|
+
"""Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""
|
71
|
+
|
69
72
|
max_tool_output: int | None
|
70
73
|
"""Maximum tool output (in bytes). Defaults to 16 * 1024."""
|
71
74
|
|
@@ -128,14 +131,17 @@ class GenerateConfig(BaseModel):
|
|
128
131
|
"""How many chat completion choices to generate for each input message. OpenAI, Grok, Google, TogetherAI, and vLLM only."""
|
129
132
|
|
130
133
|
logprobs: bool | None = Field(default=None)
|
131
|
-
"""Return log probabilities of the output tokens. OpenAI,
|
134
|
+
"""Return log probabilities of the output tokens. OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM only."""
|
132
135
|
|
133
136
|
top_logprobs: int | None = Field(default=None)
|
134
|
-
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI,
|
137
|
+
"""Number of most likely tokens (0-20) to return at each token position, each with an associated log probability. OpenAI, Grok, Huggingface, and vLLM only."""
|
135
138
|
|
136
139
|
parallel_tool_calls: bool | None = Field(default=None)
|
137
140
|
"""Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""
|
138
141
|
|
142
|
+
internal_tools: bool | None = Field(default=None)
|
143
|
+
"""Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""
|
144
|
+
|
139
145
|
max_tool_output: int | None = Field(default=None)
|
140
146
|
"""Maximum tool output (in bytes). Defaults to 16 * 1024."""
|
141
147
|
|
inspect_ai/model/_model.py
CHANGED
@@ -33,6 +33,7 @@ from inspect_ai._util.trace import trace_action
|
|
33
33
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
34
34
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
35
35
|
from inspect_ai.util import concurrency
|
36
|
+
from inspect_ai.util._limit import SampleLimitExceededError
|
36
37
|
|
37
38
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
38
39
|
from ._call_tools import disable_parallel_tools, tool_call_view, tools_info
|
@@ -43,6 +44,7 @@ from ._chat_message import (
|
|
43
44
|
ChatMessageTool,
|
44
45
|
ChatMessageUser,
|
45
46
|
)
|
47
|
+
from ._conversation import conversation_assistant_error, conversation_assistant_message
|
46
48
|
from ._generate_config import (
|
47
49
|
GenerateConfig,
|
48
50
|
active_generate_config,
|
@@ -50,7 +52,6 @@ from ._generate_config import (
|
|
50
52
|
)
|
51
53
|
from ._model_call import ModelCall
|
52
54
|
from ._model_output import ModelOutput, ModelUsage
|
53
|
-
from ._trace import trace_assistant_message
|
54
55
|
|
55
56
|
logger = logging.getLogger(__name__)
|
56
57
|
|
@@ -116,7 +117,7 @@ class ModelAPI(abc.ABC):
|
|
116
117
|
tools: list[ToolInfo],
|
117
118
|
tool_choice: ToolChoice,
|
118
119
|
config: GenerateConfig,
|
119
|
-
) -> ModelOutput | tuple[ModelOutput, ModelCall]:
|
120
|
+
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
120
121
|
"""Generate output from the model.
|
121
122
|
|
122
123
|
Args:
|
@@ -165,7 +166,7 @@ class ModelAPI(abc.ABC):
|
|
165
166
|
return False
|
166
167
|
|
167
168
|
def tool_result_images(self) -> bool:
|
168
|
-
"""Tool results can
|
169
|
+
"""Tool results can contain images"""
|
169
170
|
return False
|
170
171
|
|
171
172
|
|
@@ -222,11 +223,17 @@ class Model:
|
|
222
223
|
Returns:
|
223
224
|
ModelOutput
|
224
225
|
"""
|
226
|
+
# if we are the default model then enforce message limit if it
|
227
|
+
# exists (raise an exception if it is exceeded)
|
228
|
+
is_active_model = self == active_model()
|
229
|
+
if is_active_model:
|
230
|
+
handle_sample_message_limit(input)
|
231
|
+
|
225
232
|
# base config for this model
|
226
233
|
base_config = self.config
|
227
234
|
|
228
235
|
# if we are the active_model then merge active generate config
|
229
|
-
if
|
236
|
+
if is_active_model:
|
230
237
|
base_config = base_config.merge(active_generate_config())
|
231
238
|
|
232
239
|
# merge passed config
|
@@ -296,6 +303,9 @@ class Model:
|
|
296
303
|
tools = []
|
297
304
|
tool_choice = "none"
|
298
305
|
|
306
|
+
# apply any tool model_input handlers
|
307
|
+
input = resolve_tool_model_input(tdefs, input)
|
308
|
+
|
299
309
|
# break tool image content out into user messages if the model doesn't
|
300
310
|
# support tools returning images
|
301
311
|
if not self.api.tool_result_images():
|
@@ -389,6 +399,17 @@ class Model:
|
|
389
399
|
output = result
|
390
400
|
call = None
|
391
401
|
|
402
|
+
# raise error
|
403
|
+
if isinstance(output, Exception):
|
404
|
+
complete(output, call)
|
405
|
+
|
406
|
+
# Wrap the error in a runtime error which will show the
|
407
|
+
# request which caused the error
|
408
|
+
error = repr(output)
|
409
|
+
request = json.dumps(call.request, indent=2) if call is not None else ""
|
410
|
+
error_message = f"{error}\n\nRequest:\n{request}"
|
411
|
+
raise RuntimeError(error_message)
|
412
|
+
|
392
413
|
# update output with time elapsed
|
393
414
|
output.time = time_elapsed
|
394
415
|
|
@@ -464,7 +485,7 @@ class Model:
|
|
464
485
|
cache: Literal["read", "write"] | None,
|
465
486
|
output: ModelOutput | None = None,
|
466
487
|
call: ModelCall | None = None,
|
467
|
-
) -> Callable[[ModelOutput, ModelCall | None], None]:
|
488
|
+
) -> Callable[[ModelOutput | Exception, ModelCall | None], None]:
|
468
489
|
from inspect_ai.log._transcript import ModelEvent, transcript
|
469
490
|
|
470
491
|
# create event and add it to the transcript
|
@@ -484,13 +505,16 @@ class Model:
|
|
484
505
|
|
485
506
|
# callable that can be used to update the interaction w/ output
|
486
507
|
def complete(
|
487
|
-
|
508
|
+
result: ModelOutput | Exception, updated_call: ModelCall | None
|
488
509
|
) -> None:
|
489
510
|
# trace
|
490
|
-
|
511
|
+
if isinstance(result, ModelOutput):
|
512
|
+
conversation_assistant_message(input, result.choices[0].message)
|
513
|
+
event.output = result
|
514
|
+
else:
|
515
|
+
conversation_assistant_error(result)
|
516
|
+
event.error = repr(result)
|
491
517
|
|
492
|
-
# update event
|
493
|
-
event.output = updated_output
|
494
518
|
event.call = updated_call
|
495
519
|
event.pending = None
|
496
520
|
|
@@ -703,6 +727,40 @@ def simple_input_messages(
|
|
703
727
|
return messages
|
704
728
|
|
705
729
|
|
730
|
+
def resolve_tool_model_input(
|
731
|
+
tdefs: list[ToolDef], messages: list[ChatMessage]
|
732
|
+
) -> list[ChatMessage]:
|
733
|
+
# filter on tooldefs that have a model input handler
|
734
|
+
tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
|
735
|
+
|
736
|
+
# bail if there are no handlers
|
737
|
+
if len(tdefs) == 0:
|
738
|
+
return messages
|
739
|
+
|
740
|
+
# don't mutate the original messages
|
741
|
+
messages = deepcopy(messages)
|
742
|
+
|
743
|
+
# extract tool messages
|
744
|
+
tool_messages = [
|
745
|
+
message for message in messages if isinstance(message, ChatMessageTool)
|
746
|
+
]
|
747
|
+
# run model_input handlers over all tool_messages with the same function name
|
748
|
+
for tdef in tdefs:
|
749
|
+
assert tdef.model_input
|
750
|
+
# filter messages down to just this tool
|
751
|
+
tdef_tool_messages = [
|
752
|
+
message for message in tool_messages if message.function == tdef.name
|
753
|
+
]
|
754
|
+
# call the function for each tool, passing the index, total, and content
|
755
|
+
for index, message in enumerate(tdef_tool_messages):
|
756
|
+
message.content = tdef.model_input(
|
757
|
+
index, len(tool_messages), message.content
|
758
|
+
)
|
759
|
+
|
760
|
+
# return modified messages
|
761
|
+
return messages
|
762
|
+
|
763
|
+
|
706
764
|
def tool_result_images_as_user_message(
|
707
765
|
messages: list[ChatMessage],
|
708
766
|
) -> list[ChatMessage]:
|
@@ -713,16 +771,21 @@ def tool_result_images_reducer(
|
|
713
771
|
messages: list[ChatMessage],
|
714
772
|
message: ChatMessage,
|
715
773
|
) -> list[ChatMessage]:
|
716
|
-
# append the message
|
717
|
-
messages.append(message)
|
718
|
-
|
719
774
|
# if there are tool result images, pull them out into a ChatUserMessage
|
720
775
|
if isinstance(message, ChatMessageTool) and isinstance(message.content, list):
|
776
|
+
tool_message = ChatMessageTool(
|
777
|
+
content=message.content.copy(),
|
778
|
+
tool_call_id=message.tool_call_id,
|
779
|
+
function=message.function,
|
780
|
+
)
|
781
|
+
assert isinstance(tool_message.content, list)
|
782
|
+
messages.append(tool_message)
|
783
|
+
|
721
784
|
user_content: list[Content] = []
|
722
|
-
for i in range(0, len(
|
723
|
-
if isinstance(
|
785
|
+
for i in range(0, len(tool_message.content)):
|
786
|
+
if isinstance(tool_message.content[i], ContentImage):
|
724
787
|
user_content.append(message.content[i])
|
725
|
-
|
788
|
+
tool_message.content[i] = ContentText(
|
726
789
|
text="Image content is in the message below."
|
727
790
|
)
|
728
791
|
if len(user_content) > 0:
|
@@ -730,6 +793,9 @@ def tool_result_images_reducer(
|
|
730
793
|
ChatMessageUser(content=user_content, tool_call_id=message.tool_call_id)
|
731
794
|
)
|
732
795
|
|
796
|
+
else:
|
797
|
+
messages.append(message)
|
798
|
+
|
733
799
|
# return messages
|
734
800
|
return messages
|
735
801
|
|
@@ -813,6 +879,24 @@ def active_model() -> Model | None:
|
|
813
879
|
active_model_context_var: ContextVar[Model] = ContextVar("active_model")
|
814
880
|
|
815
881
|
|
882
|
+
def handle_sample_message_limit(input: str | list[ChatMessage]) -> None:
|
883
|
+
from inspect_ai.log._samples import (
|
884
|
+
active_sample_message_limit,
|
885
|
+
set_active_sample_total_messages,
|
886
|
+
)
|
887
|
+
|
888
|
+
total_messages = 1 if isinstance(input, str) else len(input)
|
889
|
+
message_limit = active_sample_message_limit()
|
890
|
+
if message_limit is not None:
|
891
|
+
if total_messages >= message_limit:
|
892
|
+
raise SampleLimitExceededError(
|
893
|
+
"message", value=total_messages, limit=message_limit
|
894
|
+
)
|
895
|
+
|
896
|
+
# set total messages
|
897
|
+
set_active_sample_total_messages(total_messages)
|
898
|
+
|
899
|
+
|
816
900
|
def init_model_usage() -> None:
|
817
901
|
model_usage_context_var.set({})
|
818
902
|
|
@@ -822,13 +906,28 @@ def init_sample_model_usage() -> None:
|
|
822
906
|
|
823
907
|
|
824
908
|
def record_model_usage(model: str, usage: ModelUsage) -> None:
|
909
|
+
from inspect_ai.log._samples import (
|
910
|
+
active_sample_token_limit,
|
911
|
+
set_active_sample_total_tokens,
|
912
|
+
)
|
913
|
+
|
914
|
+
# record usage
|
825
915
|
set_model_usage(model, usage, sample_model_usage_context_var.get(None))
|
826
916
|
set_model_usage(model, usage, model_usage_context_var.get(None))
|
827
917
|
|
828
|
-
#
|
829
|
-
|
918
|
+
# compute total tokens
|
919
|
+
total_tokens = sample_total_tokens()
|
830
920
|
|
831
|
-
|
921
|
+
# update active sample
|
922
|
+
set_active_sample_total_tokens(total_tokens)
|
923
|
+
|
924
|
+
# check for token limit overflow and raise
|
925
|
+
token_limit = active_sample_token_limit()
|
926
|
+
if token_limit is not None:
|
927
|
+
if total_tokens > token_limit:
|
928
|
+
raise SampleLimitExceededError(
|
929
|
+
"token", value=total_tokens, limit=token_limit
|
930
|
+
)
|
832
931
|
|
833
932
|
|
834
933
|
def set_model_usage(
|
@@ -26,9 +26,14 @@ class ModelUsage(BaseModel):
|
|
26
26
|
|
27
27
|
|
28
28
|
StopReason = Literal[
|
29
|
-
"stop",
|
29
|
+
"stop",
|
30
|
+
"max_tokens",
|
31
|
+
"model_length",
|
32
|
+
"tool_calls",
|
33
|
+
"content_filter",
|
34
|
+
"unknown",
|
30
35
|
]
|
31
|
-
"""Reason that the model stopped
|
36
|
+
"""Reason that the model stopped or failed to generate."""
|
32
37
|
|
33
38
|
|
34
39
|
class TopLogprob(BaseModel):
|