pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
- from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
4
+ import traceback
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
@@ -9,17 +11,15 @@ import optuna
9
11
  import torch
10
12
  import torch.nn.functional as F
11
13
  from sklearn.exceptions import NotFittedError
12
- from sklearn.model_selection import train_test_split
13
14
  from snpio.analysis.genotype_encoder import GenotypeEncoder
14
15
  from snpio.utils.logging import LoggerManager
15
16
  from torch.optim.lr_scheduler import CosineAnnealingLR
16
17
 
17
18
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
18
19
  from pgsui.data_processing.containers import VAEConfig
19
- from pgsui.data_processing.transformers import SimMissingTransformer
20
20
  from pgsui.impute.unsupervised.base import BaseNNImputer
21
21
  from pgsui.impute.unsupervised.callbacks import EarlyStopping
22
- from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
22
+ from pgsui.impute.unsupervised.loss_functions import FocalCELoss, compute_vae_loss
23
23
  from pgsui.impute.unsupervised.models.vae_model import VAEModel
24
24
  from pgsui.utils.logging_utils import configure_logger
25
25
  from pgsui.utils.pretty_metrics import PrettyMetrics
@@ -29,14 +29,47 @@ if TYPE_CHECKING:
29
29
  from snpio.read_input.genotype_data import GenotypeData
30
30
 
31
31
 
32
- def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
33
- """Normalize VAEConfig input from various sources.
32
+ def _make_warmup_cosine_scheduler(
33
+ optimizer: torch.optim.Optimizer,
34
+ *,
35
+ max_epochs: int,
36
+ warmup_epochs: int,
37
+ start_factor: float = 0.1,
38
+ ) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
39
+ """Create a warmup->cosine LR scheduler.
34
40
 
35
41
  Args:
36
- config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
42
+ optimizer (torch.optim.Optimizer): The optimizer to schedule.
43
+ max_epochs (int): Total number of epochs for training.
44
+ warmup_epochs (int): Number of warmup epochs.
45
+ start_factor (float): Initial LR factor for warmup.
37
46
 
38
47
  Returns:
39
- VAEConfig: Normalized configuration dataclass.
48
+ torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: The learning rate scheduler.
49
+ """
50
+ warmup_epochs = int(max(0, warmup_epochs))
51
+
52
+ if warmup_epochs == 0:
53
+ return CosineAnnealingLR(optimizer, T_max=max_epochs)
54
+
55
+ warmup = torch.optim.lr_scheduler.LinearLR(
56
+ optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
57
+ )
58
+ cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
59
+
60
+ return torch.optim.lr_scheduler.SequentialLR(
61
+ optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
62
+ )
63
+
64
+
65
+ def ensure_vae_config(config: VAEConfig | dict | str | None) -> VAEConfig:
66
+ """Ensure a VAEConfig instance from various input types.
67
+
68
+ Args:
69
+ config (VAEConfig | dict | str | None): Configuration input.
70
+
71
+ Returns:
72
+ VAEConfig: The resulting VAEConfig instance.
40
73
  """
41
74
  if config is None:
42
75
  return VAEConfig()
@@ -45,13 +78,13 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
45
78
  if isinstance(config, str):
46
79
  return load_yaml_to_dataclass(config, VAEConfig)
47
80
  if isinstance(config, dict):
81
+ cfg_in = copy.deepcopy(config)
48
82
  base = VAEConfig()
49
- # Respect top-level preset
50
- preset = config.pop("preset", None)
83
+ preset = cfg_in.pop("preset", None)
84
+ if "io" in cfg_in and isinstance(cfg_in["io"], dict):
85
+ preset = preset or cfg_in["io"].pop("preset", None)
51
86
  if preset:
52
87
  base = VAEConfig.from_preset(preset)
53
- # Flatten + apply
54
- flat: Dict[str, object] = {}
55
88
 
56
89
  def _flatten(prefix: str, d: dict, out: dict) -> dict:
57
90
  for k, v in d.items():
@@ -62,15 +95,19 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
62
95
  out[kk] = v
63
96
  return out
64
97
 
65
- flat = _flatten("", config, {})
98
+ flat = _flatten("", cfg_in, {})
66
99
  return apply_dot_overrides(base, flat)
67
100
  raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
68
101
 
69
102
 
70
103
  class ImputeVAE(BaseNNImputer):
71
- """Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
104
+ """Variational Autoencoder (VAE) imputer for 0/1/2 genotypes.
105
+
106
+ Trains a VAE on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. The workflow simulates missingness once on the full matrix, then creates train/val/test splits. It supports haploid and diploid data, focal-CE reconstruction loss with a KL term (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
72
107
 
73
- This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
108
+ Notes:
109
+ - Training includes early stopping based on validation loss.
110
+ - The imputer can handle both haploid and diploid genotype data.
74
111
  """
75
112
 
76
113
  def __init__(
@@ -79,46 +116,37 @@ class ImputeVAE(BaseNNImputer):
79
116
  *,
80
117
  tree_parser: Optional["TreeParser"] = None,
81
118
  config: Optional[Union["VAEConfig", dict, str]] = None,
82
- overrides: dict | None = None,
83
- simulate_missing: bool | None = None,
84
- sim_strategy: (
85
- Literal[
86
- "random",
87
- "random_weighted",
88
- "random_weighted_inv",
89
- "nonrandom",
90
- "nonrandom_weighted",
91
- ]
92
- | None
93
- ) = None,
94
- sim_prop: float | None = None,
95
- sim_kwargs: dict | None = None,
96
- ):
97
- """Initialize the VAE imputer with a unified config interface.
98
-
99
- This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
119
+ overrides: Optional[dict] = None,
120
+ sim_strategy: Literal[
121
+ "random",
122
+ "random_weighted",
123
+ "random_weighted_inv",
124
+ "nonrandom",
125
+ "nonrandom_weighted",
126
+ ] = "random",
127
+ sim_prop: Optional[float] = None,
128
+ sim_kwargs: Optional[dict] = None,
129
+ ) -> None:
130
+ """Initialize the ImputeVAE imputer.
100
131
 
101
132
  Args:
102
- genotype_data (GenotypeData): Backing genotype data object.
103
- tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
104
- config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
105
- overrides (dict | None): Optional dot-key overrides with highest precedence.
106
- simulate_missing (bool | None): Whether to simulate missing data during training.
107
- sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
108
- sim_prop (float | None): Proportion of data to simulate as missing if simulating.
109
- sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
133
+ genotype_data (GenotypeData): Genotype data for imputation.
134
+ tree_parser (Optional[TreeParser]): Tree parser required for nonrandom strategies.
135
+ config (Optional[Union[VAEConfig, dict, str]]): Config dataclass, nested dict, YAML path, or None.
136
+ overrides (Optional[dict]): Dot-key overrides applied last with highest precedence.
137
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Missingness simulation strategy (overrides config).
138
+ sim_prop (Optional[float]): Proportion of entries to simulate as missing (overrides config). Default is None.
139
+ sim_kwargs (Optional[dict]): Extra missingness kwargs merged into config.
110
140
  """
111
141
  self.model_name = "ImputeVAE"
112
142
  self.genotype_data = genotype_data
113
143
  self.tree_parser = tree_parser
114
144
 
115
- # Normalize configuration and apply top-precedence overrides
116
145
  cfg = ensure_vae_config(config)
117
146
  if overrides:
118
147
  cfg = apply_dot_overrides(cfg, overrides)
119
148
  self.cfg = cfg
120
149
 
121
- # Logger (align with AE/NLPCA)
122
150
  logman = LoggerManager(
123
151
  __name__,
124
152
  prefix=self.cfg.io.prefix,
@@ -126,12 +154,10 @@ class ImputeVAE(BaseNNImputer):
126
154
  verbose=self.cfg.io.verbose,
127
155
  )
128
156
  self.logger = configure_logger(
129
- logman.get_logger(),
130
- verbose=self.cfg.io.verbose,
131
- debug=self.cfg.io.debug,
157
+ logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
132
158
  )
159
+ self.logger.propagate = False
133
160
 
134
- # BaseNNImputer bootstraps device/dirs/log formatting
135
161
  super().__init__(
136
162
  model_name=self.model_name,
137
163
  genotype_data=self.genotype_data,
@@ -141,11 +167,10 @@ class ImputeVAE(BaseNNImputer):
141
167
  debug=self.cfg.io.debug,
142
168
  )
143
169
 
144
- # Model hook & encoder
145
170
  self.Model = VAEModel
146
171
  self.pgenc = GenotypeEncoder(genotype_data)
147
172
 
148
- # IO/global
173
+ # I/O and general parameters
149
174
  self.seed = self.cfg.io.seed
150
175
  self.n_jobs = self.cfg.io.n_jobs
151
176
  self.prefix = self.cfg.io.prefix
@@ -153,43 +178,41 @@ class ImputeVAE(BaseNNImputer):
153
178
  self.verbose = self.cfg.io.verbose
154
179
  self.debug = self.cfg.io.debug
155
180
  self.rng = np.random.default_rng(self.seed)
156
- self.pos_weights_: torch.Tensor | None = None
157
181
 
158
- # Simulated-missing controls (config defaults + ctor overrides)
182
+ # Simulation parameters
159
183
  sim_cfg = getattr(self.cfg, "sim", None)
160
184
  sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
161
185
  if sim_kwargs:
162
186
  sim_cfg_kwargs.update(sim_kwargs)
163
187
  if sim_cfg is None:
164
- default_sim_flag = bool(simulate_missing)
165
188
  default_strategy = "random"
166
- default_prop = 0.10
189
+ default_prop = 0.2
167
190
  else:
168
- default_sim_flag = sim_cfg.simulate_missing
169
191
  default_strategy = sim_cfg.sim_strategy
170
192
  default_prop = sim_cfg.sim_prop
171
- self.simulate_missing = (
172
- default_sim_flag if simulate_missing is None else bool(simulate_missing)
173
- )
193
+
194
+ self.simulate_missing = True
174
195
  self.sim_strategy = sim_strategy or default_strategy
175
196
  self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
176
197
  self.sim_kwargs = sim_cfg_kwargs
177
198
 
178
- # Model hyperparams (AE-parity)
199
+ # Model architecture parameters
179
200
  self.latent_dim = self.cfg.model.latent_dim
180
201
  self.dropout_rate = self.cfg.model.dropout_rate
181
202
  self.num_hidden_layers = self.cfg.model.num_hidden_layers
182
203
  self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
183
204
  self.layer_schedule = self.cfg.model.layer_schedule
184
- self.activation = self.cfg.model.hidden_activation
185
- self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
205
+ self.activation = self.cfg.model.activation
186
206
 
187
- # VAE-only KL controls
188
- self.kl_beta_final = self.cfg.vae.kl_beta
189
- self.kl_warmup = self.cfg.vae.kl_warmup
190
- self.kl_ramp = self.cfg.vae.kl_ramp
207
+ # VAE-specific parameters
208
+ self.kl_beta = self.cfg.vae.kl_beta
209
+ self.kl_beta_schedule = self.cfg.vae.kl_beta_schedule
191
210
 
192
- # Train hyperparams (AE-parity)
211
+ # Training parameters
212
+ self.power: float = self.cfg.train.weights_power
213
+ self.max_ratio: Any | float | None = self.cfg.train.weights_max_ratio
214
+ self.normalize: bool = self.cfg.train.weights_normalize
215
+ self.inverse: bool = self.cfg.train.weights_inverse
193
216
  self.batch_size = self.cfg.train.batch_size
194
217
  self.learning_rate = self.cfg.train.learning_rate
195
218
  self.l1_penalty: float = self.cfg.train.l1_penalty
@@ -197,32 +220,18 @@ class ImputeVAE(BaseNNImputer):
197
220
  self.min_epochs = self.cfg.train.min_epochs
198
221
  self.epochs = self.cfg.train.max_epochs
199
222
  self.validation_split = self.cfg.train.validation_split
200
- self.beta = self.cfg.train.weights_beta
201
- self.max_ratio = self.cfg.train.weights_max_ratio
223
+ self.gamma = self.cfg.train.gamma
224
+ self.gamma_schedule = self.cfg.train.gamma_schedule
202
225
 
203
- # Tuning (AE-parity surface; VAE ignores latent refinement during eval)
226
+ # Hyperparameter tuning
204
227
  self.tune = self.cfg.tune.enabled
205
- self.tune_fast = self.cfg.tune.fast
206
- self.tune_batch_size = self.cfg.tune.batch_size
207
- self.tune_epochs = self.cfg.tune.epochs
208
- self.tune_eval_interval = self.cfg.tune.eval_interval
209
- self.tune_metric: Literal[
210
- "pr_macro",
211
- "f1",
212
- "accuracy",
213
- "average_precision",
214
- "precision",
215
- "recall",
216
- "roc_auc",
217
- ] = self.cfg.tune.metric
228
+ self.tune_metric = self.cfg.tune.metric
218
229
  self.n_trials = self.cfg.tune.n_trials
219
230
  self.tune_save_db = self.cfg.tune.save_db
220
231
  self.tune_resume = self.cfg.tune.resume
221
- self.tune_max_samples = self.cfg.tune.max_samples
222
- self.tune_max_loci = self.cfg.tune.max_loci
223
232
  self.tune_patience = self.cfg.tune.patience
224
233
 
225
- # Plotting (AE-parity)
234
+ # Plotting parameters
226
235
  self.plot_format = self.cfg.plot.fmt
227
236
  self.plot_dpi = self.cfg.plot.dpi
228
237
  self.plot_fontsize = self.cfg.plot.fontsize
@@ -230,163 +239,281 @@ class ImputeVAE(BaseNNImputer):
230
239
  self.despine = self.cfg.plot.despine
231
240
  self.show_plots = self.cfg.plot.show
232
241
 
233
- # Derived at fit-time
234
- self.is_haploid: bool = False
235
- self.num_classes_: int = 3 # diploid default
242
+ # Internal attributes set during fitting
243
+ self.is_haploid_: bool = False
244
+ self.num_classes_: int = 3
236
245
  self.model_params: Dict[str, Any] = {}
237
- self.sim_mask_global_: np.ndarray | None = None
238
- self.sim_mask_train_: np.ndarray | None = None
239
- self.sim_mask_test_: np.ndarray | None = None
246
+ self.sim_mask_test_: np.ndarray
240
247
 
241
248
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
242
- msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
249
+ msg = "tree_parser is required for nonrandom sim strategies."
243
250
  self.logger.error(msg)
244
251
  raise ValueError(msg)
245
252
 
246
- # -------------------- Fit -------------------- #
247
253
  def fit(self) -> "ImputeVAE":
248
- """Fit the VAE on 0/1/2 encoded genotypes (missing -> -1).
249
-
250
- This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. Missing positions are encoded as -1 for loss masking (any simulated-missing loci are temporarily tagged with -9 by ``SimMissingTransformer`` before being re-encoded as -1). It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
254
+ """Fit the VAE imputer model to the genotype data.
255
+
256
+ This method performs the following steps:
257
+ 1. Validates the presence of SNP data in the genotype data.
258
+ 2. Determines the ploidy of the genotype data and sets up haploid/diploid handling.
259
+ 3. Simulates missingness in the genotype data based on the specified strategy.
260
+ 4. Splits the data into training, validation, and test sets.
261
+ 5. One-hot encodes the genotype data for model input.
262
+ 6. Initializes data loaders for training and validation.
263
+ 7. If hyperparameter tuning is enabled, tunes the model hyperparameters.
264
+ 8. Builds the VAE model with the best hyperparameters.
265
+ 9. Trains the VAE model using the training data and validates on the validation set.
266
+ 10. Evaluates the trained model on the test set and computes performance metrics.
267
+ 11. Saves the trained model and best hyperparameters.
268
+ 12. Generates plots of training history if enabled.
269
+ 13. Returns the fitted ImputeVAE instance.
251
270
 
252
271
  Returns:
253
- ImputeVAE: Fitted instance.
254
-
255
- Raises:
256
- RuntimeError: If training fails to produce a model.
272
+ ImputeVAE: The fitted ImputeVAE instance.
257
273
  """
258
274
  self.logger.info(f"Fitting {self.model_name} model...")
259
275
 
260
- # Data prep aligns with AE/NLPCA
261
- X012 = self._get_float_genotypes(copy=True)
262
- GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
263
- self.ground_truth_ = GT_full.astype(np.int64, copy=False)
276
+ if self.genotype_data.snp_data is None:
277
+ msg = f"SNP data is required for {self.model_name}."
278
+ self.logger.error(msg)
279
+ raise AttributeError(msg)
264
280
 
265
- self.sim_mask_global_ = None
266
- cache_key = self._sim_mask_cache_key()
267
- if self.simulate_missing:
268
- cached_mask = (
269
- None if cache_key is None else self._sim_mask_cache.get(cache_key)
270
- )
271
- if cached_mask is not None:
272
- self.sim_mask_global_ = cached_mask.copy()
273
- else:
274
- tr = SimMissingTransformer(
275
- genotype_data=self.genotype_data,
276
- tree_parser=self.tree_parser,
277
- prop_missing=self.sim_prop,
278
- strategy=self.sim_strategy,
279
- missing_val=-9,
280
- mask_missing=True,
281
- verbose=self.verbose,
282
- **self.sim_kwargs,
283
- )
284
- tr.fit(X012.copy())
285
- self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
286
- if cache_key is not None:
287
- self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
281
+ self.ploidy = self.cfg.io.ploidy
282
+ self.is_haploid_ = self.ploidy == 1
288
283
 
289
- X_for_model = self.ground_truth_.copy()
290
- X_for_model[self.sim_mask_global_] = -1
291
- else:
292
- X_for_model = self.ground_truth_.copy()
293
-
294
- # Ploidy/classes
295
- self.is_haploid = bool(
296
- np.all(
297
- np.isin(
298
- self.genotype_data.snp_data,
299
- ["A", "C", "G", "T", "N", "-", ".", "?"],
300
- )
301
- )
302
- )
303
- self.ploidy = 1 if self.is_haploid else 2
304
- self.num_classes_ = 2 if self.is_haploid else 3
305
- self.output_classes_ = 2
306
- self.logger.info(
307
- f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
308
- f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
309
- )
284
+ if self.ploidy > 2:
285
+ msg = f"{self.model_name} currently supports only haploid (1) or diploid (2) data; got ploidy={self.ploidy}."
286
+ self.logger.error(msg)
287
+ raise ValueError(msg)
310
288
 
311
- if self.is_haploid:
312
- self.ground_truth_[self.ground_truth_ == 2] = 1
313
- X_for_model[X_for_model == 2] = 1
289
+ self.num_classes_ = 2 if self.is_haploid_ else 3
314
290
 
315
- n_samples, self.num_features_ = X_for_model.shape
291
+ gt_full = self.pgenc.genotypes_012.copy()
292
+ gt_full[gt_full < 0] = -1
293
+ gt_full = np.nan_to_num(gt_full, nan=-1.0)
294
+ self.ground_truth_ = gt_full.astype(np.int8)
295
+ self.num_features_ = gt_full.shape[1]
316
296
 
317
- # Model params (decoder outputs L*K logits)
318
297
  self.model_params = {
319
298
  "n_features": self.num_features_,
320
- "num_classes": self.output_classes_,
299
+ "num_classes": self.num_classes_,
321
300
  "latent_dim": self.latent_dim,
322
301
  "dropout_rate": self.dropout_rate,
323
302
  "activation": self.activation,
324
303
  }
325
304
 
326
- # Train/Val split
327
- indices = np.arange(n_samples)
328
- train_idx, val_idx = train_test_split(
329
- indices, test_size=self.validation_split, random_state=self.seed
305
+ # Simulate missingness ONCE on the full matrix
306
+ sim_tup = self.sim_missing_transform(self.ground_truth_)
307
+ X_for_model_full = sim_tup[0]
308
+ self.sim_mask_ = sim_tup[1]
309
+ self.orig_mask_ = sim_tup[2]
310
+
311
+ # Split indices based on clean ground truth
312
+ self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
313
+ self.ground_truth_
330
314
  )
331
- self.train_idx_, self.test_idx_ = train_idx, val_idx
332
- self.X_train_ = X_for_model[train_idx]
333
- self.X_val_ = X_for_model[val_idx]
334
- self.GT_train_full_ = self.ground_truth_[train_idx]
335
- self.GT_test_full_ = self.ground_truth_[val_idx]
336
-
337
- if self.sim_mask_global_ is not None:
338
- self.sim_mask_train_ = self.sim_mask_global_[train_idx]
339
- self.sim_mask_test_ = self.sim_mask_global_[val_idx]
340
- else:
341
- self.sim_mask_train_ = None
342
- self.sim_mask_test_ = None
343
315
 
344
- # Plotters/scorers (shared utilities)
316
+ # --- Clean (targets) per split ---
317
+ X_train_clean = self.ground_truth_[self.train_idx_].copy()
318
+ X_val_clean = self.ground_truth_[self.val_idx_].copy()
319
+ X_test_clean = self.ground_truth_[self.test_idx_].copy()
320
+
321
+ # --- Corrupted (inputs) per split (from the single simulation) ---
322
+ X_train_corrupted = X_for_model_full[self.train_idx_].copy()
323
+ X_val_corrupted = X_for_model_full[self.val_idx_].copy()
324
+ X_test_corrupted = X_for_model_full[self.test_idx_].copy()
325
+
326
+ # --- Masks per split ---
327
+ self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
328
+ self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
329
+ self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
330
+
331
+ self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
332
+ self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
333
+ self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
334
+
335
+ # Persist clean/corrupted matrices if you want them accessible later
336
+ self.X_train_clean_ = X_train_clean
337
+ self.X_val_clean_ = X_val_clean
338
+ self.X_test_clean_ = X_test_clean
339
+
340
+ self.X_train_corrupted_ = X_train_corrupted
341
+ self.X_val_corrupted_ = X_val_corrupted
342
+ self.X_test_corrupted_ = X_test_corrupted
343
+
344
+ # --- Haploid harmonization (do NOT resimulate; just recode values) ---
345
+ if self.is_haploid_:
346
+
347
+ def _haploidize(arr: np.ndarray) -> np.ndarray:
348
+ out = arr.copy()
349
+ miss = out < 0
350
+ out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
351
+ out[miss] = -1
352
+ return out
353
+
354
+ X_train_clean = _haploidize(X_train_clean)
355
+ X_val_clean = _haploidize(X_val_clean)
356
+ X_test_clean = _haploidize(X_test_clean)
357
+
358
+ X_train_corrupted = _haploidize(X_train_corrupted)
359
+ X_val_corrupted = _haploidize(X_val_corrupted)
360
+ X_test_corrupted = _haploidize(X_test_corrupted)
361
+
362
+ # Write back the persisted versions too
363
+ self.X_train_clean_ = X_train_clean
364
+ self.X_val_clean_ = X_val_clean
365
+ self.X_test_clean_ = X_test_clean
366
+
367
+ self.X_train_corrupted_ = X_train_corrupted
368
+ self.X_val_corrupted_ = X_val_corrupted
369
+ self.X_test_corrupted_ = X_test_corrupted
370
+
371
+ # Final training tensors/matrices used by the pipeline
372
+ # Convention: X_* are corrupted inputs; y_* are clean targets
373
+ self.X_train_ = self.X_train_corrupted_
374
+ self.y_train_ = self.X_train_clean_
375
+
376
+ self.X_val_ = self.X_val_corrupted_
377
+ self.y_val_ = self.X_val_clean_
378
+
379
+ self.X_test_ = self.X_test_corrupted_
380
+ self.y_test_ = self.X_test_clean_
381
+
382
+ self.X_train_ = self._one_hot_encode_012(
383
+ self.X_train_, num_classes=self.num_classes_
384
+ )
385
+ self.X_val_ = self._one_hot_encode_012(
386
+ self.X_val_, num_classes=self.num_classes_
387
+ )
388
+ self.X_test_ = self._one_hot_encode_012(
389
+ self.X_test_, num_classes=self.num_classes_
390
+ )
391
+ self.X_test_ = self.X_test_.detach().cpu().numpy()
392
+
345
393
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
394
+ self.valid_class_mask_ = self._build_valid_class_mask()
395
+
396
+ loci = getattr(self, "valid_class_mask_conflict_loci_", None)
397
+ if loci is not None and loci.size:
398
+ self._repair_ref_alt_from_iupac(loci)
399
+ self.valid_class_mask_ = self._build_valid_class_mask()
400
+
401
+ train_loader = self._get_data_loaders(
402
+ self.X_train_.detach().cpu().numpy(),
403
+ self.y_train_,
404
+ ~self.orig_mask_train_,
405
+ self.batch_size,
406
+ shuffle=True,
407
+ )
408
+
409
+ val_loader = self._get_data_loaders(
410
+ self.X_val_.detach().cpu().numpy(),
411
+ self.y_val_,
412
+ ~self.orig_mask_val_,
413
+ self.batch_size,
414
+ shuffle=False,
415
+ )
416
+
417
+ self.train_loader_ = train_loader
418
+ self.val_loader_ = val_loader
346
419
 
347
- # Optional tuning
348
420
  if self.tune:
349
- self.tune_hyperparameters()
421
+ self.tuned_params_ = self.tune_hyperparameters()
422
+ self.model_tuned_ = True
423
+ else:
424
+ self.model_tuned_ = False
425
+ self.class_weights_ = self._class_weights_from_zygosity(
426
+ self.y_train_,
427
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
428
+ inverse=self.inverse,
429
+ normalize=self.normalize,
430
+ max_ratio=self.max_ratio,
431
+ power=self.power,
432
+ )
433
+ self.tuned_params_ = {
434
+ "latent_dim": self.latent_dim,
435
+ "learning_rate": self.learning_rate,
436
+ "dropout_rate": self.dropout_rate,
437
+ "num_hidden_layers": self.num_hidden_layers,
438
+ "activation": self.activation,
439
+ "l1_penalty": self.l1_penalty,
440
+ "layer_scaling_factor": self.layer_scaling_factor,
441
+ "layer_schedule": self.layer_schedule,
442
+ }
443
+ self.tuned_params_.update(
444
+ {
445
+ "kl_beta": self.kl_beta,
446
+ "kl_beta_schedule": self.kl_beta_schedule,
447
+ "gamma": self.gamma,
448
+ "gamma_schedule": self.gamma_schedule,
449
+ "inverse": self.inverse,
450
+ "normalize": self.normalize,
451
+ "power": self.power,
452
+ }
453
+ )
454
+
455
+ self.tuned_params_["model_params"] = self.model_params
350
456
 
351
- # Best params (tuned or default)
352
- self.best_params_ = getattr(self, "best_params_", self._default_best_params())
457
+ if self.class_weights_ is not None:
458
+ self.logger.info(
459
+ f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
460
+ )
461
+
462
+ # Always start clean
463
+ self.best_params_ = copy.deepcopy(self.tuned_params_)
353
464
 
354
- # Class weights (device-aware)
355
- self.class_weights_ = self._normalize_class_weights(
356
- self._class_weights_from_zygosity(self.X_train_)
465
+ model_params_final = {
466
+ "n_features": self.num_features_,
467
+ "num_classes": self.num_classes_,
468
+ "latent_dim": int(self.best_params_["latent_dim"]),
469
+ "dropout_rate": float(self.best_params_["dropout_rate"]),
470
+ "activation": str(self.best_params_["activation"]),
471
+ "kl_beta": float(
472
+ self.best_params_.get("kl_beta", getattr(self, "kl_beta", 1.0))
473
+ ),
474
+ }
475
+
476
+ input_dim = self.num_features_ * self.num_classes_
477
+ model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
478
+ n_inputs=input_dim,
479
+ n_outputs=self.num_classes_,
480
+ n_samples=len(self.X_train_),
481
+ n_hidden=int(self.best_params_["num_hidden_layers"]),
482
+ latent_dim=int(self.best_params_["latent_dim"]),
483
+ alpha=float(self.best_params_["layer_scaling_factor"]),
484
+ schedule=str(self.best_params_["layer_schedule"]),
485
+ min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
357
486
  )
358
- if not self.is_haploid:
359
- self.pos_weights_ = self._compute_pos_weights(self.X_train_)
360
- else:
361
- self.pos_weights_ = None
362
487
 
363
- # DataLoader
364
- train_loader = self._get_data_loader(self.X_train_)
488
+ self.best_params_["model_params"] = model_params_final
365
489
 
366
- # Build & train
367
- model = self.build_model(self.Model, self.best_params_)
490
+ # Now build the model
491
+ model = self.build_model(self.Model, self.best_params_["model_params"])
368
492
  model.apply(self.initialize_weights)
369
493
 
494
+ if self.verbose or self.debug:
495
+ self.logger.info("Using model hyperparameters:")
496
+ pm = PrettyMetrics(
497
+ self.best_params_, precision=3, title="Model Hyperparameters"
498
+ )
499
+ pm.render()
500
+
501
+ lr_final = float(self.best_params_["learning_rate"])
502
+ l1_final = float(self.best_params_["l1_penalty"])
503
+
370
504
  loss, trained_model, history = self._train_and_validate_model(
371
505
  model=model,
372
- loader=train_loader,
373
- lr=self.learning_rate,
374
- l1_penalty=self.l1_penalty,
375
- return_history=True,
376
- class_weights=self.class_weights_,
377
- X_val=self.X_val_,
506
+ lr=lr_final,
507
+ l1_penalty=l1_final,
378
508
  params=self.best_params_,
379
- prune_metric=self.tune_metric,
380
- prune_warmup_epochs=10,
381
- eval_interval=1,
382
- eval_requires_latents=False, # no latent refinement for eval
383
- eval_latent_steps=0,
384
- eval_latent_lr=0.0,
385
- eval_latent_weight_decay=0.0,
509
+ trial=None,
510
+ class_weights=self.class_weights_,
511
+ kl_beta_schedule=self.best_params_["kl_beta_schedule"],
512
+ gamma_schedule=self.best_params_["gamma_schedule"],
386
513
  )
387
514
 
388
515
  if trained_model is None:
389
- msg = "VAE training failed; no model was returned."
516
+ msg = f"{self.model_name} training failed."
390
517
  self.logger.error(msg)
391
518
  raise RuntimeError(msg)
392
519
 
@@ -395,215 +522,209 @@ class ImputeVAE(BaseNNImputer):
395
522
  self.models_dir / f"final_model_{self.model_name}.pt",
396
523
  )
397
524
 
398
- hist: dict = {"Train": history}
399
- self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
525
+ if history is None:
526
+ hist = {"Train": []}
527
+ elif isinstance(history, dict):
528
+ hist = dict(history)
529
+ else:
530
+ hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
531
+
532
+ self.best_loss_ = loss
533
+ self.model_ = trained_model
534
+ self.history_ = hist
400
535
  self.is_fit_ = True
401
536
 
402
- # Evaluate (AE-parity reporting)
403
- eval_mask = (
404
- self.sim_mask_test_
405
- if (self.simulate_missing and self.sim_mask_test_ is not None)
406
- else None
407
- )
408
537
  self._evaluate_model(
409
- self.X_val_,
410
538
  self.model_,
411
- self.best_params_,
412
- eval_mask_override=eval_mask,
539
+ X=self.X_test_,
540
+ y=self.y_test_,
541
+ eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
542
+ objective_mode=False,
413
543
  )
414
544
 
415
- self.plotter_.plot_history(self.history_)
545
+ if self.show_plots:
546
+ self.plotter_.plot_history(self.history_)
547
+
416
548
  self._save_best_params(self.best_params_)
549
+
550
+ if self.model_tuned_:
551
+ title = f"{self.model_name} Optimized Parameters"
552
+
553
+ if self.verbose or self.debug:
554
+ pm = PrettyMetrics(self.best_params_, precision=2, title=title)
555
+ pm.render()
556
+
557
+ # Save best parameters to a JSON file.
558
+ self._save_best_params(self.best_params_, objective_mode=True)
417
559
  return self
418
560
 
419
561
  def transform(self) -> np.ndarray:
420
562
  """Impute missing genotypes and return IUPAC strings.
421
563
 
422
- This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
564
+ This method performs the following steps:
565
+ 1. Validates that the model has been fitted.
566
+ 2. Uses the trained model to predict missing genotypes for the entire dataset.
567
+ 3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
568
+ 4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
569
+ 5. Checks for any remaining missing values or decoding issues, raising errors if found.
570
+ 6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
571
+ 7. Returns the imputed IUPAC genotype matrix.
423
572
 
424
573
  Returns:
425
- np.ndarray: IUPAC strings of shape (n_samples, n_loci).
574
+ np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
426
575
 
427
- Raises:
428
- NotFittedError: If called before fit().
576
+ Notes:
577
+ - ``transform()`` does not take any arguments; it operates on the data provided during initialization.
578
+ - Ensure that the model has been fitted before calling this method.
579
+ - For haploid data, genotypes encoded as '1' are treated as '2' during decoding.
580
+ - The method checks for decoding failures (i.e., resulting in 'N') and raises an error if any are found.
429
581
  """
430
582
  if not getattr(self, "is_fit_", False):
431
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
583
+ msg = "Model is not fitted. Call fit() before transform()."
584
+ self.logger.error(msg)
585
+ raise NotFittedError(msg)
432
586
 
433
587
  self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
434
588
  X_to_impute = self.ground_truth_.copy()
435
589
 
436
- pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
590
+ # 1. Predict labels (0/1/2) for the entire matrix
591
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute)
437
592
 
438
- # Fill only missing
439
- missing_mask = X_to_impute == -1
593
+ # 2. Fill ONLY originally missing values
594
+ missing_mask = X_to_impute < 0
440
595
  imputed_array = X_to_impute.copy()
441
596
  imputed_array[missing_mask] = pred_labels[missing_mask]
442
597
 
443
- # Decode to IUPAC & optionally plot
444
- imputed_genotypes = self.pgenc.decode_012(imputed_array)
445
- original_genotypes = self.pgenc.decode_012(X_to_impute)
598
+ # Sanity check: all -1s should be gone
599
+ if np.any(imputed_array < 0):
600
+ msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
601
+ self.logger.error(msg)
602
+ raise RuntimeError(msg)
446
603
 
447
- plt.rcParams.update(self.plotter_.param_dict)
448
- self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
449
- self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
604
+ # 3. Handle Haploid mapping (2->1) before decoding if needed
605
+ decode_input = imputed_array
606
+ if self.is_haploid_:
607
+ decode_input = imputed_array.copy()
608
+ decode_input[decode_input == 1] = 2
450
609
 
451
- return imputed_genotypes
610
+ # 4. Decode integers to IUPAC strings
611
+ imputed_genotypes = self.decode_012(decode_input)
452
612
 
453
- # ---------- plumbing identical to AE, naming aligned ---------- #
613
+ # 5. Check for decoding failures (N)
614
+ # Hard error: downstream pipelines expect fully imputed data.
615
+ bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
616
+ if bad_loci.size > 0:
617
+ msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
618
+ self.logger.error(msg)
619
+ self.logger.debug(
620
+ "All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
621
+ )
622
+ raise RuntimeError(msg)
454
623
 
455
- def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
456
- """Create DataLoader over indices + integer targets (-1 for missing).
624
+ if self.show_plots:
625
+ original_input = X_to_impute
626
+ if self.is_haploid_:
627
+ original_input = X_to_impute.copy()
628
+ original_input[original_input == 1] = 2
457
629
 
458
- This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
630
+ original_genotypes = self.decode_012(original_input)
459
631
 
460
- Args:
461
- y (np.ndarray): 0/1/2 matrix with -1 for missing.
632
+ plt.rcParams.update(self.plotter_.param_dict)
633
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
634
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
462
635
 
463
- Returns:
464
- torch.utils.data.DataLoader: Shuffled DataLoader.
465
- """
466
- y_tensor = torch.from_numpy(y).long()
467
- indices = torch.arange(len(y), dtype=torch.long)
468
- dataset = torch.utils.data.TensorDataset(indices, y_tensor)
469
- pin_memory = self.device.type == "cuda"
470
- return torch.utils.data.DataLoader(
471
- dataset,
472
- batch_size=self.batch_size,
473
- shuffle=True,
474
- pin_memory=pin_memory,
475
- )
636
+ return imputed_genotypes
476
637
 
477
638
  def _train_and_validate_model(
478
639
  self,
479
640
  model: torch.nn.Module,
480
- loader: torch.utils.data.DataLoader,
641
+ *,
481
642
  lr: float,
482
643
  l1_penalty: float,
483
- trial: optuna.Trial | None = None,
484
- return_history: bool = False,
485
- class_weights: torch.Tensor | None = None,
486
- *,
487
- X_val: np.ndarray | None = None,
488
- params: dict | None = None,
489
- prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
490
- prune_warmup_epochs: int = 10,
491
- eval_interval: int = 1,
492
- eval_requires_latents: bool = False, # VAE: no latent eval refinement
493
- eval_latent_steps: int = 0,
494
- eval_latent_lr: float = 0.0,
495
- eval_latent_weight_decay: float = 0.0,
496
- ) -> Tuple[float, torch.nn.Module | None, list | None]:
497
- """Wrap the VAE training loop with β-anneal & Optuna pruning.
498
-
499
- This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
644
+ trial: Optional[optuna.Trial] = None,
645
+ params: Optional[dict[str, Any]] = None,
646
+ class_weights: Optional[torch.Tensor] = None,
647
+ kl_beta_schedule: bool = False,
648
+ gamma_schedule: bool = False,
649
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
650
+ """Train and validate the model.
651
+
652
+ This method orchestrates training with early stopping and optional Optuna pruning based on validation performance. It returns the best validation loss, the best model (with best weights loaded), and training history.
500
653
 
501
654
  Args:
502
655
  model (torch.nn.Module): VAE model.
503
- loader (torch.utils.data.DataLoader): Training data loader.
504
656
  lr (float): Learning rate.
505
657
  l1_penalty (float): L1 regularization coefficient.
506
- trial (optuna.Trial | None): Optuna trial for pruning.
507
- return_history (bool): If True, return training history.
508
- class_weights (torch.Tensor | None): CE class weights on device.
509
- X_val (np.ndarray | None): Validation data for pruning eval.
510
- params (dict | None): Current hyperparameters (for logging).
511
- prune_metric (str | None): Metric for pruning decisions.
512
- prune_warmup_epochs (int): Epochs to skip before pruning.
513
- eval_interval (int): Epochs between validation evaluations.
514
- eval_requires_latents (bool): If True, refine latents during eval.
515
- eval_latent_steps (int): Latent refinement steps if needed.
516
- eval_latent_lr (float): Latent refinement learning rate.
517
- eval_latent_weight_decay (float): Latent refinement L2 penalty.
658
+ trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
659
+ params (Optional[dict[str, float | int | str | dict[str, Any]]]): Model params for evaluation.
660
+ class_weights (Optional[torch.Tensor]): Class weights for loss computation.
661
+ kl_beta_schedule (bool): Whether to use KL beta scheduling.
662
+ gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
518
663
 
519
664
  Returns:
520
- Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
665
+ tuple[float, torch.nn.Module, dict[str, list[float]]]:
666
+ Best validation loss, best model, and training history.
521
667
  """
522
- if class_weights is None:
523
- msg = "Must provide class_weights."
524
- self.logger.error(msg)
525
- raise TypeError(msg)
668
+ max_epochs = self.epochs
669
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
526
670
 
527
- max_epochs = (
528
- self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
671
+ scheduler = _make_warmup_cosine_scheduler(
672
+ optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
529
673
  )
530
674
 
531
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
532
- scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
533
-
534
675
  best_loss, best_model, hist = self._execute_training_loop(
535
- loader=loader,
536
676
  optimizer=optimizer,
537
677
  scheduler=scheduler,
538
678
  model=model,
539
679
  l1_penalty=l1_penalty,
540
680
  trial=trial,
541
- return_history=return_history,
542
- class_weights=class_weights,
543
- X_val=X_val,
544
681
  params=params,
545
- prune_metric=prune_metric,
546
- prune_warmup_epochs=prune_warmup_epochs,
547
- eval_interval=eval_interval,
548
- eval_requires_latents=eval_requires_latents,
549
- eval_latent_steps=eval_latent_steps,
550
- eval_latent_lr=eval_latent_lr,
551
- eval_latent_weight_decay=eval_latent_weight_decay,
682
+ class_weights=class_weights,
683
+ kl_beta_schedule=kl_beta_schedule,
684
+ gamma_schedule=gamma_schedule,
552
685
  )
553
- if return_history:
554
- return best_loss, best_model, hist
555
-
556
- return best_loss, best_model, None
686
+ return best_loss, best_model, hist
557
687
 
558
688
  def _execute_training_loop(
559
689
  self,
560
- loader: torch.utils.data.DataLoader,
690
+ *,
561
691
  optimizer: torch.optim.Optimizer,
562
- scheduler: CosineAnnealingLR,
692
+ scheduler: (
693
+ torch.optim.lr_scheduler.CosineAnnealingLR
694
+ | torch.optim.lr_scheduler.SequentialLR
695
+ ),
563
696
  model: torch.nn.Module,
564
697
  l1_penalty: float,
565
- trial: optuna.Trial | None,
566
- return_history: bool,
567
- class_weights: torch.Tensor,
568
- *,
569
- X_val: np.ndarray | None = None,
570
- params: dict | None = None,
571
- prune_metric: str | None = None,
572
- prune_warmup_epochs: int = 10,
573
- eval_interval: int = 1,
574
- eval_requires_latents: bool = False,
575
- eval_latent_steps: int = 0,
576
- eval_latent_lr: float = 0.0,
577
- eval_latent_weight_decay: float = 0.0,
578
- ) -> Tuple[float, torch.nn.Module, list]:
579
- """Train VAE with stable focal CE + KL(β) anneal and numeric guards.
580
-
581
- This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
698
+ trial: Optional[optuna.Trial] = None,
699
+ params: Optional[dict[str, Any]] = None,
700
+ class_weights: Optional[torch.Tensor] = None,
701
+ kl_beta_schedule: bool = False,
702
+ gamma_schedule: bool = False,
703
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
704
+ """Train the model with focal CE reconstruction + KL divergence.
705
+
706
+ This method performs the training loop for the model using the provided optimizer and learning rate scheduler. It supports early stopping based on validation loss and integrates with Optuna for hyperparameter tuning. The method returns the best validation loss, the best model state, and the training history.
582
707
 
583
708
  Args:
584
- loader (torch.utils.data.DataLoader): Training data loader.
585
709
  optimizer (torch.optim.Optimizer): Optimizer.
586
- scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
710
+ scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): Learning rate scheduler.
587
711
  model (torch.nn.Module): VAE model.
588
712
  l1_penalty (float): L1 regularization coefficient.
589
- trial (optuna.Trial | None): Optuna trial for pruning.
590
- return_history (bool): If True, return training history.
591
- class_weights (torch.Tensor): CE class weights on device.
592
- X_val (np.ndarray | None): Validation data for pruning eval.
593
- params (dict | None): Current hyperparameters (for logging).
594
- prune_metric (str | None): Metric for pruning decisions.
595
- prune_warmup_epochs (int): Epochs to skip before pruning.
596
- eval_interval (int): Epochs between validation evaluations.
597
- eval_requires_latents (bool): If True, refine latents during eval.
598
- eval_latent_steps (int): Latent refinement steps if needed.
599
- eval_latent_lr (float): Latent refinement learning rate.
600
- eval_latent_weight_decay (float): Latent refinement L2 penalty.
713
+ trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
714
+ params (Optional[dict[str, Any]]): Model params for evaluation.
715
+ class_weights (Optional[torch.Tensor]): Class weights for loss computation.
716
+ kl_beta_schedule (bool): Whether to use KL beta scheduling.
717
+ gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
601
718
 
602
719
  Returns:
603
- Tuple[float, torch.nn.Module, list]: Best loss, best model, and training history.
720
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, training history.
721
+
722
+ Notes:
723
+ - Use CE with class weights during training/validation.
724
+ - Inference de-bias happens in _predict (separate).
725
+ - If `class_weights` is None, this will fall back to self.class_weights_ if present.
604
726
  """
605
- best_model = None
606
- history: list[float] = []
727
+ history: dict[str, list[float]] = defaultdict(list)
607
728
 
608
729
  early_stopping = EarlyStopping(
609
730
  patience=self.early_stop_gen,
@@ -613,204 +734,175 @@ class ImputeVAE(BaseNNImputer):
613
734
  debug=self.debug,
614
735
  )
615
736
 
616
- # ---- scalarize schedule endpoints up front ----
617
- kl = self.kl_beta_final
618
- if isinstance(kl, (list, tuple)):
619
- if len(kl) == 0:
620
- msg = "kl_beta_final list is empty."
621
- self.logger.error(msg)
622
- raise ValueError(msg)
623
- kl = kl[0]
624
- beta_final = float(kl)
625
-
626
- gamma_val = self.gamma
627
- if isinstance(gamma_val, (list, tuple)):
628
- if len(gamma_val) == 0:
629
- raise ValueError("gamma list is empty.")
630
- gamma_val = gamma_val[0]
631
- gamma_final = float(gamma_val)
632
-
633
- gamma_warm, gamma_ramp = 50, 100
634
- beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
635
-
636
- # Optional LR warmup
637
- warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
638
- base_lr = float(optimizer.param_groups[0]["lr"])
639
- min_lr = base_lr * 0.1
640
-
641
- max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
642
-
643
- for epoch in range(max_epochs):
644
- # focal γ schedule
645
- if epoch < gamma_warm:
646
- model.gamma = 0.0 # type: ignore[attr-defined]
647
- elif epoch < gamma_warm + gamma_ramp:
648
- model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
649
- else:
650
- model.gamma = gamma_final # type: ignore[attr-defined]
737
+ # KL schedule
738
+ kl_beta_target, kl_warm, kl_ramp = self._anneal_config(
739
+ params, "kl_beta", default=self.kl_beta, max_epochs=self.epochs
740
+ )
741
+
742
+ kl_beta_target = float(kl_beta_target)
743
+
744
+ gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
745
+ params, "gamma", default=self.gamma, max_epochs=self.epochs
746
+ )
747
+
748
+ cw = class_weights
749
+ if cw is not None and cw.device != self.device:
750
+ cw = cw.to(self.device)
651
751
 
652
- # KL β schedule (float throughout + ramp guard)
653
- if epoch < beta_warm:
654
- model.beta = 0.0 # type: ignore[attr-defined]
655
- elif beta_ramp > 0 and epoch < beta_warm + beta_ramp:
656
- model.beta = beta_final * ((epoch - beta_warm) / beta_ramp) # type: ignore[attr-defined]
752
+ ce_criterion = FocalCELoss(
753
+ alpha=cw, gamma=gamma_target, reduction="mean", ignore_index=-1
754
+ )
755
+
756
+ for epoch in range(int(self.epochs)):
757
+ if kl_beta_schedule:
758
+ kl_beta_current = self._update_anneal_schedule(
759
+ kl_beta_target,
760
+ warm=kl_warm,
761
+ ramp=kl_ramp,
762
+ epoch=epoch,
763
+ init_val=0.0,
764
+ )
657
765
  else:
658
- model.beta = beta_final # type: ignore[attr-defined]
659
- # LR warmup
660
- if epoch < warmup_epochs:
661
- scale = float(epoch + 1) / warmup_epochs
662
- for g in optimizer.param_groups:
663
- g["lr"] = min_lr + (base_lr - min_lr) * scale
766
+ kl_beta_current = kl_beta_target
767
+
768
+ if gamma_schedule:
769
+ gamma_current = self._update_anneal_schedule(
770
+ gamma_target,
771
+ warm=gamma_warm,
772
+ ramp=gamma_ramp,
773
+ epoch=epoch,
774
+ init_val=0.0,
775
+ )
776
+ ce_criterion.gamma = gamma_current
664
777
 
665
778
  train_loss = self._train_step(
666
- loader=loader,
779
+ loader=self.train_loader_,
667
780
  optimizer=optimizer,
668
781
  model=model,
782
+ ce_criterion=ce_criterion,
669
783
  l1_penalty=l1_penalty,
670
- class_weights=class_weights,
784
+ kl_beta=kl_beta_current,
671
785
  )
672
786
 
673
787
  if not np.isfinite(train_loss):
674
- if trial:
675
- raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
676
- # shrink LR and continue
677
- for g in optimizer.param_groups:
678
- g["lr"] *= 0.5
679
- continue
788
+ if trial is not None:
789
+ msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
790
+ self.logger.warning(msg)
791
+ raise optuna.exceptions.TrialPruned(msg)
792
+
793
+ msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
794
+ self.logger.error(msg)
795
+ raise RuntimeError(msg)
680
796
 
681
- if scheduler is not None:
682
- scheduler.step()
797
+ val_loss = self._val_step(
798
+ loader=self.val_loader_,
799
+ model=model,
800
+ ce_criterion=ce_criterion,
801
+ l1_penalty=l1_penalty,
802
+ kl_beta=kl_beta_current,
803
+ )
683
804
 
684
- if return_history:
685
- history.append(train_loss)
805
+ scheduler.step()
806
+ history["Train"].append(float(train_loss))
807
+ history["Val"].append(float(val_loss))
686
808
 
687
- early_stopping(train_loss, model)
809
+ early_stopping(val_loss, model)
688
810
  if early_stopping.early_stop:
689
- self.logger.info(f"Early stopping at epoch {epoch + 1}.")
811
+ self.logger.debug(
812
+ f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
813
+ )
690
814
  break
691
815
 
692
- # Optional Optuna pruning on a validation metric
693
- if (
694
- trial is not None
695
- and X_val is not None
696
- and ((epoch + 1) % eval_interval == 0)
697
- ):
698
- metric_key = prune_metric or getattr(self, "tune_metric", "f1")
699
- mask_override = None
700
- if (
701
- self.simulate_missing
702
- and getattr(self, "sim_mask_test_", None) is not None
703
- and getattr(self, "X_val_", None) is not None
704
- and X_val.shape == self.X_val_.shape
705
- ):
706
- mask_override = self.sim_mask_test_
707
- metric_val = self._eval_for_pruning(
816
+ if trial is not None:
817
+ metric_vals = self._evaluate_model(
708
818
  model=model,
709
- X_val=X_val,
710
- params=params or getattr(self, "best_params_", {}),
711
- metric=metric_key,
819
+ X=self.X_val_corrupted_,
820
+ y=self.y_val_,
821
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
712
822
  objective_mode=True,
713
- do_latent_infer=False, # VAE uses encoder; no latent refine
714
- latent_steps=0,
715
- latent_lr=0.0,
716
- latent_weight_decay=0.0,
717
- latent_seed=self.seed, # type: ignore
718
- _latent_cache=None,
719
- _latent_cache_key=None,
720
- eval_mask_override=mask_override,
721
823
  )
722
- trial.report(metric_val, step=epoch + 1)
723
- if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
824
+ trial.report(metric_vals[self.tune_metric], step=epoch + 1)
825
+ if trial.should_prune():
724
826
  raise optuna.exceptions.TrialPruned(
725
- f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
827
+ f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
726
828
  )
727
829
 
728
- best_loss = early_stopping.best_score
729
- best_model = copy.deepcopy(early_stopping.best_model)
730
- if best_model is None:
731
- best_model = copy.deepcopy(model)
732
- return best_loss, best_model, history
830
+ best_loss = float(early_stopping.best_score)
831
+
832
+ if early_stopping.best_state_dict is not None:
833
+ model.load_state_dict(early_stopping.best_state_dict)
834
+
835
+ return best_loss, model, dict(history)
733
836
 
734
837
  def _train_step(
735
838
  self,
736
839
  loader: torch.utils.data.DataLoader,
737
840
  optimizer: torch.optim.Optimizer,
738
841
  model: torch.nn.Module,
842
+ ce_criterion: torch.nn.Module,
843
+ *,
739
844
  l1_penalty: float,
740
- class_weights: torch.Tensor,
845
+ kl_beta: torch.Tensor | float,
741
846
  ) -> float:
742
- """One epoch: inputs VAE forward focal CE + β·KL with guards.
743
-
744
- This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
847
+ """Single epoch train step across batches (focal CE + KL + optional L1).
745
848
 
746
849
  Args:
747
850
  loader (torch.utils.data.DataLoader): Training data loader.
748
851
  optimizer (torch.optim.Optimizer): Optimizer.
749
852
  model (torch.nn.Module): VAE model.
853
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
750
854
  l1_penalty (float): L1 regularization coefficient.
751
- class_weights (torch.Tensor): CE class weights on device.
855
+ kl_beta (torch.Tensor | float): KL divergence weight.
752
856
 
753
857
  Returns:
754
- float: Average training loss for the epoch.
858
+ float: Average training loss.
859
+
755
860
  """
756
861
  model.train()
757
- running, used = 0.0, 0
862
+ running = 0.0
863
+ num_batches = 0
864
+
865
+ nF_model = self.num_features_
866
+ nC_model = self.num_classes_
758
867
  l1_params = tuple(p for p in model.parameters() if p.requires_grad)
759
- if class_weights is not None and class_weights.device != self.device:
760
- class_weights = class_weights.to(self.device)
761
868
 
762
- for _, y_batch in loader:
869
+ for X_batch, y_batch, m_batch in loader:
763
870
  optimizer.zero_grad(set_to_none=True)
871
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
872
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
873
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
764
874
 
765
- y_int = y_batch.to(self.device, non_blocking=True).long()
875
+ raw = model(X_batch)
876
+ logits0 = raw[0]
766
877
 
767
- if self.is_haploid:
768
- x_in = self._one_hot_encode_012(y_int) # (B, L, 2)
769
- else:
770
- x_in = self._encode_multilabel_inputs(y_int) # (B, L, 2)
878
+ expected = X_batch.shape[0] * nF_model * nC_model
879
+ if logits0.numel() != expected:
880
+ msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
881
+ self.logger.error(msg)
882
+ raise ValueError(msg)
771
883
 
772
- out = model(x_in)
773
- if isinstance(out, (list, tuple)):
774
- recon_logits, mu, logvar = out[0], out[1], out[2]
775
- else:
776
- recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
777
-
778
- # Upstream guard
779
- if (
780
- not torch.isfinite(recon_logits).all()
781
- or not torch.isfinite(mu).all()
782
- or not torch.isfinite(logvar).all()
783
- ):
784
- continue
884
+ logits_masked = logits0.view(-1, nC_model)
885
+ logits_masked = logits_masked[m_batch.view(-1)]
785
886
 
786
- gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
787
- beta = float(getattr(model, "beta", getattr(self, "kl_beta_final", 0.0)))
788
- gamma = max(0.0, min(gamma, 10.0))
887
+ targets_masked = y_batch.view(-1)
888
+ targets_masked = targets_masked[m_batch.view(-1)]
789
889
 
790
- if self.is_haploid:
791
- loss = compute_vae_loss(
792
- recon_logits=recon_logits,
793
- targets=y_int,
794
- mu=mu,
795
- logvar=logvar,
796
- class_weights=class_weights,
797
- gamma=gamma,
798
- beta=beta,
799
- )
800
- else:
801
- targets = self._multi_hot_targets(y_int)
802
- pos_w = getattr(self, "pos_weights_", None)
803
- bce = F.binary_cross_entropy_with_logits(
804
- recon_logits, targets, pos_weight=pos_w, reduction="none"
805
- )
806
- mask = (y_int != -1).unsqueeze(-1).float()
807
- recon_loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
808
- kl = (
809
- -0.5
810
- * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
811
- / (y_int.shape[0] + 1e-8)
812
- )
813
- loss = recon_loss + beta * kl
890
+ mask_flat = m_batch.view(-1)
891
+ if not bool(mask_flat.any()):
892
+ continue
893
+
894
+ # average number of masked loci per sample (scalar)
895
+ recon_scale = (mask_flat.sum().float() / float(X_batch.shape[0])).detach()
896
+
897
+ loss = compute_vae_loss(
898
+ ce_criterion,
899
+ logits_masked,
900
+ targets_masked,
901
+ mu=raw[1],
902
+ logvar=raw[2],
903
+ kl_beta=kl_beta,
904
+ recon_scale=recon_scale,
905
+ )
814
906
 
815
907
  if l1_penalty > 0:
816
908
  l1 = torch.zeros((), device=self.device)
@@ -818,185 +910,279 @@ class ImputeVAE(BaseNNImputer):
818
910
  l1 = l1 + p.abs().sum()
819
911
  loss = loss + l1_penalty * l1
820
912
 
821
- if not torch.isfinite(loss):
822
- continue
823
-
824
913
  loss.backward()
825
914
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
826
-
827
- # skip update if any grad is non-finite
828
- bad = any(
829
- p.grad is not None and not torch.isfinite(p.grad).all()
830
- for p in model.parameters()
831
- )
832
- if bad:
833
- optimizer.zero_grad(set_to_none=True)
834
- continue
835
-
836
915
  optimizer.step()
837
916
 
838
917
  running += float(loss.detach().item())
839
- used += 1
918
+ num_batches += 1
919
+
920
+ return float("inf") if num_batches == 0 else running / num_batches
921
+
922
+ def _val_step(
923
+ self,
924
+ loader: torch.utils.data.DataLoader,
925
+ model: torch.nn.Module,
926
+ ce_criterion: torch.nn.Module,
927
+ *,
928
+ l1_penalty: float,
929
+ kl_beta: torch.Tensor | float = 1.0,
930
+ ) -> float:
931
+ """Validation step for a single epoch (focal CE + KL + optional L1).
932
+
933
+ Args:
934
+ loader (torch.utils.data.DataLoader): Validation data loader.
935
+ model (torch.nn.Module): VAE model.
936
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
937
+ l1_penalty (float): L1 regularization coefficient.
938
+ kl_beta (torch.Tensor | float): KL divergence weight.
939
+
940
+ Returns:
941
+ float: Average validation loss.
942
+ """
943
+ model.eval()
944
+ running = 0.0
945
+ num_batches = 0
946
+
947
+ nF_model = self.num_features_
948
+ nC_model = self.num_classes_
949
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
950
+
951
+ with torch.no_grad():
952
+ for X_batch, y_batch, m_batch in loader:
953
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
954
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
955
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
956
+
957
+ raw = model(X_batch)
958
+ logits0 = raw[0]
959
+
960
+ expected = X_batch.shape[0] * nF_model * nC_model
961
+ if logits0.numel() != expected:
962
+ msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
963
+ self.logger.error(msg)
964
+ raise ValueError(msg)
965
+
966
+ logits_masked = logits0.view(-1, nC_model)
967
+ logits_masked = logits_masked[m_batch.view(-1)]
968
+
969
+ targets_masked = y_batch.view(-1)
970
+ targets_masked = targets_masked[m_batch.view(-1)]
971
+
972
+ mask_flat = m_batch.view(-1)
973
+
974
+ if not bool(mask_flat.any()):
975
+ continue
840
976
 
841
- return (running / used) if used > 0 else float("inf")
977
+ # average number of masked loci per sample (scalar)
978
+ recon_scale = (
979
+ mask_flat.sum().float() / float(X_batch.shape[0])
980
+ ).detach()
981
+
982
+ loss = compute_vae_loss(
983
+ ce_criterion,
984
+ logits_masked,
985
+ targets_masked,
986
+ mu=raw[1],
987
+ logvar=raw[2],
988
+ kl_beta=kl_beta,
989
+ recon_scale=recon_scale,
990
+ )
991
+
992
+ if l1_penalty > 0:
993
+ l1 = torch.zeros((), device=self.device)
994
+ for p in l1_params:
995
+ l1 = l1 + p.abs().sum()
996
+ loss = loss + l1_penalty * l1
997
+
998
+ if not torch.isfinite(loss):
999
+ continue
1000
+
1001
+ running += float(loss.item())
1002
+ num_batches += 1
1003
+
1004
+ return float("inf") if num_batches == 0 else running / num_batches
842
1005
 
843
1006
  def _predict(
844
1007
  self,
845
1008
  model: torch.nn.Module,
846
1009
  X: np.ndarray | torch.Tensor,
1010
+ *,
847
1011
  return_proba: bool = False,
848
- ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
849
- """Predict 0/1/2 labels (and probabilities) from masked inputs.
1012
+ ) -> tuple[np.ndarray, np.ndarray | None]:
1013
+ """Predict categorical genotype labels from logits.
850
1014
 
851
- This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
1015
+ This method uses the trained model to predict genotype labels for the provided input data. It handles both 0/1/2 encoded matrices and one-hot encoded matrices, converting them as necessary for model input. The method returns the predicted labels and, optionally, the predicted probabilities.
852
1016
 
853
1017
  Args:
854
1018
  model (torch.nn.Module): Trained model.
855
- X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
856
- return_proba (bool): If True, also return probabilities.
1019
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
1020
+ return_proba (bool): If True, return probabilities.
857
1021
 
858
1022
  Returns:
859
- Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
1023
+ tuple[np.ndarray, np.ndarray | None]: (labels, probas|None).
860
1024
  """
861
1025
  if model is None:
862
- msg = "Model is not trained. Call fit() before predict()."
1026
+ msg = (
1027
+ "Model passed to predict() is not trained. "
1028
+ "Call fit() before predict()."
1029
+ )
863
1030
  self.logger.error(msg)
864
1031
  raise NotFittedError(msg)
865
1032
 
866
1033
  model.eval()
1034
+
1035
+ nF = self.num_features_
1036
+ nC = self.num_classes_
1037
+
1038
+ if isinstance(X, torch.Tensor):
1039
+ X_tensor = X
1040
+ else:
1041
+ X_tensor = torch.from_numpy(X)
1042
+ X_tensor = X_tensor.float()
1043
+
1044
+ if X_tensor.device != self.device:
1045
+ X_tensor = X_tensor.to(self.device)
1046
+
1047
+ if X_tensor.dim() == 2:
1048
+ # 0/1/2 matrix -> one-hot for model input
1049
+ X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
1050
+ X_tensor = X_tensor.float()
1051
+
1052
+ if X_tensor.device != self.device:
1053
+ X_tensor = X_tensor.to(self.device)
1054
+
1055
+ elif X_tensor.dim() != 3:
1056
+ msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
1057
+ self.logger.error(msg)
1058
+ raise ValueError(msg)
1059
+
1060
+ if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
1061
+ msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
1062
+ self.logger.error(msg)
1063
+ raise ValueError(msg)
1064
+
1065
+ X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
1066
+
867
1067
  with torch.no_grad():
868
- X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
869
- X_tensor = X_tensor.to(self.device).long()
870
- if self.is_haploid:
871
- x_ohe = self._one_hot_encode_012(X_tensor)
872
- outputs = model(x_ohe)
873
- logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
874
- probas = torch.softmax(logits, dim=-1)
875
- labels = torch.argmax(probas, dim=-1)
876
- else:
877
- x_in = self._encode_multilabel_inputs(X_tensor)
878
- outputs = model(x_in)
879
- logits = outputs[0].view(-1, self.num_features_, self.output_classes_)
880
- probas2 = torch.sigmoid(logits)
881
- p_ref = probas2[..., 0]
882
- p_alt = probas2[..., 1]
883
- p_het = p_ref * p_alt
884
- p_ref_only = p_ref * (1 - p_alt)
885
- p_alt_only = p_alt * (1 - p_ref)
886
- probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
887
- probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
888
- labels = torch.argmax(probas, dim=-1)
1068
+ raw = model(X_tensor)
1069
+ logits = raw[0].view(-1, nF, nC)
1070
+ probas = torch.softmax(logits, dim=-1)
1071
+ labels = torch.argmax(probas, dim=-1)
889
1072
 
890
1073
  if return_proba:
891
1074
  return labels.cpu().numpy(), probas.cpu().numpy()
892
-
893
- return labels.cpu().numpy()
1075
+ return labels.cpu().numpy(), None
894
1076
 
895
1077
  def _evaluate_model(
896
1078
  self,
897
- X_val: np.ndarray,
898
1079
  model: torch.nn.Module,
899
- params: dict,
900
- objective_mode: bool = False,
901
- latent_vectors_val: np.ndarray | None = None,
1080
+ X: np.ndarray | torch.Tensor,
1081
+ y: np.ndarray,
1082
+ eval_mask: np.ndarray,
902
1083
  *,
903
- eval_mask_override: np.ndarray | None = None,
1084
+ objective_mode: bool = False,
904
1085
  ) -> Dict[str, float]:
905
- """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
1086
+ """Evaluate model performance on masked genotypes.
906
1087
 
907
- This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
1088
+ This method evaluates the performance of the trained model on a given dataset using a specified evaluation mask. It computes various classification metrics based on the predicted labels and probabilities, comparing them to the ground truth labels. The method returns a dictionary of evaluation metrics.
908
1089
 
909
1090
  Args:
910
- X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
911
1091
  model (torch.nn.Module): Trained model.
912
- params (dict): Current hyperparameters (for logging).
913
- objective_mode (bool): If True, minimize logging for Optuna.
914
- latent_vectors_val (np.ndarray | None): Not used by VAE.
915
- eval_mask_override (np.ndarray | None): Optional mask to override default eval mask.
1092
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
1093
+ y (np.ndarray): Ground truth 0/1/2 matrix with -1 for missing.
1094
+ eval_mask (np.ndarray): Boolean mask indicating which genotypes to evaluate.
1095
+ objective_mode (bool): If True, suppresses verbose output.
916
1096
 
917
1097
  Returns:
918
- Dict[str, float]: Computed metrics.
919
-
920
- Raises:
921
- NotFittedError: If called before fit().
1098
+ Dict[str, float]: Evaluation metrics.
922
1099
  """
923
- pred_labels, pred_probas = self._predict(
924
- model=model, X=X_val, return_proba=True
925
- )
1100
+ if model is None:
1101
+ msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
1102
+ self.logger.error(msg)
1103
+ raise NotFittedError(msg)
926
1104
 
927
- finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
928
-
929
- # FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
930
- if (
931
- hasattr(self, "X_val_")
932
- and getattr(self, "X_val_", None) is not None
933
- and X_val.shape[0] == self.X_val_.shape[0]
934
- ):
935
- GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
936
- elif (
937
- hasattr(self, "X_train_")
938
- and getattr(self, "X_train_", None) is not None
939
- and X_val.shape[0] == self.X_train_.shape[0]
940
- ):
941
- GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
942
- else:
943
- GT_ref = self.ground_truth_
1105
+ pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
1106
+
1107
+ if pred_probas is None:
1108
+ msg = "Predicted probabilities are None in _evaluate_model()."
1109
+ self.logger.error(msg)
1110
+ raise ValueError(msg)
1111
+
1112
+ y_true_flat = y[eval_mask].astype(np.int8, copy=False)
1113
+ y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
1114
+ y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
944
1115
 
945
- # FIX 2: Handle Feature Mismatch
946
- # If GT has more columns than X_val, slice it to match.
947
- if GT_ref.shape[1] > X_val.shape[1]:
948
- GT_ref = GT_ref[:, : X_val.shape[1]]
1116
+ valid = y_true_flat >= 0
1117
+ y_true_flat = y_true_flat[valid]
1118
+ y_pred_flat = y_pred_flat[valid]
1119
+ y_proba_flat = y_proba_flat[valid]
949
1120
 
950
- # Fallback safeguard
951
- if GT_ref.shape != X_val.shape:
952
- GT_ref = X_val
1121
+ if y_true_flat.size == 0:
1122
+ return {self.tune_metric: 0.0}
953
1123
 
954
- # FIX 3: Allow override mask to be sliced if it's too wide
955
- if eval_mask_override is not None:
956
- if eval_mask_override.shape[0] != X_val.shape[0]:
1124
+ # --- Hard assertions on probability shape ---
1125
+ if y_proba_flat.ndim != 2:
1126
+ msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got shape {y_proba_flat.shape}."
1127
+ self.logger.error(msg)
1128
+ raise ValueError(msg)
1129
+
1130
+ K = int(y_proba_flat.shape[1])
1131
+ if self.is_haploid_:
1132
+ if K not in (2, 3):
1133
+ msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
1134
+ self.logger.error(msg)
1135
+ raise ValueError(msg)
1136
+ else:
1137
+ if K != 3:
1138
+ msg = f"Diploid evaluation expects 3 classes; got {K}."
1139
+ self.logger.error(msg)
1140
+ raise ValueError(msg)
1141
+
1142
+ if not self.is_haploid_:
1143
+ if np.any((y_true_flat < 0) | (y_true_flat > 2)):
957
1144
  msg = (
958
- f"eval_mask_override rows {eval_mask_override.shape[0]} "
959
- f"does not match X_val rows {X_val.shape[0]}"
1145
+ "Diploid y_true_flat contains values outside {0,1,2} after masking."
960
1146
  )
961
1147
  self.logger.error(msg)
962
1148
  raise ValueError(msg)
963
1149
 
964
- if eval_mask_override.shape[1] > X_val.shape[1]:
965
- eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
1150
+ # --- Harmonize for haploid vs diploid ---
1151
+ if self.is_haploid_:
1152
+ # Binary scoring: REF=0, ALT=1 (treat any non-zero as ALT)
1153
+ y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
1154
+ y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
1155
+
1156
+ K = y_proba_flat.shape[1]
1157
+ if K == 2:
1158
+ pass
1159
+ elif K == 3:
1160
+ proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
1161
+ proba_2[:, 0] = y_proba_flat[:, 0]
1162
+ proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
1163
+ y_proba_flat = proba_2
966
1164
  else:
967
- eval_mask = eval_mask_override.astype(bool)
968
- else:
969
- eval_mask = X_val != -1
970
-
971
- eval_mask = eval_mask & finite_mask & (GT_ref != -1)
972
-
973
- y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
974
- y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
975
- y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
1165
+ msg = f"Haploid evaluation expects 2 or 3 prob columns; got {K}"
1166
+ self.logger.error(msg)
1167
+ raise ValueError(msg)
976
1168
 
977
- if y_true_flat.size == 0:
978
- return {self.tune_metric: 0.0}
1169
+ labels_for_scoring = [0, 1]
1170
+ target_names = ["REF", "ALT"]
1171
+ else:
1172
+ if y_proba_flat.shape[1] != 3:
1173
+ msg = f"Diploid evaluation expects 3 prob columns; got {y_proba_flat.shape[1]}"
1174
+ self.logger.error(msg)
1175
+ raise ValueError(msg)
1176
+ labels_for_scoring = [0, 1, 2]
1177
+ target_names = ["REF", "HET", "ALT"]
979
1178
 
980
- # ensure valid probability simplex after masking
1179
+ # Ensure valid probability simplex after masking/collapsing
981
1180
  y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
982
1181
  row_sums = y_proba_flat.sum(axis=1, keepdims=True)
983
- row_sums[row_sums == 0] = 1.0
1182
+ row_sums[row_sums == 0.0] = 1.0
984
1183
  y_proba_flat = y_proba_flat / row_sums
985
1184
 
986
- labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
987
- target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
988
-
989
- if self.is_haploid:
990
- y_true_flat = y_true_flat.copy()
991
- y_pred_flat = y_pred_flat.copy()
992
- y_true_flat[y_true_flat == 2] = 1
993
- y_pred_flat[y_pred_flat == 2] = 1
994
- proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
995
- proba_2[:, 0] = y_proba_flat[:, 0]
996
- proba_2[:, 1] = y_proba_flat[:, 2]
997
- y_proba_flat = proba_2
998
-
999
- y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
1185
+ y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
1000
1186
 
1001
1187
  metrics = self.scorers_.evaluate(
1002
1188
  y_true_flat,
@@ -1004,16 +1190,29 @@ class ImputeVAE(BaseNNImputer):
1004
1190
  y_true_ohe,
1005
1191
  y_proba_flat,
1006
1192
  objective_mode,
1007
- self.tune_metric,
1193
+ cast(
1194
+ Literal[
1195
+ "pr_macro",
1196
+ "roc_auc",
1197
+ "accuracy",
1198
+ "f1",
1199
+ "average_precision",
1200
+ "precision",
1201
+ "recall",
1202
+ "mcc",
1203
+ "jaccard",
1204
+ ],
1205
+ self.tune_metric,
1206
+ ),
1008
1207
  )
1009
1208
 
1010
1209
  if not objective_mode:
1011
- pm = PrettyMetrics(
1012
- metrics, precision=3, title=f"{self.model_name} Validation Metrics"
1013
- )
1014
- pm.render()
1210
+ if self.verbose or self.debug:
1211
+ pm = PrettyMetrics(
1212
+ metrics, precision=2, title=f"{self.model_name} Validation Metrics"
1213
+ )
1214
+ pm.render()
1015
1215
 
1016
- # Primary report
1017
1216
  self._make_class_reports(
1018
1217
  y_true=y_true_flat,
1019
1218
  y_pred_proba=y_proba_flat,
@@ -1022,16 +1221,17 @@ class ImputeVAE(BaseNNImputer):
1022
1221
  labels=target_names,
1023
1222
  )
1024
1223
 
1025
- # IUPAC decode & 10-base integer report
1026
- # FIX 4: Use current shape (X_val.shape) not self.num_features_
1027
- y_true_dec = self.pgenc.decode_012(
1028
- GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1029
- )
1030
- X_pred = X_val.copy()
1031
- X_pred[eval_mask] = y_pred_flat
1032
- y_pred_dec = self.pgenc.decode_012(
1033
- X_pred.reshape(X_val.shape[0], X_val.shape[1])
1034
- )
1224
+ # --- IUPAC decode and 10-base integer report ---
1225
+ y_true_matrix = np.array(y, copy=True)
1226
+ y_pred_matrix = np.array(pred_labels, copy=True)
1227
+
1228
+ if self.is_haploid_:
1229
+ # Map any ALT-coded >0 to 2 for decode_012, preserve missing -1
1230
+ y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
1231
+ y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
1232
+
1233
+ y_true_dec = self.decode_012(y_true_matrix)
1234
+ y_pred_dec = self.decode_012(y_pred_matrix)
1035
1235
 
1036
1236
  encodings_dict = {
1037
1237
  "A": 0,
@@ -1079,80 +1279,71 @@ class ImputeVAE(BaseNNImputer):
1079
1279
 
1080
1280
  Returns:
1081
1281
  float: Value of the tuning metric to be optimized.
1282
+
1283
+ Raises:
1284
+ RuntimeError: If model training returns None.
1285
+ optuna.exceptions.TrialPruned: If training fails unexpectedly or is unpromising.
1082
1286
  """
1083
1287
  try:
1084
1288
  params = self._sample_hyperparameters(trial)
1085
1289
 
1086
- # Use tune subsets when available (tune_fast)
1087
- X_train = getattr(self, "_tune_X_train", None)
1088
- X_val = getattr(self, "_tune_X_test", None)
1089
- if X_train is None or X_val is None:
1090
- X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1091
- X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1092
-
1093
- class_weights = self._normalize_class_weights(
1094
- self._class_weights_from_zygosity(X_train)
1095
- )
1096
- # Pos weights for diploid multilabel BCE during tuning
1097
- if not self.is_haploid:
1098
- self.pos_weights_ = self._compute_pos_weights(X_train)
1099
- else:
1100
- self.pos_weights_ = None
1101
- train_loader = self._get_data_loader(X_train)
1102
-
1103
1290
  model = self.build_model(self.Model, params["model_params"])
1104
1291
  model.apply(self.initialize_weights)
1105
1292
 
1106
- lr: float = params["lr"]
1293
+ lr: float = params["learning_rate"]
1107
1294
  l1_penalty: float = params["l1_penalty"]
1108
1295
 
1109
- # Train + prune on metric
1110
- _, model, __ = self._train_and_validate_model(
1296
+ class_weights = self._class_weights_from_zygosity(
1297
+ self.y_train_,
1298
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1299
+ inverse=params["inverse"],
1300
+ normalize=params["normalize"],
1301
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1302
+ power=params["power"],
1303
+ )
1304
+
1305
+ res = self._train_and_validate_model(
1111
1306
  model=model,
1112
- loader=train_loader,
1113
1307
  lr=lr,
1114
1308
  l1_penalty=l1_penalty,
1309
+ params=params,
1115
1310
  trial=trial,
1116
- return_history=False,
1117
1311
  class_weights=class_weights,
1118
- X_val=X_val,
1119
- params=params,
1120
- prune_metric=self.tune_metric,
1121
- prune_warmup_epochs=10,
1122
- eval_interval=self.tune_eval_interval,
1123
- eval_requires_latents=False,
1124
- eval_latent_steps=0,
1125
- eval_latent_lr=0.0,
1126
- eval_latent_weight_decay=0.0,
1127
- )
1128
-
1129
- eval_mask = (
1130
- self.sim_mask_test_
1131
- if (
1132
- self.simulate_missing
1133
- and getattr(self, "sim_mask_test_", None) is not None
1134
- )
1135
- else None
1312
+ kl_beta_schedule=params["kl_beta_schedule"],
1313
+ gamma_schedule=params["gamma_schedule"],
1136
1314
  )
1315
+ model = res[1]
1137
1316
 
1138
1317
  if model is None:
1139
- raise RuntimeError("Model training failed; no model was returned.")
1318
+ msg = "Model training returned None in tuning objective."
1319
+ self.logger.error(msg)
1320
+ raise RuntimeError(msg)
1140
1321
 
1141
1322
  metrics = self._evaluate_model(
1142
- X_val, model, params, objective_mode=True, eval_mask_override=eval_mask
1323
+ model=model,
1324
+ X=self.X_val_corrupted_,
1325
+ y=self.y_val_,
1326
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
1327
+ objective_mode=True,
1143
1328
  )
1144
- self._clear_resources(model, train_loader)
1329
+
1330
+ self._clear_resources(model)
1145
1331
  return metrics[self.tune_metric]
1146
1332
 
1147
1333
  except Exception as e:
1148
- # Keep sweeps moving
1149
- self.logger.debug(f"Trial failed with error: {e}")
1150
- raise optuna.exceptions.TrialPruned(
1151
- f"Trial failed with error. Enable debug logging for details."
1334
+ # Unexpected failure: surface full details in logs while still
1335
+ # pruning the trial to keep sweeps moving.
1336
+ err_type = type(e).__name__
1337
+ self.logger.warning(
1338
+ f"Trial {trial.number} failed due to exception {err_type}: {e}"
1152
1339
  )
1340
+ self.logger.debug(traceback.format_exc())
1341
+ raise optuna.exceptions.TrialPruned(
1342
+ f"Trial {trial.number} failed due to an exception. {err_type}: {e}. Enable debug logging for full traceback."
1343
+ ) from e
1153
1344
 
1154
1345
  def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1155
- """Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
1346
+ """Sample VAE hyperparameters; hidden sizes use BaseNNImputer helper.
1156
1347
 
1157
1348
  Args:
1158
1349
  trial (optuna.Trial): Optuna trial object.
@@ -1161,156 +1352,109 @@ class ImputeVAE(BaseNNImputer):
1161
1352
  Dict[str, int | float | str]: Sampled hyperparameters.
1162
1353
  """
1163
1354
  params = {
1164
- "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1165
- "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1166
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
1167
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
1355
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
1356
+ "learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
1357
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
1358
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
1168
1359
  "activation": trial.suggest_categorical(
1169
1360
  "activation", ["relu", "elu", "selu", "leaky_relu"]
1170
1361
  ),
1171
1362
  "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1172
1363
  "layer_scaling_factor": trial.suggest_float(
1173
- "layer_scaling_factor", 2.0, 4.0, step=0.5
1364
+ "layer_scaling_factor", 2.0, 10.0, step=0.025
1174
1365
  ),
1175
1366
  "layer_schedule": trial.suggest_categorical(
1176
1367
  "layer_schedule", ["pyramid", "linear"]
1177
1368
  ),
1178
- # VAE-specific β (final value after anneal)
1179
- "beta": trial.suggest_float("beta", 0.5, 2.0, step=0.5),
1180
- # focal gamma (if used in VAE recon CE)
1181
- "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
1369
+ "power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
1370
+ "normalize": trial.suggest_categorical("normalize", [True, False]),
1371
+ "inverse": trial.suggest_categorical("inverse", [True, False]),
1372
+ "gamma": trial.suggest_float("gamma", 0.0, 3.0, step=0.1),
1373
+ "kl_beta": trial.suggest_float("kl_beta", 0.1, 5.0, step=0.1),
1374
+ "kl_beta_schedule": trial.suggest_categorical(
1375
+ "kl_beta_schedule", [True, False]
1376
+ ),
1377
+ "gamma_schedule": trial.suggest_categorical(
1378
+ "gamma_schedule", [True, False]
1379
+ ),
1182
1380
  }
1183
1381
 
1184
- use_n_features = (
1185
- self._tune_num_features
1186
- if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
1187
- else self.num_features_
1188
- )
1189
- input_dim = use_n_features * self.output_classes_
1382
+ nF: int = self.num_features_
1383
+ nC: int = self.num_classes_
1384
+ input_dim = nF * nC
1385
+
1190
1386
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1191
1387
  n_inputs=input_dim,
1192
- n_outputs=input_dim,
1193
- n_samples=len(self.train_idx_),
1388
+ n_outputs=nC,
1389
+ n_samples=len(self.X_train_),
1194
1390
  n_hidden=params["num_hidden_layers"],
1391
+ latent_dim=params["latent_dim"],
1195
1392
  alpha=params["layer_scaling_factor"],
1196
1393
  schedule=params["layer_schedule"],
1197
1394
  )
1198
1395
 
1199
- # [latent_dim] + interior widths (exclude output width)
1200
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1201
-
1202
1396
  params["model_params"] = {
1203
- "n_features": use_n_features,
1204
- "num_classes": self.output_classes_,
1205
- "latent_dim": params["latent_dim"],
1397
+ "n_features": self.num_features_,
1398
+ "num_classes": nC, # categorical head: 2 or 3
1206
1399
  "dropout_rate": params["dropout_rate"],
1207
- "hidden_layer_sizes": hidden_only,
1400
+ "hidden_layer_sizes": hidden_layer_sizes,
1208
1401
  "activation": params["activation"],
1209
- # Pass through VAE recon/regularization coefficients
1210
- "beta": params["beta"],
1211
- "gamma": params["gamma"],
1402
+ "kl_beta": params["kl_beta"],
1212
1403
  }
1404
+
1213
1405
  return params
1214
1406
 
1215
- def _set_best_params(self, best_params: dict) -> dict:
1216
- """Adopt best params and return VAE model_params.
1407
+ def _set_best_params(self, params: dict) -> dict:
1408
+ """Update instance fields from tuned params and return model_params dict.
1217
1409
 
1218
1410
  Args:
1219
- best_params (dict): Best hyperparameters from tuning.
1411
+ params (dict): Best hyperparameters from tuning.
1220
1412
 
1221
1413
  Returns:
1222
- dict: VAE model parameters.
1414
+ dict: Model parameters for building the VAE.
1223
1415
  """
1224
- self.latent_dim = best_params["latent_dim"]
1225
- self.dropout_rate = best_params["dropout_rate"]
1226
- self.learning_rate = best_params["learning_rate"]
1227
- self.l1_penalty = best_params["l1_penalty"]
1228
- self.activation = best_params["activation"]
1229
- self.layer_scaling_factor = best_params["layer_scaling_factor"]
1230
- self.layer_schedule = best_params["layer_schedule"]
1231
- self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
1232
- self.gamma = best_params.get("gamma", self.gamma)
1233
-
1234
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
1235
- n_inputs=self.num_features_ * self.output_classes_,
1236
- n_outputs=self.num_features_ * self.output_classes_,
1237
- n_samples=len(self.train_idx_),
1238
- n_hidden=best_params["num_hidden_layers"],
1239
- alpha=best_params["layer_scaling_factor"],
1240
- schedule=best_params["layer_schedule"],
1416
+ self.latent_dim = params["latent_dim"]
1417
+ self.dropout_rate = params["dropout_rate"]
1418
+ self.learning_rate = params["learning_rate"]
1419
+ self.l1_penalty = params["l1_penalty"]
1420
+ self.activation = params["activation"]
1421
+ self.layer_scaling_factor = params["layer_scaling_factor"]
1422
+ self.layer_schedule = params["layer_schedule"]
1423
+ self.power = params["power"]
1424
+ self.normalize = params["normalize"]
1425
+ self.inverse = params["inverse"]
1426
+ self.gamma = params["gamma"]
1427
+ self.gamma_schedule = params["gamma_schedule"]
1428
+ self.kl_beta = params["kl_beta"]
1429
+ self.kl_beta_schedule = params["kl_beta_schedule"]
1430
+ self.class_weights_ = self._class_weights_from_zygosity(
1431
+ self.y_train_,
1432
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1433
+ inverse=self.inverse,
1434
+ normalize=self.normalize,
1435
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1436
+ power=self.power,
1241
1437
  )
1242
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1438
+ nF = self.num_features_
1439
+ nC = self.num_classes_
1440
+ input_dim = nF * nC
1243
1441
 
1244
- return {
1245
- "n_features": self.num_features_,
1246
- "latent_dim": self.latent_dim,
1247
- "hidden_layer_sizes": hidden_only,
1248
- "dropout_rate": self.dropout_rate,
1249
- "activation": self.activation,
1250
- "num_classes": self.output_classes_,
1251
- "beta": self.kl_beta_final,
1252
- "gamma": self.gamma,
1253
- }
1254
-
1255
- def _default_best_params(self) -> Dict[str, int | float | str | list]:
1256
- """Default VAE model params when tuning is disabled.
1257
-
1258
- Returns:
1259
- Dict[str, int | float | str | list]: VAE model parameters.
1260
- """
1261
1442
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1262
- n_inputs=self.num_features_ * self.output_classes_,
1263
- n_outputs=self.num_features_ * self.output_classes_,
1264
- n_samples=len(self.ground_truth_),
1265
- n_hidden=self.num_hidden_layers,
1266
- alpha=self.layer_scaling_factor,
1267
- schedule=self.layer_schedule,
1443
+ n_inputs=input_dim,
1444
+ n_outputs=nC,
1445
+ n_samples=len(self.X_train_),
1446
+ n_hidden=params["num_hidden_layers"],
1447
+ latent_dim=params["latent_dim"],
1448
+ alpha=params["layer_scaling_factor"],
1449
+ schedule=params["layer_schedule"],
1268
1450
  )
1451
+
1269
1452
  return {
1270
- "n_features": self.num_features_,
1453
+ "n_features": nF,
1271
1454
  "latent_dim": self.latent_dim,
1272
1455
  "hidden_layer_sizes": hidden_layer_sizes,
1273
1456
  "dropout_rate": self.dropout_rate,
1274
1457
  "activation": self.activation,
1275
- "num_classes": self.output_classes_,
1276
- "beta": self.kl_beta_final,
1277
- "gamma": self.gamma,
1458
+ "num_classes": nC,
1459
+ "kl_beta": params["kl_beta"],
1278
1460
  }
1279
-
1280
- def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
1281
- """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
1282
- if self.is_haploid:
1283
- return self._one_hot_encode_012(y)
1284
- y = y.to(self.device)
1285
- shape = y.shape + (2,)
1286
- out = torch.zeros(shape, device=self.device, dtype=torch.float32)
1287
- valid = y != -1
1288
- ref_mask = valid & (y != 2)
1289
- alt_mask = valid & (y != 0)
1290
- out[ref_mask, 0] = 1.0
1291
- out[alt_mask, 1] = 1.0
1292
- return out
1293
-
1294
- def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
1295
- """Targets aligned with _encode_multilabel_inputs for diploid training."""
1296
- if self.is_haploid:
1297
- raise RuntimeError("_multi_hot_targets called for haploid data.")
1298
- y = y.to(self.device)
1299
- out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
1300
- valid = y != -1
1301
- ref_mask = valid & (y != 2)
1302
- alt_mask = valid & (y != 0)
1303
- out[ref_mask, 0] = 1.0
1304
- out[alt_mask, 1] = 1.0
1305
- return out
1306
-
1307
- def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
1308
- """Balance REF/ALT channels for multilabel BCE."""
1309
- ref_pos = np.count_nonzero((X == 0) | (X == 1))
1310
- alt_pos = np.count_nonzero((X == 2) | (X == 1))
1311
- total_valid = np.count_nonzero(X != -1)
1312
- pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
1313
- neg_counts = np.maximum(total_valid - pos_counts, 1.0)
1314
- pos_counts = np.maximum(pos_counts, 1.0)
1315
- weights = neg_counts / pos_counts
1316
- return torch.tensor(weights, device=self.device, dtype=torch.float32)