pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -0,0 +1,1316 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import optuna
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from sklearn.exceptions import NotFittedError
12
+ from sklearn.model_selection import train_test_split
13
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
14
+ from snpio.utils.logging import LoggerManager
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+
17
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
18
+ from pgsui.data_processing.containers import VAEConfig
19
+ from pgsui.data_processing.transformers import SimMissingTransformer
20
+ from pgsui.impute.unsupervised.base import BaseNNImputer
21
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
22
+ from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
23
+ from pgsui.impute.unsupervised.models.vae_model import VAEModel
24
+ from pgsui.utils.logging_utils import configure_logger
25
+ from pgsui.utils.pretty_metrics import PrettyMetrics
26
+
27
+ if TYPE_CHECKING:
28
+ from snpio import TreeParser
29
+ from snpio.read_input.genotype_data import GenotypeData
30
+
31
+
32
+ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
33
+ """Normalize VAEConfig input from various sources.
34
+
35
+ Args:
36
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
37
+
38
+ Returns:
39
+ VAEConfig: Normalized configuration dataclass.
40
+ """
41
+ if config is None:
42
+ return VAEConfig()
43
+ if isinstance(config, VAEConfig):
44
+ return config
45
+ if isinstance(config, str):
46
+ return load_yaml_to_dataclass(config, VAEConfig)
47
+ if isinstance(config, dict):
48
+ base = VAEConfig()
49
+ # Respect top-level preset
50
+ preset = config.pop("preset", None)
51
+ if preset:
52
+ base = VAEConfig.from_preset(preset)
53
+ # Flatten + apply
54
+ flat: Dict[str, object] = {}
55
+
56
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
57
+ for k, v in d.items():
58
+ kk = f"{prefix}.{k}" if prefix else k
59
+ if isinstance(v, dict):
60
+ _flatten(kk, v, out)
61
+ else:
62
+ out[kk] = v
63
+ return out
64
+
65
+ flat = _flatten("", config, {})
66
+ return apply_dot_overrides(base, flat)
67
+ raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
68
+
69
+
70
+ class ImputeVAE(BaseNNImputer):
71
+ """Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
72
+
73
+ This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ genotype_data: "GenotypeData",
79
+ *,
80
+ tree_parser: Optional["TreeParser"] = None,
81
+ config: Optional[Union["VAEConfig", dict, str]] = None,
82
+ overrides: dict | None = None,
83
+ simulate_missing: bool | None = None,
84
+ sim_strategy: (
85
+ Literal[
86
+ "random",
87
+ "random_weighted",
88
+ "random_weighted_inv",
89
+ "nonrandom",
90
+ "nonrandom_weighted",
91
+ ]
92
+ | None
93
+ ) = None,
94
+ sim_prop: float | None = None,
95
+ sim_kwargs: dict | None = None,
96
+ ):
97
+ """Initialize the VAE imputer with a unified config interface.
98
+
99
+ This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
100
+
101
+ Args:
102
+ genotype_data (GenotypeData): Backing genotype data object.
103
+ tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
104
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
105
+ overrides (dict | None): Optional dot-key overrides with highest precedence.
106
+ simulate_missing (bool | None): Whether to simulate missing data during training.
107
+ sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
108
+ sim_prop (float | None): Proportion of data to simulate as missing if simulating.
109
+ sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
110
+ """
111
+ self.model_name = "ImputeVAE"
112
+ self.genotype_data = genotype_data
113
+ self.tree_parser = tree_parser
114
+
115
+ # Normalize configuration and apply top-precedence overrides
116
+ cfg = ensure_vae_config(config)
117
+ if overrides:
118
+ cfg = apply_dot_overrides(cfg, overrides)
119
+ self.cfg = cfg
120
+
121
+ # Logger (align with AE/NLPCA)
122
+ logman = LoggerManager(
123
+ __name__,
124
+ prefix=self.cfg.io.prefix,
125
+ debug=self.cfg.io.debug,
126
+ verbose=self.cfg.io.verbose,
127
+ )
128
+ self.logger = configure_logger(
129
+ logman.get_logger(),
130
+ verbose=self.cfg.io.verbose,
131
+ debug=self.cfg.io.debug,
132
+ )
133
+
134
+ # BaseNNImputer bootstraps device/dirs/log formatting
135
+ super().__init__(
136
+ model_name=self.model_name,
137
+ genotype_data=self.genotype_data,
138
+ prefix=self.cfg.io.prefix,
139
+ device=self.cfg.train.device,
140
+ verbose=self.cfg.io.verbose,
141
+ debug=self.cfg.io.debug,
142
+ )
143
+
144
+ # Model hook & encoder
145
+ self.Model = VAEModel
146
+ self.pgenc = GenotypeEncoder(genotype_data)
147
+
148
+ # IO/global
149
+ self.seed = self.cfg.io.seed
150
+ self.n_jobs = self.cfg.io.n_jobs
151
+ self.prefix = self.cfg.io.prefix
152
+ self.scoring_averaging = self.cfg.io.scoring_averaging
153
+ self.verbose = self.cfg.io.verbose
154
+ self.debug = self.cfg.io.debug
155
+ self.rng = np.random.default_rng(self.seed)
156
+ self.pos_weights_: torch.Tensor | None = None
157
+
158
+ # Simulated-missing controls (config defaults + ctor overrides)
159
+ sim_cfg = getattr(self.cfg, "sim", None)
160
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
161
+ if sim_kwargs:
162
+ sim_cfg_kwargs.update(sim_kwargs)
163
+ if sim_cfg is None:
164
+ default_sim_flag = bool(simulate_missing)
165
+ default_strategy = "random"
166
+ default_prop = 0.10
167
+ else:
168
+ default_sim_flag = sim_cfg.simulate_missing
169
+ default_strategy = sim_cfg.sim_strategy
170
+ default_prop = sim_cfg.sim_prop
171
+ self.simulate_missing = (
172
+ default_sim_flag if simulate_missing is None else bool(simulate_missing)
173
+ )
174
+ self.sim_strategy = sim_strategy or default_strategy
175
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
176
+ self.sim_kwargs = sim_cfg_kwargs
177
+
178
+ # Model hyperparams (AE-parity)
179
+ self.latent_dim = self.cfg.model.latent_dim
180
+ self.dropout_rate = self.cfg.model.dropout_rate
181
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
182
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
183
+ self.layer_schedule = self.cfg.model.layer_schedule
184
+ self.activation = self.cfg.model.hidden_activation
185
+ self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
186
+
187
+ # VAE-only KL controls
188
+ self.kl_beta_final = self.cfg.vae.kl_beta
189
+ self.kl_warmup = self.cfg.vae.kl_warmup
190
+ self.kl_ramp = self.cfg.vae.kl_ramp
191
+
192
+ # Train hyperparams (AE-parity)
193
+ self.batch_size = self.cfg.train.batch_size
194
+ self.learning_rate = self.cfg.train.learning_rate
195
+ self.l1_penalty: float = self.cfg.train.l1_penalty
196
+ self.early_stop_gen = self.cfg.train.early_stop_gen
197
+ self.min_epochs = self.cfg.train.min_epochs
198
+ self.epochs = self.cfg.train.max_epochs
199
+ self.validation_split = self.cfg.train.validation_split
200
+ self.beta = self.cfg.train.weights_beta
201
+ self.max_ratio = self.cfg.train.weights_max_ratio
202
+
203
+ # Tuning (AE-parity surface; VAE ignores latent refinement during eval)
204
+ self.tune = self.cfg.tune.enabled
205
+ self.tune_fast = self.cfg.tune.fast
206
+ self.tune_batch_size = self.cfg.tune.batch_size
207
+ self.tune_epochs = self.cfg.tune.epochs
208
+ self.tune_eval_interval = self.cfg.tune.eval_interval
209
+ self.tune_metric: Literal[
210
+ "pr_macro",
211
+ "f1",
212
+ "accuracy",
213
+ "average_precision",
214
+ "precision",
215
+ "recall",
216
+ "roc_auc",
217
+ ] = self.cfg.tune.metric
218
+ self.n_trials = self.cfg.tune.n_trials
219
+ self.tune_save_db = self.cfg.tune.save_db
220
+ self.tune_resume = self.cfg.tune.resume
221
+ self.tune_max_samples = self.cfg.tune.max_samples
222
+ self.tune_max_loci = self.cfg.tune.max_loci
223
+ self.tune_patience = self.cfg.tune.patience
224
+
225
+ # Plotting (AE-parity)
226
+ self.plot_format = self.cfg.plot.fmt
227
+ self.plot_dpi = self.cfg.plot.dpi
228
+ self.plot_fontsize = self.cfg.plot.fontsize
229
+ self.title_fontsize = self.cfg.plot.fontsize
230
+ self.despine = self.cfg.plot.despine
231
+ self.show_plots = self.cfg.plot.show
232
+
233
+ # Derived at fit-time
234
+ self.is_haploid: bool = False
235
+ self.num_classes_: int = 3 # diploid default
236
+ self.model_params: Dict[str, Any] = {}
237
+ self.sim_mask_global_: np.ndarray | None = None
238
+ self.sim_mask_train_: np.ndarray | None = None
239
+ self.sim_mask_test_: np.ndarray | None = None
240
+
241
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
242
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
243
+ self.logger.error(msg)
244
+ raise ValueError(msg)
245
+
246
+ # -------------------- Fit -------------------- #
247
+ def fit(self) -> "ImputeVAE":
248
+ """Fit the VAE on 0/1/2 encoded genotypes (missing -> -1).
249
+
250
+ This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. Missing positions are encoded as -1 for loss masking (any simulated-missing loci are temporarily tagged with -9 by ``SimMissingTransformer`` before being re-encoded as -1). It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
251
+
252
+ Returns:
253
+ ImputeVAE: Fitted instance.
254
+
255
+ Raises:
256
+ RuntimeError: If training fails to produce a model.
257
+ """
258
+ self.logger.info(f"Fitting {self.model_name} model...")
259
+
260
+ # Data prep aligns with AE/NLPCA
261
+ X012 = self._get_float_genotypes(copy=True)
262
+ GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
263
+ self.ground_truth_ = GT_full.astype(np.int64, copy=False)
264
+
265
+ self.sim_mask_global_ = None
266
+ cache_key = self._sim_mask_cache_key()
267
+ if self.simulate_missing:
268
+ cached_mask = (
269
+ None if cache_key is None else self._sim_mask_cache.get(cache_key)
270
+ )
271
+ if cached_mask is not None:
272
+ self.sim_mask_global_ = cached_mask.copy()
273
+ else:
274
+ tr = SimMissingTransformer(
275
+ genotype_data=self.genotype_data,
276
+ tree_parser=self.tree_parser,
277
+ prop_missing=self.sim_prop,
278
+ strategy=self.sim_strategy,
279
+ missing_val=-9,
280
+ mask_missing=True,
281
+ verbose=self.verbose,
282
+ **self.sim_kwargs,
283
+ )
284
+ tr.fit(X012.copy())
285
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
286
+ if cache_key is not None:
287
+ self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
288
+
289
+ X_for_model = self.ground_truth_.copy()
290
+ X_for_model[self.sim_mask_global_] = -1
291
+ else:
292
+ X_for_model = self.ground_truth_.copy()
293
+
294
+ # Ploidy/classes
295
+ self.is_haploid = bool(
296
+ np.all(
297
+ np.isin(
298
+ self.genotype_data.snp_data,
299
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
300
+ )
301
+ )
302
+ )
303
+ self.ploidy = 1 if self.is_haploid else 2
304
+ self.num_classes_ = 2 if self.is_haploid else 3
305
+ self.output_classes_ = 2
306
+ self.logger.info(
307
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
308
+ f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
309
+ )
310
+
311
+ if self.is_haploid:
312
+ self.ground_truth_[self.ground_truth_ == 2] = 1
313
+ X_for_model[X_for_model == 2] = 1
314
+
315
+ n_samples, self.num_features_ = X_for_model.shape
316
+
317
+ # Model params (decoder outputs L*K logits)
318
+ self.model_params = {
319
+ "n_features": self.num_features_,
320
+ "num_classes": self.output_classes_,
321
+ "latent_dim": self.latent_dim,
322
+ "dropout_rate": self.dropout_rate,
323
+ "activation": self.activation,
324
+ }
325
+
326
+ # Train/Val split
327
+ indices = np.arange(n_samples)
328
+ train_idx, val_idx = train_test_split(
329
+ indices, test_size=self.validation_split, random_state=self.seed
330
+ )
331
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
332
+ self.X_train_ = X_for_model[train_idx]
333
+ self.X_val_ = X_for_model[val_idx]
334
+ self.GT_train_full_ = self.ground_truth_[train_idx]
335
+ self.GT_test_full_ = self.ground_truth_[val_idx]
336
+
337
+ if self.sim_mask_global_ is not None:
338
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
339
+ self.sim_mask_test_ = self.sim_mask_global_[val_idx]
340
+ else:
341
+ self.sim_mask_train_ = None
342
+ self.sim_mask_test_ = None
343
+
344
+ # Plotters/scorers (shared utilities)
345
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
346
+
347
+ # Optional tuning
348
+ if self.tune:
349
+ self.tune_hyperparameters()
350
+
351
+ # Best params (tuned or default)
352
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
353
+
354
+ # Class weights (device-aware)
355
+ self.class_weights_ = self._normalize_class_weights(
356
+ self._class_weights_from_zygosity(self.X_train_)
357
+ )
358
+ if not self.is_haploid:
359
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
360
+ else:
361
+ self.pos_weights_ = None
362
+
363
+ # DataLoader
364
+ train_loader = self._get_data_loader(self.X_train_)
365
+
366
+ # Build & train
367
+ model = self.build_model(self.Model, self.best_params_)
368
+ model.apply(self.initialize_weights)
369
+
370
+ loss, trained_model, history = self._train_and_validate_model(
371
+ model=model,
372
+ loader=train_loader,
373
+ lr=self.learning_rate,
374
+ l1_penalty=self.l1_penalty,
375
+ return_history=True,
376
+ class_weights=self.class_weights_,
377
+ X_val=self.X_val_,
378
+ params=self.best_params_,
379
+ prune_metric=self.tune_metric,
380
+ prune_warmup_epochs=10,
381
+ eval_interval=1,
382
+ eval_requires_latents=False, # no latent refinement for eval
383
+ eval_latent_steps=0,
384
+ eval_latent_lr=0.0,
385
+ eval_latent_weight_decay=0.0,
386
+ )
387
+
388
+ if trained_model is None:
389
+ msg = "VAE training failed; no model was returned."
390
+ self.logger.error(msg)
391
+ raise RuntimeError(msg)
392
+
393
+ torch.save(
394
+ trained_model.state_dict(),
395
+ self.models_dir / f"final_model_{self.model_name}.pt",
396
+ )
397
+
398
+ hist: dict = {"Train": history}
399
+ self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
400
+ self.is_fit_ = True
401
+
402
+ # Evaluate (AE-parity reporting)
403
+ eval_mask = (
404
+ self.sim_mask_test_
405
+ if (self.simulate_missing and self.sim_mask_test_ is not None)
406
+ else None
407
+ )
408
+ self._evaluate_model(
409
+ self.X_val_,
410
+ self.model_,
411
+ self.best_params_,
412
+ eval_mask_override=eval_mask,
413
+ )
414
+
415
+ self.plotter_.plot_history(self.history_)
416
+ self._save_best_params(self.best_params_)
417
+ return self
418
+
419
+ def transform(self) -> np.ndarray:
420
+ """Impute missing genotypes and return IUPAC strings.
421
+
422
+ This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
423
+
424
+ Returns:
425
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
426
+
427
+ Raises:
428
+ NotFittedError: If called before fit().
429
+ """
430
+ if not getattr(self, "is_fit_", False):
431
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
432
+
433
+ self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
434
+ X_to_impute = self.ground_truth_.copy()
435
+
436
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
437
+
438
+ # Fill only missing
439
+ missing_mask = X_to_impute == -1
440
+ imputed_array = X_to_impute.copy()
441
+ imputed_array[missing_mask] = pred_labels[missing_mask]
442
+
443
+ # Decode to IUPAC & optionally plot
444
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
445
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
446
+
447
+ plt.rcParams.update(self.plotter_.param_dict)
448
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
449
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
450
+
451
+ return imputed_genotypes
452
+
453
+ # ---------- plumbing identical to AE, naming aligned ---------- #
454
+
455
+ def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
456
+ """Create DataLoader over indices + integer targets (-1 for missing).
457
+
458
+ This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
459
+
460
+ Args:
461
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
462
+
463
+ Returns:
464
+ torch.utils.data.DataLoader: Shuffled DataLoader.
465
+ """
466
+ y_tensor = torch.from_numpy(y).long()
467
+ indices = torch.arange(len(y), dtype=torch.long)
468
+ dataset = torch.utils.data.TensorDataset(indices, y_tensor)
469
+ pin_memory = self.device.type == "cuda"
470
+ return torch.utils.data.DataLoader(
471
+ dataset,
472
+ batch_size=self.batch_size,
473
+ shuffle=True,
474
+ pin_memory=pin_memory,
475
+ )
476
+
477
+ def _train_and_validate_model(
478
+ self,
479
+ model: torch.nn.Module,
480
+ loader: torch.utils.data.DataLoader,
481
+ lr: float,
482
+ l1_penalty: float,
483
+ trial: optuna.Trial | None = None,
484
+ return_history: bool = False,
485
+ class_weights: torch.Tensor | None = None,
486
+ *,
487
+ X_val: np.ndarray | None = None,
488
+ params: dict | None = None,
489
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
490
+ prune_warmup_epochs: int = 10,
491
+ eval_interval: int = 1,
492
+ eval_requires_latents: bool = False, # VAE: no latent eval refinement
493
+ eval_latent_steps: int = 0,
494
+ eval_latent_lr: float = 0.0,
495
+ eval_latent_weight_decay: float = 0.0,
496
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
497
+ """Wrap the VAE training loop with β-anneal & Optuna pruning.
498
+
499
+ This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
500
+
501
+ Args:
502
+ model (torch.nn.Module): VAE model.
503
+ loader (torch.utils.data.DataLoader): Training data loader.
504
+ lr (float): Learning rate.
505
+ l1_penalty (float): L1 regularization coefficient.
506
+ trial (optuna.Trial | None): Optuna trial for pruning.
507
+ return_history (bool): If True, return training history.
508
+ class_weights (torch.Tensor | None): CE class weights on device.
509
+ X_val (np.ndarray | None): Validation data for pruning eval.
510
+ params (dict | None): Current hyperparameters (for logging).
511
+ prune_metric (str | None): Metric for pruning decisions.
512
+ prune_warmup_epochs (int): Epochs to skip before pruning.
513
+ eval_interval (int): Epochs between validation evaluations.
514
+ eval_requires_latents (bool): If True, refine latents during eval.
515
+ eval_latent_steps (int): Latent refinement steps if needed.
516
+ eval_latent_lr (float): Latent refinement learning rate.
517
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
518
+
519
+ Returns:
520
+ Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
521
+ """
522
+ if class_weights is None:
523
+ msg = "Must provide class_weights."
524
+ self.logger.error(msg)
525
+ raise TypeError(msg)
526
+
527
+ max_epochs = (
528
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
529
+ )
530
+
531
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
532
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
533
+
534
+ best_loss, best_model, hist = self._execute_training_loop(
535
+ loader=loader,
536
+ optimizer=optimizer,
537
+ scheduler=scheduler,
538
+ model=model,
539
+ l1_penalty=l1_penalty,
540
+ trial=trial,
541
+ return_history=return_history,
542
+ class_weights=class_weights,
543
+ X_val=X_val,
544
+ params=params,
545
+ prune_metric=prune_metric,
546
+ prune_warmup_epochs=prune_warmup_epochs,
547
+ eval_interval=eval_interval,
548
+ eval_requires_latents=eval_requires_latents,
549
+ eval_latent_steps=eval_latent_steps,
550
+ eval_latent_lr=eval_latent_lr,
551
+ eval_latent_weight_decay=eval_latent_weight_decay,
552
+ )
553
+ if return_history:
554
+ return best_loss, best_model, hist
555
+
556
+ return best_loss, best_model, None
557
+
558
+ def _execute_training_loop(
559
+ self,
560
+ loader: torch.utils.data.DataLoader,
561
+ optimizer: torch.optim.Optimizer,
562
+ scheduler: CosineAnnealingLR,
563
+ model: torch.nn.Module,
564
+ l1_penalty: float,
565
+ trial: optuna.Trial | None,
566
+ return_history: bool,
567
+ class_weights: torch.Tensor,
568
+ *,
569
+ X_val: np.ndarray | None = None,
570
+ params: dict | None = None,
571
+ prune_metric: str | None = None,
572
+ prune_warmup_epochs: int = 10,
573
+ eval_interval: int = 1,
574
+ eval_requires_latents: bool = False,
575
+ eval_latent_steps: int = 0,
576
+ eval_latent_lr: float = 0.0,
577
+ eval_latent_weight_decay: float = 0.0,
578
+ ) -> Tuple[float, torch.nn.Module, list]:
579
+ """Train VAE with stable focal CE + KL(β) anneal and numeric guards.
580
+
581
+ This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
582
+
583
+ Args:
584
+ loader (torch.utils.data.DataLoader): Training data loader.
585
+ optimizer (torch.optim.Optimizer): Optimizer.
586
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
587
+ model (torch.nn.Module): VAE model.
588
+ l1_penalty (float): L1 regularization coefficient.
589
+ trial (optuna.Trial | None): Optuna trial for pruning.
590
+ return_history (bool): If True, return training history.
591
+ class_weights (torch.Tensor): CE class weights on device.
592
+ X_val (np.ndarray | None): Validation data for pruning eval.
593
+ params (dict | None): Current hyperparameters (for logging).
594
+ prune_metric (str | None): Metric for pruning decisions.
595
+ prune_warmup_epochs (int): Epochs to skip before pruning.
596
+ eval_interval (int): Epochs between validation evaluations.
597
+ eval_requires_latents (bool): If True, refine latents during eval.
598
+ eval_latent_steps (int): Latent refinement steps if needed.
599
+ eval_latent_lr (float): Latent refinement learning rate.
600
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
601
+
602
+ Returns:
603
+ Tuple[float, torch.nn.Module, list]: Best loss, best model, and training history.
604
+ """
605
+ best_model = None
606
+ history: list[float] = []
607
+
608
+ early_stopping = EarlyStopping(
609
+ patience=self.early_stop_gen,
610
+ min_epochs=self.min_epochs,
611
+ verbose=self.verbose,
612
+ prefix=self.prefix,
613
+ debug=self.debug,
614
+ )
615
+
616
+ # ---- scalarize schedule endpoints up front ----
617
+ kl = self.kl_beta_final
618
+ if isinstance(kl, (list, tuple)):
619
+ if len(kl) == 0:
620
+ msg = "kl_beta_final list is empty."
621
+ self.logger.error(msg)
622
+ raise ValueError(msg)
623
+ kl = kl[0]
624
+ beta_final = float(kl)
625
+
626
+ gamma_val = self.gamma
627
+ if isinstance(gamma_val, (list, tuple)):
628
+ if len(gamma_val) == 0:
629
+ raise ValueError("gamma list is empty.")
630
+ gamma_val = gamma_val[0]
631
+ gamma_final = float(gamma_val)
632
+
633
+ gamma_warm, gamma_ramp = 50, 100
634
+ beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
635
+
636
+ # Optional LR warmup
637
+ warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
638
+ base_lr = float(optimizer.param_groups[0]["lr"])
639
+ min_lr = base_lr * 0.1
640
+
641
+ max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
642
+
643
+ for epoch in range(max_epochs):
644
+ # focal γ schedule
645
+ if epoch < gamma_warm:
646
+ model.gamma = 0.0 # type: ignore[attr-defined]
647
+ elif epoch < gamma_warm + gamma_ramp:
648
+ model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
649
+ else:
650
+ model.gamma = gamma_final # type: ignore[attr-defined]
651
+
652
+ # KL β schedule (float throughout + ramp guard)
653
+ if epoch < beta_warm:
654
+ model.beta = 0.0 # type: ignore[attr-defined]
655
+ elif beta_ramp > 0 and epoch < beta_warm + beta_ramp:
656
+ model.beta = beta_final * ((epoch - beta_warm) / beta_ramp) # type: ignore[attr-defined]
657
+ else:
658
+ model.beta = beta_final # type: ignore[attr-defined]
659
+ # LR warmup
660
+ if epoch < warmup_epochs:
661
+ scale = float(epoch + 1) / warmup_epochs
662
+ for g in optimizer.param_groups:
663
+ g["lr"] = min_lr + (base_lr - min_lr) * scale
664
+
665
+ train_loss = self._train_step(
666
+ loader=loader,
667
+ optimizer=optimizer,
668
+ model=model,
669
+ l1_penalty=l1_penalty,
670
+ class_weights=class_weights,
671
+ )
672
+
673
+ if not np.isfinite(train_loss):
674
+ if trial:
675
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
676
+ # shrink LR and continue
677
+ for g in optimizer.param_groups:
678
+ g["lr"] *= 0.5
679
+ continue
680
+
681
+ if scheduler is not None:
682
+ scheduler.step()
683
+
684
+ if return_history:
685
+ history.append(train_loss)
686
+
687
+ early_stopping(train_loss, model)
688
+ if early_stopping.early_stop:
689
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
690
+ break
691
+
692
+ # Optional Optuna pruning on a validation metric
693
+ if (
694
+ trial is not None
695
+ and X_val is not None
696
+ and ((epoch + 1) % eval_interval == 0)
697
+ ):
698
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
699
+ mask_override = None
700
+ if (
701
+ self.simulate_missing
702
+ and getattr(self, "sim_mask_test_", None) is not None
703
+ and getattr(self, "X_val_", None) is not None
704
+ and X_val.shape == self.X_val_.shape
705
+ ):
706
+ mask_override = self.sim_mask_test_
707
+ metric_val = self._eval_for_pruning(
708
+ model=model,
709
+ X_val=X_val,
710
+ params=params or getattr(self, "best_params_", {}),
711
+ metric=metric_key,
712
+ objective_mode=True,
713
+ do_latent_infer=False, # VAE uses encoder; no latent refine
714
+ latent_steps=0,
715
+ latent_lr=0.0,
716
+ latent_weight_decay=0.0,
717
+ latent_seed=self.seed, # type: ignore
718
+ _latent_cache=None,
719
+ _latent_cache_key=None,
720
+ eval_mask_override=mask_override,
721
+ )
722
+ trial.report(metric_val, step=epoch + 1)
723
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
724
+ raise optuna.exceptions.TrialPruned(
725
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
726
+ )
727
+
728
+ best_loss = early_stopping.best_score
729
+ best_model = copy.deepcopy(early_stopping.best_model)
730
+ if best_model is None:
731
+ best_model = copy.deepcopy(model)
732
+ return best_loss, best_model, history
733
+
734
+ def _train_step(
735
+ self,
736
+ loader: torch.utils.data.DataLoader,
737
+ optimizer: torch.optim.Optimizer,
738
+ model: torch.nn.Module,
739
+ l1_penalty: float,
740
+ class_weights: torch.Tensor,
741
+ ) -> float:
742
+ """One epoch: inputs → VAE forward → focal CE + β·KL with guards.
743
+
744
+ This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
745
+
746
+ Args:
747
+ loader (torch.utils.data.DataLoader): Training data loader.
748
+ optimizer (torch.optim.Optimizer): Optimizer.
749
+ model (torch.nn.Module): VAE model.
750
+ l1_penalty (float): L1 regularization coefficient.
751
+ class_weights (torch.Tensor): CE class weights on device.
752
+
753
+ Returns:
754
+ float: Average training loss for the epoch.
755
+ """
756
+ model.train()
757
+ running, used = 0.0, 0
758
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
759
+ if class_weights is not None and class_weights.device != self.device:
760
+ class_weights = class_weights.to(self.device)
761
+
762
+ for _, y_batch in loader:
763
+ optimizer.zero_grad(set_to_none=True)
764
+
765
+ y_int = y_batch.to(self.device, non_blocking=True).long()
766
+
767
+ if self.is_haploid:
768
+ x_in = self._one_hot_encode_012(y_int) # (B, L, 2)
769
+ else:
770
+ x_in = self._encode_multilabel_inputs(y_int) # (B, L, 2)
771
+
772
+ out = model(x_in)
773
+ if isinstance(out, (list, tuple)):
774
+ recon_logits, mu, logvar = out[0], out[1], out[2]
775
+ else:
776
+ recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
777
+
778
+ # Upstream guard
779
+ if (
780
+ not torch.isfinite(recon_logits).all()
781
+ or not torch.isfinite(mu).all()
782
+ or not torch.isfinite(logvar).all()
783
+ ):
784
+ continue
785
+
786
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
787
+ beta = float(getattr(model, "beta", getattr(self, "kl_beta_final", 0.0)))
788
+ gamma = max(0.0, min(gamma, 10.0))
789
+
790
+ if self.is_haploid:
791
+ loss = compute_vae_loss(
792
+ recon_logits=recon_logits,
793
+ targets=y_int,
794
+ mu=mu,
795
+ logvar=logvar,
796
+ class_weights=class_weights,
797
+ gamma=gamma,
798
+ beta=beta,
799
+ )
800
+ else:
801
+ targets = self._multi_hot_targets(y_int)
802
+ pos_w = getattr(self, "pos_weights_", None)
803
+ bce = F.binary_cross_entropy_with_logits(
804
+ recon_logits, targets, pos_weight=pos_w, reduction="none"
805
+ )
806
+ mask = (y_int != -1).unsqueeze(-1).float()
807
+ recon_loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
808
+ kl = (
809
+ -0.5
810
+ * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
811
+ / (y_int.shape[0] + 1e-8)
812
+ )
813
+ loss = recon_loss + beta * kl
814
+
815
+ if l1_penalty > 0:
816
+ l1 = torch.zeros((), device=self.device)
817
+ for p in l1_params:
818
+ l1 = l1 + p.abs().sum()
819
+ loss = loss + l1_penalty * l1
820
+
821
+ if not torch.isfinite(loss):
822
+ continue
823
+
824
+ loss.backward()
825
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
826
+
827
+ # skip update if any grad is non-finite
828
+ bad = any(
829
+ p.grad is not None and not torch.isfinite(p.grad).all()
830
+ for p in model.parameters()
831
+ )
832
+ if bad:
833
+ optimizer.zero_grad(set_to_none=True)
834
+ continue
835
+
836
+ optimizer.step()
837
+
838
+ running += float(loss.detach().item())
839
+ used += 1
840
+
841
+ return (running / used) if used > 0 else float("inf")
842
+
843
+ def _predict(
844
+ self,
845
+ model: torch.nn.Module,
846
+ X: np.ndarray | torch.Tensor,
847
+ return_proba: bool = False,
848
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
849
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
850
+
851
+ This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
852
+
853
+ Args:
854
+ model (torch.nn.Module): Trained model.
855
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
856
+ return_proba (bool): If True, also return probabilities.
857
+
858
+ Returns:
859
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
860
+ """
861
+ if model is None:
862
+ msg = "Model is not trained. Call fit() before predict()."
863
+ self.logger.error(msg)
864
+ raise NotFittedError(msg)
865
+
866
+ model.eval()
867
+ with torch.no_grad():
868
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
869
+ X_tensor = X_tensor.to(self.device).long()
870
+ if self.is_haploid:
871
+ x_ohe = self._one_hot_encode_012(X_tensor)
872
+ outputs = model(x_ohe)
873
+ logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
874
+ probas = torch.softmax(logits, dim=-1)
875
+ labels = torch.argmax(probas, dim=-1)
876
+ else:
877
+ x_in = self._encode_multilabel_inputs(X_tensor)
878
+ outputs = model(x_in)
879
+ logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
880
+ probas2 = torch.sigmoid(logits)
881
+ p_ref = probas2[..., 0]
882
+ p_alt = probas2[..., 1]
883
+ p_het = p_ref * p_alt
884
+ p_ref_only = p_ref * (1 - p_alt)
885
+ p_alt_only = p_alt * (1 - p_ref)
886
+ probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
887
+ probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
888
+ labels = torch.argmax(probas, dim=-1)
889
+
890
+ if return_proba:
891
+ return labels.cpu().numpy(), probas.cpu().numpy()
892
+
893
+ return labels.cpu().numpy()
894
+
895
+ def _evaluate_model(
896
+ self,
897
+ X_val: np.ndarray,
898
+ model: torch.nn.Module,
899
+ params: dict,
900
+ objective_mode: bool = False,
901
+ latent_vectors_val: np.ndarray | None = None,
902
+ *,
903
+ eval_mask_override: np.ndarray | None = None,
904
+ ) -> Dict[str, float]:
905
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
906
+
907
+ This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
908
+
909
+ Args:
910
+ X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
911
+ model (torch.nn.Module): Trained model.
912
+ params (dict): Current hyperparameters (for logging).
913
+ objective_mode (bool): If True, minimize logging for Optuna.
914
+ latent_vectors_val (np.ndarray | None): Not used by VAE.
915
+ eval_mask_override (np.ndarray | None): Optional mask to override default eval mask.
916
+
917
+ Returns:
918
+ Dict[str, float]: Computed metrics.
919
+
920
+ Raises:
921
+ NotFittedError: If called before fit().
922
+ """
923
+ pred_labels, pred_probas = self._predict(
924
+ model=model, X=X_val, return_proba=True
925
+ )
926
+
927
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
928
+
929
+ # FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
930
+ if (
931
+ hasattr(self, "X_val_")
932
+ and getattr(self, "X_val_", None) is not None
933
+ and X_val.shape[0] == self.X_val_.shape[0]
934
+ ):
935
+ GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
936
+ elif (
937
+ hasattr(self, "X_train_")
938
+ and getattr(self, "X_train_", None) is not None
939
+ and X_val.shape[0] == self.X_train_.shape[0]
940
+ ):
941
+ GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
942
+ else:
943
+ GT_ref = self.ground_truth_
944
+
945
+ # FIX 2: Handle Feature Mismatch
946
+ # If GT has more columns than X_val, slice it to match.
947
+ if GT_ref.shape[1] > X_val.shape[1]:
948
+ GT_ref = GT_ref[:, : X_val.shape[1]]
949
+
950
+ # Fallback safeguard
951
+ if GT_ref.shape != X_val.shape:
952
+ GT_ref = X_val
953
+
954
+ # FIX 3: Allow override mask to be sliced if it's too wide
955
+ if eval_mask_override is not None:
956
+ if eval_mask_override.shape[0] != X_val.shape[0]:
957
+ msg = (
958
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
959
+ f"does not match X_val rows {X_val.shape[0]}"
960
+ )
961
+ self.logger.error(msg)
962
+ raise ValueError(msg)
963
+
964
+ if eval_mask_override.shape[1] > X_val.shape[1]:
965
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
966
+ else:
967
+ eval_mask = eval_mask_override.astype(bool)
968
+ else:
969
+ eval_mask = X_val != -1
970
+
971
+ eval_mask = eval_mask & finite_mask & (GT_ref != -1)
972
+
973
+ y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
974
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
975
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
976
+
977
+ if y_true_flat.size == 0:
978
+ return {self.tune_metric: 0.0}
979
+
980
+ # ensure valid probability simplex after masking
981
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
982
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
983
+ row_sums[row_sums == 0] = 1.0
984
+ y_proba_flat = y_proba_flat / row_sums
985
+
986
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
987
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
988
+
989
+ if self.is_haploid:
990
+ y_true_flat = y_true_flat.copy()
991
+ y_pred_flat = y_pred_flat.copy()
992
+ y_true_flat[y_true_flat == 2] = 1
993
+ y_pred_flat[y_pred_flat == 2] = 1
994
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
995
+ proba_2[:, 0] = y_proba_flat[:, 0]
996
+ proba_2[:, 1] = y_proba_flat[:, 2]
997
+ y_proba_flat = proba_2
998
+
999
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
1000
+
1001
+ metrics = self.scorers_.evaluate(
1002
+ y_true_flat,
1003
+ y_pred_flat,
1004
+ y_true_ohe,
1005
+ y_proba_flat,
1006
+ objective_mode,
1007
+ self.tune_metric,
1008
+ )
1009
+
1010
+ if not objective_mode:
1011
+ pm = PrettyMetrics(
1012
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
1013
+ )
1014
+ pm.render()
1015
+
1016
+ # Primary report
1017
+ self._make_class_reports(
1018
+ y_true=y_true_flat,
1019
+ y_pred_proba=y_proba_flat,
1020
+ y_pred=y_pred_flat,
1021
+ metrics=metrics,
1022
+ labels=target_names,
1023
+ )
1024
+
1025
+ # IUPAC decode & 10-base integer report
1026
+ # FIX 4: Use current shape (X_val.shape) not self.num_features_
1027
+ y_true_dec = self.pgenc.decode_012(
1028
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1029
+ )
1030
+ X_pred = X_val.copy()
1031
+ X_pred[eval_mask] = y_pred_flat
1032
+ y_pred_dec = self.pgenc.decode_012(
1033
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
1034
+ )
1035
+
1036
+ encodings_dict = {
1037
+ "A": 0,
1038
+ "C": 1,
1039
+ "G": 2,
1040
+ "T": 3,
1041
+ "W": 4,
1042
+ "R": 5,
1043
+ "M": 6,
1044
+ "K": 7,
1045
+ "Y": 8,
1046
+ "S": 9,
1047
+ "N": -1,
1048
+ }
1049
+ y_true_int = self.pgenc.convert_int_iupac(
1050
+ y_true_dec, encodings_dict=encodings_dict
1051
+ )
1052
+ y_pred_int = self.pgenc.convert_int_iupac(
1053
+ y_pred_dec, encodings_dict=encodings_dict
1054
+ )
1055
+
1056
+ valid_iupac_mask = y_true_int[eval_mask] >= 0
1057
+ if valid_iupac_mask.any():
1058
+ self._make_class_reports(
1059
+ y_true=y_true_int[eval_mask][valid_iupac_mask],
1060
+ y_pred=y_pred_int[eval_mask][valid_iupac_mask],
1061
+ metrics=metrics,
1062
+ y_pred_proba=None,
1063
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
1064
+ )
1065
+ else:
1066
+ self.logger.warning(
1067
+ "Skipped IUPAC confusion matrix: No valid ground truths."
1068
+ )
1069
+
1070
+ return metrics
1071
+
1072
+ def _objective(self, trial: optuna.Trial) -> float:
1073
+ """Optuna objective for VAE (no latent refinement during eval).
1074
+
1075
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, trains the VAE model with these parameters, and evaluates its performance on a validation set. The evaluation metric specified by `self.tune_metric` is returned for optimization. If training fails, the trial is pruned to keep the tuning process efficient.
1076
+
1077
+ Args:
1078
+ trial (optuna.Trial): Optuna trial object.
1079
+
1080
+ Returns:
1081
+ float: Value of the tuning metric to be optimized.
1082
+ """
1083
+ try:
1084
+ params = self._sample_hyperparameters(trial)
1085
+
1086
+ # Use tune subsets when available (tune_fast)
1087
+ X_train = getattr(self, "_tune_X_train", None)
1088
+ X_val = getattr(self, "_tune_X_test", None)
1089
+ if X_train is None or X_val is None:
1090
+ X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1091
+ X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1092
+
1093
+ class_weights = self._normalize_class_weights(
1094
+ self._class_weights_from_zygosity(X_train)
1095
+ )
1096
+ # Pos weights for diploid multilabel BCE during tuning
1097
+ if not self.is_haploid:
1098
+ self.pos_weights_ = self._compute_pos_weights(X_train)
1099
+ else:
1100
+ self.pos_weights_ = None
1101
+ train_loader = self._get_data_loader(X_train)
1102
+
1103
+ model = self.build_model(self.Model, params["model_params"])
1104
+ model.apply(self.initialize_weights)
1105
+
1106
+ lr: float = params["lr"]
1107
+ l1_penalty: float = params["l1_penalty"]
1108
+
1109
+ # Train + prune on metric
1110
+ _, model, __ = self._train_and_validate_model(
1111
+ model=model,
1112
+ loader=train_loader,
1113
+ lr=lr,
1114
+ l1_penalty=l1_penalty,
1115
+ trial=trial,
1116
+ return_history=False,
1117
+ class_weights=class_weights,
1118
+ X_val=X_val,
1119
+ params=params,
1120
+ prune_metric=self.tune_metric,
1121
+ prune_warmup_epochs=10,
1122
+ eval_interval=self.tune_eval_interval,
1123
+ eval_requires_latents=False,
1124
+ eval_latent_steps=0,
1125
+ eval_latent_lr=0.0,
1126
+ eval_latent_weight_decay=0.0,
1127
+ )
1128
+
1129
+ eval_mask = (
1130
+ self.sim_mask_test_
1131
+ if (
1132
+ self.simulate_missing
1133
+ and getattr(self, "sim_mask_test_", None) is not None
1134
+ )
1135
+ else None
1136
+ )
1137
+
1138
+ if model is None:
1139
+ raise RuntimeError("Model training failed; no model was returned.")
1140
+
1141
+ metrics = self._evaluate_model(
1142
+ X_val, model, params, objective_mode=True, eval_mask_override=eval_mask
1143
+ )
1144
+ self._clear_resources(model, train_loader)
1145
+ return metrics[self.tune_metric]
1146
+
1147
+ except Exception as e:
1148
+ # Keep sweeps moving
1149
+ self.logger.debug(f"Trial failed with error: {e}")
1150
+ raise optuna.exceptions.TrialPruned(
1151
+ f"Trial failed with error. Enable debug logging for details."
1152
+ )
1153
+
1154
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1155
+ """Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
1156
+
1157
+ Args:
1158
+ trial (optuna.Trial): Optuna trial object.
1159
+
1160
+ Returns:
1161
+ Dict[str, int | float | str]: Sampled hyperparameters.
1162
+ """
1163
+ params = {
1164
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1165
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1166
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
1167
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
1168
+ "activation": trial.suggest_categorical(
1169
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
1170
+ ),
1171
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1172
+ "layer_scaling_factor": trial.suggest_float(
1173
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
1174
+ ),
1175
+ "layer_schedule": trial.suggest_categorical(
1176
+ "layer_schedule", ["pyramid", "linear"]
1177
+ ),
1178
+ # VAE-specific β (final value after anneal)
1179
+ "beta": trial.suggest_float("beta", 0.5, 2.0, step=0.5),
1180
+ # focal gamma (if used in VAE recon CE)
1181
+ "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
1182
+ }
1183
+
1184
+ use_n_features = (
1185
+ self._tune_num_features
1186
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
1187
+ else self.num_features_
1188
+ )
1189
+ input_dim = use_n_features * self.output_classes_
1190
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1191
+ n_inputs=input_dim,
1192
+ n_outputs=input_dim,
1193
+ n_samples=len(self.train_idx_),
1194
+ n_hidden=params["num_hidden_layers"],
1195
+ alpha=params["layer_scaling_factor"],
1196
+ schedule=params["layer_schedule"],
1197
+ )
1198
+
1199
+ # [latent_dim] + interior widths (exclude output width)
1200
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1201
+
1202
+ params["model_params"] = {
1203
+ "n_features": use_n_features,
1204
+ "num_classes": self.output_classes_,
1205
+ "latent_dim": params["latent_dim"],
1206
+ "dropout_rate": params["dropout_rate"],
1207
+ "hidden_layer_sizes": hidden_only,
1208
+ "activation": params["activation"],
1209
+ # Pass through VAE recon/regularization coefficients
1210
+ "beta": params["beta"],
1211
+ "gamma": params["gamma"],
1212
+ }
1213
+ return params
1214
+
1215
+ def _set_best_params(self, best_params: dict) -> dict:
1216
+ """Adopt best params and return VAE model_params.
1217
+
1218
+ Args:
1219
+ best_params (dict): Best hyperparameters from tuning.
1220
+
1221
+ Returns:
1222
+ dict: VAE model parameters.
1223
+ """
1224
+ self.latent_dim = best_params["latent_dim"]
1225
+ self.dropout_rate = best_params["dropout_rate"]
1226
+ self.learning_rate = best_params["learning_rate"]
1227
+ self.l1_penalty = best_params["l1_penalty"]
1228
+ self.activation = best_params["activation"]
1229
+ self.layer_scaling_factor = best_params["layer_scaling_factor"]
1230
+ self.layer_schedule = best_params["layer_schedule"]
1231
+ self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
1232
+ self.gamma = best_params.get("gamma", self.gamma)
1233
+
1234
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1235
+ n_inputs=self.num_features_ * self.output_classes_,
1236
+ n_outputs=self.num_features_ * self.output_classes_,
1237
+ n_samples=len(self.train_idx_),
1238
+ n_hidden=best_params["num_hidden_layers"],
1239
+ alpha=best_params["layer_scaling_factor"],
1240
+ schedule=best_params["layer_schedule"],
1241
+ )
1242
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1243
+
1244
+ return {
1245
+ "n_features": self.num_features_,
1246
+ "latent_dim": self.latent_dim,
1247
+ "hidden_layer_sizes": hidden_only,
1248
+ "dropout_rate": self.dropout_rate,
1249
+ "activation": self.activation,
1250
+ "num_classes": self.output_classes_,
1251
+ "beta": self.kl_beta_final,
1252
+ "gamma": self.gamma,
1253
+ }
1254
+
1255
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
1256
+ """Default VAE model params when tuning is disabled.
1257
+
1258
+ Returns:
1259
+ Dict[str, int | float | str | list]: VAE model parameters.
1260
+ """
1261
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1262
+ n_inputs=self.num_features_ * self.output_classes_,
1263
+ n_outputs=self.num_features_ * self.output_classes_,
1264
+ n_samples=len(self.ground_truth_),
1265
+ n_hidden=self.num_hidden_layers,
1266
+ alpha=self.layer_scaling_factor,
1267
+ schedule=self.layer_schedule,
1268
+ )
1269
+ return {
1270
+ "n_features": self.num_features_,
1271
+ "latent_dim": self.latent_dim,
1272
+ "hidden_layer_sizes": hidden_layer_sizes,
1273
+ "dropout_rate": self.dropout_rate,
1274
+ "activation": self.activation,
1275
+ "num_classes": self.output_classes_,
1276
+ "beta": self.kl_beta_final,
1277
+ "gamma": self.gamma,
1278
+ }
1279
+
1280
+ def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
1281
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
1282
+ if self.is_haploid:
1283
+ return self._one_hot_encode_012(y)
1284
+ y = y.to(self.device)
1285
+ shape = y.shape + (2,)
1286
+ out = torch.zeros(shape, device=self.device, dtype=torch.float32)
1287
+ valid = y != -1
1288
+ ref_mask = valid & (y != 2)
1289
+ alt_mask = valid & (y != 0)
1290
+ out[ref_mask, 0] = 1.0
1291
+ out[alt_mask, 1] = 1.0
1292
+ return out
1293
+
1294
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
1295
+ """Targets aligned with _encode_multilabel_inputs for diploid training."""
1296
+ if self.is_haploid:
1297
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
1298
+ y = y.to(self.device)
1299
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
1300
+ valid = y != -1
1301
+ ref_mask = valid & (y != 2)
1302
+ alt_mask = valid & (y != 0)
1303
+ out[ref_mask, 0] = 1.0
1304
+ out[alt_mask, 1] = 1.0
1305
+ return out
1306
+
1307
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
1308
+ """Balance REF/ALT channels for multilabel BCE."""
1309
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
1310
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
1311
+ total_valid = np.count_nonzero(X != -1)
1312
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
1313
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
1314
+ pos_counts = np.maximum(pos_counts, 1.0)
1315
+ weights = neg_counts / pos_counts
1316
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)