eval-toolkit 0.27.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1052 @@
1
+ """Leakage detection: pluggable :class:`LeakageCheck` Protocol + reference impls.
2
+
3
+ A :class:`LeakageCheck` validates a dict of :class:`~eval_toolkit.harness.EvalSlice`
4
+ (the shape produced by :class:`~eval_toolkit.loaders.DatasetLoader.load_splits`)
5
+ and returns a :class:`LeakageFinding` describing rows / pairs to drop and a
6
+ severity. Multiple checks compose into a :class:`LeakageReport`.
7
+
8
+ Within-split checks (:class:`ExactDuplicateCheck`, :class:`NearDuplicateCheck`,
9
+ :class:`NormalizedFormLeakageCheck`, :class:`LabelConflictCheck`) read one
10
+ split at a time. Cross-split checks (:class:`CrossSplitLeakageCheck`,
11
+ :class:`GroupLeakageCheck`, :class:`TemporalLeakageCheck`) read multiple
12
+ splits at once. The contract is uniform — every check takes a
13
+ ``Mapping[str, EvalSlice]``.
14
+
15
+ Two integration paths:
16
+
17
+ - Standalone :func:`run_leakage_checks` for offline / CI use.
18
+ - Pass ``leakage_checks=[...]`` directly to
19
+ :func:`eval_toolkit.harness.evaluate` for inline enforcement
20
+ (``on_leakage="raise"`` by default).
21
+
22
+ Either way the report is captured in :class:`~eval_toolkit.manifest.RunManifest`
23
+ so even non-failing runs are auditable.
24
+
25
+ References
26
+ ----------
27
+ .. [1] Kapoor, S. & Narayanan, A. "Leakage and the reproducibility crisis in
28
+ machine-learning-based science." Patterns 4(9), 2023.
29
+ https://arxiv.org/abs/2207.07048
30
+ .. [2] PI_HackAPrompt_SQuAD (2025): encoding-obfuscated dupes detect at 21.3 %
31
+ under naive dedup but achieve 76.2 % attack success rate, motivating
32
+ the :class:`NormalizedFormLeakageCheck` reference impl.
33
+ https://arxiv.org/html/2505.04806v1
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import unicodedata
39
+ from collections import defaultdict
40
+ from collections.abc import Mapping, Sequence
41
+ from dataclasses import dataclass, field
42
+ from typing import Literal, Protocol, runtime_checkable
43
+
44
+ import numpy as np
45
+ import pandas as pd
46
+
47
+ from eval_toolkit.harness import EvalSlice
48
+ from eval_toolkit.text_dedup import (
49
+ DEFAULT_DEDUP_THRESHOLD,
50
+ ExactNormalizedHashStrategy,
51
+ SimilarityStrategy,
52
+ TfidfCosineStrategy,
53
+ cross_dedup,
54
+ cross_dedup_pairs,
55
+ near_dedup,
56
+ sha256_text,
57
+ )
58
+
59
+ __all__ = [
60
+ "CrossSplitLeakageCheck",
61
+ "ExactDuplicateCheck",
62
+ "GroupLeakageCheck",
63
+ "LabelConflictCheck",
64
+ "LeakageCheck",
65
+ "LeakageFinding",
66
+ "LeakageReport",
67
+ "NearDuplicateCheck",
68
+ "NormalizedFormLeakageCheck",
69
+ "Severity",
70
+ "TemporalLeakageCheck",
71
+ "Versioned",
72
+ "run_leakage_checks",
73
+ ]
74
+
75
+ Severity = Literal["error", "warning", "info"]
76
+
77
+
78
+ @runtime_checkable
79
+ class Versioned(Protocol):
80
+ """Anything exposing a ``version: str`` attribute.
81
+
82
+ Used by :class:`~eval_toolkit.manifest.RunManifest` to capture per-object
83
+ versions of any Tier-2 implementation (Scorer, LeakageCheck, Splitter,
84
+ ThresholdSelector, DatasetLoader). Mirrors the lm-evaluation-harness
85
+ ``VERSION`` field pattern, which invalidates cross-version metric
86
+ comparisons. Opt-in: implementations are not required to set ``version``.
87
+ """
88
+
89
+ @property
90
+ def version(self) -> str: # pragma: no cover
91
+ """Stable version string for this implementation."""
92
+ ...
93
+
94
+
95
+ # ---------------------------------------------------------------------------
96
+ # Report types
97
+ # ---------------------------------------------------------------------------
98
+
99
+
100
+ @dataclass(frozen=True, slots=True)
101
+ class LeakageFinding:
102
+ """A single :class:`LeakageCheck`'s finding.
103
+
104
+ Mirrors the audit-trail discipline of
105
+ :class:`eval_toolkit.text_dedup.DedupReport` but is composable across
106
+ multiple checks via :class:`LeakageReport`.
107
+
108
+ Parameters
109
+ ----------
110
+ check_name : str
111
+ Stable identifier for the check (e.g. ``"ExactDuplicateCheck"``).
112
+ severity : {"error", "warning", "info"}
113
+ Gating semantics for ``on_leakage="raise"``: only ``"error"``
114
+ findings raise; warnings and info are recorded but don't fail.
115
+ drop_indices : dict[str, list[int]] or None
116
+ Per-split row indices the check recommends dropping. ``None``
117
+ (v0.18.0) signals a pair-count finding that has not localized
118
+ rows to drop — useful for emitters that only know
119
+ ``n_affected`` (e.g. V4's SHA256 disjointness audit and other
120
+ pair-tally audits). Distinguishes "no rows to drop because the
121
+ check found nothing" (``{}``) from "this check doesn't track
122
+ rows" (``None``).
123
+ evidence : dict[str, object]
124
+ Check-specific structured evidence (e.g. dropped-pairs lists,
125
+ violating group ids, time-boundary intervals).
126
+ message : str
127
+ Human-readable summary; appears in error messages and logs.
128
+ n_affected : int
129
+ Total rows affected (sum across splits) — quick-scan field.
130
+ """
131
+
132
+ check_name: str
133
+ severity: Severity
134
+ drop_indices: dict[str, list[int]] | None
135
+ evidence: dict[str, object]
136
+ message: str
137
+ n_affected: int
138
+
139
+ def to_dict(self) -> dict[str, object]:
140
+ """JSON-serializable representation for the manifest."""
141
+ drop: object = dict(self.drop_indices) if self.drop_indices is not None else None
142
+ return {
143
+ "check_name": self.check_name,
144
+ "severity": self.severity,
145
+ "drop_indices": drop,
146
+ "evidence": dict(self.evidence),
147
+ "message": self.message,
148
+ "n_affected": self.n_affected,
149
+ }
150
+
151
+
152
+ @dataclass(frozen=True, slots=True)
153
+ class LeakageReport:
154
+ """Aggregate of one or more :class:`LeakageCheck` results.
155
+
156
+ Parameters
157
+ ----------
158
+ findings : list[LeakageFinding]
159
+ One entry per check that ran. Empty findings list = clean run.
160
+ """
161
+
162
+ findings: list[LeakageFinding] = field(default_factory=list)
163
+
164
+ def has_errors(self) -> bool:
165
+ """True iff any finding is ``"error"`` severity AND actually flagged rows.
166
+
167
+ A clean run produces error-severity findings with ``n_affected=0``
168
+ ("the check ran and found nothing") — those don't gate the run.
169
+ """
170
+ return any(f.severity == "error" and f.n_affected > 0 for f in self.findings)
171
+
172
+ def errors(self) -> list[LeakageFinding]:
173
+ """Subset of findings with severity ``"error"`` AND ``n_affected > 0``."""
174
+ return [f for f in self.findings if f.severity == "error" and f.n_affected > 0]
175
+
176
+ def warnings(self) -> list[LeakageFinding]:
177
+ """Subset of findings with severity ``"warning"`` AND ``n_affected > 0``."""
178
+ return [f for f in self.findings if f.severity == "warning" and f.n_affected > 0]
179
+
180
+ def merged_drop_indices(self) -> dict[str, list[int]]:
181
+ """Union of all findings' ``drop_indices``, sorted, deduped per split.
182
+
183
+ Skips findings whose ``drop_indices`` is ``None`` (pair-tally
184
+ findings that don't localize rows). See v0.18.0 changelog.
185
+ """
186
+ merged: dict[str, set[int]] = defaultdict(set)
187
+ for f in self.findings:
188
+ if f.drop_indices is None:
189
+ continue
190
+ for split_name, idxs in f.drop_indices.items():
191
+ merged[split_name].update(idxs)
192
+ return {name: sorted(idxs) for name, idxs in merged.items()}
193
+
194
+ def to_dict(self) -> dict[str, object]:
195
+ """JSON-serializable representation for the manifest."""
196
+ return {"findings": [f.to_dict() for f in self.findings]}
197
+
198
+
199
+ # ---------------------------------------------------------------------------
200
+ # Protocol
201
+ # ---------------------------------------------------------------------------
202
+
203
+
204
+ @runtime_checkable
205
+ class LeakageCheck(Protocol):
206
+ """A pluggable validator over a dict of splits.
207
+
208
+ All checks take the same shape — a ``Mapping[str, EvalSlice]``, the
209
+ output of :class:`~eval_toolkit.loaders.DatasetLoader.load_splits`.
210
+ Within-split checks ignore non-target keys; cross-split checks read
211
+ multiple. Pure: no IO, no mutation.
212
+
213
+ Attributes
214
+ ----------
215
+ name : str
216
+ Stable identifier for the check, written into the
217
+ :class:`LeakageFinding` and used for downstream filtering.
218
+ """
219
+
220
+ name: str
221
+
222
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding: # pragma: no cover
223
+ """Validate the splits and return one :class:`LeakageFinding`."""
224
+ ...
225
+
226
+
227
+ # ---------------------------------------------------------------------------
228
+ # Helpers
229
+ # ---------------------------------------------------------------------------
230
+
231
+
232
+ def _aggressive_normalize(text: str) -> str:
233
+ """NFKC + zero-width / control / Symbol-Other strip + lowercase.
234
+
235
+ Used by :class:`NormalizedFormLeakageCheck` to detect encoding-obfuscated
236
+ duplicates that naive lowercase + whitespace dedup misses (the dominant
237
+ unfixed leakage class in prompt-injection corpora — see
238
+ [PI_HackAPrompt_SQuAD 2025]).
239
+ """
240
+ nfkc = unicodedata.normalize("NFKC", text)
241
+ # Drop control chars (Cc, Cf — zero-width formatters), surrogate (Cs),
242
+ # private-use (Co), and Symbol-Other (So — most emoji + decorative pads).
243
+ drop_categories = {"Cc", "Cf", "Cs", "Co", "So"}
244
+ cleaned = "".join(c for c in nfkc if unicodedata.category(c) not in drop_categories)
245
+ # Collapse whitespace + lowercase
246
+ return " ".join(cleaned.split()).lower()
247
+
248
+
249
+ def _select_targets(
250
+ splits: Mapping[str, EvalSlice], target_splits: Sequence[str] | None
251
+ ) -> list[str]:
252
+ """Return the split names to scan; default = all keys in insertion order."""
253
+ if target_splits is None:
254
+ return list(splits.keys())
255
+ missing = [name for name in target_splits if name not in splits]
256
+ if missing:
257
+ raise KeyError(f"target_splits not in splits: {missing}")
258
+ return list(target_splits)
259
+
260
+
261
+ # ---------------------------------------------------------------------------
262
+ # Within-split checks
263
+ # ---------------------------------------------------------------------------
264
+
265
+
266
+ @dataclass(frozen=True, slots=True)
267
+ class ExactDuplicateCheck:
268
+ """Within-split exact-duplicate detection.
269
+
270
+ Wraps :class:`eval_toolkit.text_dedup.ExactNormalizedHashStrategy`
271
+ (whitespace-normalized SHA-256 buckets). Severity defaults to
272
+ ``"warning"`` because exact dupes are common in real corpora; opt into
273
+ ``"error"`` for strict-mode CI gates.
274
+
275
+ Parameters
276
+ ----------
277
+ target_splits : sequence of str or None, optional
278
+ Splits to scan. ``None`` = all splits.
279
+ severity : {"error", "warning", "info"}, optional
280
+ Gating severity. Default ``"warning"``.
281
+ """
282
+
283
+ target_splits: Sequence[str] | None = None
284
+ severity: Severity = "warning"
285
+
286
+ @property
287
+ def name(self) -> str:
288
+ """Stable check identifier."""
289
+ return "ExactDuplicateCheck"
290
+
291
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
292
+ """Drop exact-normalized duplicates within each target split."""
293
+ targets = _select_targets(splits, self.target_splits)
294
+ drop: dict[str, list[int]] = {}
295
+ evidence_pairs: dict[str, list[tuple[int, int, float]]] = {}
296
+ n_affected = 0
297
+ for split_name in targets:
298
+ slice_ = splits[split_name]
299
+ texts = slice_.features
300
+ if len(texts) <= 1:
301
+ continue
302
+ report = near_dedup(
303
+ texts,
304
+ threshold=0.5, # any positive threshold; hash strategy returns 0.0/1.0
305
+ strategy=ExactNormalizedHashStrategy(),
306
+ )
307
+ dropped = [pair[0] for pair in report.dropped_pairs]
308
+ if dropped:
309
+ drop[split_name] = sorted(set(dropped))
310
+ evidence_pairs[split_name] = list(report.dropped_pairs)
311
+ n_affected += len(drop[split_name])
312
+ return LeakageFinding(
313
+ check_name=self.name,
314
+ severity=self.severity,
315
+ drop_indices=drop,
316
+ evidence={"dropped_pairs_by_split": evidence_pairs},
317
+ message=(
318
+ f"exact-duplicate dedup affected {n_affected} rows across " f"{len(drop)} split(s)"
319
+ if n_affected
320
+ else "no exact duplicates found"
321
+ ),
322
+ n_affected=n_affected,
323
+ )
324
+
325
+
326
+ @dataclass(frozen=True, slots=True)
327
+ class NearDuplicateCheck:
328
+ """Within-split near-duplicate detection via a pluggable similarity strategy.
329
+
330
+ Default strategy is :class:`eval_toolkit.text_dedup.TfidfCosineStrategy`
331
+ (lexical n-gram TF-IDF). Pass an :class:`EmbeddingCosineStrategy` for
332
+ semantic dedup or :class:`MinHashLSHStrategy` for cheap approximate dedup.
333
+
334
+ Parameters
335
+ ----------
336
+ threshold : float, optional
337
+ Similarity threshold in (0, 1). Default 0.9.
338
+ strategy : SimilarityStrategy or None, optional
339
+ Backend. ``None`` instantiates :class:`TfidfCosineStrategy`.
340
+ target_splits : sequence of str or None, optional
341
+ Splits to scan. ``None`` = all.
342
+ severity : {"error", "warning", "info"}, optional
343
+ Default ``"warning"``.
344
+ label_aware : bool, optional
345
+ When ``True``, :meth:`validate_label_split` emits two findings per
346
+ check — one for same-label pairs and one for cross-label pairs —
347
+ with their own severities (see ``severity_same_label`` /
348
+ ``severity_cross_label``). Cross-label near-duplicates within a
349
+ split are a label-noise signal (the same text carries conflicting
350
+ supervision) and typically warrant a stricter severity. Default
351
+ ``False`` preserves the single-finding contract.
352
+ severity_same_label : {"error", "warning", "info"}, optional
353
+ Severity for the same-label finding when ``label_aware=True``.
354
+ Default ``"warning"``.
355
+ severity_cross_label : {"error", "warning", "info"}, optional
356
+ Severity for the cross-label finding when ``label_aware=True``.
357
+ Default ``"error"``.
358
+ """
359
+
360
+ threshold: float = DEFAULT_DEDUP_THRESHOLD
361
+ strategy: SimilarityStrategy | None = None
362
+ target_splits: Sequence[str] | None = None
363
+ severity: Severity = "warning"
364
+ label_aware: bool = False
365
+ severity_same_label: Severity = "warning"
366
+ severity_cross_label: Severity = "error"
367
+
368
+ @property
369
+ def name(self) -> str:
370
+ """Stable check identifier."""
371
+ return "NearDuplicateCheck"
372
+
373
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
374
+ """Drop near-duplicates per the active strategy within each target split."""
375
+ targets = _select_targets(splits, self.target_splits)
376
+ active = self.strategy if self.strategy is not None else TfidfCosineStrategy()
377
+ drop: dict[str, list[int]] = {}
378
+ evidence_pairs: dict[str, list[tuple[int, int, float]]] = {}
379
+ n_affected = 0
380
+ for split_name in targets:
381
+ slice_ = splits[split_name]
382
+ texts = slice_.features
383
+ if len(texts) <= 1:
384
+ continue
385
+ report = near_dedup(texts, threshold=self.threshold, strategy=active)
386
+ dropped = [pair[0] for pair in report.dropped_pairs]
387
+ if dropped:
388
+ drop[split_name] = sorted(set(dropped))
389
+ evidence_pairs[split_name] = list(report.dropped_pairs)
390
+ n_affected += len(drop[split_name])
391
+ return LeakageFinding(
392
+ check_name=self.name,
393
+ severity=self.severity,
394
+ drop_indices=drop,
395
+ evidence={
396
+ "threshold": self.threshold,
397
+ "strategy": type(active).__name__,
398
+ "dropped_pairs_by_split": evidence_pairs,
399
+ },
400
+ message=(
401
+ f"near-duplicate dedup affected {n_affected} rows "
402
+ f"(threshold={self.threshold:.2f})"
403
+ if n_affected
404
+ else f"no near-duplicates above threshold={self.threshold:.2f}"
405
+ ),
406
+ n_affected=n_affected,
407
+ )
408
+
409
+ def validate_label_split(
410
+ self, splits: Mapping[str, EvalSlice]
411
+ ) -> tuple[LeakageFinding, LeakageFinding]:
412
+ """Emit (same_label, cross_label) findings for the label-aware mode.
413
+
414
+ Pairs each near-duplicate pair within a split by whether the two
415
+ ends share the slice's ``y_true`` label. Requires the slice to
416
+ carry binary labels (eval-toolkit's standard EvalSlice contract).
417
+ Always available regardless of ``label_aware`` value — callers
418
+ opt in by invoking this method rather than :meth:`validate`.
419
+ """
420
+ targets = _select_targets(splits, self.target_splits)
421
+ active = self.strategy if self.strategy is not None else TfidfCosineStrategy()
422
+ same_drop: dict[str, list[int]] = {}
423
+ cross_drop: dict[str, list[int]] = {}
424
+ same_pairs_by_split: dict[str, list[tuple[int, int, float]]] = {}
425
+ cross_pairs_by_split: dict[str, list[tuple[int, int, float]]] = {}
426
+ n_same = 0
427
+ n_cross = 0
428
+ for split_name in targets:
429
+ slice_ = splits[split_name]
430
+ texts = slice_.features
431
+ if len(texts) <= 1:
432
+ continue
433
+ labels = slice_.y_true
434
+ report = near_dedup(texts, threshold=self.threshold, strategy=active)
435
+ same_pairs: list[tuple[int, int, float]] = []
436
+ cross_pairs: list[tuple[int, int, float]] = []
437
+ for i, j, sim in report.dropped_pairs:
438
+ if labels[i] == labels[j]:
439
+ same_pairs.append((i, j, sim))
440
+ else:
441
+ cross_pairs.append((i, j, sim))
442
+ if same_pairs:
443
+ same_drop[split_name] = sorted({p[0] for p in same_pairs})
444
+ same_pairs_by_split[split_name] = same_pairs
445
+ n_same += len(same_drop[split_name])
446
+ if cross_pairs:
447
+ cross_drop[split_name] = sorted({p[0] for p in cross_pairs})
448
+ cross_pairs_by_split[split_name] = cross_pairs
449
+ n_cross += len(cross_drop[split_name])
450
+ same_finding = LeakageFinding(
451
+ check_name=f"{self.name}.same_label",
452
+ severity=self.severity_same_label,
453
+ drop_indices=same_drop,
454
+ evidence={
455
+ "threshold": self.threshold,
456
+ "strategy": type(active).__name__,
457
+ "dropped_pairs_by_split": same_pairs_by_split,
458
+ "label_polarity": "same",
459
+ },
460
+ message=(
461
+ f"same-label near-duplicates affected {n_same} rows "
462
+ f"(threshold={self.threshold:.2f})"
463
+ if n_same
464
+ else f"no same-label near-duplicates above threshold={self.threshold:.2f}"
465
+ ),
466
+ n_affected=n_same,
467
+ )
468
+ cross_finding = LeakageFinding(
469
+ check_name=f"{self.name}.cross_label",
470
+ severity=self.severity_cross_label,
471
+ drop_indices=cross_drop,
472
+ evidence={
473
+ "threshold": self.threshold,
474
+ "strategy": type(active).__name__,
475
+ "dropped_pairs_by_split": cross_pairs_by_split,
476
+ "label_polarity": "cross",
477
+ },
478
+ message=(
479
+ f"cross-label near-duplicates affected {n_cross} rows "
480
+ f"(threshold={self.threshold:.2f}) — label-conflict signal"
481
+ if n_cross
482
+ else f"no cross-label near-duplicates above threshold={self.threshold:.2f}"
483
+ ),
484
+ n_affected=n_cross,
485
+ )
486
+ return same_finding, cross_finding
487
+
488
+
489
+ @dataclass(frozen=True, slots=True)
490
+ class NormalizedFormLeakageCheck:
491
+ """Encoding-obfuscated duplicate detection.
492
+
493
+ NFKC + zero-width / Symbol-Other strip + whitespace collapse + lowercase
494
+ before hashing. Catches the dominant unfixed leakage class in
495
+ prompt-injection corpora (encoding-obfuscated dupes detect at 21.3 %
496
+ under naive dedup but achieve 76.2 % attack success rate per
497
+ [PI_HackAPrompt_SQuAD 2025]). Severity defaults to ``"error"`` —
498
+ encoding-obfuscated overlap is dangerous.
499
+
500
+ Parameters
501
+ ----------
502
+ target_splits : sequence of str or None, optional
503
+ Splits to scan. ``None`` = all.
504
+ severity : {"error", "warning", "info"}, optional
505
+ Default ``"error"`` (this is the obfuscation-aware variant — opt
506
+ out via ``"warning"`` only when consumer trusts upstream cleaning).
507
+ """
508
+
509
+ target_splits: Sequence[str] | None = None
510
+ severity: Severity = "error"
511
+
512
+ @property
513
+ def name(self) -> str:
514
+ """Stable check identifier."""
515
+ return "NormalizedFormLeakageCheck"
516
+
517
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
518
+ """Drop rows whose aggressively-normalized text collides with another."""
519
+ targets = _select_targets(splits, self.target_splits)
520
+ drop: dict[str, list[int]] = {}
521
+ collisions: dict[str, list[tuple[int, int]]] = {}
522
+ n_affected = 0
523
+ for split_name in targets:
524
+ slice_ = splits[split_name]
525
+ texts = slice_.features
526
+ if len(texts) <= 1:
527
+ continue
528
+ seen: dict[str, int] = {}
529
+ split_drops: list[int] = []
530
+ split_collisions: list[tuple[int, int]] = []
531
+ for i, text in enumerate(texts):
532
+ key = sha256_text(_aggressive_normalize(text), normalize=False)
533
+ if key in seen:
534
+ split_drops.append(i)
535
+ split_collisions.append((i, seen[key]))
536
+ else:
537
+ seen[key] = i
538
+ if split_drops:
539
+ drop[split_name] = sorted(set(split_drops))
540
+ collisions[split_name] = split_collisions
541
+ n_affected += len(drop[split_name])
542
+ return LeakageFinding(
543
+ check_name=self.name,
544
+ severity=self.severity,
545
+ drop_indices=drop,
546
+ evidence={"collisions_by_split": collisions},
547
+ message=(
548
+ f"encoding-obfuscated duplicates: {n_affected} rows collide "
549
+ f"after NFKC / zero-width / Symbol-Other strip"
550
+ if n_affected
551
+ else "no encoding-obfuscated duplicates found"
552
+ ),
553
+ n_affected=n_affected,
554
+ )
555
+
556
+
557
+ @dataclass(frozen=True, slots=True)
558
+ class LabelConflictCheck:
559
+ """Cross-source label conflicts: same text, different labels across splits.
560
+
561
+ The conflict resolution that ``prompt-injection-sdd`` and
562
+ ``prompt_injection_detector`` reimplement today (~50 LOC each). Severity
563
+ defaults to ``"error"``: the same prompt receiving different labels in
564
+ train vs. test poisons evaluation regardless of which label is "correct".
565
+
566
+ Parameters
567
+ ----------
568
+ target_splits : sequence of str or None, optional
569
+ Splits to compare. ``None`` = all keys; pairs (a, b) with a < b are
570
+ scanned. Use ``("train", "test")`` for the most common case.
571
+ severity : {"error", "warning", "info"}, optional
572
+ Default ``"error"``.
573
+ """
574
+
575
+ target_splits: Sequence[str] | None = None
576
+ severity: Severity = "error"
577
+
578
+ @property
579
+ def name(self) -> str:
580
+ """Stable check identifier."""
581
+ return "LabelConflictCheck"
582
+
583
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
584
+ """Find rows whose normalized text appears with different labels."""
585
+ targets = _select_targets(splits, self.target_splits)
586
+ # Build text -> {(split, idx, label)} lookup
587
+ text_to_rows: dict[str, list[tuple[str, int, int]]] = defaultdict(list)
588
+ for split_name in targets:
589
+ slice_ = splits[split_name]
590
+ texts = slice_.features
591
+ labels = slice_.y_true
592
+ for i, text in enumerate(texts):
593
+ key = sha256_text(text, normalize=True)
594
+ text_to_rows[key].append((split_name, i, int(labels[i])))
595
+ # Collect rows where the same key carries multiple distinct labels
596
+ drop: dict[str, list[int]] = defaultdict(list)
597
+ conflicts: list[dict[str, object]] = []
598
+ for rows in text_to_rows.values():
599
+ distinct_labels = {label for _, _, label in rows}
600
+ if len(distinct_labels) > 1:
601
+ conflicts.append(
602
+ {
603
+ "rows": [{"split": s, "index": i, "label": lbl} for s, i, lbl in rows],
604
+ "labels": sorted(distinct_labels),
605
+ }
606
+ )
607
+ for split_name, idx, _ in rows:
608
+ drop[split_name].append(idx)
609
+ return LeakageFinding(
610
+ check_name=self.name,
611
+ severity=self.severity,
612
+ drop_indices={k: sorted(set(v)) for k, v in drop.items()},
613
+ evidence={"conflicts": conflicts},
614
+ message=(
615
+ f"{len(conflicts)} text(s) carry conflicting labels across splits"
616
+ if conflicts
617
+ else "no cross-split label conflicts"
618
+ ),
619
+ n_affected=sum(len(v) for v in drop.values()),
620
+ )
621
+
622
+
623
+ # ---------------------------------------------------------------------------
624
+ # Cross-split checks
625
+ # ---------------------------------------------------------------------------
626
+
627
+
628
+ @dataclass(frozen=True, slots=True)
629
+ class CrossSplitLeakageCheck:
630
+ """Train↔eval near-duplicate leakage (the genuinely dangerous one).
631
+
632
+ Wraps :func:`eval_toolkit.text_dedup.cross_dedup` to drop rows in any
633
+ eval split that are near-duplicate to any row in the training split.
634
+ Severity is ``"error"``: cross-split dup is the canonical leakage that
635
+ inflates eval metrics.
636
+
637
+ Parameters
638
+ ----------
639
+ train_split : str, optional
640
+ Reference split (the corpus the eval rows must NOT echo). Default ``"train"``.
641
+ eval_splits : sequence of str or None, optional
642
+ Splits to scrub. ``None`` = every key except ``train_split``.
643
+ threshold : float, optional
644
+ Similarity threshold in (0, 1). Default 0.9.
645
+ strategy : SimilarityStrategy or None, optional
646
+ Backend. ``None`` instantiates :class:`TfidfCosineStrategy`.
647
+ severity : {"error", "warning", "info"}, optional
648
+ Default ``"error"``.
649
+ label_aware : bool, optional
650
+ When ``True``, :meth:`validate_label_split` emits two findings — one
651
+ for eval rows whose matched train neighbor shares the eval row's
652
+ label (same-label leakage = memorization risk) and one for
653
+ cross-label matches (supervision conflict + memorization). Default
654
+ ``False`` preserves the single-finding contract.
655
+ severity_same_label : {"error", "warning", "info"}, optional
656
+ Severity for the same-label finding when ``label_aware=True``.
657
+ Default ``"warning"``.
658
+ severity_cross_label : {"error", "warning", "info"}, optional
659
+ Severity for the cross-label finding when ``label_aware=True``.
660
+ Default ``"error"``.
661
+ """
662
+
663
+ train_split: str = "train"
664
+ eval_splits: Sequence[str] | None = None
665
+ threshold: float = DEFAULT_DEDUP_THRESHOLD
666
+ strategy: SimilarityStrategy | None = None
667
+ severity: Severity = "error"
668
+ label_aware: bool = False
669
+ severity_same_label: Severity = "warning"
670
+ severity_cross_label: Severity = "error"
671
+
672
+ @property
673
+ def name(self) -> str:
674
+ """Stable check identifier."""
675
+ return "CrossSplitLeakageCheck"
676
+
677
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
678
+ """Find eval-side rows that near-duplicate any train-side row.
679
+
680
+ Raises
681
+ ------
682
+ KeyError
683
+ If ``self.train_split`` is missing from ``splits``, or if any
684
+ entry in ``self.eval_splits`` is missing.
685
+ """
686
+ if self.train_split not in splits:
687
+ raise KeyError(f"train_split {self.train_split!r} not in splits")
688
+ eval_targets = (
689
+ list(self.eval_splits)
690
+ if self.eval_splits is not None
691
+ else [k for k in splits if k != self.train_split]
692
+ )
693
+ train_texts = splits[self.train_split].features
694
+ active = self.strategy if self.strategy is not None else TfidfCosineStrategy()
695
+ drop: dict[str, list[int]] = {}
696
+ n_affected = 0
697
+ for eval_name in eval_targets:
698
+ if eval_name not in splits:
699
+ raise KeyError(f"eval split {eval_name!r} not in splits")
700
+ eval_texts = splits[eval_name].features
701
+ if not train_texts or not eval_texts:
702
+ continue
703
+ kept = cross_dedup(
704
+ train_texts,
705
+ eval_texts,
706
+ threshold=self.threshold,
707
+ strategy=active,
708
+ )
709
+ kept_set = set(kept)
710
+ dropped = [i for i in range(len(eval_texts)) if i not in kept_set]
711
+ if dropped:
712
+ drop[eval_name] = dropped
713
+ n_affected += len(dropped)
714
+ return LeakageFinding(
715
+ check_name=self.name,
716
+ severity=self.severity,
717
+ drop_indices=drop,
718
+ evidence={
719
+ "train_split": self.train_split,
720
+ "eval_splits": eval_targets,
721
+ "threshold": self.threshold,
722
+ "strategy": type(active).__name__,
723
+ },
724
+ message=(
725
+ f"cross-split leakage: {n_affected} eval rows near-duplicate "
726
+ f"to {self.train_split!r} (threshold={self.threshold:.2f})"
727
+ if n_affected
728
+ else f"no cross-split leakage above threshold={self.threshold:.2f}"
729
+ ),
730
+ n_affected=n_affected,
731
+ )
732
+
733
+ def validate_label_split(
734
+ self, splits: Mapping[str, EvalSlice]
735
+ ) -> tuple[LeakageFinding, LeakageFinding]:
736
+ """Emit (same_label, cross_label) findings for cross-split leakage.
737
+
738
+ Pairs each cross-split near-duplicate by whether the matched train
739
+ and eval rows carry the same label. Uses
740
+ :func:`cross_dedup_pairs` (v0.17.0) so the train-side index of each
741
+ match is preserved. Same-label = memorization risk; cross-label =
742
+ supervision conflict + memorization.
743
+
744
+ Raises
745
+ ------
746
+ KeyError
747
+ If ``self.train_split`` is missing from ``splits``, or if any
748
+ entry in ``self.eval_splits`` is missing.
749
+ """
750
+ if self.train_split not in splits:
751
+ raise KeyError(f"train_split {self.train_split!r} not in splits")
752
+ eval_targets = (
753
+ list(self.eval_splits)
754
+ if self.eval_splits is not None
755
+ else [k for k in splits if k != self.train_split]
756
+ )
757
+ train_slice = splits[self.train_split]
758
+ train_texts = train_slice.features
759
+ train_labels = train_slice.y_true
760
+ active = self.strategy if self.strategy is not None else TfidfCosineStrategy()
761
+ same_drop: dict[str, list[int]] = {}
762
+ cross_drop: dict[str, list[int]] = {}
763
+ same_pairs_by_split: dict[str, list[tuple[int, int, float]]] = {}
764
+ cross_pairs_by_split: dict[str, list[tuple[int, int, float]]] = {}
765
+ n_same = 0
766
+ n_cross = 0
767
+ for eval_name in eval_targets:
768
+ if eval_name not in splits:
769
+ raise KeyError(f"eval split {eval_name!r} not in splits")
770
+ eval_slice = splits[eval_name]
771
+ eval_texts = eval_slice.features
772
+ eval_labels = eval_slice.y_true
773
+ if not train_texts or not eval_texts:
774
+ continue
775
+ pairs = cross_dedup_pairs(
776
+ train_texts,
777
+ eval_texts,
778
+ threshold=self.threshold,
779
+ strategy=active,
780
+ )
781
+ same_pairs: list[tuple[int, int, float]] = []
782
+ cross_pairs: list[tuple[int, int, float]] = []
783
+ for eval_idx, train_idx, sim in pairs:
784
+ if eval_labels[eval_idx] == train_labels[train_idx]:
785
+ same_pairs.append((eval_idx, train_idx, sim))
786
+ else:
787
+ cross_pairs.append((eval_idx, train_idx, sim))
788
+ if same_pairs:
789
+ same_drop[eval_name] = sorted({p[0] for p in same_pairs})
790
+ same_pairs_by_split[eval_name] = same_pairs
791
+ n_same += len(same_drop[eval_name])
792
+ if cross_pairs:
793
+ cross_drop[eval_name] = sorted({p[0] for p in cross_pairs})
794
+ cross_pairs_by_split[eval_name] = cross_pairs
795
+ n_cross += len(cross_drop[eval_name])
796
+ evidence_base: dict[str, object] = {
797
+ "train_split": self.train_split,
798
+ "eval_splits": eval_targets,
799
+ "threshold": self.threshold,
800
+ "strategy": type(active).__name__,
801
+ }
802
+ same_finding = LeakageFinding(
803
+ check_name=f"{self.name}.same_label",
804
+ severity=self.severity_same_label,
805
+ drop_indices=same_drop,
806
+ evidence={
807
+ **evidence_base,
808
+ "pairs_by_split": same_pairs_by_split,
809
+ "label_polarity": "same",
810
+ },
811
+ message=(
812
+ f"same-label cross-split leakage: {n_same} eval rows near-duplicate "
813
+ f"to {self.train_split!r} sharing label (threshold={self.threshold:.2f})"
814
+ if n_same
815
+ else f"no same-label cross-split leakage above threshold={self.threshold:.2f}"
816
+ ),
817
+ n_affected=n_same,
818
+ )
819
+ cross_finding = LeakageFinding(
820
+ check_name=f"{self.name}.cross_label",
821
+ severity=self.severity_cross_label,
822
+ drop_indices=cross_drop,
823
+ evidence={
824
+ **evidence_base,
825
+ "pairs_by_split": cross_pairs_by_split,
826
+ "label_polarity": "cross",
827
+ },
828
+ message=(
829
+ f"cross-label cross-split leakage: {n_cross} eval rows near-duplicate "
830
+ f"to {self.train_split!r} with opposing label "
831
+ f"(threshold={self.threshold:.2f}) — supervision conflict"
832
+ if n_cross
833
+ else f"no cross-label cross-split leakage above threshold={self.threshold:.2f}"
834
+ ),
835
+ n_affected=n_cross,
836
+ )
837
+ return same_finding, cross_finding
838
+
839
+
840
+ @dataclass(frozen=True, slots=True)
841
+ class GroupLeakageCheck:
842
+ """Group-id leakage: a single group spans multiple splits.
843
+
844
+ Common failure mode in clinical / user-level / source-level evals where
845
+ rows from the same patient / user / source need to stay in the same fold.
846
+ Requires every target split to carry a ``group_col``-named column.
847
+ Severity ``"error"``.
848
+
849
+ Parameters
850
+ ----------
851
+ group_col : str
852
+ Column name carrying the group id in every target slice's dataframe.
853
+ target_splits : sequence of str or None, optional
854
+ Splits to compare. ``None`` = all keys.
855
+ severity : {"error", "warning", "info"}, optional
856
+ Default ``"error"``.
857
+ """
858
+
859
+ group_col: str
860
+ target_splits: Sequence[str] | None = None
861
+ severity: Severity = "error"
862
+
863
+ @property
864
+ def name(self) -> str:
865
+ """Stable check identifier."""
866
+ return "GroupLeakageCheck"
867
+
868
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
869
+ """Find group ids appearing in more than one target split.
870
+
871
+ Raises
872
+ ------
873
+ KeyError
874
+ If any target split's dataframe lacks ``self.group_col``.
875
+ """
876
+ targets = _select_targets(splits, self.target_splits)
877
+ group_to_splits: dict[object, set[str]] = defaultdict(set)
878
+ group_rows: dict[tuple[object, str], list[int]] = defaultdict(list)
879
+ for split_name in targets:
880
+ slice_ = splits[split_name]
881
+ if self.group_col not in slice_.df.columns:
882
+ raise KeyError(f"split {split_name!r}: missing group column {self.group_col!r}")
883
+ for i, gid in enumerate(slice_.df[self.group_col].tolist()):
884
+ group_to_splits[gid].add(split_name)
885
+ group_rows[(gid, split_name)].append(i)
886
+ offending = {gid: spls for gid, spls in group_to_splits.items() if len(spls) > 1}
887
+ drop: dict[str, list[int]] = defaultdict(list)
888
+ for (gid, split_name), idxs in group_rows.items():
889
+ if gid in offending:
890
+ drop[split_name].extend(idxs)
891
+ return LeakageFinding(
892
+ check_name=self.name,
893
+ severity=self.severity,
894
+ drop_indices={k: sorted(set(v)) for k, v in drop.items()},
895
+ evidence={
896
+ "group_col": self.group_col,
897
+ "violating_groups": {str(gid): sorted(spls) for gid, spls in offending.items()},
898
+ },
899
+ message=(
900
+ f"{len(offending)} group(s) span multiple splits " f"(group_col={self.group_col!r})"
901
+ if offending
902
+ else f"no group-id leakage on {self.group_col!r}"
903
+ ),
904
+ n_affected=sum(len(v) for v in drop.values()),
905
+ )
906
+
907
+
908
+ @dataclass(frozen=True, slots=True)
909
+ class TemporalLeakageCheck:
910
+ """Temporal ordering invariant: every earlier split's max(time) ≤ next's min(time).
911
+
912
+ Required for honest time-series eval — failure means the model could
913
+ train on data from the future and "validate" on data from the past.
914
+ Severity ``"error"``.
915
+
916
+ Parameters
917
+ ----------
918
+ time_col : str
919
+ Column name carrying a sortable timestamp / date / index in every
920
+ target split's dataframe.
921
+ split_order : sequence of str
922
+ Required temporal order, e.g. ``("train", "val", "test")``. The
923
+ check asserts that for every adjacent pair ``(a, b)``,
924
+ ``max(a[time_col]) <= min(b[time_col])``.
925
+ severity : {"error", "warning", "info"}, optional
926
+ Default ``"error"``.
927
+ """
928
+
929
+ time_col: str
930
+ split_order: Sequence[str]
931
+ severity: Severity = "error"
932
+
933
+ @property
934
+ def name(self) -> str:
935
+ """Stable check identifier."""
936
+ return "TemporalLeakageCheck"
937
+
938
+ def validate(self, splits: Mapping[str, EvalSlice]) -> LeakageFinding:
939
+ """Check the temporal ordering invariant pairwise.
940
+
941
+ Raises
942
+ ------
943
+ KeyError
944
+ If any entry in ``self.split_order`` is missing from ``splits``,
945
+ or if any target split's dataframe lacks ``self.time_col``.
946
+ """
947
+ for split_name in self.split_order:
948
+ if split_name not in splits:
949
+ raise KeyError(f"split {split_name!r} (in split_order) not in splits")
950
+ if self.time_col not in splits[split_name].df.columns:
951
+ raise KeyError(f"split {split_name!r}: missing time column {self.time_col!r}")
952
+ violations: list[dict[str, object]] = []
953
+ drop: dict[str, list[int]] = defaultdict(list)
954
+ for earlier, later in zip(self.split_order, self.split_order[1:], strict=False):
955
+ t_earlier = splits[earlier].df[self.time_col]
956
+ t_later = splits[later].df[self.time_col]
957
+ if t_earlier.empty or t_later.empty:
958
+ continue
959
+ max_earlier = t_earlier.max()
960
+ min_later = t_later.min()
961
+ if max_earlier > min_later:
962
+ # Drop offending rows from the LATER split (those that come
963
+ # before max_earlier).
964
+ bad = t_later[t_later <= max_earlier].index.tolist()
965
+ # Map dataframe index -> positional index in slice features
966
+ later_df = splits[later].df.reset_index(drop=False)
967
+ positional = [int(p) for p, dfi in enumerate(later_df["index"]) if dfi in bad]
968
+ drop[later].extend(positional)
969
+ violations.append(
970
+ {
971
+ "earlier_split": earlier,
972
+ "later_split": later,
973
+ "max_earlier": str(max_earlier),
974
+ "min_later": str(min_later),
975
+ "n_offending_in_later": len(positional),
976
+ }
977
+ )
978
+ return LeakageFinding(
979
+ check_name=self.name,
980
+ severity=self.severity,
981
+ drop_indices={k: sorted(set(v)) for k, v in drop.items()},
982
+ evidence={
983
+ "time_col": self.time_col,
984
+ "split_order": list(self.split_order),
985
+ "violations": violations,
986
+ },
987
+ message=(
988
+ f"temporal ordering violated at {len(violations)} boundary/boundaries"
989
+ if violations
990
+ else f"temporal ordering OK on {self.time_col!r}"
991
+ ),
992
+ n_affected=sum(len(v) for v in drop.values()),
993
+ )
994
+
995
+
996
+ # ---------------------------------------------------------------------------
997
+ # Aggregator
998
+ # ---------------------------------------------------------------------------
999
+
1000
+
1001
+ def run_leakage_checks(
1002
+ checks: Sequence[LeakageCheck], splits: Mapping[str, EvalSlice]
1003
+ ) -> LeakageReport:
1004
+ """Run a sequence of :class:`LeakageCheck` over ``splits`` and aggregate.
1005
+
1006
+ Pure: no IO, no mutation. Each check is run sequentially in the order
1007
+ given. Use this standalone for offline / CI audits, or pass the same
1008
+ ``checks`` list to :func:`eval_toolkit.harness.evaluate` for inline
1009
+ enforcement.
1010
+
1011
+ Parameters
1012
+ ----------
1013
+ checks : sequence of LeakageCheck
1014
+ splits : Mapping[str, EvalSlice]
1015
+ Output of :class:`~eval_toolkit.loaders.DatasetLoader.load_splits`.
1016
+
1017
+ Returns
1018
+ -------
1019
+ LeakageReport
1020
+
1021
+ Raises
1022
+ ------
1023
+ KeyError
1024
+ If a check references a split / column not present in ``splits``.
1025
+
1026
+ Examples
1027
+ --------
1028
+ >>> import pandas as pd
1029
+ >>> from eval_toolkit.harness import EvalSlice
1030
+ >>> from eval_toolkit.leakage import (
1031
+ ... ExactDuplicateCheck, run_leakage_checks,
1032
+ ... )
1033
+ >>> df_train = pd.DataFrame({"text": ["a", "a", "b"], "label": [0, 0, 1]})
1034
+ >>> df_test = pd.DataFrame({"text": ["c", "d"], "label": [0, 1]})
1035
+ >>> splits = {
1036
+ ... "train": EvalSlice(name="train", df=df_train),
1037
+ ... "test": EvalSlice(name="test", df=df_test),
1038
+ ... }
1039
+ >>> report = run_leakage_checks([ExactDuplicateCheck()], splits)
1040
+ >>> report.has_errors()
1041
+ False
1042
+ >>> sum(f.n_affected for f in report.findings)
1043
+ 1
1044
+ """
1045
+ findings: list[LeakageFinding] = []
1046
+ for check in checks:
1047
+ findings.append(check.validate(splits))
1048
+ return LeakageReport(findings=findings)
1049
+
1050
+
1051
+ # Suppress unused-import warnings: pd / np are referenced in Examples blocks.
1052
+ _ = (pd, np)