pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,957 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import optuna
9
+ import torch
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.model_selection import train_test_split
12
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+
16
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
17
+ from pgsui.data_processing.containers import VAEConfig
18
+ from pgsui.impute.unsupervised.base import BaseNNImputer
19
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
+ from pgsui.impute.unsupervised.models.vae_model import VAEModel
21
+
22
+ if TYPE_CHECKING:
23
+ from snpio.read_input.genotype_data import GenotypeData
24
+
25
+
26
+ def ensure_vae_config(config: Union[VAEConfig, dict, str, None]) -> VAEConfig:
27
+ """Normalize VAEConfig input from various sources.
28
+
29
+ Args:
30
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
31
+
32
+ Returns:
33
+ VAEConfig: Normalized configuration dataclass.
34
+ """
35
+ if config is None:
36
+ return VAEConfig()
37
+ if isinstance(config, VAEConfig):
38
+ return config
39
+ if isinstance(config, str):
40
+ return load_yaml_to_dataclass(
41
+ config, VAEConfig, preset_builder=VAEConfig.from_preset
42
+ )
43
+ if isinstance(config, dict):
44
+ base = VAEConfig()
45
+ # Respect top-level preset
46
+ preset = config.pop("preset", None)
47
+ if preset:
48
+ base = VAEConfig.from_preset(preset)
49
+ # Flatten + apply
50
+ flat: Dict[str, object] = {}
51
+
52
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
53
+ for k, v in d.items():
54
+ kk = f"{prefix}.{k}" if prefix else k
55
+ if isinstance(v, dict):
56
+ _flatten(kk, v, out)
57
+ else:
58
+ out[kk] = v
59
+ return out
60
+
61
+ flat = _flatten("", config, {})
62
+ return apply_dot_overrides(base, flat)
63
+ raise TypeError("config must be a VAEConfig, dict, YAML path, or None.")
64
+
65
+
66
+ class ImputeVAE(BaseNNImputer):
67
+ """Variational Autoencoder imputer on 0/1/2 encodings (missing=-1).
68
+
69
+ This imputer implements a VAE with a multinomial (categorical) latent space. It is designed to handle missing data by inferring the latent distribution and generating plausible predictions. The model is trained using a combination of reconstruction loss (cross-entropy) and a KL divergence term, with the KL weight (beta) annealed over time. The imputer supports both haploid and diploid genotype data.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ genotype_data: "GenotypeData",
75
+ *,
76
+ config: Optional[Union["VAEConfig", dict, str]] = None,
77
+ overrides: dict | None = None,
78
+ ):
79
+ """Initialize the VAE imputer with a unified config interface.
80
+
81
+ This initializer sets up the VAE imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
82
+
83
+ Args:
84
+ genotype_data (GenotypeData): Backing genotype data object.
85
+ config (Union[VAEConfig, dict, str, None]): VAEConfig, nested dict, YAML path, or None (defaults).
86
+ overrides (dict | None): Optional dot-key overrides with highest precedence.
87
+ """
88
+ self.model_name = "ImputeVAE"
89
+ self.genotype_data = genotype_data
90
+
91
+ # Normalize configuration and apply top-precedence overrides
92
+ cfg = ensure_vae_config(config)
93
+ if overrides:
94
+ cfg = apply_dot_overrides(cfg, overrides)
95
+ self.cfg = cfg
96
+
97
+ # Logger (align with AE/NLPCA)
98
+ logman = LoggerManager(
99
+ __name__,
100
+ prefix=self.cfg.io.prefix,
101
+ debug=self.cfg.io.debug,
102
+ verbose=self.cfg.io.verbose,
103
+ )
104
+ self.logger = logman.get_logger()
105
+
106
+ # BaseNNImputer bootstraps device/dirs/log formatting
107
+ super().__init__(
108
+ prefix=self.cfg.io.prefix,
109
+ device=self.cfg.train.device,
110
+ verbose=self.cfg.io.verbose,
111
+ debug=self.cfg.io.debug,
112
+ )
113
+
114
+ # Model hook & encoder
115
+ self.Model = VAEModel
116
+ self.pgenc = GenotypeEncoder(genotype_data)
117
+
118
+ # IO/global
119
+ self.seed = self.cfg.io.seed
120
+ self.n_jobs = self.cfg.io.n_jobs
121
+ self.prefix = self.cfg.io.prefix
122
+ self.scoring_averaging = self.cfg.io.scoring_averaging
123
+ self.verbose = self.cfg.io.verbose
124
+ self.debug = self.cfg.io.debug
125
+ self.rng = np.random.default_rng(self.seed)
126
+
127
+ # Model hyperparams (AE-parity)
128
+ self.latent_dim = self.cfg.model.latent_dim
129
+ self.dropout_rate = self.cfg.model.dropout_rate
130
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
131
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
132
+ self.layer_schedule = self.cfg.model.layer_schedule
133
+ self.activation = self.cfg.model.hidden_activation
134
+ self.gamma = self.cfg.model.gamma # focal loss focusing (for recon CE)
135
+
136
+ # VAE-only KL controls
137
+ self.kl_beta_final = self.cfg.vae.kl_beta
138
+ self.kl_warmup = self.cfg.vae.kl_warmup
139
+ self.kl_ramp = self.cfg.vae.kl_ramp
140
+
141
+ # Train hyperparams (AE-parity)
142
+ self.batch_size = self.cfg.train.batch_size
143
+ self.learning_rate = self.cfg.train.learning_rate
144
+ self.l1_penalty = self.cfg.train.l1_penalty
145
+ self.early_stop_gen = self.cfg.train.early_stop_gen
146
+ self.min_epochs = self.cfg.train.min_epochs
147
+ self.epochs = self.cfg.train.max_epochs
148
+ self.validation_split = self.cfg.train.validation_split
149
+ self.beta = self.cfg.train.weights_beta
150
+ self.max_ratio = self.cfg.train.weights_max_ratio
151
+
152
+ # Tuning (AE-parity surface; VAE ignores latent refinement during eval)
153
+ self.tune = self.cfg.tune.enabled
154
+ self.tune_fast = self.cfg.tune.fast
155
+ self.tune_batch_size = self.cfg.tune.batch_size
156
+ self.tune_epochs = self.cfg.tune.epochs
157
+ self.tune_eval_interval = self.cfg.tune.eval_interval
158
+ self.tune_metric = self.cfg.tune.metric
159
+ self.n_trials = self.cfg.tune.n_trials
160
+ self.tune_save_db = self.cfg.tune.save_db
161
+ self.tune_resume = self.cfg.tune.resume
162
+ self.tune_max_samples = self.cfg.tune.max_samples
163
+ self.tune_max_loci = self.cfg.tune.max_loci
164
+ self.tune_patience = self.cfg.tune.patience
165
+
166
+ # Plotting (AE-parity)
167
+ self.plot_format = self.cfg.plot.fmt
168
+ self.plot_dpi = self.cfg.plot.dpi
169
+ self.plot_fontsize = self.cfg.plot.fontsize
170
+ self.title_fontsize = self.cfg.plot.fontsize
171
+ self.despine = self.cfg.plot.despine
172
+ self.show_plots = self.cfg.plot.show
173
+
174
+ # Derived at fit-time
175
+ self.is_haploid: bool | None = None
176
+ self.num_classes_: int | None = None
177
+ self.model_params: Dict[str, Any] = {}
178
+
179
+ # -------------------- Fit -------------------- #
180
+ def fit(self) -> "ImputeVAE":
181
+ """Fit the VAE on 0/1/2 encoded genotypes (missing → -9).
182
+
183
+ This method prepares the genotype data, initializes model parameters, splits the data into training and validation sets, and trains the VAE model. It handles both haploid and diploid data, applies class weighting, and supports optional hyperparameter tuning. After training, it evaluates the model on the validation set and saves the trained model.
184
+
185
+ Returns:
186
+ ImputeVAE: Fitted instance.
187
+
188
+ Raises:
189
+ RuntimeError: If training fails to produce a model.
190
+ """
191
+ self.logger.info(f"Fitting {self.model_name} (0/1/2 VAE) ...")
192
+
193
+ # Data prep aligns with AE/NLPCA
194
+ X = self.pgenc.genotypes_012.astype(np.float32)
195
+ X[X < 0] = np.nan
196
+ X[np.isnan(X)] = -1
197
+ self.ground_truth_ = X.astype(np.int64)
198
+
199
+ # Ploidy/classes
200
+ self.is_haploid = np.all(
201
+ np.isin(
202
+ self.genotype_data.snp_data,
203
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
204
+ )
205
+ )
206
+ self.ploidy = 1 if self.is_haploid else 2
207
+ self.num_classes_ = 2 if self.is_haploid else 3
208
+ self.logger.info(
209
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
210
+ f"using {self.num_classes_} classes."
211
+ )
212
+
213
+ n_samples, self.num_features_ = X.shape
214
+
215
+ # Model params (decoder outputs L*K logits)
216
+ self.model_params = {
217
+ "n_features": self.num_features_,
218
+ "num_classes": self.num_classes_,
219
+ "latent_dim": self.latent_dim,
220
+ "dropout_rate": self.dropout_rate,
221
+ "activation": self.activation,
222
+ }
223
+
224
+ # Train/Val split
225
+ indices = np.arange(n_samples)
226
+ train_idx, val_idx = train_test_split(
227
+ indices, test_size=self.validation_split, random_state=self.seed
228
+ )
229
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
230
+ self.X_train_ = self.ground_truth_[train_idx]
231
+ self.X_val_ = self.ground_truth_[val_idx]
232
+
233
+ # Plotters/scorers (shared utilities)
234
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
235
+
236
+ # Optional tuning
237
+ if self.tune:
238
+ self.tune_hyperparameters()
239
+
240
+ # Best params (tuned or default)
241
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
242
+
243
+ # Class weights (device-aware)
244
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_).to(
245
+ self.device
246
+ )
247
+
248
+ # DataLoader
249
+ train_loader = self._get_data_loader(self.X_train_)
250
+
251
+ # Build & train
252
+ model = self.build_model(self.Model, self.best_params_)
253
+ model.apply(self.initialize_weights)
254
+
255
+ loss, trained_model, history = self._train_and_validate_model(
256
+ model=model,
257
+ loader=train_loader,
258
+ lr=self.learning_rate,
259
+ l1_penalty=self.l1_penalty,
260
+ return_history=True,
261
+ class_weights=self.class_weights_,
262
+ X_val=self.X_val_,
263
+ params=self.best_params_,
264
+ prune_metric=self.tune_metric,
265
+ prune_warmup_epochs=5,
266
+ eval_interval=1,
267
+ eval_requires_latents=False, # no latent refinement for eval
268
+ eval_latent_steps=0,
269
+ eval_latent_lr=0.0,
270
+ eval_latent_weight_decay=0.0,
271
+ )
272
+
273
+ if trained_model is None:
274
+ msg = "VAE training failed; no model was returned."
275
+ self.logger.error(msg)
276
+ raise RuntimeError(msg)
277
+
278
+ torch.save(
279
+ trained_model.state_dict(),
280
+ self.models_dir / f"final_model_{self.model_name}.pt",
281
+ )
282
+
283
+ self.best_loss_, self.model_, self.history_ = (
284
+ loss,
285
+ trained_model,
286
+ {"Train": history},
287
+ )
288
+ self.is_fit_ = True
289
+
290
+ # Evaluate (AE-parity reporting)
291
+ self._evaluate_model(self.X_val_, self.model_, self.best_params_)
292
+ self.plotter_.plot_history(self.history_)
293
+ self._save_best_params(self.best_params_)
294
+ return self
295
+
296
+ def transform(self) -> np.ndarray:
297
+ """Impute missing genotypes and return IUPAC strings.
298
+
299
+ This method uses the trained VAE model to impute missing genotypes in the dataset. It predicts the most likely genotype for each missing entry based on the learned latent representations and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easy interpretation.
300
+
301
+ Returns:
302
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
303
+
304
+ Raises:
305
+ NotFittedError: If called before fit().
306
+ """
307
+ if not getattr(self, "is_fit_", False):
308
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
309
+
310
+ self.logger.info("Imputing entire dataset with VAE (0/1/2)...")
311
+ X_to_impute = self.ground_truth_.copy()
312
+
313
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
314
+
315
+ # Fill only missing
316
+ missing_mask = X_to_impute == -1
317
+ imputed_array = X_to_impute.copy()
318
+ imputed_array[missing_mask] = pred_labels[missing_mask]
319
+
320
+ # Decode to IUPAC & plot
321
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
322
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
323
+
324
+ plt.rcParams.update(self.plotter_.param_dict)
325
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
326
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
327
+
328
+ return imputed_genotypes
329
+
330
+ # ---------- plumbing identical to AE, naming aligned ---------- #
331
+
332
+ def _get_data_loader(self, y: np.ndarray) -> torch.utils.data.DataLoader:
333
+ """Create DataLoader over indices + integer targets (-1 for missing).
334
+
335
+ This method creates a PyTorch DataLoader for the training data. It converts the input genotype matrix into a tensor and constructs a dataset that includes both the indices and the genotype values. The DataLoader is configured to shuffle the data and use the specified batch size for training.
336
+
337
+ Args:
338
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
339
+
340
+ Returns:
341
+ torch.utils.data.DataLoader: Shuffled DataLoader.
342
+ """
343
+ y_tensor = torch.from_numpy(y).long().to(self.device)
344
+ dataset = torch.utils.data.TensorDataset(
345
+ torch.arange(len(y), device=self.device), y_tensor
346
+ )
347
+ return torch.utils.data.DataLoader(
348
+ dataset, batch_size=self.batch_size, shuffle=True
349
+ )
350
+
351
+ def _train_and_validate_model(
352
+ self,
353
+ model: torch.nn.Module,
354
+ loader: torch.utils.data.DataLoader,
355
+ lr: float,
356
+ l1_penalty: float,
357
+ trial: optuna.Trial | None = None,
358
+ return_history: bool = False,
359
+ class_weights: torch.Tensor | None = None,
360
+ *,
361
+ X_val: np.ndarray | None = None,
362
+ params: dict | None = None,
363
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
364
+ prune_warmup_epochs: int = 3,
365
+ eval_interval: int = 1,
366
+ eval_requires_latents: bool = False, # VAE: no latent eval refinement
367
+ eval_latent_steps: int = 0,
368
+ eval_latent_lr: float = 0.0,
369
+ eval_latent_weight_decay: float = 0.0,
370
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
371
+ """Wrap the VAE training loop with β-anneal & Optuna pruning.
372
+
373
+ This method orchestrates the training of the VAE model, including setting up the optimizer and learning rate scheduler, and executing the training loop with support for early stopping and Optuna pruning. It manages the training process, monitors performance on a validation set if provided, and returns the best model and training history.
374
+
375
+ Args:
376
+ model (torch.nn.Module): VAE model.
377
+ loader (torch.utils.data.DataLoader): Training data loader.
378
+ lr (float): Learning rate.
379
+ l1_penalty (float): L1 regularization coefficient.
380
+ trial (optuna.Trial | None): Optuna trial for pruning.
381
+ return_history (bool): If True, return training history.
382
+ class_weights (torch.Tensor | None): CE class weights on device.
383
+ X_val (np.ndarray | None): Validation data for pruning eval.
384
+ params (dict | None): Current hyperparameters (for logging).
385
+ prune_metric (str | None): Metric for pruning decisions.
386
+ prune_warmup_epochs (int): Epochs to skip before pruning.
387
+ eval_interval (int): Epochs between validation evaluations.
388
+ eval_requires_latents (bool): If True, refine latents during eval.
389
+ eval_latent_steps (int): Latent refinement steps if needed.
390
+ eval_latent_lr (float): Latent refinement learning rate.
391
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
392
+
393
+ Returns:
394
+ Tuple[float, torch.nn.Module | None, list | None]: Best loss, best model, and training history (if requested).
395
+ """
396
+ if class_weights is None:
397
+ msg = "Must provide class_weights."
398
+ self.logger.error(msg)
399
+ raise TypeError(msg)
400
+
401
+ max_epochs = (
402
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
403
+ )
404
+
405
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
406
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
407
+
408
+ best_loss, best_model, hist = self._execute_training_loop(
409
+ loader=loader,
410
+ optimizer=optimizer,
411
+ scheduler=scheduler,
412
+ model=model,
413
+ l1_penalty=l1_penalty,
414
+ trial=trial,
415
+ return_history=return_history,
416
+ class_weights=class_weights,
417
+ X_val=X_val,
418
+ params=params,
419
+ prune_metric=prune_metric,
420
+ prune_warmup_epochs=prune_warmup_epochs,
421
+ eval_interval=eval_interval,
422
+ eval_requires_latents=eval_requires_latents,
423
+ eval_latent_steps=eval_latent_steps,
424
+ eval_latent_lr=eval_latent_lr,
425
+ eval_latent_weight_decay=eval_latent_weight_decay,
426
+ )
427
+ if return_history:
428
+ return best_loss, best_model, hist
429
+
430
+ return best_loss, best_model, None
431
+
432
+ def _execute_training_loop(
433
+ self,
434
+ loader: torch.utils.data.DataLoader,
435
+ optimizer: torch.optim.Optimizer,
436
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
437
+ model: torch.nn.Module,
438
+ l1_penalty: float,
439
+ trial: optuna.Trial | None,
440
+ return_history: bool,
441
+ class_weights: torch.Tensor,
442
+ *,
443
+ X_val: np.ndarray | None = None,
444
+ params: dict | None = None,
445
+ prune_metric: str | None = None,
446
+ prune_warmup_epochs: int = 3,
447
+ eval_interval: int = 1,
448
+ eval_requires_latents: bool = False,
449
+ eval_latent_steps: int = 0,
450
+ eval_latent_lr: float = 0.0,
451
+ eval_latent_weight_decay: float = 0.0,
452
+ ) -> Tuple[float, torch.nn.Module, list]:
453
+ """Train VAE with focal CE + KL(β) anneal, early stopping & pruning.
454
+
455
+ This method implements the core training loop for the VAE model, incorporating focal cross-entropy loss for reconstruction and KL divergence with an annealed beta weight. It includes mechanisms for early stopping based on validation performance and supports pruning of unpromising trials when used with Optuna. The training process is monitored, and the best model is retained.
456
+
457
+ Args:
458
+ loader (torch.utils.data.DataLoader): Training data loader.
459
+ optimizer (torch.optim.Optimizer): Optimizer.
460
+ scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
461
+ model (torch.nn.Module): VAE model.
462
+ l1_penalty (float): L1 regularization coefficient.
463
+ trial (optuna.Trial | None): Optuna trial for pruning.
464
+ return_history (bool): If True, return training history.
465
+ class_weights (torch.Tensor): CE class weights on device.
466
+ X_val (np.ndarray | None): Validation data for pruning eval.
467
+ params (dict | None): Current hyperparameters (for logging).
468
+ prune_metric (str | None): Metric for pruning decisions.
469
+ prune_warmup_epochs (int): Epochs to skip before pruning.
470
+ eval_interval (int): Epochs between validation evaluations.
471
+ eval_requires_latents (bool): If True, refine latents during eval.
472
+ eval_latent_steps (int): Latent refinement steps if needed.
473
+ eval_latent_lr (float): Latent refinement learning rate.
474
+ eval_latent_weight_decay (float): Latent refinement L2 penalty.
475
+
476
+ Returns:
477
+ Tuple[float, torch.nn.Module, list[float]]: Best loss, best model, history.
478
+ """
479
+ best_model = None
480
+ history: list[float] = []
481
+
482
+ early_stopping = EarlyStopping(
483
+ patience=self.early_stop_gen,
484
+ min_epochs=self.min_epochs,
485
+ verbose=self.verbose,
486
+ prefix=self.prefix,
487
+ debug=self.debug,
488
+ )
489
+
490
+ # AE-parity gamma schedule for focal CE (reconstruction)
491
+ gamma_warm, gamma_ramp, gamma_final = 50, 100, self.gamma
492
+ # VAE β schedule for KL term
493
+ beta_warm, beta_ramp, beta_final = (
494
+ self.kl_warmup,
495
+ self.kl_ramp,
496
+ self.kl_beta_final,
497
+ )
498
+
499
+ for epoch in range(scheduler.T_max):
500
+ # schedules
501
+ # focal γ schedule (if your VAEModel uses it for recon CE)
502
+ if epoch < gamma_warm:
503
+ model.gamma = 0.0
504
+ elif epoch < gamma_warm + gamma_ramp:
505
+ model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp)
506
+ else:
507
+ model.gamma = gamma_final
508
+
509
+ # KL β schedule
510
+ if epoch < beta_warm:
511
+ model.beta = 0.0
512
+ elif epoch < beta_warm + beta_ramp:
513
+ model.beta = beta_final * ((epoch - beta_warm) / beta_ramp)
514
+ else:
515
+ model.beta = beta_final
516
+
517
+ # one epoch
518
+ train_loss = self._train_step(
519
+ loader=loader,
520
+ optimizer=optimizer,
521
+ model=model,
522
+ l1_penalty=l1_penalty,
523
+ class_weights=class_weights,
524
+ )
525
+ if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
526
+ raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
527
+
528
+ scheduler.step()
529
+ if return_history:
530
+ history.append(train_loss)
531
+
532
+ early_stopping(train_loss, model)
533
+ if early_stopping.early_stop:
534
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
535
+ break
536
+
537
+ # Optuna report/prune on validation metric
538
+ if (
539
+ trial is not None
540
+ and X_val is not None
541
+ and ((epoch + 1) % eval_interval == 0)
542
+ ):
543
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
544
+ metric_val = self._eval_for_pruning(
545
+ model=model,
546
+ X_val=X_val,
547
+ params=params or getattr(self, "best_params_", {}),
548
+ metric=metric_key,
549
+ objective_mode=True,
550
+ do_latent_infer=False, # VAE: no latent refinement needed
551
+ latent_steps=0,
552
+ latent_lr=0.0,
553
+ latent_weight_decay=0.0,
554
+ latent_seed=(self.seed if self.seed is not None else 123),
555
+ _latent_cache=None,
556
+ _latent_cache_key=None,
557
+ )
558
+ trial.report(metric_val, step=epoch + 1)
559
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
560
+ raise optuna.exceptions.TrialPruned(
561
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
562
+ )
563
+
564
+ best_loss = early_stopping.best_score
565
+ best_model = copy.deepcopy(early_stopping.best_model)
566
+ return best_loss, best_model, history
567
+
568
+ def _train_step(
569
+ self,
570
+ loader: torch.utils.data.DataLoader,
571
+ optimizer: torch.optim.Optimizer,
572
+ model: torch.nn.Module,
573
+ l1_penalty: float,
574
+ class_weights: torch.Tensor,
575
+ ) -> float:
576
+ """One epoch: one-hot inputs → VAE forward → recon (focal) + KL.
577
+
578
+ The VAEModel is expected to return (recon_logits, mu, logvar, ...) and expose a `compute_loss(outputs, y, mask, class_weights)` method that reads scheduled `model.beta` (and optionally `model.gamma`) attributes.
579
+
580
+ Args:
581
+ loader (torch.utils.data.DataLoader): Yields (indices, y_int) where y_int is 0/1/2; -1 for missing.
582
+ optimizer (torch.optim.Optimizer): Optimizer.
583
+ model (torch.nn.Module): VAE model.
584
+ l1_penalty (float): L1 regularization coefficient.
585
+ class_weights (torch.Tensor): CE class weights on device.
586
+
587
+ Returns:
588
+ float: Mean training loss for the epoch.
589
+ """
590
+ model.train()
591
+ running = 0.0
592
+
593
+ for _, y_batch in loader:
594
+ optimizer.zero_grad(set_to_none=True)
595
+
596
+ x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K), zeros for -1
597
+ outputs = model(x_ohe) # (recon_logits, mu, logvar, ...)
598
+
599
+ # Targets for masked focal CE, same shapes as AE path
600
+ y_ohe = self._one_hot_encode_012(y_batch)
601
+ valid_mask = y_batch != -1
602
+
603
+ loss = model.compute_loss(
604
+ outputs=outputs,
605
+ y=y_ohe, # (B, L, K)
606
+ mask=valid_mask, # (B, L)
607
+ class_weights=class_weights,
608
+ )
609
+
610
+ if l1_penalty > 0:
611
+ loss = loss + l1_penalty * sum(
612
+ p.abs().sum() for p in model.parameters()
613
+ )
614
+
615
+ loss.backward()
616
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
617
+ optimizer.step()
618
+ running += float(loss.item())
619
+
620
+ return running / len(loader)
621
+
622
+ def _predict(
623
+ self,
624
+ model: torch.nn.Module,
625
+ X: np.ndarray | torch.Tensor,
626
+ return_proba: bool = False,
627
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
628
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
629
+
630
+ This method uses the trained VAE model to predict genotype labels for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted labels and probabilities. The method can return either just the predicted labels or both labels and probabilities based on the `return_proba` flag.
631
+
632
+ Args:
633
+ model (torch.nn.Module): Trained model.
634
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1 for missing.
635
+ return_proba (bool): If True, also return probabilities.
636
+
637
+ Returns:
638
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels, and probabilities if requested.
639
+ """
640
+ if model is None:
641
+ msg = "Model is not trained. Call fit() before predict()."
642
+ self.logger.error(msg)
643
+ raise NotFittedError(msg)
644
+
645
+ model.eval()
646
+ with torch.no_grad():
647
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
648
+ X_tensor = X_tensor.to(self.device).long()
649
+ x_ohe = self._one_hot_encode_012(X_tensor)
650
+ outputs = model(x_ohe) # first element must be recon logits
651
+ logits = outputs[0].view(-1, self.num_features_, self.num_classes_)
652
+ probas = torch.softmax(logits, dim=-1)
653
+ labels = torch.argmax(probas, dim=-1)
654
+
655
+ if return_proba:
656
+ return labels.cpu().numpy(), probas.cpu().numpy()
657
+
658
+ return labels.cpu().numpy()
659
+
660
+ def _evaluate_model(
661
+ self,
662
+ X_val: np.ndarray,
663
+ model: torch.nn.Module,
664
+ params: dict,
665
+ objective_mode: bool = False,
666
+ latent_vectors_val: np.ndarray | None = None,
667
+ ) -> Dict[str, float]:
668
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
669
+
670
+ This method evaluates the trained VAE model on a validation dataset, computing various performance metrics. It handles missing data appropriately and generates detailed classification reports for both the original 0/1/2 encoding and the decoded IUPAC and integer formats. The evaluation metrics are logged for review.
671
+
672
+ Args:
673
+ X_val (np.ndarray): Validation 0/1/2 matrix with -1 for missing.
674
+ model (torch.nn.Module): Trained model.
675
+ params (dict): Current hyperparameters (for logging).
676
+ objective_mode (bool): If True, minimize logging for Optuna.
677
+ latent_vectors_val (np.ndarray | None): Not used by VAE.
678
+
679
+ Returns:
680
+ Dict[str, float]: Computed metrics.
681
+
682
+ Raises:
683
+ NotFittedError: If called before fit().
684
+ """
685
+ pred_labels, pred_probas = self._predict(
686
+ model=model, X=X_val, return_proba=True
687
+ )
688
+
689
+ # mask out true missing AND any non-finite prob rows
690
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N,L)
691
+ eval_mask = (X_val != -1) & finite_mask
692
+
693
+ y_true_flat = X_val[eval_mask].astype(np.int64, copy=False)
694
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
695
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
696
+
697
+ if y_true_flat.size == 0:
698
+ return {self.tune_metric: 0.0}
699
+
700
+ # ensure valid probability simplex after masking
701
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
702
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
703
+ row_sums[row_sums == 0] = 1.0
704
+ y_proba_flat = y_proba_flat / row_sums
705
+
706
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
707
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
708
+
709
+ if self.is_haploid:
710
+ y_true_flat = y_true_flat.copy()
711
+ y_pred_flat = y_pred_flat.copy()
712
+ y_true_flat[y_true_flat == 2] = 1
713
+ y_pred_flat[y_pred_flat == 2] = 1
714
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
715
+ proba_2[:, 0] = y_proba_flat[:, 0]
716
+ proba_2[:, 1] = y_proba_flat[:, 2]
717
+ y_proba_flat = proba_2
718
+
719
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
720
+
721
+ metrics = self.scorers_.evaluate(
722
+ y_true_flat,
723
+ y_pred_flat,
724
+ y_true_ohe,
725
+ y_proba_flat,
726
+ objective_mode,
727
+ self.tune_metric,
728
+ )
729
+
730
+ if not objective_mode:
731
+ self.logger.info(f"Validation Metrics: {metrics}")
732
+
733
+ # Primary report
734
+ self._make_class_reports(
735
+ y_true=y_true_flat,
736
+ y_pred_proba=y_proba_flat,
737
+ y_pred=y_pred_flat,
738
+ metrics=metrics,
739
+ labels=target_names,
740
+ )
741
+
742
+ # IUPAC decode & 10-base integer report
743
+ y_true_dec = self.pgenc.decode_012(X_val)
744
+ X_pred = X_val.copy()
745
+ X_pred[eval_mask] = y_pred_flat
746
+ y_pred_dec = self.pgenc.decode_012(
747
+ X_pred.reshape(X_val.shape[0], self.num_features_)
748
+ )
749
+
750
+ encodings_dict = {
751
+ "A": 0,
752
+ "C": 1,
753
+ "G": 2,
754
+ "T": 3,
755
+ "W": 4,
756
+ "R": 5,
757
+ "M": 6,
758
+ "K": 7,
759
+ "Y": 8,
760
+ "S": 9,
761
+ "N": -1,
762
+ }
763
+ y_true_int = self.pgenc.convert_int_iupac(
764
+ y_true_dec, encodings_dict=encodings_dict
765
+ )
766
+ y_pred_int = self.pgenc.convert_int_iupac(
767
+ y_pred_dec, encodings_dict=encodings_dict
768
+ )
769
+
770
+ self._make_class_reports(
771
+ y_true=y_true_int[eval_mask],
772
+ y_pred=y_pred_int[eval_mask],
773
+ metrics=metrics,
774
+ y_pred_proba=None,
775
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
776
+ )
777
+
778
+ return metrics
779
+
780
+ def _objective(self, trial: optuna.Trial) -> float:
781
+ """Optuna objective for VAE (no latent refinement during eval).
782
+
783
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, trains the VAE model with these parameters, and evaluates its performance on a validation set. The evaluation metric specified by `self.tune_metric` is returned for optimization. If training fails, the trial is pruned to keep the tuning process efficient.
784
+
785
+ Args:
786
+ trial (optuna.Trial): Optuna trial object.
787
+
788
+ Returns:
789
+ float: Value of the tuning metric to be optimized.
790
+ """
791
+ try:
792
+ params = self._sample_hyperparameters(trial)
793
+
794
+ X_train = self.ground_truth_[self.train_idx_]
795
+ X_val = self.ground_truth_[self.test_idx_]
796
+
797
+ class_weights = self._class_weights_from_zygosity(X_train).to(self.device)
798
+ train_loader = self._get_data_loader(X_train)
799
+
800
+ model = self.build_model(self.Model, params["model_params"])
801
+ model.apply(self.initialize_weights)
802
+
803
+ # Train + prune on metric
804
+ _, model, _ = self._train_and_validate_model(
805
+ model=model,
806
+ loader=train_loader,
807
+ lr=params["lr"],
808
+ l1_penalty=params["l1_penalty"],
809
+ trial=trial,
810
+ return_history=False,
811
+ class_weights=class_weights,
812
+ X_val=X_val,
813
+ params=params,
814
+ prune_metric=self.tune_metric,
815
+ prune_warmup_epochs=5,
816
+ eval_interval=self.tune_eval_interval,
817
+ eval_requires_latents=False,
818
+ eval_latent_steps=0,
819
+ eval_latent_lr=0.0,
820
+ eval_latent_weight_decay=0.0,
821
+ )
822
+
823
+ metrics = self._evaluate_model(X_val, model, params, objective_mode=True)
824
+ self._clear_resources(model, train_loader)
825
+ return metrics[self.tune_metric]
826
+
827
+ except Exception as e:
828
+ # Keep sweeps moving
829
+ self.logger.debug(f"Trial failed with error: {e}")
830
+ raise optuna.exceptions.TrialPruned(
831
+ f"Trial failed with error. Enable debug logging for details."
832
+ )
833
+
834
+ def _sample_hyperparameters(
835
+ self, trial: optuna.Trial
836
+ ) -> Dict[str, int | float | str]:
837
+ """Sample VAE hyperparameters; hidden sizes mirror AE/NLPCA helper.
838
+
839
+ Args:
840
+ trial (optuna.Trial): Optuna trial object.
841
+
842
+ Returns:
843
+ Dict[str, int | float | str]: Sampled hyperparameters.
844
+ """
845
+ params = {
846
+ "latent_dim": trial.suggest_int("latent_dim", 2, 64),
847
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
848
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
849
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
850
+ "activation": trial.suggest_categorical(
851
+ "activation", ["relu", "elu", "selu"]
852
+ ),
853
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
854
+ "layer_scaling_factor": trial.suggest_float(
855
+ "layer_scaling_factor", 2.0, 10.0
856
+ ),
857
+ "layer_schedule": trial.suggest_categorical(
858
+ "layer_schedule", ["pyramid", "constant", "linear"]
859
+ ),
860
+ # VAE-specific β (final value after anneal)
861
+ "beta": trial.suggest_float("beta", 0.25, 4.0),
862
+ # focal gamma (if used in VAE recon CE)
863
+ "gamma": trial.suggest_float("gamma", 0.0, 5.0),
864
+ }
865
+
866
+ input_dim = self.num_features_ * self.num_classes_
867
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
868
+ n_inputs=input_dim,
869
+ n_outputs=input_dim,
870
+ n_samples=len(self.train_idx_),
871
+ n_hidden=params["num_hidden_layers"],
872
+ alpha=params["layer_scaling_factor"],
873
+ schedule=params["layer_schedule"],
874
+ )
875
+
876
+ # [latent_dim] + interior widths (exclude output width)
877
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
878
+
879
+ params["model_params"] = {
880
+ "n_features": self.num_features_,
881
+ "num_classes": self.num_classes_,
882
+ "latent_dim": params["latent_dim"],
883
+ "dropout_rate": params["dropout_rate"],
884
+ "hidden_layer_sizes": hidden_only,
885
+ "activation": params["activation"],
886
+ # Pass through VAE recon/regularization coefficients
887
+ "beta": params["beta"],
888
+ "gamma": params["gamma"],
889
+ }
890
+ return params
891
+
892
+ def _set_best_params(
893
+ self, best_params: Dict[str, int | float | str | list]
894
+ ) -> Dict[str, int | float | str | list]:
895
+ """Adopt best params and return VAE model_params.
896
+
897
+ Args:
898
+ best_params (Dict[str, int | float | str | list]): Best hyperparameters from tuning.
899
+
900
+ Returns:
901
+ Dict[str, int | float | str | list]: VAE model parameters.
902
+ """
903
+ self.latent_dim = best_params["latent_dim"]
904
+ self.dropout_rate = best_params["dropout_rate"]
905
+ self.learning_rate = best_params["learning_rate"]
906
+ self.l1_penalty = best_params["l1_penalty"]
907
+ self.activation = best_params["activation"]
908
+ self.layer_scaling_factor = best_params["layer_scaling_factor"]
909
+ self.layer_schedule = best_params["layer_schedule"]
910
+ self.kl_beta_final = best_params.get("beta", self.kl_beta_final)
911
+ self.gamma = best_params.get("gamma", self.gamma)
912
+
913
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
914
+ n_inputs=self.num_features_ * self.num_classes_,
915
+ n_outputs=self.num_features_ * self.num_classes_,
916
+ n_samples=len(self.train_idx_),
917
+ n_hidden=best_params["num_hidden_layers"],
918
+ alpha=best_params["layer_scaling_factor"],
919
+ schedule=best_params["layer_schedule"],
920
+ )
921
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
922
+
923
+ return {
924
+ "n_features": self.num_features_,
925
+ "latent_dim": self.latent_dim,
926
+ "hidden_layer_sizes": hidden_only,
927
+ "dropout_rate": self.dropout_rate,
928
+ "activation": self.activation,
929
+ "num_classes": self.num_classes_,
930
+ "beta": self.kl_beta_final,
931
+ "gamma": self.gamma,
932
+ }
933
+
934
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
935
+ """Default VAE model params when tuning is disabled.
936
+
937
+ Returns:
938
+ Dict[str, int | float | str | list]: VAE model parameters.
939
+ """
940
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
941
+ n_inputs=self.num_features_ * self.num_classes_,
942
+ n_outputs=self.num_features_ * self.num_classes_,
943
+ n_samples=len(self.ground_truth_),
944
+ n_hidden=self.num_hidden_layers,
945
+ alpha=self.layer_scaling_factor,
946
+ schedule=self.layer_schedule,
947
+ )
948
+ return {
949
+ "n_features": self.num_features_,
950
+ "latent_dim": self.latent_dim,
951
+ "hidden_layer_sizes": hidden_layer_sizes,
952
+ "dropout_rate": self.dropout_rate,
953
+ "activation": self.activation,
954
+ "num_classes": self.num_classes_,
955
+ "beta": self.kl_beta_final,
956
+ "gamma": self.gamma,
957
+ }