if-split 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ifsplit/cli.py ADDED
@@ -0,0 +1,317 @@
1
+ """IF-Split command-line interface.
2
+
3
+ `build` runs the full pipeline: Stage 1 enumerate (RCSB Search + Data API ->
4
+ candidates.jsonl + dataset.lock), Stage 3 filter, Stage 4 ligand classification,
5
+ Stage 5 cluster, Stage 6 deterministic split, Stage 7 manifest + registry. No
6
+ structure coordinates are downloaded. `verify` re-enumerates from a lock and
7
+ reports drift; `stats` summarizes a manifest.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import sys
14
+
15
+ from pydantic import ValidationError
16
+
17
+ from . import __version__
18
+ from .config import load_config
19
+ from .rcsb import RcsbError
20
+
21
+ DEFAULT_CONFIG = "config/default.yaml"
22
+ SPLITS_CHOICES = ("train", "val", "test")
23
+
24
+
25
+ def cmd_build(args: argparse.Namespace) -> int:
26
+ from .cluster import build_clusters
27
+ from .enumerate import enumerate_candidates, make_console_progress
28
+ from .ligands import classify_components
29
+ from .manifest import (
30
+ build_lock,
31
+ build_manifest,
32
+ build_tiers_doc,
33
+ read_registry,
34
+ write_classes,
35
+ write_clusters,
36
+ write_lock,
37
+ write_manifest,
38
+ write_registry,
39
+ write_split_files,
40
+ write_tiers,
41
+ )
42
+ from .parse import drop_summary, filter_candidates
43
+ from .split import assign_splits, check_no_leakage
44
+
45
+ cfg = load_config(args.config)
46
+ sf = cfg.split_fractions
47
+ assembly = "biological (assembly 1)" if cfg.use_biological_assembly else "asymmetric unit"
48
+ print(f"IF-Split {__version__}")
49
+ print(f" config file: {args.config}")
50
+ print(f" config hash: {cfg.config_hash()}")
51
+ print(f" dataset: {cfg.dataset_version}")
52
+ print(f" snapshot_date: {cfg.snapshot_date} (selects release_date <= this)")
53
+ print(f" methods: {', '.join(cfg.experimental_methods)}")
54
+ print(f" resolution: <= {cfg.resolution_max_A} A")
55
+ print(f" max residues: < {cfg.max_total_residues + 1}")
56
+ print(f" assembly: {assembly}")
57
+ print(f" identity: {cfg.identity_threshold}")
58
+ print(f" clustering: {cfg.clustering_backend}")
59
+ print(f" splits: train={sf.train} val={sf.val} test={sf.test} salt={cfg.split_salt!r}")
60
+ if args.limit is not None:
61
+ print(f" limit: {args.limit} (dev: first N by sorted entry id)")
62
+ print()
63
+
64
+ print("Stage 1 - enumerate candidates (Search + Data API, no coordinates):")
65
+ say = make_console_progress() # timestamped + line-flushed (survives redirect)
66
+ records, _candidates_path, sha = enumerate_candidates(
67
+ cfg, args.out, limit=args.limit, progress=say
68
+ )
69
+ lock_path = write_lock(
70
+ build_lock(
71
+ cfg,
72
+ entry_ids=[r.entry_id for r in records],
73
+ candidates_sha256=sha,
74
+ limit=args.limit,
75
+ ),
76
+ args.out,
77
+ )
78
+ print(f" wrote {lock_path}")
79
+
80
+ print("Stage 3 - filter (metadata only):")
81
+ kept, drops = filter_candidates(records, cfg)
82
+ dcounts = drop_summary(drops)
83
+ print(f" kept {len(kept)} / {len(records)}; dropped {len(drops)} {dcounts or ''}")
84
+
85
+ print("Stage 4 - ligand classification + curation:")
86
+ class_map = {r.entry_id: classify_components(r, cfg) for r in kept}
87
+ n_artifact = sum(1 for v in class_map.values() if v["purification_artifact"])
88
+ print(f" classified {len(class_map)} entries; {n_artifact} purification artifact(s) flagged")
89
+
90
+ print(f"Stage 5 - cluster ({cfg.clustering_backend} @ {cfg.identity_level}%):")
91
+ clusters = build_clusters(kept, cfg)
92
+ print(
93
+ f" {clusters.n_clusters} leakage-safe components "
94
+ f"from {clusters.n_raw_clusters} raw clusters "
95
+ f"({len(clusters.multichain_entries)} multi-chain merged)"
96
+ )
97
+
98
+ print("Stage 6 - assign splits (deterministic hash):")
99
+ registry = read_registry(args.registry) if args.registry else {}
100
+ entry_classes = {eid: info["classes"] for eid, info in class_map.items()}
101
+ splits = assign_splits(clusters, cfg, registry=registry, entry_classes=entry_classes)
102
+ check_no_leakage(splits, clusters) # structural guarantee; raises on violation
103
+ c = splits.counts
104
+ print(f" train={c['train']} val={c['val']} test={c['test']} (no cross-split leakage)")
105
+ if cfg.test_min_per_class:
106
+ if splits.minimum_shortfalls:
107
+ short = ", ".join(f"{k}:{v}" for k, v in splits.minimum_shortfalls.items())
108
+ print(f" test minimums: applied; SHORTFALL (not enough supply) -> {short}")
109
+ else:
110
+ print(" test minimums: applied; all per-class floors met")
111
+
112
+ print("Stage 7 - manifest + registry:")
113
+ manifest = build_manifest(
114
+ cfg,
115
+ candidates_sha256=sha,
116
+ n_candidates=len(records),
117
+ drops=drops,
118
+ drop_counts=dcounts,
119
+ clusters=clusters,
120
+ splits=splits,
121
+ class_map=class_map,
122
+ )
123
+ mpath = write_manifest(manifest, args.out)
124
+ split_paths = write_split_files(splits, class_map, args.out)
125
+ write_clusters(clusters.entry_to_cluster, args.out)
126
+ write_classes(class_map, args.out)
127
+ write_registry(splits.cluster_split, args.out)
128
+ write_tiers(build_tiers_doc(class_map), args.out)
129
+ print(f" wrote {mpath} (provenance + counts)")
130
+ for s in ("train", "val", "test"):
131
+ print(f" wrote {split_paths[s]}")
132
+ test_class_paths = [p for k, p in split_paths.items() if k.startswith("test:")]
133
+ for p in sorted(test_class_paths):
134
+ print(f" wrote {p}")
135
+ print(" wrote clusters.json, ligands.classes.json, ligands.tiers.json, splits.registry.json")
136
+ print()
137
+ print(f"Build complete: {len(kept)} structures across train/val/test.")
138
+ print(f"Run `if-split stats {mpath}`.")
139
+ return 0
140
+
141
+
142
+ def cmd_verify(args: argparse.Namespace) -> int:
143
+ from .manifest import verify_lock
144
+
145
+ return verify_lock(args.lock)
146
+
147
+
148
+ def cmd_stats(args: argparse.Namespace) -> int:
149
+ from .manifest import summarize_manifest
150
+
151
+ return summarize_manifest(args.manifest)
152
+
153
+
154
+ # Confirm before a pull larger than this many structures unless --yes is given.
155
+ _FETCH_CONFIRM_THRESHOLD = 1000
156
+
157
+
158
+ def cmd_fetch(args: argparse.Namespace) -> int:
159
+ from .download import SPLITS, StructureFetcher
160
+ from .hydrate import hydrate, select_targets
161
+ from .manifest import read_manifest
162
+
163
+ if args.all and args.split:
164
+ print("error: use either --all or --split, not both", file=sys.stderr)
165
+ return 2
166
+ if not args.all and not args.split:
167
+ print(
168
+ "error: choose a scope: --split test (repeatable) or --all.\n"
169
+ " (explicit by design — `fetch` can pull a lot of data)",
170
+ file=sys.stderr,
171
+ )
172
+ return 2
173
+
174
+ splits = list(SPLITS) if args.all else args.split
175
+ unknown = [s for s in splits if s not in SPLITS]
176
+ if unknown:
177
+ print(f"error: unknown split(s): {', '.join(unknown)}", file=sys.stderr)
178
+ return 2
179
+
180
+ manifest = read_manifest(args.manifest)
181
+ targets = select_targets(manifest, splits)
182
+ if not targets:
183
+ print(f"nothing to fetch for split(s): {', '.join(splits)}")
184
+ return 0
185
+
186
+ assembly = not args.asymmetric_unit
187
+ print(f"fetch: {len(targets)} structures across {', '.join(splits)} -> {args.out}")
188
+
189
+ # Size estimate + large-pull guard (no accidental terabyte).
190
+ unit_label = "assembly 1" if assembly else "asymmetric unit"
191
+ with StructureFetcher(assembly=assembly, workers=args.workers) as fetcher:
192
+ est = fetcher.estimate_bytes([e for e, _ in targets])
193
+ if est is not None:
194
+ print(f" estimated download: ~{est / 1e9:.2f} GB ({unit_label})")
195
+ if len(targets) > _FETCH_CONFIRM_THRESHOLD and not args.yes:
196
+ print(
197
+ f" refusing to fetch {len(targets)} structures without --yes "
198
+ f"(threshold {_FETCH_CONFIRM_THRESHOLD}).",
199
+ file=sys.stderr,
200
+ )
201
+ return 5
202
+ summary = hydrate(
203
+ args.manifest,
204
+ args.out,
205
+ splits=splits,
206
+ assembly=assembly,
207
+ workers=args.workers,
208
+ fetcher=fetcher,
209
+ progress=lambda m: print(f" {m}"),
210
+ )
211
+
212
+ print(
213
+ f"done: {summary['fetched']} fetched, {summary['skipped']} cached, "
214
+ f"{len(summary['failed'])} failed"
215
+ )
216
+ if summary["failed"]:
217
+ for eid, reason in summary["failed"][:10]:
218
+ print(f" ! {eid}: {reason}", file=sys.stderr)
219
+ for kind, path in summary["index"].items():
220
+ print(f" index ({kind}): {path}")
221
+ print(f" dataset card: {args.out}/DATASET_CARD.md")
222
+ return 0 if not summary["failed"] else 6
223
+
224
+
225
+ def build_parser() -> argparse.ArgumentParser:
226
+ p = argparse.ArgumentParser(
227
+ prog="if-split",
228
+ description="Reproducible, ligand-aware PDB train/val/test splitter (LigandMPNN-style).",
229
+ )
230
+ p.add_argument("--version", action="version", version=f"if-split {__version__}")
231
+ sub = p.add_subparsers(dest="command", required=True)
232
+
233
+ pb = sub.add_parser("build", help="Run the pipeline (Stages 1-7) and emit manifest + lock.")
234
+ pb.add_argument(
235
+ "--config",
236
+ default=DEFAULT_CONFIG,
237
+ help=f"Path to config YAML (default: {DEFAULT_CONFIG}).",
238
+ )
239
+ pb.add_argument(
240
+ "--out",
241
+ default="data/out",
242
+ help="Output dir for candidates.jsonl + dataset.lock (default: data/out).",
243
+ )
244
+ pb.add_argument(
245
+ "--limit",
246
+ type=int,
247
+ default=None,
248
+ help="Dev only: cap to the first N candidates by sorted entry id (reproducible).",
249
+ )
250
+ pb.add_argument(
251
+ "--registry",
252
+ default=None,
253
+ help="Optional prior splits.registry.json to pin existing cluster->split "
254
+ "assignments (growth stability).",
255
+ )
256
+ pb.set_defaults(func=cmd_build)
257
+
258
+ pv = sub.add_parser(
259
+ "verify", help="Re-enumerate from a lock file and report drift vs the live PDB."
260
+ )
261
+ pv.add_argument("lock", help="Path to dataset.lock")
262
+ pv.set_defaults(func=cmd_verify)
263
+
264
+ ps = sub.add_parser(
265
+ "stats", help="Report split sizes and per-class test counts from a manifest."
266
+ )
267
+ ps.add_argument("manifest", help="Path to manifest.json")
268
+ ps.set_defaults(func=cmd_stats)
269
+
270
+ pf = sub.add_parser(
271
+ "fetch",
272
+ help="OPTIONAL: download structures for a built manifest into an ML-ready tree.",
273
+ )
274
+ pf.add_argument("manifest", help="Path to manifest.json")
275
+ pf.add_argument(
276
+ "--out", default="data/structures", help="Output root (default: data/structures)."
277
+ )
278
+ pf.add_argument(
279
+ "--split",
280
+ action="append",
281
+ choices=list(SPLITS_CHOICES),
282
+ help="Split to fetch (repeatable, e.g. --split test --split val).",
283
+ )
284
+ pf.add_argument("--all", action="store_true", help="Fetch all splits.")
285
+ pf.add_argument(
286
+ "--asymmetric-unit",
287
+ action="store_true",
288
+ help="Fetch the asymmetric unit instead of biological assembly 1.",
289
+ )
290
+ pf.add_argument("--workers", type=int, default=8, help="Concurrent downloads (default: 8).")
291
+ pf.add_argument("--yes", action="store_true", help="Proceed without confirming large pulls.")
292
+ pf.set_defaults(func=cmd_fetch)
293
+
294
+ return p
295
+
296
+
297
+ def main(argv: list[str] | None = None) -> int:
298
+ parser = build_parser()
299
+ args = parser.parse_args(argv)
300
+ try:
301
+ return args.func(args)
302
+ except FileNotFoundError as e:
303
+ print(f"error: {e}", file=sys.stderr)
304
+ return 2
305
+ except ValidationError as e:
306
+ print(f"invalid config:\n{e}", file=sys.stderr)
307
+ return 2
308
+ except NotImplementedError as e:
309
+ print(f"not implemented yet: {e}", file=sys.stderr)
310
+ return 3
311
+ except RcsbError as e:
312
+ print(f"RCSB request failed: {e}", file=sys.stderr)
313
+ return 4
314
+
315
+
316
+ if __name__ == "__main__":
317
+ sys.exit(main())
ifsplit/cluster.py ADDED
@@ -0,0 +1,130 @@
1
+ """Stage 5 - Cluster protein entities into leakage-safe groups.
2
+
3
+ The ``precomputed`` backend reads each protein entity's RCSB cluster id at the
4
+ configured identity level from ``PolymerEntity.cluster_ids`` (captured in Stage 1
5
+ from the Data API ``rcsb_cluster_membership`` field) - no file download, no
6
+ mmseqs2 binary.
7
+
8
+ A *raw cluster* is the set of protein entities sharing an RCSB cluster id. But an
9
+ entry with several protein chains can touch several raw clusters, so raw clusters
10
+ alone are NOT a leakage-safe split unit: if entry X has chain a (raw cluster A)
11
+ and chain b (raw cluster B), then A and B must land in the same split or X's b
12
+ sequence leaks across splits. We therefore merge raw clusters joined by a shared
13
+ entry into **components** (connected components, union-find). The component is the
14
+ unit Stage 6 assigns to a split, which makes cross-split sequence overlap
15
+ impossible by construction - no heuristic, no after-the-fact audit.
16
+
17
+ A component's canonical key is the lexicographically smallest entity id across all
18
+ its members. (Equivalently the smallest raw-cluster key, since each raw key is
19
+ itself a min-entity-id - so min-of-mins = global min.) Keying the split hash on a
20
+ stable member id, not RCSB's volatile integer cluster id, keeps assignments
21
+ stable as the dataset grows (PLAN.md §6). Sub-10-aa peptides that RCSB does not
22
+ cluster become their own singleton components.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from dataclasses import dataclass, field
28
+
29
+ from .config import Config
30
+ from .schema import CandidateRecord
31
+
32
+ SINGLETON_PREFIX = "singleton:"
33
+
34
+
35
+ @dataclass
36
+ class ClusterResult:
37
+ """Stage 5 output: raw sequence clusters merged into leakage-safe components."""
38
+
39
+ identity: int
40
+ entry_to_cluster: dict[str, str] # entry_id -> component key (the split unit)
41
+ cluster_members: dict[str, list[str]] # component key -> sorted entry_ids
42
+ entry_raw_clusters: dict[str, list[str]] # entry_id -> raw cluster keys it touches
43
+ multichain_entries: list[str] = field(default_factory=list)
44
+ unclustered_entries: list[str] = field(default_factory=list)
45
+ n_raw_clusters: int = 0
46
+
47
+ @property
48
+ def n_clusters(self) -> int:
49
+ """Number of components (the split units)."""
50
+ return len(self.cluster_members)
51
+
52
+
53
+ def build_clusters(records: list[CandidateRecord], cfg: Config) -> ClusterResult:
54
+ """Cluster filtered records at ``cfg.identity_level``, merged into components."""
55
+ if cfg.clustering_backend != "precomputed":
56
+ raise NotImplementedError(
57
+ f"clustering_backend {cfg.clustering_backend!r} not implemented "
58
+ "(only 'precomputed' is available)."
59
+ )
60
+ level = cfg.identity_level
61
+
62
+ # 1. Raw clusters: RCSB cluster id -> member entity ids -> canonical raw key
63
+ # (the smallest member entity id).
64
+ raw_entities: dict[int, set[str]] = {}
65
+ for r in records:
66
+ for e in r.polymer_entities:
67
+ if e.is_protein and level in e.cluster_ids:
68
+ raw_entities.setdefault(e.cluster_ids[level], set()).add(e.entity_id)
69
+ raw_key = {cid: min(ents) for cid, ents in raw_entities.items()}
70
+
71
+ # 2. Each entry -> the raw cluster keys it touches (a singleton key if no
72
+ # protein chain is clustered at this level).
73
+ entry_raw: dict[str, list[str]] = {}
74
+ multichain: list[str] = []
75
+ unclustered: list[str] = []
76
+ all_keys: set[str] = set(raw_key.values())
77
+ for r in records:
78
+ proteins = [e for e in r.polymer_entities if e.is_protein]
79
+ if not proteins:
80
+ continue # defensive; Stage 3 already drops no-protein entries
81
+ keys = sorted({raw_key[e.cluster_ids[level]] for e in proteins if level in e.cluster_ids})
82
+ if not keys:
83
+ singleton = SINGLETON_PREFIX + min(e.entity_id for e in proteins)
84
+ keys = [singleton]
85
+ all_keys.add(singleton)
86
+ unclustered.append(r.entry_id)
87
+ elif len(keys) > 1:
88
+ multichain.append(r.entry_id)
89
+ entry_raw[r.entry_id] = keys
90
+
91
+ # 3. Union-find: merge raw clusters joined by a shared entry into components.
92
+ # The smaller key is always made the root, so a component's root is its
93
+ # global-minimum key (order-independent -> deterministic).
94
+ parent = {k: k for k in all_keys}
95
+
96
+ def find(x: str) -> str:
97
+ root = x
98
+ while parent[root] != root:
99
+ root = parent[root]
100
+ while parent[x] != root: # path compression
101
+ parent[x], x = root, parent[x]
102
+ return root
103
+
104
+ def union(a: str, b: str) -> None:
105
+ ra, rb = find(a), find(b)
106
+ if ra != rb:
107
+ lo, hi = (ra, rb) if ra < rb else (rb, ra)
108
+ parent[hi] = lo
109
+
110
+ for keys in entry_raw.values():
111
+ for k in keys[1:]:
112
+ union(keys[0], k)
113
+
114
+ # 4. Materialize components: component key -> entries; entry -> component.
115
+ entry_to_cluster: dict[str, str] = {}
116
+ members: dict[str, set[str]] = {}
117
+ for entry, keys in entry_raw.items():
118
+ comp = find(keys[0])
119
+ entry_to_cluster[entry] = comp
120
+ members.setdefault(comp, set()).add(entry)
121
+
122
+ return ClusterResult(
123
+ identity=level,
124
+ entry_to_cluster=dict(sorted(entry_to_cluster.items())),
125
+ cluster_members={k: sorted(v) for k, v in sorted(members.items())},
126
+ entry_raw_clusters=dict(sorted(entry_raw.items())),
127
+ multichain_entries=sorted(multichain),
128
+ unclustered_entries=sorted(unclustered),
129
+ n_raw_clusters=len(raw_key),
130
+ )
ifsplit/config.py ADDED
@@ -0,0 +1,146 @@
1
+ """Load, validate, and hash an IF-Split run configuration.
2
+
3
+ The config is the single source of truth for a build. Its canonical hash is
4
+ embedded in the manifest so that two manifests sharing a config hash are
5
+ guaranteed to have used identical, output-affecting settings.
6
+
7
+ The hash is computed over the *validated, normalized* settings (not the raw YAML
8
+ text), so comments and formatting do not change it — only values do.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import hashlib
14
+ import json
15
+ from datetime import date
16
+ from pathlib import Path
17
+ from typing import Any, Literal
18
+
19
+ import yaml
20
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
21
+
22
+
23
+ class SplitFractions(BaseModel):
24
+ """Train/val/test partition fractions; must sum to 1.0."""
25
+
26
+ model_config = ConfigDict(extra="forbid")
27
+
28
+ train: float = Field(gt=0, lt=1)
29
+ val: float = Field(gt=0, lt=1)
30
+ test: float = Field(gt=0, lt=1)
31
+
32
+ @model_validator(mode="after")
33
+ def _sum_to_one(self) -> SplitFractions:
34
+ total = self.train + self.val + self.test
35
+ if abs(total - 1.0) > 1e-9:
36
+ raise ValueError(f"split_fractions must sum to 1.0, got {total}")
37
+ return self
38
+
39
+
40
+ class Config(BaseModel):
41
+ """A fully-validated IF-Split build configuration."""
42
+
43
+ model_config = ConfigDict(extra="forbid")
44
+
45
+ # --- snapshot definition (reproducibility anchor) ---
46
+ snapshot_date: date
47
+ experimental_methods: list[str] = Field(min_length=1)
48
+ resolution_max_A: float = Field(gt=0)
49
+ max_total_residues: int = Field(gt=0)
50
+ excluded_het: list[str] = Field(default_factory=list)
51
+ use_biological_assembly: bool = True
52
+
53
+ # --- curation: purification-artifact detection (Stage 4) ---
54
+ # A poly-His tag coordinating Ni/Co is a purification artifact, not a
55
+ # biological metal site (a known blemish in the LigandMPNN metal set). An
56
+ # entry whose *only* metal is a purification metal AND that carries a His-tag
57
+ # is flagged so it can be excluded from the metal class. Empty
58
+ # purification_metals disables the heuristic.
59
+ purification_metals: list[str] = Field(default_factory=lambda: ["NI", "CO"])
60
+ histag_min_run: int = Field(default=6, gt=0)
61
+ exclude_purification_artifacts: bool = True
62
+
63
+ # --- clustering + split ---
64
+ identity_threshold: float = Field(gt=0, le=1)
65
+ # "precomputed": reuse RCSB's entity clusters (default, no external binary).
66
+ # "mmseqs2": run our own over the snapshot's sequences.
67
+ clustering_backend: Literal["precomputed", "mmseqs2"] = "precomputed"
68
+ split_fractions: SplitFractions
69
+ split_salt: str = Field(min_length=1)
70
+ seed: int = Field(ge=0)
71
+
72
+ # --- test-set minimums (opt-in stratification top-up) ---
73
+ # Floor on the number of test entries carrying each functional ligand class,
74
+ # e.g. {"metal": 500, "nucleotide": 200}. Empty (default) = pure deterministic
75
+ # hash, no top-up. When set, after the base assignment any class below its floor
76
+ # recruits *whole components* (never individual entries, so no leakage) into
77
+ # test in deterministic hash order, skipping components already pinned by a
78
+ # registry (so growth stays stable). A floor larger than the available supply
79
+ # is satisfied as far as possible and the shortfall is reported, not forced.
80
+ test_min_per_class: dict[str, int] = Field(default_factory=dict)
81
+
82
+ # --- quality filters (Stage 3): wwPDB validation-report metrics ---
83
+ # Metadata only (no coordinates). Each cap is optional (None disables it). An
84
+ # entry is dropped only when the metric is present AND violates the cap;
85
+ # entries whose report lacks a metric are kept (never penalized for an absent
86
+ # metric). Geometry caps (clashscore/Ramachandran/rotamer) apply to X-ray and
87
+ # cryo-EM; diffraction caps (R-free/RSRZ) naturally no-op on EM entries.
88
+ max_clashscore: float | None = Field(default=None, gt=0)
89
+ max_rfree: float | None = Field(default=None, gt=0)
90
+ max_ramachandran_outlier_pct: float | None = Field(default=None, ge=0)
91
+ max_rotamer_outlier_pct: float | None = Field(default=None, ge=0)
92
+ max_rsrz_outlier_pct: float | None = Field(default=None, ge=0)
93
+ require_validation_report: bool = False
94
+
95
+ # --- featurization (downstream-optional; not part of the split definition) ---
96
+ ligand_context_radius_A: float = Field(gt=0)
97
+ max_ligand_atoms: int = Field(gt=0)
98
+
99
+ @field_validator("experimental_methods")
100
+ @classmethod
101
+ def _normalize_methods(cls, v: list[str]) -> list[str]:
102
+ return [m.strip().upper() for m in v]
103
+
104
+ @field_validator("excluded_het", "purification_metals")
105
+ @classmethod
106
+ def _normalize_codes(cls, v: list[str]) -> list[str]:
107
+ return [h.strip().upper() for h in v]
108
+
109
+ @field_validator("test_min_per_class")
110
+ @classmethod
111
+ def _check_minimums(cls, v: dict[str, int]) -> dict[str, int]:
112
+ for k, n in v.items():
113
+ if n < 0:
114
+ raise ValueError(f"test_min_per_class[{k!r}] must be >= 0, got {n}")
115
+ return v
116
+
117
+ @property
118
+ def dataset_version(self) -> str:
119
+ """Versioned dataset name, e.g. 'IF-Split-2026.05.30'."""
120
+ return f"IF-Split-{self.snapshot_date:%Y.%m.%d}"
121
+
122
+ @property
123
+ def identity_level(self) -> int:
124
+ """``identity_threshold`` as an integer percent (e.g. 0.30 -> 30)."""
125
+ return round(self.identity_threshold * 100)
126
+
127
+ def canonical_dict(self) -> dict[str, Any]:
128
+ """JSON-mode dump (dates -> ISO strings) used for hashing and manifests."""
129
+ return self.model_dump(mode="json")
130
+
131
+ def config_hash(self) -> str:
132
+ """Deterministic, formatting-independent hash of the settings."""
133
+ canonical = json.dumps(self.canonical_dict(), sort_keys=True, separators=(",", ":"))
134
+ return hashlib.blake2b(canonical.encode("utf-8"), digest_size=16).hexdigest()
135
+
136
+
137
+ def load_config(path: str | Path) -> Config:
138
+ """Load and validate a YAML config file into a :class:`Config`."""
139
+ path = Path(path)
140
+ if not path.exists():
141
+ raise FileNotFoundError(f"Config not found: {path}")
142
+ with path.open("r", encoding="utf-8") as fh:
143
+ raw = yaml.safe_load(fh)
144
+ if not isinstance(raw, dict):
145
+ raise ValueError(f"Config must be a YAML mapping, got {type(raw).__name__}: {path}")
146
+ return Config.model_validate(raw)