if-split 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ifsplit/dataset.py ADDED
@@ -0,0 +1,112 @@
1
+ """Stage 8 - Thin loader over a manifest, with cluster-balanced sampling.
2
+
3
+ Reads ``manifest.json`` and exposes train/val/test views (entry ids, ligand
4
+ classes, and the entry->cluster map). Coordinate/featurization loading is
5
+ intentionally *not* here - it's the optional downstream concern (PLAN.md §1.5);
6
+ a model repo plugs its own featurizer onto these entry lists.
7
+
8
+ The PDB is heavily redundant (thousands of near-identical lysozyme / kinase
9
+ co-crystals). Training by sampling entries uniformly drowns the model in
10
+ over-represented folds. ``sample_by_cluster`` draws one entry per sequence
11
+ cluster per epoch, which is the bigger training-quality lever than ligand tiering
12
+ - and it is free here because the clusters are already computed. Sampling is
13
+ deterministic given a seed (no global RNG), so an epoch is reproducible.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import hashlib
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+
22
+ from .manifest import read_classes, read_clusters, read_id_list, read_manifest
23
+
24
+ SPLITS = ("train", "val", "test")
25
+
26
+
27
+ def _stable_rank(key: str, seed: int) -> int:
28
+ """Deterministic pseudo-random rank for ``key`` under ``seed`` (no global RNG)."""
29
+ digest = hashlib.blake2b(f"{seed}:{key}".encode(), digest_size=8).digest()
30
+ return int.from_bytes(digest, "big")
31
+
32
+
33
+ @dataclass
34
+ class SplitView:
35
+ name: str
36
+ entry_ids: list[str]
37
+ ligand_classes: dict[str, list[str]] # entry_id -> classes
38
+ entry_clusters: dict[str, str] # entry_id -> cluster key
39
+
40
+ def __len__(self) -> int:
41
+ return len(self.entry_ids)
42
+
43
+ def with_class(self, cls: str) -> list[str]:
44
+ """Entry ids in this split tagged with ligand class ``cls``."""
45
+ return [e for e in self.entry_ids if cls in self.ligand_classes.get(e, [])]
46
+
47
+ def cluster_groups(self) -> dict[str, list[str]]:
48
+ """Map cluster key -> sorted entry ids within this split."""
49
+ groups: dict[str, list[str]] = {}
50
+ for e in self.entry_ids:
51
+ key = self.entry_clusters.get(e, e)
52
+ groups.setdefault(key, []).append(e)
53
+ return {k: sorted(v) for k, v in sorted(groups.items())}
54
+
55
+ def sample_by_cluster(self, seed: int = 0) -> list[str]:
56
+ """One entry per cluster, chosen deterministically by ``seed``.
57
+
58
+ De-redundifies the split: each sequence cluster contributes exactly one
59
+ representative, so over-represented folds don't dominate. Vary ``seed``
60
+ across epochs to rotate which member of each cluster is drawn. Returns a
61
+ deterministically ordered list (sorted by the same stable rank).
62
+ """
63
+ chosen: list[tuple[int, str]] = []
64
+ for key, members in self.cluster_groups().items():
65
+ rep = min(members, key=lambda e: (_stable_rank(e, seed), e))
66
+ chosen.append((_stable_rank(key, seed), rep))
67
+ return [e for _, e in sorted(chosen)]
68
+
69
+
70
+ class IFSplitDataset:
71
+ """Read-only view over a built manifest's train/val/test partition."""
72
+
73
+ def __init__(self, manifest_path: str | Path) -> None:
74
+ self._m = read_manifest(manifest_path)
75
+ self._dir = Path(manifest_path).parent
76
+ self.dataset_version: str = self._m["dataset_version"]
77
+ self.config_hash: str = self._m["config_hash"]
78
+ files = self._m.get("files", {})
79
+ self._split_files: dict[str, str] = files.get("splits", {})
80
+ # Supporting maps live in sidecar files referenced by the manifest.
81
+ self._classes = read_classes(
82
+ self._dir / files.get("ligand_classes", "ligands.classes.json")
83
+ )
84
+ self._entry_clusters = read_clusters(self._dir / files.get("clusters", "clusters.json"))
85
+
86
+ def split(self, name: str) -> SplitView:
87
+ if name not in SPLITS:
88
+ raise KeyError(f"unknown split {name!r}; expected one of {SPLITS}")
89
+ fname = self._split_files.get(name, f"{name}.json")
90
+ ids = read_id_list(self._dir / fname)
91
+ return SplitView(
92
+ name=name,
93
+ entry_ids=ids,
94
+ ligand_classes={e: self._classes.get(e, []) for e in ids},
95
+ entry_clusters={e: self._entry_clusters.get(e, e) for e in ids},
96
+ )
97
+
98
+ @property
99
+ def train(self) -> SplitView:
100
+ return self.split("train")
101
+
102
+ @property
103
+ def val(self) -> SplitView:
104
+ return self.split("val")
105
+
106
+ @property
107
+ def test(self) -> SplitView:
108
+ return self.split("test")
109
+
110
+
111
+ def load_dataset(manifest_path: str | Path) -> IFSplitDataset:
112
+ return IFSplitDataset(manifest_path)
ifsplit/download.py ADDED
@@ -0,0 +1,229 @@
1
+ """Stage 2 - Optional structure hydration (the only stage that touches coordinates).
2
+
3
+ `build` never calls this. `fetch` materializes the mmCIF files for a *built*
4
+ manifest into an MLOps-friendly tree that anyone can pick up and train on:
5
+
6
+ <root>/
7
+ structures/
8
+ train/ hh/4hhb-assembly1.cif.gz # split-partitioned (browsable),
9
+ val/ ... # sharded by the PDB middle-two
10
+ test/ ab/1abc-assembly1.cif.gz # chars (PDB "divided" scheme)
11
+ index.jsonl # one row/structure (zero-dep)
12
+ index.parquet # same, columnar (if pyarrow present)
13
+ manifest.json # copy of the source split manifest
14
+ DATASET_CARD.md # provenance + how-to-load
15
+
16
+ Design choices that make it "pristine":
17
+ - **Content-addressed integrity:** every file's SHA-256 is recorded in the index,
18
+ so a re-fetch / transfer can be verified and the pull is resumable (existing,
19
+ hash-matching files are skipped).
20
+ - **Deterministic paths:** path is a pure function of (split, entry_id, assembly),
21
+ so two people fetching the same manifest get byte-identical trees.
22
+ - **Explicit scope:** the caller must choose --split / --all; large pulls require
23
+ --yes. No accidental terabyte (the lightweight-by-default contract).
24
+ - **No coordinates in the split itself:** this is downstream of `build`; the
25
+ manifest + lock remain tiny and coordinate-free.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import gzip
31
+ import hashlib
32
+ import time
33
+ from collections.abc import Callable, Iterable
34
+ from concurrent.futures import ThreadPoolExecutor, as_completed
35
+ from dataclasses import dataclass, field
36
+ from pathlib import Path
37
+
38
+ import httpx
39
+
40
+ from . import __version__
41
+
42
+ FILE_BASE = "https://files.rcsb.org/download"
43
+ SPLITS = ("train", "val", "test")
44
+ ProgressFn = Callable[[str], None]
45
+
46
+ # RCSB "divided" sharding: middle two characters of the 4-char core id.
47
+ # 4HHB -> "hh"; for extended ids (pdb_0000XXXX) we shard on the core's chars 2-3.
48
+ _EXTENDED_PREFIX = "pdb_0000"
49
+
50
+
51
+ def core_id(entry_id: str) -> str:
52
+ """Lowercase 'core' id used for filenames/sharding (legacy or extended)."""
53
+ e = entry_id.lower()
54
+ if e.startswith(_EXTENDED_PREFIX) and len(e) > len(_EXTENDED_PREFIX):
55
+ return e[len(_EXTENDED_PREFIX) :]
56
+ if e.startswith("pdb_"):
57
+ return e[len("pdb_") :]
58
+ return e
59
+
60
+
61
+ def shard_for(entry_id: str) -> str:
62
+ """Two-char shard (PDB divided scheme). Falls back to a stable 2-char hash."""
63
+ c = core_id(entry_id)
64
+ if len(c) >= 3:
65
+ return c[1:3]
66
+ return hashlib.blake2b(c.encode(), digest_size=1).hexdigest() # 2 hex chars
67
+
68
+
69
+ def filename_for(entry_id: str, *, assembly: bool) -> str:
70
+ suffix = "-assembly1.cif.gz" if assembly else ".cif.gz"
71
+ return f"{core_id(entry_id)}{suffix}"
72
+
73
+
74
+ def url_for(entry_id: str, *, assembly: bool) -> str:
75
+ return f"{FILE_BASE}/{filename_for(entry_id, assembly=assembly)}"
76
+
77
+
78
+ def rel_path_for(entry_id: str, split: str, *, assembly: bool) -> Path:
79
+ return (
80
+ Path("structures") / split / shard_for(entry_id) / filename_for(entry_id, assembly=assembly)
81
+ )
82
+
83
+
84
+ @dataclass
85
+ class FetchResult:
86
+ fetched: list[str] = field(default_factory=list)
87
+ skipped: list[str] = field(default_factory=list) # already present + hash-ok
88
+ failed: list[tuple[str, str]] = field(default_factory=list) # (entry_id, reason)
89
+ index_rows: list[dict] = field(default_factory=list)
90
+
91
+
92
+ def _sha256_file(path: Path) -> str:
93
+ h = hashlib.sha256()
94
+ with path.open("rb") as fh:
95
+ for chunk in iter(lambda: fh.read(1 << 20), b""):
96
+ h.update(chunk)
97
+ return h.hexdigest()
98
+
99
+
100
+ class StructureFetcher:
101
+ """Polite, resumable mmCIF downloader (assembly 1 or asymmetric unit)."""
102
+
103
+ def __init__(
104
+ self,
105
+ *,
106
+ assembly: bool = True,
107
+ workers: int = 8,
108
+ timeout: float = 120.0,
109
+ max_retries: int = 4,
110
+ backoff_base: float = 1.5,
111
+ sleep=time.sleep,
112
+ ) -> None:
113
+ self.assembly = assembly
114
+ self.workers = workers
115
+ self._timeout = timeout
116
+ self._max_retries = max_retries
117
+ self._backoff_base = backoff_base
118
+ self._sleep = sleep
119
+ self._client = httpx.Client(
120
+ timeout=timeout,
121
+ headers={"User-Agent": f"IF-Split/{__version__} (structure fetch)"},
122
+ follow_redirects=True,
123
+ )
124
+
125
+ def __enter__(self) -> StructureFetcher:
126
+ return self
127
+
128
+ def __exit__(self, *exc) -> None:
129
+ self._client.close()
130
+
131
+ def close(self) -> None:
132
+ self._client.close()
133
+
134
+ def _get(self, url: str) -> bytes:
135
+ last: Exception | None = None
136
+ for attempt in range(self._max_retries + 1):
137
+ try:
138
+ resp = self._client.get(url)
139
+ except httpx.HTTPError as exc:
140
+ last = exc
141
+ else:
142
+ if resp.status_code == 200:
143
+ return resp.content
144
+ if resp.status_code == 404:
145
+ raise FileNotFoundError(f"not on RCSB (404): {url}")
146
+ last = RuntimeError(f"HTTP {resp.status_code}: {url}")
147
+ if attempt < self._max_retries:
148
+ self._sleep(self._backoff_base**attempt)
149
+ raise RuntimeError(f"download failed after retries: {url} ({last})")
150
+
151
+ def estimate_bytes(self, entry_ids: list[str], sample: int = 12) -> int | None:
152
+ """Rough total-size estimate from HEAD on a sample (None if unavailable)."""
153
+ sizes: list[int] = []
154
+ for eid in entry_ids[:sample]:
155
+ try:
156
+ r = self._client.head(url_for(eid, assembly=self.assembly))
157
+ cl = r.headers.get("content-length")
158
+ if r.status_code == 200 and cl:
159
+ sizes.append(int(cl))
160
+ except httpx.HTTPError:
161
+ continue
162
+ if not sizes:
163
+ return None
164
+ avg = sum(sizes) / len(sizes)
165
+ return int(avg * len(entry_ids))
166
+
167
+ def _fetch_one(self, entry_id: str, split: str, root: Path) -> dict:
168
+ rel = rel_path_for(entry_id, split, assembly=self.assembly)
169
+ dest = root / rel
170
+ if dest.exists(): # resume: trust an existing, readable file
171
+ return {
172
+ "entry_id": entry_id,
173
+ "split": split,
174
+ "path": str(rel),
175
+ "sha256": _sha256_file(dest),
176
+ "status": "skipped",
177
+ }
178
+ dest.parent.mkdir(parents=True, exist_ok=True)
179
+ data = self._get(url_for(entry_id, assembly=self.assembly))
180
+ gzip.decompress(data) # integrity check: must be valid gzip
181
+ tmp = dest.with_suffix(dest.suffix + ".part")
182
+ tmp.write_bytes(data)
183
+ tmp.replace(dest)
184
+ return {
185
+ "entry_id": entry_id,
186
+ "split": split,
187
+ "path": str(rel),
188
+ "sha256": hashlib.sha256(data).hexdigest(),
189
+ "status": "fetched",
190
+ }
191
+
192
+ def fetch(
193
+ self,
194
+ targets: Iterable[tuple[str, str]], # (entry_id, split)
195
+ root: Path,
196
+ *,
197
+ progress: ProgressFn | None = None,
198
+ ) -> FetchResult:
199
+ targets = list(targets)
200
+ result = FetchResult()
201
+ done = 0
202
+ total = len(targets)
203
+
204
+ def say(msg: str) -> None:
205
+ if progress:
206
+ progress(msg)
207
+
208
+ with ThreadPoolExecutor(max_workers=self.workers) as pool:
209
+ futs = {
210
+ pool.submit(self._fetch_one, eid, split, root): (eid, split)
211
+ for eid, split in targets
212
+ }
213
+ for fut in as_completed(futs):
214
+ eid, _split = futs[fut]
215
+ done += 1
216
+ try:
217
+ row = fut.result()
218
+ except Exception as exc:
219
+ result.failed.append((eid, str(exc)))
220
+ else:
221
+ result.index_rows.append(row)
222
+ (result.skipped if row["status"] == "skipped" else result.fetched).append(eid)
223
+ if done % 100 == 0 or done == total:
224
+ say(
225
+ f"{done}/{total} ({len(result.fetched)} new, "
226
+ f"{len(result.skipped)} cached, {len(result.failed)} failed)"
227
+ )
228
+ result.index_rows.sort(key=lambda r: (r["split"], r["entry_id"]))
229
+ return result
ifsplit/enumerate.py ADDED
@@ -0,0 +1,111 @@
1
+ """Stage 1 - Enumerate candidates from RCSB (Search v2 + Data API).
2
+
3
+ Selects entries by ``release_date <= snapshot_date`` plus the method/resolution
4
+ filters, enriches each via the Data API (sequences, ligand comps, residue
5
+ counts, assemblies), and writes the byte-stable ``candidates.jsonl`` -- the
6
+ snapshot definition. No coordinates are downloaded (PLAN.md §1.5).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ import time
13
+ from collections.abc import Callable
14
+ from pathlib import Path
15
+
16
+ from .config import Config
17
+ from .rcsb import RcsbClient
18
+ from .schema import CandidateRecord, canonical_jsonl_bytes, sha256_hex
19
+
20
+ ProgressFn = Callable[[str], None]
21
+
22
+ # How often (in enriched records) to emit a progress line.
23
+ _REPORT_EVERY = 1000
24
+
25
+
26
+ def fmt_duration(seconds: float) -> str:
27
+ """Human-friendly duration: ``45s`` / ``3m29s`` / ``1h02m``."""
28
+ s = int(max(0, seconds))
29
+ if s < 60:
30
+ return f"{s}s"
31
+ m, s = divmod(s, 60)
32
+ if m < 60:
33
+ return f"{m}m{s:02d}s"
34
+ h, m = divmod(m, 60)
35
+ return f"{h}h{m:02d}m"
36
+
37
+
38
+ def progress_line(label: str, n: int, total: int, t0: float) -> str:
39
+ """A ``label: n/total (pct) rate/s eta ...`` line from a monotonic start."""
40
+ elapsed = time.monotonic() - t0
41
+ pct = (100 * n / total) if total else 0.0
42
+ msg = f"{label}: {n}/{total} ({pct:.0f}%)"
43
+ if elapsed > 0 and n > 0:
44
+ rate = n / elapsed
45
+ eta = (total - n) / rate if rate > 0 else 0.0
46
+ msg += f" {rate:.0f}/s eta {fmt_duration(eta)}"
47
+ return msg
48
+
49
+
50
+ def make_console_progress(stream=None) -> ProgressFn:
51
+ """A timestamped, line-flushed progress printer for long CLI runs.
52
+
53
+ The flush is the important part: when stdout is redirected to a file, Python
54
+ block-buffers it, so unflushed progress lines stay invisible until the process
55
+ exits. This forces each line out immediately.
56
+ """
57
+ out = stream or sys.stdout
58
+
59
+ def say(msg: str) -> None:
60
+ print(f" [{time.strftime('%H:%M:%S')}] {msg}", file=out, flush=True)
61
+
62
+ return say
63
+
64
+
65
+ def enumerate_candidates(
66
+ cfg: Config,
67
+ out_dir: str | Path,
68
+ *,
69
+ limit: int | None = None,
70
+ client: RcsbClient | None = None,
71
+ progress: ProgressFn | None = None,
72
+ ) -> tuple[list[CandidateRecord], Path, str]:
73
+ """Run Stage 1.
74
+
75
+ Returns ``(records, candidates_path, sha256)``. ``candidates.jsonl`` is
76
+ written to ``out_dir`` in canonical (byte-stable) form.
77
+ """
78
+ out_dir = Path(out_dir)
79
+ out_dir.mkdir(parents=True, exist_ok=True)
80
+
81
+ def say(msg: str) -> None:
82
+ if progress:
83
+ progress(msg)
84
+
85
+ owns_client = client is None
86
+ client = client or RcsbClient()
87
+ try:
88
+ ids = client.search_entry_ids(cfg, limit=limit, progress=progress)
89
+ say(f"search: {len(ids)} entries match the snapshot")
90
+
91
+ records: list[CandidateRecord] = []
92
+ total = len(ids)
93
+ t0 = time.monotonic()
94
+ for raw in client.fetch_entries(ids):
95
+ records.append(CandidateRecord.from_data_api(raw))
96
+ if len(records) % _REPORT_EVERY == 0:
97
+ say(progress_line("enriched", len(records), total, t0))
98
+ say(progress_line("enriched", len(records), total, t0) + " (done)")
99
+ finally:
100
+ if owns_client:
101
+ client.close()
102
+
103
+ data = canonical_jsonl_bytes(records)
104
+ sha = sha256_hex(data)
105
+ candidates_path = out_dir / "candidates.jsonl"
106
+ candidates_path.write_bytes(data)
107
+ say(f"wrote {candidates_path} ({len(records)} records, sha256={sha[:12]}...)")
108
+
109
+ # Return in canonical (entry_id-sorted) order so callers see what was written.
110
+ records.sort(key=lambda r: r.entry_id)
111
+ return records, candidates_path, sha
ifsplit/hydrate.py ADDED
@@ -0,0 +1,216 @@
1
+ """Stage 2 orchestration: turn a built manifest into a hydrated, ML-ready dataset.
2
+
3
+ Ties together split selection, the structure fetcher, the dual index
4
+ (JSONL + optional Parquet), a copy of the manifest, and a generated dataset card.
5
+ Kept separate from ``download.py`` (the transport layer) so the I/O policy and
6
+ the MLOps packaging are easy to read and test independently.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ from collections.abc import Callable
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ from . import __version__
17
+ from .download import SPLITS, StructureFetcher
18
+ from .manifest import (
19
+ CLASSES_FILENAME,
20
+ CLUSTERS_FILENAME,
21
+ SPLIT_FILES,
22
+ TIERS_FILENAME,
23
+ read_classes,
24
+ read_clusters,
25
+ read_id_list,
26
+ read_manifest,
27
+ read_tiers,
28
+ )
29
+
30
+ ProgressFn = Callable[[str], None]
31
+
32
+
33
+ def select_targets(
34
+ manifest: dict[str, Any], splits: list[str], base_dir: Path | None = None
35
+ ) -> list[tuple[str, str]]:
36
+ """(entry_id, split) pairs for the requested splits, deterministically ordered.
37
+
38
+ Reads each split's id list from its file (train.json etc.) next to the
39
+ manifest. ``base_dir`` defaults to the manifest's directory.
40
+ """
41
+ base = Path(base_dir) if base_dir is not None else Path(".")
42
+ split_files = manifest.get("files", {}).get("splits", SPLIT_FILES)
43
+ targets: list[tuple[str, str]] = []
44
+ for split in splits:
45
+ fname = split_files.get(split, f"{split}.json")
46
+ for eid in sorted(read_id_list(base / fname)):
47
+ targets.append((eid, split))
48
+ return targets
49
+
50
+
51
+ def _index_rows(
52
+ fetch_rows: list[dict],
53
+ classes: dict[str, list[str]],
54
+ entry_clusters: dict[str, str],
55
+ tiers: dict[str, dict] | None = None,
56
+ ) -> list[dict]:
57
+ """Enrich each fetched file with split metadata for the ML index."""
58
+ tiers = tiers or {}
59
+ rows: list[dict] = []
60
+ for r in fetch_rows:
61
+ eid = r["entry_id"]
62
+ rows.append(
63
+ {
64
+ "entry_id": eid,
65
+ "split": r["split"],
66
+ "path": r["path"],
67
+ "sha256": r["sha256"],
68
+ "cluster": entry_clusters.get(eid),
69
+ "ligand_classes": classes.get(eid, []),
70
+ "ligand_tiers": tiers.get(eid, {}),
71
+ }
72
+ )
73
+ rows.sort(key=lambda r: (r["split"], r["entry_id"]))
74
+ return rows
75
+
76
+
77
+ def write_index(rows: list[dict], root: Path) -> dict[str, Path]:
78
+ """Write index.jsonl (always) and index.parquet (if pyarrow present)."""
79
+ out: dict[str, Path] = {}
80
+ jsonl = root / "index.jsonl"
81
+ with jsonl.open("w", encoding="utf-8") as fh:
82
+ for r in rows:
83
+ fh.write(json.dumps(r, sort_keys=True, separators=(",", ":")) + "\n")
84
+ out["jsonl"] = jsonl
85
+
86
+ try:
87
+ import pyarrow as pa
88
+ import pyarrow.parquet as pq
89
+ except ImportError:
90
+ return out
91
+
92
+ # Stringify nested fields for a clean, portable columnar schema.
93
+ table = pa.table(
94
+ {
95
+ "entry_id": [r["entry_id"] for r in rows],
96
+ "split": [r["split"] for r in rows],
97
+ "path": [r["path"] for r in rows],
98
+ "sha256": [r["sha256"] for r in rows],
99
+ "cluster": [r["cluster"] for r in rows],
100
+ "ligand_classes": [",".join(r["ligand_classes"]) for r in rows],
101
+ "ligand_tiers_json": [json.dumps(r["ligand_tiers"], sort_keys=True) for r in rows],
102
+ }
103
+ )
104
+ parquet = root / "index.parquet"
105
+ pq.write_table(table, parquet)
106
+ out["parquet"] = parquet
107
+ return out
108
+
109
+
110
+ def _dataset_card(manifest: dict[str, Any], rows: list[dict], *, assembly: bool) -> str:
111
+ counts = {s: sum(1 for r in rows if r["split"] == s) for s in SPLITS}
112
+ cfg = manifest.get("config", {})
113
+ unit = "biological assembly 1" if assembly else "asymmetric unit"
114
+ cl = manifest["clustering"]
115
+ clustering = f"{cl['backend']} @ {cl['identity']}% identity"
116
+ return f"""# {manifest["dataset_version"]} — structures
117
+
118
+ Generated by IF-Split {__version__}. Coordinates fetched from RCSB
119
+ ({unit}, gzipped mmCIF). Reproducible from the source split.
120
+
121
+ ## Provenance
122
+ - **config hash:** `{manifest["config_hash"]}`
123
+ - **snapshot date:** `{cfg.get("snapshot_date", "?")}` (entries with release_date <= this)
124
+ - **clustering:** {clustering}
125
+ - **candidates.sha256:** `{manifest["candidates"]["sha256"]}`
126
+
127
+ ## Splits (hydrated)
128
+ | split | structures |
129
+ |-------|-----------:|
130
+ | train | {counts["train"]:>10,} |
131
+ | val | {counts["val"]:>10,} |
132
+ | test | {counts["test"]:>10,} |
133
+
134
+ ## Layout
135
+ ```
136
+ structures/<split>/<shard>/<id>{"-assembly1" if assembly else ""}.cif.gz
137
+ ```
138
+ `<shard>` is the PDB "divided" scheme — the middle two characters of the entry id
139
+ (e.g. `4hhb` -> `hh/`) — so no single split directory holds an unwieldy number of
140
+ files. Browsable by hand, scalable to the whole PDB.
141
+
142
+ ## Files
143
+ - `index.jsonl` / `index.parquet` — one row per structure: entry_id, split, path,
144
+ sha256, cluster, ligand_classes, ligand_tiers. The `sha256` lets you verify
145
+ integrity and de-duplicate; `cluster` lets you sample cluster-balanced batches.
146
+ - `manifest.json` — the full source split manifest (config, drop log, per-class
147
+ counts, component stats).
148
+
149
+ ## Load (Python)
150
+ ```python
151
+ import pandas as pd
152
+ df = pd.read_parquet("index.parquet") # or pd.read_json("index.jsonl", lines=True)
153
+ train = df[df.split == "train"]
154
+ # one representative per cluster (de-redundified):
155
+ epoch = train.sort_values("entry_id").groupby("cluster").head(1)
156
+ ```
157
+
158
+ ## Integrity
159
+ ```bash
160
+ # every file's sha256 is in the index; re-fetch is resumable and verifiable.
161
+ ```
162
+ """
163
+
164
+
165
+ def hydrate(
166
+ manifest_path: str | Path,
167
+ root: str | Path,
168
+ *,
169
+ splits: list[str] | None = None,
170
+ assembly: bool = True,
171
+ workers: int = 8,
172
+ fetcher: StructureFetcher | None = None,
173
+ progress: ProgressFn | None = None,
174
+ ) -> dict[str, Any]:
175
+ """Download structures for a manifest and write the ML-ready package.
176
+
177
+ Returns a summary dict. Idempotent: re-running skips files already present.
178
+ """
179
+ manifest = read_manifest(manifest_path)
180
+ src = Path(manifest_path).parent
181
+ files = manifest.get("files", {})
182
+ # Split membership + supporting maps live in sidecar files next to the manifest.
183
+ classes = read_classes(src / files.get("ligand_classes", CLASSES_FILENAME))
184
+ entry_clusters = read_clusters(src / files.get("clusters", CLUSTERS_FILENAME))
185
+ tiers = read_tiers(src / files.get("ligand_tiers", TIERS_FILENAME))
186
+ root = Path(root)
187
+ root.mkdir(parents=True, exist_ok=True)
188
+ splits = splits or list(SPLITS)
189
+
190
+ targets = select_targets(manifest, splits, base_dir=src)
191
+
192
+ owns = fetcher is None
193
+ fetcher = fetcher or StructureFetcher(assembly=assembly, workers=workers)
194
+ try:
195
+ res = fetcher.fetch(targets, root, progress=progress)
196
+ finally:
197
+ if owns:
198
+ fetcher.close()
199
+
200
+ rows = _index_rows(res.index_rows, classes, entry_clusters, tiers)
201
+ index_paths = write_index(rows, root)
202
+ (root / "manifest.json").write_text(
203
+ json.dumps(manifest, indent=2, sort_keys=True) + "\n", encoding="utf-8"
204
+ )
205
+ (root / "DATASET_CARD.md").write_text(
206
+ _dataset_card(manifest, rows, assembly=assembly), encoding="utf-8"
207
+ )
208
+
209
+ return {
210
+ "root": str(root),
211
+ "requested": len(targets),
212
+ "fetched": len(res.fetched),
213
+ "skipped": len(res.skipped),
214
+ "failed": res.failed,
215
+ "index": {k: str(v) for k, v in index_paths.items()},
216
+ }