inspect-ai 0.3.98__py3-none-any.whl → 0.3.100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. inspect_ai/__init__.py +2 -0
  2. inspect_ai/_cli/log.py +1 -1
  3. inspect_ai/_display/core/config.py +11 -5
  4. inspect_ai/_display/core/panel.py +66 -2
  5. inspect_ai/_display/core/textual.py +5 -2
  6. inspect_ai/_display/plain/display.py +1 -0
  7. inspect_ai/_display/rich/display.py +2 -2
  8. inspect_ai/_display/textual/widgets/transcript.py +41 -1
  9. inspect_ai/_eval/run.py +12 -4
  10. inspect_ai/_eval/score.py +2 -4
  11. inspect_ai/_eval/task/log.py +1 -1
  12. inspect_ai/_eval/task/run.py +59 -81
  13. inspect_ai/_eval/task/task.py +1 -1
  14. inspect_ai/_util/_async.py +1 -1
  15. inspect_ai/_util/content.py +11 -6
  16. inspect_ai/_util/interrupt.py +2 -2
  17. inspect_ai/_util/text.py +7 -0
  18. inspect_ai/_util/working.py +8 -37
  19. inspect_ai/_view/__init__.py +0 -0
  20. inspect_ai/_view/schema.py +3 -1
  21. inspect_ai/_view/view.py +14 -0
  22. inspect_ai/_view/www/CLAUDE.md +15 -0
  23. inspect_ai/_view/www/dist/assets/index.css +273 -169
  24. inspect_ai/_view/www/dist/assets/index.js +20079 -17019
  25. inspect_ai/_view/www/log-schema.json +122 -8
  26. inspect_ai/_view/www/package.json +5 -1
  27. inspect_ai/_view/www/src/@types/log.d.ts +20 -2
  28. inspect_ai/_view/www/src/app/App.tsx +1 -15
  29. inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
  30. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
  31. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
  32. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +221 -205
  33. inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
  34. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
  35. inspect_ai/_view/www/src/app/routing/url.ts +84 -4
  36. inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
  37. inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
  39. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +26 -19
  40. inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
  41. inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
  42. inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
  43. inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
  44. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
  45. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
  46. inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
  47. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
  48. inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
  49. inspect_ai/_view/www/src/app/samples/scores/SampleScoresGrid.module.css +2 -2
  50. inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +2 -3
  51. inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
  52. inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
  53. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
  54. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  55. inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
  56. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
  57. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
  58. inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
  59. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
  61. inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
  62. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
  63. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
  64. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
  65. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
  66. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
  67. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
  68. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
  69. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
  70. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
  71. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
  72. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
  73. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
  74. inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
  75. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
  76. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
  77. inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
  78. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
  79. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
  80. inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
  81. inspect_ai/_view/www/src/app/types.ts +5 -1
  82. inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
  83. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
  84. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
  85. inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
  86. inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
  87. inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
  88. inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
  89. inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
  90. inspect_ai/_view/www/src/state/hooks.ts +52 -2
  91. inspect_ai/_view/www/src/state/logSlice.ts +4 -3
  92. inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
  93. inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
  94. inspect_ai/_view/www/src/state/scrolling.ts +152 -0
  95. inspect_ai/_view/www/src/utils/attachments.ts +7 -0
  96. inspect_ai/_view/www/src/utils/python.ts +18 -0
  97. inspect_ai/_view/www/yarn.lock +269 -6
  98. inspect_ai/agent/_react.py +12 -7
  99. inspect_ai/agent/_run.py +46 -11
  100. inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
  101. inspect_ai/log/_bundle.py +5 -3
  102. inspect_ai/log/_log.py +3 -3
  103. inspect_ai/log/_recorders/file.py +2 -9
  104. inspect_ai/log/_transcript.py +1 -1
  105. inspect_ai/model/_call_tools.py +6 -2
  106. inspect_ai/model/_openai.py +1 -1
  107. inspect_ai/model/_openai_responses.py +78 -39
  108. inspect_ai/model/_openai_web_search.py +31 -0
  109. inspect_ai/model/_providers/anthropic.py +3 -6
  110. inspect_ai/model/_providers/azureai.py +72 -3
  111. inspect_ai/model/_providers/openai.py +2 -1
  112. inspect_ai/model/_providers/providers.py +1 -1
  113. inspect_ai/scorer/_metric.py +1 -2
  114. inspect_ai/solver/_task_state.py +2 -2
  115. inspect_ai/tool/_tool.py +6 -2
  116. inspect_ai/tool/_tool_def.py +27 -4
  117. inspect_ai/tool/_tool_info.py +2 -0
  118. inspect_ai/tool/_tools/_web_search/_google.py +15 -4
  119. inspect_ai/tool/_tools/_web_search/_tavily.py +35 -12
  120. inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
  121. inspect_ai/util/__init__.py +6 -0
  122. inspect_ai/util/_json.py +3 -0
  123. inspect_ai/util/_limit.py +374 -141
  124. inspect_ai/util/_sandbox/docker/compose.py +20 -11
  125. inspect_ai/util/_span.py +1 -1
  126. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/METADATA +3 -3
  127. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/RECORD +131 -117
  128. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/WHEEL +1 -1
  129. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/entry_points.txt +0 -0
  130. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/licenses/LICENSE +0 -0
  131. {inspect_ai-0.3.98.dist-info → inspect_ai-0.3.100.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import types
4
4
  from copy import copy
5
5
  from dataclasses import is_dataclass
6
6
  from datetime import date, datetime, time
7
+ from enum import EnumMeta
7
8
  from logging import getLogger
8
9
  from textwrap import dedent
9
10
  from types import UnionType
@@ -172,7 +173,7 @@ async def execute_tools(
172
173
  except LimitExceededError as ex:
173
174
  tool_error = ToolCallError(
174
175
  "limit",
175
- f"The tool exceeded its {ex.type} limit of {ex.limit}.",
176
+ f"The tool exceeded its {ex.type} limit of {ex.limit_str}.",
176
177
  )
177
178
  except ToolParsingError as ex:
178
179
  tool_error = ToolCallError("parsing", ex.message)
@@ -497,7 +498,7 @@ async def agent_handoff(
497
498
  ChatMessageUser(
498
499
  content=(
499
500
  f"The {agent_name} exceeded its {limit_error.type} limit of "
500
- f"{limit_error.limit}."
501
+ f"{limit_error.limit_str}."
501
502
  )
502
503
  )
503
504
  )
@@ -548,6 +549,7 @@ def tools_info(
548
549
  name=tool.name,
549
550
  description=tool.description,
550
551
  parameters=tool.parameters,
552
+ options=tool.options,
551
553
  )
552
554
  )
553
555
  return tools_info
@@ -652,6 +654,8 @@ def tool_param(type_hint: Type[Any], input: Any) -> Any:
652
654
  return type_hint(**dataclass_data)
653
655
  elif issubclass(type_hint, BaseModel):
654
656
  return type_hint(**input)
657
+ elif isinstance(type_hint, EnumMeta):
658
+ return type_hint(input)
655
659
  else:
656
660
  return input
657
661
  elif origin is list or origin is List:
@@ -594,7 +594,7 @@ def chat_choices_from_openai(
594
594
  stop_reason=as_stop_reason(choice.finish_reason),
595
595
  logprobs=(
596
596
  Logprobs(**choice.logprobs.model_dump())
597
- if choice.logprobs is not None
597
+ if choice.logprobs and choice.logprobs.content is not None
598
598
  else None
599
599
  ),
600
600
  )
@@ -1,6 +1,5 @@
1
1
  import json
2
- from itertools import chain
3
- from typing import TypedDict, cast
2
+ from typing import Sequence, TypedDict, cast
4
3
 
5
4
  from openai.types.responses import (
6
5
  FunctionToolParam,
@@ -8,6 +7,8 @@ from openai.types.responses import (
8
7
  ResponseComputerToolCallParam,
9
8
  ResponseFunctionToolCall,
10
9
  ResponseFunctionToolCallParam,
10
+ ResponseFunctionWebSearch,
11
+ ResponseFunctionWebSearchParam,
11
12
  ResponseInputContentParam,
12
13
  ResponseInputImageParam,
13
14
  ResponseInputItemParam,
@@ -51,6 +52,7 @@ from inspect_ai.model._openai_computer_use import (
51
52
  maybe_computer_use_preview_tool,
52
53
  tool_call_from_openai_computer_tool_call,
53
54
  )
55
+ from inspect_ai.model._openai_web_search import maybe_web_search_tool
54
56
  from inspect_ai.tool._tool_call import ToolCall
55
57
  from inspect_ai.tool._tool_choice import ToolChoice
56
58
  from inspect_ai.tool._tool_info import ToolInfo
@@ -174,6 +176,12 @@ def openai_responses_chat_choices(
174
176
  return [ChatCompletionChoice(message=message, stop_reason=stop_reason)]
175
177
 
176
178
 
179
+ def is_native_tool_configured(
180
+ tools: Sequence[ToolInfo], config: GenerateConfig
181
+ ) -> bool:
182
+ return any(_maybe_native_tool_param(tool, config) is not None for tool in tools)
183
+
184
+
177
185
  # The next two function perform transformations between OpenAI types an Inspect
178
186
  # ChatMessageAssistant. Here is a diagram that helps visualize the transforms.
179
187
  # ┌───────────────────────────┐ ┌───────────────────────────┐ ┌───────────────────────────┐
@@ -207,7 +215,6 @@ def openai_responses_chat_choices(
207
215
 
208
216
 
209
217
  class _AssistantInternal(TypedDict):
210
- output_message_id: str | None
211
218
  tool_message_ids: dict[str, str]
212
219
 
213
220
 
@@ -237,17 +244,17 @@ def _chat_message_assistant_from_openai_response(
237
244
  # collect output and tool calls
238
245
  message_content: list[Content] = []
239
246
  tool_calls: list[ToolCall] = []
240
- internal = _AssistantInternal(output_message_id=None, tool_message_ids={})
247
+ internal = _AssistantInternal(tool_message_ids={})
241
248
  for output in response.output:
242
249
  match output:
243
250
  case ResponseOutputMessage(content=content, id=id):
244
- assert internal["output_message_id"] is None, "Multiple message outputs"
245
- internal["output_message_id"] = id
246
251
  message_content.extend(
247
252
  [
248
- ContentText(text=c.text)
253
+ ContentText(text=c.text, internal={"id": id})
249
254
  if isinstance(c, ResponseOutputText)
250
- else ContentText(text=c.refusal, refusal=True)
255
+ else ContentText(
256
+ text=c.refusal, refusal=True, internal={"id": id}
257
+ )
251
258
  for c in content
252
259
  ]
253
260
  )
@@ -277,6 +284,13 @@ def _chat_message_assistant_from_openai_response(
277
284
  tool_calls.append(
278
285
  tool_call_from_openai_computer_tool_call(output)
279
286
  )
287
+ case ResponseFunctionWebSearch():
288
+ # We don't currently capture this since the model did the
289
+ # "tool call" internally. It's conceivable that could be
290
+ # forced to include it in `.internal` in the future, but
291
+ # for now we just ignore it.
292
+ # {"id":"ws_682cdcec3fa88198bc10b38fafefbd5e077e89e31fd4a3d5","status":"completed","type":"web_search_call"}
293
+ pass
280
294
  case _:
281
295
  raise ValueError(f"Unexpected output type: {output.__class__}")
282
296
 
@@ -304,25 +318,39 @@ def _openai_input_items_from_chat_message_assistant(
304
318
  field of the `ChatMessageAssistant` to help it provide the proper id's the
305
319
  items in the returned list.
306
320
  """
307
- (output_message_id, tool_message_ids) = _ids_from_assistant_internal(message)
321
+ tool_message_ids = _ids_from_assistant_internal(message)
308
322
 
309
323
  # we want to prevent yielding output messages in the case where we have an
310
324
  # 'internal' field (so the message came from the model API as opposed to
311
- # being user synthesized) AND there is no output_message_id (indicating that
312
- # when reading the message from the server we didn't find output). this could
313
- # happen e.g. when a react() agent sets the output.completion in response
325
+ # being user synthesized) AND there are no ContentText items with message IDs
326
+ # (indicating that when reading the message from the server we didn't find output).
327
+ # this could happen e.g. when a react() agent sets the output.completion in response
314
328
  # to a submit() tool call
315
- suppress_output_message = message.internal is not None and output_message_id is None
329
+ content_items: list[ContentText | ContentReasoning] = (
330
+ [ContentText(text=message.content)]
331
+ if isinstance(message.content, str)
332
+ else [
333
+ c for c in message.content if isinstance(c, ContentText | ContentReasoning)
334
+ ]
335
+ )
336
+ has_content_with_ids = any(
337
+ isinstance(c, ContentText)
338
+ and isinstance(c.internal, dict)
339
+ and "id" in c.internal
340
+ for c in content_items
341
+ )
342
+ suppress_output_message = message.internal is not None and not has_content_with_ids
316
343
 
317
344
  # if we are not storing messages on the server then blank these out
318
345
  if not store:
319
- output_message_id = None
320
346
  tool_message_ids = {}
321
347
 
322
- # items to return -- ensure we use a single output message (and just chain
323
- # additional content on to it)
348
+ # items to return
324
349
  items: list[ResponseInputItemParam] = []
325
- output_message: ResponseOutputMessageParam | None = None
350
+ # group content by message ID
351
+ messages_by_id: dict[
352
+ str | None, list[ResponseOutputTextParam | ResponseOutputRefusalParam]
353
+ ] = {}
326
354
 
327
355
  for content in (
328
356
  list[ContentText | ContentReasoning]([ContentText(text=message.content)])
@@ -352,6 +380,14 @@ def _openai_input_items_from_chat_message_assistant(
352
380
  if suppress_output_message:
353
381
  continue
354
382
 
383
+ # get the message ID from ContentText.modelJson
384
+ content_message_id: str | None = None
385
+ if isinstance(content.internal, dict) and "id" in content.internal:
386
+ id_value = content.internal["id"]
387
+ content_message_id = id_value if isinstance(id_value, str) else None
388
+ else:
389
+ content_message_id = None
390
+
355
391
  new_content = (
356
392
  ResponseOutputRefusalParam(type="refusal", refusal=text)
357
393
  if refusal
@@ -359,22 +395,24 @@ def _openai_input_items_from_chat_message_assistant(
359
395
  type="output_text", text=text, annotations=[]
360
396
  )
361
397
  )
362
- if output_message is None:
363
- output_message = ResponseOutputMessageParam(
364
- type="message",
365
- role="assistant",
366
- # this actually can be `None`, and it will in fact be `None` when the
367
- # assistant message is synthesized by the scaffold as opposed to being
368
- # replayed from the model (or when store=False)
369
- id=output_message_id, # type: ignore[typeddict-item]
370
- content=[new_content],
371
- status="completed",
372
- )
373
- items.append(output_message)
374
- else:
375
- output_message["content"] = chain(
376
- output_message["content"], [new_content]
377
- )
398
+
399
+ if content_message_id not in messages_by_id:
400
+ messages_by_id[content_message_id] = []
401
+ messages_by_id[content_message_id].append(new_content)
402
+
403
+ # create ResponseOutputMessage for each unique ID
404
+ for msg_id, content_list in messages_by_id.items():
405
+ output_message = ResponseOutputMessageParam(
406
+ type="message",
407
+ role="assistant",
408
+ # this actually can be `None`, and it will in fact be `None` when the
409
+ # assistant message is synthesized by the scaffold as opposed to being
410
+ # replayed from the model (or when store=False)
411
+ id=msg_id, # type: ignore[typeddict-item]
412
+ content=content_list,
413
+ status="completed",
414
+ )
415
+ items.append(output_message)
378
416
 
379
417
  return items + _tool_call_items_from_assistant_message(message, tool_message_ids)
380
418
 
@@ -399,7 +437,7 @@ def _maybe_native_tool_param(
399
437
  ) -> ToolParam | None:
400
438
  return (
401
439
  (
402
- maybe_computer_use_preview_tool(tool)
440
+ maybe_computer_use_preview_tool(tool) or maybe_web_search_tool(tool)
403
441
  # or self.text_editor_tool_param(tool)
404
442
  # or self.bash_tool_param(tool)
405
443
  )
@@ -442,22 +480,23 @@ def _tool_call_items_from_assistant_message(
442
480
 
443
481
  def _ids_from_assistant_internal(
444
482
  message: ChatMessageAssistant,
445
- ) -> tuple[str | None, dict[str, str]]:
483
+ ) -> dict[str, str]:
446
484
  if message.internal is not None:
447
485
  assert isinstance(message.internal, dict), (
448
486
  "OpenAI ChatMessageAssistant internal must be an _AssistantInternal"
449
487
  )
450
488
  internal = cast(_AssistantInternal, message.internal)
451
- return (internal["output_message_id"], internal["tool_message_ids"])
489
+ return internal["tool_message_ids"]
452
490
  else:
453
- return None, {}
491
+ return {}
454
492
 
455
493
 
456
494
  _ResponseToolCallParam = (
457
- ResponseFunctionToolCallParam | ResponseComputerToolCallParam
495
+ ResponseFunctionToolCallParam
496
+ | ResponseComputerToolCallParam
497
+ | ResponseFunctionWebSearchParam
458
498
  # | ResponseFileSearchToolCallParam
459
499
  # | ResponseFunctionToolCallParam
460
- # | ResponseFunctionWebSearchParam
461
500
  )
462
501
 
463
502
 
@@ -0,0 +1,31 @@
1
+ from typing import cast
2
+
3
+ from openai.types.responses import WebSearchTool, WebSearchToolParam
4
+
5
+ from inspect_ai.tool._tool_info import ToolInfo
6
+
7
+
8
+ def maybe_web_search_tool(tool: ToolInfo) -> WebSearchToolParam | None:
9
+ return (
10
+ _web_search_tool(tool.options["openai"])
11
+ if tool.name == "web_search" and tool.options and "openai" in tool.options
12
+ else None
13
+ )
14
+
15
+
16
+ def _web_search_tool(maybe_openai_options: object) -> WebSearchToolParam:
17
+ if maybe_openai_options is None:
18
+ maybe_openai_options = {}
19
+ elif not isinstance(maybe_openai_options, dict):
20
+ raise TypeError(
21
+ f"Expected a dictionary for openai_options, got {type(maybe_openai_options)}"
22
+ )
23
+ openai_options = (
24
+ WebSearchTool.model_validate(
25
+ {"type": "web_search_preview", **maybe_openai_options}
26
+ )
27
+ if maybe_openai_options
28
+ else WebSearchTool(type="web_search_preview")
29
+ )
30
+
31
+ return cast(WebSearchToolParam, openai_options.model_dump(exclude_none=True))
@@ -356,12 +356,9 @@ class AnthropicAPI(ModelAPI):
356
356
  if isinstance(ex, APIStatusError):
357
357
  # for unknown reasons, anthropic does not always set status_code == 529
358
358
  # for "overloaded_error" so we check for it explicitly
359
- if (
360
- isinstance(ex.body, dict)
361
- and isinstance(ex.body.get("error", {}), dict)
362
- and ex.body.get("error", {}).get("type", "") == "overloaded_error"
363
- ):
364
- return True
359
+ if isinstance(ex.body, dict):
360
+ if "overloaded_error" in str(ex.body):
361
+ return True
365
362
 
366
363
  # standard http status code checking
367
364
  return is_retryable_http_status(ex.status_code)
@@ -1,3 +1,4 @@
1
+ import functools
1
2
  import json
2
3
  import os
3
4
  from copy import copy
@@ -151,7 +152,7 @@ class AzureAIAPI(ModelAPI):
151
152
 
152
153
  # prepare request
153
154
  request = dict(
154
- messages=await chat_request_messages(input, handler),
155
+ messages=await chat_request_messages(input, handler, self.is_mistral()),
155
156
  **self.completion_params(config),
156
157
  )
157
158
  # newer versions of vllm reject requests with tools or tool_choice if the
@@ -280,9 +281,77 @@ class AzureAIAPI(ModelAPI):
280
281
 
281
282
 
282
283
  async def chat_request_messages(
283
- messages: list[ChatMessage], handler: ChatAPIHandler | None
284
+ messages: list[ChatMessage],
285
+ handler: ChatAPIHandler | None,
286
+ is_mistral: bool = False,
287
+ ) -> list[ChatRequestMessage]:
288
+ chat_messages = [
289
+ await chat_request_message(message, handler) for message in messages
290
+ ]
291
+ if is_mistral:
292
+ chat_messages = functools.reduce(mistral_message_reducer, chat_messages, [])
293
+
294
+ return chat_messages
295
+
296
+
297
+ def mistral_message_reducer(
298
+ messages: list[ChatRequestMessage],
299
+ message: ChatRequestMessage,
284
300
  ) -> list[ChatRequestMessage]:
285
- return [await chat_request_message(message, handler) for message in messages]
301
+ """Fold any user messages found immediately after tool messages into the last tool message."""
302
+ if (
303
+ len(messages) > 0
304
+ and isinstance(messages[-1], ToolMessage)
305
+ and isinstance(message, UserMessage)
306
+ ):
307
+ messages[-1] = fold_user_message_into_tool_message(messages[-1], message)
308
+ else:
309
+ messages.append(message)
310
+
311
+ return messages
312
+
313
+
314
+ def fold_user_message_into_tool_message(
315
+ tool_message: ToolMessage,
316
+ user_message: UserMessage,
317
+ ) -> ToolMessage:
318
+ def convert_content_items_to_string(list_content: list[ContentItem]) -> str:
319
+ if not all(
320
+ isinstance(item, (TextContentItem | ImageContentItem))
321
+ for item in list_content
322
+ ):
323
+ raise TypeError(
324
+ "Expected all items to be TextContentItem or ImageContentItem"
325
+ )
326
+
327
+ parts = []
328
+ for item in list_content:
329
+ if isinstance(item, TextContentItem):
330
+ parts.append(item.text)
331
+ elif isinstance(item, ImageContentItem):
332
+ parts.append(f"[Image: {item.image_url.url}]")
333
+ else:
334
+ raise ValueError("Unexpected content item type")
335
+ return "".join(parts)
336
+
337
+ def normalise_content(
338
+ content: str | list[ContentItem] | None,
339
+ ) -> str | None:
340
+ return (
341
+ None
342
+ if content is None
343
+ else convert_content_items_to_string(content)
344
+ if isinstance(content, list)
345
+ else content
346
+ )
347
+
348
+ tool_content = normalise_content(tool_message.content)
349
+ user_content = normalise_content(user_message.content)
350
+
351
+ return ToolMessage(
352
+ content=(tool_content or "") + (user_content or ""),
353
+ tool_call_id=tool_message.tool_call_id,
354
+ )
286
355
 
287
356
 
288
357
  async def chat_request_message(
@@ -42,6 +42,7 @@ from .._openai import (
42
42
  openai_media_filter,
43
43
  openai_should_retry,
44
44
  )
45
+ from .._openai_responses import is_native_tool_configured
45
46
  from .openai_o1 import generate_o1
46
47
  from .util import environment_prerequisite_error, model_base_url
47
48
 
@@ -241,7 +242,7 @@ class OpenAIAPI(ModelAPI):
241
242
  tools=tools,
242
243
  **self.completion_params(config, False),
243
244
  )
244
- elif self.responses_api:
245
+ elif self.responses_api or is_native_tool_configured(tools, config):
245
246
  return await generate_responses(
246
247
  client=self.client,
247
248
  http_hooks=self._http_hooks,
@@ -105,7 +105,7 @@ def vertex() -> type[ModelAPI]:
105
105
  def google() -> type[ModelAPI]:
106
106
  FEATURE = "Google API"
107
107
  PACKAGE = "google-genai"
108
- MIN_VERSION = "1.12.1"
108
+ MIN_VERSION = "1.16.1"
109
109
 
110
110
  # verify we have the package
111
111
  try:
@@ -7,7 +7,6 @@ from typing import (
7
7
  Protocol,
8
8
  Type,
9
9
  Union,
10
- cast,
11
10
  overload,
12
11
  runtime_checkable,
13
12
  )
@@ -356,7 +355,7 @@ def metric(
356
355
  )
357
356
  return metric
358
357
 
359
- return metric_register(cast(Callable[P, Metric], metric_wrapper), metric_name)
358
+ return metric_register(metric_wrapper, metric_name)
360
359
 
361
360
  # for decorators with an explicit name, one more wrapper for the name
362
361
  if isinstance(name, str):
@@ -290,7 +290,7 @@ class TaskState:
290
290
  return self._tools
291
291
 
292
292
  @tools.setter
293
- def tools(self, tools: list[Tool | ToolDef]) -> None:
293
+ def tools(self, tools: Sequence[Tool | ToolDef]) -> None:
294
294
  self._tools.clear()
295
295
  for tool in tools:
296
296
  self._tools.append(tool if isinstance(tool, Tool) else tool.as_tool())
@@ -353,7 +353,7 @@ class TaskState:
353
353
  def completed(self) -> bool:
354
354
  """Is the task completed.
355
355
 
356
- Additionally, checks message and token limits and raises if they are exceeded, and also checks for an operator interrupt of the sample.
356
+ Additionally, checks for an operator interrupt of the sample.
357
357
  """
358
358
  from inspect_ai.log._samples import set_active_sample_total_messages
359
359
 
inspect_ai/tool/_tool.py CHANGED
@@ -224,13 +224,15 @@ def tool(
224
224
  tool_parallel = parallel
225
225
  tool_viewer = viewer
226
226
  tool_model_input = model_input
227
+ tool_options: dict[str, object] | None = None
227
228
  if is_registry_object(tool):
228
- _, _, reg_parallel, reg_viewer, reg_model_input = tool_registry_info(
229
- tool
229
+ _, _, reg_parallel, reg_viewer, reg_model_input, options = (
230
+ tool_registry_info(tool)
230
231
  )
231
232
  tool_parallel = parallel and reg_parallel
232
233
  tool_viewer = viewer or reg_viewer
233
234
  tool_model_input = model_input or reg_model_input
235
+ tool_options = options
234
236
 
235
237
  # tag the object
236
238
  registry_tag(
@@ -247,6 +249,7 @@ def tool(
247
249
  tool_model_input
248
250
  or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
249
251
  ),
252
+ TOOL_OPTIONS: tool_options,
250
253
  },
251
254
  ),
252
255
  *args,
@@ -267,6 +270,7 @@ TOOL_PROMPT = "prompt"
267
270
  TOOL_PARALLEL = "parallel"
268
271
  TOOL_VIEWER = "viewer"
269
272
  TOOL_MODEL_INPUT = "model_input"
273
+ TOOL_OPTIONS = "options"
270
274
 
271
275
 
272
276
  TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
@@ -16,6 +16,7 @@ from inspect_ai._util.registry import (
16
16
 
17
17
  from ._tool import (
18
18
  TOOL_MODEL_INPUT,
19
+ TOOL_OPTIONS,
19
20
  TOOL_PARALLEL,
20
21
  TOOL_PROMPT,
21
22
  TOOL_VIEWER,
@@ -44,6 +45,7 @@ class ToolDef:
44
45
  parallel: bool | None = None,
45
46
  viewer: ToolCallViewer | None = None,
46
47
  model_input: ToolCallModelInput | None = None,
48
+ options: dict[str, object] | None = None,
47
49
  ) -> None:
48
50
  """Create a tool definition.
49
51
 
@@ -59,6 +61,8 @@ class ToolDef:
59
61
  viewer: Optional tool call viewer implementation.
60
62
  model_input: Optional function that determines how
61
63
  tool call results are played back as model input.
64
+ options: Optional property bag that can be used by the model provider
65
+ to customize the implementation of the tool
62
66
 
63
67
  Returns:
64
68
  Tool definition.
@@ -82,6 +86,7 @@ class ToolDef:
82
86
  self.parallel = parallel if parallel is not None else tdef.parallel
83
87
  self.viewer = viewer or tdef.viewer
84
88
  self.model_input = model_input or tdef.model_input
89
+ self.options = options or tdef.options
85
90
 
86
91
  # if its not a tool then extract tool_info if all fields have not
87
92
  # been provided explicitly
@@ -112,6 +117,7 @@ class ToolDef:
112
117
  self.parallel = parallel is not False
113
118
  self.viewer = viewer
114
119
  self.model_input = model_input
120
+ self.options = options
115
121
 
116
122
  tool: Callable[..., Any]
117
123
  """Callable to execute tool."""
@@ -134,13 +140,20 @@ class ToolDef:
134
140
  model_input: ToolCallModelInput | None
135
141
  """Custom model input presenter for tool calls."""
136
142
 
143
+ options: dict[str, object] | None = None
144
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
145
+
137
146
  def as_tool(self) -> Tool:
138
147
  """Convert a ToolDef to a Tool."""
139
148
  tool = self.tool
140
149
  info = RegistryInfo(
141
150
  type="tool",
142
151
  name=self.name,
143
- metadata={TOOL_PARALLEL: self.parallel, TOOL_VIEWER: self.viewer},
152
+ metadata={
153
+ TOOL_PARALLEL: self.parallel,
154
+ TOOL_VIEWER: self.viewer,
155
+ TOOL_OPTIONS: self.options,
156
+ },
144
157
  )
145
158
  set_registry_info(tool, info)
146
159
  set_registry_params(tool, {})
@@ -189,11 +202,12 @@ class ToolDefFields(NamedTuple):
189
202
  parallel: bool
190
203
  viewer: ToolCallViewer | None
191
204
  model_input: ToolCallModelInput | None
205
+ options: dict[str, object] | None
192
206
 
193
207
 
194
208
  def tool_def_fields(tool: Tool) -> ToolDefFields:
195
209
  # get tool_info
196
- name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
210
+ name, prompt, parallel, viewer, model_input, options = tool_registry_info(tool)
197
211
  tool_info = parse_tool_info(tool)
198
212
 
199
213
  # if there is a description then append any prompt to the
@@ -234,19 +248,28 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
234
248
  parallel=parallel,
235
249
  viewer=viewer,
236
250
  model_input=model_input,
251
+ options=options,
237
252
  )
238
253
 
239
254
 
240
255
  def tool_registry_info(
241
256
  tool: Tool,
242
- ) -> tuple[str, str | None, bool, ToolCallViewer | None, ToolCallModelInput | None]:
257
+ ) -> tuple[
258
+ str,
259
+ str | None,
260
+ bool,
261
+ ToolCallViewer | None,
262
+ ToolCallModelInput | None,
263
+ dict[str, object] | None,
264
+ ]:
243
265
  info = registry_info(tool)
244
266
  name = info.name.split("/")[-1]
245
267
  prompt = info.metadata.get(TOOL_PROMPT, None)
246
268
  parallel = info.metadata.get(TOOL_PARALLEL, True)
247
269
  viewer = info.metadata.get(TOOL_VIEWER, None)
248
270
  model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
249
- return name, prompt, parallel, viewer, model_input
271
+ options = info.metadata.get(TOOL_OPTIONS, None)
272
+ return name, prompt, parallel, viewer, model_input, options
250
273
 
251
274
 
252
275
  def validate_tool_parameters(tool_name: str, parameters: dict[str, ToolParam]) -> None:
@@ -49,6 +49,8 @@ class ToolInfo(BaseModel):
49
49
  """Short description of tool."""
50
50
  parameters: ToolParams = Field(default_factory=ToolParams)
51
51
  """JSON Schema of tool parameters object."""
52
+ options: dict[str, object] | None = Field(default=None)
53
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
52
54
 
53
55
 
54
56
  def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
@@ -4,6 +4,7 @@ from typing import Awaitable, Callable
4
4
  import anyio
5
5
  import httpx
6
6
  from bs4 import BeautifulSoup, NavigableString
7
+ from pydantic import BaseModel
7
8
  from tenacity import (
8
9
  retry,
9
10
  retry_if_exception,
@@ -23,6 +24,13 @@ Page Content: {text}
23
24
  """
24
25
 
25
26
 
27
+ class GoogleOptions(BaseModel):
28
+ num_results: int | None = None
29
+ max_provider_calls: int | None = None
30
+ max_connections: int | None = None
31
+ model: str | None = None
32
+
33
+
26
34
  class SearchLink:
27
35
  def __init__(self, url: str, snippet: str) -> None:
28
36
  self.url = url
@@ -42,11 +50,14 @@ def maybe_get_google_api_keys() -> tuple[str, str] | None:
42
50
 
43
51
 
44
52
  def google_search_provider(
45
- num_results: int,
46
- max_provider_calls: int,
47
- max_connections: int,
48
- model: str | None,
53
+ in_options: dict[str, object] | None = None,
49
54
  ) -> Callable[[str], Awaitable[str | None]]:
55
+ options = GoogleOptions.model_validate(in_options) if in_options else None
56
+ num_results = (options.num_results if options else None) or 3
57
+ max_provider_calls = (options.max_provider_calls if options else None) or 3
58
+ max_connections = (options.max_connections if options else None) or 10
59
+ model = options.model if options else None
60
+
50
61
  keys = maybe_get_google_api_keys()
51
62
  if not keys:
52
63
  raise PrerequisiteError(