variantfold 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- variantfold/__init__.py +41 -0
- variantfold/cli.py +161 -0
- variantfold/config.py +104 -0
- variantfold/graphs.py +224 -0
- variantfold/model.py +316 -0
- variantfold/pipeline.py +254 -0
- variantfold/structure.py +256 -0
- variantfold/variants.py +283 -0
- variantfold-0.1.0.dist-info/METADATA +151 -0
- variantfold-0.1.0.dist-info/RECORD +13 -0
- variantfold-0.1.0.dist-info/WHEEL +5 -0
- variantfold-0.1.0.dist-info/entry_points.txt +2 -0
- variantfold-0.1.0.dist-info/top_level.txt +1 -0
variantfold/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
VariantFold — Classify variants of uncertain significance using
|
|
3
|
+
AlphaFold-predicted protein structures and graph neural networks.
|
|
4
|
+
|
|
5
|
+
Workflow
|
|
6
|
+
--------
|
|
7
|
+
1. Parse ClinVar variant data
|
|
8
|
+
2. Mutate reference protein sequences
|
|
9
|
+
3. Predict 3-D structures with ColabFold / AlphaFold
|
|
10
|
+
4. Convert PDB structures to residue-level graphs
|
|
11
|
+
5. Train a GCN classifier (benign vs pathogenic)
|
|
12
|
+
6. Classify VUS with the trained model
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
__version__ = "0.1.0"
|
|
16
|
+
|
|
17
|
+
from variantfold.config import VariantFoldConfig
|
|
18
|
+
from variantfold.variants import (
|
|
19
|
+
parse_clinvar_variant,
|
|
20
|
+
load_clinvar_table,
|
|
21
|
+
swap_amino_acid,
|
|
22
|
+
generate_mutant_sequences,
|
|
23
|
+
)
|
|
24
|
+
from variantfold.graphs import pdb_to_graph, load_pdb_directory
|
|
25
|
+
from variantfold.model import VariantGCN, train_model, evaluate_model, predict_vus
|
|
26
|
+
from variantfold.pipeline import VariantFoldPipeline
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"VariantFoldConfig",
|
|
30
|
+
"parse_clinvar_variant",
|
|
31
|
+
"load_clinvar_table",
|
|
32
|
+
"swap_amino_acid",
|
|
33
|
+
"generate_mutant_sequences",
|
|
34
|
+
"pdb_to_graph",
|
|
35
|
+
"load_pdb_directory",
|
|
36
|
+
"VariantGCN",
|
|
37
|
+
"train_model",
|
|
38
|
+
"evaluate_model",
|
|
39
|
+
"predict_vus",
|
|
40
|
+
"VariantFoldPipeline",
|
|
41
|
+
]
|
variantfold/cli.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for VariantFold.
|
|
3
|
+
|
|
4
|
+
Usage
|
|
5
|
+
-----
|
|
6
|
+
variantfold run --gene VHL --email me@example.com --steps 1,3,4,5
|
|
7
|
+
variantfold predict --model model.pt --pdb-dir ./vus_library/
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import argparse
|
|
13
|
+
import logging
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
from variantfold import __version__
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _build_parser() -> argparse.ArgumentParser:
|
|
20
|
+
p = argparse.ArgumentParser(
|
|
21
|
+
prog="variantfold",
|
|
22
|
+
description="Classify VUS using AlphaFold structures and GNNs.",
|
|
23
|
+
)
|
|
24
|
+
p.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
|
|
25
|
+
sub = p.add_subparsers(dest="command")
|
|
26
|
+
|
|
27
|
+
# ---- run ----------------------------------------------------------------
|
|
28
|
+
run_p = sub.add_parser("run", help="Run the full or partial pipeline.")
|
|
29
|
+
run_p.add_argument("--gene", required=True, help="HGNC gene symbol.")
|
|
30
|
+
run_p.add_argument("--email", required=True, help="Email for NCBI Entrez.")
|
|
31
|
+
run_p.add_argument(
|
|
32
|
+
"--work-dir", default=None,
|
|
33
|
+
help="Working directory (default: ./variantfold_<gene>).",
|
|
34
|
+
)
|
|
35
|
+
run_p.add_argument(
|
|
36
|
+
"--steps", default="1,3,4,5",
|
|
37
|
+
help=(
|
|
38
|
+
"Comma-separated step numbers to run. "
|
|
39
|
+
"1=parse, 2=predict structures, 3=collect models, "
|
|
40
|
+
"4=train, 5=classify VUS. "
|
|
41
|
+
"Default: 1,3,4,5 (skip structure prediction)."
|
|
42
|
+
),
|
|
43
|
+
)
|
|
44
|
+
run_p.add_argument(
|
|
45
|
+
"--accession", default=None,
|
|
46
|
+
help="Protein accession number (auto-fetched if omitted).",
|
|
47
|
+
)
|
|
48
|
+
run_p.add_argument("--epochs", type=int, default=200)
|
|
49
|
+
run_p.add_argument("--lr", type=float, default=0.01)
|
|
50
|
+
run_p.add_argument("--batch-size", type=int, default=32)
|
|
51
|
+
run_p.add_argument("--distance-threshold", type=float, default=6.5)
|
|
52
|
+
run_p.add_argument("--hidden-dim", type=int, default=64)
|
|
53
|
+
run_p.add_argument("--num-layers", type=int, default=3)
|
|
54
|
+
run_p.add_argument(
|
|
55
|
+
"--legacy-features", action="store_true",
|
|
56
|
+
help="Use pLDDT-only features (1-dim) instead of rich features (24-dim).",
|
|
57
|
+
)
|
|
58
|
+
run_p.add_argument("--seed", type=int, default=42)
|
|
59
|
+
run_p.add_argument("-v", "--verbose", action="store_true")
|
|
60
|
+
|
|
61
|
+
# ---- predict (standalone inference) -------------------------------------
|
|
62
|
+
pred_p = sub.add_parser("predict", help="Classify PDBs with a trained model.")
|
|
63
|
+
pred_p.add_argument("--model", required=True, help="Path to .pt model file.")
|
|
64
|
+
pred_p.add_argument("--pdb-dir", required=True, help="Directory of VUS PDB files.")
|
|
65
|
+
pred_p.add_argument("--output", default="vus_predictions.csv")
|
|
66
|
+
pred_p.add_argument("--distance-threshold", type=float, default=6.5)
|
|
67
|
+
pred_p.add_argument("--batch-size", type=int, default=32)
|
|
68
|
+
pred_p.add_argument(
|
|
69
|
+
"--legacy-features", action="store_true",
|
|
70
|
+
help="Use pLDDT-only features (must match training mode).",
|
|
71
|
+
)
|
|
72
|
+
pred_p.add_argument("-v", "--verbose", action="store_true")
|
|
73
|
+
|
|
74
|
+
return p
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def main(argv=None) -> None:
|
|
78
|
+
parser = _build_parser()
|
|
79
|
+
args = parser.parse_args(argv)
|
|
80
|
+
|
|
81
|
+
if args.command is None:
|
|
82
|
+
parser.print_help()
|
|
83
|
+
sys.exit(1)
|
|
84
|
+
|
|
85
|
+
level = logging.DEBUG if getattr(args, "verbose", False) else logging.INFO
|
|
86
|
+
logging.basicConfig(
|
|
87
|
+
level=level,
|
|
88
|
+
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if args.command == "run":
|
|
92
|
+
_cmd_run(args)
|
|
93
|
+
elif args.command == "predict":
|
|
94
|
+
_cmd_predict(args)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _cmd_run(args) -> None:
|
|
98
|
+
from variantfold.config import VariantFoldConfig
|
|
99
|
+
from variantfold.pipeline import VariantFoldPipeline
|
|
100
|
+
|
|
101
|
+
cfg = VariantFoldConfig(
|
|
102
|
+
gene_symbol=args.gene,
|
|
103
|
+
entrez_email=args.email,
|
|
104
|
+
work_dir=args.work_dir,
|
|
105
|
+
accession_number=args.accession,
|
|
106
|
+
epochs=args.epochs,
|
|
107
|
+
learning_rate=args.lr,
|
|
108
|
+
batch_size=args.batch_size,
|
|
109
|
+
distance_threshold=args.distance_threshold,
|
|
110
|
+
gcn_hidden_dim=args.hidden_dim,
|
|
111
|
+
gcn_num_layers=args.num_layers,
|
|
112
|
+
use_residue_features=not args.legacy_features,
|
|
113
|
+
random_seed=args.seed,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
pipe = VariantFoldPipeline(cfg)
|
|
117
|
+
steps = {int(s.strip()) for s in args.steps.split(",")}
|
|
118
|
+
|
|
119
|
+
if 1 in steps:
|
|
120
|
+
pipe.step1_parse_variants()
|
|
121
|
+
if 2 in steps:
|
|
122
|
+
pipe.step2_predict_structures()
|
|
123
|
+
if 3 in steps:
|
|
124
|
+
pipe.step3_collect_models()
|
|
125
|
+
if 4 in steps:
|
|
126
|
+
metrics = pipe.step4_train()
|
|
127
|
+
print(f"\nTest accuracy: {metrics['accuracy']:.4f}")
|
|
128
|
+
print(f"Confusion matrix:\n{metrics['confusion_matrix']}")
|
|
129
|
+
if 5 in steps:
|
|
130
|
+
df = pipe.step5_classify_vus()
|
|
131
|
+
print(f"\nVUS predictions:\n{df.to_string(index=False)}")
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _cmd_predict(args) -> None:
|
|
135
|
+
import pandas as pd
|
|
136
|
+
|
|
137
|
+
from variantfold.graphs import load_pdb_directory
|
|
138
|
+
from variantfold.model import load_model, predict_vus
|
|
139
|
+
|
|
140
|
+
model = load_model(args.model)
|
|
141
|
+
use_rich = not args.legacy_features
|
|
142
|
+
|
|
143
|
+
graphs = load_pdb_directory(
|
|
144
|
+
args.pdb_dir, label=None,
|
|
145
|
+
distance_threshold=args.distance_threshold,
|
|
146
|
+
use_residue_features=use_rich,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
if not graphs:
|
|
150
|
+
print(f"No PDB files found in {args.pdb_dir}", file=sys.stderr)
|
|
151
|
+
sys.exit(1)
|
|
152
|
+
|
|
153
|
+
results = predict_vus(model, graphs, batch_size=args.batch_size)
|
|
154
|
+
df = pd.DataFrame(results)
|
|
155
|
+
df.to_csv(args.output, index=False)
|
|
156
|
+
print(f"Predictions saved to {args.output}")
|
|
157
|
+
print(df.to_string(index=False))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
main()
|
variantfold/config.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Central configuration for a VariantFold run.
|
|
3
|
+
|
|
4
|
+
All paths, thresholds, and hyper-parameters live here so that nothing
|
|
5
|
+
is hard-coded to Google Drive or Colab.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class VariantFoldConfig:
|
|
21
|
+
"""Settings for one end-to-end VariantFold analysis."""
|
|
22
|
+
|
|
23
|
+
# ---- Identity -----------------------------------------------------------
|
|
24
|
+
gene_symbol: str
|
|
25
|
+
entrez_email: str # required by NCBI Entrez
|
|
26
|
+
|
|
27
|
+
# ---- Paths (default to ./variantfold_<gene>/) ---------------------------
|
|
28
|
+
work_dir: Optional[str] = None # root working directory
|
|
29
|
+
|
|
30
|
+
# ---- Variant parsing ----------------------------------------------------
|
|
31
|
+
clinvar_benign_file: str = "clinvar_result_bng.txt"
|
|
32
|
+
clinvar_pathogenic_file: str = "clinvar_result_ptg.txt"
|
|
33
|
+
clinvar_vus_file: str = "clinvar_result_vus.txt"
|
|
34
|
+
|
|
35
|
+
# ---- Protein sequence ---------------------------------------------------
|
|
36
|
+
accession_number: Optional[str] = None # auto-fetched if None
|
|
37
|
+
|
|
38
|
+
# ---- Structure prediction (ColabFold) -----------------------------------
|
|
39
|
+
num_models: int = 5
|
|
40
|
+
num_relax: int = 0
|
|
41
|
+
msa_mode: str = "mmseqs2_uniref_env"
|
|
42
|
+
pair_mode: str = "unpaired_paired"
|
|
43
|
+
model_type: str = "auto"
|
|
44
|
+
num_recycles: Optional[int] = None # None = auto
|
|
45
|
+
recycle_early_stop_tolerance: Optional[float] = None
|
|
46
|
+
num_seeds: int = 1
|
|
47
|
+
use_dropout: bool = False
|
|
48
|
+
use_templates: bool = False
|
|
49
|
+
|
|
50
|
+
# ---- Graph construction -------------------------------------------------
|
|
51
|
+
distance_threshold: float = 6.5 # Å, residue contact threshold
|
|
52
|
+
use_residue_features: bool = True # one-hot AA + coords + pLDDT
|
|
53
|
+
|
|
54
|
+
# ---- GCN training -------------------------------------------------------
|
|
55
|
+
gcn_hidden_dim: int = 64
|
|
56
|
+
gcn_num_layers: int = 3
|
|
57
|
+
gcn_dropout: float = 0.5
|
|
58
|
+
learning_rate: float = 0.01
|
|
59
|
+
epochs: int = 200
|
|
60
|
+
batch_size: int = 32
|
|
61
|
+
train_fraction: float = 0.8
|
|
62
|
+
random_seed: int = 42
|
|
63
|
+
|
|
64
|
+
# ---- Derived paths (set in __post_init__) -------------------------------
|
|
65
|
+
benign_dir: str = field(init=False, repr=False)
|
|
66
|
+
pathogenic_dir: str = field(init=False, repr=False)
|
|
67
|
+
vus_dir: str = field(init=False, repr=False)
|
|
68
|
+
benign_library: str = field(init=False, repr=False)
|
|
69
|
+
pathogenic_library: str = field(init=False, repr=False)
|
|
70
|
+
vus_library: str = field(init=False, repr=False)
|
|
71
|
+
|
|
72
|
+
def __post_init__(self) -> None:
|
|
73
|
+
if self.work_dir is None:
|
|
74
|
+
self.work_dir = os.path.join(".", f"variantfold_{self.gene_symbol}")
|
|
75
|
+
|
|
76
|
+
root = Path(self.work_dir)
|
|
77
|
+
self.benign_dir = str(root / "Benign")
|
|
78
|
+
self.pathogenic_dir = str(root / "Pathogenic")
|
|
79
|
+
self.vus_dir = str(root / "VUS")
|
|
80
|
+
self.benign_library = str(root / "Benign" / "library_bng")
|
|
81
|
+
self.pathogenic_library = str(root / "Pathogenic" / "library_ptg")
|
|
82
|
+
self.vus_library = str(root / "VUS" / "library_vus")
|
|
83
|
+
|
|
84
|
+
# ---- Helpers ------------------------------------------------------------
|
|
85
|
+
def ensure_directories(self) -> None:
|
|
86
|
+
"""Create the full directory tree if it doesn't exist."""
|
|
87
|
+
for d in [
|
|
88
|
+
self.work_dir,
|
|
89
|
+
self.benign_dir,
|
|
90
|
+
self.pathogenic_dir,
|
|
91
|
+
self.vus_dir,
|
|
92
|
+
self.benign_library,
|
|
93
|
+
self.pathogenic_library,
|
|
94
|
+
self.vus_library,
|
|
95
|
+
]:
|
|
96
|
+
os.makedirs(d, exist_ok=True)
|
|
97
|
+
logger.debug("Ensured directory: %s", d)
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def num_node_features(self) -> int:
|
|
101
|
+
"""Number of features per graph node."""
|
|
102
|
+
if self.use_residue_features:
|
|
103
|
+
return 24 # 20 one-hot AA + 3 coords + 1 pLDDT
|
|
104
|
+
return 1 # pLDDT only (legacy mode)
|
variantfold/graphs.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convert PDB structure files to PyTorch Geometric residue-level graphs.
|
|
3
|
+
|
|
4
|
+
Fixes from audit
|
|
5
|
+
-----------------
|
|
6
|
+
- BUG-7 : Labels are set correctly per-graph (no more hardcoded y=1).
|
|
7
|
+
- BUG-8 : VUS samples carry label=None rather than fake class 2.
|
|
8
|
+
- DESIGN-3: Rich node features (one-hot AA + 3-D coords + pLDDT).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
import torch
|
|
19
|
+
from torch_geometric.data import Data
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# ---- Amino-acid one-hot encoding -------------------------------------------
|
|
24
|
+
|
|
25
|
+
_AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"
|
|
26
|
+
_AA_INDEX = {aa: i for i, aa in enumerate(_AA_ORDER)}
|
|
27
|
+
|
|
28
|
+
# Mapping 3-letter codes to 1-letter for PDB ATOM records
|
|
29
|
+
_AA3_TO_1 = {
|
|
30
|
+
"ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F",
|
|
31
|
+
"GLY": "G", "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L",
|
|
32
|
+
"MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R",
|
|
33
|
+
"SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y",
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _one_hot_aa(resname_3: str) -> np.ndarray:
|
|
38
|
+
"""Return a length-20 one-hot vector for the amino acid."""
|
|
39
|
+
vec = np.zeros(20, dtype=np.float32)
|
|
40
|
+
one = _AA3_TO_1.get(resname_3.strip().upper())
|
|
41
|
+
if one is not None and one in _AA_INDEX:
|
|
42
|
+
vec[_AA_INDEX[one]] = 1.0
|
|
43
|
+
return vec
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ---- Distance matrix -------------------------------------------------------
|
|
47
|
+
|
|
48
|
+
def _distance_matrix(coords: np.ndarray) -> np.ndarray:
|
|
49
|
+
"""Pairwise Euclidean distance between rows of *coords*."""
|
|
50
|
+
diff = np.expand_dims(coords, 1) - np.expand_dims(coords, 0)
|
|
51
|
+
return np.sqrt((diff ** 2).sum(axis=-1))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
# ---- Core graph builder -----------------------------------------------------
|
|
55
|
+
|
|
56
|
+
def pdb_to_graph(
|
|
57
|
+
pdb_path: str,
|
|
58
|
+
distance_threshold: float = 6.5,
|
|
59
|
+
use_residue_features: bool = True,
|
|
60
|
+
label: Optional[int] = None,
|
|
61
|
+
) -> Data:
|
|
62
|
+
"""Convert a PDB file to a PyTorch Geometric ``Data`` object.
|
|
63
|
+
|
|
64
|
+
Nodes correspond to residues. An edge is added between two
|
|
65
|
+
residues whose Cα atoms are within *distance_threshold* Å.
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
pdb_path : str
|
|
70
|
+
Path to a ``.pdb`` file.
|
|
71
|
+
distance_threshold : float
|
|
72
|
+
Contact distance cutoff in Ångströms.
|
|
73
|
+
use_residue_features : bool
|
|
74
|
+
If True, each node carries 24 features (one-hot AA, x/y/z, pLDDT).
|
|
75
|
+
If False, each node carries only pLDDT (legacy 1-feature mode).
|
|
76
|
+
label : int or None
|
|
77
|
+
Graph-level class label (0 = benign, 1 = pathogenic, None = VUS).
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
torch_geometric.data.Data
|
|
82
|
+
"""
|
|
83
|
+
from biopandas.pdb import PandasPdb
|
|
84
|
+
|
|
85
|
+
ppdb = PandasPdb().read_pdb(str(pdb_path))
|
|
86
|
+
atom_df = ppdb.df["ATOM"]
|
|
87
|
+
|
|
88
|
+
# Aggregate per residue
|
|
89
|
+
residue_groups = atom_df.groupby(
|
|
90
|
+
"residue_number", as_index=False,
|
|
91
|
+
)
|
|
92
|
+
residue_df = residue_groups[
|
|
93
|
+
["x_coord", "y_coord", "z_coord", "b_factor"]
|
|
94
|
+
].mean().sort_values("residue_number")
|
|
95
|
+
|
|
96
|
+
# Also get the residue name for one-hot encoding
|
|
97
|
+
resnames = (
|
|
98
|
+
atom_df.groupby("residue_number", as_index=False)["residue_name"]
|
|
99
|
+
.first()
|
|
100
|
+
.sort_values("residue_number")["residue_name"]
|
|
101
|
+
.values
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
coords = residue_df[["x_coord", "y_coord", "z_coord"]].values
|
|
105
|
+
plddt = residue_df["b_factor"].values
|
|
106
|
+
n_residues = len(coords)
|
|
107
|
+
|
|
108
|
+
# --- Build node features -------------------------------------------------
|
|
109
|
+
if use_residue_features:
|
|
110
|
+
# 20 one-hot AA + 3 normalised coords + 1 pLDDT
|
|
111
|
+
one_hot = np.array([_one_hot_aa(r) for r in resnames], dtype=np.float32)
|
|
112
|
+
|
|
113
|
+
# Normalise coordinates (zero-centre)
|
|
114
|
+
normed_coords = (coords - coords.mean(axis=0)).astype(np.float32)
|
|
115
|
+
|
|
116
|
+
# Normalise pLDDT to [0, 1]
|
|
117
|
+
plddt_norm = (plddt / 100.0).reshape(-1, 1).astype(np.float32)
|
|
118
|
+
|
|
119
|
+
features = np.concatenate([one_hot, normed_coords, plddt_norm], axis=1)
|
|
120
|
+
else:
|
|
121
|
+
features = plddt.reshape(-1, 1).astype(np.float32)
|
|
122
|
+
|
|
123
|
+
x = torch.from_numpy(features)
|
|
124
|
+
|
|
125
|
+
# --- Build adjacency (contact map) ---------------------------------------
|
|
126
|
+
dist_mat = _distance_matrix(coords)
|
|
127
|
+
adj = dist_mat < distance_threshold
|
|
128
|
+
np.fill_diagonal(adj, False)
|
|
129
|
+
src, dst = np.nonzero(adj)
|
|
130
|
+
edge_index = torch.tensor(
|
|
131
|
+
np.stack([src, dst]), dtype=torch.long,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# --- Construct Data object -----------------------------------------------
|
|
135
|
+
y = torch.tensor([label], dtype=torch.long) if label is not None else None
|
|
136
|
+
data = Data(x=x, edge_index=edge_index, y=y)
|
|
137
|
+
data.pdb_path = str(pdb_path)
|
|
138
|
+
data.num_residues = n_residues
|
|
139
|
+
return data
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ---- Directory loader -------------------------------------------------------
|
|
143
|
+
|
|
144
|
+
def load_pdb_directory(
|
|
145
|
+
directory: str,
|
|
146
|
+
label: Optional[int] = None,
|
|
147
|
+
distance_threshold: float = 6.5,
|
|
148
|
+
use_residue_features: bool = True,
|
|
149
|
+
filename_pattern: str = "*.pdb",
|
|
150
|
+
) -> List[Data]:
|
|
151
|
+
"""Load all PDB files from a directory and convert to graphs.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
directory : str
|
|
156
|
+
Path to a folder containing ``.pdb`` files.
|
|
157
|
+
label : int or None
|
|
158
|
+
Class label to assign to every graph in the directory.
|
|
159
|
+
distance_threshold : float
|
|
160
|
+
Contact cutoff (Å).
|
|
161
|
+
use_residue_features : bool
|
|
162
|
+
Whether to use rich (24-dim) or minimal (1-dim) features.
|
|
163
|
+
filename_pattern : str
|
|
164
|
+
Glob pattern to match PDB files.
|
|
165
|
+
|
|
166
|
+
Returns
|
|
167
|
+
-------
|
|
168
|
+
list of Data
|
|
169
|
+
"""
|
|
170
|
+
dirpath = Path(directory)
|
|
171
|
+
if not dirpath.is_dir():
|
|
172
|
+
raise FileNotFoundError(f"Directory not found: {directory}")
|
|
173
|
+
|
|
174
|
+
pdb_files = sorted(dirpath.glob(filename_pattern))
|
|
175
|
+
if not pdb_files:
|
|
176
|
+
logger.warning("No PDB files matching %r in %s", filename_pattern, directory)
|
|
177
|
+
return []
|
|
178
|
+
|
|
179
|
+
graphs: list[Data] = []
|
|
180
|
+
for pdb_file in pdb_files:
|
|
181
|
+
try:
|
|
182
|
+
g = pdb_to_graph(
|
|
183
|
+
str(pdb_file),
|
|
184
|
+
distance_threshold=distance_threshold,
|
|
185
|
+
use_residue_features=use_residue_features,
|
|
186
|
+
label=label,
|
|
187
|
+
)
|
|
188
|
+
graphs.append(g)
|
|
189
|
+
except Exception as exc:
|
|
190
|
+
logger.warning("Failed to convert %s: %s", pdb_file.name, exc)
|
|
191
|
+
|
|
192
|
+
logger.info(
|
|
193
|
+
"Loaded %d graphs from %s (label=%s)", len(graphs), directory, label,
|
|
194
|
+
)
|
|
195
|
+
return graphs
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def collect_best_models(
|
|
199
|
+
source_dir: str,
|
|
200
|
+
dest_dir: str,
|
|
201
|
+
pattern: str = "*model_1_seed_000.pdb",
|
|
202
|
+
) -> List[Path]:
|
|
203
|
+
"""Find the best-ranked PDB from each ColabFold job and copy to *dest_dir*.
|
|
204
|
+
|
|
205
|
+
This replaces the notebook's ``search_and_move_files`` function.
|
|
206
|
+
Files are *copied* (not moved) to avoid destructive side-effects.
|
|
207
|
+
|
|
208
|
+
Returns the list of destination paths.
|
|
209
|
+
"""
|
|
210
|
+
import shutil
|
|
211
|
+
|
|
212
|
+
src = Path(source_dir)
|
|
213
|
+
dst = Path(dest_dir)
|
|
214
|
+
dst.mkdir(parents=True, exist_ok=True)
|
|
215
|
+
|
|
216
|
+
copied: list[Path] = []
|
|
217
|
+
for pdb in src.rglob(pattern):
|
|
218
|
+
dest_path = dst / pdb.name
|
|
219
|
+
shutil.copy2(pdb, dest_path)
|
|
220
|
+
copied.append(dest_path)
|
|
221
|
+
logger.debug("Copied %s → %s", pdb, dest_path)
|
|
222
|
+
|
|
223
|
+
logger.info("Collected %d PDB models into %s", len(copied), dest_dir)
|
|
224
|
+
return copied
|