variantfold 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ Metadata-Version: 2.4
2
+ Name: variantfold
3
+ Version: 0.1.0
4
+ Summary: Classify variants of uncertain significance using AlphaFold-predicted protein structures and graph neural networks.
5
+ Author: VariantFold Contributors
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/comparativechrono/VariantFold
8
+ Project-URL: Repository, https://github.com/comparativechrono/VariantFold
9
+ Project-URL: Issues, https://github.com/comparativechrono/VariantFold/issues
10
+ Keywords: bioinformatics,alphafold,variant classification,GNN,VUS,ACMG
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Requires-Python: >=3.9
20
+ Description-Content-Type: text/markdown
21
+ Requires-Dist: biopython>=1.80
22
+ Requires-Dist: biopandas>=0.4
23
+ Requires-Dist: numpy>=1.22
24
+ Requires-Dist: pandas>=1.4
25
+ Requires-Dist: scikit-learn>=1.0
26
+ Requires-Dist: torch>=2.0
27
+ Requires-Dist: torch-geometric>=2.3
28
+ Provides-Extra: structure
29
+ Requires-Dist: colabfold[alphafold-minus-jax]; extra == "structure"
30
+ Provides-Extra: viz
31
+ Requires-Dist: matplotlib>=3.5; extra == "viz"
32
+ Requires-Dist: seaborn>=0.12; extra == "viz"
33
+ Requires-Dist: networkx>=2.8; extra == "viz"
34
+ Requires-Dist: py3Dmol>=1.8; extra == "viz"
35
+ Provides-Extra: dgl
36
+ Requires-Dist: dgl>=1.0; extra == "dgl"
37
+ Provides-Extra: dev
38
+ Requires-Dist: pytest>=7.0; extra == "dev"
39
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
40
+ Requires-Dist: ruff>=0.1; extra == "dev"
41
+ Provides-Extra: all
42
+ Requires-Dist: variantfold[dev,dgl,structure,viz]; extra == "all"
43
+
44
+ # VariantFold
45
+
46
+ Classify **Variants of Uncertain Significance (VUS)** using AlphaFold-predicted protein structures and Graph Neural Networks.
47
+
48
+ VariantFold leverages protein structure predictions from ColabFold/AlphaFold and a Graph Convolutional Network (GCN) to classify VUS based on the standardised ACMG-AMP variant classification system.
49
+
50
+ ## Workflow
51
+
52
+ ```
53
+ ClinVar data → Parse variants → Mutate sequences → ColabFold 3-D prediction
54
+ → PDB-to-graph conversion → Train GCN (benign vs pathogenic) → Classify VUS
55
+ ```
56
+
57
+ 1. **Parse** — Extract missense variants from ClinVar downloads (benign, pathogenic, VUS).
58
+ 2. **Mutate** — Apply each variant to the reference protein sequence.
59
+ 3. **Predict** — Run ColabFold to generate 3-D structure models for every variant.
60
+ 4. **Convert** — Transform PDB files into PyTorch Geometric residue-level graphs with rich node features (one-hot amino acid, 3-D coordinates, pLDDT).
61
+ 5. **Train** — Train a multi-layer GCN on the benign vs pathogenic graph dataset.
62
+ 6. **Classify** — Run the trained model on VUS structures to predict likely benign / likely pathogenic with probabilities.
63
+
64
+ ## Installation
65
+
66
+ ```bash
67
+ # Core package (graph conversion + GCN training/inference)
68
+ pip install .
69
+
70
+ # With ColabFold for structure prediction (GPU recommended)
71
+ pip install ".[structure]"
72
+
73
+ # With visualisation tools
74
+ pip install ".[viz]"
75
+
76
+ # Everything
77
+ pip install ".[all]"
78
+ ```
79
+
80
+ ## Quick start — Python API
81
+
82
+ ```python
83
+ from variantfold import VariantFoldConfig, VariantFoldPipeline
84
+
85
+ cfg = VariantFoldConfig(
86
+ gene_symbol="VHL",
87
+ entrez_email="your_email@example.com",
88
+ )
89
+
90
+ pipe = VariantFoldPipeline(cfg)
91
+ pipe.step1_parse_variants() # Parse ClinVar files + fetch sequence
92
+ # pipe.step2_predict_structures() # Run ColabFold (long — needs GPU)
93
+ pipe.step3_collect_models() # Gather best PDB models
94
+ metrics = pipe.step4_train() # Train GCN
95
+ print(f"Test accuracy: {metrics['accuracy']:.2%}")
96
+
97
+ vus_df = pipe.step5_classify_vus()
98
+ print(vus_df)
99
+ ```
100
+
101
+ ## Quick start — CLI
102
+
103
+ ```bash
104
+ # Run steps 1, 3, 4, 5 (assumes PDB libraries are already populated)
105
+ variantfold run --gene VHL --email you@example.com --steps 1,3,4,5
106
+
107
+ # Standalone inference on new PDB files
108
+ variantfold predict --model variantfold_VHL/variantfold_model.pt \
109
+ --pdb-dir ./new_vus_pdbs/
110
+ ```
111
+
112
+ ## Input data
113
+
114
+ Place these files in the working directory (`./variantfold_<gene>/`):
115
+
116
+ | File | Description |
117
+ |------|-------------|
118
+ | `clinvar_result_bng.txt` | ClinVar download filtered to **benign** variants |
119
+ | `clinvar_result_ptg.txt` | ClinVar download filtered to **pathogenic** variants |
120
+ | `clinvar_result_vus.txt` | ClinVar download filtered to **VUS** *(optional)* |
121
+
122
+ Download from [ClinVar](https://www.ncbi.nlm.nih.gov/clinvar/) using the tab-delimited download with default settings.
123
+
124
+ ## Configuration
125
+
126
+ All parameters are set via `VariantFoldConfig`:
127
+
128
+ ```python
129
+ cfg = VariantFoldConfig(
130
+ gene_symbol="TP53",
131
+ entrez_email="you@example.com",
132
+ distance_threshold=6.5, # Å, residue contact cutoff
133
+ gcn_hidden_dim=64, # GCN layer width
134
+ gcn_num_layers=3, # depth
135
+ epochs=200,
136
+ learning_rate=0.01,
137
+ train_fraction=0.8,
138
+ use_residue_features=True, # 24-dim features (set False for legacy 1-dim)
139
+ )
140
+ ```
141
+
142
+ ## Development
143
+
144
+ ```bash
145
+ pip install -e ".[dev]"
146
+ pytest
147
+ ```
148
+
149
+ ## Licence
150
+
151
+ MIT
@@ -0,0 +1,108 @@
1
+ # VariantFold
2
+
3
+ Classify **Variants of Uncertain Significance (VUS)** using AlphaFold-predicted protein structures and Graph Neural Networks.
4
+
5
+ VariantFold leverages protein structure predictions from ColabFold/AlphaFold and a Graph Convolutional Network (GCN) to classify VUS based on the standardised ACMG-AMP variant classification system.
6
+
7
+ ## Workflow
8
+
9
+ ```
10
+ ClinVar data → Parse variants → Mutate sequences → ColabFold 3-D prediction
11
+ → PDB-to-graph conversion → Train GCN (benign vs pathogenic) → Classify VUS
12
+ ```
13
+
14
+ 1. **Parse** — Extract missense variants from ClinVar downloads (benign, pathogenic, VUS).
15
+ 2. **Mutate** — Apply each variant to the reference protein sequence.
16
+ 3. **Predict** — Run ColabFold to generate 3-D structure models for every variant.
17
+ 4. **Convert** — Transform PDB files into PyTorch Geometric residue-level graphs with rich node features (one-hot amino acid, 3-D coordinates, pLDDT).
18
+ 5. **Train** — Train a multi-layer GCN on the benign vs pathogenic graph dataset.
19
+ 6. **Classify** — Run the trained model on VUS structures to predict likely benign / likely pathogenic with probabilities.
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ # Core package (graph conversion + GCN training/inference)
25
+ pip install .
26
+
27
+ # With ColabFold for structure prediction (GPU recommended)
28
+ pip install ".[structure]"
29
+
30
+ # With visualisation tools
31
+ pip install ".[viz]"
32
+
33
+ # Everything
34
+ pip install ".[all]"
35
+ ```
36
+
37
+ ## Quick start — Python API
38
+
39
+ ```python
40
+ from variantfold import VariantFoldConfig, VariantFoldPipeline
41
+
42
+ cfg = VariantFoldConfig(
43
+ gene_symbol="VHL",
44
+ entrez_email="your_email@example.com",
45
+ )
46
+
47
+ pipe = VariantFoldPipeline(cfg)
48
+ pipe.step1_parse_variants() # Parse ClinVar files + fetch sequence
49
+ # pipe.step2_predict_structures() # Run ColabFold (long — needs GPU)
50
+ pipe.step3_collect_models() # Gather best PDB models
51
+ metrics = pipe.step4_train() # Train GCN
52
+ print(f"Test accuracy: {metrics['accuracy']:.2%}")
53
+
54
+ vus_df = pipe.step5_classify_vus()
55
+ print(vus_df)
56
+ ```
57
+
58
+ ## Quick start — CLI
59
+
60
+ ```bash
61
+ # Run steps 1, 3, 4, 5 (assumes PDB libraries are already populated)
62
+ variantfold run --gene VHL --email you@example.com --steps 1,3,4,5
63
+
64
+ # Standalone inference on new PDB files
65
+ variantfold predict --model variantfold_VHL/variantfold_model.pt \
66
+ --pdb-dir ./new_vus_pdbs/
67
+ ```
68
+
69
+ ## Input data
70
+
71
+ Place these files in the working directory (`./variantfold_<gene>/`):
72
+
73
+ | File | Description |
74
+ |------|-------------|
75
+ | `clinvar_result_bng.txt` | ClinVar download filtered to **benign** variants |
76
+ | `clinvar_result_ptg.txt` | ClinVar download filtered to **pathogenic** variants |
77
+ | `clinvar_result_vus.txt` | ClinVar download filtered to **VUS** *(optional)* |
78
+
79
+ Download from [ClinVar](https://www.ncbi.nlm.nih.gov/clinvar/) using the tab-delimited download with default settings.
80
+
81
+ ## Configuration
82
+
83
+ All parameters are set via `VariantFoldConfig`:
84
+
85
+ ```python
86
+ cfg = VariantFoldConfig(
87
+ gene_symbol="TP53",
88
+ entrez_email="you@example.com",
89
+ distance_threshold=6.5, # Å, residue contact cutoff
90
+ gcn_hidden_dim=64, # GCN layer width
91
+ gcn_num_layers=3, # depth
92
+ epochs=200,
93
+ learning_rate=0.01,
94
+ train_fraction=0.8,
95
+ use_residue_features=True, # 24-dim features (set False for legacy 1-dim)
96
+ )
97
+ ```
98
+
99
+ ## Development
100
+
101
+ ```bash
102
+ pip install -e ".[dev]"
103
+ pytest
104
+ ```
105
+
106
+ ## Licence
107
+
108
+ MIT
@@ -0,0 +1,76 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "variantfold"
7
+ version = "0.1.0"
8
+ description = "Classify variants of uncertain significance using AlphaFold-predicted protein structures and graph neural networks."
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.9"
12
+ authors = [
13
+ {name = "VariantFold Contributors"},
14
+ ]
15
+ keywords = ["bioinformatics", "alphafold", "variant classification", "GNN", "VUS", "ACMG"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Science/Research",
19
+ "Topic :: Scientific/Engineering :: Bio-Informatics",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.9",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ ]
26
+
27
+ dependencies = [
28
+ "biopython>=1.80",
29
+ "biopandas>=0.4",
30
+ "numpy>=1.22",
31
+ "pandas>=1.4",
32
+ "scikit-learn>=1.0",
33
+ "torch>=2.0",
34
+ "torch-geometric>=2.3",
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ structure = [
39
+ # ColabFold and its AlphaFold dependency — heavy, GPU-only
40
+ "colabfold[alphafold-minus-jax]",
41
+ ]
42
+ viz = [
43
+ "matplotlib>=3.5",
44
+ "seaborn>=0.12",
45
+ "networkx>=2.8",
46
+ "py3Dmol>=1.8",
47
+ ]
48
+ dgl = [
49
+ # Only needed for the legacy DGL-based analysis path
50
+ "dgl>=1.0",
51
+ ]
52
+ dev = [
53
+ "pytest>=7.0",
54
+ "pytest-cov>=4.0",
55
+ "ruff>=0.1",
56
+ ]
57
+ all = ["variantfold[structure,viz,dgl,dev]"]
58
+
59
+ [project.scripts]
60
+ variantfold = "variantfold.cli:main"
61
+
62
+ [project.urls]
63
+ Homepage = "https://github.com/comparativechrono/VariantFold"
64
+ Repository = "https://github.com/comparativechrono/VariantFold"
65
+ Issues = "https://github.com/comparativechrono/VariantFold/issues"
66
+
67
+ [tool.setuptools.packages.find]
68
+ include = ["variantfold*"]
69
+
70
+ [tool.pytest.ini_options]
71
+ testpaths = ["tests"]
72
+ addopts = "-v --tb=short"
73
+
74
+ [tool.ruff]
75
+ target-version = "py39"
76
+ line-length = 95
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,305 @@
1
+ """
2
+ Tests for VariantFold.
3
+
4
+ Run with: pytest tests/
5
+ """
6
+
7
+ import os
8
+ import tempfile
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import pytest
13
+ import torch
14
+
15
+ # ============================================================================
16
+ # Tests for variantfold.variants (BUG-1, BUG-2, BUG-9)
17
+ # ============================================================================
18
+
19
+ from variantfold.variants import (
20
+ parse_clinvar_variant,
21
+ load_clinvar_table,
22
+ swap_amino_acid,
23
+ generate_mutant_sequences,
24
+ )
25
+
26
+
27
+ class TestParseClinvarVariant:
28
+ """BUG-1: The old regex only handled single-letter codes.
29
+ These tests verify all ClinVar formats are parsed correctly."""
30
+
31
+ def test_three_letter_codes(self):
32
+ assert parse_clinvar_variant("p.Val600Glu") == ("V", 600, "E")
33
+
34
+ def test_three_letter_with_parens(self):
35
+ assert parse_clinvar_variant("p.(Arg100Trp)") == ("R", 100, "W")
36
+
37
+ def test_single_letter_with_prefix(self):
38
+ assert parse_clinvar_variant("p.R155W") == ("R", 155, "W")
39
+
40
+ def test_single_letter_bare(self):
41
+ assert parse_clinvar_variant("V600E") == ("V", 600, "E")
42
+
43
+ def test_nonsense_returns_none(self):
44
+ # Ter = stop codon — not a missense variant
45
+ assert parse_clinvar_variant("p.Arg214Ter") is None
46
+
47
+ def test_frameshift_returns_none(self):
48
+ assert parse_clinvar_variant("p.Ter214GlnfsTer59") is None
49
+
50
+ def test_deletion_returns_none(self):
51
+ assert parse_clinvar_variant("p.Val600del") is None
52
+
53
+ def test_insertion_returns_none(self):
54
+ assert parse_clinvar_variant("p.Ala100_Glu101insGly") is None
55
+
56
+ def test_synonymous_returns_none(self):
57
+ # Same AA → synonymous, should be skipped
58
+ assert parse_clinvar_variant("p.Val600Val") is None
59
+
60
+ def test_empty_returns_none(self):
61
+ assert parse_clinvar_variant("") is None
62
+ assert parse_clinvar_variant("-") is None
63
+
64
+ def test_comma_separated_takes_first(self):
65
+ result = parse_clinvar_variant("p.Val600Glu, p.Val600Ala")
66
+ assert result == ("V", 600, "E")
67
+
68
+ def test_duplication_returns_none(self):
69
+ assert parse_clinvar_variant("p.Ala100dup") is None
70
+
71
+
72
+ class TestLoadClinvarTable:
73
+ """Test ClinVar file loading with realistic data."""
74
+
75
+ def test_loads_valid_file(self, tmp_path):
76
+ tsv = tmp_path / "test_clinvar.txt"
77
+ tsv.write_text(
78
+ "Name\tProtein change\tOther\n"
79
+ "var1\tp.Val600Glu\tinfo\n"
80
+ "var2\tp.Arg100Trp\tinfo\n"
81
+ "var3\tp.Ter214GlnfsTer59\tinfo\n" # should be skipped
82
+ "var4\t\tinfo\n" # NaN, skipped
83
+ )
84
+ result = load_clinvar_table(str(tsv))
85
+ assert len(result) == 2
86
+ assert result[0] == ("V", 600, "E")
87
+ assert result[1] == ("R", 100, "W")
88
+
89
+ def test_missing_file_raises(self):
90
+ with pytest.raises(FileNotFoundError):
91
+ load_clinvar_table("/nonexistent/file.txt")
92
+
93
+ def test_missing_column_raises(self, tmp_path):
94
+ tsv = tmp_path / "bad.txt"
95
+ tsv.write_text("Name\tWrong Column\n")
96
+ with pytest.raises(ValueError, match="Protein change"):
97
+ load_clinvar_table(str(tsv))
98
+
99
+
100
+ class TestSwapAminoAcid:
101
+ """BUG-9: Validate original residue before swapping."""
102
+
103
+ def test_basic_swap(self):
104
+ seq = "MVLSPADKTN"
105
+ result = swap_amino_acid(seq, 1, "A")
106
+ assert result == "AVLSPADKTN"
107
+
108
+ def test_last_position(self):
109
+ seq = "MVLSPADKTN"
110
+ result = swap_amino_acid(seq, 10, "A")
111
+ assert result == "MVLSPADKTA"
112
+
113
+ def test_validation_passes(self):
114
+ seq = "MVLSPADKTN"
115
+ result = swap_amino_acid(seq, 1, "A", expected_ref="M")
116
+ assert result[0] == "A"
117
+
118
+ def test_validation_fails(self):
119
+ seq = "MVLSPADKTN"
120
+ with pytest.raises(ValueError, match="Expected G at position 1"):
121
+ swap_amino_acid(seq, 1, "A", expected_ref="G")
122
+
123
+ def test_position_out_of_range(self):
124
+ with pytest.raises(ValueError, match="out of range"):
125
+ swap_amino_acid("MVLS", 0, "A")
126
+ with pytest.raises(ValueError, match="out of range"):
127
+ swap_amino_acid("MVLS", 5, "A")
128
+
129
+ def test_invalid_aa_code(self):
130
+ with pytest.raises(ValueError, match="Invalid amino acid"):
131
+ swap_amino_acid("MVLS", 1, "X")
132
+
133
+
134
+ class TestGenerateMutantSequences:
135
+ def test_generates_fasta_files(self, tmp_path):
136
+ ref_seq = "MVLSPADKTNVKAAWGKVGA"
137
+ variants = [("M", 1, "A"), ("V", 2, "G")]
138
+ results = generate_mutant_sequences(ref_seq, variants, str(tmp_path))
139
+
140
+ assert len(results) == 2
141
+ assert results[0][0] == "M1A"
142
+ assert results[0][1].exists()
143
+
144
+ content = results[0][1].read_text()
145
+ assert content.startswith(">M1A\n")
146
+ assert content.strip().split("\n")[1][0] == "A"
147
+
148
+ def test_skips_mismatched_variants(self, tmp_path):
149
+ ref_seq = "MVLS"
150
+ variants = [("G", 1, "A")] # G ≠ M at position 1
151
+ results = generate_mutant_sequences(ref_seq, variants, str(tmp_path))
152
+ assert len(results) == 0 # skipped
153
+
154
+
155
+ # ============================================================================
156
+ # Tests for variantfold.graphs
157
+ # ============================================================================
158
+
159
+ from variantfold.graphs import pdb_to_graph, load_pdb_directory
160
+
161
+
162
+ def _make_minimal_pdb(path: str, n_residues: int = 10) -> None:
163
+ """Write a minimal valid PDB file for testing."""
164
+ lines = []
165
+ for i in range(n_residues):
166
+ atom_num = i + 1
167
+ res_num = i + 1
168
+ x = float(i * 3.8)
169
+ y = 0.0
170
+ z = 0.0
171
+ b_factor = 80.0 + i
172
+ lines.append(
173
+ f"ATOM {atom_num:5d} CA ALA A{res_num:4d} "
174
+ f"{x:8.3f}{y:8.3f}{z:8.3f} 1.00{b_factor:6.2f} C "
175
+ )
176
+ lines.append("END")
177
+ with open(path, "w") as f:
178
+ f.write("\n".join(lines) + "\n")
179
+
180
+
181
+ class TestPdbToGraph:
182
+ def test_creates_graph(self, tmp_path):
183
+ pdb = tmp_path / "test.pdb"
184
+ _make_minimal_pdb(str(pdb), n_residues=10)
185
+
186
+ data = pdb_to_graph(str(pdb), distance_threshold=6.5, label=0)
187
+ assert data.x is not None
188
+ assert data.edge_index.shape[0] == 2
189
+ assert data.y.item() == 0
190
+
191
+ def test_rich_features_shape(self, tmp_path):
192
+ pdb = tmp_path / "test.pdb"
193
+ _make_minimal_pdb(str(pdb), n_residues=5)
194
+
195
+ data = pdb_to_graph(str(pdb), use_residue_features=True, label=1)
196
+ assert data.x.shape == (5, 24) # 20 AA + 3 coords + 1 pLDDT
197
+
198
+ def test_legacy_features_shape(self, tmp_path):
199
+ pdb = tmp_path / "test.pdb"
200
+ _make_minimal_pdb(str(pdb), n_residues=5)
201
+
202
+ data = pdb_to_graph(str(pdb), use_residue_features=False, label=0)
203
+ assert data.x.shape == (5, 1)
204
+
205
+ def test_vus_label_is_none(self, tmp_path):
206
+ pdb = tmp_path / "test.pdb"
207
+ _make_minimal_pdb(str(pdb), n_residues=5)
208
+
209
+ data = pdb_to_graph(str(pdb), label=None)
210
+ assert data.y is None
211
+
212
+
213
+ class TestLoadPdbDirectory:
214
+ def test_loads_multiple_pdbs(self, tmp_path):
215
+ for i in range(3):
216
+ _make_minimal_pdb(str(tmp_path / f"var_{i}.pdb"), n_residues=8)
217
+
218
+ graphs = load_pdb_directory(str(tmp_path), label=1)
219
+ assert len(graphs) == 3
220
+ assert all(g.y.item() == 1 for g in graphs)
221
+
222
+ def test_empty_directory(self, tmp_path):
223
+ graphs = load_pdb_directory(str(tmp_path), label=0)
224
+ assert graphs == []
225
+
226
+
227
+ # ============================================================================
228
+ # Tests for variantfold.model
229
+ # ============================================================================
230
+
231
+ from variantfold.model import VariantGCN, train_model, predict_vus, save_model, load_model
232
+
233
+
234
+ class TestVariantGCN:
235
+ def test_forward_pass(self):
236
+ from torch_geometric.data import Batch
237
+
238
+ model = VariantGCN(input_dim=24, hidden_dim=16, num_layers=2).float()
239
+
240
+ # Create a small fake batch
241
+ data1 = torch_geometric.data.Data(
242
+ x=torch.randn(10, 24),
243
+ edge_index=torch.tensor([[0,1,2,3], [1,2,3,4]], dtype=torch.long),
244
+ y=torch.tensor([0]),
245
+ )
246
+ data2 = torch_geometric.data.Data(
247
+ x=torch.randn(8, 24),
248
+ edge_index=torch.tensor([[0,1,2], [1,2,3]], dtype=torch.long),
249
+ y=torch.tensor([1]),
250
+ )
251
+ batch = Batch.from_data_list([data1, data2])
252
+
253
+ out = model(batch)
254
+ assert out.shape == (2, 2) # 2 graphs, 2 classes
255
+
256
+
257
+ class TestSaveLoadModel:
258
+ def test_round_trip(self, tmp_path):
259
+ model = VariantGCN(input_dim=24, hidden_dim=16, num_layers=2)
260
+ path = str(tmp_path / "model.pt")
261
+
262
+ save_model(model, path, config={"input_dim": 24, "hidden_dim": 16, "num_layers": 2})
263
+ loaded = load_model(path)
264
+
265
+ # Check that parameters match
266
+ for (n1, p1), (n2, p2) in zip(
267
+ model.named_parameters(), loaded.named_parameters()
268
+ ):
269
+ assert n1 == n2
270
+ assert torch.equal(p1, p2)
271
+
272
+
273
+ # ============================================================================
274
+ # Tests for variantfold.config
275
+ # ============================================================================
276
+
277
+ from variantfold.config import VariantFoldConfig
278
+
279
+
280
+ class TestConfig:
281
+ def test_defaults(self):
282
+ cfg = VariantFoldConfig(gene_symbol="TP53", entrez_email="test@test.com")
283
+ assert cfg.work_dir == "./variantfold_TP53"
284
+ assert cfg.distance_threshold == 6.5
285
+ assert cfg.num_node_features == 24
286
+
287
+ def test_legacy_features(self):
288
+ cfg = VariantFoldConfig(
289
+ gene_symbol="TP53", entrez_email="test@test.com",
290
+ use_residue_features=False,
291
+ )
292
+ assert cfg.num_node_features == 1
293
+
294
+ def test_ensure_directories(self, tmp_path):
295
+ cfg = VariantFoldConfig(
296
+ gene_symbol="VHL", entrez_email="test@test.com",
297
+ work_dir=str(tmp_path / "test_run"),
298
+ )
299
+ cfg.ensure_directories()
300
+ assert os.path.isdir(cfg.benign_library)
301
+ assert os.path.isdir(cfg.vus_library)
302
+
303
+
304
+ # Need this import for the model test
305
+ import torch_geometric.data
@@ -0,0 +1,41 @@
1
+ """
2
+ VariantFold — Classify variants of uncertain significance using
3
+ AlphaFold-predicted protein structures and graph neural networks.
4
+
5
+ Workflow
6
+ --------
7
+ 1. Parse ClinVar variant data
8
+ 2. Mutate reference protein sequences
9
+ 3. Predict 3-D structures with ColabFold / AlphaFold
10
+ 4. Convert PDB structures to residue-level graphs
11
+ 5. Train a GCN classifier (benign vs pathogenic)
12
+ 6. Classify VUS with the trained model
13
+ """
14
+
15
+ __version__ = "0.1.0"
16
+
17
+ from variantfold.config import VariantFoldConfig
18
+ from variantfold.variants import (
19
+ parse_clinvar_variant,
20
+ load_clinvar_table,
21
+ swap_amino_acid,
22
+ generate_mutant_sequences,
23
+ )
24
+ from variantfold.graphs import pdb_to_graph, load_pdb_directory
25
+ from variantfold.model import VariantGCN, train_model, evaluate_model, predict_vus
26
+ from variantfold.pipeline import VariantFoldPipeline
27
+
28
+ __all__ = [
29
+ "VariantFoldConfig",
30
+ "parse_clinvar_variant",
31
+ "load_clinvar_table",
32
+ "swap_amino_acid",
33
+ "generate_mutant_sequences",
34
+ "pdb_to_graph",
35
+ "load_pdb_directory",
36
+ "VariantGCN",
37
+ "train_model",
38
+ "evaluate_model",
39
+ "predict_vus",
40
+ "VariantFoldPipeline",
41
+ ]