if-split 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ifsplit/split.py ADDED
@@ -0,0 +1,177 @@
1
+ """Stage 6 - Deterministic component -> split assignment (the reproducibility core).
2
+
3
+ Each *component* (a leakage-safe group of sequence clusters joined by shared
4
+ multi-chain entries; see cluster.py) is assigned to a split by
5
+ ``blake2b(salt + ':' + component_key)`` mapped onto the cumulative
6
+ ``split_fractions``. Same salt + same key -> same split, forever, independent of
7
+ how many other components exist - so a larger snapshot only *adds* components and
8
+ never moves existing ones.
9
+
10
+ An optional ``registry`` (component_key -> split) pins prior assignments: if a
11
+ key is already in the registry its recorded split wins over the hash, so growth
12
+ is stable even if a component's canonical key shifts (e.g. a smaller-id member
13
+ joins later).
14
+
15
+ **No-leakage is structural, not heuristic.** Because every entity an entry
16
+ touches lives in the same component (union-find merged them), and a component
17
+ maps to exactly one split, two splits can never share a sequence cluster. The
18
+ ``check_no_leakage`` invariant re-derives this from the cluster membership and
19
+ fails loudly if it is ever violated - a real guard, not a tautology.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import hashlib
25
+ from dataclasses import dataclass, field
26
+
27
+ from .cluster import ClusterResult
28
+ from .config import Config
29
+
30
+ SPLITS = ("train", "val", "test")
31
+
32
+
33
+ def bucket(key: str, salt: str) -> float:
34
+ """Uniform float in [0, 1) from a stable hash of ``salt:key``."""
35
+ digest = hashlib.blake2b(f"{salt}:{key}".encode(), digest_size=8).digest()
36
+ return int.from_bytes(digest, "big") / 2**64
37
+
38
+
39
+ def split_for_key(key: str, cfg: Config) -> str:
40
+ """Map a component key to a split via the cumulative fractions."""
41
+ b = bucket(key, cfg.split_salt)
42
+ sf = cfg.split_fractions
43
+ if b < sf.train:
44
+ return "train"
45
+ if b < sf.train + sf.val:
46
+ return "val"
47
+ return "test"
48
+
49
+
50
+ @dataclass
51
+ class SplitResult:
52
+ cluster_split: dict[str, str] # component key -> split
53
+ entry_split: dict[str, str] # entry_id -> split
54
+ counts: dict[str, int] # split -> entry count
55
+ cluster_counts: dict[str, int] # split -> component count
56
+ # Per-class test floors that could not be fully met (class -> shortfall). Empty
57
+ # when no minimums were requested or all were satisfied. Reported, never forced.
58
+ minimum_shortfalls: dict[str, int] = field(default_factory=dict)
59
+
60
+
61
+ def _enforce_test_minimums(
62
+ cluster_split: dict[str, str],
63
+ clusters: ClusterResult,
64
+ cfg: Config,
65
+ entry_classes: dict[str, list[str]],
66
+ registry: dict[str, str],
67
+ ) -> tuple[dict[str, str], dict[str, int]]:
68
+ """Recruit whole components into test until per-class floors are met.
69
+
70
+ Leakage-safe (moves components, never entries), deterministic (hash-ordered),
71
+ and growth-stable (never overrides a registry-pinned component). Returns the
72
+ updated ``cluster_split`` and a ``{class: shortfall}`` map for any floor that
73
+ could not be fully satisfied from the available supply.
74
+ """
75
+ minimums = {c: n for c, n in cfg.test_min_per_class.items() if n > 0}
76
+ if not minimums:
77
+ return cluster_split, {}
78
+
79
+ cluster_split = dict(cluster_split)
80
+ # Per-component count of entries carrying each class (so test totals update O(1)).
81
+ comp_class_counts: dict[str, dict[str, int]] = {}
82
+ for key, entries in clusters.cluster_members.items():
83
+ counts: dict[str, int] = {}
84
+ for e in entries:
85
+ for cls in entry_classes.get(e, []):
86
+ counts[cls] = counts.get(cls, 0) + 1
87
+ comp_class_counts[key] = counts
88
+
89
+ test_totals: dict[str, int] = {}
90
+ for key, split in cluster_split.items():
91
+ if split == "test":
92
+ for cls, n in comp_class_counts[key].items():
93
+ test_totals[cls] = test_totals.get(cls, 0) + n
94
+
95
+ shortfalls: dict[str, int] = {}
96
+ for cls in sorted(minimums):
97
+ need = minimums[cls]
98
+ eligible = [
99
+ key
100
+ for key in clusters.cluster_members
101
+ if cluster_split[key] != "test"
102
+ and comp_class_counts[key].get(cls, 0) > 0
103
+ and key not in registry # respect pinned assignments (growth stability)
104
+ ]
105
+ eligible.sort(key=lambda k: (bucket(k, cfg.split_salt), k))
106
+ for key in eligible:
107
+ if test_totals.get(cls, 0) >= need:
108
+ break
109
+ cluster_split[key] = "test"
110
+ for c, n in comp_class_counts[key].items():
111
+ test_totals[c] = test_totals.get(c, 0) + n
112
+ deficit = need - test_totals.get(cls, 0)
113
+ if deficit > 0:
114
+ shortfalls[cls] = deficit
115
+ return cluster_split, shortfalls
116
+
117
+
118
+ def assign_splits(
119
+ clusters: ClusterResult,
120
+ cfg: Config,
121
+ registry: dict[str, str] | None = None,
122
+ entry_classes: dict[str, list[str]] | None = None,
123
+ ) -> SplitResult:
124
+ """Assign every component (and thus every entry) to a split.
125
+
126
+ With ``cfg.test_min_per_class`` set and ``entry_classes`` provided, a
127
+ deterministic top-up recruits whole components into test to meet per-class
128
+ floors (see :func:`_enforce_test_minimums`).
129
+ """
130
+ registry = registry or {}
131
+
132
+ cluster_split: dict[str, str] = {
133
+ key: registry.get(key, split_for_key(key, cfg)) for key in clusters.cluster_members
134
+ }
135
+
136
+ shortfalls: dict[str, int] = {}
137
+ if cfg.test_min_per_class and entry_classes is not None:
138
+ cluster_split, shortfalls = _enforce_test_minimums(
139
+ cluster_split, clusters, cfg, entry_classes, registry
140
+ )
141
+
142
+ counts = {s: 0 for s in SPLITS}
143
+ entry_split: dict[str, str] = {}
144
+ for entry, key in clusters.entry_to_cluster.items():
145
+ s = cluster_split[key]
146
+ entry_split[entry] = s
147
+ counts[s] += 1
148
+
149
+ cluster_counts = {s: 0 for s in SPLITS}
150
+ for s in cluster_split.values():
151
+ cluster_counts[s] += 1
152
+
153
+ return SplitResult(
154
+ cluster_split=dict(sorted(cluster_split.items())),
155
+ entry_split=dict(sorted(entry_split.items())),
156
+ counts=counts,
157
+ cluster_counts=cluster_counts,
158
+ minimum_shortfalls=dict(sorted(shortfalls.items())),
159
+ )
160
+
161
+
162
+ def check_no_leakage(result: SplitResult, clusters: ClusterResult) -> None:
163
+ """Verify no sequence cluster spans two splits. Raises on violation.
164
+
165
+ Genuine check (not a tautology): for every entry, every *raw* sequence
166
+ cluster it touches must resolve to the entry's own split. Union-find
167
+ guarantees this, so a failure means a real bug upstream.
168
+ """
169
+ raw_to_split: dict[str, str] = {}
170
+ for entry, raw_keys in clusters.entry_raw_clusters.items():
171
+ split = result.entry_split[entry]
172
+ for rk in raw_keys:
173
+ prior = raw_to_split.setdefault(rk, split)
174
+ if prior != split:
175
+ raise AssertionError(
176
+ f"leakage: raw cluster {rk} appears in both {prior!r} and {split!r}"
177
+ )