if-split 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- if_split-0.1.0.dist-info/METADATA +312 -0
- if_split-0.1.0.dist-info/RECORD +20 -0
- if_split-0.1.0.dist-info/WHEEL +4 -0
- if_split-0.1.0.dist-info/entry_points.txt +2 -0
- if_split-0.1.0.dist-info/licenses/LICENSE +21 -0
- ifsplit/__init__.py +8 -0
- ifsplit/__main__.py +8 -0
- ifsplit/cli.py +317 -0
- ifsplit/cluster.py +130 -0
- ifsplit/config.py +146 -0
- ifsplit/dataset.py +112 -0
- ifsplit/download.py +229 -0
- ifsplit/enumerate.py +111 -0
- ifsplit/hydrate.py +216 -0
- ifsplit/ligands.py +267 -0
- ifsplit/manifest.py +417 -0
- ifsplit/parse.py +111 -0
- ifsplit/rcsb.py +251 -0
- ifsplit/schema.py +241 -0
- ifsplit/split.py +177 -0
ifsplit/cli.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""IF-Split command-line interface.
|
|
2
|
+
|
|
3
|
+
`build` runs the full pipeline: Stage 1 enumerate (RCSB Search + Data API ->
|
|
4
|
+
candidates.jsonl + dataset.lock), Stage 3 filter, Stage 4 ligand classification,
|
|
5
|
+
Stage 5 cluster, Stage 6 deterministic split, Stage 7 manifest + registry. No
|
|
6
|
+
structure coordinates are downloaded. `verify` re-enumerates from a lock and
|
|
7
|
+
reports drift; `stats` summarizes a manifest.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
from pydantic import ValidationError
|
|
16
|
+
|
|
17
|
+
from . import __version__
|
|
18
|
+
from .config import load_config
|
|
19
|
+
from .rcsb import RcsbError
|
|
20
|
+
|
|
21
|
+
DEFAULT_CONFIG = "config/default.yaml"
|
|
22
|
+
SPLITS_CHOICES = ("train", "val", "test")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def cmd_build(args: argparse.Namespace) -> int:
|
|
26
|
+
from .cluster import build_clusters
|
|
27
|
+
from .enumerate import enumerate_candidates, make_console_progress
|
|
28
|
+
from .ligands import classify_components
|
|
29
|
+
from .manifest import (
|
|
30
|
+
build_lock,
|
|
31
|
+
build_manifest,
|
|
32
|
+
build_tiers_doc,
|
|
33
|
+
read_registry,
|
|
34
|
+
write_classes,
|
|
35
|
+
write_clusters,
|
|
36
|
+
write_lock,
|
|
37
|
+
write_manifest,
|
|
38
|
+
write_registry,
|
|
39
|
+
write_split_files,
|
|
40
|
+
write_tiers,
|
|
41
|
+
)
|
|
42
|
+
from .parse import drop_summary, filter_candidates
|
|
43
|
+
from .split import assign_splits, check_no_leakage
|
|
44
|
+
|
|
45
|
+
cfg = load_config(args.config)
|
|
46
|
+
sf = cfg.split_fractions
|
|
47
|
+
assembly = "biological (assembly 1)" if cfg.use_biological_assembly else "asymmetric unit"
|
|
48
|
+
print(f"IF-Split {__version__}")
|
|
49
|
+
print(f" config file: {args.config}")
|
|
50
|
+
print(f" config hash: {cfg.config_hash()}")
|
|
51
|
+
print(f" dataset: {cfg.dataset_version}")
|
|
52
|
+
print(f" snapshot_date: {cfg.snapshot_date} (selects release_date <= this)")
|
|
53
|
+
print(f" methods: {', '.join(cfg.experimental_methods)}")
|
|
54
|
+
print(f" resolution: <= {cfg.resolution_max_A} A")
|
|
55
|
+
print(f" max residues: < {cfg.max_total_residues + 1}")
|
|
56
|
+
print(f" assembly: {assembly}")
|
|
57
|
+
print(f" identity: {cfg.identity_threshold}")
|
|
58
|
+
print(f" clustering: {cfg.clustering_backend}")
|
|
59
|
+
print(f" splits: train={sf.train} val={sf.val} test={sf.test} salt={cfg.split_salt!r}")
|
|
60
|
+
if args.limit is not None:
|
|
61
|
+
print(f" limit: {args.limit} (dev: first N by sorted entry id)")
|
|
62
|
+
print()
|
|
63
|
+
|
|
64
|
+
print("Stage 1 - enumerate candidates (Search + Data API, no coordinates):")
|
|
65
|
+
say = make_console_progress() # timestamped + line-flushed (survives redirect)
|
|
66
|
+
records, _candidates_path, sha = enumerate_candidates(
|
|
67
|
+
cfg, args.out, limit=args.limit, progress=say
|
|
68
|
+
)
|
|
69
|
+
lock_path = write_lock(
|
|
70
|
+
build_lock(
|
|
71
|
+
cfg,
|
|
72
|
+
entry_ids=[r.entry_id for r in records],
|
|
73
|
+
candidates_sha256=sha,
|
|
74
|
+
limit=args.limit,
|
|
75
|
+
),
|
|
76
|
+
args.out,
|
|
77
|
+
)
|
|
78
|
+
print(f" wrote {lock_path}")
|
|
79
|
+
|
|
80
|
+
print("Stage 3 - filter (metadata only):")
|
|
81
|
+
kept, drops = filter_candidates(records, cfg)
|
|
82
|
+
dcounts = drop_summary(drops)
|
|
83
|
+
print(f" kept {len(kept)} / {len(records)}; dropped {len(drops)} {dcounts or ''}")
|
|
84
|
+
|
|
85
|
+
print("Stage 4 - ligand classification + curation:")
|
|
86
|
+
class_map = {r.entry_id: classify_components(r, cfg) for r in kept}
|
|
87
|
+
n_artifact = sum(1 for v in class_map.values() if v["purification_artifact"])
|
|
88
|
+
print(f" classified {len(class_map)} entries; {n_artifact} purification artifact(s) flagged")
|
|
89
|
+
|
|
90
|
+
print(f"Stage 5 - cluster ({cfg.clustering_backend} @ {cfg.identity_level}%):")
|
|
91
|
+
clusters = build_clusters(kept, cfg)
|
|
92
|
+
print(
|
|
93
|
+
f" {clusters.n_clusters} leakage-safe components "
|
|
94
|
+
f"from {clusters.n_raw_clusters} raw clusters "
|
|
95
|
+
f"({len(clusters.multichain_entries)} multi-chain merged)"
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
print("Stage 6 - assign splits (deterministic hash):")
|
|
99
|
+
registry = read_registry(args.registry) if args.registry else {}
|
|
100
|
+
entry_classes = {eid: info["classes"] for eid, info in class_map.items()}
|
|
101
|
+
splits = assign_splits(clusters, cfg, registry=registry, entry_classes=entry_classes)
|
|
102
|
+
check_no_leakage(splits, clusters) # structural guarantee; raises on violation
|
|
103
|
+
c = splits.counts
|
|
104
|
+
print(f" train={c['train']} val={c['val']} test={c['test']} (no cross-split leakage)")
|
|
105
|
+
if cfg.test_min_per_class:
|
|
106
|
+
if splits.minimum_shortfalls:
|
|
107
|
+
short = ", ".join(f"{k}:{v}" for k, v in splits.minimum_shortfalls.items())
|
|
108
|
+
print(f" test minimums: applied; SHORTFALL (not enough supply) -> {short}")
|
|
109
|
+
else:
|
|
110
|
+
print(" test minimums: applied; all per-class floors met")
|
|
111
|
+
|
|
112
|
+
print("Stage 7 - manifest + registry:")
|
|
113
|
+
manifest = build_manifest(
|
|
114
|
+
cfg,
|
|
115
|
+
candidates_sha256=sha,
|
|
116
|
+
n_candidates=len(records),
|
|
117
|
+
drops=drops,
|
|
118
|
+
drop_counts=dcounts,
|
|
119
|
+
clusters=clusters,
|
|
120
|
+
splits=splits,
|
|
121
|
+
class_map=class_map,
|
|
122
|
+
)
|
|
123
|
+
mpath = write_manifest(manifest, args.out)
|
|
124
|
+
split_paths = write_split_files(splits, class_map, args.out)
|
|
125
|
+
write_clusters(clusters.entry_to_cluster, args.out)
|
|
126
|
+
write_classes(class_map, args.out)
|
|
127
|
+
write_registry(splits.cluster_split, args.out)
|
|
128
|
+
write_tiers(build_tiers_doc(class_map), args.out)
|
|
129
|
+
print(f" wrote {mpath} (provenance + counts)")
|
|
130
|
+
for s in ("train", "val", "test"):
|
|
131
|
+
print(f" wrote {split_paths[s]}")
|
|
132
|
+
test_class_paths = [p for k, p in split_paths.items() if k.startswith("test:")]
|
|
133
|
+
for p in sorted(test_class_paths):
|
|
134
|
+
print(f" wrote {p}")
|
|
135
|
+
print(" wrote clusters.json, ligands.classes.json, ligands.tiers.json, splits.registry.json")
|
|
136
|
+
print()
|
|
137
|
+
print(f"Build complete: {len(kept)} structures across train/val/test.")
|
|
138
|
+
print(f"Run `if-split stats {mpath}`.")
|
|
139
|
+
return 0
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def cmd_verify(args: argparse.Namespace) -> int:
|
|
143
|
+
from .manifest import verify_lock
|
|
144
|
+
|
|
145
|
+
return verify_lock(args.lock)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def cmd_stats(args: argparse.Namespace) -> int:
|
|
149
|
+
from .manifest import summarize_manifest
|
|
150
|
+
|
|
151
|
+
return summarize_manifest(args.manifest)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# Confirm before a pull larger than this many structures unless --yes is given.
|
|
155
|
+
_FETCH_CONFIRM_THRESHOLD = 1000
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def cmd_fetch(args: argparse.Namespace) -> int:
|
|
159
|
+
from .download import SPLITS, StructureFetcher
|
|
160
|
+
from .hydrate import hydrate, select_targets
|
|
161
|
+
from .manifest import read_manifest
|
|
162
|
+
|
|
163
|
+
if args.all and args.split:
|
|
164
|
+
print("error: use either --all or --split, not both", file=sys.stderr)
|
|
165
|
+
return 2
|
|
166
|
+
if not args.all and not args.split:
|
|
167
|
+
print(
|
|
168
|
+
"error: choose a scope: --split test (repeatable) or --all.\n"
|
|
169
|
+
" (explicit by design — `fetch` can pull a lot of data)",
|
|
170
|
+
file=sys.stderr,
|
|
171
|
+
)
|
|
172
|
+
return 2
|
|
173
|
+
|
|
174
|
+
splits = list(SPLITS) if args.all else args.split
|
|
175
|
+
unknown = [s for s in splits if s not in SPLITS]
|
|
176
|
+
if unknown:
|
|
177
|
+
print(f"error: unknown split(s): {', '.join(unknown)}", file=sys.stderr)
|
|
178
|
+
return 2
|
|
179
|
+
|
|
180
|
+
manifest = read_manifest(args.manifest)
|
|
181
|
+
targets = select_targets(manifest, splits)
|
|
182
|
+
if not targets:
|
|
183
|
+
print(f"nothing to fetch for split(s): {', '.join(splits)}")
|
|
184
|
+
return 0
|
|
185
|
+
|
|
186
|
+
assembly = not args.asymmetric_unit
|
|
187
|
+
print(f"fetch: {len(targets)} structures across {', '.join(splits)} -> {args.out}")
|
|
188
|
+
|
|
189
|
+
# Size estimate + large-pull guard (no accidental terabyte).
|
|
190
|
+
unit_label = "assembly 1" if assembly else "asymmetric unit"
|
|
191
|
+
with StructureFetcher(assembly=assembly, workers=args.workers) as fetcher:
|
|
192
|
+
est = fetcher.estimate_bytes([e for e, _ in targets])
|
|
193
|
+
if est is not None:
|
|
194
|
+
print(f" estimated download: ~{est / 1e9:.2f} GB ({unit_label})")
|
|
195
|
+
if len(targets) > _FETCH_CONFIRM_THRESHOLD and not args.yes:
|
|
196
|
+
print(
|
|
197
|
+
f" refusing to fetch {len(targets)} structures without --yes "
|
|
198
|
+
f"(threshold {_FETCH_CONFIRM_THRESHOLD}).",
|
|
199
|
+
file=sys.stderr,
|
|
200
|
+
)
|
|
201
|
+
return 5
|
|
202
|
+
summary = hydrate(
|
|
203
|
+
args.manifest,
|
|
204
|
+
args.out,
|
|
205
|
+
splits=splits,
|
|
206
|
+
assembly=assembly,
|
|
207
|
+
workers=args.workers,
|
|
208
|
+
fetcher=fetcher,
|
|
209
|
+
progress=lambda m: print(f" {m}"),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
print(
|
|
213
|
+
f"done: {summary['fetched']} fetched, {summary['skipped']} cached, "
|
|
214
|
+
f"{len(summary['failed'])} failed"
|
|
215
|
+
)
|
|
216
|
+
if summary["failed"]:
|
|
217
|
+
for eid, reason in summary["failed"][:10]:
|
|
218
|
+
print(f" ! {eid}: {reason}", file=sys.stderr)
|
|
219
|
+
for kind, path in summary["index"].items():
|
|
220
|
+
print(f" index ({kind}): {path}")
|
|
221
|
+
print(f" dataset card: {args.out}/DATASET_CARD.md")
|
|
222
|
+
return 0 if not summary["failed"] else 6
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
226
|
+
p = argparse.ArgumentParser(
|
|
227
|
+
prog="if-split",
|
|
228
|
+
description="Reproducible, ligand-aware PDB train/val/test splitter (LigandMPNN-style).",
|
|
229
|
+
)
|
|
230
|
+
p.add_argument("--version", action="version", version=f"if-split {__version__}")
|
|
231
|
+
sub = p.add_subparsers(dest="command", required=True)
|
|
232
|
+
|
|
233
|
+
pb = sub.add_parser("build", help="Run the pipeline (Stages 1-7) and emit manifest + lock.")
|
|
234
|
+
pb.add_argument(
|
|
235
|
+
"--config",
|
|
236
|
+
default=DEFAULT_CONFIG,
|
|
237
|
+
help=f"Path to config YAML (default: {DEFAULT_CONFIG}).",
|
|
238
|
+
)
|
|
239
|
+
pb.add_argument(
|
|
240
|
+
"--out",
|
|
241
|
+
default="data/out",
|
|
242
|
+
help="Output dir for candidates.jsonl + dataset.lock (default: data/out).",
|
|
243
|
+
)
|
|
244
|
+
pb.add_argument(
|
|
245
|
+
"--limit",
|
|
246
|
+
type=int,
|
|
247
|
+
default=None,
|
|
248
|
+
help="Dev only: cap to the first N candidates by sorted entry id (reproducible).",
|
|
249
|
+
)
|
|
250
|
+
pb.add_argument(
|
|
251
|
+
"--registry",
|
|
252
|
+
default=None,
|
|
253
|
+
help="Optional prior splits.registry.json to pin existing cluster->split "
|
|
254
|
+
"assignments (growth stability).",
|
|
255
|
+
)
|
|
256
|
+
pb.set_defaults(func=cmd_build)
|
|
257
|
+
|
|
258
|
+
pv = sub.add_parser(
|
|
259
|
+
"verify", help="Re-enumerate from a lock file and report drift vs the live PDB."
|
|
260
|
+
)
|
|
261
|
+
pv.add_argument("lock", help="Path to dataset.lock")
|
|
262
|
+
pv.set_defaults(func=cmd_verify)
|
|
263
|
+
|
|
264
|
+
ps = sub.add_parser(
|
|
265
|
+
"stats", help="Report split sizes and per-class test counts from a manifest."
|
|
266
|
+
)
|
|
267
|
+
ps.add_argument("manifest", help="Path to manifest.json")
|
|
268
|
+
ps.set_defaults(func=cmd_stats)
|
|
269
|
+
|
|
270
|
+
pf = sub.add_parser(
|
|
271
|
+
"fetch",
|
|
272
|
+
help="OPTIONAL: download structures for a built manifest into an ML-ready tree.",
|
|
273
|
+
)
|
|
274
|
+
pf.add_argument("manifest", help="Path to manifest.json")
|
|
275
|
+
pf.add_argument(
|
|
276
|
+
"--out", default="data/structures", help="Output root (default: data/structures)."
|
|
277
|
+
)
|
|
278
|
+
pf.add_argument(
|
|
279
|
+
"--split",
|
|
280
|
+
action="append",
|
|
281
|
+
choices=list(SPLITS_CHOICES),
|
|
282
|
+
help="Split to fetch (repeatable, e.g. --split test --split val).",
|
|
283
|
+
)
|
|
284
|
+
pf.add_argument("--all", action="store_true", help="Fetch all splits.")
|
|
285
|
+
pf.add_argument(
|
|
286
|
+
"--asymmetric-unit",
|
|
287
|
+
action="store_true",
|
|
288
|
+
help="Fetch the asymmetric unit instead of biological assembly 1.",
|
|
289
|
+
)
|
|
290
|
+
pf.add_argument("--workers", type=int, default=8, help="Concurrent downloads (default: 8).")
|
|
291
|
+
pf.add_argument("--yes", action="store_true", help="Proceed without confirming large pulls.")
|
|
292
|
+
pf.set_defaults(func=cmd_fetch)
|
|
293
|
+
|
|
294
|
+
return p
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def main(argv: list[str] | None = None) -> int:
|
|
298
|
+
parser = build_parser()
|
|
299
|
+
args = parser.parse_args(argv)
|
|
300
|
+
try:
|
|
301
|
+
return args.func(args)
|
|
302
|
+
except FileNotFoundError as e:
|
|
303
|
+
print(f"error: {e}", file=sys.stderr)
|
|
304
|
+
return 2
|
|
305
|
+
except ValidationError as e:
|
|
306
|
+
print(f"invalid config:\n{e}", file=sys.stderr)
|
|
307
|
+
return 2
|
|
308
|
+
except NotImplementedError as e:
|
|
309
|
+
print(f"not implemented yet: {e}", file=sys.stderr)
|
|
310
|
+
return 3
|
|
311
|
+
except RcsbError as e:
|
|
312
|
+
print(f"RCSB request failed: {e}", file=sys.stderr)
|
|
313
|
+
return 4
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
if __name__ == "__main__":
|
|
317
|
+
sys.exit(main())
|
ifsplit/cluster.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Stage 5 - Cluster protein entities into leakage-safe groups.
|
|
2
|
+
|
|
3
|
+
The ``precomputed`` backend reads each protein entity's RCSB cluster id at the
|
|
4
|
+
configured identity level from ``PolymerEntity.cluster_ids`` (captured in Stage 1
|
|
5
|
+
from the Data API ``rcsb_cluster_membership`` field) - no file download, no
|
|
6
|
+
mmseqs2 binary.
|
|
7
|
+
|
|
8
|
+
A *raw cluster* is the set of protein entities sharing an RCSB cluster id. But an
|
|
9
|
+
entry with several protein chains can touch several raw clusters, so raw clusters
|
|
10
|
+
alone are NOT a leakage-safe split unit: if entry X has chain a (raw cluster A)
|
|
11
|
+
and chain b (raw cluster B), then A and B must land in the same split or X's b
|
|
12
|
+
sequence leaks across splits. We therefore merge raw clusters joined by a shared
|
|
13
|
+
entry into **components** (connected components, union-find). The component is the
|
|
14
|
+
unit Stage 6 assigns to a split, which makes cross-split sequence overlap
|
|
15
|
+
impossible by construction - no heuristic, no after-the-fact audit.
|
|
16
|
+
|
|
17
|
+
A component's canonical key is the lexicographically smallest entity id across all
|
|
18
|
+
its members. (Equivalently the smallest raw-cluster key, since each raw key is
|
|
19
|
+
itself a min-entity-id - so min-of-mins = global min.) Keying the split hash on a
|
|
20
|
+
stable member id, not RCSB's volatile integer cluster id, keeps assignments
|
|
21
|
+
stable as the dataset grows (PLAN.md §6). Sub-10-aa peptides that RCSB does not
|
|
22
|
+
cluster become their own singleton components.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from dataclasses import dataclass, field
|
|
28
|
+
|
|
29
|
+
from .config import Config
|
|
30
|
+
from .schema import CandidateRecord
|
|
31
|
+
|
|
32
|
+
SINGLETON_PREFIX = "singleton:"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ClusterResult:
|
|
37
|
+
"""Stage 5 output: raw sequence clusters merged into leakage-safe components."""
|
|
38
|
+
|
|
39
|
+
identity: int
|
|
40
|
+
entry_to_cluster: dict[str, str] # entry_id -> component key (the split unit)
|
|
41
|
+
cluster_members: dict[str, list[str]] # component key -> sorted entry_ids
|
|
42
|
+
entry_raw_clusters: dict[str, list[str]] # entry_id -> raw cluster keys it touches
|
|
43
|
+
multichain_entries: list[str] = field(default_factory=list)
|
|
44
|
+
unclustered_entries: list[str] = field(default_factory=list)
|
|
45
|
+
n_raw_clusters: int = 0
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def n_clusters(self) -> int:
|
|
49
|
+
"""Number of components (the split units)."""
|
|
50
|
+
return len(self.cluster_members)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def build_clusters(records: list[CandidateRecord], cfg: Config) -> ClusterResult:
|
|
54
|
+
"""Cluster filtered records at ``cfg.identity_level``, merged into components."""
|
|
55
|
+
if cfg.clustering_backend != "precomputed":
|
|
56
|
+
raise NotImplementedError(
|
|
57
|
+
f"clustering_backend {cfg.clustering_backend!r} not implemented "
|
|
58
|
+
"(only 'precomputed' is available)."
|
|
59
|
+
)
|
|
60
|
+
level = cfg.identity_level
|
|
61
|
+
|
|
62
|
+
# 1. Raw clusters: RCSB cluster id -> member entity ids -> canonical raw key
|
|
63
|
+
# (the smallest member entity id).
|
|
64
|
+
raw_entities: dict[int, set[str]] = {}
|
|
65
|
+
for r in records:
|
|
66
|
+
for e in r.polymer_entities:
|
|
67
|
+
if e.is_protein and level in e.cluster_ids:
|
|
68
|
+
raw_entities.setdefault(e.cluster_ids[level], set()).add(e.entity_id)
|
|
69
|
+
raw_key = {cid: min(ents) for cid, ents in raw_entities.items()}
|
|
70
|
+
|
|
71
|
+
# 2. Each entry -> the raw cluster keys it touches (a singleton key if no
|
|
72
|
+
# protein chain is clustered at this level).
|
|
73
|
+
entry_raw: dict[str, list[str]] = {}
|
|
74
|
+
multichain: list[str] = []
|
|
75
|
+
unclustered: list[str] = []
|
|
76
|
+
all_keys: set[str] = set(raw_key.values())
|
|
77
|
+
for r in records:
|
|
78
|
+
proteins = [e for e in r.polymer_entities if e.is_protein]
|
|
79
|
+
if not proteins:
|
|
80
|
+
continue # defensive; Stage 3 already drops no-protein entries
|
|
81
|
+
keys = sorted({raw_key[e.cluster_ids[level]] for e in proteins if level in e.cluster_ids})
|
|
82
|
+
if not keys:
|
|
83
|
+
singleton = SINGLETON_PREFIX + min(e.entity_id for e in proteins)
|
|
84
|
+
keys = [singleton]
|
|
85
|
+
all_keys.add(singleton)
|
|
86
|
+
unclustered.append(r.entry_id)
|
|
87
|
+
elif len(keys) > 1:
|
|
88
|
+
multichain.append(r.entry_id)
|
|
89
|
+
entry_raw[r.entry_id] = keys
|
|
90
|
+
|
|
91
|
+
# 3. Union-find: merge raw clusters joined by a shared entry into components.
|
|
92
|
+
# The smaller key is always made the root, so a component's root is its
|
|
93
|
+
# global-minimum key (order-independent -> deterministic).
|
|
94
|
+
parent = {k: k for k in all_keys}
|
|
95
|
+
|
|
96
|
+
def find(x: str) -> str:
|
|
97
|
+
root = x
|
|
98
|
+
while parent[root] != root:
|
|
99
|
+
root = parent[root]
|
|
100
|
+
while parent[x] != root: # path compression
|
|
101
|
+
parent[x], x = root, parent[x]
|
|
102
|
+
return root
|
|
103
|
+
|
|
104
|
+
def union(a: str, b: str) -> None:
|
|
105
|
+
ra, rb = find(a), find(b)
|
|
106
|
+
if ra != rb:
|
|
107
|
+
lo, hi = (ra, rb) if ra < rb else (rb, ra)
|
|
108
|
+
parent[hi] = lo
|
|
109
|
+
|
|
110
|
+
for keys in entry_raw.values():
|
|
111
|
+
for k in keys[1:]:
|
|
112
|
+
union(keys[0], k)
|
|
113
|
+
|
|
114
|
+
# 4. Materialize components: component key -> entries; entry -> component.
|
|
115
|
+
entry_to_cluster: dict[str, str] = {}
|
|
116
|
+
members: dict[str, set[str]] = {}
|
|
117
|
+
for entry, keys in entry_raw.items():
|
|
118
|
+
comp = find(keys[0])
|
|
119
|
+
entry_to_cluster[entry] = comp
|
|
120
|
+
members.setdefault(comp, set()).add(entry)
|
|
121
|
+
|
|
122
|
+
return ClusterResult(
|
|
123
|
+
identity=level,
|
|
124
|
+
entry_to_cluster=dict(sorted(entry_to_cluster.items())),
|
|
125
|
+
cluster_members={k: sorted(v) for k, v in sorted(members.items())},
|
|
126
|
+
entry_raw_clusters=dict(sorted(entry_raw.items())),
|
|
127
|
+
multichain_entries=sorted(multichain),
|
|
128
|
+
unclustered_entries=sorted(unclustered),
|
|
129
|
+
n_raw_clusters=len(raw_key),
|
|
130
|
+
)
|
ifsplit/config.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Load, validate, and hash an IF-Split run configuration.
|
|
2
|
+
|
|
3
|
+
The config is the single source of truth for a build. Its canonical hash is
|
|
4
|
+
embedded in the manifest so that two manifests sharing a config hash are
|
|
5
|
+
guaranteed to have used identical, output-affecting settings.
|
|
6
|
+
|
|
7
|
+
The hash is computed over the *validated, normalized* settings (not the raw YAML
|
|
8
|
+
text), so comments and formatting do not change it — only values do.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import hashlib
|
|
14
|
+
import json
|
|
15
|
+
from datetime import date
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Any, Literal
|
|
18
|
+
|
|
19
|
+
import yaml
|
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SplitFractions(BaseModel):
|
|
24
|
+
"""Train/val/test partition fractions; must sum to 1.0."""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(extra="forbid")
|
|
27
|
+
|
|
28
|
+
train: float = Field(gt=0, lt=1)
|
|
29
|
+
val: float = Field(gt=0, lt=1)
|
|
30
|
+
test: float = Field(gt=0, lt=1)
|
|
31
|
+
|
|
32
|
+
@model_validator(mode="after")
|
|
33
|
+
def _sum_to_one(self) -> SplitFractions:
|
|
34
|
+
total = self.train + self.val + self.test
|
|
35
|
+
if abs(total - 1.0) > 1e-9:
|
|
36
|
+
raise ValueError(f"split_fractions must sum to 1.0, got {total}")
|
|
37
|
+
return self
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Config(BaseModel):
|
|
41
|
+
"""A fully-validated IF-Split build configuration."""
|
|
42
|
+
|
|
43
|
+
model_config = ConfigDict(extra="forbid")
|
|
44
|
+
|
|
45
|
+
# --- snapshot definition (reproducibility anchor) ---
|
|
46
|
+
snapshot_date: date
|
|
47
|
+
experimental_methods: list[str] = Field(min_length=1)
|
|
48
|
+
resolution_max_A: float = Field(gt=0)
|
|
49
|
+
max_total_residues: int = Field(gt=0)
|
|
50
|
+
excluded_het: list[str] = Field(default_factory=list)
|
|
51
|
+
use_biological_assembly: bool = True
|
|
52
|
+
|
|
53
|
+
# --- curation: purification-artifact detection (Stage 4) ---
|
|
54
|
+
# A poly-His tag coordinating Ni/Co is a purification artifact, not a
|
|
55
|
+
# biological metal site (a known blemish in the LigandMPNN metal set). An
|
|
56
|
+
# entry whose *only* metal is a purification metal AND that carries a His-tag
|
|
57
|
+
# is flagged so it can be excluded from the metal class. Empty
|
|
58
|
+
# purification_metals disables the heuristic.
|
|
59
|
+
purification_metals: list[str] = Field(default_factory=lambda: ["NI", "CO"])
|
|
60
|
+
histag_min_run: int = Field(default=6, gt=0)
|
|
61
|
+
exclude_purification_artifacts: bool = True
|
|
62
|
+
|
|
63
|
+
# --- clustering + split ---
|
|
64
|
+
identity_threshold: float = Field(gt=0, le=1)
|
|
65
|
+
# "precomputed": reuse RCSB's entity clusters (default, no external binary).
|
|
66
|
+
# "mmseqs2": run our own over the snapshot's sequences.
|
|
67
|
+
clustering_backend: Literal["precomputed", "mmseqs2"] = "precomputed"
|
|
68
|
+
split_fractions: SplitFractions
|
|
69
|
+
split_salt: str = Field(min_length=1)
|
|
70
|
+
seed: int = Field(ge=0)
|
|
71
|
+
|
|
72
|
+
# --- test-set minimums (opt-in stratification top-up) ---
|
|
73
|
+
# Floor on the number of test entries carrying each functional ligand class,
|
|
74
|
+
# e.g. {"metal": 500, "nucleotide": 200}. Empty (default) = pure deterministic
|
|
75
|
+
# hash, no top-up. When set, after the base assignment any class below its floor
|
|
76
|
+
# recruits *whole components* (never individual entries, so no leakage) into
|
|
77
|
+
# test in deterministic hash order, skipping components already pinned by a
|
|
78
|
+
# registry (so growth stays stable). A floor larger than the available supply
|
|
79
|
+
# is satisfied as far as possible and the shortfall is reported, not forced.
|
|
80
|
+
test_min_per_class: dict[str, int] = Field(default_factory=dict)
|
|
81
|
+
|
|
82
|
+
# --- quality filters (Stage 3): wwPDB validation-report metrics ---
|
|
83
|
+
# Metadata only (no coordinates). Each cap is optional (None disables it). An
|
|
84
|
+
# entry is dropped only when the metric is present AND violates the cap;
|
|
85
|
+
# entries whose report lacks a metric are kept (never penalized for an absent
|
|
86
|
+
# metric). Geometry caps (clashscore/Ramachandran/rotamer) apply to X-ray and
|
|
87
|
+
# cryo-EM; diffraction caps (R-free/RSRZ) naturally no-op on EM entries.
|
|
88
|
+
max_clashscore: float | None = Field(default=None, gt=0)
|
|
89
|
+
max_rfree: float | None = Field(default=None, gt=0)
|
|
90
|
+
max_ramachandran_outlier_pct: float | None = Field(default=None, ge=0)
|
|
91
|
+
max_rotamer_outlier_pct: float | None = Field(default=None, ge=0)
|
|
92
|
+
max_rsrz_outlier_pct: float | None = Field(default=None, ge=0)
|
|
93
|
+
require_validation_report: bool = False
|
|
94
|
+
|
|
95
|
+
# --- featurization (downstream-optional; not part of the split definition) ---
|
|
96
|
+
ligand_context_radius_A: float = Field(gt=0)
|
|
97
|
+
max_ligand_atoms: int = Field(gt=0)
|
|
98
|
+
|
|
99
|
+
@field_validator("experimental_methods")
|
|
100
|
+
@classmethod
|
|
101
|
+
def _normalize_methods(cls, v: list[str]) -> list[str]:
|
|
102
|
+
return [m.strip().upper() for m in v]
|
|
103
|
+
|
|
104
|
+
@field_validator("excluded_het", "purification_metals")
|
|
105
|
+
@classmethod
|
|
106
|
+
def _normalize_codes(cls, v: list[str]) -> list[str]:
|
|
107
|
+
return [h.strip().upper() for h in v]
|
|
108
|
+
|
|
109
|
+
@field_validator("test_min_per_class")
|
|
110
|
+
@classmethod
|
|
111
|
+
def _check_minimums(cls, v: dict[str, int]) -> dict[str, int]:
|
|
112
|
+
for k, n in v.items():
|
|
113
|
+
if n < 0:
|
|
114
|
+
raise ValueError(f"test_min_per_class[{k!r}] must be >= 0, got {n}")
|
|
115
|
+
return v
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def dataset_version(self) -> str:
|
|
119
|
+
"""Versioned dataset name, e.g. 'IF-Split-2026.05.30'."""
|
|
120
|
+
return f"IF-Split-{self.snapshot_date:%Y.%m.%d}"
|
|
121
|
+
|
|
122
|
+
@property
|
|
123
|
+
def identity_level(self) -> int:
|
|
124
|
+
"""``identity_threshold`` as an integer percent (e.g. 0.30 -> 30)."""
|
|
125
|
+
return round(self.identity_threshold * 100)
|
|
126
|
+
|
|
127
|
+
def canonical_dict(self) -> dict[str, Any]:
|
|
128
|
+
"""JSON-mode dump (dates -> ISO strings) used for hashing and manifests."""
|
|
129
|
+
return self.model_dump(mode="json")
|
|
130
|
+
|
|
131
|
+
def config_hash(self) -> str:
|
|
132
|
+
"""Deterministic, formatting-independent hash of the settings."""
|
|
133
|
+
canonical = json.dumps(self.canonical_dict(), sort_keys=True, separators=(",", ":"))
|
|
134
|
+
return hashlib.blake2b(canonical.encode("utf-8"), digest_size=16).hexdigest()
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def load_config(path: str | Path) -> Config:
|
|
138
|
+
"""Load and validate a YAML config file into a :class:`Config`."""
|
|
139
|
+
path = Path(path)
|
|
140
|
+
if not path.exists():
|
|
141
|
+
raise FileNotFoundError(f"Config not found: {path}")
|
|
142
|
+
with path.open("r", encoding="utf-8") as fh:
|
|
143
|
+
raw = yaml.safe_load(fh)
|
|
144
|
+
if not isinstance(raw, dict):
|
|
145
|
+
raise ValueError(f"Config must be a YAML mapping, got {type(raw).__name__}: {path}")
|
|
146
|
+
return Config.model_validate(raw)
|