eval-toolkit 0.27.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,238 @@
1
+ """eval-toolkit — reusable evaluation contracts for binary classification.
2
+
3
+ Public API remains available from ``eval_toolkit`` and from submodules:
4
+
5
+ from eval_toolkit import pr_auc, bootstrap_ci, BootstrapCI
6
+ from eval_toolkit.metrics import pr_auc
7
+
8
+ The package root uses lazy exports so importing ``eval_toolkit`` does not
9
+ eagerly import optional-heavy modules such as plotting, loaders, or harnesses.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from importlib import import_module
15
+ from typing import Any
16
+
17
+ from eval_toolkit._version import __version__
18
+
19
+ _EXPORTS: dict[str, str] = {
20
+ "CsvPredictionReader": "eval_toolkit.analysis",
21
+ "JsonlPredictionReader": "eval_toolkit.analysis",
22
+ "PredictionArrays": "eval_toolkit.analysis",
23
+ "bootstrap_metric_from_predictions": "eval_toolkit.analysis",
24
+ "load_prediction_arrays": "eval_toolkit.analysis",
25
+ "paired_diff_from_prediction_refs": "eval_toolkit.analysis",
26
+ "MetricState": "eval_toolkit.artifacts",
27
+ "PredictionArtifactRef": "eval_toolkit.artifacts",
28
+ "PredictionColumns": "eval_toolkit.artifacts",
29
+ "error_metric": "eval_toolkit.artifacts",
30
+ "sanitize_for_json": "eval_toolkit.artifacts",
31
+ "skipped_metric": "eval_toolkit.artifacts",
32
+ "validate_manifest": "eval_toolkit.artifacts",
33
+ "validate_payload": "eval_toolkit.artifacts",
34
+ "validate_prediction_artifact_ref": "eval_toolkit.artifacts",
35
+ "validate_results": "eval_toolkit.artifacts",
36
+ "write_json_strict": "eval_toolkit.artifacts",
37
+ "DEFAULT_CONFIDENCE": "eval_toolkit.bootstrap",
38
+ "DEFAULT_METHOD": "eval_toolkit.bootstrap",
39
+ "DEFAULT_N_RESAMPLES": "eval_toolkit.bootstrap",
40
+ "DEFAULT_SEED": "eval_toolkit.bootstrap",
41
+ "BootstrapCI": "eval_toolkit.bootstrap",
42
+ "DeLongResult": "eval_toolkit.bootstrap",
43
+ "MDEEstimate": "eval_toolkit.bootstrap",
44
+ "MetricFn": "eval_toolkit.bootstrap",
45
+ "PairedBootstrapCI": "eval_toolkit.bootstrap",
46
+ "ThresholdedMetricFn": "eval_toolkit.bootstrap",
47
+ "ThresholdFn": "eval_toolkit.bootstrap",
48
+ "bootstrap_ci": "eval_toolkit.bootstrap",
49
+ "cross_validate_metric": "eval_toolkit.bootstrap",
50
+ "cv_clt_ci": "eval_toolkit.bootstrap",
51
+ "delong_roc_variance": "eval_toolkit.bootstrap",
52
+ "mde_from_ci": "eval_toolkit.bootstrap",
53
+ "paired_bootstrap_diff": "eval_toolkit.bootstrap",
54
+ "paired_bootstrap_ece_diff": "eval_toolkit.bootstrap",
55
+ "paired_bootstrap_op_point_diff": "eval_toolkit.bootstrap",
56
+ "paired_mde": "eval_toolkit.bootstrap",
57
+ "DEFAULT_FN_COST": "eval_toolkit.calibration",
58
+ "DEFAULT_FP_COST": "eval_toolkit.calibration",
59
+ "DEFAULT_N_BINS": "eval_toolkit.calibration",
60
+ "DEFAULT_PRIOR": "eval_toolkit.calibration",
61
+ "DEFAULT_STRATEGY": "eval_toolkit.calibration",
62
+ "CostMatrix": "eval_toolkit.calibration",
63
+ "bayes_optimal_threshold": "eval_toolkit.calibration",
64
+ "fit_beta_calibrator": "eval_toolkit.calibration",
65
+ "fit_isotonic_calibrator": "eval_toolkit.calibration",
66
+ "fit_platt_calibrator": "eval_toolkit.calibration",
67
+ "fit_temperature": "eval_toolkit.calibration",
68
+ "fit_temperature_oracle": "eval_toolkit.calibration",
69
+ "reliability_curve": "eval_toolkit.calibration",
70
+ "reliability_diagram_data": "eval_toolkit.calibration",
71
+ "ClaimReport": "eval_toolkit.claims",
72
+ "ClaimSpec": "eval_toolkit.claims",
73
+ "EvidenceGate": "eval_toolkit.claims",
74
+ "GateResult": "eval_toolkit.claims",
75
+ "evaluate_claims": "eval_toolkit.claims",
76
+ "external_diagnostic_gate": "eval_toolkit.claims",
77
+ "headline_present_gate": "eval_toolkit.claims",
78
+ "low_fpr_feasibility_gate": "eval_toolkit.claims",
79
+ "metric_threshold_gate": "eval_toolkit.claims",
80
+ "minimum_slice_size_gate": "eval_toolkit.claims",
81
+ "no_leakage_errors_gate": "eval_toolkit.claims",
82
+ "no_scorer_errors_gate": "eval_toolkit.claims",
83
+ "paired_diff_present_gate": "eval_toolkit.claims",
84
+ "required_metric_gate": "eval_toolkit.claims",
85
+ "required_scorer_gate": "eval_toolkit.claims",
86
+ "required_slice_gate": "eval_toolkit.claims",
87
+ "source_role_gate": "eval_toolkit.claims",
88
+ "strict_artifact_gate": "eval_toolkit.claims",
89
+ "from_yaml": "eval_toolkit.config",
90
+ "frozen_config": "eval_toolkit.config",
91
+ "ANCHOR_RE": "eval_toolkit.docs",
92
+ "DEFAULT_FORMATTERS": "eval_toolkit.docs",
93
+ "render_files": "eval_toolkit.docs",
94
+ "render_text": "eval_toolkit.docs",
95
+ "walk_path": "eval_toolkit.docs",
96
+ "AggregateEvidence": "eval_toolkit.evidence",
97
+ "EvidenceAxis": "eval_toolkit.evidence",
98
+ "PairingMetadata": "eval_toolkit.evidence",
99
+ "RECOMMENDED_SOURCE_ROLES": "eval_toolkit.evidence",
100
+ "DEFAULT_BOOTSTRAP_RESAMPLES": "eval_toolkit.harness",
101
+ "RUN_RESULT_SCHEMA_VERSION": "eval_toolkit.harness",
102
+ "EvalSlice": "eval_toolkit.harness",
103
+ "RunResult": "eval_toolkit.harness",
104
+ "evaluate": "eval_toolkit.harness",
105
+ "evaluate_folded": "eval_toolkit.harness",
106
+ "evaluate_scorer_on_slice": "eval_toolkit.harness",
107
+ "with_claim_report": "eval_toolkit.harness",
108
+ "write_run_result": "eval_toolkit.harness",
109
+ "CrossSplitLeakageCheck": "eval_toolkit.leakage",
110
+ "ExactDuplicateCheck": "eval_toolkit.leakage",
111
+ "GroupLeakageCheck": "eval_toolkit.leakage",
112
+ "LabelConflictCheck": "eval_toolkit.leakage",
113
+ "LeakageCheck": "eval_toolkit.leakage",
114
+ "LeakageFinding": "eval_toolkit.leakage",
115
+ "LeakageReport": "eval_toolkit.leakage",
116
+ "NearDuplicateCheck": "eval_toolkit.leakage",
117
+ "NormalizedFormLeakageCheck": "eval_toolkit.leakage",
118
+ "TemporalLeakageCheck": "eval_toolkit.leakage",
119
+ "run_leakage_checks": "eval_toolkit.leakage",
120
+ "DataFrameLoader": "eval_toolkit.loaders",
121
+ "DatasetLoader": "eval_toolkit.loaders",
122
+ "HFDatasetsLoader": "eval_toolkit.loaders",
123
+ "ParquetGlobLoader": "eval_toolkit.loaders",
124
+ "SingleSliceLoader": "eval_toolkit.loaders",
125
+ "MANIFEST_SCHEMA_VERSION": "eval_toolkit.manifest",
126
+ "RunManifest": "eval_toolkit.manifest",
127
+ "SourceRoleRecord": "eval_toolkit.manifest",
128
+ "build_manifest": "eval_toolkit.manifest",
129
+ "validate_source_roles": "eval_toolkit.manifest",
130
+ "write_manifest": "eval_toolkit.manifest",
131
+ "DEFAULT_ASSUMED_PRIORS": "eval_toolkit.metrics",
132
+ "ThresholdResult": "eval_toolkit.metrics",
133
+ "brier_decomposition": "eval_toolkit.metrics",
134
+ "brier_score": "eval_toolkit.metrics",
135
+ "expected_calibration_error": "eval_toolkit.metrics",
136
+ "expected_calibration_error_debiased": "eval_toolkit.metrics",
137
+ "expected_calibration_error_equal_mass": "eval_toolkit.metrics",
138
+ "expected_calibration_error_l2": "eval_toolkit.metrics",
139
+ "expected_calibration_error_l2_debiased": "eval_toolkit.metrics",
140
+ "headline_metrics": "eval_toolkit.metrics",
141
+ "metrics_at_threshold": "eval_toolkit.metrics",
142
+ "pr_auc": "eval_toolkit.metrics",
143
+ "precision_at_prior": "eval_toolkit.metrics",
144
+ "quantile_stratified_pr_auc": "eval_toolkit.metrics",
145
+ "quantile_stratified_report": "eval_toolkit.metrics",
146
+ "roc_auc": "eval_toolkit.metrics",
147
+ "score_distribution_summary": "eval_toolkit.metrics",
148
+ "single_class_threshold_metrics": "eval_toolkit.metrics",
149
+ "stratified_recall": "eval_toolkit.metrics",
150
+ "FittedOperatingPoint": "eval_toolkit.operating_points",
151
+ "OperatingPointSpec": "eval_toolkit.operating_points",
152
+ "apply_operating_points": "eval_toolkit.operating_points",
153
+ "fit_operating_points": "eval_toolkit.operating_points",
154
+ "path_for_config": "eval_toolkit.paths",
155
+ "resolve_repo_path": "eval_toolkit.paths",
156
+ "split_provenance_config": "eval_toolkit.paths",
157
+ "DEFAULT_FIGSIZE": "eval_toolkit.plotting",
158
+ "PALETTE": "eval_toolkit.plotting",
159
+ "PLOT_STYLE": "eval_toolkit.plotting",
160
+ "make_palette": "eval_toolkit.plotting",
161
+ "plot_bootstrap_distribution": "eval_toolkit.plotting",
162
+ "plot_confusion_matrix_grid": "eval_toolkit.plotting",
163
+ "plot_lift_ci": "eval_toolkit.plotting",
164
+ "plot_metric_bars": "eval_toolkit.plotting",
165
+ "plot_pr_curve": "eval_toolkit.plotting",
166
+ "plot_reliability_diagram": "eval_toolkit.plotting",
167
+ "plot_score_histograms": "eval_toolkit.plotting",
168
+ "save_figure": "eval_toolkit.plotting",
169
+ "set_plot_style": "eval_toolkit.plotting",
170
+ "FileHash": "eval_toolkit.provenance",
171
+ "FileHashMissing": "eval_toolkit.provenance",
172
+ "capture_git_sha": "eval_toolkit.provenance",
173
+ "compute_file_hash": "eval_toolkit.provenance",
174
+ "figure_metadata": "eval_toolkit.provenance",
175
+ "file_sha256": "eval_toolkit.provenance",
176
+ "make_run_dir": "eval_toolkit.provenance",
177
+ "EvalSliceLike": "eval_toolkit.protocols",
178
+ "PredictionReader": "eval_toolkit.protocols",
179
+ "Scorer": "eval_toolkit.protocols",
180
+ "SliceAwareScorer": "eval_toolkit.protocols",
181
+ "Versioned": "eval_toolkit.protocols",
182
+ "set_global_seeds": "eval_toolkit.seeds",
183
+ "GroupKFoldSplitter": "eval_toolkit.splits",
184
+ "HoldoutSplitter": "eval_toolkit.splits",
185
+ "PoolBuilder": "eval_toolkit.splits",
186
+ "SourceDisjointKFoldSplitter": "eval_toolkit.splits",
187
+ "Splitter": "eval_toolkit.splits",
188
+ "StratifiedKFoldSplitter": "eval_toolkit.splits",
189
+ "TimeSeriesSplitter": "eval_toolkit.splits",
190
+ "iter_folds_with_pool": "eval_toolkit.splits",
191
+ "DEFAULT_DEDUP_THRESHOLD": "eval_toolkit.text_dedup",
192
+ "DedupReport": "eval_toolkit.text_dedup",
193
+ "EmbeddingCosineStrategy": "eval_toolkit.text_dedup",
194
+ "ExactNormalizedHashStrategy": "eval_toolkit.text_dedup",
195
+ "JaccardNgramStrategy": "eval_toolkit.text_dedup",
196
+ "MinHashLSHStrategy": "eval_toolkit.text_dedup",
197
+ "SimilarityAuditFinding": "eval_toolkit.text_dedup",
198
+ "SimilarityAuditReport": "eval_toolkit.text_dedup",
199
+ "SimilarityStrategy": "eval_toolkit.text_dedup",
200
+ "TfidfCosineStrategy": "eval_toolkit.text_dedup",
201
+ "audit_source_label_similarity": "eval_toolkit.text_dedup",
202
+ "cross_dedup": "eval_toolkit.text_dedup",
203
+ "near_dedup": "eval_toolkit.text_dedup",
204
+ "normalize_text_for_dedup": "eval_toolkit.text_dedup",
205
+ "sha256_text": "eval_toolkit.text_dedup",
206
+ "CISafeThresholdSelector": "eval_toolkit.thresholds",
207
+ "CostSensitiveSelector": "eval_toolkit.thresholds",
208
+ "MaxF1Selector": "eval_toolkit.thresholds",
209
+ "TargetFPRSelector": "eval_toolkit.thresholds",
210
+ "TargetPrecisionSelector": "eval_toolkit.thresholds",
211
+ "TargetRecallSelector": "eval_toolkit.thresholds",
212
+ "ThresholdPolicyMetadata": "eval_toolkit.thresholds",
213
+ "ThresholdSelector": "eval_toolkit.thresholds",
214
+ "WilsonInterval": "eval_toolkit.thresholds",
215
+ "YoudenJSelector": "eval_toolkit.thresholds",
216
+ "select_threshold": "eval_toolkit.thresholds",
217
+ "wilson_interval": "eval_toolkit.thresholds",
218
+ }
219
+
220
+ __all__ = ["__version__", *_EXPORTS.keys()]
221
+
222
+
223
+ def __getattr__(name: str) -> Any:
224
+ """Resolve public symbols lazily."""
225
+ if name == "__version__":
226
+ return __version__
227
+ module_name = _EXPORTS.get(name)
228
+ if module_name is None:
229
+ raise AttributeError(f"module 'eval_toolkit' has no attribute {name!r}")
230
+ module = import_module(module_name)
231
+ value = getattr(module, name)
232
+ globals()[name] = value
233
+ return value
234
+
235
+
236
+ def __dir__() -> list[str]:
237
+ """Expose lazy public symbols to introspection."""
238
+ return sorted(__all__)
@@ -0,0 +1,156 @@
1
+ """eval_toolkit CLI entry: ``python -m eval_toolkit ...``.
2
+
3
+ Subcommands:
4
+ schemas list List bundled JSON schema names.
5
+ schemas show <name> Pretty-print a single schema.
6
+ schemas check Meta-validate every bundled schema against
7
+ Draft 2020-12. Used by CI to assert
8
+ schema integrity as a package invariant.
9
+ validate <file> <schema-name> Validate a JSON file against a bundled schema
10
+ (requires [validation] extra).
11
+
12
+ Exit codes:
13
+ 0 — ok
14
+ 1 — validation failed (payload or schema meta-validation)
15
+ 2 — bad arg (file or schema not found; empty schemas directory)
16
+ 3 — missing optional dependency (jsonschema for validate)
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import argparse
22
+ import json
23
+ import sys
24
+ from pathlib import Path
25
+
26
+ import eval_toolkit
27
+
28
+
29
+ def _schemas_dir() -> Path:
30
+ return Path(eval_toolkit.__file__).parent / "schemas"
31
+
32
+
33
+ def _cmd_schemas_list(_args: argparse.Namespace) -> int:
34
+ for f in sorted(_schemas_dir().glob("*.json")):
35
+ print(f.stem)
36
+ return 0
37
+
38
+
39
+ def _cmd_schemas_show(args: argparse.Namespace) -> int:
40
+ name: str = args.name
41
+ candidate = _schemas_dir() / f"{name}.json"
42
+ if not candidate.exists():
43
+ # Tolerate users typing the full filename (e.g., results.v1.json).
44
+ candidate_with_ext = _schemas_dir() / name
45
+ if candidate_with_ext.exists() and candidate_with_ext.suffix == ".json":
46
+ candidate = candidate_with_ext
47
+ else:
48
+ print(f"unknown schema: {name}", file=sys.stderr)
49
+ return 2
50
+ print(json.dumps(json.loads(candidate.read_text()), indent=2, sort_keys=True))
51
+ return 0
52
+
53
+
54
+ def _cmd_schemas_check(_args: argparse.Namespace) -> int:
55
+ """Meta-validate every bundled schema against Draft 2020-12.
56
+
57
+ Exits non-zero on empty schemas directory, malformed JSON, or
58
+ schema-meta-validation failure. ``jsonschema`` is a hard dep
59
+ (see ``pyproject.toml``), so no optional-extra degrade path needed.
60
+ """
61
+ from jsonschema import ( # type: ignore[import-untyped] # noqa: PLC0415
62
+ Draft202012Validator,
63
+ )
64
+ from jsonschema.exceptions import ( # type: ignore[import-untyped] # noqa: PLC0415
65
+ SchemaError,
66
+ )
67
+
68
+ schemas_dir = _schemas_dir()
69
+ files = sorted(schemas_dir.glob("*.json"))
70
+ if not files:
71
+ print(f"ERROR: no schemas found in {schemas_dir}", file=sys.stderr)
72
+ return 2
73
+ failures: list[str] = []
74
+ for f in files:
75
+ try:
76
+ schema = json.loads(f.read_text())
77
+ Draft202012Validator.check_schema(schema)
78
+ print(f" {f.name}: OK")
79
+ except (json.JSONDecodeError, SchemaError) as exc:
80
+ failures.append(f"{f.name}: {exc}")
81
+ if failures:
82
+ print("\nSchema validation failures:", file=sys.stderr)
83
+ for line in failures:
84
+ print(f" {line}", file=sys.stderr)
85
+ return 1
86
+ return 0
87
+
88
+
89
+ def _cmd_validate(args: argparse.Namespace) -> int:
90
+ try:
91
+ import jsonschema # noqa: F401
92
+ except ImportError:
93
+ print(
94
+ "validate requires the [validation] extra: " "pip install 'eval-toolkit[validation]'",
95
+ file=sys.stderr,
96
+ )
97
+ return 3
98
+ schema_name: str = args.schema
99
+ schema_path = _schemas_dir() / f"{schema_name}.json"
100
+ if not schema_path.exists():
101
+ # Tolerate users typing the full filename.
102
+ alt = _schemas_dir() / schema_name
103
+ if alt.exists() and alt.suffix == ".json":
104
+ schema_path = alt
105
+ else:
106
+ print(f"unknown schema: {schema_name}", file=sys.stderr)
107
+ return 2
108
+ file_path = Path(args.file)
109
+ if not file_path.exists():
110
+ print(f"file not found: {args.file}", file=sys.stderr)
111
+ return 2
112
+ schema = json.loads(schema_path.read_text())
113
+ payload = json.loads(file_path.read_text())
114
+ import jsonschema as _js # noqa: PLC0415
115
+
116
+ try:
117
+ _js.validate(payload, schema)
118
+ except _js.ValidationError as exc:
119
+ loc = "/".join(str(p) for p in exc.absolute_path) or "(root)"
120
+ print(f"VALIDATION ERROR at {loc}: {exc.message}", file=sys.stderr)
121
+ return 1
122
+ print(f"{args.file}: OK against {schema_name}")
123
+ return 0
124
+
125
+
126
+ def main(argv: list[str] | None = None) -> int:
127
+ """Entry point for ``python -m eval_toolkit`` and the ``eval-toolkit`` script."""
128
+ parser = argparse.ArgumentParser(
129
+ prog="eval-toolkit",
130
+ description="eval-toolkit CLI: schema discovery + payload validation",
131
+ )
132
+ sub = parser.add_subparsers(dest="cmd", required=True)
133
+
134
+ schemas = sub.add_parser("schemas", help="bundled JSON schema discovery")
135
+ schemas_sub = schemas.add_subparsers(dest="schemas_cmd", required=True)
136
+ schemas_list = schemas_sub.add_parser("list", help="list bundled schema names")
137
+ schemas_list.set_defaults(func=_cmd_schemas_list)
138
+ schemas_show = schemas_sub.add_parser("show", help="pretty-print a single schema")
139
+ schemas_show.add_argument("name", help="schema name (e.g., 'results.v1')")
140
+ schemas_show.set_defaults(func=_cmd_schemas_show)
141
+ schemas_check = schemas_sub.add_parser(
142
+ "check", help="meta-validate every bundled schema against Draft 2020-12"
143
+ )
144
+ schemas_check.set_defaults(func=_cmd_schemas_check)
145
+
146
+ validate = sub.add_parser("validate", help="validate a JSON file against a bundled schema")
147
+ validate.add_argument("file", help="path to JSON file to validate")
148
+ validate.add_argument("schema", help="schema name (e.g., 'results.v1')")
149
+ validate.set_defaults(func=_cmd_validate)
150
+
151
+ args = parser.parse_args(argv)
152
+ return int(args.func(args))
153
+
154
+
155
+ if __name__ == "__main__": # pragma: no cover
156
+ sys.exit(main())
@@ -0,0 +1,5 @@
1
+ """Single lightweight version source."""
2
+
3
+ __all__ = ["__version__"]
4
+
5
+ __version__ = "0.27.1"
@@ -0,0 +1,196 @@
1
+ """Post-run analysis over retained prediction artifacts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import json
7
+ from collections.abc import Mapping, Sequence
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+
14
+ from eval_toolkit.bootstrap import bootstrap_ci, paired_bootstrap_diff
15
+ from eval_toolkit.metrics import pr_auc
16
+ from eval_toolkit.protocols import PredictionReader
17
+
18
+ __all__ = [
19
+ "CsvPredictionReader",
20
+ "JsonlPredictionReader",
21
+ "PredictionArrays",
22
+ "bootstrap_metric_from_predictions",
23
+ "load_prediction_arrays",
24
+ "paired_diff_from_prediction_refs",
25
+ ]
26
+
27
+
28
+ @dataclass(frozen=True, slots=True)
29
+ class PredictionArrays:
30
+ """Numeric arrays loaded from a prediction artifact."""
31
+
32
+ labels: np.ndarray
33
+ scores: np.ndarray
34
+ row_ids: tuple[str, ...] = ()
35
+ content_hashes: tuple[str, ...] = ()
36
+
37
+ def __post_init__(self) -> None:
38
+ """Validate array shape."""
39
+ if self.labels.shape != self.scores.shape:
40
+ raise ValueError("labels and scores must have identical shape")
41
+
42
+
43
+ class CsvPredictionReader:
44
+ """Read CSV prediction files into a column-oriented mapping."""
45
+
46
+ def read_predictions(
47
+ self,
48
+ uri: str,
49
+ *,
50
+ columns: Mapping[str, str],
51
+ ) -> Mapping[str, Sequence[object]]:
52
+ """Read a local CSV file."""
53
+ wanted = set(columns.values())
54
+ out: dict[str, list[object]] = {col: [] for col in wanted}
55
+ with Path(uri).open(newline="") as fh:
56
+ reader = csv.DictReader(fh)
57
+ for row in reader:
58
+ for col in wanted:
59
+ out[col].append(row.get(col, ""))
60
+ return out
61
+
62
+
63
+ class JsonlPredictionReader:
64
+ """Read JSON Lines prediction files into a column-oriented mapping."""
65
+
66
+ def read_predictions(
67
+ self,
68
+ uri: str,
69
+ *,
70
+ columns: Mapping[str, str],
71
+ ) -> Mapping[str, Sequence[object]]:
72
+ """Read a local JSONL file."""
73
+ wanted = set(columns.values())
74
+ out: dict[str, list[object]] = {col: [] for col in wanted}
75
+ with Path(uri).open() as fh:
76
+ for line in fh:
77
+ if not line.strip():
78
+ continue
79
+ row = json.loads(line)
80
+ for col in wanted:
81
+ out[col].append(row.get(col))
82
+ return out
83
+
84
+
85
+ def load_prediction_arrays(
86
+ ref: Mapping[str, Any],
87
+ *,
88
+ reader: PredictionReader | None = None,
89
+ ) -> PredictionArrays:
90
+ """Load labels and scores from a prediction artifact reference.
91
+
92
+ Raises
93
+ ------
94
+ ValueError
95
+ If ``ref`` lacks a ``columns`` mapping, lacks a non-empty ``uri``,
96
+ or its ``columns`` mapping is missing the ``label`` / ``score``
97
+ keys (re-raised from :func:`_required_column`).
98
+ """
99
+ columns = ref.get("columns")
100
+ if not isinstance(columns, Mapping):
101
+ raise ValueError("prediction ref must include a columns mapping")
102
+ label_col = _required_column(columns, "label")
103
+ score_col = _required_column(columns, "score")
104
+ uri = ref.get("uri")
105
+ if not isinstance(uri, str) or not uri:
106
+ raise ValueError("prediction ref must include a non-empty uri")
107
+ selected_reader = reader or _reader_for_ref(ref)
108
+ reader_columns = {str(k): str(v) for k, v in columns.items() if isinstance(v, str)}
109
+ table = selected_reader.read_predictions(uri, columns=reader_columns)
110
+ labels = np.asarray(table[label_col], dtype=int)
111
+ scores = np.asarray(table[score_col], dtype=float)
112
+ row_id_col = columns.get("row_id")
113
+ hash_col = columns.get("content_hash")
114
+ row_ids = tuple(str(v) for v in table.get(str(row_id_col), ())) if row_id_col else ()
115
+ hashes = tuple(str(v) for v in table.get(str(hash_col), ())) if hash_col else ()
116
+ return PredictionArrays(labels=labels, scores=scores, row_ids=row_ids, content_hashes=hashes)
117
+
118
+
119
+ def bootstrap_metric_from_predictions(
120
+ ref: Mapping[str, Any],
121
+ *,
122
+ reader: PredictionReader | None = None,
123
+ n_resamples: int = 1000,
124
+ seed: int = 42,
125
+ ) -> dict[str, object]:
126
+ """Compute a PR-AUC bootstrap CI from one prediction ref."""
127
+ arrays = load_prediction_arrays(ref, reader=reader)
128
+ return bootstrap_ci(
129
+ arrays.labels,
130
+ arrays.scores,
131
+ pr_auc,
132
+ n_resamples=n_resamples,
133
+ seed=seed,
134
+ ).to_dict()
135
+
136
+
137
+ def paired_diff_from_prediction_refs(
138
+ baseline_ref: Mapping[str, Any],
139
+ candidate_ref: Mapping[str, Any],
140
+ *,
141
+ baseline_reader: PredictionReader | None = None,
142
+ candidate_reader: PredictionReader | None = None,
143
+ n_resamples: int = 1000,
144
+ seed: int = 42,
145
+ ) -> dict[str, object]:
146
+ """Compute paired PR-AUC delta from two prediction refs.
147
+
148
+ Raises
149
+ ------
150
+ ValueError
151
+ If the two refs disagree on row count, label values, ``row_ids``,
152
+ or ``content_hashes``; or if either ref is malformed (re-raised
153
+ from :func:`load_prediction_arrays`).
154
+ """
155
+ baseline = load_prediction_arrays(baseline_ref, reader=baseline_reader)
156
+ candidate = load_prediction_arrays(candidate_ref, reader=candidate_reader)
157
+ if baseline.labels.shape != candidate.labels.shape:
158
+ raise ValueError("prediction refs must have the same number of rows")
159
+ if not np.array_equal(baseline.labels, candidate.labels):
160
+ raise ValueError("prediction refs must have identical labels for paired comparison")
161
+ if baseline.row_ids and candidate.row_ids and baseline.row_ids != candidate.row_ids:
162
+ raise ValueError("prediction refs must have identical row_ids for paired comparison")
163
+ if (
164
+ baseline.content_hashes
165
+ and candidate.content_hashes
166
+ and baseline.content_hashes != candidate.content_hashes
167
+ ):
168
+ raise ValueError("prediction refs must have identical content_hashes for paired comparison")
169
+ return paired_bootstrap_diff(
170
+ baseline.labels,
171
+ baseline.scores,
172
+ candidate.scores,
173
+ pr_auc,
174
+ n_resamples=n_resamples,
175
+ seed=seed,
176
+ ).to_dict()
177
+
178
+
179
+ def _reader_for_ref(ref: Mapping[str, Any]) -> PredictionReader:
180
+ media_type = str(ref.get("media_type", ""))
181
+ uri = str(ref.get("uri", ""))
182
+ if media_type in {"text/csv", "application/csv"} or uri.endswith(".csv"):
183
+ return CsvPredictionReader()
184
+ if media_type in {"application/jsonl", "application/x-ndjson"} or uri.endswith(".jsonl"):
185
+ return JsonlPredictionReader()
186
+ raise ValueError(
187
+ "no built-in prediction reader for this artifact; pass a PredictionReader "
188
+ "or use CSV/JSONL"
189
+ )
190
+
191
+
192
+ def _required_column(columns: Mapping[str, object], key: str) -> str:
193
+ value = columns.get(key)
194
+ if not isinstance(value, str) or not value:
195
+ raise ValueError(f"prediction columns must include {key!r}")
196
+ return value