pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,972 @@
1
+ import copy
2
+ import json
3
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from sklearn.exceptions import NotFittedError
11
+ from sklearn.model_selection import train_test_split
12
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
13
+ from snpio.utils.logging import LoggerManager
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+
16
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
17
+ from pgsui.data_processing.containers import AutoencoderConfig
18
+ from pgsui.impute.unsupervised.base import BaseNNImputer
19
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
+ from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
21
+
22
+ if TYPE_CHECKING:
23
+ from snpio.read_input.genotype_data import GenotypeData
24
+
25
+
26
+ def ensure_autoencoder_config(
27
+ config: AutoencoderConfig | dict | str | None,
28
+ ) -> AutoencoderConfig:
29
+ """Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
30
+
31
+ This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
32
+
33
+ Args:
34
+ config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
35
+
36
+ Returns:
37
+ AutoencoderConfig: Concrete configuration instance.
38
+ """
39
+ if config is None:
40
+ return AutoencoderConfig()
41
+ if isinstance(config, AutoencoderConfig):
42
+ return config
43
+ if isinstance(config, str):
44
+ # YAML path — top-level `preset` key is supported
45
+ return load_yaml_to_dataclass(
46
+ config, AutoencoderConfig, preset_builder=AutoencoderConfig.from_preset
47
+ )
48
+ if isinstance(config, dict):
49
+ # Flatten dict into dot-keys then overlay onto a fresh instance
50
+ base = AutoencoderConfig()
51
+
52
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
53
+ for k, v in d.items():
54
+ kk = f"{prefix}.{k}" if prefix else k
55
+ if isinstance(v, dict):
56
+ _flatten(kk, v, out)
57
+ else:
58
+ out[kk] = v
59
+ return out
60
+
61
+ # Lift any present preset first
62
+ preset_name = config.pop("preset", None)
63
+ if "io" in config and isinstance(config["io"], dict):
64
+ preset_name = preset_name or config["io"].pop("preset", None)
65
+
66
+ if preset_name:
67
+ base = AutoencoderConfig.from_preset(preset_name)
68
+
69
+ flat = _flatten("", config, {})
70
+ return apply_dot_overrides(base, flat)
71
+
72
+ raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
73
+
74
+
75
+ class ImputeAutoencoder(BaseNNImputer):
76
+ """Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
77
+
78
+ This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
79
+
80
+ The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ genotype_data: "GenotypeData",
86
+ *,
87
+ config: Optional[Union["AutoencoderConfig", dict, str]] = None,
88
+ overrides: dict | None = None,
89
+ ) -> None:
90
+ """Initialize the Autoencoder imputer with a unified config interface.
91
+
92
+ This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
93
+
94
+ Args:
95
+ genotype_data: Backing genotype data object.
96
+ config: Structured configuration as dataclass, nested dict, YAML path, or None.
97
+ overrides: Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
98
+ """
99
+ self.model_name = "ImputeAutoencoder"
100
+ self.genotype_data = genotype_data
101
+
102
+ # Normalize config then apply highest-precedence overrides
103
+ cfg = ensure_autoencoder_config(config)
104
+ if overrides:
105
+ cfg = apply_dot_overrides(cfg, overrides)
106
+ self.cfg = cfg
107
+
108
+ # Logger consistent with NLPCA
109
+ logman = LoggerManager(
110
+ __name__,
111
+ prefix=self.cfg.io.prefix,
112
+ debug=self.cfg.io.debug,
113
+ verbose=self.cfg.io.verbose,
114
+ )
115
+ self.logger = logman.get_logger()
116
+
117
+ # BaseNNImputer bootstrapping (device/dirs/logging handled here)
118
+ super().__init__(
119
+ prefix=self.cfg.io.prefix,
120
+ device=self.cfg.train.device,
121
+ verbose=self.cfg.io.verbose,
122
+ debug=self.cfg.io.debug,
123
+ )
124
+
125
+ # Model hook & encoder
126
+ self.Model = AutoencoderModel
127
+ self.pgenc = GenotypeEncoder(genotype_data)
128
+
129
+ # IO / global
130
+ self.seed = self.cfg.io.seed
131
+ self.n_jobs = self.cfg.io.n_jobs
132
+ self.prefix = self.cfg.io.prefix
133
+ self.scoring_averaging = self.cfg.io.scoring_averaging
134
+ self.verbose = self.cfg.io.verbose
135
+ self.debug = self.cfg.io.debug
136
+ self.rng = np.random.default_rng(self.seed)
137
+
138
+ # Model hyperparams
139
+ self.latent_dim = self.cfg.model.latent_dim
140
+ self.dropout_rate = self.cfg.model.dropout_rate
141
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
142
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
143
+ self.layer_schedule = self.cfg.model.layer_schedule
144
+ self.activation = self.cfg.model.hidden_activation
145
+ self.gamma = self.cfg.model.gamma
146
+
147
+ # Train hyperparams
148
+ self.batch_size = self.cfg.train.batch_size
149
+ self.learning_rate = self.cfg.train.learning_rate
150
+ self.l1_penalty = self.cfg.train.l1_penalty
151
+ self.early_stop_gen = self.cfg.train.early_stop_gen
152
+ self.min_epochs = self.cfg.train.min_epochs
153
+ self.epochs = self.cfg.train.max_epochs
154
+ self.validation_split = self.cfg.train.validation_split
155
+ self.beta = self.cfg.train.weights_beta
156
+ self.max_ratio = self.cfg.train.weights_max_ratio
157
+
158
+ # Tuning
159
+ self.tune = self.cfg.tune.enabled
160
+ self.tune_fast = self.cfg.tune.fast
161
+ self.tune_batch_size = self.cfg.tune.batch_size
162
+ self.tune_epochs = self.cfg.tune.epochs
163
+ self.tune_eval_interval = self.cfg.tune.eval_interval
164
+ self.tune_metric = self.cfg.tune.metric
165
+ self.n_trials = self.cfg.tune.n_trials
166
+ self.tune_save_db = self.cfg.tune.save_db
167
+ self.tune_resume = self.cfg.tune.resume
168
+ self.tune_max_samples = self.cfg.tune.max_samples
169
+ self.tune_max_loci = self.cfg.tune.max_loci
170
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 0) # AE unused
171
+ self.tune_patience = self.cfg.tune.patience
172
+
173
+ # Evaluate
174
+ # AE does not optimize latents, so these are unused / fixed
175
+ self.eval_latent_steps = 0
176
+ self.eval_latent_lr = 0.0
177
+ self.eval_latent_weight_decay = 0.0
178
+
179
+ # Plotting (parity with NLPCA PlotConfig)
180
+ self.plot_format = self.cfg.plot.fmt
181
+ self.plot_dpi = self.cfg.plot.dpi
182
+ self.plot_fontsize = self.cfg.plot.fontsize
183
+ self.title_fontsize = self.cfg.plot.fontsize
184
+ self.despine = self.cfg.plot.despine
185
+ self.show_plots = self.cfg.plot.show
186
+
187
+ # Core derived at fit-time
188
+ self.is_haploid: bool | None = None
189
+ self.num_classes_: int | None = None
190
+ self.model_params: Dict[str, Any] = {}
191
+
192
+ def fit(self) -> "ImputeAutoencoder":
193
+ """Fit the autoencoder on 0/1/2 encoded genotypes (missing → -9).
194
+
195
+ This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented as -9. The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
196
+
197
+ Returns:
198
+ ImputeAutoencoder: Fitted instance.
199
+
200
+ Raises:
201
+ NotFittedError: If training fails.
202
+ """
203
+ self.logger.info(f"Fitting {self.model_name} (0/1/2 AE) ...")
204
+
205
+ # --- Data prep (mirror NLPCA) ---
206
+ X = self.pgenc.genotypes_012.astype(np.float32)
207
+ X[X < 0] = np.nan
208
+ X[np.isnan(X)] = -1
209
+ self.ground_truth_ = X.astype(np.int64)
210
+
211
+ # Ploidy & classes
212
+ self.is_haploid = np.all(
213
+ np.isin(
214
+ self.genotype_data.snp_data,
215
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
216
+ )
217
+ )
218
+ self.ploidy = 1 if self.is_haploid else 2
219
+ self.num_classes_ = 2 if self.is_haploid else 3
220
+ self.logger.info(
221
+ f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
222
+ f"using {self.num_classes_} classes."
223
+ )
224
+
225
+ n_samples, self.num_features_ = X.shape
226
+
227
+ # Model params (decoder outputs L * K logits)
228
+ self.model_params = {
229
+ "n_features": self.num_features_,
230
+ "num_classes": self.num_classes_,
231
+ "latent_dim": self.latent_dim,
232
+ "dropout_rate": self.dropout_rate,
233
+ "activation": self.activation,
234
+ }
235
+
236
+ # Train/Val split
237
+ indices = np.arange(n_samples)
238
+ train_idx, val_idx = train_test_split(
239
+ indices, test_size=self.validation_split, random_state=self.seed
240
+ )
241
+ self.train_idx_, self.test_idx_ = train_idx, val_idx
242
+ self.X_train_ = self.ground_truth_[train_idx]
243
+ self.X_val_ = self.ground_truth_[val_idx]
244
+
245
+ # Plotters/scorers (shared utilities)
246
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
247
+
248
+ # Tuning (optional; AE never needs latent refinement)
249
+ if self.tune:
250
+ self.tune_hyperparameters()
251
+
252
+ # Best params (tuned or default)
253
+ self.best_params_ = getattr(self, "best_params_", self._default_best_params())
254
+
255
+ # Class weights (device-aware)
256
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_).to(
257
+ self.device
258
+ )
259
+
260
+ # DataLoader
261
+ train_loader = self._get_data_loaders(self.X_train_)
262
+
263
+ # Build & train
264
+ model = self.build_model(self.Model, self.best_params_)
265
+ model.apply(self.initialize_weights)
266
+
267
+ loss, trained_model, history = self._train_and_validate_model(
268
+ model=model,
269
+ loader=train_loader,
270
+ lr=self.learning_rate,
271
+ l1_penalty=self.l1_penalty,
272
+ return_history=True,
273
+ class_weights=self.class_weights_,
274
+ X_val=self.X_val_,
275
+ params=self.best_params_,
276
+ prune_metric=self.tune_metric,
277
+ prune_warmup_epochs=5,
278
+ eval_interval=1,
279
+ eval_requires_latents=False,
280
+ eval_latent_steps=0,
281
+ eval_latent_lr=0.0,
282
+ eval_latent_weight_decay=0.0,
283
+ )
284
+
285
+ if trained_model is None:
286
+ msg = "Autoencoder training failed; no model was returned."
287
+ self.logger.error(msg)
288
+ raise RuntimeError(msg)
289
+
290
+ torch.save(
291
+ trained_model.state_dict(),
292
+ self.models_dir / f"final_model_{self.model_name}.pt",
293
+ )
294
+
295
+ self.best_loss_, self.model_, self.history_ = (
296
+ loss,
297
+ trained_model,
298
+ {"Train": history},
299
+ )
300
+ self.is_fit_ = True
301
+
302
+ # Evaluate on validation set (parity with NLPCA reporting)
303
+ self._evaluate_model(self.X_val_, self.model_, self.best_params_)
304
+ self.plotter_.plot_history(self.history_)
305
+ self._save_best_params(self.best_params_)
306
+
307
+ return self
308
+
309
+ def transform(self) -> np.ndarray:
310
+ """Impute missing genotypes (0/1/2) and return IUPAC strings.
311
+
312
+ This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
313
+
314
+ Returns:
315
+ np.ndarray: IUPAC strings of shape (n_samples, n_loci).
316
+
317
+ Raises:
318
+ NotFittedError: If called before fit().
319
+ """
320
+ if not getattr(self, "is_fit_", False):
321
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
322
+
323
+ self.logger.info("Imputing entire dataset with AE (0/1/2)...")
324
+ X_to_impute = self.ground_truth_.copy()
325
+
326
+ # Predict with masked inputs (no latent optimization)
327
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
328
+
329
+ # Fill only missing
330
+ missing_mask = X_to_impute == -1
331
+ imputed_array = X_to_impute.copy()
332
+ imputed_array[missing_mask] = pred_labels[missing_mask]
333
+
334
+ # Decode to IUPAC & plot
335
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
336
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
337
+
338
+ plt.rcParams.update(self.plotter_.param_dict)
339
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
340
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
341
+
342
+ return imputed_genotypes
343
+
344
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
345
+ """Create DataLoader over indices + integer targets (-1 for missing).
346
+
347
+ This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
348
+
349
+ Args:
350
+ y (np.ndarray): 0/1/2 matrix with -1 for missing.
351
+
352
+ Returns:
353
+ torch.utils.data.DataLoader: Shuffled DataLoader.
354
+ """
355
+ y_tensor = torch.from_numpy(y).long().to(self.device)
356
+ dataset = torch.utils.data.TensorDataset(
357
+ torch.arange(len(y), device=self.device), y_tensor
358
+ )
359
+ return torch.utils.data.DataLoader(
360
+ dataset, batch_size=self.batch_size, shuffle=True
361
+ )
362
+
363
+ def _train_and_validate_model(
364
+ self,
365
+ model: torch.nn.Module,
366
+ loader: torch.utils.data.DataLoader,
367
+ lr: float,
368
+ l1_penalty: float,
369
+ trial: optuna.Trial | None = None,
370
+ return_history: bool = False,
371
+ class_weights: torch.Tensor | None = None,
372
+ *,
373
+ X_val: np.ndarray | None = None,
374
+ params: dict | None = None,
375
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
376
+ prune_warmup_epochs: int = 3,
377
+ eval_interval: int = 1,
378
+ # Evaluation parameters (AE ignores latent refinement knobs)
379
+ eval_requires_latents: bool = False, # AE: always False
380
+ eval_latent_steps: int = 0,
381
+ eval_latent_lr: float = 0.0,
382
+ eval_latent_weight_decay: float = 0.0,
383
+ ) -> Tuple[float, torch.nn.Module | None, list | None]:
384
+ """Wrap the AE training loop (no latent optimizer), with Optuna pruning.
385
+
386
+ This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
387
+
388
+ Args:
389
+ model (torch.nn.Module): Autoencoder model.
390
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
391
+ lr (float): Learning rate.
392
+ l1_penalty (float): L1 regularization coeff.
393
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
394
+ return_history (bool): If True, return train loss history.
395
+ class_weights (torch.Tensor | None): Class weights tensor (on device).
396
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
397
+ params (dict | None): Model params for evaluation.
398
+ prune_metric (str | None): Metric for pruning reports.
399
+ prune_warmup_epochs (int): Pruning warmup epochs.
400
+ eval_interval (int): Eval frequency (epochs).
401
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
402
+ eval_latent_steps (int): Unused for AE.
403
+ eval_latent_lr (float): Unused for AE.
404
+ eval_latent_weight_decay (float): Unused for AE.
405
+
406
+ Returns:
407
+ Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
408
+ """
409
+ if class_weights is None:
410
+ msg = "Must provide class_weights."
411
+ self.logger.error(msg)
412
+ raise TypeError(msg)
413
+
414
+ # Epoch budget mirrors NLPCA config (tuning vs final)
415
+ max_epochs = (
416
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
417
+ )
418
+
419
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
420
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
421
+
422
+ best_loss, best_model, hist = self._execute_training_loop(
423
+ loader=loader,
424
+ optimizer=optimizer,
425
+ scheduler=scheduler,
426
+ model=model,
427
+ l1_penalty=l1_penalty,
428
+ trial=trial,
429
+ return_history=return_history,
430
+ class_weights=class_weights,
431
+ X_val=X_val,
432
+ params=params,
433
+ prune_metric=prune_metric,
434
+ prune_warmup_epochs=prune_warmup_epochs,
435
+ eval_interval=eval_interval,
436
+ eval_requires_latents=False, # AE: no latent inference
437
+ eval_latent_steps=0,
438
+ eval_latent_lr=0.0,
439
+ eval_latent_weight_decay=0.0,
440
+ )
441
+ if return_history:
442
+ return best_loss, best_model, hist
443
+
444
+ return best_loss, best_model, None
445
+
446
+ def _execute_training_loop(
447
+ self,
448
+ loader: torch.utils.data.DataLoader,
449
+ optimizer: torch.optim.Optimizer,
450
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
451
+ model: torch.nn.Module,
452
+ l1_penalty: float,
453
+ trial: optuna.Trial | None,
454
+ return_history: bool,
455
+ class_weights: torch.Tensor,
456
+ *,
457
+ X_val: np.ndarray | None = None,
458
+ params: dict | None = None,
459
+ prune_metric: str | None = None,
460
+ prune_warmup_epochs: int = 3,
461
+ eval_interval: int = 1,
462
+ # Evaluation parameters (AE ignores latent refinement knobs)
463
+ eval_requires_latents: bool = False, # AE: False
464
+ eval_latent_steps: int = 0,
465
+ eval_latent_lr: float = 0.0,
466
+ eval_latent_weight_decay: float = 0.0,
467
+ ) -> Tuple[float, torch.nn.Module, list]:
468
+ """Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
469
+
470
+ This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
471
+
472
+ Args:
473
+ loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
474
+ optimizer (torch.optim.Optimizer): Optimizer.
475
+ scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
476
+ model (torch.nn.Module): Autoencoder model.
477
+ l1_penalty (float): L1 regularization coeff.
478
+ trial (optuna.Trial | None): Optuna trial for pruning (optional).
479
+ return_history (bool): If True, return train loss history.
480
+ class_weights (torch.Tensor): Class weights tensor (on device).
481
+ X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
482
+ params (dict | None): Model params for evaluation.
483
+ prune_metric (str | None): Metric for pruning reports.
484
+ prune_warmup_epochs (int): Pruning warmup epochs.
485
+ eval_interval (int): Eval frequency (epochs).
486
+ eval_requires_latents (bool): Ignored for AE (no latent inference).
487
+ eval_latent_steps (int): Unused for AE.
488
+ eval_latent_lr (float): Unused for AE.
489
+ eval_latent_weight_decay (float): Unused for AE.
490
+
491
+ Returns:
492
+ Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
493
+ """
494
+ best_loss = float("inf")
495
+ best_model = None
496
+ history: list[float] = []
497
+
498
+ early_stopping = EarlyStopping(
499
+ patience=self.early_stop_gen,
500
+ min_epochs=self.min_epochs,
501
+ verbose=self.verbose,
502
+ prefix=self.prefix,
503
+ debug=self.debug,
504
+ )
505
+
506
+ # Parity with NLPCA (warm/ramp gamma schedule)
507
+ warm, ramp, gamma_final = 50, 100, self.gamma
508
+
509
+ # Epoch budget mirrors the caller's scheduler T_max
510
+ # (already set to tune_epochs or epochs).
511
+ for epoch in range(scheduler.T_max):
512
+ # Gamma schedule
513
+ if epoch < warm:
514
+ model.gamma = 0.0
515
+ elif epoch < warm + ramp:
516
+ model.gamma = gamma_final * ((epoch - warm) / ramp)
517
+ else:
518
+ model.gamma = gamma_final
519
+
520
+ # ---- one epoch ----
521
+ train_loss = self._train_step(
522
+ loader=loader,
523
+ optimizer=optimizer,
524
+ model=model,
525
+ l1_penalty=l1_penalty,
526
+ class_weights=class_weights,
527
+ )
528
+
529
+ if trial and (np.isnan(train_loss) or np.isinf(train_loss)):
530
+ raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
531
+
532
+ scheduler.step()
533
+ if return_history:
534
+ history.append(train_loss)
535
+
536
+ early_stopping(train_loss, model)
537
+ if early_stopping.early_stop:
538
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
539
+ break
540
+
541
+ # Optuna report/prune on validation metric
542
+ if (
543
+ trial is not None
544
+ and X_val is not None
545
+ and ((epoch + 1) % eval_interval == 0)
546
+ ):
547
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
548
+ metric_val = self._eval_for_pruning(
549
+ model=model,
550
+ X_val=X_val,
551
+ params=params or getattr(self, "best_params_", {}),
552
+ metric=metric_key,
553
+ objective_mode=True,
554
+ do_latent_infer=False, # AE: False
555
+ latent_steps=0,
556
+ latent_lr=0.0,
557
+ latent_weight_decay=0.0,
558
+ latent_seed=(self.seed if self.seed is not None else 123),
559
+ _latent_cache=None, # AE: not used
560
+ _latent_cache_key=None,
561
+ )
562
+ trial.report(metric_val, step=epoch + 1)
563
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
564
+ raise optuna.exceptions.TrialPruned(
565
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
566
+ )
567
+
568
+ best_loss = early_stopping.best_score
569
+ best_model = copy.deepcopy(early_stopping.best_model)
570
+ return best_loss, best_model, history
571
+
572
+ def _train_step(
573
+ self,
574
+ loader: torch.utils.data.DataLoader,
575
+ optimizer: torch.optim.Optimizer,
576
+ model: torch.nn.Module,
577
+ l1_penalty: float,
578
+ class_weights: torch.Tensor,
579
+ ) -> float:
580
+ """One epoch (indices, y_int) → one-hot inputs → logits → masked focal CE.
581
+
582
+ This method performs a single training epoch, processing batches of data from the DataLoader. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified.
583
+
584
+ Args:
585
+ loader (DataLoader): Yields (indices, y_int) where y_int is 0/1/2, -1 for missing.
586
+ optimizer (torch.optim.Optimizer): Optimizer.
587
+ model (torch.nn.Module): Autoencoder model.
588
+ l1_penalty (float): L1 regularization.
589
+ class_weights (torch.Tensor): Class weights for CE.
590
+
591
+ Returns:
592
+ float: Mean training loss for the epoch.
593
+ """
594
+ model.train()
595
+ running = 0.0
596
+
597
+ for _, y_batch in loader:
598
+ optimizer.zero_grad(set_to_none=True)
599
+
600
+ # Inputs: one-hot with zeros for missing; Targets: ints with -1
601
+ x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K)
602
+ logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
603
+
604
+ logits_flat = logits.view(-1, self.num_classes_)
605
+ targets_flat = y_batch.view(-1)
606
+
607
+ ce = F.cross_entropy(
608
+ logits_flat,
609
+ targets_flat,
610
+ weight=class_weights,
611
+ reduction="none",
612
+ ignore_index=-1,
613
+ )
614
+ pt = torch.exp(-ce)
615
+ gamma = getattr(model, "gamma", self.gamma)
616
+ focal = ((1 - pt) ** gamma) * ce
617
+
618
+ valid_mask = targets_flat != -1
619
+ loss = (
620
+ focal[valid_mask].mean()
621
+ if valid_mask.any()
622
+ else torch.tensor(0.0, device=logits.device)
623
+ )
624
+
625
+ if l1_penalty > 0:
626
+ loss = loss + l1_penalty * sum(
627
+ p.abs().sum() for p in model.parameters()
628
+ )
629
+
630
+ loss.backward()
631
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
632
+ optimizer.step()
633
+
634
+ running += float(loss.item())
635
+
636
+ return running / len(loader)
637
+
638
+ def _predict(
639
+ self,
640
+ model: torch.nn.Module,
641
+ X: np.ndarray | torch.Tensor,
642
+ return_proba: bool = False,
643
+ ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
644
+ """Predict 0/1/2 labels (and probabilities) from masked inputs.
645
+
646
+ This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
647
+
648
+ Args:
649
+ model (torch.nn.Module): Trained model.
650
+ X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
651
+ for missing.
652
+ return_proba (bool): If True, return probabilities.
653
+
654
+ Returns:
655
+ Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
656
+ and probabilities if requested.
657
+ """
658
+ if model is None:
659
+ msg = "Model is not trained. Call fit() before predict()."
660
+ self.logger.error(msg)
661
+ raise NotFittedError(msg)
662
+
663
+ model.eval()
664
+ with torch.no_grad():
665
+ X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
666
+ X_tensor = X_tensor.to(self.device).long()
667
+ x_ohe = self._one_hot_encode_012(X_tensor)
668
+ logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
669
+ probas = torch.softmax(logits, dim=-1)
670
+ labels = torch.argmax(probas, dim=-1)
671
+
672
+ if return_proba:
673
+ return labels.cpu().numpy(), probas.cpu().numpy()
674
+
675
+ return labels.cpu().numpy()
676
+
677
+ def _evaluate_model(
678
+ self,
679
+ X_val: np.ndarray,
680
+ model: torch.nn.Module,
681
+ params: dict,
682
+ objective_mode: bool = False,
683
+ latent_vectors_val: Optional[np.ndarray] = None,
684
+ ) -> Dict[str, float]:
685
+ """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
686
+
687
+ This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
688
+
689
+ Args:
690
+ X_val (np.ndarray): Validation set 0/1/2 matrix with -1
691
+ for missing.
692
+ model (torch.nn.Module): Trained model.
693
+ params (dict): Model parameters.
694
+ objective_mode (bool): If True, suppress logging and reports.
695
+
696
+ Returns:
697
+ Dict[str, float]: Dictionary of evaluation metrics.
698
+ """
699
+ pred_labels, pred_probas = self._predict(
700
+ model=model, X=X_val, return_proba=True
701
+ )
702
+
703
+ # mask out true missing AND any non-finite prob rows
704
+ finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N,L)
705
+ eval_mask = (X_val != -1) & finite_mask
706
+
707
+ y_true_flat = X_val[eval_mask].astype(np.int64, copy=False)
708
+ y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
709
+ y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
710
+
711
+ if y_true_flat.size == 0:
712
+ return {self.tune_metric: 0.0}
713
+
714
+ # ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
715
+ y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
716
+ row_sums = y_proba_flat.sum(axis=1, keepdims=True)
717
+ row_sums[row_sums == 0] = 1.0
718
+ y_proba_flat = y_proba_flat / row_sums
719
+
720
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
721
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
722
+
723
+ if self.is_haploid:
724
+ y_true_flat = y_true_flat.copy()
725
+ y_pred_flat = y_pred_flat.copy()
726
+ y_true_flat[y_true_flat == 2] = 1
727
+ y_pred_flat[y_pred_flat == 2] = 1
728
+ # collapse probs to 2-class
729
+ proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
730
+ proba_2[:, 0] = y_proba_flat[:, 0]
731
+ proba_2[:, 1] = y_proba_flat[:, 2]
732
+ y_proba_flat = proba_2
733
+
734
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
735
+
736
+ metrics = self.scorers_.evaluate(
737
+ y_true_flat,
738
+ y_pred_flat,
739
+ y_true_ohe,
740
+ y_proba_flat,
741
+ objective_mode,
742
+ self.tune_metric,
743
+ )
744
+
745
+ if not objective_mode:
746
+ self.logger.info(f"Validation Metrics: {metrics}")
747
+
748
+ # Primary report (REF/HET/ALT or REF/ALT)
749
+ self._make_class_reports(
750
+ y_true=y_true_flat,
751
+ y_pred_proba=y_proba_flat,
752
+ y_pred=y_pred_flat,
753
+ metrics=metrics,
754
+ labels=target_names,
755
+ )
756
+
757
+ # IUPAC decode & 10-base integer report (parity with ImputeNLPCA)
758
+ y_true_dec = self.pgenc.decode_012(X_val)
759
+ X_pred = X_val.copy()
760
+ X_pred[eval_mask] = y_pred_flat
761
+ y_pred_dec = self.pgenc.decode_012(
762
+ X_pred.reshape(X_val.shape[0], self.num_features_)
763
+ )
764
+
765
+ encodings_dict = {
766
+ "A": 0,
767
+ "C": 1,
768
+ "G": 2,
769
+ "T": 3,
770
+ "W": 4,
771
+ "R": 5,
772
+ "M": 6,
773
+ "K": 7,
774
+ "Y": 8,
775
+ "S": 9,
776
+ "N": -1,
777
+ }
778
+ y_true_int = self.pgenc.convert_int_iupac(
779
+ y_true_dec, encodings_dict=encodings_dict
780
+ )
781
+ y_pred_int = self.pgenc.convert_int_iupac(
782
+ y_pred_dec, encodings_dict=encodings_dict
783
+ )
784
+
785
+ self._make_class_reports(
786
+ y_true=y_true_int[eval_mask],
787
+ y_pred=y_pred_int[eval_mask],
788
+ metrics=metrics,
789
+ y_pred_proba=None,
790
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
791
+ )
792
+
793
+ return metrics
794
+
795
+ def _objective(self, trial: optuna.Trial) -> float:
796
+ """Optuna objective for AE; mirrors NLPCA study driver without latents.
797
+
798
+ This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
799
+
800
+ Args:
801
+ trial (optuna.Trial): Optuna trial.
802
+
803
+ Returns:
804
+ float: Value of the tuning metric (maximize).
805
+ """
806
+ try:
807
+ # Sample hyperparameters (existing helper; unchanged signature)
808
+ params = self._sample_hyperparameters(trial)
809
+
810
+ # Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
811
+ X_train = self.ground_truth_[self.train_idx_]
812
+ X_val = self.ground_truth_[self.test_idx_]
813
+
814
+ class_weights = self._class_weights_from_zygosity(X_train).to(self.device)
815
+ train_loader = self._get_data_loaders(X_train)
816
+
817
+ model = self.build_model(self.Model, params["model_params"])
818
+ model.apply(self.initialize_weights)
819
+
820
+ # Train + prune on metric
821
+ _, model, _ = self._train_and_validate_model(
822
+ model=model,
823
+ loader=train_loader,
824
+ lr=params["lr"],
825
+ l1_penalty=params["l1_penalty"],
826
+ trial=trial,
827
+ return_history=False,
828
+ class_weights=class_weights,
829
+ X_val=X_val,
830
+ params=params,
831
+ prune_metric=self.tune_metric,
832
+ prune_warmup_epochs=5,
833
+ eval_interval=self.tune_eval_interval,
834
+ eval_requires_latents=False,
835
+ eval_latent_steps=0,
836
+ eval_latent_lr=0.0,
837
+ eval_latent_weight_decay=0.0,
838
+ )
839
+
840
+ metrics = self._evaluate_model(X_val, model, params, objective_mode=True)
841
+ self._clear_resources(model, train_loader)
842
+ return metrics[self.tune_metric]
843
+
844
+ except Exception as e:
845
+ # Keep sweeps moving if a trial fails
846
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
847
+
848
+ def _sample_hyperparameters(
849
+ self, trial: optuna.Trial
850
+ ) -> Dict[str, int | float | str]:
851
+ """Sample AE hyperparameters and compute hidden sizes for model params.
852
+
853
+ This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
854
+
855
+ Args:
856
+ trial (optuna.Trial): Optuna trial object.
857
+
858
+ Returns:
859
+ Dict[str, int | float | str]: Sampled hyperparameters and model_params.
860
+ """
861
+ params = {
862
+ "latent_dim": trial.suggest_int("latent_dim", 2, 64),
863
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
864
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
865
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
866
+ "activation": trial.suggest_categorical(
867
+ "activation", ["relu", "elu", "selu"]
868
+ ),
869
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
870
+ "layer_scaling_factor": trial.suggest_float(
871
+ "layer_scaling_factor", 2.0, 10.0
872
+ ),
873
+ "layer_schedule": trial.suggest_categorical(
874
+ "layer_schedule", ["pyramid", "constant", "linear"]
875
+ ),
876
+ }
877
+
878
+ input_dim = self.num_features_ * self.num_classes_
879
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
880
+ n_inputs=input_dim,
881
+ n_outputs=input_dim,
882
+ n_samples=len(self.train_idx_),
883
+ n_hidden=params["num_hidden_layers"],
884
+ alpha=params["layer_scaling_factor"],
885
+ schedule=params["layer_schedule"],
886
+ )
887
+
888
+ # Keep the latent_dim as the first element,
889
+ # then the interior hidden widths.
890
+ # If there are no interior widths (very small nets),
891
+ # this still leaves [latent_dim].
892
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
893
+
894
+ params["model_params"] = {
895
+ "n_features": self.num_features_,
896
+ "num_classes": self.num_classes_,
897
+ "latent_dim": params["latent_dim"],
898
+ "dropout_rate": params["dropout_rate"],
899
+ "hidden_layer_sizes": hidden_only,
900
+ "activation": params["activation"],
901
+ }
902
+ return params
903
+
904
+ def _set_best_params(
905
+ self, best_params: Dict[str, int | float | str | list]
906
+ ) -> Dict[str, int | float | str | list]:
907
+ """Adopt best params (ImputeNLPCA parity) and return model_params.
908
+
909
+ This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
910
+
911
+ Args:
912
+ best_params (Dict[str, int | float | str | list]): Best hyperparameters from tuning.
913
+
914
+ Returns:
915
+ Dict[str, int | float | str | list]: Model parameters for building the model.
916
+ """
917
+ self.latent_dim = best_params["latent_dim"]
918
+ self.dropout_rate = best_params["dropout_rate"]
919
+ self.learning_rate = best_params["learning_rate"]
920
+ self.l1_penalty = best_params["l1_penalty"]
921
+ self.activation = best_params["activation"]
922
+ self.layer_scaling_factor = best_params["layer_scaling_factor"]
923
+ self.layer_schedule = best_params["layer_schedule"]
924
+
925
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
926
+ n_inputs=self.num_features_ * self.num_classes_,
927
+ n_outputs=self.num_features_ * self.num_classes_,
928
+ n_samples=len(self.train_idx_),
929
+ n_hidden=best_params["num_hidden_layers"],
930
+ alpha=best_params["layer_scaling_factor"],
931
+ schedule=best_params["layer_schedule"],
932
+ )
933
+
934
+ # Keep the latent_dim as the first element,
935
+ # then the interior hidden widths.
936
+ # If there are no interior widths (very small nets),
937
+ # this still leaves [latent_dim].
938
+ hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
939
+
940
+ return {
941
+ "n_features": self.num_features_,
942
+ "latent_dim": self.latent_dim,
943
+ "hidden_layer_sizes": hidden_only,
944
+ "dropout_rate": self.dropout_rate,
945
+ "activation": self.activation,
946
+ "num_classes": self.num_classes_,
947
+ }
948
+
949
+ def _default_best_params(self) -> Dict[str, int | float | str | list]:
950
+ """Default model params when tuning is disabled.
951
+
952
+ This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
953
+
954
+ Returns:
955
+ Dict[str, int | float | str | list]: Default model parameters.
956
+ """
957
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
958
+ n_inputs=self.num_features_ * self.num_classes_,
959
+ n_outputs=self.num_features_ * self.num_classes_,
960
+ n_samples=len(self.ground_truth_),
961
+ n_hidden=self.num_hidden_layers,
962
+ alpha=self.layer_scaling_factor,
963
+ schedule=self.layer_schedule,
964
+ )
965
+ return {
966
+ "n_features": self.num_features_,
967
+ "latent_dim": self.latent_dim,
968
+ "hidden_layer_sizes": hidden_layer_sizes,
969
+ "dropout_rate": self.dropout_rate,
970
+ "activation": self.activation,
971
+ "num_classes": self.num_classes_,
972
+ }