if-split 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ifsplit/ligands.py ADDED
@@ -0,0 +1,267 @@
1
+ """Stage 4 - Tier and classify non-protein components (metadata only).
2
+
3
+ Curation comes *before* classification, and quality is **annotated, never
4
+ destroyed**: no structure is dropped for ligand-quality reasons (a protein with a
5
+ junk ion is still a good training backbone). Each non-protein component is tiered
6
+
7
+ - ``functional`` : real, biologically meaningful ligand/site
8
+ - ``ambiguous`` : present and contacting, but unconfirmed
9
+ - ``artifact`` : buffer / cryoprotectant / counterion / purification tag
10
+
11
+ with a machine-readable reason. Ligand-*class* labels (metal / small_molecule /
12
+ nucleotide) derive from the tier via a config threshold (default: only
13
+ ``functional`` sets a class label; ``ambiguous`` is reported but not labelled;
14
+ ``artifact`` is excluded). A downstream featurizer reads the same per-component
15
+ tier to decide what counts as real ligand context — this is the lever that
16
+ improves *training* quality, not just test reporting.
17
+
18
+ Signals used (all from the Data API, no coordinates):
19
+ - ``bound_components`` : the comp actually contacts the protein (buffer gate)
20
+ - ``affinity_comp_ids`` : a measured binding affinity exists (strong positive)
21
+ - chem-comp ``formula`` : metal-only vs organic
22
+ - His-tag + Ni/Co : the IMAC purification-artifact pattern (existing rule)
23
+ - protein_na_interface_count: a protein<->nucleic-acid assembly interface (>0)
24
+ verifies a real contact (holo gate for nucleotide)
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import re
30
+
31
+ from .config import Config
32
+ from .schema import CandidateRecord, NonpolymerComp
33
+
34
+ # Ligand classes IF-Split tracks.
35
+ CLASS_METAL = "metal"
36
+ CLASS_NUCLEOTIDE = "nucleotide"
37
+ CLASS_SMALL_MOLECULE = "small_molecule"
38
+
39
+ # Confidence tiers.
40
+ TIER_FUNCTIONAL = "functional"
41
+ TIER_AMBIGUOUS = "ambiguous"
42
+ TIER_ARTIFACT = "artifact"
43
+
44
+ # Elements that count as a "metal ion" when a component's formula is made up
45
+ # only of these. A comp like HEM (C34 H32 Fe N4 O4) is NOT a metal ion — its
46
+ # formula contains C/H/N/O — so it falls through to small-molecule.
47
+ METAL_ELEMENTS: frozenset[str] = frozenset(
48
+ {
49
+ "LI", "BE", "NA", "MG", "AL", "K", "CA", "SC", "TI", "V", "CR", "MN",
50
+ "FE", "CO", "NI", "CU", "ZN", "GA", "RB", "SR", "Y", "ZR", "NB", "MO",
51
+ "TC", "RU", "RH", "PD", "AG", "CD", "IN", "SN", "CS", "BA", "LA", "CE",
52
+ "PR", "ND", "PM", "SM", "EU", "GD", "TB", "DY", "HO", "ER", "TM", "YB",
53
+ "LU", "HF", "TA", "W", "RE", "OS", "IR", "PT", "AU", "HG", "TL", "PB",
54
+ "BI",
55
+ }
56
+ ) # fmt: skip
57
+
58
+ # Monatomic counterions / phasing atoms: even though some are technically metals,
59
+ # as lone ions they are crystallization/phasing additives, not biological sites.
60
+ COUNTERION_COMPS: frozenset[str] = frozenset(
61
+ {"NA", "CL", "K", "BR", "I", "IOD", "CS", "RB", "F", "LI"}
62
+ )
63
+
64
+ # Common crystallization additives / buffer junk. Config-extensible via
65
+ # excluded_het; this is the always-on baseline.
66
+ DEFAULT_ADDITIVE_BLACKLIST: frozenset[str] = frozenset(
67
+ {
68
+ "HOH", "DOD", # water
69
+ "GOL", "EDO", "PEG", "PG4", "PGE", "1PE", "2PE", "P6G", "PE4", "MPD",
70
+ "SO4", "PO4", "ACT", "ACY", "FMT", "EPE", "MES", "TRS", "BME", "DTT",
71
+ "IMD", "DMS", "BOG", "OLC", "LDA", "SCN", "AZI", "NO3", "CO3", "FLC",
72
+ "TLA", "CIT", "MLI", "IOD", "GLC", "BCN", "MRD", "BU3", "P33",
73
+ }
74
+ ) # fmt: skip
75
+
76
+ _ELEMENT_RE = re.compile(r"[A-Z][a-z]?")
77
+
78
+
79
+ def elements_in_formula(formula: str | None) -> set[str]:
80
+ """Element symbols present in a chem-comp formula, upper-cased.
81
+
82
+ ``"C34 H32 Fe N4 O4"`` -> ``{"C", "H", "FE", "N", "O"}``; ``"Zn"`` -> ``{"ZN"}``.
83
+ Charge tokens and counts are ignored.
84
+ """
85
+ if not formula:
86
+ return set()
87
+ cleaned = re.sub(r"[0-9]+[+-]?", " ", formula)
88
+ return {m.group(0).upper() for m in _ELEMENT_RE.finditer(cleaned)}
89
+
90
+
91
+ def is_metal_ion(comp: NonpolymerComp) -> bool:
92
+ """True if the component's formula is composed only of metal element(s)."""
93
+ elems = elements_in_formula(comp.formula)
94
+ if not elems:
95
+ elems = {comp.comp_id} if comp.comp_id in METAL_ELEMENTS else set()
96
+ return bool(elems) and elems <= METAL_ELEMENTS
97
+
98
+
99
+ def longest_residue_run(seq: str, residue: str = "H") -> int:
100
+ """Length of the longest consecutive run of ``residue`` in ``seq``."""
101
+ if not seq:
102
+ return 0
103
+ best = run = 0
104
+ up = seq.upper()
105
+ res = residue.upper()
106
+ for ch in up:
107
+ if ch == res:
108
+ run += 1
109
+ best = max(best, run)
110
+ else:
111
+ run = 0
112
+ return best
113
+
114
+
115
+ def has_histag(seq: str, min_run: int) -> bool:
116
+ """True if ``seq`` contains a poly-histidine run of at least ``min_run``."""
117
+ return longest_residue_run(seq, "H") >= min_run
118
+
119
+
120
+ def has_protein_na_interface(record: CandidateRecord) -> bool:
121
+ """True if the assembly has a protein<->nucleic-acid interface.
122
+
123
+ Reads RCSB's precomputed ``num_prot_na_interface_entities`` — a metadata signal
124
+ that the protein actually *contacts* the DNA/RNA, not just that both were
125
+ co-deposited. No coordinates.
126
+ """
127
+ return record.protein_na_interface_count > 0
128
+
129
+
130
+ def metal_comps(record: CandidateRecord) -> list[NonpolymerComp]:
131
+ return [c for c in record.nonpolymer_comps if is_metal_ion(c)]
132
+
133
+
134
+ def is_purification_artifact(
135
+ record: CandidateRecord,
136
+ *,
137
+ purification_metals: set[str],
138
+ histag_min_run: int,
139
+ ) -> bool:
140
+ """Flag the His-tag-binds-Ni/Co purification-artifact pattern.
141
+
142
+ True only when (a) the entry has at least one metal, (b) *every* metal it has
143
+ is a purification metal, and (c) some protein chain carries a His-tag. A real
144
+ metal site (e.g. a catalytic Zn) present alongside a tag is NOT flagged.
145
+ """
146
+ if not purification_metals:
147
+ return False
148
+ metals = {c.comp_id for c in metal_comps(record)}
149
+ if not metals or not metals <= purification_metals:
150
+ return False
151
+ return any(has_histag(e.seq, histag_min_run) for e in record.polymer_entities if e.is_protein)
152
+
153
+
154
+ def tier_component(
155
+ comp: NonpolymerComp,
156
+ record: CandidateRecord,
157
+ cfg: Config,
158
+ *,
159
+ blacklist: set[str],
160
+ purification: set[str],
161
+ is_artifact_entry: bool,
162
+ ) -> tuple[str, str, str | None]:
163
+ """Tier a single non-protein component.
164
+
165
+ Returns ``(tier, reason, class_label_or_None)``. The class label is set only
166
+ for the ``functional`` tier (the default threshold); ``ambiguous`` components
167
+ return their would-be class as None so they are reported but not labelled.
168
+ """
169
+ cid = comp.comp_id
170
+ bound = cid in set(record.bound_components)
171
+ has_affinity = cid in set(record.affinity_comp_ids)
172
+
173
+ if is_metal_ion(comp):
174
+ # Lone counterion / phasing atom -> artifact regardless of binding.
175
+ if cid in COUNTERION_COMPS:
176
+ return TIER_ARTIFACT, "counterion", None
177
+ # His-tag/Ni(Co) IMAC purification artifact.
178
+ if is_artifact_entry and cfg.exclude_purification_artifacts and cid in purification:
179
+ return TIER_ARTIFACT, "histag_metal", None
180
+ if has_affinity:
181
+ return TIER_FUNCTIONAL, "metal_affinity", CLASS_METAL
182
+ if bound:
183
+ return TIER_FUNCTIONAL, "metal_bound", CLASS_METAL
184
+ # "Trust biological metals" but require contact: an unbound metal far from
185
+ # the protein is adventitious -> ambiguous, not functional.
186
+ return TIER_AMBIGUOUS, "metal_unbound", None
187
+
188
+ # Non-metal small molecule.
189
+ if cid in blacklist:
190
+ return TIER_ARTIFACT, "additive", None
191
+ if has_affinity:
192
+ return TIER_FUNCTIONAL, "ligand_affinity", CLASS_SMALL_MOLECULE
193
+ if bound:
194
+ return TIER_FUNCTIONAL, "ligand_bound", CLASS_SMALL_MOLECULE
195
+ return TIER_AMBIGUOUS, "ligand_unbound", None
196
+
197
+
198
+ def classify_components(record: CandidateRecord, cfg: Config) -> dict:
199
+ """Tier + classify one entry's components into ligand classes (metadata only).
200
+
201
+ The structure is always kept; only labels and tiers are assigned. Returns
202
+ per-component tiers, the functional class tags, the ambiguous class tags, and
203
+ curation flags.
204
+ """
205
+ blacklist = DEFAULT_ADDITIVE_BLACKLIST | set(cfg.excluded_het)
206
+ purification = set(cfg.purification_metals)
207
+ is_artifact_entry = is_purification_artifact(
208
+ record,
209
+ purification_metals=purification,
210
+ histag_min_run=cfg.histag_min_run,
211
+ )
212
+
213
+ tiers: dict[str, dict[str, str]] = {}
214
+ functional_metals: list[str] = []
215
+ functional_sms: list[str] = []
216
+ ambiguous_classes: set[str] = set()
217
+
218
+ for comp in record.nonpolymer_comps:
219
+ tier, reason, label = tier_component(
220
+ comp,
221
+ record,
222
+ cfg,
223
+ blacklist=blacklist,
224
+ purification=purification,
225
+ is_artifact_entry=is_artifact_entry,
226
+ )
227
+ tiers[comp.comp_id] = {"tier": tier, "reason": reason}
228
+ if tier == TIER_FUNCTIONAL and label == CLASS_METAL:
229
+ functional_metals.append(comp.comp_id)
230
+ elif tier == TIER_FUNCTIONAL and label == CLASS_SMALL_MOLECULE:
231
+ functional_sms.append(comp.comp_id)
232
+ elif tier == TIER_AMBIGUOUS:
233
+ # Record the would-be class for reporting (metal vs small molecule).
234
+ ambiguous_classes.add(CLASS_METAL if is_metal_ion(comp) else CLASS_SMALL_MOLECULE)
235
+
236
+ # Nucleotide class = DNA/RNA polymer chains. Functional only if the protein
237
+ # actually *interfaces* the nucleic acid (RCSB assembly-interface metadata);
238
+ # an NA chain with no protein/NA interface is co-deposited, not holo, so the
239
+ # class is reported as ambiguous, not labelled. Non-polymer nucleotides
240
+ # (ATP/GTP/NAD) are handled above as small molecules.
241
+ has_nucleotide = any(e.is_nucleic for e in record.polymer_entities)
242
+ nucleotide_functional = has_nucleotide and has_protein_na_interface(record)
243
+ if has_nucleotide:
244
+ if nucleotide_functional:
245
+ tiers["nucleic_acid"] = {"tier": TIER_FUNCTIONAL, "reason": "protein_na_interface"}
246
+ else:
247
+ ambiguous_classes.add(CLASS_NUCLEOTIDE)
248
+ tiers["nucleic_acid"] = {"tier": TIER_AMBIGUOUS, "reason": "no_protein_na_interface"}
249
+
250
+ classes: set[str] = set()
251
+ if functional_metals:
252
+ classes.add(CLASS_METAL)
253
+ if functional_sms:
254
+ classes.add(CLASS_SMALL_MOLECULE)
255
+ if nucleotide_functional:
256
+ classes.add(CLASS_NUCLEOTIDE)
257
+
258
+ return {
259
+ "entry_id": record.entry_id,
260
+ "classes": sorted(classes), # functional-tier class labels
261
+ "ambiguous_classes": sorted(ambiguous_classes - classes),
262
+ "metals": sorted(functional_metals),
263
+ "small_molecules": sorted(functional_sms),
264
+ "has_nucleotide": has_nucleotide,
265
+ "tiers": dict(sorted(tiers.items())),
266
+ "purification_artifact": is_artifact_entry,
267
+ }
ifsplit/manifest.py ADDED
@@ -0,0 +1,417 @@
1
+ """Stage 7 - Snapshot lock, manifest, split registry, and verify/stats commands.
2
+
3
+ Artifacts written to the output dir:
4
+
5
+ - ``dataset.lock`` - reproduction anchor: embedded config + canonical
6
+ ``candidates.jsonl`` hash + entry-id list. ``verify`` re-enumerates from it and
7
+ reports drift (added/removed entries, hash match), warning not failing.
8
+ - ``manifest.json`` - small (~KB) provenance record: config, drop log, per-split
9
+ + per-class counts, cluster/component stats, and a ``files`` index pointing at
10
+ the data files below. No per-entry arrays, so it stays tiny at any scale. Built
11
+ as a pure function of its inputs (no wall-clock fields) -> byte-identical across
12
+ runs of the same config.
13
+
14
+ The split itself is plain lists of PDB ids, each its own file:
15
+ - ``train.json`` / ``val.json`` / ``test.json`` - the entry ids in each split
16
+ (one id per line; grepable and trivially loadable).
17
+ - ``test/<class>_test.json`` - the test ids carrying each functional ligand class
18
+ (``metal`` / ``small_molecule`` / ``nucleotide``), for per-class evaluation.
19
+
20
+ Supporting maps (only needed for sampling / curation, not to read the split):
21
+ - ``clusters.json`` - entry_id -> component key (for cluster-balanced sampling).
22
+ - ``ligands.classes.json`` - entry_id -> functional class labels.
23
+ - ``ligands.tiers.json`` - per-component curation *audit trail* (tier + reason);
24
+ bulky (~24 MB at full-PDB scale), read only by ``fetch`` and curation audits.
25
+ - ``splits.registry.json`` - canonical_key -> split, so a later, larger snapshot
26
+ reuses prior assignments instead of re-hashing (growth stability).
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import tempfile
33
+ from pathlib import Path
34
+ from typing import Any
35
+
36
+ from . import __version__
37
+ from .config import Config
38
+
39
+ LOCK_SCHEMA = "if-split/lock@1"
40
+ MANIFEST_SCHEMA = "if-split/manifest@2"
41
+ REGISTRY_SCHEMA = "if-split/registry@1"
42
+ TIERS_SCHEMA = "if-split/tiers@1"
43
+
44
+ # Data files written next to manifest.json (referenced by manifest["files"]).
45
+ TIERS_FILENAME = "ligands.tiers.json"
46
+ CLASSES_FILENAME = "ligands.classes.json"
47
+ CLUSTERS_FILENAME = "clusters.json"
48
+ SPLIT_FILES = {"train": "train.json", "val": "val.json", "test": "test.json"}
49
+ TEST_SUBDIR = "test" # per-class test-id lists live here: test/<class>_test.json
50
+
51
+
52
+ # --------------------------------------------------------------------------- #
53
+ # Lock (Stage 1 reproduction anchor)
54
+ # --------------------------------------------------------------------------- #
55
+ def build_lock(
56
+ cfg: Config,
57
+ *,
58
+ entry_ids: list[str],
59
+ candidates_sha256: str,
60
+ limit: int | None,
61
+ ) -> dict[str, Any]:
62
+ """Assemble the lock document (pure; does not touch disk)."""
63
+ return {
64
+ "lock_schema": LOCK_SCHEMA,
65
+ "dataset_version": cfg.dataset_version,
66
+ "if_split_version": __version__,
67
+ "config_hash": cfg.config_hash(),
68
+ "config": cfg.canonical_dict(),
69
+ "selection": {"limit": limit},
70
+ "candidates": {
71
+ "count": len(entry_ids),
72
+ "sha256": candidates_sha256,
73
+ "entry_ids": sorted(entry_ids),
74
+ },
75
+ }
76
+
77
+
78
+ def _write_json(obj: dict[str, Any], path: Path, *, compact: bool = False) -> Path:
79
+ path.parent.mkdir(parents=True, exist_ok=True)
80
+ if compact:
81
+ text = json.dumps(obj, sort_keys=True, separators=(",", ":"))
82
+ else:
83
+ text = json.dumps(obj, indent=2, sort_keys=True)
84
+ path.write_text(text + "\n", encoding="utf-8")
85
+ return path
86
+
87
+
88
+ def write_lock(lock: dict[str, Any], out_dir: str | Path) -> Path:
89
+ return _write_json(lock, Path(out_dir) / "dataset.lock")
90
+
91
+
92
+ def read_lock(path: str | Path) -> dict[str, Any]:
93
+ path = Path(path)
94
+ if not path.exists():
95
+ raise FileNotFoundError(f"Lock file not found: {path}")
96
+ return json.loads(path.read_text(encoding="utf-8"))
97
+
98
+
99
+ # --------------------------------------------------------------------------- #
100
+ # Split registry (growth stability)
101
+ # --------------------------------------------------------------------------- #
102
+ def read_registry(path: str | Path) -> dict[str, str]:
103
+ """Load a canonical_key -> split registry, or {} if absent."""
104
+ path = Path(path)
105
+ if not path.exists():
106
+ return {}
107
+ doc = json.loads(path.read_text(encoding="utf-8"))
108
+ return dict(doc.get("assignments", {}))
109
+
110
+
111
+ def write_registry(cluster_split: dict[str, str], out_dir: str | Path) -> Path:
112
+ doc = {
113
+ "registry_schema": REGISTRY_SCHEMA,
114
+ "assignments": dict(sorted(cluster_split.items())),
115
+ }
116
+ return _write_json(doc, Path(out_dir) / "splits.registry.json")
117
+
118
+
119
+ # --------------------------------------------------------------------------- #
120
+ # Manifest (human-facing, deterministic)
121
+ # --------------------------------------------------------------------------- #
122
+ def build_manifest(
123
+ cfg: Config,
124
+ *,
125
+ candidates_sha256: str,
126
+ n_candidates: int,
127
+ drops: list[dict],
128
+ drop_counts: dict[str, int],
129
+ clusters,
130
+ splits,
131
+ class_map: dict[str, dict],
132
+ ) -> dict[str, Any]:
133
+ """Assemble manifest.json as a pure function of the build outputs."""
134
+ from .split import SPLITS
135
+
136
+ # Per-split entry lists.
137
+ per_split: dict[str, list[str]] = {s: [] for s in SPLITS}
138
+ for entry, s in splits.entry_split.items():
139
+ per_split[s].append(entry)
140
+ for s in per_split:
141
+ per_split[s].sort()
142
+
143
+ # Per-split, per-class counts at the functional tier (the test-quality view),
144
+ # plus the ambiguous counts so under-/over-confidence is visible.
145
+ def _class_counts(entries, key):
146
+ counts: dict[str, int] = {}
147
+ for entry in entries:
148
+ for cls in class_map.get(entry, {}).get(key, []):
149
+ counts[cls] = counts.get(cls, 0) + 1
150
+ return dict(sorted(counts.items()))
151
+
152
+ per_split_class_counts = {s: _class_counts(per_split[s], "classes") for s in SPLITS}
153
+ per_split_ambiguous_counts = {
154
+ s: _class_counts(per_split[s], "ambiguous_classes") for s in SPLITS
155
+ }
156
+ n_artifacts = sum(1 for info in class_map.values() if info.get("purification_artifact"))
157
+
158
+ # Per-class test-id files that will be written (only classes that occur).
159
+ test_class_files = {
160
+ cls: f"{TEST_SUBDIR}/{cls}_test.json" for cls in per_split_class_counts["test"]
161
+ }
162
+
163
+ # The manifest is small provenance only: NO per-entry arrays live here. The
164
+ # split membership and supporting maps are separate files (see "files").
165
+ return {
166
+ "manifest_schema": MANIFEST_SCHEMA,
167
+ "dataset_version": cfg.dataset_version,
168
+ "if_split_version": __version__,
169
+ "config_hash": cfg.config_hash(),
170
+ "config": cfg.canonical_dict(),
171
+ "candidates": {"count": n_candidates, "sha256": candidates_sha256},
172
+ "filter": {
173
+ "kept": len(splits.entry_split),
174
+ "dropped": len(drops),
175
+ "drop_counts": dict(sorted(drop_counts.items())),
176
+ },
177
+ "clustering": {
178
+ "backend": cfg.clustering_backend,
179
+ "identity": clusters.identity,
180
+ "n_clusters": clusters.n_clusters,
181
+ "n_raw_clusters": clusters.n_raw_clusters,
182
+ "multichain_entries": len(clusters.multichain_entries),
183
+ "unclustered_entries": len(clusters.unclustered_entries),
184
+ },
185
+ "splits": {
186
+ "entry_counts": dict(sorted(splits.counts.items())),
187
+ "cluster_counts": dict(sorted(splits.cluster_counts.items())),
188
+ "per_split_class_counts": per_split_class_counts,
189
+ "per_split_ambiguous_counts": per_split_ambiguous_counts,
190
+ "test_minimum_shortfalls": dict(sorted(splits.minimum_shortfalls.items())),
191
+ },
192
+ "ligands": {"n_purification_artifacts": n_artifacts},
193
+ # Pointers to the data files written alongside this manifest.
194
+ "files": {
195
+ "splits": dict(SPLIT_FILES),
196
+ "test_by_class": dict(sorted(test_class_files.items())),
197
+ "clusters": CLUSTERS_FILENAME,
198
+ "ligand_classes": CLASSES_FILENAME,
199
+ "ligand_tiers": TIERS_FILENAME,
200
+ },
201
+ }
202
+
203
+
204
+ # --------------------------------------------------------------------------- #
205
+ # Split data files (the actual lists of PDB ids + supporting maps)
206
+ # --------------------------------------------------------------------------- #
207
+ def _write_id_list(ids: list[str], path: Path) -> Path:
208
+ """Write a JSON array of ids, one per line (compact yet grepable)."""
209
+ path.parent.mkdir(parents=True, exist_ok=True)
210
+ body = ",\n".join(json.dumps(i) for i in ids)
211
+ path.write_text(f"[\n{body}\n]\n" if ids else "[]\n", encoding="utf-8")
212
+ return path
213
+
214
+
215
+ def write_split_files(splits, class_map: dict[str, dict], out_dir: str | Path) -> dict[str, Path]:
216
+ """Write train/val/test id lists + per-class test lists + supporting maps.
217
+
218
+ Returns a name->path map of everything written. Pure function of the inputs:
219
+ ids are sorted, so output is byte-stable.
220
+ """
221
+ out = Path(out_dir)
222
+ per_split: dict[str, list[str]] = {s: [] for s in SPLIT_FILES}
223
+ for entry, s in splits.entry_split.items():
224
+ per_split[s].append(entry)
225
+ for s in per_split:
226
+ per_split[s].sort()
227
+
228
+ written: dict[str, Path] = {}
229
+ for s, fname in SPLIT_FILES.items():
230
+ written[s] = _write_id_list(per_split[s], out / fname)
231
+
232
+ # Per-class test-id lists: test entries carrying each functional class.
233
+ test_ids = per_split["test"]
234
+ class_to_ids: dict[str, list[str]] = {}
235
+ for eid in test_ids:
236
+ for cls in class_map.get(eid, {}).get("classes", []):
237
+ class_to_ids.setdefault(cls, []).append(eid)
238
+ for cls, ids in class_to_ids.items():
239
+ written[f"test:{cls}"] = _write_id_list(sorted(ids), out / TEST_SUBDIR / f"{cls}_test.json")
240
+
241
+ return written
242
+
243
+
244
+ def write_clusters(entry_to_cluster: dict[str, str], out_dir: str | Path) -> Path:
245
+ """entry_id -> component key, for cluster-balanced sampling."""
246
+ doc = {
247
+ "clusters_schema": "if-split/clusters@1",
248
+ "entry_clusters": dict(sorted(entry_to_cluster.items())),
249
+ }
250
+ return _write_json(doc, Path(out_dir) / CLUSTERS_FILENAME, compact=True)
251
+
252
+
253
+ def read_clusters(path: str | Path) -> dict[str, str]:
254
+ path = Path(path)
255
+ if not path.exists():
256
+ return {}
257
+ return dict(json.loads(path.read_text(encoding="utf-8")).get("entry_clusters", {}))
258
+
259
+
260
+ def write_classes(class_map: dict[str, dict], out_dir: str | Path) -> Path:
261
+ """entry_id -> functional class labels."""
262
+ classes = {eid: info["classes"] for eid, info in sorted(class_map.items())}
263
+ doc = {"classes_schema": "if-split/classes@1", "classes": classes}
264
+ return _write_json(doc, Path(out_dir) / CLASSES_FILENAME, compact=True)
265
+
266
+
267
+ def read_classes(path: str | Path) -> dict[str, list[str]]:
268
+ path = Path(path)
269
+ if not path.exists():
270
+ return {}
271
+ return dict(json.loads(path.read_text(encoding="utf-8")).get("classes", {}))
272
+
273
+
274
+ def read_id_list(path: str | Path) -> list[str]:
275
+ """Read a split id-list file (train.json etc.) into a list of ids."""
276
+ path = Path(path)
277
+ if not path.exists():
278
+ return []
279
+ return list(json.loads(path.read_text(encoding="utf-8")))
280
+
281
+
282
+ def write_manifest(manifest: dict[str, Any], out_dir: str | Path) -> Path:
283
+ # Pretty-print: the manifest is now small (KB) provenance, meant to be read.
284
+ return _write_json(manifest, Path(out_dir) / "manifest.json")
285
+
286
+
287
+ def read_manifest(path: str | Path) -> dict[str, Any]:
288
+ path = Path(path)
289
+ if not path.exists():
290
+ raise FileNotFoundError(f"Manifest not found: {path}")
291
+ return json.loads(path.read_text(encoding="utf-8"))
292
+
293
+
294
+ # --------------------------------------------------------------------------- #
295
+ # Ligand-tier audit sidecar (bulky; off the load path)
296
+ # --------------------------------------------------------------------------- #
297
+ def build_tiers_doc(class_map: dict[str, dict]) -> dict[str, Any]:
298
+ """Per-entry, per-component tier + reason — the curation audit trail.
299
+
300
+ Pure function of ``class_map`` (the same input as the manifest), so it stays
301
+ deterministic and byte-stable. Lives in its own file because it is large and
302
+ read by nobody on the load path.
303
+ """
304
+ tiers = {eid: info.get("tiers", {}) for eid, info in sorted(class_map.items())}
305
+ return {"tiers_schema": TIERS_SCHEMA, "tiers": tiers}
306
+
307
+
308
+ def write_tiers(doc: dict[str, Any], out_dir: str | Path) -> Path:
309
+ return _write_json(doc, Path(out_dir) / TIERS_FILENAME, compact=True)
310
+
311
+
312
+ def read_tiers(path: str | Path) -> dict[str, dict]:
313
+ """Load the tier map from a sidecar file, or {} if absent."""
314
+ path = Path(path)
315
+ if not path.exists():
316
+ return {}
317
+ doc = json.loads(path.read_text(encoding="utf-8"))
318
+ return dict(doc.get("tiers", {}))
319
+
320
+
321
+ # --------------------------------------------------------------------------- #
322
+ # verify / stats commands
323
+ # --------------------------------------------------------------------------- #
324
+ def verify_lock(lock_path: str | Path, *, client=None) -> int:
325
+ """Re-enumerate from a lock's embedded config and report drift.
326
+
327
+ Returns a process exit code: 0 = reproduced exactly, 1 = drift detected.
328
+ ``client`` is injectable for offline testing; production passes None.
329
+ """
330
+ from .enumerate import enumerate_candidates
331
+
332
+ lock = read_lock(lock_path)
333
+ if lock.get("lock_schema") != LOCK_SCHEMA:
334
+ print(f"warning: unexpected lock_schema {lock.get('lock_schema')!r}")
335
+
336
+ cfg = Config.model_validate(lock["config"])
337
+ limit = (lock.get("selection") or {}).get("limit")
338
+ locked = lock["candidates"]
339
+ locked_ids = set(locked["entry_ids"])
340
+ locked_sha = locked["sha256"]
341
+
342
+ print(f"verifying {lock['dataset_version']} (config {cfg.config_hash()})")
343
+ print(f" locked: {locked['count']} entries, candidates sha256={locked_sha[:12]}...")
344
+
345
+ with tempfile.TemporaryDirectory() as tmp:
346
+ records, _, sha = enumerate_candidates(
347
+ cfg, tmp, limit=limit, client=client, progress=lambda m: print(f" {m}")
348
+ )
349
+
350
+ now_ids = {r.entry_id for r in records}
351
+ added = sorted(now_ids - locked_ids)
352
+ removed = sorted(locked_ids - now_ids) # obsoleted / withdrawn
353
+
354
+ if sha == locked_sha and not added and not removed:
355
+ print(f"OK: reproduced exactly ({len(records)} entries, hashes match).")
356
+ return 0
357
+
358
+ print("DRIFT detected:")
359
+ if sha != locked_sha:
360
+ print(f" candidates sha256 differs: now {sha[:12]}... vs locked {locked_sha[:12]}...")
361
+ if removed:
362
+ print(f" {len(removed)} entries no longer present (obsoleted/withdrawn):")
363
+ print(f" {', '.join(removed[:20])}{' ...' if len(removed) > 20 else ''}")
364
+ if added:
365
+ print(f" {len(added)} new entries match the snapshot filters:")
366
+ print(f" {', '.join(added[:20])}{' ...' if len(added) > 20 else ''}")
367
+ if not added and not removed:
368
+ print(" entry set unchanged, but per-entry metadata changed (see hash).")
369
+ return 1
370
+
371
+
372
+ def summarize_manifest(manifest_path: str | Path) -> int:
373
+ """`stats` command: print split sizes and per-class (functional) test counts."""
374
+ m = read_manifest(manifest_path)
375
+ print(f"{m['dataset_version']} (config {m['config_hash']})")
376
+ flt = m["filter"]
377
+ print(
378
+ f" candidates: {m['candidates']['count']} kept: {flt['kept']} dropped: {flt['dropped']}"
379
+ )
380
+ for reason, n in flt["drop_counts"].items():
381
+ print(f" - {reason}: {n}")
382
+ cl = m["clustering"]
383
+ print(
384
+ f" clustering: {cl['backend']} @ {cl['identity']}% "
385
+ f"components={cl['n_clusters']} (from {cl.get('n_raw_clusters', '?')} raw) "
386
+ f"multichain={cl['multichain_entries']}"
387
+ )
388
+ sp = m["splits"]
389
+ print(" splits (entries / components):")
390
+ for s in ("train", "val", "test"):
391
+ ec = sp["entry_counts"].get(s, 0)
392
+ cc = sp["cluster_counts"].get(s, 0)
393
+ print(f" {s:5s}: {ec:>7} entries {cc:>7} components")
394
+ print(" test set by ligand class (functional tier):")
395
+ for cls, n in sp["per_split_class_counts"].get("test", {}).items():
396
+ print(f" {cls}: {n}")
397
+ amb = sp.get("per_split_ambiguous_counts", {}).get("test", {})
398
+ if amb:
399
+ print(" test set ambiguous (reported, not labelled):")
400
+ for cls, n in amb.items():
401
+ print(f" {cls}: {n}")
402
+ shortfalls = sp.get("test_minimum_shortfalls", {})
403
+ if shortfalls:
404
+ print(" test minimum shortfalls (floor exceeded available supply):")
405
+ for cls, n in shortfalls.items():
406
+ print(f" {cls}: short by {n}")
407
+ lig = m["ligands"]
408
+ n_arts = lig.get("n_purification_artifacts", len(lig.get("purification_artifacts", [])))
409
+ print(f" His-tag/Ni purification artifacts flagged: {n_arts}")
410
+ files = m.get("files", {})
411
+ if files:
412
+ sf = files.get("splits", {})
413
+ tbc = files.get("test_by_class", {})
414
+ print(f" split files: {', '.join(sf.values())}")
415
+ if tbc:
416
+ print(f" per-class test files: {', '.join(tbc.values())}")
417
+ return 0