pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/RECORD +0 -75
  83. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,725 @@
1
+ # Standard library imports
2
+ import logging
3
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple
4
+
5
+ # Third-party imports
6
+ import numpy as np
7
+ import pandas as pd
8
+ from sklearn.exceptions import NotFittedError
9
+ from sklearn.metrics import (
10
+ accuracy_score,
11
+ classification_report,
12
+ f1_score,
13
+ precision_score,
14
+ recall_score,
15
+ )
16
+ from snpio import GenotypeEncoder
17
+
18
+ # Local imports
19
+ from pgsui.utils.plotting import Plotting
20
+
21
+ # Type checking imports
22
+ if TYPE_CHECKING:
23
+ from snpio.read_input.genotype_data import GenotypeData
24
+
25
+
26
+ class ImputeAlleleFreq:
27
+ """Frequency-based imputer for integer-encoded categorical genotype data with test-only evaluation.
28
+
29
+ This implementation imputes missing values by sampling from the empirical frequency distribution of observed integer codes (e.g., 0-9 for nucleotides). It supports a strict train/test protocol: distributions are learned on the train split, missingness is simulated only on the test split, and all metrics are computed exclusively on the test split.
30
+
31
+ The algorithm is as follows:
32
+ 1. Split the dataset into train and test sets (row-wise).
33
+ 2. For each feature (column), compute the empirical frequency distribution of observed values in the train set.
34
+ 3. On the test set, simulate additional missing values by randomly masking a specified proportion of observed entries.
35
+ 4. Impute missing values in the test set by sampling from the train-learned distributions.
36
+ 5. Evaluate imputation accuracy using various metrics on the test set only.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ genotype_data: "GenotypeData",
42
+ *,
43
+ prefix: str = "pgsui",
44
+ by_populations: bool = False,
45
+ default: int = 0,
46
+ missing: int = -9,
47
+ verbose: bool = True,
48
+ seed: Optional[int] = None,
49
+ sim_prop_missing: float = 0.30,
50
+ debug: bool = False,
51
+ test_size: float = 0.2,
52
+ test_indices: Optional[Sequence[int]] = None,
53
+ stratify_by_populations: bool = False,
54
+ logger: logging.Logger | None = None,
55
+ ) -> None:
56
+ """Initialize ImputeAlleleFreq.
57
+
58
+ Args:
59
+ genotype_data: Object with `.genotypes_int`, `.ref`, `.alt`, and optional `.populations`.
60
+ prefix: Output prefix.
61
+ by_populations: Learn separate dists per population.
62
+ default: Default genotype for cold-start loci.
63
+ missing: Integer code for missing values in X_.
64
+ verbose: Verbosity switch.
65
+ seed: RNG seed.
66
+ sim_prop_missing: Fraction of OBSERVED test cells to mask for evaluation.
67
+ debug: Debug switch.
68
+ test_size: Fraction of rows held out for test if `test_indices` not provided.
69
+ test_indices: Explicit test row indices. Overrides `test_size` if given.
70
+ stratify_by_populations: If True and populations are available, create a stratified test split per population.
71
+ """
72
+ self.genotype_data: "GenotypeData" = genotype_data
73
+ self.prefix: str = prefix
74
+ self.by_populations: bool = by_populations
75
+ self.default: int = int(default)
76
+ self.missing: int = int(missing)
77
+ self.sim_prop_missing: float = float(sim_prop_missing)
78
+ self.verbose: bool = verbose
79
+ self.debug: bool = debug
80
+ self.logger: logging.Logger = logger or logging.getLogger(__name__)
81
+
82
+ if not (0.0 <= self.sim_prop_missing <= 0.95):
83
+ raise ValueError("sim_prop_missing must be in [0, 0.95].")
84
+
85
+ self.rng: np.random.Generator = np.random.default_rng(seed)
86
+ self.encoder: GenotypeEncoder = GenotypeEncoder(self.genotype_data)
87
+ self.X_: np.ndarray = np.asarray(self.encoder.genotypes_int, dtype=np.int8)
88
+ self.num_features_: int = self.X_.shape[1]
89
+
90
+ # --- split controls ---
91
+ self.test_size: float = float(test_size)
92
+ self.test_indices: np.ndarray | None = (
93
+ None if test_indices is None else np.asarray(test_indices, dtype=int)
94
+ )
95
+
96
+ self.stratify_by_populations: bool = bool(stratify_by_populations)
97
+ self.pops: np.ndarray | None = None
98
+ if self.by_populations:
99
+ pops = getattr(self.genotype_data, "populations", None)
100
+ if pops is None:
101
+ raise TypeError(
102
+ "by_populations=True requires genotype_data.populations."
103
+ )
104
+ self.pops = np.asarray(pops)
105
+ if len(self.pops) != self.X_.shape[0]:
106
+ raise ValueError(
107
+ f"`populations` length ({len(self.pops)}) != number of samples ({self.X_.shape[0]})."
108
+ )
109
+
110
+ self.is_fit_: bool = False
111
+ self.global_dist_: Dict[int | str, Tuple[np.ndarray, np.ndarray]] = {}
112
+ self.group_dist_: Dict[
113
+ str | int, Dict[int | str, Tuple[np.ndarray, np.ndarray]]
114
+ ] = {}
115
+ self.sim_mask_: Optional[np.ndarray] = None
116
+ self.train_idx_: Optional[np.ndarray] = None
117
+ self.test_idx_: Optional[np.ndarray] = None
118
+ self.X_train_: Optional[pd.DataFrame] = None
119
+ self.ground_truths_: Optional[np.ndarray] = None
120
+ self.metrics_: Dict[str, float] = {}
121
+ self.X_imputed_: Optional[np.ndarray] = None
122
+
123
+ # VCF ref/alt cache + IUPAC LUT
124
+ self.ref_codes_: Optional[np.ndarray] = None # (nF,) in {0..3}
125
+ self.alt_mask_: Optional[np.ndarray] = None # (nF,4) bool
126
+ self._iupac_presence_lut_: Optional[np.ndarray] = None # (10,4) bool
127
+
128
+ plot_fmt: Literal["pdf", "png", "jpg", "jpeg"] = getattr(
129
+ genotype_data, "plot_format", "png"
130
+ )
131
+
132
+ self.plotter = Plotting(
133
+ "ImputeAlleleFreq",
134
+ prefix=self.prefix,
135
+ plot_format=plot_fmt,
136
+ plot_fontsize=genotype_data.plot_fontsize,
137
+ plot_dpi=genotype_data.plot_dpi,
138
+ title_fontsize=genotype_data.plot_fontsize,
139
+ despine=genotype_data.plot_despine,
140
+ show_plots=genotype_data.show_plots,
141
+ verbose=self.verbose,
142
+ debug=self.debug,
143
+ )
144
+
145
+ # ------------------------------------------
146
+ # Helpers for VCF ref/alt and IUPAC mapping
147
+ # ------------------------------------------
148
+ def _map_base_to_int(self, arr: np.ndarray | None) -> np.ndarray | None:
149
+ """Map bases to A/C/G/T -> 0/1/2/3; pass integers through; others -> -1.
150
+
151
+ Args:
152
+ arr (np.ndarray | None): Array-like of bases (str) or integer codes.
153
+
154
+ Returns:
155
+ np.ndarray | None: Mapped integer array, or None if input is None or invalid.
156
+ """
157
+ if arr is None:
158
+ return None
159
+ arr = np.asarray(arr)
160
+ if arr.dtype.kind in ("i", "u"):
161
+ return arr.astype(np.int32, copy=False)
162
+ if arr.dtype.kind in ("U", "S", "O"):
163
+ up = np.char.upper(arr.astype("U1"))
164
+ out = np.full(up.shape, -1, dtype=np.int32)
165
+ out[up == "A"] = 0
166
+ out[up == "C"] = 1
167
+ out[up == "G"] = 2
168
+ out[up == "T"] = 3
169
+ return out
170
+ return None
171
+
172
+ def _ref_codes_from_genotype_data(self) -> np.ndarray:
173
+ """Fetch per-locus reference base from genotype_data.ref as 0..3.
174
+
175
+ Returns:
176
+ np.ndarray: Array of shape (n_features,) with values in {0,1,2,3}.
177
+ """
178
+ ref_raw = getattr(self.genotype_data, "ref", None)
179
+ ref_codes = self._map_base_to_int(ref_raw)
180
+ if ref_codes is None or ref_codes.shape[0] != self.num_features_:
181
+ msg = (
182
+ "genotype_data.ref missing or wrong length; "
183
+ f"expected ({self.num_features_},) got "
184
+ f"{None if ref_codes is None else ref_codes.shape}"
185
+ )
186
+ raise ValueError(msg)
187
+ return ref_codes
188
+
189
+ def _alt_mask_from_genotype_data(self) -> np.ndarray:
190
+ """Build a per-locus mask of which bases are ALT (supports multi-alt).
191
+
192
+ Returns:
193
+ np.ndarray: Boolean array of shape (n_features, 4) indicating presence of A,C,G,T as ALT.
194
+ """
195
+ nF = self.num_features_
196
+ alt_raw = getattr(self.genotype_data, "alt", None)
197
+ alt_mask = np.zeros((nF, 4), dtype=bool)
198
+ if alt_raw is None:
199
+ return alt_mask
200
+
201
+ alt_arr = np.asarray(alt_raw, dtype=object)
202
+ if alt_arr.shape[0] != nF and self.verbose:
203
+ print(
204
+ f"[warn] genotype_data.alt length {alt_arr.shape[0]} != n_features {nF}; truncating."
205
+ )
206
+
207
+ def add_code(mask_row, x):
208
+ if x is None:
209
+ return
210
+ if isinstance(x, (int, np.integer)):
211
+ v = int(x)
212
+ if 0 <= v <= 3:
213
+ mask_row[v] = True
214
+ return
215
+ if isinstance(x, str):
216
+ s = x.strip().upper()
217
+ if not s:
218
+ return
219
+ if "," in s:
220
+ for token in s.split(","):
221
+ add_code(mask_row, token.strip())
222
+ return
223
+ if s in ("A", "C", "G", "T"):
224
+ idx = {"A": 0, "C": 1, "G": 2, "T": 3}[s]
225
+ mask_row[idx] = True
226
+ return
227
+ if isinstance(x, (list, tuple, np.ndarray)):
228
+ for t in x:
229
+ add_code(mask_row, t)
230
+ return
231
+
232
+ for i in range(min(nF, alt_arr.shape[0])):
233
+ add_code(alt_mask[i], alt_arr[i])
234
+ return alt_mask
235
+
236
+ def _build_iupac_presence_lut(self) -> np.ndarray:
237
+ """Create LUT mapping integer codes {0..9} -> allele presence over A,C,G,T.
238
+
239
+ Returns:
240
+ np.ndarray: Boolean array of shape (10,4) indicating presence of A,C,G,T for each IUPAC code.
241
+ """
242
+ lut = np.zeros((10, 4), dtype=bool) # A,C,G,T
243
+ lut[0, 0] = True
244
+ lut[1, 3] = True
245
+ lut[2, 2] = True
246
+ lut[3, 1] = True # A T G C
247
+ lut[4, [0, 3]] = True
248
+ lut[5, [0, 2]] = True
249
+ lut[6, [0, 1]] = True # W R M
250
+ lut[7, [2, 3]] = True
251
+ lut[8, [1, 3]] = True
252
+ lut[9, [1, 2]] = True # K Y S
253
+ return lut
254
+
255
+ # -----------------------
256
+ # Fit / Transform
257
+ # -----------------------
258
+ def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
259
+ """Create row-wise train/test split according to init settings.
260
+
261
+ Returns:
262
+ Tuple[np.ndarray, np.ndarray]: (train indices, test indices) as integer arrays.
263
+ """
264
+ n = self.X_.shape[0]
265
+ all_idx = np.arange(n, dtype=int)
266
+
267
+ if self.test_indices is not None:
268
+ test_idx = np.unique(self.test_indices)
269
+ if np.any((test_idx < 0) | (test_idx >= n)):
270
+ raise IndexError("Some test_indices are out of bounds.")
271
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
272
+ return train_idx, test_idx
273
+
274
+ # Random split (optionally stratified by population)
275
+ if (
276
+ self.by_populations
277
+ and self.stratify_by_populations
278
+ and (self.pops is not None)
279
+ ):
280
+ test_buckets = []
281
+ for pop in np.unique(self.pops):
282
+ pop_rows = np.where(self.pops == pop)[0]
283
+ k = int(round(self.test_size * pop_rows.size))
284
+ if k > 0:
285
+ chosen = self.rng.choice(pop_rows, size=k, replace=False)
286
+ test_buckets.append(chosen)
287
+ test_idx = (
288
+ np.sort(np.concatenate(test_buckets))
289
+ if test_buckets
290
+ else np.array([], dtype=int)
291
+ )
292
+ else:
293
+ k = int(round(self.test_size * n))
294
+ test_idx = (
295
+ self.rng.choice(n, size=k, replace=False)
296
+ if k > 0
297
+ else np.array([], dtype=int)
298
+ )
299
+
300
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
301
+ return train_idx, test_idx
302
+
303
+ def fit(self) -> "ImputeAlleleFreq":
304
+ """Learn per-locus distributions on TRAIN rows; simulate missingness on TEST rows.
305
+
306
+ Notes:
307
+ The general workflow is:
308
+ 1) Split rows into train/test.
309
+ 2) Build distributions from TRAIN only.
310
+ 3) Simulate missingness on TEST only (stores `sim_mask_`).
311
+ 4) Cache ref/alt and IUPAC LUT.
312
+ """
313
+ # 0) Row split
314
+ self.train_idx_, self.test_idx_ = self._make_train_test_split()
315
+
316
+ self.ground_truths_ = self.X_.copy()
317
+ df_all = pd.DataFrame(self.ground_truths_, dtype=np.float32)
318
+ df_all.replace(self.missing, np.nan, inplace=True)
319
+
320
+ # 1) TRAIN-only distributions (no simulated holes here)
321
+ df_train = df_all.iloc[self.train_idx_].copy()
322
+
323
+ self.global_dist_ = {
324
+ col: self._series_distribution(df_train[col]) for col in df_train.columns
325
+ }
326
+
327
+ self.group_dist_.clear()
328
+ if self.by_populations:
329
+ tmp = df_train.copy()
330
+
331
+ if self.pops is not None:
332
+ tmp["_pops_"] = self.pops[self.train_idx_]
333
+ for pop, grp in tmp.groupby("_pops_"):
334
+ pop_key = str(pop)
335
+ gdf = grp.drop(columns=["_pops_"])
336
+ self.group_dist_[pop_key] = {
337
+ col: self._series_distribution(gdf[col]) for col in gdf.columns
338
+ }
339
+
340
+ # 2) TEST-only simulated missingness
341
+ obs_mask = df_all.notna().to_numpy()
342
+ sim_mask = np.zeros_like(obs_mask, dtype=bool)
343
+ if self.test_idx_.size > 0 and self.sim_prop_missing > 0.0:
344
+ # restrict candidate coords to TEST rows that are observed
345
+ coords = np.argwhere(obs_mask)
346
+ test_row_mask = np.zeros(obs_mask.shape[0], dtype=bool)
347
+ test_row_mask[self.test_idx_] = True
348
+ coords_test = coords[test_row_mask[coords[:, 0]]]
349
+ total_obs_test = coords_test.shape[0]
350
+ if total_obs_test > 0:
351
+ n_to_mask = int(round(self.sim_prop_missing * total_obs_test))
352
+ if n_to_mask > 0:
353
+ choice_idx = self.rng.choice(
354
+ total_obs_test, size=n_to_mask, replace=False
355
+ )
356
+ chosen_coords = coords_test[choice_idx]
357
+ sim_mask[chosen_coords[:, 0], chosen_coords[:, 1]] = True
358
+
359
+ df_sim = df_all.copy()
360
+ df_sim.values[sim_mask] = np.nan
361
+
362
+ # Store matrix to be imputed (train rows intact; test rows with simulated NaNs)
363
+ self.sim_mask_ = sim_mask
364
+ self.X_train_ = df_sim
365
+
366
+ # 3) Cache VCF ref/alt + IUPAC LUT once
367
+ self.ref_codes_ = self._ref_codes_from_genotype_data() # (nF,)
368
+ self.alt_mask_ = self._alt_mask_from_genotype_data() # (nF,4)
369
+ self._iupac_presence_lut_ = self._build_iupac_presence_lut()
370
+
371
+ self.is_fit_ = True
372
+ if self.verbose:
373
+ n_masked = int(sim_mask.sum())
374
+ print(
375
+ f"Fit complete. Train rows: {self.train_idx_.size}, Test rows: {self.test_idx_.size}."
376
+ )
377
+ print(
378
+ f"Simulated {n_masked} missing values (TEST rows only) for evaluation."
379
+ )
380
+ return self
381
+
382
+ def transform(self) -> np.ndarray:
383
+ """Impute the matrix and evaluate on the TEST-set simulated cells.
384
+
385
+ Returns:
386
+ np.ndarray: Imputed genotypes in the original (IUPAC/int) encoding.
387
+ """
388
+ if not self.is_fit_:
389
+ msg = "Model is not fitted. Call `fit()` before `transform()`."
390
+ self.logger.error(msg)
391
+ raise NotFittedError(msg)
392
+ assert (
393
+ self.X_train_ is not None
394
+ and self.sim_mask_ is not None
395
+ and self.ground_truths_ is not None
396
+ )
397
+
398
+ # ---- Impute using train-learned distributions ----
399
+ # uses self.global_dist_/group_dist_
400
+ imputed_df = self._impute_df(self.X_train_)
401
+ X_imp = imputed_df.to_numpy(dtype=np.int8)
402
+ self.X_imputed_ = X_imp
403
+
404
+ # -------------------------------------------------------------------
405
+ # Test-set evaluation on IUPAC/int codes (mask is test-only by design)
406
+ # -------------------------------------------------------------------
407
+ sim_mask = self.sim_mask_
408
+ y_true = self.ground_truths_[sim_mask]
409
+ y_pred = X_imp[sim_mask]
410
+
411
+ if y_true.size > 0:
412
+ labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
413
+ self.metrics_ = {
414
+ "n_masked_test": int(y_true.size),
415
+ "accuracy": float(accuracy_score(y_true, y_pred)),
416
+ "f1": float(
417
+ f1_score(
418
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
419
+ )
420
+ ),
421
+ "precision": float(
422
+ precision_score(
423
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
424
+ )
425
+ ),
426
+ "recall": float(
427
+ recall_score(
428
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
429
+ )
430
+ ),
431
+ }
432
+ if self.verbose:
433
+ print("\n--- TEST-only Evaluation (IUPAC/int) ---")
434
+ for k, v in self.metrics_.items():
435
+ print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
436
+ print("\nClassification Report (IUPAC/int, TEST-only):")
437
+ print(
438
+ classification_report(
439
+ y_true, y_pred, labels=labels, zero_division=0
440
+ )
441
+ )
442
+ else:
443
+ self.metrics_.update({"n_masked_test": 0})
444
+ if self.verbose:
445
+ print("No TEST cells were held out for evaluation (n_masked_test=0).")
446
+
447
+ # Optional confusion matrix (IUPAC/int)
448
+ labels_map = {
449
+ "A": 0,
450
+ "T": 1,
451
+ "G": 2,
452
+ "C": 3,
453
+ "W": 4,
454
+ "R": 5,
455
+ "M": 6,
456
+ "K": 7,
457
+ "Y": 8,
458
+ "S": 9,
459
+ "N": -9,
460
+ }
461
+ self.plotter.plot_confusion_matrix(y_true, y_pred, label_names=labels_map)
462
+
463
+ # ----------------------------------------------------------------
464
+ # TEST-only Zygosity 0/1/2 evaluation using VCF ref/alt (hom-ref/het/hom-alt)
465
+ # ----------------------------------------------------------------
466
+ r_idx, f_idx = np.nonzero(sim_mask)
467
+ if r_idx.size > 0:
468
+ true_codes = self.ground_truths_[r_idx, f_idx].astype(np.int16, copy=False)
469
+ pred_codes = X_imp[r_idx, f_idx].astype(np.int16, copy=False)
470
+
471
+ keep_nm = (true_codes != self.missing) & (pred_codes != self.missing)
472
+ if np.any(keep_nm):
473
+ true_codes = true_codes[keep_nm]
474
+ pred_codes = pred_codes[keep_nm]
475
+ f_k = f_idx[keep_nm]
476
+
477
+ if self.ref_codes_ is None or self.alt_mask_ is None:
478
+ raise RuntimeError(
479
+ "VCF ref/alt codes not cached; cannot compute zygosity metrics."
480
+ )
481
+
482
+ ref_k = self.ref_codes_[f_k] # (n,)
483
+ alt_rows = self.alt_mask_[f_k, :] # (n,4)
484
+ ra_mask = alt_rows.copy()
485
+ ra_mask[np.arange(ref_k.size), ref_k] = True
486
+
487
+ lut = self._iupac_presence_lut_
488
+
489
+ if lut is None:
490
+ raise RuntimeError(
491
+ "IUPAC presence LUT not cached; cannot compute zygosity metrics."
492
+ )
493
+ valid_true = (true_codes >= 0) & (true_codes < lut.shape[0])
494
+ valid_pred = (pred_codes >= 0) & (pred_codes < lut.shape[0])
495
+ keep_valid = valid_true & valid_pred
496
+
497
+ if np.any(keep_valid):
498
+ true_codes = true_codes[keep_valid]
499
+ pred_codes = pred_codes[keep_valid]
500
+ ref_k = ref_k[keep_valid]
501
+ ra_mask = ra_mask[keep_valid, :]
502
+
503
+ A_true = lut[true_codes] # (n,4)
504
+ A_pred = lut[pred_codes] # (n,4)
505
+
506
+ any_true = A_true.any(axis=1)
507
+ any_pred = A_pred.any(axis=1)
508
+ out_true = (A_true & ~ra_mask).any(axis=1)
509
+ out_pred = (A_pred & ~ra_mask).any(axis=1)
510
+ valid_rows = any_true & any_pred & (~out_true) & (~out_pred)
511
+
512
+ if np.any(valid_rows):
513
+ A_true = A_true[valid_rows]
514
+ A_pred = A_pred[valid_rows]
515
+ ref_kv = ref_k[valid_rows]
516
+ n = A_true.shape[0]
517
+ rows = np.arange(n, dtype=int)
518
+
519
+ cnt_true = A_true.sum(axis=1)
520
+ homref_true = (cnt_true == 1) & A_true[rows, ref_kv]
521
+ homalt_true = (cnt_true == 1) & (~A_true[rows, ref_kv])
522
+ y_true_3 = np.empty(n, dtype=np.int8)
523
+ y_true_3[homref_true] = 0
524
+ y_true_3[homalt_true] = 2
525
+ y_true_3[~(homref_true | homalt_true)] = 1
526
+
527
+ cnt_pred = A_pred.sum(axis=1)
528
+ homref_pred = (cnt_pred == 1) & A_pred[rows, ref_kv]
529
+ homalt_pred = (cnt_pred == 1) & (~A_pred[rows, ref_kv])
530
+ y_pred_3 = np.empty(n, dtype=np.int8)
531
+ y_pred_3[homref_pred] = 0
532
+ y_pred_3[homalt_pred] = 2
533
+ y_pred_3[~(homref_pred | homalt_pred)] = 1
534
+
535
+ labels_3 = [0, 1, 2]
536
+ self.metrics_.update(
537
+ {
538
+ "zyg_n_test": int(n),
539
+ "zyg_accuracy": float(
540
+ accuracy_score(y_true_3, y_pred_3)
541
+ ),
542
+ "zyg_f1": float(
543
+ f1_score(
544
+ y_true_3,
545
+ y_pred_3,
546
+ average="macro",
547
+ labels=labels_3,
548
+ zero_division=0,
549
+ )
550
+ ),
551
+ "zyg_precision": float(
552
+ precision_score(
553
+ y_true_3,
554
+ y_pred_3,
555
+ average="macro",
556
+ labels=labels_3,
557
+ zero_division=0,
558
+ )
559
+ ),
560
+ "zyg_recall": float(
561
+ recall_score(
562
+ y_true_3,
563
+ y_pred_3,
564
+ average="macro",
565
+ labels=labels_3,
566
+ zero_division=0,
567
+ )
568
+ ),
569
+ }
570
+ )
571
+ if self.verbose:
572
+ print(
573
+ "\n--- TEST-only Zygosity (0=hom-ref,1=het,2=hom-alt) ---"
574
+ )
575
+ for k in (
576
+ "zyg_n_test",
577
+ "zyg_accuracy",
578
+ "zyg_f1",
579
+ "zyg_precision",
580
+ "zyg_recall",
581
+ ):
582
+ v = self.metrics_[k]
583
+ print(
584
+ f" {k}: {v:.4f}"
585
+ if isinstance(v, float)
586
+ else f" {k}: {v}"
587
+ )
588
+ print("\nClassification Report (zyg, TEST-only):")
589
+ print(
590
+ classification_report(
591
+ y_true_3,
592
+ y_pred_3,
593
+ labels=labels_3,
594
+ target_names=["hom-ref", "het", "hom-alt"],
595
+ zero_division=0,
596
+ )
597
+ )
598
+ self.plotter.plot_confusion_matrix(
599
+ y_true_3,
600
+ y_pred_3,
601
+ label_names=["hom-ref", "het", "hom-alt"],
602
+ )
603
+ else:
604
+ if self.verbose:
605
+ print(
606
+ "[info] Zygosity TEST-only: no valid rows after RA filtering."
607
+ )
608
+ else:
609
+ if self.verbose:
610
+ print(
611
+ "[info] Zygosity TEST-only: no valid rows after code filtering."
612
+ )
613
+ else:
614
+ if self.verbose:
615
+ print(
616
+ "[info] Zygosity TEST-only: nothing to score (all masked entries missing)."
617
+ )
618
+ else:
619
+ if self.verbose:
620
+ print("[info] TEST-only evaluation: no masked coordinates found.")
621
+
622
+ return self.encoder.inverse_int_iupac(X_imp)
623
+
624
+ def fit_transform(self) -> np.ndarray:
625
+ """Convenience method that calls `fit()` then `transform()`."""
626
+ self.fit()
627
+ return self.transform()
628
+
629
+ # -----------------------
630
+ # Core frequency model
631
+ # -----------------------
632
+ def _safe_probs(self, probs: np.ndarray) -> np.ndarray:
633
+ """Ensure probs are non-negative and sum to 1; fallback to uniform if invalid.
634
+
635
+ Args:
636
+ probs: Array of non-negative values (not necessarily summing to 1).
637
+
638
+ Returns:
639
+ np.ndarray: Valid probability distribution summing to 1.
640
+ """
641
+ probs = np.asarray(probs, dtype=float)
642
+ probs[probs < 0] = 0
643
+ s = probs.sum()
644
+
645
+ if not np.isfinite(s) or s <= 0:
646
+ return np.full(probs.size, 1.0 / max(1, probs.size))
647
+ return probs / s
648
+
649
+ def _series_distribution(self, s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
650
+ """Compute empirical (states, probs) for one locus from observed integer codes.
651
+
652
+ Args:
653
+ s: One column (locus) as a pandas Series with NaNs for missing.
654
+
655
+ Returns:
656
+ Tuple[np.ndarray, np.ndarray]: (states, probs) sorted by state.
657
+ """
658
+ s_valid = s.dropna().astype(int)
659
+ if s_valid.empty:
660
+ return np.array([self.default], dtype=int), np.array([1.0])
661
+ freqs = s_valid.value_counts(normalize=True).sort_index()
662
+ states = freqs.index.to_numpy(dtype=int)
663
+ probs = self._safe_probs(freqs.to_numpy(dtype=float))
664
+ return states, probs
665
+
666
+ def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
667
+ """Impute NaNs in df_in using precomputed TRAIN distributions.
668
+
669
+ Args:
670
+ df_in: DataFrame with NaNs to impute.
671
+
672
+ Returns:
673
+ DataFrame with NaNs imputed.
674
+ """
675
+ return (
676
+ self._impute_global(df_in)
677
+ if not self.by_populations
678
+ else self._impute_by_population(df_in)
679
+ )
680
+
681
+ def _impute_global(self, df_in: pd.DataFrame) -> pd.DataFrame:
682
+ """Impute dataframe globally, preserving original sample order.
683
+
684
+ Args:
685
+ df_in: DataFrame with NaNs to impute.
686
+
687
+ Returns:
688
+ DataFrame with NaNs imputed.
689
+ """
690
+ df = df_in.copy()
691
+ for col in df.columns:
692
+ if not df[col].isnull().any():
693
+ continue
694
+ states, probs = self.global_dist_[col]
695
+ n_missing = int(df[col].isnull().sum())
696
+ samples = self.rng.choice(states, size=n_missing, p=probs)
697
+ df.loc[df[col].isnull(), col] = samples
698
+ return df.astype(np.int8)
699
+
700
+ def _impute_by_population(self, df_in: pd.DataFrame) -> pd.DataFrame:
701
+ """Impute dataframe by population, preserving original sample order.
702
+
703
+ Args:
704
+ df_in: DataFrame with NaNs to impute.
705
+
706
+ Returns:
707
+ DataFrame with NaNs imputed.
708
+ """
709
+ df = df_in.copy()
710
+ df["_pops_"] = getattr(self, "pops", None)
711
+ for pop, grp in df.groupby("_pops_"):
712
+ pop_key = str(pop)
713
+ grp_imputed = grp.copy()
714
+ per_pop_dist = self.group_dist_.get(pop_key, {})
715
+ for col in grp.columns:
716
+ if col == "_pops_":
717
+ continue
718
+ if not grp[col].isnull().any():
719
+ continue
720
+ states, probs = per_pop_dist.get(col, self.global_dist_[col])
721
+ n_missing = int(grp[col].isnull().sum())
722
+ samples = self.rng.choice(states, size=n_missing, p=probs)
723
+ grp_imputed.loc[grp_imputed[col].isnull(), col] = samples
724
+ df.update(grp_imputed)
725
+ return df.drop(columns=["_pops_"]).astype(np.int8)