inspect-ai 0.3.82__py3-none-any.whl → 0.3.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/__init__.py +2 -1
- inspect_ai/_display/textual/app.py +14 -3
- inspect_ai/_display/textual/display.py +4 -0
- inspect_ai/_display/textual/widgets/samples.py +9 -3
- inspect_ai/_display/textual/widgets/task_detail.py +3 -4
- inspect_ai/_display/textual/widgets/tasks.py +17 -1
- inspect_ai/_display/textual/widgets/vscode.py +48 -0
- inspect_ai/_eval/eval.py +36 -24
- inspect_ai/_eval/evalset.py +17 -18
- inspect_ai/_eval/loader.py +34 -11
- inspect_ai/_eval/run.py +8 -13
- inspect_ai/_eval/score.py +13 -3
- inspect_ai/_eval/task/generate.py +8 -9
- inspect_ai/_eval/task/log.py +2 -0
- inspect_ai/_eval/task/task.py +23 -9
- inspect_ai/_util/file.py +13 -0
- inspect_ai/_util/json.py +2 -1
- inspect_ai/_util/registry.py +1 -0
- inspect_ai/_util/vscode.py +37 -0
- inspect_ai/_view/www/App.css +6 -0
- inspect_ai/_view/www/dist/assets/index.css +304 -128
- inspect_ai/_view/www/dist/assets/index.js +47495 -27519
- inspect_ai/_view/www/log-schema.json +124 -31
- inspect_ai/_view/www/package.json +3 -0
- inspect_ai/_view/www/src/App.tsx +12 -0
- inspect_ai/_view/www/src/appearance/icons.ts +1 -0
- inspect_ai/_view/www/src/components/Card.tsx +6 -4
- inspect_ai/_view/www/src/components/LinkButton.module.css +16 -0
- inspect_ai/_view/www/src/components/LinkButton.tsx +33 -0
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +1 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +113 -23
- inspect_ai/_view/www/src/components/Modal.module.css +38 -0
- inspect_ai/_view/www/src/components/Modal.tsx +77 -0
- inspect_ai/_view/www/src/plan/DetailStep.module.css +4 -0
- inspect_ai/_view/www/src/plan/DetailStep.tsx +6 -3
- inspect_ai/_view/www/src/plan/SolverDetailView.module.css +2 -1
- inspect_ai/_view/www/src/samples/InlineSampleDisplay.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDialog.tsx +7 -0
- inspect_ai/_view/www/src/samples/SampleDisplay.tsx +11 -34
- inspect_ai/_view/www/src/samples/SampleSummaryView.module.css +6 -0
- inspect_ai/_view/www/src/samples/SampleSummaryView.tsx +2 -2
- inspect_ai/_view/www/src/samples/SamplesTools.tsx +12 -0
- inspect_ai/_view/www/src/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/samples/chat/messages.ts +3 -1
- inspect_ai/_view/www/src/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/samples/descriptor/samplesDescriptor.tsx +9 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.module.css +3 -3
- inspect_ai/_view/www/src/samples/descriptor/score/BooleanScoreDescriptor.tsx +1 -1
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.module.css +4 -4
- inspect_ai/_view/www/src/samples/descriptor/score/ObjectScoreDescriptor.tsx +10 -11
- inspect_ai/_view/www/src/samples/list/SampleFooter.module.css +2 -1
- inspect_ai/_view/www/src/samples/list/SampleFooter.tsx +7 -1
- inspect_ai/_view/www/src/samples/list/SampleList.tsx +25 -8
- inspect_ai/_view/www/src/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/samples/scores/SampleScores.tsx +11 -22
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.module.css +38 -0
- inspect_ai/_view/www/src/samples/scores/SampleScoresGrid.tsx +118 -0
- inspect_ai/_view/www/src/samples/scores/{SampleScoreView.module.css → SampleScoresView.module.css} +10 -1
- inspect_ai/_view/www/src/samples/scores/SampleScoresView.tsx +78 -0
- inspect_ai/_view/www/src/samples/transcript/SampleLimitEventView.tsx +3 -3
- inspect_ai/_view/www/src/samples/transcript/ToolEventView.tsx +25 -4
- inspect_ai/_view/www/src/samples/transcript/event/EventPanel.tsx +29 -2
- inspect_ai/_view/www/src/samples/transcript/state/StateEventRenderers.tsx +0 -1
- inspect_ai/_view/www/src/state/hooks.ts +5 -3
- inspect_ai/_view/www/src/state/logPolling.ts +5 -1
- inspect_ai/_view/www/src/state/logSlice.ts +10 -0
- inspect_ai/_view/www/src/state/samplePolling.ts +4 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +13 -0
- inspect_ai/_view/www/src/types/log.d.ts +34 -26
- inspect_ai/_view/www/src/types/markdown-it-katex.d.ts +21 -0
- inspect_ai/_view/www/src/utils/json-worker.ts +79 -12
- inspect_ai/_view/www/src/workspace/WorkSpace.tsx +18 -16
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.module.css +16 -0
- inspect_ai/_view/www/src/workspace/navbar/ResultsPanel.tsx +68 -71
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.module.css +35 -0
- inspect_ai/_view/www/src/workspace/navbar/ScoreGrid.tsx +117 -0
- inspect_ai/_view/www/src/workspace/navbar/SecondaryBar.tsx +1 -1
- inspect_ai/_view/www/src/workspace/sidebar/Sidebar.module.css +3 -2
- inspect_ai/_view/www/src/workspace/tabs/SamplesTab.tsx +18 -0
- inspect_ai/_view/www/yarn.lock +94 -1
- inspect_ai/agent/__init__.py +36 -0
- inspect_ai/agent/_agent.py +268 -0
- inspect_ai/agent/_as_solver.py +72 -0
- inspect_ai/agent/_as_tool.py +122 -0
- inspect_ai/{solver → agent}/_bridge/bridge.py +23 -37
- inspect_ai/{solver → agent}/_bridge/patch.py +9 -8
- inspect_ai/agent/_filter.py +46 -0
- inspect_ai/agent/_handoff.py +93 -0
- inspect_ai/{solver/_human_agent → agent/_human}/agent.py +11 -12
- inspect_ai/{solver/_human_agent → agent/_human}/commands/__init__.py +2 -3
- inspect_ai/{solver/_human_agent → agent/_human}/commands/clock.py +3 -1
- inspect_ai/{solver/_human_agent → agent/_human}/commands/score.py +5 -5
- inspect_ai/{solver/_human_agent → agent/_human}/install.py +6 -3
- inspect_ai/{solver/_human_agent → agent/_human}/service.py +7 -3
- inspect_ai/{solver/_human_agent → agent/_human}/state.py +5 -5
- inspect_ai/agent/_react.py +241 -0
- inspect_ai/agent/_run.py +36 -0
- inspect_ai/agent/_types.py +81 -0
- inspect_ai/log/_log.py +11 -2
- inspect_ai/log/_transcript.py +13 -9
- inspect_ai/model/__init__.py +7 -1
- inspect_ai/model/_call_tools.py +256 -52
- inspect_ai/model/_chat_message.py +7 -4
- inspect_ai/model/_conversation.py +13 -62
- inspect_ai/model/_display.py +85 -0
- inspect_ai/model/_model.py +113 -14
- inspect_ai/model/_model_output.py +14 -9
- inspect_ai/model/_openai.py +16 -4
- inspect_ai/model/_openai_computer_use.py +162 -0
- inspect_ai/model/_openai_responses.py +319 -165
- inspect_ai/model/_providers/anthropic.py +20 -21
- inspect_ai/model/_providers/azureai.py +24 -13
- inspect_ai/model/_providers/bedrock.py +1 -7
- inspect_ai/model/_providers/cloudflare.py +3 -3
- inspect_ai/model/_providers/goodfire.py +2 -6
- inspect_ai/model/_providers/google.py +11 -10
- inspect_ai/model/_providers/groq.py +6 -3
- inspect_ai/model/_providers/hf.py +7 -3
- inspect_ai/model/_providers/mistral.py +7 -10
- inspect_ai/model/_providers/openai.py +47 -17
- inspect_ai/model/_providers/openai_o1.py +11 -4
- inspect_ai/model/_providers/openai_responses.py +12 -14
- inspect_ai/model/_providers/providers.py +2 -2
- inspect_ai/model/_providers/together.py +12 -2
- inspect_ai/model/_providers/util/chatapi.py +7 -2
- inspect_ai/model/_providers/util/hf_handler.py +4 -2
- inspect_ai/model/_providers/util/llama31.py +4 -2
- inspect_ai/model/_providers/vertex.py +11 -9
- inspect_ai/model/_providers/vllm.py +4 -4
- inspect_ai/scorer/__init__.py +2 -0
- inspect_ai/scorer/_metrics/__init__.py +2 -0
- inspect_ai/scorer/_metrics/grouped.py +84 -0
- inspect_ai/scorer/_score.py +26 -6
- inspect_ai/solver/__init__.py +2 -2
- inspect_ai/solver/_basic_agent.py +22 -9
- inspect_ai/solver/_bridge.py +31 -0
- inspect_ai/solver/_chain.py +20 -12
- inspect_ai/solver/_fork.py +5 -1
- inspect_ai/solver/_human_agent.py +52 -0
- inspect_ai/solver/_prompt.py +3 -1
- inspect_ai/solver/_run.py +59 -0
- inspect_ai/solver/_solver.py +14 -4
- inspect_ai/solver/_task_state.py +5 -3
- inspect_ai/tool/_tool_call.py +15 -8
- inspect_ai/tool/_tool_def.py +17 -12
- inspect_ai/tool/_tool_support_helpers.py +2 -2
- inspect_ai/tool/_tool_with.py +14 -11
- inspect_ai/tool/_tools/_bash_session.py +11 -2
- inspect_ai/tool/_tools/_computer/_common.py +18 -2
- inspect_ai/tool/_tools/_computer/_computer.py +18 -2
- inspect_ai/tool/_tools/_computer/_resources/tool/_constants.py +2 -0
- inspect_ai/tool/_tools/_computer/_resources/tool/_x11_client.py +17 -0
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_browser/_web_browser.py +100 -61
- inspect_ai/util/__init__.py +2 -0
- inspect_ai/util/_anyio.py +27 -0
- inspect_ai/util/_sandbox/__init__.py +2 -1
- inspect_ai/util/_sandbox/context.py +32 -7
- inspect_ai/util/_sandbox/docker/cleanup.py +4 -0
- inspect_ai/util/_sandbox/docker/compose.py +2 -2
- inspect_ai/util/_sandbox/docker/docker.py +12 -1
- inspect_ai/util/_store_model.py +30 -7
- inspect_ai/util/_subprocess.py +13 -3
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/RECORD +179 -153
- inspect_ai/_view/www/src/samples/scores/SampleScoreView.tsx +0 -167
- /inspect_ai/{solver → agent}/_bridge/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/__init__.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/command.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/instructions.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/note.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/status.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/commands/submit.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/panel.py +0 -0
- /inspect_ai/{solver/_human_agent → agent/_human}/view.py +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.82.dist-info → inspect_ai-0.3.84.dist-info}/top_level.txt +0 -0
inspect_ai/model/_model.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6
6
|
import os
|
7
7
|
import time
|
8
8
|
from contextvars import ContextVar
|
9
|
-
from copy import deepcopy
|
9
|
+
from copy import copy, deepcopy
|
10
10
|
from datetime import datetime
|
11
11
|
from types import TracebackType
|
12
12
|
from typing import Any, AsyncIterator, Callable, Literal, Type, cast
|
@@ -45,11 +45,17 @@ from inspect_ai._util.retry import report_http_retry
|
|
45
45
|
from inspect_ai._util.trace import trace_action
|
46
46
|
from inspect_ai._util.working import report_sample_waiting_time, sample_working_time
|
47
47
|
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
|
48
|
+
from inspect_ai.tool._tool_call import ToolCallModelInputHints
|
48
49
|
from inspect_ai.tool._tool_def import ToolDef, tool_defs
|
49
50
|
from inspect_ai.util import concurrency
|
50
51
|
|
51
52
|
from ._cache import CacheEntry, CachePolicy, cache_fetch, cache_store
|
52
|
-
from ._call_tools import
|
53
|
+
from ._call_tools import (
|
54
|
+
disable_parallel_tools,
|
55
|
+
execute_tools,
|
56
|
+
tool_call_view,
|
57
|
+
tools_info,
|
58
|
+
)
|
53
59
|
from ._chat_message import (
|
54
60
|
ChatMessage,
|
55
61
|
ChatMessageAssistant,
|
@@ -57,7 +63,10 @@ from ._chat_message import (
|
|
57
63
|
ChatMessageTool,
|
58
64
|
ChatMessageUser,
|
59
65
|
)
|
60
|
-
from .
|
66
|
+
from ._display import (
|
67
|
+
display_conversation_assistant,
|
68
|
+
display_conversation_assistant_error,
|
69
|
+
)
|
61
70
|
from ._generate_config import (
|
62
71
|
GenerateConfig,
|
63
72
|
active_generate_config,
|
@@ -123,9 +132,20 @@ class ModelAPI(abc.ABC):
|
|
123
132
|
# set any explicitly specified api key
|
124
133
|
self.api_key = api_key
|
125
134
|
|
126
|
-
async def
|
127
|
-
"""
|
128
|
-
|
135
|
+
async def aclose(self) -> None:
|
136
|
+
"""Async close method for closing any client allocated for the model."""
|
137
|
+
self.close()
|
138
|
+
|
139
|
+
def close(self) -> None:
|
140
|
+
"""Sync close method for closing any client allocated for the model."""
|
141
|
+
# if this is is called and aclose is implemented by a subclass then
|
142
|
+
# raise a runtime error (as this model reuqires async close)
|
143
|
+
aclose_method = getattr(self.__class__, "aclose")
|
144
|
+
base_aclose_method = getattr(ModelAPI, "aclose")
|
145
|
+
if aclose_method != base_aclose_method:
|
146
|
+
raise RuntimeError(
|
147
|
+
f"{self.__class__.__name__} models require an async close / context manager."
|
148
|
+
)
|
129
149
|
|
130
150
|
@abc.abstractmethod
|
131
151
|
async def generate(
|
@@ -201,6 +221,10 @@ class ModelAPI(abc.ABC):
|
|
201
221
|
"""Tool results can contain images"""
|
202
222
|
return False
|
203
223
|
|
224
|
+
def disable_computer_screenshot_truncation(self) -> bool:
|
225
|
+
"""Some models do not support truncation of computer screenshots."""
|
226
|
+
return False
|
227
|
+
|
204
228
|
def emulate_reasoning_history(self) -> bool:
|
205
229
|
"""Chat message assistant messages with reasoning should playback reasoning with emulation (.e.g. <think> tags)"""
|
206
230
|
return True
|
@@ -255,10 +279,23 @@ class Model:
|
|
255
279
|
# get hit before score() or eval() so we activate nest_asyncio
|
256
280
|
platform_init()
|
257
281
|
|
258
|
-
|
282
|
+
def __enter__(self: "Model") -> "Model":
|
259
283
|
self._context_bound = True
|
260
284
|
return self
|
261
285
|
|
286
|
+
async def __aenter__(self: "Model") -> "Model":
|
287
|
+
return self.__enter__()
|
288
|
+
|
289
|
+
def __exit__(
|
290
|
+
self,
|
291
|
+
exc_type: type[BaseException] | None,
|
292
|
+
exc: BaseException | None,
|
293
|
+
exc_tb: TracebackType | None,
|
294
|
+
) -> None:
|
295
|
+
if not self._closed:
|
296
|
+
self.api.close()
|
297
|
+
self._closed = True
|
298
|
+
|
262
299
|
async def __aexit__(
|
263
300
|
self,
|
264
301
|
exc_type: type[BaseException] | None,
|
@@ -266,7 +303,7 @@ class Model:
|
|
266
303
|
exc_tb: TracebackType | None,
|
267
304
|
) -> None:
|
268
305
|
if not self._closed:
|
269
|
-
await self.api.
|
306
|
+
await self.api.aclose()
|
270
307
|
self._closed = True
|
271
308
|
|
272
309
|
@property
|
@@ -373,6 +410,55 @@ class Model:
|
|
373
410
|
# return output
|
374
411
|
return output
|
375
412
|
|
413
|
+
async def generate_loop(
|
414
|
+
self,
|
415
|
+
input: str | list[ChatMessage],
|
416
|
+
tools: list[Tool] | list[ToolDef] | list[Tool | ToolDef] = [],
|
417
|
+
config: GenerateConfig = GenerateConfig(),
|
418
|
+
cache: bool | CachePolicy = False,
|
419
|
+
) -> tuple[list[ChatMessage], ModelOutput]:
|
420
|
+
"""Generate output from the model, looping as long as the model calls tools.
|
421
|
+
|
422
|
+
Similar to `generate()`, but runs in a loop resolving model tool calls.
|
423
|
+
The loop terminates when the model stops calling tools. The final `ModelOutput`
|
424
|
+
as well the message list for the conversation are returned as a tuple.
|
425
|
+
|
426
|
+
Args:
|
427
|
+
input: Chat message input (if a `str` is passed it is converted
|
428
|
+
to a `ChatMessageUser`).
|
429
|
+
tools: Tools available for the model to call.
|
430
|
+
config: Model configuration.
|
431
|
+
cache: Caching behavior for generate responses (defaults to no caching).
|
432
|
+
|
433
|
+
Returns:
|
434
|
+
Tuple of list[ChatMessage], ModelOutput
|
435
|
+
"""
|
436
|
+
# initialise messages
|
437
|
+
input = [ChatMessageUser(content=input)] if isinstance(input, str) else input
|
438
|
+
messages = copy(input)
|
439
|
+
while True:
|
440
|
+
# call model
|
441
|
+
output = await self.generate(
|
442
|
+
input=messages,
|
443
|
+
tools=tools, # type:ignore[arg-type]
|
444
|
+
config=config,
|
445
|
+
cache=cache,
|
446
|
+
)
|
447
|
+
|
448
|
+
# append to new messages
|
449
|
+
messages.append(output.message)
|
450
|
+
|
451
|
+
# make tool calls or terminate if there are none
|
452
|
+
if output.message.tool_calls:
|
453
|
+
tools_messages, tools_output = await execute_tools(
|
454
|
+
messages, tools, config.max_tool_output
|
455
|
+
)
|
456
|
+
messages.extend(tools_messages)
|
457
|
+
if tools_output is not None:
|
458
|
+
output = tools_output
|
459
|
+
else:
|
460
|
+
return messages[len(input) :], output
|
461
|
+
|
376
462
|
async def _generate(
|
377
463
|
self,
|
378
464
|
input: list[ChatMessage],
|
@@ -414,7 +500,13 @@ class Model:
|
|
414
500
|
input = resolve_reasoning_history(input, config, self.api)
|
415
501
|
|
416
502
|
# apply any tool model_input handlers
|
417
|
-
input = resolve_tool_model_input(
|
503
|
+
input = resolve_tool_model_input(
|
504
|
+
tdefs,
|
505
|
+
input,
|
506
|
+
ToolCallModelInputHints(
|
507
|
+
disable_computer_screenshot_truncation=self.api.disable_computer_screenshot_truncation()
|
508
|
+
),
|
509
|
+
)
|
418
510
|
|
419
511
|
# break tool image content out into user messages if the model doesn't
|
420
512
|
# support tools returning images
|
@@ -664,10 +756,10 @@ class Model:
|
|
664
756
|
# trace
|
665
757
|
if isinstance(result, ModelOutput):
|
666
758
|
if result.choices:
|
667
|
-
|
759
|
+
display_conversation_assistant(input, result.choices[0].message)
|
668
760
|
event.output = result
|
669
761
|
else:
|
670
|
-
|
762
|
+
display_conversation_assistant_error(result)
|
671
763
|
event.error = repr(result)
|
672
764
|
|
673
765
|
event.call = updated_call
|
@@ -1034,7 +1126,7 @@ def resolve_reasoning_history(
|
|
1034
1126
|
|
1035
1127
|
|
1036
1128
|
def resolve_tool_model_input(
|
1037
|
-
tdefs: list[ToolDef], messages: list[ChatMessage]
|
1129
|
+
tdefs: list[ToolDef], messages: list[ChatMessage], hints: ToolCallModelInputHints
|
1038
1130
|
) -> list[ChatMessage]:
|
1039
1131
|
# filter on tooldefs that have a model input handler
|
1040
1132
|
tdefs = [tdef for tdef in tdefs if tdef.model_input is not None]
|
@@ -1060,7 +1152,7 @@ def resolve_tool_model_input(
|
|
1060
1152
|
# call the function for each tool, passing the index, total, and content
|
1061
1153
|
for index, message in enumerate(tdef_tool_messages):
|
1062
1154
|
message.content = tdef.model_input(
|
1063
|
-
index, len(tool_messages), message.content
|
1155
|
+
index, len(tool_messages), message.content, hints
|
1064
1156
|
)
|
1065
1157
|
|
1066
1158
|
# return modified messages
|
@@ -1116,7 +1208,7 @@ def tool_result_images_reducer(
|
|
1116
1208
|
content=edited_tool_message_content,
|
1117
1209
|
tool_call_id=message.tool_call_id,
|
1118
1210
|
function=message.function,
|
1119
|
-
|
1211
|
+
internal=message.internal,
|
1120
1212
|
)
|
1121
1213
|
],
|
1122
1214
|
pending_content + new_user_message_content,
|
@@ -1219,6 +1311,13 @@ def consecutive_message_reducer(
|
|
1219
1311
|
def combine_messages(
|
1220
1312
|
a: ChatMessage, b: ChatMessage, message_type: Type[ChatMessage]
|
1221
1313
|
) -> ChatMessage:
|
1314
|
+
# TODO: Although unlikely to happen based on the current call sites, these
|
1315
|
+
# fabricated messages drop interesting fields from the source messages -
|
1316
|
+
# such as `internal_name`, `tool_calls`, etc.
|
1317
|
+
# To be more specific, since all `ChatMessageXxx` fields other than `id` and
|
1318
|
+
# `content` have default values, it's more the case that they're reset to
|
1319
|
+
# default values rather than dropped.
|
1320
|
+
|
1222
1321
|
if isinstance(a.content, str) and isinstance(b.content, str):
|
1223
1322
|
return message_type(id=a.id, content=f"{a.content}\n{b.content}")
|
1224
1323
|
elif isinstance(a.content, list) and isinstance(b.content, list):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import uuid
|
2
2
|
from typing import Any, Literal, Type
|
3
3
|
|
4
|
-
from pydantic import BaseModel, Field, model_validator
|
4
|
+
from pydantic import BaseModel, Field, JsonValue, model_validator
|
5
5
|
|
6
6
|
from inspect_ai.tool._tool_call import ToolCall
|
7
7
|
|
@@ -123,6 +123,10 @@ class ModelOutput(BaseModel):
|
|
123
123
|
error: str | None = Field(default=None)
|
124
124
|
"""Error message in the case of content moderation refusals."""
|
125
125
|
|
126
|
+
@property
|
127
|
+
def empty(self) -> bool:
|
128
|
+
return len(self.choices) == 0
|
129
|
+
|
126
130
|
@property
|
127
131
|
def stop_reason(self) -> StopReason:
|
128
132
|
"""First message stop reason."""
|
@@ -153,7 +157,8 @@ class ModelOutput(BaseModel):
|
|
153
157
|
else:
|
154
158
|
self.choices.append(
|
155
159
|
ChatCompletionChoice(
|
156
|
-
message=ChatMessageAssistant(content=completion
|
160
|
+
message=ChatMessageAssistant(content=completion, model=self.model),
|
161
|
+
stop_reason="stop",
|
157
162
|
)
|
158
163
|
)
|
159
164
|
|
@@ -176,7 +181,9 @@ class ModelOutput(BaseModel):
|
|
176
181
|
model=model,
|
177
182
|
choices=[
|
178
183
|
ChatCompletionChoice(
|
179
|
-
message=ChatMessageAssistant(
|
184
|
+
message=ChatMessageAssistant(
|
185
|
+
content=content, model=model, source="generate"
|
186
|
+
),
|
180
187
|
stop_reason=stop_reason,
|
181
188
|
)
|
182
189
|
],
|
@@ -188,10 +195,9 @@ class ModelOutput(BaseModel):
|
|
188
195
|
model: str,
|
189
196
|
tool_name: str,
|
190
197
|
tool_arguments: dict[str, Any],
|
191
|
-
|
198
|
+
internal: JsonValue | None = None,
|
192
199
|
tool_call_id: str | None = None,
|
193
200
|
content: str | None = None,
|
194
|
-
type: str = "function",
|
195
201
|
) -> "ModelOutput":
|
196
202
|
"""
|
197
203
|
Returns a ModelOutput for requesting a tool call.
|
@@ -199,8 +205,7 @@ class ModelOutput(BaseModel):
|
|
199
205
|
Args:
|
200
206
|
model: model name
|
201
207
|
tool_name: The name of the tool.
|
202
|
-
|
203
|
-
type: The model's type for the tool. e.g. "function", "computer_use_preview"
|
208
|
+
internal: The model's internal info for the tool (if any).
|
204
209
|
tool_arguments: The arguments passed to the tool.
|
205
210
|
tool_call_id: Optional ID for the tool call. Defaults to a random UUID.
|
206
211
|
content: Optional content to include in the message. Defaults to "tool call for tool {tool_name}".
|
@@ -220,14 +225,14 @@ class ModelOutput(BaseModel):
|
|
220
225
|
ChatCompletionChoice(
|
221
226
|
message=ChatMessageAssistant(
|
222
227
|
content=content,
|
228
|
+
model=model,
|
223
229
|
source="generate",
|
224
230
|
tool_calls=[
|
225
231
|
ToolCall(
|
226
232
|
id=tool_call_id,
|
227
233
|
function=tool_name,
|
228
|
-
|
234
|
+
internal=internal,
|
229
235
|
arguments=tool_arguments,
|
230
|
-
type=type,
|
231
236
|
)
|
232
237
|
],
|
233
238
|
),
|
inspect_ai/model/_openai.py
CHANGED
@@ -83,6 +83,10 @@ def is_o1_preview(name: str) -> bool:
|
|
83
83
|
return "o1-preview" in name
|
84
84
|
|
85
85
|
|
86
|
+
def is_computer_use_preview(name: str) -> bool:
|
87
|
+
return "computer-use-preview" in name
|
88
|
+
|
89
|
+
|
86
90
|
def is_gpt(name: str) -> bool:
|
87
91
|
return "gpt" in name
|
88
92
|
|
@@ -100,13 +104,12 @@ def openai_chat_tool_call(tool_call: ToolCall) -> ChatCompletionMessageToolCall:
|
|
100
104
|
def openai_chat_tool_call_param(
|
101
105
|
tool_call: ToolCall,
|
102
106
|
) -> ChatCompletionMessageToolCallParam:
|
103
|
-
assert tool_call.type == "function", f"Unexpected tool call type {tool_call.type}"
|
104
107
|
return ChatCompletionMessageToolCallParam(
|
105
108
|
id=tool_call.id,
|
106
109
|
function=dict(
|
107
110
|
name=tool_call.function, arguments=json.dumps(tool_call.arguments)
|
108
111
|
),
|
109
|
-
type="function",
|
112
|
+
type="function",
|
110
113
|
)
|
111
114
|
|
112
115
|
|
@@ -308,6 +311,7 @@ def chat_tool_calls_from_openai(
|
|
308
311
|
|
309
312
|
|
310
313
|
def chat_messages_from_openai(
|
314
|
+
model: str,
|
311
315
|
messages: list[ChatCompletionMessageParam],
|
312
316
|
) -> list[ChatMessage]:
|
313
317
|
# track tool names by id
|
@@ -386,6 +390,8 @@ def chat_messages_from_openai(
|
|
386
390
|
ChatMessageAssistant(
|
387
391
|
content=content,
|
388
392
|
tool_calls=tool_calls or None,
|
393
|
+
model=model,
|
394
|
+
source="generate",
|
389
395
|
)
|
390
396
|
)
|
391
397
|
elif message["role"] == "tool":
|
@@ -464,7 +470,7 @@ def content_from_openai(
|
|
464
470
|
|
465
471
|
|
466
472
|
def chat_message_assistant_from_openai(
|
467
|
-
message: ChatCompletionMessage, tools: list[ToolInfo]
|
473
|
+
model: str, message: ChatCompletionMessage, tools: list[ToolInfo]
|
468
474
|
) -> ChatMessageAssistant:
|
469
475
|
refusal = getattr(message, "refusal", None)
|
470
476
|
reasoning = getattr(message, "reasoning_content", None) or getattr(
|
@@ -484,6 +490,7 @@ def chat_message_assistant_from_openai(
|
|
484
490
|
|
485
491
|
return ChatMessageAssistant(
|
486
492
|
content=content,
|
493
|
+
model=model,
|
487
494
|
source="generate",
|
488
495
|
tool_calls=chat_tool_calls_from_openai(message, tools),
|
489
496
|
)
|
@@ -496,7 +503,9 @@ def chat_choices_from_openai(
|
|
496
503
|
choices.sort(key=lambda c: c.index)
|
497
504
|
return [
|
498
505
|
ChatCompletionChoice(
|
499
|
-
message=chat_message_assistant_from_openai(
|
506
|
+
message=chat_message_assistant_from_openai(
|
507
|
+
response.model, choice.message, tools
|
508
|
+
),
|
500
509
|
stop_reason=as_stop_reason(choice.finish_reason),
|
501
510
|
logprobs=(
|
502
511
|
Logprobs(**choice.logprobs.model_dump())
|
@@ -538,6 +547,9 @@ def openai_handle_bad_request(
|
|
538
547
|
|
539
548
|
def openai_media_filter(key: JsonValue | None, value: JsonValue) -> JsonValue:
|
540
549
|
# remove images from raw api call
|
550
|
+
if key == "output" and isinstance(value, dict) and "image_url" in value:
|
551
|
+
value = copy(value)
|
552
|
+
value.update(image_url=BASE_64_DATA_REMOVED)
|
541
553
|
if key == "image_url" and isinstance(value, dict) and "url" in value:
|
542
554
|
url = str(value.get("url"))
|
543
555
|
if url.startswith("data:"):
|
@@ -0,0 +1,162 @@
|
|
1
|
+
from openai.types.responses import (
|
2
|
+
ComputerToolParam,
|
3
|
+
ResponseComputerToolCall,
|
4
|
+
ResponseComputerToolCallOutputScreenshotParam,
|
5
|
+
)
|
6
|
+
from openai.types.responses.response_input_item_param import ComputerCallOutput
|
7
|
+
|
8
|
+
from inspect_ai._util.content import Content, ContentImage
|
9
|
+
from inspect_ai.model._chat_message import ChatMessageTool
|
10
|
+
from inspect_ai.tool._tool_call import ToolCall
|
11
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
12
|
+
|
13
|
+
|
14
|
+
def tool_call_from_openai_computer_tool_call(
|
15
|
+
output: ResponseComputerToolCall,
|
16
|
+
) -> ToolCall:
|
17
|
+
return ToolCall(
|
18
|
+
id=output.call_id,
|
19
|
+
function="computer",
|
20
|
+
arguments=_parse_computer_tool_call_arguments(output),
|
21
|
+
internal=output.model_dump(),
|
22
|
+
)
|
23
|
+
|
24
|
+
|
25
|
+
def maybe_computer_use_preview_tool(tool: ToolInfo) -> ComputerToolParam | None:
|
26
|
+
# check for compatible 'computer' tool
|
27
|
+
return (
|
28
|
+
ComputerToolParam(
|
29
|
+
type="computer_use_preview",
|
30
|
+
# The OpenAI model is ahead of the sdk — "ubuntu" -> "linux"
|
31
|
+
environment="linux", # type: ignore
|
32
|
+
# Note: The dimensions passed here for display_width and display_height should
|
33
|
+
# match the dimensions of screenshots returned by the tool.
|
34
|
+
# Those dimensions will always be one of the values in MAX_SCALING_TARGETS
|
35
|
+
# in _x11_client.py.
|
36
|
+
# TODO: enhance this code to calculate the dimensions based on the scaled screen
|
37
|
+
# size used by the container.
|
38
|
+
display_width=1366,
|
39
|
+
display_height=768,
|
40
|
+
)
|
41
|
+
if tool.name == "computer"
|
42
|
+
and (
|
43
|
+
sorted(tool.parameters.properties.keys())
|
44
|
+
== sorted(
|
45
|
+
[
|
46
|
+
"action",
|
47
|
+
"coordinate",
|
48
|
+
"duration",
|
49
|
+
"scroll_amount",
|
50
|
+
"scroll_direction",
|
51
|
+
"start_coordinate",
|
52
|
+
"text",
|
53
|
+
]
|
54
|
+
)
|
55
|
+
)
|
56
|
+
else None
|
57
|
+
)
|
58
|
+
|
59
|
+
|
60
|
+
def computer_call_output(
|
61
|
+
message: ChatMessageTool,
|
62
|
+
# internal is passed in despite being within message to avoid an extra
|
63
|
+
# validation step
|
64
|
+
internal: ResponseComputerToolCall,
|
65
|
+
) -> ComputerCallOutput:
|
66
|
+
return ComputerCallOutput(
|
67
|
+
call_id=internal.call_id,
|
68
|
+
type="computer_call_output",
|
69
|
+
output=ResponseComputerToolCallOutputScreenshotParam(
|
70
|
+
type="computer_screenshot",
|
71
|
+
image_url=_content_image(message.content),
|
72
|
+
),
|
73
|
+
)
|
74
|
+
|
75
|
+
|
76
|
+
def _parse_computer_tool_call_arguments(
|
77
|
+
output: ResponseComputerToolCall,
|
78
|
+
) -> dict[str, object]:
|
79
|
+
action = output.action
|
80
|
+
|
81
|
+
if action.type == "click":
|
82
|
+
coordinate = [action.x, action.y]
|
83
|
+
match action.button:
|
84
|
+
case "left":
|
85
|
+
return {"action": "left_click", "coordinate": coordinate}
|
86
|
+
case "right":
|
87
|
+
return {"action": "right_click", "coordinate": coordinate}
|
88
|
+
case "wheel":
|
89
|
+
return {"action": "middle_click", "coordinate": coordinate}
|
90
|
+
case "back":
|
91
|
+
return {"action": "back_click", "coordinate": coordinate}
|
92
|
+
case "forward":
|
93
|
+
return {"action": "forward_click", "coordinate": coordinate}
|
94
|
+
elif action.type == "double_click":
|
95
|
+
return {"action": "double_click", "coordinate": [action.x, action.y]}
|
96
|
+
elif action.type == "drag":
|
97
|
+
# TODO: For now, we go directly from the first to the last coordinate in
|
98
|
+
# the path. Ultimately, we'll need to extend the tool to support all of
|
99
|
+
# the intermediate coordinates in the path.
|
100
|
+
path = action.path
|
101
|
+
assert len(path) >= 2
|
102
|
+
start = path[0]
|
103
|
+
end = path[-1]
|
104
|
+
return {
|
105
|
+
"action": "left_click_drag",
|
106
|
+
"start_coordinate": [start.x, start.y],
|
107
|
+
"coordinate": [end.x, end.y],
|
108
|
+
}
|
109
|
+
elif action.type == "keypress":
|
110
|
+
# TODO: This mapping logic is copied from their example, but seems incomplete
|
111
|
+
mapping = {
|
112
|
+
"ENTER": "Return",
|
113
|
+
"LEFT": "Left",
|
114
|
+
"RIGHT": "Right",
|
115
|
+
"UP": "Up",
|
116
|
+
"DOWN": "Down",
|
117
|
+
"ESC": "Escape",
|
118
|
+
"SPACE": "space",
|
119
|
+
"BACKSPACE": "BackSpace",
|
120
|
+
"TAB": "Tab",
|
121
|
+
}
|
122
|
+
return {
|
123
|
+
"action": "key",
|
124
|
+
"text": "+".join([mapping.get(key, key) for key in action.keys]),
|
125
|
+
}
|
126
|
+
elif action.type == "move":
|
127
|
+
return {"action": "mouse_move", "coordinate": [action.x, action.y]}
|
128
|
+
elif action.type == "screenshot":
|
129
|
+
return {"action": "screenshot"}
|
130
|
+
elif action.type == "scroll":
|
131
|
+
# TODO: OpenAI spec's with x/y distances. Their example code treats the
|
132
|
+
# unit of measurement as a "click" of the scroll wheel. Since it's not
|
133
|
+
# really a thing to scroll both horizontally and vertically at the same
|
134
|
+
# time, we'll just pick one of the potentially two directions and
|
135
|
+
# scroll along that dimension.
|
136
|
+
(scroll_direction, scroll_amount) = (
|
137
|
+
("right" if action.scroll_x > 0 else "left", abs(action.scroll_x))
|
138
|
+
if action.scroll_x
|
139
|
+
else ("down" if action.scroll_y > 0 else "up", abs(action.scroll_y))
|
140
|
+
)
|
141
|
+
return {
|
142
|
+
"action": "scroll",
|
143
|
+
"coordinate": [action.x, action.y],
|
144
|
+
"scroll_direction": scroll_direction,
|
145
|
+
"scroll_amount": scroll_amount,
|
146
|
+
}
|
147
|
+
elif action.type == "type":
|
148
|
+
return {"action": "type", "text": action.text}
|
149
|
+
elif action.type == "wait":
|
150
|
+
return {"action": "wait", "duration": 1}
|
151
|
+
|
152
|
+
assert False, f"Unexpected action type: {action.type}"
|
153
|
+
|
154
|
+
|
155
|
+
def _content_image(input: str | list[Content]) -> str:
|
156
|
+
result = (
|
157
|
+
next((item.image for item in input if isinstance(item, ContentImage)), None)
|
158
|
+
if isinstance(input, list)
|
159
|
+
else None
|
160
|
+
)
|
161
|
+
assert result, "Must find image in content"
|
162
|
+
return result
|