inspect-ai 0.3.100__py3-none-any.whl → 0.3.101__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_cli/eval.py +2 -1
- inspect_ai/_eval/eval.py +13 -1
- inspect_ai/_eval/evalset.py +3 -2
- inspect_ai/_eval/run.py +2 -0
- inspect_ai/_eval/task/log.py +3 -1
- inspect_ai/_view/www/dist/assets/index.css +44 -12
- inspect_ai/_view/www/dist/assets/index.js +1499 -1467
- inspect_ai/_view/www/package.json +4 -4
- inspect_ai/_view/www/src/app/log-view/tabs/grouping.ts +4 -4
- inspect_ai/_view/www/src/app/routing/navigationHooks.ts +22 -25
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +17 -5
- inspect_ai/_view/www/src/state/hooks.ts +1 -1
- inspect_ai/_view/www/yarn.lock +21 -27
- inspect_ai/analysis/beta/__init__.py +2 -0
- inspect_ai/dataset/_sources/csv.py +2 -6
- inspect_ai/dataset/_sources/hf.py +2 -6
- inspect_ai/dataset/_sources/json.py +2 -6
- inspect_ai/dataset/_util.py +23 -0
- inspect_ai/log/_recorders/eval.py +4 -3
- inspect_ai/log/_recorders/json.py +1 -0
- inspect_ai/log/_recorders/recorder.py +1 -0
- inspect_ai/model/_openai_responses.py +11 -6
- inspect_ai/model/_openai_web_search.py +9 -2
- inspect_ai/model/_providers/openai.py +3 -1
- inspect_ai/model/_providers/openai_responses.py +5 -1
- inspect_ai/scorer/_reducer/reducer.py +1 -1
- inspect_ai/tool/_tools/_web_search/_google.py +28 -11
- inspect_ai/tool/_tools/_web_search/_tavily.py +11 -1
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/RECORD +34 -34
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/WHEEL +0 -0
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.100.dist-info → inspect_ai-0.3.101.dist-info}/top_level.txt +0 -0
@@ -67,8 +67,8 @@
|
|
67
67
|
"@popperjs/core": "^2.11.8",
|
68
68
|
"ansi-output": "^0.0.9",
|
69
69
|
"asciinema-player": "^3.9.0",
|
70
|
-
"bootstrap": "^5.3.
|
71
|
-
"bootstrap-icons": "^1.
|
70
|
+
"bootstrap": "^5.3.6",
|
71
|
+
"bootstrap-icons": "^1.12.1",
|
72
72
|
"clipboard": "^2.0.11",
|
73
73
|
"clsx": "^2.1.1",
|
74
74
|
"codemirror": "^6.0.1",
|
@@ -89,8 +89,8 @@
|
|
89
89
|
"react": "^19.0.0",
|
90
90
|
"react-dom": "^19.0.0",
|
91
91
|
"react-popper": "^2.3.0",
|
92
|
-
"react-router-dom": "^7.
|
93
|
-
"react-virtuoso": "^4.12.
|
92
|
+
"react-router-dom": "^7.6.0",
|
93
|
+
"react-virtuoso": "^4.12.7",
|
94
94
|
"zustand": "^5.0.5",
|
95
95
|
"use-resize-observer": "^9.1.0"
|
96
96
|
}
|
@@ -47,7 +47,7 @@ const noGrouping = (
|
|
47
47
|
const itemCount = counter.item();
|
48
48
|
return [
|
49
49
|
{
|
50
|
-
label: `Sample ${
|
50
|
+
label: `Sample ${sample.id}`,
|
51
51
|
number: itemCount,
|
52
52
|
index: index,
|
53
53
|
data: sample,
|
@@ -107,10 +107,10 @@ const groupBySample = (
|
|
107
107
|
if (sample.id !== lastId) {
|
108
108
|
counter.incrementGroup();
|
109
109
|
results.push({
|
110
|
-
label: `Sample ${
|
110
|
+
label: `Sample ${sample.id}`,
|
111
111
|
number: counter.group(),
|
112
112
|
index: index,
|
113
|
-
data: `Sample ${
|
113
|
+
data: `Sample ${sample.id}`,
|
114
114
|
type: "separator",
|
115
115
|
} as SeparatorListItem);
|
116
116
|
counter.resetItem();
|
@@ -175,7 +175,7 @@ const groupByEpoch = (
|
|
175
175
|
// Compute the index within the epoch
|
176
176
|
counter.incrementItem();
|
177
177
|
results.push({
|
178
|
-
label: `Sample ${
|
178
|
+
label: `Sample ${sample.id} (Epoch ${sample.epoch})`,
|
179
179
|
number: counter.item(),
|
180
180
|
index: index,
|
181
181
|
data: sample,
|
@@ -130,29 +130,26 @@ export const useSampleNavigation = () => {
|
|
130
130
|
|
131
131
|
// Navigate to a specific sample with index
|
132
132
|
const showSample = useCallback(
|
133
|
-
(
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
// Navigate to the sample URL
|
154
|
-
navigate(url);
|
155
|
-
}
|
133
|
+
(
|
134
|
+
index: number,
|
135
|
+
id: string | number,
|
136
|
+
epoch: number,
|
137
|
+
specifiedSampleTabId?: string,
|
138
|
+
) => {
|
139
|
+
const resolvedPath = resolveLogPath();
|
140
|
+
|
141
|
+
if (resolvedPath) {
|
142
|
+
// Update internal state
|
143
|
+
selectSample(index);
|
144
|
+
setShowingSampleDialog(true);
|
145
|
+
|
146
|
+
// Use specified sampleTabId if provided, otherwise use current sampleTabId from URL params
|
147
|
+
const currentSampleTabId = specifiedSampleTabId || sampleTabId;
|
148
|
+
|
149
|
+
const url = sampleUrl(resolvedPath, id, epoch, currentSampleTabId);
|
150
|
+
|
151
|
+
// Navigate to the sample URL
|
152
|
+
navigate(url);
|
156
153
|
}
|
157
154
|
},
|
158
155
|
[
|
@@ -171,7 +168,7 @@ export const useSampleNavigation = () => {
|
|
171
168
|
const itemsCount = sampleSummaries.length;
|
172
169
|
const next = Math.min(selectedSampleIndex + 1, itemsCount - 1);
|
173
170
|
if (next > -1) {
|
174
|
-
|
171
|
+
selectSample(next);
|
175
172
|
}
|
176
173
|
}, [selectedSampleIndex, showSample, sampleTabId]);
|
177
174
|
|
@@ -179,7 +176,7 @@ export const useSampleNavigation = () => {
|
|
179
176
|
const previousSample = useCallback(() => {
|
180
177
|
const prev = selectedSampleIndex - 1;
|
181
178
|
if (prev > -1) {
|
182
|
-
|
179
|
+
selectSample(prev);
|
183
180
|
}
|
184
181
|
}, [selectedSampleIndex, showSample, sampleTabId]);
|
185
182
|
|
@@ -113,11 +113,19 @@ export const SampleList: FC<SampleListProps> = memo((props) => {
|
|
113
113
|
e.preventDefault();
|
114
114
|
e.stopPropagation();
|
115
115
|
break;
|
116
|
-
case "Enter":
|
117
|
-
|
118
|
-
|
119
|
-
|
116
|
+
case "Enter": {
|
117
|
+
const item = items[selectedSampleIndex];
|
118
|
+
if (item.type === "sample") {
|
119
|
+
sampleNavigation.showSample(
|
120
|
+
item.index,
|
121
|
+
item.data.id,
|
122
|
+
item.data.epoch,
|
123
|
+
);
|
124
|
+
e.preventDefault();
|
125
|
+
e.stopPropagation();
|
126
|
+
}
|
120
127
|
break;
|
128
|
+
}
|
121
129
|
}
|
122
130
|
},
|
123
131
|
[
|
@@ -150,7 +158,11 @@ export const SampleList: FC<SampleListProps> = memo((props) => {
|
|
150
158
|
item.data.epoch,
|
151
159
|
)}
|
152
160
|
showSample={() => {
|
153
|
-
sampleNavigation.showSample(
|
161
|
+
sampleNavigation.showSample(
|
162
|
+
item.index,
|
163
|
+
item.data.id,
|
164
|
+
item.data.epoch,
|
165
|
+
);
|
154
166
|
}}
|
155
167
|
/>
|
156
168
|
);
|
@@ -277,7 +277,7 @@ export const useCollapseSampleEvent = (
|
|
277
277
|
const collapseEvent = useStore((state) => state.sampleActions.collapseEvent);
|
278
278
|
|
279
279
|
return useMemo(() => {
|
280
|
-
const isCollapsed = collapsed !== null && collapsed[scope][id] === true;
|
280
|
+
const isCollapsed = collapsed !== null && collapsed[scope]?.[id] === true;
|
281
281
|
const set = (value: boolean) => {
|
282
282
|
log.debug("Set collapsed", id, value);
|
283
283
|
collapseEvent(scope, id, value);
|
inspect_ai/_view/www/yarn.lock
CHANGED
@@ -2339,15 +2339,15 @@ balanced-match@^1.0.0:
|
|
2339
2339
|
resolved "https://registry.yarnpkg.com/balanced-match/-/balanced-match-1.0.2.tgz#e83e3a7e3f300b34cb9d87f615fa0cbf357690ee"
|
2340
2340
|
integrity sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==
|
2341
2341
|
|
2342
|
-
bootstrap-icons@^1.
|
2343
|
-
version "1.
|
2344
|
-
resolved "https://registry.yarnpkg.com/bootstrap-icons/-/bootstrap-icons-1.
|
2345
|
-
integrity sha512
|
2342
|
+
bootstrap-icons@^1.12.1:
|
2343
|
+
version "1.13.1"
|
2344
|
+
resolved "https://registry.yarnpkg.com/bootstrap-icons/-/bootstrap-icons-1.13.1.tgz#0aad3f5b55b67402990e729ce3883416f9cef6c5"
|
2345
|
+
integrity sha512-ijombt4v6bv5CLeXvRWKy7CuM3TRTuPEuGaGKvTV5cz65rQSY8RQ2JcHt6b90cBBAC7s8fsf2EkQDldzCoXUjw==
|
2346
2346
|
|
2347
|
-
bootstrap@^5.3.
|
2348
|
-
version "5.3.
|
2349
|
-
resolved "https://registry.yarnpkg.com/bootstrap/-/bootstrap-5.3.
|
2350
|
-
integrity sha512-
|
2347
|
+
bootstrap@^5.3.6:
|
2348
|
+
version "5.3.6"
|
2349
|
+
resolved "https://registry.yarnpkg.com/bootstrap/-/bootstrap-5.3.6.tgz#fbd91ebaff093f5b191a1c01a8c866d24f9fa6e1"
|
2350
|
+
integrity sha512-jX0GAcRzvdwISuvArXn3m7KZscWWFAf1MKBcnzaN02qWMb3jpMoUX4/qgeiGzqyIb4ojulRzs89UCUmGcFSzTA==
|
2351
2351
|
|
2352
2352
|
brace-expansion@^1.1.7:
|
2353
2353
|
version "1.1.11"
|
@@ -4478,21 +4478,20 @@ react-refresh@^0.17.0:
|
|
4478
4478
|
resolved "https://registry.yarnpkg.com/react-refresh/-/react-refresh-0.17.0.tgz#b7e579c3657f23d04eccbe4ad2e58a8ed51e7e53"
|
4479
4479
|
integrity sha512-z6F7K9bV85EfseRCp2bzrpyQ0Gkw1uLoCel9XBVWPg/TjRj94SkJzUTGfOa4bs7iJvBWtQG0Wq7wnI0syw3EBQ==
|
4480
4480
|
|
4481
|
-
react-router-dom@^7.
|
4482
|
-
version "7.
|
4483
|
-
resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-7.
|
4484
|
-
integrity sha512-
|
4481
|
+
react-router-dom@^7.6.0:
|
4482
|
+
version "7.6.1"
|
4483
|
+
resolved "https://registry.yarnpkg.com/react-router-dom/-/react-router-dom-7.6.1.tgz#263c9102e96b58d336258a51d68080b40c28f526"
|
4484
|
+
integrity sha512-vxU7ei//UfPYQ3iZvHuO1D/5fX3/JOqhNTbRR+WjSBWxf9bIvpWK+ftjmdfJHzPOuMQKe2fiEdG+dZX6E8uUpA==
|
4485
4485
|
dependencies:
|
4486
|
-
react-router "7.
|
4486
|
+
react-router "7.6.1"
|
4487
4487
|
|
4488
|
-
react-router@7.
|
4489
|
-
version "7.
|
4490
|
-
resolved "https://registry.yarnpkg.com/react-router/-/react-router-7.
|
4491
|
-
integrity sha512-
|
4488
|
+
react-router@7.6.1:
|
4489
|
+
version "7.6.1"
|
4490
|
+
resolved "https://registry.yarnpkg.com/react-router/-/react-router-7.6.1.tgz#a54f9b980b94594bcb4b7f26611612a9f6e17461"
|
4491
|
+
integrity sha512-hPJXXxHJZEsPFNVbtATH7+MMX43UDeOauz+EAU4cgqTn7ojdI9qQORqS8Z0qmDlL1TclO/6jLRYUEtbWidtdHQ==
|
4492
4492
|
dependencies:
|
4493
4493
|
cookie "^1.0.1"
|
4494
4494
|
set-cookie-parser "^2.6.0"
|
4495
|
-
turbo-stream "2.4.0"
|
4496
4495
|
|
4497
4496
|
react-transition-group@^4.4.5:
|
4498
4497
|
version "4.4.5"
|
@@ -4504,10 +4503,10 @@ react-transition-group@^4.4.5:
|
|
4504
4503
|
loose-envify "^1.4.0"
|
4505
4504
|
prop-types "^15.6.2"
|
4506
4505
|
|
4507
|
-
react-virtuoso@^4.12.
|
4508
|
-
version "4.12.
|
4509
|
-
resolved "https://registry.yarnpkg.com/react-virtuoso/-/react-virtuoso-4.12.
|
4510
|
-
integrity sha512-
|
4506
|
+
react-virtuoso@^4.12.7:
|
4507
|
+
version "4.12.8"
|
4508
|
+
resolved "https://registry.yarnpkg.com/react-virtuoso/-/react-virtuoso-4.12.8.tgz#db1dbba617f91c1dcd760aa90e09ef991e65a356"
|
4509
|
+
integrity sha512-NMMKfDBr/+xZZqCQF3tN1SZsh6FwOJkYgThlfnsPLkaEhdyQo0EuWUzu3ix6qjnI7rYwJhMwRGoJBi+aiDfGsA==
|
4511
4510
|
|
4512
4511
|
react@^19.0.0:
|
4513
4512
|
version "19.1.0"
|
@@ -4922,11 +4921,6 @@ ts-jest@^29.3.2:
|
|
4922
4921
|
type-fest "^4.39.1"
|
4923
4922
|
yargs-parser "^21.1.1"
|
4924
4923
|
|
4925
|
-
turbo-stream@2.4.0:
|
4926
|
-
version "2.4.0"
|
4927
|
-
resolved "https://registry.yarnpkg.com/turbo-stream/-/turbo-stream-2.4.0.tgz#1e4fca6725e90fa14ac4adb782f2d3759a5695f0"
|
4928
|
-
integrity sha512-FHncC10WpBd2eOmGwpmQsWLDoK4cqsA/UT/GqNoaKOQnT8uzhtCbg3EoUDMvqpOSAI0S26mr0rkjzbOO6S3v1g==
|
4929
|
-
|
4930
4924
|
type-check@^0.4.0, type-check@~0.4.0:
|
4931
4925
|
version "0.4.0"
|
4932
4926
|
resolved "https://registry.yarnpkg.com/type-check/-/type-check-0.4.0.tgz#07b8203bfa7056c0657050e3ccd2c37730bab8f1"
|
@@ -7,6 +7,7 @@ from ._dataframe.evals.columns import (
|
|
7
7
|
EvalColumn,
|
8
8
|
EvalColumns,
|
9
9
|
EvalConfig,
|
10
|
+
EvalDataset,
|
10
11
|
EvalInfo,
|
11
12
|
EvalModel,
|
12
13
|
EvalResults,
|
@@ -41,6 +42,7 @@ __all__ = [
|
|
41
42
|
"EvalModel",
|
42
43
|
"EvalColumns",
|
43
44
|
"EvalConfig",
|
45
|
+
"EvalDataset",
|
44
46
|
"EvalResults",
|
45
47
|
"EvalScores",
|
46
48
|
"samples_df",
|
@@ -14,7 +14,7 @@ from .._dataset import (
|
|
14
14
|
MemoryDataset,
|
15
15
|
RecordToSample,
|
16
16
|
)
|
17
|
-
from .._util import data_to_samples, record_to_sample_fn
|
17
|
+
from .._util import data_to_samples, record_to_sample_fn, shuffle_choices_if_requested
|
18
18
|
|
19
19
|
|
20
20
|
def csv_dataset(
|
@@ -88,11 +88,7 @@ def csv_dataset(
|
|
88
88
|
if shuffle:
|
89
89
|
dataset.shuffle(seed=seed)
|
90
90
|
|
91
|
-
|
92
|
-
if isinstance(shuffle_choices, int):
|
93
|
-
dataset.shuffle_choices(seed=shuffle_choices)
|
94
|
-
elif shuffle_choices is True:
|
95
|
-
dataset.shuffle_choices()
|
91
|
+
shuffle_choices_if_requested(dataset, shuffle_choices)
|
96
92
|
|
97
93
|
# limit if requested
|
98
94
|
if limit:
|
@@ -16,7 +16,7 @@ from .._dataset import (
|
|
16
16
|
MemoryDataset,
|
17
17
|
RecordToSample,
|
18
18
|
)
|
19
|
-
from .._util import data_to_samples, record_to_sample_fn
|
19
|
+
from .._util import data_to_samples, record_to_sample_fn, shuffle_choices_if_requested
|
20
20
|
|
21
21
|
|
22
22
|
def hf_dataset(
|
@@ -125,10 +125,6 @@ def hf_dataset(
|
|
125
125
|
location=path,
|
126
126
|
)
|
127
127
|
|
128
|
-
|
129
|
-
if isinstance(shuffle_choices, int):
|
130
|
-
memory_dataset.shuffle_choices(seed=shuffle_choices)
|
131
|
-
elif shuffle_choices is True:
|
132
|
-
memory_dataset.shuffle_choices()
|
128
|
+
shuffle_choices_if_requested(memory_dataset, shuffle_choices)
|
133
129
|
|
134
130
|
return memory_dataset
|
@@ -15,7 +15,7 @@ from .._dataset import (
|
|
15
15
|
MemoryDataset,
|
16
16
|
RecordToSample,
|
17
17
|
)
|
18
|
-
from .._util import data_to_samples, record_to_sample_fn
|
18
|
+
from .._util import data_to_samples, record_to_sample_fn, shuffle_choices_if_requested
|
19
19
|
from .util import resolve_sample_files
|
20
20
|
|
21
21
|
|
@@ -88,11 +88,7 @@ def json_dataset(
|
|
88
88
|
if shuffle:
|
89
89
|
dataset.shuffle(seed=seed)
|
90
90
|
|
91
|
-
|
92
|
-
if isinstance(shuffle_choices, int):
|
93
|
-
dataset.shuffle_choices(seed=shuffle_choices)
|
94
|
-
elif shuffle_choices is True:
|
95
|
-
dataset.shuffle_choices()
|
91
|
+
shuffle_choices_if_requested(dataset, shuffle_choices)
|
96
92
|
|
97
93
|
# limit if requested
|
98
94
|
if limit:
|
inspect_ai/dataset/_util.py
CHANGED
@@ -13,6 +13,7 @@ from inspect_ai.model import (
|
|
13
13
|
from inspect_ai.util._sandbox.environment import SandboxEnvironmentSpec
|
14
14
|
|
15
15
|
from ._dataset import (
|
16
|
+
Dataset,
|
16
17
|
DatasetRecord,
|
17
18
|
FieldSpec,
|
18
19
|
RecordToSample,
|
@@ -225,3 +226,25 @@ def read_files(files: Any | None) -> dict[str, str] | None:
|
|
225
226
|
raise ValueError(f"Unexpected type for 'files' field: {type(files)}")
|
226
227
|
else:
|
227
228
|
return None
|
229
|
+
|
230
|
+
|
231
|
+
def shuffle_choices_if_requested(
|
232
|
+
dataset: Dataset, shuffle_choices: bool | int | None
|
233
|
+
) -> None:
|
234
|
+
"""
|
235
|
+
Shuffle the choices in the dataset if requested.
|
236
|
+
|
237
|
+
The `shuffle_choices` parameter passed to `json_dataset`, `csv_dataset`,
|
238
|
+
and `hf_dataset` can be a boolean, an integer, or `None` (default).
|
239
|
+
If it is a boolean, it will shuffle the choices if the value is `True`,
|
240
|
+
and do nothing if it is `False`.
|
241
|
+
If it is an integer, it will shuffle the choices using the integer as the seed.
|
242
|
+
"""
|
243
|
+
# Note that `isinstance(x, int)` returns True if x is True or False,
|
244
|
+
# so we need to check for both explicitly
|
245
|
+
if shuffle_choices is True:
|
246
|
+
dataset.shuffle_choices()
|
247
|
+
elif shuffle_choices is False:
|
248
|
+
pass
|
249
|
+
elif isinstance(shuffle_choices, int):
|
250
|
+
dataset.shuffle_choices(seed=shuffle_choices)
|
@@ -133,6 +133,7 @@ class EvalRecorder(FileRecorder):
|
|
133
133
|
results: EvalResults | None,
|
134
134
|
reductions: list[EvalSampleReductions] | None,
|
135
135
|
error: EvalError | None = None,
|
136
|
+
header_only: bool = False,
|
136
137
|
) -> EvalLog:
|
137
138
|
# get the key and log
|
138
139
|
key = self._log_file_key(eval)
|
@@ -174,7 +175,7 @@ class EvalRecorder(FileRecorder):
|
|
174
175
|
|
175
176
|
# flush and write the results
|
176
177
|
await log.flush()
|
177
|
-
return await log.close()
|
178
|
+
return await log.close(header_only)
|
178
179
|
|
179
180
|
@classmethod
|
180
181
|
@override
|
@@ -321,12 +322,12 @@ class ZipLogFile:
|
|
321
322
|
# re-open zip file w/ self.temp_file pointer at end
|
322
323
|
self._open()
|
323
324
|
|
324
|
-
async def close(self) -> EvalLog:
|
325
|
+
async def close(self, header_only: bool) -> EvalLog:
|
325
326
|
async with self._lock:
|
326
327
|
# read the log from the temp file then close it
|
327
328
|
try:
|
328
329
|
self._temp_file.seek(0)
|
329
|
-
return _read_log(self._temp_file, self._file)
|
330
|
+
return _read_log(self._temp_file, self._file, header_only=header_only)
|
330
331
|
finally:
|
331
332
|
self._temp_file.close()
|
332
333
|
if self._zip:
|
@@ -96,6 +96,7 @@ class JSONRecorder(FileRecorder):
|
|
96
96
|
results: EvalResults | None,
|
97
97
|
reductions: list[EvalSampleReductions] | None,
|
98
98
|
error: EvalError | None = None,
|
99
|
+
header_only: bool = False,
|
99
100
|
) -> EvalLog:
|
100
101
|
log = self.data[self._log_file_key(spec)]
|
101
102
|
log.data.status = status
|
@@ -162,9 +162,9 @@ def openai_responses_tool_choice(
|
|
162
162
|
|
163
163
|
|
164
164
|
def openai_responses_tools(
|
165
|
-
tools: list[ToolInfo], config: GenerateConfig
|
165
|
+
tools: list[ToolInfo], model_name: str, config: GenerateConfig
|
166
166
|
) -> list[ToolParam]:
|
167
|
-
return [_tool_param_for_tool_info(tool, config) for tool in tools]
|
167
|
+
return [_tool_param_for_tool_info(tool, model_name, config) for tool in tools]
|
168
168
|
|
169
169
|
|
170
170
|
def openai_responses_chat_choices(
|
@@ -177,9 +177,11 @@ def openai_responses_chat_choices(
|
|
177
177
|
|
178
178
|
|
179
179
|
def is_native_tool_configured(
|
180
|
-
tools: Sequence[ToolInfo], config: GenerateConfig
|
180
|
+
tools: Sequence[ToolInfo], model_name: str, config: GenerateConfig
|
181
181
|
) -> bool:
|
182
|
-
return any(
|
182
|
+
return any(
|
183
|
+
_maybe_native_tool_param(tool, model_name, config) is not None for tool in tools
|
184
|
+
)
|
183
185
|
|
184
186
|
|
185
187
|
# The next two function perform transformations between OpenAI types an Inspect
|
@@ -433,11 +435,13 @@ def _model_tool_call_for_internal(
|
|
433
435
|
|
434
436
|
def _maybe_native_tool_param(
|
435
437
|
tool: ToolInfo,
|
438
|
+
model_name: str,
|
436
439
|
config: GenerateConfig,
|
437
440
|
) -> ToolParam | None:
|
438
441
|
return (
|
439
442
|
(
|
440
|
-
maybe_computer_use_preview_tool(tool)
|
443
|
+
maybe_computer_use_preview_tool(tool)
|
444
|
+
or maybe_web_search_tool(model_name, tool)
|
441
445
|
# or self.text_editor_tool_param(tool)
|
442
446
|
# or self.bash_tool_param(tool)
|
443
447
|
)
|
@@ -502,11 +506,12 @@ _ResponseToolCallParam = (
|
|
502
506
|
|
503
507
|
def _tool_param_for_tool_info(
|
504
508
|
tool: ToolInfo,
|
509
|
+
model_name: str,
|
505
510
|
config: GenerateConfig,
|
506
511
|
) -> ToolParam:
|
507
512
|
# Use a native tool implementation when available. Otherwise, use the
|
508
513
|
# standard tool implementation
|
509
|
-
return _maybe_native_tool_param(tool, config) or FunctionToolParam(
|
514
|
+
return _maybe_native_tool_param(tool, model_name, config) or FunctionToolParam(
|
510
515
|
type="function",
|
511
516
|
name=_responses_tool_alias(tool.name),
|
512
517
|
description=tool.description,
|
@@ -4,11 +4,18 @@ from openai.types.responses import WebSearchTool, WebSearchToolParam
|
|
4
4
|
|
5
5
|
from inspect_ai.tool._tool_info import ToolInfo
|
6
6
|
|
7
|
+
COMPATIBLE_MODELS = ["gpt-4o", "gpt-4o-mini", "gpt-4.1"]
|
7
8
|
|
8
|
-
|
9
|
+
|
10
|
+
def maybe_web_search_tool(model_name: str, tool: ToolInfo) -> WebSearchToolParam | None:
|
9
11
|
return (
|
10
12
|
_web_search_tool(tool.options["openai"])
|
11
|
-
if
|
13
|
+
if (
|
14
|
+
tool.name == "web_search"
|
15
|
+
and tool.options
|
16
|
+
and "openai" in tool.options
|
17
|
+
and model_name in COMPATIBLE_MODELS
|
18
|
+
)
|
12
19
|
else None
|
13
20
|
)
|
14
21
|
|
@@ -242,7 +242,9 @@ class OpenAIAPI(ModelAPI):
|
|
242
242
|
tools=tools,
|
243
243
|
**self.completion_params(config, False),
|
244
244
|
)
|
245
|
-
elif self.responses_api or is_native_tool_configured(
|
245
|
+
elif self.responses_api or is_native_tool_configured(
|
246
|
+
tools, self.model_name, config
|
247
|
+
):
|
246
248
|
return await generate_responses(
|
247
249
|
client=self.client,
|
248
250
|
http_hooks=self._http_hooks,
|
@@ -59,7 +59,11 @@ async def generate_responses(
|
|
59
59
|
)
|
60
60
|
|
61
61
|
# prepare request (we do this so we can log the ModelCall)
|
62
|
-
tool_params =
|
62
|
+
tool_params = (
|
63
|
+
openai_responses_tools(tools, model_name, config)
|
64
|
+
if len(tools) > 0
|
65
|
+
else NOT_GIVEN
|
66
|
+
)
|
63
67
|
request = dict(
|
64
68
|
input=await openai_responses_inputs(input, model_name, store),
|
65
69
|
tools=tool_params,
|
@@ -121,7 +121,7 @@ def pass_at(
|
|
121
121
|
def reduce(scores: list[Score]) -> Score:
|
122
122
|
def pass_at_k(values: list[float]) -> float:
|
123
123
|
total = len(scores)
|
124
|
-
correct = sum(1 for v in values if v
|
124
|
+
correct = sum(1 for v in values if v >= value)
|
125
125
|
if total - correct < k:
|
126
126
|
return 1.0
|
127
127
|
else:
|
@@ -32,9 +32,10 @@ class GoogleOptions(BaseModel):
|
|
32
32
|
|
33
33
|
|
34
34
|
class SearchLink:
|
35
|
-
def __init__(self, url: str, snippet: str) -> None:
|
35
|
+
def __init__(self, url: str, snippet: str, title: str) -> None:
|
36
36
|
self.url = url
|
37
37
|
self.snippet = snippet
|
38
|
+
self.title = title
|
38
39
|
|
39
40
|
|
40
41
|
def maybe_get_google_api_keys() -> tuple[str, str] | None:
|
@@ -71,8 +72,7 @@ def google_search_provider(
|
|
71
72
|
async def search(query: str) -> str | None:
|
72
73
|
# limit number of concurrent searches
|
73
74
|
page_contents: list[str] = []
|
74
|
-
|
75
|
-
snippets: list[str] = []
|
75
|
+
processed_links: list[SearchLink] = []
|
76
76
|
search_calls = 0
|
77
77
|
|
78
78
|
# Paginate through search results until we have successfully extracted num_results pages or we have reached max_provider_calls
|
@@ -87,8 +87,7 @@ def google_search_provider(
|
|
87
87
|
page = await page_if_relevant(link.url, query, model, client)
|
88
88
|
if page:
|
89
89
|
page_contents.append(page)
|
90
|
-
|
91
|
-
snippets.append(link.snippet)
|
90
|
+
processed_links.append(link)
|
92
91
|
# exceptions fetching pages are very common!
|
93
92
|
except Exception:
|
94
93
|
pass
|
@@ -98,8 +97,18 @@ def google_search_provider(
|
|
98
97
|
|
99
98
|
search_calls += 1
|
100
99
|
|
101
|
-
|
102
|
-
|
100
|
+
return (
|
101
|
+
"\n\n".join(
|
102
|
+
"[{title}]({url}):\n{page_content}".format(
|
103
|
+
title=link.title, url=link.url, page_content=page_content
|
104
|
+
)
|
105
|
+
for link, page_content in zip(
|
106
|
+
processed_links, page_contents, strict=True
|
107
|
+
)
|
108
|
+
)
|
109
|
+
if processed_links
|
110
|
+
else None
|
111
|
+
)
|
103
112
|
|
104
113
|
async def _search(query: str, start_idx: int) -> list[SearchLink]:
|
105
114
|
# List of allowed parameters can be found https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
|
@@ -121,13 +130,21 @@ def google_search_provider(
|
|
121
130
|
before_sleep=log_httpx_retry_attempt(search_url),
|
122
131
|
)
|
123
132
|
async def execute_search() -> httpx.Response:
|
133
|
+
# See https://developers.google.com/custom-search/v1/reference/rest/v1/Search
|
124
134
|
return await client.get(search_url)
|
125
135
|
|
126
136
|
result = await execute_search()
|
127
137
|
data = result.json()
|
128
138
|
|
129
139
|
if "items" in data:
|
130
|
-
return [
|
140
|
+
return [
|
141
|
+
SearchLink(
|
142
|
+
url=item["link"],
|
143
|
+
snippet=item.get("snippet", ""), # sometimes not present
|
144
|
+
title=item["title"],
|
145
|
+
)
|
146
|
+
for item in data["items"]
|
147
|
+
]
|
131
148
|
else:
|
132
149
|
return []
|
133
150
|
|
@@ -135,13 +152,13 @@ def google_search_provider(
|
|
135
152
|
|
136
153
|
|
137
154
|
async def page_if_relevant(
|
138
|
-
|
155
|
+
url: str, query: str, relevance_model: str | None, client: httpx.AsyncClient
|
139
156
|
) -> str | None:
|
140
157
|
"""
|
141
158
|
Use parser model to determine if a web page contents is relevant to a query.
|
142
159
|
|
143
160
|
Args:
|
144
|
-
|
161
|
+
url (str): Web page url.
|
145
162
|
query (str): Search query.
|
146
163
|
relevance_model (Model): Model used to parse web pages for relevance.
|
147
164
|
client: (httpx.Client): HTTP client to use to fetch the page
|
@@ -156,7 +173,7 @@ async def page_if_relevant(
|
|
156
173
|
|
157
174
|
# retrieve document
|
158
175
|
try:
|
159
|
-
response = await client.get(
|
176
|
+
response = await client.get(url)
|
160
177
|
response.raise_for_status()
|
161
178
|
except httpx.HTTPError as exc:
|
162
179
|
raise Exception(f"HTTP error occurred: {exc}")
|
@@ -75,6 +75,7 @@ def tavily_search_provider(
|
|
75
75
|
client = httpx.AsyncClient(timeout=30)
|
76
76
|
|
77
77
|
async def search(query: str) -> str | None:
|
78
|
+
# See https://docs.tavily.com/documentation/api-reference/endpoint/search
|
78
79
|
search_url = "https://api.tavily.com/search"
|
79
80
|
headers = {
|
80
81
|
"Authorization": f"Bearer {tavily_api_key}",
|
@@ -95,6 +96,15 @@ def tavily_search_provider(
|
|
95
96
|
return response
|
96
97
|
|
97
98
|
async with concurrency("tavily_web_search", max_connections):
|
98
|
-
|
99
|
+
tavily_search_response = TavilySearchResponse.model_validate(
|
100
|
+
(await _search()).json()
|
101
|
+
)
|
102
|
+
results_str = "\n\n".join(
|
103
|
+
[
|
104
|
+
f"[{result.title}]({result.url}):\n{result.content}"
|
105
|
+
for result in tavily_search_response.results
|
106
|
+
]
|
107
|
+
)
|
108
|
+
return f"Answer: {tavily_search_response.answer}\n\n{results_str}"
|
99
109
|
|
100
110
|
return search
|