variantfold 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,41 @@
1
+ """
2
+ VariantFold — Classify variants of uncertain significance using
3
+ AlphaFold-predicted protein structures and graph neural networks.
4
+
5
+ Workflow
6
+ --------
7
+ 1. Parse ClinVar variant data
8
+ 2. Mutate reference protein sequences
9
+ 3. Predict 3-D structures with ColabFold / AlphaFold
10
+ 4. Convert PDB structures to residue-level graphs
11
+ 5. Train a GCN classifier (benign vs pathogenic)
12
+ 6. Classify VUS with the trained model
13
+ """
14
+
15
+ __version__ = "0.1.0"
16
+
17
+ from variantfold.config import VariantFoldConfig
18
+ from variantfold.variants import (
19
+ parse_clinvar_variant,
20
+ load_clinvar_table,
21
+ swap_amino_acid,
22
+ generate_mutant_sequences,
23
+ )
24
+ from variantfold.graphs import pdb_to_graph, load_pdb_directory
25
+ from variantfold.model import VariantGCN, train_model, evaluate_model, predict_vus
26
+ from variantfold.pipeline import VariantFoldPipeline
27
+
28
+ __all__ = [
29
+ "VariantFoldConfig",
30
+ "parse_clinvar_variant",
31
+ "load_clinvar_table",
32
+ "swap_amino_acid",
33
+ "generate_mutant_sequences",
34
+ "pdb_to_graph",
35
+ "load_pdb_directory",
36
+ "VariantGCN",
37
+ "train_model",
38
+ "evaluate_model",
39
+ "predict_vus",
40
+ "VariantFoldPipeline",
41
+ ]
variantfold/cli.py ADDED
@@ -0,0 +1,161 @@
1
+ """
2
+ Command-line interface for VariantFold.
3
+
4
+ Usage
5
+ -----
6
+ variantfold run --gene VHL --email me@example.com --steps 1,3,4,5
7
+ variantfold predict --model model.pt --pdb-dir ./vus_library/
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+
16
+ from variantfold import __version__
17
+
18
+
19
+ def _build_parser() -> argparse.ArgumentParser:
20
+ p = argparse.ArgumentParser(
21
+ prog="variantfold",
22
+ description="Classify VUS using AlphaFold structures and GNNs.",
23
+ )
24
+ p.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
25
+ sub = p.add_subparsers(dest="command")
26
+
27
+ # ---- run ----------------------------------------------------------------
28
+ run_p = sub.add_parser("run", help="Run the full or partial pipeline.")
29
+ run_p.add_argument("--gene", required=True, help="HGNC gene symbol.")
30
+ run_p.add_argument("--email", required=True, help="Email for NCBI Entrez.")
31
+ run_p.add_argument(
32
+ "--work-dir", default=None,
33
+ help="Working directory (default: ./variantfold_<gene>).",
34
+ )
35
+ run_p.add_argument(
36
+ "--steps", default="1,3,4,5",
37
+ help=(
38
+ "Comma-separated step numbers to run. "
39
+ "1=parse, 2=predict structures, 3=collect models, "
40
+ "4=train, 5=classify VUS. "
41
+ "Default: 1,3,4,5 (skip structure prediction)."
42
+ ),
43
+ )
44
+ run_p.add_argument(
45
+ "--accession", default=None,
46
+ help="Protein accession number (auto-fetched if omitted).",
47
+ )
48
+ run_p.add_argument("--epochs", type=int, default=200)
49
+ run_p.add_argument("--lr", type=float, default=0.01)
50
+ run_p.add_argument("--batch-size", type=int, default=32)
51
+ run_p.add_argument("--distance-threshold", type=float, default=6.5)
52
+ run_p.add_argument("--hidden-dim", type=int, default=64)
53
+ run_p.add_argument("--num-layers", type=int, default=3)
54
+ run_p.add_argument(
55
+ "--legacy-features", action="store_true",
56
+ help="Use pLDDT-only features (1-dim) instead of rich features (24-dim).",
57
+ )
58
+ run_p.add_argument("--seed", type=int, default=42)
59
+ run_p.add_argument("-v", "--verbose", action="store_true")
60
+
61
+ # ---- predict (standalone inference) -------------------------------------
62
+ pred_p = sub.add_parser("predict", help="Classify PDBs with a trained model.")
63
+ pred_p.add_argument("--model", required=True, help="Path to .pt model file.")
64
+ pred_p.add_argument("--pdb-dir", required=True, help="Directory of VUS PDB files.")
65
+ pred_p.add_argument("--output", default="vus_predictions.csv")
66
+ pred_p.add_argument("--distance-threshold", type=float, default=6.5)
67
+ pred_p.add_argument("--batch-size", type=int, default=32)
68
+ pred_p.add_argument(
69
+ "--legacy-features", action="store_true",
70
+ help="Use pLDDT-only features (must match training mode).",
71
+ )
72
+ pred_p.add_argument("-v", "--verbose", action="store_true")
73
+
74
+ return p
75
+
76
+
77
+ def main(argv=None) -> None:
78
+ parser = _build_parser()
79
+ args = parser.parse_args(argv)
80
+
81
+ if args.command is None:
82
+ parser.print_help()
83
+ sys.exit(1)
84
+
85
+ level = logging.DEBUG if getattr(args, "verbose", False) else logging.INFO
86
+ logging.basicConfig(
87
+ level=level,
88
+ format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
89
+ )
90
+
91
+ if args.command == "run":
92
+ _cmd_run(args)
93
+ elif args.command == "predict":
94
+ _cmd_predict(args)
95
+
96
+
97
+ def _cmd_run(args) -> None:
98
+ from variantfold.config import VariantFoldConfig
99
+ from variantfold.pipeline import VariantFoldPipeline
100
+
101
+ cfg = VariantFoldConfig(
102
+ gene_symbol=args.gene,
103
+ entrez_email=args.email,
104
+ work_dir=args.work_dir,
105
+ accession_number=args.accession,
106
+ epochs=args.epochs,
107
+ learning_rate=args.lr,
108
+ batch_size=args.batch_size,
109
+ distance_threshold=args.distance_threshold,
110
+ gcn_hidden_dim=args.hidden_dim,
111
+ gcn_num_layers=args.num_layers,
112
+ use_residue_features=not args.legacy_features,
113
+ random_seed=args.seed,
114
+ )
115
+
116
+ pipe = VariantFoldPipeline(cfg)
117
+ steps = {int(s.strip()) for s in args.steps.split(",")}
118
+
119
+ if 1 in steps:
120
+ pipe.step1_parse_variants()
121
+ if 2 in steps:
122
+ pipe.step2_predict_structures()
123
+ if 3 in steps:
124
+ pipe.step3_collect_models()
125
+ if 4 in steps:
126
+ metrics = pipe.step4_train()
127
+ print(f"\nTest accuracy: {metrics['accuracy']:.4f}")
128
+ print(f"Confusion matrix:\n{metrics['confusion_matrix']}")
129
+ if 5 in steps:
130
+ df = pipe.step5_classify_vus()
131
+ print(f"\nVUS predictions:\n{df.to_string(index=False)}")
132
+
133
+
134
+ def _cmd_predict(args) -> None:
135
+ import pandas as pd
136
+
137
+ from variantfold.graphs import load_pdb_directory
138
+ from variantfold.model import load_model, predict_vus
139
+
140
+ model = load_model(args.model)
141
+ use_rich = not args.legacy_features
142
+
143
+ graphs = load_pdb_directory(
144
+ args.pdb_dir, label=None,
145
+ distance_threshold=args.distance_threshold,
146
+ use_residue_features=use_rich,
147
+ )
148
+
149
+ if not graphs:
150
+ print(f"No PDB files found in {args.pdb_dir}", file=sys.stderr)
151
+ sys.exit(1)
152
+
153
+ results = predict_vus(model, graphs, batch_size=args.batch_size)
154
+ df = pd.DataFrame(results)
155
+ df.to_csv(args.output, index=False)
156
+ print(f"Predictions saved to {args.output}")
157
+ print(df.to_string(index=False))
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
variantfold/config.py ADDED
@@ -0,0 +1,104 @@
1
+ """
2
+ Central configuration for a VariantFold run.
3
+
4
+ All paths, thresholds, and hyper-parameters live here so that nothing
5
+ is hard-coded to Google Drive or Colab.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ import os
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class VariantFoldConfig:
21
+ """Settings for one end-to-end VariantFold analysis."""
22
+
23
+ # ---- Identity -----------------------------------------------------------
24
+ gene_symbol: str
25
+ entrez_email: str # required by NCBI Entrez
26
+
27
+ # ---- Paths (default to ./variantfold_<gene>/) ---------------------------
28
+ work_dir: Optional[str] = None # root working directory
29
+
30
+ # ---- Variant parsing ----------------------------------------------------
31
+ clinvar_benign_file: str = "clinvar_result_bng.txt"
32
+ clinvar_pathogenic_file: str = "clinvar_result_ptg.txt"
33
+ clinvar_vus_file: str = "clinvar_result_vus.txt"
34
+
35
+ # ---- Protein sequence ---------------------------------------------------
36
+ accession_number: Optional[str] = None # auto-fetched if None
37
+
38
+ # ---- Structure prediction (ColabFold) -----------------------------------
39
+ num_models: int = 5
40
+ num_relax: int = 0
41
+ msa_mode: str = "mmseqs2_uniref_env"
42
+ pair_mode: str = "unpaired_paired"
43
+ model_type: str = "auto"
44
+ num_recycles: Optional[int] = None # None = auto
45
+ recycle_early_stop_tolerance: Optional[float] = None
46
+ num_seeds: int = 1
47
+ use_dropout: bool = False
48
+ use_templates: bool = False
49
+
50
+ # ---- Graph construction -------------------------------------------------
51
+ distance_threshold: float = 6.5 # Å, residue contact threshold
52
+ use_residue_features: bool = True # one-hot AA + coords + pLDDT
53
+
54
+ # ---- GCN training -------------------------------------------------------
55
+ gcn_hidden_dim: int = 64
56
+ gcn_num_layers: int = 3
57
+ gcn_dropout: float = 0.5
58
+ learning_rate: float = 0.01
59
+ epochs: int = 200
60
+ batch_size: int = 32
61
+ train_fraction: float = 0.8
62
+ random_seed: int = 42
63
+
64
+ # ---- Derived paths (set in __post_init__) -------------------------------
65
+ benign_dir: str = field(init=False, repr=False)
66
+ pathogenic_dir: str = field(init=False, repr=False)
67
+ vus_dir: str = field(init=False, repr=False)
68
+ benign_library: str = field(init=False, repr=False)
69
+ pathogenic_library: str = field(init=False, repr=False)
70
+ vus_library: str = field(init=False, repr=False)
71
+
72
+ def __post_init__(self) -> None:
73
+ if self.work_dir is None:
74
+ self.work_dir = os.path.join(".", f"variantfold_{self.gene_symbol}")
75
+
76
+ root = Path(self.work_dir)
77
+ self.benign_dir = str(root / "Benign")
78
+ self.pathogenic_dir = str(root / "Pathogenic")
79
+ self.vus_dir = str(root / "VUS")
80
+ self.benign_library = str(root / "Benign" / "library_bng")
81
+ self.pathogenic_library = str(root / "Pathogenic" / "library_ptg")
82
+ self.vus_library = str(root / "VUS" / "library_vus")
83
+
84
+ # ---- Helpers ------------------------------------------------------------
85
+ def ensure_directories(self) -> None:
86
+ """Create the full directory tree if it doesn't exist."""
87
+ for d in [
88
+ self.work_dir,
89
+ self.benign_dir,
90
+ self.pathogenic_dir,
91
+ self.vus_dir,
92
+ self.benign_library,
93
+ self.pathogenic_library,
94
+ self.vus_library,
95
+ ]:
96
+ os.makedirs(d, exist_ok=True)
97
+ logger.debug("Ensured directory: %s", d)
98
+
99
+ @property
100
+ def num_node_features(self) -> int:
101
+ """Number of features per graph node."""
102
+ if self.use_residue_features:
103
+ return 24 # 20 one-hot AA + 3 coords + 1 pLDDT
104
+ return 1 # pLDDT only (legacy mode)
variantfold/graphs.py ADDED
@@ -0,0 +1,224 @@
1
+ """
2
+ Convert PDB structure files to PyTorch Geometric residue-level graphs.
3
+
4
+ Fixes from audit
5
+ -----------------
6
+ - BUG-7 : Labels are set correctly per-graph (no more hardcoded y=1).
7
+ - BUG-8 : VUS samples carry label=None rather than fake class 2.
8
+ - DESIGN-3: Rich node features (one-hot AA + 3-D coords + pLDDT).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import List, Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ from torch_geometric.data import Data
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ---- Amino-acid one-hot encoding -------------------------------------------
24
+
25
+ _AA_ORDER = "ACDEFGHIKLMNPQRSTVWY"
26
+ _AA_INDEX = {aa: i for i, aa in enumerate(_AA_ORDER)}
27
+
28
+ # Mapping 3-letter codes to 1-letter for PDB ATOM records
29
+ _AA3_TO_1 = {
30
+ "ALA": "A", "CYS": "C", "ASP": "D", "GLU": "E", "PHE": "F",
31
+ "GLY": "G", "HIS": "H", "ILE": "I", "LYS": "K", "LEU": "L",
32
+ "MET": "M", "ASN": "N", "PRO": "P", "GLN": "Q", "ARG": "R",
33
+ "SER": "S", "THR": "T", "VAL": "V", "TRP": "W", "TYR": "Y",
34
+ }
35
+
36
+
37
+ def _one_hot_aa(resname_3: str) -> np.ndarray:
38
+ """Return a length-20 one-hot vector for the amino acid."""
39
+ vec = np.zeros(20, dtype=np.float32)
40
+ one = _AA3_TO_1.get(resname_3.strip().upper())
41
+ if one is not None and one in _AA_INDEX:
42
+ vec[_AA_INDEX[one]] = 1.0
43
+ return vec
44
+
45
+
46
+ # ---- Distance matrix -------------------------------------------------------
47
+
48
+ def _distance_matrix(coords: np.ndarray) -> np.ndarray:
49
+ """Pairwise Euclidean distance between rows of *coords*."""
50
+ diff = np.expand_dims(coords, 1) - np.expand_dims(coords, 0)
51
+ return np.sqrt((diff ** 2).sum(axis=-1))
52
+
53
+
54
+ # ---- Core graph builder -----------------------------------------------------
55
+
56
+ def pdb_to_graph(
57
+ pdb_path: str,
58
+ distance_threshold: float = 6.5,
59
+ use_residue_features: bool = True,
60
+ label: Optional[int] = None,
61
+ ) -> Data:
62
+ """Convert a PDB file to a PyTorch Geometric ``Data`` object.
63
+
64
+ Nodes correspond to residues. An edge is added between two
65
+ residues whose Cα atoms are within *distance_threshold* Å.
66
+
67
+ Parameters
68
+ ----------
69
+ pdb_path : str
70
+ Path to a ``.pdb`` file.
71
+ distance_threshold : float
72
+ Contact distance cutoff in Ångströms.
73
+ use_residue_features : bool
74
+ If True, each node carries 24 features (one-hot AA, x/y/z, pLDDT).
75
+ If False, each node carries only pLDDT (legacy 1-feature mode).
76
+ label : int or None
77
+ Graph-level class label (0 = benign, 1 = pathogenic, None = VUS).
78
+
79
+ Returns
80
+ -------
81
+ torch_geometric.data.Data
82
+ """
83
+ from biopandas.pdb import PandasPdb
84
+
85
+ ppdb = PandasPdb().read_pdb(str(pdb_path))
86
+ atom_df = ppdb.df["ATOM"]
87
+
88
+ # Aggregate per residue
89
+ residue_groups = atom_df.groupby(
90
+ "residue_number", as_index=False,
91
+ )
92
+ residue_df = residue_groups[
93
+ ["x_coord", "y_coord", "z_coord", "b_factor"]
94
+ ].mean().sort_values("residue_number")
95
+
96
+ # Also get the residue name for one-hot encoding
97
+ resnames = (
98
+ atom_df.groupby("residue_number", as_index=False)["residue_name"]
99
+ .first()
100
+ .sort_values("residue_number")["residue_name"]
101
+ .values
102
+ )
103
+
104
+ coords = residue_df[["x_coord", "y_coord", "z_coord"]].values
105
+ plddt = residue_df["b_factor"].values
106
+ n_residues = len(coords)
107
+
108
+ # --- Build node features -------------------------------------------------
109
+ if use_residue_features:
110
+ # 20 one-hot AA + 3 normalised coords + 1 pLDDT
111
+ one_hot = np.array([_one_hot_aa(r) for r in resnames], dtype=np.float32)
112
+
113
+ # Normalise coordinates (zero-centre)
114
+ normed_coords = (coords - coords.mean(axis=0)).astype(np.float32)
115
+
116
+ # Normalise pLDDT to [0, 1]
117
+ plddt_norm = (plddt / 100.0).reshape(-1, 1).astype(np.float32)
118
+
119
+ features = np.concatenate([one_hot, normed_coords, plddt_norm], axis=1)
120
+ else:
121
+ features = plddt.reshape(-1, 1).astype(np.float32)
122
+
123
+ x = torch.from_numpy(features)
124
+
125
+ # --- Build adjacency (contact map) ---------------------------------------
126
+ dist_mat = _distance_matrix(coords)
127
+ adj = dist_mat < distance_threshold
128
+ np.fill_diagonal(adj, False)
129
+ src, dst = np.nonzero(adj)
130
+ edge_index = torch.tensor(
131
+ np.stack([src, dst]), dtype=torch.long,
132
+ )
133
+
134
+ # --- Construct Data object -----------------------------------------------
135
+ y = torch.tensor([label], dtype=torch.long) if label is not None else None
136
+ data = Data(x=x, edge_index=edge_index, y=y)
137
+ data.pdb_path = str(pdb_path)
138
+ data.num_residues = n_residues
139
+ return data
140
+
141
+
142
+ # ---- Directory loader -------------------------------------------------------
143
+
144
+ def load_pdb_directory(
145
+ directory: str,
146
+ label: Optional[int] = None,
147
+ distance_threshold: float = 6.5,
148
+ use_residue_features: bool = True,
149
+ filename_pattern: str = "*.pdb",
150
+ ) -> List[Data]:
151
+ """Load all PDB files from a directory and convert to graphs.
152
+
153
+ Parameters
154
+ ----------
155
+ directory : str
156
+ Path to a folder containing ``.pdb`` files.
157
+ label : int or None
158
+ Class label to assign to every graph in the directory.
159
+ distance_threshold : float
160
+ Contact cutoff (Å).
161
+ use_residue_features : bool
162
+ Whether to use rich (24-dim) or minimal (1-dim) features.
163
+ filename_pattern : str
164
+ Glob pattern to match PDB files.
165
+
166
+ Returns
167
+ -------
168
+ list of Data
169
+ """
170
+ dirpath = Path(directory)
171
+ if not dirpath.is_dir():
172
+ raise FileNotFoundError(f"Directory not found: {directory}")
173
+
174
+ pdb_files = sorted(dirpath.glob(filename_pattern))
175
+ if not pdb_files:
176
+ logger.warning("No PDB files matching %r in %s", filename_pattern, directory)
177
+ return []
178
+
179
+ graphs: list[Data] = []
180
+ for pdb_file in pdb_files:
181
+ try:
182
+ g = pdb_to_graph(
183
+ str(pdb_file),
184
+ distance_threshold=distance_threshold,
185
+ use_residue_features=use_residue_features,
186
+ label=label,
187
+ )
188
+ graphs.append(g)
189
+ except Exception as exc:
190
+ logger.warning("Failed to convert %s: %s", pdb_file.name, exc)
191
+
192
+ logger.info(
193
+ "Loaded %d graphs from %s (label=%s)", len(graphs), directory, label,
194
+ )
195
+ return graphs
196
+
197
+
198
+ def collect_best_models(
199
+ source_dir: str,
200
+ dest_dir: str,
201
+ pattern: str = "*model_1_seed_000.pdb",
202
+ ) -> List[Path]:
203
+ """Find the best-ranked PDB from each ColabFold job and copy to *dest_dir*.
204
+
205
+ This replaces the notebook's ``search_and_move_files`` function.
206
+ Files are *copied* (not moved) to avoid destructive side-effects.
207
+
208
+ Returns the list of destination paths.
209
+ """
210
+ import shutil
211
+
212
+ src = Path(source_dir)
213
+ dst = Path(dest_dir)
214
+ dst.mkdir(parents=True, exist_ok=True)
215
+
216
+ copied: list[Path] = []
217
+ for pdb in src.rglob(pattern):
218
+ dest_path = dst / pdb.name
219
+ shutil.copy2(pdb, dest_path)
220
+ copied.append(dest_path)
221
+ logger.debug("Copied %s → %s", pdb, dest_path)
222
+
223
+ logger.info("Collected %d PDB models into %s", len(copied), dest_dir)
224
+ return copied