inspect-ai 0.3.57__py3-none-any.whl → 0.3.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_cli/common.py +4 -2
- inspect_ai/_cli/eval.py +2 -0
- inspect_ai/_cli/trace.py +21 -2
- inspect_ai/_display/core/active.py +0 -2
- inspect_ai/_display/rich/display.py +4 -4
- inspect_ai/_display/textual/app.py +4 -1
- inspect_ai/_display/textual/widgets/samples.py +41 -5
- inspect_ai/_eval/eval.py +32 -20
- inspect_ai/_eval/evalset.py +7 -5
- inspect_ai/_eval/task/__init__.py +2 -2
- inspect_ai/_eval/task/images.py +40 -25
- inspect_ai/_eval/task/run.py +141 -119
- inspect_ai/_eval/task/task.py +140 -25
- inspect_ai/_util/constants.py +1 -0
- inspect_ai/_util/content.py +23 -1
- inspect_ai/_util/images.py +20 -17
- inspect_ai/_util/kvstore.py +73 -0
- inspect_ai/_util/notgiven.py +18 -0
- inspect_ai/_util/thread.py +5 -0
- inspect_ai/_view/www/dist/assets/index.js +37 -3
- inspect_ai/_view/www/log-schema.json +97 -13
- inspect_ai/_view/www/src/components/MessageBand.mjs +2 -2
- inspect_ai/_view/www/src/components/MessageContent.mjs +43 -1
- inspect_ai/_view/www/src/samples/SamplesDescriptor.mjs +5 -1
- inspect_ai/_view/www/src/types/log.d.ts +51 -27
- inspect_ai/approval/_human/util.py +2 -2
- inspect_ai/dataset/_sources/csv.py +2 -1
- inspect_ai/dataset/_sources/json.py +2 -1
- inspect_ai/dataset/_sources/util.py +15 -7
- inspect_ai/log/_condense.py +11 -1
- inspect_ai/log/_log.py +2 -5
- inspect_ai/log/_recorders/eval.py +19 -8
- inspect_ai/log/_samples.py +10 -5
- inspect_ai/log/_transcript.py +28 -1
- inspect_ai/model/__init__.py +10 -2
- inspect_ai/model/_call_tools.py +55 -12
- inspect_ai/model/_chat_message.py +2 -4
- inspect_ai/model/{_trace.py → _conversation.py} +9 -8
- inspect_ai/model/_model.py +2 -2
- inspect_ai/model/_providers/anthropic.py +9 -7
- inspect_ai/model/_providers/azureai.py +6 -4
- inspect_ai/model/_providers/bedrock.py +6 -4
- inspect_ai/model/_providers/google.py +79 -8
- inspect_ai/model/_providers/groq.py +7 -5
- inspect_ai/model/_providers/hf.py +11 -6
- inspect_ai/model/_providers/mistral.py +6 -9
- inspect_ai/model/_providers/openai.py +17 -5
- inspect_ai/model/_providers/vertex.py +17 -4
- inspect_ai/scorer/__init__.py +13 -2
- inspect_ai/scorer/_metrics/__init__.py +2 -2
- inspect_ai/scorer/_metrics/std.py +3 -3
- inspect_ai/tool/__init__.py +9 -1
- inspect_ai/tool/_tool.py +9 -2
- inspect_ai/util/__init__.py +0 -3
- inspect_ai/util/{_trace.py → _conversation.py} +3 -17
- inspect_ai/util/_display.py +14 -4
- inspect_ai/util/_sandbox/context.py +12 -13
- inspect_ai/util/_sandbox/docker/compose.py +24 -11
- inspect_ai/util/_sandbox/docker/docker.py +20 -13
- inspect_ai/util/_sandbox/environment.py +13 -1
- inspect_ai/util/_sandbox/local.py +1 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/METADATA +2 -2
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/RECORD +68 -65
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/LICENSE +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.57.dist-info → inspect_ai-0.3.58.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from rich.text import Text
|
|
5
5
|
|
6
6
|
from inspect_ai._util.transcript import transcript_markdown
|
7
7
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallView
|
8
|
-
from inspect_ai.util.
|
8
|
+
from inspect_ai.util._display import display_type
|
9
9
|
|
10
10
|
HUMAN_APPROVED = "Human operator approved tool call."
|
11
11
|
HUMAN_REJECTED = "Human operator rejected the tool call."
|
@@ -18,7 +18,7 @@ def render_tool_approval(message: str, view: ToolCallView) -> list[RenderableTyp
|
|
18
18
|
text_highlighter = ReprHighlighter()
|
19
19
|
|
20
20
|
# ignore content if trace enabled
|
21
|
-
message = message.strip() if
|
21
|
+
message = message.strip() if display_type() != "conversation" else ""
|
22
22
|
|
23
23
|
def add_view_content(view_content: ToolCallContent) -> None:
|
24
24
|
if view_content.title:
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import csv
|
2
|
+
import os
|
2
3
|
from io import TextIOWrapper
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Any
|
@@ -75,7 +76,7 @@ def csv_dataset(
|
|
75
76
|
dataset = MemoryDataset(
|
76
77
|
samples=data_to_samples(valid_data, data_to_sample, auto_id),
|
77
78
|
name=name,
|
78
|
-
location=csv_file,
|
79
|
+
location=os.path.abspath(csv_file),
|
79
80
|
)
|
80
81
|
|
81
82
|
# resolve relative file paths
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
import os
|
2
3
|
from io import TextIOWrapper
|
3
4
|
from pathlib import Path
|
4
5
|
from typing import Any, cast
|
@@ -75,7 +76,7 @@ def json_dataset(
|
|
75
76
|
dataset = MemoryDataset(
|
76
77
|
samples=data_to_samples(dataset_reader(f), data_to_sample, auto_id),
|
77
78
|
name=name,
|
78
|
-
location=json_file,
|
79
|
+
location=os.path.abspath(json_file),
|
79
80
|
)
|
80
81
|
|
81
82
|
# resolve relative file paths
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from typing import Callable
|
2
2
|
|
3
|
-
from inspect_ai._util.content import Content, ContentImage
|
3
|
+
from inspect_ai._util.content import Content, ContentAudio, ContentImage, ContentVideo
|
4
4
|
from inspect_ai._util.file import filesystem
|
5
5
|
from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser
|
6
6
|
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
|
@@ -44,24 +44,28 @@ def resolve_sample_files(dataset: Dataset) -> None:
|
|
44
44
|
for path in sample.files.keys():
|
45
45
|
sample.files[path] = resolve_file(sample.files[path])
|
46
46
|
|
47
|
+
# check for setup script
|
48
|
+
if sample.setup is not None:
|
49
|
+
sample.setup = resolve_file(sample.setup)
|
50
|
+
|
47
51
|
# check for image paths
|
48
52
|
if not isinstance(sample.input, str):
|
49
|
-
sample.input =
|
53
|
+
sample.input = messages_with_resolved_content(sample.input, resolve_file)
|
50
54
|
|
51
55
|
|
52
|
-
def
|
56
|
+
def messages_with_resolved_content(
|
53
57
|
messages: list[ChatMessage], resolver: Callable[[str], str]
|
54
58
|
) -> list[ChatMessage]:
|
55
|
-
return [
|
59
|
+
return [message_with_resolved_content(message, resolver) for message in messages]
|
56
60
|
|
57
61
|
|
58
|
-
def
|
62
|
+
def message_with_resolved_content(
|
59
63
|
message: ChatMessage, resolver: Callable[[str], str]
|
60
64
|
) -> ChatMessage:
|
61
65
|
if isinstance(message, ChatMessageUser) and not isinstance(message.content, str):
|
62
66
|
return ChatMessageUser(
|
63
67
|
content=[
|
64
|
-
|
68
|
+
chat_content_with_resolved_content(content, resolver)
|
65
69
|
for content in message.content
|
66
70
|
],
|
67
71
|
source=message.source,
|
@@ -70,7 +74,7 @@ def message_with_resolved_image(
|
|
70
74
|
return message
|
71
75
|
|
72
76
|
|
73
|
-
def
|
77
|
+
def chat_content_with_resolved_content(
|
74
78
|
content: Content, resolver: Callable[[str], str]
|
75
79
|
) -> Content:
|
76
80
|
if isinstance(content, ContentImage):
|
@@ -78,5 +82,9 @@ def chat_content_with_resolved_image(
|
|
78
82
|
image=resolver(content.image),
|
79
83
|
detail=content.detail,
|
80
84
|
)
|
85
|
+
elif isinstance(content, ContentAudio):
|
86
|
+
return ContentAudio(audio=resolver(content.audio), format=content.format)
|
87
|
+
elif isinstance(content, ContentVideo):
|
88
|
+
return ContentVideo(video=resolver(content.video), format=content.format)
|
81
89
|
else:
|
82
90
|
return content
|
inspect_ai/log/_condense.py
CHANGED
@@ -6,7 +6,13 @@ from typing import (
|
|
6
6
|
from pydantic import JsonValue
|
7
7
|
|
8
8
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED
|
9
|
-
from inspect_ai._util.content import
|
9
|
+
from inspect_ai._util.content import (
|
10
|
+
Content,
|
11
|
+
ContentAudio,
|
12
|
+
ContentImage,
|
13
|
+
ContentText,
|
14
|
+
ContentVideo,
|
15
|
+
)
|
10
16
|
from inspect_ai._util.hash import mm3_hash
|
11
17
|
from inspect_ai._util.json import JsonChange
|
12
18
|
from inspect_ai._util.url import is_data_uri
|
@@ -304,3 +310,7 @@ def walk_content(content: Content, content_fn: Callable[[str], str]) -> Content:
|
|
304
310
|
return content.model_copy(update=dict(text=content_fn(content.text)))
|
305
311
|
elif isinstance(content, ContentImage):
|
306
312
|
return content.model_copy(update=dict(image=content_fn(content.image)))
|
313
|
+
elif isinstance(content, ContentAudio):
|
314
|
+
return content.model_copy(update=dict(audio=content_fn(content.audio)))
|
315
|
+
elif isinstance(content, ContentVideo):
|
316
|
+
return content.model_copy(update=dict(video=content_fn(content.video)))
|
inspect_ai/log/_log.py
CHANGED
@@ -48,9 +48,6 @@ class EvalConfig(BaseModel):
|
|
48
48
|
epochs_reducer: list[str] | None = Field(default=None)
|
49
49
|
"""Reducers for aggregating per-sample scores."""
|
50
50
|
|
51
|
-
trace: bool | None = Field(default=None)
|
52
|
-
"""Trace message interactions with evaluated model to terminal."""
|
53
|
-
|
54
51
|
approval: ApprovalPolicyConfig | None = Field(default=None)
|
55
52
|
"""Approval policy for tool use."""
|
56
53
|
|
@@ -355,7 +352,7 @@ class EvalResults(BaseModel):
|
|
355
352
|
"""Scorer used to compute results (deprecated)."""
|
356
353
|
warn_once(
|
357
354
|
logger,
|
358
|
-
"The 'scorer' field is deprecated. Use '
|
355
|
+
"The 'scorer' field is deprecated. Use 'scores' instead.",
|
359
356
|
)
|
360
357
|
return self.scores[0] if self.scores else None
|
361
358
|
|
@@ -364,7 +361,7 @@ class EvalResults(BaseModel):
|
|
364
361
|
"""Metrics computed (deprecated)."""
|
365
362
|
warn_once(
|
366
363
|
logger,
|
367
|
-
"The 'metrics' field is deprecated. Access metrics through '
|
364
|
+
"The 'metrics' field is deprecated. Access metrics through 'scores' instead.",
|
368
365
|
)
|
369
366
|
return self.scores[0].metrics if self.scores else {}
|
370
367
|
|
@@ -13,7 +13,12 @@ from pydantic_core import to_json
|
|
13
13
|
from typing_extensions import override
|
14
14
|
|
15
15
|
from inspect_ai._util.constants import LOG_SCHEMA_VERSION
|
16
|
-
from inspect_ai._util.content import
|
16
|
+
from inspect_ai._util.content import (
|
17
|
+
ContentAudio,
|
18
|
+
ContentImage,
|
19
|
+
ContentText,
|
20
|
+
ContentVideo,
|
21
|
+
)
|
17
22
|
from inspect_ai._util.error import EvalError
|
18
23
|
from inspect_ai._util.file import FileSystem, async_fileystem, dirname, file, filesystem
|
19
24
|
from inspect_ai._util.json import jsonable_python
|
@@ -90,9 +95,11 @@ class EvalRecorder(FileRecorder):
|
|
90
95
|
self.data: dict[str, ZipLogFile] = {}
|
91
96
|
|
92
97
|
@override
|
93
|
-
async def log_init(
|
98
|
+
async def log_init(
|
99
|
+
self, eval: EvalSpec, location: str | None = None, *, clean: bool = False
|
100
|
+
) -> str:
|
94
101
|
# if the file exists then read summaries
|
95
|
-
if location is not None and self.fs.exists(location):
|
102
|
+
if not clean and location is not None and self.fs.exists(location):
|
96
103
|
with file(location, "rb") as f:
|
97
104
|
with ZipFile(f, "r") as zip:
|
98
105
|
log_start = _read_start(zip)
|
@@ -229,7 +236,7 @@ class EvalRecorder(FileRecorder):
|
|
229
236
|
async def write_log(cls, location: str, log: EvalLog) -> None:
|
230
237
|
# write using the recorder (so we get all of the extra streams)
|
231
238
|
recorder = EvalRecorder(dirname(location))
|
232
|
-
await recorder.log_init(log.eval, location)
|
239
|
+
await recorder.log_init(log.eval, location, clean=True)
|
233
240
|
await recorder.log_start(log.eval, log.plan)
|
234
241
|
for sample in log.samples or []:
|
235
242
|
await recorder.log_sample(log.eval, sample)
|
@@ -244,12 +251,16 @@ def text_inputs(inputs: str | list[ChatMessage]) -> str | list[ChatMessage]:
|
|
244
251
|
input: list[ChatMessage] = []
|
245
252
|
for message in inputs:
|
246
253
|
if not isinstance(message.content, str):
|
247
|
-
filtered_content: list[
|
254
|
+
filtered_content: list[
|
255
|
+
ContentText | ContentImage | ContentAudio | ContentVideo
|
256
|
+
] = []
|
248
257
|
for content in message.content:
|
249
|
-
if content.type
|
258
|
+
if content.type == "text":
|
250
259
|
filtered_content.append(content)
|
251
|
-
|
252
|
-
|
260
|
+
else:
|
261
|
+
filtered_content.append(
|
262
|
+
ContentText(text=f"({content.type.capitalize()})")
|
263
|
+
)
|
253
264
|
message.content = filtered_content
|
254
265
|
input.append(message)
|
255
266
|
else:
|
inspect_ai/log/_samples.py
CHANGED
@@ -29,7 +29,7 @@ class ActiveSample:
|
|
29
29
|
sandboxes: dict[str, SandboxConnection],
|
30
30
|
) -> None:
|
31
31
|
self.id = uuid()
|
32
|
-
self.started =
|
32
|
+
self.started: float | None = None
|
33
33
|
self.completed: float | None = None
|
34
34
|
self.task = task
|
35
35
|
self.model = model
|
@@ -48,10 +48,15 @@ class ActiveSample:
|
|
48
48
|
|
49
49
|
@property
|
50
50
|
def execution_time(self) -> float:
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
51
|
+
if self.started is not None:
|
52
|
+
completed = (
|
53
|
+
self.completed
|
54
|
+
if self.completed is not None
|
55
|
+
else datetime.now().timestamp()
|
56
|
+
)
|
57
|
+
return completed - self.started
|
58
|
+
else:
|
59
|
+
return 0
|
55
60
|
|
56
61
|
def interrupt(self, action: Literal["score", "error"]) -> None:
|
57
62
|
self._interrupt_action = action
|
inspect_ai/log/_transcript.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import asyncio
|
1
2
|
import contextlib
|
2
3
|
from contextvars import ContextVar
|
3
4
|
from datetime import datetime
|
@@ -11,7 +12,7 @@ from typing import (
|
|
11
12
|
Union,
|
12
13
|
)
|
13
14
|
|
14
|
-
from pydantic import BaseModel, Field, JsonValue, field_serializer
|
15
|
+
from pydantic import BaseModel, ConfigDict, Field, JsonValue, field_serializer
|
15
16
|
|
16
17
|
from inspect_ai._util.constants import SAMPLE_SUBTASK
|
17
18
|
from inspect_ai._util.error import EvalError
|
@@ -176,6 +177,32 @@ class ToolEvent(BaseEvent):
|
|
176
177
|
self.events = events
|
177
178
|
self.pending = None
|
178
179
|
|
180
|
+
# mechanism for operator to cancel the tool call
|
181
|
+
|
182
|
+
def set_task(self, task: asyncio.Task[Any]) -> None:
|
183
|
+
"""Set the tool task (for possible cancellation)"""
|
184
|
+
self._task = task
|
185
|
+
|
186
|
+
def cancel(self) -> None:
|
187
|
+
"""Cancel the tool task."""
|
188
|
+
if self._task:
|
189
|
+
self._cancelled = True
|
190
|
+
self._task.cancel()
|
191
|
+
|
192
|
+
@property
|
193
|
+
def cancelled(self) -> bool:
|
194
|
+
"""Was the task cancelled?"""
|
195
|
+
return self._cancelled is True
|
196
|
+
|
197
|
+
_cancelled: bool | None = None
|
198
|
+
"""Was this tool call cancelled?"""
|
199
|
+
|
200
|
+
_task: asyncio.Task[Any] | None = None
|
201
|
+
"""Handle to task (used for cancellation)"""
|
202
|
+
|
203
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
204
|
+
"""Required so that we can include '_task' as a member."""
|
205
|
+
|
179
206
|
|
180
207
|
class ApprovalEvent(BaseEvent):
|
181
208
|
"""Tool approval."""
|
inspect_ai/model/__init__.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
# ruff: noqa: F401 F403 F405
|
2
2
|
|
3
|
-
from inspect_ai._util.content import
|
3
|
+
from inspect_ai._util.content import (
|
4
|
+
Content,
|
5
|
+
ContentAudio,
|
6
|
+
ContentImage,
|
7
|
+
ContentText,
|
8
|
+
ContentVideo,
|
9
|
+
)
|
4
10
|
from inspect_ai._util.deprecation import relocated_module_attribute
|
5
11
|
|
6
12
|
from ._cache import (
|
@@ -42,8 +48,10 @@ __all__ = [
|
|
42
48
|
"GenerateConfig",
|
43
49
|
"GenerateConfigArgs",
|
44
50
|
"CachePolicy",
|
45
|
-
"
|
51
|
+
"ContentAudio",
|
46
52
|
"ContentImage",
|
53
|
+
"ContentText",
|
54
|
+
"ContentVideo",
|
47
55
|
"Content",
|
48
56
|
"ChatMessage",
|
49
57
|
"ChatMessageSystem",
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -24,11 +24,17 @@ from typing import (
|
|
24
24
|
from jsonschema import Draft7Validator
|
25
25
|
from pydantic import BaseModel
|
26
26
|
|
27
|
-
from inspect_ai._util.content import
|
27
|
+
from inspect_ai._util.content import (
|
28
|
+
Content,
|
29
|
+
ContentAudio,
|
30
|
+
ContentImage,
|
31
|
+
ContentText,
|
32
|
+
ContentVideo,
|
33
|
+
)
|
28
34
|
from inspect_ai._util.format import format_function_call
|
29
35
|
from inspect_ai._util.text import truncate_string_to_bytes
|
30
36
|
from inspect_ai._util.trace import trace_action
|
31
|
-
from inspect_ai.model.
|
37
|
+
from inspect_ai.model._conversation import conversation_tool_mesage
|
32
38
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
33
39
|
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
|
34
40
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
|
@@ -120,10 +126,14 @@ async def call_tools(
|
|
120
126
|
# massage result, leave list[Content] alone, convert all other
|
121
127
|
# types to string as that is what the model APIs accept
|
122
128
|
truncated: tuple[int, int] | None = None
|
123
|
-
if isinstance(
|
129
|
+
if isinstance(
|
130
|
+
result, ContentText | ContentImage | ContentAudio | ContentVideo
|
131
|
+
):
|
124
132
|
content: str | list[Content] = [result]
|
125
133
|
elif isinstance(result, list) and (
|
126
|
-
isinstance(
|
134
|
+
isinstance(
|
135
|
+
result[0], ContentText | ContentImage | ContentAudio | ContentVideo
|
136
|
+
)
|
127
137
|
):
|
128
138
|
content = result
|
129
139
|
else:
|
@@ -163,6 +173,9 @@ async def call_tools(
|
|
163
173
|
# call tools
|
164
174
|
tool_messages: list[ChatMessageTool] = []
|
165
175
|
for call in message.tool_calls:
|
176
|
+
# create the task
|
177
|
+
task = asyncio.create_task(call_tool_task(call))
|
178
|
+
|
166
179
|
# create pending tool event and add it to the transcript
|
167
180
|
event = ToolEvent(
|
168
181
|
id=call.id,
|
@@ -171,15 +184,44 @@ async def call_tools(
|
|
171
184
|
view=call.view,
|
172
185
|
pending=True,
|
173
186
|
)
|
187
|
+
event.set_task(task)
|
174
188
|
transcript()._event(event)
|
175
189
|
|
176
|
-
# execute the tool call
|
177
|
-
|
178
|
-
|
190
|
+
# execute the tool call. if the operator cancelled the
|
191
|
+
# tool call then synthesize the appropriate message/event
|
192
|
+
try:
|
193
|
+
tool_message, result_event = await task
|
194
|
+
except asyncio.CancelledError:
|
195
|
+
if event.cancelled:
|
196
|
+
tool_message = ChatMessageTool(
|
197
|
+
content="",
|
198
|
+
function=call.function,
|
199
|
+
tool_call_id=call.id,
|
200
|
+
error=ToolCallError(
|
201
|
+
"timeout", "Command timed out before completing."
|
202
|
+
),
|
203
|
+
)
|
204
|
+
result_event = ToolEvent(
|
205
|
+
id=call.id,
|
206
|
+
function=call.function,
|
207
|
+
arguments=call.arguments,
|
208
|
+
result=tool_message.content,
|
209
|
+
truncated=None,
|
210
|
+
view=call.view,
|
211
|
+
error=tool_message.error,
|
212
|
+
events=[],
|
213
|
+
)
|
214
|
+
transcript().info(
|
215
|
+
f"Tool call '{call.function}' was cancelled by operator."
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
raise
|
219
|
+
|
220
|
+
# update return messages
|
179
221
|
tool_messages.append(tool_message)
|
180
222
|
|
181
|
-
#
|
182
|
-
|
223
|
+
# print conversation if display is conversation
|
224
|
+
conversation_tool_mesage(tool_message)
|
183
225
|
|
184
226
|
# update the event with the results
|
185
227
|
event.set_result(
|
@@ -411,12 +453,13 @@ def truncate_tool_output(
|
|
411
453
|
# truncate if required
|
412
454
|
truncated = truncate_string_to_bytes(output, active_max_output)
|
413
455
|
if truncated:
|
414
|
-
truncated_output = dedent(
|
456
|
+
truncated_output = dedent("""
|
415
457
|
The output of your call to {tool_name} was too long to be displayed.
|
416
458
|
Here is a truncated version:
|
417
459
|
<START_TOOL_OUTPUT>
|
418
|
-
{
|
419
|
-
<END_TOOL_OUTPUT>
|
460
|
+
{truncated_output}
|
461
|
+
<END_TOOL_OUTPUT>
|
462
|
+
""").format(tool_name=tool_name, truncated_output=truncated.output)
|
420
463
|
return TruncatedToolOutput(
|
421
464
|
truncated_output, truncated.original_bytes, active_max_output
|
422
465
|
)
|
@@ -59,10 +59,8 @@ class ChatMessageBase(BaseModel):
|
|
59
59
|
if isinstance(self.content, str):
|
60
60
|
self.content = text
|
61
61
|
else:
|
62
|
-
|
63
|
-
|
64
|
-
]
|
65
|
-
self.content = [ContentText(text=text)] + all_images
|
62
|
+
all_other = [content for content in self.content if content.type != "text"]
|
63
|
+
self.content = [ContentText(text=text)] + all_other
|
66
64
|
|
67
65
|
|
68
66
|
class ChatMessageSystem(ChatMessageBase):
|
@@ -3,7 +3,8 @@ from rich.text import Text
|
|
3
3
|
|
4
4
|
from inspect_ai._util.rich import lines_display
|
5
5
|
from inspect_ai._util.transcript import transcript_markdown
|
6
|
-
from inspect_ai.util.
|
6
|
+
from inspect_ai.util._conversation import conversation_panel
|
7
|
+
from inspect_ai.util._display import display_type
|
7
8
|
|
8
9
|
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
9
10
|
from ._render import messages_preceding_assistant, render_tool_calls
|
@@ -11,25 +12,25 @@ from ._render import messages_preceding_assistant, render_tool_calls
|
|
11
12
|
MESSAGE_TITLE = "Message"
|
12
13
|
|
13
14
|
|
14
|
-
def
|
15
|
-
if
|
15
|
+
def conversation_tool_mesage(message: ChatMessageTool) -> None:
|
16
|
+
if display_type() == "conversation":
|
16
17
|
# truncate output to 100 lines
|
17
18
|
output = message.error.message if message.error else message.text.strip()
|
18
19
|
content = lines_display(output, 100)
|
19
20
|
|
20
|
-
|
21
|
+
conversation_panel(
|
21
22
|
title=f"Tool Output: {message.function}",
|
22
23
|
content=content,
|
23
24
|
)
|
24
25
|
|
25
26
|
|
26
|
-
def
|
27
|
+
def conversation_assistant_message(
|
27
28
|
input: list[ChatMessage], message: ChatMessageAssistant
|
28
29
|
) -> None:
|
29
|
-
if
|
30
|
+
if display_type() == "conversation":
|
30
31
|
# print precding messages that aren't tool or assistant
|
31
32
|
for m in messages_preceding_assistant(input):
|
32
|
-
|
33
|
+
conversation_panel(
|
33
34
|
title=m.role.capitalize(),
|
34
35
|
content=transcript_markdown(m.text, escape=True),
|
35
36
|
)
|
@@ -45,4 +46,4 @@ def trace_assistant_message(
|
|
45
46
|
content.extend(render_tool_calls(message.tool_calls))
|
46
47
|
|
47
48
|
# print the assistant message
|
48
|
-
|
49
|
+
conversation_panel(title="Assistant", content=content)
|
inspect_ai/model/_model.py
CHANGED
@@ -43,6 +43,7 @@ from ._chat_message import (
|
|
43
43
|
ChatMessageTool,
|
44
44
|
ChatMessageUser,
|
45
45
|
)
|
46
|
+
from ._conversation import conversation_assistant_message
|
46
47
|
from ._generate_config import (
|
47
48
|
GenerateConfig,
|
48
49
|
active_generate_config,
|
@@ -50,7 +51,6 @@ from ._generate_config import (
|
|
50
51
|
)
|
51
52
|
from ._model_call import ModelCall
|
52
53
|
from ._model_output import ModelOutput, ModelUsage
|
53
|
-
from ._trace import trace_assistant_message
|
54
54
|
|
55
55
|
logger = logging.getLogger(__name__)
|
56
56
|
|
@@ -487,7 +487,7 @@ class Model:
|
|
487
487
|
updated_output: ModelOutput, updated_call: ModelCall | None
|
488
488
|
) -> None:
|
489
489
|
# trace
|
490
|
-
|
490
|
+
conversation_assistant_message(input, updated_output.choices[0].message)
|
491
491
|
|
492
492
|
# update event
|
493
493
|
event.output = updated_output
|
@@ -28,11 +28,11 @@ from pydantic import JsonValue
|
|
28
28
|
from typing_extensions import override
|
29
29
|
|
30
30
|
from inspect_ai._util.constants import BASE_64_DATA_REMOVED, DEFAULT_MAX_RETRIES
|
31
|
-
from inspect_ai._util.content import Content, ContentText
|
31
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
32
32
|
from inspect_ai._util.error import exception_message
|
33
|
-
from inspect_ai._util.images import
|
33
|
+
from inspect_ai._util.images import file_as_data_uri
|
34
34
|
from inspect_ai._util.logger import warn_once
|
35
|
-
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
35
|
+
from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64
|
36
36
|
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo
|
37
37
|
|
38
38
|
from .._chat_message import (
|
@@ -584,11 +584,9 @@ async def message_param_content(
|
|
584
584
|
) -> TextBlockParam | ImageBlockParam:
|
585
585
|
if isinstance(content, ContentText):
|
586
586
|
return TextBlockParam(type="text", text=content.text or NO_CONTENT)
|
587
|
-
|
587
|
+
elif isinstance(content, ContentImage):
|
588
588
|
# resolve to url
|
589
|
-
image = content.image
|
590
|
-
if not is_data_uri(image):
|
591
|
-
image = await image_as_data_uri(image)
|
589
|
+
image = await file_as_data_uri(content.image)
|
592
590
|
|
593
591
|
# resolve mime type and base64 content
|
594
592
|
media_type = data_uri_mime_type(image) or "image/png"
|
@@ -601,6 +599,10 @@ async def message_param_content(
|
|
601
599
|
type="image",
|
602
600
|
source=dict(type="base64", media_type=cast(Any, media_type), data=image),
|
603
601
|
)
|
602
|
+
else:
|
603
|
+
raise RuntimeError(
|
604
|
+
"Anthropic models do not currently support audio or video inputs."
|
605
|
+
)
|
604
606
|
|
605
607
|
|
606
608
|
def model_call_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
@@ -31,8 +31,8 @@ from azure.core.exceptions import AzureError, HttpResponseError
|
|
31
31
|
from typing_extensions import override
|
32
32
|
|
33
33
|
from inspect_ai._util.constants import DEFAULT_MAX_TOKENS
|
34
|
-
from inspect_ai._util.content import Content, ContentText
|
35
|
-
from inspect_ai._util.images import
|
34
|
+
from inspect_ai._util.content import Content, ContentImage, ContentText
|
35
|
+
from inspect_ai._util.images import file_as_data_uri
|
36
36
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
37
37
|
from inspect_ai.tool._tool_call import ToolCall
|
38
38
|
from inspect_ai.tool._tool_choice import ToolFunction
|
@@ -312,12 +312,14 @@ async def chat_request_message(
|
|
312
312
|
async def chat_content_item(content: Content) -> ContentItem:
|
313
313
|
if isinstance(content, ContentText):
|
314
314
|
return TextContentItem(text=content.text)
|
315
|
-
|
315
|
+
elif isinstance(content, ContentImage):
|
316
316
|
return ImageContentItem(
|
317
317
|
image_url=ImageUrl(
|
318
|
-
url=await
|
318
|
+
url=await file_as_data_uri(content.image), detail=content.detail
|
319
319
|
)
|
320
320
|
)
|
321
|
+
else:
|
322
|
+
raise RuntimeError("Azure AI models do not support audio or video inputs.")
|
321
323
|
|
322
324
|
|
323
325
|
def chat_tool_call(tool_call: ToolCall) -> ChatCompletionsToolCall:
|
@@ -11,7 +11,7 @@ from inspect_ai._util.constants import (
|
|
11
11
|
)
|
12
12
|
from inspect_ai._util.content import Content, ContentImage, ContentText
|
13
13
|
from inspect_ai._util.error import pip_dependency_error
|
14
|
-
from inspect_ai._util.images import
|
14
|
+
from inspect_ai._util.images import file_as_data
|
15
15
|
from inspect_ai._util.version import verify_required_version
|
16
16
|
from inspect_ai.tool import ToolChoice, ToolInfo
|
17
17
|
from inspect_ai.tool._tool_call import ToolCall
|
@@ -430,7 +430,9 @@ def model_output_from_response(
|
|
430
430
|
content.append(ContentText(type="text", text=c.text))
|
431
431
|
elif c.image is not None:
|
432
432
|
base64_image = base64.b64encode(c.image.source.bytes).decode("utf-8")
|
433
|
-
content.append(
|
433
|
+
content.append(
|
434
|
+
ContentImage(image=f"data:image/{c.image.format};base64,{base64_image}")
|
435
|
+
)
|
434
436
|
elif c.toolUse is not None:
|
435
437
|
tool_calls.append(
|
436
438
|
ToolCall(
|
@@ -565,7 +567,7 @@ async def converse_chat_message(
|
|
565
567
|
if c.type == "text":
|
566
568
|
tool_result_content.append(ConverseToolResultContent(text=c.text))
|
567
569
|
elif c.type == "image":
|
568
|
-
image_data, image_type = await
|
570
|
+
image_data, image_type = await file_as_data(c.image)
|
569
571
|
tool_result_content.append(
|
570
572
|
ConverseToolResultContent(
|
571
573
|
image=ConverseImage(
|
@@ -604,7 +606,7 @@ async def converse_contents(
|
|
604
606
|
result: list[ConverseMessageContent] = []
|
605
607
|
for c in content:
|
606
608
|
if c.type == "image":
|
607
|
-
image_data, image_type = await
|
609
|
+
image_data, image_type = await file_as_data(c.image)
|
608
610
|
result.append(
|
609
611
|
ConverseMessageContent(
|
610
612
|
image=ConverseImage(
|