pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,679 @@
1
+ # Standard library imports
2
+ import json
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5
+
6
+ # Third-party imports
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.metrics import (
12
+ accuracy_score,
13
+ classification_report,
14
+ f1_score,
15
+ precision_score,
16
+ recall_score,
17
+ )
18
+ from snpio import GenotypeEncoder
19
+ from snpio.utils.logging import LoggerManager
20
+
21
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
22
+ from pgsui.data_processing.containers import MostFrequentConfig
23
+ from pgsui.utils.classification_viz import ClassificationReportVisualizer
24
+
25
+ # Local imports
26
+ from pgsui.utils.plotting import Plotting
27
+
28
+ # Type checking imports
29
+ if TYPE_CHECKING:
30
+ from snpio.read_input.genotype_data import GenotypeData
31
+
32
+
33
+ def ensure_mostfrequent_config(
34
+ config: Union[MostFrequentConfig, dict, str, None],
35
+ ) -> MostFrequentConfig:
36
+ """Return a concrete MostFrequentConfig (dataclass, dict, YAML path, or None).
37
+
38
+ Args:
39
+ config (Union[MostFrequentConfig, dict, str, None]): The configuration to ensure is a MostFrequentConfig.
40
+
41
+ Returns:
42
+ MostFrequentConfig: The ensured MostFrequentConfig.
43
+ """
44
+ if config is None:
45
+ return MostFrequentConfig()
46
+ if isinstance(config, MostFrequentConfig):
47
+ return config
48
+ if isinstance(config, str):
49
+ return load_yaml_to_dataclass(
50
+ config, MostFrequentConfig, preset_builder=MostFrequentConfig.from_preset
51
+ )
52
+ if isinstance(config, dict):
53
+ base = MostFrequentConfig()
54
+ # honor optional top-level 'preset'
55
+ preset = config.pop("preset", None)
56
+ if preset:
57
+ base = MostFrequentConfig.from_preset(preset)
58
+
59
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
60
+ for k, v in d.items():
61
+ kk = f"{prefix}.{k}" if prefix else k
62
+ if isinstance(v, dict):
63
+ _flatten(kk, v, out)
64
+ else:
65
+ out[kk] = v
66
+ return out
67
+
68
+ flat = _flatten("", config, {})
69
+ return apply_dot_overrides(base, flat)
70
+
71
+ raise TypeError("config must be a MostFrequentConfig, dict, YAML path, or None.")
72
+
73
+
74
+ class ImputeMostFrequent:
75
+ """Most-frequent (mode) imputer that mirrors DL evaluation on 0/1/2.
76
+
77
+ This imputer computes the most frequent genotype (mode) for each locus based on the training set and uses it to fill in missing values. It supports both global modes and population-specific modes if population data is provided. The imputer follows an evaluation protocol similar to deep learning models, including splitting the data into training and testing sets, masking observed cells in the test set for evaluation, and producing detailed classification reports and plots. It handles both diploid and haploid data, with special considerations for haploid scenarios. The imputer is designed to work seamlessly with genotype data encoded in 0/1/2 format, where -1 indicates missing values.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ genotype_data: "GenotypeData",
83
+ *,
84
+ config: Optional[Union[MostFrequentConfig, dict, str]] = None,
85
+ overrides: Optional[dict] = None,
86
+ ) -> None:
87
+ """Initialize the Most-Frequent (mode) imputer from a unified config.
88
+
89
+ This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and various parameters based on the configuration. The imputer is prepared to handle population-specific modes if specified in the configuration.
90
+
91
+ Args:
92
+ genotype_data (GenotypeData): Backing genotype data.
93
+ config (MostFrequentConfig | dict | str | None): Configuration as a dataclass,
94
+ nested dict, or YAML path. If None, defaults are used.
95
+ overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'algo.by_populations': True, 'split.test_size': 0.3}.
96
+
97
+ Notes:
98
+ - This mirrors other config-driven models (AE/VAE/NLPCA/UBP).
99
+ - Evaluation split behavior uses cfg.split; plotting uses cfg.plot.
100
+ - I/O/logging seeds and verbosity use cfg.io.
101
+ """
102
+ # Normalize config then apply highest-precedence overrides
103
+ cfg = ensure_mostfrequent_config(config)
104
+ if overrides:
105
+ cfg = apply_dot_overrides(cfg, overrides)
106
+ self.cfg = cfg
107
+
108
+ # Basic fields
109
+ self.genotype_data = genotype_data
110
+ self.prefix = cfg.io.prefix
111
+ self.verbose = cfg.io.verbose
112
+ self.debug = cfg.io.debug
113
+
114
+ # Logger
115
+ logman = LoggerManager(
116
+ __name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
117
+ )
118
+ self.logger = logman.get_logger()
119
+
120
+ # RNG / encoder
121
+ self.rng = np.random.default_rng(cfg.io.seed)
122
+ self.encoder = GenotypeEncoder(self.genotype_data)
123
+
124
+ # Work in 0/1/2 with -1 for missing (parity with DL modules)
125
+ X012 = self.encoder.genotypes_012.astype(np.int16, copy=True)
126
+ X012[X012 < 0] = -1
127
+ self.X012_ = X012
128
+ self.num_features_ = X012.shape[1]
129
+
130
+ # Split & algo knobs
131
+ self.test_size = float(cfg.split.test_size)
132
+ self.test_indices = (
133
+ None
134
+ if cfg.split.test_indices is None
135
+ else np.asarray(cfg.split.test_indices, dtype=int)
136
+ )
137
+ self.by_populations = bool(cfg.algo.by_populations)
138
+ self.default = int(cfg.algo.default)
139
+ self.missing = int(cfg.algo.missing)
140
+
141
+ # Populations (if requested)
142
+ self.pops = None
143
+ if self.by_populations:
144
+ pops = getattr(self.genotype_data, "populations", None)
145
+ if pops is None:
146
+ msg = "by_populations=True requires genotype_data.populations."
147
+ self.logger.error(msg)
148
+ raise TypeError(msg)
149
+ self.pops = np.asarray(pops)
150
+ if len(self.pops) != self.X012_.shape[0]:
151
+ msg = f"`populations` length ({len(self.pops)}) != number of samples ({self.X012_.shape[0]})."
152
+ self.logger.error(msg)
153
+ raise ValueError(msg)
154
+
155
+ # State
156
+ self.is_fit_: bool = False
157
+ self.global_modes_: Dict[int, int] = {}
158
+ self.group_modes_: Dict[str | int, Dict[int, int]] = {}
159
+ self.sim_mask_: Optional[np.ndarray] = None
160
+ self.train_idx_: Optional[np.ndarray] = None
161
+ self.test_idx_: Optional[np.ndarray] = None
162
+ self.X_train_df_: Optional[pd.DataFrame] = None
163
+ self.ground_truth012_: Optional[np.ndarray] = None
164
+ self.metrics_: Dict[str, int | float] = {}
165
+ self.X_imputed012_: Optional[np.ndarray] = None
166
+
167
+ # Ploidy heuristic for 0/1/2 scoring parity
168
+ uniq = np.unique(self.X012_[self.X012_ != -1])
169
+ self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
170
+
171
+ # Plotting (use config, not genotype_data fields)
172
+ self.plot_format = cfg.plot.fmt
173
+ self.plot_fontsize = cfg.plot.fontsize
174
+ self.plot_despine = cfg.plot.despine
175
+ self.plot_dpi = cfg.plot.dpi
176
+ self.show_plots = cfg.plot.show
177
+
178
+ self.model_name = (
179
+ "ImputeMostFrequentPerPop" if self.by_populations else "ImputeMostFrequent"
180
+ )
181
+ self.plotter_ = Plotting(
182
+ self.model_name,
183
+ prefix=self.prefix,
184
+ plot_format=self.plot_format,
185
+ plot_fontsize=self.plot_fontsize,
186
+ plot_dpi=self.plot_dpi,
187
+ title_fontsize=self.plot_fontsize,
188
+ despine=self.plot_despine,
189
+ show_plots=self.show_plots,
190
+ verbose=self.verbose,
191
+ debug=self.debug,
192
+ )
193
+
194
+ # Output dirs
195
+ dirs = ["models", "plots", "metrics", "optimize", "parameters"]
196
+ self._create_model_directories(self.prefix, dirs)
197
+
198
+ def fit(self) -> "ImputeMostFrequent":
199
+ """Learn per-locus modes on TRAIN rows; mask all observed cells on TEST rows.
200
+
201
+ This method prepares the data for imputation by splitting it into training and testing sets, computing the most frequent genotype (mode) for each locus based on the training set, and creating a mask to simulate missing data in the test set for evaluation purposes.
202
+
203
+ Returns:
204
+ ImputeMostFrequent: The fitted imputer instance.
205
+ """
206
+ self.train_idx_, self.test_idx_ = self._make_train_test_split()
207
+ self.ground_truth012_ = self.X012_.copy()
208
+
209
+ # Work in DataFrame with NaN as missing for mode computation
210
+ df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
211
+ df_all = df_all.replace(self.missing, np.nan)
212
+ df_all = df_all.replace(-9, np.nan) # Just in case
213
+
214
+ # Modes from TRAIN rows only (per-locus)
215
+ df_train = df_all.iloc[self.train_idx_].copy()
216
+
217
+ self.global_modes_ = {
218
+ col: self._series_mode(df_train[col]) for col in df_train.columns
219
+ }
220
+
221
+ self.group_modes_.clear()
222
+ if self.by_populations:
223
+ tmp = df_train.copy()
224
+ tmp["_pops_"] = self.pops[self.train_idx_]
225
+ for pop, grp in tmp.groupby("_pops_"):
226
+ gdf = grp.drop(columns=["_pops_"])
227
+ self.group_modes_[pop] = {
228
+ col: self._series_mode(gdf[col]) for col in gdf.columns
229
+ }
230
+
231
+ # Mask ALL observed cells on TEST rows (evaluation protocol parity)
232
+ obs_mask = df_all.notna().to_numpy() # observed = not NaN
233
+ test_rows_mask = np.zeros(obs_mask.shape[0], dtype=bool)
234
+
235
+ if self.test_idx_.size > 0:
236
+ test_rows_mask[self.test_idx_] = True
237
+ sim_mask = obs_mask & test_rows_mask[:, None] # cells to mask for eval
238
+
239
+ df_sim = df_all.copy()
240
+ df_sim.values[sim_mask] = np.nan
241
+
242
+ self.sim_mask_ = sim_mask
243
+ self.X_train_df_ = df_sim
244
+ self.is_fit_ = True
245
+
246
+ best_params = self.cfg.to_dict()
247
+ params_fp = self.parameters_dir / "best_parameters.json"
248
+
249
+ with open(params_fp, "w") as f:
250
+ json.dump(best_params, f, indent=4)
251
+
252
+ self.logger.info(
253
+ f"Fit complete. Train rows: {self.train_idx_.size}, Test rows: {self.test_idx_.size}. Masked {int(sim_mask.sum())} observed test cells for evaluation."
254
+ )
255
+ return self
256
+
257
+ def transform(self) -> np.ndarray:
258
+ """Impute missing cells in the FULL dataset; evaluate on masked test cells.
259
+
260
+ This method first imputes the evaluation-masked training DataFrame to compute metrics, then imputes the full dataset (only true missings) for final output. It produces the same evaluation reports and plots as the DL models, including both 0/1/2 zygosity and 10-class IUPAC reports.
261
+
262
+ Returns:
263
+ np.ndarray: Imputed genotypes as IUPAC strings, shape (n_samples, n_variants).
264
+
265
+ Raises:
266
+ NotFittedError: If fit() has not been called prior to transform().
267
+ """
268
+ if not self.is_fit_:
269
+ msg = "Model is not fitted. Call fit() before transform()."
270
+ self.logger.error(msg)
271
+ raise NotFittedError(msg)
272
+ assert self.X_train_df_ is not None
273
+
274
+ # 1) Impute the evaluation-masked copy (to compute metrics)
275
+ imputed_eval_df = self._impute_df(self.X_train_df_)
276
+ X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
277
+ self.X_imputed012_ = X_imputed_eval
278
+
279
+ # Evaluate like DL models (0/1/2, then 10-class from decoded strings)
280
+ self._evaluate_and_report()
281
+
282
+ # 2) Impute the FULL dataset (only true missings)
283
+ df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
284
+ df_missingonly.replace(self.missing, np.nan, inplace=True)
285
+ imputed_full_df = self._impute_df(df_missingonly)
286
+ X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
287
+
288
+ # Plot distributions (parity with DL transform())
289
+ gt_decoded = self.encoder.decode_012(self.ground_truth012_)
290
+ imp_decoded = self.encoder.decode_012(X_imputed_full_012)
291
+ self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
292
+ self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
293
+
294
+ # Return IUPAC strings (same as DL .transform())
295
+ return imp_decoded
296
+
297
+ def _impute_df(self, df_in: pd.DataFrame) -> pd.DataFrame:
298
+ """Impute missing cells in df_in using global or population-specific modes.
299
+
300
+ This method imputes missing values in the provided DataFrame using either global modes or population-specific modes, depending on the configuration of the imputer. It fills in missing values (NaNs) with the most frequent genotype for each locus.
301
+
302
+ Args:
303
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
304
+
305
+ Returns:
306
+ pd.DataFrame: DataFrame with missing values imputed.
307
+ """
308
+ return (
309
+ self._impute_global_mode(df_in)
310
+ if not self.by_populations
311
+ else self._impute_by_population_mode(df_in)
312
+ )
313
+
314
+ def _impute_global_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
315
+ """Impute missing cells in df_in using global modes.
316
+
317
+ This method imputes missing values in the provided DataFrame using global modes. It fills in missing values (NaNs) with the most frequent genotype for each locus across all samples.
318
+
319
+ Args:
320
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
321
+
322
+ Returns:
323
+ pd.DataFrame: DataFrame with missing values imputed.
324
+ """
325
+ if df_in.isnull().values.any():
326
+ modes = pd.Series(self.global_modes_)
327
+ df = df_in.fillna(modes)
328
+ else:
329
+ df = df_in.copy()
330
+ return df.astype(np.int16)
331
+
332
+ def _impute_by_population_mode(self, df_in: pd.DataFrame) -> pd.DataFrame:
333
+ """Impute missing cells in df_in using population-specific modes.
334
+
335
+ This method imputes missing values in the provided DataFrame using population-specific modes. It fills in missing values (NaNs) with the most frequent genotype for each locus within the corresponding population. If a population-specific mode is not available for a locus, it falls back to the global mode.
336
+
337
+ Args:
338
+ df_in (pd.DataFrame): Input DataFrame with missing values as NaN.
339
+
340
+ Returns:
341
+ pd.DataFrame: DataFrame with missing values imputed.
342
+ """
343
+ if not df_in.isnull().values.any():
344
+ return df_in.astype(np.int16)
345
+
346
+ df = df_in.copy()
347
+ pops = pd.Series(self.pops, index=df.index)
348
+ global_modes = pd.Series(self.global_modes_)
349
+
350
+ pop_modes = pd.DataFrame.from_dict(self.group_modes_, orient="index")
351
+ if pop_modes.empty:
352
+ pop_modes = pd.DataFrame(index=pd.Index([], name="population"), columns=df.columns)
353
+
354
+ pop_modes = pop_modes.reindex(columns=df.columns)
355
+ pop_modes = pop_modes.fillna(global_modes)
356
+
357
+ aligned_modes = pop_modes.reindex(pops.to_numpy(), fill_value=np.nan)
358
+ aligned_modes = aligned_modes.fillna(global_modes)
359
+
360
+ values = df.to_numpy(dtype=np.float32)
361
+ replacements = aligned_modes.to_numpy(dtype=np.float32)
362
+ mask = np.isnan(values)
363
+ values[mask] = replacements[mask]
364
+
365
+ return pd.DataFrame(values, columns=df.columns, index=df.index).astype(np.int16)
366
+
367
+ def _series_mode(self, s: pd.Series) -> int:
368
+ """Compute the mode of a pandas Series, ignoring NaNs.
369
+
370
+ This method computes the mode of a pandas Series, ignoring NaN values. If the Series is empty after removing NaNs, it returns a default value. The method ensures that the mode is one of the valid genotype values (0, 1, or 2), clamping to the default if necessary.
371
+
372
+ Args:
373
+ s (pd.Series): Input pandas Series.
374
+
375
+ Returns:
376
+ int: The mode of the series, or the default value if no valid entries exist.
377
+ """
378
+ s_valid = s.dropna().astype(int)
379
+ if s_valid.empty:
380
+ return self.default
381
+ # Mode among {0,1,2}; if ties, pandas picks the smallest (okay)
382
+ mode_val = int(s_valid.mode().iloc[0])
383
+ if mode_val not in (0, 1, 2):
384
+ # Safety: clamp to valid zygosity in case of odd inputs
385
+ mode_val = self.default if self.default in (0, 1, 2) else 0
386
+ return mode_val
387
+
388
+ def _evaluate_and_report(self) -> None:
389
+ """Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
390
+
391
+ Requires that fit() and transform() have been called. This method evaluates the imputed genotypes against the ground truth for the masked test cells, generating classification reports and confusion matrices for both 0/1/2 zygosity and 10-class IUPAC codes. It logs the results and saves the reports and plots to the designated output directories.
392
+
393
+ Raises:
394
+ NotFittedError: If fit() and transform() have not been called.
395
+ """
396
+ assert (
397
+ self.sim_mask_ is not None
398
+ and self.ground_truth012_ is not None
399
+ and self.X_imputed012_ is not None
400
+ )
401
+ # Cells we masked for eval
402
+ y_true_012 = self.ground_truth012_[self.sim_mask_]
403
+ y_pred_012 = self.X_imputed012_[self.sim_mask_]
404
+ if y_true_012.size == 0:
405
+ self.logger.info("No masked test cells; skipping evaluation.")
406
+ return
407
+
408
+ # 0/1/2 report (REF/HET/ALT), with haploid folding 2->1
409
+ self._evaluate_012_and_plot(y_true_012.copy(), y_pred_012.copy())
410
+
411
+ # 10-class report from decoded IUPAC strings
412
+ # Rebuild per-row/pcol predictions to decode:
413
+ X_pred_eval = self.ground_truth012_.copy()
414
+ X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
415
+
416
+ y_true_dec = self.encoder.decode_012(self.ground_truth012_)
417
+ y_pred_dec = self.encoder.decode_012(X_pred_eval)
418
+
419
+ encodings_dict = {
420
+ "A": 0,
421
+ "C": 1,
422
+ "G": 2,
423
+ "T": 3,
424
+ "W": 4,
425
+ "R": 5,
426
+ "M": 6,
427
+ "K": 7,
428
+ "Y": 8,
429
+ "S": 9,
430
+ "N": -1,
431
+ }
432
+ y_true_int = self.encoder.convert_int_iupac(
433
+ y_true_dec, encodings_dict=encodings_dict
434
+ )
435
+ y_pred_int = self.encoder.convert_int_iupac(
436
+ y_pred_dec, encodings_dict=encodings_dict
437
+ )
438
+
439
+ y_true_10 = y_true_int[self.sim_mask_]
440
+ y_pred_10 = y_pred_int[self.sim_mask_]
441
+ self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
442
+
443
+ def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
444
+ """0/1/2 zygosity report & confusion matrix.
445
+
446
+ This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
447
+
448
+ Args:
449
+ y_true (np.ndarray): True genotypes (0/1/2) for masked
450
+ y_pred (np.ndarray): Predicted genotypes (0/1/2) for masked
451
+
452
+ Raises:
453
+ NotFittedError: If fit() and transform() have not been called.
454
+ """
455
+ labels = [0, 1, 2]
456
+ # Haploid parity: fold ALT (2) into ALT/Present (1)
457
+ if self.is_haploid_:
458
+ y_true[y_true == 2] = 1
459
+ y_pred[y_pred == 2] = 1
460
+ labels = [0, 1]
461
+
462
+ metrics = {
463
+ "n_masked_test": int(y_true.size),
464
+ "accuracy": accuracy_score(y_true, y_pred),
465
+ "f1": f1_score(
466
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
467
+ ),
468
+ "precision": precision_score(
469
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
470
+ ),
471
+ "recall": recall_score(
472
+ y_true, y_pred, average="macro", labels=labels, zero_division=0
473
+ ),
474
+ }
475
+ self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
476
+
477
+ report_names = ["REF", "HET"] if self.is_haploid_ else ["REF", "HET", "ALT"]
478
+
479
+ self.logger.info(
480
+ f"\n{classification_report(y_true, y_pred, labels=labels, target_names=report_names, zero_division=0)}"
481
+ )
482
+
483
+ report = classification_report(
484
+ y_true,
485
+ y_pred,
486
+ labels=labels,
487
+ target_names=report_names,
488
+ zero_division=0,
489
+ output_dict=True,
490
+ )
491
+
492
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
493
+
494
+ plots = viz.plot_all(
495
+ report,
496
+ title_prefix=f"{self.model_name} Zygosity Report",
497
+ show=getattr(self, "show_plots", False),
498
+ heatmap_classes_only=True,
499
+ )
500
+
501
+ for name, fig in plots.items():
502
+ fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
503
+ if hasattr(fig, "savefig"):
504
+ fig.savefig(fout, dpi=300, facecolor="#111122")
505
+ plt.close(fig)
506
+ else:
507
+ fig.write_html(file=fout.with_suffix(".html"))
508
+
509
+ viz._reset_mpl_style()
510
+
511
+ # Save JSON
512
+ self._save_report(report, suffix="zygosity")
513
+
514
+ # Confusion matrix
515
+ self.plotter_.plot_confusion_matrix(
516
+ y_true, y_pred, label_names=report_names, prefix="zygosity"
517
+ )
518
+
519
+ def _evaluate_iupac10_and_plot(
520
+ self, y_true: np.ndarray, y_pred: np.ndarray
521
+ ) -> None:
522
+ """10-class IUPAC report & confusion matrix.
523
+
524
+ This method generates a classification report and confusion matrix for genotypes encoded as 10-class IUPAC codes (0-9). It computes various performance metrics, logs the classification report, and creates visualizations of the results.
525
+
526
+ Args:
527
+ y_true (np.ndarray): True genotypes (0-9) for masked
528
+ y_pred (np.ndarray): Predicted genotypes (0-9) for masked
529
+
530
+ Raises:
531
+ NotFittedError: If fit() and transform() have not been called.
532
+ """
533
+ labels_idx = list(range(10))
534
+ labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
535
+
536
+ metrics = {
537
+ "accuracy": accuracy_score(y_true, y_pred),
538
+ "f1": f1_score(
539
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
540
+ ),
541
+ "precision": precision_score(
542
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
543
+ ),
544
+ "recall": recall_score(
545
+ y_true, y_pred, average="macro", labels=labels_idx, zero_division=0
546
+ ),
547
+ }
548
+ self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
549
+
550
+ self.logger.info(
551
+ f"\n{classification_report(y_true, y_pred, labels=labels_idx, target_names=labels_names, zero_division=0)}"
552
+ )
553
+
554
+ report = classification_report(
555
+ y_true,
556
+ y_pred,
557
+ labels=labels_idx,
558
+ target_names=labels_names,
559
+ zero_division=0,
560
+ output_dict=True,
561
+ )
562
+
563
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
564
+
565
+ plots = viz.plot_all(
566
+ report,
567
+ title_prefix=f"{self.model_name} IUPAC Report",
568
+ show=getattr(self, "show_plots", False),
569
+ heatmap_classes_only=True,
570
+ )
571
+
572
+ # Reset the style from Optuna's plotting.
573
+ plt.rcParams.update(self.plotter_.param_dict)
574
+
575
+ for name, fig in plots.items():
576
+ fout = self.plots_dir / f"iupac_report_{name}.{self.plot_format}"
577
+ if hasattr(fig, "savefig"):
578
+ fig.savefig(fout, dpi=300, facecolor="#111122")
579
+ plt.close(fig)
580
+ else:
581
+ fig.write_html(file=fout.with_suffix(".html"))
582
+
583
+ # Reset the style
584
+ viz._reset_mpl_style()
585
+
586
+ # Save JSON
587
+ self._save_report(report, suffix="iupac")
588
+
589
+ # Confusion matrix
590
+ self.plotter_.plot_confusion_matrix(
591
+ y_true, y_pred, label_names=labels_names, prefix="iupac"
592
+ )
593
+
594
+ def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
595
+ """Create train/test split indices.
596
+
597
+ This method creates training and testing indices based on the specified test size or provided test indices. If population-based splitting is enabled, it ensures that the test set includes samples from each population according to the specified test size.
598
+
599
+ Returns:
600
+ Tuple[np.ndarray, np.ndarray]: Arrays of train and test indices.
601
+
602
+ Raises:
603
+ IndexError: If provided test_indices are out of bounds.
604
+ """
605
+ n = self.X012_.shape[0]
606
+ all_idx = np.arange(n, dtype=int)
607
+ if self.test_indices is not None:
608
+ test_idx = np.unique(self.test_indices)
609
+ if np.any((test_idx < 0) | (test_idx >= n)):
610
+ raise IndexError("Some test_indices are out of bounds.")
611
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
612
+ return train_idx, test_idx
613
+
614
+ if self.by_populations and self.pops is not None:
615
+ buckets = []
616
+ for pop in np.unique(self.pops):
617
+ rows = np.where(self.pops == pop)[0]
618
+ k = int(round(self.test_size * rows.size))
619
+ if k > 0:
620
+ buckets.append(self.rng.choice(rows, size=k, replace=False))
621
+ test_idx = (
622
+ np.sort(np.concatenate(buckets)) if buckets else np.array([], dtype=int)
623
+ )
624
+ else:
625
+ k = int(round(self.test_size * n))
626
+ test_idx = (
627
+ self.rng.choice(n, size=k, replace=False)
628
+ if k > 0
629
+ else np.array([], dtype=int)
630
+ )
631
+
632
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
633
+ return train_idx, test_idx
634
+
635
+ def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
636
+ """Save classification report dictionary as a JSON file.
637
+
638
+ This method saves the provided classification report dictionary to a JSON file in the metrics directory, appending the specified suffix to the filename.
639
+
640
+ Args:
641
+ report_dict (Dict[str, float]): The classification report dictionary to save.
642
+ suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
643
+
644
+ Raises:
645
+ NotFittedError: If fit() and transform() have not been called.
646
+ """
647
+ if not self.is_fit_ or self.X_imputed012_ is None:
648
+ msg = "No report to save. Ensure fit() and transform() have been called."
649
+ raise NotFittedError(msg)
650
+
651
+ out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
652
+ with open(out_fp, "w") as f:
653
+ json.dump(report_dict, f, indent=4)
654
+ self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
655
+
656
+ def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
657
+ """Creates the directory structure for storing model outputs.
658
+
659
+ This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
660
+
661
+ Args:
662
+ prefix (str): The prefix for the main output directory.
663
+ outdirs (List[str]): A list of subdirectory names to create within the main directory.
664
+
665
+ Raises:
666
+ Exception: If any of the directories cannot be created.
667
+ """
668
+ formatted_output_dir = Path(f"{prefix}_output")
669
+ base_dir = formatted_output_dir / "Deterministic"
670
+
671
+ for d in outdirs:
672
+ subdir = base_dir / d / self.model_name
673
+ setattr(self, f"{d}_dir", subdir)
674
+ try:
675
+ getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
676
+ except Exception as e:
677
+ msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
678
+ self.logger.error(msg)
679
+ raise Exception(msg)