if-split 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- if_split-0.1.0.dist-info/METADATA +312 -0
- if_split-0.1.0.dist-info/RECORD +20 -0
- if_split-0.1.0.dist-info/WHEEL +4 -0
- if_split-0.1.0.dist-info/entry_points.txt +2 -0
- if_split-0.1.0.dist-info/licenses/LICENSE +21 -0
- ifsplit/__init__.py +8 -0
- ifsplit/__main__.py +8 -0
- ifsplit/cli.py +317 -0
- ifsplit/cluster.py +130 -0
- ifsplit/config.py +146 -0
- ifsplit/dataset.py +112 -0
- ifsplit/download.py +229 -0
- ifsplit/enumerate.py +111 -0
- ifsplit/hydrate.py +216 -0
- ifsplit/ligands.py +267 -0
- ifsplit/manifest.py +417 -0
- ifsplit/parse.py +111 -0
- ifsplit/rcsb.py +251 -0
- ifsplit/schema.py +241 -0
- ifsplit/split.py +177 -0
ifsplit/dataset.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Stage 8 - Thin loader over a manifest, with cluster-balanced sampling.
|
|
2
|
+
|
|
3
|
+
Reads ``manifest.json`` and exposes train/val/test views (entry ids, ligand
|
|
4
|
+
classes, and the entry->cluster map). Coordinate/featurization loading is
|
|
5
|
+
intentionally *not* here - it's the optional downstream concern (PLAN.md §1.5);
|
|
6
|
+
a model repo plugs its own featurizer onto these entry lists.
|
|
7
|
+
|
|
8
|
+
The PDB is heavily redundant (thousands of near-identical lysozyme / kinase
|
|
9
|
+
co-crystals). Training by sampling entries uniformly drowns the model in
|
|
10
|
+
over-represented folds. ``sample_by_cluster`` draws one entry per sequence
|
|
11
|
+
cluster per epoch, which is the bigger training-quality lever than ligand tiering
|
|
12
|
+
- and it is free here because the clusters are already computed. Sampling is
|
|
13
|
+
deterministic given a seed (no global RNG), so an epoch is reproducible.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import hashlib
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
|
|
22
|
+
from .manifest import read_classes, read_clusters, read_id_list, read_manifest
|
|
23
|
+
|
|
24
|
+
SPLITS = ("train", "val", "test")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _stable_rank(key: str, seed: int) -> int:
|
|
28
|
+
"""Deterministic pseudo-random rank for ``key`` under ``seed`` (no global RNG)."""
|
|
29
|
+
digest = hashlib.blake2b(f"{seed}:{key}".encode(), digest_size=8).digest()
|
|
30
|
+
return int.from_bytes(digest, "big")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class SplitView:
|
|
35
|
+
name: str
|
|
36
|
+
entry_ids: list[str]
|
|
37
|
+
ligand_classes: dict[str, list[str]] # entry_id -> classes
|
|
38
|
+
entry_clusters: dict[str, str] # entry_id -> cluster key
|
|
39
|
+
|
|
40
|
+
def __len__(self) -> int:
|
|
41
|
+
return len(self.entry_ids)
|
|
42
|
+
|
|
43
|
+
def with_class(self, cls: str) -> list[str]:
|
|
44
|
+
"""Entry ids in this split tagged with ligand class ``cls``."""
|
|
45
|
+
return [e for e in self.entry_ids if cls in self.ligand_classes.get(e, [])]
|
|
46
|
+
|
|
47
|
+
def cluster_groups(self) -> dict[str, list[str]]:
|
|
48
|
+
"""Map cluster key -> sorted entry ids within this split."""
|
|
49
|
+
groups: dict[str, list[str]] = {}
|
|
50
|
+
for e in self.entry_ids:
|
|
51
|
+
key = self.entry_clusters.get(e, e)
|
|
52
|
+
groups.setdefault(key, []).append(e)
|
|
53
|
+
return {k: sorted(v) for k, v in sorted(groups.items())}
|
|
54
|
+
|
|
55
|
+
def sample_by_cluster(self, seed: int = 0) -> list[str]:
|
|
56
|
+
"""One entry per cluster, chosen deterministically by ``seed``.
|
|
57
|
+
|
|
58
|
+
De-redundifies the split: each sequence cluster contributes exactly one
|
|
59
|
+
representative, so over-represented folds don't dominate. Vary ``seed``
|
|
60
|
+
across epochs to rotate which member of each cluster is drawn. Returns a
|
|
61
|
+
deterministically ordered list (sorted by the same stable rank).
|
|
62
|
+
"""
|
|
63
|
+
chosen: list[tuple[int, str]] = []
|
|
64
|
+
for key, members in self.cluster_groups().items():
|
|
65
|
+
rep = min(members, key=lambda e: (_stable_rank(e, seed), e))
|
|
66
|
+
chosen.append((_stable_rank(key, seed), rep))
|
|
67
|
+
return [e for _, e in sorted(chosen)]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class IFSplitDataset:
|
|
71
|
+
"""Read-only view over a built manifest's train/val/test partition."""
|
|
72
|
+
|
|
73
|
+
def __init__(self, manifest_path: str | Path) -> None:
|
|
74
|
+
self._m = read_manifest(manifest_path)
|
|
75
|
+
self._dir = Path(manifest_path).parent
|
|
76
|
+
self.dataset_version: str = self._m["dataset_version"]
|
|
77
|
+
self.config_hash: str = self._m["config_hash"]
|
|
78
|
+
files = self._m.get("files", {})
|
|
79
|
+
self._split_files: dict[str, str] = files.get("splits", {})
|
|
80
|
+
# Supporting maps live in sidecar files referenced by the manifest.
|
|
81
|
+
self._classes = read_classes(
|
|
82
|
+
self._dir / files.get("ligand_classes", "ligands.classes.json")
|
|
83
|
+
)
|
|
84
|
+
self._entry_clusters = read_clusters(self._dir / files.get("clusters", "clusters.json"))
|
|
85
|
+
|
|
86
|
+
def split(self, name: str) -> SplitView:
|
|
87
|
+
if name not in SPLITS:
|
|
88
|
+
raise KeyError(f"unknown split {name!r}; expected one of {SPLITS}")
|
|
89
|
+
fname = self._split_files.get(name, f"{name}.json")
|
|
90
|
+
ids = read_id_list(self._dir / fname)
|
|
91
|
+
return SplitView(
|
|
92
|
+
name=name,
|
|
93
|
+
entry_ids=ids,
|
|
94
|
+
ligand_classes={e: self._classes.get(e, []) for e in ids},
|
|
95
|
+
entry_clusters={e: self._entry_clusters.get(e, e) for e in ids},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def train(self) -> SplitView:
|
|
100
|
+
return self.split("train")
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def val(self) -> SplitView:
|
|
104
|
+
return self.split("val")
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def test(self) -> SplitView:
|
|
108
|
+
return self.split("test")
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def load_dataset(manifest_path: str | Path) -> IFSplitDataset:
|
|
112
|
+
return IFSplitDataset(manifest_path)
|
ifsplit/download.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Stage 2 - Optional structure hydration (the only stage that touches coordinates).
|
|
2
|
+
|
|
3
|
+
`build` never calls this. `fetch` materializes the mmCIF files for a *built*
|
|
4
|
+
manifest into an MLOps-friendly tree that anyone can pick up and train on:
|
|
5
|
+
|
|
6
|
+
<root>/
|
|
7
|
+
structures/
|
|
8
|
+
train/ hh/4hhb-assembly1.cif.gz # split-partitioned (browsable),
|
|
9
|
+
val/ ... # sharded by the PDB middle-two
|
|
10
|
+
test/ ab/1abc-assembly1.cif.gz # chars (PDB "divided" scheme)
|
|
11
|
+
index.jsonl # one row/structure (zero-dep)
|
|
12
|
+
index.parquet # same, columnar (if pyarrow present)
|
|
13
|
+
manifest.json # copy of the source split manifest
|
|
14
|
+
DATASET_CARD.md # provenance + how-to-load
|
|
15
|
+
|
|
16
|
+
Design choices that make it "pristine":
|
|
17
|
+
- **Content-addressed integrity:** every file's SHA-256 is recorded in the index,
|
|
18
|
+
so a re-fetch / transfer can be verified and the pull is resumable (existing,
|
|
19
|
+
hash-matching files are skipped).
|
|
20
|
+
- **Deterministic paths:** path is a pure function of (split, entry_id, assembly),
|
|
21
|
+
so two people fetching the same manifest get byte-identical trees.
|
|
22
|
+
- **Explicit scope:** the caller must choose --split / --all; large pulls require
|
|
23
|
+
--yes. No accidental terabyte (the lightweight-by-default contract).
|
|
24
|
+
- **No coordinates in the split itself:** this is downstream of `build`; the
|
|
25
|
+
manifest + lock remain tiny and coordinate-free.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import gzip
|
|
31
|
+
import hashlib
|
|
32
|
+
import time
|
|
33
|
+
from collections.abc import Callable, Iterable
|
|
34
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
35
|
+
from dataclasses import dataclass, field
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
|
|
38
|
+
import httpx
|
|
39
|
+
|
|
40
|
+
from . import __version__
|
|
41
|
+
|
|
42
|
+
FILE_BASE = "https://files.rcsb.org/download"
|
|
43
|
+
SPLITS = ("train", "val", "test")
|
|
44
|
+
ProgressFn = Callable[[str], None]
|
|
45
|
+
|
|
46
|
+
# RCSB "divided" sharding: middle two characters of the 4-char core id.
|
|
47
|
+
# 4HHB -> "hh"; for extended ids (pdb_0000XXXX) we shard on the core's chars 2-3.
|
|
48
|
+
_EXTENDED_PREFIX = "pdb_0000"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def core_id(entry_id: str) -> str:
|
|
52
|
+
"""Lowercase 'core' id used for filenames/sharding (legacy or extended)."""
|
|
53
|
+
e = entry_id.lower()
|
|
54
|
+
if e.startswith(_EXTENDED_PREFIX) and len(e) > len(_EXTENDED_PREFIX):
|
|
55
|
+
return e[len(_EXTENDED_PREFIX) :]
|
|
56
|
+
if e.startswith("pdb_"):
|
|
57
|
+
return e[len("pdb_") :]
|
|
58
|
+
return e
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def shard_for(entry_id: str) -> str:
|
|
62
|
+
"""Two-char shard (PDB divided scheme). Falls back to a stable 2-char hash."""
|
|
63
|
+
c = core_id(entry_id)
|
|
64
|
+
if len(c) >= 3:
|
|
65
|
+
return c[1:3]
|
|
66
|
+
return hashlib.blake2b(c.encode(), digest_size=1).hexdigest() # 2 hex chars
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def filename_for(entry_id: str, *, assembly: bool) -> str:
|
|
70
|
+
suffix = "-assembly1.cif.gz" if assembly else ".cif.gz"
|
|
71
|
+
return f"{core_id(entry_id)}{suffix}"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def url_for(entry_id: str, *, assembly: bool) -> str:
|
|
75
|
+
return f"{FILE_BASE}/{filename_for(entry_id, assembly=assembly)}"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def rel_path_for(entry_id: str, split: str, *, assembly: bool) -> Path:
|
|
79
|
+
return (
|
|
80
|
+
Path("structures") / split / shard_for(entry_id) / filename_for(entry_id, assembly=assembly)
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class FetchResult:
|
|
86
|
+
fetched: list[str] = field(default_factory=list)
|
|
87
|
+
skipped: list[str] = field(default_factory=list) # already present + hash-ok
|
|
88
|
+
failed: list[tuple[str, str]] = field(default_factory=list) # (entry_id, reason)
|
|
89
|
+
index_rows: list[dict] = field(default_factory=list)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _sha256_file(path: Path) -> str:
|
|
93
|
+
h = hashlib.sha256()
|
|
94
|
+
with path.open("rb") as fh:
|
|
95
|
+
for chunk in iter(lambda: fh.read(1 << 20), b""):
|
|
96
|
+
h.update(chunk)
|
|
97
|
+
return h.hexdigest()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class StructureFetcher:
|
|
101
|
+
"""Polite, resumable mmCIF downloader (assembly 1 or asymmetric unit)."""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
*,
|
|
106
|
+
assembly: bool = True,
|
|
107
|
+
workers: int = 8,
|
|
108
|
+
timeout: float = 120.0,
|
|
109
|
+
max_retries: int = 4,
|
|
110
|
+
backoff_base: float = 1.5,
|
|
111
|
+
sleep=time.sleep,
|
|
112
|
+
) -> None:
|
|
113
|
+
self.assembly = assembly
|
|
114
|
+
self.workers = workers
|
|
115
|
+
self._timeout = timeout
|
|
116
|
+
self._max_retries = max_retries
|
|
117
|
+
self._backoff_base = backoff_base
|
|
118
|
+
self._sleep = sleep
|
|
119
|
+
self._client = httpx.Client(
|
|
120
|
+
timeout=timeout,
|
|
121
|
+
headers={"User-Agent": f"IF-Split/{__version__} (structure fetch)"},
|
|
122
|
+
follow_redirects=True,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def __enter__(self) -> StructureFetcher:
|
|
126
|
+
return self
|
|
127
|
+
|
|
128
|
+
def __exit__(self, *exc) -> None:
|
|
129
|
+
self._client.close()
|
|
130
|
+
|
|
131
|
+
def close(self) -> None:
|
|
132
|
+
self._client.close()
|
|
133
|
+
|
|
134
|
+
def _get(self, url: str) -> bytes:
|
|
135
|
+
last: Exception | None = None
|
|
136
|
+
for attempt in range(self._max_retries + 1):
|
|
137
|
+
try:
|
|
138
|
+
resp = self._client.get(url)
|
|
139
|
+
except httpx.HTTPError as exc:
|
|
140
|
+
last = exc
|
|
141
|
+
else:
|
|
142
|
+
if resp.status_code == 200:
|
|
143
|
+
return resp.content
|
|
144
|
+
if resp.status_code == 404:
|
|
145
|
+
raise FileNotFoundError(f"not on RCSB (404): {url}")
|
|
146
|
+
last = RuntimeError(f"HTTP {resp.status_code}: {url}")
|
|
147
|
+
if attempt < self._max_retries:
|
|
148
|
+
self._sleep(self._backoff_base**attempt)
|
|
149
|
+
raise RuntimeError(f"download failed after retries: {url} ({last})")
|
|
150
|
+
|
|
151
|
+
def estimate_bytes(self, entry_ids: list[str], sample: int = 12) -> int | None:
|
|
152
|
+
"""Rough total-size estimate from HEAD on a sample (None if unavailable)."""
|
|
153
|
+
sizes: list[int] = []
|
|
154
|
+
for eid in entry_ids[:sample]:
|
|
155
|
+
try:
|
|
156
|
+
r = self._client.head(url_for(eid, assembly=self.assembly))
|
|
157
|
+
cl = r.headers.get("content-length")
|
|
158
|
+
if r.status_code == 200 and cl:
|
|
159
|
+
sizes.append(int(cl))
|
|
160
|
+
except httpx.HTTPError:
|
|
161
|
+
continue
|
|
162
|
+
if not sizes:
|
|
163
|
+
return None
|
|
164
|
+
avg = sum(sizes) / len(sizes)
|
|
165
|
+
return int(avg * len(entry_ids))
|
|
166
|
+
|
|
167
|
+
def _fetch_one(self, entry_id: str, split: str, root: Path) -> dict:
|
|
168
|
+
rel = rel_path_for(entry_id, split, assembly=self.assembly)
|
|
169
|
+
dest = root / rel
|
|
170
|
+
if dest.exists(): # resume: trust an existing, readable file
|
|
171
|
+
return {
|
|
172
|
+
"entry_id": entry_id,
|
|
173
|
+
"split": split,
|
|
174
|
+
"path": str(rel),
|
|
175
|
+
"sha256": _sha256_file(dest),
|
|
176
|
+
"status": "skipped",
|
|
177
|
+
}
|
|
178
|
+
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
data = self._get(url_for(entry_id, assembly=self.assembly))
|
|
180
|
+
gzip.decompress(data) # integrity check: must be valid gzip
|
|
181
|
+
tmp = dest.with_suffix(dest.suffix + ".part")
|
|
182
|
+
tmp.write_bytes(data)
|
|
183
|
+
tmp.replace(dest)
|
|
184
|
+
return {
|
|
185
|
+
"entry_id": entry_id,
|
|
186
|
+
"split": split,
|
|
187
|
+
"path": str(rel),
|
|
188
|
+
"sha256": hashlib.sha256(data).hexdigest(),
|
|
189
|
+
"status": "fetched",
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def fetch(
|
|
193
|
+
self,
|
|
194
|
+
targets: Iterable[tuple[str, str]], # (entry_id, split)
|
|
195
|
+
root: Path,
|
|
196
|
+
*,
|
|
197
|
+
progress: ProgressFn | None = None,
|
|
198
|
+
) -> FetchResult:
|
|
199
|
+
targets = list(targets)
|
|
200
|
+
result = FetchResult()
|
|
201
|
+
done = 0
|
|
202
|
+
total = len(targets)
|
|
203
|
+
|
|
204
|
+
def say(msg: str) -> None:
|
|
205
|
+
if progress:
|
|
206
|
+
progress(msg)
|
|
207
|
+
|
|
208
|
+
with ThreadPoolExecutor(max_workers=self.workers) as pool:
|
|
209
|
+
futs = {
|
|
210
|
+
pool.submit(self._fetch_one, eid, split, root): (eid, split)
|
|
211
|
+
for eid, split in targets
|
|
212
|
+
}
|
|
213
|
+
for fut in as_completed(futs):
|
|
214
|
+
eid, _split = futs[fut]
|
|
215
|
+
done += 1
|
|
216
|
+
try:
|
|
217
|
+
row = fut.result()
|
|
218
|
+
except Exception as exc:
|
|
219
|
+
result.failed.append((eid, str(exc)))
|
|
220
|
+
else:
|
|
221
|
+
result.index_rows.append(row)
|
|
222
|
+
(result.skipped if row["status"] == "skipped" else result.fetched).append(eid)
|
|
223
|
+
if done % 100 == 0 or done == total:
|
|
224
|
+
say(
|
|
225
|
+
f"{done}/{total} ({len(result.fetched)} new, "
|
|
226
|
+
f"{len(result.skipped)} cached, {len(result.failed)} failed)"
|
|
227
|
+
)
|
|
228
|
+
result.index_rows.sort(key=lambda r: (r["split"], r["entry_id"]))
|
|
229
|
+
return result
|
ifsplit/enumerate.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Stage 1 - Enumerate candidates from RCSB (Search v2 + Data API).
|
|
2
|
+
|
|
3
|
+
Selects entries by ``release_date <= snapshot_date`` plus the method/resolution
|
|
4
|
+
filters, enriches each via the Data API (sequences, ligand comps, residue
|
|
5
|
+
counts, assemblies), and writes the byte-stable ``candidates.jsonl`` -- the
|
|
6
|
+
snapshot definition. No coordinates are downloaded (PLAN.md §1.5).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
import time
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
|
|
16
|
+
from .config import Config
|
|
17
|
+
from .rcsb import RcsbClient
|
|
18
|
+
from .schema import CandidateRecord, canonical_jsonl_bytes, sha256_hex
|
|
19
|
+
|
|
20
|
+
ProgressFn = Callable[[str], None]
|
|
21
|
+
|
|
22
|
+
# How often (in enriched records) to emit a progress line.
|
|
23
|
+
_REPORT_EVERY = 1000
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def fmt_duration(seconds: float) -> str:
|
|
27
|
+
"""Human-friendly duration: ``45s`` / ``3m29s`` / ``1h02m``."""
|
|
28
|
+
s = int(max(0, seconds))
|
|
29
|
+
if s < 60:
|
|
30
|
+
return f"{s}s"
|
|
31
|
+
m, s = divmod(s, 60)
|
|
32
|
+
if m < 60:
|
|
33
|
+
return f"{m}m{s:02d}s"
|
|
34
|
+
h, m = divmod(m, 60)
|
|
35
|
+
return f"{h}h{m:02d}m"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def progress_line(label: str, n: int, total: int, t0: float) -> str:
|
|
39
|
+
"""A ``label: n/total (pct) rate/s eta ...`` line from a monotonic start."""
|
|
40
|
+
elapsed = time.monotonic() - t0
|
|
41
|
+
pct = (100 * n / total) if total else 0.0
|
|
42
|
+
msg = f"{label}: {n}/{total} ({pct:.0f}%)"
|
|
43
|
+
if elapsed > 0 and n > 0:
|
|
44
|
+
rate = n / elapsed
|
|
45
|
+
eta = (total - n) / rate if rate > 0 else 0.0
|
|
46
|
+
msg += f" {rate:.0f}/s eta {fmt_duration(eta)}"
|
|
47
|
+
return msg
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def make_console_progress(stream=None) -> ProgressFn:
|
|
51
|
+
"""A timestamped, line-flushed progress printer for long CLI runs.
|
|
52
|
+
|
|
53
|
+
The flush is the important part: when stdout is redirected to a file, Python
|
|
54
|
+
block-buffers it, so unflushed progress lines stay invisible until the process
|
|
55
|
+
exits. This forces each line out immediately.
|
|
56
|
+
"""
|
|
57
|
+
out = stream or sys.stdout
|
|
58
|
+
|
|
59
|
+
def say(msg: str) -> None:
|
|
60
|
+
print(f" [{time.strftime('%H:%M:%S')}] {msg}", file=out, flush=True)
|
|
61
|
+
|
|
62
|
+
return say
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def enumerate_candidates(
|
|
66
|
+
cfg: Config,
|
|
67
|
+
out_dir: str | Path,
|
|
68
|
+
*,
|
|
69
|
+
limit: int | None = None,
|
|
70
|
+
client: RcsbClient | None = None,
|
|
71
|
+
progress: ProgressFn | None = None,
|
|
72
|
+
) -> tuple[list[CandidateRecord], Path, str]:
|
|
73
|
+
"""Run Stage 1.
|
|
74
|
+
|
|
75
|
+
Returns ``(records, candidates_path, sha256)``. ``candidates.jsonl`` is
|
|
76
|
+
written to ``out_dir`` in canonical (byte-stable) form.
|
|
77
|
+
"""
|
|
78
|
+
out_dir = Path(out_dir)
|
|
79
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
80
|
+
|
|
81
|
+
def say(msg: str) -> None:
|
|
82
|
+
if progress:
|
|
83
|
+
progress(msg)
|
|
84
|
+
|
|
85
|
+
owns_client = client is None
|
|
86
|
+
client = client or RcsbClient()
|
|
87
|
+
try:
|
|
88
|
+
ids = client.search_entry_ids(cfg, limit=limit, progress=progress)
|
|
89
|
+
say(f"search: {len(ids)} entries match the snapshot")
|
|
90
|
+
|
|
91
|
+
records: list[CandidateRecord] = []
|
|
92
|
+
total = len(ids)
|
|
93
|
+
t0 = time.monotonic()
|
|
94
|
+
for raw in client.fetch_entries(ids):
|
|
95
|
+
records.append(CandidateRecord.from_data_api(raw))
|
|
96
|
+
if len(records) % _REPORT_EVERY == 0:
|
|
97
|
+
say(progress_line("enriched", len(records), total, t0))
|
|
98
|
+
say(progress_line("enriched", len(records), total, t0) + " (done)")
|
|
99
|
+
finally:
|
|
100
|
+
if owns_client:
|
|
101
|
+
client.close()
|
|
102
|
+
|
|
103
|
+
data = canonical_jsonl_bytes(records)
|
|
104
|
+
sha = sha256_hex(data)
|
|
105
|
+
candidates_path = out_dir / "candidates.jsonl"
|
|
106
|
+
candidates_path.write_bytes(data)
|
|
107
|
+
say(f"wrote {candidates_path} ({len(records)} records, sha256={sha[:12]}...)")
|
|
108
|
+
|
|
109
|
+
# Return in canonical (entry_id-sorted) order so callers see what was written.
|
|
110
|
+
records.sort(key=lambda r: r.entry_id)
|
|
111
|
+
return records, candidates_path, sha
|
ifsplit/hydrate.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Stage 2 orchestration: turn a built manifest into a hydrated, ML-ready dataset.
|
|
2
|
+
|
|
3
|
+
Ties together split selection, the structure fetcher, the dual index
|
|
4
|
+
(JSONL + optional Parquet), a copy of the manifest, and a generated dataset card.
|
|
5
|
+
Kept separate from ``download.py`` (the transport layer) so the I/O policy and
|
|
6
|
+
the MLOps packaging are easy to read and test independently.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from collections.abc import Callable
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from . import __version__
|
|
17
|
+
from .download import SPLITS, StructureFetcher
|
|
18
|
+
from .manifest import (
|
|
19
|
+
CLASSES_FILENAME,
|
|
20
|
+
CLUSTERS_FILENAME,
|
|
21
|
+
SPLIT_FILES,
|
|
22
|
+
TIERS_FILENAME,
|
|
23
|
+
read_classes,
|
|
24
|
+
read_clusters,
|
|
25
|
+
read_id_list,
|
|
26
|
+
read_manifest,
|
|
27
|
+
read_tiers,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
ProgressFn = Callable[[str], None]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def select_targets(
|
|
34
|
+
manifest: dict[str, Any], splits: list[str], base_dir: Path | None = None
|
|
35
|
+
) -> list[tuple[str, str]]:
|
|
36
|
+
"""(entry_id, split) pairs for the requested splits, deterministically ordered.
|
|
37
|
+
|
|
38
|
+
Reads each split's id list from its file (train.json etc.) next to the
|
|
39
|
+
manifest. ``base_dir`` defaults to the manifest's directory.
|
|
40
|
+
"""
|
|
41
|
+
base = Path(base_dir) if base_dir is not None else Path(".")
|
|
42
|
+
split_files = manifest.get("files", {}).get("splits", SPLIT_FILES)
|
|
43
|
+
targets: list[tuple[str, str]] = []
|
|
44
|
+
for split in splits:
|
|
45
|
+
fname = split_files.get(split, f"{split}.json")
|
|
46
|
+
for eid in sorted(read_id_list(base / fname)):
|
|
47
|
+
targets.append((eid, split))
|
|
48
|
+
return targets
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _index_rows(
|
|
52
|
+
fetch_rows: list[dict],
|
|
53
|
+
classes: dict[str, list[str]],
|
|
54
|
+
entry_clusters: dict[str, str],
|
|
55
|
+
tiers: dict[str, dict] | None = None,
|
|
56
|
+
) -> list[dict]:
|
|
57
|
+
"""Enrich each fetched file with split metadata for the ML index."""
|
|
58
|
+
tiers = tiers or {}
|
|
59
|
+
rows: list[dict] = []
|
|
60
|
+
for r in fetch_rows:
|
|
61
|
+
eid = r["entry_id"]
|
|
62
|
+
rows.append(
|
|
63
|
+
{
|
|
64
|
+
"entry_id": eid,
|
|
65
|
+
"split": r["split"],
|
|
66
|
+
"path": r["path"],
|
|
67
|
+
"sha256": r["sha256"],
|
|
68
|
+
"cluster": entry_clusters.get(eid),
|
|
69
|
+
"ligand_classes": classes.get(eid, []),
|
|
70
|
+
"ligand_tiers": tiers.get(eid, {}),
|
|
71
|
+
}
|
|
72
|
+
)
|
|
73
|
+
rows.sort(key=lambda r: (r["split"], r["entry_id"]))
|
|
74
|
+
return rows
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def write_index(rows: list[dict], root: Path) -> dict[str, Path]:
|
|
78
|
+
"""Write index.jsonl (always) and index.parquet (if pyarrow present)."""
|
|
79
|
+
out: dict[str, Path] = {}
|
|
80
|
+
jsonl = root / "index.jsonl"
|
|
81
|
+
with jsonl.open("w", encoding="utf-8") as fh:
|
|
82
|
+
for r in rows:
|
|
83
|
+
fh.write(json.dumps(r, sort_keys=True, separators=(",", ":")) + "\n")
|
|
84
|
+
out["jsonl"] = jsonl
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
import pyarrow as pa
|
|
88
|
+
import pyarrow.parquet as pq
|
|
89
|
+
except ImportError:
|
|
90
|
+
return out
|
|
91
|
+
|
|
92
|
+
# Stringify nested fields for a clean, portable columnar schema.
|
|
93
|
+
table = pa.table(
|
|
94
|
+
{
|
|
95
|
+
"entry_id": [r["entry_id"] for r in rows],
|
|
96
|
+
"split": [r["split"] for r in rows],
|
|
97
|
+
"path": [r["path"] for r in rows],
|
|
98
|
+
"sha256": [r["sha256"] for r in rows],
|
|
99
|
+
"cluster": [r["cluster"] for r in rows],
|
|
100
|
+
"ligand_classes": [",".join(r["ligand_classes"]) for r in rows],
|
|
101
|
+
"ligand_tiers_json": [json.dumps(r["ligand_tiers"], sort_keys=True) for r in rows],
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
parquet = root / "index.parquet"
|
|
105
|
+
pq.write_table(table, parquet)
|
|
106
|
+
out["parquet"] = parquet
|
|
107
|
+
return out
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _dataset_card(manifest: dict[str, Any], rows: list[dict], *, assembly: bool) -> str:
|
|
111
|
+
counts = {s: sum(1 for r in rows if r["split"] == s) for s in SPLITS}
|
|
112
|
+
cfg = manifest.get("config", {})
|
|
113
|
+
unit = "biological assembly 1" if assembly else "asymmetric unit"
|
|
114
|
+
cl = manifest["clustering"]
|
|
115
|
+
clustering = f"{cl['backend']} @ {cl['identity']}% identity"
|
|
116
|
+
return f"""# {manifest["dataset_version"]} — structures
|
|
117
|
+
|
|
118
|
+
Generated by IF-Split {__version__}. Coordinates fetched from RCSB
|
|
119
|
+
({unit}, gzipped mmCIF). Reproducible from the source split.
|
|
120
|
+
|
|
121
|
+
## Provenance
|
|
122
|
+
- **config hash:** `{manifest["config_hash"]}`
|
|
123
|
+
- **snapshot date:** `{cfg.get("snapshot_date", "?")}` (entries with release_date <= this)
|
|
124
|
+
- **clustering:** {clustering}
|
|
125
|
+
- **candidates.sha256:** `{manifest["candidates"]["sha256"]}`
|
|
126
|
+
|
|
127
|
+
## Splits (hydrated)
|
|
128
|
+
| split | structures |
|
|
129
|
+
|-------|-----------:|
|
|
130
|
+
| train | {counts["train"]:>10,} |
|
|
131
|
+
| val | {counts["val"]:>10,} |
|
|
132
|
+
| test | {counts["test"]:>10,} |
|
|
133
|
+
|
|
134
|
+
## Layout
|
|
135
|
+
```
|
|
136
|
+
structures/<split>/<shard>/<id>{"-assembly1" if assembly else ""}.cif.gz
|
|
137
|
+
```
|
|
138
|
+
`<shard>` is the PDB "divided" scheme — the middle two characters of the entry id
|
|
139
|
+
(e.g. `4hhb` -> `hh/`) — so no single split directory holds an unwieldy number of
|
|
140
|
+
files. Browsable by hand, scalable to the whole PDB.
|
|
141
|
+
|
|
142
|
+
## Files
|
|
143
|
+
- `index.jsonl` / `index.parquet` — one row per structure: entry_id, split, path,
|
|
144
|
+
sha256, cluster, ligand_classes, ligand_tiers. The `sha256` lets you verify
|
|
145
|
+
integrity and de-duplicate; `cluster` lets you sample cluster-balanced batches.
|
|
146
|
+
- `manifest.json` — the full source split manifest (config, drop log, per-class
|
|
147
|
+
counts, component stats).
|
|
148
|
+
|
|
149
|
+
## Load (Python)
|
|
150
|
+
```python
|
|
151
|
+
import pandas as pd
|
|
152
|
+
df = pd.read_parquet("index.parquet") # or pd.read_json("index.jsonl", lines=True)
|
|
153
|
+
train = df[df.split == "train"]
|
|
154
|
+
# one representative per cluster (de-redundified):
|
|
155
|
+
epoch = train.sort_values("entry_id").groupby("cluster").head(1)
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
## Integrity
|
|
159
|
+
```bash
|
|
160
|
+
# every file's sha256 is in the index; re-fetch is resumable and verifiable.
|
|
161
|
+
```
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def hydrate(
|
|
166
|
+
manifest_path: str | Path,
|
|
167
|
+
root: str | Path,
|
|
168
|
+
*,
|
|
169
|
+
splits: list[str] | None = None,
|
|
170
|
+
assembly: bool = True,
|
|
171
|
+
workers: int = 8,
|
|
172
|
+
fetcher: StructureFetcher | None = None,
|
|
173
|
+
progress: ProgressFn | None = None,
|
|
174
|
+
) -> dict[str, Any]:
|
|
175
|
+
"""Download structures for a manifest and write the ML-ready package.
|
|
176
|
+
|
|
177
|
+
Returns a summary dict. Idempotent: re-running skips files already present.
|
|
178
|
+
"""
|
|
179
|
+
manifest = read_manifest(manifest_path)
|
|
180
|
+
src = Path(manifest_path).parent
|
|
181
|
+
files = manifest.get("files", {})
|
|
182
|
+
# Split membership + supporting maps live in sidecar files next to the manifest.
|
|
183
|
+
classes = read_classes(src / files.get("ligand_classes", CLASSES_FILENAME))
|
|
184
|
+
entry_clusters = read_clusters(src / files.get("clusters", CLUSTERS_FILENAME))
|
|
185
|
+
tiers = read_tiers(src / files.get("ligand_tiers", TIERS_FILENAME))
|
|
186
|
+
root = Path(root)
|
|
187
|
+
root.mkdir(parents=True, exist_ok=True)
|
|
188
|
+
splits = splits or list(SPLITS)
|
|
189
|
+
|
|
190
|
+
targets = select_targets(manifest, splits, base_dir=src)
|
|
191
|
+
|
|
192
|
+
owns = fetcher is None
|
|
193
|
+
fetcher = fetcher or StructureFetcher(assembly=assembly, workers=workers)
|
|
194
|
+
try:
|
|
195
|
+
res = fetcher.fetch(targets, root, progress=progress)
|
|
196
|
+
finally:
|
|
197
|
+
if owns:
|
|
198
|
+
fetcher.close()
|
|
199
|
+
|
|
200
|
+
rows = _index_rows(res.index_rows, classes, entry_clusters, tiers)
|
|
201
|
+
index_paths = write_index(rows, root)
|
|
202
|
+
(root / "manifest.json").write_text(
|
|
203
|
+
json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8"
|
|
204
|
+
)
|
|
205
|
+
(root / "DATASET_CARD.md").write_text(
|
|
206
|
+
_dataset_card(manifest, rows, assembly=assembly), encoding="utf-8"
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
return {
|
|
210
|
+
"root": str(root),
|
|
211
|
+
"requested": len(targets),
|
|
212
|
+
"fetched": len(res.fetched),
|
|
213
|
+
"skipped": len(res.skipped),
|
|
214
|
+
"failed": res.failed,
|
|
215
|
+
"index": {k: str(v) for k, v in index_paths.items()},
|
|
216
|
+
}
|