inspect-ai 0.3.99__py3-none-any.whl → 0.3.101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +2 -1
- inspect_ai/_display/core/config.py +11 -5
- inspect_ai/_display/core/panel.py +66 -2
- inspect_ai/_display/core/textual.py +5 -2
- inspect_ai/_display/plain/display.py +1 -0
- inspect_ai/_display/rich/display.py +2 -2
- inspect_ai/_display/textual/widgets/transcript.py +37 -9
- inspect_ai/_eval/eval.py +13 -1
- inspect_ai/_eval/evalset.py +3 -2
- inspect_ai/_eval/run.py +2 -0
- inspect_ai/_eval/score.py +2 -4
- inspect_ai/_eval/task/log.py +3 -1
- inspect_ai/_eval/task/run.py +59 -81
- inspect_ai/_util/content.py +11 -6
- inspect_ai/_util/interrupt.py +2 -2
- inspect_ai/_util/text.py +7 -0
- inspect_ai/_util/working.py +8 -37
- inspect_ai/_view/__init__.py +0 -0
- inspect_ai/_view/schema.py +2 -1
- inspect_ai/_view/www/CLAUDE.md +15 -0
- inspect_ai/_view/www/dist/assets/index.css +307 -171
- inspect_ai/_view/www/dist/assets/index.js +24733 -21641
- inspect_ai/_view/www/log-schema.json +77 -3
- inspect_ai/_view/www/package.json +9 -5
- inspect_ai/_view/www/src/@types/log.d.ts +9 -0
- inspect_ai/_view/www/src/app/App.tsx +1 -15
- inspect_ai/_view/www/src/app/appearance/icons.ts +4 -1
- inspect_ai/_view/www/src/app/content/MetaDataGrid.tsx +24 -6
- inspect_ai/_view/www/src/app/content/MetadataGrid.module.css +0 -5
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +220 -205
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +2 -1
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +5 -0
- inspect_ai/_view/www/src/app/log-view/tabs/grouping.ts +4 -4
- inspect_ai/_view/www/src/app/routing/navigationHooks.ts +22 -25
- inspect_ai/_view/www/src/app/routing/url.ts +84 -4
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +0 -5
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +7 -0
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +24 -17
- inspect_ai/_view/www/src/app/samples/SampleSummaryView.module.css +1 -2
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +8 -6
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +0 -4
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/chat/MessageContent.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/MessageContents.tsx +2 -0
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +1 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -0
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +17 -5
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +3 -2
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +4 -5
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +1 -3
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +1 -2
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.module.css +42 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptPanel.tsx +77 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +27 -71
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +13 -3
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +27 -2
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +1 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +21 -22
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.module.css +45 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/OutlineRow.tsx +223 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.module.css +10 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/TranscriptOutline.tsx +258 -0
- inspect_ai/_view/www/src/app/samples/transcript/outline/tree-visitors.ts +187 -0
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventRenderers.tsx +8 -1
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +3 -4
- inspect_ai/_view/www/src/app/samples/transcript/transform/hooks.ts +78 -0
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +340 -135
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +3 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +2 -0
- inspect_ai/_view/www/src/app/types.ts +5 -1
- inspect_ai/_view/www/src/client/api/api-browser.ts +2 -2
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +6 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +1 -1
- inspect_ai/_view/www/src/components/PopOver.tsx +422 -0
- inspect_ai/_view/www/src/components/PulsingDots.module.css +9 -9
- inspect_ai/_view/www/src/components/PulsingDots.tsx +4 -1
- inspect_ai/_view/www/src/components/StickyScroll.tsx +183 -0
- inspect_ai/_view/www/src/components/TabSet.tsx +4 -0
- inspect_ai/_view/www/src/state/hooks.ts +52 -2
- inspect_ai/_view/www/src/state/logSlice.ts +4 -3
- inspect_ai/_view/www/src/state/samplePolling.ts +8 -0
- inspect_ai/_view/www/src/state/sampleSlice.ts +53 -9
- inspect_ai/_view/www/src/state/scrolling.ts +152 -0
- inspect_ai/_view/www/src/utils/attachments.ts +7 -0
- inspect_ai/_view/www/src/utils/python.ts +18 -0
- inspect_ai/_view/www/yarn.lock +290 -33
- inspect_ai/agent/_react.py +12 -7
- inspect_ai/agent/_run.py +2 -3
- inspect_ai/analysis/beta/__init__.py +2 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +19 -18
- inspect_ai/dataset/_sources/csv.py +2 -6
- inspect_ai/dataset/_sources/hf.py +2 -6
- inspect_ai/dataset/_sources/json.py +2 -6
- inspect_ai/dataset/_util.py +23 -0
- inspect_ai/log/_log.py +1 -1
- inspect_ai/log/_recorders/eval.py +4 -3
- inspect_ai/log/_recorders/file.py +2 -9
- inspect_ai/log/_recorders/json.py +1 -0
- inspect_ai/log/_recorders/recorder.py +1 -0
- inspect_ai/log/_transcript.py +1 -1
- inspect_ai/model/_call_tools.py +6 -2
- inspect_ai/model/_openai.py +1 -1
- inspect_ai/model/_openai_responses.py +85 -41
- inspect_ai/model/_openai_web_search.py +38 -0
- inspect_ai/model/_providers/azureai.py +72 -3
- inspect_ai/model/_providers/openai.py +4 -1
- inspect_ai/model/_providers/openai_responses.py +5 -1
- inspect_ai/scorer/_metric.py +1 -2
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/solver/_task_state.py +2 -2
- inspect_ai/tool/_tool.py +6 -2
- inspect_ai/tool/_tool_def.py +27 -4
- inspect_ai/tool/_tool_info.py +2 -0
- inspect_ai/tool/_tools/_web_search/_google.py +43 -15
- inspect_ai/tool/_tools/_web_search/_tavily.py +46 -13
- inspect_ai/tool/_tools/_web_search/_web_search.py +214 -45
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_json.py +3 -0
- inspect_ai/util/_limit.py +230 -20
- inspect_ai/util/_sandbox/docker/compose.py +20 -11
- inspect_ai/util/_span.py +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/METADATA +3 -3
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/RECORD +138 -124
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/WHEEL +1 -1
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.99.dist-info → inspect_ai-0.3.101.dist-info}/top_level.txt +0 -0
inspect_ai/tool/_tool.py
CHANGED
@@ -224,13 +224,15 @@ def tool(
|
|
224
224
|
tool_parallel = parallel
|
225
225
|
tool_viewer = viewer
|
226
226
|
tool_model_input = model_input
|
227
|
+
tool_options: dict[str, object] | None = None
|
227
228
|
if is_registry_object(tool):
|
228
|
-
_, _, reg_parallel, reg_viewer, reg_model_input =
|
229
|
-
tool
|
229
|
+
_, _, reg_parallel, reg_viewer, reg_model_input, options = (
|
230
|
+
tool_registry_info(tool)
|
230
231
|
)
|
231
232
|
tool_parallel = parallel and reg_parallel
|
232
233
|
tool_viewer = viewer or reg_viewer
|
233
234
|
tool_model_input = model_input or reg_model_input
|
235
|
+
tool_options = options
|
234
236
|
|
235
237
|
# tag the object
|
236
238
|
registry_tag(
|
@@ -247,6 +249,7 @@ def tool(
|
|
247
249
|
tool_model_input
|
248
250
|
or getattr(tool, TOOL_INIT_MODEL_INPUT, None)
|
249
251
|
),
|
252
|
+
TOOL_OPTIONS: tool_options,
|
250
253
|
},
|
251
254
|
),
|
252
255
|
*args,
|
@@ -267,6 +270,7 @@ TOOL_PROMPT = "prompt"
|
|
267
270
|
TOOL_PARALLEL = "parallel"
|
268
271
|
TOOL_VIEWER = "viewer"
|
269
272
|
TOOL_MODEL_INPUT = "model_input"
|
273
|
+
TOOL_OPTIONS = "options"
|
270
274
|
|
271
275
|
|
272
276
|
TOOL_INIT_MODEL_INPUT = "__TOOL_INIT_MODEL_INPUT__"
|
inspect_ai/tool/_tool_def.py
CHANGED
@@ -16,6 +16,7 @@ from inspect_ai._util.registry import (
|
|
16
16
|
|
17
17
|
from ._tool import (
|
18
18
|
TOOL_MODEL_INPUT,
|
19
|
+
TOOL_OPTIONS,
|
19
20
|
TOOL_PARALLEL,
|
20
21
|
TOOL_PROMPT,
|
21
22
|
TOOL_VIEWER,
|
@@ -44,6 +45,7 @@ class ToolDef:
|
|
44
45
|
parallel: bool | None = None,
|
45
46
|
viewer: ToolCallViewer | None = None,
|
46
47
|
model_input: ToolCallModelInput | None = None,
|
48
|
+
options: dict[str, object] | None = None,
|
47
49
|
) -> None:
|
48
50
|
"""Create a tool definition.
|
49
51
|
|
@@ -59,6 +61,8 @@ class ToolDef:
|
|
59
61
|
viewer: Optional tool call viewer implementation.
|
60
62
|
model_input: Optional function that determines how
|
61
63
|
tool call results are played back as model input.
|
64
|
+
options: Optional property bag that can be used by the model provider
|
65
|
+
to customize the implementation of the tool
|
62
66
|
|
63
67
|
Returns:
|
64
68
|
Tool definition.
|
@@ -82,6 +86,7 @@ class ToolDef:
|
|
82
86
|
self.parallel = parallel if parallel is not None else tdef.parallel
|
83
87
|
self.viewer = viewer or tdef.viewer
|
84
88
|
self.model_input = model_input or tdef.model_input
|
89
|
+
self.options = options or tdef.options
|
85
90
|
|
86
91
|
# if its not a tool then extract tool_info if all fields have not
|
87
92
|
# been provided explicitly
|
@@ -112,6 +117,7 @@ class ToolDef:
|
|
112
117
|
self.parallel = parallel is not False
|
113
118
|
self.viewer = viewer
|
114
119
|
self.model_input = model_input
|
120
|
+
self.options = options
|
115
121
|
|
116
122
|
tool: Callable[..., Any]
|
117
123
|
"""Callable to execute tool."""
|
@@ -134,13 +140,20 @@ class ToolDef:
|
|
134
140
|
model_input: ToolCallModelInput | None
|
135
141
|
"""Custom model input presenter for tool calls."""
|
136
142
|
|
143
|
+
options: dict[str, object] | None = None
|
144
|
+
"""Optional property bag that can be used by the model provider to customize the implementation of the tool"""
|
145
|
+
|
137
146
|
def as_tool(self) -> Tool:
|
138
147
|
"""Convert a ToolDef to a Tool."""
|
139
148
|
tool = self.tool
|
140
149
|
info = RegistryInfo(
|
141
150
|
type="tool",
|
142
151
|
name=self.name,
|
143
|
-
metadata={
|
152
|
+
metadata={
|
153
|
+
TOOL_PARALLEL: self.parallel,
|
154
|
+
TOOL_VIEWER: self.viewer,
|
155
|
+
TOOL_OPTIONS: self.options,
|
156
|
+
},
|
144
157
|
)
|
145
158
|
set_registry_info(tool, info)
|
146
159
|
set_registry_params(tool, {})
|
@@ -189,11 +202,12 @@ class ToolDefFields(NamedTuple):
|
|
189
202
|
parallel: bool
|
190
203
|
viewer: ToolCallViewer | None
|
191
204
|
model_input: ToolCallModelInput | None
|
205
|
+
options: dict[str, object] | None
|
192
206
|
|
193
207
|
|
194
208
|
def tool_def_fields(tool: Tool) -> ToolDefFields:
|
195
209
|
# get tool_info
|
196
|
-
name, prompt, parallel, viewer, model_input = tool_registry_info(tool)
|
210
|
+
name, prompt, parallel, viewer, model_input, options = tool_registry_info(tool)
|
197
211
|
tool_info = parse_tool_info(tool)
|
198
212
|
|
199
213
|
# if there is a description then append any prompt to the
|
@@ -234,19 +248,28 @@ def tool_def_fields(tool: Tool) -> ToolDefFields:
|
|
234
248
|
parallel=parallel,
|
235
249
|
viewer=viewer,
|
236
250
|
model_input=model_input,
|
251
|
+
options=options,
|
237
252
|
)
|
238
253
|
|
239
254
|
|
240
255
|
def tool_registry_info(
|
241
256
|
tool: Tool,
|
242
|
-
) -> tuple[
|
257
|
+
) -> tuple[
|
258
|
+
str,
|
259
|
+
str | None,
|
260
|
+
bool,
|
261
|
+
ToolCallViewer | None,
|
262
|
+
ToolCallModelInput | None,
|
263
|
+
dict[str, object] | None,
|
264
|
+
]:
|
243
265
|
info = registry_info(tool)
|
244
266
|
name = info.name.split("/")[-1]
|
245
267
|
prompt = info.metadata.get(TOOL_PROMPT, None)
|
246
268
|
parallel = info.metadata.get(TOOL_PARALLEL, True)
|
247
269
|
viewer = info.metadata.get(TOOL_VIEWER, None)
|
248
270
|
model_input = info.metadata.get(TOOL_MODEL_INPUT, None)
|
249
|
-
|
271
|
+
options = info.metadata.get(TOOL_OPTIONS, None)
|
272
|
+
return name, prompt, parallel, viewer, model_input, options
|
250
273
|
|
251
274
|
|
252
275
|
def validate_tool_parameters(tool_name: str, parameters: dict[str, ToolParam]) -> None:
|
inspect_ai/tool/_tool_info.py
CHANGED
@@ -49,6 +49,8 @@ class ToolInfo(BaseModel):
|
|
49
49
|
"""Short description of tool."""
|
50
50
|
parameters: ToolParams = Field(default_factory=ToolParams)
|
51
51
|
"""JSON Schema of tool parameters object."""
|
52
|
+
options: dict[str, object] | None = Field(default=None)
|
53
|
+
"""Optional property bag that can be used by the model provider to customize the implementation of the tool"""
|
52
54
|
|
53
55
|
|
54
56
|
def parse_tool_info(func: Callable[..., Any]) -> ToolInfo:
|
@@ -4,6 +4,7 @@ from typing import Awaitable, Callable
|
|
4
4
|
import anyio
|
5
5
|
import httpx
|
6
6
|
from bs4 import BeautifulSoup, NavigableString
|
7
|
+
from pydantic import BaseModel
|
7
8
|
from tenacity import (
|
8
9
|
retry,
|
9
10
|
retry_if_exception,
|
@@ -23,10 +24,18 @@ Page Content: {text}
|
|
23
24
|
"""
|
24
25
|
|
25
26
|
|
27
|
+
class GoogleOptions(BaseModel):
|
28
|
+
num_results: int | None = None
|
29
|
+
max_provider_calls: int | None = None
|
30
|
+
max_connections: int | None = None
|
31
|
+
model: str | None = None
|
32
|
+
|
33
|
+
|
26
34
|
class SearchLink:
|
27
|
-
def __init__(self, url: str, snippet: str) -> None:
|
35
|
+
def __init__(self, url: str, snippet: str, title: str) -> None:
|
28
36
|
self.url = url
|
29
37
|
self.snippet = snippet
|
38
|
+
self.title = title
|
30
39
|
|
31
40
|
|
32
41
|
def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
@@ -42,11 +51,14 @@ def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
|
42
51
|
|
43
52
|
|
44
53
|
def google_search_provider(
|
45
|
-
|
46
|
-
max_provider_calls: int,
|
47
|
-
max_connections: int,
|
48
|
-
model: str | None,
|
54
|
+
in_options: dict[str, object] | None = None,
|
49
55
|
) -> Callable[[str], Awaitable[str | None]]:
|
56
|
+
options = GoogleOptions.model_validate(in_options) if in_options else None
|
57
|
+
num_results = (options.num_results if options else None) or 3
|
58
|
+
max_provider_calls = (options.max_provider_calls if options else None) or 3
|
59
|
+
max_connections = (options.max_connections if options else None) or 10
|
60
|
+
model = options.model if options else None
|
61
|
+
|
50
62
|
keys = maybe_get_google_api_keys()
|
51
63
|
if not keys:
|
52
64
|
raise PrerequisiteError(
|
@@ -60,8 +72,7 @@ def google_search_provider(
|
|
60
72
|
async def search(query: str) -> str | None:
|
61
73
|
# limit number of concurrent searches
|
62
74
|
page_contents: list[str] = []
|
63
|
-
|
64
|
-
snippets: list[str] = []
|
75
|
+
processed_links: list[SearchLink] = []
|
65
76
|
search_calls = 0
|
66
77
|
|
67
78
|
# Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
|
@@ -76,8 +87,7 @@ def google_search_provider(
|
|
76
87
|
page = await page_if_relevant(link.url, query, model, client)
|
77
88
|
if page:
|
78
89
|
page_contents.append(page)
|
79
|
-
|
80
|
-
snippets.append(link.snippet)
|
90
|
+
processed_links.append(link)
|
81
91
|
# exceptions fetching pages are very common!
|
82
92
|
except Exception:
|
83
93
|
pass
|
@@ -87,8 +97,18 @@ def google_search_provider(
|
|
87
97
|
|
88
98
|
search_calls += 1
|
89
99
|
|
90
|
-
|
91
|
-
|
100
|
+
return (
|
101
|
+
"\n\n".join(
|
102
|
+
"[{title}]({url}):\n{page_content}".format(
|
103
|
+
title=link.title, url=link.url, page_content=page_content
|
104
|
+
)
|
105
|
+
for link, page_content in zip(
|
106
|
+
processed_links, page_contents, strict=True
|
107
|
+
)
|
108
|
+
)
|
109
|
+
if processed_links
|
110
|
+
else None
|
111
|
+
)
|
92
112
|
|
93
113
|
async def _search(query: str, start_idx: int) -> list[SearchLink]:
|
94
114
|
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
@@ -110,13 +130,21 @@ def google_search_provider(
|
|
110
130
|
before_sleep=log_httpx_retry_attempt(search_url),
|
111
131
|
)
|
112
132
|
async def execute_search() -> httpx.Response:
|
133
|
+
# See https://developers.google.com/custom-search/v1/reference/rest/v1/Search
|
113
134
|
return await client.get(search_url)
|
114
135
|
|
115
136
|
result = await execute_search()
|
116
137
|
data = result.json()
|
117
138
|
|
118
139
|
if "items" in data:
|
119
|
-
return [
|
140
|
+
return [
|
141
|
+
SearchLink(
|
142
|
+
url=item["link"],
|
143
|
+
snippet=item.get("snippet", ""), # sometimes not present
|
144
|
+
title=item["title"],
|
145
|
+
)
|
146
|
+
for item in data["items"]
|
147
|
+
]
|
120
148
|
else:
|
121
149
|
return []
|
122
150
|
|
@@ -124,13 +152,13 @@ def google_search_provider(
|
|
124
152
|
|
125
153
|
|
126
154
|
async def page_if_relevant(
|
127
|
-
|
155
|
+
url: str, query: str, relevance_model: str | None, client: httpx.AsyncClient
|
128
156
|
) -> str | None:
|
129
157
|
"""
|
130
158
|
Use parser model to determine if a web page contents is relevant to a query.
|
131
159
|
|
132
160
|
Args:
|
133
|
-
|
161
|
+
url (str): Web page url.
|
134
162
|
query (str): Search query.
|
135
163
|
relevance_model (Model): Model used to parse web pages for relevance.
|
136
164
|
client: (httpx.Client): HTTP client to use to fetch the page
|
@@ -145,7 +173,7 @@ async def page_if_relevant(
|
|
145
173
|
|
146
174
|
# retrieve document
|
147
175
|
try:
|
148
|
-
response = await client.get(
|
176
|
+
response = await client.get(url)
|
149
177
|
response.raise_for_status()
|
150
178
|
except httpx.HTTPError as exc:
|
151
179
|
raise Exception(f"HTTP error occurred: {exc}")
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Awaitable, Callable
|
2
|
+
from typing import Awaitable, Callable, Literal
|
3
3
|
|
4
4
|
import httpx
|
5
5
|
from pydantic import BaseModel, Field
|
@@ -16,6 +16,25 @@ from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt
|
|
16
16
|
from inspect_ai.util._concurrency import concurrency
|
17
17
|
|
18
18
|
|
19
|
+
class TavilyOptions(BaseModel):
|
20
|
+
topic: Literal["general", "news"] | None = None
|
21
|
+
search_depth: Literal["basic", "advanced"] | None = None
|
22
|
+
chunks_per_source: Literal[1, 2, 3] | None = None
|
23
|
+
max_results: int | None = None
|
24
|
+
time_range: Literal["day", "week", "month", "year", "d", "w", "m", "y"] | None = (
|
25
|
+
None
|
26
|
+
)
|
27
|
+
days: int | None = None
|
28
|
+
include_answer: bool | Literal["basic", "advanced"] | None = None
|
29
|
+
include_raw_content: bool | None = None
|
30
|
+
include_images: bool | None = None
|
31
|
+
include_image_descriptions: bool | None = None
|
32
|
+
include_domains: list[str] | None = None
|
33
|
+
exclude_domains: list[str] | None = None
|
34
|
+
# max_connections is not a Tavily API option, but an inspect option
|
35
|
+
max_connections: int | None = None
|
36
|
+
|
37
|
+
|
19
38
|
class TavilySearchResult(BaseModel):
|
20
39
|
title: str
|
21
40
|
url: str
|
@@ -32,32 +51,37 @@ class TavilySearchResponse(BaseModel):
|
|
32
51
|
|
33
52
|
|
34
53
|
def tavily_search_provider(
|
35
|
-
|
54
|
+
in_options: dict[str, object] | None = None,
|
36
55
|
) -> Callable[[str], Awaitable[str | None]]:
|
56
|
+
options = TavilyOptions.model_validate(in_options) if in_options else None
|
57
|
+
# Separate max_connections (which is an inspect thing) from the rest of the
|
58
|
+
# options which will be passed in the request body
|
59
|
+
max_connections = (options.max_connections if options else None) or 10
|
60
|
+
api_options = (
|
61
|
+
options.model_dump(exclude={"max_connections"}, exclude_none=True)
|
62
|
+
if options
|
63
|
+
else {}
|
64
|
+
)
|
65
|
+
if not api_options.get("include_answer", False):
|
66
|
+
api_options["include_answer"] = True
|
67
|
+
|
37
68
|
tavily_api_key = os.environ.get("TAVILY_API_KEY", None)
|
38
69
|
if not tavily_api_key:
|
39
70
|
raise PrerequisiteError(
|
40
71
|
"TAVILY_API_KEY not set in the environment. Please ensure ths variable is defined to use Tavily with the web_search tool.\n\nLearn more about the Tavily web search provider at https://inspect.aisi.org.uk/tools.html#tavily-provider"
|
41
72
|
)
|
42
|
-
if num_results > 20:
|
43
|
-
raise PrerequisiteError(
|
44
|
-
"The Tavily search provider is limited to 20 results per query."
|
45
|
-
)
|
46
73
|
|
47
74
|
# Create the client within the provider
|
48
75
|
client = httpx.AsyncClient(timeout=30)
|
49
76
|
|
50
77
|
async def search(query: str) -> str | None:
|
78
|
+
# See https://docs.tavily.com/documentation/api-reference/endpoint/search
|
51
79
|
search_url = "https://api.tavily.com/search"
|
52
80
|
headers = {
|
53
81
|
"Authorization": f"Bearer {tavily_api_key}",
|
54
82
|
}
|
55
|
-
|
56
|
-
|
57
|
-
"max_results": 10, # num_results,
|
58
|
-
# "search_depth": "advanced",
|
59
|
-
"include_answer": "advanced",
|
60
|
-
}
|
83
|
+
|
84
|
+
body = {"query": query, **api_options}
|
61
85
|
|
62
86
|
# retry up to 5 times over a period of up to 1 minute
|
63
87
|
@retry(
|
@@ -72,6 +96,15 @@ def tavily_search_provider(
|
|
72
96
|
return response
|
73
97
|
|
74
98
|
async with concurrency("tavily_web_search", max_connections):
|
75
|
-
|
99
|
+
tavily_search_response = TavilySearchResponse.model_validate(
|
100
|
+
(await _search()).json()
|
101
|
+
)
|
102
|
+
results_str = "\n\n".join(
|
103
|
+
[
|
104
|
+
f"[{result.title}]({result.url}):\n{result.content}"
|
105
|
+
for result in tavily_search_response.results
|
106
|
+
]
|
107
|
+
)
|
108
|
+
return f"Answer: {tavily_search_response.answer}\n\n{results_str}"
|
76
109
|
|
77
110
|
return search
|
@@ -1,68 +1,123 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import (
|
2
|
+
Any,
|
3
|
+
Awaitable,
|
4
|
+
Callable,
|
5
|
+
Literal,
|
6
|
+
TypeAlias,
|
7
|
+
TypedDict,
|
8
|
+
get_args,
|
9
|
+
)
|
10
|
+
|
11
|
+
from typing_extensions import Unpack
|
2
12
|
|
3
13
|
from inspect_ai._util.deprecation import deprecation_warning
|
14
|
+
from inspect_ai.tool._tool_def import ToolDef
|
4
15
|
|
5
16
|
from ..._tool import Tool, ToolResult, tool
|
6
|
-
from ._google import
|
7
|
-
from ._tavily import tavily_search_provider
|
17
|
+
from ._google import GoogleOptions, google_search_provider
|
18
|
+
from ._tavily import TavilyOptions, tavily_search_provider
|
19
|
+
|
20
|
+
Provider: TypeAlias = Literal["openai", "tavily", "google"] # , "gemini", "anthropic"
|
21
|
+
valid_providers = set(get_args(Provider))
|
22
|
+
|
23
|
+
|
24
|
+
# It would have been nice if the values below were TypedDicts. The problem is
|
25
|
+
# that if the caller creates a literal dict variable (rather than passing the
|
26
|
+
# dict inline), the type checker will erase the type of the literal to something
|
27
|
+
# that doesn't conform the the required TypedDict when passed. This is lame, but
|
28
|
+
# we'll do runtime validation instead.
|
29
|
+
#
|
30
|
+
# If the caller uses this dict form and uses a value of `None`, it means that
|
31
|
+
# they want to use that provider and to use the default options.
|
32
|
+
class Providers(TypedDict, total=False):
|
33
|
+
google: dict[str, Any] | None
|
34
|
+
tavily: dict[str, Any] | None
|
35
|
+
openai: dict[str, Any] | None
|
36
|
+
|
37
|
+
|
38
|
+
class WebSearchDeprecatedArgs(TypedDict, total=False):
|
39
|
+
provider: Literal["tavily", "google"] | None
|
40
|
+
num_results: int | None
|
41
|
+
max_provider_calls: int | None
|
42
|
+
max_connections: int | None
|
43
|
+
model: str | None
|
8
44
|
|
9
45
|
|
10
46
|
@tool
|
11
47
|
def web_search(
|
12
|
-
|
13
|
-
|
14
|
-
max_provider_calls: int = 3,
|
15
|
-
max_connections: int = 10,
|
16
|
-
model: str | None = None,
|
48
|
+
providers: Provider | Providers | list[Provider | Providers] | None = None,
|
49
|
+
**deprecated: Unpack[WebSearchDeprecatedArgs],
|
17
50
|
) -> Tool:
|
18
51
|
"""Web search tool.
|
19
52
|
|
20
|
-
|
21
|
-
|
22
|
-
`use_tools(web_search(provider="tavily"))`))
|
53
|
+
Web searches are executed using a provider. Providers are split
|
54
|
+
into two categories:
|
23
55
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
56
|
+
- Internal providers: "openai" - these use the model's built-in search
|
57
|
+
capability and do not require separate API keys. These work only for
|
58
|
+
their respective model provider (e.g. the "openai" search provider
|
59
|
+
works only for `openai/*` models).
|
60
|
+
|
61
|
+
- External providers: "tavily" and "google". These are external services
|
62
|
+
that work with any m odel and require separate accounts and API keys.
|
63
|
+
|
64
|
+
Internal providers will be prioritized if running on the corresponding model
|
65
|
+
(e.g., "openai" provider will be used when running on `openai` models). If an
|
66
|
+
internal provider is specified but the evaluation is run with a different
|
67
|
+
model, a fallback external provider must also be specified.
|
29
68
|
|
30
69
|
See further documentation at <https://inspect.aisi.org.uk/tools-standard.html#sec-web-search>.
|
31
70
|
|
32
71
|
Args:
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
provider
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
72
|
+
providers: Configuration for the search providers to use. Currently supported
|
73
|
+
providers are "openai","tavily", and "google", The `providers` parameter
|
74
|
+
supports several formats based on either a `str` specifying a provider or
|
75
|
+
a `dict` whose keys are the provider names and whose values are the
|
76
|
+
provider-specific options. A single value or a list of these can be passed.
|
77
|
+
This arg is optional just for backwards compatibility. New code should
|
78
|
+
always provide this argument.
|
79
|
+
|
80
|
+
Single provider:
|
81
|
+
```
|
82
|
+
web_search("tavily")
|
83
|
+
web_search({"tavily": {"max_results": 5}}) # Tavily-specific options
|
84
|
+
```
|
85
|
+
|
86
|
+
Multiple providers:
|
87
|
+
```
|
88
|
+
# "openai" used for OpenAI models, "tavily" as fallback
|
89
|
+
web_search(["openai", "tavily"])
|
90
|
+
|
91
|
+
# The None value means to use the provider with default options
|
92
|
+
web_search({"openai": None, "tavily": {"max_results": 5}}
|
93
|
+
```
|
94
|
+
|
95
|
+
Mixed format:
|
96
|
+
```
|
97
|
+
web_search(["openai", {"tavily": {"max_results": 5}}])
|
98
|
+
```
|
99
|
+
|
100
|
+
When specified in the `dict` format, the `None` value for a provider means
|
101
|
+
to use the provider with default options.
|
102
|
+
|
103
|
+
Provider-specific options:
|
104
|
+
- openai: Supports OpenAI's web search parameters.
|
105
|
+
See https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses
|
106
|
+
|
107
|
+
- tavily: Supports options like `max_results`, `search_depth`, etc.
|
108
|
+
See https://docs.tavily.com/documentation/api-reference/endpoint/search
|
109
|
+
|
110
|
+
- google: Supports options like `num_results`, `max_provider_calls`,
|
111
|
+
`max_connections`, and `model`
|
112
|
+
|
113
|
+
**deprecated: Deprecated arguments.
|
46
114
|
|
47
115
|
Returns:
|
48
116
|
A tool that can be registered for use by models to search the web.
|
49
117
|
"""
|
50
|
-
|
51
|
-
if maybe_get_google_api_keys():
|
52
|
-
deprecation_warning(
|
53
|
-
"The `google` `web_search` provider was inferred based on the presence of environment variables. Please specify the provider explicitly to avoid this warning."
|
54
|
-
)
|
55
|
-
provider = "google"
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
"Omitting `provider` is no longer supported. Please specify the `web_search` provider explicitly to avoid this error."
|
59
|
-
)
|
118
|
+
normalized_providers = _normalize_config(providers, **deprecated)
|
60
119
|
|
61
|
-
search_provider =
|
62
|
-
google_search_provider(num_results, max_provider_calls, max_connections, model)
|
63
|
-
if provider == "google"
|
64
|
-
else tavily_search_provider(num_results, max_connections)
|
65
|
-
)
|
120
|
+
search_provider: Callable[[str], Awaitable[str | None]] | None = None
|
66
121
|
|
67
122
|
async def execute(query: str) -> ToolResult:
|
68
123
|
"""
|
@@ -71,6 +126,9 @@ def web_search(
|
|
71
126
|
Args:
|
72
127
|
query (str): Search query.
|
73
128
|
"""
|
129
|
+
nonlocal search_provider
|
130
|
+
if not search_provider:
|
131
|
+
search_provider = _create_external_provider(normalized_providers)
|
74
132
|
search_result = await search_provider(query)
|
75
133
|
|
76
134
|
return (
|
@@ -82,4 +140,115 @@ def web_search(
|
|
82
140
|
else ("I'm sorry, I couldn't find any relevant information on the web.")
|
83
141
|
)
|
84
142
|
|
85
|
-
return
|
143
|
+
return ToolDef(
|
144
|
+
execute, name="web_search", options=dict(normalized_providers)
|
145
|
+
).as_tool()
|
146
|
+
|
147
|
+
|
148
|
+
def _normalize_config(
|
149
|
+
providers: Provider | Providers | list[Provider | Providers] | None,
|
150
|
+
**deprecated: Unpack[WebSearchDeprecatedArgs],
|
151
|
+
) -> Providers:
|
152
|
+
"""
|
153
|
+
Deal with breaking changes in the web_search parameter list.
|
154
|
+
|
155
|
+
This function adapts (hopefully) all of the old variants of how the tool
|
156
|
+
factory may have been called converts to the new config format.
|
157
|
+
"""
|
158
|
+
# Cases to handle:
|
159
|
+
# 1. Both deprecated_provider and providers are set
|
160
|
+
# ValueError
|
161
|
+
# 2. Neither deprecated_provider nor providers is set
|
162
|
+
# act as if they passed provider="google"
|
163
|
+
# 3. Only providers is set
|
164
|
+
# if any of the other deprecated parameters is set, then ValueError
|
165
|
+
# else Happy path
|
166
|
+
# 4. Only deprecated_provider is set
|
167
|
+
# convert to new config format - including processing old other params
|
168
|
+
|
169
|
+
deprecated_provider = deprecated.get("provider", None)
|
170
|
+
# Case 1.
|
171
|
+
if deprecated_provider and providers:
|
172
|
+
raise ValueError("`provider` is deprecated. Please only specify `providers`.")
|
173
|
+
|
174
|
+
# Case 2.
|
175
|
+
if providers is None and deprecated_provider is None:
|
176
|
+
deprecated_provider = "google"
|
177
|
+
|
178
|
+
num_results = deprecated.get("num_results", None)
|
179
|
+
max_provider_calls = deprecated.get("max_provider_calls", None)
|
180
|
+
max_connections = deprecated.get("max_connections", None)
|
181
|
+
model = deprecated.get("model", None)
|
182
|
+
|
183
|
+
# Getting here means that we have either a providers or a deprecated_provider
|
184
|
+
if deprecated_provider:
|
185
|
+
return _get_config_via_back_compat(
|
186
|
+
deprecated_provider,
|
187
|
+
num_results=num_results,
|
188
|
+
max_provider_calls=max_provider_calls,
|
189
|
+
max_connections=max_connections,
|
190
|
+
model=model,
|
191
|
+
)
|
192
|
+
|
193
|
+
assert providers, "providers should not be None here"
|
194
|
+
normalized: Providers = {}
|
195
|
+
for entry in providers if isinstance(providers, list) else [providers]:
|
196
|
+
if isinstance(entry, str):
|
197
|
+
if entry not in valid_providers:
|
198
|
+
raise ValueError(f"Invalid provider: '{entry}'")
|
199
|
+
normalized[entry] = None # type: ignore
|
200
|
+
else:
|
201
|
+
for key, value in entry.items():
|
202
|
+
if key not in valid_providers:
|
203
|
+
raise ValueError(f"Invalid provider: '{key}'")
|
204
|
+
normalized[key] = value # type: ignore
|
205
|
+
return normalized
|
206
|
+
|
207
|
+
|
208
|
+
def _get_config_via_back_compat(
|
209
|
+
provider: Literal["tavily", "google"],
|
210
|
+
num_results: int | None,
|
211
|
+
max_provider_calls: int | None,
|
212
|
+
max_connections: int | None,
|
213
|
+
model: str | None,
|
214
|
+
) -> Providers:
|
215
|
+
if (
|
216
|
+
num_results is None
|
217
|
+
and max_provider_calls is None
|
218
|
+
and max_connections is None
|
219
|
+
and model is None
|
220
|
+
):
|
221
|
+
return {"google": None} if provider == "google" else {"tavily": None}
|
222
|
+
|
223
|
+
# If we get here, we have at least one old school parameter
|
224
|
+
deprecation_warning(
|
225
|
+
"The `num_results`, `max_provider_calls`, `max_connections`, and `model` parameters are deprecated. Please use the `config` parameter instead."
|
226
|
+
)
|
227
|
+
|
228
|
+
if provider == "google":
|
229
|
+
return {
|
230
|
+
"google": GoogleOptions(
|
231
|
+
num_results=num_results,
|
232
|
+
max_provider_calls=max_provider_calls,
|
233
|
+
max_connections=max_connections,
|
234
|
+
model=model,
|
235
|
+
).model_dump(exclude_none=True)
|
236
|
+
}
|
237
|
+
else:
|
238
|
+
return {
|
239
|
+
"tavily": TavilyOptions(
|
240
|
+
max_results=num_results, max_connections=max_connections
|
241
|
+
).model_dump(exclude_none=True)
|
242
|
+
}
|
243
|
+
|
244
|
+
|
245
|
+
def _create_external_provider(
|
246
|
+
providers: Providers,
|
247
|
+
) -> Callable[[str], Awaitable[str | None]]:
|
248
|
+
if "tavily" in providers:
|
249
|
+
return tavily_search_provider(providers.get("tavily", None))
|
250
|
+
|
251
|
+
if "google" in providers:
|
252
|
+
return google_search_provider(providers.get("google", None))
|
253
|
+
|
254
|
+
raise ValueError("No valid provider found.")
|