inspect-ai 0.3.96__py3-none-any.whl → 0.3.97__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_eval/eval.py +10 -2
- inspect_ai/_eval/task/util.py +32 -3
- inspect_ai/_util/registry.py +7 -0
- inspect_ai/_util/timer.py +13 -0
- inspect_ai/_view/www/dist/assets/index.css +275 -195
- inspect_ai/_view/www/dist/assets/index.js +8568 -7376
- inspect_ai/_view/www/src/app/App.css +1 -0
- inspect_ai/_view/www/src/app/App.tsx +27 -10
- inspect_ai/_view/www/src/app/appearance/icons.ts +5 -0
- inspect_ai/_view/www/src/app/content/RecordTree.module.css +22 -0
- inspect_ai/_view/www/src/app/content/RecordTree.tsx +370 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.module.css +5 -0
- inspect_ai/_view/www/src/app/content/RenderedContent.tsx +32 -19
- inspect_ai/_view/www/src/app/content/record_processors/store.ts +101 -0
- inspect_ai/_view/www/src/app/content/record_processors/types.ts +3 -0
- inspect_ai/_view/www/src/app/content/types.ts +5 -0
- inspect_ai/_view/www/src/app/log-view/LogView.tsx +1 -0
- inspect_ai/_view/www/src/app/log-view/LogViewContainer.tsx +35 -28
- inspect_ai/_view/www/src/app/log-view/LogViewLayout.tsx +1 -8
- inspect_ai/_view/www/src/app/log-view/navbar/PrimaryBar.tsx +2 -4
- inspect_ai/_view/www/src/app/log-view/navbar/ResultsPanel.tsx +13 -3
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.module.css +15 -0
- inspect_ai/_view/www/src/app/log-view/navbar/ScoreGrid.tsx +14 -10
- inspect_ai/_view/www/src/app/log-view/tabs/InfoTab.tsx +9 -3
- inspect_ai/_view/www/src/app/log-view/tabs/JsonTab.tsx +1 -3
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +8 -2
- inspect_ai/_view/www/src/app/log-view/types.ts +1 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.module.css +7 -0
- inspect_ai/_view/www/src/app/plan/ModelCard.tsx +5 -2
- inspect_ai/_view/www/src/app/plan/PlanCard.tsx +13 -8
- inspect_ai/_view/www/src/app/routing/navigationHooks.ts +63 -8
- inspect_ai/_view/www/src/app/routing/url.ts +45 -0
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/InlineSampleDisplay.tsx +15 -8
- inspect_ai/_view/www/src/app/samples/SampleDialog.module.css +3 -0
- inspect_ai/_view/www/src/app/samples/SampleDialog.tsx +16 -5
- inspect_ai/_view/www/src/app/samples/SampleDisplay.module.css +9 -1
- inspect_ai/_view/www/src/app/samples/SampleDisplay.tsx +68 -31
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.module.css +12 -7
- inspect_ai/_view/www/src/app/samples/chat/ChatMessage.tsx +17 -5
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.module.css +9 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatMessageRow.tsx +48 -18
- inspect_ai/_view/www/src/app/samples/chat/ChatView.tsx +0 -1
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.module.css +4 -0
- inspect_ai/_view/www/src/app/samples/chat/ChatViewVirtualList.tsx +41 -1
- inspect_ai/_view/www/src/app/samples/chat/messages.ts +7 -0
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.module.css +0 -3
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolCallView.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolInput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/chat/tools/ToolOutput.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/NumericScoreDescriptor.tsx +5 -1
- inspect_ai/_view/www/src/app/samples/descriptor/score/PassFailScoreDescriptor.tsx +11 -6
- inspect_ai/_view/www/src/app/samples/list/SampleList.tsx +7 -0
- inspect_ai/_view/www/src/app/samples/list/SampleRow.tsx +5 -18
- inspect_ai/_view/www/src/app/samples/sample-tools/SortFilter.tsx +1 -1
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresGrid.tsx +18 -5
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.module.css +0 -6
- inspect_ai/_view/www/src/app/samples/scores/SampleScoresView.tsx +4 -1
- inspect_ai/_view/www/src/app/samples/transcript/ApprovalEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ErrorEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/InfoEventView.tsx +13 -6
- inspect_ai/_view/www/src/app/samples/transcript/InputEventView.tsx +6 -4
- inspect_ai/_view/www/src/app/samples/transcript/LoggerEventView.tsx +4 -2
- inspect_ai/_view/www/src/app/samples/transcript/ModelEventView.tsx +11 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleInitEventView.tsx +14 -8
- inspect_ai/_view/www/src/app/samples/transcript/SampleLimitEventView.tsx +13 -8
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.tsx +25 -16
- inspect_ai/_view/www/src/app/samples/transcript/ScoreEventView.tsx +7 -5
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +11 -28
- inspect_ai/_view/www/src/app/samples/transcript/StepEventView.tsx +12 -20
- inspect_ai/_view/www/src/app/samples/transcript/SubtaskEventView.tsx +12 -31
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +25 -29
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualList.tsx +297 -0
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +0 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.tsx +43 -25
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.module.css +43 -0
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +109 -43
- inspect_ai/_view/www/src/app/samples/transcript/state/StateEventView.tsx +19 -8
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +128 -60
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +14 -4
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +6 -4
- inspect_ai/_view/www/src/app/types.ts +12 -1
- inspect_ai/_view/www/src/components/Card.css +6 -3
- inspect_ai/_view/www/src/components/Card.tsx +15 -2
- inspect_ai/_view/www/src/components/CopyButton.tsx +4 -6
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +20 -14
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +17 -22
- inspect_ai/_view/www/src/components/LargeModal.tsx +5 -1
- inspect_ai/_view/www/src/components/LiveVirtualList.tsx +25 -1
- inspect_ai/_view/www/src/components/MarkdownDiv.css +4 -0
- inspect_ai/_view/www/src/components/MarkdownDiv.tsx +2 -2
- inspect_ai/_view/www/src/components/TabSet.module.css +6 -1
- inspect_ai/_view/www/src/components/TabSet.tsx +8 -2
- inspect_ai/_view/www/src/state/hooks.ts +83 -13
- inspect_ai/_view/www/src/state/logPolling.ts +2 -2
- inspect_ai/_view/www/src/state/logSlice.ts +1 -2
- inspect_ai/_view/www/src/state/logsSlice.ts +9 -9
- inspect_ai/_view/www/src/state/samplePolling.ts +1 -1
- inspect_ai/_view/www/src/state/sampleSlice.ts +134 -7
- inspect_ai/_view/www/src/state/scoring.ts +1 -1
- inspect_ai/_view/www/src/state/scrolling.ts +39 -6
- inspect_ai/_view/www/src/state/store.ts +5 -0
- inspect_ai/_view/www/src/state/store_filter.ts +47 -44
- inspect_ai/_view/www/src/utils/debugging.ts +95 -0
- inspect_ai/_view/www/src/utils/format.ts +2 -2
- inspect_ai/_view/www/src/utils/json.ts +29 -0
- inspect_ai/agent/__init__.py +2 -1
- inspect_ai/agent/_agent.py +12 -0
- inspect_ai/agent/_react.py +184 -48
- inspect_ai/agent/_types.py +14 -1
- inspect_ai/analysis/beta/__init__.py +0 -2
- inspect_ai/analysis/beta/_dataframe/columns.py +11 -16
- inspect_ai/analysis/beta/_dataframe/evals/table.py +65 -40
- inspect_ai/analysis/beta/_dataframe/events/table.py +24 -36
- inspect_ai/analysis/beta/_dataframe/messages/table.py +24 -15
- inspect_ai/analysis/beta/_dataframe/progress.py +35 -5
- inspect_ai/analysis/beta/_dataframe/record.py +13 -9
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +1 -1
- inspect_ai/analysis/beta/_dataframe/samples/table.py +156 -46
- inspect_ai/analysis/beta/_dataframe/util.py +14 -12
- inspect_ai/model/_call_tools.py +1 -1
- inspect_ai/model/_providers/anthropic.py +18 -5
- inspect_ai/model/_providers/azureai.py +7 -2
- inspect_ai/model/_providers/util/llama31.py +3 -3
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/METADATA +1 -1
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/RECORD +131 -126
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/WHEEL +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.module.css +0 -48
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +0 -276
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.96.dist-info → inspect_ai-0.3.97.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,22 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import multiprocessing as mp
|
4
|
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
3
5
|
from dataclasses import dataclass
|
4
6
|
from functools import lru_cache
|
7
|
+
from itertools import chain
|
5
8
|
from typing import (
|
6
9
|
TYPE_CHECKING,
|
7
10
|
Callable,
|
8
11
|
Generator,
|
9
12
|
Literal,
|
13
|
+
Sequence,
|
14
|
+
cast,
|
10
15
|
overload,
|
11
16
|
)
|
12
17
|
|
13
18
|
from inspect_ai._util.hash import mm3_hash
|
14
|
-
from inspect_ai.
|
15
|
-
from inspect_ai.analysis.beta._dataframe.progress import import_progress
|
19
|
+
from inspect_ai.analysis.beta._dataframe.progress import import_progress, no_progress
|
16
20
|
from inspect_ai.log._file import (
|
17
21
|
list_eval_logs,
|
18
22
|
read_eval_log_sample_summaries,
|
@@ -22,7 +26,7 @@ from inspect_ai.log._log import EvalSample, EvalSampleSummary
|
|
22
26
|
from inspect_ai.log._transcript import Event
|
23
27
|
from inspect_ai.model._chat_message import ChatMessage
|
24
28
|
|
25
|
-
from ..columns import Column,
|
29
|
+
from ..columns import Column, ColumnError, ColumnType
|
26
30
|
from ..evals.columns import EvalColumn
|
27
31
|
from ..evals.table import EVAL_ID, EVAL_SUFFIX, _read_evals_df, ensure_eval_id
|
28
32
|
from ..events.columns import EventColumn
|
@@ -51,24 +55,30 @@ SAMPLE_SUFFIX = "_sample"
|
|
51
55
|
@overload
|
52
56
|
def samples_df(
|
53
57
|
logs: LogPaths = list_eval_logs(),
|
54
|
-
columns:
|
58
|
+
columns: Sequence[Column] = SampleSummary,
|
55
59
|
strict: Literal[True] = True,
|
60
|
+
parallel: bool | int = False,
|
61
|
+
quiet: bool = False,
|
56
62
|
) -> "pd.DataFrame": ...
|
57
63
|
|
58
64
|
|
59
65
|
@overload
|
60
66
|
def samples_df(
|
61
67
|
logs: LogPaths = list_eval_logs(),
|
62
|
-
columns:
|
68
|
+
columns: Sequence[Column] = SampleSummary,
|
63
69
|
strict: Literal[False] = False,
|
64
|
-
|
70
|
+
parallel: bool | int = False,
|
71
|
+
quiet: bool = False,
|
72
|
+
) -> tuple["pd.DataFrame", list[ColumnError]]: ...
|
65
73
|
|
66
74
|
|
67
75
|
def samples_df(
|
68
76
|
logs: LogPaths = list_eval_logs(),
|
69
|
-
columns:
|
77
|
+
columns: Sequence[Column] = SampleSummary,
|
70
78
|
strict: bool = True,
|
71
|
-
|
79
|
+
parallel: bool | int = False,
|
80
|
+
quiet: bool = False,
|
81
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", list[ColumnError]]:
|
72
82
|
"""Read a dataframe containing samples from a set of evals.
|
73
83
|
|
74
84
|
Args:
|
@@ -78,41 +88,130 @@ def samples_df(
|
|
78
88
|
columns: Specification for what columns to read from log files.
|
79
89
|
strict: Raise import errors immediately. Defaults to `True`.
|
80
90
|
If `False` then a tuple of `DataFrame` and errors is returned.
|
91
|
+
parallel: If `True`, use `ProcessPoolExecutor` to read logs in parallel
|
92
|
+
(with workers based on `mp.cpu_count()`, capped at 8). If `int`, read
|
93
|
+
in parallel with the specified number of workers. If `False` (the default)
|
94
|
+
do not read in parallel.
|
95
|
+
quiet: If `True` do not print any output or progress (defaults to `False`).
|
81
96
|
|
82
97
|
Returns:
|
83
98
|
For `strict`, a Pandas `DataFrame` with information for the specified logs.
|
84
99
|
For `strict=False`, a tuple of Pandas `DataFrame` and a dictionary of errors
|
85
100
|
encountered (by log file) during import.
|
86
101
|
"""
|
87
|
-
|
102
|
+
verify_prerequisites()
|
103
|
+
|
104
|
+
return _read_samples_df(
|
105
|
+
logs, columns, strict=strict, progress=not quiet, parallel=parallel
|
106
|
+
)
|
88
107
|
|
89
108
|
|
90
109
|
@dataclass
|
91
110
|
class MessagesDetail:
|
92
111
|
name: str = "message"
|
93
112
|
col_type = MessageColumn
|
94
|
-
filter: Callable[[ChatMessage], bool]
|
113
|
+
filter: Callable[[ChatMessage], bool] | None = None
|
95
114
|
|
96
115
|
|
97
116
|
@dataclass
|
98
117
|
class EventsDetail:
|
99
118
|
name: str = "event"
|
100
119
|
col_type = EventColumn
|
101
|
-
filter: Callable[[Event], bool]
|
120
|
+
filter: Callable[[Event], bool] | None = None
|
102
121
|
|
103
122
|
|
104
123
|
def _read_samples_df(
|
105
124
|
logs: LogPaths,
|
106
|
-
columns:
|
125
|
+
columns: Sequence[Column],
|
107
126
|
*,
|
108
127
|
strict: bool = True,
|
109
128
|
detail: MessagesDetail | EventsDetail | None = None,
|
110
|
-
|
111
|
-
|
129
|
+
progress: bool = True,
|
130
|
+
parallel: bool | int = False,
|
131
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", list[ColumnError]]:
|
132
|
+
import pandas as pd
|
112
133
|
|
113
134
|
# resolve logs
|
114
135
|
logs = resolve_logs(logs)
|
115
136
|
|
137
|
+
if parallel:
|
138
|
+
# resolve number of workers (cap at 8 as eventually we run into disk/memory contention)
|
139
|
+
if parallel is True:
|
140
|
+
parallel = max(min(mp.cpu_count(), 8), 2)
|
141
|
+
|
142
|
+
# flatted out list of logs
|
143
|
+
logs = resolve_logs(logs)
|
144
|
+
|
145
|
+
# establish progress
|
146
|
+
entity = detail.name if detail else "sample"
|
147
|
+
progress_cm = (
|
148
|
+
import_progress(f"reading {entity}s", total=len(logs))
|
149
|
+
if progress
|
150
|
+
else no_progress()
|
151
|
+
)
|
152
|
+
|
153
|
+
# run the parallel reads (setup arrays for holding results in order)
|
154
|
+
df_results: list[pd.DataFrame | None] = [None] * len(logs)
|
155
|
+
error_results: list[list[ColumnError] | None] = [None] * len(logs)
|
156
|
+
executor = ProcessPoolExecutor(max_workers=parallel)
|
157
|
+
try:
|
158
|
+
with progress_cm as p:
|
159
|
+
futures = {
|
160
|
+
executor.submit(
|
161
|
+
_read_samples_df_serial, # type: ignore[arg-type]
|
162
|
+
logs=[log],
|
163
|
+
columns=columns,
|
164
|
+
strict=strict,
|
165
|
+
detail=detail,
|
166
|
+
progress=False,
|
167
|
+
): idx
|
168
|
+
for idx, log in enumerate(logs)
|
169
|
+
}
|
170
|
+
for fut in as_completed(futures):
|
171
|
+
idx = futures[fut]
|
172
|
+
if strict:
|
173
|
+
df_results[idx] = cast(pd.DataFrame, fut.result())
|
174
|
+
else:
|
175
|
+
df, errs = cast(
|
176
|
+
tuple[pd.DataFrame, list[ColumnError]], fut.result()
|
177
|
+
)
|
178
|
+
df_results[idx] = df
|
179
|
+
error_results[idx] = errs
|
180
|
+
p.update()
|
181
|
+
finally:
|
182
|
+
executor.shutdown(wait=False, cancel_futures=True)
|
183
|
+
|
184
|
+
# recombine df
|
185
|
+
df = pd.concat(df_results, ignore_index=True)
|
186
|
+
subset = f"{detail.name}_id" if detail else SAMPLE_ID
|
187
|
+
df.drop_duplicates(subset=subset, ignore_index=True, inplace=True)
|
188
|
+
|
189
|
+
# recombine errors
|
190
|
+
errors: list[ColumnError] = list(
|
191
|
+
chain.from_iterable(e for e in error_results if e)
|
192
|
+
)
|
193
|
+
|
194
|
+
# return as required
|
195
|
+
if strict:
|
196
|
+
return df
|
197
|
+
else:
|
198
|
+
return df, errors
|
199
|
+
|
200
|
+
# non-parallel
|
201
|
+
else:
|
202
|
+
return _read_samples_df_serial(
|
203
|
+
logs=logs, columns=columns, strict=strict, detail=detail, progress=progress
|
204
|
+
)
|
205
|
+
|
206
|
+
|
207
|
+
def _read_samples_df_serial(
|
208
|
+
logs: list[str],
|
209
|
+
columns: Sequence[Column],
|
210
|
+
*,
|
211
|
+
strict: bool = True,
|
212
|
+
detail: MessagesDetail | EventsDetail | None = None,
|
213
|
+
progress: bool = True,
|
214
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", list[ColumnError]]:
|
116
215
|
# split columns by type
|
117
216
|
columns_eval: list[Column] = []
|
118
217
|
columns_sample: list[Column] = []
|
@@ -141,52 +240,56 @@ def _read_samples_df(
|
|
141
240
|
)
|
142
241
|
|
143
242
|
# make sure eval_id is present
|
144
|
-
ensure_eval_id(columns_eval)
|
145
|
-
|
146
|
-
# determine how we will allocate progress
|
147
|
-
with import_progress("scanning logs", total=len(logs)) as (
|
148
|
-
p,
|
149
|
-
task_id,
|
150
|
-
):
|
243
|
+
columns_eval = list(ensure_eval_id(columns_eval))
|
151
244
|
|
152
|
-
|
153
|
-
|
245
|
+
# establish progress
|
246
|
+
progress_cm = (
|
247
|
+
import_progress("scanning logs", total=len(logs)) if progress else no_progress()
|
248
|
+
)
|
154
249
|
|
250
|
+
# determine how we will allocate progress
|
251
|
+
with progress_cm as p:
|
155
252
|
# read samples from each log
|
156
253
|
sample_records: list[dict[str, ColumnType]] = []
|
157
254
|
detail_records: list[dict[str, ColumnType]] = []
|
158
|
-
all_errors =
|
255
|
+
all_errors: list[ColumnError] = []
|
159
256
|
|
160
257
|
# read logs and note total samples
|
161
|
-
evals_table, total_samples = _read_evals_df(
|
162
|
-
logs, columns=columns_eval, strict=True, progress=
|
258
|
+
evals_table, eval_logs, total_samples = _read_evals_df(
|
259
|
+
logs, columns=columns_eval, strict=True, progress=p.update
|
163
260
|
)
|
164
261
|
|
165
262
|
# update progress now that we know the total samples
|
166
263
|
entity = detail.name if detail else "sample"
|
167
|
-
p.reset(
|
168
|
-
task_id, description=f"reading {entity}s", completed=0, total=total_samples
|
169
|
-
)
|
264
|
+
p.reset(description=f"reading {entity}s", completed=0, total=total_samples)
|
170
265
|
|
171
266
|
# read samples
|
172
|
-
for eval_id,
|
267
|
+
for eval_id, eval_log in zip(evals_table[EVAL_ID].to_list(), eval_logs):
|
173
268
|
# get a generator for the samples (might require reading the full log
|
174
269
|
# or might be fine to just read the summaries)
|
175
270
|
if require_full_samples:
|
176
271
|
samples: Generator[EvalSample | EvalSampleSummary, None, None] = (
|
177
272
|
read_eval_log_samples(
|
178
|
-
|
273
|
+
eval_log.location,
|
274
|
+
all_samples_required=False,
|
275
|
+
resolve_attachments=True,
|
179
276
|
)
|
180
277
|
)
|
181
278
|
else:
|
182
|
-
samples = (
|
279
|
+
samples = (
|
280
|
+
summary
|
281
|
+
for summary in read_eval_log_sample_summaries(eval_log.location)
|
282
|
+
)
|
183
283
|
for sample in samples:
|
184
284
|
if strict:
|
185
|
-
record = import_record(
|
285
|
+
record = import_record(
|
286
|
+
eval_log, sample, columns_sample, strict=True
|
287
|
+
)
|
186
288
|
else:
|
187
|
-
record, errors = import_record(
|
188
|
-
|
189
|
-
|
289
|
+
record, errors = import_record(
|
290
|
+
eval_log, sample, columns_sample, strict=False
|
291
|
+
)
|
292
|
+
all_errors.extend(errors)
|
190
293
|
|
191
294
|
# inject ids
|
192
295
|
sample_id = sample.uuid or auto_sample_id(eval_id, sample)
|
@@ -207,7 +310,11 @@ def _read_samples_df(
|
|
207
310
|
sample_messages_from_events(sample.events, detail.filter)
|
208
311
|
)
|
209
312
|
elif isinstance(detail, EventsDetail):
|
210
|
-
detail_items = [
|
313
|
+
detail_items = [
|
314
|
+
e
|
315
|
+
for e in sample.events
|
316
|
+
if detail.filter is None or detail.filter(e)
|
317
|
+
]
|
211
318
|
else:
|
212
319
|
detail_items = []
|
213
320
|
|
@@ -215,16 +322,13 @@ def _read_samples_df(
|
|
215
322
|
for index, item in enumerate(detail_items):
|
216
323
|
if strict:
|
217
324
|
detail_record = import_record(
|
218
|
-
item, columns_detail, strict=True
|
325
|
+
eval_log, item, columns_detail, strict=True
|
219
326
|
)
|
220
327
|
else:
|
221
328
|
detail_record, errors = import_record(
|
222
|
-
item, columns_detail, strict=False
|
223
|
-
)
|
224
|
-
error_key = (
|
225
|
-
f"{pretty_path(log)} [{sample.id}, {sample.epoch}]"
|
329
|
+
eval_log, item, columns_detail, strict=False
|
226
330
|
)
|
227
|
-
all_errors
|
331
|
+
all_errors.extend(errors)
|
228
332
|
|
229
333
|
# inject ids
|
230
334
|
detail_id = detail_record.get(
|
@@ -238,14 +342,20 @@ def _read_samples_df(
|
|
238
342
|
|
239
343
|
# record sample record
|
240
344
|
sample_records.append(record)
|
241
|
-
|
345
|
+
p.update()
|
242
346
|
|
243
347
|
# normalize records and produce samples table
|
244
348
|
samples_table = records_to_pandas(sample_records)
|
349
|
+
samples_table.drop_duplicates(
|
350
|
+
"sample_id", keep="first", inplace=True, ignore_index=True
|
351
|
+
)
|
245
352
|
|
246
353
|
# if we have detail records then join them into the samples table
|
247
354
|
if detail is not None:
|
248
355
|
details_table = records_to_pandas(detail_records)
|
356
|
+
details_table.drop_duplicates(
|
357
|
+
f"{detail.name}_id", keep="first", inplace=True, ignore_index=True
|
358
|
+
)
|
249
359
|
samples_table = details_table.merge(
|
250
360
|
samples_table,
|
251
361
|
on=SAMPLE_ID,
|
@@ -275,7 +385,7 @@ def _read_samples_df(
|
|
275
385
|
|
276
386
|
|
277
387
|
def sample_messages_from_events(
|
278
|
-
events: list[Event], filter: Callable[[ChatMessage], bool]
|
388
|
+
events: list[Event], filter: Callable[[ChatMessage], bool] | None
|
279
389
|
) -> list[ChatMessage]:
|
280
390
|
# don't yield the same event twice
|
281
391
|
ids: set[str] = set()
|
@@ -295,7 +405,7 @@ def sample_messages_from_events(
|
|
295
405
|
ids.add(id)
|
296
406
|
|
297
407
|
# then apply the filter
|
298
|
-
return [message for message in messages if filter(message)]
|
408
|
+
return [message for message in messages if filter is None or filter(message)]
|
299
409
|
|
300
410
|
|
301
411
|
@lru_cache(maxsize=100)
|
@@ -39,7 +39,7 @@ def verify_prerequisites() -> None:
|
|
39
39
|
raise pip_dependency_error("inspect_ai.analysis", required_packages)
|
40
40
|
|
41
41
|
# enforce version constraints
|
42
|
-
verify_required_version("inspect_ai.analysis", "pandas", "2.
|
42
|
+
verify_required_version("inspect_ai.analysis", "pandas", "2.1.0")
|
43
43
|
verify_required_version("inspect_ai.analysis", "pyarrow", "10.0.1")
|
44
44
|
|
45
45
|
|
@@ -141,20 +141,22 @@ def add_unreferenced_columns(
|
|
141
141
|
def records_to_pandas(records: list[dict[str, ColumnType]]) -> "pd.DataFrame":
|
142
142
|
import pyarrow as pa
|
143
143
|
|
144
|
+
# create arrow table
|
144
145
|
records = normalize_records(records)
|
145
|
-
table = pa.Table.from_pylist(records)
|
146
|
-
return table
|
146
|
+
table = pa.Table.from_pylist(records)
|
147
147
|
|
148
|
+
# convert arrow to pandas
|
149
|
+
df = table.to_pandas(types_mapper=arrow_types_mapper)
|
148
150
|
|
149
|
-
|
150
|
-
|
151
|
-
|
151
|
+
# swap numpy-backed nullable columns for arrow-backed equivalents
|
152
|
+
# df = df.convert_dtypes(dtype_backend="pyarrow")
|
153
|
+
return df
|
154
|
+
|
155
|
+
|
156
|
+
def arrow_types_mapper(arrow_type: pa.DataType) -> pd.ArrowDtype:
|
152
157
|
import pandas as pd
|
153
158
|
import pyarrow as pa
|
154
159
|
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
# default conversion for other types
|
159
|
-
else:
|
160
|
-
return None
|
160
|
+
if pa.types.is_null(arrow_type):
|
161
|
+
arrow_type = pa.string()
|
162
|
+
return pd.ArrowDtype(arrow_type)
|
inspect_ai/model/_call_tools.py
CHANGED
@@ -303,7 +303,7 @@ async def execute_tools(
|
|
303
303
|
)
|
304
304
|
result_messages.append(tool_message)
|
305
305
|
display_conversation_message(tool_message)
|
306
|
-
|
306
|
+
elif result is not None:
|
307
307
|
for message in result.messages:
|
308
308
|
result_messages.append(message)
|
309
309
|
display_conversation_message(message)
|
@@ -276,13 +276,25 @@ class AnthropicAPI(ModelAPI):
|
|
276
276
|
params = dict(model=self.service_model_name(), max_tokens=max_tokens)
|
277
277
|
headers: dict[str, str] = {}
|
278
278
|
betas: list[str] = []
|
279
|
-
|
280
|
-
|
281
|
-
|
279
|
+
|
280
|
+
# temperature not compatible with extended thinking
|
281
|
+
THINKING_WARNING = "anthropic models do not support the '{parameter}' parameter when using extended thinking."
|
282
|
+
if config.temperature is not None:
|
283
|
+
if self.is_using_thinking(config):
|
284
|
+
warn_once(logger, THINKING_WARNING.format(parameter="temperature"))
|
285
|
+
else:
|
282
286
|
params["temperature"] = config.temperature
|
283
|
-
|
287
|
+
# top_p not compatible with extended thinking
|
288
|
+
if config.top_p is not None:
|
289
|
+
if self.is_using_thinking(config):
|
290
|
+
warn_once(logger, THINKING_WARNING.format(parameter="top_p"))
|
291
|
+
else:
|
284
292
|
params["top_p"] = config.top_p
|
285
|
-
|
293
|
+
# top_k not compatible with extended thinking
|
294
|
+
if config.top_k is not None:
|
295
|
+
if self.is_using_thinking(config):
|
296
|
+
warn_once(logger, THINKING_WARNING.format(parameter="top_k"))
|
297
|
+
else:
|
286
298
|
params["top_k"] = config.top_k
|
287
299
|
|
288
300
|
# some thinking-only stuff
|
@@ -346,6 +358,7 @@ class AnthropicAPI(ModelAPI):
|
|
346
358
|
# for "overloaded_error" so we check for it explicitly
|
347
359
|
if (
|
348
360
|
isinstance(ex.body, dict)
|
361
|
+
and isinstance(ex.body.get("error", {}), dict)
|
349
362
|
and ex.body.get("error", {}).get("type", "") == "overloaded_error"
|
350
363
|
):
|
351
364
|
return True
|
@@ -138,6 +138,7 @@ class AzureAIAPI(ModelAPI):
|
|
138
138
|
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
|
139
139
|
# emulate tools (auto for llama, opt-in for others)
|
140
140
|
if self.emulate_tools is None and self.is_llama():
|
141
|
+
self.emulate_tools = True
|
141
142
|
handler: ChatAPIHandler | None = Llama31Handler(self.model_name)
|
142
143
|
elif self.emulate_tools:
|
143
144
|
handler = Llama31Handler(self.model_name)
|
@@ -151,10 +152,14 @@ class AzureAIAPI(ModelAPI):
|
|
151
152
|
# prepare request
|
152
153
|
request = dict(
|
153
154
|
messages=await chat_request_messages(input, handler),
|
154
|
-
tools=chat_tools(tools) if len(tools) > 0 else None,
|
155
|
-
tool_choice=chat_tool_choice(tool_choice) if len(tools) > 0 else None,
|
156
155
|
**self.completion_params(config),
|
157
156
|
)
|
157
|
+
# newer versions of vllm reject requests with tools or tool_choice if the
|
158
|
+
# server hasn't been started explicitly with the --tool-call-parser and
|
159
|
+
# --enable-auto-tool-choice flags
|
160
|
+
if (not self.emulate_tools) and len(tools) > 0:
|
161
|
+
request["tools"] = chat_tools(tools)
|
162
|
+
request["tool_choice"] = chat_tool_choice(tool_choice)
|
158
163
|
|
159
164
|
# create client (note the client needs to be created and closed
|
160
165
|
# with each call so it can be cleaned up and not end up on another
|
@@ -79,7 +79,7 @@ class Llama31Handler(ChatAPIHandler):
|
|
79
79
|
prompt that asks the model to use the <tool_call>...</tool_call> syntax)
|
80
80
|
"""
|
81
81
|
# extract tool calls
|
82
|
-
tool_call_regex = rf"<{TOOL_CALL}
|
82
|
+
tool_call_regex = rf"<{TOOL_CALL}s?>((?:.|\n)*?)</{TOOL_CALL}s?>"
|
83
83
|
tool_calls_content: list[str] = re.findall(tool_call_regex, response)
|
84
84
|
|
85
85
|
# if there are tool calls proceed with parsing
|
@@ -93,7 +93,7 @@ class Llama31Handler(ChatAPIHandler):
|
|
93
93
|
]
|
94
94
|
|
95
95
|
# find other content that exists outside tool calls
|
96
|
-
tool_call_content_regex = rf"<{TOOL_CALL}
|
96
|
+
tool_call_content_regex = rf"<{TOOL_CALL}s?>(?:.|\n)*?</{TOOL_CALL}s?>"
|
97
97
|
other_content = re.split(tool_call_content_regex, response, flags=re.DOTALL)
|
98
98
|
other_content = [
|
99
99
|
str(content).strip()
|
@@ -164,7 +164,7 @@ def parse_tool_call_content(content: str, tools: list[ToolInfo]) -> ToolCall:
|
|
164
164
|
# see if we can get the fields (if not report error)
|
165
165
|
name = tool_call_data.get("name", None)
|
166
166
|
arguments = tool_call_data.get("arguments", None)
|
167
|
-
if not name or
|
167
|
+
if not name or (arguments is None):
|
168
168
|
raise ValueError(
|
169
169
|
"Required 'name' and 'arguments' not provided in JSON dictionary."
|
170
170
|
)
|