pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,669 @@
1
+ # Standard library
2
+ import json
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
5
+
6
+ # Third-party
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ from matplotlib.figure import Figure
11
+ from plotly.graph_objs import Figure as PlotlyFigure
12
+ from sklearn.exceptions import NotFittedError
13
+ from sklearn.metrics import (
14
+ accuracy_score,
15
+ classification_report,
16
+ f1_score,
17
+ precision_score,
18
+ recall_score,
19
+ )
20
+
21
+ # Project
22
+ from snpio import GenotypeEncoder
23
+ from snpio.utils.logging import LoggerManager
24
+
25
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
26
+ from pgsui.data_processing.containers import RefAlleleConfig
27
+ from pgsui.data_processing.transformers import SimMissingTransformer
28
+ from pgsui.utils.classification_viz import ClassificationReportVisualizer
29
+ from pgsui.utils.logging_utils import configure_logger
30
+ from pgsui.utils.plotting import Plotting
31
+ from pgsui.utils.pretty_metrics import PrettyMetrics
32
+
33
+ if TYPE_CHECKING:
34
+ from snpio import TreeParser
35
+ from snpio.read_input.genotype_data import GenotypeData
36
+
37
+
38
+ def ensure_refallele_config(
39
+ config: Union[RefAlleleConfig, dict, str, None],
40
+ ) -> RefAlleleConfig:
41
+ """Return a concrete RefAlleleConfig (dataclass, dict, YAML path, or None).
42
+
43
+ This function normalizes the input configuration for the RefAllele imputer. It accepts a RefAlleleConfig instance, a dictionary of parameters, a path to a YAML file, or None. If None is provided, it returns a default RefAlleleConfig instance. If a dictionary is provided, it flattens any nested structures and applies the parameters to a base configuration, honoring any top-level 'preset' key. If a string path is provided, it loads the configuration from the specified YAML file.
44
+
45
+ Args:
46
+ config (Union[RefAlleleConfig, dict, str, None]): Configuration input which can be a RefAlleleConfig instance, a dictionary of parameters, a path to a YAML file, or None.
47
+
48
+ Returns:
49
+ RefAlleleConfig: A concrete RefAlleleConfig instance.
50
+
51
+ Raises:
52
+ TypeError: If the input type is not supported.
53
+ """
54
+ if config is None:
55
+ return RefAlleleConfig()
56
+ if isinstance(config, RefAlleleConfig):
57
+ return config
58
+ if isinstance(config, str):
59
+ return load_yaml_to_dataclass(config, RefAlleleConfig)
60
+ if isinstance(config, dict):
61
+ base = RefAlleleConfig()
62
+ # honor optional top-level 'preset'
63
+ preset = config.pop("preset", None)
64
+ if preset:
65
+ base = RefAlleleConfig.from_preset(preset)
66
+
67
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
68
+ for k, v in d.items():
69
+ kk = f"{prefix}.{k}" if prefix else k
70
+ if isinstance(v, dict):
71
+ _flatten(kk, v, out)
72
+ else:
73
+ out[kk] = v
74
+ return out
75
+
76
+ flat = _flatten("", config, {})
77
+ return apply_dot_overrides(base, flat)
78
+
79
+ raise TypeError(
80
+ f"config must be RefAlleleConfig, dict, YAML path, or None, but got: {type(config)}."
81
+ )
82
+
83
+
84
+ class ImputeRefAllele:
85
+ """Deterministic imputer that replaces all missing 0/1/2 genotype values with the REF genotype (0).
86
+
87
+ The imputer works on 0/1/2 with -1 as missing. Evaluation splits samples into TRAIN/TEST once. Masks ALL originally observed cells on TEST rows for eval. Produces: 0/1/2 (zygosity) classification report + confusion matrix 10-class IUPAC classification report (via decode_012) + confusion matrix. Plots genotype distribution before/after imputation.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ genotype_data: "GenotypeData",
93
+ *,
94
+ tree_parser: Optional["TreeParser"] = None,
95
+ config: Optional[Union[RefAlleleConfig, dict, str]] = None,
96
+ overrides: Optional[dict] = None,
97
+ simulate_missing: bool = True,
98
+ sim_strategy: Literal[
99
+ "random",
100
+ "random_weighted",
101
+ "random_weighted_inv",
102
+ "nonrandom",
103
+ "nonrandom_weighted",
104
+ ] = "random",
105
+ sim_prop: float = 0.2,
106
+ sim_kwargs: Optional[dict] = None,
107
+ ) -> None:
108
+ """Initialize the Ref-Allele imputer from a unified config.
109
+
110
+ This constructor ensures that the provided configuration is valid and initializes the imputer's internal state. It sets up logging, random number generation, genotype encoding, and various parameters based on the configuration. The imputer is prepared to handle population-specific modes if specified in the configuration.
111
+
112
+ Args:
113
+ genotype_data (GenotypeData): Backing genotype data.
114
+ tree_parser (Optional[TreeParser]): Optional SNPio phylogenetic tree parser for population-specific modes.
115
+ config (RefAlleleConfig | dict | str | None): Configuration as a dataclass, nested dict, or YAML path. If None, defaults are used.
116
+ overrides (dict | None): Flat dot-key overrides applied last with highest precedence, e.g. {'split.test_size': 0.25, 'algo.missing': -1}.
117
+ simulate_missing (bool): Whether to simulate missing data during evaluation. Default is True.
118
+ sim_strategy (Literal): Strategy for simulating missing data if enabled in config.
119
+ sim_prop (float): Proportion of data to simulate as missing if enabled in config.
120
+ sim_kwargs (Optional[dict]): Additional keyword arguments for the simulated missing data transformer.
121
+ """
122
+ # Normalize config then apply highest-precedence overrides
123
+ cfg = ensure_refallele_config(config)
124
+ if overrides:
125
+ cfg = apply_dot_overrides(cfg, overrides)
126
+ self.cfg = cfg
127
+
128
+ # Basic fields
129
+ self.genotype_data = genotype_data
130
+ self.tree_parser = tree_parser
131
+ self.prefix = cfg.io.prefix
132
+ self.verbose = cfg.io.verbose
133
+ self.debug = cfg.io.debug
134
+
135
+ # Simulation knobs (shared with other deterministic imputers)
136
+ if cfg.sim is None:
137
+ self.simulate_missing = simulate_missing
138
+ self.sim_strategy = sim_strategy
139
+ self.sim_prop = sim_prop
140
+ self.sim_kwargs = sim_kwargs or {}
141
+ else:
142
+ sim_cfg = cfg.sim
143
+ self.simulate_missing = getattr(
144
+ sim_cfg, "simulate_missing", simulate_missing
145
+ )
146
+ self.sim_strategy = getattr(sim_cfg, "sim_strategy", sim_strategy)
147
+ self.sim_prop = float(getattr(sim_cfg, "sim_prop", sim_prop))
148
+ self.sim_kwargs: Dict[str, Any] = dict(
149
+ getattr(sim_cfg, "sim_kwargs", sim_kwargs) or {}
150
+ )
151
+
152
+ # Output dirs
153
+ self.plots_dir: Path
154
+ self.metrics_dir: Path
155
+ self.parameters_dir: Path
156
+ self.model_dir: Path
157
+ self.optimize_dir: Path
158
+
159
+ # Logger
160
+ logman = LoggerManager(
161
+ __name__, prefix=self.prefix, verbose=self.verbose, debug=self.debug
162
+ )
163
+ self.logger = configure_logger(
164
+ logman.get_logger(), verbose=self.verbose, debug=self.debug
165
+ )
166
+
167
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
168
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
169
+ self.logger.error(msg)
170
+ raise ValueError(msg)
171
+
172
+ # RNG / encoder
173
+ self.rng = np.random.default_rng(cfg.io.seed)
174
+ self.encoder = GenotypeEncoder(self.genotype_data)
175
+
176
+ # Work in 0/1/2 with -1 for missing
177
+ X012 = self.encoder.genotypes_012.astype(np.int16, copy=True)
178
+ X012[X012 < 0] = -1
179
+ self.X012_ = X012
180
+ self.num_features_ = X012.shape[1]
181
+
182
+ # Split & algo knobs
183
+ self.test_size = float(cfg.split.test_size)
184
+ self.test_indices = (
185
+ None
186
+ if cfg.split.test_indices is None
187
+ else np.asarray(cfg.split.test_indices, dtype=int)
188
+ )
189
+ self.missing = int(cfg.algo.missing)
190
+
191
+ # State
192
+ self.is_fit_: bool = False
193
+ self.sim_mask_: np.ndarray | None = None
194
+ self.train_idx_: np.ndarray | None = None
195
+ self.test_idx_: np.ndarray | None = None
196
+ self.X_train_df_: pd.DataFrame | None = None
197
+ self.ground_truth012_: np.ndarray | None = None
198
+ self.X_imputed012_: np.ndarray | None = None
199
+ self.metrics_: Dict[str, int | float] = {}
200
+
201
+ # Ploidy heuristic for 0/1/2 scoring parity
202
+ uniq = np.unique(self.X012_[self.X012_ != -1])
203
+ self.is_haploid_ = np.array_equal(np.sort(uniq), np.array([0, 2]))
204
+
205
+ # Plotting (use config)
206
+ self.plot_format = cfg.plot.fmt
207
+ self.plot_fontsize = cfg.plot.fontsize
208
+ self.plot_despine = cfg.plot.despine
209
+ self.plot_dpi = cfg.plot.dpi
210
+ self.show_plots = cfg.plot.show
211
+
212
+ self.model_name = "ImputeRefAllele"
213
+ self.plotter_ = Plotting(
214
+ self.model_name,
215
+ prefix=self.prefix,
216
+ plot_format=self.plot_format,
217
+ plot_fontsize=self.plot_fontsize,
218
+ plot_dpi=self.plot_dpi,
219
+ title_fontsize=self.plot_fontsize,
220
+ despine=self.plot_despine,
221
+ show_plots=self.show_plots,
222
+ verbose=self.verbose,
223
+ debug=self.debug,
224
+ multiqc=True,
225
+ multiqc_section=f"PG-SUI: {self.model_name} Model Imputation",
226
+ )
227
+
228
+ # Output dirs
229
+ dirs = ["models", "plots", "metrics", "optimize", "parameters"]
230
+ self._create_model_directories(self.prefix, dirs)
231
+
232
+ def fit(self) -> "ImputeRefAllele":
233
+ """Create TRAIN/TEST split and build eval mask, with optional sim-missing.
234
+
235
+ This method prepares the imputer by splitting the data into training and testing sets and constructing an evaluation mask. If `cfg.sim.simulate_missing` is False (default), it masks all originally observed genotype entries on TEST rows. If `cfg.sim.simulate_missing` is True, it uses SimMissingTransformer to select a subset of observed cells as simulated-missing, then restricts that mask to TEST rows only. Evaluation is then performed only on these simulated-missing cells, mirroring the deep learning models.
236
+
237
+ Returns:
238
+ ImputeRefAllele: The fitted imputer instance.
239
+ """
240
+ # Train/test split indices
241
+ self.train_idx_, self.test_idx_ = self._make_train_test_split()
242
+ self.ground_truth012_ = self.X012_.copy()
243
+
244
+ # Use NaN for missing inside a DataFrame to leverage fillna
245
+ df_all = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
246
+ df_all = df_all.replace(self.missing, np.nan)
247
+ df_all = df_all.replace(-9, np.nan) # Just in case
248
+
249
+ # Observed mask in the ORIGINAL data (before any simulated-missing)
250
+ obs_mask = df_all.notna().to_numpy() # shape (n_samples, n_loci)
251
+
252
+ # TEST row selector
253
+ test_rows_mask = np.zeros(obs_mask.shape[0], dtype=bool)
254
+ if self.test_idx_ is not None and self.test_idx_.size > 0:
255
+ test_rows_mask[self.test_idx_] = True
256
+
257
+ # Decide how to build the sim mask: legacy vs simulated-missing
258
+ if getattr(self, "simulate_missing", False):
259
+ # Simulate missing on the full matrix; we only use the mask.
260
+ tr = SimMissingTransformer(
261
+ genotype_data=self.genotype_data,
262
+ tree_parser=self.tree_parser,
263
+ prop_missing=self.sim_prop,
264
+ strategy=self.sim_strategy,
265
+ missing_val=-9,
266
+ mask_missing=True,
267
+ verbose=self.verbose,
268
+ **(self.sim_kwargs or {}),
269
+ )
270
+ tr.fit(self.ground_truth012_.copy())
271
+ sim_mask_global = tr.sim_missing_mask_.astype(bool)
272
+
273
+ # Only consider cells that were originally observed
274
+ sim_mask_global = sim_mask_global & obs_mask
275
+
276
+ # Restrict evaluation to TEST rows only
277
+ sim_mask = sim_mask_global & test_rows_mask[:, None]
278
+ mode_desc = "simulated missing on TEST rows"
279
+ else:
280
+ # Legacy behavior: mask ALL originally observed TEST cells
281
+ sim_mask = obs_mask & test_rows_mask[:, None]
282
+ mode_desc = "all originally observed cells on TEST rows"
283
+
284
+ # Apply eval mask: set these cells to NaN in the eval DataFrame
285
+ df_sim = df_all.copy()
286
+ df_sim.values[sim_mask] = np.nan
287
+
288
+ # Store state
289
+ self.sim_mask_ = sim_mask
290
+ self.X_train_df_ = df_sim
291
+ self.is_fit_ = True
292
+
293
+ n_masked = int(sim_mask.sum())
294
+ self.logger.info(
295
+ f"Fit complete. Train rows: {self.train_idx_.size}, "
296
+ f"Test rows: {self.test_idx_.size}. "
297
+ f"Masked {n_masked} cells for evaluation ({mode_desc})."
298
+ )
299
+
300
+ # Persist config for reproducibility
301
+ params_fp = self.parameters_dir / "best_parameters.json"
302
+ best_params = self.cfg.to_dict()
303
+ with open(params_fp, "w") as f:
304
+ json.dump(best_params, f, indent=4)
305
+
306
+ return self
307
+
308
+ def transform(self) -> np.ndarray:
309
+ """Impute missing values with REF genotype (0) and evaluate on masked test cells.
310
+
311
+ This method performs the imputation by replacing all missing genotype values with the REF genotype (0). It evaluates the imputation performance on the masked test cells, producing classification reports and plots that mirror those generated by deep learning models. The final output is the fully imputed genotype matrix in IUPAC string format.
312
+
313
+ Returns:
314
+ np.ndarray: The fully imputed genotype matrix in IUPAC string format.
315
+
316
+ Raises:
317
+ NotFittedError: If the model has not been fitted yet.
318
+ """
319
+ if not self.is_fit_:
320
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
321
+ assert self.X_train_df_ is not None
322
+
323
+ # 1) Impute the evaluation-masked copy (compute metrics)
324
+ imputed_eval_df = self._impute_ref(df_in=self.X_train_df_)
325
+ X_imputed_eval = imputed_eval_df.to_numpy(dtype=np.int16)
326
+ self.X_imputed012_ = X_imputed_eval
327
+
328
+ # Evaluate parity with DL models
329
+ self._evaluate_and_report()
330
+
331
+ # 2) Impute the FULL dataset (only true missings)
332
+ df_missingonly = pd.DataFrame(self.ground_truth012_, dtype=np.float32)
333
+ df_missingonly = df_missingonly.replace(self.missing, np.nan)
334
+ df_missingonly = df_missingonly.replace(-9, np.nan) # Just in case
335
+
336
+ imputed_full_df = self._impute_ref(df_in=df_missingonly)
337
+ X_imputed_full_012 = imputed_full_df.to_numpy(dtype=np.int16)
338
+
339
+ # Plot distributions (like DL .transform())
340
+
341
+ if self.ground_truth012_ is None:
342
+ msg = "ground_truth012_ is None; cannot plot distributions."
343
+ self.logger.error(msg)
344
+
345
+ raise NotFittedError("ground_truth012_ is None; cannot plot distributions.")
346
+ gt_decoded = self.encoder.decode_012(self.ground_truth012_)
347
+ imp_decoded = self.encoder.decode_012(X_imputed_full_012)
348
+ self.plotter_.plot_gt_distribution(gt_decoded, is_imputed=False)
349
+ self.plotter_.plot_gt_distribution(imp_decoded, is_imputed=True)
350
+
351
+ # Return IUPAC strings
352
+ return imp_decoded
353
+
354
+ def _impute_ref(self, df_in: pd.DataFrame) -> pd.DataFrame:
355
+ """Replace every NaN with the REF genotype code (0) across all loci.
356
+
357
+ This is the deterministic REF-allele imputation in 0/1/2 encoding. The method fills all NaN values in the input DataFrame with 0, representing the REF genotype. The operation is performed column-wise, and since the fill value is constant, it is efficient to apply it in a vectorized manner.
358
+
359
+ Args:
360
+ df_in (pd.DataFrame): Input DataFrame with NaNs representing missing genotypes.
361
+
362
+ Returns:
363
+ pd.DataFrame: DataFrame with NaNs replaced by 0 (REF genotype).
364
+ """
365
+ df = df_in.copy()
366
+ # Fill all NaNs with 0 (homozygous REF) column-wise; constant so vectorized is fine
367
+ df = df.fillna(0)
368
+ return df.astype(np.int16)
369
+
370
+ def _evaluate_and_report(self) -> None:
371
+ """Evaluate imputed vs. ground truth on masked test cells; produce reports and plots.
372
+
373
+ Requires that fit() and transform() have been called. This method evaluates the imputed genotypes against the ground truth for the masked test cells, generating classification reports and confusion matrices for both 0/1/2 zygosity and 10-class IUPAC codes. It logs the results and saves the reports and plots to the designated output directories.
374
+
375
+ Raises:
376
+ NotFittedError: If fit() and transform() have not been called.
377
+ """
378
+ assert (
379
+ self.sim_mask_ is not None
380
+ and self.ground_truth012_ is not None
381
+ and self.X_imputed012_ is not None
382
+ )
383
+ y_true_012 = self.ground_truth012_[self.sim_mask_]
384
+ y_pred_012 = self.X_imputed012_[self.sim_mask_]
385
+
386
+ if y_true_012.size == 0:
387
+ self.logger.info("No masked test cells; skipping evaluation.")
388
+ return
389
+
390
+ # 0/1/2 report (REF/HET/ALT), with haploid folding 2->1 if needed
391
+ self._evaluate_012_and_plot(y_true_012.copy(), y_pred_012.copy())
392
+
393
+ # 10-class IUPAC report from decoded strings (parity with DL)
394
+ X_pred_eval = self.ground_truth012_.copy()
395
+ X_pred_eval[self.sim_mask_] = self.X_imputed012_[self.sim_mask_]
396
+
397
+ y_true_dec = self.encoder.decode_012(self.ground_truth012_)
398
+ y_pred_dec = self.encoder.decode_012(X_pred_eval)
399
+
400
+ encodings_dict = {
401
+ "A": 0,
402
+ "C": 1,
403
+ "G": 2,
404
+ "T": 3,
405
+ "W": 4,
406
+ "R": 5,
407
+ "M": 6,
408
+ "K": 7,
409
+ "Y": 8,
410
+ "S": 9,
411
+ "N": -1,
412
+ }
413
+ y_true_int = self.encoder.convert_int_iupac(
414
+ y_true_dec, encodings_dict=encodings_dict
415
+ )
416
+ y_pred_int = self.encoder.convert_int_iupac(
417
+ y_pred_dec, encodings_dict=encodings_dict
418
+ )
419
+ y_true_10 = y_true_int[self.sim_mask_]
420
+ y_pred_10 = y_pred_int[self.sim_mask_]
421
+ self._evaluate_iupac10_and_plot(y_true_10, y_pred_10)
422
+
423
+ def _evaluate_012_and_plot(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
424
+ """0/1/2 zygosity report & confusion matrix.
425
+
426
+ This method generates a classification report and confusion matrix for genotypes encoded as 0 (REF), 1 (HET), and 2 (ALT). If the data is determined to be haploid (only 0 and 2 present), it folds the ALT genotype (2) into HET (1) for evaluation purposes. The method computes various performance metrics, logs the classification report, and creates visualizations of the results.
427
+
428
+ Args:
429
+ y_true (np.ndarray): True genotypes (0/1/2) for masked
430
+ y_pred (np.ndarray): Predicted genotypes (0/1/2) for
431
+ """
432
+ labels = [0, 1, 2]
433
+ report_names = ["REF", "HET", "ALT"]
434
+
435
+ # Haploid parity: fold 2 -> 1
436
+ if self.is_haploid_:
437
+ y_true[y_true == 2] = 1
438
+ y_pred[y_pred == 2] = 1
439
+ labels = [0, 1]
440
+ report_names = ["REF", "ALT"]
441
+
442
+ metrics = {
443
+ "n_masked_test": int(y_true.size),
444
+ "accuracy": accuracy_score(y_true, y_pred),
445
+ "f1": f1_score(
446
+ y_true, y_pred, average="weighted", labels=labels, zero_division=0
447
+ ),
448
+ "precision": precision_score(
449
+ y_true, y_pred, average="weighted", labels=labels, zero_division=0
450
+ ),
451
+ "recall": recall_score(
452
+ y_true, y_pred, average="weighted", labels=labels, zero_division=0
453
+ ),
454
+ }
455
+ self.metrics_.update({f"zygosity_{k}": v for k, v in metrics.items()})
456
+
457
+ report: str | dict = classification_report(
458
+ y_true,
459
+ y_pred,
460
+ labels=labels,
461
+ target_names=report_names,
462
+ zero_division=0,
463
+ output_dict=True,
464
+ )
465
+
466
+ if not isinstance(report, dict):
467
+ msg = "classification_report did not return a dict as expected."
468
+ self.logger.error(msg)
469
+ raise TypeError(msg)
470
+
471
+ report_subset = {}
472
+ for k, v in report.items():
473
+ tmp = {}
474
+ if isinstance(v, dict) and "support" in v:
475
+ for k2, v2 in v.items():
476
+ if k2 != "support":
477
+ tmp[k2] = v2
478
+ if tmp:
479
+ report_subset[k] = tmp
480
+
481
+ if report_subset:
482
+ pm = PrettyMetrics(
483
+ report_subset,
484
+ precision=3,
485
+ title=f"{self.model_name} Zygosity Report",
486
+ )
487
+ pm.render()
488
+
489
+ viz = ClassificationReportVisualizer(reset_kwargs=self.plotter_.param_dict)
490
+
491
+ if not isinstance(report, dict):
492
+ msg = "classification_report did not return a dict as expected."
493
+ self.logger.error(msg)
494
+ raise TypeError(msg)
495
+
496
+ plots = viz.plot_all(
497
+ report,
498
+ title_prefix=f"{self.model_name} Zygosity Report",
499
+ show=getattr(self, "show_plots", False),
500
+ heatmap_classes_only=True,
501
+ )
502
+
503
+ # Reset the style from Optuna's plotting.
504
+ plt.rcParams.update(self.plotter_.param_dict)
505
+
506
+ for name, fig in plots.items():
507
+ fout = self.plots_dir / f"zygosity_report_{name}.{self.plot_format}"
508
+ if hasattr(fig, "savefig") and isinstance(fig, Figure):
509
+ fig.savefig(fout, dpi=300, facecolor="#111122")
510
+ plt.close(fig)
511
+ elif isinstance(fig, PlotlyFigure):
512
+ fig.write_html(file=fout.with_suffix(".html"))
513
+
514
+ viz._reset_mpl_style()
515
+
516
+ self._save_report(report, suffix="zygosity")
517
+
518
+ # Confusion matrix
519
+ self.plotter_.plot_confusion_matrix(
520
+ y_true, y_pred, label_names=report_names, prefix="zygosity"
521
+ )
522
+
523
+ def _evaluate_iupac10_and_plot(
524
+ self, y_true: np.ndarray, y_pred: np.ndarray
525
+ ) -> None:
526
+ """10-class IUPAC report & confusion matrix.
527
+
528
+ This method generates a classification report and confusion matrix for genotypes encoded using the 10 IUPAC codes (0-9). The IUPAC codes represent various nucleotide combinations, including ambiguous bases.
529
+
530
+ Args:
531
+ y_true (np.ndarray): True genotypes (0-9) for masked test cells.
532
+ y_pred (np.ndarray): Predicted genotypes (0-9) for masked test cells.
533
+ """
534
+ labels_idx = list(range(10))
535
+ labels_names = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
536
+
537
+ metrics = {
538
+ "accuracy": accuracy_score(y_true, y_pred),
539
+ "f1": f1_score(
540
+ y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
541
+ ),
542
+ "precision": precision_score(
543
+ y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
544
+ ),
545
+ "recall": recall_score(
546
+ y_true, y_pred, average="weighted", labels=labels_idx, zero_division=0
547
+ ),
548
+ }
549
+ self.metrics_.update({f"iupac_{k}": v for k, v in metrics.items()})
550
+
551
+ report = classification_report(
552
+ y_true,
553
+ y_pred,
554
+ labels=labels_idx,
555
+ target_names=labels_names,
556
+ zero_division=0,
557
+ output_dict=True,
558
+ )
559
+
560
+ if not isinstance(report, dict):
561
+ msg = "classification_report did not return a dict as expected."
562
+ self.logger.error(msg)
563
+ raise TypeError(msg)
564
+
565
+ report_subset = {}
566
+ for k, v in report.items():
567
+ tmp = {}
568
+ if isinstance(v, dict) and "support" in v:
569
+ for k2, v2 in v.items():
570
+ if k2 != "support":
571
+ tmp[k2] = v2
572
+ if tmp:
573
+ report_subset[k] = tmp
574
+
575
+ if report_subset:
576
+ pm = PrettyMetrics(
577
+ report_subset,
578
+ precision=3,
579
+ title=f"{self.model_name} IUPAC 10-Class Report",
580
+ )
581
+ pm.render()
582
+
583
+ self._save_report(report, suffix="iupac")
584
+
585
+ # Confusion matrix
586
+ self.plotter_.plot_confusion_matrix(
587
+ y_true, y_pred, label_names=labels_names, prefix="iupac"
588
+ )
589
+
590
+ def _make_train_test_split(self) -> Tuple[np.ndarray, np.ndarray]:
591
+ """Create train/test split indices.
592
+
593
+ This method generates training and testing indices for the dataset. If specific test indices are provided, it uses those; otherwise, it randomly selects a proportion of samples as the test set based on the specified test size. The method ensures that the selected test indices are within valid bounds and that there is no overlap between training and testing sets.
594
+
595
+ Returns:
596
+ Tuple[np.ndarray, np.ndarray]: Arrays of train and test indices.
597
+
598
+ Raises:
599
+ IndexError: If provided test_indices are out of bounds.
600
+ """
601
+ n = self.X012_.shape[0]
602
+ all_idx = np.arange(n, dtype=int)
603
+
604
+ if self.test_indices is not None:
605
+ test_idx = np.unique(self.test_indices)
606
+
607
+ if np.any((test_idx < 0) | (test_idx >= n)):
608
+ msg = "Some test_indices are out of bounds."
609
+ self.logger.error(msg)
610
+ raise IndexError(msg)
611
+
612
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
613
+ return train_idx, test_idx
614
+
615
+ k = int(round(self.test_size * n))
616
+
617
+ test_idx = (
618
+ self.rng.choice(n, size=k, replace=False)
619
+ if k > 0
620
+ else np.array([], dtype=int)
621
+ )
622
+
623
+ train_idx = np.setdiff1d(all_idx, test_idx, assume_unique=False)
624
+ return train_idx, test_idx
625
+
626
+ def _save_report(self, report_dict: Dict[str, float], suffix: str) -> None:
627
+ """Save classification report dictionary as a JSON file.
628
+
629
+ This method saves the provided classification report dictionary to a JSON file in the metrics directory. The filename includes a suffix to distinguish between different types of reports (e.g., 'zygosity' or 'iupac').
630
+
631
+ Args:
632
+ report_dict (Dict[str, float]): The classification report dictionary to save.
633
+ suffix (str): Suffix to append to the filename (e.g., 'zygosity' or 'iupac').
634
+
635
+ Raises:
636
+ NotFittedError: If fit() and transform() have not been called.
637
+ """
638
+ if not self.is_fit_ or self.X_imputed012_ is None:
639
+ raise NotFittedError("No report to save. Ensure fit() and transform() ran.")
640
+
641
+ out_fp = self.metrics_dir / f"classification_report_{suffix}.json"
642
+ with open(out_fp, "w") as f:
643
+ json.dump(report_dict, f, indent=4)
644
+ self.logger.info(f"{self.model_name} {suffix} report saved to {out_fp}.")
645
+
646
+ def _create_model_directories(self, prefix: str, outdirs: List[str]) -> None:
647
+ """Creates the directory structure for storing model outputs.
648
+
649
+ This method sets up a standardized folder hierarchy for saving models, plots, metrics, and optimization results, organized under a main directory named after the provided prefix.
650
+
651
+ Args:
652
+ prefix (str): The prefix for the main output directory.
653
+ outdirs (List[str]): A list of subdirectory names to create within the main directory.
654
+
655
+ Raises:
656
+ Exception: If any of the directories cannot be created.
657
+ """
658
+ formatted_output_dir = Path(f"{prefix}_output")
659
+ base_dir = formatted_output_dir / "Deterministic"
660
+
661
+ for d in outdirs:
662
+ subdir = base_dir / d / self.model_name
663
+ setattr(self, f"{d}_dir", subdir)
664
+ try:
665
+ getattr(self, f"{d}_dir").mkdir(parents=True, exist_ok=True)
666
+ except Exception as e:
667
+ msg = f"Failed to create directory {getattr(self, f'{d}_dir')}: {e}"
668
+ self.logger.error(msg)
669
+ raise Exception(msg)
File without changes