pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -0,0 +1,844 @@
1
+ # Standard library imports
2
+ import json
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ # Third-party imports
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from matplotlib.figure import Figure
11
+ from plotly.graph_objs import Figure as PlotlyFigure
12
+ from sklearn.exceptions import NotFittedError
13
+ from sklearn.metrics import (
14
+ accuracy_score,
15
+ classification_report,
16
+ f1_score,
17
+ precision_score,
18
+ recall_score,
19
+ )
20
+ from snpio import GenotypeEncoder
21
+ from snpio.utils.logging import LoggerManager
22
+
23
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
24
+ from pgsui.data_processing.containers import MostFrequentConfig
25
+ from pgsui.data_processing.transformers import SimMissingTransformer
26
+ from pgsui.utils.classification_viz import ClassificationReportVisualizer
27
+ from pgsui.utils.logging_utils import configure_logger
28
+
29
+ # Local imports
30
+ from pgsui.utils.plotting import Plotting
31
+ from pgsui.utils.pretty_metrics import PrettyMetrics
32
+
33
+ # Type checking imports
34
+ if TYPE_CHECKING:
35
+ from snpio import TreeParser
36
+ from snpio.read_input.genotype_data import GenotypeData
37
+
38
+
39
+ def ensure_mostfrequent_config(
40
+ config: Union[MostFrequentConfig, dict, str, None],
41
+ ) -> MostFrequentConfig:
42
+ """Return a concrete MostFrequentConfig (dataclass, dict, YAML path, or None).
43
+
44
+ Args:
45
+ config (Union[MostFrequentConfig, dict, str, None]): The configuration to ensure is a MostFrequentConfig.
46
+
47
+ Returns:
48
+ MostFrequentConfig: The ensured MostFrequentConfig.
49
+ """
50
+ if config is None:
51
+ return MostFrequentConfig()
52
+ if isinstance(config, MostFrequentConfig):
53
+ return config
54
+ if isinstance(config, str):
55
+ return load_yaml_to_dataclass(config, MostFrequentConfig)
56
+ if isinstance(config, dict):
57
+ base = MostFrequentConfig()
58
+ # honor optional top-level 'preset'
59
+ preset = config.pop("preset", None)
60
+
61
+ if preset:
62
+ base = MostFrequentConfig.from_preset(preset)
63
+
64
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
65
+ for k, v in d.items():
66
+ kk = f"{prefix}.{k}" if prefix else k
67
+ if isinstance(v, dict):
68
+ _flatten(kk, v, out)
69
+ else:
70
+ out[kk] = v
71
+ return out
72
+
73
+ flat = _flatten("", config, {})
74
+ return apply_dot_overrides(base, flat)
75
+
76
+ raise TypeError("config must be a MostFrequentConfig, dict, YAML path, or None.")
77
+
78
+
79
+ class ImputeMostFrequent:
80
+ """Most-frequent (mode) imputer that mirrors DL evaluation on 0/1/2.
81
+
82
+ This imputer computes the most frequent genotype (mode) for each locus based on the training set and uses it to fill in missing values. It supports both global modes and population-specific modes if population data is provided. The imputer follows an evaluation protocol similar to deep learning models, including splitting the data into training and testing sets, masking observed cells in the test set for evaluation, and producing detailed classification reports and plots. It handles both diploid and haploid data, with special considerations for haploid scenarios. The imputer is designed to work seamlessly with genotype data encoded in 0/1/2 format, where -1 indicates missing values.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ genotype_data: "GenotypeData",
88
+ *,
89
+ tree_parser: Optional["TreeParser"] = None,
90
+ config: Optional[Union[MostFrequentConfig, dict, str]] = None,
91
+ overrides: Optional[dict] = None,
92
+ simulate_missing: bool = True,
93
+ sim_strategy: Literal[
94
+ "random",
95
+ "random_weighted",
96
+ "random_weighted_inv",
97
+ "nonrandom",
98
+ "nonrandom_weighted",
99
+ ] = "random",
100
+ sim_prop: float = 0.2,
101
+ sim_kwargs: Optional[dict] = None,
102
+ ) -> None:
103
+ """Initialize the Most-Frequent (mode) imputer from a unified config.
104
+
105
+ This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and various parameters based on the configuration. The imputer is prepared to handle population-specific modes if specified in the configuration.
106
+
107
+ Args:
108
+ genotype_data (GenotypeData): Backing genotype data.
109
+ tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
110
+ config (MostFrequentConfig | dict | str | None): Configuration as a dataclass,
111
+ nested dict, or YAML path. If None, defaults are used.
112
+ overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'algo.by_populations': True, 'split.test_size': 0.3}.
113
+ simulate_missing (bool): Whether to simulate missing data if enabled in config. Defaults to True.
114
+ sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
115
+ sim_prop (float): Proportion of data to simulate as missing if enabled in config.
116
+ sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
117
+
118
+ Notes:
119
+ - This mirrors other config-driven models (AE/VAE/NLPCA/UBP).
120
+ - Evaluation split behavior uses cfg.split; plotting uses cfg.plot.
121
+ - I/O/logging seeds and verbosity use cfg.io.
122
+ """
123
+ # Normalize config then apply highest-precedence overrides
124
+ cfg = ensure_mostfrequent_config(config)
125
+ if overrides:
126
+ cfg = apply_dot_overrides(cfg, overrides)
127
+ self.cfg = cfg
128
+
129
+ # Basic fields
130
+ self.genotype_data = genotype_data
131
+ self.tree_parser = tree_parser
132
+ self.prefix = cfg.io.prefix
133
+ self.verbose = cfg.io.verbose
134
+ self.debug = cfg.io.debug
135
+
136
+ self.parameters_dir: Path
137
+ self.metrics_dir: Path
138
+ self.plots_dir: Path
139
+ self.models_dir: Path
140
+ self.optimize_dir: Path
141
+
142
+ # Logger
143
+ logman = LoggerManager(
144
+ __name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
145
+ )
146
+ self.logger = configure_logger(
147
+ logman.get_logger(), verbose=self.verbose, debug=self.debug
148
+ )
149
+
150
+ # RNG / encoder
151
+ self.rng = np.random.default_rng(cfg.io.seed)
152
+ self.encoder = GenotypeEncoder(self.genotype_data)
153
+
154
+ # Work in 0/1/2 with -1 for missing (parity with DL modules)
155
+ X012 = self.encoder.genotypes_012.astype(np.int16, copy=False)
156
+
157
+ # 2. In-place replacement of NaNs
158
+ # NOTE: X012 will be consumed to make ground_truth_
159
+ np.nan_to_num(X012, nan=-1.0, copy=False)
160
+
161
+ X012[X012 < 0] = -1
162
+ self.X012_ = X012
163
+ self.num_features_ = X012.shape[1]
164
+
165
+ # Simulated-missing controls (mirror VAE/AE/NLPCA semantics where possible)
166
+ sim_cfg = getattr(self.cfg, "sim", None)
167
+ sim_cfg_kwargs = dict(getattr(sim_cfg, "sim_kwargs", {}) or {})
168
+
169
+ self.simulate_missing: bool
170
+ self.sim_strategy: str
171
+ self.sim_prop: float
172
+ self.sim_kwargs: dict
173
+
174
+ # Missing simulation config
175
+ if sim_cfg is None:
176
+ # Fallback defaults if MostFrequentConfig has no .sim block
177
+ self.simulate_missing = simulate_missing
178
+ self.sim_strategy = sim_strategy
179
+ self.sim_prop = sim_prop
180
+ else:
181
+ self.simulate_missing = bool(
182
+ getattr(sim_cfg, "simulate_missing", simulate_missing)
183
+ )
184
+ self.sim_strategy = getattr(sim_cfg, "sim_strategy", sim_strategy)
185
+ self.sim_prop = float(getattr(sim_cfg, "sim_prop", sim_prop))
186
+ if getattr(sim_cfg, "sim_kwargs", sim_kwargs):
187
+ sim_cfg_kwargs.update(sim_cfg.sim_kwargs)
188
+
189
+ self.sim_kwargs = sim_cfg_kwargs
190
+
191
+ # Simulated-missing masks (global + test-only)
192
+ self.sim_mask_global_: Optional[np.ndarray] = None # shape (N, L), bool
193
+ self.sim_mask_test_only_: Optional[np.ndarray] = None
194
+
195
+ # Split & algo knobs
196
+ self.test_size = float(cfg.split.test_size)
197
+ self.test_indices = (
198
+ None
199
+ if cfg.split.test_indices is None
200
+ else np.asarray(cfg.split.test_indices, dtype=int)
201
+ )
202
+ self.by_populations = bool(cfg.algo.by_populations)
203
+ self.default = int(cfg.algo.default)
204
+ self.missing = int(cfg.algo.missing)
205
+
206
+ # Populations (if requested)
207
+ self.pops = None
208
+ if self.by_populations:
209
+ pops = getattr(self.genotype_data, "populations", None)
210
+ if pops is None:
211
+ msg = "by_populations=True requires genotype_data.populations."
212
+ self.logger.error(msg)
213
+ raise TypeError(msg)
214
+ self.pops = np.asarray(pops)
215
+ if len(self.pops) != self.X012_.shape[0]:
216
+ msg = f"`populations` length ({len(self.pops)}) != number of samples ({self.X012_.shape[0]})."
217
+ self.logger.error(msg)
218
+ raise ValueError(msg)
219
+
220
+ # State
221
+ self.is_fit_: bool = False
222
+ self.global_modes_: Dict[str, int] = {}
223
+ self.group_modes_: dict = {}
224
+ self.sim_mask_: Optional[np.ndarray] = None
225
+ self.train_idx_: Optional[np.ndarray] = None
226
+ self.test_idx_: Optional[np.ndarray] = None
227
+ self.X_train_df_: Optional[pd.DataFrame] = None
228
+ self.ground_truth012_: Optional[np.ndarray] = None
229
+ self.metrics_: Dict[str, int | float] = {}
230
+ self.X_imputed012_: Optional[np.ndarray] = None
231
+
232
+ # Ploidy heuristic for 0/1/2 scoring parity
233
+ uniq = np.unique(self.X012_[self.X012_ != -1])
234
+ self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
235
+
236
+ # Plotting (use config, not genotype_data fields)
237
+ self.plot_format = cfg.plot.fmt
238
+ self.plot_fontsize = cfg.plot.fontsize
239
+ self.plot_despine = cfg.plot.despine
240
+ self.plot_dpi = cfg.plot.dpi
241
+ self.show_plots = cfg.plot.show
242
+
243
+ self.model_name = (
244
+ "ImputeMostFrequentPerPop" if self.by_populations else "ImputeMostFrequent"
245
+ )
246
+ self.plotter_ = Plotting(
247
+ self.model_name,
248
+ prefix=self.prefix,
249
+ plot_format=self.plot_format,
250
+ plot_fontsize=self.plot_fontsize,
251
+ plot_dpi=self.plot_dpi,
252
+ title_fontsize=self.plot_fontsize,
253
+ despine=self.plot_despine,
254
+ show_plots=self.show_plots,
255
+ verbose=self.verbose,
256
+ debug=self.debug,
257
+ multiqc=True,
258
+ multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
259
+ )
260
+
261
+ # Output dirs
262
+ dirs = ["models", "plots", "metrics", "optimize", "parameters"]
263
+ self._create_model_directories(self.prefix, dirs)
264
+
265
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
266
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
267
+ self.logger.error(msg)
268
+ raise ValueError(msg)
269
+
270
+ def fit(self) -> "ImputeMostFrequent":
271
+ """Learn per-locus modes on TRAIN rows; mask simulated cells on TEST rows.
272
+
273
+ This method computes the most frequent genotype (mode) for each locus based on the training set and prepares the evaluation masks for the test set. It supports both global modes and population-specific modes if population data is provided. The method sets up the internal state required for imputation and evaluation.
274
+
275
+ Returns:
276
+ ImputeMostFrequent: The fitted imputer instance.
277
+ """
278
+ self.train_idx_, self.test_idx_ = self._make_train_test_split()
279
+ self.ground_truth012_ = self.X012_.copy()
280
+
281
+ # Work in DataFrame with NaN as missing for mode computation
282
+ df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
283
+ df_all = df_all.replace(self.missing, np.nan)
284
+ df_all = df_all.replace(-9, np.nan) # Just in case
285
+
286
+ # Modes from TRAIN rows only (per-locus)
287
+ df_train = df_all.iloc[self.train_idx_].copy()
288
+ self.global_modes_ = {
289
+ col: self._series_mode(df_train[col]) for col in df_train.columns
290
+ }
291
+
292
+ self.group_modes_.clear()
293
+ if self.by_populations:
294
+ tmp = df_train.copy()
295
+ if self.pops is not None:
296
+ tmp["_pops_"] = self.pops[self.train_idx_]
297
+ for pop, grp in tmp.groupby("_pops_"):
298
+ gdf = grp.drop(columns=["_pops_"])
299
+ self.group_modes_[pop] = {
300
+ col: self._series_mode(gdf[col]) for col in gdf.columns
301
+ }
302
+ else:
303
+ msg = "Population data is required when by_populations=True."
304
+ self.logger.error(msg)
305
+ raise ValueError(msg)
306
+
307
+ # ------------------------------
308
+ # Simulated-missing mask (global → test-only)
309
+ # ------------------------------
310
+ obs_mask = df_all.notna().to_numpy() # observed = not NaN
311
+ n_samples, n_loci = obs_mask.shape
312
+
313
+ if self.simulate_missing:
314
+ # Use the same transformer as VAE
315
+ tr = SimMissingTransformer(
316
+ genotype_data=self.genotype_data,
317
+ tree_parser=self.tree_parser,
318
+ prop_missing=self.sim_prop,
319
+ strategy=self.sim_strategy,
320
+ missing_val=-9,
321
+ mask_missing=True,
322
+ verbose=self.verbose,
323
+ **self.sim_kwargs,
324
+ )
325
+ # Fit on 0/1/2 with -1 for missing, like VAE
326
+ X_for_sim = self.ground_truth012_.astype(float, copy=True)
327
+ X_for_sim[X_for_sim < 0] = np.nan
328
+ tr.fit(X_for_sim)
329
+
330
+ sim_mask_global = tr.sim_missing_mask_.astype(bool)
331
+
332
+ # Don't simulate on already-missing cells
333
+ sim_mask_global &= obs_mask
334
+
335
+ # Restrict evaluation to TEST rows only
336
+ test_rows_mask = np.zeros(n_samples, dtype=bool)
337
+ if self.test_idx_.size > 0:
338
+ test_rows_mask[self.test_idx_] = True
339
+
340
+ sim_mask = sim_mask_global & test_rows_mask[:, None]
341
+
342
+ self.sim_mask_global_ = sim_mask_global
343
+ self.sim_mask_test_only_ = sim_mask
344
+ else:
345
+ # Fallback: current behavior – mask all observed cells on TEST rows
346
+ test_rows_mask = np.zeros(n_samples, dtype=bool)
347
+ if self.test_idx_.size > 0:
348
+ test_rows_mask[self.test_idx_] = True
349
+ sim_mask = obs_mask & test_rows_mask[:, None]
350
+
351
+ self.sim_mask_global_ = None
352
+ self.sim_mask_test_only_ = sim_mask
353
+
354
+ # Apply the mask to create the evaluation DataFrame
355
+ df_sim = df_all.copy()
356
+ df_sim.values[self.sim_mask_test_only_] = np.nan
357
+
358
+ self.sim_mask_ = self.sim_mask_test_only_
359
+ self.X_train_df_ = df_sim
360
+ self.is_fit_ = True
361
+
362
+ # Save parameters (unchanged)
363
+ best_params = self.cfg.to_dict()
364
+ params_fp = self.parameters_dir / "best_parameters.json"
365
+ with open(params_fp, "w") as f:
366
+ json.dump(best_params, f, indent=4)
367
+
368
+ n_masked = int(self.sim_mask_test_only_.sum())
369
+ self.logger.info(
370
+ f"Fit complete. Train rows: {self.train_idx_.size}, "
371
+ f"Test rows: {self.test_idx_.size}. "
372
+ f"Masked {n_masked} test cells for evaluation "
373
+ f"({'simulated' if self.simulate_missing else 'all observed'})."
374
+ )
375
+ return self
376
+
377
+ def transform(self) -> np.ndarray:
378
+ """Impute missing cells in the FULL dataset; evaluate on masked test cells.
379
+
380
+ This method first imputes the evaluation-masked training DataFrame to compute metrics, then imputes the full dataset (only true missings) for final output. It produces the same evaluation reports and plots as the DL models, including both 0/1/2 zygosity and 10-class IUPAC reports.
381
+
382
+ Returns:
383
+ np.ndarray: Imputed genotypes as IUPAC strings, shape (n_samples, n_variants).
384
+
385
+ Raises:
386
+ NotFittedError: If fit() has not been called prior to transform().
387
+ """
388
+ if not self.is_fit_:
389
+ msg = "Model is not fitted. Call fit() before transform()."
390
+ self.logger.error(msg)
391
+ raise NotFittedError(msg)
392
+ assert self.X_train_df_ is not None
393
+
394
+ # 1) Impute the evaluation-masked copy (to compute metrics)
395
+ imputed_eval_df = self._impute_df(self.X_train_df_)
396
+ X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
397
+ self.X_imputed012_ = X_imputed_eval
398
+
399
+ # Evaluate like DL models (0/1/2, then 10-class from decoded strings)
400
+ self._evaluate_and_report()
401
+
402
+ # 2) Impute the FULL dataset (only true missings)
403
+ df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
404
+ df_missingonly.replace(self.missing, np.nan, inplace=True)
405
+ imputed_full_df = self._impute_df(df_missingonly)
406
+ X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
407
+
408
+ # Plot distributions (parity with DL transform())
409
+ if self.ground_truth012_ is None:
410
+ raise NotFittedError(
411
+ "ground_truth012_ is not set; cannot plot distributions."
412
+ )
413
+
414
+ gt_decoded = self.encoder.decode_012(self.ground_truth012_)
415
+ imp_decoded = self.encoder.decode_012(X_imputed_full_012)
416
+ self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
417
+ self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
418
+
419
+ # Return IUPAC strings (same as DL .transform())
420
+ return imp_decoded
421
+
422
+ def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
423
+ """Impute missing cells in df_in using global or population-specific modes.
424
+
425
+ This method imputes missing values in the provided DataFrame using either global modes or population-specific modes, depending on the configuration of the imputer. It fills in missing values (NaNs) with the most frequent genotype for each locus.
426
+
427
+ Args:
428
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
429
+
430
+ Returns:
431
+ pd.DataFrame: DataFrame with missing values imputed.
432
+ """
433
+ return (
434
+ self._impute_global_mode(df_in)
435
+ if not self.by_populations
436
+ else self._impute_by_population_mode(df_in)
437
+ )
438
+
439
+ def _impute_global_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
440
+ """Impute missing cells in df_in using global modes.
441
+
442
+ This method imputes missing values in the provided DataFrame using global modes. It fills in missing values (NaNs) with the most frequent genotype for each locus across all samples.
443
+
444
+ Args:
445
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
446
+
447
+ Returns:
448
+ pd.DataFrame: DataFrame with missing values imputed.
449
+ """
450
+ if df_in.isnull().values.any():
451
+ modes = pd.Series(self.global_modes_)
452
+ df = df_in.fillna(modes)
453
+ else:
454
+ df = df_in.copy()
455
+ return df.astype(np.int16)
456
+
457
+ def _impute_by_population_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
458
+ """Impute missing cells in df_in using population-specific modes.
459
+
460
+ This method imputes missing values in the provided DataFrame using population-specific modes. It fills in missing values (NaNs) with the most frequent genotype for each locus within the corresponding population. If a population-specific mode is not available for a locus, it falls back to the global mode.
461
+
462
+ Args:
463
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
464
+
465
+ Returns:
466
+ pd.DataFrame: DataFrame with missing values imputed.
467
+ """
468
+ if not df_in.isnull().values.any():
469
+ return df_in.astype(np.int16)
470
+
471
+ df = df_in.copy()
472
+ pops = pd.Series(self.pops, index=df.index)
473
+ global_modes = pd.Series(self.global_modes_)
474
+
475
+ pop_modes = pd.DataFrame.from_dict(self.group_modes_, orient="index")
476
+ if pop_modes.empty:
477
+ pop_modes = pd.DataFrame(
478
+ index=pd.Index([], name="population"), columns=df.columns
479
+ )
480
+
481
+ pop_modes = pop_modes.reindex(columns=df.columns)
482
+ pop_modes = pop_modes.fillna(global_modes)
483
+
484
+ aligned_modes = pop_modes.reindex(pops.to_numpy(), fill_value=np.nan)
485
+ aligned_modes = aligned_modes.fillna(global_modes)
486
+
487
+ values = df.to_numpy(dtype=np.float32)
488
+ replacements = aligned_modes.to_numpy(dtype=np.float32)
489
+ mask = np.isnan(values)
490
+ values[mask] = replacements[mask]
491
+
492
+ return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.int16)
493
+
494
+ def _series_mode(self, s: pd.Series) -> int:
495
+ """Compute the mode of a pandas Series, ignoring NaNs.
496
+
497
+ This method computes the mode of a pandas Series, ignoring NaN values. If the Series is empty after removing NaNs, it returns a default value. The method ensures that the mode is one of the valid genotype values (0, 1, or 2), clamping to the default if necessary.
498
+
499
+ Args:
500
+ s (pd.Series): Input pandas Series.
501
+
502
+ Returns:
503
+ int: The mode of the series, or the default value if no valid entries exist.
504
+ """
505
+ s_valid = s.dropna().astype(int)
506
+ if s_valid.empty:
507
+ return self.default
508
+ # Mode among {0,1,2}; if ties, pandas picks the smallest (okay)
509
+ mode_val = int(s_valid.mode().iloc[0])
510
+ if mode_val not in (0, 1, 2):
511
+ # Safety: clamp to valid zygosity in case of odd inputs
512
+ mode_val = self.default if self.default in (0, 1, 2) else 0
513
+ return mode_val
514
+
515
+ def _evaluate_and_report(self) -> None:
516
+ """Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
517
+
518
+ Requires that fit() and transform() have been called. This method evaluates the imputed genotypes against the ground truth for the masked test cells, generating classification reports and confusion matrices for both 0/1/2 zygosity and 10-class IUPAC codes. It logs the results and saves the reports and plots to the designated output directories.
519
+
520
+ Raises:
521
+ NotFittedError: If fit() and transform() have not been called.
522
+ """
523
+ assert (
524
+ self.sim_mask_ is not None
525
+ and self.ground_truth012_ is not None
526
+ and self.X_imputed012_ is not None
527
+ )
528
+ # Cells we masked for eval
529
+ y_true_012 = self.ground_truth012_[self.sim_mask_]
530
+ y_pred_012 = self.X_imputed012_[self.sim_mask_]
531
+ if y_true_012.size == 0:
532
+ self.logger.info("No masked test cells; skipping evaluation.")
533
+ return
534
+
535
+ # 0/1/2 report (REF/HET/ALT), with haploid folding 2->1
536
+ self._evaluate_012_and_plot(y_true_012.copy(), y_pred_012.copy())
537
+
538
+ # 10-class report from decoded IUPAC strings
539
+ # Rebuild per-row/pcol predictions to decode:
540
+ X_pred_eval = self.ground_truth012_.copy()
541
+ X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
542
+
543
+ y_true_dec = self.encoder.decode_012(self.ground_truth012_)
544
+ y_pred_dec = self.encoder.decode_012(X_pred_eval)
545
+
546
+ encodings_dict = {
547
+ "A": 0,
548
+ "C": 1,
549
+ "G": 2,
550
+ "T": 3,
551
+ "W": 4,
552
+ "R": 5,
553
+ "M": 6,
554
+ "K": 7,
555
+ "Y": 8,
556
+ "S": 9,
557
+ "N": -1,
558
+ }
559
+ y_true_int = self.encoder.convert_int_iupac(
560
+ y_true_dec, encodings_dict=encodings_dict
561
+ )
562
+ y_pred_int = self.encoder.convert_int_iupac(
563
+ y_pred_dec, encodings_dict=encodings_dict
564
+ )
565
+
566
+ y_true_10 = y_true_int[self.sim_mask_]
567
+ y_pred_10 = y_pred_int[self.sim_mask_]
568
+ self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
569
+
570
+ def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
571
+ """0/1/2 zygosity report & confusion matrix.
572
+
573
+ This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
574
+
575
+ Args:
576
+ y_true (np.ndarray): True genotypes (0/1/2) for masked
577
+ y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
578
+
579
+ Raises:
580
+ NotFittedError: If fit() and transform() have not been called.
581
+ """
582
+ labels = [0, 1, 2]
583
+ # Haploid parity: fold ALT (2) into ALT/Present (1)
584
+ if self.is_haploid_:
585
+ y_true[y_true == 2] = 1
586
+ y_pred[y_pred == 2] = 1
587
+ labels = [0, 1]
588
+
589
+ metrics = {
590
+ "n_masked_test": int(y_true.size),
591
+ "accuracy": accuracy_score(y_true, y_pred),
592
+ "f1": f1_score(
593
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
594
+ ),
595
+ "precision": precision_score(
596
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
597
+ ),
598
+ "recall": recall_score(
599
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
600
+ ),
601
+ }
602
+ self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
603
+
604
+ report_names = ["REF", "HET"] if self.is_haploid_ else ["REF", "HET", "ALT"]
605
+
606
+ report: dict | str = classification_report(
607
+ y_true,
608
+ y_pred,
609
+ labels=labels,
610
+ target_names=report_names,
611
+ zero_division=0,
612
+ output_dict=True,
613
+ )
614
+
615
+ if not isinstance(report, dict):
616
+ msg = "classification_report did not return a dict as expected."
617
+ self.logger.error(msg)
618
+ raise TypeError(msg)
619
+
620
+ report_subset = {}
621
+ for k, v in report.items():
622
+ tmp = {}
623
+ if isinstance(v, dict) and "support" in v:
624
+ for k2, v2 in v.items():
625
+ if k2 != "support":
626
+ tmp[k2] = v2
627
+ if tmp:
628
+ report_subset[k] = tmp
629
+
630
+ if report_subset:
631
+ pm = PrettyMetrics(
632
+ report_subset,
633
+ precision=3,
634
+ title=f"{self.model_name} Zygosity Report",
635
+ )
636
+ pm.render()
637
+
638
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
639
+
640
+ plots = viz.plot_all(
641
+ report,
642
+ title_prefix=f"{self.model_name} Zygosity Report",
643
+ show=getattr(self, "show_plots", False),
644
+ heatmap_classes_only=True,
645
+ )
646
+
647
+ for name, fig in plots.items():
648
+ fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
649
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
650
+ fig.savefig(fout, dpi=300, facecolor="#111122")
651
+ plt.close(fig)
652
+ elif isinstance(fig, PlotlyFigure):
653
+ fig.write_html(file=fout.with_suffix(".html"))
654
+
655
+ viz._reset_mpl_style()
656
+
657
+ # Save JSON
658
+ self._save_report(report, suffix="zygosity")
659
+
660
+ # Confusion matrix
661
+ self.plotter_.plot_confusion_matrix(
662
+ y_true, y_pred, label_names=report_names, prefix="zygosity"
663
+ )
664
+
665
+ def _evaluate_iupac10_and_plot(
666
+ self, y_true: np.ndarray, y_pred: np.ndarray
667
+ ) -> None:
668
+ """10-class IUPAC report & confusion matrix.
669
+
670
+ This method generates a classification report and confusion matrix for genotypes encoded as 10-class IUPAC codes (0-9). It computes various performance metrics, logs the classification report, and creates visualizations of the results.
671
+
672
+ Args:
673
+ y_true (np.ndarray): True genotypes (0-9) for masked
674
+ y_pred (np.ndarray): Predicted genotypes (0-9) for masked
675
+
676
+ Raises:
677
+ NotFittedError: If fit() and transform() have not been called.
678
+ """
679
+ labels_idx = list(range(10))
680
+ labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
681
+
682
+ metrics = {
683
+ "accuracy": accuracy_score(y_true, y_pred),
684
+ "f1": f1_score(
685
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
686
+ ),
687
+ "precision": precision_score(
688
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
689
+ ),
690
+ "recall": recall_score(
691
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
692
+ ),
693
+ }
694
+ self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
695
+
696
+ report: dict | str = classification_report(
697
+ y_true,
698
+ y_pred,
699
+ labels=labels_idx,
700
+ target_names=labels_names,
701
+ zero_division=0,
702
+ output_dict=True,
703
+ )
704
+
705
+ if not isinstance(report, dict):
706
+ msg = "classification_report did not return a dict as expected."
707
+ self.logger.error(msg)
708
+ raise TypeError(msg)
709
+
710
+ report_subset = {}
711
+ for k, v in report.items():
712
+ tmp = {}
713
+ if isinstance(v, dict) and "support" in v:
714
+ for k2, v2 in v.items():
715
+ if k2 != "support":
716
+ tmp[k2] = v2
717
+ if tmp:
718
+ report_subset[k] = tmp
719
+
720
+ if report_subset:
721
+ pm = PrettyMetrics(
722
+ report_subset,
723
+ precision=3,
724
+ title=f"{self.model_name} IUPAC 10-Class Report",
725
+ )
726
+ pm.render()
727
+
728
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
729
+
730
+ plots = viz.plot_all(
731
+ report,
732
+ title_prefix=f"{self.model_name} IUPAC Report",
733
+ show=getattr(self, "show_plots", False),
734
+ heatmap_classes_only=True,
735
+ )
736
+
737
+ # Reset the style from Optuna's plotting.
738
+ plt.rcParams.update(self.plotter_.param_dict)
739
+
740
+ for name, fig in plots.items():
741
+ fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
742
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
743
+ fig.savefig(fout, dpi=300, facecolor="#111122")
744
+ plt.close(fig)
745
+ elif isinstance(fig, PlotlyFigure):
746
+ fig.write_html(file=fout.with_suffix(".html"))
747
+
748
+ # Reset the style
749
+ viz._reset_mpl_style()
750
+
751
+ # Save JSON
752
+ self._save_report(report, suffix="iupac")
753
+
754
+ # Confusion matrix
755
+ self.plotter_.plot_confusion_matrix(
756
+ y_true, y_pred, label_names=labels_names, prefix="iupac"
757
+ )
758
+
759
+ def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
760
+ """Create train/test split indices.
761
+
762
+ This method creates training and testing indices based on the specified test size or provided test indices. If population-based splitting is enabled, it ensures that the test set includes samples from each population according to the specified test size.
763
+
764
+ Returns:
765
+ Tuple[np.ndarray, np.ndarray]: Arrays of train and test indices.
766
+
767
+ Raises:
768
+ IndexError: If provided test_indices are out of bounds.
769
+ """
770
+ n = self.X012_.shape[0]
771
+ all_idx = np.arange(n, dtype=int)
772
+ if self.test_indices is not None:
773
+ test_idx = np.unique(self.test_indices)
774
+ if np.any((test_idx < 0) | (test_idx >= n)):
775
+ raise IndexError("Some test_indices are out of bounds.")
776
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
777
+ return train_idx, test_idx
778
+
779
+ if self.by_populations and self.pops is not None:
780
+ buckets = []
781
+ for pop in np.unique(self.pops):
782
+ rows = np.where(self.pops == pop)[0]
783
+ k = int(round(self.test_size * rows.size))
784
+ if k > 0:
785
+ buckets.append(self.rng.choice(rows, size=k, replace=False))
786
+ test_idx = (
787
+ np.sort(np.concatenate(buckets)) if buckets else np.array([], dtype=int)
788
+ )
789
+ else:
790
+ k = int(round(self.test_size * n))
791
+ test_idx = (
792
+ self.rng.choice(n, size=k, replace=False)
793
+ if k > 0
794
+ else np.array([], dtype=int)
795
+ )
796
+
797
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
798
+ return train_idx, test_idx
799
+
800
+ def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
801
+ """Save classification report dictionary as a JSON file.
802
+
803
+ This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
804
+
805
+ Args:
806
+ report_dict (Dict[str, float]): The classification report dictionary to save.
807
+ suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
808
+
809
+ Raises:
810
+ NotFittedError: If fit() and transform() have not been called.
811
+ """
812
+ if not self.is_fit_ or self.X_imputed012_ is None:
813
+ msg = "No report to save. Ensure fit() and transform() have been called."
814
+ raise NotFittedError(msg)
815
+
816
+ out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
817
+ with open(out_fp, "w") as f:
818
+ json.dump(report_dict, f, indent=4)
819
+ self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
820
+
821
+ def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
822
+ """Creates the directory structure for storing model outputs.
823
+
824
+ This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
825
+
826
+ Args:
827
+ prefix (str): The prefix for the main output directory.
828
+ outdirs (List[str]): A list of subdirectory names to create within the main directory.
829
+
830
+ Raises:
831
+ Exception: If any of the directories cannot be created.
832
+ """
833
+ formatted_output_dir = Path(f"{prefix}_output")
834
+ base_dir = formatted_output_dir / "Deterministic"
835
+
836
+ for d in outdirs:
837
+ subdir = base_dir / d / self.model_name
838
+ setattr(self, f"{d}_dir", subdir)
839
+ try:
840
+ getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
841
+ except Exception as e:
842
+ msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
843
+ self.logger.error(msg)
844
+ raise Exception(msg)