pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,691 @@
1
+ # Standard library imports
2
+ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
3
+
4
+ # Third-party imports
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.exceptions import NotFittedError
8
+ from sklearn.metrics import (
9
+ accuracy_score,
10
+ classification_report,
11
+ f1_score,
12
+ precision_score,
13
+ recall_score,
14
+ )
15
+ from snpio import GenotypeEncoder
16
+
17
+ # Local imports
18
+ from pgsui.utils.plotting import Plotting
19
+
20
+ # Type checking imports
21
+ if TYPE_CHECKING:
22
+ from snpio.read_input.genotype_data import GenotypeData
23
+
24
+
25
+ class ImputeAlleleFreq:
26
+ """Frequency-based imputer for integer-encoded categorical genotype data with test-only evaluation.
27
+
28
+ This implementation imputes missing values by sampling from the empirical frequency distribution of observed integer codes (e.g., 0-9 for nucleotides). It supports a strict train/test protocol: distributions are learned on the train split, missingness is simulated only on the test split, and all metrics are computed exclusively on the test split.
29
+
30
+ The algorithm is as follows:
31
+ 1. Split the dataset into train and test sets (row-wise).
32
+ 2. For each feature (column), compute the empirical frequency distribution of observed values in the train set.
33
+ 3. On the test set, simulate additional missing values by randomly masking a specified proportion of observed entries.
34
+ 4. Impute missing values in the test set by sampling from the train-learned distributions.
35
+ 5. Evaluate imputation accuracy using various metrics on the test set only.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ genotype_data: "GenotypeData",
41
+ *,
42
+ prefix: str = "pgsui",
43
+ by_populations: bool = False,
44
+ default: int = 0,
45
+ missing: int = -9,
46
+ verbose: bool = True,
47
+ seed: Optional[int] = None,
48
+ sim_prop_missing: float = 0.30,
49
+ debug: bool = False,
50
+ test_size: float = 0.2,
51
+ test_indices: Optional[Sequence[int]] = None,
52
+ stratify_by_populations: bool = False,
53
+ ) -> None:
54
+ """Initialize ImputeAlleleFreq.
55
+
56
+ Args:
57
+ genotype_data: Object with `.genotypes_int`, `.ref`, `.alt`, and optional `.populations`.
58
+ prefix: Output prefix.
59
+ by_populations: Learn separate dists per population.
60
+ default: Default genotype for cold-start loci.
61
+ missing: Integer code for missing values in X_.
62
+ verbose: Verbosity switch.
63
+ seed: RNG seed.
64
+ sim_prop_missing: Fraction of OBSERVED test cells to mask for evaluation.
65
+ debug: Debug switch.
66
+ test_size: Fraction of rows held out for test if `test_indices` not provided.
67
+ test_indices: Explicit test row indices. Overrides `test_size` if given.
68
+ stratify_by_populations: If True and populations are available, create a stratified test split per population.
69
+ """
70
+ self.genotype_data = genotype_data
71
+ self.prefix = prefix
72
+ self.by_populations = by_populations
73
+ self.default = int(default)
74
+ self.missing = int(missing)
75
+ self.verbose = verbose
76
+ self.sim_prop_missing = float(sim_prop_missing)
77
+ self.debug = debug
78
+
79
+ if not (0.0 <= self.sim_prop_missing <= 0.95):
80
+ raise ValueError("sim_prop_missing must be in [0, 0.95].")
81
+
82
+ self.rng = np.random.default_rng(seed)
83
+ self.encoder = GenotypeEncoder(self.genotype_data)
84
+ self.X_ = np.asarray(self.encoder.genotypes_int, dtype=np.int8)
85
+ self.num_features_ = self.X_.shape[1]
86
+
87
+ # --- split controls ---
88
+ self.test_size = float(test_size)
89
+ self.test_indices = (
90
+ None if test_indices is None else np.asarray(test_indices, dtype=int)
91
+ )
92
+
93
+ self.stratify_by_populations = bool(stratify_by_populations)
94
+
95
+ self.pops = None
96
+ if self.by_populations:
97
+ pops = getattr(self.genotype_data, "populations", None)
98
+ if pops is None:
99
+ raise TypeError(
100
+ "by_populations=True requires genotype_data.populations."
101
+ )
102
+ self.pops = np.asarray(pops)
103
+ if len(self.pops) != self.X_.shape[0]:
104
+ raise ValueError(
105
+ f"`populations` length ({len(self.pops)}) != number of samples ({self.X_.shape[0]})."
106
+ )
107
+
108
+ self.is_fit_: bool = False
109
+ self.global_dist_: Dict[int, Tuple[np.ndarray, np.ndarray]] = {}
110
+ self.group_dist_: Dict[
111
+ Union[str, int], Dict[int, Tuple[np.ndarray, np.ndarray]]
112
+ ] = {}
113
+ self.sim_mask_: Optional[np.ndarray] = None
114
+ self.train_idx_: Optional[np.ndarray] = None
115
+ self.test_idx_: Optional[np.ndarray] = None
116
+ self.X_train_: Optional[pd.DataFrame] = None
117
+ self.ground_truths_: Optional[np.ndarray] = None
118
+ self.metrics_: Dict[str, float] = {}
119
+ self.X_imputed_: Optional[np.ndarray] = None
120
+
121
+ # VCF ref/alt cache + IUPAC LUT
122
+ self.ref_codes_: Optional[np.ndarray] = None # (nF,) in {0..3}
123
+ self.alt_mask_: Optional[np.ndarray] = None # (nF,4) bool
124
+ self._iupac_presence_lut_: Optional[np.ndarray] = None # (10,4) bool
125
+
126
+ self.plotter = Plotting(
127
+ "ImputeAlleleFreq",
128
+ prefix=self.prefix,
129
+ plot_format=genotype_data.plot_format,
130
+ plot_fontsize=genotype_data.plot_fontsize,
131
+ plot_dpi=genotype_data.plot_dpi,
132
+ title_fontsize=genotype_data.plot_fontsize,
133
+ despine=genotype_data.plot_despine,
134
+ show_plots=genotype_data.show_plots,
135
+ verbose=self.verbose,
136
+ debug=self.debug,
137
+ )
138
+
139
+ # ------------------------------------------
140
+ # Helpers for VCF ref/alt and IUPAC mapping
141
+ # ------------------------------------------
142
+ def _map_base_to_int(self, arr) -> np.ndarray | None:
143
+ """Map bases to A/C/G/T -> 0/1/2/3; pass integers through; others -> -1.
144
+
145
+ Args:
146
+ arr: Array-like of bases (str) or integer codes.
147
+
148
+ Returns:
149
+ np.ndarray | None: Mapped integer array, or None if input is None or invalid.
150
+ """
151
+ if arr is None:
152
+ return None
153
+ arr = np.asarray(arr)
154
+ if arr.dtype.kind in ("i", "u"):
155
+ return arr.astype(np.int32, copy=False)
156
+ if arr.dtype.kind in ("U", "S", "O"):
157
+ up = np.char.upper(arr.astype("U1"))
158
+ out = np.full(up.shape, -1, dtype=np.int32)
159
+ out[up == "A"] = 0
160
+ out[up == "C"] = 1
161
+ out[up == "G"] = 2
162
+ out[up == "T"] = 3
163
+ return out
164
+ return None
165
+
166
+ def _ref_codes_from_genotype_data(self) -> np.ndarray:
167
+ """Fetch per-locus reference base from genotype_data.ref as 0..3.
168
+
169
+ Returns:
170
+ np.ndarray: Array of shape (n_features,) with values in {0,1,2,3}.
171
+ """
172
+ ref_raw = getattr(self.genotype_data, "ref", None)
173
+ ref_codes = self._map_base_to_int(ref_raw)
174
+ if ref_codes is None or ref_codes.shape[0] != self.num_features_:
175
+ msg = (
176
+ "genotype_data.ref missing or wrong length; "
177
+ f"expected ({self.num_features_},) got "
178
+ f"{None if ref_codes is None else ref_codes.shape}"
179
+ )
180
+ raise ValueError(msg)
181
+ return ref_codes
182
+
183
+ def _alt_mask_from_genotype_data(self) -> np.ndarray:
184
+ """Build a per-locus mask of which bases are ALT (supports multi-alt).
185
+
186
+ Returns:
187
+ np.ndarray: Boolean array of shape (n_features, 4) indicating presence of A,C,G,T as ALT.
188
+ """
189
+ nF = self.num_features_
190
+ alt_raw = getattr(self.genotype_data, "alt", None)
191
+ alt_mask = np.zeros((nF, 4), dtype=bool)
192
+ if alt_raw is None:
193
+ return alt_mask
194
+
195
+ alt_arr = np.asarray(alt_raw, dtype=object)
196
+ if alt_arr.shape[0] != nF and self.verbose:
197
+ print(
198
+ f"[warn] genotype_data.alt length {alt_arr.shape[0]} != n_features {nF}; truncating."
199
+ )
200
+
201
+ def add_code(mask_row, x):
202
+ if x is None:
203
+ return
204
+ if isinstance(x, (int, np.integer)):
205
+ v = int(x)
206
+ if 0 <= v <= 3:
207
+ mask_row[v] = True
208
+ return
209
+ if isinstance(x, str):
210
+ s = x.strip().upper()
211
+ if not s:
212
+ return
213
+ if "," in s:
214
+ for token in s.split(","):
215
+ add_code(mask_row, token.strip())
216
+ return
217
+ if s in ("A", "C", "G", "T"):
218
+ idx = {"A": 0, "C": 1, "G": 2, "T": 3}[s]
219
+ mask_row[idx] = True
220
+ return
221
+ if isinstance(x, (list, tuple, np.ndarray)):
222
+ for t in x:
223
+ add_code(mask_row, t)
224
+ return
225
+
226
+ for i in range(min(nF, alt_arr.shape[0])):
227
+ add_code(alt_mask[i], alt_arr[i])
228
+ return alt_mask
229
+
230
+ def _build_iupac_presence_lut(self) -> np.ndarray:
231
+ """Create LUT mapping integer codes {0..9} -> allele presence over A,C,G,T.
232
+
233
+ Returns:
234
+ np.ndarray: Boolean array of shape (10,4) indicating presence of A,C,G,T for each IUPAC code.
235
+ """
236
+ lut = np.zeros((10, 4), dtype=bool) # A,C,G,T
237
+ lut[0, 0] = True
238
+ lut[1, 3] = True
239
+ lut[2, 2] = True
240
+ lut[3, 1] = True # A T G C
241
+ lut[4, [0, 3]] = True
242
+ lut[5, [0, 2]] = True
243
+ lut[6, [0, 1]] = True # W R M
244
+ lut[7, [2, 3]] = True
245
+ lut[8, [1, 3]] = True
246
+ lut[9, [1, 2]] = True # K Y S
247
+ return lut
248
+
249
+ # -----------------------
250
+ # Fit / Transform
251
+ # -----------------------
252
+ def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
253
+ """Create row-wise train/test split according to init settings.
254
+
255
+ Returns:
256
+ Tuple[np.ndarray, np.ndarray]: (train indices, test indices) as integer arrays.
257
+ """
258
+ n = self.X_.shape[0]
259
+ all_idx = np.arange(n, dtype=int)
260
+
261
+ if self.test_indices is not None:
262
+ test_idx = np.unique(self.test_indices)
263
+ if np.any((test_idx < 0) | (test_idx >= n)):
264
+ raise IndexError("Some test_indices are out of bounds.")
265
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
266
+ return train_idx, test_idx
267
+
268
+ # Random split (optionally stratified by population)
269
+ if (
270
+ self.by_populations
271
+ and self.stratify_by_populations
272
+ and (self.pops is not None)
273
+ ):
274
+ test_buckets = []
275
+ for pop in np.unique(self.pops):
276
+ pop_rows = np.where(self.pops == pop)[0]
277
+ k = int(round(self.test_size * pop_rows.size))
278
+ if k > 0:
279
+ chosen = self.rng.choice(pop_rows, size=k, replace=False)
280
+ test_buckets.append(chosen)
281
+ test_idx = (
282
+ np.sort(np.concatenate(test_buckets))
283
+ if test_buckets
284
+ else np.array([], dtype=int)
285
+ )
286
+ else:
287
+ k = int(round(self.test_size * n))
288
+ test_idx = (
289
+ self.rng.choice(n, size=k, replace=False)
290
+ if k > 0
291
+ else np.array([], dtype=int)
292
+ )
293
+
294
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
295
+ return train_idx, test_idx
296
+
297
+ def fit(self) -> "ImputeAlleleFreq":
298
+ """Learn per-locus distributions on TRAIN rows; simulate missingness on TEST rows.
299
+
300
+ Notes:
301
+ The general workflow is:
302
+ 1) Split rows into train/test.
303
+ 2) Build distributions from TRAIN only.
304
+ 3) Simulate missingness on TEST only (stores `sim_mask_`).
305
+ 4) Cache ref/alt and IUPAC LUT.
306
+ """
307
+ # 0) Row split
308
+ self.train_idx_, self.test_idx_ = self._make_train_test_split()
309
+
310
+ self.ground_truths_ = self.X_.copy()
311
+ df_all = pd.DataFrame(self.ground_truths_, dtype=np.float32)
312
+ df_all.replace(self.missing, np.nan, inplace=True)
313
+
314
+ # 1) TRAIN-only distributions (no simulated holes here)
315
+ df_train = df_all.iloc[self.train_idx_].copy()
316
+
317
+ self.global_dist_ = {
318
+ col: self._series_distribution(df_train[col]) for col in df_train.columns
319
+ }
320
+
321
+ self.group_dist_.clear()
322
+ if self.by_populations:
323
+ tmp = df_train.copy()
324
+ tmp["_pops_"] = self.pops[self.train_idx_]
325
+ for pop, grp in tmp.groupby("_pops_"):
326
+ gdf = grp.drop(columns=["_pops_"])
327
+ self.group_dist_[pop] = {
328
+ col: self._series_distribution(gdf[col]) for col in gdf.columns
329
+ }
330
+
331
+ # 2) TEST-only simulated missingness
332
+ obs_mask = df_all.notna().to_numpy()
333
+ sim_mask = np.zeros_like(obs_mask, dtype=bool)
334
+ if self.test_idx_.size > 0 and self.sim_prop_missing > 0.0:
335
+ # restrict candidate coords to TEST rows that are observed
336
+ coords = np.argwhere(obs_mask)
337
+ test_row_mask = np.zeros(obs_mask.shape[0], dtype=bool)
338
+ test_row_mask[self.test_idx_] = True
339
+ coords_test = coords[test_row_mask[coords[:, 0]]]
340
+ total_obs_test = coords_test.shape[0]
341
+ if total_obs_test > 0:
342
+ n_to_mask = int(round(self.sim_prop_missing * total_obs_test))
343
+ if n_to_mask > 0:
344
+ choice_idx = self.rng.choice(
345
+ total_obs_test, size=n_to_mask, replace=False
346
+ )
347
+ chosen_coords = coords_test[choice_idx]
348
+ sim_mask[chosen_coords[:, 0], chosen_coords[:, 1]] = True
349
+
350
+ df_sim = df_all.copy()
351
+ df_sim.values[sim_mask] = np.nan
352
+
353
+ # Store matrix to be imputed (train rows intact; test rows with simulated NaNs)
354
+ self.sim_mask_ = sim_mask
355
+ self.X_train_ = df_sim
356
+
357
+ # 3) Cache VCF ref/alt + IUPAC LUT once
358
+ self.ref_codes_ = self._ref_codes_from_genotype_data() # (nF,)
359
+ self.alt_mask_ = self._alt_mask_from_genotype_data() # (nF,4)
360
+ self._iupac_presence_lut_ = self._build_iupac_presence_lut()
361
+
362
+ self.is_fit_ = True
363
+ if self.verbose:
364
+ n_masked = int(sim_mask.sum())
365
+ print(
366
+ f"Fit complete. Train rows: {self.train_idx_.size}, Test rows: {self.test_idx_.size}."
367
+ )
368
+ print(
369
+ f"Simulated {n_masked} missing values (TEST rows only) for evaluation."
370
+ )
371
+ return self
372
+
373
+ def transform(self) -> np.ndarray:
374
+ """Impute the matrix and evaluate on the TEST-set simulated cells.
375
+
376
+ Returns:
377
+ np.ndarray: Imputed genotypes in the original (IUPAC/int) encoding.
378
+ """
379
+ if not self.is_fit_:
380
+ msg = "Model is not fitted. Call `fit()` before `transform()`."
381
+ self.logger.error(msg)
382
+ raise NotFittedError(msg)
383
+ assert (
384
+ self.X_train_ is not None
385
+ and self.sim_mask_ is not None
386
+ and self.ground_truths_ is not None
387
+ )
388
+
389
+ # ---- Impute using train-learned distributions ----
390
+ # uses self.global_dist_/group_dist_
391
+ imputed_df = self._impute_df(self.X_train_)
392
+ X_imp = imputed_df.to_numpy(dtype=np.int8)
393
+ self.X_imputed_ = X_imp
394
+
395
+ # -------------------------------------------------------------------
396
+ # Test-set evaluation on IUPAC/int codes (mask is test-only by design)
397
+ # -------------------------------------------------------------------
398
+ sim_mask = self.sim_mask_
399
+ y_true = self.ground_truths_[sim_mask]
400
+ y_pred = X_imp[sim_mask]
401
+
402
+ if y_true.size > 0:
403
+ labels = sorted(np.unique(np.concatenate([y_true, y_pred])))
404
+ self.metrics_ = {
405
+ "n_masked_test": int(y_true.size),
406
+ "accuracy": accuracy_score(y_true, y_pred),
407
+ "f1": f1_score(
408
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
409
+ ),
410
+ "precision": precision_score(
411
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
412
+ ),
413
+ "recall": recall_score(
414
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
415
+ ),
416
+ }
417
+ if self.verbose:
418
+ print("\n--- TEST-only Evaluation (IUPAC/int) ---")
419
+ for k, v in self.metrics_.items():
420
+ print(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}")
421
+ print("\nClassification Report (IUPAC/int, TEST-only):")
422
+ print(
423
+ classification_report(
424
+ y_true, y_pred, labels=labels, zero_division=0
425
+ )
426
+ )
427
+ else:
428
+ self.metrics_.update({"n_masked_test": 0})
429
+ if self.verbose:
430
+ print("No TEST cells were held out for evaluation (n_masked_test=0).")
431
+
432
+ # Optional confusion matrix (IUPAC/int)
433
+ labels_map = {
434
+ "A": 0,
435
+ "T": 1,
436
+ "G": 2,
437
+ "C": 3,
438
+ "W": 4,
439
+ "R": 5,
440
+ "M": 6,
441
+ "K": 7,
442
+ "Y": 8,
443
+ "S": 9,
444
+ "N": -9,
445
+ }
446
+ self.plotter.plot_confusion_matrix(y_true, y_pred, label_names=labels_map)
447
+
448
+ # ----------------------------------------------------------------
449
+ # TEST-only Zygosity 0/1/2 evaluation using VCF ref/alt (hom-ref/het/hom-alt)
450
+ # ----------------------------------------------------------------
451
+ r_idx, f_idx = np.nonzero(sim_mask)
452
+ if r_idx.size > 0:
453
+ true_codes = self.ground_truths_[r_idx, f_idx].astype(np.int16, copy=False)
454
+ pred_codes = X_imp[r_idx, f_idx].astype(np.int16, copy=False)
455
+
456
+ keep_nm = (true_codes != self.missing) & (pred_codes != self.missing)
457
+ if np.any(keep_nm):
458
+ true_codes = true_codes[keep_nm]
459
+ pred_codes = pred_codes[keep_nm]
460
+ f_k = f_idx[keep_nm]
461
+
462
+ ref_k = self.ref_codes_[f_k] # (n,)
463
+ alt_rows = self.alt_mask_[f_k, :] # (n,4)
464
+ ra_mask = alt_rows.copy()
465
+ ra_mask[np.arange(ref_k.size), ref_k] = True
466
+
467
+ lut = self._iupac_presence_lut_
468
+ valid_true = (true_codes >= 0) & (true_codes < lut.shape[0])
469
+ valid_pred = (pred_codes >= 0) & (pred_codes < lut.shape[0])
470
+ keep_valid = valid_true & valid_pred
471
+
472
+ if np.any(keep_valid):
473
+ true_codes = true_codes[keep_valid]
474
+ pred_codes = pred_codes[keep_valid]
475
+ ref_k = ref_k[keep_valid]
476
+ ra_mask = ra_mask[keep_valid, :]
477
+
478
+ A_true = lut[true_codes] # (n,4)
479
+ A_pred = lut[pred_codes] # (n,4)
480
+
481
+ any_true = A_true.any(axis=1)
482
+ any_pred = A_pred.any(axis=1)
483
+ out_true = (A_true & ~ra_mask).any(axis=1)
484
+ out_pred = (A_pred & ~ra_mask).any(axis=1)
485
+ valid_rows = any_true & any_pred & (~out_true) & (~out_pred)
486
+
487
+ if np.any(valid_rows):
488
+ A_true = A_true[valid_rows]
489
+ A_pred = A_pred[valid_rows]
490
+ ref_kv = ref_k[valid_rows]
491
+ n = A_true.shape[0]
492
+ rows = np.arange(n, dtype=int)
493
+
494
+ cnt_true = A_true.sum(axis=1)
495
+ homref_true = (cnt_true == 1) & A_true[rows, ref_kv]
496
+ homalt_true = (cnt_true == 1) & (~A_true[rows, ref_kv])
497
+ y_true_3 = np.empty(n, dtype=np.int8)
498
+ y_true_3[homref_true] = 0
499
+ y_true_3[homalt_true] = 2
500
+ y_true_3[~(homref_true | homalt_true)] = 1
501
+
502
+ cnt_pred = A_pred.sum(axis=1)
503
+ homref_pred = (cnt_pred == 1) & A_pred[rows, ref_kv]
504
+ homalt_pred = (cnt_pred == 1) & (~A_pred[rows, ref_kv])
505
+ y_pred_3 = np.empty(n, dtype=np.int8)
506
+ y_pred_3[homref_pred] = 0
507
+ y_pred_3[homalt_pred] = 2
508
+ y_pred_3[~(homref_pred | homalt_pred)] = 1
509
+
510
+ labels_3 = [0, 1, 2]
511
+ self.metrics_.update(
512
+ {
513
+ "zyg_n_test": int(n),
514
+ "zyg_accuracy": accuracy_score(y_true_3, y_pred_3),
515
+ "zyg_f1": f1_score(
516
+ y_true_3,
517
+ y_pred_3,
518
+ average="macro",
519
+ labels=labels_3,
520
+ zero_division=0,
521
+ ),
522
+ "zyg_precision": precision_score(
523
+ y_true_3,
524
+ y_pred_3,
525
+ average="macro",
526
+ labels=labels_3,
527
+ zero_division=0,
528
+ ),
529
+ "zyg_recall": recall_score(
530
+ y_true_3,
531
+ y_pred_3,
532
+ average="macro",
533
+ labels=labels_3,
534
+ zero_division=0,
535
+ ),
536
+ }
537
+ )
538
+ if self.verbose:
539
+ print(
540
+ "\n--- TEST-only Zygosity (0=hom-ref,1=het,2=hom-alt) ---"
541
+ )
542
+ for k in (
543
+ "zyg_n_test",
544
+ "zyg_accuracy",
545
+ "zyg_f1",
546
+ "zyg_precision",
547
+ "zyg_recall",
548
+ ):
549
+ v = self.metrics_[k]
550
+ print(
551
+ f" {k}: {v:.4f}"
552
+ if isinstance(v, float)
553
+ else f" {k}: {v}"
554
+ )
555
+ print("\nClassification Report (zyg, TEST-only):")
556
+ print(
557
+ classification_report(
558
+ y_true_3,
559
+ y_pred_3,
560
+ labels=labels_3,
561
+ target_names=["hom-ref", "het", "hom-alt"],
562
+ zero_division=0,
563
+ )
564
+ )
565
+ self.plotter.plot_confusion_matrix(
566
+ y_true_3,
567
+ y_pred_3,
568
+ label_names=["hom-ref", "het", "hom-alt"],
569
+ )
570
+ else:
571
+ if self.verbose:
572
+ print(
573
+ "[info] Zygosity TEST-only: no valid rows after RA filtering."
574
+ )
575
+ else:
576
+ if self.verbose:
577
+ print(
578
+ "[info] Zygosity TEST-only: no valid rows after code filtering."
579
+ )
580
+ else:
581
+ if self.verbose:
582
+ print(
583
+ "[info] Zygosity TEST-only: nothing to score (all masked entries missing)."
584
+ )
585
+ else:
586
+ if self.verbose:
587
+ print("[info] TEST-only evaluation: no masked coordinates found.")
588
+
589
+ return self.encoder.inverse_int_iupac(X_imp)
590
+
591
+ def fit_transform(self) -> np.ndarray:
592
+ """Convenience method that calls `fit()` then `transform()`."""
593
+ self.fit()
594
+ return self.transform()
595
+
596
+ # -----------------------
597
+ # Core frequency model
598
+ # -----------------------
599
+ def _safe_probs(self, probs: np.ndarray) -> np.ndarray:
600
+ """Ensure probs are non-negative and sum to 1; fallback to uniform if invalid.
601
+
602
+ Args:
603
+ probs: Array of non-negative values (not necessarily summing to 1).
604
+
605
+ Returns:
606
+ np.ndarray: Valid probability distribution summing to 1.
607
+ """
608
+ probs = np.asarray(probs, dtype=float)
609
+ probs[probs < 0] = 0
610
+ s = probs.sum()
611
+
612
+ if not np.isfinite(s) or s <= 0:
613
+ return np.full(probs.size, 1.0 / max(1, probs.size))
614
+ return probs / s
615
+
616
+ def _series_distribution(self, s: pd.Series) -> Tuple[np.ndarray, np.ndarray]:
617
+ """Compute empirical (states, probs) for one locus from observed integer codes.
618
+
619
+ Args:
620
+ s: One column (locus) as a pandas Series with NaNs for missing.
621
+
622
+ Returns:
623
+ Tuple[np.ndarray, np.ndarray]: (states, probs) sorted by state.
624
+ """
625
+ s_valid = s.dropna().astype(int)
626
+ if s_valid.empty:
627
+ return np.array([self.default], dtype=int), np.array([1.0])
628
+ freqs = s_valid.value_counts(normalize=True).sort_index()
629
+ states = freqs.index.to_numpy(dtype=int)
630
+ probs = self._safe_probs(freqs.to_numpy(dtype=float))
631
+ return states, probs
632
+
633
+ def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
634
+ """Impute NaNs in df_in using precomputed TRAIN distributions.
635
+
636
+ Args:
637
+ df_in: DataFrame with NaNs to impute.
638
+
639
+ Returns:
640
+ DataFrame with NaNs imputed.
641
+ """
642
+ return (
643
+ self._impute_global(df_in)
644
+ if not self.by_populations
645
+ else self._impute_by_population(df_in)
646
+ )
647
+
648
+ def _impute_global(self, df_in: pd.DataFrame) -> pd.DataFrame:
649
+ """Impute dataframe globally, preserving original sample order.
650
+
651
+ Args:
652
+ df_in: DataFrame with NaNs to impute.
653
+
654
+ Returns:
655
+ DataFrame with NaNs imputed.
656
+ """
657
+ df = df_in.copy()
658
+ for col in df.columns:
659
+ if not df[col].isnull().any():
660
+ continue
661
+ states, probs = self.global_dist_[col]
662
+ n_missing = int(df[col].isnull().sum())
663
+ samples = self.rng.choice(states, size=n_missing, p=probs)
664
+ df.loc[df[col].isnull(), col] = samples
665
+ return df.astype(np.int8)
666
+
667
+ def _impute_by_population(self, df_in: pd.DataFrame) -> pd.DataFrame:
668
+ """Impute dataframe by population, preserving original sample order.
669
+
670
+ Args:
671
+ df_in: DataFrame with NaNs to impute.
672
+
673
+ Returns:
674
+ DataFrame with NaNs imputed.
675
+ """
676
+ df = df_in.copy()
677
+ df["_pops_"] = getattr(self, "pops", None)
678
+ for pop, grp in df.groupby("_pops_"):
679
+ grp_imputed = grp.copy()
680
+ per_pop_dist = self.group_dist_.get(pop, {})
681
+ for col in grp.columns:
682
+ if col == "_pops_":
683
+ continue
684
+ if not grp[col].isnull().any():
685
+ continue
686
+ states, probs = per_pop_dist.get(col, self.global_dist_[col])
687
+ n_missing = int(grp[col].isnull().sum())
688
+ samples = self.rng.choice(states, size=n_missing, p=probs)
689
+ grp_imputed.loc[grp_imputed[col].isnull(), col] = samples
690
+ df.update(grp_imputed)
691
+ return df.drop(columns=["_pops_"]).astype(np.int8)