@platforma-open/milaboratories.3d-structure-prediction.software 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.turbo/turbo-build.log +265 -0
- package/dist/artifacts/py-archive/archive.json +1 -0
- package/dist/artifacts/py-archive/docker_x64.json +1 -0
- package/dist/docker/Dockerfile-py-archive +22 -0
- package/dist/tengo/software/immunebuilder-predict.sw.json +1 -0
- package/package.json +41 -0
- package/pkg-platforma-open-milaboratories.3d-structure-prediction.software-py-archive-1.0.0.tgz +0 -0
- package/src_python/numbering.py +103 -0
- package/src_python/pdb_writer.py +137 -0
- package/src_python/requirements.txt +4 -0
- package/src_python/run_immunebuilder.py +647 -0
- package/src_python/sanitize.py +150 -0
|
@@ -0,0 +1,647 @@
|
|
|
1
|
+
"""ImmuneBuilder batch runner for the Platforma 3D Structure Prediction block.
|
|
2
|
+
|
|
3
|
+
Reads a batch TSV of clonotypes and predicts structures via ABodyBuilder2 or
|
|
4
|
+
NanoBodyBuilder2 (spec R22). Emits:
|
|
5
|
+
|
|
6
|
+
- Per-clonotype PDB files named `<sha1(clonotypeKey)>.pdb` (R30).
|
|
7
|
+
- `manifest.tsv` : (clonotypeKey, pdb_filename).
|
|
8
|
+
- `confidence.tsv`: aggregate + per-residue confidence (Å error, R32-R36)
|
|
9
|
+
plus failureReason (R40) and warning columns.
|
|
10
|
+
|
|
11
|
+
Dependencies (ImmuneBuilder, torch) ride the venv that pl-pkg's install-deps
|
|
12
|
+
creates. ANARCI and pdbfixer are not on PyPI; the atls runenv builds them
|
|
13
|
+
from source and stages them in the runenv's site-packages. The SDK's venv
|
|
14
|
+
is created without --system-site-packages, so we bootstrap the runenv's
|
|
15
|
+
site-packages onto sys.path here before any project-level imports run.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import os as _os
|
|
21
|
+
import sys as _sys
|
|
22
|
+
import sysconfig as _sysconfig
|
|
23
|
+
|
|
24
|
+
_runenv_root = _os.environ.get("PYTHONHOME")
|
|
25
|
+
if _runenv_root:
|
|
26
|
+
# Pick up packages staged via runenv-python-builder's `copyFiles` directive
|
|
27
|
+
# (e.g. `{site-packages}/anarci`, `{site-packages}/pdbfixer`).
|
|
28
|
+
_py_ver_short = f"python{_sys.version_info.major}.{_sys.version_info.minor}"
|
|
29
|
+
_candidates = [
|
|
30
|
+
_os.path.join(_runenv_root, "lib", _py_ver_short, "site-packages"),
|
|
31
|
+
_os.path.join(_runenv_root, "Lib", "site-packages"), # Windows layout
|
|
32
|
+
]
|
|
33
|
+
_platlib = _sysconfig.get_paths().get("platlib")
|
|
34
|
+
if _platlib:
|
|
35
|
+
_candidates.append(_platlib)
|
|
36
|
+
for _candidate in _candidates:
|
|
37
|
+
if _os.path.isdir(_candidate) and _candidate not in _sys.path:
|
|
38
|
+
_sys.path.append(_candidate)
|
|
39
|
+
|
|
40
|
+
# ANARCI shells out to `hmmscan` (HMMER); the binary ships in the runenv's
|
|
41
|
+
# bin/ via copyFiles, but only the venv's bin is on PATH by default.
|
|
42
|
+
_runenv_bin = _os.path.join(_runenv_root, "bin")
|
|
43
|
+
if _os.path.isdir(_runenv_bin):
|
|
44
|
+
_path_env = _os.environ.get("PATH", "")
|
|
45
|
+
if _runenv_bin not in _path_env.split(_os.pathsep):
|
|
46
|
+
_os.environ["PATH"] = _runenv_bin + _os.pathsep + _path_env
|
|
47
|
+
|
|
48
|
+
import argparse
|
|
49
|
+
import csv
|
|
50
|
+
import hashlib
|
|
51
|
+
import json
|
|
52
|
+
import os
|
|
53
|
+
import sys
|
|
54
|
+
import time
|
|
55
|
+
import traceback
|
|
56
|
+
from dataclasses import asdict, dataclass, field
|
|
57
|
+
from datetime import datetime, timezone
|
|
58
|
+
from pathlib import Path
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _log(message: str) -> None:
|
|
62
|
+
"""Line-buffered log entry to stderr.
|
|
63
|
+
|
|
64
|
+
Workflow uses `printErrStreamToStdout` + `saveStdoutStream`; the resulting
|
|
65
|
+
log handle is consumed by `PlLogView` in the UI. Every line is timestamped
|
|
66
|
+
so users can track progress across long-running batches.
|
|
67
|
+
"""
|
|
68
|
+
ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
69
|
+
print(f"[{ts}] {message}", file=sys.stderr, flush=True)
|
|
70
|
+
|
|
71
|
+
from numbering import cdrh3_length as imgt_cdrh3_length
|
|
72
|
+
from numbering import extract_numbered_residues, vhh_hallmarks_present
|
|
73
|
+
from pdb_writer import augment_pdb
|
|
74
|
+
from sanitize import sanitize_pair
|
|
75
|
+
|
|
76
|
+
# `clonotypeKey` is a placeholder — the actual column name we write back is
|
|
77
|
+
# the same name we read from the input TSV (the orchestrator uses the axis
|
|
78
|
+
# spec name, e.g. `pl7.app/vdj/clonotypeKey`, as the column header). See
|
|
79
|
+
# build_confidence_fields().
|
|
80
|
+
KEY_COLUMN_PLACEHOLDER = "clonotypeKey"
|
|
81
|
+
|
|
82
|
+
CONFIDENCE_FIELDS = [
|
|
83
|
+
KEY_COLUMN_PLACEHOLDER,
|
|
84
|
+
"clonotypeLabel",
|
|
85
|
+
"meanError",
|
|
86
|
+
"cdrh1Error",
|
|
87
|
+
"cdrh2Error",
|
|
88
|
+
"cdrh3Error",
|
|
89
|
+
"cdrl1Error",
|
|
90
|
+
"cdrl2Error",
|
|
91
|
+
"cdrl3Error",
|
|
92
|
+
"perResidueError",
|
|
93
|
+
"cdrh3Length",
|
|
94
|
+
"failureReason",
|
|
95
|
+
"failureReasonText",
|
|
96
|
+
"warning",
|
|
97
|
+
"warningText",
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def build_confidence_fields(key_col: str) -> list[str]:
|
|
102
|
+
return [key_col if f == KEY_COLUMN_PLACEHOLDER else f for f in CONFIDENCE_FIELDS]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# Code → human-readable label maps. The code columns stay (hidden by default)
|
|
106
|
+
# so downstream blocks / future failure-stats logic can group by enum value;
|
|
107
|
+
# the *Text columns are what the user actually sees in the table.
|
|
108
|
+
FAILURE_REASON_LABELS: dict[str, str] = {
|
|
109
|
+
"empty_sequence": "Empty sequence",
|
|
110
|
+
"stop_codon_mid_sequence": "Stop codon in sequence",
|
|
111
|
+
"non_standard_aa_only_after_strip": "No standard amino acids after cleanup",
|
|
112
|
+
"non_standard_aa_residue": "Non-standard amino acid residue",
|
|
113
|
+
"length_out_of_range": "Length outside expected VH/VL range",
|
|
114
|
+
"light_chain_missing_in_paired_mode":
|
|
115
|
+
"Light chain missing — switch to NanoBodyBuilder2 or pick a light chain",
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
WARNING_LABELS: dict[str, str] = {
|
|
119
|
+
"probable_signal_peptide": "Possible N-terminal signal peptide",
|
|
120
|
+
"long_cdrh3": "Long CDR-H3 (≥20 aa) — confidence may be reduced",
|
|
121
|
+
"vhh_hallmarks_missing": "VHH hallmark residues not detected",
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _failure_reason_label(code: str) -> str:
|
|
126
|
+
if not code:
|
|
127
|
+
return ""
|
|
128
|
+
if code in FAILURE_REASON_LABELS:
|
|
129
|
+
return FAILURE_REASON_LABELS[code]
|
|
130
|
+
# Structured prefixes — preserve the suffix for triage rather than
|
|
131
|
+
# collapsing to a generic message.
|
|
132
|
+
if code.startswith("immunebuilder_exception:"):
|
|
133
|
+
return f"ImmuneBuilder error: {code.split(':', 1)[1]}"
|
|
134
|
+
if code.startswith("unknown_mode:"):
|
|
135
|
+
return f"Unknown prediction mode: {code.split(':', 1)[1]}"
|
|
136
|
+
return code
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _warning_label(code: str) -> str:
|
|
140
|
+
return WARNING_LABELS.get(code, code)
|
|
141
|
+
|
|
142
|
+
MANIFEST_FIELDS = [KEY_COLUMN_PLACEHOLDER, "pdb_filename"]
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def indexed_filename(row_idx: int) -> str:
|
|
146
|
+
return f"pdb_{row_idx:05d}.pdb"
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def sha1_filename(key: str) -> str:
|
|
150
|
+
"""Retained for manifest metadata — downstream blocks can SHA-1 keys to cross-reference."""
|
|
151
|
+
return hashlib.sha1(key.encode("utf-8")).hexdigest() + ".pdb"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_immunebuilder_version() -> str:
|
|
155
|
+
try:
|
|
156
|
+
from importlib.metadata import version as _version
|
|
157
|
+
return _version("ImmuneBuilder")
|
|
158
|
+
except Exception:
|
|
159
|
+
return "unknown"
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_block_version() -> str:
|
|
163
|
+
return os.environ.get("BLOCK_VERSION", "unknown")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def pick_key_column(fieldnames: list[str]) -> str:
|
|
167
|
+
for name in fieldnames:
|
|
168
|
+
if name.lower() == "clonotypekey" or name.lower().endswith("clonotypekey"):
|
|
169
|
+
return name
|
|
170
|
+
return fieldnames[0]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass
|
|
174
|
+
class RowResult:
|
|
175
|
+
clonotype_key: str
|
|
176
|
+
clonotype_label: str = ""
|
|
177
|
+
mean_error: str = ""
|
|
178
|
+
cdrh1: str = ""
|
|
179
|
+
cdrh2: str = ""
|
|
180
|
+
cdrh3: str = ""
|
|
181
|
+
cdrl1: str = ""
|
|
182
|
+
cdrl2: str = ""
|
|
183
|
+
cdrl3: str = ""
|
|
184
|
+
per_residue_json: str = ""
|
|
185
|
+
cdrh3_len: str = ""
|
|
186
|
+
failure_reason: str = ""
|
|
187
|
+
warnings: list[str] = field(default_factory=list)
|
|
188
|
+
pdb_filename: str = ""
|
|
189
|
+
|
|
190
|
+
@property
|
|
191
|
+
def warning_str(self) -> str:
|
|
192
|
+
return ";".join(self.warnings)
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def warning_text(self) -> str:
|
|
196
|
+
return "; ".join(_warning_label(w) for w in self.warnings)
|
|
197
|
+
|
|
198
|
+
def to_tsv_row(self, key_col: str) -> dict[str, str]:
|
|
199
|
+
return {
|
|
200
|
+
key_col: self.clonotype_key,
|
|
201
|
+
"clonotypeLabel": self.clonotype_label,
|
|
202
|
+
"meanError": self.mean_error,
|
|
203
|
+
"cdrh1Error": self.cdrh1,
|
|
204
|
+
"cdrh2Error": self.cdrh2,
|
|
205
|
+
"cdrh3Error": self.cdrh3,
|
|
206
|
+
"cdrl1Error": self.cdrl1,
|
|
207
|
+
"cdrl2Error": self.cdrl2,
|
|
208
|
+
"cdrl3Error": self.cdrl3,
|
|
209
|
+
"perResidueError": self.per_residue_json,
|
|
210
|
+
"cdrh3Length": self.cdrh3_len,
|
|
211
|
+
"failureReason": self.failure_reason,
|
|
212
|
+
"failureReasonText": _failure_reason_label(self.failure_reason),
|
|
213
|
+
"warning": self.warning_str,
|
|
214
|
+
"warningText": self.warning_text,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _mean(values: list[float]) -> float | None:
|
|
219
|
+
valid = [v for v in values if v is not None]
|
|
220
|
+
return sum(valid) / len(valid) if valid else None
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _region_errors(per_residue, chain: str, cdr_name: str) -> list[float]:
|
|
224
|
+
target = f"CDR{cdr_name[-1]}" # "CDR1" or similar
|
|
225
|
+
return [
|
|
226
|
+
r["errorAngstroms"]
|
|
227
|
+
for r in per_residue
|
|
228
|
+
if r["chain"] == chain and r.get("_region") == target
|
|
229
|
+
]
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _load_predictor(mode: str):
|
|
233
|
+
if mode == "ABodyBuilder2":
|
|
234
|
+
from ImmuneBuilder import ABodyBuilder2
|
|
235
|
+
return ABodyBuilder2()
|
|
236
|
+
if mode == "NanoBodyBuilder2":
|
|
237
|
+
from ImmuneBuilder import NanoBodyBuilder2
|
|
238
|
+
return NanoBodyBuilder2()
|
|
239
|
+
raise ValueError(f"unknown mode: {mode}")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _set_seed(seed: int) -> None:
|
|
243
|
+
import random
|
|
244
|
+
import numpy as np
|
|
245
|
+
random.seed(seed)
|
|
246
|
+
np.random.seed(seed)
|
|
247
|
+
try:
|
|
248
|
+
import torch
|
|
249
|
+
torch.manual_seed(seed)
|
|
250
|
+
if torch.cuda.is_available():
|
|
251
|
+
torch.cuda.manual_seed_all(seed)
|
|
252
|
+
except Exception:
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _predict_one(predictor, mode: str, vh: str, vl: str | None):
|
|
257
|
+
sequences = {"H": vh}
|
|
258
|
+
if mode == "ABodyBuilder2" and vl:
|
|
259
|
+
sequences["L"] = vl
|
|
260
|
+
return predictor.predict(sequences)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _per_residue_records(antibody) -> list[dict]:
|
|
264
|
+
"""Extract per-residue error records.
|
|
265
|
+
|
|
266
|
+
ImmuneBuilder stores ensemble disagreement on `antibody.error_estimates`;
|
|
267
|
+
`error_estimates.mean(0).sqrt().cpu().numpy()` produces a 1D array of
|
|
268
|
+
per-residue RMSD in Å indexed in the order of `numbered_sequences` flattened
|
|
269
|
+
(heavy chain then light). This is the same array that gets written to the
|
|
270
|
+
B-factor column of the saved PDB (see `ImmuneBuilder.util.add_errors_as_bfactors`).
|
|
271
|
+
|
|
272
|
+
Returns a list of dicts shaped per spec R34:
|
|
273
|
+
{"pos": "<string>", "chain": "H"|"L", "errorAngstroms": <float>}
|
|
274
|
+
Stores an additional `_region` marker (stripped before JSON emission) so
|
|
275
|
+
we can aggregate per-CDR quickly in `_region_errors`.
|
|
276
|
+
"""
|
|
277
|
+
records: list[dict] = []
|
|
278
|
+
residues = extract_numbered_residues(antibody)
|
|
279
|
+
|
|
280
|
+
err_array = None
|
|
281
|
+
if hasattr(antibody, "error_estimates"):
|
|
282
|
+
try:
|
|
283
|
+
err_array = antibody.error_estimates.mean(0).sqrt().cpu().numpy()
|
|
284
|
+
except Exception:
|
|
285
|
+
err_array = None
|
|
286
|
+
|
|
287
|
+
for idx, r in enumerate(residues):
|
|
288
|
+
if err_array is not None and idx < len(err_array):
|
|
289
|
+
err_val = float(err_array[idx])
|
|
290
|
+
else:
|
|
291
|
+
err_val = 0.0
|
|
292
|
+
records.append({
|
|
293
|
+
"pos": r.pos,
|
|
294
|
+
"chain": r.chain,
|
|
295
|
+
"errorAngstroms": err_val,
|
|
296
|
+
"_region": r.region,
|
|
297
|
+
})
|
|
298
|
+
return records
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _json_safe_per_residue(records: list[dict]) -> str:
|
|
302
|
+
cleaned = [
|
|
303
|
+
{"pos": r["pos"], "chain": r["chain"], "errorAngstroms": r["errorAngstroms"]}
|
|
304
|
+
for r in records
|
|
305
|
+
]
|
|
306
|
+
return json.dumps(cleaned, separators=(",", ":"))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _format_number(value: float | None, digits: int = 3) -> str:
|
|
310
|
+
if value is None:
|
|
311
|
+
return ""
|
|
312
|
+
return f"{value:.{digits}f}"
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def _format_int(value: int | None) -> str:
|
|
316
|
+
return "" if value is None else str(value)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _metric_value(result: RowResult, metric: str) -> float | None:
|
|
320
|
+
src = result.cdrh3 if metric == "cdrh3Mean" else result.mean_error
|
|
321
|
+
if not src:
|
|
322
|
+
return None
|
|
323
|
+
try:
|
|
324
|
+
return float(src)
|
|
325
|
+
except ValueError:
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _build_summary(
|
|
330
|
+
results: list[RowResult],
|
|
331
|
+
metric: str,
|
|
332
|
+
threshold: float,
|
|
333
|
+
) -> dict:
|
|
334
|
+
"""Aggregate per-row stats into the schema the model + UI consume.
|
|
335
|
+
|
|
336
|
+
Schema is part of the contract between this script and the block model
|
|
337
|
+
(model.failureStats); keep the field names stable.
|
|
338
|
+
"""
|
|
339
|
+
by_failure: dict[str, int] = {}
|
|
340
|
+
by_warning: dict[str, int] = {}
|
|
341
|
+
metric_values: list[float] = []
|
|
342
|
+
confident = 0
|
|
343
|
+
succeeded = 0
|
|
344
|
+
|
|
345
|
+
for r in results:
|
|
346
|
+
if r.failure_reason:
|
|
347
|
+
by_failure[r.failure_reason] = by_failure.get(r.failure_reason, 0) + 1
|
|
348
|
+
else:
|
|
349
|
+
succeeded += 1
|
|
350
|
+
v = _metric_value(r, metric)
|
|
351
|
+
if v is not None:
|
|
352
|
+
metric_values.append(v)
|
|
353
|
+
if v <= threshold:
|
|
354
|
+
confident += 1
|
|
355
|
+
for w in r.warnings:
|
|
356
|
+
by_warning[w] = by_warning.get(w, 0) + 1
|
|
357
|
+
|
|
358
|
+
summary: dict = {
|
|
359
|
+
"totalRows": len(results),
|
|
360
|
+
"succeeded": succeeded,
|
|
361
|
+
"failed": len(results) - succeeded,
|
|
362
|
+
"byFailureReason": by_failure,
|
|
363
|
+
"byWarning": by_warning,
|
|
364
|
+
"metric": metric,
|
|
365
|
+
"thresholdAngstroms": threshold,
|
|
366
|
+
"confidentCount": confident,
|
|
367
|
+
}
|
|
368
|
+
if metric_values:
|
|
369
|
+
summary["metricMean"] = sum(metric_values) / len(metric_values)
|
|
370
|
+
summary["metricMin"] = min(metric_values)
|
|
371
|
+
summary["metricMax"] = max(metric_values)
|
|
372
|
+
return summary
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def process_batch(
|
|
376
|
+
input_tsv: Path,
|
|
377
|
+
pdb_dir: Path,
|
|
378
|
+
manifest_tsv: Path,
|
|
379
|
+
confidence_tsv: Path,
|
|
380
|
+
summary_json: Path | None,
|
|
381
|
+
mode: str,
|
|
382
|
+
seed: int,
|
|
383
|
+
metric: str,
|
|
384
|
+
threshold: float,
|
|
385
|
+
) -> None:
|
|
386
|
+
pdb_dir.mkdir(parents=True, exist_ok=True)
|
|
387
|
+
manifest_tsv.parent.mkdir(parents=True, exist_ok=True)
|
|
388
|
+
confidence_tsv.parent.mkdir(parents=True, exist_ok=True)
|
|
389
|
+
|
|
390
|
+
with open(input_tsv, newline="") as f:
|
|
391
|
+
reader = csv.DictReader(f, delimiter="\t")
|
|
392
|
+
fieldnames = reader.fieldnames or []
|
|
393
|
+
key_col = pick_key_column(fieldnames)
|
|
394
|
+
rows = list(reader)
|
|
395
|
+
|
|
396
|
+
ib_version = get_immunebuilder_version()
|
|
397
|
+
block_version = get_block_version()
|
|
398
|
+
|
|
399
|
+
_log(
|
|
400
|
+
f"start mode={mode} rows={len(rows)} seed={seed} metric={metric} "
|
|
401
|
+
f"threshold={threshold} immunebuilder={ib_version} block={block_version}"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
if not rows:
|
|
405
|
+
_log("no input rows; skipping ImmuneBuilder load and emitting empty outputs")
|
|
406
|
+
|
|
407
|
+
_set_seed(seed)
|
|
408
|
+
if rows:
|
|
409
|
+
_log(f"loading {mode} ensemble (4 models)")
|
|
410
|
+
predictor_t0 = time.time()
|
|
411
|
+
predictor = _load_predictor(mode)
|
|
412
|
+
_log(f"predictor ready in {time.time() - predictor_t0:.1f}s")
|
|
413
|
+
else:
|
|
414
|
+
predictor = None
|
|
415
|
+
|
|
416
|
+
# Intra-batch dedup cache (R16). Key: (clean_vh, clean_vl).
|
|
417
|
+
prediction_cache: dict[tuple[str, str], object] = {}
|
|
418
|
+
results: list[RowResult] = []
|
|
419
|
+
|
|
420
|
+
n = len(rows)
|
|
421
|
+
log_every = max(1, n // 20) if n > 20 else 1
|
|
422
|
+
fail_count = 0
|
|
423
|
+
success_count = 0
|
|
424
|
+
cache_hits = 0
|
|
425
|
+
run_t0 = time.time()
|
|
426
|
+
|
|
427
|
+
for row_idx, row in enumerate(rows):
|
|
428
|
+
key = row.get(key_col, "")
|
|
429
|
+
# Use the human-readable label (e.g. CDR3 sequence) for log lines
|
|
430
|
+
# AND echo it into confidence.tsv as the `clonotypeLabel` column —
|
|
431
|
+
# the workflow's xsv import surfaces that column as `pl7.app/label`
|
|
432
|
+
# so the V3 structures table substitutes it into the row-axis cells.
|
|
433
|
+
label_from_row = row.get("clonotypeLabel") or ""
|
|
434
|
+
result = RowResult(clonotype_key=key, clonotype_label=label_from_row)
|
|
435
|
+
prefix = f"[{row_idx + 1}/{n}] {label_from_row or key}"
|
|
436
|
+
|
|
437
|
+
sanitized = sanitize_pair(
|
|
438
|
+
row.get("heavyChain", ""),
|
|
439
|
+
row.get("lightChain", "") if mode == "ABodyBuilder2" else None,
|
|
440
|
+
mode,
|
|
441
|
+
)
|
|
442
|
+
result.warnings.extend(sanitized.warnings)
|
|
443
|
+
|
|
444
|
+
if not sanitized.success:
|
|
445
|
+
result.failure_reason = sanitized.failure_reason
|
|
446
|
+
fail_count += 1
|
|
447
|
+
_log(f"{prefix} FAIL sanitize reason={sanitized.failure_reason}")
|
|
448
|
+
results.append(result)
|
|
449
|
+
continue
|
|
450
|
+
|
|
451
|
+
if sanitized.warnings:
|
|
452
|
+
_log(f"{prefix} warning {','.join(sanitized.warnings)}")
|
|
453
|
+
|
|
454
|
+
cache_key = (sanitized.vh, sanitized.vl)
|
|
455
|
+
antibody = prediction_cache.get(cache_key)
|
|
456
|
+
if antibody is None:
|
|
457
|
+
try:
|
|
458
|
+
pred_t0 = time.time()
|
|
459
|
+
antibody = _predict_one(predictor, mode, sanitized.vh, sanitized.vl)
|
|
460
|
+
pred_dt = time.time() - pred_t0
|
|
461
|
+
prediction_cache[cache_key] = antibody
|
|
462
|
+
if (row_idx + 1) % log_every == 0 or row_idx == 0:
|
|
463
|
+
_log(f"{prefix} predicted in {pred_dt:.1f}s")
|
|
464
|
+
except Exception as exc: # noqa: BLE001
|
|
465
|
+
fail_count += 1
|
|
466
|
+
_log(f"{prefix} FAIL ImmuneBuilder {type(exc).__name__}: {exc}")
|
|
467
|
+
traceback.print_exc(file=sys.stderr)
|
|
468
|
+
result.failure_reason = f"immunebuilder_exception:{type(exc).__name__}"
|
|
469
|
+
results.append(result)
|
|
470
|
+
continue
|
|
471
|
+
else:
|
|
472
|
+
cache_hits += 1
|
|
473
|
+
_log(f"{prefix} dedup-cache hit (skipping ImmuneBuilder call)")
|
|
474
|
+
|
|
475
|
+
pdb_filename = indexed_filename(row_idx)
|
|
476
|
+
raw_pdb_path = pdb_dir / ("_raw_" + pdb_filename)
|
|
477
|
+
final_pdb_path = pdb_dir / pdb_filename
|
|
478
|
+
try:
|
|
479
|
+
antibody.save(str(raw_pdb_path))
|
|
480
|
+
augment_pdb(
|
|
481
|
+
raw_pdb_path,
|
|
482
|
+
final_pdb_path,
|
|
483
|
+
mode=mode,
|
|
484
|
+
immunebuilder_version=ib_version,
|
|
485
|
+
torch_seed=seed,
|
|
486
|
+
block_version=block_version,
|
|
487
|
+
numbering_scheme="imgt",
|
|
488
|
+
)
|
|
489
|
+
try:
|
|
490
|
+
raw_pdb_path.unlink()
|
|
491
|
+
except FileNotFoundError:
|
|
492
|
+
pass
|
|
493
|
+
except Exception as exc: # noqa: BLE001
|
|
494
|
+
fail_count += 1
|
|
495
|
+
_log(f"{prefix} FAIL save/augment {type(exc).__name__}: {exc}")
|
|
496
|
+
traceback.print_exc(file=sys.stderr)
|
|
497
|
+
result.failure_reason = f"immunebuilder_exception:{type(exc).__name__}"
|
|
498
|
+
results.append(result)
|
|
499
|
+
continue
|
|
500
|
+
|
|
501
|
+
residues = _per_residue_records(antibody)
|
|
502
|
+
per_res_json = _json_safe_per_residue(residues)
|
|
503
|
+
all_err = [r["errorAngstroms"] for r in residues]
|
|
504
|
+
mean_err = _mean(all_err)
|
|
505
|
+
cdrh1 = _mean(_region_errors(residues, "H", "CDR1"))
|
|
506
|
+
cdrh2 = _mean(_region_errors(residues, "H", "CDR2"))
|
|
507
|
+
cdrh3 = _mean(_region_errors(residues, "H", "CDR3"))
|
|
508
|
+
|
|
509
|
+
result.per_residue_json = per_res_json
|
|
510
|
+
result.mean_error = _format_number(mean_err)
|
|
511
|
+
result.cdrh1 = _format_number(cdrh1)
|
|
512
|
+
result.cdrh2 = _format_number(cdrh2)
|
|
513
|
+
result.cdrh3 = _format_number(cdrh3)
|
|
514
|
+
|
|
515
|
+
if mode == "ABodyBuilder2":
|
|
516
|
+
result.cdrl1 = _format_number(_mean(_region_errors(residues, "L", "CDR1")))
|
|
517
|
+
result.cdrl2 = _format_number(_mean(_region_errors(residues, "L", "CDR2")))
|
|
518
|
+
result.cdrl3 = _format_number(_mean(_region_errors(residues, "L", "CDR3")))
|
|
519
|
+
|
|
520
|
+
numbered = extract_numbered_residues(antibody)
|
|
521
|
+
nb = imgt_cdrh3_length(numbered)
|
|
522
|
+
result.cdrh3_len = _format_int(nb)
|
|
523
|
+
if nb >= 20:
|
|
524
|
+
result.warnings.append("long_cdrh3")
|
|
525
|
+
if mode == "NanoBodyBuilder2" and not vhh_hallmarks_present(numbered):
|
|
526
|
+
result.warnings.append("vhh_hallmarks_missing")
|
|
527
|
+
result.pdb_filename = pdb_filename
|
|
528
|
+
|
|
529
|
+
success_count += 1
|
|
530
|
+
if (row_idx + 1) % log_every == 0 or row_idx == 0 or row_idx == n - 1:
|
|
531
|
+
_log(
|
|
532
|
+
f"{prefix} OK mean={result.mean_error}Å cdrh3={result.cdrh3}Å "
|
|
533
|
+
f"cdrh3Length={result.cdrh3_len}"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
results.append(result)
|
|
537
|
+
|
|
538
|
+
# Preserve the input key-column header in our outputs. The batch
|
|
539
|
+
# orchestrator hands us TSVs whose key column is named after the axis spec
|
|
540
|
+
# (e.g. `pl7.app/vdj/clonotypeKey`); when batches are concatenated, the
|
|
541
|
+
# downstream xsv import expects that same header back.
|
|
542
|
+
manifest_fields = [key_col if f == KEY_COLUMN_PLACEHOLDER else f for f in MANIFEST_FIELDS]
|
|
543
|
+
confidence_fields = build_confidence_fields(key_col)
|
|
544
|
+
|
|
545
|
+
with open(manifest_tsv, "w", newline="") as f:
|
|
546
|
+
writer = csv.DictWriter(
|
|
547
|
+
f, fieldnames=manifest_fields, delimiter="\t", lineterminator="\n"
|
|
548
|
+
)
|
|
549
|
+
writer.writeheader()
|
|
550
|
+
for r in results:
|
|
551
|
+
if r.pdb_filename:
|
|
552
|
+
writer.writerow({key_col: r.clonotype_key, "pdb_filename": r.pdb_filename})
|
|
553
|
+
|
|
554
|
+
with open(confidence_tsv, "w", newline="") as f:
|
|
555
|
+
writer = csv.DictWriter(f, fieldnames=confidence_fields, delimiter="\t")
|
|
556
|
+
writer.writeheader()
|
|
557
|
+
for r in results:
|
|
558
|
+
writer.writerow(r.to_tsv_row(key_col))
|
|
559
|
+
|
|
560
|
+
summary = _build_summary(results, metric, threshold)
|
|
561
|
+
if summary_json is not None:
|
|
562
|
+
summary_json.parent.mkdir(parents=True, exist_ok=True)
|
|
563
|
+
with open(summary_json, "w") as f:
|
|
564
|
+
json.dump(summary, f)
|
|
565
|
+
|
|
566
|
+
elapsed = time.time() - run_t0 if rows else 0.0
|
|
567
|
+
_log(
|
|
568
|
+
f"done total={summary['totalRows']} succeeded={summary['succeeded']} "
|
|
569
|
+
f"failed={summary['failed']} confident={summary['confidentCount']} "
|
|
570
|
+
f"cache_hits={cache_hits} elapsed={elapsed:.1f}s"
|
|
571
|
+
)
|
|
572
|
+
if summary["byFailureReason"]:
|
|
573
|
+
for reason, n_fail in sorted(
|
|
574
|
+
summary["byFailureReason"].items(), key=lambda kv: -kv[1]
|
|
575
|
+
):
|
|
576
|
+
_log(f" failure: {n_fail} × {reason}")
|
|
577
|
+
if summary["byWarning"]:
|
|
578
|
+
for warning, n_warn in sorted(
|
|
579
|
+
summary["byWarning"].items(), key=lambda kv: -kv[1]
|
|
580
|
+
):
|
|
581
|
+
_log(f" warning: {n_warn} × {warning}")
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def warmup(mode: str, sentinel: Path | None) -> None:
|
|
585
|
+
"""Force the ImmuneBuilder weight download into a known cache location.
|
|
586
|
+
|
|
587
|
+
Run as a single pre-step before the parallel batch fan-out. Avoids the
|
|
588
|
+
race where multiple batch containers download the same weight files into
|
|
589
|
+
a shared cache dir simultaneously, producing partial / corrupt files.
|
|
590
|
+
"""
|
|
591
|
+
_log(f"warmup mode={mode} loading predictor (this may download weights)")
|
|
592
|
+
_load_predictor(mode)
|
|
593
|
+
if sentinel is not None:
|
|
594
|
+
sentinel.write_text(
|
|
595
|
+
f"mode={mode}\nimmunebuilder_version={get_immunebuilder_version()}\n"
|
|
596
|
+
)
|
|
597
|
+
_log("warmup OK")
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def main() -> None:
|
|
601
|
+
parser = argparse.ArgumentParser()
|
|
602
|
+
parser.add_argument("--mode", choices=["ABodyBuilder2", "NanoBodyBuilder2"], required=True)
|
|
603
|
+
parser.add_argument("--warmup", action="store_true",
|
|
604
|
+
help="Pre-download model weights and exit. --input/--output-dir/--manifest/--confidence are not used.")
|
|
605
|
+
parser.add_argument("--sentinel", default=None,
|
|
606
|
+
help="In --warmup mode, path to a sentinel file written on success.")
|
|
607
|
+
parser.add_argument("--input", help="Batch TSV with clonotypeKey, heavyChain[, lightChain]")
|
|
608
|
+
parser.add_argument("--output-dir", help="Directory for per-clonotype PDB files")
|
|
609
|
+
parser.add_argument("--manifest", help="Path to manifest.tsv")
|
|
610
|
+
parser.add_argument("--confidence", help="Path to confidence.tsv")
|
|
611
|
+
parser.add_argument("--summary", default=None, help="Path to summary.json (aggregate stats for the model)")
|
|
612
|
+
parser.add_argument("--seed", type=int, default=42)
|
|
613
|
+
parser.add_argument("--metric", choices=["cdrh3Mean", "overallMean"], default="cdrh3Mean")
|
|
614
|
+
parser.add_argument("--threshold", type=float, default=2.5,
|
|
615
|
+
help="Confidence threshold (Å) used to derive confidentCount in summary.json")
|
|
616
|
+
args = parser.parse_args()
|
|
617
|
+
|
|
618
|
+
if args.warmup:
|
|
619
|
+
warmup(args.mode, Path(args.sentinel) if args.sentinel else None)
|
|
620
|
+
return
|
|
621
|
+
|
|
622
|
+
missing = [
|
|
623
|
+
name for name, value in [
|
|
624
|
+
("--input", args.input),
|
|
625
|
+
("--output-dir", args.output_dir),
|
|
626
|
+
("--manifest", args.manifest),
|
|
627
|
+
("--confidence", args.confidence),
|
|
628
|
+
] if not value
|
|
629
|
+
]
|
|
630
|
+
if missing:
|
|
631
|
+
parser.error(f"the following arguments are required when not in --warmup mode: {', '.join(missing)}")
|
|
632
|
+
|
|
633
|
+
process_batch(
|
|
634
|
+
input_tsv=Path(args.input),
|
|
635
|
+
pdb_dir=Path(args.output_dir),
|
|
636
|
+
manifest_tsv=Path(args.manifest),
|
|
637
|
+
confidence_tsv=Path(args.confidence),
|
|
638
|
+
summary_json=Path(args.summary) if args.summary else None,
|
|
639
|
+
mode=args.mode,
|
|
640
|
+
seed=args.seed,
|
|
641
|
+
metric=args.metric,
|
|
642
|
+
threshold=args.threshold,
|
|
643
|
+
)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
if __name__ == "__main__":
|
|
647
|
+
main()
|