pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,973 @@
1
+ # type: ignore
2
+
3
+ import copy
4
+ import random
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
+
7
+ # Third-party imports
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.linalg
11
+ import toytree as tt
12
+
13
+ from pgsui.utils.plotting import Plotting
14
+ from pgsui.utils.scorers import Scorer
15
+
16
+ if TYPE_CHECKING:
17
+ from snpio.analysis.tree_parser import TreeParser
18
+ from snpio.read_input.genotype_data import GenotypeData
19
+
20
+
21
+ class _DiploidAggregator:
22
+ """Precompute mappings to lump ordered diploid states (16) into unordered genotypes (10).
23
+
24
+ Notes:
25
+ Genotype class order matches the SNPio allele → genotype encodings:
26
+ - 0: AA, 1: AC, 2: AG, 3: AT, 4: CC, 5: CG, 6: CT, 7: GG, 8: GT, 9: T
27
+ - Allele order is A(0), C(1), G(2), T(3).
28
+ """
29
+
30
+ def __init__(self) -> None:
31
+ # Ordered pairs (i,j) lexicographic over {0,1,2,3}; 16 states
32
+ self.ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
33
+ # Unordered genotype classes as (min,max) in the specified class order
34
+ self.genotype_classes = (
35
+ [(0, 0)]
36
+ + [(0, 1), (0, 2), (0, 3)]
37
+ + [(1, 1), (1, 2), (1, 3)]
38
+ + [(2, 2), (2, 3)]
39
+ + [(3, 3)]
40
+ ) # 10 states
41
+
42
+ # Map: class index -> list of ordered-state indices
43
+ self.class_to_ordered: list[list[int]] = []
44
+ for i, j in self.genotype_classes:
45
+ members = []
46
+ for k, (a, b) in enumerate(self.ordered_pairs):
47
+ if i == j: # homozygote
48
+ if a == i and b == j:
49
+ members.append(k) # exactly one member
50
+ else: # heterozygote, both permutations
51
+ if (a == i and b == j) or (a == j and b == i):
52
+ members.append(k)
53
+ self.class_to_ordered.append(members)
54
+
55
+ # Build R (16x10) and C (10x16)
56
+ self.R = np.zeros((16, 10), dtype=float)
57
+ self.C = np.zeros((10, 16), dtype=float)
58
+ for c, members in enumerate(self.class_to_ordered):
59
+ m = float(len(members))
60
+ for o in members:
61
+ self.R[o, c] = 1.0 / m # spread class prob equally to ordered members
62
+ self.C[c, o] = 1.0 # sum ordered probs back to class
63
+
64
+ # For allele-marginalization from 10-state genotype posterior to 4 alleles
65
+ # p_allele[i] = P(ii) + 0.5 * sum_{j!=i} P(min(i,j), max(i,j))
66
+ self.het_classes_by_allele: dict[int, list[int]] = {0: [], 1: [], 2: [], 3: []}
67
+ self.hom_class_index = {0: 0, 1: 4, 2: 7, 3: 9} # AA, CC, GG, TT indices
68
+ for idx, (i, j) in enumerate(self.genotype_classes):
69
+ if i != j:
70
+ self.het_classes_by_allele[i].append(idx)
71
+ self.het_classes_by_allele[j].append(idx)
72
+
73
+
74
+ class _QCache:
75
+ """Cache P(t) for haploid (4) and optionally diploid (10) via Kronecker + lumping."""
76
+
77
+ def __init__(
78
+ self,
79
+ q_df: pd.DataFrame,
80
+ mode: str = "haploid",
81
+ diploid_agg: "_DiploidAggregator | None" = None,
82
+ ) -> None:
83
+ """Precompute eigendecomposition of haploid Q.
84
+
85
+ Args:
86
+ q_df: 4x4 generator in A,C,G,T order.
87
+ mode: "haploid" or "diploid".
88
+ diploid_agg: required if mode="diploid".
89
+ """
90
+ m = np.asarray(q_df, dtype=float)
91
+ evals, V = scipy.linalg.eig(m)
92
+ Vinv = scipy.linalg.inv(V)
93
+ self.evals = evals
94
+ self.V = V
95
+ self.Vinv = Vinv
96
+ self.mode = mode
97
+ self.agg = diploid_agg
98
+ if self.mode == "diploid" and self.agg is None:
99
+ raise ValueError("Diploid mode requires a _DiploidAggregator.")
100
+ self._cache4: dict[float, np.ndarray] = {}
101
+ self._cache10: dict[float, np.ndarray] = {}
102
+
103
+ def _P4(self, s: float) -> np.ndarray:
104
+ """Return P(t) for haploid (4 states)."""
105
+ key = round(float(s), 12)
106
+ if key in self._cache4:
107
+ return self._cache4[key]
108
+ expo = np.exp(self.evals * key)
109
+ P4 = (self.V * expo) @ self.Vinv
110
+ P4 = np.real_if_close(P4, tol=1e5)
111
+ P4[P4 < 0.0] = 0.0
112
+ P4 /= P4.sum(axis=1, keepdims=True).clip(min=1.0)
113
+ self._cache4[key] = P4
114
+ return P4
115
+
116
+ def P(self, t: float, rate: float = 1.0) -> np.ndarray:
117
+ """Return P(t) in the active mode."""
118
+ s = float(rate) * float(t)
119
+ if self.mode == "haploid":
120
+ return self._P4(s)
121
+ # diploid
122
+ key = round(s, 12)
123
+ if key in self._cache10:
124
+ return self._cache10[key]
125
+ P4 = self._P4(s)
126
+ P16 = np.kron(P4, P4) # independent alleles
127
+ P10 = self.agg.C @ P16 @ self.agg.R # lump to unordered genotypes
128
+ P10 = np.maximum(P10, 0.0)
129
+ P10 /= P10.sum(axis=1, keepdims=True).clip(min=1.0)
130
+ self._cache10[key] = P10
131
+ return P10
132
+
133
+
134
+ class ImputePhylo:
135
+ """Imputes missing genotype data using a phylogenetic likelihood model.
136
+
137
+ This imputer uses a continuous-time Markov chain (CTMC) model of sequence evolution to impute missing genotype data based on a provided phylogenetic tree. It supports both haploid and diploid data, with options for evaluating imputation accuracy through simulated missingness.
138
+
139
+ Notes:
140
+ - Haploid CTMC (4 states: A,C,G,T) [default].
141
+ - Diploid CTMC (10 unordered genotype states: AA,...,TT), derived by independent allele evolution and lumping ordered pairs.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ genotype_data: "GenotypeData",
147
+ tree_parser: "TreeParser",
148
+ prefix: str,
149
+ min_branch_length: float = 1e-10,
150
+ *,
151
+ haploid: bool = False,
152
+ eval_missing_rate: float = 0.0,
153
+ column_subset: Optional[List[int]] = None,
154
+ save_plots: bool = False,
155
+ verbose: bool = False,
156
+ debug: bool = False,
157
+ ) -> None:
158
+ self.genotype_data = genotype_data
159
+ self.tree_parser = tree_parser
160
+ self.prefix = prefix
161
+ self.min_branch_length = min_branch_length
162
+ self.eval_missing_rate = eval_missing_rate
163
+ self.column_subset = column_subset
164
+ self.save_plots = save_plots
165
+ self.logger = genotype_data.logger
166
+ self.verbose = verbose
167
+ self.debug = debug
168
+ self.char_map = {"A": 0, "C": 1, "G": 2, "T": 3}
169
+ self.nuc_map = {v: k for k, v in self.char_map.items()}
170
+ self.imputer_data_: Optional[Tuple] = None
171
+ self.ground_truth_: Optional[Dict] = None
172
+ self.scorer = Scorer(self.prefix, average="macro", verbose=verbose, debug=debug)
173
+ self.plotter = Plotting(
174
+ "ImputePhylo",
175
+ prefix=self.prefix,
176
+ plot_format=self.genotype_data.plot_format,
177
+ plot_fontsize=self.genotype_data.plot_fontsize,
178
+ plot_dpi=self.genotype_data.plot_dpi,
179
+ title_fontsize=self.genotype_data.plot_fontsize,
180
+ verbose=verbose,
181
+ debug=debug,
182
+ )
183
+
184
+ self._MISSING_TOKENS = {"N", "n", "-9", ".", "?", "./.", ""}
185
+
186
+ self.haploid = haploid
187
+ self._dip_agg: _DiploidAggregator | None = None
188
+
189
+ self.imputed_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
190
+ self.imputed_genotype_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
191
+ self.evaluation_results_: dict | None = None
192
+
193
+ def fit(self) -> "ImputePhylo":
194
+ """Prepares the imputer by parsing and validating input data.
195
+
196
+ This method does the following:
197
+ - Validates the genotype data and phylogenetic tree.
198
+ - Extracts the genotype matrix, pruned tree, Q-matrix, and site rates.
199
+ - Sets up internal structures for imputation.
200
+ """
201
+ self.imputer_data_ = self._parse_arguments()
202
+ return self
203
+
204
+ def transform(self) -> np.ndarray:
205
+ """Transforms the data by imputing missing values.
206
+
207
+ This method does the following:
208
+ - Uses the fitted imputer to perform phylogenetic imputation.
209
+ - Returns the imputed genotype DataFrame.
210
+ """
211
+ if self.imputer_data_ is None:
212
+ msg = "The imputer has not been fitted. Call 'fit' first."
213
+ self.logger.error(msg)
214
+ raise RuntimeError(msg)
215
+
216
+ original_genotypes, tree, q_matrix_in, site_rates = self.imputer_data_
217
+ q_matrix = self._repair_and_scale_q(q_matrix_in, target_mean_rate=1.0)
218
+
219
+ if not self.haploid:
220
+ if self._dip_agg is None:
221
+ self._dip_agg = _DiploidAggregator()
222
+ self._qcache = _QCache(q_matrix, mode="diploid", diploid_agg=self._dip_agg)
223
+ else:
224
+ self._qcache = _QCache(q_matrix, mode="haploid")
225
+
226
+ genotypes_to_impute = original_genotypes
227
+
228
+ if self.eval_missing_rate > 0:
229
+ genotypes_to_impute = self._simulate_missing_data(original_genotypes)
230
+
231
+ imputed_df = self.impute_phylo(genotypes_to_impute, tree, q_matrix, site_rates)
232
+
233
+ if self.ground_truth_:
234
+ self._evaluate_imputation(imputed_df)
235
+ return imputed_df # keep as DataFrame with sample index
236
+
237
+ def fit_transform(self) -> pd.DataFrame:
238
+ """Fits the imputer and transforms the data in one step.
239
+
240
+ This method does the following:
241
+ - Calls the `fit` method to prepare the imputer.
242
+ - Calls the `transform` method to perform imputation.
243
+ - Returns the imputed genotype DataFrame.
244
+ """
245
+ self.fit()
246
+ return self.transform()
247
+
248
+ def _simulate_missing_data(
249
+ self, original_genotypes: Dict[str, List[str]]
250
+ ) -> Dict[str, List[str]]:
251
+ """Masks a fraction of known genotypes for evaluation."""
252
+ genotypes_to_impute = copy.deepcopy(original_genotypes)
253
+ known_positions = [
254
+ (sample, site_idx)
255
+ for sample, seq in genotypes_to_impute.items()
256
+ for site_idx, base in enumerate(seq)
257
+ if self._is_known(base)
258
+ ]
259
+
260
+ if not known_positions:
261
+ self.logger.warning("No known values to mask for evaluation.")
262
+ return genotypes_to_impute
263
+
264
+ num_to_mask = int(len(known_positions) * self.eval_missing_rate)
265
+ if num_to_mask == 0:
266
+ self.logger.warning(f"eval_missing_rate is too low to mask any values.")
267
+ return genotypes_to_impute
268
+
269
+ # Sample the (sample, site_idx) tuples to be masked
270
+ positions_to_mask = random.sample(known_positions, num_to_mask)
271
+
272
+ # Correctly build the ground_truth dictionary
273
+ self.ground_truth_ = {}
274
+ for sample, site_idx in positions_to_mask:
275
+ # Store the original base before masking it
276
+ self.ground_truth_[(sample, site_idx)] = genotypes_to_impute[sample][
277
+ site_idx
278
+ ]
279
+ # Now, apply the mask
280
+ genotypes_to_impute[sample][site_idx] = "N"
281
+
282
+ self.logger.info(f"Masked {len(self.ground_truth_)} values for evaluation.")
283
+ return genotypes_to_impute
284
+
285
+ def _evaluate_imputation(self, imputed_df: pd.DataFrame) -> None:
286
+ """Evaluate imputation with 4 classes (haploid) or 10/16 classes (diploid).
287
+
288
+ Behavior:
289
+ - If self.haploid is True:
290
+ Evaluate per-allele with 4 classes (A,C,G,T), keeping your existing logic.
291
+ - If self.haploid is False:
292
+ Evaluate per-genotype.
293
+ * Default: 10 unordered classes (AA,AC,AG,AT,CC,CG,CT,GG,GT,TT).
294
+ * If getattr(self, 'diploid_ordered', False) is True:
295
+ Also compute 16-class ordered metrics (AA,AC,...,TT).
296
+ Requires self._dip_agg (the _DiploidAggregator). If absent, it is created.
297
+
298
+ Notes:
299
+ - Truth handling: only unambiguous IUPAC that maps to exactly one genotype class
300
+ is used for 10-class evaluation (e.g., A, C, G, T, M, R, W, S, Y, K).
301
+ Ambiguous codes (N, B, D, H, V, '-', etc.) are skipped for strict genotype
302
+ evaluation.
303
+ - Probabilities: prefer stored 10-class posteriors in
304
+ self.imputed_genotype_likelihoods_; if missing, fall back by converting the
305
+ 4-class allele posterior to 10-class using HW independence.
306
+ """
307
+ if not self.ground_truth_:
308
+ return
309
+
310
+ # -------------------------
311
+ # Haploid path (4 classes)
312
+ # -------------------------
313
+ if getattr(self, "haploid", True):
314
+ # --- Gather ground truth bases and posterior vectors (only if both exist)
315
+ rows = []
316
+ for (sample, site_idx), true_base in self.ground_truth_.items():
317
+ if (
318
+ sample in imputed_df.index
319
+ and (sample, site_idx) in self.imputed_likelihoods_
320
+ ):
321
+ rows.append(
322
+ (true_base, self.imputed_likelihoods_[(sample, site_idx)])
323
+ )
324
+ if not rows:
325
+ self.logger.warning("No matching masked sites found to evaluate.")
326
+ return
327
+
328
+ true_bases = np.array([b for (b, _) in rows], dtype=object)
329
+ proba_mat = np.vstack([p for (_, p) in rows]) # (n,4) in A,C,G,T
330
+ nuc_to_idx = self.char_map # {'A':0,'C':1,'G':2,'T':3}
331
+
332
+ y_true_site = np.array(
333
+ [nuc_to_idx.get(str(b).upper(), -1) for b in true_bases], dtype=int
334
+ )
335
+ valid_mask = (y_true_site >= 0) & np.isfinite(proba_mat).all(axis=1)
336
+ if not np.any(valid_mask):
337
+ self.logger.error(
338
+ "Evaluation arrays empty after filtering invalid entries."
339
+ )
340
+ return
341
+
342
+ y_true_site = y_true_site[valid_mask]
343
+ proba_mat = proba_mat[valid_mask]
344
+
345
+ # Interleave each site as two alleles (iid/HW assumption) for per-allele metrics
346
+ y_true_int = np.repeat(y_true_site, 2)
347
+ y_pred_proba = np.repeat(proba_mat, 2, axis=0)
348
+ y_pred_int = np.argmax(y_pred_proba, axis=1)
349
+
350
+ n_classes = 4
351
+ y_true_ohe = np.eye(n_classes, dtype=float)[y_true_int]
352
+ idx_to_nuc = {v: k for k, v in nuc_to_idx.items()}
353
+
354
+ self.logger.info("--- Per-Allele Imputation Performance (4-class) ---")
355
+ self.evaluation_results_ = self.scorer.evaluate(
356
+ y_true=y_true_int,
357
+ y_pred=y_pred_int,
358
+ y_true_ohe=y_true_ohe,
359
+ y_pred_proba=y_pred_proba,
360
+ )
361
+ self.logger.info(f"Evaluation results: {self.evaluation_results_}")
362
+
363
+ # Plots
364
+ self.plotter.plot_metrics(
365
+ y_true=y_true_int,
366
+ y_pred_proba=y_pred_proba,
367
+ metrics=self.evaluation_results_,
368
+ )
369
+ self.plotter.plot_confusion_matrix(
370
+ y_true_1d=y_true_int,
371
+ y_pred_1d=y_pred_int,
372
+ label_names=["A", "C", "G", "T"],
373
+ )
374
+ return
375
+
376
+ # -------------------------
377
+ # Diploid path (10/16 classes)
378
+ # -------------------------
379
+ # Ensure aggregator
380
+ if getattr(self, "_dip_agg", None) is None:
381
+ self._dip_agg = _DiploidAggregator()
382
+
383
+ # Label catalogs
384
+ geno_classes = self._dip_agg.genotype_classes # list of (i,j) with i<=j
385
+ alleles = ["A", "C", "G", "T"]
386
+ labels10 = [f"{alleles[i]}{alleles[j]}" for (i, j) in geno_classes] # 10 labels
387
+ ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
388
+ labels16 = [
389
+ f"{alleles[i]}{alleles[j]}" for (i, j) in ordered_pairs
390
+ ] # 16 labels
391
+
392
+ # Helper: strict IUPAC→single 10-class index (skip ambiguous >2-allele codes)
393
+ iupac_to_10_single = {
394
+ "A": "AA",
395
+ "C": "CC",
396
+ "G": "GG",
397
+ "T": "TT",
398
+ "M": "AC",
399
+ "R": "AG",
400
+ "W": "AT",
401
+ "S": "CG",
402
+ "Y": "CT",
403
+ "K": "GT",
404
+ }
405
+ label_to_10_idx = {lab: idx for idx, lab in enumerate(labels10)}
406
+
407
+ rows_10: list[tuple[int, np.ndarray]] = [] # (true_idx10, p10)
408
+ rows_16_truth: list[int] = [] # expanded true labels (for 16)
409
+ rows_16_proba: list[np.ndarray] = [] # matching (p16) rows
410
+
411
+ # Collect rows
412
+ for (sample, site_idx), true_code in self.ground_truth_.items():
413
+ # Need probabilities for this masked (sample,site)
414
+ p10 = None
415
+ if (sample, site_idx) in self.imputed_genotype_likelihoods_:
416
+ p10 = np.asarray(
417
+ self.imputed_genotype_likelihoods_[(sample, site_idx)], float
418
+ )
419
+ elif (sample, site_idx) in self.imputed_likelihoods_:
420
+ # Fallback: build 10-class from 4-class posterior (HW independence)
421
+ p4 = np.asarray(self.imputed_likelihoods_[(sample, site_idx)], float)
422
+ p4 = np.clip(p4, 0.0, np.inf)
423
+ p4 = p4 / (p4.sum() or 1.0)
424
+ # Unordered genotype probs: AA=pA^2, AC=2pApC, ...
425
+ p10 = np.zeros(10, dtype=float)
426
+ # class indices in our order:
427
+ # 0:AA, 1:AC, 2:AG, 3:AT, 4:CC, 5:CG, 6:CT, 7:GG, 8:GT, 9:TT
428
+ pA, pC, pG, pT = p4
429
+ p10[0] = pA * pA
430
+ p10[1] = 2 * pA * pC
431
+ p10[2] = 2 * pA * pG
432
+ p10[3] = 2 * pA * pT
433
+ p10[4] = pC * pC
434
+ p10[5] = 2 * pC * pG
435
+ p10[6] = 2 * pC * pT
436
+ p10[7] = pG * pG
437
+ p10[8] = 2 * pG * pT
438
+ p10[9] = pT * pT
439
+ p10 = p10 / (p10.sum() or 1.0)
440
+ else:
441
+ continue # no probabilities available
442
+
443
+ # Truth mapping to a single unordered class (skip ambiguous >2-allele codes)
444
+ c = str(true_code).upper()
445
+ lab10 = iupac_to_10_single.get(c, None)
446
+ if lab10 is None:
447
+ # Skip N, B, D, H, V, '-', etc., because they expand to multiple classes
448
+ continue
449
+ true_idx10 = label_to_10_idx[lab10]
450
+ rows_10.append((true_idx10, p10))
451
+
452
+ # Optional ordered-16 expansion (for confusion/metrics if requested)
453
+ if getattr(self, "diploid_ordered", False):
454
+ # Spread 10->16 via R (16x10)
455
+ p16 = self._dip_agg.R @ p10
456
+ p16 = p16 / (p16.sum() or 1.0)
457
+ i, j = geno_classes[true_idx10]
458
+ if i == j:
459
+ # Homozygote: single ordered index
460
+ true16 = ordered_pairs.index((i, j))
461
+ rows_16_truth.append(true16)
462
+ rows_16_proba.append(p16)
463
+ else:
464
+ # Heterozygote: duplicate as two ordered permutations
465
+ true16a = ordered_pairs.index((i, j))
466
+ true16b = ordered_pairs.index((j, i))
467
+ rows_16_truth.append(true16a)
468
+ rows_16_proba.append(p16)
469
+ rows_16_truth.append(true16b)
470
+ rows_16_proba.append(p16)
471
+
472
+ if not rows_10:
473
+ self.logger.warning("No valid diploid truth rows for evaluation.")
474
+ return
475
+
476
+ # ---- 10-class metrics ----
477
+ y_true_10 = np.array([t for (t, _) in rows_10], dtype=int)
478
+ y_pred_proba_10 = np.vstack([p for (_, p) in rows_10]) # (n,10)
479
+
480
+ y_pred_10 = np.argmax(y_pred_proba_10, axis=1)
481
+ y_true_ohe_10 = np.eye(10, dtype=float)[y_true_10]
482
+
483
+ self.logger.info("--- Per-Genotype Imputation Performance (10-class) ---")
484
+ self.evaluation_results_ = self.scorer.evaluate(
485
+ y_true=y_true_10,
486
+ y_pred=y_pred_10,
487
+ y_true_ohe=y_true_ohe_10,
488
+ y_pred_proba=y_pred_proba_10,
489
+ )
490
+ self.logger.info(f"Evaluation results (10-class): {self.evaluation_results_}")
491
+
492
+ # Plots (10-class)
493
+ self.plotter.plot_metrics(
494
+ y_true=y_true_10,
495
+ y_pred_proba=y_pred_proba_10,
496
+ metrics=self.evaluation_results_,
497
+ label_names=labels10,
498
+ )
499
+ self.plotter.plot_confusion_matrix(
500
+ y_true_1d=y_true_10, y_pred_1d=y_pred_10, label_names=labels10
501
+ )
502
+
503
+ # ---- Optional 16-class ordered metrics ----
504
+ if getattr(self, "diploid_ordered", False) and rows_16_truth:
505
+ y_true_16 = np.array(rows_16_truth, dtype=int)
506
+ y_pred_proba_16 = np.vstack(rows_16_proba)
507
+ y_pred_16 = np.argmax(y_pred_proba_16, axis=1)
508
+ y_true_ohe_16 = np.eye(16, dtype=float)[y_true_16]
509
+
510
+ self.logger.info(
511
+ "--- Per-Genotype Imputation Performance (16-class ordered) ---"
512
+ )
513
+ eval16 = self.scorer.evaluate(
514
+ y_true=y_true_16,
515
+ y_pred=y_pred_16,
516
+ y_true_ohe=y_true_ohe_16,
517
+ y_pred_proba=y_pred_proba_16,
518
+ )
519
+ self.logger.info(f"Evaluation results (16-class): {eval16}")
520
+
521
+ self.plotter.plot_metrics(
522
+ y_true=y_true_16,
523
+ y_pred_proba=y_pred_proba_16,
524
+ metrics=eval16,
525
+ label_names=labels16,
526
+ )
527
+ self.plotter.plot_confusion_matrix(
528
+ y_true_1d=y_true_16, y_pred_1d=y_pred_16, label_names=labels16
529
+ )
530
+
531
+ def _infer_proba_permutation(
532
+ self, y_ref: np.ndarray, P: np.ndarray, n_classes: int = 4
533
+ ) -> tuple[np.ndarray, list[int]]:
534
+ """Infers a permutation to align probability columns to the label space of y_ref."""
535
+ perm = [-1] * n_classes
536
+ taken = set()
537
+ for k in range(n_classes):
538
+ mask = y_ref == k
539
+ if not np.any(mask):
540
+ means = P.mean(axis=0)
541
+ else:
542
+ means = P[mask].mean(axis=0)
543
+
544
+ # Find best unused column
545
+ for col in np.argsort(means)[::-1]:
546
+ if col not in taken:
547
+ perm[k] = col
548
+ taken.add(col)
549
+ break
550
+
551
+ if len(taken) != n_classes: # Fallback if permutation is incomplete
552
+ unassigned_cols = [c for c in range(n_classes) if c not in taken]
553
+ for i in range(n_classes):
554
+ if perm[i] == -1:
555
+ perm[i] = unassigned_cols.pop(0)
556
+
557
+ self.logger.info(f"Inferred probability permutation (label->col): {perm}")
558
+ return P[:, perm], perm
559
+
560
+ def _stationary_pi(self, Q: np.ndarray) -> np.ndarray:
561
+ """Robustly calculates the stationary distribution of Q."""
562
+ w, v = scipy.linalg.eig(Q.T)
563
+ k = int(np.argmin(np.abs(w)))
564
+ pi = np.real(v[:, k])
565
+ pi = np.maximum(pi, 0.0)
566
+ s = pi.sum()
567
+ return (pi / s) if s > 0 else np.ones(Q.shape[0]) / Q.shape[0]
568
+
569
+ def impute_phylo(
570
+ self,
571
+ genotypes: Dict[str, List[Union[str, int]]],
572
+ tree: tt.tree,
573
+ q_matrix: pd.DataFrame,
574
+ site_rates: Optional[List[float]],
575
+ ) -> pd.DataFrame:
576
+ """Imputes missing values using a phylogenetic guide tree."""
577
+ self.imputed_likelihoods_.clear()
578
+ self.imputed_genotype_likelihoods_.clear()
579
+
580
+ common_samples = set(tree.get_tip_labels()) & set(genotypes.keys())
581
+ if not common_samples:
582
+ raise ValueError("No samples in common between tree and genotypes.")
583
+ filt_genotypes = copy.deepcopy({s: genotypes[s] for s in common_samples})
584
+ num_snps = len(next(iter(filt_genotypes.values())))
585
+
586
+ if site_rates is not None and len(site_rates) != num_snps:
587
+ raise ValueError(
588
+ f"len(site_rates)={len(site_rates)} != num_snps={num_snps}"
589
+ )
590
+
591
+ # Stationary prior at root
592
+ if not self.haploid:
593
+ # pi4 from Q; then lump pi4⊗pi4 to 10 classes
594
+ pi4 = self._stationary_pi(q_matrix.to_numpy())
595
+ pi_ord = np.kron(pi4, pi4) # 16
596
+ pi10 = self._dip_agg.C @ pi_ord
597
+ pi10 = pi10 / (pi10.sum() or 1.0)
598
+ root_prior = pi10
599
+ n_states = 10
600
+ else:
601
+ root_prior = self._stationary_pi(q_matrix.to_numpy()) # 4
602
+ n_states = 4
603
+
604
+ for snp_index in range(num_snps):
605
+ rate = site_rates[snp_index] if site_rates is not None else 1.0
606
+ tips_with_missing = [
607
+ s
608
+ for s, seq in filt_genotypes.items()
609
+ if self._is_missing(seq[snp_index])
610
+ ]
611
+ if not tips_with_missing:
612
+ continue
613
+
614
+ down_liks: Dict[int, np.ndarray] = {}
615
+ for node in tree.treenode.traverse("postorder"):
616
+ lik = np.zeros(n_states, dtype=float)
617
+ if node.is_leaf():
618
+ if node.name not in filt_genotypes or self._is_missing(
619
+ filt_genotypes[node.name][snp_index]
620
+ ):
621
+ lik[:] = 1.0 # missing: uniform emission
622
+ else:
623
+ obs = filt_genotypes[node.name][snp_index]
624
+ if not self.haploid:
625
+ cls = self._iupac_to_genotype_classes(obs)
626
+ lik[cls] = 1.0
627
+ else:
628
+ for state in self._get_iupac_full(obs):
629
+ lik[self.char_map[state]] = 1.0
630
+ down_liks[node.idx] = lik / lik.sum()
631
+ else:
632
+ msg = np.ones(n_states, dtype=float)
633
+ for child in node.children:
634
+ P = self._qcache.P(
635
+ max(child.dist, self.min_branch_length), rate
636
+ )
637
+ msg *= P @ down_liks[child.idx]
638
+ down_liks[node.idx] = self._norm(msg)
639
+
640
+ up_liks: Dict[int, np.ndarray] = {tree.treenode.idx: root_prior.copy()}
641
+ for node in tree.treenode.traverse("preorder"):
642
+ if node.is_root():
643
+ continue
644
+ parent = node.up
645
+ sib_prod = np.ones(n_states, dtype=float)
646
+ for sib in parent.children:
647
+ if sib.idx == node.idx:
648
+ continue
649
+ P_sib = self._qcache.P(max(sib.dist, self.min_branch_length), rate)
650
+ sib_prod *= P_sib @ down_liks[sib.idx]
651
+ parent_msg = up_liks[parent.idx] * sib_prod
652
+ P = self._qcache.P(max(node.dist, self.min_branch_length), rate)
653
+ up = parent_msg @ P
654
+ up_liks[node.idx] = up / (up.sum() or 1.0)
655
+
656
+ for samp in tips_with_missing:
657
+ node = tree.get_nodes(samp)[0]
658
+ leaf_emission = down_liks[node.idx] # uniform if missing
659
+ tip_post = self._norm(up_liks[node.idx] * leaf_emission)
660
+
661
+ if self.ground_truth_ and (samp, snp_index) in self.ground_truth_:
662
+ if not self.haploid:
663
+ # store genotype posterior (10) and allele-marginal (4)
664
+ self.imputed_genotype_likelihoods_[(samp, snp_index)] = (
665
+ tip_post.copy()
666
+ )
667
+ self.imputed_likelihoods_[(samp, snp_index)] = (
668
+ self._marginalize_genotype_to_allele(tip_post)
669
+ )
670
+ else:
671
+ self.imputed_likelihoods_[(samp, snp_index)] = tip_post.copy()
672
+
673
+ if not self.haploid:
674
+ call = self._genotype_posterior_to_iupac(tip_post, mode="MAP")
675
+ else:
676
+ call = self._allele_posterior_to_iupac_genotype(
677
+ tip_post, mode="MAP"
678
+ )
679
+ filt_genotypes[samp][snp_index] = call
680
+
681
+ df = pd.DataFrame.from_dict(filt_genotypes, orient="index")
682
+
683
+ if df.applymap(self._is_missing).any().any():
684
+ raise AssertionError("Imputation failed. Missing values remain.")
685
+
686
+ return df
687
+
688
+ def _iupac_to_genotype_classes(self, char: str) -> list[int]:
689
+ """Map IUPAC code to allowed unordered genotype class indices (10-state)."""
690
+ c = str(char).upper()
691
+ # Allele indices: A=0, C=1, G=2, T=3
692
+ single = {"A": 0, "C": 1, "G": 2, "T": 3}
693
+ het_map = { # two-allele ambiguity
694
+ "M": (0, 1),
695
+ "R": (0, 2),
696
+ "W": (0, 3),
697
+ "S": (2, 1),
698
+ "Y": (1, 3),
699
+ "K": (2, 3),
700
+ }
701
+ if c in single:
702
+ # homozygote only
703
+ return [[0, 4, 7, 9][single[c]]] # AA,CC,GG,TT class indices
704
+ if c in het_map:
705
+ i, j = het_map[c]
706
+ # find index in our class order (i<=j by construction)
707
+ i, j = (i, j) if i <= j else (j, i)
708
+ class_order = self._dip_agg.genotype_classes
709
+ return [class_order.index((i, j))]
710
+ # Ambiguity codes with >2 alleles or missing: allow any compatible genotype
711
+ amb = {
712
+ "N": {0, 1, 2, 3},
713
+ "-": {0, 1, 2, 3},
714
+ "B": {1, 2, 3},
715
+ "D": {0, 2, 3},
716
+ "H": {0, 1, 3},
717
+ "V": {0, 1, 2},
718
+ }
719
+ if c in amb:
720
+ allowed = amb[c]
721
+ classes = []
722
+ for idx, (i, j) in enumerate(self._dip_agg.genotype_classes):
723
+ if i in allowed and j in allowed:
724
+ classes.append(idx)
725
+ return classes
726
+ # Fallback: allow all
727
+ return list(range(10))
728
+
729
+ def _genotype_posterior_to_iupac(self, p10: np.ndarray, mode: str = "MAP") -> str:
730
+ """Convert genotype posterior over 10 classes to an IUPAC code."""
731
+ p = np.asarray(p10, dtype=float)
732
+ p = p / (p.sum() or 1.0)
733
+ if mode.upper() == "SAMPLE":
734
+ k = int(np.random.choice(len(p), p=p))
735
+ else:
736
+ k = int(np.argmax(p))
737
+ i, j = self._dip_agg.genotype_classes[k]
738
+ alleles = ["A", "C", "G", "T"]
739
+ gt = alleles[i] + alleles[j]
740
+ return self._genotype_to_iupac(gt)
741
+
742
+ def _marginalize_genotype_to_allele(self, p10: np.ndarray) -> np.ndarray:
743
+ """Allele posterior from genotype posterior for eval/diagnostics (length 4)."""
744
+ p10 = np.asarray(p10, dtype=float)
745
+ p10 = p10 / (p10.sum() or 1.0)
746
+ agg = self._dip_agg
747
+ out = np.zeros(4, dtype=float)
748
+ for a in range(4):
749
+ out[a] = p10[agg.hom_class_index[a]] + 0.5 * np.sum(
750
+ p10[agg.het_classes_by_allele[a]]
751
+ )
752
+ s = out.sum()
753
+ return out / (s or 1.0)
754
+
755
+ def _parse_arguments(
756
+ self,
757
+ ) -> Tuple[
758
+ Dict[str, List[Union[str, int]]], tt.tree, pd.DataFrame, Optional[List[float]]
759
+ ]:
760
+ if (
761
+ not hasattr(self.genotype_data, "snpsdict")
762
+ or self.genotype_data.snpsdict is None
763
+ ):
764
+ raise TypeError("`GenotypeData.snpsdict` must be defined.")
765
+ if not hasattr(self.tree_parser, "tree") or self.tree_parser.tree is None:
766
+ raise TypeError("`TreeParser.tree` must be defined.")
767
+ if not hasattr(self.tree_parser, "qmat") or self.tree_parser.qmat is None:
768
+ raise TypeError("`TreeParser.qmat` must be defined.")
769
+ site_rates = getattr(self.tree_parser, "site_rates", None)
770
+ return (
771
+ self.genotype_data.snpsdict,
772
+ self.tree_parser.tree,
773
+ self.tree_parser.qmat,
774
+ site_rates,
775
+ )
776
+
777
+ def _get_iupac_full(self, char: str) -> List[str]:
778
+ iupac_map = {
779
+ "A": ["A"],
780
+ "G": ["G"],
781
+ "C": ["C"],
782
+ "T": ["T"],
783
+ "N": ["A", "C", "T", "G"],
784
+ "-": ["A", "C", "T", "G"],
785
+ "R": ["A", "G"],
786
+ "Y": ["C", "T"],
787
+ "S": ["G", "C"],
788
+ "W": ["A", "T"],
789
+ "K": ["G", "T"],
790
+ "M": ["A", "C"],
791
+ "B": ["C", "G", "T"],
792
+ "D": ["A", "G", "T"],
793
+ "H": ["A", "C", "T"],
794
+ "V": ["A", "C", "G"],
795
+ }
796
+ return iupac_map.get(char.upper(), ["A", "C", "T", "G"])
797
+
798
+ def _allele_posterior_to_genotype_probs(self, p: np.ndarray) -> Dict[str, float]:
799
+ p = np.maximum(p, 0.0)
800
+ s = p.sum()
801
+ p = (p / s) if s > 0 else np.ones_like(p) / len(p)
802
+ alleles = ["A", "C", "G", "T"]
803
+ probs = {}
804
+ for i, a in enumerate(alleles):
805
+ for j, b in enumerate(alleles[i:], start=i):
806
+ probs[a + b] = float(p[i] * p[j] * (2.0 if i != j else 1.0))
807
+ z = sum(probs.values()) or 1.0
808
+ return {k: v / z for k, v in probs.items()}
809
+
810
+ def _genotype_to_iupac(self, gt: str) -> str:
811
+ if gt[0] == gt[1]:
812
+ return gt[0]
813
+ pair = "".join(sorted(gt))
814
+ het_map = {"AC": "M", "AG": "R", "AT": "W", "CG": "S", "CT": "Y", "GT": "K"}
815
+ return het_map.get(pair, "N")
816
+
817
+ def _allele_posterior_to_iupac_genotype(
818
+ self, p: np.ndarray, mode: str = "MAP"
819
+ ) -> str:
820
+ gprobs = self._allele_posterior_to_genotype_probs(p)
821
+ if mode.upper() == "SAMPLE":
822
+ gts, vals = zip(*gprobs.items())
823
+ choice = np.random.choice(len(gts), p=np.array(vals, dtype=float))
824
+ return self._genotype_to_iupac(gts[choice])
825
+ best_gt = max(gprobs, key=gprobs.get)
826
+ return self._genotype_to_iupac(best_gt)
827
+
828
+ def _align_and_check_q(self, q_df: pd.DataFrame) -> pd.DataFrame:
829
+ """Return Q reindexed to ['A','C','G','T'] and assert CTMC sanity.
830
+
831
+ Args:
832
+ q_df: Square rate matrix with nucleotide index/columns.
833
+
834
+ Returns:
835
+ Reindexed Q as a DataFrame in A,C,G,T order.
836
+
837
+ Raises:
838
+ ValueError: If Q is malformed or not alignable.
839
+ """
840
+ required = ["A", "C", "G", "T"]
841
+ if not set(required).issubset(set(q_df.index)) or not set(required).issubset(
842
+ set(q_df.columns)
843
+ ):
844
+ raise ValueError("Q must have index/columns including exactly A,C,G,T.")
845
+
846
+ q_df = q_df.loc[required, required].astype(float)
847
+
848
+ q = q_df.to_numpy()
849
+
850
+ # Off-diagonals >= 0, diagonals <= 0, rows sum ~ 0
851
+ if np.any(q[~np.eye(4, dtype=bool)] < -1e-12):
852
+ raise ValueError("Q off-diagonal entries must be non-negative.")
853
+
854
+ if np.any(np.diag(q) > 1e-12):
855
+ raise ValueError("Q diagonal entries must be non-positive.")
856
+
857
+ if not np.allclose(q.sum(axis=1), 0.0, atol=1e-8):
858
+ self.logger.error(q.sum(axis=1))
859
+ raise ValueError("Q rows must sum to 0.")
860
+
861
+ return pd.DataFrame(q, index=required, columns=required)
862
+
863
+ def _stationary_pi_from_q(self, Q: np.ndarray) -> np.ndarray:
864
+ """Compute stationary distribution pi for generator Q.
865
+
866
+ Uses eigenvector of Q^T at eigenvalue 0; clips tiny negatives; renormalizes.
867
+
868
+ Args:
869
+ Q: (4,4) CTMC generator.
870
+
871
+ Returns:
872
+ (4,) stationary distribution pi with sum=1.
873
+ """
874
+ w, v = scipy.linalg.eig(Q.T)
875
+ k = int(np.argmin(np.abs(w)))
876
+ pi = np.real(v[:, k])
877
+ pi = np.maximum(pi, 0.0)
878
+ s = float(pi.sum())
879
+ if not np.isfinite(s) or s <= 0:
880
+ # Fallback uniform if eigen failed
881
+ return np.ones(Q.shape[0]) / Q.shape[0]
882
+ return pi / s
883
+
884
+ def _repair_and_scale_q(
885
+ self,
886
+ q_df: pd.DataFrame,
887
+ target_mean_rate: float = 1.0,
888
+ state_order: tuple[str, ...] = ("A", "C", "G", "T"),
889
+ neg_offdiag_tol: float = 1e-8,
890
+ ) -> pd.DataFrame:
891
+ """Repair Q to a valid CTMC generator and scale its mean rate.
892
+
893
+ Steps:
894
+ 1) Reindex to `state_order` and cast to float.
895
+ 2) Set negative off-diagonals to 0 (warn if any < -neg_offdiag_tol).
896
+ 3) Set diagonal q_ii = -sum_{j!=i} q_ij so rows sum to 0 exactly.
897
+ 4) If a row has zero off-diagonal sum, inject a tiny uniform exit rate.
898
+ 5) Compute stationary pi and scale Q so that -sum_i pi_i q_ii = target_mean_rate.
899
+
900
+ Args:
901
+ q_df: DataFrame with index=columns=states.
902
+ target_mean_rate: Desired average rate under pi (commonly 1.0).
903
+ state_order: Desired nucleotide order.
904
+ neg_offdiag_tol: Tolerance to log warnings for negative off-diags.
905
+
906
+ Returns:
907
+ Repaired, scaled Q as a DataFrame in `state_order`.
908
+ """
909
+ # 1) Align
910
+ missing = set(state_order) - set(q_df.index) | set(state_order) - set(
911
+ q_df.columns
912
+ )
913
+ if missing:
914
+ raise ValueError(
915
+ f"Q must have states {state_order}, missing: {sorted(missing)}"
916
+ )
917
+ Q = q_df.loc[state_order, state_order].to_numpy(dtype=float, copy=True)
918
+
919
+ # 2) Clip negative off-diagonals
920
+ off_mask = ~np.eye(Q.shape[0], dtype=bool)
921
+ neg_off = Q[off_mask] < 0
922
+ if np.any(Q[off_mask] < -neg_offdiag_tol):
923
+ self.logger.warning(
924
+ f"Q has negative off-diagonals; clipping {int(np.sum(neg_off))} entries."
925
+ )
926
+ Q[off_mask] = np.maximum(Q[off_mask], 0.0)
927
+
928
+ # 3) Set diagonal so rows sum to 0 exactly
929
+ row_off_sum = Q.sum(axis=1) - np.diag(Q) # includes diag; fix next
930
+ np.fill_diagonal(Q, -row_off_sum)
931
+
932
+ # 4) Ensure no absorbing rows
933
+ zero_rows = row_off_sum <= 0
934
+ if np.any(zero_rows):
935
+ eps = 1e-8
936
+ Q[zero_rows, :] = 0.0
937
+ for i in np.where(zero_rows)[0]:
938
+ Q[i, :] = eps / (Q.shape[0] - 1)
939
+ Q[i, i] = -eps
940
+ self.logger.warning(
941
+ f"Injected tiny exit rates in {int(np.sum(zero_rows))} rows."
942
+ )
943
+
944
+ # Sanity: rows now sum to 0 exactly
945
+ if not np.allclose(Q.sum(axis=1), 0.0, atol=1e-12):
946
+ raise RuntimeError("Internal error: row sums not zero after repair.")
947
+
948
+ # 5) Scale to target mean rate under stationary pi
949
+ pi = self._stationary_pi_from_q(Q)
950
+ mean_rate = float(-np.dot(pi, np.diag(Q)))
951
+ if not (np.isfinite(mean_rate) and mean_rate > 0):
952
+ self.logger.warning("Mean rate non-positive; skipping scaling.")
953
+ scale = 1.0
954
+ else:
955
+ scale = mean_rate / float(target_mean_rate)
956
+ Q /= scale
957
+
958
+ return pd.DataFrame(Q, index=state_order, columns=state_order)
959
+
960
+ def _is_missing(self, x: object) -> bool:
961
+ """Return True if x represents a missing genotype token."""
962
+ s = str(x).strip()
963
+ return s in self._MISSING_TOKENS
964
+
965
+ def _is_known(self, x: object) -> bool:
966
+ return not self._is_missing(x)
967
+
968
+ @staticmethod
969
+ def _norm(v: np.ndarray) -> np.ndarray:
970
+ s = float(np.sum(v))
971
+ if s <= 0 or not np.isfinite(s):
972
+ return np.ones_like(v) / v.size
973
+ return v / s