inspect-ai 0.3.99__py3-none-any.whl → 0.3.101__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. inspect_ai/_cli/eval.py +2 -1
  2. inspect_ai/_display/core/config.py +11 -5
  3. inspect_ai/_display/core/panel.py +66 -2
  4. inspect_ai/_display/core/textual.py +5 -2
  5. inspect_ai/_display/plain/display.py +1 -0
  6. inspect_ai/_display/rich/display.py +2 -2
  7. inspect_ai/_display/textual/widgets/transcript.py +37 -9
  8. inspect_ai/_eval/eval.py +13 -1
  9. inspect_ai/_eval/evalset.py +3 -2
  10. inspect_ai/_eval/run.py +2 -0
  11. inspect_ai/_eval/score.py +2 -4
  12. inspect_ai/_eval/task/log.py +3 -1
  13. inspect_ai/_eval/task/run.py +59 -81
  14. inspect_ai/_util/content.py +11 -6
  15. inspect_ai/_util/interrupt.py +2 -2
  16. inspect_ai/_util/text.py +7 -0
  17. inspect_ai/_util/working.py +8 -37
  18. inspect_ai/_view/__init__.py +0 -0
  19. inspect_ai/_view/schema.py +2 -1
  20. inspect_ai/_view/www/CLAUDE.md +15 -0
  21. inspect_ai/_view/www/dist/assets/index.css +307 -171
  22. inspect_ai/_view/www/dist/assets/index.js +24733 -21641
  23. inspect_ai/_view/www/log-schema.json +77 -3
  24. inspect_ai/_view/www/package.json +9 -5
  25. inspect_ai/_view/www/src/@types/log.d.ts +9 -0
  26. inspect_ai/_view/www/src/app/App.tsx +1 -15
  27. inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
  28. inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
  29. inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
  30. inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
  31. inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
  32. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
  33. inspect_ai/_view/www/src/app/log-view/tabs/grouping.ts +4 -4
  34. inspect_ai/_view/www/src/app/routing/navigationHooks.ts +22 -25
  35. inspect_ai/_view/www/src/app/routing/url.ts +84 -4
  36. inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
  37. inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
  38. inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
  39. inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
  40. inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
  41. inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
  42. inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
  43. inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
  44. inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
  45. inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
  46. inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
  47. inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
  48. inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +17 -5
  49. inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
  50. inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
  51. inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
  52. inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
  53. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
  54. inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
  55. inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
  56. inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
  57. inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
  58. inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
  59. inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
  60. inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
  61. inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
  62. inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
  63. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
  64. inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
  65. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
  66. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
  67. inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
  68. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
  69. inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
  70. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
  71. inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
  72. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
  73. inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
  74. inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
  75. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
  76. inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
  77. inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
  78. inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
  79. inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
  80. inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
  81. inspect_ai/_view/www/src/app/types.ts +5 -1
  82. inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
  83. inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
  84. inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
  85. inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
  86. inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
  87. inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
  88. inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
  89. inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
  90. inspect_ai/_view/www/src/state/hooks.ts +52 -2
  91. inspect_ai/_view/www/src/state/logSlice.ts +4 -3
  92. inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
  93. inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
  94. inspect_ai/_view/www/src/state/scrolling.ts +152 -0
  95. inspect_ai/_view/www/src/utils/attachments.ts +7 -0
  96. inspect_ai/_view/www/src/utils/python.ts +18 -0
  97. inspect_ai/_view/www/yarn.lock +290 -33
  98. inspect_ai/agent/_react.py +12 -7
  99. inspect_ai/agent/_run.py +2 -3
  100. inspect_ai/analysis/beta/__init__.py +2 -0
  101. inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
  102. inspect_ai/dataset/_sources/csv.py +2 -6
  103. inspect_ai/dataset/_sources/hf.py +2 -6
  104. inspect_ai/dataset/_sources/json.py +2 -6
  105. inspect_ai/dataset/_util.py +23 -0
  106. inspect_ai/log/_log.py +1 -1
  107. inspect_ai/log/_recorders/eval.py +4 -3
  108. inspect_ai/log/_recorders/file.py +2 -9
  109. inspect_ai/log/_recorders/json.py +1 -0
  110. inspect_ai/log/_recorders/recorder.py +1 -0
  111. inspect_ai/log/_transcript.py +1 -1
  112. inspect_ai/model/_call_tools.py +6 -2
  113. inspect_ai/model/_openai.py +1 -1
  114. inspect_ai/model/_openai_responses.py +85 -41
  115. inspect_ai/model/_openai_web_search.py +38 -0
  116. inspect_ai/model/_providers/azureai.py +72 -3
  117. inspect_ai/model/_providers/openai.py +4 -1
  118. inspect_ai/model/_providers/openai_responses.py +5 -1
  119. inspect_ai/scorer/_metric.py +1 -2
  120. inspect_ai/scorer/_reducer/reducer.py +1 -1
  121. inspect_ai/solver/_task_state.py +2 -2
  122. inspect_ai/tool/_tool.py +6 -2
  123. inspect_ai/tool/_tool_def.py +27 -4
  124. inspect_ai/tool/_tool_info.py +2 -0
  125. inspect_ai/tool/_tools/_web_search/_google.py +43 -15
  126. inspect_ai/tool/_tools/_web_search/_tavily.py +46 -13
  127. inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
  128. inspect_ai/util/__init__.py +4 -0
  129. inspect_ai/util/_json.py +3 -0
  130. inspect_ai/util/_limit.py +230 -20
  131. inspect_ai/util/_sandbox/docker/compose.py +20 -11
  132. inspect_ai/util/_span.py +1 -1
  133. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/METADATA +3 -3
  134. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/RECORD +138 -124
  135. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/WHEEL +1 -1
  136. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/entry_points.txt +0 -0
  137. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/licenses/LICENSE +0 -0
  138. {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/top_level.txt +0 -0
inspect_ai/tool/_tool.py CHANGED
@@ -224,13 +224,15 @@ def tool(
224
224
  tool_parallel = parallel
225
225
  tool_viewer = viewer
226
226
  tool_model_input = model_input
227
+ tool_options: dict[str, object] | None = None
227
228
  if is_registry_object(tool):
228
- _, _, reg_parallel, reg_viewer, reg_model_input = tool_registry_info(
229
- tool
229
+ _, _, reg_parallel, reg_viewer, reg_model_input, options = (
230
+ tool_registry_info(tool)
230
231
  )
231
232
  tool_parallel = parallel and reg_parallel
232
233
  tool_viewer = viewer or reg_viewer
233
234
  tool_model_input = model_input or reg_model_input
235
+ tool_options = options
234
236
 
235
237
  # tag the object
236
238
  registry_tag(
@@ -247,6 +249,7 @@ def tool(
247
249
  tool_model_input
248
250
  or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
249
251
  ),
252
+ TOOL_OPTIONS: tool_options,
250
253
  },
251
254
  ),
252
255
  *args,
@@ -267,6 +270,7 @@ TOOL_PROMPT = "prompt"
267
270
  TOOL_PARALLEL = "parallel"
268
271
  TOOL_VIEWER = "viewer"
269
272
  TOOL_MODEL_INPUT = "model_input"
273
+ TOOL_OPTIONS = "options"
270
274
 
271
275
 
272
276
  TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
@@ -16,6 +16,7 @@ from inspect_ai._util.registry import (
16
16
 
17
17
  from ._tool import (
18
18
  TOOL_MODEL_INPUT,
19
+ TOOL_OPTIONS,
19
20
  TOOL_PARALLEL,
20
21
  TOOL_PROMPT,
21
22
  TOOL_VIEWER,
@@ -44,6 +45,7 @@ class ToolDef:
44
45
  parallel: bool | None = None,
45
46
  viewer: ToolCallViewer | None = None,
46
47
  model_input: ToolCallModelInput | None = None,
48
+ options: dict[str, object] | None = None,
47
49
  ) -> None:
48
50
  """Create a tool definition.
49
51
 
@@ -59,6 +61,8 @@ class ToolDef:
59
61
  viewer: Optional tool call viewer implementation.
60
62
  model_input: Optional function that determines how
61
63
  tool call results are played back as model input.
64
+ options: Optional property bag that can be used by the model provider
65
+ to customize the implementation of the tool
62
66
 
63
67
  Returns:
64
68
  Tool definition.
@@ -82,6 +86,7 @@ class ToolDef:
82
86
  self.parallel = parallel if parallel is not None else tdef.parallel
83
87
  self.viewer = viewer or tdef.viewer
84
88
  self.model_input = model_input or tdef.model_input
89
+ self.options = options or tdef.options
85
90
 
86
91
  # if its not a tool then extract tool_info if all fields have not
87
92
  # been provided explicitly
@@ -112,6 +117,7 @@ class ToolDef:
112
117
  self.parallel = parallel is not False
113
118
  self.viewer = viewer
114
119
  self.model_input = model_input
120
+ self.options = options
115
121
 
116
122
  tool: Callable[..., Any]
117
123
  """Callable to execute tool."""
@@ -134,13 +140,20 @@ class ToolDef:
134
140
  model_input: ToolCallModelInput | None
135
141
  """Custom model input presenter for tool calls."""
136
142
 
143
+ options: dict[str, object] | None = None
144
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
145
+
137
146
  def as_tool(self) -> Tool:
138
147
  """Convert a ToolDef to a Tool."""
139
148
  tool = self.tool
140
149
  info = RegistryInfo(
141
150
  type="tool",
142
151
  name=self.name,
143
- metadata={TOOL_PARALLEL: self.parallel, TOOL_VIEWER: self.viewer},
152
+ metadata={
153
+ TOOL_PARALLEL: self.parallel,
154
+ TOOL_VIEWER: self.viewer,
155
+ TOOL_OPTIONS: self.options,
156
+ },
144
157
  )
145
158
  set_registry_info(tool, info)
146
159
  set_registry_params(tool, {})
@@ -189,11 +202,12 @@ class ToolDefFields(NamedTuple):
189
202
  parallel: bool
190
203
  viewer: ToolCallViewer | None
191
204
  model_input: ToolCallModelInput | None
205
+ options: dict[str, object] | None
192
206
 
193
207
 
194
208
  def tool_def_fields(tool: Tool) -> ToolDefFields:
195
209
  # get tool_info
196
- name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
210
+ name, prompt, parallel, viewer, model_input, options = tool_registry_info(tool)
197
211
  tool_info = parse_tool_info(tool)
198
212
 
199
213
  # if there is a description then append any prompt to the
@@ -234,19 +248,28 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
234
248
  parallel=parallel,
235
249
  viewer=viewer,
236
250
  model_input=model_input,
251
+ options=options,
237
252
  )
238
253
 
239
254
 
240
255
  def tool_registry_info(
241
256
  tool: Tool,
242
- ) -> tuple[str, str | None, bool, ToolCallViewer | None, ToolCallModelInput | None]:
257
+ ) -> tuple[
258
+ str,
259
+ str | None,
260
+ bool,
261
+ ToolCallViewer | None,
262
+ ToolCallModelInput | None,
263
+ dict[str, object] | None,
264
+ ]:
243
265
  info = registry_info(tool)
244
266
  name = info.name.split("/")[-1]
245
267
  prompt = info.metadata.get(TOOL_PROMPT, None)
246
268
  parallel = info.metadata.get(TOOL_PARALLEL, True)
247
269
  viewer = info.metadata.get(TOOL_VIEWER, None)
248
270
  model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
249
- return name, prompt, parallel, viewer, model_input
271
+ options = info.metadata.get(TOOL_OPTIONS, None)
272
+ return name, prompt, parallel, viewer, model_input, options
250
273
 
251
274
 
252
275
  def validate_tool_parameters(tool_name: str, parameters: dict[str, ToolParam]) -> None:
@@ -49,6 +49,8 @@ class ToolInfo(BaseModel):
49
49
  """Short description of tool."""
50
50
  parameters: ToolParams = Field(default_factory=ToolParams)
51
51
  """JSON Schema of tool parameters object."""
52
+ options: dict[str, object] | None = Field(default=None)
53
+ """Optional property bag that can be used by the model provider to customize the implementation of the tool"""
52
54
 
53
55
 
54
56
  def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
@@ -4,6 +4,7 @@ from typing import Awaitable, Callable
4
4
  import anyio
5
5
  import httpx
6
6
  from bs4 import BeautifulSoup, NavigableString
7
+ from pydantic import BaseModel
7
8
  from tenacity import (
8
9
  retry,
9
10
  retry_if_exception,
@@ -23,10 +24,18 @@ Page Content: {text}
23
24
  """
24
25
 
25
26
 
27
+ class GoogleOptions(BaseModel):
28
+ num_results: int | None = None
29
+ max_provider_calls: int | None = None
30
+ max_connections: int | None = None
31
+ model: str | None = None
32
+
33
+
26
34
  class SearchLink:
27
- def __init__(self, url: str, snippet: str) -> None:
35
+ def __init__(self, url: str, snippet: str, title: str) -> None:
28
36
  self.url = url
29
37
  self.snippet = snippet
38
+ self.title = title
30
39
 
31
40
 
32
41
  def maybe_get_google_api_keys() -> tuple[str, str] | None:
@@ -42,11 +51,14 @@ def maybe_get_google_api_keys() -> tuple[str, str] | None:
42
51
 
43
52
 
44
53
  def google_search_provider(
45
- num_results: int,
46
- max_provider_calls: int,
47
- max_connections: int,
48
- model: str | None,
54
+ in_options: dict[str, object] | None = None,
49
55
  ) -> Callable[[str], Awaitable[str | None]]:
56
+ options = GoogleOptions.model_validate(in_options) if in_options else None
57
+ num_results = (options.num_results if options else None) or 3
58
+ max_provider_calls = (options.max_provider_calls if options else None) or 3
59
+ max_connections = (options.max_connections if options else None) or 10
60
+ model = options.model if options else None
61
+
50
62
  keys = maybe_get_google_api_keys()
51
63
  if not keys:
52
64
  raise PrerequisiteError(
@@ -60,8 +72,7 @@ def google_search_provider(
60
72
  async def search(query: str) -> str | None:
61
73
  # limit number of concurrent searches
62
74
  page_contents: list[str] = []
63
- urls: list[str] = []
64
- snippets: list[str] = []
75
+ processed_links: list[SearchLink] = []
65
76
  search_calls = 0
66
77
 
67
78
  # Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
@@ -76,8 +87,7 @@ def google_search_provider(
76
87
  page = await page_if_relevant(link.url, query, model, client)
77
88
  if page:
78
89
  page_contents.append(page)
79
- urls.append(link.url)
80
- snippets.append(link.snippet)
90
+ processed_links.append(link)
81
91
  # exceptions fetching pages are very common!
82
92
  except Exception:
83
93
  pass
@@ -87,8 +97,18 @@ def google_search_provider(
87
97
 
88
98
  search_calls += 1
89
99
 
90
- all_page_contents = "\n\n".join(page_contents)
91
- return None if all_page_contents == "" else all_page_contents
100
+ return (
101
+ "\n\n".join(
102
+ "[{title}]({url}):\n{page_content}".format(
103
+ title=link.title, url=link.url, page_content=page_content
104
+ )
105
+ for link, page_content in zip(
106
+ processed_links, page_contents, strict=True
107
+ )
108
+ )
109
+ if processed_links
110
+ else None
111
+ )
92
112
 
93
113
  async def _search(query: str, start_idx: int) -> list[SearchLink]:
94
114
  # List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
@@ -110,13 +130,21 @@ def google_search_provider(
110
130
  before_sleep=log_httpx_retry_attempt(search_url),
111
131
  )
112
132
  async def execute_search() -> httpx.Response:
133
+ # See https://developers.google.com/custom-search/v1/reference/rest/v1/Search
113
134
  return await client.get(search_url)
114
135
 
115
136
  result = await execute_search()
116
137
  data = result.json()
117
138
 
118
139
  if "items" in data:
119
- return [SearchLink(item["link"], item["snippet"]) for item in data["items"]]
140
+ return [
141
+ SearchLink(
142
+ url=item["link"],
143
+ snippet=item.get("snippet", ""), # sometimes not present
144
+ title=item["title"],
145
+ )
146
+ for item in data["items"]
147
+ ]
120
148
  else:
121
149
  return []
122
150
 
@@ -124,13 +152,13 @@ def google_search_provider(
124
152
 
125
153
 
126
154
  async def page_if_relevant(
127
- link: str, query: str, relevance_model: str | None, client: httpx.AsyncClient
155
+ url: str, query: str, relevance_model: str | None, client: httpx.AsyncClient
128
156
  ) -> str | None:
129
157
  """
130
158
  Use parser model to determine if a web page contents is relevant to a query.
131
159
 
132
160
  Args:
133
- link (str): Web page link.
161
+ url (str): Web page url.
134
162
  query (str): Search query.
135
163
  relevance_model (Model): Model used to parse web pages for relevance.
136
164
  client: (httpx.Client): HTTP client to use to fetch the page
@@ -145,7 +173,7 @@ async def page_if_relevant(
145
173
 
146
174
  # retrieve document
147
175
  try:
148
- response = await client.get(link)
176
+ response = await client.get(url)
149
177
  response.raise_for_status()
150
178
  except httpx.HTTPError as exc:
151
179
  raise Exception(f"HTTP error occurred: {exc}")
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Awaitable, Callable
2
+ from typing import Awaitable, Callable, Literal
3
3
 
4
4
  import httpx
5
5
  from pydantic import BaseModel, Field
@@ -16,6 +16,25 @@ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
16
16
  from inspect_ai.util._concurrency import concurrency
17
17
 
18
18
 
19
+ class TavilyOptions(BaseModel):
20
+ topic: Literal["general", "news"] | None = None
21
+ search_depth: Literal["basic", "advanced"] | None = None
22
+ chunks_per_source: Literal[1, 2, 3] | None = None
23
+ max_results: int | None = None
24
+ time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | None = (
25
+ None
26
+ )
27
+ days: int | None = None
28
+ include_answer: bool | Literal["basic", "advanced"] | None = None
29
+ include_raw_content: bool | None = None
30
+ include_images: bool | None = None
31
+ include_image_descriptions: bool | None = None
32
+ include_domains: list[str] | None = None
33
+ exclude_domains: list[str] | None = None
34
+ # max_connections is not a Tavily API option, but an inspect option
35
+ max_connections: int | None = None
36
+
37
+
19
38
  class TavilySearchResult(BaseModel):
20
39
  title: str
21
40
  url: str
@@ -32,32 +51,37 @@ class TavilySearchResponse(BaseModel):
32
51
 
33
52
 
34
53
  def tavily_search_provider(
35
- num_results: int, max_connections: int
54
+ in_options: dict[str, object] | None = None,
36
55
  ) -> Callable[[str], Awaitable[str | None]]:
56
+ options = TavilyOptions.model_validate(in_options) if in_options else None
57
+ # Separate max_connections (which is an inspect thing) from the rest of the
58
+ # options which will be passed in the request body
59
+ max_connections = (options.max_connections if options else None) or 10
60
+ api_options = (
61
+ options.model_dump(exclude={"max_connections"}, exclude_none=True)
62
+ if options
63
+ else {}
64
+ )
65
+ if not api_options.get("include_answer", False):
66
+ api_options["include_answer"] = True
67
+
37
68
  tavily_api_key = os.environ.get("TAVILY_API_KEY", None)
38
69
  if not tavily_api_key:
39
70
  raise PrerequisiteError(
40
71
  "TAVILY_API_KEY not set in the environment. Please ensure ths variable is defined to use Tavily with the web_search tool.\n\nLearn more about the Tavily web search provider at https://inspect.aisi.org.uk/tools.html#tavily-provider"
41
72
  )
42
- if num_results > 20:
43
- raise PrerequisiteError(
44
- "The Tavily search provider is limited to 20 results per query."
45
- )
46
73
 
47
74
  # Create the client within the provider
48
75
  client = httpx.AsyncClient(timeout=30)
49
76
 
50
77
  async def search(query: str) -> str | None:
78
+ # See https://docs.tavily.com/documentation/api-reference/endpoint/search
51
79
  search_url = "https://api.tavily.com/search"
52
80
  headers = {
53
81
  "Authorization": f"Bearer {tavily_api_key}",
54
82
  }
55
- body = {
56
- "query": query,
57
- "max_results": 10, # num_results,
58
- # "search_depth": "advanced",
59
- "include_answer": "advanced",
60
- }
83
+
84
+ body = {"query": query, **api_options}
61
85
 
62
86
  # retry up to 5 times over a period of up to 1 minute
63
87
  @retry(
@@ -72,6 +96,15 @@ def tavily_search_provider(
72
96
  return response
73
97
 
74
98
  async with concurrency("tavily_web_search", max_connections):
75
- return TavilySearchResponse.model_validate((await _search()).json()).answer
99
+ tavily_search_response = TavilySearchResponse.model_validate(
100
+ (await _search()).json()
101
+ )
102
+ results_str = "\n\n".join(
103
+ [
104
+ f"[{result.title}]({result.url}):\n{result.content}"
105
+ for result in tavily_search_response.results
106
+ ]
107
+ )
108
+ return f"Answer: {tavily_search_response.answer}\n\n{results_str}"
76
109
 
77
110
  return search
@@ -1,68 +1,123 @@
1
- from typing import Literal
1
+ from typing import (
2
+ Any,
3
+ Awaitable,
4
+ Callable,
5
+ Literal,
6
+ TypeAlias,
7
+ TypedDict,
8
+ get_args,
9
+ )
10
+
11
+ from typing_extensions import Unpack
2
12
 
3
13
  from inspect_ai._util.deprecation import deprecation_warning
14
+ from inspect_ai.tool._tool_def import ToolDef
4
15
 
5
16
  from ..._tool import Tool, ToolResult, tool
6
- from ._google import google_search_provider, maybe_get_google_api_keys
7
- from ._tavily import tavily_search_provider
17
+ from ._google import GoogleOptions, google_search_provider
18
+ from ._tavily import TavilyOptions, tavily_search_provider
19
+
20
+ Provider: TypeAlias = Literal["openai", "tavily", "google"] # , "gemini", "anthropic"
21
+ valid_providers = set(get_args(Provider))
22
+
23
+
24
+ # It would have been nice if the values below were TypedDicts. The problem is
25
+ # that if the caller creates a literal dict variable (rather than passing the
26
+ # dict inline), the type checker will erase the type of the literal to something
27
+ # that doesn't conform the the required TypedDict when passed. This is lame, but
28
+ # we'll do runtime validation instead.
29
+ #
30
+ # If the caller uses this dict form and uses a value of `None`, it means that
31
+ # they want to use that provider and to use the default options.
32
+ class Providers(TypedDict, total=False):
33
+ google: dict[str, Any] | None
34
+ tavily: dict[str, Any] | None
35
+ openai: dict[str, Any] | None
36
+
37
+
38
+ class WebSearchDeprecatedArgs(TypedDict, total=False):
39
+ provider: Literal["tavily", "google"] | None
40
+ num_results: int | None
41
+ max_provider_calls: int | None
42
+ max_connections: int | None
43
+ model: str | None
8
44
 
9
45
 
10
46
  @tool
11
47
  def web_search(
12
- provider: Literal["tavily", "google"] | None = None,
13
- num_results: int = 3,
14
- max_provider_calls: int = 3,
15
- max_connections: int = 10,
16
- model: str | None = None,
48
+ providers: Provider | Providers | list[Provider | Providers] | None = None,
49
+ **deprecated: Unpack[WebSearchDeprecatedArgs],
17
50
  ) -> Tool:
18
51
  """Web search tool.
19
52
 
20
- A tool that can be registered for use by models to search the web. Use
21
- the `use_tools()` solver to make the tool available (e.g.
22
- `use_tools(web_search(provider="tavily"))`))
53
+ Web searches are executed using a provider. Providers are split
54
+ into two categories:
23
55
 
24
- A web search is conducted using the specified provider.
25
- - When using Tavily, all logic for relevance and summarization is handled by
26
- the Tavily API.
27
- - When using Google, the results are parsed for relevance using the specified
28
- model, and the top 'num_results' relevant pages are returned.
56
+ - Internal providers: "openai" - these use the model's built-in search
57
+ capability and do not require separate API keys. These work only for
58
+ their respective model provider (e.g. the "openai" search provider
59
+ works only for `openai/*` models).
60
+
61
+ - External providers: "tavily" and "google". These are external services
62
+ that work with any m odel and require separate accounts and API keys.
63
+
64
+ Internal providers will be prioritized if running on the corresponding model
65
+ (e.g., "openai" provider will be used when running on `openai` models). If an
66
+ internal provider is specified but the evaluation is run with a different
67
+ model, a fallback external provider must also be specified.
29
68
 
30
69
  See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
31
70
 
32
71
  Args:
33
- provider: Search provider to use:
34
- - "tavily": Uses Tavily's Research API.
35
- - "google": Uses Google Custom Search.
36
- Note: The `| None` type is only for backwards compatibility. Passing
37
- `None` is deprecated.
38
- num_results: The number of search result pages used to provide information
39
- back to the model.
40
- max_provider_calls: Maximum number of search calls to make to the search
41
- provider.
42
- max_connections: Maximum number of concurrent connections to API endpoint
43
- of search provider.
44
- model: Model used to parse web pages for relevance - used only by the
45
- `google` provider.
72
+ providers: Configuration for the search providers to use. Currently supported
73
+ providers are "openai","tavily", and "google", The `providers` parameter
74
+ supports several formats based on either a `str` specifying a provider or
75
+ a `dict` whose keys are the provider names and whose values are the
76
+ provider-specific options. A single value or a list of these can be passed.
77
+ This arg is optional just for backwards compatibility. New code should
78
+ always provide this argument.
79
+
80
+ Single provider:
81
+ ```
82
+ web_search("tavily")
83
+ web_search({"tavily": {"max_results": 5}}) # Tavily-specific options
84
+ ```
85
+
86
+ Multiple providers:
87
+ ```
88
+ # "openai" used for OpenAI models, "tavily" as fallback
89
+ web_search(["openai", "tavily"])
90
+
91
+ # The None value means to use the provider with default options
92
+ web_search({"openai": None, "tavily": {"max_results": 5}}
93
+ ```
94
+
95
+ Mixed format:
96
+ ```
97
+ web_search(["openai", {"tavily": {"max_results": 5}}])
98
+ ```
99
+
100
+ When specified in the `dict` format, the `None` value for a provider means
101
+ to use the provider with default options.
102
+
103
+ Provider-specific options:
104
+ - openai: Supports OpenAI's web search parameters.
105
+ See https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses
106
+
107
+ - tavily: Supports options like `max_results`, `search_depth`, etc.
108
+ See https://docs.tavily.com/documentation/api-reference/endpoint/search
109
+
110
+ - google: Supports options like `num_results`, `max_provider_calls`,
111
+ `max_connections`, and `model`
112
+
113
+ **deprecated: Deprecated arguments.
46
114
 
47
115
  Returns:
48
116
  A tool that can be registered for use by models to search the web.
49
117
  """
50
- if provider is None:
51
- if maybe_get_google_api_keys():
52
- deprecation_warning(
53
- "The `google` `web_search` provider was inferred based on the presence of environment variables. Please specify the provider explicitly to avoid this warning."
54
- )
55
- provider = "google"
56
- else:
57
- raise ValueError(
58
- "Omitting `provider` is no longer supported. Please specify the `web_search` provider explicitly to avoid this error."
59
- )
118
+ normalized_providers = _normalize_config(providers, **deprecated)
60
119
 
61
- search_provider = (
62
- google_search_provider(num_results, max_provider_calls, max_connections, model)
63
- if provider == "google"
64
- else tavily_search_provider(num_results, max_connections)
65
- )
120
+ search_provider: Callable[[str], Awaitable[str | None]] | None = None
66
121
 
67
122
  async def execute(query: str) -> ToolResult:
68
123
  """
@@ -71,6 +126,9 @@ def web_search(
71
126
  Args:
72
127
  query (str): Search query.
73
128
  """
129
+ nonlocal search_provider
130
+ if not search_provider:
131
+ search_provider = _create_external_provider(normalized_providers)
74
132
  search_result = await search_provider(query)
75
133
 
76
134
  return (
@@ -82,4 +140,115 @@ def web_search(
82
140
  else ("I'm sorry, I couldn't find any relevant information on the web.")
83
141
  )
84
142
 
85
- return execute
143
+ return ToolDef(
144
+ execute, name="web_search", options=dict(normalized_providers)
145
+ ).as_tool()
146
+
147
+
148
+ def _normalize_config(
149
+ providers: Provider | Providers | list[Provider | Providers] | None,
150
+ **deprecated: Unpack[WebSearchDeprecatedArgs],
151
+ ) -> Providers:
152
+ """
153
+ Deal with breaking changes in the web_search parameter list.
154
+
155
+ This function adapts (hopefully) all of the old variants of how the tool
156
+ factory may have been called converts to the new config format.
157
+ """
158
+ # Cases to handle:
159
+ # 1. Both deprecated_provider and providers are set
160
+ # ValueError
161
+ # 2. Neither deprecated_provider nor providers is set
162
+ # act as if they passed provider="google"
163
+ # 3. Only providers is set
164
+ # if any of the other deprecated parameters is set, then ValueError
165
+ # else Happy path
166
+ # 4. Only deprecated_provider is set
167
+ # convert to new config format - including processing old other params
168
+
169
+ deprecated_provider = deprecated.get("provider", None)
170
+ # Case 1.
171
+ if deprecated_provider and providers:
172
+ raise ValueError("`provider` is deprecated. Please only specify `providers`.")
173
+
174
+ # Case 2.
175
+ if providers is None and deprecated_provider is None:
176
+ deprecated_provider = "google"
177
+
178
+ num_results = deprecated.get("num_results", None)
179
+ max_provider_calls = deprecated.get("max_provider_calls", None)
180
+ max_connections = deprecated.get("max_connections", None)
181
+ model = deprecated.get("model", None)
182
+
183
+ # Getting here means that we have either a providers or a deprecated_provider
184
+ if deprecated_provider:
185
+ return _get_config_via_back_compat(
186
+ deprecated_provider,
187
+ num_results=num_results,
188
+ max_provider_calls=max_provider_calls,
189
+ max_connections=max_connections,
190
+ model=model,
191
+ )
192
+
193
+ assert providers, "providers should not be None here"
194
+ normalized: Providers = {}
195
+ for entry in providers if isinstance(providers, list) else [providers]:
196
+ if isinstance(entry, str):
197
+ if entry not in valid_providers:
198
+ raise ValueError(f"Invalid provider: '{entry}'")
199
+ normalized[entry] = None # type: ignore
200
+ else:
201
+ for key, value in entry.items():
202
+ if key not in valid_providers:
203
+ raise ValueError(f"Invalid provider: '{key}'")
204
+ normalized[key] = value # type: ignore
205
+ return normalized
206
+
207
+
208
+ def _get_config_via_back_compat(
209
+ provider: Literal["tavily", "google"],
210
+ num_results: int | None,
211
+ max_provider_calls: int | None,
212
+ max_connections: int | None,
213
+ model: str | None,
214
+ ) -> Providers:
215
+ if (
216
+ num_results is None
217
+ and max_provider_calls is None
218
+ and max_connections is None
219
+ and model is None
220
+ ):
221
+ return {"google": None} if provider == "google" else {"tavily": None}
222
+
223
+ # If we get here, we have at least one old school parameter
224
+ deprecation_warning(
225
+ "The `num_results`, `max_provider_calls`, `max_connections`, and `model` parameters are deprecated. Please use the `config` parameter instead."
226
+ )
227
+
228
+ if provider == "google":
229
+ return {
230
+ "google": GoogleOptions(
231
+ num_results=num_results,
232
+ max_provider_calls=max_provider_calls,
233
+ max_connections=max_connections,
234
+ model=model,
235
+ ).model_dump(exclude_none=True)
236
+ }
237
+ else:
238
+ return {
239
+ "tavily": TavilyOptions(
240
+ max_results=num_results, max_connections=max_connections
241
+ ).model_dump(exclude_none=True)
242
+ }
243
+
244
+
245
+ def _create_external_provider(
246
+ providers: Providers,
247
+ ) -> Callable[[str], Awaitable[str | None]]:
248
+ if "tavily" in providers:
249
+ return tavily_search_provider(providers.get("tavily", None))
250
+
251
+ if "google" in providers:
252
+ return google_search_provider(providers.get("google", None))
253
+
254
+ raise ValueError("No valid provider found.")