if-split 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- if_split-0.1.0.dist-info/METADATA +312 -0
- if_split-0.1.0.dist-info/RECORD +20 -0
- if_split-0.1.0.dist-info/WHEEL +4 -0
- if_split-0.1.0.dist-info/entry_points.txt +2 -0
- if_split-0.1.0.dist-info/licenses/LICENSE +21 -0
- ifsplit/__init__.py +8 -0
- ifsplit/__main__.py +8 -0
- ifsplit/cli.py +317 -0
- ifsplit/cluster.py +130 -0
- ifsplit/config.py +146 -0
- ifsplit/dataset.py +112 -0
- ifsplit/download.py +229 -0
- ifsplit/enumerate.py +111 -0
- ifsplit/hydrate.py +216 -0
- ifsplit/ligands.py +267 -0
- ifsplit/manifest.py +417 -0
- ifsplit/parse.py +111 -0
- ifsplit/rcsb.py +251 -0
- ifsplit/schema.py +241 -0
- ifsplit/split.py +177 -0
ifsplit/ligands.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Stage 4 - Tier and classify non-protein components (metadata only).
|
|
2
|
+
|
|
3
|
+
Curation comes *before* classification, and quality is **annotated, never
|
|
4
|
+
destroyed**: no structure is dropped for ligand-quality reasons (a protein with a
|
|
5
|
+
junk ion is still a good training backbone). Each non-protein component is tiered
|
|
6
|
+
|
|
7
|
+
- ``functional`` : real, biologically meaningful ligand/site
|
|
8
|
+
- ``ambiguous`` : present and contacting, but unconfirmed
|
|
9
|
+
- ``artifact`` : buffer / cryoprotectant / counterion / purification tag
|
|
10
|
+
|
|
11
|
+
with a machine-readable reason. Ligand-*class* labels (metal / small_molecule /
|
|
12
|
+
nucleotide) derive from the tier via a config threshold (default: only
|
|
13
|
+
``functional`` sets a class label; ``ambiguous`` is reported but not labelled;
|
|
14
|
+
``artifact`` is excluded). A downstream featurizer reads the same per-component
|
|
15
|
+
tier to decide what counts as real ligand context — this is the lever that
|
|
16
|
+
improves *training* quality, not just test reporting.
|
|
17
|
+
|
|
18
|
+
Signals used (all from the Data API, no coordinates):
|
|
19
|
+
- ``bound_components`` : the comp actually contacts the protein (buffer gate)
|
|
20
|
+
- ``affinity_comp_ids`` : a measured binding affinity exists (strong positive)
|
|
21
|
+
- chem-comp ``formula`` : metal-only vs organic
|
|
22
|
+
- His-tag + Ni/Co : the IMAC purification-artifact pattern (existing rule)
|
|
23
|
+
- protein_na_interface_count: a protein<->nucleic-acid assembly interface (>0)
|
|
24
|
+
verifies a real contact (holo gate for nucleotide)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import re
|
|
30
|
+
|
|
31
|
+
from .config import Config
|
|
32
|
+
from .schema import CandidateRecord, NonpolymerComp
|
|
33
|
+
|
|
34
|
+
# Ligand classes IF-Split tracks.
|
|
35
|
+
CLASS_METAL = "metal"
|
|
36
|
+
CLASS_NUCLEOTIDE = "nucleotide"
|
|
37
|
+
CLASS_SMALL_MOLECULE = "small_molecule"
|
|
38
|
+
|
|
39
|
+
# Confidence tiers.
|
|
40
|
+
TIER_FUNCTIONAL = "functional"
|
|
41
|
+
TIER_AMBIGUOUS = "ambiguous"
|
|
42
|
+
TIER_ARTIFACT = "artifact"
|
|
43
|
+
|
|
44
|
+
# Elements that count as a "metal ion" when a component's formula is made up
|
|
45
|
+
# only of these. A comp like HEM (C34 H32 Fe N4 O4) is NOT a metal ion — its
|
|
46
|
+
# formula contains C/H/N/O — so it falls through to small-molecule.
|
|
47
|
+
METAL_ELEMENTS: frozenset[str] = frozenset(
|
|
48
|
+
{
|
|
49
|
+
"LI", "BE", "NA", "MG", "AL", "K", "CA", "SC", "TI", "V", "CR", "MN",
|
|
50
|
+
"FE", "CO", "NI", "CU", "ZN", "GA", "RB", "SR", "Y", "ZR", "NB", "MO",
|
|
51
|
+
"TC", "RU", "RH", "PD", "AG", "CD", "IN", "SN", "CS", "BA", "LA", "CE",
|
|
52
|
+
"PR", "ND", "PM", "SM", "EU", "GD", "TB", "DY", "HO", "ER", "TM", "YB",
|
|
53
|
+
"LU", "HF", "TA", "W", "RE", "OS", "IR", "PT", "AU", "HG", "TL", "PB",
|
|
54
|
+
"BI",
|
|
55
|
+
}
|
|
56
|
+
) # fmt: skip
|
|
57
|
+
|
|
58
|
+
# Monatomic counterions / phasing atoms: even though some are technically metals,
|
|
59
|
+
# as lone ions they are crystallization/phasing additives, not biological sites.
|
|
60
|
+
COUNTERION_COMPS: frozenset[str] = frozenset(
|
|
61
|
+
{"NA", "CL", "K", "BR", "I", "IOD", "CS", "RB", "F", "LI"}
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Common crystallization additives / buffer junk. Config-extensible via
|
|
65
|
+
# excluded_het; this is the always-on baseline.
|
|
66
|
+
DEFAULT_ADDITIVE_BLACKLIST: frozenset[str] = frozenset(
|
|
67
|
+
{
|
|
68
|
+
"HOH", "DOD", # water
|
|
69
|
+
"GOL", "EDO", "PEG", "PG4", "PGE", "1PE", "2PE", "P6G", "PE4", "MPD",
|
|
70
|
+
"SO4", "PO4", "ACT", "ACY", "FMT", "EPE", "MES", "TRS", "BME", "DTT",
|
|
71
|
+
"IMD", "DMS", "BOG", "OLC", "LDA", "SCN", "AZI", "NO3", "CO3", "FLC",
|
|
72
|
+
"TLA", "CIT", "MLI", "IOD", "GLC", "BCN", "MRD", "BU3", "P33",
|
|
73
|
+
}
|
|
74
|
+
) # fmt: skip
|
|
75
|
+
|
|
76
|
+
_ELEMENT_RE = re.compile(r"[A-Z][a-z]?")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def elements_in_formula(formula: str | None) -> set[str]:
|
|
80
|
+
"""Element symbols present in a chem-comp formula, upper-cased.
|
|
81
|
+
|
|
82
|
+
``"C34 H32 Fe N4 O4"`` -> ``{"C", "H", "FE", "N", "O"}``; ``"Zn"`` -> ``{"ZN"}``.
|
|
83
|
+
Charge tokens and counts are ignored.
|
|
84
|
+
"""
|
|
85
|
+
if not formula:
|
|
86
|
+
return set()
|
|
87
|
+
cleaned = re.sub(r"[0-9]+[+-]?", " ", formula)
|
|
88
|
+
return {m.group(0).upper() for m in _ELEMENT_RE.finditer(cleaned)}
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def is_metal_ion(comp: NonpolymerComp) -> bool:
|
|
92
|
+
"""True if the component's formula is composed only of metal element(s)."""
|
|
93
|
+
elems = elements_in_formula(comp.formula)
|
|
94
|
+
if not elems:
|
|
95
|
+
elems = {comp.comp_id} if comp.comp_id in METAL_ELEMENTS else set()
|
|
96
|
+
return bool(elems) and elems <= METAL_ELEMENTS
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def longest_residue_run(seq: str, residue: str = "H") -> int:
|
|
100
|
+
"""Length of the longest consecutive run of ``residue`` in ``seq``."""
|
|
101
|
+
if not seq:
|
|
102
|
+
return 0
|
|
103
|
+
best = run = 0
|
|
104
|
+
up = seq.upper()
|
|
105
|
+
res = residue.upper()
|
|
106
|
+
for ch in up:
|
|
107
|
+
if ch == res:
|
|
108
|
+
run += 1
|
|
109
|
+
best = max(best, run)
|
|
110
|
+
else:
|
|
111
|
+
run = 0
|
|
112
|
+
return best
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def has_histag(seq: str, min_run: int) -> bool:
|
|
116
|
+
"""True if ``seq`` contains a poly-histidine run of at least ``min_run``."""
|
|
117
|
+
return longest_residue_run(seq, "H") >= min_run
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def has_protein_na_interface(record: CandidateRecord) -> bool:
|
|
121
|
+
"""True if the assembly has a protein<->nucleic-acid interface.
|
|
122
|
+
|
|
123
|
+
Reads RCSB's precomputed ``num_prot_na_interface_entities`` — a metadata signal
|
|
124
|
+
that the protein actually *contacts* the DNA/RNA, not just that both were
|
|
125
|
+
co-deposited. No coordinates.
|
|
126
|
+
"""
|
|
127
|
+
return record.protein_na_interface_count > 0
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def metal_comps(record: CandidateRecord) -> list[NonpolymerComp]:
|
|
131
|
+
return [c for c in record.nonpolymer_comps if is_metal_ion(c)]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def is_purification_artifact(
|
|
135
|
+
record: CandidateRecord,
|
|
136
|
+
*,
|
|
137
|
+
purification_metals: set[str],
|
|
138
|
+
histag_min_run: int,
|
|
139
|
+
) -> bool:
|
|
140
|
+
"""Flag the His-tag-binds-Ni/Co purification-artifact pattern.
|
|
141
|
+
|
|
142
|
+
True only when (a) the entry has at least one metal, (b) *every* metal it has
|
|
143
|
+
is a purification metal, and (c) some protein chain carries a His-tag. A real
|
|
144
|
+
metal site (e.g. a catalytic Zn) present alongside a tag is NOT flagged.
|
|
145
|
+
"""
|
|
146
|
+
if not purification_metals:
|
|
147
|
+
return False
|
|
148
|
+
metals = {c.comp_id for c in metal_comps(record)}
|
|
149
|
+
if not metals or not metals <= purification_metals:
|
|
150
|
+
return False
|
|
151
|
+
return any(has_histag(e.seq, histag_min_run) for e in record.polymer_entities if e.is_protein)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def tier_component(
|
|
155
|
+
comp: NonpolymerComp,
|
|
156
|
+
record: CandidateRecord,
|
|
157
|
+
cfg: Config,
|
|
158
|
+
*,
|
|
159
|
+
blacklist: set[str],
|
|
160
|
+
purification: set[str],
|
|
161
|
+
is_artifact_entry: bool,
|
|
162
|
+
) -> tuple[str, str, str | None]:
|
|
163
|
+
"""Tier a single non-protein component.
|
|
164
|
+
|
|
165
|
+
Returns ``(tier, reason, class_label_or_None)``. The class label is set only
|
|
166
|
+
for the ``functional`` tier (the default threshold); ``ambiguous`` components
|
|
167
|
+
return their would-be class as None so they are reported but not labelled.
|
|
168
|
+
"""
|
|
169
|
+
cid = comp.comp_id
|
|
170
|
+
bound = cid in set(record.bound_components)
|
|
171
|
+
has_affinity = cid in set(record.affinity_comp_ids)
|
|
172
|
+
|
|
173
|
+
if is_metal_ion(comp):
|
|
174
|
+
# Lone counterion / phasing atom -> artifact regardless of binding.
|
|
175
|
+
if cid in COUNTERION_COMPS:
|
|
176
|
+
return TIER_ARTIFACT, "counterion", None
|
|
177
|
+
# His-tag/Ni(Co) IMAC purification artifact.
|
|
178
|
+
if is_artifact_entry and cfg.exclude_purification_artifacts and cid in purification:
|
|
179
|
+
return TIER_ARTIFACT, "histag_metal", None
|
|
180
|
+
if has_affinity:
|
|
181
|
+
return TIER_FUNCTIONAL, "metal_affinity", CLASS_METAL
|
|
182
|
+
if bound:
|
|
183
|
+
return TIER_FUNCTIONAL, "metal_bound", CLASS_METAL
|
|
184
|
+
# "Trust biological metals" but require contact: an unbound metal far from
|
|
185
|
+
# the protein is adventitious -> ambiguous, not functional.
|
|
186
|
+
return TIER_AMBIGUOUS, "metal_unbound", None
|
|
187
|
+
|
|
188
|
+
# Non-metal small molecule.
|
|
189
|
+
if cid in blacklist:
|
|
190
|
+
return TIER_ARTIFACT, "additive", None
|
|
191
|
+
if has_affinity:
|
|
192
|
+
return TIER_FUNCTIONAL, "ligand_affinity", CLASS_SMALL_MOLECULE
|
|
193
|
+
if bound:
|
|
194
|
+
return TIER_FUNCTIONAL, "ligand_bound", CLASS_SMALL_MOLECULE
|
|
195
|
+
return TIER_AMBIGUOUS, "ligand_unbound", None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def classify_components(record: CandidateRecord, cfg: Config) -> dict:
|
|
199
|
+
"""Tier + classify one entry's components into ligand classes (metadata only).
|
|
200
|
+
|
|
201
|
+
The structure is always kept; only labels and tiers are assigned. Returns
|
|
202
|
+
per-component tiers, the functional class tags, the ambiguous class tags, and
|
|
203
|
+
curation flags.
|
|
204
|
+
"""
|
|
205
|
+
blacklist = DEFAULT_ADDITIVE_BLACKLIST | set(cfg.excluded_het)
|
|
206
|
+
purification = set(cfg.purification_metals)
|
|
207
|
+
is_artifact_entry = is_purification_artifact(
|
|
208
|
+
record,
|
|
209
|
+
purification_metals=purification,
|
|
210
|
+
histag_min_run=cfg.histag_min_run,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
tiers: dict[str, dict[str, str]] = {}
|
|
214
|
+
functional_metals: list[str] = []
|
|
215
|
+
functional_sms: list[str] = []
|
|
216
|
+
ambiguous_classes: set[str] = set()
|
|
217
|
+
|
|
218
|
+
for comp in record.nonpolymer_comps:
|
|
219
|
+
tier, reason, label = tier_component(
|
|
220
|
+
comp,
|
|
221
|
+
record,
|
|
222
|
+
cfg,
|
|
223
|
+
blacklist=blacklist,
|
|
224
|
+
purification=purification,
|
|
225
|
+
is_artifact_entry=is_artifact_entry,
|
|
226
|
+
)
|
|
227
|
+
tiers[comp.comp_id] = {"tier": tier, "reason": reason}
|
|
228
|
+
if tier == TIER_FUNCTIONAL and label == CLASS_METAL:
|
|
229
|
+
functional_metals.append(comp.comp_id)
|
|
230
|
+
elif tier == TIER_FUNCTIONAL and label == CLASS_SMALL_MOLECULE:
|
|
231
|
+
functional_sms.append(comp.comp_id)
|
|
232
|
+
elif tier == TIER_AMBIGUOUS:
|
|
233
|
+
# Record the would-be class for reporting (metal vs small molecule).
|
|
234
|
+
ambiguous_classes.add(CLASS_METAL if is_metal_ion(comp) else CLASS_SMALL_MOLECULE)
|
|
235
|
+
|
|
236
|
+
# Nucleotide class = DNA/RNA polymer chains. Functional only if the protein
|
|
237
|
+
# actually *interfaces* the nucleic acid (RCSB assembly-interface metadata);
|
|
238
|
+
# an NA chain with no protein/NA interface is co-deposited, not holo, so the
|
|
239
|
+
# class is reported as ambiguous, not labelled. Non-polymer nucleotides
|
|
240
|
+
# (ATP/GTP/NAD) are handled above as small molecules.
|
|
241
|
+
has_nucleotide = any(e.is_nucleic for e in record.polymer_entities)
|
|
242
|
+
nucleotide_functional = has_nucleotide and has_protein_na_interface(record)
|
|
243
|
+
if has_nucleotide:
|
|
244
|
+
if nucleotide_functional:
|
|
245
|
+
tiers["nucleic_acid"] = {"tier": TIER_FUNCTIONAL, "reason": "protein_na_interface"}
|
|
246
|
+
else:
|
|
247
|
+
ambiguous_classes.add(CLASS_NUCLEOTIDE)
|
|
248
|
+
tiers["nucleic_acid"] = {"tier": TIER_AMBIGUOUS, "reason": "no_protein_na_interface"}
|
|
249
|
+
|
|
250
|
+
classes: set[str] = set()
|
|
251
|
+
if functional_metals:
|
|
252
|
+
classes.add(CLASS_METAL)
|
|
253
|
+
if functional_sms:
|
|
254
|
+
classes.add(CLASS_SMALL_MOLECULE)
|
|
255
|
+
if nucleotide_functional:
|
|
256
|
+
classes.add(CLASS_NUCLEOTIDE)
|
|
257
|
+
|
|
258
|
+
return {
|
|
259
|
+
"entry_id": record.entry_id,
|
|
260
|
+
"classes": sorted(classes), # functional-tier class labels
|
|
261
|
+
"ambiguous_classes": sorted(ambiguous_classes - classes),
|
|
262
|
+
"metals": sorted(functional_metals),
|
|
263
|
+
"small_molecules": sorted(functional_sms),
|
|
264
|
+
"has_nucleotide": has_nucleotide,
|
|
265
|
+
"tiers": dict(sorted(tiers.items())),
|
|
266
|
+
"purification_artifact": is_artifact_entry,
|
|
267
|
+
}
|
ifsplit/manifest.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
1
|
+
"""Stage 7 - Snapshot lock, manifest, split registry, and verify/stats commands.
|
|
2
|
+
|
|
3
|
+
Artifacts written to the output dir:
|
|
4
|
+
|
|
5
|
+
- ``dataset.lock`` - reproduction anchor: embedded config + canonical
|
|
6
|
+
``candidates.jsonl`` hash + entry-id list. ``verify`` re-enumerates from it and
|
|
7
|
+
reports drift (added/removed entries, hash match), warning not failing.
|
|
8
|
+
- ``manifest.json`` - small (~KB) provenance record: config, drop log, per-split
|
|
9
|
+
+ per-class counts, cluster/component stats, and a ``files`` index pointing at
|
|
10
|
+
the data files below. No per-entry arrays, so it stays tiny at any scale. Built
|
|
11
|
+
as a pure function of its inputs (no wall-clock fields) -> byte-identical across
|
|
12
|
+
runs of the same config.
|
|
13
|
+
|
|
14
|
+
The split itself is plain lists of PDB ids, each its own file:
|
|
15
|
+
- ``train.json`` / ``val.json`` / ``test.json`` - the entry ids in each split
|
|
16
|
+
(one id per line; grepable and trivially loadable).
|
|
17
|
+
- ``test/<class>_test.json`` - the test ids carrying each functional ligand class
|
|
18
|
+
(``metal`` / ``small_molecule`` / ``nucleotide``), for per-class evaluation.
|
|
19
|
+
|
|
20
|
+
Supporting maps (only needed for sampling / curation, not to read the split):
|
|
21
|
+
- ``clusters.json`` - entry_id -> component key (for cluster-balanced sampling).
|
|
22
|
+
- ``ligands.classes.json`` - entry_id -> functional class labels.
|
|
23
|
+
- ``ligands.tiers.json`` - per-component curation *audit trail* (tier + reason);
|
|
24
|
+
bulky (~24 MB at full-PDB scale), read only by ``fetch`` and curation audits.
|
|
25
|
+
- ``splits.registry.json`` - canonical_key -> split, so a later, larger snapshot
|
|
26
|
+
reuses prior assignments instead of re-hashing (growth stability).
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import tempfile
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
from typing import Any
|
|
35
|
+
|
|
36
|
+
from . import __version__
|
|
37
|
+
from .config import Config
|
|
38
|
+
|
|
39
|
+
LOCK_SCHEMA = "if-split/lock@1"
|
|
40
|
+
MANIFEST_SCHEMA = "if-split/manifest@2"
|
|
41
|
+
REGISTRY_SCHEMA = "if-split/registry@1"
|
|
42
|
+
TIERS_SCHEMA = "if-split/tiers@1"
|
|
43
|
+
|
|
44
|
+
# Data files written next to manifest.json (referenced by manifest["files"]).
|
|
45
|
+
TIERS_FILENAME = "ligands.tiers.json"
|
|
46
|
+
CLASSES_FILENAME = "ligands.classes.json"
|
|
47
|
+
CLUSTERS_FILENAME = "clusters.json"
|
|
48
|
+
SPLIT_FILES = {"train": "train.json", "val": "val.json", "test": "test.json"}
|
|
49
|
+
TEST_SUBDIR = "test" # per-class test-id lists live here: test/<class>_test.json
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# --------------------------------------------------------------------------- #
|
|
53
|
+
# Lock (Stage 1 reproduction anchor)
|
|
54
|
+
# --------------------------------------------------------------------------- #
|
|
55
|
+
def build_lock(
|
|
56
|
+
cfg: Config,
|
|
57
|
+
*,
|
|
58
|
+
entry_ids: list[str],
|
|
59
|
+
candidates_sha256: str,
|
|
60
|
+
limit: int | None,
|
|
61
|
+
) -> dict[str, Any]:
|
|
62
|
+
"""Assemble the lock document (pure; does not touch disk)."""
|
|
63
|
+
return {
|
|
64
|
+
"lock_schema": LOCK_SCHEMA,
|
|
65
|
+
"dataset_version": cfg.dataset_version,
|
|
66
|
+
"if_split_version": __version__,
|
|
67
|
+
"config_hash": cfg.config_hash(),
|
|
68
|
+
"config": cfg.canonical_dict(),
|
|
69
|
+
"selection": {"limit": limit},
|
|
70
|
+
"candidates": {
|
|
71
|
+
"count": len(entry_ids),
|
|
72
|
+
"sha256": candidates_sha256,
|
|
73
|
+
"entry_ids": sorted(entry_ids),
|
|
74
|
+
},
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _write_json(obj: dict[str, Any], path: Path, *, compact: bool = False) -> Path:
|
|
79
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
if compact:
|
|
81
|
+
text = json.dumps(obj, sort_keys=True, separators=(",", ":"))
|
|
82
|
+
else:
|
|
83
|
+
text = json.dumps(obj, indent=2, sort_keys=True)
|
|
84
|
+
path.write_text(text + "\n", encoding="utf-8")
|
|
85
|
+
return path
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def write_lock(lock: dict[str, Any], out_dir: str | Path) -> Path:
|
|
89
|
+
return _write_json(lock, Path(out_dir) / "dataset.lock")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def read_lock(path: str | Path) -> dict[str, Any]:
|
|
93
|
+
path = Path(path)
|
|
94
|
+
if not path.exists():
|
|
95
|
+
raise FileNotFoundError(f"Lock file not found: {path}")
|
|
96
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# --------------------------------------------------------------------------- #
|
|
100
|
+
# Split registry (growth stability)
|
|
101
|
+
# --------------------------------------------------------------------------- #
|
|
102
|
+
def read_registry(path: str | Path) -> dict[str, str]:
|
|
103
|
+
"""Load a canonical_key -> split registry, or {} if absent."""
|
|
104
|
+
path = Path(path)
|
|
105
|
+
if not path.exists():
|
|
106
|
+
return {}
|
|
107
|
+
doc = json.loads(path.read_text(encoding="utf-8"))
|
|
108
|
+
return dict(doc.get("assignments", {}))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def write_registry(cluster_split: dict[str, str], out_dir: str | Path) -> Path:
|
|
112
|
+
doc = {
|
|
113
|
+
"registry_schema": REGISTRY_SCHEMA,
|
|
114
|
+
"assignments": dict(sorted(cluster_split.items())),
|
|
115
|
+
}
|
|
116
|
+
return _write_json(doc, Path(out_dir) / "splits.registry.json")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# --------------------------------------------------------------------------- #
|
|
120
|
+
# Manifest (human-facing, deterministic)
|
|
121
|
+
# --------------------------------------------------------------------------- #
|
|
122
|
+
def build_manifest(
|
|
123
|
+
cfg: Config,
|
|
124
|
+
*,
|
|
125
|
+
candidates_sha256: str,
|
|
126
|
+
n_candidates: int,
|
|
127
|
+
drops: list[dict],
|
|
128
|
+
drop_counts: dict[str, int],
|
|
129
|
+
clusters,
|
|
130
|
+
splits,
|
|
131
|
+
class_map: dict[str, dict],
|
|
132
|
+
) -> dict[str, Any]:
|
|
133
|
+
"""Assemble manifest.json as a pure function of the build outputs."""
|
|
134
|
+
from .split import SPLITS
|
|
135
|
+
|
|
136
|
+
# Per-split entry lists.
|
|
137
|
+
per_split: dict[str, list[str]] = {s: [] for s in SPLITS}
|
|
138
|
+
for entry, s in splits.entry_split.items():
|
|
139
|
+
per_split[s].append(entry)
|
|
140
|
+
for s in per_split:
|
|
141
|
+
per_split[s].sort()
|
|
142
|
+
|
|
143
|
+
# Per-split, per-class counts at the functional tier (the test-quality view),
|
|
144
|
+
# plus the ambiguous counts so under-/over-confidence is visible.
|
|
145
|
+
def _class_counts(entries, key):
|
|
146
|
+
counts: dict[str, int] = {}
|
|
147
|
+
for entry in entries:
|
|
148
|
+
for cls in class_map.get(entry, {}).get(key, []):
|
|
149
|
+
counts[cls] = counts.get(cls, 0) + 1
|
|
150
|
+
return dict(sorted(counts.items()))
|
|
151
|
+
|
|
152
|
+
per_split_class_counts = {s: _class_counts(per_split[s], "classes") for s in SPLITS}
|
|
153
|
+
per_split_ambiguous_counts = {
|
|
154
|
+
s: _class_counts(per_split[s], "ambiguous_classes") for s in SPLITS
|
|
155
|
+
}
|
|
156
|
+
n_artifacts = sum(1 for info in class_map.values() if info.get("purification_artifact"))
|
|
157
|
+
|
|
158
|
+
# Per-class test-id files that will be written (only classes that occur).
|
|
159
|
+
test_class_files = {
|
|
160
|
+
cls: f"{TEST_SUBDIR}/{cls}_test.json" for cls in per_split_class_counts["test"]
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
# The manifest is small provenance only: NO per-entry arrays live here. The
|
|
164
|
+
# split membership and supporting maps are separate files (see "files").
|
|
165
|
+
return {
|
|
166
|
+
"manifest_schema": MANIFEST_SCHEMA,
|
|
167
|
+
"dataset_version": cfg.dataset_version,
|
|
168
|
+
"if_split_version": __version__,
|
|
169
|
+
"config_hash": cfg.config_hash(),
|
|
170
|
+
"config": cfg.canonical_dict(),
|
|
171
|
+
"candidates": {"count": n_candidates, "sha256": candidates_sha256},
|
|
172
|
+
"filter": {
|
|
173
|
+
"kept": len(splits.entry_split),
|
|
174
|
+
"dropped": len(drops),
|
|
175
|
+
"drop_counts": dict(sorted(drop_counts.items())),
|
|
176
|
+
},
|
|
177
|
+
"clustering": {
|
|
178
|
+
"backend": cfg.clustering_backend,
|
|
179
|
+
"identity": clusters.identity,
|
|
180
|
+
"n_clusters": clusters.n_clusters,
|
|
181
|
+
"n_raw_clusters": clusters.n_raw_clusters,
|
|
182
|
+
"multichain_entries": len(clusters.multichain_entries),
|
|
183
|
+
"unclustered_entries": len(clusters.unclustered_entries),
|
|
184
|
+
},
|
|
185
|
+
"splits": {
|
|
186
|
+
"entry_counts": dict(sorted(splits.counts.items())),
|
|
187
|
+
"cluster_counts": dict(sorted(splits.cluster_counts.items())),
|
|
188
|
+
"per_split_class_counts": per_split_class_counts,
|
|
189
|
+
"per_split_ambiguous_counts": per_split_ambiguous_counts,
|
|
190
|
+
"test_minimum_shortfalls": dict(sorted(splits.minimum_shortfalls.items())),
|
|
191
|
+
},
|
|
192
|
+
"ligands": {"n_purification_artifacts": n_artifacts},
|
|
193
|
+
# Pointers to the data files written alongside this manifest.
|
|
194
|
+
"files": {
|
|
195
|
+
"splits": dict(SPLIT_FILES),
|
|
196
|
+
"test_by_class": dict(sorted(test_class_files.items())),
|
|
197
|
+
"clusters": CLUSTERS_FILENAME,
|
|
198
|
+
"ligand_classes": CLASSES_FILENAME,
|
|
199
|
+
"ligand_tiers": TIERS_FILENAME,
|
|
200
|
+
},
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
# --------------------------------------------------------------------------- #
|
|
205
|
+
# Split data files (the actual lists of PDB ids + supporting maps)
|
|
206
|
+
# --------------------------------------------------------------------------- #
|
|
207
|
+
def _write_id_list(ids: list[str], path: Path) -> Path:
|
|
208
|
+
"""Write a JSON array of ids, one per line (compact yet grepable)."""
|
|
209
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
210
|
+
body = ",\n".join(json.dumps(i) for i in ids)
|
|
211
|
+
path.write_text(f"[\n{body}\n]\n" if ids else "[]\n", encoding="utf-8")
|
|
212
|
+
return path
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def write_split_files(splits, class_map: dict[str, dict], out_dir: str | Path) -> dict[str, Path]:
|
|
216
|
+
"""Write train/val/test id lists + per-class test lists + supporting maps.
|
|
217
|
+
|
|
218
|
+
Returns a name->path map of everything written. Pure function of the inputs:
|
|
219
|
+
ids are sorted, so output is byte-stable.
|
|
220
|
+
"""
|
|
221
|
+
out = Path(out_dir)
|
|
222
|
+
per_split: dict[str, list[str]] = {s: [] for s in SPLIT_FILES}
|
|
223
|
+
for entry, s in splits.entry_split.items():
|
|
224
|
+
per_split[s].append(entry)
|
|
225
|
+
for s in per_split:
|
|
226
|
+
per_split[s].sort()
|
|
227
|
+
|
|
228
|
+
written: dict[str, Path] = {}
|
|
229
|
+
for s, fname in SPLIT_FILES.items():
|
|
230
|
+
written[s] = _write_id_list(per_split[s], out / fname)
|
|
231
|
+
|
|
232
|
+
# Per-class test-id lists: test entries carrying each functional class.
|
|
233
|
+
test_ids = per_split["test"]
|
|
234
|
+
class_to_ids: dict[str, list[str]] = {}
|
|
235
|
+
for eid in test_ids:
|
|
236
|
+
for cls in class_map.get(eid, {}).get("classes", []):
|
|
237
|
+
class_to_ids.setdefault(cls, []).append(eid)
|
|
238
|
+
for cls, ids in class_to_ids.items():
|
|
239
|
+
written[f"test:{cls}"] = _write_id_list(sorted(ids), out / TEST_SUBDIR / f"{cls}_test.json")
|
|
240
|
+
|
|
241
|
+
return written
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def write_clusters(entry_to_cluster: dict[str, str], out_dir: str | Path) -> Path:
|
|
245
|
+
"""entry_id -> component key, for cluster-balanced sampling."""
|
|
246
|
+
doc = {
|
|
247
|
+
"clusters_schema": "if-split/clusters@1",
|
|
248
|
+
"entry_clusters": dict(sorted(entry_to_cluster.items())),
|
|
249
|
+
}
|
|
250
|
+
return _write_json(doc, Path(out_dir) / CLUSTERS_FILENAME, compact=True)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def read_clusters(path: str | Path) -> dict[str, str]:
|
|
254
|
+
path = Path(path)
|
|
255
|
+
if not path.exists():
|
|
256
|
+
return {}
|
|
257
|
+
return dict(json.loads(path.read_text(encoding="utf-8")).get("entry_clusters", {}))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def write_classes(class_map: dict[str, dict], out_dir: str | Path) -> Path:
|
|
261
|
+
"""entry_id -> functional class labels."""
|
|
262
|
+
classes = {eid: info["classes"] for eid, info in sorted(class_map.items())}
|
|
263
|
+
doc = {"classes_schema": "if-split/classes@1", "classes": classes}
|
|
264
|
+
return _write_json(doc, Path(out_dir) / CLASSES_FILENAME, compact=True)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def read_classes(path: str | Path) -> dict[str, list[str]]:
|
|
268
|
+
path = Path(path)
|
|
269
|
+
if not path.exists():
|
|
270
|
+
return {}
|
|
271
|
+
return dict(json.loads(path.read_text(encoding="utf-8")).get("classes", {}))
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def read_id_list(path: str | Path) -> list[str]:
|
|
275
|
+
"""Read a split id-list file (train.json etc.) into a list of ids."""
|
|
276
|
+
path = Path(path)
|
|
277
|
+
if not path.exists():
|
|
278
|
+
return []
|
|
279
|
+
return list(json.loads(path.read_text(encoding="utf-8")))
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def write_manifest(manifest: dict[str, Any], out_dir: str | Path) -> Path:
|
|
283
|
+
# Pretty-print: the manifest is now small (KB) provenance, meant to be read.
|
|
284
|
+
return _write_json(manifest, Path(out_dir) / "manifest.json")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def read_manifest(path: str | Path) -> dict[str, Any]:
|
|
288
|
+
path = Path(path)
|
|
289
|
+
if not path.exists():
|
|
290
|
+
raise FileNotFoundError(f"Manifest not found: {path}")
|
|
291
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# --------------------------------------------------------------------------- #
|
|
295
|
+
# Ligand-tier audit sidecar (bulky; off the load path)
|
|
296
|
+
# --------------------------------------------------------------------------- #
|
|
297
|
+
def build_tiers_doc(class_map: dict[str, dict]) -> dict[str, Any]:
|
|
298
|
+
"""Per-entry, per-component tier + reason — the curation audit trail.
|
|
299
|
+
|
|
300
|
+
Pure function of ``class_map`` (the same input as the manifest), so it stays
|
|
301
|
+
deterministic and byte-stable. Lives in its own file because it is large and
|
|
302
|
+
read by nobody on the load path.
|
|
303
|
+
"""
|
|
304
|
+
tiers = {eid: info.get("tiers", {}) for eid, info in sorted(class_map.items())}
|
|
305
|
+
return {"tiers_schema": TIERS_SCHEMA, "tiers": tiers}
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def write_tiers(doc: dict[str, Any], out_dir: str | Path) -> Path:
|
|
309
|
+
return _write_json(doc, Path(out_dir) / TIERS_FILENAME, compact=True)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def read_tiers(path: str | Path) -> dict[str, dict]:
|
|
313
|
+
"""Load the tier map from a sidecar file, or {} if absent."""
|
|
314
|
+
path = Path(path)
|
|
315
|
+
if not path.exists():
|
|
316
|
+
return {}
|
|
317
|
+
doc = json.loads(path.read_text(encoding="utf-8"))
|
|
318
|
+
return dict(doc.get("tiers", {}))
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# --------------------------------------------------------------------------- #
|
|
322
|
+
# verify / stats commands
|
|
323
|
+
# --------------------------------------------------------------------------- #
|
|
324
|
+
def verify_lock(lock_path: str | Path, *, client=None) -> int:
|
|
325
|
+
"""Re-enumerate from a lock's embedded config and report drift.
|
|
326
|
+
|
|
327
|
+
Returns a process exit code: 0 = reproduced exactly, 1 = drift detected.
|
|
328
|
+
``client`` is injectable for offline testing; production passes None.
|
|
329
|
+
"""
|
|
330
|
+
from .enumerate import enumerate_candidates
|
|
331
|
+
|
|
332
|
+
lock = read_lock(lock_path)
|
|
333
|
+
if lock.get("lock_schema") != LOCK_SCHEMA:
|
|
334
|
+
print(f"warning: unexpected lock_schema {lock.get('lock_schema')!r}")
|
|
335
|
+
|
|
336
|
+
cfg = Config.model_validate(lock["config"])
|
|
337
|
+
limit = (lock.get("selection") or {}).get("limit")
|
|
338
|
+
locked = lock["candidates"]
|
|
339
|
+
locked_ids = set(locked["entry_ids"])
|
|
340
|
+
locked_sha = locked["sha256"]
|
|
341
|
+
|
|
342
|
+
print(f"verifying {lock['dataset_version']} (config {cfg.config_hash()})")
|
|
343
|
+
print(f" locked: {locked['count']} entries, candidates sha256={locked_sha[:12]}...")
|
|
344
|
+
|
|
345
|
+
with tempfile.TemporaryDirectory() as tmp:
|
|
346
|
+
records, _, sha = enumerate_candidates(
|
|
347
|
+
cfg, tmp, limit=limit, client=client, progress=lambda m: print(f" {m}")
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
now_ids = {r.entry_id for r in records}
|
|
351
|
+
added = sorted(now_ids - locked_ids)
|
|
352
|
+
removed = sorted(locked_ids - now_ids) # obsoleted / withdrawn
|
|
353
|
+
|
|
354
|
+
if sha == locked_sha and not added and not removed:
|
|
355
|
+
print(f"OK: reproduced exactly ({len(records)} entries, hashes match).")
|
|
356
|
+
return 0
|
|
357
|
+
|
|
358
|
+
print("DRIFT detected:")
|
|
359
|
+
if sha != locked_sha:
|
|
360
|
+
print(f" candidates sha256 differs: now {sha[:12]}... vs locked {locked_sha[:12]}...")
|
|
361
|
+
if removed:
|
|
362
|
+
print(f" {len(removed)} entries no longer present (obsoleted/withdrawn):")
|
|
363
|
+
print(f" {', '.join(removed[:20])}{' ...' if len(removed) > 20 else ''}")
|
|
364
|
+
if added:
|
|
365
|
+
print(f" {len(added)} new entries match the snapshot filters:")
|
|
366
|
+
print(f" {', '.join(added[:20])}{' ...' if len(added) > 20 else ''}")
|
|
367
|
+
if not added and not removed:
|
|
368
|
+
print(" entry set unchanged, but per-entry metadata changed (see hash).")
|
|
369
|
+
return 1
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def summarize_manifest(manifest_path: str | Path) -> int:
|
|
373
|
+
"""`stats` command: print split sizes and per-class (functional) test counts."""
|
|
374
|
+
m = read_manifest(manifest_path)
|
|
375
|
+
print(f"{m['dataset_version']} (config {m['config_hash']})")
|
|
376
|
+
flt = m["filter"]
|
|
377
|
+
print(
|
|
378
|
+
f" candidates: {m['candidates']['count']} kept: {flt['kept']} dropped: {flt['dropped']}"
|
|
379
|
+
)
|
|
380
|
+
for reason, n in flt["drop_counts"].items():
|
|
381
|
+
print(f" - {reason}: {n}")
|
|
382
|
+
cl = m["clustering"]
|
|
383
|
+
print(
|
|
384
|
+
f" clustering: {cl['backend']} @ {cl['identity']}% "
|
|
385
|
+
f"components={cl['n_clusters']} (from {cl.get('n_raw_clusters', '?')} raw) "
|
|
386
|
+
f"multichain={cl['multichain_entries']}"
|
|
387
|
+
)
|
|
388
|
+
sp = m["splits"]
|
|
389
|
+
print(" splits (entries / components):")
|
|
390
|
+
for s in ("train", "val", "test"):
|
|
391
|
+
ec = sp["entry_counts"].get(s, 0)
|
|
392
|
+
cc = sp["cluster_counts"].get(s, 0)
|
|
393
|
+
print(f" {s:5s}: {ec:>7} entries {cc:>7} components")
|
|
394
|
+
print(" test set by ligand class (functional tier):")
|
|
395
|
+
for cls, n in sp["per_split_class_counts"].get("test", {}).items():
|
|
396
|
+
print(f" {cls}: {n}")
|
|
397
|
+
amb = sp.get("per_split_ambiguous_counts", {}).get("test", {})
|
|
398
|
+
if amb:
|
|
399
|
+
print(" test set ambiguous (reported, not labelled):")
|
|
400
|
+
for cls, n in amb.items():
|
|
401
|
+
print(f" {cls}: {n}")
|
|
402
|
+
shortfalls = sp.get("test_minimum_shortfalls", {})
|
|
403
|
+
if shortfalls:
|
|
404
|
+
print(" test minimum shortfalls (floor exceeded available supply):")
|
|
405
|
+
for cls, n in shortfalls.items():
|
|
406
|
+
print(f" {cls}: short by {n}")
|
|
407
|
+
lig = m["ligands"]
|
|
408
|
+
n_arts = lig.get("n_purification_artifacts", len(lig.get("purification_artifacts", [])))
|
|
409
|
+
print(f" His-tag/Ni purification artifacts flagged: {n_arts}")
|
|
410
|
+
files = m.get("files", {})
|
|
411
|
+
if files:
|
|
412
|
+
sf = files.get("splits", {})
|
|
413
|
+
tbc = files.get("test_by_class", {})
|
|
414
|
+
print(f" split files: {', '.join(sf.values())}")
|
|
415
|
+
if tbc:
|
|
416
|
+
print(f" per-class test files: {', '.join(tbc.values())}")
|
|
417
|
+
return 0
|