pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,25 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import copy
4
- from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union
4
+ import traceback
5
+ from collections import defaultdict
6
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
5
7
 
6
8
  import matplotlib.pyplot as plt
7
9
  import numpy as np
8
10
  import optuna
9
11
  import torch
12
+ import torch.nn.functional as F
10
13
  from sklearn.exceptions import NotFittedError
11
- from sklearn.model_selection import train_test_split
12
14
  from snpio.analysis.genotype_encoder import GenotypeEncoder
13
15
  from snpio.utils.logging import LoggerManager
14
16
  from torch.optim.lr_scheduler import CosineAnnealingLR
15
17
 
16
18
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
17
19
  from pgsui.data_processing.containers import VAEConfig
18
- from pgsui.data_processing.transformers import SimMissingTransformer
19
20
  from pgsui.impute.unsupervised.base import BaseNNImputer
20
21
  from pgsui.impute.unsupervised.callbacks import EarlyStopping
21
- from pgsui.impute.unsupervised.loss_functions import compute_vae_loss
22
+ from pgsui.impute.unsupervised.loss_functions import FocalCELoss, compute_vae_loss
22
23
  from pgsui.impute.unsupervised.models.vae_model import VAEModel
23
24
  from pgsui.utils.logging_utils import configure_logger
24
25
  from pgsui.utils.pretty_metrics import PrettyMetrics
@@ -28,14 +29,47 @@ if TYPE_CHECKING:
28
29
  from snpio.read_input.genotype_data import GenotypeData
29
30
 
30
31
 
31
- def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
32
- """Normalize VAEConfig input from various sources.
32
+ def _make_warmup_cosine_scheduler(
33
+ optimizer: torch.optim.Optimizer,
34
+ *,
35
+ max_epochs: int,
36
+ warmup_epochs: int,
37
+ start_factor: float = 0.1,
38
+ ) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
39
+ """Create a warmup->cosine LR scheduler.
33
40
 
34
41
  Args:
35
- config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
42
+ optimizer (torch.optim.Optimizer): The optimizer to schedule.
43
+ max_epochs (int): Total number of epochs for training.
44
+ warmup_epochs (int): Number of warmup epochs.
45
+ start_factor (float): Initial LR factor for warmup.
36
46
 
37
47
  Returns:
38
- VAEConfig: Normalized configuration dataclass.
48
+ torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: The learning rate scheduler.
49
+ """
50
+ warmup_epochs = int(max(0, warmup_epochs))
51
+
52
+ if warmup_epochs == 0:
53
+ return CosineAnnealingLR(optimizer, T_max=max_epochs)
54
+
55
+ warmup = torch.optim.lr_scheduler.LinearLR(
56
+ optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
57
+ )
58
+ cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
59
+
60
+ return torch.optim.lr_scheduler.SequentialLR(
61
+ optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
62
+ )
63
+
64
+
65
+ def ensure_vae_config(config: VAEConfig | dict | str | None) -> VAEConfig:
66
+ """Ensure a VAEConfig instance from various input types.
67
+
68
+ Args:
69
+ config (VAEConfig | dict | str | None): Configuration input.
70
+
71
+ Returns:
72
+ VAEConfig: The resulting VAEConfig instance.
39
73
  """
40
74
  if config is None:
41
75
  return VAEConfig()
@@ -44,13 +78,13 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
44
78
  if isinstance(config, str):
45
79
  return load_yaml_to_dataclass(config, VAEConfig)
46
80
  if isinstance(config, dict):
81
+ cfg_in = copy.deepcopy(config)
47
82
  base = VAEConfig()
48
- # Respect top-level preset
49
- preset = config.pop("preset", None)
83
+ preset = cfg_in.pop("preset", None)
84
+ if "io" in cfg_in and isinstance(cfg_in["io"], dict):
85
+ preset = preset or cfg_in["io"].pop("preset", None)
50
86
  if preset:
51
87
  base = VAEConfig.from_preset(preset)
52
- # Flatten + apply
53
- flat: Dict[str, object] = {}
54
88
 
55
89
  def _flatten(prefix: str, d: dict, out: dict) -> dict:
56
90
  for k, v in d.items():
@@ -61,15 +95,19 @@ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
61
95
  out[kk] = v
62
96
  return out
63
97
 
64
- flat = _flatten("", config, {})
98
+ flat = _flatten("", cfg_in, {})
65
99
  return apply_dot_overrides(base, flat)
66
100
  raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
67
101
 
68
102
 
69
103
  class ImputeVAE(BaseNNImputer):
70
- """Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
104
+ """Variational Autoencoder (VAE) imputer for 0/1/2 genotypes.
105
+
106
+ Trains a VAE on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. The workflow simulates missingness once on the full matrix, then creates train/val/test splits. It supports haploid and diploid data, focal-CE reconstruction loss with a KL term (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
71
107
 
72
- This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
108
+ Notes:
109
+ - Training includes early stopping based on validation loss.
110
+ - The imputer can handle both haploid and diploid genotype data.
73
111
  """
74
112
 
75
113
  def __init__(
@@ -78,46 +116,37 @@ class ImputeVAE(BaseNNImputer):
78
116
  *,
79
117
  tree_parser: Optional["TreeParser"] = None,
80
118
  config: Optional[Union["VAEConfig", dict, str]] = None,
81
- overrides: dict | None = None,
82
- simulate_missing: bool | None = None,
83
- sim_strategy: (
84
- Literal[
85
- "random",
86
- "random_weighted",
87
- "random_weighted_inv",
88
- "nonrandom",
89
- "nonrandom_weighted",
90
- ]
91
- | None
92
- ) = None,
93
- sim_prop: float | None = None,
94
- sim_kwargs: dict | None = None,
95
- ):
96
- """Initialize the VAE imputer with a unified config interface.
97
-
98
- This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
119
+ overrides: Optional[dict] = None,
120
+ sim_strategy: Literal[
121
+ "random",
122
+ "random_weighted",
123
+ "random_weighted_inv",
124
+ "nonrandom",
125
+ "nonrandom_weighted",
126
+ ] = "random",
127
+ sim_prop: Optional[float] = None,
128
+ sim_kwargs: Optional[dict] = None,
129
+ ) -> None:
130
+ """Initialize the ImputeVAE imputer.
99
131
 
100
132
  Args:
101
- genotype_data (GenotypeData): Backing genotype data object.
102
- tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
103
- config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
104
- overrides (dict | None): Optional dot-key overrides with highest precedence.
105
- simulate_missing (bool | None): Whether to simulate missing data during training.
106
- sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
107
- sim_prop (float | None): Proportion of data to simulate as missing if simulating.
108
- sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
133
+ genotype_data (GenotypeData): Genotype data for imputation.
134
+ tree_parser (Optional[TreeParser]): Tree parser required for nonrandom strategies.
135
+ config (Optional[Union[VAEConfig, dict, str]]): Config dataclass, nested dict, YAML path, or None.
136
+ overrides (Optional[dict]): Dot-key overrides applied last with highest precedence.
137
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Missingness simulation strategy (overrides config).
138
+ sim_prop (Optional[float]): Proportion of entries to simulate as missing (overrides config). Default is None.
139
+ sim_kwargs (Optional[dict]): Extra missingness kwargs merged into config.
109
140
  """
110
141
  self.model_name = "ImputeVAE"
111
142
  self.genotype_data = genotype_data
112
143
  self.tree_parser = tree_parser
113
144
 
114
- # Normalize configuration and apply top-precedence overrides
115
145
  cfg = ensure_vae_config(config)
116
146
  if overrides:
117
147
  cfg = apply_dot_overrides(cfg, overrides)
118
148
  self.cfg = cfg
119
149
 
120
- # Logger (align with AE/NLPCA)
121
150
  logman = LoggerManager(
122
151
  __name__,
123
152
  prefix=self.cfg.io.prefix,
@@ -125,12 +154,10 @@ class ImputeVAE(BaseNNImputer):
125
154
  verbose=self.cfg.io.verbose,
126
155
  )
127
156
  self.logger = configure_logger(
128
- logman.get_logger(),
129
- verbose=self.cfg.io.verbose,
130
- debug=self.cfg.io.debug,
157
+ logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
131
158
  )
159
+ self.logger.propagate = False
132
160
 
133
- # BaseNNImputer bootstraps device/dirs/log formatting
134
161
  super().__init__(
135
162
  model_name=self.model_name,
136
163
  genotype_data=self.genotype_data,
@@ -140,11 +167,10 @@ class ImputeVAE(BaseNNImputer):
140
167
  debug=self.cfg.io.debug,
141
168
  )
142
169
 
143
- # Model hook & encoder
144
170
  self.Model = VAEModel
145
171
  self.pgenc = GenotypeEncoder(genotype_data)
146
172
 
147
- # IO/global
173
+ # I/O and general parameters
148
174
  self.seed = self.cfg.io.seed
149
175
  self.n_jobs = self.cfg.io.n_jobs
150
176
  self.prefix = self.cfg.io.prefix
@@ -153,41 +179,40 @@ class ImputeVAE(BaseNNImputer):
153
179
  self.debug = self.cfg.io.debug
154
180
  self.rng = np.random.default_rng(self.seed)
155
181
 
156
- # Simulated-missing controls (config defaults + ctor overrides)
182
+ # Simulation parameters
157
183
  sim_cfg = getattr(self.cfg, "sim", None)
158
184
  sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
159
185
  if sim_kwargs:
160
186
  sim_cfg_kwargs.update(sim_kwargs)
161
187
  if sim_cfg is None:
162
- default_sim_flag = bool(simulate_missing)
163
188
  default_strategy = "random"
164
- default_prop = 0.10
189
+ default_prop = 0.2
165
190
  else:
166
- default_sim_flag = sim_cfg.simulate_missing
167
191
  default_strategy = sim_cfg.sim_strategy
168
192
  default_prop = sim_cfg.sim_prop
169
- self.simulate_missing = (
170
- default_sim_flag if simulate_missing is None else bool(simulate_missing)
171
- )
193
+
194
+ self.simulate_missing = True
172
195
  self.sim_strategy = sim_strategy or default_strategy
173
196
  self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
174
197
  self.sim_kwargs = sim_cfg_kwargs
175
198
 
176
- # Model hyperparams (AE-parity)
199
+ # Model architecture parameters
177
200
  self.latent_dim = self.cfg.model.latent_dim
178
201
  self.dropout_rate = self.cfg.model.dropout_rate
179
202
  self.num_hidden_layers = self.cfg.model.num_hidden_layers
180
203
  self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
181
204
  self.layer_schedule = self.cfg.model.layer_schedule
182
- self.activation = self.cfg.model.hidden_activation
183
- self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
205
+ self.activation = self.cfg.model.activation
184
206
 
185
- # VAE-only KL controls
186
- self.kl_beta_final = self.cfg.vae.kl_beta
187
- self.kl_warmup = self.cfg.vae.kl_warmup
188
- self.kl_ramp = self.cfg.vae.kl_ramp
207
+ # VAE-specific parameters
208
+ self.kl_beta = self.cfg.vae.kl_beta
209
+ self.kl_beta_schedule = self.cfg.vae.kl_beta_schedule
189
210
 
190
- # Train hyperparams (AE-parity)
211
+ # Training parameters
212
+ self.power: float = self.cfg.train.weights_power
213
+ self.max_ratio: Any | float | None = self.cfg.train.weights_max_ratio
214
+ self.normalize: bool = self.cfg.train.weights_normalize
215
+ self.inverse: bool = self.cfg.train.weights_inverse
191
216
  self.batch_size = self.cfg.train.batch_size
192
217
  self.learning_rate = self.cfg.train.learning_rate
193
218
  self.l1_penalty: float = self.cfg.train.l1_penalty
@@ -195,32 +220,18 @@ class ImputeVAE(BaseNNImputer):
195
220
  self.min_epochs = self.cfg.train.min_epochs
196
221
  self.epochs = self.cfg.train.max_epochs
197
222
  self.validation_split = self.cfg.train.validation_split
198
- self.beta = self.cfg.train.weights_beta
199
- self.max_ratio = self.cfg.train.weights_max_ratio
223
+ self.gamma = self.cfg.train.gamma
224
+ self.gamma_schedule = self.cfg.train.gamma_schedule
200
225
 
201
- # Tuning (AE-parity surface; VAE ignores latent refinement during eval)
226
+ # Hyperparameter tuning
202
227
  self.tune = self.cfg.tune.enabled
203
- self.tune_fast = self.cfg.tune.fast
204
- self.tune_batch_size = self.cfg.tune.batch_size
205
- self.tune_epochs = self.cfg.tune.epochs
206
- self.tune_eval_interval = self.cfg.tune.eval_interval
207
- self.tune_metric: Literal[
208
- "pr_macro",
209
- "f1",
210
- "accuracy",
211
- "average_precision",
212
- "precision",
213
- "recall",
214
- "roc_auc",
215
- ] = self.cfg.tune.metric
228
+ self.tune_metric = self.cfg.tune.metric
216
229
  self.n_trials = self.cfg.tune.n_trials
217
230
  self.tune_save_db = self.cfg.tune.save_db
218
231
  self.tune_resume = self.cfg.tune.resume
219
- self.tune_max_samples = self.cfg.tune.max_samples
220
- self.tune_max_loci = self.cfg.tune.max_loci
221
232
  self.tune_patience = self.cfg.tune.patience
222
233
 
223
- # Plotting (AE-parity)
234
+ # Plotting parameters
224
235
  self.plot_format = self.cfg.plot.fmt
225
236
  self.plot_dpi = self.cfg.plot.dpi
226
237
  self.plot_fontsize = self.cfg.plot.fontsize
@@ -228,90 +239,61 @@ class ImputeVAE(BaseNNImputer):
228
239
  self.despine = self.cfg.plot.despine
229
240
  self.show_plots = self.cfg.plot.show
230
241
 
231
- # Derived at fit-time
232
- self.is_haploid: bool = False
233
- self.num_classes_: int = 3 # diploid default
242
+ # Internal attributes set during fitting
243
+ self.is_haploid_: bool = False
244
+ self.num_classes_: int = 3
234
245
  self.model_params: Dict[str, Any] = {}
235
- self.sim_mask_global_: np.ndarray | None = None
236
- self.sim_mask_train_: np.ndarray | None = None
237
- self.sim_mask_test_: np.ndarray | None = None
246
+ self.sim_mask_test_: np.ndarray
238
247
 
239
248
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
240
- msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
249
+ msg = "tree_parser is required for nonrandom sim strategies."
241
250
  self.logger.error(msg)
242
251
  raise ValueError(msg)
243
252
 
244
- # -------------------- Fit -------------------- #
245
253
  def fit(self) -> "ImputeVAE":
246
- """Fit the VAE on 0/1/2 encoded genotypes (missing -> -1).
247
-
248
- This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. Missing positions are encoded as -1 for loss masking (any simulated-missing loci are temporarily tagged with -9 by ``SimMissingTransformer`` before being re-encoded as -1). It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
254
+ """Fit the VAE imputer model to the genotype data.
255
+
256
+ This method performs the following steps:
257
+ 1. Validates the presence of SNP data in the genotype data.
258
+ 2. Determines the ploidy of the genotype data and sets up haploid/diploid handling.
259
+ 3. Simulates missingness in the genotype data based on the specified strategy.
260
+ 4. Splits the data into training, validation, and test sets.
261
+ 5. One-hot encodes the genotype data for model input.
262
+ 6. Initializes data loaders for training and validation.
263
+ 7. If hyperparameter tuning is enabled, tunes the model hyperparameters.
264
+ 8. Builds the VAE model with the best hyperparameters.
265
+ 9. Trains the VAE model using the training data and validates on the validation set.
266
+ 10. Evaluates the trained model on the test set and computes performance metrics.
267
+ 11. Saves the trained model and best hyperparameters.
268
+ 12. Generates plots of training history if enabled.
269
+ 13. Returns the fitted ImputeVAE instance.
249
270
 
250
271
  Returns:
251
- ImputeVAE: Fitted instance.
252
-
253
- Raises:
254
- RuntimeError: If training fails to produce a model.
272
+ ImputeVAE: The fitted ImputeVAE instance.
255
273
  """
256
274
  self.logger.info(f"Fitting {self.model_name} model...")
257
275
 
258
- # Data prep aligns with AE/NLPCA
259
- X012 = self._get_float_genotypes(copy=True)
260
- GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
261
- self.ground_truth_ = GT_full.astype(np.int64, copy=False)
276
+ if self.genotype_data.snp_data is None:
277
+ msg = f"SNP data is required for {self.model_name}."
278
+ self.logger.error(msg)
279
+ raise AttributeError(msg)
262
280
 
263
- self.sim_mask_global_ = None
264
- cache_key = self._sim_mask_cache_key()
265
- if self.simulate_missing:
266
- cached_mask = (
267
- None if cache_key is None else self._sim_mask_cache.get(cache_key)
268
- )
269
- if cached_mask is not None:
270
- self.sim_mask_global_ = cached_mask.copy()
271
- else:
272
- tr = SimMissingTransformer(
273
- genotype_data=self.genotype_data,
274
- tree_parser=self.tree_parser,
275
- prop_missing=self.sim_prop,
276
- strategy=self.sim_strategy,
277
- missing_val=-9,
278
- mask_missing=True,
279
- verbose=self.verbose,
280
- **self.sim_kwargs,
281
- )
282
- tr.fit(X012.copy())
283
- self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
284
- if cache_key is not None:
285
- self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
281
+ self.ploidy = self.cfg.io.ploidy
282
+ self.is_haploid_ = self.ploidy == 1
286
283
 
287
- X_for_model = self.ground_truth_.copy()
288
- X_for_model[self.sim_mask_global_] = -1
289
- else:
290
- X_for_model = self.ground_truth_.copy()
291
-
292
- # Ploidy/classes
293
- self.is_haploid = bool(
294
- np.all(
295
- np.isin(
296
- self.genotype_data.snp_data,
297
- ["A", "C", "G", "T", "N", "-", ".", "?"],
298
- )
299
- )
300
- )
301
- self.ploidy = 1 if self.is_haploid else 2
302
- self.num_classes_ = 2 if self.is_haploid else 3
303
- self.logger.info(
304
- f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
305
- f"using {self.num_classes_} classes."
306
- )
284
+ if self.ploidy > 2:
285
+ msg = f"{self.model_name} currently supports only haploid (1) or diploid (2) data; got ploidy={self.ploidy}."
286
+ self.logger.error(msg)
287
+ raise ValueError(msg)
307
288
 
308
- if self.is_haploid:
309
- self.ground_truth_[self.ground_truth_ == 2] = 1
310
- X_for_model[X_for_model == 2] = 1
289
+ self.num_classes_ = 2 if self.is_haploid_ else 3
311
290
 
312
- n_samples, self.num_features_ = X_for_model.shape
291
+ gt_full = self.pgenc.genotypes_012.copy()
292
+ gt_full[gt_full < 0] = -1
293
+ gt_full = np.nan_to_num(gt_full, nan=-1.0)
294
+ self.ground_truth_ = gt_full.astype(np.int8)
295
+ self.num_features_ = gt_full.shape[1]
313
296
 
314
- # Model params (decoder outputs L*K logits)
315
297
  self.model_params = {
316
298
  "n_features": self.num_features_,
317
299
  "num_classes": self.num_classes_,
@@ -320,66 +302,218 @@ class ImputeVAE(BaseNNImputer):
320
302
  "activation": self.activation,
321
303
  }
322
304
 
323
- # Train/Val split
324
- indices = np.arange(n_samples)
325
- train_idx, val_idx = train_test_split(
326
- indices, test_size=self.validation_split, random_state=self.seed
305
+ # Simulate missingness ONCE on the full matrix
306
+ sim_tup = self.sim_missing_transform(self.ground_truth_)
307
+ X_for_model_full = sim_tup[0]
308
+ self.sim_mask_ = sim_tup[1]
309
+ self.orig_mask_ = sim_tup[2]
310
+
311
+ # Split indices based on clean ground truth
312
+ self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
313
+ self.ground_truth_
327
314
  )
328
- self.train_idx_, self.test_idx_ = train_idx, val_idx
329
- self.X_train_ = X_for_model[train_idx]
330
- self.X_val_ = X_for_model[val_idx]
331
- self.GT_train_full_ = self.ground_truth_[train_idx]
332
- self.GT_test_full_ = self.ground_truth_[val_idx]
333
-
334
- if self.sim_mask_global_ is not None:
335
- self.sim_mask_train_ = self.sim_mask_global_[train_idx]
336
- self.sim_mask_test_ = self.sim_mask_global_[val_idx]
337
- else:
338
- self.sim_mask_train_ = None
339
- self.sim_mask_test_ = None
340
315
 
341
- # Plotters/scorers (shared utilities)
316
+ # --- Clean (targets) per split ---
317
+ X_train_clean = self.ground_truth_[self.train_idx_].copy()
318
+ X_val_clean = self.ground_truth_[self.val_idx_].copy()
319
+ X_test_clean = self.ground_truth_[self.test_idx_].copy()
320
+
321
+ # --- Corrupted (inputs) per split (from the single simulation) ---
322
+ X_train_corrupted = X_for_model_full[self.train_idx_].copy()
323
+ X_val_corrupted = X_for_model_full[self.val_idx_].copy()
324
+ X_test_corrupted = X_for_model_full[self.test_idx_].copy()
325
+
326
+ # --- Masks per split ---
327
+ self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
328
+ self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
329
+ self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
330
+
331
+ self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
332
+ self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
333
+ self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
334
+
335
+ # Persist clean/corrupted matrices if you want them accessible later
336
+ self.X_train_clean_ = X_train_clean
337
+ self.X_val_clean_ = X_val_clean
338
+ self.X_test_clean_ = X_test_clean
339
+
340
+ self.X_train_corrupted_ = X_train_corrupted
341
+ self.X_val_corrupted_ = X_val_corrupted
342
+ self.X_test_corrupted_ = X_test_corrupted
343
+
344
+ # --- Haploid harmonization (do NOT resimulate; just recode values) ---
345
+ if self.is_haploid_:
346
+
347
+ def _haploidize(arr: np.ndarray) -> np.ndarray:
348
+ out = arr.copy()
349
+ miss = out < 0
350
+ out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
351
+ out[miss] = -1
352
+ return out
353
+
354
+ X_train_clean = _haploidize(X_train_clean)
355
+ X_val_clean = _haploidize(X_val_clean)
356
+ X_test_clean = _haploidize(X_test_clean)
357
+
358
+ X_train_corrupted = _haploidize(X_train_corrupted)
359
+ X_val_corrupted = _haploidize(X_val_corrupted)
360
+ X_test_corrupted = _haploidize(X_test_corrupted)
361
+
362
+ # Write back the persisted versions too
363
+ self.X_train_clean_ = X_train_clean
364
+ self.X_val_clean_ = X_val_clean
365
+ self.X_test_clean_ = X_test_clean
366
+
367
+ self.X_train_corrupted_ = X_train_corrupted
368
+ self.X_val_corrupted_ = X_val_corrupted
369
+ self.X_test_corrupted_ = X_test_corrupted
370
+
371
+ # Final training tensors/matrices used by the pipeline
372
+ # Convention: X_* are corrupted inputs; y_* are clean targets
373
+ self.X_train_ = self.X_train_corrupted_
374
+ self.y_train_ = self.X_train_clean_
375
+
376
+ self.X_val_ = self.X_val_corrupted_
377
+ self.y_val_ = self.X_val_clean_
378
+
379
+ self.X_test_ = self.X_test_corrupted_
380
+ self.y_test_ = self.X_test_clean_
381
+
382
+ self.X_train_ = self._one_hot_encode_012(
383
+ self.X_train_, num_classes=self.num_classes_
384
+ )
385
+ self.X_val_ = self._one_hot_encode_012(
386
+ self.X_val_, num_classes=self.num_classes_
387
+ )
388
+ self.X_test_ = self._one_hot_encode_012(
389
+ self.X_test_, num_classes=self.num_classes_
390
+ )
391
+ self.X_test_ = self.X_test_.detach().cpu().numpy()
392
+
342
393
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
394
+ self.valid_class_mask_ = self._build_valid_class_mask()
395
+
396
+ loci = getattr(self, "valid_class_mask_conflict_loci_", None)
397
+ if loci is not None and loci.size:
398
+ self._repair_ref_alt_from_iupac(loci)
399
+ self.valid_class_mask_ = self._build_valid_class_mask()
400
+
401
+ train_loader = self._get_data_loaders(
402
+ self.X_train_.detach().cpu().numpy(),
403
+ self.y_train_,
404
+ ~self.orig_mask_train_,
405
+ self.batch_size,
406
+ shuffle=True,
407
+ )
408
+
409
+ val_loader = self._get_data_loaders(
410
+ self.X_val_.detach().cpu().numpy(),
411
+ self.y_val_,
412
+ ~self.orig_mask_val_,
413
+ self.batch_size,
414
+ shuffle=False,
415
+ )
416
+
417
+ self.train_loader_ = train_loader
418
+ self.val_loader_ = val_loader
343
419
 
344
- # Optional tuning
345
420
  if self.tune:
346
- self.tune_hyperparameters()
421
+ self.tuned_params_ = self.tune_hyperparameters()
422
+ self.model_tuned_ = True
423
+ else:
424
+ self.model_tuned_ = False
425
+ self.class_weights_ = self._class_weights_from_zygosity(
426
+ self.y_train_,
427
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
428
+ inverse=self.inverse,
429
+ normalize=self.normalize,
430
+ max_ratio=self.max_ratio,
431
+ power=self.power,
432
+ )
433
+ self.tuned_params_ = {
434
+ "latent_dim": self.latent_dim,
435
+ "learning_rate": self.learning_rate,
436
+ "dropout_rate": self.dropout_rate,
437
+ "num_hidden_layers": self.num_hidden_layers,
438
+ "activation": self.activation,
439
+ "l1_penalty": self.l1_penalty,
440
+ "layer_scaling_factor": self.layer_scaling_factor,
441
+ "layer_schedule": self.layer_schedule,
442
+ }
443
+ self.tuned_params_.update(
444
+ {
445
+ "kl_beta": self.kl_beta,
446
+ "kl_beta_schedule": self.kl_beta_schedule,
447
+ "gamma": self.gamma,
448
+ "gamma_schedule": self.gamma_schedule,
449
+ "inverse": self.inverse,
450
+ "normalize": self.normalize,
451
+ "power": self.power,
452
+ }
453
+ )
454
+
455
+ self.tuned_params_["model_params"] = self.model_params
456
+
457
+ if self.class_weights_ is not None:
458
+ self.logger.info(
459
+ f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
460
+ )
347
461
 
348
- # Best params (tuned or default)
349
- self.best_params_ = getattr(self, "best_params_", self._default_best_params())
462
+ # Always start clean
463
+ self.best_params_ = copy.deepcopy(self.tuned_params_)
464
+
465
+ model_params_final = {
466
+ "n_features": self.num_features_,
467
+ "num_classes": self.num_classes_,
468
+ "latent_dim": int(self.best_params_["latent_dim"]),
469
+ "dropout_rate": float(self.best_params_["dropout_rate"]),
470
+ "activation": str(self.best_params_["activation"]),
471
+ "kl_beta": float(
472
+ self.best_params_.get("kl_beta", getattr(self, "kl_beta", 1.0))
473
+ ),
474
+ }
350
475
 
351
- # Class weights (device-aware)
352
- self.class_weights_ = self._normalize_class_weights(
353
- self._class_weights_from_zygosity(self.X_train_)
476
+ input_dim = self.num_features_ * self.num_classes_
477
+ model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
478
+ n_inputs=input_dim,
479
+ n_outputs=self.num_classes_,
480
+ n_samples=len(self.X_train_),
481
+ n_hidden=int(self.best_params_["num_hidden_layers"]),
482
+ latent_dim=int(self.best_params_["latent_dim"]),
483
+ alpha=float(self.best_params_["layer_scaling_factor"]),
484
+ schedule=str(self.best_params_["layer_schedule"]),
485
+ min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
354
486
  )
355
487
 
356
- # DataLoader
357
- train_loader = self._get_data_loader(self.X_train_)
488
+ self.best_params_["model_params"] = model_params_final
358
489
 
359
- # Build & train
360
- model = self.build_model(self.Model, self.best_params_)
490
+ # Now build the model
491
+ model = self.build_model(self.Model, self.best_params_["model_params"])
361
492
  model.apply(self.initialize_weights)
362
493
 
494
+ if self.verbose or self.debug:
495
+ self.logger.info("Using model hyperparameters:")
496
+ pm = PrettyMetrics(
497
+ self.best_params_, precision=3, title="Model Hyperparameters"
498
+ )
499
+ pm.render()
500
+
501
+ lr_final = float(self.best_params_["learning_rate"])
502
+ l1_final = float(self.best_params_["l1_penalty"])
503
+
363
504
  loss, trained_model, history = self._train_and_validate_model(
364
505
  model=model,
365
- loader=train_loader,
366
- lr=self.learning_rate,
367
- l1_penalty=self.l1_penalty,
368
- return_history=True,
369
- class_weights=self.class_weights_,
370
- X_val=self.X_val_,
506
+ lr=lr_final,
507
+ l1_penalty=l1_final,
371
508
  params=self.best_params_,
372
- prune_metric=self.tune_metric,
373
- prune_warmup_epochs=5,
374
- eval_interval=1,
375
- eval_requires_latents=False, # no latent refinement for eval
376
- eval_latent_steps=0,
377
- eval_latent_lr=0.0,
378
- eval_latent_weight_decay=0.0,
509
+ trial=None,
510
+ class_weights=self.class_weights_,
511
+ kl_beta_schedule=self.best_params_["kl_beta_schedule"],
512
+ gamma_schedule=self.best_params_["gamma_schedule"],
379
513
  )
380
514
 
381
515
  if trained_model is None:
382
- msg = "VAE training failed; no model was returned."
516
+ msg = f"{self.model_name} training failed."
383
517
  self.logger.error(msg)
384
518
  raise RuntimeError(msg)
385
519
 
@@ -388,215 +522,209 @@ class ImputeVAE(BaseNNImputer):
388
522
  self.models_dir / f"final_model_{self.model_name}.pt",
389
523
  )
390
524
 
391
- hist: dict = {"Train": history}
392
- self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
525
+ if history is None:
526
+ hist = {"Train": []}
527
+ elif isinstance(history, dict):
528
+ hist = dict(history)
529
+ else:
530
+ hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
531
+
532
+ self.best_loss_ = loss
533
+ self.model_ = trained_model
534
+ self.history_ = hist
393
535
  self.is_fit_ = True
394
536
 
395
- # Evaluate (AE-parity reporting)
396
- eval_mask = (
397
- self.sim_mask_test_
398
- if (self.simulate_missing and self.sim_mask_test_ is not None)
399
- else None
400
- )
401
537
  self._evaluate_model(
402
- self.X_val_,
403
538
  self.model_,
404
- self.best_params_,
405
- eval_mask_override=eval_mask,
539
+ X=self.X_test_,
540
+ y=self.y_test_,
541
+ eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
542
+ objective_mode=False,
406
543
  )
407
544
 
408
- self.plotter_.plot_history(self.history_)
545
+ if self.show_plots:
546
+ self.plotter_.plot_history(self.history_)
547
+
409
548
  self._save_best_params(self.best_params_)
549
+
550
+ if self.model_tuned_:
551
+ title = f"{self.model_name} Optimized Parameters"
552
+
553
+ if self.verbose or self.debug:
554
+ pm = PrettyMetrics(self.best_params_, precision=2, title=title)
555
+ pm.render()
556
+
557
+ # Save best parameters to a JSON file.
558
+ self._save_best_params(self.best_params_, objective_mode=True)
410
559
  return self
411
560
 
412
561
  def transform(self) -> np.ndarray:
413
562
  """Impute missing genotypes and return IUPAC strings.
414
563
 
415
- This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
564
+ This method performs the following steps:
565
+ 1. Validates that the model has been fitted.
566
+ 2. Uses the trained model to predict missing genotypes for the entire dataset.
567
+ 3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
568
+ 4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
569
+ 5. Checks for any remaining missing values or decoding issues, raising errors if found.
570
+ 6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
571
+ 7. Returns the imputed IUPAC genotype matrix.
416
572
 
417
573
  Returns:
418
- np.ndarray: IUPAC strings of shape (n_samples, n_loci).
574
+ np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
419
575
 
420
- Raises:
421
- NotFittedError: If called before fit().
576
+ Notes:
577
+ - ``transform()`` does not take any arguments; it operates on the data provided during initialization.
578
+ - Ensure that the model has been fitted before calling this method.
579
+ - For haploid data, genotypes encoded as '1' are treated as '2' during decoding.
580
+ - The method checks for decoding failures (i.e., resulting in 'N') and raises an error if any are found.
422
581
  """
423
582
  if not getattr(self, "is_fit_", False):
424
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
583
+ msg = "Model is not fitted. Call fit() before transform()."
584
+ self.logger.error(msg)
585
+ raise NotFittedError(msg)
425
586
 
426
587
  self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
427
588
  X_to_impute = self.ground_truth_.copy()
428
589
 
429
- pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
590
+ # 1. Predict labels (0/1/2) for the entire matrix
591
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute)
430
592
 
431
- # Fill only missing
432
- missing_mask = X_to_impute == -1
593
+ # 2. Fill ONLY originally missing values
594
+ missing_mask = X_to_impute < 0
433
595
  imputed_array = X_to_impute.copy()
434
596
  imputed_array[missing_mask] = pred_labels[missing_mask]
435
597
 
436
- # Decode to IUPAC & optionally plot
437
- imputed_genotypes = self.pgenc.decode_012(imputed_array)
438
- original_genotypes = self.pgenc.decode_012(X_to_impute)
598
+ # Sanity check: all -1s should be gone
599
+ if np.any(imputed_array < 0):
600
+ msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
601
+ self.logger.error(msg)
602
+ raise RuntimeError(msg)
439
603
 
440
- plt.rcParams.update(self.plotter_.param_dict)
441
- self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
442
- self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
604
+ # 3. Handle Haploid mapping (2->1) before decoding if needed
605
+ decode_input = imputed_array
606
+ if self.is_haploid_:
607
+ decode_input = imputed_array.copy()
608
+ decode_input[decode_input == 1] = 2
443
609
 
444
- return imputed_genotypes
610
+ # 4. Decode integers to IUPAC strings
611
+ imputed_genotypes = self.decode_012(decode_input)
445
612
 
446
- # ---------- plumbing identical to AE, naming aligned ---------- #
613
+ # 5. Check for decoding failures (N)
614
+ # Hard error: downstream pipelines expect fully imputed data.
615
+ bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
616
+ if bad_loci.size > 0:
617
+ msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
618
+ self.logger.error(msg)
619
+ self.logger.debug(
620
+ "All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
621
+ )
622
+ raise RuntimeError(msg)
447
623
 
448
- def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
449
- """Create DataLoader over indices + integer targets (-1 for missing).
624
+ if self.show_plots:
625
+ original_input = X_to_impute
626
+ if self.is_haploid_:
627
+ original_input = X_to_impute.copy()
628
+ original_input[original_input == 1] = 2
450
629
 
451
- This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
630
+ original_genotypes = self.decode_012(original_input)
452
631
 
453
- Args:
454
- y (np.ndarray): 0/1/2 matrix with -1 for missing.
632
+ plt.rcParams.update(self.plotter_.param_dict)
633
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
634
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
455
635
 
456
- Returns:
457
- torch.utils.data.DataLoader: Shuffled DataLoader.
458
- """
459
- y_tensor = torch.from_numpy(y).long()
460
- indices = torch.arange(len(y), dtype=torch.long)
461
- dataset = torch.utils.data.TensorDataset(indices, y_tensor)
462
- pin_memory = self.device.type == "cuda"
463
- return torch.utils.data.DataLoader(
464
- dataset,
465
- batch_size=self.batch_size,
466
- shuffle=True,
467
- pin_memory=pin_memory,
468
- )
636
+ return imputed_genotypes
469
637
 
470
638
  def _train_and_validate_model(
471
639
  self,
472
640
  model: torch.nn.Module,
473
- loader: torch.utils.data.DataLoader,
641
+ *,
474
642
  lr: float,
475
643
  l1_penalty: float,
476
- trial: optuna.Trial | None = None,
477
- return_history: bool = False,
478
- class_weights: torch.Tensor | None = None,
479
- *,
480
- X_val: np.ndarray | None = None,
481
- params: dict | None = None,
482
- prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
483
- prune_warmup_epochs: int = 3,
484
- eval_interval: int = 1,
485
- eval_requires_latents: bool = False, # VAE: no latent eval refinement
486
- eval_latent_steps: int = 0,
487
- eval_latent_lr: float = 0.0,
488
- eval_latent_weight_decay: float = 0.0,
489
- ) -> Tuple[float, torch.nn.Module | None, list | None]:
490
- """Wrap the VAE training loop with β-anneal & Optuna pruning.
491
-
492
- This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
644
+ trial: Optional[optuna.Trial] = None,
645
+ params: Optional[dict[str, Any]] = None,
646
+ class_weights: Optional[torch.Tensor] = None,
647
+ kl_beta_schedule: bool = False,
648
+ gamma_schedule: bool = False,
649
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
650
+ """Train and validate the model.
651
+
652
+ This method orchestrates training with early stopping and optional Optuna pruning based on validation performance. It returns the best validation loss, the best model (with best weights loaded), and training history.
493
653
 
494
654
  Args:
495
655
  model (torch.nn.Module): VAE model.
496
- loader (torch.utils.data.DataLoader): Training data loader.
497
656
  lr (float): Learning rate.
498
657
  l1_penalty (float): L1 regularization coefficient.
499
- trial (optuna.Trial | None): Optuna trial for pruning.
500
- return_history (bool): If True, return training history.
501
- class_weights (torch.Tensor | None): CE class weights on device.
502
- X_val (np.ndarray | None): Validation data for pruning eval.
503
- params (dict | None): Current hyperparameters (for logging).
504
- prune_metric (str | None): Metric for pruning decisions.
505
- prune_warmup_epochs (int): Epochs to skip before pruning.
506
- eval_interval (int): Epochs between validation evaluations.
507
- eval_requires_latents (bool): If True, refine latents during eval.
508
- eval_latent_steps (int): Latent refinement steps if needed.
509
- eval_latent_lr (float): Latent refinement learning rate.
510
- eval_latent_weight_decay (float): Latent refinement L2 penalty.
658
+ trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
659
+ params (Optional[dict[str, float | int | str | dict[str, Any]]]): Model params for evaluation.
660
+ class_weights (Optional[torch.Tensor]): Class weights for loss computation.
661
+ kl_beta_schedule (bool): Whether to use KL beta scheduling.
662
+ gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
511
663
 
512
664
  Returns:
513
- Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
665
+ tuple[float, torch.nn.Module, dict[str, list[float]]]:
666
+ Best validation loss, best model, and training history.
514
667
  """
515
- if class_weights is None:
516
- msg = "Must provide class_weights."
517
- self.logger.error(msg)
518
- raise TypeError(msg)
668
+ max_epochs = self.epochs
669
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
519
670
 
520
- max_epochs = (
521
- self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
671
+ scheduler = _make_warmup_cosine_scheduler(
672
+ optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
522
673
  )
523
674
 
524
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
525
- scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
526
-
527
675
  best_loss, best_model, hist = self._execute_training_loop(
528
- loader=loader,
529
676
  optimizer=optimizer,
530
677
  scheduler=scheduler,
531
678
  model=model,
532
679
  l1_penalty=l1_penalty,
533
680
  trial=trial,
534
- return_history=return_history,
535
- class_weights=class_weights,
536
- X_val=X_val,
537
681
  params=params,
538
- prune_metric=prune_metric,
539
- prune_warmup_epochs=prune_warmup_epochs,
540
- eval_interval=eval_interval,
541
- eval_requires_latents=eval_requires_latents,
542
- eval_latent_steps=eval_latent_steps,
543
- eval_latent_lr=eval_latent_lr,
544
- eval_latent_weight_decay=eval_latent_weight_decay,
682
+ class_weights=class_weights,
683
+ kl_beta_schedule=kl_beta_schedule,
684
+ gamma_schedule=gamma_schedule,
545
685
  )
546
- if return_history:
547
- return best_loss, best_model, hist
548
-
549
- return best_loss, best_model, None
686
+ return best_loss, best_model, hist
550
687
 
551
688
  def _execute_training_loop(
552
689
  self,
553
- loader: torch.utils.data.DataLoader,
690
+ *,
554
691
  optimizer: torch.optim.Optimizer,
555
- scheduler: CosineAnnealingLR,
692
+ scheduler: (
693
+ torch.optim.lr_scheduler.CosineAnnealingLR
694
+ | torch.optim.lr_scheduler.SequentialLR
695
+ ),
556
696
  model: torch.nn.Module,
557
697
  l1_penalty: float,
558
- trial: optuna.Trial | None,
559
- return_history: bool,
560
- class_weights: torch.Tensor,
561
- *,
562
- X_val: np.ndarray | None = None,
563
- params: dict | None = None,
564
- prune_metric: str | None = None,
565
- prune_warmup_epochs: int = 3,
566
- eval_interval: int = 1,
567
- eval_requires_latents: bool = False,
568
- eval_latent_steps: int = 0,
569
- eval_latent_lr: float = 0.0,
570
- eval_latent_weight_decay: float = 0.0,
571
- ) -> Tuple[float, torch.nn.Module, list]:
572
- """Train VAE with stable focal CE + KL(β) anneal and numeric guards.
573
-
574
- This method implements the core training loop for the VAE model, incorporating a focal cross-entropy loss for reconstruction and a KL divergence term with an annealed weight (beta). It includes mechanisms for early stopping based on validation performance, learning rate scheduling, and optional pruning of unpromising trials when using Optuna for hyperparameter optimization. The method ensures numerical stability throughout the training process.
698
+ trial: Optional[optuna.Trial] = None,
699
+ params: Optional[dict[str, Any]] = None,
700
+ class_weights: Optional[torch.Tensor] = None,
701
+ kl_beta_schedule: bool = False,
702
+ gamma_schedule: bool = False,
703
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
704
+ """Train the model with focal CE reconstruction + KL divergence.
705
+
706
+ This method performs the training loop for the model using the provided optimizer and learning rate scheduler. It supports early stopping based on validation loss and integrates with Optuna for hyperparameter tuning. The method returns the best validation loss, the best model state, and the training history.
575
707
 
576
708
  Args:
577
- loader (torch.utils.data.DataLoader): Training data loader.
578
709
  optimizer (torch.optim.Optimizer): Optimizer.
579
- scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
710
+ scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): Learning rate scheduler.
580
711
  model (torch.nn.Module): VAE model.
581
712
  l1_penalty (float): L1 regularization coefficient.
582
- trial (optuna.Trial | None): Optuna trial for pruning.
583
- return_history (bool): If True, return training history.
584
- class_weights (torch.Tensor): CE class weights on device.
585
- X_val (np.ndarray | None): Validation data for pruning eval.
586
- params (dict | None): Current hyperparameters (for logging).
587
- prune_metric (str | None): Metric for pruning decisions.
588
- prune_warmup_epochs (int): Epochs to skip before pruning.
589
- eval_interval (int): Epochs between validation evaluations.
590
- eval_requires_latents (bool): If True, refine latents during eval.
591
- eval_latent_steps (int): Latent refinement steps if needed.
592
- eval_latent_lr (float): Latent refinement learning rate.
593
- eval_latent_weight_decay (float): Latent refinement L2 penalty.
713
+ trial (Optional[optuna.Trial]): Optuna trial for pruning (optional).
714
+ params (Optional[dict[str, Any]]): Model params for evaluation.
715
+ class_weights (Optional[torch.Tensor]): Class weights for loss computation.
716
+ kl_beta_schedule (bool): Whether to use KL beta scheduling.
717
+ gamma_schedule (bool): Whether to use gamma scheduling for focal CE loss.
594
718
 
595
719
  Returns:
596
- Tuple[float, torch.nn.Module, list]: Best loss, best model, and training history.
720
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, training history.
721
+
722
+ Notes:
723
+ - Use CE with class weights during training/validation.
724
+ - Inference de-bias happens in _predict (separate).
725
+ - If `class_weights` is None, this will fall back to self.class_weights_ if present.
597
726
  """
598
- best_model = None
599
- history: list[float] = []
727
+ history: dict[str, list[float]] = defaultdict(list)
600
728
 
601
729
  early_stopping = EarlyStopping(
602
730
  patience=self.early_stop_gen,
@@ -606,188 +734,174 @@ class ImputeVAE(BaseNNImputer):
606
734
  debug=self.debug,
607
735
  )
608
736
 
609
- # ---- scalarize schedule endpoints up front ----
610
- kl = self.kl_beta_final
611
- if isinstance(kl, (list, tuple)):
612
- if len(kl) == 0:
613
- msg = "kl_beta_final list is empty."
614
- self.logger.error(msg)
615
- raise ValueError(msg)
616
- kl = kl[0]
617
- beta_final = float(kl)
618
-
619
- gamma_val = self.gamma
620
- if isinstance(gamma_val, (list, tuple)):
621
- if len(gamma_val) == 0:
622
- raise ValueError("gamma list is empty.")
623
- gamma_val = gamma_val[0]
624
- gamma_final = float(gamma_val)
625
-
626
- gamma_warm, gamma_ramp = 50, 100
627
- beta_warm, beta_ramp = int(self.kl_warmup), int(self.kl_ramp)
628
-
629
- # Optional LR warmup
630
- warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
631
- base_lr = float(optimizer.param_groups[0]["lr"])
632
- min_lr = base_lr * 0.1
633
-
634
- max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
635
-
636
- for epoch in range(max_epochs):
637
- # focal γ schedule
638
- if epoch < gamma_warm:
639
- model.gamma = 0.0 # type: ignore[attr-defined]
640
- elif epoch < gamma_warm + gamma_ramp:
641
- model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore[attr-defined]
642
- else:
643
- model.gamma = gamma_final # type: ignore[attr-defined]
737
+ # KL schedule
738
+ kl_beta_target, kl_warm, kl_ramp = self._anneal_config(
739
+ params, "kl_beta", default=self.kl_beta, max_epochs=self.epochs
740
+ )
741
+
742
+ kl_beta_target = float(kl_beta_target)
743
+
744
+ gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
745
+ params, "gamma", default=self.gamma, max_epochs=self.epochs
746
+ )
747
+
748
+ cw = class_weights
749
+ if cw is not None and cw.device != self.device:
750
+ cw = cw.to(self.device)
751
+
752
+ ce_criterion = FocalCELoss(
753
+ alpha=cw, gamma=gamma_target, reduction="mean", ignore_index=-1
754
+ )
644
755
 
645
- # KL β schedule (float throughout + ramp guard)
646
- if epoch < beta_warm:
647
- model.beta = 0.0 # type: ignore[attr-defined]
648
- elif beta_ramp > 0 and epoch < beta_warm + beta_ramp:
649
- model.beta = beta_final * ((epoch - beta_warm) / beta_ramp) # type: ignore[attr-defined]
756
+ for epoch in range(int(self.epochs)):
757
+ if kl_beta_schedule:
758
+ kl_beta_current = self._update_anneal_schedule(
759
+ kl_beta_target,
760
+ warm=kl_warm,
761
+ ramp=kl_ramp,
762
+ epoch=epoch,
763
+ init_val=0.0,
764
+ )
650
765
  else:
651
- model.beta = beta_final # type: ignore[attr-defined]
652
- # LR warmup
653
- if epoch < warmup_epochs:
654
- scale = float(epoch + 1) / warmup_epochs
655
- for g in optimizer.param_groups:
656
- g["lr"] = min_lr + (base_lr - min_lr) * scale
766
+ kl_beta_current = kl_beta_target
767
+
768
+ if gamma_schedule:
769
+ gamma_current = self._update_anneal_schedule(
770
+ gamma_target,
771
+ warm=gamma_warm,
772
+ ramp=gamma_ramp,
773
+ epoch=epoch,
774
+ init_val=0.0,
775
+ )
776
+ ce_criterion.gamma = gamma_current
657
777
 
658
778
  train_loss = self._train_step(
659
- loader=loader,
779
+ loader=self.train_loader_,
660
780
  optimizer=optimizer,
661
781
  model=model,
782
+ ce_criterion=ce_criterion,
662
783
  l1_penalty=l1_penalty,
663
- class_weights=class_weights,
784
+ kl_beta=kl_beta_current,
664
785
  )
665
786
 
666
787
  if not np.isfinite(train_loss):
667
- if trial:
668
- raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
669
- # shrink LR and continue
670
- for g in optimizer.param_groups:
671
- g["lr"] *= 0.5
672
- continue
788
+ if trial is not None:
789
+ msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
790
+ self.logger.warning(msg)
791
+ raise optuna.exceptions.TrialPruned(msg)
673
792
 
674
- if scheduler is not None:
675
- scheduler.step()
793
+ msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
794
+ self.logger.error(msg)
795
+ raise RuntimeError(msg)
796
+
797
+ val_loss = self._val_step(
798
+ loader=self.val_loader_,
799
+ model=model,
800
+ ce_criterion=ce_criterion,
801
+ l1_penalty=l1_penalty,
802
+ kl_beta=kl_beta_current,
803
+ )
676
804
 
677
- if return_history:
678
- history.append(train_loss)
805
+ scheduler.step()
806
+ history["Train"].append(float(train_loss))
807
+ history["Val"].append(float(val_loss))
679
808
 
680
- early_stopping(train_loss, model)
809
+ early_stopping(val_loss, model)
681
810
  if early_stopping.early_stop:
682
- self.logger.info(f"Early stopping at epoch {epoch + 1}.")
811
+ self.logger.debug(
812
+ f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
813
+ )
683
814
  break
684
815
 
685
- # Optional Optuna pruning on a validation metric
686
- if (
687
- trial is not None
688
- and X_val is not None
689
- and ((epoch + 1) % eval_interval == 0)
690
- ):
691
- metric_key = prune_metric or getattr(self, "tune_metric", "f1")
692
- mask_override = None
693
- if (
694
- self.simulate_missing
695
- and getattr(self, "sim_mask_test_", None) is not None
696
- and getattr(self, "X_val_", None) is not None
697
- and X_val.shape == self.X_val_.shape
698
- ):
699
- mask_override = self.sim_mask_test_
700
- metric_val = self._eval_for_pruning(
816
+ if trial is not None:
817
+ metric_vals = self._evaluate_model(
701
818
  model=model,
702
- X_val=X_val,
703
- params=params or getattr(self, "best_params_", {}),
704
- metric=metric_key,
819
+ X=self.X_val_corrupted_,
820
+ y=self.y_val_,
821
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
705
822
  objective_mode=True,
706
- do_latent_infer=False, # VAE uses encoder; no latent refine
707
- latent_steps=0,
708
- latent_lr=0.0,
709
- latent_weight_decay=0.0,
710
- latent_seed=self.seed, # type: ignore
711
- _latent_cache=None,
712
- _latent_cache_key=None,
713
- eval_mask_override=mask_override,
714
823
  )
715
- trial.report(metric_val, step=epoch + 1)
716
- if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
824
+ trial.report(metric_vals[self.tune_metric], step=epoch + 1)
825
+ if trial.should_prune():
717
826
  raise optuna.exceptions.TrialPruned(
718
- f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
827
+ f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
719
828
  )
720
829
 
721
- best_loss = early_stopping.best_score
722
- best_model = copy.deepcopy(early_stopping.best_model)
723
- if best_model is None:
724
- best_model = copy.deepcopy(model)
725
- return best_loss, best_model, history
830
+ best_loss = float(early_stopping.best_score)
831
+
832
+ if early_stopping.best_state_dict is not None:
833
+ model.load_state_dict(early_stopping.best_state_dict)
834
+
835
+ return best_loss, model, dict(history)
726
836
 
727
837
  def _train_step(
728
838
  self,
729
839
  loader: torch.utils.data.DataLoader,
730
840
  optimizer: torch.optim.Optimizer,
731
841
  model: torch.nn.Module,
842
+ ce_criterion: torch.nn.Module,
843
+ *,
732
844
  l1_penalty: float,
733
- class_weights: torch.Tensor,
845
+ kl_beta: torch.Tensor | float,
734
846
  ) -> float:
735
- """One epoch: inputs VAE forward focal CE + β·KL with guards.
736
-
737
- This method performs a single training epoch for the VAE model. It processes batches of data, computes the reconstruction and KL divergence losses, applies L1 regularization if specified, and updates the model parameters. The method includes safeguards against non-finite values in the model outputs and gradients to ensure stable training.
847
+ """Single epoch train step across batches (focal CE + KL + optional L1).
738
848
 
739
849
  Args:
740
850
  loader (torch.utils.data.DataLoader): Training data loader.
741
851
  optimizer (torch.optim.Optimizer): Optimizer.
742
852
  model (torch.nn.Module): VAE model.
853
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
743
854
  l1_penalty (float): L1 regularization coefficient.
744
- class_weights (torch.Tensor): CE class weights on device.
855
+ kl_beta (torch.Tensor | float): KL divergence weight.
745
856
 
746
857
  Returns:
747
- float: Average training loss for the epoch.
858
+ float: Average training loss.
859
+
748
860
  """
749
861
  model.train()
750
- running, used = 0.0, 0
862
+ running = 0.0
863
+ num_batches = 0
864
+
865
+ nF_model = self.num_features_
866
+ nC_model = self.num_classes_
751
867
  l1_params = tuple(p for p in model.parameters() if p.requires_grad)
752
- if class_weights is not None and class_weights.device != self.device:
753
- class_weights = class_weights.to(self.device)
754
868
 
755
- for _, y_batch in loader:
869
+ for X_batch, y_batch, m_batch in loader:
756
870
  optimizer.zero_grad(set_to_none=True)
871
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
872
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
873
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
757
874
 
758
- # targets: (B, L) int in {0,1,2,-1}
759
- y_int = y_batch.to(self.device, non_blocking=True).long()
875
+ raw = model(X_batch)
876
+ logits0 = raw[0]
760
877
 
761
- # inputs: one-hot with zeros for missing
762
- x_ohe = self._one_hot_encode_012(y_int) # (B, L, K)
878
+ expected = X_batch.shape[0] * nF_model * nC_model
879
+ if logits0.numel() != expected:
880
+ msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
881
+ self.logger.error(msg)
882
+ raise ValueError(msg)
763
883
 
764
- # Forward. Expect model to return recon_logits, mu, logvar, ...
765
- out = model(x_ohe)
766
- if isinstance(out, (list, tuple)):
767
- recon_logits, mu, logvar = out[0], out[1], out[2]
768
- else:
769
- recon_logits, mu, logvar = out["recon_logits"], out["mu"], out["logvar"]
770
-
771
- # Upstream guard
772
- if (
773
- not torch.isfinite(recon_logits).all()
774
- or not torch.isfinite(mu).all()
775
- or not torch.isfinite(logvar).all()
776
- ):
884
+ logits_masked = logits0.view(-1, nC_model)
885
+ logits_masked = logits_masked[m_batch.view(-1)]
886
+
887
+ targets_masked = y_batch.view(-1)
888
+ targets_masked = targets_masked[m_batch.view(-1)]
889
+
890
+ mask_flat = m_batch.view(-1)
891
+ if not bool(mask_flat.any()):
777
892
  continue
778
893
 
779
- gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
780
- beta = float(getattr(model, "beta", getattr(self, "kl_beta_final", 0.0)))
781
- gamma = max(0.0, min(gamma, 10.0))
894
+ # average number of masked loci per sample (scalar)
895
+ recon_scale = (mask_flat.sum().float() / float(X_batch.shape[0])).detach()
782
896
 
783
897
  loss = compute_vae_loss(
784
- recon_logits=recon_logits,
785
- targets=y_int,
786
- mu=mu,
787
- logvar=logvar,
788
- class_weights=class_weights,
789
- gamma=gamma,
790
- beta=beta,
898
+ ce_criterion,
899
+ logits_masked,
900
+ targets_masked,
901
+ mu=raw[1],
902
+ logvar=raw[2],
903
+ kl_beta=kl_beta,
904
+ recon_scale=recon_scale,
791
905
  )
792
906
 
793
907
  if l1_penalty > 0:
@@ -796,171 +910,279 @@ class ImputeVAE(BaseNNImputer):
796
910
  l1 = l1 + p.abs().sum()
797
911
  loss = loss + l1_penalty * l1
798
912
 
799
- if not torch.isfinite(loss):
800
- continue
801
-
802
913
  loss.backward()
803
914
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
804
-
805
- # skip update if any grad is non-finite
806
- bad = any(
807
- p.grad is not None and not torch.isfinite(p.grad).all()
808
- for p in model.parameters()
809
- )
810
- if bad:
811
- optimizer.zero_grad(set_to_none=True)
812
- continue
813
-
814
915
  optimizer.step()
815
916
 
816
917
  running += float(loss.detach().item())
817
- used += 1
918
+ num_batches += 1
919
+
920
+ return float("inf") if num_batches == 0 else running / num_batches
921
+
922
+ def _val_step(
923
+ self,
924
+ loader: torch.utils.data.DataLoader,
925
+ model: torch.nn.Module,
926
+ ce_criterion: torch.nn.Module,
927
+ *,
928
+ l1_penalty: float,
929
+ kl_beta: torch.Tensor | float = 1.0,
930
+ ) -> float:
931
+ """Validation step for a single epoch (focal CE + KL + optional L1).
932
+
933
+ Args:
934
+ loader (torch.utils.data.DataLoader): Validation data loader.
935
+ model (torch.nn.Module): VAE model.
936
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
937
+ l1_penalty (float): L1 regularization coefficient.
938
+ kl_beta (torch.Tensor | float): KL divergence weight.
939
+
940
+ Returns:
941
+ float: Average validation loss.
942
+ """
943
+ model.eval()
944
+ running = 0.0
945
+ num_batches = 0
946
+
947
+ nF_model = self.num_features_
948
+ nC_model = self.num_classes_
949
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
950
+
951
+ with torch.no_grad():
952
+ for X_batch, y_batch, m_batch in loader:
953
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
954
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
955
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
956
+
957
+ raw = model(X_batch)
958
+ logits0 = raw[0]
959
+
960
+ expected = X_batch.shape[0] * nF_model * nC_model
961
+ if logits0.numel() != expected:
962
+ msg = f"VAE logits size mismatch: got {logits0.numel()}, expected {expected}"
963
+ self.logger.error(msg)
964
+ raise ValueError(msg)
965
+
966
+ logits_masked = logits0.view(-1, nC_model)
967
+ logits_masked = logits_masked[m_batch.view(-1)]
968
+
969
+ targets_masked = y_batch.view(-1)
970
+ targets_masked = targets_masked[m_batch.view(-1)]
971
+
972
+ mask_flat = m_batch.view(-1)
973
+
974
+ if not bool(mask_flat.any()):
975
+ continue
976
+
977
+ # average number of masked loci per sample (scalar)
978
+ recon_scale = (
979
+ mask_flat.sum().float() / float(X_batch.shape[0])
980
+ ).detach()
981
+
982
+ loss = compute_vae_loss(
983
+ ce_criterion,
984
+ logits_masked,
985
+ targets_masked,
986
+ mu=raw[1],
987
+ logvar=raw[2],
988
+ kl_beta=kl_beta,
989
+ recon_scale=recon_scale,
990
+ )
991
+
992
+ if l1_penalty > 0:
993
+ l1 = torch.zeros((), device=self.device)
994
+ for p in l1_params:
995
+ l1 = l1 + p.abs().sum()
996
+ loss = loss + l1_penalty * l1
997
+
998
+ if not torch.isfinite(loss):
999
+ continue
818
1000
 
819
- return (running / used) if used > 0 else float("inf")
1001
+ running += float(loss.item())
1002
+ num_batches += 1
1003
+
1004
+ return float("inf") if num_batches == 0 else running / num_batches
820
1005
 
821
1006
  def _predict(
822
1007
  self,
823
1008
  model: torch.nn.Module,
824
1009
  X: np.ndarray | torch.Tensor,
1010
+ *,
825
1011
  return_proba: bool = False,
826
- ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
827
- """Predict 0/1/2 labels (and probabilities) from masked inputs.
1012
+ ) -> tuple[np.ndarray, np.ndarray | None]:
1013
+ """Predict categorical genotype labels from logits.
828
1014
 
829
- This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
1015
+ This method uses the trained model to predict genotype labels for the provided input data. It handles both 0/1/2 encoded matrices and one-hot encoded matrices, converting them as necessary for model input. The method returns the predicted labels and, optionally, the predicted probabilities.
830
1016
 
831
1017
  Args:
832
1018
  model (torch.nn.Module): Trained model.
833
- X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
834
- return_proba (bool): If True, also return probabilities.
1019
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
1020
+ return_proba (bool): If True, return probabilities.
835
1021
 
836
1022
  Returns:
837
- Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
1023
+ tuple[np.ndarray, np.ndarray | None]: (labels, probas|None).
838
1024
  """
839
1025
  if model is None:
840
- msg = "Model is not trained. Call fit() before predict()."
1026
+ msg = (
1027
+ "Model passed to predict() is not trained. "
1028
+ "Call fit() before predict()."
1029
+ )
841
1030
  self.logger.error(msg)
842
1031
  raise NotFittedError(msg)
843
1032
 
844
1033
  model.eval()
1034
+
1035
+ nF = self.num_features_
1036
+ nC = self.num_classes_
1037
+
1038
+ if isinstance(X, torch.Tensor):
1039
+ X_tensor = X
1040
+ else:
1041
+ X_tensor = torch.from_numpy(X)
1042
+ X_tensor = X_tensor.float()
1043
+
1044
+ if X_tensor.device != self.device:
1045
+ X_tensor = X_tensor.to(self.device)
1046
+
1047
+ if X_tensor.dim() == 2:
1048
+ # 0/1/2 matrix -> one-hot for model input
1049
+ X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
1050
+ X_tensor = X_tensor.float()
1051
+
1052
+ if X_tensor.device != self.device:
1053
+ X_tensor = X_tensor.to(self.device)
1054
+
1055
+ elif X_tensor.dim() != 3:
1056
+ msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
1057
+ self.logger.error(msg)
1058
+ raise ValueError(msg)
1059
+
1060
+ if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
1061
+ msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
1062
+ self.logger.error(msg)
1063
+ raise ValueError(msg)
1064
+
1065
+ X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
1066
+
845
1067
  with torch.no_grad():
846
- X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
847
- X_tensor = X_tensor.to(self.device).long()
848
- x_ohe = self._one_hot_encode_012(X_tensor)
849
- outputs = model(x_ohe) # first element must be recon logits
850
- logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
1068
+ raw = model(X_tensor)
1069
+ logits = raw[0].view(-1, nF, nC)
851
1070
  probas = torch.softmax(logits, dim=-1)
852
1071
  labels = torch.argmax(probas, dim=-1)
853
1072
 
854
1073
  if return_proba:
855
1074
  return labels.cpu().numpy(), probas.cpu().numpy()
856
-
857
- return labels.cpu().numpy()
1075
+ return labels.cpu().numpy(), None
858
1076
 
859
1077
  def _evaluate_model(
860
1078
  self,
861
- X_val: np.ndarray,
862
1079
  model: torch.nn.Module,
863
- params: dict,
864
- objective_mode: bool = False,
865
- latent_vectors_val: np.ndarray | None = None,
1080
+ X: np.ndarray | torch.Tensor,
1081
+ y: np.ndarray,
1082
+ eval_mask: np.ndarray,
866
1083
  *,
867
- eval_mask_override: np.ndarray | None = None,
1084
+ objective_mode: bool = False,
868
1085
  ) -> Dict[str, float]:
869
- """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
1086
+ """Evaluate model performance on masked genotypes.
870
1087
 
871
- This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
1088
+ This method evaluates the performance of the trained model on a given dataset using a specified evaluation mask. It computes various classification metrics based on the predicted labels and probabilities, comparing them to the ground truth labels. The method returns a dictionary of evaluation metrics.
872
1089
 
873
1090
  Args:
874
- X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
875
1091
  model (torch.nn.Module): Trained model.
876
- params (dict): Current hyperparameters (for logging).
877
- objective_mode (bool): If True, minimize logging for Optuna.
878
- latent_vectors_val (np.ndarray | None): Not used by VAE.
879
- eval_mask_override (np.ndarray | None): Optional mask to override default eval mask.
1092
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing, or one-hot encoded (B, L, K).
1093
+ y (np.ndarray): Ground truth 0/1/2 matrix with -1 for missing.
1094
+ eval_mask (np.ndarray): Boolean mask indicating which genotypes to evaluate.
1095
+ objective_mode (bool): If True, suppresses verbose output.
880
1096
 
881
1097
  Returns:
882
- Dict[str, float]: Computed metrics.
883
-
884
- Raises:
885
- NotFittedError: If called before fit().
1098
+ Dict[str, float]: Evaluation metrics.
886
1099
  """
887
- pred_labels, pred_probas = self._predict(
888
- model=model, X=X_val, return_proba=True
889
- )
1100
+ if model is None:
1101
+ msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
1102
+ self.logger.error(msg)
1103
+ raise NotFittedError(msg)
890
1104
 
891
- finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
892
-
893
- # FIX 1: Match rows (shape[0]) only to allow feature subsets (tune_fast)
894
- if (
895
- hasattr(self, "X_val_")
896
- and getattr(self, "X_val_", None) is not None
897
- and X_val.shape[0] == self.X_val_.shape[0]
898
- ):
899
- GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
900
- elif (
901
- hasattr(self, "X_train_")
902
- and getattr(self, "X_train_", None) is not None
903
- and X_val.shape[0] == self.X_train_.shape[0]
904
- ):
905
- GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
906
- else:
907
- GT_ref = self.ground_truth_
1105
+ pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
908
1106
 
909
- # FIX 2: Handle Feature Mismatch
910
- # If GT has more columns than X_val, slice it to match.
911
- if GT_ref.shape[1] > X_val.shape[1]:
912
- GT_ref = GT_ref[:, : X_val.shape[1]]
1107
+ if pred_probas is None:
1108
+ msg = "Predicted probabilities are None in _evaluate_model()."
1109
+ self.logger.error(msg)
1110
+ raise ValueError(msg)
913
1111
 
914
- # Fallback safeguard
915
- if GT_ref.shape != X_val.shape:
916
- GT_ref = X_val
1112
+ y_true_flat = y[eval_mask].astype(np.int8, copy=False)
1113
+ y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
1114
+ y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
1115
+
1116
+ valid = y_true_flat >= 0
1117
+ y_true_flat = y_true_flat[valid]
1118
+ y_pred_flat = y_pred_flat[valid]
1119
+ y_proba_flat = y_proba_flat[valid]
1120
+
1121
+ if y_true_flat.size == 0:
1122
+ return {self.tune_metric: 0.0}
917
1123
 
918
- # FIX 3: Allow override mask to be sliced if it's too wide
919
- if eval_mask_override is not None:
920
- if eval_mask_override.shape[0] != X_val.shape[0]:
1124
+ # --- Hard assertions on probability shape ---
1125
+ if y_proba_flat.ndim != 2:
1126
+ msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got shape {y_proba_flat.shape}."
1127
+ self.logger.error(msg)
1128
+ raise ValueError(msg)
1129
+
1130
+ K = int(y_proba_flat.shape[1])
1131
+ if self.is_haploid_:
1132
+ if K not in (2, 3):
1133
+ msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
1134
+ self.logger.error(msg)
1135
+ raise ValueError(msg)
1136
+ else:
1137
+ if K != 3:
1138
+ msg = f"Diploid evaluation expects 3 classes; got {K}."
1139
+ self.logger.error(msg)
1140
+ raise ValueError(msg)
1141
+
1142
+ if not self.is_haploid_:
1143
+ if np.any((y_true_flat < 0) | (y_true_flat > 2)):
921
1144
  msg = (
922
- f"eval_mask_override rows {eval_mask_override.shape[0]} "
923
- f"does not match X_val rows {X_val.shape[0]}"
1145
+ "Diploid y_true_flat contains values outside {0,1,2} after masking."
924
1146
  )
925
1147
  self.logger.error(msg)
926
1148
  raise ValueError(msg)
927
1149
 
928
- if eval_mask_override.shape[1] > X_val.shape[1]:
929
- eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
1150
+ # --- Harmonize for haploid vs diploid ---
1151
+ if self.is_haploid_:
1152
+ # Binary scoring: REF=0, ALT=1 (treat any non-zero as ALT)
1153
+ y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
1154
+ y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
1155
+
1156
+ K = y_proba_flat.shape[1]
1157
+ if K == 2:
1158
+ pass
1159
+ elif K == 3:
1160
+ proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
1161
+ proba_2[:, 0] = y_proba_flat[:, 0]
1162
+ proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
1163
+ y_proba_flat = proba_2
930
1164
  else:
931
- eval_mask = eval_mask_override.astype(bool)
932
- else:
933
- eval_mask = X_val != -1
934
-
935
- eval_mask = eval_mask & finite_mask & (GT_ref != -1)
936
-
937
- y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
938
- y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
939
- y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
1165
+ msg = f"Haploid evaluation expects 2 or 3 prob columns; got {K}"
1166
+ self.logger.error(msg)
1167
+ raise ValueError(msg)
940
1168
 
941
- if y_true_flat.size == 0:
942
- return {self.tune_metric: 0.0}
1169
+ labels_for_scoring = [0, 1]
1170
+ target_names = ["REF", "ALT"]
1171
+ else:
1172
+ if y_proba_flat.shape[1] != 3:
1173
+ msg = f"Diploid evaluation expects 3 prob columns; got {y_proba_flat.shape[1]}"
1174
+ self.logger.error(msg)
1175
+ raise ValueError(msg)
1176
+ labels_for_scoring = [0, 1, 2]
1177
+ target_names = ["REF", "HET", "ALT"]
943
1178
 
944
- # ensure valid probability simplex after masking
1179
+ # Ensure valid probability simplex after masking/collapsing
945
1180
  y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
946
1181
  row_sums = y_proba_flat.sum(axis=1, keepdims=True)
947
- row_sums[row_sums == 0] = 1.0
1182
+ row_sums[row_sums == 0.0] = 1.0
948
1183
  y_proba_flat = y_proba_flat / row_sums
949
1184
 
950
- labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
951
- target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
952
-
953
- if self.is_haploid:
954
- y_true_flat = y_true_flat.copy()
955
- y_pred_flat = y_pred_flat.copy()
956
- y_true_flat[y_true_flat == 2] = 1
957
- y_pred_flat[y_pred_flat == 2] = 1
958
- proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
959
- proba_2[:, 0] = y_proba_flat[:, 0]
960
- proba_2[:, 1] = y_proba_flat[:, 2]
961
- y_proba_flat = proba_2
962
-
963
- y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
1185
+ y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
964
1186
 
965
1187
  metrics = self.scorers_.evaluate(
966
1188
  y_true_flat,
@@ -968,16 +1190,29 @@ class ImputeVAE(BaseNNImputer):
968
1190
  y_true_ohe,
969
1191
  y_proba_flat,
970
1192
  objective_mode,
971
- self.tune_metric,
1193
+ cast(
1194
+ Literal[
1195
+ "pr_macro",
1196
+ "roc_auc",
1197
+ "accuracy",
1198
+ "f1",
1199
+ "average_precision",
1200
+ "precision",
1201
+ "recall",
1202
+ "mcc",
1203
+ "jaccard",
1204
+ ],
1205
+ self.tune_metric,
1206
+ ),
972
1207
  )
973
1208
 
974
1209
  if not objective_mode:
975
- pm = PrettyMetrics(
976
- metrics, precision=3, title=f"{self.model_name} Validation Metrics"
977
- )
978
- pm.render()
1210
+ if self.verbose or self.debug:
1211
+ pm = PrettyMetrics(
1212
+ metrics, precision=2, title=f"{self.model_name} Validation Metrics"
1213
+ )
1214
+ pm.render()
979
1215
 
980
- # Primary report
981
1216
  self._make_class_reports(
982
1217
  y_true=y_true_flat,
983
1218
  y_pred_proba=y_proba_flat,
@@ -986,16 +1221,17 @@ class ImputeVAE(BaseNNImputer):
986
1221
  labels=target_names,
987
1222
  )
988
1223
 
989
- # IUPAC decode & 10-base integer report
990
- # FIX 4: Use current shape (X_val.shape) not self.num_features_
991
- y_true_dec = self.pgenc.decode_012(
992
- GT_ref.reshape(X_val.shape[0], X_val.shape[1])
993
- )
994
- X_pred = X_val.copy()
995
- X_pred[eval_mask] = y_pred_flat
996
- y_pred_dec = self.pgenc.decode_012(
997
- X_pred.reshape(X_val.shape[0], X_val.shape[1])
998
- )
1224
+ # --- IUPAC decode and 10-base integer report ---
1225
+ y_true_matrix = np.array(y, copy=True)
1226
+ y_pred_matrix = np.array(pred_labels, copy=True)
1227
+
1228
+ if self.is_haploid_:
1229
+ # Map any ALT-coded >0 to 2 for decode_012, preserve missing -1
1230
+ y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
1231
+ y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
1232
+
1233
+ y_true_dec = self.decode_012(y_true_matrix)
1234
+ y_pred_dec = self.decode_012(y_pred_matrix)
999
1235
 
1000
1236
  encodings_dict = {
1001
1237
  "A": 0,
@@ -1043,71 +1279,71 @@ class ImputeVAE(BaseNNImputer):
1043
1279
 
1044
1280
  Returns:
1045
1281
  float: Value of the tuning metric to be optimized.
1282
+
1283
+ Raises:
1284
+ RuntimeError: If model training returns None.
1285
+ optuna.exceptions.TrialPruned: If training fails unexpectedly or is unpromising.
1046
1286
  """
1047
1287
  try:
1048
1288
  params = self._sample_hyperparameters(trial)
1049
1289
 
1050
- X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1051
- X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1052
-
1053
- class_weights = self._normalize_class_weights(
1054
- self._class_weights_from_zygosity(X_train)
1055
- )
1056
- train_loader = self._get_data_loader(X_train)
1057
-
1058
1290
  model = self.build_model(self.Model, params["model_params"])
1059
1291
  model.apply(self.initialize_weights)
1060
1292
 
1061
- lr: float = params["lr"]
1293
+ lr: float = params["learning_rate"]
1062
1294
  l1_penalty: float = params["l1_penalty"]
1063
1295
 
1064
- # Train + prune on metric
1065
- _, model, __ = self._train_and_validate_model(
1296
+ class_weights = self._class_weights_from_zygosity(
1297
+ self.y_train_,
1298
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1299
+ inverse=params["inverse"],
1300
+ normalize=params["normalize"],
1301
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1302
+ power=params["power"],
1303
+ )
1304
+
1305
+ res = self._train_and_validate_model(
1066
1306
  model=model,
1067
- loader=train_loader,
1068
1307
  lr=lr,
1069
1308
  l1_penalty=l1_penalty,
1309
+ params=params,
1070
1310
  trial=trial,
1071
- return_history=False,
1072
1311
  class_weights=class_weights,
1073
- X_val=X_val,
1074
- params=params,
1075
- prune_metric=self.tune_metric,
1076
- prune_warmup_epochs=5,
1077
- eval_interval=self.tune_eval_interval,
1078
- eval_requires_latents=False,
1079
- eval_latent_steps=0,
1080
- eval_latent_lr=0.0,
1081
- eval_latent_weight_decay=0.0,
1082
- )
1083
-
1084
- eval_mask = (
1085
- self.sim_mask_test_
1086
- if (
1087
- self.simulate_missing
1088
- and getattr(self, "sim_mask_test_", None) is not None
1089
- )
1090
- else None
1312
+ kl_beta_schedule=params["kl_beta_schedule"],
1313
+ gamma_schedule=params["gamma_schedule"],
1091
1314
  )
1315
+ model = res[1]
1092
1316
 
1093
1317
  if model is None:
1094
- raise RuntimeError("Model training failed; no model was returned.")
1318
+ msg = "Model training returned None in tuning objective."
1319
+ self.logger.error(msg)
1320
+ raise RuntimeError(msg)
1095
1321
 
1096
1322
  metrics = self._evaluate_model(
1097
- X_val, model, params, objective_mode=True, eval_mask_override=eval_mask
1323
+ model=model,
1324
+ X=self.X_val_corrupted_,
1325
+ y=self.y_val_,
1326
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
1327
+ objective_mode=True,
1098
1328
  )
1099
- self._clear_resources(model, train_loader)
1329
+
1330
+ self._clear_resources(model)
1100
1331
  return metrics[self.tune_metric]
1101
1332
 
1102
1333
  except Exception as e:
1103
- # Keep sweeps moving
1104
- self.logger.debug(f"Trial failed with error: {e}")
1105
- raise optuna.exceptions.TrialPruned(
1106
- f"Trial failed with error. Enable debug logging for details."
1334
+ # Unexpected failure: surface full details in logs while still
1335
+ # pruning the trial to keep sweeps moving.
1336
+ err_type = type(e).__name__
1337
+ self.logger.warning(
1338
+ f"Trial {trial.number} failed due to exception {err_type}: {e}"
1107
1339
  )
1340
+ self.logger.debug(traceback.format_exc())
1341
+ raise optuna.exceptions.TrialPruned(
1342
+ f"Trial {trial.number} failed due to an exception. {err_type}: {e}. Enable debug logging for full traceback."
1343
+ ) from e
1108
1344
 
1109
1345
  def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1110
- """Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
1346
+ """Sample VAE hyperparameters; hidden sizes use BaseNNImputer helper.
1111
1347
 
1112
1348
  Args:
1113
1349
  trial (optuna.Trial): Optuna trial object.
@@ -1116,113 +1352,109 @@ class ImputeVAE(BaseNNImputer):
1116
1352
  Dict[str, int | float | str]: Sampled hyperparameters.
1117
1353
  """
1118
1354
  params = {
1119
- "latent_dim": trial.suggest_int("latent_dim", 2, 64),
1120
- "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
1121
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
1122
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
1355
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
1356
+ "learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
1357
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
1358
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
1123
1359
  "activation": trial.suggest_categorical(
1124
- "activation", ["relu", "elu", "selu"]
1360
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
1125
1361
  ),
1126
- "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
1362
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1127
1363
  "layer_scaling_factor": trial.suggest_float(
1128
- "layer_scaling_factor", 2.0, 10.0
1364
+ "layer_scaling_factor", 2.0, 10.0, step=0.025
1129
1365
  ),
1130
1366
  "layer_schedule": trial.suggest_categorical(
1131
- "layer_schedule", ["pyramid", "constant", "linear"]
1367
+ "layer_schedule", ["pyramid", "linear"]
1368
+ ),
1369
+ "power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
1370
+ "normalize": trial.suggest_categorical("normalize", [True, False]),
1371
+ "inverse": trial.suggest_categorical("inverse", [True, False]),
1372
+ "gamma": trial.suggest_float("gamma", 0.0, 3.0, step=0.1),
1373
+ "kl_beta": trial.suggest_float("kl_beta", 0.1, 5.0, step=0.1),
1374
+ "kl_beta_schedule": trial.suggest_categorical(
1375
+ "kl_beta_schedule", [True, False]
1376
+ ),
1377
+ "gamma_schedule": trial.suggest_categorical(
1378
+ "gamma_schedule", [True, False]
1132
1379
  ),
1133
- # VAE-specific β (final value after anneal)
1134
- "beta": trial.suggest_float("beta", 0.25, 4.0),
1135
- # focal gamma (if used in VAE recon CE)
1136
- "gamma": trial.suggest_float("gamma", 0.0, 5.0),
1137
1380
  }
1138
1381
 
1139
- input_dim = self.num_features_ * self.num_classes_
1382
+ nF: int = self.num_features_
1383
+ nC: int = self.num_classes_
1384
+ input_dim = nF * nC
1385
+
1140
1386
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1141
1387
  n_inputs=input_dim,
1142
- n_outputs=input_dim,
1143
- n_samples=len(self.train_idx_),
1388
+ n_outputs=nC,
1389
+ n_samples=len(self.X_train_),
1144
1390
  n_hidden=params["num_hidden_layers"],
1391
+ latent_dim=params["latent_dim"],
1145
1392
  alpha=params["layer_scaling_factor"],
1146
1393
  schedule=params["layer_schedule"],
1147
1394
  )
1148
1395
 
1149
- # [latent_dim] + interior widths (exclude output width)
1150
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1151
-
1152
1396
  params["model_params"] = {
1153
1397
  "n_features": self.num_features_,
1154
- "num_classes": self.num_classes_,
1155
- "latent_dim": params["latent_dim"],
1398
+ "num_classes": nC, # categorical head: 2 or 3
1156
1399
  "dropout_rate": params["dropout_rate"],
1157
- "hidden_layer_sizes": hidden_only,
1400
+ "hidden_layer_sizes": hidden_layer_sizes,
1158
1401
  "activation": params["activation"],
1159
- # Pass through VAE recon/regularization coefficients
1160
- "beta": params["beta"],
1161
- "gamma": params["gamma"],
1402
+ "kl_beta": params["kl_beta"],
1162
1403
  }
1404
+
1163
1405
  return params
1164
1406
 
1165
- def _set_best_params(self, best_params: dict) -> dict:
1166
- """Adopt best params and return VAE model_params.
1407
+ def _set_best_params(self, params: dict) -> dict:
1408
+ """Update instance fields from tuned params and return model_params dict.
1167
1409
 
1168
1410
  Args:
1169
- best_params (dict): Best hyperparameters from tuning.
1411
+ params (dict): Best hyperparameters from tuning.
1170
1412
 
1171
1413
  Returns:
1172
- dict: VAE model parameters.
1414
+ dict: Model parameters for building the VAE.
1173
1415
  """
1174
- self.latent_dim = best_params["latent_dim"]
1175
- self.dropout_rate = best_params["dropout_rate"]
1176
- self.learning_rate = best_params["learning_rate"]
1177
- self.l1_penalty = best_params["l1_penalty"]
1178
- self.activation = best_params["activation"]
1179
- self.layer_scaling_factor = best_params["layer_scaling_factor"]
1180
- self.layer_schedule = best_params["layer_schedule"]
1181
- self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
1182
- self.gamma = best_params.get("gamma", self.gamma)
1183
-
1184
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
1185
- n_inputs=self.num_features_ * self.num_classes_,
1186
- n_outputs=self.num_features_ * self.num_classes_,
1187
- n_samples=len(self.train_idx_),
1188
- n_hidden=best_params["num_hidden_layers"],
1189
- alpha=best_params["layer_scaling_factor"],
1190
- schedule=best_params["layer_schedule"],
1416
+ self.latent_dim = params["latent_dim"]
1417
+ self.dropout_rate = params["dropout_rate"]
1418
+ self.learning_rate = params["learning_rate"]
1419
+ self.l1_penalty = params["l1_penalty"]
1420
+ self.activation = params["activation"]
1421
+ self.layer_scaling_factor = params["layer_scaling_factor"]
1422
+ self.layer_schedule = params["layer_schedule"]
1423
+ self.power = params["power"]
1424
+ self.normalize = params["normalize"]
1425
+ self.inverse = params["inverse"]
1426
+ self.gamma = params["gamma"]
1427
+ self.gamma_schedule = params["gamma_schedule"]
1428
+ self.kl_beta = params["kl_beta"]
1429
+ self.kl_beta_schedule = params["kl_beta_schedule"]
1430
+ self.class_weights_ = self._class_weights_from_zygosity(
1431
+ self.y_train_,
1432
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1433
+ inverse=self.inverse,
1434
+ normalize=self.normalize,
1435
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1436
+ power=self.power,
1191
1437
  )
1192
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1438
+ nF = self.num_features_
1439
+ nC = self.num_classes_
1440
+ input_dim = nF * nC
1193
1441
 
1194
- return {
1195
- "n_features": self.num_features_,
1196
- "latent_dim": self.latent_dim,
1197
- "hidden_layer_sizes": hidden_only,
1198
- "dropout_rate": self.dropout_rate,
1199
- "activation": self.activation,
1200
- "num_classes": self.num_classes_,
1201
- "beta": self.kl_beta_final,
1202
- "gamma": self.gamma,
1203
- }
1204
-
1205
- def _default_best_params(self) -> Dict[str, int | float | str | list]:
1206
- """Default VAE model params when tuning is disabled.
1207
-
1208
- Returns:
1209
- Dict[str, int | float | str | list]: VAE model parameters.
1210
- """
1211
1442
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1212
- n_inputs=self.num_features_ * self.num_classes_,
1213
- n_outputs=self.num_features_ * self.num_classes_,
1214
- n_samples=len(self.ground_truth_),
1215
- n_hidden=self.num_hidden_layers,
1216
- alpha=self.layer_scaling_factor,
1217
- schedule=self.layer_schedule,
1443
+ n_inputs=input_dim,
1444
+ n_outputs=nC,
1445
+ n_samples=len(self.X_train_),
1446
+ n_hidden=params["num_hidden_layers"],
1447
+ latent_dim=params["latent_dim"],
1448
+ alpha=params["layer_scaling_factor"],
1449
+ schedule=params["layer_schedule"],
1218
1450
  )
1451
+
1219
1452
  return {
1220
- "n_features": self.num_features_,
1453
+ "n_features": nF,
1221
1454
  "latent_dim": self.latent_dim,
1222
1455
  "hidden_layer_sizes": hidden_layer_sizes,
1223
1456
  "dropout_rate": self.dropout_rate,
1224
1457
  "activation": self.activation,
1225
- "num_classes": self.num_classes_,
1226
- "beta": self.kl_beta_final,
1227
- "gamma": self.gamma,
1458
+ "num_classes": nC,
1459
+ "kl_beta": params["kl_beta"],
1228
1460
  }