pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1228 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import optuna
9
+ import torch
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.model_selection import train_test_split
12
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+
16
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
17
+ from pgsui.data_processing.containers import VAEConfig
18
+ from pgsui.data_processing.transformers import SimMissingTransformer
19
+ from pgsui.impute.unsupervised.base import BaseNNImputer
20
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
21
+ from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
22
+ from pgsui.impute.unsupervised.models.vae_model import VAEModel
23
+ from pgsui.utils.logging_utils import configure_logger
24
+ from pgsui.utils.pretty_metrics import PrettyMetrics
25
+
26
+ if TYPE_CHECKING:
27
+ from snpio import TreeParser
28
+ from snpio.read_input.genotype_data import GenotypeData
29
+
30
+
31
+ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
32
+ """Normalize VAEConfig input from various sources.
33
+
34
+ Args:
35
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
36
+
37
+ Returns:
38
+ VAEConfig: Normalized configuration dataclass.
39
+ """
40
+ if config is None:
41
+ return VAEConfig()
42
+ if isinstance(config, VAEConfig):
43
+ return config
44
+ if isinstance(config, str):
45
+ return load_yaml_to_dataclass(config, VAEConfig)
46
+ if isinstance(config, dict):
47
+ base = VAEConfig()
48
+ # Respect top-level preset
49
+ preset = config.pop("preset", None)
50
+ if preset:
51
+ base = VAEConfig.from_preset(preset)
52
+ # Flatten + apply
53
+ flat: Dict[str, object] = {}
54
+
55
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
56
+ for k, v in d.items():
57
+ kk = f"{prefix}.{k}" if prefix else k
58
+ if isinstance(v, dict):
59
+ _flatten(kk, v, out)
60
+ else:
61
+ out[kk] = v
62
+ return out
63
+
64
+ flat = _flatten("", config, {})
65
+ return apply_dot_overrides(base, flat)
66
+ raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
67
+
68
+
69
+ class ImputeVAE(BaseNNImputer):
70
+ """Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
71
+
72
+ This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ genotype_data: "GenotypeData",
78
+ *,
79
+ tree_parser: Optional["TreeParser"] = None,
80
+ config: Optional[Union["VAEConfig", dict, str]] = None,
81
+ overrides: dict | None = None,
82
+ simulate_missing: bool | None = None,
83
+ sim_strategy: (
84
+ Literal[
85
+ "random",
86
+ "random_weighted",
87
+ "random_weighted_inv",
88
+ "nonrandom",
89
+ "nonrandom_weighted",
90
+ ]
91
+ | None
92
+ ) = None,
93
+ sim_prop: float | None = None,
94
+ sim_kwargs: dict | None = None,
95
+ ):
96
+ """Initialize the VAE imputer with a unified config interface.
97
+
98
+ This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
99
+
100
+ Args:
101
+ genotype_data (GenotypeData): Backing genotype data object.
102
+ tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
103
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
104
+ overrides (dict | None): Optional dot-key overrides with highest precedence.
105
+ simulate_missing (bool | None): Whether to simulate missing data during training.
106
+ sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
107
+ sim_prop (float | None): Proportion of data to simulate as missing if simulating.
108
+ sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
109
+ """
110
+ self.model_name = "ImputeVAE"
111
+ self.genotype_data = genotype_data
112
+ self.tree_parser = tree_parser
113
+
114
+ # Normalize configuration and apply top-precedence overrides
115
+ cfg = ensure_vae_config(config)
116
+ if overrides:
117
+ cfg = apply_dot_overrides(cfg, overrides)
118
+ self.cfg = cfg
119
+
120
+ # Logger (align with AE/NLPCA)
121
+ logman = LoggerManager(
122
+ __name__,
123
+ prefix=self.cfg.io.prefix,
124
+ debug=self.cfg.io.debug,
125
+ verbose=self.cfg.io.verbose,
126
+ )
127
+ self.logger = configure_logger(
128
+ logman.get_logger(),
129
+ verbose=self.cfg.io.verbose,
130
+ debug=self.cfg.io.debug,
131
+ )
132
+
133
+ # BaseNNImputer bootstraps device/dirs/log formatting
134
+ super().__init__(
135
+ model_name=self.model_name,
136
+ genotype_data=self.genotype_data,
137
+ prefix=self.cfg.io.prefix,
138
+ device=self.cfg.train.device,
139
+ verbose=self.cfg.io.verbose,
140
+ debug=self.cfg.io.debug,
141
+ )
142
+
143
+ # Model hook & encoder
144
+ self.Model = VAEModel
145
+ self.pgenc = GenotypeEncoder(genotype_data)
146
+
147
+ # IO/global
148
+ self.seed = self.cfg.io.seed
149
+ self.n_jobs = self.cfg.io.n_jobs
150
+ self.prefix = self.cfg.io.prefix
151
+ self.scoring_averaging = self.cfg.io.scoring_averaging
152
+ self.verbose = self.cfg.io.verbose
153
+ self.debug = self.cfg.io.debug
154
+ self.rng = np.random.default_rng(self.seed)
155
+
156
+ # Simulated-missing controls (config defaults + ctor overrides)
157
+ sim_cfg = getattr(self.cfg, "sim", None)
158
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
159
+ if sim_kwargs:
160
+ sim_cfg_kwargs.update(sim_kwargs)
161
+ if sim_cfg is None:
162
+ default_sim_flag = bool(simulate_missing)
163
+ default_strategy = "random"
164
+ default_prop = 0.10
165
+ else:
166
+ default_sim_flag = sim_cfg.simulate_missing
167
+ default_strategy = sim_cfg.sim_strategy
168
+ default_prop = sim_cfg.sim_prop
169
+ self.simulate_missing = (
170
+ default_sim_flag if simulate_missing is None else bool(simulate_missing)
171
+ )
172
+ self.sim_strategy = sim_strategy or default_strategy
173
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
174
+ self.sim_kwargs = sim_cfg_kwargs
175
+
176
+ # Model hyperparams (AE-parity)
177
+ self.latent_dim = self.cfg.model.latent_dim
178
+ self.dropout_rate = self.cfg.model.dropout_rate
179
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
180
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
181
+ self.layer_schedule = self.cfg.model.layer_schedule
182
+ self.activation = self.cfg.model.hidden_activation
183
+ self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
184
+
185
+ # VAE-only KL controls
186
+ self.kl_beta_final = self.cfg.vae.kl_beta
187
+ self.kl_warmup = self.cfg.vae.kl_warmup
188
+ self.kl_ramp = self.cfg.vae.kl_ramp
189
+
190
+ # Train hyperparams (AE-parity)
191
+ self.batch_size = self.cfg.train.batch_size
192
+ self.learning_rate = self.cfg.train.learning_rate
193
+ self.l1_penalty: float = self.cfg.train.l1_penalty
194
+ self.early_stop_gen = self.cfg.train.early_stop_gen
195
+ self.min_epochs = self.cfg.train.min_epochs
196
+ self.epochs = self.cfg.train.max_epochs
197
+ self.validation_split = self.cfg.train.validation_split
198
+ self.beta = self.cfg.train.weights_beta
199
+ self.max_ratio = self.cfg.train.weights_max_ratio
200
+
201
+ # Tuning (AE-parity surface; VAE ignores latent refinement during eval)
202
+ self.tune = self.cfg.tune.enabled
203
+ self.tune_fast = self.cfg.tune.fast
204
+ self.tune_batch_size = self.cfg.tune.batch_size
205
+ self.tune_epochs = self.cfg.tune.epochs
206
+ self.tune_eval_interval = self.cfg.tune.eval_interval
207
+ self.tune_metric: Literal[
208
+ "pr_macro",
209
+ "f1",
210
+ "accuracy",
211
+ "average_precision",
212
+ "precision",
213
+ "recall",
214
+ "roc_auc",
215
+ ] = self.cfg.tune.metric
216
+ self.n_trials = self.cfg.tune.n_trials
217
+ self.tune_save_db = self.cfg.tune.save_db
218
+ self.tune_resume = self.cfg.tune.resume
219
+ self.tune_max_samples = self.cfg.tune.max_samples
220
+ self.tune_max_loci = self.cfg.tune.max_loci
221
+ self.tune_patience = self.cfg.tune.patience
222
+
223
+ # Plotting (AE-parity)
224
+ self.plot_format = self.cfg.plot.fmt
225
+ self.plot_dpi = self.cfg.plot.dpi
226
+ self.plot_fontsize = self.cfg.plot.fontsize
227
+ self.title_fontsize = self.cfg.plot.fontsize
228
+ self.despine = self.cfg.plot.despine
229
+ self.show_plots = self.cfg.plot.show
230
+
231
+ # Derived at fit-time
232
+ self.is_haploid: bool = False
233
+ self.num_classes_: int = 3 # diploid default
234
+ self.model_params: Dict[str, Any] = {}
235
+ self.sim_mask_global_: np.ndarray | None = None
236
+ self.sim_mask_train_: np.ndarray | None = None
237
+ self.sim_mask_test_: np.ndarray | None = None
238
+
239
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
240
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
241
+ self.logger.error(msg)
242
+ raise ValueError(msg)
243
+
244
+ # -------------------- Fit -------------------- #
245
+ def fit(self) -> "ImputeVAE":
246
+ """Fit the VAE on 0/1/2 encoded genotypes (missing -> -1).
247
+
248
+ This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. Missing positions are encoded as -1 for loss masking (any simulated-missing loci are temporarily tagged with -9 by ``SimMissingTransformer`` before being re-encoded as -1). It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
249
+
250
+ Returns:
251
+ ImputeVAE: Fitted instance.
252
+
253
+ Raises:
254
+ RuntimeError: If training fails to produce a model.
255
+ """
256
+ self.logger.info(f"Fitting {self.model_name} model...")
257
+
258
+ # Data prep aligns with AE/NLPCA
259
+ X012 = self._get_float_genotypes(copy=True)
260
+ GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
261
+ self.ground_truth_ = GT_full.astype(np.int64, copy=False)
262
+
263
+ self.sim_mask_global_ = None
264
+ cache_key = self._sim_mask_cache_key()
265
+ if self.simulate_missing:
266
+ cached_mask = (
267
+ None if cache_key is None else self._sim_mask_cache.get(cache_key)
268
+ )
269
+ if cached_mask is not None:
270
+ self.sim_mask_global_ = cached_mask.copy()
271
+ else:
272
+ tr = SimMissingTransformer(
273
+ genotype_data=self.genotype_data,
274
+ tree_parser=self.tree_parser,
275
+ prop_missing=self.sim_prop,
276
+ strategy=self.sim_strategy,
277
+ missing_val=-9,
278
+ mask_missing=True,
279
+ verbose=self.verbose,
280
+ **self.sim_kwargs,
281
+ )
282
+ tr.fit(X012.copy())
283
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
284
+ if cache_key is not None:
285
+ self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
286
+
287
+ X_for_model = self.ground_truth_.copy()
288
+ X_for_model[self.sim_mask_global_] = -1
289
+ else:
290
+ X_for_model = self.ground_truth_.copy()
291
+
292
+ # Ploidy/classes
293
+ self.is_haploid = bool(
294
+ np.all(
295
+ np.isin(
296
+ self.genotype_data.snp_data,
297
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
298
+ )
299
+ )
300
+ )
301
+ self.ploidy = 1 if self.is_haploid else 2
302
+ self.num_classes_ = 2 if self.is_haploid else 3
303
+ self.logger.info(
304
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
305
+ f"using {self.num_classes_} classes."
306
+ )
307
+
308
+ if self.is_haploid:
309
+ self.ground_truth_[self.ground_truth_ == 2] = 1
310
+ X_for_model[X_for_model == 2] = 1
311
+
312
+ n_samples, self.num_features_ = X_for_model.shape
313
+
314
+ # Model params (decoder outputs L*K logits)
315
+ self.model_params = {
316
+ "n_features": self.num_features_,
317
+ "num_classes": self.num_classes_,
318
+ "latent_dim": self.latent_dim,
319
+ "dropout_rate": self.dropout_rate,
320
+ "activation": self.activation,
321
+ }
322
+
323
+ # Train/Val split
324
+ indices = np.arange(n_samples)
325
+ train_idx, val_idx = train_test_split(
326
+ indices, test_size=self.validation_split, random_state=self.seed
327
+ )
328
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
329
+ self.X_train_ = X_for_model[train_idx]
330
+ self.X_val_ = X_for_model[val_idx]
331
+ self.GT_train_full_ = self.ground_truth_[train_idx]
332
+ self.GT_test_full_ = self.ground_truth_[val_idx]
333
+
334
+ if self.sim_mask_global_ is not None:
335
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
336
+ self.sim_mask_test_ = self.sim_mask_global_[val_idx]
337
+ else:
338
+ self.sim_mask_train_ = None
339
+ self.sim_mask_test_ = None
340
+
341
+ # Plotters/scorers (shared utilities)
342
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
343
+
344
+ # Optional tuning
345
+ if self.tune:
346
+ self.tune_hyperparameters()
347
+
348
+ # Best params (tuned or default)
349
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
350
+
351
+ # Class weights (device-aware)
352
+ self.class_weights_ = self._normalize_class_weights(
353
+ self._class_weights_from_zygosity(self.X_train_)
354
+ )
355
+
356
+ # DataLoader
357
+ train_loader = self._get_data_loader(self.X_train_)
358
+
359
+ # Build & train
360
+ model = self.build_model(self.Model, self.best_params_)
361
+ model.apply(self.initialize_weights)
362
+
363
+ loss, trained_model, history = self._train_and_validate_model(
364
+ model=model,
365
+ loader=train_loader,
366
+ lr=self.learning_rate,
367
+ l1_penalty=self.l1_penalty,
368
+ return_history=True,
369
+ class_weights=self.class_weights_,
370
+ X_val=self.X_val_,
371
+ params=self.best_params_,
372
+ prune_metric=self.tune_metric,
373
+ prune_warmup_epochs=5,
374
+ eval_interval=1,
375
+ eval_requires_latents=False, # no latent refinement for eval
376
+ eval_latent_steps=0,
377
+ eval_latent_lr=0.0,
378
+ eval_latent_weight_decay=0.0,
379
+ )
380
+
381
+ if trained_model is None:
382
+ msg = "VAE training failed; no model was returned."
383
+ self.logger.error(msg)
384
+ raise RuntimeError(msg)
385
+
386
+ torch.save(
387
+ trained_model.state_dict(),
388
+ self.models_dir / f"final_model_{self.model_name}.pt",
389
+ )
390
+
391
+ hist: dict = {"Train": history}
392
+ self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
393
+ self.is_fit_ = True
394
+
395
+ # Evaluate (AE-parity reporting)
396
+ eval_mask = (
397
+ self.sim_mask_test_
398
+ if (self.simulate_missing and self.sim_mask_test_ is not None)
399
+ else None
400
+ )
401
+ self._evaluate_model(
402
+ self.X_val_,
403
+ self.model_,
404
+ self.best_params_,
405
+ eval_mask_override=eval_mask,
406
+ )
407
+
408
+ self.plotter_.plot_history(self.history_)
409
+ self._save_best_params(self.best_params_)
410
+ return self
411
+
412
+ def transform(self) -> np.ndarray:
413
+ """Impute missing genotypes and return IUPAC strings.
414
+
415
+ This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
416
+
417
+ Returns:
418
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
419
+
420
+ Raises:
421
+ NotFittedError: If called before fit().
422
+ """
423
+ if not getattr(self, "is_fit_", False):
424
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
425
+
426
+ self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
427
+ X_to_impute = self.ground_truth_.copy()
428
+
429
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
430
+
431
+ # Fill only missing
432
+ missing_mask = X_to_impute == -1
433
+ imputed_array = X_to_impute.copy()
434
+ imputed_array[missing_mask] = pred_labels[missing_mask]
435
+
436
+ # Decode to IUPAC & optionally plot
437
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
438
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
439
+
440
+ plt.rcParams.update(self.plotter_.param_dict)
441
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
442
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
443
+
444
+ return imputed_genotypes
445
+
446
+ # ---------- plumbing identical to AE, naming aligned ---------- #
447
+
448
+ def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
449
+ """Create DataLoader over indices + integer targets (-1 for missing).
450
+
451
+ This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
452
+
453
+ Args:
454
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
455
+
456
+ Returns:
457
+ torch.utils.data.DataLoader: Shuffled DataLoader.
458
+ """
459
+ y_tensor = torch.from_numpy(y).long()
460
+ indices = torch.arange(len(y), dtype=torch.long)
461
+ dataset = torch.utils.data.TensorDataset(indices, y_tensor)
462
+ pin_memory = self.device.type == "cuda"
463
+ return torch.utils.data.DataLoader(
464
+ dataset,
465
+ batch_size=self.batch_size,
466
+ shuffle=True,
467
+ pin_memory=pin_memory,
468
+ )
469
+
470
+ def _train_and_validate_model(
471
+ self,
472
+ model: torch.nn.Module,
473
+ loader: torch.utils.data.DataLoader,
474
+ lr: float,
475
+ l1_penalty: float,
476
+ trial: optuna.Trial | None = None,
477
+ return_history: bool = False,
478
+ class_weights: torch.Tensor | None = None,
479
+ *,
480
+ X_val: np.ndarray | None = None,
481
+ params: dict | None = None,
482
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
483
+ prune_warmup_epochs: int = 3,
484
+ eval_interval: int = 1,
485
+ eval_requires_latents: bool = False, # VAE: no latent eval refinement
486
+ eval_latent_steps: int = 0,
487
+ eval_latent_lr: float = 0.0,
488
+ eval_latent_weight_decay: float = 0.0,
489
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
490
+ """Wrap the VAE training loop with β-anneal & Optuna pruning.
491
+
492
+ This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
493
+
494
+ Args:
495
+ model (torch.nn.Module): VAE model.
496
+ loader (torch.utils.data.DataLoader): Training data loader.
497
+ lr (float): Learning rate.
498
+ l1_penalty (float): L1 regularization coefficient.
499
+ trial (optuna.Trial | None): Optuna trial for pruning.
500
+ return_history (bool): If True, return training history.
501
+ class_weights (torch.Tensor | None): CE class weights on device.
502
+ X_val (np.ndarray | None): Validation data for pruning eval.
503
+ params (dict | None): Current hyperparameters (for logging).
504
+ prune_metric (str | None): Metric for pruning decisions.
505
+ prune_warmup_epochs (int): Epochs to skip before pruning.
506
+ eval_interval (int): Epochs between validation evaluations.
507
+ eval_requires_latents (bool): If True, refine latents during eval.
508
+ eval_latent_steps (int): Latent refinement steps if needed.
509
+ eval_latent_lr (float): Latent refinement learning rate.
510
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
511
+
512
+ Returns:
513
+ Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
514
+ """
515
+ if class_weights is None:
516
+ msg = "Must provide class_weights."
517
+ self.logger.error(msg)
518
+ raise TypeError(msg)
519
+
520
+ max_epochs = (
521
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
522
+ )
523
+
524
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
525
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
526
+
527
+ best_loss, best_model, hist = self._execute_training_loop(
528
+ loader=loader,
529
+ optimizer=optimizer,
530
+ scheduler=scheduler,
531
+ model=model,
532
+ l1_penalty=l1_penalty,
533
+ trial=trial,
534
+ return_history=return_history,
535
+ class_weights=class_weights,
536
+ X_val=X_val,
537
+ params=params,
538
+ prune_metric=prune_metric,
539
+ prune_warmup_epochs=prune_warmup_epochs,
540
+ eval_interval=eval_interval,
541
+ eval_requires_latents=eval_requires_latents,
542
+ eval_latent_steps=eval_latent_steps,
543
+ eval_latent_lr=eval_latent_lr,
544
+ eval_latent_weight_decay=eval_latent_weight_decay,
545
+ )
546
+ if return_history:
547
+ return best_loss, best_model, hist
548
+
549
+ return best_loss, best_model, None
550
+
551
+ def _execute_training_loop(
552
+ self,
553
+ loader: torch.utils.data.DataLoader,
554
+ optimizer: torch.optim.Optimizer,
555
+ scheduler: CosineAnnealingLR,
556
+ model: torch.nn.Module,
557
+ l1_penalty: float,
558
+ trial: optuna.Trial | None,
559
+ return_history: bool,
560
+ class_weights: torch.Tensor,
561
+ *,
562
+ X_val: np.ndarray | None = None,
563
+ params: dict | None = None,
564
+ prune_metric: str | None = None,
565
+ prune_warmup_epochs: int = 3,
566
+ eval_interval: int = 1,
567
+ eval_requires_latents: bool = False,
568
+ eval_latent_steps: int = 0,
569
+ eval_latent_lr: float = 0.0,
570
+ eval_latent_weight_decay: float = 0.0,
571
+ ) -> Tuple[float, torch.nn.Module, list]:
572
+ """Train VAE with stable focal CE + KL(β) anneal and numeric guards.
573
+
574
+ This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
575
+
576
+ Args:
577
+ loader (torch.utils.data.DataLoader): Training data loader.
578
+ optimizer (torch.optim.Optimizer): Optimizer.
579
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
580
+ model (torch.nn.Module): VAE model.
581
+ l1_penalty (float): L1 regularization coefficient.
582
+ trial (optuna.Trial | None): Optuna trial for pruning.
583
+ return_history (bool): If True, return training history.
584
+ class_weights (torch.Tensor): CE class weights on device.
585
+ X_val (np.ndarray | None): Validation data for pruning eval.
586
+ params (dict | None): Current hyperparameters (for logging).
587
+ prune_metric (str | None): Metric for pruning decisions.
588
+ prune_warmup_epochs (int): Epochs to skip before pruning.
589
+ eval_interval (int): Epochs between validation evaluations.
590
+ eval_requires_latents (bool): If True, refine latents during eval.
591
+ eval_latent_steps (int): Latent refinement steps if needed.
592
+ eval_latent_lr (float): Latent refinement learning rate.
593
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
594
+
595
+ Returns:
596
+ Tuple[float, torch.nn.Module, list]: Best loss, best model, and training history.
597
+ """
598
+ best_model = None
599
+ history: list[float] = []
600
+
601
+ early_stopping = EarlyStopping(
602
+ patience=self.early_stop_gen,
603
+ min_epochs=self.min_epochs,
604
+ verbose=self.verbose,
605
+ prefix=self.prefix,
606
+ debug=self.debug,
607
+ )
608
+
609
+ # ---- scalarize schedule endpoints up front ----
610
+ kl = self.kl_beta_final
611
+ if isinstance(kl, (list, tuple)):
612
+ if len(kl) == 0:
613
+ msg = "kl_beta_final list is empty."
614
+ self.logger.error(msg)
615
+ raise ValueError(msg)
616
+ kl = kl[0]
617
+ beta_final = float(kl)
618
+
619
+ gamma_val = self.gamma
620
+ if isinstance(gamma_val, (list, tuple)):
621
+ if len(gamma_val) == 0:
622
+ raise ValueError("gamma list is empty.")
623
+ gamma_val = gamma_val[0]
624
+ gamma_final = float(gamma_val)
625
+
626
+ gamma_warm, gamma_ramp = 50, 100
627
+ beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
628
+
629
+ # Optional LR warmup
630
+ warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
631
+ base_lr = float(optimizer.param_groups[0]["lr"])
632
+ min_lr = base_lr * 0.1
633
+
634
+ max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
635
+
636
+ for epoch in range(max_epochs):
637
+ # focal γ schedule
638
+ if epoch < gamma_warm:
639
+ model.gamma = 0.0 # type: ignore[attr-defined]
640
+ elif epoch < gamma_warm + gamma_ramp:
641
+ model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
642
+ else:
643
+ model.gamma = gamma_final # type: ignore[attr-defined]
644
+
645
+ # KL β schedule (float throughout + ramp guard)
646
+ if epoch < beta_warm:
647
+ model.beta = 0.0 # type: ignore[attr-defined]
648
+ elif beta_ramp > 0 and epoch < beta_warm + beta_ramp:
649
+ model.beta = beta_final * ((epoch - beta_warm) / beta_ramp) # type: ignore[attr-defined]
650
+ else:
651
+ model.beta = beta_final # type: ignore[attr-defined]
652
+ # LR warmup
653
+ if epoch < warmup_epochs:
654
+ scale = float(epoch + 1) / warmup_epochs
655
+ for g in optimizer.param_groups:
656
+ g["lr"] = min_lr + (base_lr - min_lr) * scale
657
+
658
+ train_loss = self._train_step(
659
+ loader=loader,
660
+ optimizer=optimizer,
661
+ model=model,
662
+ l1_penalty=l1_penalty,
663
+ class_weights=class_weights,
664
+ )
665
+
666
+ if not np.isfinite(train_loss):
667
+ if trial:
668
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
669
+ # shrink LR and continue
670
+ for g in optimizer.param_groups:
671
+ g["lr"] *= 0.5
672
+ continue
673
+
674
+ if scheduler is not None:
675
+ scheduler.step()
676
+
677
+ if return_history:
678
+ history.append(train_loss)
679
+
680
+ early_stopping(train_loss, model)
681
+ if early_stopping.early_stop:
682
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
683
+ break
684
+
685
+ # Optional Optuna pruning on a validation metric
686
+ if (
687
+ trial is not None
688
+ and X_val is not None
689
+ and ((epoch + 1) % eval_interval == 0)
690
+ ):
691
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
692
+ mask_override = None
693
+ if (
694
+ self.simulate_missing
695
+ and getattr(self, "sim_mask_test_", None) is not None
696
+ and getattr(self, "X_val_", None) is not None
697
+ and X_val.shape == self.X_val_.shape
698
+ ):
699
+ mask_override = self.sim_mask_test_
700
+ metric_val = self._eval_for_pruning(
701
+ model=model,
702
+ X_val=X_val,
703
+ params=params or getattr(self, "best_params_", {}),
704
+ metric=metric_key,
705
+ objective_mode=True,
706
+ do_latent_infer=False, # VAE uses encoder; no latent refine
707
+ latent_steps=0,
708
+ latent_lr=0.0,
709
+ latent_weight_decay=0.0,
710
+ latent_seed=self.seed, # type: ignore
711
+ _latent_cache=None,
712
+ _latent_cache_key=None,
713
+ eval_mask_override=mask_override,
714
+ )
715
+ trial.report(metric_val, step=epoch + 1)
716
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
717
+ raise optuna.exceptions.TrialPruned(
718
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
719
+ )
720
+
721
+ best_loss = early_stopping.best_score
722
+ best_model = copy.deepcopy(early_stopping.best_model)
723
+ if best_model is None:
724
+ best_model = copy.deepcopy(model)
725
+ return best_loss, best_model, history
726
+
727
+ def _train_step(
728
+ self,
729
+ loader: torch.utils.data.DataLoader,
730
+ optimizer: torch.optim.Optimizer,
731
+ model: torch.nn.Module,
732
+ l1_penalty: float,
733
+ class_weights: torch.Tensor,
734
+ ) -> float:
735
+ """One epoch: inputs → VAE forward → focal CE + β·KL with guards.
736
+
737
+ This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
738
+
739
+ Args:
740
+ loader (torch.utils.data.DataLoader): Training data loader.
741
+ optimizer (torch.optim.Optimizer): Optimizer.
742
+ model (torch.nn.Module): VAE model.
743
+ l1_penalty (float): L1 regularization coefficient.
744
+ class_weights (torch.Tensor): CE class weights on device.
745
+
746
+ Returns:
747
+ float: Average training loss for the epoch.
748
+ """
749
+ model.train()
750
+ running, used = 0.0, 0
751
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
752
+ if class_weights is not None and class_weights.device != self.device:
753
+ class_weights = class_weights.to(self.device)
754
+
755
+ for _, y_batch in loader:
756
+ optimizer.zero_grad(set_to_none=True)
757
+
758
+ # targets: (B, L) int in {0,1,2,-1}
759
+ y_int = y_batch.to(self.device, non_blocking=True).long()
760
+
761
+ # inputs: one-hot with zeros for missing
762
+ x_ohe = self._one_hot_encode_012(y_int) # (B, L, K)
763
+
764
+ # Forward. Expect model to return recon_logits, mu, logvar, ...
765
+ out = model(x_ohe)
766
+ if isinstance(out, (list, tuple)):
767
+ recon_logits, mu, logvar = out[0], out[1], out[2]
768
+ else:
769
+ recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
770
+
771
+ # Upstream guard
772
+ if (
773
+ not torch.isfinite(recon_logits).all()
774
+ or not torch.isfinite(mu).all()
775
+ or not torch.isfinite(logvar).all()
776
+ ):
777
+ continue
778
+
779
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
780
+ beta = float(getattr(model, "beta", getattr(self, "kl_beta_final", 0.0)))
781
+ gamma = max(0.0, min(gamma, 10.0))
782
+
783
+ loss = compute_vae_loss(
784
+ recon_logits=recon_logits,
785
+ targets=y_int,
786
+ mu=mu,
787
+ logvar=logvar,
788
+ class_weights=class_weights,
789
+ gamma=gamma,
790
+ beta=beta,
791
+ )
792
+
793
+ if l1_penalty > 0:
794
+ l1 = torch.zeros((), device=self.device)
795
+ for p in l1_params:
796
+ l1 = l1 + p.abs().sum()
797
+ loss = loss + l1_penalty * l1
798
+
799
+ if not torch.isfinite(loss):
800
+ continue
801
+
802
+ loss.backward()
803
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
804
+
805
+ # skip update if any grad is non-finite
806
+ bad = any(
807
+ p.grad is not None and not torch.isfinite(p.grad).all()
808
+ for p in model.parameters()
809
+ )
810
+ if bad:
811
+ optimizer.zero_grad(set_to_none=True)
812
+ continue
813
+
814
+ optimizer.step()
815
+
816
+ running += float(loss.detach().item())
817
+ used += 1
818
+
819
+ return (running / used) if used > 0 else float("inf")
820
+
821
+ def _predict(
822
+ self,
823
+ model: torch.nn.Module,
824
+ X: np.ndarray | torch.Tensor,
825
+ return_proba: bool = False,
826
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
827
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
828
+
829
+ This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
830
+
831
+ Args:
832
+ model (torch.nn.Module): Trained model.
833
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
834
+ return_proba (bool): If True, also return probabilities.
835
+
836
+ Returns:
837
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
838
+ """
839
+ if model is None:
840
+ msg = "Model is not trained. Call fit() before predict()."
841
+ self.logger.error(msg)
842
+ raise NotFittedError(msg)
843
+
844
+ model.eval()
845
+ with torch.no_grad():
846
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
847
+ X_tensor = X_tensor.to(self.device).long()
848
+ x_ohe = self._one_hot_encode_012(X_tensor)
849
+ outputs = model(x_ohe) # first element must be recon logits
850
+ logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
851
+ probas = torch.softmax(logits, dim=-1)
852
+ labels = torch.argmax(probas, dim=-1)
853
+
854
+ if return_proba:
855
+ return labels.cpu().numpy(), probas.cpu().numpy()
856
+
857
+ return labels.cpu().numpy()
858
+
859
+ def _evaluate_model(
860
+ self,
861
+ X_val: np.ndarray,
862
+ model: torch.nn.Module,
863
+ params: dict,
864
+ objective_mode: bool = False,
865
+ latent_vectors_val: np.ndarray | None = None,
866
+ *,
867
+ eval_mask_override: np.ndarray | None = None,
868
+ ) -> Dict[str, float]:
869
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
870
+
871
+ This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
872
+
873
+ Args:
874
+ X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
875
+ model (torch.nn.Module): Trained model.
876
+ params (dict): Current hyperparameters (for logging).
877
+ objective_mode (bool): If True, minimize logging for Optuna.
878
+ latent_vectors_val (np.ndarray | None): Not used by VAE.
879
+ eval_mask_override (np.ndarray | None): Optional mask to override default eval mask.
880
+
881
+ Returns:
882
+ Dict[str, float]: Computed metrics.
883
+
884
+ Raises:
885
+ NotFittedError: If called before fit().
886
+ """
887
+ pred_labels, pred_probas = self._predict(
888
+ model=model, X=X_val, return_proba=True
889
+ )
890
+
891
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
892
+
893
+ # FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
894
+ if (
895
+ hasattr(self, "X_val_")
896
+ and getattr(self, "X_val_", None) is not None
897
+ and X_val.shape[0] == self.X_val_.shape[0]
898
+ ):
899
+ GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
900
+ elif (
901
+ hasattr(self, "X_train_")
902
+ and getattr(self, "X_train_", None) is not None
903
+ and X_val.shape[0] == self.X_train_.shape[0]
904
+ ):
905
+ GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
906
+ else:
907
+ GT_ref = self.ground_truth_
908
+
909
+ # FIX 2: Handle Feature Mismatch
910
+ # If GT has more columns than X_val, slice it to match.
911
+ if GT_ref.shape[1] > X_val.shape[1]:
912
+ GT_ref = GT_ref[:, : X_val.shape[1]]
913
+
914
+ # Fallback safeguard
915
+ if GT_ref.shape != X_val.shape:
916
+ GT_ref = X_val
917
+
918
+ # FIX 3: Allow override mask to be sliced if it's too wide
919
+ if eval_mask_override is not None:
920
+ if eval_mask_override.shape[0] != X_val.shape[0]:
921
+ msg = (
922
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
923
+ f"does not match X_val rows {X_val.shape[0]}"
924
+ )
925
+ self.logger.error(msg)
926
+ raise ValueError(msg)
927
+
928
+ if eval_mask_override.shape[1] > X_val.shape[1]:
929
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
930
+ else:
931
+ eval_mask = eval_mask_override.astype(bool)
932
+ else:
933
+ eval_mask = X_val != -1
934
+
935
+ eval_mask = eval_mask & finite_mask & (GT_ref != -1)
936
+
937
+ y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
938
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
939
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
940
+
941
+ if y_true_flat.size == 0:
942
+ return {self.tune_metric: 0.0}
943
+
944
+ # ensure valid probability simplex after masking
945
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
946
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
947
+ row_sums[row_sums == 0] = 1.0
948
+ y_proba_flat = y_proba_flat / row_sums
949
+
950
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
951
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
952
+
953
+ if self.is_haploid:
954
+ y_true_flat = y_true_flat.copy()
955
+ y_pred_flat = y_pred_flat.copy()
956
+ y_true_flat[y_true_flat == 2] = 1
957
+ y_pred_flat[y_pred_flat == 2] = 1
958
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
959
+ proba_2[:, 0] = y_proba_flat[:, 0]
960
+ proba_2[:, 1] = y_proba_flat[:, 2]
961
+ y_proba_flat = proba_2
962
+
963
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
964
+
965
+ metrics = self.scorers_.evaluate(
966
+ y_true_flat,
967
+ y_pred_flat,
968
+ y_true_ohe,
969
+ y_proba_flat,
970
+ objective_mode,
971
+ self.tune_metric,
972
+ )
973
+
974
+ if not objective_mode:
975
+ pm = PrettyMetrics(
976
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
977
+ )
978
+ pm.render()
979
+
980
+ # Primary report
981
+ self._make_class_reports(
982
+ y_true=y_true_flat,
983
+ y_pred_proba=y_proba_flat,
984
+ y_pred=y_pred_flat,
985
+ metrics=metrics,
986
+ labels=target_names,
987
+ )
988
+
989
+ # IUPAC decode & 10-base integer report
990
+ # FIX 4: Use current shape (X_val.shape) not self.num_features_
991
+ y_true_dec = self.pgenc.decode_012(
992
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
993
+ )
994
+ X_pred = X_val.copy()
995
+ X_pred[eval_mask] = y_pred_flat
996
+ y_pred_dec = self.pgenc.decode_012(
997
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
998
+ )
999
+
1000
+ encodings_dict = {
1001
+ "A": 0,
1002
+ "C": 1,
1003
+ "G": 2,
1004
+ "T": 3,
1005
+ "W": 4,
1006
+ "R": 5,
1007
+ "M": 6,
1008
+ "K": 7,
1009
+ "Y": 8,
1010
+ "S": 9,
1011
+ "N": -1,
1012
+ }
1013
+ y_true_int = self.pgenc.convert_int_iupac(
1014
+ y_true_dec, encodings_dict=encodings_dict
1015
+ )
1016
+ y_pred_int = self.pgenc.convert_int_iupac(
1017
+ y_pred_dec, encodings_dict=encodings_dict
1018
+ )
1019
+
1020
+ valid_iupac_mask = y_true_int[eval_mask] >= 0
1021
+ if valid_iupac_mask.any():
1022
+ self._make_class_reports(
1023
+ y_true=y_true_int[eval_mask][valid_iupac_mask],
1024
+ y_pred=y_pred_int[eval_mask][valid_iupac_mask],
1025
+ metrics=metrics,
1026
+ y_pred_proba=None,
1027
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
1028
+ )
1029
+ else:
1030
+ self.logger.warning(
1031
+ "Skipped IUPAC confusion matrix: No valid ground truths."
1032
+ )
1033
+
1034
+ return metrics
1035
+
1036
+ def _objective(self, trial: optuna.Trial) -> float:
1037
+ """Optuna objective for VAE (no latent refinement during eval).
1038
+
1039
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, trains the VAE model with these parameters, and evaluates its performance on a validation set. The evaluation metric specified by `self.tune_metric` is returned for optimization. If training fails, the trial is pruned to keep the tuning process efficient.
1040
+
1041
+ Args:
1042
+ trial (optuna.Trial): Optuna trial object.
1043
+
1044
+ Returns:
1045
+ float: Value of the tuning metric to be optimized.
1046
+ """
1047
+ try:
1048
+ params = self._sample_hyperparameters(trial)
1049
+
1050
+ X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1051
+ X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1052
+
1053
+ class_weights = self._normalize_class_weights(
1054
+ self._class_weights_from_zygosity(X_train)
1055
+ )
1056
+ train_loader = self._get_data_loader(X_train)
1057
+
1058
+ model = self.build_model(self.Model, params["model_params"])
1059
+ model.apply(self.initialize_weights)
1060
+
1061
+ lr: float = params["lr"]
1062
+ l1_penalty: float = params["l1_penalty"]
1063
+
1064
+ # Train + prune on metric
1065
+ _, model, __ = self._train_and_validate_model(
1066
+ model=model,
1067
+ loader=train_loader,
1068
+ lr=lr,
1069
+ l1_penalty=l1_penalty,
1070
+ trial=trial,
1071
+ return_history=False,
1072
+ class_weights=class_weights,
1073
+ X_val=X_val,
1074
+ params=params,
1075
+ prune_metric=self.tune_metric,
1076
+ prune_warmup_epochs=5,
1077
+ eval_interval=self.tune_eval_interval,
1078
+ eval_requires_latents=False,
1079
+ eval_latent_steps=0,
1080
+ eval_latent_lr=0.0,
1081
+ eval_latent_weight_decay=0.0,
1082
+ )
1083
+
1084
+ eval_mask = (
1085
+ self.sim_mask_test_
1086
+ if (
1087
+ self.simulate_missing
1088
+ and getattr(self, "sim_mask_test_", None) is not None
1089
+ )
1090
+ else None
1091
+ )
1092
+
1093
+ if model is None:
1094
+ raise RuntimeError("Model training failed; no model was returned.")
1095
+
1096
+ metrics = self._evaluate_model(
1097
+ X_val, model, params, objective_mode=True, eval_mask_override=eval_mask
1098
+ )
1099
+ self._clear_resources(model, train_loader)
1100
+ return metrics[self.tune_metric]
1101
+
1102
+ except Exception as e:
1103
+ # Keep sweeps moving
1104
+ self.logger.debug(f"Trial failed with error: {e}")
1105
+ raise optuna.exceptions.TrialPruned(
1106
+ f"Trial failed with error. Enable debug logging for details."
1107
+ )
1108
+
1109
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1110
+ """Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
1111
+
1112
+ Args:
1113
+ trial (optuna.Trial): Optuna trial object.
1114
+
1115
+ Returns:
1116
+ Dict[str, int | float | str]: Sampled hyperparameters.
1117
+ """
1118
+ params = {
1119
+ "latent_dim": trial.suggest_int("latent_dim", 2, 64),
1120
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
1121
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
1122
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
1123
+ "activation": trial.suggest_categorical(
1124
+ "activation", ["relu", "elu", "selu"]
1125
+ ),
1126
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
1127
+ "layer_scaling_factor": trial.suggest_float(
1128
+ "layer_scaling_factor", 2.0, 10.0
1129
+ ),
1130
+ "layer_schedule": trial.suggest_categorical(
1131
+ "layer_schedule", ["pyramid", "constant", "linear"]
1132
+ ),
1133
+ # VAE-specific β (final value after anneal)
1134
+ "beta": trial.suggest_float("beta", 0.25, 4.0),
1135
+ # focal gamma (if used in VAE recon CE)
1136
+ "gamma": trial.suggest_float("gamma", 0.0, 5.0),
1137
+ }
1138
+
1139
+ input_dim = self.num_features_ * self.num_classes_
1140
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1141
+ n_inputs=input_dim,
1142
+ n_outputs=input_dim,
1143
+ n_samples=len(self.train_idx_),
1144
+ n_hidden=params["num_hidden_layers"],
1145
+ alpha=params["layer_scaling_factor"],
1146
+ schedule=params["layer_schedule"],
1147
+ )
1148
+
1149
+ # [latent_dim] + interior widths (exclude output width)
1150
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1151
+
1152
+ params["model_params"] = {
1153
+ "n_features": self.num_features_,
1154
+ "num_classes": self.num_classes_,
1155
+ "latent_dim": params["latent_dim"],
1156
+ "dropout_rate": params["dropout_rate"],
1157
+ "hidden_layer_sizes": hidden_only,
1158
+ "activation": params["activation"],
1159
+ # Pass through VAE recon/regularization coefficients
1160
+ "beta": params["beta"],
1161
+ "gamma": params["gamma"],
1162
+ }
1163
+ return params
1164
+
1165
+ def _set_best_params(self, best_params: dict) -> dict:
1166
+ """Adopt best params and return VAE model_params.
1167
+
1168
+ Args:
1169
+ best_params (dict): Best hyperparameters from tuning.
1170
+
1171
+ Returns:
1172
+ dict: VAE model parameters.
1173
+ """
1174
+ self.latent_dim = best_params["latent_dim"]
1175
+ self.dropout_rate = best_params["dropout_rate"]
1176
+ self.learning_rate = best_params["learning_rate"]
1177
+ self.l1_penalty = best_params["l1_penalty"]
1178
+ self.activation = best_params["activation"]
1179
+ self.layer_scaling_factor = best_params["layer_scaling_factor"]
1180
+ self.layer_schedule = best_params["layer_schedule"]
1181
+ self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
1182
+ self.gamma = best_params.get("gamma", self.gamma)
1183
+
1184
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1185
+ n_inputs=self.num_features_ * self.num_classes_,
1186
+ n_outputs=self.num_features_ * self.num_classes_,
1187
+ n_samples=len(self.train_idx_),
1188
+ n_hidden=best_params["num_hidden_layers"],
1189
+ alpha=best_params["layer_scaling_factor"],
1190
+ schedule=best_params["layer_schedule"],
1191
+ )
1192
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1193
+
1194
+ return {
1195
+ "n_features": self.num_features_,
1196
+ "latent_dim": self.latent_dim,
1197
+ "hidden_layer_sizes": hidden_only,
1198
+ "dropout_rate": self.dropout_rate,
1199
+ "activation": self.activation,
1200
+ "num_classes": self.num_classes_,
1201
+ "beta": self.kl_beta_final,
1202
+ "gamma": self.gamma,
1203
+ }
1204
+
1205
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
1206
+ """Default VAE model params when tuning is disabled.
1207
+
1208
+ Returns:
1209
+ Dict[str, int | float | str | list]: VAE model parameters.
1210
+ """
1211
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1212
+ n_inputs=self.num_features_ * self.num_classes_,
1213
+ n_outputs=self.num_features_ * self.num_classes_,
1214
+ n_samples=len(self.ground_truth_),
1215
+ n_hidden=self.num_hidden_layers,
1216
+ alpha=self.layer_scaling_factor,
1217
+ schedule=self.layer_schedule,
1218
+ )
1219
+ return {
1220
+ "n_features": self.num_features_,
1221
+ "latent_dim": self.latent_dim,
1222
+ "hidden_layer_sizes": hidden_layer_sizes,
1223
+ "dropout_rate": self.dropout_rate,
1224
+ "activation": self.activation,
1225
+ "num_classes": self.num_classes_,
1226
+ "beta": self.kl_beta_final,
1227
+ "gamma": self.gamma,
1228
+ }