inspect-ai 0.3.94__py3-none-any.whl → 0.3.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. inspect_ai/_eval/loader.py +1 -1
  2. inspect_ai/_eval/task/run.py +12 -6
  3. inspect_ai/_util/exception.py +4 -0
  4. inspect_ai/_util/hash.py +39 -0
  5. inspect_ai/_util/local_server.py +16 -0
  6. inspect_ai/_util/path.py +22 -0
  7. inspect_ai/_util/trace.py +1 -1
  8. inspect_ai/_util/working.py +4 -0
  9. inspect_ai/_view/www/dist/assets/index.css +9 -9
  10. inspect_ai/_view/www/dist/assets/index.js +117 -120
  11. inspect_ai/_view/www/package.json +1 -1
  12. inspect_ai/_view/www/src/app/log-view/navbar/SecondaryBar.tsx +2 -2
  13. inspect_ai/_view/www/src/app/log-view/tabs/SamplesTab.tsx +1 -4
  14. inspect_ai/_view/www/src/app/samples/SamplesTools.tsx +3 -13
  15. inspect_ai/_view/www/src/app/samples/sample-tools/SelectScorer.tsx +45 -48
  16. inspect_ai/_view/www/src/app/samples/sample-tools/filters.ts +16 -15
  17. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/SampleFilter.tsx +47 -75
  18. inspect_ai/_view/www/src/app/samples/sample-tools/sample-filter/completions.ts +9 -9
  19. inspect_ai/_view/www/src/app/types.ts +12 -2
  20. inspect_ai/_view/www/src/components/ExpandablePanel.module.css +1 -1
  21. inspect_ai/_view/www/src/components/ExpandablePanel.tsx +5 -5
  22. inspect_ai/_view/www/src/state/hooks.ts +19 -3
  23. inspect_ai/_view/www/src/state/logSlice.ts +23 -5
  24. inspect_ai/_view/www/yarn.lock +9 -9
  25. inspect_ai/agent/_bridge/patch.py +1 -3
  26. inspect_ai/agent/_types.py +1 -1
  27. inspect_ai/analysis/__init__.py +0 -0
  28. inspect_ai/analysis/beta/__init__.py +67 -0
  29. inspect_ai/analysis/beta/_dataframe/__init__.py +0 -0
  30. inspect_ai/analysis/beta/_dataframe/columns.py +145 -0
  31. inspect_ai/analysis/beta/_dataframe/evals/__init__.py +0 -0
  32. inspect_ai/analysis/beta/_dataframe/evals/columns.py +132 -0
  33. inspect_ai/analysis/beta/_dataframe/evals/extract.py +23 -0
  34. inspect_ai/analysis/beta/_dataframe/evals/table.py +177 -0
  35. inspect_ai/analysis/beta/_dataframe/events/__init__.py +0 -0
  36. inspect_ai/analysis/beta/_dataframe/events/columns.py +87 -0
  37. inspect_ai/analysis/beta/_dataframe/events/extract.py +26 -0
  38. inspect_ai/analysis/beta/_dataframe/events/table.py +100 -0
  39. inspect_ai/analysis/beta/_dataframe/extract.py +73 -0
  40. inspect_ai/analysis/beta/_dataframe/messages/__init__.py +0 -0
  41. inspect_ai/analysis/beta/_dataframe/messages/columns.py +60 -0
  42. inspect_ai/analysis/beta/_dataframe/messages/extract.py +21 -0
  43. inspect_ai/analysis/beta/_dataframe/messages/table.py +79 -0
  44. inspect_ai/analysis/beta/_dataframe/progress.py +26 -0
  45. inspect_ai/analysis/beta/_dataframe/record.py +377 -0
  46. inspect_ai/analysis/beta/_dataframe/samples/__init__.py +0 -0
  47. inspect_ai/analysis/beta/_dataframe/samples/columns.py +77 -0
  48. inspect_ai/analysis/beta/_dataframe/samples/extract.py +54 -0
  49. inspect_ai/analysis/beta/_dataframe/samples/table.py +370 -0
  50. inspect_ai/analysis/beta/_dataframe/util.py +160 -0
  51. inspect_ai/analysis/beta/_dataframe/validate.py +171 -0
  52. inspect_ai/log/_file.py +10 -3
  53. inspect_ai/log/_log.py +21 -1
  54. inspect_ai/model/_call_tools.py +2 -1
  55. inspect_ai/model/_model.py +6 -4
  56. inspect_ai/model/_openai_responses.py +17 -18
  57. inspect_ai/model/_providers/anthropic.py +30 -5
  58. inspect_ai/model/_providers/providers.py +1 -1
  59. inspect_ai/solver/_multiple_choice.py +4 -1
  60. inspect_ai/solver/_task_state.py +8 -4
  61. inspect_ai/tool/_mcp/_context.py +3 -5
  62. inspect_ai/tool/_mcp/_sandbox.py +17 -14
  63. inspect_ai/tool/_mcp/server.py +1 -1
  64. inspect_ai/tool/_tools/_think.py +1 -1
  65. inspect_ai/tool/_tools/_web_search/__init__.py +3 -0
  66. inspect_ai/tool/_tools/{_web_search.py → _web_search/_google.py} +56 -103
  67. inspect_ai/tool/_tools/_web_search/_tavily.py +77 -0
  68. inspect_ai/tool/_tools/_web_search/_web_search.py +85 -0
  69. inspect_ai/util/_sandbox/events.py +3 -2
  70. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/METADATA +9 -2
  71. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/RECORD +75 -46
  72. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/WHEEL +1 -1
  73. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/entry_points.txt +0 -0
  74. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/licenses/LICENSE +0 -0
  75. {inspect_ai-0.3.94.dist-info → inspect_ai-0.3.96.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,370 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Callable,
8
+ Generator,
9
+ Literal,
10
+ overload,
11
+ )
12
+
13
+ from inspect_ai._util.hash import mm3_hash
14
+ from inspect_ai._util.path import pretty_path
15
+ from inspect_ai.analysis.beta._dataframe.progress import import_progress
16
+ from inspect_ai.log._file import (
17
+ list_eval_logs,
18
+ read_eval_log_sample_summaries,
19
+ read_eval_log_samples,
20
+ )
21
+ from inspect_ai.log._log import EvalSample, EvalSampleSummary
22
+ from inspect_ai.log._transcript import Event
23
+ from inspect_ai.model._chat_message import ChatMessage
24
+
25
+ from ..columns import Column, ColumnErrors, ColumnType
26
+ from ..evals.columns import EvalColumn
27
+ from ..evals.table import EVAL_ID, EVAL_SUFFIX, _read_evals_df, ensure_eval_id
28
+ from ..events.columns import EventColumn
29
+ from ..extract import message_as_str
30
+ from ..messages.columns import MessageColumn
31
+ from ..record import import_record, resolve_duplicate_columns
32
+ from ..util import (
33
+ LogPaths,
34
+ add_unreferenced_columns,
35
+ records_to_pandas,
36
+ resolve_columns,
37
+ resolve_logs,
38
+ verify_prerequisites,
39
+ )
40
+ from .columns import SampleColumn, SampleSummary
41
+ from .extract import auto_detail_id, auto_sample_id
42
+
43
+ if TYPE_CHECKING:
44
+ import pandas as pd
45
+
46
+
47
+ SAMPLE_ID = "sample_id"
48
+ SAMPLE_SUFFIX = "_sample"
49
+
50
+
51
+ @overload
52
+ def samples_df(
53
+ logs: LogPaths = list_eval_logs(),
54
+ columns: list[Column] = SampleSummary,
55
+ strict: Literal[True] = True,
56
+ ) -> "pd.DataFrame": ...
57
+
58
+
59
+ @overload
60
+ def samples_df(
61
+ logs: LogPaths = list_eval_logs(),
62
+ columns: list[Column] = SampleSummary,
63
+ strict: Literal[False] = False,
64
+ ) -> tuple["pd.DataFrame", ColumnErrors]: ...
65
+
66
+
67
+ def samples_df(
68
+ logs: LogPaths = list_eval_logs(),
69
+ columns: list[Column] = SampleSummary,
70
+ strict: bool = True,
71
+ ) -> "pd.DataFrame" | tuple["pd.DataFrame", ColumnErrors]:
72
+ """Read a dataframe containing samples from a set of evals.
73
+
74
+ Args:
75
+ logs: One or more paths to log files or log directories.
76
+ Defaults to the contents of the currently active log directory
77
+ (e.g. ./logs or INSPECT_LOG_DIR).
78
+ columns: Specification for what columns to read from log files.
79
+ strict: Raise import errors immediately. Defaults to `True`.
80
+ If `False` then a tuple of `DataFrame` and errors is returned.
81
+
82
+ Returns:
83
+ For `strict`, a Pandas `DataFrame` with information for the specified logs.
84
+ For `strict=False`, a tuple of Pandas `DataFrame` and a dictionary of errors
85
+ encountered (by log file) during import.
86
+ """
87
+ return _read_samples_df(logs, columns, strict=strict)
88
+
89
+
90
+ @dataclass
91
+ class MessagesDetail:
92
+ name: str = "message"
93
+ col_type = MessageColumn
94
+ filter: Callable[[ChatMessage], bool] = lambda m: True
95
+
96
+
97
+ @dataclass
98
+ class EventsDetail:
99
+ name: str = "event"
100
+ col_type = EventColumn
101
+ filter: Callable[[Event], bool] = lambda e: True
102
+
103
+
104
+ def _read_samples_df(
105
+ logs: LogPaths,
106
+ columns: list[Column],
107
+ *,
108
+ strict: bool = True,
109
+ detail: MessagesDetail | EventsDetail | None = None,
110
+ ) -> "pd.DataFrame" | tuple["pd.DataFrame", ColumnErrors]:
111
+ verify_prerequisites()
112
+
113
+ # resolve logs
114
+ logs = resolve_logs(logs)
115
+
116
+ # split columns by type
117
+ columns_eval: list[Column] = []
118
+ columns_sample: list[Column] = []
119
+ columns_detail: list[Column] = []
120
+ for column in columns:
121
+ if isinstance(column, EvalColumn):
122
+ columns_eval.append(column)
123
+ elif isinstance(column, SampleColumn):
124
+ columns_sample.append(column)
125
+ if column._full:
126
+ require_full_samples = True
127
+ elif detail and isinstance(column, detail.col_type):
128
+ columns_detail.append(column)
129
+ else:
130
+ raise ValueError(
131
+ f"Unexpected column type passed to samples_df: {type(column)}"
132
+ )
133
+ # resolve duplciates
134
+ columns_eval = resolve_duplicate_columns(columns_eval)
135
+ columns_sample = resolve_duplicate_columns(columns_sample)
136
+ columns_detail = resolve_duplicate_columns(columns_detail)
137
+
138
+ # determine if we require full samples
139
+ require_full_samples = len(columns_detail) > 0 or any(
140
+ [isinstance(column, SampleColumn) and column._full for column in columns_sample]
141
+ )
142
+
143
+ # make sure eval_id is present
144
+ ensure_eval_id(columns_eval)
145
+
146
+ # determine how we will allocate progress
147
+ with import_progress("scanning logs", total=len(logs)) as (
148
+ p,
149
+ task_id,
150
+ ):
151
+
152
+ def progress() -> None:
153
+ p.update(task_id, advance=1)
154
+
155
+ # read samples from each log
156
+ sample_records: list[dict[str, ColumnType]] = []
157
+ detail_records: list[dict[str, ColumnType]] = []
158
+ all_errors = ColumnErrors()
159
+
160
+ # read logs and note total samples
161
+ evals_table, total_samples = _read_evals_df(
162
+ logs, columns=columns_eval, strict=True, progress=progress
163
+ )
164
+
165
+ # update progress now that we know the total samples
166
+ entity = detail.name if detail else "sample"
167
+ p.reset(
168
+ task_id, description=f"reading {entity}s", completed=0, total=total_samples
169
+ )
170
+
171
+ # read samples
172
+ for eval_id, log in zip(evals_table[EVAL_ID].to_list(), logs):
173
+ # get a generator for the samples (might require reading the full log
174
+ # or might be fine to just read the summaries)
175
+ if require_full_samples:
176
+ samples: Generator[EvalSample | EvalSampleSummary, None, None] = (
177
+ read_eval_log_samples(
178
+ log, all_samples_required=False, resolve_attachments=True
179
+ )
180
+ )
181
+ else:
182
+ samples = (summary for summary in read_eval_log_sample_summaries(log))
183
+ for sample in samples:
184
+ if strict:
185
+ record = import_record(sample, columns_sample, strict=True)
186
+ else:
187
+ record, errors = import_record(sample, columns_sample, strict=False)
188
+ error_key = f"{pretty_path(log)} [{sample.id}, {sample.epoch}]"
189
+ all_errors[error_key] = errors
190
+
191
+ # inject ids
192
+ sample_id = sample.uuid or auto_sample_id(eval_id, sample)
193
+ ids: dict[str, ColumnType] = {
194
+ EVAL_ID: eval_id,
195
+ SAMPLE_ID: sample_id,
196
+ }
197
+
198
+ # record with ids
199
+ record = ids | record
200
+
201
+ # if there are detail columns then we blow out these records w/ detail
202
+ if detail is not None:
203
+ # filter detail records
204
+ assert isinstance(sample, EvalSample)
205
+ if isinstance(detail, MessagesDetail):
206
+ detail_items: list[ChatMessage] | list[Event] = (
207
+ sample_messages_from_events(sample.events, detail.filter)
208
+ )
209
+ elif isinstance(detail, EventsDetail):
210
+ detail_items = [e for e in sample.events if detail.filter(e)]
211
+ else:
212
+ detail_items = []
213
+
214
+ # read detail records (provide auto-ids)
215
+ for index, item in enumerate(detail_items):
216
+ if strict:
217
+ detail_record = import_record(
218
+ item, columns_detail, strict=True
219
+ )
220
+ else:
221
+ detail_record, errors = import_record(
222
+ item, columns_detail, strict=False
223
+ )
224
+ error_key = (
225
+ f"{pretty_path(log)} [{sample.id}, {sample.epoch}]"
226
+ )
227
+ all_errors[error_key] = errors
228
+
229
+ # inject ids
230
+ detail_id = detail_record.get(
231
+ "id", auto_detail_id(sample_id, detail.name, index)
232
+ )
233
+ ids = {SAMPLE_ID: sample_id, f"{detail.name}_id": detail_id}
234
+ detail_record = ids | detail_record
235
+
236
+ # append detail record
237
+ detail_records.append(detail_record)
238
+
239
+ # record sample record
240
+ sample_records.append(record)
241
+ progress()
242
+
243
+ # normalize records and produce samples table
244
+ samples_table = records_to_pandas(sample_records)
245
+
246
+ # if we have detail records then join them into the samples table
247
+ if detail is not None:
248
+ details_table = records_to_pandas(detail_records)
249
+ samples_table = details_table.merge(
250
+ samples_table,
251
+ on=SAMPLE_ID,
252
+ how="left",
253
+ suffixes=(f"_{detail.name}", SAMPLE_SUFFIX),
254
+ )
255
+
256
+ # join eval_records
257
+ samples_table = samples_table.merge(
258
+ evals_table, on=EVAL_ID, how="left", suffixes=(SAMPLE_SUFFIX, EVAL_SUFFIX)
259
+ )
260
+
261
+ # re-order based on original specification
262
+ samples_table = reorder_samples_df_columns(
263
+ samples_table,
264
+ columns_eval,
265
+ columns_sample,
266
+ columns_detail,
267
+ detail.name if detail else "",
268
+ )
269
+
270
+ # return
271
+ if strict:
272
+ return samples_table
273
+ else:
274
+ return samples_table, all_errors
275
+
276
+
277
+ def sample_messages_from_events(
278
+ events: list[Event], filter: Callable[[ChatMessage], bool]
279
+ ) -> list[ChatMessage]:
280
+ # don't yield the same event twice
281
+ ids: set[str] = set()
282
+
283
+ # we need to look at the full input to every model event and add
284
+ # messages we haven't seen before
285
+ messages: list[ChatMessage] = []
286
+ for event in events:
287
+ if event.event == "model":
288
+ event_messages = event.input + (
289
+ [event.output.message] if not event.output.empty else []
290
+ )
291
+ for message in event_messages:
292
+ id = message.id or message_hash(message_as_str(message))
293
+ if id not in ids:
294
+ messages.append(message)
295
+ ids.add(id)
296
+
297
+ # then apply the filter
298
+ return [message for message in messages if filter(message)]
299
+
300
+
301
+ @lru_cache(maxsize=100)
302
+ def message_hash(message: str) -> str:
303
+ return mm3_hash(message)
304
+
305
+
306
+ def reorder_samples_df_columns(
307
+ df: "pd.DataFrame",
308
+ eval_columns: list[Column],
309
+ sample_columns: list[Column],
310
+ detail_columns: list[Column],
311
+ details_name: str,
312
+ ) -> "pd.DataFrame":
313
+ """Reorder columns in the merged DataFrame.
314
+
315
+ Order with:
316
+ 1. sample_id first
317
+ 2. eval_id second
318
+ 3. eval columns
319
+ 4. sample columns
320
+ 5. any remaining columns
321
+ """
322
+ actual_columns = list(df.columns)
323
+ ordered_columns: list[str] = []
324
+
325
+ # detail first if we have detail
326
+ if details_name:
327
+ ordered_columns.append(f"{details_name}_id")
328
+
329
+ # sample_id first
330
+ if SAMPLE_ID in actual_columns:
331
+ ordered_columns.append(SAMPLE_ID)
332
+
333
+ # eval_id next
334
+ if EVAL_ID in actual_columns:
335
+ ordered_columns.append(EVAL_ID)
336
+
337
+ # eval columns
338
+ for column in eval_columns:
339
+ if column.name == EVAL_ID or column.name == SAMPLE_ID:
340
+ continue # Already handled
341
+
342
+ ordered_columns.extend(
343
+ resolve_columns(column.name, EVAL_SUFFIX, actual_columns, ordered_columns)
344
+ )
345
+
346
+ # then sample columns
347
+ for column in sample_columns:
348
+ if column.name == EVAL_ID or column.name == SAMPLE_ID:
349
+ continue # Already handled
350
+
351
+ ordered_columns.extend(
352
+ resolve_columns(column.name, SAMPLE_SUFFIX, actual_columns, ordered_columns)
353
+ )
354
+
355
+ # then detail columns
356
+ for column in detail_columns:
357
+ if column.name == EVAL_ID or column.name == SAMPLE_ID:
358
+ continue # Already handled
359
+
360
+ ordered_columns.extend(
361
+ resolve_columns(
362
+ column.name, f"_{details_name}", actual_columns, ordered_columns
363
+ )
364
+ )
365
+
366
+ # add any unreferenced columns
367
+ ordered_columns = add_unreferenced_columns(actual_columns, ordered_columns)
368
+
369
+ # reorder the DataFrame
370
+ return df[ordered_columns]
@@ -0,0 +1,160 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from os import PathLike
5
+ from pathlib import Path
6
+ from re import Pattern
7
+ from typing import TYPE_CHECKING, Sequence, TypeAlias
8
+
9
+ from inspect_ai._util.error import pip_dependency_error
10
+ from inspect_ai._util.file import FileInfo, filesystem
11
+ from inspect_ai._util.version import verify_required_version
12
+ from inspect_ai.log._file import EvalLogInfo, log_files_from_ls
13
+
14
+ if TYPE_CHECKING:
15
+ import pandas as pd
16
+ import pyarrow as pa
17
+
18
+ from .columns import ColumnType
19
+
20
+ LogPaths: TypeAlias = (
21
+ PathLike[str] | str | EvalLogInfo | Sequence[PathLike[str] | str | EvalLogInfo]
22
+ )
23
+
24
+
25
+ def verify_prerequisites() -> None:
26
+ # ensure we have all of the optional packages we need
27
+ required_packages: list[str] = []
28
+ try:
29
+ import pandas # noqa: F401
30
+ except ImportError:
31
+ required_packages.append("pandas")
32
+
33
+ try:
34
+ import pyarrow # noqa: F401
35
+ except ImportError:
36
+ required_packages.append("pyarrow")
37
+
38
+ if len(required_packages) > 0:
39
+ raise pip_dependency_error("inspect_ai.analysis", required_packages)
40
+
41
+ # enforce version constraints
42
+ verify_required_version("inspect_ai.analysis", "pandas", "2.0.0")
43
+ verify_required_version("inspect_ai.analysis", "pyarrow", "10.0.1")
44
+
45
+
46
+ def resolve_logs(logs: LogPaths) -> list[str]:
47
+ # normalize to list of str
48
+ logs = [logs] if isinstance(logs, str | PathLike | EvalLogInfo) else logs
49
+ logs_str = [
50
+ Path(log).as_posix()
51
+ if isinstance(log, PathLike)
52
+ else log.name
53
+ if isinstance(log, EvalLogInfo)
54
+ else log
55
+ for log in logs
56
+ ]
57
+
58
+ # expand directories
59
+ log_paths: list[FileInfo] = []
60
+ for log_str in logs_str:
61
+ fs = filesystem(log_str)
62
+ info = fs.info(log_str)
63
+ if info.type == "directory":
64
+ log_paths.extend(
65
+ [fi for fi in fs.ls(info.name, recursive=True) if fi.type == "file"]
66
+ )
67
+ else:
68
+ log_paths.append(info)
69
+
70
+ log_files = log_files_from_ls(log_paths, sort=False)
71
+ return [log_file.name for log_file in log_files]
72
+
73
+
74
+ def normalize_records(
75
+ records: list[dict[str, ColumnType]],
76
+ ) -> list[dict[str, ColumnType]]:
77
+ all_keys: set[str] = set()
78
+ for record in records:
79
+ all_keys.update(record.keys())
80
+ normalized_records = []
81
+ for record in records:
82
+ normalized_record = {key: record.get(key, None) for key in all_keys}
83
+ normalized_records.append(normalized_record)
84
+ return normalized_records
85
+
86
+
87
+ def resolve_columns(
88
+ col_pattern: str, suffix: str, columns: list[str], processed_columns: list[str]
89
+ ) -> list[str]:
90
+ resolved_columns: list[str] = []
91
+
92
+ if "*" not in col_pattern:
93
+ # Regular column - check with suffix
94
+ col_with_suffix = f"{col_pattern}{suffix}"
95
+ if col_with_suffix in columns and col_with_suffix not in processed_columns:
96
+ resolved_columns.append(col_with_suffix)
97
+ # Then without suffix
98
+ elif col_pattern in columns and col_pattern not in processed_columns:
99
+ resolved_columns.append(col_pattern)
100
+ else:
101
+ # Wildcard pattern - check both with and without suffix
102
+ suffix_pattern = col_pattern + suffix
103
+ matching_with_suffix = match_col_pattern(
104
+ suffix_pattern, columns, processed_columns
105
+ )
106
+ matching_without_suffix = match_col_pattern(
107
+ col_pattern, columns, processed_columns
108
+ )
109
+
110
+ # Add all matches
111
+ matched_columns = sorted(set(matching_with_suffix + matching_without_suffix))
112
+ resolved_columns.extend(matched_columns)
113
+
114
+ return resolved_columns
115
+
116
+
117
+ def match_col_pattern(
118
+ pattern: str, columns: list[str], processed_columns: list[str]
119
+ ) -> list[str]:
120
+ regex = _col_pattern_to_regex(pattern)
121
+ return [c for c in columns if regex.match(c) and c not in processed_columns]
122
+
123
+
124
+ def _col_pattern_to_regex(pattern: str) -> Pattern[str]:
125
+ parts = []
126
+ for part in re.split(r"(\*)", pattern):
127
+ if part == "*":
128
+ parts.append(".*")
129
+ else:
130
+ parts.append(re.escape(part))
131
+ return re.compile("^" + "".join(parts) + "$")
132
+
133
+
134
+ def add_unreferenced_columns(
135
+ columns: list[str], referenced_columns: list[str]
136
+ ) -> list[str]:
137
+ unreferenced_columns = sorted([c for c in columns if c not in referenced_columns])
138
+ return referenced_columns + unreferenced_columns
139
+
140
+
141
+ def records_to_pandas(records: list[dict[str, ColumnType]]) -> "pd.DataFrame":
142
+ import pyarrow as pa
143
+
144
+ records = normalize_records(records)
145
+ table = pa.Table.from_pylist(records).to_pandas(types_mapper=arrow_types_mapper)
146
+ return table
147
+
148
+
149
+ def arrow_types_mapper(
150
+ arrow_type: "pa.DataType",
151
+ ) -> "pd.api.extensions.ExtensionDtype" | None:
152
+ import pandas as pd
153
+ import pyarrow as pa
154
+
155
+ # convert str => str
156
+ if pa.types.is_string(arrow_type):
157
+ return pd.StringDtype()
158
+ # default conversion for other types
159
+ else:
160
+ return None
@@ -0,0 +1,171 @@
1
+ from __future__ import annotations
2
+
3
+ from logging import getLogger
4
+ from typing import Any, Iterator, Mapping, Type
5
+
6
+ import jsonref # type: ignore
7
+ from jsonpath_ng import Fields, Index, JSONPath, Slice, Where, WhereNot # type: ignore
8
+ from jsonpath_ng.ext.filter import Filter # type: ignore
9
+ from pydantic import BaseModel
10
+
11
+ logger = getLogger(__name__)
12
+
13
+ Schema = Mapping[str, Any]
14
+
15
+
16
+ def resolved_schema(model: Type[BaseModel]) -> Schema:
17
+ schema_dict = model.model_json_schema()
18
+ base = "file:///memory/inspect_schema.json"
19
+ schema: Schema = jsonref.replace_refs(
20
+ schema_dict, base_uri=base, jsonschema=True, proxies=False
21
+ )
22
+ return schema
23
+
24
+
25
+ def jsonpath_in_schema(expr: JSONPath, schema: Schema) -> bool:
26
+ # don't validate unsupported constructs
27
+ if find_unsupported(expr):
28
+ return True
29
+
30
+ def descend(sch: Schema, tok: str | int | None) -> list[Schema]:
31
+ # First, branch through anyOf/oneOf/allOf
32
+ outs: list[Schema] = []
33
+ for branch in _expand_union(sch):
34
+ outs.extend(descend_concrete(branch, tok))
35
+ return outs
36
+
37
+ def descend_concrete(sch: Schema, tok: str | int | None) -> list[Schema]:
38
+ # totally open object – accept any child
39
+ if sch == {}:
40
+ return [{}] # stay alive, accept any key
41
+
42
+ outs: list[Schema] = []
43
+
44
+ def open_dict(node: Schema) -> None:
45
+ """Append the schema that governs unknown keys.
46
+
47
+ - None / missing -> open object -> {}
48
+ - True -> open object -> {}
49
+ - Mapping -> that mapping (could be {} or a real subschema)
50
+ - False -> closed object -> (do nothing)
51
+ """
52
+ if "additionalProperties" not in node:
53
+ if not node.get("properties"):
54
+ outs.append({})
55
+ else:
56
+ ap = node["additionalProperties"]
57
+ if ap is True:
58
+ outs.append({})
59
+ elif isinstance(ap, Mapping): # {} or {...}
60
+ outs.append(ap)
61
+ # ap is False -> closed dict -> ignore
62
+
63
+ # Wildcard -----------------------------------------------------------
64
+ if tok is None:
65
+ if "properties" in sch:
66
+ outs.extend(sch["properties"].values())
67
+ if "object" in _types(sch):
68
+ open_dict(sch)
69
+ if "array" in _types(sch) and "items" in sch:
70
+ outs.extend(_normalize_items(sch["items"]))
71
+ return outs
72
+
73
+ # Property access ----------------------------------------------------
74
+ if isinstance(tok, str):
75
+ if "properties" in sch and tok in sch["properties"]:
76
+ outs.append(sch["properties"][tok])
77
+ elif "additionalProperties" in sch: # PRESENCE, not truthiness
78
+ open_dict(sch)
79
+ elif "object" in _types(sch):
80
+ open_dict(sch)
81
+
82
+ # Array index --------------------------------------------------------
83
+ else: # tok is int or None from an Index node
84
+ if "array" in _types(sch) and "items" in sch:
85
+ outs.extend(_normalize_items(sch["items"], index=tok))
86
+
87
+ return outs
88
+
89
+ def _types(sch: Schema) -> set[str]:
90
+ t = sch.get("type")
91
+ return set(t) if isinstance(t, list) else {t} if t else set()
92
+
93
+ def _normalize_items(items: Any, index: int | None = None) -> list[Schema]:
94
+ if isinstance(items, list):
95
+ if index is None: # wildcard/slice
96
+ return items
97
+ if 0 <= index < len(items):
98
+ return [items[index]]
99
+ return []
100
+ if isinstance(items, Mapping):
101
+ return [items]
102
+ return []
103
+
104
+ states = [schema]
105
+ for tok in iter_tokens(expr):
106
+ next_states: list[Schema] = []
107
+ for st in states:
108
+ next_states.extend(descend(st, tok))
109
+ if not next_states: # nothing matched this segment
110
+ return False
111
+ states = next_states
112
+ return True # every segment found at least one schema
113
+
114
+
115
+ def iter_tokens(node: JSONPath) -> Iterator[str | int | None]:
116
+ """Linearise a jsonpath-ng AST into a stream of tokens we care about."""
117
+ if hasattr(node, "left"): # Child, Descendants, etc.
118
+ yield from iter_tokens(node.left)
119
+ yield from iter_tokens(node.right)
120
+ elif isinstance(node, Fields):
121
+ yield from node.fields # e.g. ["foo"]
122
+ elif isinstance(node, Index):
123
+ yield node.index # 0 / -1 / None for wildcard
124
+ elif isinstance(node, Slice):
125
+ yield None # treat any slice as wildcard
126
+
127
+
128
+ COMBINATORS = ("anyOf", "oneOf", "allOf")
129
+
130
+
131
+ def _expand_union(sch: Schema) -> list[Schema]:
132
+ """Return sch itself or the list of subschemas if it is a combinator."""
133
+ for key in COMBINATORS:
134
+ if key in sch:
135
+ subs: list[Schema] = []
136
+ for sub in sch[key]:
137
+ # a sub-schema might itself be an anyOf/oneOf/allOf
138
+ subs.extend(_expand_union(sub))
139
+ return subs
140
+ return [sch]
141
+
142
+
143
+ UNSUPPORTED: tuple[type[JSONPath], ...] = (
144
+ Filter, # [?foo > 0]
145
+ Where, # .foo[(@.bar < 42)]
146
+ WhereNot,
147
+ Slice, # [1:5] (wildcard “[*]” is Index/None, not Slice)
148
+ )
149
+
150
+
151
+ def find_unsupported(node: JSONPath) -> list[type[JSONPath]]:
152
+ """Return a list of node types present in `node` that we do not validate."""
153
+ bad: list[type[JSONPath]] = []
154
+ stack: list[JSONPath] = [node]
155
+ while stack:
156
+ n = stack.pop()
157
+ if isinstance(n, UNSUPPORTED):
158
+ bad.append(type(n))
159
+ # Drill into children (jsonpath-ng uses .left / .right / .child attributes)
160
+ for attr in ("left", "right", "child", "expression"):
161
+ stack.extend(
162
+ [getattr(n, attr)]
163
+ if hasattr(n, attr) and isinstance(getattr(n, attr), JSONPath)
164
+ else []
165
+ )
166
+ # handle containers like Fields(fields=[...]) and Index(index=[...])
167
+ if hasattr(n, "__dict__"):
168
+ for v in n.__dict__.values():
169
+ if isinstance(v, list):
170
+ stack.extend(x for x in v if isinstance(x, JSONPath))
171
+ return bad