@platforma-open/milaboratories.3d-structure-prediction.software 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,647 @@
1
+ """ImmuneBuilder batch runner for the Platforma 3D Structure Prediction block.
2
+
3
+ Reads a batch TSV of clonotypes and predicts structures via ABodyBuilder2 or
4
+ NanoBodyBuilder2 (spec R22). Emits:
5
+
6
+ - Per-clonotype PDB files named `<sha1(clonotypeKey)>.pdb` (R30).
7
+ - `manifest.tsv` : (clonotypeKey, pdb_filename).
8
+ - `confidence.tsv`: aggregate + per-residue confidence (Å error, R32-R36)
9
+ plus failureReason (R40) and warning columns.
10
+
11
+ Dependencies (ImmuneBuilder, torch) ride the venv that pl-pkg's install-deps
12
+ creates. ANARCI and pdbfixer are not on PyPI; the atls runenv builds them
13
+ from source and stages them in the runenv's site-packages. The SDK's venv
14
+ is created without --system-site-packages, so we bootstrap the runenv's
15
+ site-packages onto sys.path here before any project-level imports run.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os as _os
21
+ import sys as _sys
22
+ import sysconfig as _sysconfig
23
+
24
+ _runenv_root = _os.environ.get("PYTHONHOME")
25
+ if _runenv_root:
26
+ # Pick up packages staged via runenv-python-builder's `copyFiles` directive
27
+ # (e.g. `{site-packages}/anarci`, `{site-packages}/pdbfixer`).
28
+ _py_ver_short = f"python{_sys.version_info.major}.{_sys.version_info.minor}"
29
+ _candidates = [
30
+ _os.path.join(_runenv_root, "lib", _py_ver_short, "site-packages"),
31
+ _os.path.join(_runenv_root, "Lib", "site-packages"), # Windows layout
32
+ ]
33
+ _platlib = _sysconfig.get_paths().get("platlib")
34
+ if _platlib:
35
+ _candidates.append(_platlib)
36
+ for _candidate in _candidates:
37
+ if _os.path.isdir(_candidate) and _candidate not in _sys.path:
38
+ _sys.path.append(_candidate)
39
+
40
+ # ANARCI shells out to `hmmscan` (HMMER); the binary ships in the runenv's
41
+ # bin/ via copyFiles, but only the venv's bin is on PATH by default.
42
+ _runenv_bin = _os.path.join(_runenv_root, "bin")
43
+ if _os.path.isdir(_runenv_bin):
44
+ _path_env = _os.environ.get("PATH", "")
45
+ if _runenv_bin not in _path_env.split(_os.pathsep):
46
+ _os.environ["PATH"] = _runenv_bin + _os.pathsep + _path_env
47
+
48
+ import argparse
49
+ import csv
50
+ import hashlib
51
+ import json
52
+ import os
53
+ import sys
54
+ import time
55
+ import traceback
56
+ from dataclasses import asdict, dataclass, field
57
+ from datetime import datetime, timezone
58
+ from pathlib import Path
59
+
60
+
61
+ def _log(message: str) -> None:
62
+ """Line-buffered log entry to stderr.
63
+
64
+ Workflow uses `printErrStreamToStdout` + `saveStdoutStream`; the resulting
65
+ log handle is consumed by `PlLogView` in the UI. Every line is timestamped
66
+ so users can track progress across long-running batches.
67
+ """
68
+ ts = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
69
+ print(f"[{ts}] {message}", file=sys.stderr, flush=True)
70
+
71
+ from numbering import cdrh3_length as imgt_cdrh3_length
72
+ from numbering import extract_numbered_residues, vhh_hallmarks_present
73
+ from pdb_writer import augment_pdb
74
+ from sanitize import sanitize_pair
75
+
76
+ # `clonotypeKey` is a placeholder — the actual column name we write back is
77
+ # the same name we read from the input TSV (the orchestrator uses the axis
78
+ # spec name, e.g. `pl7.app/vdj/clonotypeKey`, as the column header). See
79
+ # build_confidence_fields().
80
+ KEY_COLUMN_PLACEHOLDER = "clonotypeKey"
81
+
82
+ CONFIDENCE_FIELDS = [
83
+ KEY_COLUMN_PLACEHOLDER,
84
+ "clonotypeLabel",
85
+ "meanError",
86
+ "cdrh1Error",
87
+ "cdrh2Error",
88
+ "cdrh3Error",
89
+ "cdrl1Error",
90
+ "cdrl2Error",
91
+ "cdrl3Error",
92
+ "perResidueError",
93
+ "cdrh3Length",
94
+ "failureReason",
95
+ "failureReasonText",
96
+ "warning",
97
+ "warningText",
98
+ ]
99
+
100
+
101
+ def build_confidence_fields(key_col: str) -> list[str]:
102
+ return [key_col if f == KEY_COLUMN_PLACEHOLDER else f for f in CONFIDENCE_FIELDS]
103
+
104
+
105
+ # Code → human-readable label maps. The code columns stay (hidden by default)
106
+ # so downstream blocks / future failure-stats logic can group by enum value;
107
+ # the *Text columns are what the user actually sees in the table.
108
+ FAILURE_REASON_LABELS: dict[str, str] = {
109
+ "empty_sequence": "Empty sequence",
110
+ "stop_codon_mid_sequence": "Stop codon in sequence",
111
+ "non_standard_aa_only_after_strip": "No standard amino acids after cleanup",
112
+ "non_standard_aa_residue": "Non-standard amino acid residue",
113
+ "length_out_of_range": "Length outside expected VH/VL range",
114
+ "light_chain_missing_in_paired_mode":
115
+ "Light chain missing — switch to NanoBodyBuilder2 or pick a light chain",
116
+ }
117
+
118
+ WARNING_LABELS: dict[str, str] = {
119
+ "probable_signal_peptide": "Possible N-terminal signal peptide",
120
+ "long_cdrh3": "Long CDR-H3 (≥20 aa) — confidence may be reduced",
121
+ "vhh_hallmarks_missing": "VHH hallmark residues not detected",
122
+ }
123
+
124
+
125
+ def _failure_reason_label(code: str) -> str:
126
+ if not code:
127
+ return ""
128
+ if code in FAILURE_REASON_LABELS:
129
+ return FAILURE_REASON_LABELS[code]
130
+ # Structured prefixes — preserve the suffix for triage rather than
131
+ # collapsing to a generic message.
132
+ if code.startswith("immunebuilder_exception:"):
133
+ return f"ImmuneBuilder error: {code.split(':', 1)[1]}"
134
+ if code.startswith("unknown_mode:"):
135
+ return f"Unknown prediction mode: {code.split(':', 1)[1]}"
136
+ return code
137
+
138
+
139
+ def _warning_label(code: str) -> str:
140
+ return WARNING_LABELS.get(code, code)
141
+
142
+ MANIFEST_FIELDS = [KEY_COLUMN_PLACEHOLDER, "pdb_filename"]
143
+
144
+
145
+ def indexed_filename(row_idx: int) -> str:
146
+ return f"pdb_{row_idx:05d}.pdb"
147
+
148
+
149
+ def sha1_filename(key: str) -> str:
150
+ """Retained for manifest metadata — downstream blocks can SHA-1 keys to cross-reference."""
151
+ return hashlib.sha1(key.encode("utf-8")).hexdigest() + ".pdb"
152
+
153
+
154
+ def get_immunebuilder_version() -> str:
155
+ try:
156
+ from importlib.metadata import version as _version
157
+ return _version("ImmuneBuilder")
158
+ except Exception:
159
+ return "unknown"
160
+
161
+
162
+ def get_block_version() -> str:
163
+ return os.environ.get("BLOCK_VERSION", "unknown")
164
+
165
+
166
+ def pick_key_column(fieldnames: list[str]) -> str:
167
+ for name in fieldnames:
168
+ if name.lower() == "clonotypekey" or name.lower().endswith("clonotypekey"):
169
+ return name
170
+ return fieldnames[0]
171
+
172
+
173
+ @dataclass
174
+ class RowResult:
175
+ clonotype_key: str
176
+ clonotype_label: str = ""
177
+ mean_error: str = ""
178
+ cdrh1: str = ""
179
+ cdrh2: str = ""
180
+ cdrh3: str = ""
181
+ cdrl1: str = ""
182
+ cdrl2: str = ""
183
+ cdrl3: str = ""
184
+ per_residue_json: str = ""
185
+ cdrh3_len: str = ""
186
+ failure_reason: str = ""
187
+ warnings: list[str] = field(default_factory=list)
188
+ pdb_filename: str = ""
189
+
190
+ @property
191
+ def warning_str(self) -> str:
192
+ return ";".join(self.warnings)
193
+
194
+ @property
195
+ def warning_text(self) -> str:
196
+ return "; ".join(_warning_label(w) for w in self.warnings)
197
+
198
+ def to_tsv_row(self, key_col: str) -> dict[str, str]:
199
+ return {
200
+ key_col: self.clonotype_key,
201
+ "clonotypeLabel": self.clonotype_label,
202
+ "meanError": self.mean_error,
203
+ "cdrh1Error": self.cdrh1,
204
+ "cdrh2Error": self.cdrh2,
205
+ "cdrh3Error": self.cdrh3,
206
+ "cdrl1Error": self.cdrl1,
207
+ "cdrl2Error": self.cdrl2,
208
+ "cdrl3Error": self.cdrl3,
209
+ "perResidueError": self.per_residue_json,
210
+ "cdrh3Length": self.cdrh3_len,
211
+ "failureReason": self.failure_reason,
212
+ "failureReasonText": _failure_reason_label(self.failure_reason),
213
+ "warning": self.warning_str,
214
+ "warningText": self.warning_text,
215
+ }
216
+
217
+
218
+ def _mean(values: list[float]) -> float | None:
219
+ valid = [v for v in values if v is not None]
220
+ return sum(valid) / len(valid) if valid else None
221
+
222
+
223
+ def _region_errors(per_residue, chain: str, cdr_name: str) -> list[float]:
224
+ target = f"CDR{cdr_name[-1]}" # "CDR1" or similar
225
+ return [
226
+ r["errorAngstroms"]
227
+ for r in per_residue
228
+ if r["chain"] == chain and r.get("_region") == target
229
+ ]
230
+
231
+
232
+ def _load_predictor(mode: str):
233
+ if mode == "ABodyBuilder2":
234
+ from ImmuneBuilder import ABodyBuilder2
235
+ return ABodyBuilder2()
236
+ if mode == "NanoBodyBuilder2":
237
+ from ImmuneBuilder import NanoBodyBuilder2
238
+ return NanoBodyBuilder2()
239
+ raise ValueError(f"unknown mode: {mode}")
240
+
241
+
242
+ def _set_seed(seed: int) -> None:
243
+ import random
244
+ import numpy as np
245
+ random.seed(seed)
246
+ np.random.seed(seed)
247
+ try:
248
+ import torch
249
+ torch.manual_seed(seed)
250
+ if torch.cuda.is_available():
251
+ torch.cuda.manual_seed_all(seed)
252
+ except Exception:
253
+ pass
254
+
255
+
256
+ def _predict_one(predictor, mode: str, vh: str, vl: str | None):
257
+ sequences = {"H": vh}
258
+ if mode == "ABodyBuilder2" and vl:
259
+ sequences["L"] = vl
260
+ return predictor.predict(sequences)
261
+
262
+
263
+ def _per_residue_records(antibody) -> list[dict]:
264
+ """Extract per-residue error records.
265
+
266
+ ImmuneBuilder stores ensemble disagreement on `antibody.error_estimates`;
267
+ `error_estimates.mean(0).sqrt().cpu().numpy()` produces a 1D array of
268
+ per-residue RMSD in Å indexed in the order of `numbered_sequences` flattened
269
+ (heavy chain then light). This is the same array that gets written to the
270
+ B-factor column of the saved PDB (see `ImmuneBuilder.util.add_errors_as_bfactors`).
271
+
272
+ Returns a list of dicts shaped per spec R34:
273
+ {"pos": "<string>", "chain": "H"|"L", "errorAngstroms": <float>}
274
+ Stores an additional `_region` marker (stripped before JSON emission) so
275
+ we can aggregate per-CDR quickly in `_region_errors`.
276
+ """
277
+ records: list[dict] = []
278
+ residues = extract_numbered_residues(antibody)
279
+
280
+ err_array = None
281
+ if hasattr(antibody, "error_estimates"):
282
+ try:
283
+ err_array = antibody.error_estimates.mean(0).sqrt().cpu().numpy()
284
+ except Exception:
285
+ err_array = None
286
+
287
+ for idx, r in enumerate(residues):
288
+ if err_array is not None and idx < len(err_array):
289
+ err_val = float(err_array[idx])
290
+ else:
291
+ err_val = 0.0
292
+ records.append({
293
+ "pos": r.pos,
294
+ "chain": r.chain,
295
+ "errorAngstroms": err_val,
296
+ "_region": r.region,
297
+ })
298
+ return records
299
+
300
+
301
+ def _json_safe_per_residue(records: list[dict]) -> str:
302
+ cleaned = [
303
+ {"pos": r["pos"], "chain": r["chain"], "errorAngstroms": r["errorAngstroms"]}
304
+ for r in records
305
+ ]
306
+ return json.dumps(cleaned, separators=(",", ":"))
307
+
308
+
309
+ def _format_number(value: float | None, digits: int = 3) -> str:
310
+ if value is None:
311
+ return ""
312
+ return f"{value:.{digits}f}"
313
+
314
+
315
+ def _format_int(value: int | None) -> str:
316
+ return "" if value is None else str(value)
317
+
318
+
319
+ def _metric_value(result: RowResult, metric: str) -> float | None:
320
+ src = result.cdrh3 if metric == "cdrh3Mean" else result.mean_error
321
+ if not src:
322
+ return None
323
+ try:
324
+ return float(src)
325
+ except ValueError:
326
+ return None
327
+
328
+
329
+ def _build_summary(
330
+ results: list[RowResult],
331
+ metric: str,
332
+ threshold: float,
333
+ ) -> dict:
334
+ """Aggregate per-row stats into the schema the model + UI consume.
335
+
336
+ Schema is part of the contract between this script and the block model
337
+ (model.failureStats); keep the field names stable.
338
+ """
339
+ by_failure: dict[str, int] = {}
340
+ by_warning: dict[str, int] = {}
341
+ metric_values: list[float] = []
342
+ confident = 0
343
+ succeeded = 0
344
+
345
+ for r in results:
346
+ if r.failure_reason:
347
+ by_failure[r.failure_reason] = by_failure.get(r.failure_reason, 0) + 1
348
+ else:
349
+ succeeded += 1
350
+ v = _metric_value(r, metric)
351
+ if v is not None:
352
+ metric_values.append(v)
353
+ if v <= threshold:
354
+ confident += 1
355
+ for w in r.warnings:
356
+ by_warning[w] = by_warning.get(w, 0) + 1
357
+
358
+ summary: dict = {
359
+ "totalRows": len(results),
360
+ "succeeded": succeeded,
361
+ "failed": len(results) - succeeded,
362
+ "byFailureReason": by_failure,
363
+ "byWarning": by_warning,
364
+ "metric": metric,
365
+ "thresholdAngstroms": threshold,
366
+ "confidentCount": confident,
367
+ }
368
+ if metric_values:
369
+ summary["metricMean"] = sum(metric_values) / len(metric_values)
370
+ summary["metricMin"] = min(metric_values)
371
+ summary["metricMax"] = max(metric_values)
372
+ return summary
373
+
374
+
375
+ def process_batch(
376
+ input_tsv: Path,
377
+ pdb_dir: Path,
378
+ manifest_tsv: Path,
379
+ confidence_tsv: Path,
380
+ summary_json: Path | None,
381
+ mode: str,
382
+ seed: int,
383
+ metric: str,
384
+ threshold: float,
385
+ ) -> None:
386
+ pdb_dir.mkdir(parents=True, exist_ok=True)
387
+ manifest_tsv.parent.mkdir(parents=True, exist_ok=True)
388
+ confidence_tsv.parent.mkdir(parents=True, exist_ok=True)
389
+
390
+ with open(input_tsv, newline="") as f:
391
+ reader = csv.DictReader(f, delimiter="\t")
392
+ fieldnames = reader.fieldnames or []
393
+ key_col = pick_key_column(fieldnames)
394
+ rows = list(reader)
395
+
396
+ ib_version = get_immunebuilder_version()
397
+ block_version = get_block_version()
398
+
399
+ _log(
400
+ f"start mode={mode} rows={len(rows)} seed={seed} metric={metric} "
401
+ f"threshold={threshold} immunebuilder={ib_version} block={block_version}"
402
+ )
403
+
404
+ if not rows:
405
+ _log("no input rows; skipping ImmuneBuilder load and emitting empty outputs")
406
+
407
+ _set_seed(seed)
408
+ if rows:
409
+ _log(f"loading {mode} ensemble (4 models)")
410
+ predictor_t0 = time.time()
411
+ predictor = _load_predictor(mode)
412
+ _log(f"predictor ready in {time.time() - predictor_t0:.1f}s")
413
+ else:
414
+ predictor = None
415
+
416
+ # Intra-batch dedup cache (R16). Key: (clean_vh, clean_vl).
417
+ prediction_cache: dict[tuple[str, str], object] = {}
418
+ results: list[RowResult] = []
419
+
420
+ n = len(rows)
421
+ log_every = max(1, n // 20) if n > 20 else 1
422
+ fail_count = 0
423
+ success_count = 0
424
+ cache_hits = 0
425
+ run_t0 = time.time()
426
+
427
+ for row_idx, row in enumerate(rows):
428
+ key = row.get(key_col, "")
429
+ # Use the human-readable label (e.g. CDR3 sequence) for log lines
430
+ # AND echo it into confidence.tsv as the `clonotypeLabel` column —
431
+ # the workflow's xsv import surfaces that column as `pl7.app/label`
432
+ # so the V3 structures table substitutes it into the row-axis cells.
433
+ label_from_row = row.get("clonotypeLabel") or ""
434
+ result = RowResult(clonotype_key=key, clonotype_label=label_from_row)
435
+ prefix = f"[{row_idx + 1}/{n}] {label_from_row or key}"
436
+
437
+ sanitized = sanitize_pair(
438
+ row.get("heavyChain", ""),
439
+ row.get("lightChain", "") if mode == "ABodyBuilder2" else None,
440
+ mode,
441
+ )
442
+ result.warnings.extend(sanitized.warnings)
443
+
444
+ if not sanitized.success:
445
+ result.failure_reason = sanitized.failure_reason
446
+ fail_count += 1
447
+ _log(f"{prefix} FAIL sanitize reason={sanitized.failure_reason}")
448
+ results.append(result)
449
+ continue
450
+
451
+ if sanitized.warnings:
452
+ _log(f"{prefix} warning {','.join(sanitized.warnings)}")
453
+
454
+ cache_key = (sanitized.vh, sanitized.vl)
455
+ antibody = prediction_cache.get(cache_key)
456
+ if antibody is None:
457
+ try:
458
+ pred_t0 = time.time()
459
+ antibody = _predict_one(predictor, mode, sanitized.vh, sanitized.vl)
460
+ pred_dt = time.time() - pred_t0
461
+ prediction_cache[cache_key] = antibody
462
+ if (row_idx + 1) % log_every == 0 or row_idx == 0:
463
+ _log(f"{prefix} predicted in {pred_dt:.1f}s")
464
+ except Exception as exc: # noqa: BLE001
465
+ fail_count += 1
466
+ _log(f"{prefix} FAIL ImmuneBuilder {type(exc).__name__}: {exc}")
467
+ traceback.print_exc(file=sys.stderr)
468
+ result.failure_reason = f"immunebuilder_exception:{type(exc).__name__}"
469
+ results.append(result)
470
+ continue
471
+ else:
472
+ cache_hits += 1
473
+ _log(f"{prefix} dedup-cache hit (skipping ImmuneBuilder call)")
474
+
475
+ pdb_filename = indexed_filename(row_idx)
476
+ raw_pdb_path = pdb_dir / ("_raw_" + pdb_filename)
477
+ final_pdb_path = pdb_dir / pdb_filename
478
+ try:
479
+ antibody.save(str(raw_pdb_path))
480
+ augment_pdb(
481
+ raw_pdb_path,
482
+ final_pdb_path,
483
+ mode=mode,
484
+ immunebuilder_version=ib_version,
485
+ torch_seed=seed,
486
+ block_version=block_version,
487
+ numbering_scheme="imgt",
488
+ )
489
+ try:
490
+ raw_pdb_path.unlink()
491
+ except FileNotFoundError:
492
+ pass
493
+ except Exception as exc: # noqa: BLE001
494
+ fail_count += 1
495
+ _log(f"{prefix} FAIL save/augment {type(exc).__name__}: {exc}")
496
+ traceback.print_exc(file=sys.stderr)
497
+ result.failure_reason = f"immunebuilder_exception:{type(exc).__name__}"
498
+ results.append(result)
499
+ continue
500
+
501
+ residues = _per_residue_records(antibody)
502
+ per_res_json = _json_safe_per_residue(residues)
503
+ all_err = [r["errorAngstroms"] for r in residues]
504
+ mean_err = _mean(all_err)
505
+ cdrh1 = _mean(_region_errors(residues, "H", "CDR1"))
506
+ cdrh2 = _mean(_region_errors(residues, "H", "CDR2"))
507
+ cdrh3 = _mean(_region_errors(residues, "H", "CDR3"))
508
+
509
+ result.per_residue_json = per_res_json
510
+ result.mean_error = _format_number(mean_err)
511
+ result.cdrh1 = _format_number(cdrh1)
512
+ result.cdrh2 = _format_number(cdrh2)
513
+ result.cdrh3 = _format_number(cdrh3)
514
+
515
+ if mode == "ABodyBuilder2":
516
+ result.cdrl1 = _format_number(_mean(_region_errors(residues, "L", "CDR1")))
517
+ result.cdrl2 = _format_number(_mean(_region_errors(residues, "L", "CDR2")))
518
+ result.cdrl3 = _format_number(_mean(_region_errors(residues, "L", "CDR3")))
519
+
520
+ numbered = extract_numbered_residues(antibody)
521
+ nb = imgt_cdrh3_length(numbered)
522
+ result.cdrh3_len = _format_int(nb)
523
+ if nb >= 20:
524
+ result.warnings.append("long_cdrh3")
525
+ if mode == "NanoBodyBuilder2" and not vhh_hallmarks_present(numbered):
526
+ result.warnings.append("vhh_hallmarks_missing")
527
+ result.pdb_filename = pdb_filename
528
+
529
+ success_count += 1
530
+ if (row_idx + 1) % log_every == 0 or row_idx == 0 or row_idx == n - 1:
531
+ _log(
532
+ f"{prefix} OK mean={result.mean_error}Å cdrh3={result.cdrh3}Å "
533
+ f"cdrh3Length={result.cdrh3_len}"
534
+ )
535
+
536
+ results.append(result)
537
+
538
+ # Preserve the input key-column header in our outputs. The batch
539
+ # orchestrator hands us TSVs whose key column is named after the axis spec
540
+ # (e.g. `pl7.app/vdj/clonotypeKey`); when batches are concatenated, the
541
+ # downstream xsv import expects that same header back.
542
+ manifest_fields = [key_col if f == KEY_COLUMN_PLACEHOLDER else f for f in MANIFEST_FIELDS]
543
+ confidence_fields = build_confidence_fields(key_col)
544
+
545
+ with open(manifest_tsv, "w", newline="") as f:
546
+ writer = csv.DictWriter(
547
+ f, fieldnames=manifest_fields, delimiter="\t", lineterminator="\n"
548
+ )
549
+ writer.writeheader()
550
+ for r in results:
551
+ if r.pdb_filename:
552
+ writer.writerow({key_col: r.clonotype_key, "pdb_filename": r.pdb_filename})
553
+
554
+ with open(confidence_tsv, "w", newline="") as f:
555
+ writer = csv.DictWriter(f, fieldnames=confidence_fields, delimiter="\t")
556
+ writer.writeheader()
557
+ for r in results:
558
+ writer.writerow(r.to_tsv_row(key_col))
559
+
560
+ summary = _build_summary(results, metric, threshold)
561
+ if summary_json is not None:
562
+ summary_json.parent.mkdir(parents=True, exist_ok=True)
563
+ with open(summary_json, "w") as f:
564
+ json.dump(summary, f)
565
+
566
+ elapsed = time.time() - run_t0 if rows else 0.0
567
+ _log(
568
+ f"done total={summary['totalRows']} succeeded={summary['succeeded']} "
569
+ f"failed={summary['failed']} confident={summary['confidentCount']} "
570
+ f"cache_hits={cache_hits} elapsed={elapsed:.1f}s"
571
+ )
572
+ if summary["byFailureReason"]:
573
+ for reason, n_fail in sorted(
574
+ summary["byFailureReason"].items(), key=lambda kv: -kv[1]
575
+ ):
576
+ _log(f" failure: {n_fail} × {reason}")
577
+ if summary["byWarning"]:
578
+ for warning, n_warn in sorted(
579
+ summary["byWarning"].items(), key=lambda kv: -kv[1]
580
+ ):
581
+ _log(f" warning: {n_warn} × {warning}")
582
+
583
+
584
+ def warmup(mode: str, sentinel: Path | None) -> None:
585
+ """Force the ImmuneBuilder weight download into a known cache location.
586
+
587
+ Run as a single pre-step before the parallel batch fan-out. Avoids the
588
+ race where multiple batch containers download the same weight files into
589
+ a shared cache dir simultaneously, producing partial / corrupt files.
590
+ """
591
+ _log(f"warmup mode={mode} loading predictor (this may download weights)")
592
+ _load_predictor(mode)
593
+ if sentinel is not None:
594
+ sentinel.write_text(
595
+ f"mode={mode}\nimmunebuilder_version={get_immunebuilder_version()}\n"
596
+ )
597
+ _log("warmup OK")
598
+
599
+
600
+ def main() -> None:
601
+ parser = argparse.ArgumentParser()
602
+ parser.add_argument("--mode", choices=["ABodyBuilder2", "NanoBodyBuilder2"], required=True)
603
+ parser.add_argument("--warmup", action="store_true",
604
+ help="Pre-download model weights and exit. --input/--output-dir/--manifest/--confidence are not used.")
605
+ parser.add_argument("--sentinel", default=None,
606
+ help="In --warmup mode, path to a sentinel file written on success.")
607
+ parser.add_argument("--input", help="Batch TSV with clonotypeKey, heavyChain[, lightChain]")
608
+ parser.add_argument("--output-dir", help="Directory for per-clonotype PDB files")
609
+ parser.add_argument("--manifest", help="Path to manifest.tsv")
610
+ parser.add_argument("--confidence", help="Path to confidence.tsv")
611
+ parser.add_argument("--summary", default=None, help="Path to summary.json (aggregate stats for the model)")
612
+ parser.add_argument("--seed", type=int, default=42)
613
+ parser.add_argument("--metric", choices=["cdrh3Mean", "overallMean"], default="cdrh3Mean")
614
+ parser.add_argument("--threshold", type=float, default=2.5,
615
+ help="Confidence threshold (Å) used to derive confidentCount in summary.json")
616
+ args = parser.parse_args()
617
+
618
+ if args.warmup:
619
+ warmup(args.mode, Path(args.sentinel) if args.sentinel else None)
620
+ return
621
+
622
+ missing = [
623
+ name for name, value in [
624
+ ("--input", args.input),
625
+ ("--output-dir", args.output_dir),
626
+ ("--manifest", args.manifest),
627
+ ("--confidence", args.confidence),
628
+ ] if not value
629
+ ]
630
+ if missing:
631
+ parser.error(f"the following arguments are required when not in --warmup mode: {', '.join(missing)}")
632
+
633
+ process_batch(
634
+ input_tsv=Path(args.input),
635
+ pdb_dir=Path(args.output_dir),
636
+ manifest_tsv=Path(args.manifest),
637
+ confidence_tsv=Path(args.confidence),
638
+ summary_json=Path(args.summary) if args.summary else None,
639
+ mode=args.mode,
640
+ seed=args.seed,
641
+ metric=args.metric,
642
+ threshold=args.threshold,
643
+ )
644
+
645
+
646
+ if __name__ == "__main__":
647
+ main()