pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of pg-sui might be problematic. Click here for more details.
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
- pg_sui-1.6.8.dist-info/RECORD +78 -0
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
- pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
- pg_sui-1.6.8.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +635 -0
- pgsui/data_processing/config.py +576 -0
- pgsui/data_processing/containers.py +1782 -0
- pgsui/data_processing/transformers.py +121 -1103
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +189 -0
- pgsui/electron/app/package-lock.json +6893 -0
- pgsui/electron/app/package.json +50 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +146 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +130 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +59 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
- pgsui/impute/deterministic/imputers/mode.py +679 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +971 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
- pgsui/impute/supervised/base.py +339 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
- pgsui/impute/supervised/imputers/random_forest.py +287 -0
- pgsui/impute/unsupervised/base.py +924 -0
- pgsui/impute/unsupervised/callbacks.py +89 -263
- pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
- pgsui/impute/unsupervised/imputers/vae.py +957 -0
- pgsui/impute/unsupervised/loss_functions.py +158 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
- pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
- pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
- pgsui/impute/unsupervised/models/vae_model.py +259 -618
- pgsui/impute/unsupervised/nn_scorers.py +215 -0
- pgsui/utils/classification_viz.py +591 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +514 -824
- pgsui/utils/scorers.py +212 -438
- pg_sui-1.0.2.1.dist-info/RECORD +0 -75
- pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -735
- pgsui/impute/impute.py +0 -1486
- pgsui/impute/simple_imputers.py +0 -1439
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
- pgsui/impute/unsupervised/keras_classifiers.py +0 -702
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -297
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -214
- {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
- /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
|
@@ -0,0 +1,971 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import random
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
# Third-party imports
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import scipy.linalg
|
|
9
|
+
import toytree as tt
|
|
10
|
+
|
|
11
|
+
from pgsui.utils.plotting import Plotting
|
|
12
|
+
from pgsui.utils.scorers import Scorer
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from snpio.analysis.tree_parser import TreeParser
|
|
16
|
+
from snpio.read_input.genotype_data import GenotypeData
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _DiploidAggregator:
|
|
20
|
+
"""Precompute mappings to lump ordered diploid states (16) into unordered genotypes (10).
|
|
21
|
+
|
|
22
|
+
Notes:
|
|
23
|
+
Genotype class order matches the SNPio allele → genotype encodings:
|
|
24
|
+
- 0: AA, 1: AC, 2: AG, 3: AT, 4: CC, 5: CG, 6: CT, 7: GG, 8: GT, 9: T
|
|
25
|
+
- Allele order is A(0), C(1), G(2), T(3).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
# Ordered pairs (i,j) lexicographic over {0,1,2,3}; 16 states
|
|
30
|
+
self.ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
|
|
31
|
+
# Unordered genotype classes as (min,max) in the specified class order
|
|
32
|
+
self.genotype_classes = (
|
|
33
|
+
[(0, 0)]
|
|
34
|
+
+ [(0, 1), (0, 2), (0, 3)]
|
|
35
|
+
+ [(1, 1), (1, 2), (1, 3)]
|
|
36
|
+
+ [(2, 2), (2, 3)]
|
|
37
|
+
+ [(3, 3)]
|
|
38
|
+
) # 10 states
|
|
39
|
+
|
|
40
|
+
# Map: class index -> list of ordered-state indices
|
|
41
|
+
self.class_to_ordered: list[list[int]] = []
|
|
42
|
+
for i, j in self.genotype_classes:
|
|
43
|
+
members = []
|
|
44
|
+
for k, (a, b) in enumerate(self.ordered_pairs):
|
|
45
|
+
if i == j: # homozygote
|
|
46
|
+
if a == i and b == j:
|
|
47
|
+
members.append(k) # exactly one member
|
|
48
|
+
else: # heterozygote, both permutations
|
|
49
|
+
if (a == i and b == j) or (a == j and b == i):
|
|
50
|
+
members.append(k)
|
|
51
|
+
self.class_to_ordered.append(members)
|
|
52
|
+
|
|
53
|
+
# Build R (16x10) and C (10x16)
|
|
54
|
+
self.R = np.zeros((16, 10), dtype=float)
|
|
55
|
+
self.C = np.zeros((10, 16), dtype=float)
|
|
56
|
+
for c, members in enumerate(self.class_to_ordered):
|
|
57
|
+
m = float(len(members))
|
|
58
|
+
for o in members:
|
|
59
|
+
self.R[o, c] = 1.0 / m # spread class prob equally to ordered members
|
|
60
|
+
self.C[c, o] = 1.0 # sum ordered probs back to class
|
|
61
|
+
|
|
62
|
+
# For allele-marginalization from 10-state genotype posterior to 4 alleles
|
|
63
|
+
# p_allele[i] = P(ii) + 0.5 * sum_{j!=i} P(min(i,j), max(i,j))
|
|
64
|
+
self.het_classes_by_allele: dict[int, list[int]] = {0: [], 1: [], 2: [], 3: []}
|
|
65
|
+
self.hom_class_index = {0: 0, 1: 4, 2: 7, 3: 9} # AA, CC, GG, TT indices
|
|
66
|
+
for idx, (i, j) in enumerate(self.genotype_classes):
|
|
67
|
+
if i != j:
|
|
68
|
+
self.het_classes_by_allele[i].append(idx)
|
|
69
|
+
self.het_classes_by_allele[j].append(idx)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class _QCache:
|
|
73
|
+
"""Cache P(t) for haploid (4) and optionally diploid (10) via Kronecker + lumping."""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
q_df: pd.DataFrame,
|
|
78
|
+
mode: str = "haploid",
|
|
79
|
+
diploid_agg: "_DiploidAggregator | None" = None,
|
|
80
|
+
) -> None:
|
|
81
|
+
"""Precompute eigendecomposition of haploid Q.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
q_df: 4x4 generator in A,C,G,T order.
|
|
85
|
+
mode: "haploid" or "diploid".
|
|
86
|
+
diploid_agg: required if mode="diploid".
|
|
87
|
+
"""
|
|
88
|
+
m = np.asarray(q_df, dtype=float)
|
|
89
|
+
evals, V = scipy.linalg.eig(m)
|
|
90
|
+
Vinv = scipy.linalg.inv(V)
|
|
91
|
+
self.evals = evals
|
|
92
|
+
self.V = V
|
|
93
|
+
self.Vinv = Vinv
|
|
94
|
+
self.mode = mode
|
|
95
|
+
self.agg = diploid_agg
|
|
96
|
+
if self.mode == "diploid" and self.agg is None:
|
|
97
|
+
raise ValueError("Diploid mode requires a _DiploidAggregator.")
|
|
98
|
+
self._cache4: dict[float, np.ndarray] = {}
|
|
99
|
+
self._cache10: dict[float, np.ndarray] = {}
|
|
100
|
+
|
|
101
|
+
def _P4(self, s: float) -> np.ndarray:
|
|
102
|
+
"""Return P(t) for haploid (4 states)."""
|
|
103
|
+
key = round(float(s), 12)
|
|
104
|
+
if key in self._cache4:
|
|
105
|
+
return self._cache4[key]
|
|
106
|
+
expo = np.exp(self.evals * key)
|
|
107
|
+
P4 = (self.V * expo) @ self.Vinv
|
|
108
|
+
P4 = np.real_if_close(P4, tol=1e5)
|
|
109
|
+
P4[P4 < 0.0] = 0.0
|
|
110
|
+
P4 /= P4.sum(axis=1, keepdims=True).clip(min=1.0)
|
|
111
|
+
self._cache4[key] = P4
|
|
112
|
+
return P4
|
|
113
|
+
|
|
114
|
+
def P(self, t: float, rate: float = 1.0) -> np.ndarray:
|
|
115
|
+
"""Return P(t) in the active mode."""
|
|
116
|
+
s = float(rate) * float(t)
|
|
117
|
+
if self.mode == "haploid":
|
|
118
|
+
return self._P4(s)
|
|
119
|
+
# diploid
|
|
120
|
+
key = round(s, 12)
|
|
121
|
+
if key in self._cache10:
|
|
122
|
+
return self._cache10[key]
|
|
123
|
+
P4 = self._P4(s)
|
|
124
|
+
P16 = np.kron(P4, P4) # independent alleles
|
|
125
|
+
P10 = self.agg.C @ P16 @ self.agg.R # lump to unordered genotypes
|
|
126
|
+
P10 = np.maximum(P10, 0.0)
|
|
127
|
+
P10 /= P10.sum(axis=1, keepdims=True).clip(min=1.0)
|
|
128
|
+
self._cache10[key] = P10
|
|
129
|
+
return P10
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ImputePhylo:
|
|
133
|
+
"""Imputes missing genotype data using a phylogenetic likelihood model.
|
|
134
|
+
|
|
135
|
+
This imputer uses a continuous-time Markov chain (CTMC) model of sequence evolution to impute missing genotype data based on a provided phylogenetic tree. It supports both haploid and diploid data, with options for evaluating imputation accuracy through simulated missingness.
|
|
136
|
+
|
|
137
|
+
Notes:
|
|
138
|
+
- Haploid CTMC (4 states: A,C,G,T) [default].
|
|
139
|
+
- Diploid CTMC (10 unordered genotype states: AA,...,TT), derived by independent allele evolution and lumping ordered pairs.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
genotype_data: "GenotypeData",
|
|
145
|
+
tree_parser: "TreeParser",
|
|
146
|
+
prefix: str,
|
|
147
|
+
min_branch_length: float = 1e-10,
|
|
148
|
+
*,
|
|
149
|
+
haploid: bool = False,
|
|
150
|
+
eval_missing_rate: float = 0.0,
|
|
151
|
+
column_subset: Optional[List[int]] = None,
|
|
152
|
+
save_plots: bool = False,
|
|
153
|
+
verbose: bool = False,
|
|
154
|
+
debug: bool = False,
|
|
155
|
+
) -> None:
|
|
156
|
+
self.genotype_data = genotype_data
|
|
157
|
+
self.tree_parser = tree_parser
|
|
158
|
+
self.prefix = prefix
|
|
159
|
+
self.min_branch_length = min_branch_length
|
|
160
|
+
self.eval_missing_rate = eval_missing_rate
|
|
161
|
+
self.column_subset = column_subset
|
|
162
|
+
self.save_plots = save_plots
|
|
163
|
+
self.logger = genotype_data.logger
|
|
164
|
+
self.verbose = verbose
|
|
165
|
+
self.debug = debug
|
|
166
|
+
self.char_map = {"A": 0, "C": 1, "G": 2, "T": 3}
|
|
167
|
+
self.nuc_map = {v: k for k, v in self.char_map.items()}
|
|
168
|
+
self.imputer_data_: Optional[Tuple] = None
|
|
169
|
+
self.ground_truth_: Optional[Dict] = None
|
|
170
|
+
self.scorer = Scorer(self.prefix, average="macro", verbose=verbose, debug=debug)
|
|
171
|
+
self.plotter = Plotting(
|
|
172
|
+
"ImputePhylo",
|
|
173
|
+
prefix=self.prefix,
|
|
174
|
+
plot_format=self.genotype_data.plot_format,
|
|
175
|
+
plot_fontsize=self.genotype_data.plot_fontsize,
|
|
176
|
+
plot_dpi=self.genotype_data.plot_dpi,
|
|
177
|
+
title_fontsize=self.genotype_data.plot_fontsize,
|
|
178
|
+
verbose=verbose,
|
|
179
|
+
debug=debug,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
self._MISSING_TOKENS = {"N", "n", "-9", ".", "?", "./.", ""}
|
|
183
|
+
|
|
184
|
+
self.haploid = haploid
|
|
185
|
+
self._dip_agg: _DiploidAggregator | None = None
|
|
186
|
+
|
|
187
|
+
self.imputed_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
|
|
188
|
+
self.imputed_genotype_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
|
|
189
|
+
self.evaluation_results_: dict | None = None
|
|
190
|
+
|
|
191
|
+
def fit(self) -> "ImputePhylo":
|
|
192
|
+
"""Prepares the imputer by parsing and validating input data.
|
|
193
|
+
|
|
194
|
+
This method does the following:
|
|
195
|
+
- Validates the genotype data and phylogenetic tree.
|
|
196
|
+
- Extracts the genotype matrix, pruned tree, Q-matrix, and site rates.
|
|
197
|
+
- Sets up internal structures for imputation.
|
|
198
|
+
"""
|
|
199
|
+
self.imputer_data_ = self._parse_arguments()
|
|
200
|
+
return self
|
|
201
|
+
|
|
202
|
+
def transform(self) -> np.ndarray:
|
|
203
|
+
"""Transforms the data by imputing missing values.
|
|
204
|
+
|
|
205
|
+
This method does the following:
|
|
206
|
+
- Uses the fitted imputer to perform phylogenetic imputation.
|
|
207
|
+
- Returns the imputed genotype DataFrame.
|
|
208
|
+
"""
|
|
209
|
+
if self.imputer_data_ is None:
|
|
210
|
+
msg = "The imputer has not been fitted. Call 'fit' first."
|
|
211
|
+
self.logger.error(msg)
|
|
212
|
+
raise RuntimeError(msg)
|
|
213
|
+
|
|
214
|
+
original_genotypes, tree, q_matrix_in, site_rates = self.imputer_data_
|
|
215
|
+
q_matrix = self._repair_and_scale_q(q_matrix_in, target_mean_rate=1.0)
|
|
216
|
+
|
|
217
|
+
if not self.haploid:
|
|
218
|
+
if self._dip_agg is None:
|
|
219
|
+
self._dip_agg = _DiploidAggregator()
|
|
220
|
+
self._qcache = _QCache(q_matrix, mode="diploid", diploid_agg=self._dip_agg)
|
|
221
|
+
else:
|
|
222
|
+
self._qcache = _QCache(q_matrix, mode="haploid")
|
|
223
|
+
|
|
224
|
+
genotypes_to_impute = original_genotypes
|
|
225
|
+
|
|
226
|
+
if self.eval_missing_rate > 0:
|
|
227
|
+
genotypes_to_impute = self._simulate_missing_data(original_genotypes)
|
|
228
|
+
|
|
229
|
+
imputed_df = self.impute_phylo(genotypes_to_impute, tree, q_matrix, site_rates)
|
|
230
|
+
|
|
231
|
+
if self.ground_truth_:
|
|
232
|
+
self._evaluate_imputation(imputed_df)
|
|
233
|
+
return imputed_df # keep as DataFrame with sample index
|
|
234
|
+
|
|
235
|
+
def fit_transform(self) -> pd.DataFrame:
|
|
236
|
+
"""Fits the imputer and transforms the data in one step.
|
|
237
|
+
|
|
238
|
+
This method does the following:
|
|
239
|
+
- Calls the `fit` method to prepare the imputer.
|
|
240
|
+
- Calls the `transform` method to perform imputation.
|
|
241
|
+
- Returns the imputed genotype DataFrame.
|
|
242
|
+
"""
|
|
243
|
+
self.fit()
|
|
244
|
+
return self.transform()
|
|
245
|
+
|
|
246
|
+
def _simulate_missing_data(
|
|
247
|
+
self, original_genotypes: Dict[str, List[str]]
|
|
248
|
+
) -> Dict[str, List[str]]:
|
|
249
|
+
"""Masks a fraction of known genotypes for evaluation."""
|
|
250
|
+
genotypes_to_impute = copy.deepcopy(original_genotypes)
|
|
251
|
+
known_positions = [
|
|
252
|
+
(sample, site_idx)
|
|
253
|
+
for sample, seq in genotypes_to_impute.items()
|
|
254
|
+
for site_idx, base in enumerate(seq)
|
|
255
|
+
if self._is_known(base)
|
|
256
|
+
]
|
|
257
|
+
|
|
258
|
+
if not known_positions:
|
|
259
|
+
self.logger.warning("No known values to mask for evaluation.")
|
|
260
|
+
return genotypes_to_impute
|
|
261
|
+
|
|
262
|
+
num_to_mask = int(len(known_positions) * self.eval_missing_rate)
|
|
263
|
+
if num_to_mask == 0:
|
|
264
|
+
self.logger.warning(f"eval_missing_rate is too low to mask any values.")
|
|
265
|
+
return genotypes_to_impute
|
|
266
|
+
|
|
267
|
+
# Sample the (sample, site_idx) tuples to be masked
|
|
268
|
+
positions_to_mask = random.sample(known_positions, num_to_mask)
|
|
269
|
+
|
|
270
|
+
# Correctly build the ground_truth dictionary
|
|
271
|
+
self.ground_truth_ = {}
|
|
272
|
+
for sample, site_idx in positions_to_mask:
|
|
273
|
+
# Store the original base before masking it
|
|
274
|
+
self.ground_truth_[(sample, site_idx)] = genotypes_to_impute[sample][
|
|
275
|
+
site_idx
|
|
276
|
+
]
|
|
277
|
+
# Now, apply the mask
|
|
278
|
+
genotypes_to_impute[sample][site_idx] = "N"
|
|
279
|
+
|
|
280
|
+
self.logger.info(f"Masked {len(self.ground_truth_)} values for evaluation.")
|
|
281
|
+
return genotypes_to_impute
|
|
282
|
+
|
|
283
|
+
def _evaluate_imputation(self, imputed_df: pd.DataFrame) -> None:
|
|
284
|
+
"""Evaluate imputation with 4 classes (haploid) or 10/16 classes (diploid).
|
|
285
|
+
|
|
286
|
+
Behavior:
|
|
287
|
+
- If self.haploid is True:
|
|
288
|
+
Evaluate per-allele with 4 classes (A,C,G,T), keeping your existing logic.
|
|
289
|
+
- If self.haploid is False:
|
|
290
|
+
Evaluate per-genotype.
|
|
291
|
+
* Default: 10 unordered classes (AA,AC,AG,AT,CC,CG,CT,GG,GT,TT).
|
|
292
|
+
* If getattr(self, 'diploid_ordered', False) is True:
|
|
293
|
+
Also compute 16-class ordered metrics (AA,AC,...,TT).
|
|
294
|
+
Requires self._dip_agg (the _DiploidAggregator). If absent, it is created.
|
|
295
|
+
|
|
296
|
+
Notes:
|
|
297
|
+
- Truth handling: only unambiguous IUPAC that maps to exactly one genotype class
|
|
298
|
+
is used for 10-class evaluation (e.g., A, C, G, T, M, R, W, S, Y, K).
|
|
299
|
+
Ambiguous codes (N, B, D, H, V, '-', etc.) are skipped for strict genotype
|
|
300
|
+
evaluation.
|
|
301
|
+
- Probabilities: prefer stored 10-class posteriors in
|
|
302
|
+
self.imputed_genotype_likelihoods_; if missing, fall back by converting the
|
|
303
|
+
4-class allele posterior to 10-class using HW independence.
|
|
304
|
+
"""
|
|
305
|
+
if not self.ground_truth_:
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
# -------------------------
|
|
309
|
+
# Haploid path (4 classes)
|
|
310
|
+
# -------------------------
|
|
311
|
+
if getattr(self, "haploid", True):
|
|
312
|
+
# --- Gather ground truth bases and posterior vectors (only if both exist)
|
|
313
|
+
rows = []
|
|
314
|
+
for (sample, site_idx), true_base in self.ground_truth_.items():
|
|
315
|
+
if (
|
|
316
|
+
sample in imputed_df.index
|
|
317
|
+
and (sample, site_idx) in self.imputed_likelihoods_
|
|
318
|
+
):
|
|
319
|
+
rows.append(
|
|
320
|
+
(true_base, self.imputed_likelihoods_[(sample, site_idx)])
|
|
321
|
+
)
|
|
322
|
+
if not rows:
|
|
323
|
+
self.logger.warning("No matching masked sites found to evaluate.")
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
true_bases = np.array([b for (b, _) in rows], dtype=object)
|
|
327
|
+
proba_mat = np.vstack([p for (_, p) in rows]) # (n,4) in A,C,G,T
|
|
328
|
+
nuc_to_idx = self.char_map # {'A':0,'C':1,'G':2,'T':3}
|
|
329
|
+
|
|
330
|
+
y_true_site = np.array(
|
|
331
|
+
[nuc_to_idx.get(str(b).upper(), -1) for b in true_bases], dtype=int
|
|
332
|
+
)
|
|
333
|
+
valid_mask = (y_true_site >= 0) & np.isfinite(proba_mat).all(axis=1)
|
|
334
|
+
if not np.any(valid_mask):
|
|
335
|
+
self.logger.error(
|
|
336
|
+
"Evaluation arrays empty after filtering invalid entries."
|
|
337
|
+
)
|
|
338
|
+
return
|
|
339
|
+
|
|
340
|
+
y_true_site = y_true_site[valid_mask]
|
|
341
|
+
proba_mat = proba_mat[valid_mask]
|
|
342
|
+
|
|
343
|
+
# Interleave each site as two alleles (iid/HW assumption) for per-allele metrics
|
|
344
|
+
y_true_int = np.repeat(y_true_site, 2)
|
|
345
|
+
y_pred_proba = np.repeat(proba_mat, 2, axis=0)
|
|
346
|
+
y_pred_int = np.argmax(y_pred_proba, axis=1)
|
|
347
|
+
|
|
348
|
+
n_classes = 4
|
|
349
|
+
y_true_ohe = np.eye(n_classes, dtype=float)[y_true_int]
|
|
350
|
+
idx_to_nuc = {v: k for k, v in nuc_to_idx.items()}
|
|
351
|
+
|
|
352
|
+
self.logger.info("--- Per-Allele Imputation Performance (4-class) ---")
|
|
353
|
+
self.evaluation_results_ = self.scorer.evaluate(
|
|
354
|
+
y_true=y_true_int,
|
|
355
|
+
y_pred=y_pred_int,
|
|
356
|
+
y_true_ohe=y_true_ohe,
|
|
357
|
+
y_pred_proba=y_pred_proba,
|
|
358
|
+
)
|
|
359
|
+
self.logger.info(f"Evaluation results: {self.evaluation_results_}")
|
|
360
|
+
|
|
361
|
+
# Plots
|
|
362
|
+
self.plotter.plot_metrics(
|
|
363
|
+
y_true=y_true_int,
|
|
364
|
+
y_pred_proba=y_pred_proba,
|
|
365
|
+
metrics=self.evaluation_results_,
|
|
366
|
+
)
|
|
367
|
+
self.plotter.plot_confusion_matrix(
|
|
368
|
+
y_true_1d=y_true_int,
|
|
369
|
+
y_pred_1d=y_pred_int,
|
|
370
|
+
label_names=["A", "C", "G", "T"],
|
|
371
|
+
)
|
|
372
|
+
return
|
|
373
|
+
|
|
374
|
+
# -------------------------
|
|
375
|
+
# Diploid path (10/16 classes)
|
|
376
|
+
# -------------------------
|
|
377
|
+
# Ensure aggregator
|
|
378
|
+
if getattr(self, "_dip_agg", None) is None:
|
|
379
|
+
self._dip_agg = _DiploidAggregator()
|
|
380
|
+
|
|
381
|
+
# Label catalogs
|
|
382
|
+
geno_classes = self._dip_agg.genotype_classes # list of (i,j) with i<=j
|
|
383
|
+
alleles = ["A", "C", "G", "T"]
|
|
384
|
+
labels10 = [f"{alleles[i]}{alleles[j]}" for (i, j) in geno_classes] # 10 labels
|
|
385
|
+
ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
|
|
386
|
+
labels16 = [
|
|
387
|
+
f"{alleles[i]}{alleles[j]}" for (i, j) in ordered_pairs
|
|
388
|
+
] # 16 labels
|
|
389
|
+
|
|
390
|
+
# Helper: strict IUPAC→single 10-class index (skip ambiguous >2-allele codes)
|
|
391
|
+
iupac_to_10_single = {
|
|
392
|
+
"A": "AA",
|
|
393
|
+
"C": "CC",
|
|
394
|
+
"G": "GG",
|
|
395
|
+
"T": "TT",
|
|
396
|
+
"M": "AC",
|
|
397
|
+
"R": "AG",
|
|
398
|
+
"W": "AT",
|
|
399
|
+
"S": "CG",
|
|
400
|
+
"Y": "CT",
|
|
401
|
+
"K": "GT",
|
|
402
|
+
}
|
|
403
|
+
label_to_10_idx = {lab: idx for idx, lab in enumerate(labels10)}
|
|
404
|
+
|
|
405
|
+
rows_10: list[tuple[int, np.ndarray]] = [] # (true_idx10, p10)
|
|
406
|
+
rows_16_truth: list[int] = [] # expanded true labels (for 16)
|
|
407
|
+
rows_16_proba: list[np.ndarray] = [] # matching (p16) rows
|
|
408
|
+
|
|
409
|
+
# Collect rows
|
|
410
|
+
for (sample, site_idx), true_code in self.ground_truth_.items():
|
|
411
|
+
# Need probabilities for this masked (sample,site)
|
|
412
|
+
p10 = None
|
|
413
|
+
if (sample, site_idx) in self.imputed_genotype_likelihoods_:
|
|
414
|
+
p10 = np.asarray(
|
|
415
|
+
self.imputed_genotype_likelihoods_[(sample, site_idx)], float
|
|
416
|
+
)
|
|
417
|
+
elif (sample, site_idx) in self.imputed_likelihoods_:
|
|
418
|
+
# Fallback: build 10-class from 4-class posterior (HW independence)
|
|
419
|
+
p4 = np.asarray(self.imputed_likelihoods_[(sample, site_idx)], float)
|
|
420
|
+
p4 = np.clip(p4, 0.0, np.inf)
|
|
421
|
+
p4 = p4 / (p4.sum() or 1.0)
|
|
422
|
+
# Unordered genotype probs: AA=pA^2, AC=2pApC, ...
|
|
423
|
+
p10 = np.zeros(10, dtype=float)
|
|
424
|
+
# class indices in our order:
|
|
425
|
+
# 0:AA, 1:AC, 2:AG, 3:AT, 4:CC, 5:CG, 6:CT, 7:GG, 8:GT, 9:TT
|
|
426
|
+
pA, pC, pG, pT = p4
|
|
427
|
+
p10[0] = pA * pA
|
|
428
|
+
p10[1] = 2 * pA * pC
|
|
429
|
+
p10[2] = 2 * pA * pG
|
|
430
|
+
p10[3] = 2 * pA * pT
|
|
431
|
+
p10[4] = pC * pC
|
|
432
|
+
p10[5] = 2 * pC * pG
|
|
433
|
+
p10[6] = 2 * pC * pT
|
|
434
|
+
p10[7] = pG * pG
|
|
435
|
+
p10[8] = 2 * pG * pT
|
|
436
|
+
p10[9] = pT * pT
|
|
437
|
+
p10 = p10 / (p10.sum() or 1.0)
|
|
438
|
+
else:
|
|
439
|
+
continue # no probabilities available
|
|
440
|
+
|
|
441
|
+
# Truth mapping to a single unordered class (skip ambiguous >2-allele codes)
|
|
442
|
+
c = str(true_code).upper()
|
|
443
|
+
lab10 = iupac_to_10_single.get(c, None)
|
|
444
|
+
if lab10 is None:
|
|
445
|
+
# Skip N, B, D, H, V, '-', etc., because they expand to multiple classes
|
|
446
|
+
continue
|
|
447
|
+
true_idx10 = label_to_10_idx[lab10]
|
|
448
|
+
rows_10.append((true_idx10, p10))
|
|
449
|
+
|
|
450
|
+
# Optional ordered-16 expansion (for confusion/metrics if requested)
|
|
451
|
+
if getattr(self, "diploid_ordered", False):
|
|
452
|
+
# Spread 10->16 via R (16x10)
|
|
453
|
+
p16 = self._dip_agg.R @ p10
|
|
454
|
+
p16 = p16 / (p16.sum() or 1.0)
|
|
455
|
+
i, j = geno_classes[true_idx10]
|
|
456
|
+
if i == j:
|
|
457
|
+
# Homozygote: single ordered index
|
|
458
|
+
true16 = ordered_pairs.index((i, j))
|
|
459
|
+
rows_16_truth.append(true16)
|
|
460
|
+
rows_16_proba.append(p16)
|
|
461
|
+
else:
|
|
462
|
+
# Heterozygote: duplicate as two ordered permutations
|
|
463
|
+
true16a = ordered_pairs.index((i, j))
|
|
464
|
+
true16b = ordered_pairs.index((j, i))
|
|
465
|
+
rows_16_truth.append(true16a)
|
|
466
|
+
rows_16_proba.append(p16)
|
|
467
|
+
rows_16_truth.append(true16b)
|
|
468
|
+
rows_16_proba.append(p16)
|
|
469
|
+
|
|
470
|
+
if not rows_10:
|
|
471
|
+
self.logger.warning("No valid diploid truth rows for evaluation.")
|
|
472
|
+
return
|
|
473
|
+
|
|
474
|
+
# ---- 10-class metrics ----
|
|
475
|
+
y_true_10 = np.array([t for (t, _) in rows_10], dtype=int)
|
|
476
|
+
y_pred_proba_10 = np.vstack([p for (_, p) in rows_10]) # (n,10)
|
|
477
|
+
|
|
478
|
+
y_pred_10 = np.argmax(y_pred_proba_10, axis=1)
|
|
479
|
+
y_true_ohe_10 = np.eye(10, dtype=float)[y_true_10]
|
|
480
|
+
|
|
481
|
+
self.logger.info("--- Per-Genotype Imputation Performance (10-class) ---")
|
|
482
|
+
self.evaluation_results_ = self.scorer.evaluate(
|
|
483
|
+
y_true=y_true_10,
|
|
484
|
+
y_pred=y_pred_10,
|
|
485
|
+
y_true_ohe=y_true_ohe_10,
|
|
486
|
+
y_pred_proba=y_pred_proba_10,
|
|
487
|
+
)
|
|
488
|
+
self.logger.info(f"Evaluation results (10-class): {self.evaluation_results_}")
|
|
489
|
+
|
|
490
|
+
# Plots (10-class)
|
|
491
|
+
self.plotter.plot_metrics(
|
|
492
|
+
y_true=y_true_10,
|
|
493
|
+
y_pred_proba=y_pred_proba_10,
|
|
494
|
+
metrics=self.evaluation_results_,
|
|
495
|
+
label_names=labels10,
|
|
496
|
+
)
|
|
497
|
+
self.plotter.plot_confusion_matrix(
|
|
498
|
+
y_true_1d=y_true_10, y_pred_1d=y_pred_10, label_names=labels10
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
# ---- Optional 16-class ordered metrics ----
|
|
502
|
+
if getattr(self, "diploid_ordered", False) and rows_16_truth:
|
|
503
|
+
y_true_16 = np.array(rows_16_truth, dtype=int)
|
|
504
|
+
y_pred_proba_16 = np.vstack(rows_16_proba)
|
|
505
|
+
y_pred_16 = np.argmax(y_pred_proba_16, axis=1)
|
|
506
|
+
y_true_ohe_16 = np.eye(16, dtype=float)[y_true_16]
|
|
507
|
+
|
|
508
|
+
self.logger.info(
|
|
509
|
+
"--- Per-Genotype Imputation Performance (16-class ordered) ---"
|
|
510
|
+
)
|
|
511
|
+
eval16 = self.scorer.evaluate(
|
|
512
|
+
y_true=y_true_16,
|
|
513
|
+
y_pred=y_pred_16,
|
|
514
|
+
y_true_ohe=y_true_ohe_16,
|
|
515
|
+
y_pred_proba=y_pred_proba_16,
|
|
516
|
+
)
|
|
517
|
+
self.logger.info(f"Evaluation results (16-class): {eval16}")
|
|
518
|
+
|
|
519
|
+
self.plotter.plot_metrics(
|
|
520
|
+
y_true=y_true_16,
|
|
521
|
+
y_pred_proba=y_pred_proba_16,
|
|
522
|
+
metrics=eval16,
|
|
523
|
+
label_names=labels16,
|
|
524
|
+
)
|
|
525
|
+
self.plotter.plot_confusion_matrix(
|
|
526
|
+
y_true_1d=y_true_16, y_pred_1d=y_pred_16, label_names=labels16
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
def _infer_proba_permutation(
|
|
530
|
+
self, y_ref: np.ndarray, P: np.ndarray, n_classes: int = 4
|
|
531
|
+
) -> tuple[np.ndarray, list[int]]:
|
|
532
|
+
"""Infers a permutation to align probability columns to the label space of y_ref."""
|
|
533
|
+
perm = [-1] * n_classes
|
|
534
|
+
taken = set()
|
|
535
|
+
for k in range(n_classes):
|
|
536
|
+
mask = y_ref == k
|
|
537
|
+
if not np.any(mask):
|
|
538
|
+
means = P.mean(axis=0)
|
|
539
|
+
else:
|
|
540
|
+
means = P[mask].mean(axis=0)
|
|
541
|
+
|
|
542
|
+
# Find best unused column
|
|
543
|
+
for col in np.argsort(means)[::-1]:
|
|
544
|
+
if col not in taken:
|
|
545
|
+
perm[k] = col
|
|
546
|
+
taken.add(col)
|
|
547
|
+
break
|
|
548
|
+
|
|
549
|
+
if len(taken) != n_classes: # Fallback if permutation is incomplete
|
|
550
|
+
unassigned_cols = [c for c in range(n_classes) if c not in taken]
|
|
551
|
+
for i in range(n_classes):
|
|
552
|
+
if perm[i] == -1:
|
|
553
|
+
perm[i] = unassigned_cols.pop(0)
|
|
554
|
+
|
|
555
|
+
self.logger.info(f"Inferred probability permutation (label->col): {perm}")
|
|
556
|
+
return P[:, perm], perm
|
|
557
|
+
|
|
558
|
+
def _stationary_pi(self, Q: np.ndarray) -> np.ndarray:
|
|
559
|
+
"""Robustly calculates the stationary distribution of Q."""
|
|
560
|
+
w, v = scipy.linalg.eig(Q.T)
|
|
561
|
+
k = int(np.argmin(np.abs(w)))
|
|
562
|
+
pi = np.real(v[:, k])
|
|
563
|
+
pi = np.maximum(pi, 0.0)
|
|
564
|
+
s = pi.sum()
|
|
565
|
+
return (pi / s) if s > 0 else np.ones(Q.shape[0]) / Q.shape[0]
|
|
566
|
+
|
|
567
|
+
def impute_phylo(
|
|
568
|
+
self,
|
|
569
|
+
genotypes: Dict[str, List[Union[str, int]]],
|
|
570
|
+
tree: tt.tree,
|
|
571
|
+
q_matrix: pd.DataFrame,
|
|
572
|
+
site_rates: Optional[List[float]],
|
|
573
|
+
) -> pd.DataFrame:
|
|
574
|
+
"""Imputes missing values using a phylogenetic guide tree."""
|
|
575
|
+
self.imputed_likelihoods_.clear()
|
|
576
|
+
self.imputed_genotype_likelihoods_.clear()
|
|
577
|
+
|
|
578
|
+
common_samples = set(tree.get_tip_labels()) & set(genotypes.keys())
|
|
579
|
+
if not common_samples:
|
|
580
|
+
raise ValueError("No samples in common between tree and genotypes.")
|
|
581
|
+
filt_genotypes = copy.deepcopy({s: genotypes[s] for s in common_samples})
|
|
582
|
+
num_snps = len(next(iter(filt_genotypes.values())))
|
|
583
|
+
|
|
584
|
+
if site_rates is not None and len(site_rates) != num_snps:
|
|
585
|
+
raise ValueError(
|
|
586
|
+
f"len(site_rates)={len(site_rates)} != num_snps={num_snps}"
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
# Stationary prior at root
|
|
590
|
+
if not self.haploid:
|
|
591
|
+
# pi4 from Q; then lump pi4⊗pi4 to 10 classes
|
|
592
|
+
pi4 = self._stationary_pi(q_matrix.to_numpy())
|
|
593
|
+
pi_ord = np.kron(pi4, pi4) # 16
|
|
594
|
+
pi10 = self._dip_agg.C @ pi_ord
|
|
595
|
+
pi10 = pi10 / (pi10.sum() or 1.0)
|
|
596
|
+
root_prior = pi10
|
|
597
|
+
n_states = 10
|
|
598
|
+
else:
|
|
599
|
+
root_prior = self._stationary_pi(q_matrix.to_numpy()) # 4
|
|
600
|
+
n_states = 4
|
|
601
|
+
|
|
602
|
+
for snp_index in range(num_snps):
|
|
603
|
+
rate = site_rates[snp_index] if site_rates is not None else 1.0
|
|
604
|
+
tips_with_missing = [
|
|
605
|
+
s
|
|
606
|
+
for s, seq in filt_genotypes.items()
|
|
607
|
+
if self._is_missing(seq[snp_index])
|
|
608
|
+
]
|
|
609
|
+
if not tips_with_missing:
|
|
610
|
+
continue
|
|
611
|
+
|
|
612
|
+
down_liks: Dict[int, np.ndarray] = {}
|
|
613
|
+
for node in tree.treenode.traverse("postorder"):
|
|
614
|
+
lik = np.zeros(n_states, dtype=float)
|
|
615
|
+
if node.is_leaf():
|
|
616
|
+
if node.name not in filt_genotypes or self._is_missing(
|
|
617
|
+
filt_genotypes[node.name][snp_index]
|
|
618
|
+
):
|
|
619
|
+
lik[:] = 1.0 # missing: uniform emission
|
|
620
|
+
else:
|
|
621
|
+
obs = filt_genotypes[node.name][snp_index]
|
|
622
|
+
if not self.haploid:
|
|
623
|
+
cls = self._iupac_to_genotype_classes(obs)
|
|
624
|
+
lik[cls] = 1.0
|
|
625
|
+
else:
|
|
626
|
+
for state in self._get_iupac_full(obs):
|
|
627
|
+
lik[self.char_map[state]] = 1.0
|
|
628
|
+
down_liks[node.idx] = lik / lik.sum()
|
|
629
|
+
else:
|
|
630
|
+
msg = np.ones(n_states, dtype=float)
|
|
631
|
+
for child in node.children:
|
|
632
|
+
P = self._qcache.P(
|
|
633
|
+
max(child.dist, self.min_branch_length), rate
|
|
634
|
+
)
|
|
635
|
+
msg *= P @ down_liks[child.idx]
|
|
636
|
+
down_liks[node.idx] = self._norm(msg)
|
|
637
|
+
|
|
638
|
+
up_liks: Dict[int, np.ndarray] = {tree.treenode.idx: root_prior.copy()}
|
|
639
|
+
for node in tree.treenode.traverse("preorder"):
|
|
640
|
+
if node.is_root():
|
|
641
|
+
continue
|
|
642
|
+
parent = node.up
|
|
643
|
+
sib_prod = np.ones(n_states, dtype=float)
|
|
644
|
+
for sib in parent.children:
|
|
645
|
+
if sib.idx == node.idx:
|
|
646
|
+
continue
|
|
647
|
+
P_sib = self._qcache.P(max(sib.dist, self.min_branch_length), rate)
|
|
648
|
+
sib_prod *= P_sib @ down_liks[sib.idx]
|
|
649
|
+
parent_msg = up_liks[parent.idx] * sib_prod
|
|
650
|
+
P = self._qcache.P(max(node.dist, self.min_branch_length), rate)
|
|
651
|
+
up = parent_msg @ P
|
|
652
|
+
up_liks[node.idx] = up / (up.sum() or 1.0)
|
|
653
|
+
|
|
654
|
+
for samp in tips_with_missing:
|
|
655
|
+
node = tree.get_nodes(samp)[0]
|
|
656
|
+
leaf_emission = down_liks[node.idx] # uniform if missing
|
|
657
|
+
tip_post = self._norm(up_liks[node.idx] * leaf_emission)
|
|
658
|
+
|
|
659
|
+
if self.ground_truth_ and (samp, snp_index) in self.ground_truth_:
|
|
660
|
+
if not self.haploid:
|
|
661
|
+
# store genotype posterior (10) and allele-marginal (4)
|
|
662
|
+
self.imputed_genotype_likelihoods_[(samp, snp_index)] = (
|
|
663
|
+
tip_post.copy()
|
|
664
|
+
)
|
|
665
|
+
self.imputed_likelihoods_[(samp, snp_index)] = (
|
|
666
|
+
self._marginalize_genotype_to_allele(tip_post)
|
|
667
|
+
)
|
|
668
|
+
else:
|
|
669
|
+
self.imputed_likelihoods_[(samp, snp_index)] = tip_post.copy()
|
|
670
|
+
|
|
671
|
+
if not self.haploid:
|
|
672
|
+
call = self._genotype_posterior_to_iupac(tip_post, mode="MAP")
|
|
673
|
+
else:
|
|
674
|
+
call = self._allele_posterior_to_iupac_genotype(
|
|
675
|
+
tip_post, mode="MAP"
|
|
676
|
+
)
|
|
677
|
+
filt_genotypes[samp][snp_index] = call
|
|
678
|
+
|
|
679
|
+
df = pd.DataFrame.from_dict(filt_genotypes, orient="index")
|
|
680
|
+
|
|
681
|
+
if df.applymap(self._is_missing).any().any():
|
|
682
|
+
raise AssertionError("Imputation failed. Missing values remain.")
|
|
683
|
+
|
|
684
|
+
return df
|
|
685
|
+
|
|
686
|
+
def _iupac_to_genotype_classes(self, char: str) -> list[int]:
|
|
687
|
+
"""Map IUPAC code to allowed unordered genotype class indices (10-state)."""
|
|
688
|
+
c = str(char).upper()
|
|
689
|
+
# Allele indices: A=0, C=1, G=2, T=3
|
|
690
|
+
single = {"A": 0, "C": 1, "G": 2, "T": 3}
|
|
691
|
+
het_map = { # two-allele ambiguity
|
|
692
|
+
"M": (0, 1),
|
|
693
|
+
"R": (0, 2),
|
|
694
|
+
"W": (0, 3),
|
|
695
|
+
"S": (2, 1),
|
|
696
|
+
"Y": (1, 3),
|
|
697
|
+
"K": (2, 3),
|
|
698
|
+
}
|
|
699
|
+
if c in single:
|
|
700
|
+
# homozygote only
|
|
701
|
+
return [[0, 4, 7, 9][single[c]]] # AA,CC,GG,TT class indices
|
|
702
|
+
if c in het_map:
|
|
703
|
+
i, j = het_map[c]
|
|
704
|
+
# find index in our class order (i<=j by construction)
|
|
705
|
+
i, j = (i, j) if i <= j else (j, i)
|
|
706
|
+
class_order = self._dip_agg.genotype_classes
|
|
707
|
+
return [class_order.index((i, j))]
|
|
708
|
+
# Ambiguity codes with >2 alleles or missing: allow any compatible genotype
|
|
709
|
+
amb = {
|
|
710
|
+
"N": {0, 1, 2, 3},
|
|
711
|
+
"-": {0, 1, 2, 3},
|
|
712
|
+
"B": {1, 2, 3},
|
|
713
|
+
"D": {0, 2, 3},
|
|
714
|
+
"H": {0, 1, 3},
|
|
715
|
+
"V": {0, 1, 2},
|
|
716
|
+
}
|
|
717
|
+
if c in amb:
|
|
718
|
+
allowed = amb[c]
|
|
719
|
+
classes = []
|
|
720
|
+
for idx, (i, j) in enumerate(self._dip_agg.genotype_classes):
|
|
721
|
+
if i in allowed and j in allowed:
|
|
722
|
+
classes.append(idx)
|
|
723
|
+
return classes
|
|
724
|
+
# Fallback: allow all
|
|
725
|
+
return list(range(10))
|
|
726
|
+
|
|
727
|
+
def _genotype_posterior_to_iupac(self, p10: np.ndarray, mode: str = "MAP") -> str:
|
|
728
|
+
"""Convert genotype posterior over 10 classes to an IUPAC code."""
|
|
729
|
+
p = np.asarray(p10, dtype=float)
|
|
730
|
+
p = p / (p.sum() or 1.0)
|
|
731
|
+
if mode.upper() == "SAMPLE":
|
|
732
|
+
k = int(np.random.choice(len(p), p=p))
|
|
733
|
+
else:
|
|
734
|
+
k = int(np.argmax(p))
|
|
735
|
+
i, j = self._dip_agg.genotype_classes[k]
|
|
736
|
+
alleles = ["A", "C", "G", "T"]
|
|
737
|
+
gt = alleles[i] + alleles[j]
|
|
738
|
+
return self._genotype_to_iupac(gt)
|
|
739
|
+
|
|
740
|
+
def _marginalize_genotype_to_allele(self, p10: np.ndarray) -> np.ndarray:
|
|
741
|
+
"""Allele posterior from genotype posterior for eval/diagnostics (length 4)."""
|
|
742
|
+
p10 = np.asarray(p10, dtype=float)
|
|
743
|
+
p10 = p10 / (p10.sum() or 1.0)
|
|
744
|
+
agg = self._dip_agg
|
|
745
|
+
out = np.zeros(4, dtype=float)
|
|
746
|
+
for a in range(4):
|
|
747
|
+
out[a] = p10[agg.hom_class_index[a]] + 0.5 * np.sum(
|
|
748
|
+
p10[agg.het_classes_by_allele[a]]
|
|
749
|
+
)
|
|
750
|
+
s = out.sum()
|
|
751
|
+
return out / (s or 1.0)
|
|
752
|
+
|
|
753
|
+
def _parse_arguments(
|
|
754
|
+
self,
|
|
755
|
+
) -> Tuple[
|
|
756
|
+
Dict[str, List[Union[str, int]]], tt.tree, pd.DataFrame, Optional[List[float]]
|
|
757
|
+
]:
|
|
758
|
+
if (
|
|
759
|
+
not hasattr(self.genotype_data, "snpsdict")
|
|
760
|
+
or self.genotype_data.snpsdict is None
|
|
761
|
+
):
|
|
762
|
+
raise TypeError("`GenotypeData.snpsdict` must be defined.")
|
|
763
|
+
if not hasattr(self.tree_parser, "tree") or self.tree_parser.tree is None:
|
|
764
|
+
raise TypeError("`TreeParser.tree` must be defined.")
|
|
765
|
+
if not hasattr(self.tree_parser, "qmat") or self.tree_parser.qmat is None:
|
|
766
|
+
raise TypeError("`TreeParser.qmat` must be defined.")
|
|
767
|
+
site_rates = getattr(self.tree_parser, "site_rates", None)
|
|
768
|
+
return (
|
|
769
|
+
self.genotype_data.snpsdict,
|
|
770
|
+
self.tree_parser.tree,
|
|
771
|
+
self.tree_parser.qmat,
|
|
772
|
+
site_rates,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
def _get_iupac_full(self, char: str) -> List[str]:
|
|
776
|
+
iupac_map = {
|
|
777
|
+
"A": ["A"],
|
|
778
|
+
"G": ["G"],
|
|
779
|
+
"C": ["C"],
|
|
780
|
+
"T": ["T"],
|
|
781
|
+
"N": ["A", "C", "T", "G"],
|
|
782
|
+
"-": ["A", "C", "T", "G"],
|
|
783
|
+
"R": ["A", "G"],
|
|
784
|
+
"Y": ["C", "T"],
|
|
785
|
+
"S": ["G", "C"],
|
|
786
|
+
"W": ["A", "T"],
|
|
787
|
+
"K": ["G", "T"],
|
|
788
|
+
"M": ["A", "C"],
|
|
789
|
+
"B": ["C", "G", "T"],
|
|
790
|
+
"D": ["A", "G", "T"],
|
|
791
|
+
"H": ["A", "C", "T"],
|
|
792
|
+
"V": ["A", "C", "G"],
|
|
793
|
+
}
|
|
794
|
+
return iupac_map.get(char.upper(), ["A", "C", "T", "G"])
|
|
795
|
+
|
|
796
|
+
def _allele_posterior_to_genotype_probs(self, p: np.ndarray) -> Dict[str, float]:
|
|
797
|
+
p = np.maximum(p, 0.0)
|
|
798
|
+
s = p.sum()
|
|
799
|
+
p = (p / s) if s > 0 else np.ones_like(p) / len(p)
|
|
800
|
+
alleles = ["A", "C", "G", "T"]
|
|
801
|
+
probs = {}
|
|
802
|
+
for i, a in enumerate(alleles):
|
|
803
|
+
for j, b in enumerate(alleles[i:], start=i):
|
|
804
|
+
probs[a + b] = float(p[i] * p[j] * (2.0 if i != j else 1.0))
|
|
805
|
+
z = sum(probs.values()) or 1.0
|
|
806
|
+
return {k: v / z for k, v in probs.items()}
|
|
807
|
+
|
|
808
|
+
def _genotype_to_iupac(self, gt: str) -> str:
|
|
809
|
+
if gt[0] == gt[1]:
|
|
810
|
+
return gt[0]
|
|
811
|
+
pair = "".join(sorted(gt))
|
|
812
|
+
het_map = {"AC": "M", "AG": "R", "AT": "W", "CG": "S", "CT": "Y", "GT": "K"}
|
|
813
|
+
return het_map.get(pair, "N")
|
|
814
|
+
|
|
815
|
+
def _allele_posterior_to_iupac_genotype(
|
|
816
|
+
self, p: np.ndarray, mode: str = "MAP"
|
|
817
|
+
) -> str:
|
|
818
|
+
gprobs = self._allele_posterior_to_genotype_probs(p)
|
|
819
|
+
if mode.upper() == "SAMPLE":
|
|
820
|
+
gts, vals = zip(*gprobs.items())
|
|
821
|
+
choice = np.random.choice(len(gts), p=np.array(vals, dtype=float))
|
|
822
|
+
return self._genotype_to_iupac(gts[choice])
|
|
823
|
+
best_gt = max(gprobs, key=gprobs.get)
|
|
824
|
+
return self._genotype_to_iupac(best_gt)
|
|
825
|
+
|
|
826
|
+
def _align_and_check_q(self, q_df: pd.DataFrame) -> pd.DataFrame:
|
|
827
|
+
"""Return Q reindexed to ['A','C','G','T'] and assert CTMC sanity.
|
|
828
|
+
|
|
829
|
+
Args:
|
|
830
|
+
q_df: Square rate matrix with nucleotide index/columns.
|
|
831
|
+
|
|
832
|
+
Returns:
|
|
833
|
+
Reindexed Q as a DataFrame in A,C,G,T order.
|
|
834
|
+
|
|
835
|
+
Raises:
|
|
836
|
+
ValueError: If Q is malformed or not alignable.
|
|
837
|
+
"""
|
|
838
|
+
required = ["A", "C", "G", "T"]
|
|
839
|
+
if not set(required).issubset(set(q_df.index)) or not set(required).issubset(
|
|
840
|
+
set(q_df.columns)
|
|
841
|
+
):
|
|
842
|
+
raise ValueError("Q must have index/columns including exactly A,C,G,T.")
|
|
843
|
+
|
|
844
|
+
q_df = q_df.loc[required, required].astype(float)
|
|
845
|
+
|
|
846
|
+
q = q_df.to_numpy()
|
|
847
|
+
|
|
848
|
+
# Off-diagonals >= 0, diagonals <= 0, rows sum ~ 0
|
|
849
|
+
if np.any(q[~np.eye(4, dtype=bool)] < -1e-12):
|
|
850
|
+
raise ValueError("Q off-diagonal entries must be non-negative.")
|
|
851
|
+
|
|
852
|
+
if np.any(np.diag(q) > 1e-12):
|
|
853
|
+
raise ValueError("Q diagonal entries must be non-positive.")
|
|
854
|
+
|
|
855
|
+
if not np.allclose(q.sum(axis=1), 0.0, atol=1e-8):
|
|
856
|
+
self.logger.error(q.sum(axis=1))
|
|
857
|
+
raise ValueError("Q rows must sum to 0.")
|
|
858
|
+
|
|
859
|
+
return pd.DataFrame(q, index=required, columns=required)
|
|
860
|
+
|
|
861
|
+
def _stationary_pi_from_q(self, Q: np.ndarray) -> np.ndarray:
|
|
862
|
+
"""Compute stationary distribution pi for generator Q.
|
|
863
|
+
|
|
864
|
+
Uses eigenvector of Q^T at eigenvalue 0; clips tiny negatives; renormalizes.
|
|
865
|
+
|
|
866
|
+
Args:
|
|
867
|
+
Q: (4,4) CTMC generator.
|
|
868
|
+
|
|
869
|
+
Returns:
|
|
870
|
+
(4,) stationary distribution pi with sum=1.
|
|
871
|
+
"""
|
|
872
|
+
w, v = scipy.linalg.eig(Q.T)
|
|
873
|
+
k = int(np.argmin(np.abs(w)))
|
|
874
|
+
pi = np.real(v[:, k])
|
|
875
|
+
pi = np.maximum(pi, 0.0)
|
|
876
|
+
s = float(pi.sum())
|
|
877
|
+
if not np.isfinite(s) or s <= 0:
|
|
878
|
+
# Fallback uniform if eigen failed
|
|
879
|
+
return np.ones(Q.shape[0]) / Q.shape[0]
|
|
880
|
+
return pi / s
|
|
881
|
+
|
|
882
|
+
def _repair_and_scale_q(
|
|
883
|
+
self,
|
|
884
|
+
q_df: pd.DataFrame,
|
|
885
|
+
target_mean_rate: float = 1.0,
|
|
886
|
+
state_order: tuple[str, ...] = ("A", "C", "G", "T"),
|
|
887
|
+
neg_offdiag_tol: float = 1e-8,
|
|
888
|
+
) -> pd.DataFrame:
|
|
889
|
+
"""Repair Q to a valid CTMC generator and scale its mean rate.
|
|
890
|
+
|
|
891
|
+
Steps:
|
|
892
|
+
1) Reindex to `state_order` and cast to float.
|
|
893
|
+
2) Set negative off-diagonals to 0 (warn if any < -neg_offdiag_tol).
|
|
894
|
+
3) Set diagonal q_ii = -sum_{j!=i} q_ij so rows sum to 0 exactly.
|
|
895
|
+
4) If a row has zero off-diagonal sum, inject a tiny uniform exit rate.
|
|
896
|
+
5) Compute stationary pi and scale Q so that -sum_i pi_i q_ii = target_mean_rate.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
q_df: DataFrame with index=columns=states.
|
|
900
|
+
target_mean_rate: Desired average rate under pi (commonly 1.0).
|
|
901
|
+
state_order: Desired nucleotide order.
|
|
902
|
+
neg_offdiag_tol: Tolerance to log warnings for negative off-diags.
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
Repaired, scaled Q as a DataFrame in `state_order`.
|
|
906
|
+
"""
|
|
907
|
+
# 1) Align
|
|
908
|
+
missing = set(state_order) - set(q_df.index) | set(state_order) - set(
|
|
909
|
+
q_df.columns
|
|
910
|
+
)
|
|
911
|
+
if missing:
|
|
912
|
+
raise ValueError(
|
|
913
|
+
f"Q must have states {state_order}, missing: {sorted(missing)}"
|
|
914
|
+
)
|
|
915
|
+
Q = q_df.loc[state_order, state_order].to_numpy(dtype=float, copy=True)
|
|
916
|
+
|
|
917
|
+
# 2) Clip negative off-diagonals
|
|
918
|
+
off_mask = ~np.eye(Q.shape[0], dtype=bool)
|
|
919
|
+
neg_off = Q[off_mask] < 0
|
|
920
|
+
if np.any(Q[off_mask] < -neg_offdiag_tol):
|
|
921
|
+
self.logger.warning(
|
|
922
|
+
f"Q has negative off-diagonals; clipping {int(np.sum(neg_off))} entries."
|
|
923
|
+
)
|
|
924
|
+
Q[off_mask] = np.maximum(Q[off_mask], 0.0)
|
|
925
|
+
|
|
926
|
+
# 3) Set diagonal so rows sum to 0 exactly
|
|
927
|
+
row_off_sum = Q.sum(axis=1) - np.diag(Q) # includes diag; fix next
|
|
928
|
+
np.fill_diagonal(Q, -row_off_sum)
|
|
929
|
+
|
|
930
|
+
# 4) Ensure no absorbing rows
|
|
931
|
+
zero_rows = row_off_sum <= 0
|
|
932
|
+
if np.any(zero_rows):
|
|
933
|
+
eps = 1e-8
|
|
934
|
+
Q[zero_rows, :] = 0.0
|
|
935
|
+
for i in np.where(zero_rows)[0]:
|
|
936
|
+
Q[i, :] = eps / (Q.shape[0] - 1)
|
|
937
|
+
Q[i, i] = -eps
|
|
938
|
+
self.logger.warning(
|
|
939
|
+
f"Injected tiny exit rates in {int(np.sum(zero_rows))} rows."
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
# Sanity: rows now sum to 0 exactly
|
|
943
|
+
if not np.allclose(Q.sum(axis=1), 0.0, atol=1e-12):
|
|
944
|
+
raise RuntimeError("Internal error: row sums not zero after repair.")
|
|
945
|
+
|
|
946
|
+
# 5) Scale to target mean rate under stationary pi
|
|
947
|
+
pi = self._stationary_pi_from_q(Q)
|
|
948
|
+
mean_rate = float(-np.dot(pi, np.diag(Q)))
|
|
949
|
+
if not (np.isfinite(mean_rate) and mean_rate > 0):
|
|
950
|
+
self.logger.warning("Mean rate non-positive; skipping scaling.")
|
|
951
|
+
scale = 1.0
|
|
952
|
+
else:
|
|
953
|
+
scale = mean_rate / float(target_mean_rate)
|
|
954
|
+
Q /= scale
|
|
955
|
+
|
|
956
|
+
return pd.DataFrame(Q, index=state_order, columns=state_order)
|
|
957
|
+
|
|
958
|
+
def _is_missing(self, x: object) -> bool:
|
|
959
|
+
"""Return True if x represents a missing genotype token."""
|
|
960
|
+
s = str(x).strip()
|
|
961
|
+
return s in self._MISSING_TOKENS
|
|
962
|
+
|
|
963
|
+
def _is_known(self, x: object) -> bool:
|
|
964
|
+
return not self._is_missing(x)
|
|
965
|
+
|
|
966
|
+
@staticmethod
|
|
967
|
+
def _norm(v: np.ndarray) -> np.ndarray:
|
|
968
|
+
s = float(np.sum(v))
|
|
969
|
+
if s <= 0 or not np.isfinite(s):
|
|
970
|
+
return np.ones_like(v) / v.size
|
|
971
|
+
return v / s
|