pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,971 @@
1
+ import copy
2
+ import random
3
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
4
+
5
+ # Third-party imports
6
+ import numpy as np
7
+ import pandas as pd
8
+ import scipy.linalg
9
+ import toytree as tt
10
+
11
+ from pgsui.utils.plotting import Plotting
12
+ from pgsui.utils.scorers import Scorer
13
+
14
+ if TYPE_CHECKING:
15
+ from snpio.analysis.tree_parser import TreeParser
16
+ from snpio.read_input.genotype_data import GenotypeData
17
+
18
+
19
+ class _DiploidAggregator:
20
+ """Precompute mappings to lump ordered diploid states (16) into unordered genotypes (10).
21
+
22
+ Notes:
23
+ Genotype class order matches the SNPio allele → genotype encodings:
24
+ - 0: AA, 1: AC, 2: AG, 3: AT, 4: CC, 5: CG, 6: CT, 7: GG, 8: GT, 9: T
25
+ - Allele order is A(0), C(1), G(2), T(3).
26
+ """
27
+
28
+ def __init__(self) -> None:
29
+ # Ordered pairs (i,j) lexicographic over {0,1,2,3}; 16 states
30
+ self.ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
31
+ # Unordered genotype classes as (min,max) in the specified class order
32
+ self.genotype_classes = (
33
+ [(0, 0)]
34
+ + [(0, 1), (0, 2), (0, 3)]
35
+ + [(1, 1), (1, 2), (1, 3)]
36
+ + [(2, 2), (2, 3)]
37
+ + [(3, 3)]
38
+ ) # 10 states
39
+
40
+ # Map: class index -> list of ordered-state indices
41
+ self.class_to_ordered: list[list[int]] = []
42
+ for i, j in self.genotype_classes:
43
+ members = []
44
+ for k, (a, b) in enumerate(self.ordered_pairs):
45
+ if i == j: # homozygote
46
+ if a == i and b == j:
47
+ members.append(k) # exactly one member
48
+ else: # heterozygote, both permutations
49
+ if (a == i and b == j) or (a == j and b == i):
50
+ members.append(k)
51
+ self.class_to_ordered.append(members)
52
+
53
+ # Build R (16x10) and C (10x16)
54
+ self.R = np.zeros((16, 10), dtype=float)
55
+ self.C = np.zeros((10, 16), dtype=float)
56
+ for c, members in enumerate(self.class_to_ordered):
57
+ m = float(len(members))
58
+ for o in members:
59
+ self.R[o, c] = 1.0 / m # spread class prob equally to ordered members
60
+ self.C[c, o] = 1.0 # sum ordered probs back to class
61
+
62
+ # For allele-marginalization from 10-state genotype posterior to 4 alleles
63
+ # p_allele[i] = P(ii) + 0.5 * sum_{j!=i} P(min(i,j), max(i,j))
64
+ self.het_classes_by_allele: dict[int, list[int]] = {0: [], 1: [], 2: [], 3: []}
65
+ self.hom_class_index = {0: 0, 1: 4, 2: 7, 3: 9} # AA, CC, GG, TT indices
66
+ for idx, (i, j) in enumerate(self.genotype_classes):
67
+ if i != j:
68
+ self.het_classes_by_allele[i].append(idx)
69
+ self.het_classes_by_allele[j].append(idx)
70
+
71
+
72
+ class _QCache:
73
+ """Cache P(t) for haploid (4) and optionally diploid (10) via Kronecker + lumping."""
74
+
75
+ def __init__(
76
+ self,
77
+ q_df: pd.DataFrame,
78
+ mode: str = "haploid",
79
+ diploid_agg: "_DiploidAggregator | None" = None,
80
+ ) -> None:
81
+ """Precompute eigendecomposition of haploid Q.
82
+
83
+ Args:
84
+ q_df: 4x4 generator in A,C,G,T order.
85
+ mode: "haploid" or "diploid".
86
+ diploid_agg: required if mode="diploid".
87
+ """
88
+ m = np.asarray(q_df, dtype=float)
89
+ evals, V = scipy.linalg.eig(m)
90
+ Vinv = scipy.linalg.inv(V)
91
+ self.evals = evals
92
+ self.V = V
93
+ self.Vinv = Vinv
94
+ self.mode = mode
95
+ self.agg = diploid_agg
96
+ if self.mode == "diploid" and self.agg is None:
97
+ raise ValueError("Diploid mode requires a _DiploidAggregator.")
98
+ self._cache4: dict[float, np.ndarray] = {}
99
+ self._cache10: dict[float, np.ndarray] = {}
100
+
101
+ def _P4(self, s: float) -> np.ndarray:
102
+ """Return P(t) for haploid (4 states)."""
103
+ key = round(float(s), 12)
104
+ if key in self._cache4:
105
+ return self._cache4[key]
106
+ expo = np.exp(self.evals * key)
107
+ P4 = (self.V * expo) @ self.Vinv
108
+ P4 = np.real_if_close(P4, tol=1e5)
109
+ P4[P4 < 0.0] = 0.0
110
+ P4 /= P4.sum(axis=1, keepdims=True).clip(min=1.0)
111
+ self._cache4[key] = P4
112
+ return P4
113
+
114
+ def P(self, t: float, rate: float = 1.0) -> np.ndarray:
115
+ """Return P(t) in the active mode."""
116
+ s = float(rate) * float(t)
117
+ if self.mode == "haploid":
118
+ return self._P4(s)
119
+ # diploid
120
+ key = round(s, 12)
121
+ if key in self._cache10:
122
+ return self._cache10[key]
123
+ P4 = self._P4(s)
124
+ P16 = np.kron(P4, P4) # independent alleles
125
+ P10 = self.agg.C @ P16 @ self.agg.R # lump to unordered genotypes
126
+ P10 = np.maximum(P10, 0.0)
127
+ P10 /= P10.sum(axis=1, keepdims=True).clip(min=1.0)
128
+ self._cache10[key] = P10
129
+ return P10
130
+
131
+
132
+ class ImputePhylo:
133
+ """Imputes missing genotype data using a phylogenetic likelihood model.
134
+
135
+ This imputer uses a continuous-time Markov chain (CTMC) model of sequence evolution to impute missing genotype data based on a provided phylogenetic tree. It supports both haploid and diploid data, with options for evaluating imputation accuracy through simulated missingness.
136
+
137
+ Notes:
138
+ - Haploid CTMC (4 states: A,C,G,T) [default].
139
+ - Diploid CTMC (10 unordered genotype states: AA,...,TT), derived by independent allele evolution and lumping ordered pairs.
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ genotype_data: "GenotypeData",
145
+ tree_parser: "TreeParser",
146
+ prefix: str,
147
+ min_branch_length: float = 1e-10,
148
+ *,
149
+ haploid: bool = False,
150
+ eval_missing_rate: float = 0.0,
151
+ column_subset: Optional[List[int]] = None,
152
+ save_plots: bool = False,
153
+ verbose: bool = False,
154
+ debug: bool = False,
155
+ ) -> None:
156
+ self.genotype_data = genotype_data
157
+ self.tree_parser = tree_parser
158
+ self.prefix = prefix
159
+ self.min_branch_length = min_branch_length
160
+ self.eval_missing_rate = eval_missing_rate
161
+ self.column_subset = column_subset
162
+ self.save_plots = save_plots
163
+ self.logger = genotype_data.logger
164
+ self.verbose = verbose
165
+ self.debug = debug
166
+ self.char_map = {"A": 0, "C": 1, "G": 2, "T": 3}
167
+ self.nuc_map = {v: k for k, v in self.char_map.items()}
168
+ self.imputer_data_: Optional[Tuple] = None
169
+ self.ground_truth_: Optional[Dict] = None
170
+ self.scorer = Scorer(self.prefix, average="macro", verbose=verbose, debug=debug)
171
+ self.plotter = Plotting(
172
+ "ImputePhylo",
173
+ prefix=self.prefix,
174
+ plot_format=self.genotype_data.plot_format,
175
+ plot_fontsize=self.genotype_data.plot_fontsize,
176
+ plot_dpi=self.genotype_data.plot_dpi,
177
+ title_fontsize=self.genotype_data.plot_fontsize,
178
+ verbose=verbose,
179
+ debug=debug,
180
+ )
181
+
182
+ self._MISSING_TOKENS = {"N", "n", "-9", ".", "?", "./.", ""}
183
+
184
+ self.haploid = haploid
185
+ self._dip_agg: _DiploidAggregator | None = None
186
+
187
+ self.imputed_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
188
+ self.imputed_genotype_likelihoods_: Dict[Tuple[str, int], np.ndarray] = {}
189
+ self.evaluation_results_: dict | None = None
190
+
191
+ def fit(self) -> "ImputePhylo":
192
+ """Prepares the imputer by parsing and validating input data.
193
+
194
+ This method does the following:
195
+ - Validates the genotype data and phylogenetic tree.
196
+ - Extracts the genotype matrix, pruned tree, Q-matrix, and site rates.
197
+ - Sets up internal structures for imputation.
198
+ """
199
+ self.imputer_data_ = self._parse_arguments()
200
+ return self
201
+
202
+ def transform(self) -> np.ndarray:
203
+ """Transforms the data by imputing missing values.
204
+
205
+ This method does the following:
206
+ - Uses the fitted imputer to perform phylogenetic imputation.
207
+ - Returns the imputed genotype DataFrame.
208
+ """
209
+ if self.imputer_data_ is None:
210
+ msg = "The imputer has not been fitted. Call 'fit' first."
211
+ self.logger.error(msg)
212
+ raise RuntimeError(msg)
213
+
214
+ original_genotypes, tree, q_matrix_in, site_rates = self.imputer_data_
215
+ q_matrix = self._repair_and_scale_q(q_matrix_in, target_mean_rate=1.0)
216
+
217
+ if not self.haploid:
218
+ if self._dip_agg is None:
219
+ self._dip_agg = _DiploidAggregator()
220
+ self._qcache = _QCache(q_matrix, mode="diploid", diploid_agg=self._dip_agg)
221
+ else:
222
+ self._qcache = _QCache(q_matrix, mode="haploid")
223
+
224
+ genotypes_to_impute = original_genotypes
225
+
226
+ if self.eval_missing_rate > 0:
227
+ genotypes_to_impute = self._simulate_missing_data(original_genotypes)
228
+
229
+ imputed_df = self.impute_phylo(genotypes_to_impute, tree, q_matrix, site_rates)
230
+
231
+ if self.ground_truth_:
232
+ self._evaluate_imputation(imputed_df)
233
+ return imputed_df # keep as DataFrame with sample index
234
+
235
+ def fit_transform(self) -> pd.DataFrame:
236
+ """Fits the imputer and transforms the data in one step.
237
+
238
+ This method does the following:
239
+ - Calls the `fit` method to prepare the imputer.
240
+ - Calls the `transform` method to perform imputation.
241
+ - Returns the imputed genotype DataFrame.
242
+ """
243
+ self.fit()
244
+ return self.transform()
245
+
246
+ def _simulate_missing_data(
247
+ self, original_genotypes: Dict[str, List[str]]
248
+ ) -> Dict[str, List[str]]:
249
+ """Masks a fraction of known genotypes for evaluation."""
250
+ genotypes_to_impute = copy.deepcopy(original_genotypes)
251
+ known_positions = [
252
+ (sample, site_idx)
253
+ for sample, seq in genotypes_to_impute.items()
254
+ for site_idx, base in enumerate(seq)
255
+ if self._is_known(base)
256
+ ]
257
+
258
+ if not known_positions:
259
+ self.logger.warning("No known values to mask for evaluation.")
260
+ return genotypes_to_impute
261
+
262
+ num_to_mask = int(len(known_positions) * self.eval_missing_rate)
263
+ if num_to_mask == 0:
264
+ self.logger.warning(f"eval_missing_rate is too low to mask any values.")
265
+ return genotypes_to_impute
266
+
267
+ # Sample the (sample, site_idx) tuples to be masked
268
+ positions_to_mask = random.sample(known_positions, num_to_mask)
269
+
270
+ # Correctly build the ground_truth dictionary
271
+ self.ground_truth_ = {}
272
+ for sample, site_idx in positions_to_mask:
273
+ # Store the original base before masking it
274
+ self.ground_truth_[(sample, site_idx)] = genotypes_to_impute[sample][
275
+ site_idx
276
+ ]
277
+ # Now, apply the mask
278
+ genotypes_to_impute[sample][site_idx] = "N"
279
+
280
+ self.logger.info(f"Masked {len(self.ground_truth_)} values for evaluation.")
281
+ return genotypes_to_impute
282
+
283
+ def _evaluate_imputation(self, imputed_df: pd.DataFrame) -> None:
284
+ """Evaluate imputation with 4 classes (haploid) or 10/16 classes (diploid).
285
+
286
+ Behavior:
287
+ - If self.haploid is True:
288
+ Evaluate per-allele with 4 classes (A,C,G,T), keeping your existing logic.
289
+ - If self.haploid is False:
290
+ Evaluate per-genotype.
291
+ * Default: 10 unordered classes (AA,AC,AG,AT,CC,CG,CT,GG,GT,TT).
292
+ * If getattr(self, 'diploid_ordered', False) is True:
293
+ Also compute 16-class ordered metrics (AA,AC,...,TT).
294
+ Requires self._dip_agg (the _DiploidAggregator). If absent, it is created.
295
+
296
+ Notes:
297
+ - Truth handling: only unambiguous IUPAC that maps to exactly one genotype class
298
+ is used for 10-class evaluation (e.g., A, C, G, T, M, R, W, S, Y, K).
299
+ Ambiguous codes (N, B, D, H, V, '-', etc.) are skipped for strict genotype
300
+ evaluation.
301
+ - Probabilities: prefer stored 10-class posteriors in
302
+ self.imputed_genotype_likelihoods_; if missing, fall back by converting the
303
+ 4-class allele posterior to 10-class using HW independence.
304
+ """
305
+ if not self.ground_truth_:
306
+ return
307
+
308
+ # -------------------------
309
+ # Haploid path (4 classes)
310
+ # -------------------------
311
+ if getattr(self, "haploid", True):
312
+ # --- Gather ground truth bases and posterior vectors (only if both exist)
313
+ rows = []
314
+ for (sample, site_idx), true_base in self.ground_truth_.items():
315
+ if (
316
+ sample in imputed_df.index
317
+ and (sample, site_idx) in self.imputed_likelihoods_
318
+ ):
319
+ rows.append(
320
+ (true_base, self.imputed_likelihoods_[(sample, site_idx)])
321
+ )
322
+ if not rows:
323
+ self.logger.warning("No matching masked sites found to evaluate.")
324
+ return
325
+
326
+ true_bases = np.array([b for (b, _) in rows], dtype=object)
327
+ proba_mat = np.vstack([p for (_, p) in rows]) # (n,4) in A,C,G,T
328
+ nuc_to_idx = self.char_map # {'A':0,'C':1,'G':2,'T':3}
329
+
330
+ y_true_site = np.array(
331
+ [nuc_to_idx.get(str(b).upper(), -1) for b in true_bases], dtype=int
332
+ )
333
+ valid_mask = (y_true_site >= 0) & np.isfinite(proba_mat).all(axis=1)
334
+ if not np.any(valid_mask):
335
+ self.logger.error(
336
+ "Evaluation arrays empty after filtering invalid entries."
337
+ )
338
+ return
339
+
340
+ y_true_site = y_true_site[valid_mask]
341
+ proba_mat = proba_mat[valid_mask]
342
+
343
+ # Interleave each site as two alleles (iid/HW assumption) for per-allele metrics
344
+ y_true_int = np.repeat(y_true_site, 2)
345
+ y_pred_proba = np.repeat(proba_mat, 2, axis=0)
346
+ y_pred_int = np.argmax(y_pred_proba, axis=1)
347
+
348
+ n_classes = 4
349
+ y_true_ohe = np.eye(n_classes, dtype=float)[y_true_int]
350
+ idx_to_nuc = {v: k for k, v in nuc_to_idx.items()}
351
+
352
+ self.logger.info("--- Per-Allele Imputation Performance (4-class) ---")
353
+ self.evaluation_results_ = self.scorer.evaluate(
354
+ y_true=y_true_int,
355
+ y_pred=y_pred_int,
356
+ y_true_ohe=y_true_ohe,
357
+ y_pred_proba=y_pred_proba,
358
+ )
359
+ self.logger.info(f"Evaluation results: {self.evaluation_results_}")
360
+
361
+ # Plots
362
+ self.plotter.plot_metrics(
363
+ y_true=y_true_int,
364
+ y_pred_proba=y_pred_proba,
365
+ metrics=self.evaluation_results_,
366
+ )
367
+ self.plotter.plot_confusion_matrix(
368
+ y_true_1d=y_true_int,
369
+ y_pred_1d=y_pred_int,
370
+ label_names=["A", "C", "G", "T"],
371
+ )
372
+ return
373
+
374
+ # -------------------------
375
+ # Diploid path (10/16 classes)
376
+ # -------------------------
377
+ # Ensure aggregator
378
+ if getattr(self, "_dip_agg", None) is None:
379
+ self._dip_agg = _DiploidAggregator()
380
+
381
+ # Label catalogs
382
+ geno_classes = self._dip_agg.genotype_classes # list of (i,j) with i<=j
383
+ alleles = ["A", "C", "G", "T"]
384
+ labels10 = [f"{alleles[i]}{alleles[j]}" for (i, j) in geno_classes] # 10 labels
385
+ ordered_pairs = [(i, j) for i in range(4) for j in range(4)]
386
+ labels16 = [
387
+ f"{alleles[i]}{alleles[j]}" for (i, j) in ordered_pairs
388
+ ] # 16 labels
389
+
390
+ # Helper: strict IUPAC→single 10-class index (skip ambiguous >2-allele codes)
391
+ iupac_to_10_single = {
392
+ "A": "AA",
393
+ "C": "CC",
394
+ "G": "GG",
395
+ "T": "TT",
396
+ "M": "AC",
397
+ "R": "AG",
398
+ "W": "AT",
399
+ "S": "CG",
400
+ "Y": "CT",
401
+ "K": "GT",
402
+ }
403
+ label_to_10_idx = {lab: idx for idx, lab in enumerate(labels10)}
404
+
405
+ rows_10: list[tuple[int, np.ndarray]] = [] # (true_idx10, p10)
406
+ rows_16_truth: list[int] = [] # expanded true labels (for 16)
407
+ rows_16_proba: list[np.ndarray] = [] # matching (p16) rows
408
+
409
+ # Collect rows
410
+ for (sample, site_idx), true_code in self.ground_truth_.items():
411
+ # Need probabilities for this masked (sample,site)
412
+ p10 = None
413
+ if (sample, site_idx) in self.imputed_genotype_likelihoods_:
414
+ p10 = np.asarray(
415
+ self.imputed_genotype_likelihoods_[(sample, site_idx)], float
416
+ )
417
+ elif (sample, site_idx) in self.imputed_likelihoods_:
418
+ # Fallback: build 10-class from 4-class posterior (HW independence)
419
+ p4 = np.asarray(self.imputed_likelihoods_[(sample, site_idx)], float)
420
+ p4 = np.clip(p4, 0.0, np.inf)
421
+ p4 = p4 / (p4.sum() or 1.0)
422
+ # Unordered genotype probs: AA=pA^2, AC=2pApC, ...
423
+ p10 = np.zeros(10, dtype=float)
424
+ # class indices in our order:
425
+ # 0:AA, 1:AC, 2:AG, 3:AT, 4:CC, 5:CG, 6:CT, 7:GG, 8:GT, 9:TT
426
+ pA, pC, pG, pT = p4
427
+ p10[0] = pA * pA
428
+ p10[1] = 2 * pA * pC
429
+ p10[2] = 2 * pA * pG
430
+ p10[3] = 2 * pA * pT
431
+ p10[4] = pC * pC
432
+ p10[5] = 2 * pC * pG
433
+ p10[6] = 2 * pC * pT
434
+ p10[7] = pG * pG
435
+ p10[8] = 2 * pG * pT
436
+ p10[9] = pT * pT
437
+ p10 = p10 / (p10.sum() or 1.0)
438
+ else:
439
+ continue # no probabilities available
440
+
441
+ # Truth mapping to a single unordered class (skip ambiguous >2-allele codes)
442
+ c = str(true_code).upper()
443
+ lab10 = iupac_to_10_single.get(c, None)
444
+ if lab10 is None:
445
+ # Skip N, B, D, H, V, '-', etc., because they expand to multiple classes
446
+ continue
447
+ true_idx10 = label_to_10_idx[lab10]
448
+ rows_10.append((true_idx10, p10))
449
+
450
+ # Optional ordered-16 expansion (for confusion/metrics if requested)
451
+ if getattr(self, "diploid_ordered", False):
452
+ # Spread 10->16 via R (16x10)
453
+ p16 = self._dip_agg.R @ p10
454
+ p16 = p16 / (p16.sum() or 1.0)
455
+ i, j = geno_classes[true_idx10]
456
+ if i == j:
457
+ # Homozygote: single ordered index
458
+ true16 = ordered_pairs.index((i, j))
459
+ rows_16_truth.append(true16)
460
+ rows_16_proba.append(p16)
461
+ else:
462
+ # Heterozygote: duplicate as two ordered permutations
463
+ true16a = ordered_pairs.index((i, j))
464
+ true16b = ordered_pairs.index((j, i))
465
+ rows_16_truth.append(true16a)
466
+ rows_16_proba.append(p16)
467
+ rows_16_truth.append(true16b)
468
+ rows_16_proba.append(p16)
469
+
470
+ if not rows_10:
471
+ self.logger.warning("No valid diploid truth rows for evaluation.")
472
+ return
473
+
474
+ # ---- 10-class metrics ----
475
+ y_true_10 = np.array([t for (t, _) in rows_10], dtype=int)
476
+ y_pred_proba_10 = np.vstack([p for (_, p) in rows_10]) # (n,10)
477
+
478
+ y_pred_10 = np.argmax(y_pred_proba_10, axis=1)
479
+ y_true_ohe_10 = np.eye(10, dtype=float)[y_true_10]
480
+
481
+ self.logger.info("--- Per-Genotype Imputation Performance (10-class) ---")
482
+ self.evaluation_results_ = self.scorer.evaluate(
483
+ y_true=y_true_10,
484
+ y_pred=y_pred_10,
485
+ y_true_ohe=y_true_ohe_10,
486
+ y_pred_proba=y_pred_proba_10,
487
+ )
488
+ self.logger.info(f"Evaluation results (10-class): {self.evaluation_results_}")
489
+
490
+ # Plots (10-class)
491
+ self.plotter.plot_metrics(
492
+ y_true=y_true_10,
493
+ y_pred_proba=y_pred_proba_10,
494
+ metrics=self.evaluation_results_,
495
+ label_names=labels10,
496
+ )
497
+ self.plotter.plot_confusion_matrix(
498
+ y_true_1d=y_true_10, y_pred_1d=y_pred_10, label_names=labels10
499
+ )
500
+
501
+ # ---- Optional 16-class ordered metrics ----
502
+ if getattr(self, "diploid_ordered", False) and rows_16_truth:
503
+ y_true_16 = np.array(rows_16_truth, dtype=int)
504
+ y_pred_proba_16 = np.vstack(rows_16_proba)
505
+ y_pred_16 = np.argmax(y_pred_proba_16, axis=1)
506
+ y_true_ohe_16 = np.eye(16, dtype=float)[y_true_16]
507
+
508
+ self.logger.info(
509
+ "--- Per-Genotype Imputation Performance (16-class ordered) ---"
510
+ )
511
+ eval16 = self.scorer.evaluate(
512
+ y_true=y_true_16,
513
+ y_pred=y_pred_16,
514
+ y_true_ohe=y_true_ohe_16,
515
+ y_pred_proba=y_pred_proba_16,
516
+ )
517
+ self.logger.info(f"Evaluation results (16-class): {eval16}")
518
+
519
+ self.plotter.plot_metrics(
520
+ y_true=y_true_16,
521
+ y_pred_proba=y_pred_proba_16,
522
+ metrics=eval16,
523
+ label_names=labels16,
524
+ )
525
+ self.plotter.plot_confusion_matrix(
526
+ y_true_1d=y_true_16, y_pred_1d=y_pred_16, label_names=labels16
527
+ )
528
+
529
+ def _infer_proba_permutation(
530
+ self, y_ref: np.ndarray, P: np.ndarray, n_classes: int = 4
531
+ ) -> tuple[np.ndarray, list[int]]:
532
+ """Infers a permutation to align probability columns to the label space of y_ref."""
533
+ perm = [-1] * n_classes
534
+ taken = set()
535
+ for k in range(n_classes):
536
+ mask = y_ref == k
537
+ if not np.any(mask):
538
+ means = P.mean(axis=0)
539
+ else:
540
+ means = P[mask].mean(axis=0)
541
+
542
+ # Find best unused column
543
+ for col in np.argsort(means)[::-1]:
544
+ if col not in taken:
545
+ perm[k] = col
546
+ taken.add(col)
547
+ break
548
+
549
+ if len(taken) != n_classes: # Fallback if permutation is incomplete
550
+ unassigned_cols = [c for c in range(n_classes) if c not in taken]
551
+ for i in range(n_classes):
552
+ if perm[i] == -1:
553
+ perm[i] = unassigned_cols.pop(0)
554
+
555
+ self.logger.info(f"Inferred probability permutation (label->col): {perm}")
556
+ return P[:, perm], perm
557
+
558
+ def _stationary_pi(self, Q: np.ndarray) -> np.ndarray:
559
+ """Robustly calculates the stationary distribution of Q."""
560
+ w, v = scipy.linalg.eig(Q.T)
561
+ k = int(np.argmin(np.abs(w)))
562
+ pi = np.real(v[:, k])
563
+ pi = np.maximum(pi, 0.0)
564
+ s = pi.sum()
565
+ return (pi / s) if s > 0 else np.ones(Q.shape[0]) / Q.shape[0]
566
+
567
+ def impute_phylo(
568
+ self,
569
+ genotypes: Dict[str, List[Union[str, int]]],
570
+ tree: tt.tree,
571
+ q_matrix: pd.DataFrame,
572
+ site_rates: Optional[List[float]],
573
+ ) -> pd.DataFrame:
574
+ """Imputes missing values using a phylogenetic guide tree."""
575
+ self.imputed_likelihoods_.clear()
576
+ self.imputed_genotype_likelihoods_.clear()
577
+
578
+ common_samples = set(tree.get_tip_labels()) & set(genotypes.keys())
579
+ if not common_samples:
580
+ raise ValueError("No samples in common between tree and genotypes.")
581
+ filt_genotypes = copy.deepcopy({s: genotypes[s] for s in common_samples})
582
+ num_snps = len(next(iter(filt_genotypes.values())))
583
+
584
+ if site_rates is not None and len(site_rates) != num_snps:
585
+ raise ValueError(
586
+ f"len(site_rates)={len(site_rates)} != num_snps={num_snps}"
587
+ )
588
+
589
+ # Stationary prior at root
590
+ if not self.haploid:
591
+ # pi4 from Q; then lump pi4⊗pi4 to 10 classes
592
+ pi4 = self._stationary_pi(q_matrix.to_numpy())
593
+ pi_ord = np.kron(pi4, pi4) # 16
594
+ pi10 = self._dip_agg.C @ pi_ord
595
+ pi10 = pi10 / (pi10.sum() or 1.0)
596
+ root_prior = pi10
597
+ n_states = 10
598
+ else:
599
+ root_prior = self._stationary_pi(q_matrix.to_numpy()) # 4
600
+ n_states = 4
601
+
602
+ for snp_index in range(num_snps):
603
+ rate = site_rates[snp_index] if site_rates is not None else 1.0
604
+ tips_with_missing = [
605
+ s
606
+ for s, seq in filt_genotypes.items()
607
+ if self._is_missing(seq[snp_index])
608
+ ]
609
+ if not tips_with_missing:
610
+ continue
611
+
612
+ down_liks: Dict[int, np.ndarray] = {}
613
+ for node in tree.treenode.traverse("postorder"):
614
+ lik = np.zeros(n_states, dtype=float)
615
+ if node.is_leaf():
616
+ if node.name not in filt_genotypes or self._is_missing(
617
+ filt_genotypes[node.name][snp_index]
618
+ ):
619
+ lik[:] = 1.0 # missing: uniform emission
620
+ else:
621
+ obs = filt_genotypes[node.name][snp_index]
622
+ if not self.haploid:
623
+ cls = self._iupac_to_genotype_classes(obs)
624
+ lik[cls] = 1.0
625
+ else:
626
+ for state in self._get_iupac_full(obs):
627
+ lik[self.char_map[state]] = 1.0
628
+ down_liks[node.idx] = lik / lik.sum()
629
+ else:
630
+ msg = np.ones(n_states, dtype=float)
631
+ for child in node.children:
632
+ P = self._qcache.P(
633
+ max(child.dist, self.min_branch_length), rate
634
+ )
635
+ msg *= P @ down_liks[child.idx]
636
+ down_liks[node.idx] = self._norm(msg)
637
+
638
+ up_liks: Dict[int, np.ndarray] = {tree.treenode.idx: root_prior.copy()}
639
+ for node in tree.treenode.traverse("preorder"):
640
+ if node.is_root():
641
+ continue
642
+ parent = node.up
643
+ sib_prod = np.ones(n_states, dtype=float)
644
+ for sib in parent.children:
645
+ if sib.idx == node.idx:
646
+ continue
647
+ P_sib = self._qcache.P(max(sib.dist, self.min_branch_length), rate)
648
+ sib_prod *= P_sib @ down_liks[sib.idx]
649
+ parent_msg = up_liks[parent.idx] * sib_prod
650
+ P = self._qcache.P(max(node.dist, self.min_branch_length), rate)
651
+ up = parent_msg @ P
652
+ up_liks[node.idx] = up / (up.sum() or 1.0)
653
+
654
+ for samp in tips_with_missing:
655
+ node = tree.get_nodes(samp)[0]
656
+ leaf_emission = down_liks[node.idx] # uniform if missing
657
+ tip_post = self._norm(up_liks[node.idx] * leaf_emission)
658
+
659
+ if self.ground_truth_ and (samp, snp_index) in self.ground_truth_:
660
+ if not self.haploid:
661
+ # store genotype posterior (10) and allele-marginal (4)
662
+ self.imputed_genotype_likelihoods_[(samp, snp_index)] = (
663
+ tip_post.copy()
664
+ )
665
+ self.imputed_likelihoods_[(samp, snp_index)] = (
666
+ self._marginalize_genotype_to_allele(tip_post)
667
+ )
668
+ else:
669
+ self.imputed_likelihoods_[(samp, snp_index)] = tip_post.copy()
670
+
671
+ if not self.haploid:
672
+ call = self._genotype_posterior_to_iupac(tip_post, mode="MAP")
673
+ else:
674
+ call = self._allele_posterior_to_iupac_genotype(
675
+ tip_post, mode="MAP"
676
+ )
677
+ filt_genotypes[samp][snp_index] = call
678
+
679
+ df = pd.DataFrame.from_dict(filt_genotypes, orient="index")
680
+
681
+ if df.applymap(self._is_missing).any().any():
682
+ raise AssertionError("Imputation failed. Missing values remain.")
683
+
684
+ return df
685
+
686
+ def _iupac_to_genotype_classes(self, char: str) -> list[int]:
687
+ """Map IUPAC code to allowed unordered genotype class indices (10-state)."""
688
+ c = str(char).upper()
689
+ # Allele indices: A=0, C=1, G=2, T=3
690
+ single = {"A": 0, "C": 1, "G": 2, "T": 3}
691
+ het_map = { # two-allele ambiguity
692
+ "M": (0, 1),
693
+ "R": (0, 2),
694
+ "W": (0, 3),
695
+ "S": (2, 1),
696
+ "Y": (1, 3),
697
+ "K": (2, 3),
698
+ }
699
+ if c in single:
700
+ # homozygote only
701
+ return [[0, 4, 7, 9][single[c]]] # AA,CC,GG,TT class indices
702
+ if c in het_map:
703
+ i, j = het_map[c]
704
+ # find index in our class order (i<=j by construction)
705
+ i, j = (i, j) if i <= j else (j, i)
706
+ class_order = self._dip_agg.genotype_classes
707
+ return [class_order.index((i, j))]
708
+ # Ambiguity codes with >2 alleles or missing: allow any compatible genotype
709
+ amb = {
710
+ "N": {0, 1, 2, 3},
711
+ "-": {0, 1, 2, 3},
712
+ "B": {1, 2, 3},
713
+ "D": {0, 2, 3},
714
+ "H": {0, 1, 3},
715
+ "V": {0, 1, 2},
716
+ }
717
+ if c in amb:
718
+ allowed = amb[c]
719
+ classes = []
720
+ for idx, (i, j) in enumerate(self._dip_agg.genotype_classes):
721
+ if i in allowed and j in allowed:
722
+ classes.append(idx)
723
+ return classes
724
+ # Fallback: allow all
725
+ return list(range(10))
726
+
727
+ def _genotype_posterior_to_iupac(self, p10: np.ndarray, mode: str = "MAP") -> str:
728
+ """Convert genotype posterior over 10 classes to an IUPAC code."""
729
+ p = np.asarray(p10, dtype=float)
730
+ p = p / (p.sum() or 1.0)
731
+ if mode.upper() == "SAMPLE":
732
+ k = int(np.random.choice(len(p), p=p))
733
+ else:
734
+ k = int(np.argmax(p))
735
+ i, j = self._dip_agg.genotype_classes[k]
736
+ alleles = ["A", "C", "G", "T"]
737
+ gt = alleles[i] + alleles[j]
738
+ return self._genotype_to_iupac(gt)
739
+
740
+ def _marginalize_genotype_to_allele(self, p10: np.ndarray) -> np.ndarray:
741
+ """Allele posterior from genotype posterior for eval/diagnostics (length 4)."""
742
+ p10 = np.asarray(p10, dtype=float)
743
+ p10 = p10 / (p10.sum() or 1.0)
744
+ agg = self._dip_agg
745
+ out = np.zeros(4, dtype=float)
746
+ for a in range(4):
747
+ out[a] = p10[agg.hom_class_index[a]] + 0.5 * np.sum(
748
+ p10[agg.het_classes_by_allele[a]]
749
+ )
750
+ s = out.sum()
751
+ return out / (s or 1.0)
752
+
753
+ def _parse_arguments(
754
+ self,
755
+ ) -> Tuple[
756
+ Dict[str, List[Union[str, int]]], tt.tree, pd.DataFrame, Optional[List[float]]
757
+ ]:
758
+ if (
759
+ not hasattr(self.genotype_data, "snpsdict")
760
+ or self.genotype_data.snpsdict is None
761
+ ):
762
+ raise TypeError("`GenotypeData.snpsdict` must be defined.")
763
+ if not hasattr(self.tree_parser, "tree") or self.tree_parser.tree is None:
764
+ raise TypeError("`TreeParser.tree` must be defined.")
765
+ if not hasattr(self.tree_parser, "qmat") or self.tree_parser.qmat is None:
766
+ raise TypeError("`TreeParser.qmat` must be defined.")
767
+ site_rates = getattr(self.tree_parser, "site_rates", None)
768
+ return (
769
+ self.genotype_data.snpsdict,
770
+ self.tree_parser.tree,
771
+ self.tree_parser.qmat,
772
+ site_rates,
773
+ )
774
+
775
+ def _get_iupac_full(self, char: str) -> List[str]:
776
+ iupac_map = {
777
+ "A": ["A"],
778
+ "G": ["G"],
779
+ "C": ["C"],
780
+ "T": ["T"],
781
+ "N": ["A", "C", "T", "G"],
782
+ "-": ["A", "C", "T", "G"],
783
+ "R": ["A", "G"],
784
+ "Y": ["C", "T"],
785
+ "S": ["G", "C"],
786
+ "W": ["A", "T"],
787
+ "K": ["G", "T"],
788
+ "M": ["A", "C"],
789
+ "B": ["C", "G", "T"],
790
+ "D": ["A", "G", "T"],
791
+ "H": ["A", "C", "T"],
792
+ "V": ["A", "C", "G"],
793
+ }
794
+ return iupac_map.get(char.upper(), ["A", "C", "T", "G"])
795
+
796
+ def _allele_posterior_to_genotype_probs(self, p: np.ndarray) -> Dict[str, float]:
797
+ p = np.maximum(p, 0.0)
798
+ s = p.sum()
799
+ p = (p / s) if s > 0 else np.ones_like(p) / len(p)
800
+ alleles = ["A", "C", "G", "T"]
801
+ probs = {}
802
+ for i, a in enumerate(alleles):
803
+ for j, b in enumerate(alleles[i:], start=i):
804
+ probs[a + b] = float(p[i] * p[j] * (2.0 if i != j else 1.0))
805
+ z = sum(probs.values()) or 1.0
806
+ return {k: v / z for k, v in probs.items()}
807
+
808
+ def _genotype_to_iupac(self, gt: str) -> str:
809
+ if gt[0] == gt[1]:
810
+ return gt[0]
811
+ pair = "".join(sorted(gt))
812
+ het_map = {"AC": "M", "AG": "R", "AT": "W", "CG": "S", "CT": "Y", "GT": "K"}
813
+ return het_map.get(pair, "N")
814
+
815
+ def _allele_posterior_to_iupac_genotype(
816
+ self, p: np.ndarray, mode: str = "MAP"
817
+ ) -> str:
818
+ gprobs = self._allele_posterior_to_genotype_probs(p)
819
+ if mode.upper() == "SAMPLE":
820
+ gts, vals = zip(*gprobs.items())
821
+ choice = np.random.choice(len(gts), p=np.array(vals, dtype=float))
822
+ return self._genotype_to_iupac(gts[choice])
823
+ best_gt = max(gprobs, key=gprobs.get)
824
+ return self._genotype_to_iupac(best_gt)
825
+
826
+ def _align_and_check_q(self, q_df: pd.DataFrame) -> pd.DataFrame:
827
+ """Return Q reindexed to ['A','C','G','T'] and assert CTMC sanity.
828
+
829
+ Args:
830
+ q_df: Square rate matrix with nucleotide index/columns.
831
+
832
+ Returns:
833
+ Reindexed Q as a DataFrame in A,C,G,T order.
834
+
835
+ Raises:
836
+ ValueError: If Q is malformed or not alignable.
837
+ """
838
+ required = ["A", "C", "G", "T"]
839
+ if not set(required).issubset(set(q_df.index)) or not set(required).issubset(
840
+ set(q_df.columns)
841
+ ):
842
+ raise ValueError("Q must have index/columns including exactly A,C,G,T.")
843
+
844
+ q_df = q_df.loc[required, required].astype(float)
845
+
846
+ q = q_df.to_numpy()
847
+
848
+ # Off-diagonals >= 0, diagonals <= 0, rows sum ~ 0
849
+ if np.any(q[~np.eye(4, dtype=bool)] < -1e-12):
850
+ raise ValueError("Q off-diagonal entries must be non-negative.")
851
+
852
+ if np.any(np.diag(q) > 1e-12):
853
+ raise ValueError("Q diagonal entries must be non-positive.")
854
+
855
+ if not np.allclose(q.sum(axis=1), 0.0, atol=1e-8):
856
+ self.logger.error(q.sum(axis=1))
857
+ raise ValueError("Q rows must sum to 0.")
858
+
859
+ return pd.DataFrame(q, index=required, columns=required)
860
+
861
+ def _stationary_pi_from_q(self, Q: np.ndarray) -> np.ndarray:
862
+ """Compute stationary distribution pi for generator Q.
863
+
864
+ Uses eigenvector of Q^T at eigenvalue 0; clips tiny negatives; renormalizes.
865
+
866
+ Args:
867
+ Q: (4,4) CTMC generator.
868
+
869
+ Returns:
870
+ (4,) stationary distribution pi with sum=1.
871
+ """
872
+ w, v = scipy.linalg.eig(Q.T)
873
+ k = int(np.argmin(np.abs(w)))
874
+ pi = np.real(v[:, k])
875
+ pi = np.maximum(pi, 0.0)
876
+ s = float(pi.sum())
877
+ if not np.isfinite(s) or s <= 0:
878
+ # Fallback uniform if eigen failed
879
+ return np.ones(Q.shape[0]) / Q.shape[0]
880
+ return pi / s
881
+
882
+ def _repair_and_scale_q(
883
+ self,
884
+ q_df: pd.DataFrame,
885
+ target_mean_rate: float = 1.0,
886
+ state_order: tuple[str, ...] = ("A", "C", "G", "T"),
887
+ neg_offdiag_tol: float = 1e-8,
888
+ ) -> pd.DataFrame:
889
+ """Repair Q to a valid CTMC generator and scale its mean rate.
890
+
891
+ Steps:
892
+ 1) Reindex to `state_order` and cast to float.
893
+ 2) Set negative off-diagonals to 0 (warn if any < -neg_offdiag_tol).
894
+ 3) Set diagonal q_ii = -sum_{j!=i} q_ij so rows sum to 0 exactly.
895
+ 4) If a row has zero off-diagonal sum, inject a tiny uniform exit rate.
896
+ 5) Compute stationary pi and scale Q so that -sum_i pi_i q_ii = target_mean_rate.
897
+
898
+ Args:
899
+ q_df: DataFrame with index=columns=states.
900
+ target_mean_rate: Desired average rate under pi (commonly 1.0).
901
+ state_order: Desired nucleotide order.
902
+ neg_offdiag_tol: Tolerance to log warnings for negative off-diags.
903
+
904
+ Returns:
905
+ Repaired, scaled Q as a DataFrame in `state_order`.
906
+ """
907
+ # 1) Align
908
+ missing = set(state_order) - set(q_df.index) | set(state_order) - set(
909
+ q_df.columns
910
+ )
911
+ if missing:
912
+ raise ValueError(
913
+ f"Q must have states {state_order}, missing: {sorted(missing)}"
914
+ )
915
+ Q = q_df.loc[state_order, state_order].to_numpy(dtype=float, copy=True)
916
+
917
+ # 2) Clip negative off-diagonals
918
+ off_mask = ~np.eye(Q.shape[0], dtype=bool)
919
+ neg_off = Q[off_mask] < 0
920
+ if np.any(Q[off_mask] < -neg_offdiag_tol):
921
+ self.logger.warning(
922
+ f"Q has negative off-diagonals; clipping {int(np.sum(neg_off))} entries."
923
+ )
924
+ Q[off_mask] = np.maximum(Q[off_mask], 0.0)
925
+
926
+ # 3) Set diagonal so rows sum to 0 exactly
927
+ row_off_sum = Q.sum(axis=1) - np.diag(Q) # includes diag; fix next
928
+ np.fill_diagonal(Q, -row_off_sum)
929
+
930
+ # 4) Ensure no absorbing rows
931
+ zero_rows = row_off_sum <= 0
932
+ if np.any(zero_rows):
933
+ eps = 1e-8
934
+ Q[zero_rows, :] = 0.0
935
+ for i in np.where(zero_rows)[0]:
936
+ Q[i, :] = eps / (Q.shape[0] - 1)
937
+ Q[i, i] = -eps
938
+ self.logger.warning(
939
+ f"Injected tiny exit rates in {int(np.sum(zero_rows))} rows."
940
+ )
941
+
942
+ # Sanity: rows now sum to 0 exactly
943
+ if not np.allclose(Q.sum(axis=1), 0.0, atol=1e-12):
944
+ raise RuntimeError("Internal error: row sums not zero after repair.")
945
+
946
+ # 5) Scale to target mean rate under stationary pi
947
+ pi = self._stationary_pi_from_q(Q)
948
+ mean_rate = float(-np.dot(pi, np.diag(Q)))
949
+ if not (np.isfinite(mean_rate) and mean_rate > 0):
950
+ self.logger.warning("Mean rate non-positive; skipping scaling.")
951
+ scale = 1.0
952
+ else:
953
+ scale = mean_rate / float(target_mean_rate)
954
+ Q /= scale
955
+
956
+ return pd.DataFrame(Q, index=state_order, columns=state_order)
957
+
958
+ def _is_missing(self, x: object) -> bool:
959
+ """Return True if x represents a missing genotype token."""
960
+ s = str(x).strip()
961
+ return s in self._MISSING_TOKENS
962
+
963
+ def _is_known(self, x: object) -> bool:
964
+ return not self._is_missing(x)
965
+
966
+ @staticmethod
967
+ def _norm(v: np.ndarray) -> np.ndarray:
968
+ s = float(np.sum(v))
969
+ if s <= 0 or not np.isfinite(s):
970
+ return np.ones_like(v) / v.size
971
+ return v / s