inspect-ai 0.3.82__py3-none-any.whl → 0.3.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +3 -4
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +48 -0
- inspect_ai/_eval/eval.py +36 -24
- inspect_ai/_eval/evalset.py +17 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +8 -13
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +2 -0
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/file.py +13 -0
- inspect_ai/_util/json.py +2 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/www/App.css +6 -0
- inspect_ai/_view/www/dist/assets/index.css +304 -128
- inspect_ai/_view/www/dist/assets/index.js +47495 -27519
- inspect_ai/_view/www/log-schema.json +124 -31
- inspect_ai/_view/www/package.json +3 -0
- inspect_ai/_view/www/src/App.tsx +12 -0
- inspect_ai/_view/www/src/appearance/icons.ts +1 -0
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
- inspect_ai/_view/www/src/state/hooks.ts +5 -3
- inspect_ai/_view/www/src/state/logPolling.ts +5 -1
- inspect_ai/_view/www/src/state/logSlice.ts +10 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
- inspect_ai/_view/www/src/types/log.d.ts +34 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
- inspect_ai/_view/www/yarn.lock +94 -1
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_log.py +11 -2
- inspect_ai/log/_transcript.py +13 -9
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +256 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_model.py +113 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +2 -2
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/RECORD +179 -153
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/top_level.txt +0 -0
inspect_ai/model/_call_tools.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import inspect
|
2
2
|
import json
|
3
|
-
import sys
|
4
3
|
import types
|
4
|
+
from copy import copy
|
5
5
|
from dataclasses import is_dataclass
|
6
6
|
from logging import getLogger
|
7
7
|
from textwrap import dedent
|
@@ -16,15 +16,13 @@ from typing import (
|
|
16
16
|
Tuple,
|
17
17
|
Type,
|
18
18
|
Union,
|
19
|
+
cast,
|
19
20
|
get_args,
|
20
21
|
get_origin,
|
21
22
|
get_type_hints,
|
22
23
|
is_typeddict,
|
23
24
|
)
|
24
25
|
|
25
|
-
if sys.version_info < (3, 11):
|
26
|
-
from exceptiongroup import ExceptionGroup
|
27
|
-
|
28
26
|
import anyio
|
29
27
|
import yaml
|
30
28
|
from anyio.streams.memory import MemoryObjectSendStream
|
@@ -39,42 +37,69 @@ from inspect_ai._util.content import (
|
|
39
37
|
ContentVideo,
|
40
38
|
)
|
41
39
|
from inspect_ai._util.format import format_function_call
|
40
|
+
from inspect_ai._util.logger import warn_once
|
41
|
+
from inspect_ai._util.registry import registry_unqualified_name
|
42
42
|
from inspect_ai._util.text import truncate_string_to_bytes
|
43
43
|
from inspect_ai._util.trace import trace_action
|
44
44
|
from inspect_ai._util.working import sample_waiting_time
|
45
|
-
from inspect_ai.model.
|
45
|
+
from inspect_ai.model._display import display_conversation_message
|
46
|
+
from inspect_ai.model._model_output import ModelOutput
|
46
47
|
from inspect_ai.tool import Tool, ToolCall, ToolError, ToolInfo
|
47
|
-
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError
|
48
|
+
from inspect_ai.tool._tool import ToolApprovalError, ToolParsingError, ToolResult
|
48
49
|
from inspect_ai.tool._tool_call import ToolCallContent, ToolCallError
|
49
50
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
50
51
|
from inspect_ai.tool._tool_info import parse_docstring
|
51
52
|
from inspect_ai.tool._tool_params import ToolParams
|
52
53
|
from inspect_ai.util import OutputLimitExceededError
|
53
|
-
|
54
|
-
|
54
|
+
from inspect_ai.util._anyio import inner_exception
|
55
|
+
|
56
|
+
from ._chat_message import (
|
57
|
+
ChatMessage,
|
58
|
+
ChatMessageAssistant,
|
59
|
+
ChatMessageSystem,
|
60
|
+
ChatMessageTool,
|
61
|
+
ChatMessageUser,
|
62
|
+
)
|
55
63
|
from ._generate_config import active_generate_config
|
56
64
|
|
57
65
|
logger = getLogger(__name__)
|
58
66
|
|
59
67
|
|
60
|
-
|
61
|
-
message
|
68
|
+
class ExecuteToolsResult(NamedTuple):
|
69
|
+
"""Result from executing tools in the last assistant message.
|
70
|
+
|
71
|
+
In conventional tool calling scenarios there will be only a list
|
72
|
+
of `ChatMessageTool` appended and no-output. However, if there
|
73
|
+
are `handoff()` tools (used in multi-agent systems) then other
|
74
|
+
messages may be appended and an `output` may be available as well.
|
75
|
+
"""
|
76
|
+
|
77
|
+
messages: list[ChatMessage]
|
78
|
+
"""Messages added to conversation."""
|
79
|
+
|
80
|
+
output: ModelOutput | None = None
|
81
|
+
"""Model output if a generation occurred within the conversation."""
|
82
|
+
|
83
|
+
|
84
|
+
async def execute_tools(
|
85
|
+
messages: list[ChatMessage],
|
62
86
|
tools: list[Tool] | list[ToolDef] | list[Tool | ToolDef],
|
63
87
|
max_output: int | None = None,
|
64
|
-
) ->
|
65
|
-
"""Perform tool calls in assistant message.
|
88
|
+
) -> ExecuteToolsResult:
|
89
|
+
"""Perform tool calls in the last assistant message.
|
66
90
|
|
67
91
|
Args:
|
68
|
-
|
92
|
+
messages: Current message list
|
69
93
|
tools (list[Tool]): Available tools
|
70
94
|
max_output (int | None): Maximum output length (in bytes).
|
71
95
|
Defaults to max_tool_output from active GenerateConfig
|
72
96
|
(16 * 1024 by default).
|
73
97
|
|
74
98
|
Returns:
|
75
|
-
|
99
|
+
Messages added to the conversation and final model output (if any)
|
76
100
|
"""
|
77
|
-
|
101
|
+
message = messages[-1]
|
102
|
+
if isinstance(message, ChatMessageAssistant) and message.tool_calls:
|
78
103
|
from inspect_ai.log._transcript import (
|
79
104
|
ToolEvent,
|
80
105
|
Transcript,
|
@@ -87,16 +112,31 @@ async def call_tools(
|
|
87
112
|
|
88
113
|
async def call_tool_task(
|
89
114
|
call: ToolCall,
|
90
|
-
|
115
|
+
conversation: list[ChatMessage],
|
116
|
+
send_stream: MemoryObjectSendStream[
|
117
|
+
tuple[ExecuteToolsResult, ToolEvent, Exception | None]
|
118
|
+
],
|
91
119
|
) -> None:
|
92
120
|
# create a transript for this call
|
93
121
|
init_transcript(Transcript(name=call.function))
|
94
122
|
|
95
|
-
result:
|
123
|
+
result: ToolResult = ""
|
124
|
+
messages: list[ChatMessage] = []
|
125
|
+
output: ModelOutput | None = None
|
126
|
+
agent: str | None = None
|
96
127
|
tool_error: ToolCallError | None = None
|
128
|
+
tool_exception: Exception | None = None
|
97
129
|
try:
|
98
130
|
with track_store_changes():
|
99
|
-
|
131
|
+
try:
|
132
|
+
result, messages, output, agent = await call_tool(
|
133
|
+
tdefs, message.text, call, conversation
|
134
|
+
)
|
135
|
+
# unwrap exception group
|
136
|
+
except Exception as ex:
|
137
|
+
inner_ex = inner_exception(ex)
|
138
|
+
raise inner_ex.with_traceback(inner_ex.__traceback__)
|
139
|
+
|
100
140
|
except TimeoutError:
|
101
141
|
tool_error = ToolCallError(
|
102
142
|
"timeout", "Command timed out before completing."
|
@@ -133,6 +173,8 @@ async def call_tools(
|
|
133
173
|
tool_error = ToolCallError("approval", ex.message)
|
134
174
|
except ToolError as ex:
|
135
175
|
tool_error = ToolCallError("unknown", ex.message)
|
176
|
+
except Exception as ex:
|
177
|
+
tool_exception = ex
|
136
178
|
|
137
179
|
# massage result, leave list[Content] alone, convert all other
|
138
180
|
# types to string as that is what the model APIs accept
|
@@ -167,31 +209,39 @@ async def call_tools(
|
|
167
209
|
id=call.id,
|
168
210
|
function=call.function,
|
169
211
|
arguments=call.arguments,
|
170
|
-
internal_name=call.internal_name,
|
171
212
|
result=content,
|
172
213
|
truncated=truncated,
|
173
214
|
view=call.view,
|
174
215
|
error=tool_error,
|
175
216
|
events=list(transcript().events),
|
217
|
+
agent=agent,
|
176
218
|
)
|
177
219
|
|
178
220
|
# yield message and event
|
179
221
|
async with send_stream:
|
180
222
|
await send_stream.send(
|
181
223
|
(
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
224
|
+
ExecuteToolsResult(
|
225
|
+
messages=[
|
226
|
+
ChatMessageTool(
|
227
|
+
content=content,
|
228
|
+
tool_call_id=call.id,
|
229
|
+
function=call.function,
|
230
|
+
error=tool_error,
|
231
|
+
internal=call.internal,
|
232
|
+
)
|
233
|
+
]
|
234
|
+
+ messages,
|
235
|
+
output=output,
|
188
236
|
),
|
189
237
|
event,
|
238
|
+
tool_exception,
|
190
239
|
)
|
191
240
|
)
|
192
241
|
|
193
242
|
# call tools
|
194
|
-
|
243
|
+
result_messages: list[ChatMessage] = []
|
244
|
+
result_output: ModelOutput | None = None
|
195
245
|
for call in message.tool_calls:
|
196
246
|
# create pending tool event and add it to the transcript
|
197
247
|
# (record the waiting time for the sample so we can compare
|
@@ -202,8 +252,8 @@ async def call_tools(
|
|
202
252
|
id=call.id,
|
203
253
|
function=call.function,
|
204
254
|
arguments=call.arguments,
|
205
|
-
internal_name=call.internal_name,
|
206
255
|
view=call.view,
|
256
|
+
internal=call.internal,
|
207
257
|
pending=True,
|
208
258
|
)
|
209
259
|
transcript()._event(event)
|
@@ -211,22 +261,23 @@ async def call_tools(
|
|
211
261
|
# execute the tool call. if the operator cancels the
|
212
262
|
# tool call then synthesize the appropriate message/event
|
213
263
|
send_stream, receive_stream = anyio.create_memory_object_stream[
|
214
|
-
tuple[
|
264
|
+
tuple[ExecuteToolsResult, ToolEvent, Exception | None]
|
215
265
|
]()
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
266
|
+
|
267
|
+
async with anyio.create_task_group() as tg:
|
268
|
+
tg.start_soon(call_tool_task, call, messages, send_stream)
|
269
|
+
event._set_cancel_fn(tg.cancel_scope.cancel)
|
270
|
+
async with receive_stream:
|
271
|
+
(
|
272
|
+
result,
|
273
|
+
result_event,
|
274
|
+
result_exception,
|
275
|
+
) = await receive_stream.receive()
|
224
276
|
|
225
277
|
if event.cancelled:
|
226
278
|
tool_message = ChatMessageTool(
|
227
279
|
content="",
|
228
280
|
function=call.function,
|
229
|
-
internal_name=call.internal_name,
|
230
281
|
tool_call_id=call.id,
|
231
282
|
error=ToolCallError(
|
232
283
|
"timeout", "Command timed out before completing."
|
@@ -236,7 +287,6 @@ async def call_tools(
|
|
236
287
|
id=call.id,
|
237
288
|
function=call.function,
|
238
289
|
arguments=call.arguments,
|
239
|
-
internal_name=call.internal_name,
|
240
290
|
result=tool_message.content,
|
241
291
|
truncated=None,
|
242
292
|
view=call.view,
|
@@ -246,12 +296,14 @@ async def call_tools(
|
|
246
296
|
transcript().info(
|
247
297
|
f"Tool call '{call.function}' was cancelled by operator."
|
248
298
|
)
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
299
|
+
result_messages.append(tool_message)
|
300
|
+
display_conversation_message(tool_message)
|
301
|
+
else:
|
302
|
+
for message in result.messages:
|
303
|
+
result_messages.append(message)
|
304
|
+
display_conversation_message(message)
|
305
|
+
if result.output is not None:
|
306
|
+
result_output = result.output
|
255
307
|
|
256
308
|
# update the event with the results
|
257
309
|
waiting_time_end = sample_waiting_time()
|
@@ -261,17 +313,29 @@ async def call_tools(
|
|
261
313
|
error=result_event.error,
|
262
314
|
events=result_event.events,
|
263
315
|
waiting_time=waiting_time_end - waiting_time_start,
|
316
|
+
agent=result_event.agent,
|
317
|
+
failed=True if result_exception else None,
|
264
318
|
)
|
265
319
|
transcript()._event_updated(event)
|
266
320
|
|
321
|
+
# if there was an exception then re-raise it -- we do this
|
322
|
+
# after updating the event so that we flush the transcript
|
323
|
+
# for the event
|
324
|
+
if result_exception is not None:
|
325
|
+
raise result_exception
|
326
|
+
|
267
327
|
# return tool messages
|
268
|
-
return
|
328
|
+
return ExecuteToolsResult(result_messages, result_output)
|
269
329
|
|
270
330
|
else:
|
271
|
-
return []
|
331
|
+
return ExecuteToolsResult([])
|
332
|
+
|
272
333
|
|
334
|
+
async def call_tool(
|
335
|
+
tools: list[ToolDef], message: str, call: ToolCall, conversation: list[ChatMessage]
|
336
|
+
) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str | None]:
|
337
|
+
from inspect_ai.agent._handoff import AgentTool
|
273
338
|
|
274
|
-
async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
|
275
339
|
# if there was an error parsing the ToolCall, raise that
|
276
340
|
if call.parse_error:
|
277
341
|
raise ToolParsingError(call.parse_error)
|
@@ -302,10 +366,122 @@ async def call_tool(tools: list[ToolDef], message: str, call: ToolCall) -> Any:
|
|
302
366
|
with trace_action(
|
303
367
|
logger, "Tool Call", format_function_call(tool_def.name, arguments, width=1000)
|
304
368
|
):
|
305
|
-
|
369
|
+
# agent tools get special handling
|
370
|
+
if isinstance(tool_def.tool, AgentTool):
|
371
|
+
return await agent_handoff(tool_def, call, conversation)
|
306
372
|
|
307
|
-
|
308
|
-
|
373
|
+
# normal tool call
|
374
|
+
else:
|
375
|
+
arguments = tool_params(call.arguments, tool_def.tool)
|
376
|
+
result: ToolResult = await tool_def.tool(**arguments)
|
377
|
+
return result, [], None, None
|
378
|
+
|
379
|
+
|
380
|
+
async def agent_handoff(
|
381
|
+
tool_def: ToolDef, call: ToolCall, conversation: list[ChatMessage]
|
382
|
+
) -> tuple[ToolResult, list[ChatMessage], ModelOutput | None, str]:
|
383
|
+
from inspect_ai.agent._agent import AgentState
|
384
|
+
from inspect_ai.agent._handoff import AgentTool
|
385
|
+
|
386
|
+
# alias agent tool and get agent name
|
387
|
+
agent_tool = cast(AgentTool, tool_def.tool)
|
388
|
+
agent_name = registry_unqualified_name(agent_tool.agent)
|
389
|
+
|
390
|
+
# copy list
|
391
|
+
agent_conversation = copy(conversation)
|
392
|
+
|
393
|
+
# remove other tool calls from the assistant message so the
|
394
|
+
# conversation remains valid (the model may have called multiple
|
395
|
+
# tools in parallel and we won't be handling the other calls)
|
396
|
+
last_message = agent_conversation[-1]
|
397
|
+
if isinstance(last_message, ChatMessageAssistant) and last_message.tool_calls:
|
398
|
+
agent_conversation[-1] = agent_conversation[-1].model_copy(
|
399
|
+
update=dict(
|
400
|
+
tool_calls=[
|
401
|
+
tool_call
|
402
|
+
for tool_call in last_message.tool_calls
|
403
|
+
if tool_call.id == call.id
|
404
|
+
]
|
405
|
+
)
|
406
|
+
)
|
407
|
+
|
408
|
+
# ammend the conversation with a ChatMessageTool to indicate
|
409
|
+
# to the downstream agent that we satisfied the call
|
410
|
+
tool_result = f"Successfully transferred to {agent_name}."
|
411
|
+
agent_conversation.append(
|
412
|
+
ChatMessageTool(
|
413
|
+
content=tool_result,
|
414
|
+
tool_call_id=call.id,
|
415
|
+
function=call.function,
|
416
|
+
internal=call.internal,
|
417
|
+
)
|
418
|
+
)
|
419
|
+
|
420
|
+
# run input filter if we have one
|
421
|
+
if agent_tool.input_filter is not None:
|
422
|
+
agent_conversation = await agent_tool.input_filter(agent_conversation)
|
423
|
+
|
424
|
+
# remove system messages (as they can refer to tools or other special
|
425
|
+
# instructions that don't apply to the sub-agent)
|
426
|
+
agent_conversation = [
|
427
|
+
m for m in agent_conversation if not isinstance(m, ChatMessageSystem)
|
428
|
+
]
|
429
|
+
|
430
|
+
# inject curried args
|
431
|
+
arguments = {**call.arguments, **agent_tool.kwargs}
|
432
|
+
|
433
|
+
# parse arguments
|
434
|
+
arguments = tool_params(arguments, agent_tool.agent)
|
435
|
+
del arguments["state"]
|
436
|
+
|
437
|
+
# make the call
|
438
|
+
agent_state = AgentState(messages=copy(agent_conversation))
|
439
|
+
agent_state = await agent_tool.agent(agent_state, **arguments)
|
440
|
+
|
441
|
+
# determine which messages are new and return only those (but exclude new
|
442
|
+
# system messages as they an internal matter for the handed off to agent.
|
443
|
+
# also, inject the agent's name as a prefix in assistant messages
|
444
|
+
conversation_message_ids = [message.id for message in agent_conversation]
|
445
|
+
agent_messages: list[ChatMessage] = []
|
446
|
+
for m in agent_state.messages:
|
447
|
+
if m.id not in conversation_message_ids:
|
448
|
+
if isinstance(m, ChatMessageAssistant):
|
449
|
+
m = prepend_agent_name(m, agent_name)
|
450
|
+
if not isinstance(m, ChatMessageSystem):
|
451
|
+
agent_messages.append(m)
|
452
|
+
|
453
|
+
# run output filter if we have one
|
454
|
+
if agent_tool.output_filter is not None:
|
455
|
+
agent_messages = await agent_tool.output_filter(agent_messages)
|
456
|
+
|
457
|
+
# if we end with an assistant message then add a user message
|
458
|
+
# so that the calling agent carries on
|
459
|
+
if len(agent_messages) == 0 or isinstance(agent_messages[-1], ChatMessageAssistant):
|
460
|
+
agent_messages.append(
|
461
|
+
ChatMessageUser(content=f"The {agent_name} agent has completed its work.")
|
462
|
+
)
|
463
|
+
|
464
|
+
return (tool_result, agent_messages, agent_state.output, agent_name)
|
465
|
+
|
466
|
+
|
467
|
+
def prepend_agent_name(
|
468
|
+
message: ChatMessageAssistant, agent_name: str
|
469
|
+
) -> ChatMessageAssistant:
|
470
|
+
if isinstance(message.content, str):
|
471
|
+
return message.model_copy(
|
472
|
+
update=dict(content=f"[{agent_name}] {message.content}")
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
content = copy(message.content)
|
476
|
+
for i in range(0, len(content)):
|
477
|
+
if isinstance(content[i], ContentText):
|
478
|
+
content[i] = content[i].model_copy(
|
479
|
+
update=dict(
|
480
|
+
text=f"[{agent_name}] {cast(ContentText, content[i]).text}"
|
481
|
+
)
|
482
|
+
)
|
483
|
+
break
|
484
|
+
return message.model_copy(update=dict(content=content))
|
309
485
|
|
310
486
|
|
311
487
|
def tools_info(
|
@@ -441,7 +617,7 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
|
|
441
617
|
else:
|
442
618
|
return input
|
443
619
|
elif origin is Union or origin is types.UnionType:
|
444
|
-
if args[1] is type(None):
|
620
|
+
if args[1] is type(None) and input is not None:
|
445
621
|
return tool_param(args[0], input)
|
446
622
|
else:
|
447
623
|
return input
|
@@ -559,6 +735,34 @@ def parse_tool_call(
|
|
559
735
|
id=id,
|
560
736
|
function=function,
|
561
737
|
arguments=arguments_dict,
|
562
|
-
type="function",
|
563
738
|
parse_error=error,
|
564
739
|
)
|
740
|
+
|
741
|
+
|
742
|
+
async def call_tools(
|
743
|
+
message: ChatMessageAssistant,
|
744
|
+
tools: list[Tool] | list[ToolDef] | list[Tool | ToolDef],
|
745
|
+
max_output: int | None = None,
|
746
|
+
) -> list[ChatMessageTool]:
|
747
|
+
"""Perform tool calls in assistant message.
|
748
|
+
|
749
|
+
This method is deprecated. Use the `execute_tools()` method instead
|
750
|
+
(which correctly handles agent `handoff()` tools).
|
751
|
+
|
752
|
+
Args:
|
753
|
+
message: Assistant message.
|
754
|
+
tools (list[Tool]): Available tools
|
755
|
+
max_output (int | None): Maximum output length (in bytes).
|
756
|
+
Defaults to max_tool_output from active GenerateConfig
|
757
|
+
(16 * 1024 by default).
|
758
|
+
|
759
|
+
Returns:
|
760
|
+
Messages added to the conversation.
|
761
|
+
"""
|
762
|
+
warn_once(
|
763
|
+
logger,
|
764
|
+
"call_tools is deprecated -- please use execute_tools instead (as it supports agent handoff tools)",
|
765
|
+
)
|
766
|
+
|
767
|
+
messages, _ = await execute_tools([message], tools, max_output)
|
768
|
+
return [m for m in messages if isinstance(m, ChatMessageTool)]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from logging import getLogger
|
2
2
|
from typing import Any, Literal, Type, Union
|
3
3
|
|
4
|
-
from pydantic import BaseModel, Field, model_validator
|
4
|
+
from pydantic import BaseModel, Field, JsonValue, model_validator
|
5
5
|
from shortuuid import uuid
|
6
6
|
|
7
7
|
from inspect_ai._util.constants import DESERIALIZING
|
@@ -26,6 +26,9 @@ class ChatMessageBase(BaseModel):
|
|
26
26
|
source: Literal["input", "generate"] | None = Field(default=None)
|
27
27
|
"""Source of message."""
|
28
28
|
|
29
|
+
internal: JsonValue | None = Field(default=None)
|
30
|
+
"""Model provider specific payload - typically used to aid transformation back to model types."""
|
31
|
+
|
29
32
|
def model_post_init(self, __context: Any) -> None:
|
30
33
|
# check if deserializing
|
31
34
|
is_deserializing = isinstance(__context, dict) and __context.get(
|
@@ -105,6 +108,9 @@ class ChatMessageAssistant(ChatMessageBase):
|
|
105
108
|
tool_calls: list[ToolCall] | None = Field(default=None)
|
106
109
|
"""Tool calls made by the model."""
|
107
110
|
|
111
|
+
model: str | None = Field(default=None)
|
112
|
+
"""Model used to generate assistant message."""
|
113
|
+
|
108
114
|
# Some OpenAI compatible REST endpoints include reasoning as a field alongside
|
109
115
|
# content, however since this field doesn't exist in the OpenAI interface,
|
110
116
|
# hosting providers (so far we've seen this with Together and Groq) may
|
@@ -158,9 +164,6 @@ class ChatMessageTool(ChatMessageBase):
|
|
158
164
|
function: str | None = Field(default=None)
|
159
165
|
"""Name of function called."""
|
160
166
|
|
161
|
-
internal_name: str | None = Field(default=None)
|
162
|
-
"""Internal name for tool (if any)."""
|
163
|
-
|
164
167
|
error: ToolCallError | None = Field(default=None)
|
165
168
|
"""Error which occurred during tool call."""
|
166
169
|
|
@@ -1,67 +1,18 @@
|
|
1
|
-
from
|
2
|
-
from rich.text import Text
|
1
|
+
from typing import Protocol
|
3
2
|
|
4
|
-
from
|
5
|
-
from
|
6
|
-
from inspect_ai._util.transcript import transcript_markdown, transcript_reasoning
|
7
|
-
from inspect_ai.util._conversation import conversation_panel
|
8
|
-
from inspect_ai.util._display import display_type
|
3
|
+
from ._chat_message import ChatMessage
|
4
|
+
from ._model_output import ModelOutput
|
9
5
|
|
10
|
-
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
11
|
-
from ._render import messages_preceding_assistant, render_tool_calls
|
12
6
|
|
13
|
-
|
7
|
+
class ModelConversation(Protocol):
|
8
|
+
"""Model conversation."""
|
14
9
|
|
10
|
+
@property
|
11
|
+
def messages(self) -> list[ChatMessage]:
|
12
|
+
"""Conversation history."""
|
13
|
+
...
|
15
14
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
message.error.message.strip() if message.error else message.text.strip()
|
21
|
-
)
|
22
|
-
if output:
|
23
|
-
content = lines_display(output, 50)
|
24
|
-
|
25
|
-
conversation_panel(
|
26
|
-
title=f"Tool Output: {message.function}",
|
27
|
-
content=content,
|
28
|
-
)
|
29
|
-
|
30
|
-
|
31
|
-
def conversation_assistant_message(
|
32
|
-
input: list[ChatMessage], message: ChatMessageAssistant
|
33
|
-
) -> None:
|
34
|
-
if display_type() == "conversation":
|
35
|
-
# print precding messages that aren't tool or assistant
|
36
|
-
for m in messages_preceding_assistant(input):
|
37
|
-
conversation_panel(
|
38
|
-
title=m.role.capitalize(),
|
39
|
-
content=transcript_markdown(m.text, escape=True),
|
40
|
-
)
|
41
|
-
|
42
|
-
# build content
|
43
|
-
content: list[RenderableType] = []
|
44
|
-
|
45
|
-
# deal with plain text or with content blocks
|
46
|
-
if isinstance(message.content, str):
|
47
|
-
content.extend([transcript_markdown(message.text.strip(), escape=True)])
|
48
|
-
else:
|
49
|
-
for c in message.content:
|
50
|
-
if isinstance(c, ContentReasoning):
|
51
|
-
content.extend(transcript_reasoning(c))
|
52
|
-
elif isinstance(c, ContentText) and c.text:
|
53
|
-
content.extend([transcript_markdown(c.text.strip(), escape=True)])
|
54
|
-
|
55
|
-
# print tool calls
|
56
|
-
if message.tool_calls:
|
57
|
-
if content:
|
58
|
-
content.append(Text())
|
59
|
-
content.extend(render_tool_calls(message.tool_calls))
|
60
|
-
|
61
|
-
# print the assistant message
|
62
|
-
conversation_panel(title="Assistant", content=content)
|
63
|
-
|
64
|
-
|
65
|
-
def conversation_assistant_error(error: Exception) -> None:
|
66
|
-
if display_type() == "conversation":
|
67
|
-
conversation_panel(title="Assistant", content=repr(error))
|
15
|
+
@property
|
16
|
+
def output(self) -> ModelOutput:
|
17
|
+
"""Model output."""
|
18
|
+
...
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from rich.console import RenderableType
|
2
|
+
from rich.text import Text
|
3
|
+
|
4
|
+
from inspect_ai._util.content import ContentReasoning, ContentText
|
5
|
+
from inspect_ai._util.rich import lines_display
|
6
|
+
from inspect_ai._util.transcript import transcript_markdown, transcript_reasoning
|
7
|
+
from inspect_ai.util._conversation import conversation_panel
|
8
|
+
from inspect_ai.util._display import display_type
|
9
|
+
|
10
|
+
from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool
|
11
|
+
from ._render import messages_preceding_assistant, render_tool_calls
|
12
|
+
|
13
|
+
MESSAGE_TITLE = "Message"
|
14
|
+
|
15
|
+
|
16
|
+
def display_conversation_message(message: ChatMessage) -> None:
|
17
|
+
if display_type() == "conversation":
|
18
|
+
if isinstance(message, ChatMessageTool):
|
19
|
+
display_conversation_tool_message(message)
|
20
|
+
elif isinstance(message, ChatMessageAssistant):
|
21
|
+
display_conversation_assistant_message(message)
|
22
|
+
else:
|
23
|
+
conversation_panel(
|
24
|
+
title=message.role.capitalize(),
|
25
|
+
content=transcript_markdown(message.text, escape=True),
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
def display_conversation_tool_message(message: ChatMessageTool) -> None:
|
30
|
+
if display_type() == "conversation":
|
31
|
+
# truncate output to 100 lines
|
32
|
+
output = (
|
33
|
+
message.error.message.strip() if message.error else message.text.strip()
|
34
|
+
)
|
35
|
+
if output:
|
36
|
+
content = lines_display(output, 50)
|
37
|
+
|
38
|
+
conversation_panel(
|
39
|
+
title=f"Tool Output: {message.function}",
|
40
|
+
content=content,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
def display_conversation_assistant_message(message: ChatMessageAssistant) -> None:
|
45
|
+
# build content
|
46
|
+
content: list[RenderableType] = []
|
47
|
+
|
48
|
+
# deal with plain text or with content blocks
|
49
|
+
if isinstance(message.content, str):
|
50
|
+
content.extend([transcript_markdown(message.text.strip(), escape=True)])
|
51
|
+
else:
|
52
|
+
for c in message.content:
|
53
|
+
if isinstance(c, ContentReasoning):
|
54
|
+
content.extend(transcript_reasoning(c))
|
55
|
+
elif isinstance(c, ContentText) and c.text:
|
56
|
+
content.extend([transcript_markdown(c.text.strip(), escape=True)])
|
57
|
+
|
58
|
+
# print tool calls
|
59
|
+
if message.tool_calls:
|
60
|
+
if content:
|
61
|
+
content.append(Text())
|
62
|
+
content.extend(render_tool_calls(message.tool_calls))
|
63
|
+
|
64
|
+
# print the assistant message
|
65
|
+
conversation_panel(title="Assistant", content=content)
|
66
|
+
|
67
|
+
|
68
|
+
def display_conversation_assistant(
|
69
|
+
input: list[ChatMessage], message: ChatMessageAssistant
|
70
|
+
) -> None:
|
71
|
+
if display_type() == "conversation":
|
72
|
+
# print precding messages that aren't tool or assistant
|
73
|
+
for m in messages_preceding_assistant(input):
|
74
|
+
conversation_panel(
|
75
|
+
title=m.role.capitalize(),
|
76
|
+
content=transcript_markdown(m.text, escape=True),
|
77
|
+
)
|
78
|
+
|
79
|
+
# show assistant message
|
80
|
+
display_conversation_assistant_message(message)
|
81
|
+
|
82
|
+
|
83
|
+
def display_conversation_assistant_error(error: Exception) -> None:
|
84
|
+
if display_type() == "conversation":
|
85
|
+
conversation_panel(title="Assistant", content=repr(error))
|