inspect-ai 0.3.99__py3-none-any.whl → 0.3.100__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/core/config.py +11 -5
- inspect_ai/_display/core/panel.py +66 -2
- inspect_ai/_display/core/textual.py +5 -2
- inspect_ai/_display/plain/display.py +1 -0
- inspect_ai/_display/rich/display.py +2 -2
- inspect_ai/_display/textual/widgets/transcript.py +37 -9
- inspect_ai/_eval/score.py +2 -4
- inspect_ai/_eval/task/run.py +59 -81
- inspect_ai/_util/content.py +11 -6
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/text.py +7 -0
- inspect_ai/_util/working.py +8 -37
- inspect_ai/_view/__init__.py +0 -0
- inspect_ai/_view/schema.py +2 -1
- inspect_ai/_view/www/CLAUDE.md +15 -0
- inspect_ai/_view/www/dist/assets/index.css +263 -159
- inspect_ai/_view/www/dist/assets/index.js +22153 -19093
- inspect_ai/_view/www/log-schema.json +77 -3
- inspect_ai/_view/www/package.json +5 -1
- inspect_ai/_view/www/src/@types/log.d.ts +9 -0
- inspect_ai/_view/www/src/app/App.tsx +1 -15
- inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
- inspect_ai/_view/www/src/app/routing/url.ts +84 -4
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
- inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
- inspect_ai/_view/www/src/app/types.ts +5 -1
- inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
- inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
- inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
- inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
- inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
- inspect_ai/_view/www/src/state/hooks.ts +52 -2
- inspect_ai/_view/www/src/state/logSlice.ts +4 -3
- inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
- inspect_ai/_view/www/src/state/scrolling.ts +152 -0
- inspect_ai/_view/www/src/utils/attachments.ts +7 -0
- inspect_ai/_view/www/src/utils/python.ts +18 -0
- inspect_ai/_view/www/yarn.lock +269 -6
- inspect_ai/agent/_react.py +12 -7
- inspect_ai/agent/_run.py +2 -3
- inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_recorders/file.py +2 -9
- inspect_ai/log/_transcript.py +1 -1
- inspect_ai/model/_call_tools.py +6 -2
- inspect_ai/model/_openai.py +1 -1
- inspect_ai/model/_openai_responses.py +78 -39
- inspect_ai/model/_openai_web_search.py +31 -0
- inspect_ai/model/_providers/azureai.py +72 -3
- inspect_ai/model/_providers/openai.py +2 -1
- inspect_ai/scorer/_metric.py +1 -2
- inspect_ai/solver/_task_state.py +2 -2
- inspect_ai/tool/_tool.py +6 -2
- inspect_ai/tool/_tool_def.py +27 -4
- inspect_ai/tool/_tool_info.py +2 -0
- inspect_ai/tool/_tools/_web_search/_google.py +15 -4
- inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
- inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_json.py +3 -0
- inspect_ai/util/_limit.py +230 -20
- inspect_ai/util/_sandbox/docker/compose.py +20 -11
- inspect_ai/util/_span.py +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +120 -106
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
|
|
1
1
|
import json
|
2
|
-
from
|
3
|
-
from typing import TypedDict, cast
|
2
|
+
from typing import Sequence, TypedDict, cast
|
4
3
|
|
5
4
|
from openai.types.responses import (
|
6
5
|
FunctionToolParam,
|
@@ -8,6 +7,8 @@ from openai.types.responses import (
|
|
8
7
|
ResponseComputerToolCallParam,
|
9
8
|
ResponseFunctionToolCall,
|
10
9
|
ResponseFunctionToolCallParam,
|
10
|
+
ResponseFunctionWebSearch,
|
11
|
+
ResponseFunctionWebSearchParam,
|
11
12
|
ResponseInputContentParam,
|
12
13
|
ResponseInputImageParam,
|
13
14
|
ResponseInputItemParam,
|
@@ -51,6 +52,7 @@ from inspect_ai.model._openai_computer_use import (
|
|
51
52
|
maybe_computer_use_preview_tool,
|
52
53
|
tool_call_from_openai_computer_tool_call,
|
53
54
|
)
|
55
|
+
from inspect_ai.model._openai_web_search import maybe_web_search_tool
|
54
56
|
from inspect_ai.tool._tool_call import ToolCall
|
55
57
|
from inspect_ai.tool._tool_choice import ToolChoice
|
56
58
|
from inspect_ai.tool._tool_info import ToolInfo
|
@@ -174,6 +176,12 @@ def openai_responses_chat_choices(
|
|
174
176
|
return [ChatCompletionChoice(message=message, stop_reason=stop_reason)]
|
175
177
|
|
176
178
|
|
179
|
+
def is_native_tool_configured(
|
180
|
+
tools: Sequence[ToolInfo], config: GenerateConfig
|
181
|
+
) -> bool:
|
182
|
+
return any(_maybe_native_tool_param(tool, config) is not None for tool in tools)
|
183
|
+
|
184
|
+
|
177
185
|
# The next two function perform transformations between OpenAI types an Inspect
|
178
186
|
# ChatMessageAssistant. Here is a diagram that helps visualize the transforms.
|
179
187
|
# ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐
|
@@ -207,7 +215,6 @@ def openai_responses_chat_choices(
|
|
207
215
|
|
208
216
|
|
209
217
|
class _AssistantInternal(TypedDict):
|
210
|
-
output_message_id: str | None
|
211
218
|
tool_message_ids: dict[str, str]
|
212
219
|
|
213
220
|
|
@@ -237,17 +244,17 @@ def _chat_message_assistant_from_openai_response(
|
|
237
244
|
# collect output and tool calls
|
238
245
|
message_content: list[Content] = []
|
239
246
|
tool_calls: list[ToolCall] = []
|
240
|
-
internal = _AssistantInternal(
|
247
|
+
internal = _AssistantInternal(tool_message_ids={})
|
241
248
|
for output in response.output:
|
242
249
|
match output:
|
243
250
|
case ResponseOutputMessage(content=content, id=id):
|
244
|
-
assert internal["output_message_id"] is None, "Multiple message outputs"
|
245
|
-
internal["output_message_id"] = id
|
246
251
|
message_content.extend(
|
247
252
|
[
|
248
|
-
ContentText(text=c.text)
|
253
|
+
ContentText(text=c.text, internal={"id": id})
|
249
254
|
if isinstance(c, ResponseOutputText)
|
250
|
-
else ContentText(
|
255
|
+
else ContentText(
|
256
|
+
text=c.refusal, refusal=True, internal={"id": id}
|
257
|
+
)
|
251
258
|
for c in content
|
252
259
|
]
|
253
260
|
)
|
@@ -277,6 +284,13 @@ def _chat_message_assistant_from_openai_response(
|
|
277
284
|
tool_calls.append(
|
278
285
|
tool_call_from_openai_computer_tool_call(output)
|
279
286
|
)
|
287
|
+
case ResponseFunctionWebSearch():
|
288
|
+
# We don't currently capture this since the model did the
|
289
|
+
# "tool call" internally. It's conceivable that could be
|
290
|
+
# forced to include it in `.internal` in the future, but
|
291
|
+
# for now we just ignore it.
|
292
|
+
# {"id":"ws_682cdcec3fa88198bc10b38fafefbd5e077e89e31fd4a3d5","status":"completed","type":"web_search_call"}
|
293
|
+
pass
|
280
294
|
case _:
|
281
295
|
raise ValueError(f"Unexpected output type: {output.__class__}")
|
282
296
|
|
@@ -304,25 +318,39 @@ def _openai_input_items_from_chat_message_assistant(
|
|
304
318
|
field of the `ChatMessageAssistant` to help it provide the proper id's the
|
305
319
|
items in the returned list.
|
306
320
|
"""
|
307
|
-
|
321
|
+
tool_message_ids = _ids_from_assistant_internal(message)
|
308
322
|
|
309
323
|
# we want to prevent yielding output messages in the case where we have an
|
310
324
|
# 'internal' field (so the message came from the model API as opposed to
|
311
|
-
# being user synthesized) AND there
|
312
|
-
# when reading the message from the server we didn't find output).
|
313
|
-
# happen e.g. when a react() agent sets the output.completion in response
|
325
|
+
# being user synthesized) AND there are no ContentText items with message IDs
|
326
|
+
# (indicating that when reading the message from the server we didn't find output).
|
327
|
+
# this could happen e.g. when a react() agent sets the output.completion in response
|
314
328
|
# to a submit() tool call
|
315
|
-
|
329
|
+
content_items: list[ContentText | ContentReasoning] = (
|
330
|
+
[ContentText(text=message.content)]
|
331
|
+
if isinstance(message.content, str)
|
332
|
+
else [
|
333
|
+
c for c in message.content if isinstance(c, ContentText | ContentReasoning)
|
334
|
+
]
|
335
|
+
)
|
336
|
+
has_content_with_ids = any(
|
337
|
+
isinstance(c, ContentText)
|
338
|
+
and isinstance(c.internal, dict)
|
339
|
+
and "id" in c.internal
|
340
|
+
for c in content_items
|
341
|
+
)
|
342
|
+
suppress_output_message = message.internal is not None and not has_content_with_ids
|
316
343
|
|
317
344
|
# if we are not storing messages on the server then blank these out
|
318
345
|
if not store:
|
319
|
-
output_message_id = None
|
320
346
|
tool_message_ids = {}
|
321
347
|
|
322
|
-
# items to return
|
323
|
-
# additional content on to it)
|
348
|
+
# items to return
|
324
349
|
items: list[ResponseInputItemParam] = []
|
325
|
-
|
350
|
+
# group content by message ID
|
351
|
+
messages_by_id: dict[
|
352
|
+
str | None, list[ResponseOutputTextParam | ResponseOutputRefusalParam]
|
353
|
+
] = {}
|
326
354
|
|
327
355
|
for content in (
|
328
356
|
list[ContentText | ContentReasoning]([ContentText(text=message.content)])
|
@@ -352,6 +380,14 @@ def _openai_input_items_from_chat_message_assistant(
|
|
352
380
|
if suppress_output_message:
|
353
381
|
continue
|
354
382
|
|
383
|
+
# get the message ID from ContentText.modelJson
|
384
|
+
content_message_id: str | None = None
|
385
|
+
if isinstance(content.internal, dict) and "id" in content.internal:
|
386
|
+
id_value = content.internal["id"]
|
387
|
+
content_message_id = id_value if isinstance(id_value, str) else None
|
388
|
+
else:
|
389
|
+
content_message_id = None
|
390
|
+
|
355
391
|
new_content = (
|
356
392
|
ResponseOutputRefusalParam(type="refusal", refusal=text)
|
357
393
|
if refusal
|
@@ -359,22 +395,24 @@ def _openai_input_items_from_chat_message_assistant(
|
|
359
395
|
type="output_text", text=text, annotations=[]
|
360
396
|
)
|
361
397
|
)
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
398
|
+
|
399
|
+
if content_message_id not in messages_by_id:
|
400
|
+
messages_by_id[content_message_id] = []
|
401
|
+
messages_by_id[content_message_id].append(new_content)
|
402
|
+
|
403
|
+
# create ResponseOutputMessage for each unique ID
|
404
|
+
for msg_id, content_list in messages_by_id.items():
|
405
|
+
output_message = ResponseOutputMessageParam(
|
406
|
+
type="message",
|
407
|
+
role="assistant",
|
408
|
+
# this actually can be `None`, and it will in fact be `None` when the
|
409
|
+
# assistant message is synthesized by the scaffold as opposed to being
|
410
|
+
# replayed from the model (or when store=False)
|
411
|
+
id=msg_id, # type: ignore[typeddict-item]
|
412
|
+
content=content_list,
|
413
|
+
status="completed",
|
414
|
+
)
|
415
|
+
items.append(output_message)
|
378
416
|
|
379
417
|
return items + _tool_call_items_from_assistant_message(message, tool_message_ids)
|
380
418
|
|
@@ -399,7 +437,7 @@ def _maybe_native_tool_param(
|
|
399
437
|
) -> ToolParam | None:
|
400
438
|
return (
|
401
439
|
(
|
402
|
-
maybe_computer_use_preview_tool(tool)
|
440
|
+
maybe_computer_use_preview_tool(tool) or maybe_web_search_tool(tool)
|
403
441
|
# or self.text_editor_tool_param(tool)
|
404
442
|
# or self.bash_tool_param(tool)
|
405
443
|
)
|
@@ -442,22 +480,23 @@ def _tool_call_items_from_assistant_message(
|
|
442
480
|
|
443
481
|
def _ids_from_assistant_internal(
|
444
482
|
message: ChatMessageAssistant,
|
445
|
-
) ->
|
483
|
+
) -> dict[str, str]:
|
446
484
|
if message.internal is not None:
|
447
485
|
assert isinstance(message.internal, dict), (
|
448
486
|
"OpenAI ChatMessageAssistant internal must be an _AssistantInternal"
|
449
487
|
)
|
450
488
|
internal = cast(_AssistantInternal, message.internal)
|
451
|
-
return
|
489
|
+
return internal["tool_message_ids"]
|
452
490
|
else:
|
453
|
-
return
|
491
|
+
return {}
|
454
492
|
|
455
493
|
|
456
494
|
_ResponseToolCallParam = (
|
457
|
-
ResponseFunctionToolCallParam
|
495
|
+
ResponseFunctionToolCallParam
|
496
|
+
| ResponseComputerToolCallParam
|
497
|
+
| ResponseFunctionWebSearchParam
|
458
498
|
# | ResponseFileSearchToolCallParam
|
459
499
|
# | ResponseFunctionToolCallParam
|
460
|
-
# | ResponseFunctionWebSearchParam
|
461
500
|
)
|
462
501
|
|
463
502
|
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from typing import cast
|
2
|
+
|
3
|
+
from openai.types.responses import WebSearchTool, WebSearchToolParam
|
4
|
+
|
5
|
+
from inspect_ai.tool._tool_info import ToolInfo
|
6
|
+
|
7
|
+
|
8
|
+
def maybe_web_search_tool(tool: ToolInfo) -> WebSearchToolParam | None:
|
9
|
+
return (
|
10
|
+
_web_search_tool(tool.options["openai"])
|
11
|
+
if tool.name == "web_search" and tool.options and "openai" in tool.options
|
12
|
+
else None
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def _web_search_tool(maybe_openai_options: object) -> WebSearchToolParam:
|
17
|
+
if maybe_openai_options is None:
|
18
|
+
maybe_openai_options = {}
|
19
|
+
elif not isinstance(maybe_openai_options, dict):
|
20
|
+
raise TypeError(
|
21
|
+
f"Expected a dictionary for openai_options, got {type(maybe_openai_options)}"
|
22
|
+
)
|
23
|
+
openai_options = (
|
24
|
+
WebSearchTool.model_validate(
|
25
|
+
{"type": "web_search_preview", **maybe_openai_options}
|
26
|
+
)
|
27
|
+
if maybe_openai_options
|
28
|
+
else WebSearchTool(type="web_search_preview")
|
29
|
+
)
|
30
|
+
|
31
|
+
return cast(WebSearchToolParam, openai_options.model_dump(exclude_none=True))
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import functools
|
1
2
|
import json
|
2
3
|
import os
|
3
4
|
from copy import copy
|
@@ -151,7 +152,7 @@ class AzureAIAPI(ModelAPI):
|
|
151
152
|
|
152
153
|
# prepare request
|
153
154
|
request = dict(
|
154
|
-
messages=await chat_request_messages(input, handler),
|
155
|
+
messages=await chat_request_messages(input, handler, self.is_mistral()),
|
155
156
|
**self.completion_params(config),
|
156
157
|
)
|
157
158
|
# newer versions of vllm reject requests with tools or tool_choice if the
|
@@ -280,9 +281,77 @@ class AzureAIAPI(ModelAPI):
|
|
280
281
|
|
281
282
|
|
282
283
|
async def chat_request_messages(
|
283
|
-
messages: list[ChatMessage],
|
284
|
+
messages: list[ChatMessage],
|
285
|
+
handler: ChatAPIHandler | None,
|
286
|
+
is_mistral: bool = False,
|
287
|
+
) -> list[ChatRequestMessage]:
|
288
|
+
chat_messages = [
|
289
|
+
await chat_request_message(message, handler) for message in messages
|
290
|
+
]
|
291
|
+
if is_mistral:
|
292
|
+
chat_messages = functools.reduce(mistral_message_reducer, chat_messages, [])
|
293
|
+
|
294
|
+
return chat_messages
|
295
|
+
|
296
|
+
|
297
|
+
def mistral_message_reducer(
|
298
|
+
messages: list[ChatRequestMessage],
|
299
|
+
message: ChatRequestMessage,
|
284
300
|
) -> list[ChatRequestMessage]:
|
285
|
-
|
301
|
+
"""Fold any user messages found immediately after tool messages into the last tool message."""
|
302
|
+
if (
|
303
|
+
len(messages) > 0
|
304
|
+
and isinstance(messages[-1], ToolMessage)
|
305
|
+
and isinstance(message, UserMessage)
|
306
|
+
):
|
307
|
+
messages[-1] = fold_user_message_into_tool_message(messages[-1], message)
|
308
|
+
else:
|
309
|
+
messages.append(message)
|
310
|
+
|
311
|
+
return messages
|
312
|
+
|
313
|
+
|
314
|
+
def fold_user_message_into_tool_message(
|
315
|
+
tool_message: ToolMessage,
|
316
|
+
user_message: UserMessage,
|
317
|
+
) -> ToolMessage:
|
318
|
+
def convert_content_items_to_string(list_content: list[ContentItem]) -> str:
|
319
|
+
if not all(
|
320
|
+
isinstance(item, (TextContentItem | ImageContentItem))
|
321
|
+
for item in list_content
|
322
|
+
):
|
323
|
+
raise TypeError(
|
324
|
+
"Expected all items to be TextContentItem or ImageContentItem"
|
325
|
+
)
|
326
|
+
|
327
|
+
parts = []
|
328
|
+
for item in list_content:
|
329
|
+
if isinstance(item, TextContentItem):
|
330
|
+
parts.append(item.text)
|
331
|
+
elif isinstance(item, ImageContentItem):
|
332
|
+
parts.append(f"[Image: {item.image_url.url}]")
|
333
|
+
else:
|
334
|
+
raise ValueError("Unexpected content item type")
|
335
|
+
return "".join(parts)
|
336
|
+
|
337
|
+
def normalise_content(
|
338
|
+
content: str | list[ContentItem] | None,
|
339
|
+
) -> str | None:
|
340
|
+
return (
|
341
|
+
None
|
342
|
+
if content is None
|
343
|
+
else convert_content_items_to_string(content)
|
344
|
+
if isinstance(content, list)
|
345
|
+
else content
|
346
|
+
)
|
347
|
+
|
348
|
+
tool_content = normalise_content(tool_message.content)
|
349
|
+
user_content = normalise_content(user_message.content)
|
350
|
+
|
351
|
+
return ToolMessage(
|
352
|
+
content=(tool_content or "") + (user_content or ""),
|
353
|
+
tool_call_id=tool_message.tool_call_id,
|
354
|
+
)
|
286
355
|
|
287
356
|
|
288
357
|
async def chat_request_message(
|
@@ -42,6 +42,7 @@ from .._openai import (
|
|
42
42
|
openai_media_filter,
|
43
43
|
openai_should_retry,
|
44
44
|
)
|
45
|
+
from .._openai_responses import is_native_tool_configured
|
45
46
|
from .openai_o1 import generate_o1
|
46
47
|
from .util import environment_prerequisite_error, model_base_url
|
47
48
|
|
@@ -241,7 +242,7 @@ class OpenAIAPI(ModelAPI):
|
|
241
242
|
tools=tools,
|
242
243
|
**self.completion_params(config, False),
|
243
244
|
)
|
244
|
-
elif self.responses_api:
|
245
|
+
elif self.responses_api or is_native_tool_configured(tools, config):
|
245
246
|
return await generate_responses(
|
246
247
|
client=self.client,
|
247
248
|
http_hooks=self._http_hooks,
|
inspect_ai/scorer/_metric.py
CHANGED
@@ -7,7 +7,6 @@ from typing import (
|
|
7
7
|
Protocol,
|
8
8
|
Type,
|
9
9
|
Union,
|
10
|
-
cast,
|
11
10
|
overload,
|
12
11
|
runtime_checkable,
|
13
12
|
)
|
@@ -356,7 +355,7 @@ def metric(
|
|
356
355
|
)
|
357
356
|
return metric
|
358
357
|
|
359
|
-
return metric_register(
|
358
|
+
return metric_register(metric_wrapper, metric_name)
|
360
359
|
|
361
360
|
# for decorators with an explicit name, one more wrapper for the name
|
362
361
|
if isinstance(name, str):
|
inspect_ai/solver/_task_state.py
CHANGED
@@ -290,7 +290,7 @@ class TaskState:
|
|
290
290
|
return self._tools
|
291
291
|
|
292
292
|
@tools.setter
|
293
|
-
def tools(self, tools:
|
293
|
+
def tools(self, tools: Sequence[Tool | ToolDef]) -> None:
|
294
294
|
self._tools.clear()
|
295
295
|
for tool in tools:
|
296
296
|
self._tools.append(tool if isinstance(tool, Tool) else tool.as_tool())
|
@@ -353,7 +353,7 @@ class TaskState:
|
|
353
353
|
def completed(self) -> bool:
|
354
354
|
"""Is the task completed.
|
355
355
|
|
356
|
-
Additionally, checks
|
356
|
+
Additionally, checks for an operator interrupt of the sample.
|
357
357
|
"""
|
358
358
|
from inspect_ai.log._samples import set_active_sample_total_messages
|
359
359
|
|
inspect_ai/tool/_tool.py
CHANGED
@@ -224,13 +224,15 @@ def tool(
|
|
224
224
|
tool_parallel = parallel
|
225
225
|
tool_viewer = viewer
|
226
226
|
tool_model_input = model_input
|
227
|
+
tool_options: dict[str, object] | None = None
|
227
228
|
if is_registry_object(tool):
|
228
|
-
_, _, reg_parallel, reg_viewer, reg_model_input =
|
229
|
-
tool
|
229
|
+
_, _, reg_parallel, reg_viewer, reg_model_input, options = (
|
230
|
+
tool_registry_info(tool)
|
230
231
|
)
|
231
232
|
tool_parallel = parallel and reg_parallel
|
232
233
|
tool_viewer = viewer or reg_viewer
|
233
234
|
tool_model_input = model_input or reg_model_input
|
235
|
+
tool_options = options
|
234
236
|
|
235
237
|
# tag the object
|
236
238
|
registry_tag(
|
@@ -247,6 +249,7 @@ def tool(
|
|
247
249
|
tool_model_input
|
248
250
|
or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
|
249
251
|
),
|
252
|
+
TOOL_OPTIONS: tool_options,
|
250
253
|
},
|
251
254
|
),
|
252
255
|
*args,
|
@@ -267,6 +270,7 @@ TOOL_PROMPT = "prompt"
|
|
267
270
|
TOOL_PARALLEL = "parallel"
|
268
271
|
TOOL_VIEWER = "viewer"
|
269
272
|
TOOL_MODEL_INPUT = "model_input"
|
273
|
+
TOOL_OPTIONS = "options"
|
270
274
|
|
271
275
|
|
272
276
|
TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
|
inspect_ai/tool/_tool_def.py
CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util.registry import (
|
|
16
16
|
|
17
17
|
from ._tool import (
|
18
18
|
TOOL_MODEL_INPUT,
|
19
|
+
TOOL_OPTIONS,
|
19
20
|
TOOL_PARALLEL,
|
20
21
|
TOOL_PROMPT,
|
21
22
|
TOOL_VIEWER,
|
@@ -44,6 +45,7 @@ class ToolDef:
|
|
44
45
|
parallel: bool | None = None,
|
45
46
|
viewer: ToolCallViewer | None = None,
|
46
47
|
model_input: ToolCallModelInput | None = None,
|
48
|
+
options: dict[str, object] | None = None,
|
47
49
|
) -> None:
|
48
50
|
"""Create a tool definition.
|
49
51
|
|
@@ -59,6 +61,8 @@ class ToolDef:
|
|
59
61
|
viewer: Optional tool call viewer implementation.
|
60
62
|
model_input: Optional function that determines how
|
61
63
|
tool call results are played back as model input.
|
64
|
+
options: Optional property bag that can be used by the model provider
|
65
|
+
to customize the implementation of the tool
|
62
66
|
|
63
67
|
Returns:
|
64
68
|
Tool definition.
|
@@ -82,6 +86,7 @@ class ToolDef:
|
|
82
86
|
self.parallel = parallel if parallel is not None else tdef.parallel
|
83
87
|
self.viewer = viewer or tdef.viewer
|
84
88
|
self.model_input = model_input or tdef.model_input
|
89
|
+
self.options = options or tdef.options
|
85
90
|
|
86
91
|
# if its not a tool then extract tool_info if all fields have not
|
87
92
|
# been provided explicitly
|
@@ -112,6 +117,7 @@ class ToolDef:
|
|
112
117
|
self.parallel = parallel is not False
|
113
118
|
self.viewer = viewer
|
114
119
|
self.model_input = model_input
|
120
|
+
self.options = options
|
115
121
|
|
116
122
|
tool: Callable[..., Any]
|
117
123
|
"""Callable to execute tool."""
|
@@ -134,13 +140,20 @@ class ToolDef:
|
|
134
140
|
model_input: ToolCallModelInput | None
|
135
141
|
"""Custom model input presenter for tool calls."""
|
136
142
|
|
143
|
+
options: dict[str, object] | None = None
|
144
|
+
"""Optional property bag that can be used by the model provider to customize the implementation of the tool"""
|
145
|
+
|
137
146
|
def as_tool(self) -> Tool:
|
138
147
|
"""Convert a ToolDef to a Tool."""
|
139
148
|
tool = self.tool
|
140
149
|
info = RegistryInfo(
|
141
150
|
type="tool",
|
142
151
|
name=self.name,
|
143
|
-
metadata={
|
152
|
+
metadata={
|
153
|
+
TOOL_PARALLEL: self.parallel,
|
154
|
+
TOOL_VIEWER: self.viewer,
|
155
|
+
TOOL_OPTIONS: self.options,
|
156
|
+
},
|
144
157
|
)
|
145
158
|
set_registry_info(tool, info)
|
146
159
|
set_registry_params(tool, {})
|
@@ -189,11 +202,12 @@ class ToolDefFields(NamedTuple):
|
|
189
202
|
parallel: bool
|
190
203
|
viewer: ToolCallViewer | None
|
191
204
|
model_input: ToolCallModelInput | None
|
205
|
+
options: dict[str, object] | None
|
192
206
|
|
193
207
|
|
194
208
|
def tool_def_fields(tool: Tool) -> ToolDefFields:
|
195
209
|
# get tool_info
|
196
|
-
name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
|
210
|
+
name, prompt, parallel, viewer, model_input, options = tool_registry_info(tool)
|
197
211
|
tool_info = parse_tool_info(tool)
|
198
212
|
|
199
213
|
# if there is a description then append any prompt to the
|
@@ -234,19 +248,28 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
|
|
234
248
|
parallel=parallel,
|
235
249
|
viewer=viewer,
|
236
250
|
model_input=model_input,
|
251
|
+
options=options,
|
237
252
|
)
|
238
253
|
|
239
254
|
|
240
255
|
def tool_registry_info(
|
241
256
|
tool: Tool,
|
242
|
-
) -> tuple[
|
257
|
+
) -> tuple[
|
258
|
+
str,
|
259
|
+
str | None,
|
260
|
+
bool,
|
261
|
+
ToolCallViewer | None,
|
262
|
+
ToolCallModelInput | None,
|
263
|
+
dict[str, object] | None,
|
264
|
+
]:
|
243
265
|
info = registry_info(tool)
|
244
266
|
name = info.name.split("/")[-1]
|
245
267
|
prompt = info.metadata.get(TOOL_PROMPT, None)
|
246
268
|
parallel = info.metadata.get(TOOL_PARALLEL, True)
|
247
269
|
viewer = info.metadata.get(TOOL_VIEWER, None)
|
248
270
|
model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
|
249
|
-
|
271
|
+
options = info.metadata.get(TOOL_OPTIONS, None)
|
272
|
+
return name, prompt, parallel, viewer, model_input, options
|
250
273
|
|
251
274
|
|
252
275
|
def validate_tool_parameters(tool_name: str, parameters: dict[str, ToolParam]) -> None:
|
inspect_ai/tool/_tool_info.py
CHANGED
@@ -49,6 +49,8 @@ class ToolInfo(BaseModel):
|
|
49
49
|
"""Short description of tool."""
|
50
50
|
parameters: ToolParams = Field(default_factory=ToolParams)
|
51
51
|
"""JSON Schema of tool parameters object."""
|
52
|
+
options: dict[str, object] | None = Field(default=None)
|
53
|
+
"""Optional property bag that can be used by the model provider to customize the implementation of the tool"""
|
52
54
|
|
53
55
|
|
54
56
|
def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
|
@@ -4,6 +4,7 @@ from typing import Awaitable, Callable
|
|
4
4
|
import anyio
|
5
5
|
import httpx
|
6
6
|
from bs4 import BeautifulSoup, NavigableString
|
7
|
+
from pydantic import BaseModel
|
7
8
|
from tenacity import (
|
8
9
|
retry,
|
9
10
|
retry_if_exception,
|
@@ -23,6 +24,13 @@ Page Content: {text}
|
|
23
24
|
"""
|
24
25
|
|
25
26
|
|
27
|
+
class GoogleOptions(BaseModel):
|
28
|
+
num_results: int | None = None
|
29
|
+
max_provider_calls: int | None = None
|
30
|
+
max_connections: int | None = None
|
31
|
+
model: str | None = None
|
32
|
+
|
33
|
+
|
26
34
|
class SearchLink:
|
27
35
|
def __init__(self, url: str, snippet: str) -> None:
|
28
36
|
self.url = url
|
@@ -42,11 +50,14 @@ def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
|
42
50
|
|
43
51
|
|
44
52
|
def google_search_provider(
|
45
|
-
|
46
|
-
max_provider_calls: int,
|
47
|
-
max_connections: int,
|
48
|
-
model: str | None,
|
53
|
+
in_options: dict[str, object] | None = None,
|
49
54
|
) -> Callable[[str], Awaitable[str | None]]:
|
55
|
+
options = GoogleOptions.model_validate(in_options) if in_options else None
|
56
|
+
num_results = (options.num_results if options else None) or 3
|
57
|
+
max_provider_calls = (options.max_provider_calls if options else None) or 3
|
58
|
+
max_connections = (options.max_connections if options else None) or 10
|
59
|
+
model = options.model if options else None
|
60
|
+
|
50
61
|
keys = maybe_get_google_api_keys()
|
51
62
|
if not keys:
|
52
63
|
raise PrerequisiteError(
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Awaitable, Callable
|
2
|
+
from typing import Awaitable, Callable, Literal
|
3
3
|
|
4
4
|
import httpx
|
5
5
|
from pydantic import BaseModel, Field
|
@@ -16,6 +16,25 @@ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
|
16
16
|
from inspect_ai.util._concurrency import concurrency
|
17
17
|
|
18
18
|
|
19
|
+
class TavilyOptions(BaseModel):
|
20
|
+
topic: Literal["general", "news"] | None = None
|
21
|
+
search_depth: Literal["basic", "advanced"] | None = None
|
22
|
+
chunks_per_source: Literal[1, 2, 3] | None = None
|
23
|
+
max_results: int | None = None
|
24
|
+
time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | None = (
|
25
|
+
None
|
26
|
+
)
|
27
|
+
days: int | None = None
|
28
|
+
include_answer: bool | Literal["basic", "advanced"] | None = None
|
29
|
+
include_raw_content: bool | None = None
|
30
|
+
include_images: bool | None = None
|
31
|
+
include_image_descriptions: bool | None = None
|
32
|
+
include_domains: list[str] | None = None
|
33
|
+
exclude_domains: list[str] | None = None
|
34
|
+
# max_connections is not a Tavily API option, but an inspect option
|
35
|
+
max_connections: int | None = None
|
36
|
+
|
37
|
+
|
19
38
|
class TavilySearchResult(BaseModel):
|
20
39
|
title: str
|
21
40
|
url: str
|
@@ -32,17 +51,25 @@ class TavilySearchResponse(BaseModel):
|
|
32
51
|
|
33
52
|
|
34
53
|
def tavily_search_provider(
|
35
|
-
|
54
|
+
in_options: dict[str, object] | None = None,
|
36
55
|
) -> Callable[[str], Awaitable[str | None]]:
|
56
|
+
options = TavilyOptions.model_validate(in_options) if in_options else None
|
57
|
+
# Separate max_connections (which is an inspect thing) from the rest of the
|
58
|
+
# options which will be passed in the request body
|
59
|
+
max_connections = (options.max_connections if options else None) or 10
|
60
|
+
api_options = (
|
61
|
+
options.model_dump(exclude={"max_connections"}, exclude_none=True)
|
62
|
+
if options
|
63
|
+
else {}
|
64
|
+
)
|
65
|
+
if not api_options.get("include_answer", False):
|
66
|
+
api_options["include_answer"] = True
|
67
|
+
|
37
68
|
tavily_api_key = os.environ.get("TAVILY_API_KEY", None)
|
38
69
|
if not tavily_api_key:
|
39
70
|
raise PrerequisiteError(
|
40
71
|
"TAVILY_API_KEY not set in the environment. Please ensure ths variable is defined to use Tavily with the web_search tool.\n\nLearn more about the Tavily web search provider at https://inspect.aisi.org.uk/tools.html#tavily-provider"
|
41
72
|
)
|
42
|
-
if num_results > 20:
|
43
|
-
raise PrerequisiteError(
|
44
|
-
"The Tavily search provider is limited to 20 results per query."
|
45
|
-
)
|
46
73
|
|
47
74
|
# Create the client within the provider
|
48
75
|
client = httpx.AsyncClient(timeout=30)
|
@@ -52,12 +79,8 @@ def tavily_search_provider(
|
|
52
79
|
headers = {
|
53
80
|
"Authorization": f"Bearer {tavily_api_key}",
|
54
81
|
}
|
55
|
-
|
56
|
-
|
57
|
-
"max_results": 10, # num_results,
|
58
|
-
# "search_depth": "advanced",
|
59
|
-
"include_answer": "advanced",
|
60
|
-
}
|
82
|
+
|
83
|
+
body = {"query": query, **api_options}
|
61
84
|
|
62
85
|
# retry up to 5 times over a period of up to 1 minute
|
63
86
|
@retry(
|