inspect-ai 0.3.99__py3-none-any.whl → 0.3.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. inspect_ai/_display/core/config.py +11 -5
  2. inspect_ai/_display/core/panel.py +66 -2
  3. inspect_ai/_display/core/textual.py +5 -2
  4. inspect_ai/_display/plain/display.py +1 -0
  5. inspect_ai/_display/rich/display.py +2 -2
  6. inspect_ai/_display/textual/widgets/transcript.py +37 -9
  7. inspect_ai/_eval/score.py +2 -4
  8. inspect_ai/_eval/task/run.py +59 -81
  9. inspect_ai/_util/content.py +11 -6
  10. inspect_ai/_util/interrupt.py +2 -2
  11. inspect_ai/_util/text.py +7 -0
  12. inspect_ai/_util/working.py +8 -37
  13. inspect_ai/_view/__init__.py +0 -0
  14. inspect_ai/_view/schema.py +2 -1
  15. inspect_ai/_view/www/CLAUDE.md +15 -0
  16. inspect_ai/_view/www/dist/assets/index.css +263 -159
  17. inspect_ai/_view/www/dist/assets/index.js +22153 -19093
  18. inspect_ai/_view/www/log-schema.json +77 -3
  19. inspect_ai/_view/www/package.json +5 -1
  20. inspect_ai/_view/www/src/@types/log.d.ts +9 -0
  21. inspect_ai/_view/www/src/app/App.tsx +1 -15
  22. inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
  23. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
  24. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
  25. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
  26. inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
  27. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
  28. inspect_ai/_view/www/src/app/routing/url.ts +84 -4
  29. inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
  30. inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
  31. inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
  32. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
  33. inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
  34. inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
  35. inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
  36. inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
  37. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
  38. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
  39. inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
  40. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
  41. inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
  42. inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
  43. inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
  44. inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
  45. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
  46. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  47. inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
  48. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
  49. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
  50. inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
  51. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
  52. inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
  53. inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
  54. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
  55. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
  56. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
  57. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
  58. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
  59. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
  61. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
  62. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
  63. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
  64. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
  65. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
  66. inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
  67. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
  68. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
  69. inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
  70. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
  71. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
  72. inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
  73. inspect_ai/_view/www/src/app/types.ts +5 -1
  74. inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
  75. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
  76. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
  77. inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
  78. inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
  79. inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
  80. inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
  81. inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
  82. inspect_ai/_view/www/src/state/hooks.ts +52 -2
  83. inspect_ai/_view/www/src/state/logSlice.ts +4 -3
  84. inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
  85. inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
  86. inspect_ai/_view/www/src/state/scrolling.ts +152 -0
  87. inspect_ai/_view/www/src/utils/attachments.ts +7 -0
  88. inspect_ai/_view/www/src/utils/python.ts +18 -0
  89. inspect_ai/_view/www/yarn.lock +269 -6
  90. inspect_ai/agent/_react.py +12 -7
  91. inspect_ai/agent/_run.py +2 -3
  92. inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
  93. inspect_ai/log/_log.py +1 -1
  94. inspect_ai/log/_recorders/file.py +2 -9
  95. inspect_ai/log/_transcript.py +1 -1
  96. inspect_ai/model/_call_tools.py +6 -2
  97. inspect_ai/model/_openai.py +1 -1
  98. inspect_ai/model/_openai_responses.py +78 -39
  99. inspect_ai/model/_openai_web_search.py +31 -0
  100. inspect_ai/model/_providers/azureai.py +72 -3
  101. inspect_ai/model/_providers/openai.py +2 -1
  102. inspect_ai/scorer/_metric.py +1 -2
  103. inspect_ai/solver/_task_state.py +2 -2
  104. inspect_ai/tool/_tool.py +6 -2
  105. inspect_ai/tool/_tool_def.py +27 -4
  106. inspect_ai/tool/_tool_info.py +2 -0
  107. inspect_ai/tool/_tools/_web_search/_google.py +15 -4
  108. inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
  109. inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
  110. inspect_ai/util/__init__.py +4 -0
  111. inspect_ai/util/_json.py +3 -0
  112. inspect_ai/util/_limit.py +230 -20
  113. inspect_ai/util/_sandbox/docker/compose.py +20 -11
  114. inspect_ai/util/_span.py +1 -1
  115. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
  116. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +120 -106
  117. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
  118. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
  119. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
  120. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import json
2
- from itertools import chain
3
- from typing import TypedDict, cast
2
+ from typing import Sequence, TypedDict, cast
4
3
 
5
4
  from openai.types.responses import (
6
5
  FunctionToolParam,
@@ -8,6 +7,8 @@ from openai.types.responses import (
8
7
  ResponseComputerToolCallParam,
9
8
  ResponseFunctionToolCall,
10
9
  ResponseFunctionToolCallParam,
10
+ ResponseFunctionWebSearch,
11
+ ResponseFunctionWebSearchParam,
11
12
  ResponseInputContentParam,
12
13
  ResponseInputImageParam,
13
14
  ResponseInputItemParam,
@@ -51,6 +52,7 @@ from inspect_ai.model._openai_computer_use import (
51
52
  maybe_computer_use_preview_tool,
52
53
  tool_call_from_openai_computer_tool_call,
53
54
  )
55
+ from inspect_ai.model._openai_web_search import maybe_web_search_tool
54
56
  from inspect_ai.tool._tool_call import ToolCall
55
57
  from inspect_ai.tool._tool_choice import ToolChoice
56
58
  from inspect_ai.tool._tool_info import ToolInfo
@@ -174,6 +176,12 @@ def openai_responses_chat_choices(
174
176
  return [ChatCompletionChoice(message=message, stop_reason=stop_reason)]
175
177
 
176
178
 
179
+ def is_native_tool_configured(
180
+ tools: Sequence[ToolInfo], config: GenerateConfig
181
+ ) -> bool:
182
+ return any(_maybe_native_tool_param(tool, config) is not None for tool in tools)
183
+
184
+
177
185
  # The next two function perform transformations between OpenAI types an Inspect
178
186
  # ChatMessageAssistant. Here is a diagram that helps visualize the transforms.
179
187
  # ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐
@@ -207,7 +215,6 @@ def openai_responses_chat_choices(
207
215
 
208
216
 
209
217
  class _AssistantInternal(TypedDict):
210
- output_message_id: str | None
211
218
  tool_message_ids: dict[str, str]
212
219
 
213
220
 
@@ -237,17 +244,17 @@ def _chat_message_assistant_from_openai_response(
237
244
  # collect output and tool calls
238
245
  message_content: list[Content] = []
239
246
  tool_calls: list[ToolCall] = []
240
- internal = _AssistantInternal(output_message_id=None, tool_message_ids={})
247
+ internal = _AssistantInternal(tool_message_ids={})
241
248
  for output in response.output:
242
249
  match output:
243
250
  case ResponseOutputMessage(content=content, id=id):
244
- assert internal["output_message_id"] is None, "Multiple message outputs"
245
- internal["output_message_id"] = id
246
251
  message_content.extend(
247
252
  [
248
- ContentText(text=c.text)
253
+ ContentText(text=c.text, internal={"id": id})
249
254
  if isinstance(c, ResponseOutputText)
250
- else ContentText(text=c.refusal, refusal=True)
255
+ else ContentText(
256
+ text=c.refusal, refusal=True, internal={"id": id}
257
+ )
251
258
  for c in content
252
259
  ]
253
260
  )
@@ -277,6 +284,13 @@ def _chat_message_assistant_from_openai_response(
277
284
  tool_calls.append(
278
285
  tool_call_from_openai_computer_tool_call(output)
279
286
  )
287
+ case ResponseFunctionWebSearch():
288
+ # We don't currently capture this since the model did the
289
+ # "tool call" internally. It's conceivable that could be
290
+ # forced to include it in `.internal` in the future, but
291
+ # for now we just ignore it.
292
+ # {"id":"ws_682cdcec3fa88198bc10b38fafefbd5e077e89e31fd4a3d5","status":"completed","type":"web_search_call"}
293
+ pass
280
294
  case _:
281
295
  raise ValueError(f"Unexpected output type: {output.__class__}")
282
296
 
@@ -304,25 +318,39 @@ def _openai_input_items_from_chat_message_assistant(
304
318
  field of the `ChatMessageAssistant` to help it provide the proper id's the
305
319
  items in the returned list.
306
320
  """
307
- (output_message_id, tool_message_ids) = _ids_from_assistant_internal(message)
321
+ tool_message_ids = _ids_from_assistant_internal(message)
308
322
 
309
323
  # we want to prevent yielding output messages in the case where we have an
310
324
  # 'internal' field (so the message came from the model API as opposed to
311
- # being user synthesized) AND there is no output_message_id (indicating that
312
- # when reading the message from the server we didn't find output). this could
313
- # happen e.g. when a react() agent sets the output.completion in response
325
+ # being user synthesized) AND there are no ContentText items with message IDs
326
+ # (indicating that when reading the message from the server we didn't find output).
327
+ # this could happen e.g. when a react() agent sets the output.completion in response
314
328
  # to a submit() tool call
315
- suppress_output_message = message.internal is not None and output_message_id is None
329
+ content_items: list[ContentText | ContentReasoning] = (
330
+ [ContentText(text=message.content)]
331
+ if isinstance(message.content, str)
332
+ else [
333
+ c for c in message.content if isinstance(c, ContentText | ContentReasoning)
334
+ ]
335
+ )
336
+ has_content_with_ids = any(
337
+ isinstance(c, ContentText)
338
+ and isinstance(c.internal, dict)
339
+ and "id" in c.internal
340
+ for c in content_items
341
+ )
342
+ suppress_output_message = message.internal is not None and not has_content_with_ids
316
343
 
317
344
  # if we are not storing messages on the server then blank these out
318
345
  if not store:
319
- output_message_id = None
320
346
  tool_message_ids = {}
321
347
 
322
- # items to return -- ensure we use a single output message (and just chain
323
- # additional content on to it)
348
+ # items to return
324
349
  items: list[ResponseInputItemParam] = []
325
- output_message: ResponseOutputMessageParam | None = None
350
+ # group content by message ID
351
+ messages_by_id: dict[
352
+ str | None, list[ResponseOutputTextParam | ResponseOutputRefusalParam]
353
+ ] = {}
326
354
 
327
355
  for content in (
328
356
  list[ContentText | ContentReasoning]([ContentText(text=message.content)])
@@ -352,6 +380,14 @@ def _openai_input_items_from_chat_message_assistant(
352
380
  if suppress_output_message:
353
381
  continue
354
382
 
383
+ # get the message ID from ContentText.modelJson
384
+ content_message_id: str | None = None
385
+ if isinstance(content.internal, dict) and "id" in content.internal:
386
+ id_value = content.internal["id"]
387
+ content_message_id = id_value if isinstance(id_value, str) else None
388
+ else:
389
+ content_message_id = None
390
+
355
391
  new_content = (
356
392
  ResponseOutputRefusalParam(type="refusal", refusal=text)
357
393
  if refusal
@@ -359,22 +395,24 @@ def _openai_input_items_from_chat_message_assistant(
359
395
  type="output_text", text=text, annotations=[]
360
396
  )
361
397
  )
362
- if output_message is None:
363
- output_message = ResponseOutputMessageParam(
364
- type="message",
365
- role="assistant",
366
- # this actually can be `None`, and it will in fact be `None` when the
367
- # assistant message is synthesized by the scaffold as opposed to being
368
- # replayed from the model (or when store=False)
369
- id=output_message_id, # type: ignore[typeddict-item]
370
- content=[new_content],
371
- status="completed",
372
- )
373
- items.append(output_message)
374
- else:
375
- output_message["content"] = chain(
376
- output_message["content"], [new_content]
377
- )
398
+
399
+ if content_message_id not in messages_by_id:
400
+ messages_by_id[content_message_id] = []
401
+ messages_by_id[content_message_id].append(new_content)
402
+
403
+ # create ResponseOutputMessage for each unique ID
404
+ for msg_id, content_list in messages_by_id.items():
405
+ output_message = ResponseOutputMessageParam(
406
+ type="message",
407
+ role="assistant",
408
+ # this actually can be `None`, and it will in fact be `None` when the
409
+ # assistant message is synthesized by the scaffold as opposed to being
410
+ # replayed from the model (or when store=False)
411
+ id=msg_id, # type: ignore[typeddict-item]
412
+ content=content_list,
413
+ status="completed",
414
+ )
415
+ items.append(output_message)
378
416
 
379
417
  return items + _tool_call_items_from_assistant_message(message, tool_message_ids)
380
418
 
@@ -399,7 +437,7 @@ def _maybe_native_tool_param(
399
437
  ) -> ToolParam | None:
400
438
  return (
401
439
  (
402
- maybe_computer_use_preview_tool(tool)
440
+ maybe_computer_use_preview_tool(tool) or maybe_web_search_tool(tool)
403
441
  # or self.text_editor_tool_param(tool)
404
442
  # or self.bash_tool_param(tool)
405
443
  )
@@ -442,22 +480,23 @@ def _tool_call_items_from_assistant_message(
442
480
 
443
481
  def _ids_from_assistant_internal(
444
482
  message: ChatMessageAssistant,
445
- ) -> tuple[str | None, dict[str, str]]:
483
+ ) -> dict[str, str]:
446
484
  if message.internal is not None:
447
485
  assert isinstance(message.internal, dict), (
448
486
  "OpenAI ChatMessageAssistant internal must be an _AssistantInternal"
449
487
  )
450
488
  internal = cast(_AssistantInternal, message.internal)
451
- return (internal["output_message_id"], internal["tool_message_ids"])
489
+ return internal["tool_message_ids"]
452
490
  else:
453
- return None, {}
491
+ return {}
454
492
 
455
493
 
456
494
  _ResponseToolCallParam = (
457
- ResponseFunctionToolCallParam | ResponseComputerToolCallParam
495
+ ResponseFunctionToolCallParam
496
+ | ResponseComputerToolCallParam
497
+ | ResponseFunctionWebSearchParam
458
498
  # | ResponseFileSearchToolCallParam
459
499
  # | ResponseFunctionToolCallParam
460
- # | ResponseFunctionWebSearchParam
461
500
  )
462
501
 
463
502
 
@@ -0,0 +1,31 @@
1
+ from typing import cast
2
+
3
+ from openai.types.responses import WebSearchTool, WebSearchToolParam
4
+
5
+ from inspect_ai.tool._tool_info import ToolInfo
6
+
7
+
8
+ def maybe_web_search_tool(tool: ToolInfo) -> WebSearchToolParam | None:
9
+ return (
10
+ _web_search_tool(tool.options["openai"])
11
+ if tool.name == "web_search" and tool.options and "openai" in tool.options
12
+ else None
13
+ )
14
+
15
+
16
+ def _web_search_tool(maybe_openai_options: object) -> WebSearchToolParam:
17
+ if maybe_openai_options is None:
18
+ maybe_openai_options = {}
19
+ elif not isinstance(maybe_openai_options, dict):
20
+ raise TypeError(
21
+ f"Expected a dictionary for openai_options, got {type(maybe_openai_options)}"
22
+ )
23
+ openai_options = (
24
+ WebSearchTool.model_validate(
25
+ {"type": "web_search_preview", **maybe_openai_options}
26
+ )
27
+ if maybe_openai_options
28
+ else WebSearchTool(type="web_search_preview")
29
+ )
30
+
31
+ return cast(WebSearchToolParam, openai_options.model_dump(exclude_none=True))
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import json
2
3
  import os
3
4
  from copy import copy
@@ -151,7 +152,7 @@ class AzureAIAPI(ModelAPI):
151
152
 
152
153
  # prepare request
153
154
  request = dict(
154
- messages=await chat_request_messages(input, handler),
155
+ messages=await chat_request_messages(input, handler, self.is_mistral()),
155
156
  **self.completion_params(config),
156
157
  )
157
158
  # newer versions of vllm reject requests with tools or tool_choice if the
@@ -280,9 +281,77 @@ class AzureAIAPI(ModelAPI):
280
281
 
281
282
 
282
283
  async def chat_request_messages(
283
- messages: list[ChatMessage], handler: ChatAPIHandler | None
284
+ messages: list[ChatMessage],
285
+ handler: ChatAPIHandler | None,
286
+ is_mistral: bool = False,
287
+ ) -> list[ChatRequestMessage]:
288
+ chat_messages = [
289
+ await chat_request_message(message, handler) for message in messages
290
+ ]
291
+ if is_mistral:
292
+ chat_messages = functools.reduce(mistral_message_reducer, chat_messages, [])
293
+
294
+ return chat_messages
295
+
296
+
297
+ def mistral_message_reducer(
298
+ messages: list[ChatRequestMessage],
299
+ message: ChatRequestMessage,
284
300
  ) -> list[ChatRequestMessage]:
285
- return [await chat_request_message(message, handler) for message in messages]
301
+ """Fold any user messages found immediately after tool messages into the last tool message."""
302
+ if (
303
+ len(messages) > 0
304
+ and isinstance(messages[-1], ToolMessage)
305
+ and isinstance(message, UserMessage)
306
+ ):
307
+ messages[-1] = fold_user_message_into_tool_message(messages[-1], message)
308
+ else:
309
+ messages.append(message)
310
+
311
+ return messages
312
+
313
+
314
+ def fold_user_message_into_tool_message(
315
+ tool_message: ToolMessage,
316
+ user_message: UserMessage,
317
+ ) -> ToolMessage:
318
+ def convert_content_items_to_string(list_content: list[ContentItem]) -> str:
319
+ if not all(
320
+ isinstance(item, (TextContentItem | ImageContentItem))
321
+ for item in list_content
322
+ ):
323
+ raise TypeError(
324
+ "Expected all items to be TextContentItem or ImageContentItem"
325
+ )
326
+
327
+ parts = []
328
+ for item in list_content:
329
+ if isinstance(item, TextContentItem):
330
+ parts.append(item.text)
331
+ elif isinstance(item, ImageContentItem):
332
+ parts.append(f"[Image: {item.image_url.url}]")
333
+ else:
334
+ raise ValueError("Unexpected content item type")
335
+ return "".join(parts)
336
+
337
+ def normalise_content(
338
+ content: str | list[ContentItem] | None,
339
+ ) -> str | None:
340
+ return (
341
+ None
342
+ if content is None
343
+ else convert_content_items_to_string(content)
344
+ if isinstance(content, list)
345
+ else content
346
+ )
347
+
348
+ tool_content = normalise_content(tool_message.content)
349
+ user_content = normalise_content(user_message.content)
350
+
351
+ return ToolMessage(
352
+ content=(tool_content or "") + (user_content or ""),
353
+ tool_call_id=tool_message.tool_call_id,
354
+ )
286
355
 
287
356
 
288
357
  async def chat_request_message(
@@ -42,6 +42,7 @@ from .._openai import (
42
42
  openai_media_filter,
43
43
  openai_should_retry,
44
44
  )
45
+ from .._openai_responses import is_native_tool_configured
45
46
  from .openai_o1 import generate_o1
46
47
  from .util import environment_prerequisite_error, model_base_url
47
48
 
@@ -241,7 +242,7 @@ class OpenAIAPI(ModelAPI):
241
242
  tools=tools,
242
243
  **self.completion_params(config, False),
243
244
  )
244
- elif self.responses_api:
245
+ elif self.responses_api or is_native_tool_configured(tools, config):
245
246
  return await generate_responses(
246
247
  client=self.client,
247
248
  http_hooks=self._http_hooks,
@@ -7,7 +7,6 @@ from typing import (
7
7
  Protocol,
8
8
  Type,
9
9
  Union,
10
- cast,
11
10
  overload,
12
11
  runtime_checkable,
13
12
  )
@@ -356,7 +355,7 @@ def metric(
356
355
  )
357
356
  return metric
358
357
 
359
- return metric_register(cast(Callable[P, Metric], metric_wrapper), metric_name)
358
+ return metric_register(metric_wrapper, metric_name)
360
359
 
361
360
  # for decorators with an explicit name, one more wrapper for the name
362
361
  if isinstance(name, str):
@@ -290,7 +290,7 @@ class TaskState:
290
290
  return self._tools
291
291
 
292
292
  @tools.setter
293
- def tools(self, tools: list[Tool | ToolDef]) -> None:
293
+ def tools(self, tools: Sequence[Tool | ToolDef]) -> None:
294
294
  self._tools.clear()
295
295
  for tool in tools:
296
296
  self._tools.append(tool if isinstance(tool, Tool) else tool.as_tool())
@@ -353,7 +353,7 @@ class TaskState:
353
353
  def completed(self) -> bool:
354
354
  """Is the task completed.
355
355
 
356
- Additionally, checks message and token limits and raises if they are exceeded, and also checks for an operator interrupt of the sample.
356
+ Additionally, checks for an operator interrupt of the sample.
357
357
  """
358
358
  from inspect_ai.log._samples import set_active_sample_total_messages
359
359
 
inspect_ai/tool/_tool.py CHANGED
@@ -224,13 +224,15 @@ def tool(
224
224
  tool_parallel = parallel
225
225
  tool_viewer = viewer
226
226
  tool_model_input = model_input
227
+ tool_options: dict[str, object] | None = None
227
228
  if is_registry_object(tool):
228
- _, _, reg_parallel, reg_viewer, reg_model_input = tool_registry_info(
229
- tool
229
+ _, _, reg_parallel, reg_viewer, reg_model_input, options = (
230
+ tool_registry_info(tool)
230
231
  )
231
232
  tool_parallel = parallel and reg_parallel
232
233
  tool_viewer = viewer or reg_viewer
233
234
  tool_model_input = model_input or reg_model_input
235
+ tool_options = options
234
236
 
235
237
  # tag the object
236
238
  registry_tag(
@@ -247,6 +249,7 @@ def tool(
247
249
  tool_model_input
248
250
  or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
249
251
  ),
252
+ TOOL_OPTIONS: tool_options,
250
253
  },
251
254
  ),
252
255
  *args,
@@ -267,6 +270,7 @@ TOOL_PROMPT = "prompt"
267
270
  TOOL_PARALLEL = "parallel"
268
271
  TOOL_VIEWER = "viewer"
269
272
  TOOL_MODEL_INPUT = "model_input"
273
+ TOOL_OPTIONS = "options"
270
274
 
271
275
 
272
276
  TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
@@ -16,6 +16,7 @@ from inspect_ai._util.registry import (
16
16
 
17
17
  from ._tool import (
18
18
  TOOL_MODEL_INPUT,
19
+ TOOL_OPTIONS,
19
20
  TOOL_PARALLEL,
20
21
  TOOL_PROMPT,
21
22
  TOOL_VIEWER,
@@ -44,6 +45,7 @@ class ToolDef:
44
45
  parallel: bool | None = None,
45
46
  viewer: ToolCallViewer | None = None,
46
47
  model_input: ToolCallModelInput | None = None,
48
+ options: dict[str, object] | None = None,
47
49
  ) -> None:
48
50
  """Create a tool definition.
49
51
 
@@ -59,6 +61,8 @@ class ToolDef:
59
61
  viewer: Optional tool call viewer implementation.
60
62
  model_input: Optional function that determines how
61
63
  tool call results are played back as model input.
64
+ options: Optional property bag that can be used by the model provider
65
+ to customize the implementation of the tool
62
66
 
63
67
  Returns:
64
68
  Tool definition.
@@ -82,6 +86,7 @@ class ToolDef:
82
86
  self.parallel = parallel if parallel is not None else tdef.parallel
83
87
  self.viewer = viewer or tdef.viewer
84
88
  self.model_input = model_input or tdef.model_input
89
+ self.options = options or tdef.options
85
90
 
86
91
  # if its not a tool then extract tool_info if all fields have not
87
92
  # been provided explicitly
@@ -112,6 +117,7 @@ class ToolDef:
112
117
  self.parallel = parallel is not False
113
118
  self.viewer = viewer
114
119
  self.model_input = model_input
120
+ self.options = options
115
121
 
116
122
  tool: Callable[..., Any]
117
123
  """Callable to execute tool."""
@@ -134,13 +140,20 @@ class ToolDef:
134
140
  model_input: ToolCallModelInput | None
135
141
  """Custom model input presenter for tool calls."""
136
142
 
143
+ options: dict[str, object] | None = None
144
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
145
+
137
146
  def as_tool(self) -> Tool:
138
147
  """Convert a ToolDef to a Tool."""
139
148
  tool = self.tool
140
149
  info = RegistryInfo(
141
150
  type="tool",
142
151
  name=self.name,
143
- metadata={TOOL_PARALLEL: self.parallel, TOOL_VIEWER: self.viewer},
152
+ metadata={
153
+ TOOL_PARALLEL: self.parallel,
154
+ TOOL_VIEWER: self.viewer,
155
+ TOOL_OPTIONS: self.options,
156
+ },
144
157
  )
145
158
  set_registry_info(tool, info)
146
159
  set_registry_params(tool, {})
@@ -189,11 +202,12 @@ class ToolDefFields(NamedTuple):
189
202
  parallel: bool
190
203
  viewer: ToolCallViewer | None
191
204
  model_input: ToolCallModelInput | None
205
+ options: dict[str, object] | None
192
206
 
193
207
 
194
208
  def tool_def_fields(tool: Tool) -> ToolDefFields:
195
209
  # get tool_info
196
- name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
210
+ name, prompt, parallel, viewer, model_input, options = tool_registry_info(tool)
197
211
  tool_info = parse_tool_info(tool)
198
212
 
199
213
  # if there is a description then append any prompt to the
@@ -234,19 +248,28 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
234
248
  parallel=parallel,
235
249
  viewer=viewer,
236
250
  model_input=model_input,
251
+ options=options,
237
252
  )
238
253
 
239
254
 
240
255
  def tool_registry_info(
241
256
  tool: Tool,
242
- ) -> tuple[str, str | None, bool, ToolCallViewer | None, ToolCallModelInput | None]:
257
+ ) -> tuple[
258
+ str,
259
+ str | None,
260
+ bool,
261
+ ToolCallViewer | None,
262
+ ToolCallModelInput | None,
263
+ dict[str, object] | None,
264
+ ]:
243
265
  info = registry_info(tool)
244
266
  name = info.name.split("/")[-1]
245
267
  prompt = info.metadata.get(TOOL_PROMPT, None)
246
268
  parallel = info.metadata.get(TOOL_PARALLEL, True)
247
269
  viewer = info.metadata.get(TOOL_VIEWER, None)
248
270
  model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
249
- return name, prompt, parallel, viewer, model_input
271
+ options = info.metadata.get(TOOL_OPTIONS, None)
272
+ return name, prompt, parallel, viewer, model_input, options
250
273
 
251
274
 
252
275
  def validate_tool_parameters(tool_name: str, parameters: dict[str, ToolParam]) -> None:
@@ -49,6 +49,8 @@ class ToolInfo(BaseModel):
49
49
  """Short description of tool."""
50
50
  parameters: ToolParams = Field(default_factory=ToolParams)
51
51
  """JSON Schema of tool parameters object."""
52
+ options: dict[str, object] | None = Field(default=None)
53
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
52
54
 
53
55
 
54
56
  def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
@@ -4,6 +4,7 @@ from typing import Awaitable, Callable
4
4
  import anyio
5
5
  import httpx
6
6
  from bs4 import BeautifulSoup, NavigableString
7
+ from pydantic import BaseModel
7
8
  from tenacity import (
8
9
  retry,
9
10
  retry_if_exception,
@@ -23,6 +24,13 @@ Page Content: {text}
23
24
  """
24
25
 
25
26
 
27
+ class GoogleOptions(BaseModel):
28
+ num_results: int | None = None
29
+ max_provider_calls: int | None = None
30
+ max_connections: int | None = None
31
+ model: str | None = None
32
+
33
+
26
34
  class SearchLink:
27
35
  def __init__(self, url: str, snippet: str) -> None:
28
36
  self.url = url
@@ -42,11 +50,14 @@ def maybe_get_google_api_keys() -> tuple[str, str] | None:
42
50
 
43
51
 
44
52
  def google_search_provider(
45
- num_results: int,
46
- max_provider_calls: int,
47
- max_connections: int,
48
- model: str | None,
53
+ in_options: dict[str, object] | None = None,
49
54
  ) -> Callable[[str], Awaitable[str | None]]:
55
+ options = GoogleOptions.model_validate(in_options) if in_options else None
56
+ num_results = (options.num_results if options else None) or 3
57
+ max_provider_calls = (options.max_provider_calls if options else None) or 3
58
+ max_connections = (options.max_connections if options else None) or 10
59
+ model = options.model if options else None
60
+
50
61
  keys = maybe_get_google_api_keys()
51
62
  if not keys:
52
63
  raise PrerequisiteError(
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Awaitable, Callable
2
+ from typing import Awaitable, Callable, Literal
3
3
 
4
4
  import httpx
5
5
  from pydantic import BaseModel, Field
@@ -16,6 +16,25 @@ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
16
16
  from inspect_ai.util._concurrency import concurrency
17
17
 
18
18
 
19
+ class TavilyOptions(BaseModel):
20
+ topic: Literal["general", "news"] | None = None
21
+ search_depth: Literal["basic", "advanced"] | None = None
22
+ chunks_per_source: Literal[1, 2, 3] | None = None
23
+ max_results: int | None = None
24
+ time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | None = (
25
+ None
26
+ )
27
+ days: int | None = None
28
+ include_answer: bool | Literal["basic", "advanced"] | None = None
29
+ include_raw_content: bool | None = None
30
+ include_images: bool | None = None
31
+ include_image_descriptions: bool | None = None
32
+ include_domains: list[str] | None = None
33
+ exclude_domains: list[str] | None = None
34
+ # max_connections is not a Tavily API option, but an inspect option
35
+ max_connections: int | None = None
36
+
37
+
19
38
  class TavilySearchResult(BaseModel):
20
39
  title: str
21
40
  url: str
@@ -32,17 +51,25 @@ class TavilySearchResponse(BaseModel):
32
51
 
33
52
 
34
53
  def tavily_search_provider(
35
- num_results: int, max_connections: int
54
+ in_options: dict[str, object] | None = None,
36
55
  ) -> Callable[[str], Awaitable[str | None]]:
56
+ options = TavilyOptions.model_validate(in_options) if in_options else None
57
+ # Separate max_connections (which is an inspect thing) from the rest of the
58
+ # options which will be passed in the request body
59
+ max_connections = (options.max_connections if options else None) or 10
60
+ api_options = (
61
+ options.model_dump(exclude={"max_connections"}, exclude_none=True)
62
+ if options
63
+ else {}
64
+ )
65
+ if not api_options.get("include_answer", False):
66
+ api_options["include_answer"] = True
67
+
37
68
  tavily_api_key = os.environ.get("TAVILY_API_KEY", None)
38
69
  if not tavily_api_key:
39
70
  raise PrerequisiteError(
40
71
  "TAVILY_API_KEY not set in the environment. Please ensure ths variable is defined to use Tavily with the web_search tool.\n\nLearn more about the Tavily web search provider at https://inspect.aisi.org.uk/tools.html#tavily-provider"
41
72
  )
42
- if num_results > 20:
43
- raise PrerequisiteError(
44
- "The Tavily search provider is limited to 20 results per query."
45
- )
46
73
 
47
74
  # Create the client within the provider
48
75
  client = httpx.AsyncClient(timeout=30)
@@ -52,12 +79,8 @@ def tavily_search_provider(
52
79
  headers = {
53
80
  "Authorization": f"Bearer {tavily_api_key}",
54
81
  }
55
- body = {
56
- "query": query,
57
- "max_results": 10, # num_results,
58
- # "search_depth": "advanced",
59
- "include_answer": "advanced",
60
- }
82
+
83
+ body = {"query": query, **api_options}
61
84
 
62
85
  # retry up to 5 times over a period of up to 1 minute
63
86
  @retry(