if-split 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- if_split-0.1.0.dist-info/METADATA +312 -0
- if_split-0.1.0.dist-info/RECORD +20 -0
- if_split-0.1.0.dist-info/WHEEL +4 -0
- if_split-0.1.0.dist-info/entry_points.txt +2 -0
- if_split-0.1.0.dist-info/licenses/LICENSE +21 -0
- ifsplit/__init__.py +8 -0
- ifsplit/__main__.py +8 -0
- ifsplit/cli.py +317 -0
- ifsplit/cluster.py +130 -0
- ifsplit/config.py +146 -0
- ifsplit/dataset.py +112 -0
- ifsplit/download.py +229 -0
- ifsplit/enumerate.py +111 -0
- ifsplit/hydrate.py +216 -0
- ifsplit/ligands.py +267 -0
- ifsplit/manifest.py +417 -0
- ifsplit/parse.py +111 -0
- ifsplit/rcsb.py +251 -0
- ifsplit/schema.py +241 -0
- ifsplit/split.py +177 -0
ifsplit/split.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Stage 6 - Deterministic component -> split assignment (the reproducibility core).
|
|
2
|
+
|
|
3
|
+
Each *component* (a leakage-safe group of sequence clusters joined by shared
|
|
4
|
+
multi-chain entries; see cluster.py) is assigned to a split by
|
|
5
|
+
``blake2b(salt + ':' + component_key)`` mapped onto the cumulative
|
|
6
|
+
``split_fractions``. Same salt + same key -> same split, forever, independent of
|
|
7
|
+
how many other components exist - so a larger snapshot only *adds* components and
|
|
8
|
+
never moves existing ones.
|
|
9
|
+
|
|
10
|
+
An optional ``registry`` (component_key -> split) pins prior assignments: if a
|
|
11
|
+
key is already in the registry its recorded split wins over the hash, so growth
|
|
12
|
+
is stable even if a component's canonical key shifts (e.g. a smaller-id member
|
|
13
|
+
joins later).
|
|
14
|
+
|
|
15
|
+
**No-leakage is structural, not heuristic.** Because every entity an entry
|
|
16
|
+
touches lives in the same component (union-find merged them), and a component
|
|
17
|
+
maps to exactly one split, two splits can never share a sequence cluster. The
|
|
18
|
+
``check_no_leakage`` invariant re-derives this from the cluster membership and
|
|
19
|
+
fails loudly if it is ever violated - a real guard, not a tautology.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import hashlib
|
|
25
|
+
from dataclasses import dataclass, field
|
|
26
|
+
|
|
27
|
+
from .cluster import ClusterResult
|
|
28
|
+
from .config import Config
|
|
29
|
+
|
|
30
|
+
SPLITS = ("train", "val", "test")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def bucket(key: str, salt: str) -> float:
|
|
34
|
+
"""Uniform float in [0, 1) from a stable hash of ``salt:key``."""
|
|
35
|
+
digest = hashlib.blake2b(f"{salt}:{key}".encode(), digest_size=8).digest()
|
|
36
|
+
return int.from_bytes(digest, "big") / 2**64
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def split_for_key(key: str, cfg: Config) -> str:
|
|
40
|
+
"""Map a component key to a split via the cumulative fractions."""
|
|
41
|
+
b = bucket(key, cfg.split_salt)
|
|
42
|
+
sf = cfg.split_fractions
|
|
43
|
+
if b < sf.train:
|
|
44
|
+
return "train"
|
|
45
|
+
if b < sf.train + sf.val:
|
|
46
|
+
return "val"
|
|
47
|
+
return "test"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class SplitResult:
|
|
52
|
+
cluster_split: dict[str, str] # component key -> split
|
|
53
|
+
entry_split: dict[str, str] # entry_id -> split
|
|
54
|
+
counts: dict[str, int] # split -> entry count
|
|
55
|
+
cluster_counts: dict[str, int] # split -> component count
|
|
56
|
+
# Per-class test floors that could not be fully met (class -> shortfall). Empty
|
|
57
|
+
# when no minimums were requested or all were satisfied. Reported, never forced.
|
|
58
|
+
minimum_shortfalls: dict[str, int] = field(default_factory=dict)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _enforce_test_minimums(
|
|
62
|
+
cluster_split: dict[str, str],
|
|
63
|
+
clusters: ClusterResult,
|
|
64
|
+
cfg: Config,
|
|
65
|
+
entry_classes: dict[str, list[str]],
|
|
66
|
+
registry: dict[str, str],
|
|
67
|
+
) -> tuple[dict[str, str], dict[str, int]]:
|
|
68
|
+
"""Recruit whole components into test until per-class floors are met.
|
|
69
|
+
|
|
70
|
+
Leakage-safe (moves components, never entries), deterministic (hash-ordered),
|
|
71
|
+
and growth-stable (never overrides a registry-pinned component). Returns the
|
|
72
|
+
updated ``cluster_split`` and a ``{class: shortfall}`` map for any floor that
|
|
73
|
+
could not be fully satisfied from the available supply.
|
|
74
|
+
"""
|
|
75
|
+
minimums = {c: n for c, n in cfg.test_min_per_class.items() if n > 0}
|
|
76
|
+
if not minimums:
|
|
77
|
+
return cluster_split, {}
|
|
78
|
+
|
|
79
|
+
cluster_split = dict(cluster_split)
|
|
80
|
+
# Per-component count of entries carrying each class (so test totals update O(1)).
|
|
81
|
+
comp_class_counts: dict[str, dict[str, int]] = {}
|
|
82
|
+
for key, entries in clusters.cluster_members.items():
|
|
83
|
+
counts: dict[str, int] = {}
|
|
84
|
+
for e in entries:
|
|
85
|
+
for cls in entry_classes.get(e, []):
|
|
86
|
+
counts[cls] = counts.get(cls, 0) + 1
|
|
87
|
+
comp_class_counts[key] = counts
|
|
88
|
+
|
|
89
|
+
test_totals: dict[str, int] = {}
|
|
90
|
+
for key, split in cluster_split.items():
|
|
91
|
+
if split == "test":
|
|
92
|
+
for cls, n in comp_class_counts[key].items():
|
|
93
|
+
test_totals[cls] = test_totals.get(cls, 0) + n
|
|
94
|
+
|
|
95
|
+
shortfalls: dict[str, int] = {}
|
|
96
|
+
for cls in sorted(minimums):
|
|
97
|
+
need = minimums[cls]
|
|
98
|
+
eligible = [
|
|
99
|
+
key
|
|
100
|
+
for key in clusters.cluster_members
|
|
101
|
+
if cluster_split[key] != "test"
|
|
102
|
+
and comp_class_counts[key].get(cls, 0) > 0
|
|
103
|
+
and key not in registry # respect pinned assignments (growth stability)
|
|
104
|
+
]
|
|
105
|
+
eligible.sort(key=lambda k: (bucket(k, cfg.split_salt), k))
|
|
106
|
+
for key in eligible:
|
|
107
|
+
if test_totals.get(cls, 0) >= need:
|
|
108
|
+
break
|
|
109
|
+
cluster_split[key] = "test"
|
|
110
|
+
for c, n in comp_class_counts[key].items():
|
|
111
|
+
test_totals[c] = test_totals.get(c, 0) + n
|
|
112
|
+
deficit = need - test_totals.get(cls, 0)
|
|
113
|
+
if deficit > 0:
|
|
114
|
+
shortfalls[cls] = deficit
|
|
115
|
+
return cluster_split, shortfalls
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def assign_splits(
|
|
119
|
+
clusters: ClusterResult,
|
|
120
|
+
cfg: Config,
|
|
121
|
+
registry: dict[str, str] | None = None,
|
|
122
|
+
entry_classes: dict[str, list[str]] | None = None,
|
|
123
|
+
) -> SplitResult:
|
|
124
|
+
"""Assign every component (and thus every entry) to a split.
|
|
125
|
+
|
|
126
|
+
With ``cfg.test_min_per_class`` set and ``entry_classes`` provided, a
|
|
127
|
+
deterministic top-up recruits whole components into test to meet per-class
|
|
128
|
+
floors (see :func:`_enforce_test_minimums`).
|
|
129
|
+
"""
|
|
130
|
+
registry = registry or {}
|
|
131
|
+
|
|
132
|
+
cluster_split: dict[str, str] = {
|
|
133
|
+
key: registry.get(key, split_for_key(key, cfg)) for key in clusters.cluster_members
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
shortfalls: dict[str, int] = {}
|
|
137
|
+
if cfg.test_min_per_class and entry_classes is not None:
|
|
138
|
+
cluster_split, shortfalls = _enforce_test_minimums(
|
|
139
|
+
cluster_split, clusters, cfg, entry_classes, registry
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
counts = {s: 0 for s in SPLITS}
|
|
143
|
+
entry_split: dict[str, str] = {}
|
|
144
|
+
for entry, key in clusters.entry_to_cluster.items():
|
|
145
|
+
s = cluster_split[key]
|
|
146
|
+
entry_split[entry] = s
|
|
147
|
+
counts[s] += 1
|
|
148
|
+
|
|
149
|
+
cluster_counts = {s: 0 for s in SPLITS}
|
|
150
|
+
for s in cluster_split.values():
|
|
151
|
+
cluster_counts[s] += 1
|
|
152
|
+
|
|
153
|
+
return SplitResult(
|
|
154
|
+
cluster_split=dict(sorted(cluster_split.items())),
|
|
155
|
+
entry_split=dict(sorted(entry_split.items())),
|
|
156
|
+
counts=counts,
|
|
157
|
+
cluster_counts=cluster_counts,
|
|
158
|
+
minimum_shortfalls=dict(sorted(shortfalls.items())),
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def check_no_leakage(result: SplitResult, clusters: ClusterResult) -> None:
|
|
163
|
+
"""Verify no sequence cluster spans two splits. Raises on violation.
|
|
164
|
+
|
|
165
|
+
Genuine check (not a tautology): for every entry, every *raw* sequence
|
|
166
|
+
cluster it touches must resolve to the entry's own split. Union-find
|
|
167
|
+
guarantees this, so a failure means a real bug upstream.
|
|
168
|
+
"""
|
|
169
|
+
raw_to_split: dict[str, str] = {}
|
|
170
|
+
for entry, raw_keys in clusters.entry_raw_clusters.items():
|
|
171
|
+
split = result.entry_split[entry]
|
|
172
|
+
for rk in raw_keys:
|
|
173
|
+
prior = raw_to_split.setdefault(rk, split)
|
|
174
|
+
if prior != split:
|
|
175
|
+
raise AssertionError(
|
|
176
|
+
f"leakage: raw cluster {rk} appears in both {prior!r} and {split!r}"
|
|
177
|
+
)
|