inspect-ai 0.3.93__py3-none-any.whl → 0.3.95__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inspect_ai/_display/textual/widgets/samples.py +3 -3
- inspect_ai/_display/textual/widgets/transcript.py +3 -29
- inspect_ai/_eval/loader.py +1 -1
- inspect_ai/_eval/task/run.py +21 -12
- inspect_ai/_util/answer.py +26 -0
- inspect_ai/_util/constants.py +0 -1
- inspect_ai/_util/exception.py +4 -0
- inspect_ai/_util/hash.py +39 -0
- inspect_ai/_util/local_server.py +51 -21
- inspect_ai/_util/path.py +22 -0
- inspect_ai/_util/trace.py +1 -1
- inspect_ai/_util/working.py +4 -0
- inspect_ai/_view/www/dist/assets/index.css +23 -22
- inspect_ai/_view/www/dist/assets/index.js +517 -204
- inspect_ai/_view/www/log-schema.json +375 -0
- inspect_ai/_view/www/package.json +1 -1
- inspect_ai/_view/www/src/@types/log.d.ts +90 -12
- inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
- inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
- inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
- inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
- inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
- inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
- inspect_ai/_view/www/src/app/samples/transcript/SandboxEventView.module.css +2 -1
- inspect_ai/_view/www/src/app/samples/transcript/SpanEventView.tsx +174 -0
- inspect_ai/_view/www/src/app/samples/transcript/ToolEventView.tsx +8 -8
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptView.tsx +12 -2
- inspect_ai/_view/www/src/app/samples/transcript/TranscriptVirtualListComponent.module.css +1 -1
- inspect_ai/_view/www/src/app/samples/transcript/event/EventPanel.tsx +0 -3
- inspect_ai/_view/www/src/app/samples/transcript/transform/fixups.ts +87 -25
- inspect_ai/_view/www/src/app/samples/transcript/transform/treeify.ts +229 -17
- inspect_ai/_view/www/src/app/samples/transcript/transform/utils.ts +11 -0
- inspect_ai/_view/www/src/app/samples/transcript/types.ts +5 -1
- inspect_ai/_view/www/src/app/types.ts +12 -2
- inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
- inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
- inspect_ai/_view/www/src/state/hooks.ts +19 -3
- inspect_ai/_view/www/src/state/logSlice.ts +23 -5
- inspect_ai/_view/www/yarn.lock +9 -9
- inspect_ai/agent/_as_solver.py +3 -1
- inspect_ai/agent/_as_tool.py +6 -4
- inspect_ai/agent/_bridge/patch.py +1 -3
- inspect_ai/agent/_handoff.py +5 -1
- inspect_ai/agent/_react.py +4 -3
- inspect_ai/agent/_run.py +6 -1
- inspect_ai/agent/_types.py +9 -0
- inspect_ai/analysis/__init__.py +0 -0
- inspect_ai/analysis/beta/__init__.py +57 -0
- inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
- inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
- inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
- inspect_ai/analysis/beta/_dataframe/evals/table.py +140 -0
- inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/events/columns.py +37 -0
- inspect_ai/analysis/beta/_dataframe/events/table.py +14 -0
- inspect_ai/analysis/beta/_dataframe/extract.py +54 -0
- inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
- inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
- inspect_ai/analysis/beta/_dataframe/messages/table.py +87 -0
- inspect_ai/analysis/beta/_dataframe/record.py +377 -0
- inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
- inspect_ai/analysis/beta/_dataframe/samples/columns.py +73 -0
- inspect_ai/analysis/beta/_dataframe/samples/extract.py +82 -0
- inspect_ai/analysis/beta/_dataframe/samples/table.py +329 -0
- inspect_ai/analysis/beta/_dataframe/util.py +157 -0
- inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
- inspect_ai/dataset/_dataset.py +6 -3
- inspect_ai/log/__init__.py +10 -0
- inspect_ai/log/_convert.py +4 -9
- inspect_ai/log/_file.py +1 -1
- inspect_ai/log/_log.py +21 -1
- inspect_ai/log/_samples.py +14 -17
- inspect_ai/log/_transcript.py +77 -35
- inspect_ai/log/_tree.py +118 -0
- inspect_ai/model/_call_tools.py +44 -35
- inspect_ai/model/_model.py +51 -44
- inspect_ai/model/_openai_responses.py +17 -18
- inspect_ai/model/_providers/anthropic.py +30 -5
- inspect_ai/model/_providers/hf.py +27 -1
- inspect_ai/model/_providers/providers.py +1 -1
- inspect_ai/model/_providers/sglang.py +8 -2
- inspect_ai/model/_providers/vllm.py +6 -2
- inspect_ai/scorer/_choice.py +1 -2
- inspect_ai/solver/_chain.py +1 -1
- inspect_ai/solver/_fork.py +1 -1
- inspect_ai/solver/_multiple_choice.py +9 -23
- inspect_ai/solver/_plan.py +2 -2
- inspect_ai/solver/_task_state.py +7 -3
- inspect_ai/solver/_transcript.py +6 -7
- inspect_ai/tool/_mcp/_context.py +3 -5
- inspect_ai/tool/_mcp/_mcp.py +6 -5
- inspect_ai/tool/_mcp/server.py +1 -1
- inspect_ai/tool/_tools/_execute.py +4 -1
- inspect_ai/tool/_tools/_think.py +1 -1
- inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
- inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
- inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
- inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
- inspect_ai/util/__init__.py +4 -0
- inspect_ai/util/_anyio.py +11 -0
- inspect_ai/util/_collect.py +50 -0
- inspect_ai/util/_sandbox/events.py +3 -2
- inspect_ai/util/_span.py +58 -0
- inspect_ai/util/_subtask.py +27 -42
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/METADATA +8 -1
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/RECORD +114 -82
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/WHEEL +1 -1
- inspect_ai/_display/core/group.py +0 -79
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/entry_points.txt +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/licenses/LICENSE +0 -0
- {inspect_ai-0.3.93.dist-info → inspect_ai-0.3.95.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,82 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
|
3
|
+
from jsonpath_ng import JSONPath # type: ignore
|
4
|
+
from pydantic import JsonValue
|
5
|
+
|
6
|
+
from inspect_ai.analysis.beta._dataframe.extract import auto_id
|
7
|
+
from inspect_ai.log._log import EvalSample, EvalSampleSummary
|
8
|
+
from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool
|
9
|
+
|
10
|
+
|
11
|
+
def sample_messages_as_str(sample: EvalSample) -> str:
|
12
|
+
# format each message for the transcript
|
13
|
+
transcript: list[str] = []
|
14
|
+
for msg in sample.messages:
|
15
|
+
role = msg.role
|
16
|
+
content = msg.text.strip() if msg.text else ""
|
17
|
+
|
18
|
+
# assistant messages with tool calls
|
19
|
+
if isinstance(msg, ChatMessageAssistant) and msg.tool_calls is not None:
|
20
|
+
entry = f"{role}:\n{content}\n"
|
21
|
+
|
22
|
+
for tool in msg.tool_calls:
|
23
|
+
func_name = tool.function
|
24
|
+
args = tool.arguments
|
25
|
+
|
26
|
+
if isinstance(args, dict):
|
27
|
+
args_text = "\n".join(f"{k}: {v}" for k, v in args.items())
|
28
|
+
entry += f"\nTool Call: {func_name}\nArguments:\n{args_text}"
|
29
|
+
else:
|
30
|
+
entry += f"\nTool Call: {func_name}\nArguments: {args}"
|
31
|
+
|
32
|
+
transcript.append(entry)
|
33
|
+
|
34
|
+
# tool responses with errors
|
35
|
+
elif isinstance(msg, ChatMessageTool) and msg.error is not None:
|
36
|
+
func_name = msg.function or "unknown"
|
37
|
+
entry = f"{role}:\n{content}\n\nError in tool call '{func_name}':\n{msg.error.message}\n"
|
38
|
+
transcript.append(entry)
|
39
|
+
|
40
|
+
# normal messages
|
41
|
+
else:
|
42
|
+
transcript.append(f"{role}:\n{content}\n")
|
43
|
+
|
44
|
+
return "\n".join(transcript)
|
45
|
+
|
46
|
+
|
47
|
+
def sample_path_requires_full(
|
48
|
+
path: str
|
49
|
+
| JSONPath
|
50
|
+
| Callable[[EvalSampleSummary], JsonValue]
|
51
|
+
| Callable[[EvalSample], JsonValue],
|
52
|
+
) -> bool:
|
53
|
+
if callable(path):
|
54
|
+
return False
|
55
|
+
else:
|
56
|
+
path = str(path)
|
57
|
+
return any(
|
58
|
+
[
|
59
|
+
path.startswith(prefix)
|
60
|
+
for prefix in [
|
61
|
+
"choices",
|
62
|
+
"sandbox",
|
63
|
+
"files",
|
64
|
+
"setup",
|
65
|
+
"messages",
|
66
|
+
"output",
|
67
|
+
"store",
|
68
|
+
"events",
|
69
|
+
"uuid",
|
70
|
+
"error_retries",
|
71
|
+
"attachments",
|
72
|
+
]
|
73
|
+
]
|
74
|
+
)
|
75
|
+
|
76
|
+
|
77
|
+
def auto_sample_id(eval_id: str, sample: EvalSample | EvalSampleSummary) -> str:
|
78
|
+
return auto_id(eval_id, f"{sample.id}_{sample.epoch}")
|
79
|
+
|
80
|
+
|
81
|
+
def auto_detail_id(sample_id: str, name: str, index: int) -> str:
|
82
|
+
return auto_id(sample_id, f"{name}_{index}")
|
@@ -0,0 +1,329 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import (
|
5
|
+
TYPE_CHECKING,
|
6
|
+
Callable,
|
7
|
+
Generator,
|
8
|
+
Literal,
|
9
|
+
overload,
|
10
|
+
)
|
11
|
+
|
12
|
+
from inspect_ai._display import display
|
13
|
+
from inspect_ai._util.path import pretty_path
|
14
|
+
from inspect_ai.analysis.beta._dataframe.events.columns import EventColumn
|
15
|
+
from inspect_ai.analysis.beta._dataframe.messages.columns import MessageColumn
|
16
|
+
from inspect_ai.log._file import (
|
17
|
+
read_eval_log_sample_summaries,
|
18
|
+
read_eval_log_samples,
|
19
|
+
)
|
20
|
+
from inspect_ai.log._log import EvalSample, EvalSampleSummary
|
21
|
+
from inspect_ai.log._transcript import BaseEvent, Event
|
22
|
+
from inspect_ai.model._chat_message import ChatMessage
|
23
|
+
|
24
|
+
from ..columns import Column, ColumnErrors, ColumnType
|
25
|
+
from ..evals.columns import EvalColumn
|
26
|
+
from ..evals.table import EVAL_ID, EVAL_SUFFIX, ensure_eval_id, evals_df
|
27
|
+
from ..record import import_record, resolve_duplicate_columns
|
28
|
+
from ..util import (
|
29
|
+
LogPaths,
|
30
|
+
add_unreferenced_columns,
|
31
|
+
records_to_pandas,
|
32
|
+
resolve_columns,
|
33
|
+
resolve_logs,
|
34
|
+
verify_prerequisites,
|
35
|
+
)
|
36
|
+
from .columns import SampleColumn, SampleSummary
|
37
|
+
from .extract import auto_detail_id, auto_sample_id
|
38
|
+
|
39
|
+
if TYPE_CHECKING:
|
40
|
+
import pandas as pd
|
41
|
+
|
42
|
+
|
43
|
+
SAMPLE_ID = "sample_id"
|
44
|
+
SAMPLE_SUFFIX = "_sample"
|
45
|
+
|
46
|
+
|
47
|
+
@overload
|
48
|
+
def samples_df(
|
49
|
+
logs: LogPaths,
|
50
|
+
columns: list[Column] = SampleSummary,
|
51
|
+
recursive: bool = True,
|
52
|
+
reverse: bool = False,
|
53
|
+
strict: Literal[True] = True,
|
54
|
+
) -> "pd.DataFrame": ...
|
55
|
+
|
56
|
+
|
57
|
+
@overload
|
58
|
+
def samples_df(
|
59
|
+
logs: LogPaths,
|
60
|
+
columns: list[Column] = SampleSummary,
|
61
|
+
recursive: bool = True,
|
62
|
+
reverse: bool = False,
|
63
|
+
strict: Literal[False] = False,
|
64
|
+
) -> tuple["pd.DataFrame", ColumnErrors]: ...
|
65
|
+
|
66
|
+
|
67
|
+
def samples_df(
|
68
|
+
logs: LogPaths,
|
69
|
+
columns: list[Column] = SampleSummary,
|
70
|
+
recursive: bool = True,
|
71
|
+
reverse: bool = False,
|
72
|
+
strict: bool = True,
|
73
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", ColumnErrors]:
|
74
|
+
"""Read a dataframe containing samples from a set of evals.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
logs: One or more paths to log files or log directories.
|
78
|
+
columns: Specification for what columns to read from log files.
|
79
|
+
recursive: Include recursive contents of directories (defaults to `True`)
|
80
|
+
reverse: Reverse the order of the dataframe (by default, items
|
81
|
+
are ordered from oldest to newest).
|
82
|
+
strict: Raise import errors immediately. Defaults to `True`.
|
83
|
+
If `False` then a tuple of `DataFrame` and errors is returned.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
For `strict`, a Pandas `DataFrame` with information for the specified logs.
|
87
|
+
For `strict=False`, a tuple of Pandas `DataFrame` and a dictionary of errors
|
88
|
+
encountered (by log file) during import.
|
89
|
+
"""
|
90
|
+
return _read_samples_df(
|
91
|
+
logs, columns, recursive=recursive, reverse=reverse, strict=strict
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
@dataclass
|
96
|
+
class MessagesDetail:
|
97
|
+
name: str = "message"
|
98
|
+
col_type = MessageColumn
|
99
|
+
filter: Callable[[ChatMessage], bool] = lambda m: True
|
100
|
+
|
101
|
+
|
102
|
+
@dataclass
|
103
|
+
class EventsDetail:
|
104
|
+
name: str = "message"
|
105
|
+
col_type = EventColumn
|
106
|
+
filter: Callable[[BaseEvent], bool] = lambda e: True
|
107
|
+
|
108
|
+
|
109
|
+
def _read_samples_df(
|
110
|
+
logs: LogPaths,
|
111
|
+
columns: list[Column],
|
112
|
+
*,
|
113
|
+
recursive: bool = True,
|
114
|
+
reverse: bool = False,
|
115
|
+
strict: bool = True,
|
116
|
+
detail: MessagesDetail | EventsDetail | None = None,
|
117
|
+
) -> "pd.DataFrame" | tuple["pd.DataFrame", ColumnErrors]:
|
118
|
+
verify_prerequisites()
|
119
|
+
|
120
|
+
# resolve logs
|
121
|
+
logs = resolve_logs(logs, recursive=recursive, reverse=reverse)
|
122
|
+
|
123
|
+
# split columns by type
|
124
|
+
columns_eval: list[Column] = []
|
125
|
+
columns_sample: list[Column] = []
|
126
|
+
columns_detail: list[Column] = []
|
127
|
+
for column in columns:
|
128
|
+
if isinstance(column, EvalColumn):
|
129
|
+
columns_eval.append(column)
|
130
|
+
elif isinstance(column, SampleColumn):
|
131
|
+
columns_sample.append(column)
|
132
|
+
if column._full:
|
133
|
+
require_full_samples = True
|
134
|
+
elif detail and isinstance(column, detail.col_type):
|
135
|
+
columns_detail.append(column)
|
136
|
+
else:
|
137
|
+
raise ValueError(
|
138
|
+
f"Unexpected column type passed to samples_df: {type(column)}"
|
139
|
+
)
|
140
|
+
# resolve duplciates
|
141
|
+
columns_eval = resolve_duplicate_columns(columns_eval)
|
142
|
+
columns_sample = resolve_duplicate_columns(columns_sample)
|
143
|
+
columns_detail = resolve_duplicate_columns(columns_detail)
|
144
|
+
|
145
|
+
# determine if we require full samples
|
146
|
+
require_full_samples = len(columns_detail) > 0 or any(
|
147
|
+
[isinstance(column, SampleColumn) and column._full for column in columns_sample]
|
148
|
+
)
|
149
|
+
|
150
|
+
# make sure eval_id is present
|
151
|
+
ensure_eval_id(columns_eval)
|
152
|
+
|
153
|
+
# read samples from each log
|
154
|
+
sample_records: list[dict[str, ColumnType]] = []
|
155
|
+
detail_records: list[dict[str, ColumnType]] = []
|
156
|
+
all_errors = ColumnErrors()
|
157
|
+
evals_table = evals_df(logs, columns=columns_eval)
|
158
|
+
with display().progress(total=len(evals_table)) as p:
|
159
|
+
# read samples
|
160
|
+
for eval_id, log in zip(evals_table[EVAL_ID].to_list(), logs):
|
161
|
+
# get a generator for the samples (might require reading the full log
|
162
|
+
# or might be fine to just read the summaries)
|
163
|
+
if require_full_samples:
|
164
|
+
samples: Generator[EvalSample | EvalSampleSummary, None, None] = (
|
165
|
+
read_eval_log_samples(
|
166
|
+
log, all_samples_required=False, resolve_attachments=True
|
167
|
+
)
|
168
|
+
)
|
169
|
+
else:
|
170
|
+
samples = (summary for summary in read_eval_log_sample_summaries(log))
|
171
|
+
for sample in samples:
|
172
|
+
if strict:
|
173
|
+
record = import_record(sample, columns_sample, strict=True)
|
174
|
+
else:
|
175
|
+
record, errors = import_record(sample, columns_sample, strict=False)
|
176
|
+
error_key = f"{pretty_path(log)} [{sample.id}, {sample.epoch}]"
|
177
|
+
all_errors[error_key] = errors
|
178
|
+
|
179
|
+
# inject ids
|
180
|
+
sample_id = sample.uuid or auto_sample_id(eval_id, sample)
|
181
|
+
ids: dict[str, ColumnType] = {
|
182
|
+
EVAL_ID: eval_id,
|
183
|
+
SAMPLE_ID: sample_id,
|
184
|
+
}
|
185
|
+
|
186
|
+
# record with ids
|
187
|
+
record = ids | record
|
188
|
+
|
189
|
+
# if there are detail columns then we blow out these records w/ detail
|
190
|
+
if detail is not None:
|
191
|
+
# filter detail records
|
192
|
+
assert isinstance(sample, EvalSample)
|
193
|
+
if isinstance(detail, MessagesDetail):
|
194
|
+
detail_items: list[ChatMessage] | list[Event] = [
|
195
|
+
m for m in sample.messages if detail.filter(m)
|
196
|
+
]
|
197
|
+
elif isinstance(detail, EventsDetail):
|
198
|
+
detail_items = [e for e in sample.events if detail.filter(e)]
|
199
|
+
else:
|
200
|
+
detail_items = []
|
201
|
+
|
202
|
+
# read detail records (provide auto-ids)
|
203
|
+
for index, item in enumerate(detail_items):
|
204
|
+
if strict:
|
205
|
+
detail_record = import_record(
|
206
|
+
item, columns_detail, strict=True
|
207
|
+
)
|
208
|
+
else:
|
209
|
+
detail_record, errors = import_record(
|
210
|
+
item, columns_detail, strict=False
|
211
|
+
)
|
212
|
+
error_key = (
|
213
|
+
f"{pretty_path(log)} [{sample.id}, {sample.epoch}]"
|
214
|
+
)
|
215
|
+
all_errors[error_key] = errors
|
216
|
+
|
217
|
+
# inject ids
|
218
|
+
detail_id = detail_record.get(
|
219
|
+
"id", auto_detail_id(sample_id, detail.name, index)
|
220
|
+
)
|
221
|
+
ids = {SAMPLE_ID: sample_id, f"{detail.name}_id": detail_id}
|
222
|
+
detail_record = ids | detail_record
|
223
|
+
|
224
|
+
# append detail record
|
225
|
+
detail_records.append(detail_record)
|
226
|
+
|
227
|
+
# record sample record
|
228
|
+
sample_records.append(record)
|
229
|
+
p.update()
|
230
|
+
|
231
|
+
# normalize records and produce samples table
|
232
|
+
samples_table = records_to_pandas(sample_records)
|
233
|
+
|
234
|
+
# if we have detail records then join them into the samples table
|
235
|
+
if detail is not None:
|
236
|
+
details_table = records_to_pandas(detail_records)
|
237
|
+
samples_table = details_table.merge(
|
238
|
+
samples_table,
|
239
|
+
on=SAMPLE_ID,
|
240
|
+
how="left",
|
241
|
+
suffixes=(f"_{detail.name}", SAMPLE_SUFFIX),
|
242
|
+
)
|
243
|
+
|
244
|
+
# join eval_records
|
245
|
+
samples_table = samples_table.merge(
|
246
|
+
evals_table, on=EVAL_ID, how="left", suffixes=(SAMPLE_SUFFIX, EVAL_SUFFIX)
|
247
|
+
)
|
248
|
+
|
249
|
+
# re-order based on original specification
|
250
|
+
samples_table = reorder_samples_df_columns(
|
251
|
+
samples_table,
|
252
|
+
columns_eval,
|
253
|
+
columns_sample,
|
254
|
+
columns_detail,
|
255
|
+
detail.name if detail else "",
|
256
|
+
)
|
257
|
+
|
258
|
+
# return
|
259
|
+
if strict:
|
260
|
+
return samples_table
|
261
|
+
else:
|
262
|
+
return samples_table, all_errors
|
263
|
+
|
264
|
+
|
265
|
+
def reorder_samples_df_columns(
|
266
|
+
df: "pd.DataFrame",
|
267
|
+
eval_columns: list[Column],
|
268
|
+
sample_columns: list[Column],
|
269
|
+
detail_columns: list[Column],
|
270
|
+
details_name: str,
|
271
|
+
) -> "pd.DataFrame":
|
272
|
+
"""Reorder columns in the merged DataFrame.
|
273
|
+
|
274
|
+
Order with:
|
275
|
+
1. sample_id first
|
276
|
+
2. eval_id second
|
277
|
+
3. eval columns
|
278
|
+
4. sample columns
|
279
|
+
5. any remaining columns
|
280
|
+
"""
|
281
|
+
actual_columns = list(df.columns)
|
282
|
+
ordered_columns: list[str] = []
|
283
|
+
|
284
|
+
# detail first if we have detail
|
285
|
+
if details_name:
|
286
|
+
ordered_columns.append(f"{details_name}_id")
|
287
|
+
|
288
|
+
# sample_id first
|
289
|
+
if SAMPLE_ID in actual_columns:
|
290
|
+
ordered_columns.append(SAMPLE_ID)
|
291
|
+
|
292
|
+
# eval_id next
|
293
|
+
if EVAL_ID in actual_columns:
|
294
|
+
ordered_columns.append(EVAL_ID)
|
295
|
+
|
296
|
+
# eval columns
|
297
|
+
for column in eval_columns:
|
298
|
+
if column.name == EVAL_ID or column.name == SAMPLE_ID:
|
299
|
+
continue # Already handled
|
300
|
+
|
301
|
+
ordered_columns.extend(
|
302
|
+
resolve_columns(column.name, EVAL_SUFFIX, actual_columns, ordered_columns)
|
303
|
+
)
|
304
|
+
|
305
|
+
# then sample columns
|
306
|
+
for column in sample_columns:
|
307
|
+
if column.name == EVAL_ID or column.name == SAMPLE_ID:
|
308
|
+
continue # Already handled
|
309
|
+
|
310
|
+
ordered_columns.extend(
|
311
|
+
resolve_columns(column.name, SAMPLE_SUFFIX, actual_columns, ordered_columns)
|
312
|
+
)
|
313
|
+
|
314
|
+
# then detail columns
|
315
|
+
for column in detail_columns:
|
316
|
+
if column.name == EVAL_ID or column.name == SAMPLE_ID:
|
317
|
+
continue # Already handled
|
318
|
+
|
319
|
+
ordered_columns.extend(
|
320
|
+
resolve_columns(
|
321
|
+
column.name, f"_{details_name}", actual_columns, ordered_columns
|
322
|
+
)
|
323
|
+
)
|
324
|
+
|
325
|
+
# add any unreferenced columns
|
326
|
+
ordered_columns = add_unreferenced_columns(actual_columns, ordered_columns)
|
327
|
+
|
328
|
+
# reorder the DataFrame
|
329
|
+
return df[ordered_columns]
|
@@ -0,0 +1,157 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import re
|
4
|
+
from os import PathLike
|
5
|
+
from pathlib import Path
|
6
|
+
from re import Pattern
|
7
|
+
from typing import TYPE_CHECKING, Sequence, TypeAlias
|
8
|
+
|
9
|
+
from inspect_ai._util.error import pip_dependency_error
|
10
|
+
from inspect_ai._util.file import FileInfo, filesystem
|
11
|
+
from inspect_ai._util.version import verify_required_version
|
12
|
+
from inspect_ai.log._file import log_files_from_ls
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
import pandas as pd
|
16
|
+
import pyarrow as pa
|
17
|
+
|
18
|
+
from .columns import ColumnType
|
19
|
+
|
20
|
+
LogPaths: TypeAlias = PathLike[str] | str | Sequence[PathLike[str] | str]
|
21
|
+
|
22
|
+
|
23
|
+
def verify_prerequisites() -> None:
|
24
|
+
# ensure we have all of the optional packages we need
|
25
|
+
required_packages: list[str] = []
|
26
|
+
try:
|
27
|
+
import pandas # noqa: F401
|
28
|
+
except ImportError:
|
29
|
+
required_packages.append("pandas")
|
30
|
+
|
31
|
+
try:
|
32
|
+
import pyarrow # noqa: F401
|
33
|
+
except ImportError:
|
34
|
+
required_packages.append("pyarrow")
|
35
|
+
|
36
|
+
if len(required_packages) > 0:
|
37
|
+
raise pip_dependency_error("inspect_ai.analysis", required_packages)
|
38
|
+
|
39
|
+
# enforce version constraints
|
40
|
+
verify_required_version("inspect_ai.analysis", "pandas", "2.0.0")
|
41
|
+
verify_required_version("inspect_ai.analysis", "pyarrow", "10.0.1")
|
42
|
+
|
43
|
+
|
44
|
+
def resolve_logs(logs: LogPaths, recursive: bool, reverse: bool) -> list[str]:
|
45
|
+
# normalize to list of str
|
46
|
+
logs = [logs] if isinstance(logs, str | PathLike) else logs
|
47
|
+
logs = [Path(log).as_posix() if isinstance(log, PathLike) else log for log in logs]
|
48
|
+
|
49
|
+
# expand directories
|
50
|
+
log_paths: list[FileInfo] = []
|
51
|
+
for log in logs:
|
52
|
+
if isinstance(log, PathLike):
|
53
|
+
log = Path(log).as_posix()
|
54
|
+
fs = filesystem(log)
|
55
|
+
info = fs.info(log)
|
56
|
+
if info.type == "directory":
|
57
|
+
log_paths.extend(
|
58
|
+
[
|
59
|
+
fi
|
60
|
+
for fi in fs.ls(info.name, recursive=recursive)
|
61
|
+
if fi.type == "file"
|
62
|
+
]
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
log_paths.append(info)
|
66
|
+
|
67
|
+
log_files = log_files_from_ls(log_paths, descending=reverse)
|
68
|
+
return [log_file.name for log_file in log_files]
|
69
|
+
|
70
|
+
|
71
|
+
def normalize_records(
|
72
|
+
records: list[dict[str, ColumnType]],
|
73
|
+
) -> list[dict[str, ColumnType]]:
|
74
|
+
all_keys: set[str] = set()
|
75
|
+
for record in records:
|
76
|
+
all_keys.update(record.keys())
|
77
|
+
normalized_records = []
|
78
|
+
for record in records:
|
79
|
+
normalized_record = {key: record.get(key, None) for key in all_keys}
|
80
|
+
normalized_records.append(normalized_record)
|
81
|
+
return normalized_records
|
82
|
+
|
83
|
+
|
84
|
+
def resolve_columns(
|
85
|
+
col_pattern: str, suffix: str, columns: list[str], processed_columns: list[str]
|
86
|
+
) -> list[str]:
|
87
|
+
resolved_columns: list[str] = []
|
88
|
+
|
89
|
+
if "*" not in col_pattern:
|
90
|
+
# Regular column - check with suffix
|
91
|
+
col_with_suffix = f"{col_pattern}{suffix}"
|
92
|
+
if col_with_suffix in columns and col_with_suffix not in processed_columns:
|
93
|
+
resolved_columns.append(col_with_suffix)
|
94
|
+
# Then without suffix
|
95
|
+
elif col_pattern in columns and col_pattern not in processed_columns:
|
96
|
+
resolved_columns.append(col_pattern)
|
97
|
+
else:
|
98
|
+
# Wildcard pattern - check both with and without suffix
|
99
|
+
suffix_pattern = col_pattern + suffix
|
100
|
+
matching_with_suffix = match_col_pattern(
|
101
|
+
suffix_pattern, columns, processed_columns
|
102
|
+
)
|
103
|
+
matching_without_suffix = match_col_pattern(
|
104
|
+
col_pattern, columns, processed_columns
|
105
|
+
)
|
106
|
+
|
107
|
+
# Add all matches
|
108
|
+
matched_columns = sorted(set(matching_with_suffix + matching_without_suffix))
|
109
|
+
resolved_columns.extend(matched_columns)
|
110
|
+
|
111
|
+
return resolved_columns
|
112
|
+
|
113
|
+
|
114
|
+
def match_col_pattern(
|
115
|
+
pattern: str, columns: list[str], processed_columns: list[str]
|
116
|
+
) -> list[str]:
|
117
|
+
regex = _col_pattern_to_regex(pattern)
|
118
|
+
return [c for c in columns if regex.match(c) and c not in processed_columns]
|
119
|
+
|
120
|
+
|
121
|
+
def _col_pattern_to_regex(pattern: str) -> Pattern[str]:
|
122
|
+
parts = []
|
123
|
+
for part in re.split(r"(\*)", pattern):
|
124
|
+
if part == "*":
|
125
|
+
parts.append(".*")
|
126
|
+
else:
|
127
|
+
parts.append(re.escape(part))
|
128
|
+
return re.compile("^" + "".join(parts) + "$")
|
129
|
+
|
130
|
+
|
131
|
+
def add_unreferenced_columns(
|
132
|
+
columns: list[str], referenced_columns: list[str]
|
133
|
+
) -> list[str]:
|
134
|
+
unreferenced_columns = sorted([c for c in columns if c not in referenced_columns])
|
135
|
+
return referenced_columns + unreferenced_columns
|
136
|
+
|
137
|
+
|
138
|
+
def records_to_pandas(records: list[dict[str, ColumnType]]) -> "pd.DataFrame":
|
139
|
+
import pyarrow as pa
|
140
|
+
|
141
|
+
records = normalize_records(records)
|
142
|
+
table = pa.Table.from_pylist(records).to_pandas(types_mapper=arrow_types_mapper)
|
143
|
+
return table
|
144
|
+
|
145
|
+
|
146
|
+
def arrow_types_mapper(
|
147
|
+
arrow_type: "pa.DataType",
|
148
|
+
) -> "pd.api.extensions.ExtensionDtype" | None:
|
149
|
+
import pandas as pd
|
150
|
+
import pyarrow as pa
|
151
|
+
|
152
|
+
# convert str => str
|
153
|
+
if pa.types.is_string(arrow_type):
|
154
|
+
return pd.StringDtype()
|
155
|
+
# default conversion for other types
|
156
|
+
else:
|
157
|
+
return None
|