pg-sui 1.6.14.dev9__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. pg_sui-1.7.0.dist-info/METADATA +288 -0
  2. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +591 -126
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +218 -533
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1268 -530
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +869 -764
  16. pgsui/impute/unsupervised/imputers/vae.py +928 -696
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pg_sui-1.6.14.dev9.dist-info/METADATA +0 -344
  27. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1554
  28. pgsui/impute/unsupervised/imputers/ubp.py +0 -1575
  29. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  30. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  31. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  32. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  33. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  34. {pg_sui-1.6.14.dev9.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,25 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
1
4
  import copy
2
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
5
+ import traceback
6
+ from collections import defaultdict
7
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
3
8
 
4
9
  import matplotlib.pyplot as plt
5
10
  import numpy as np
6
11
  import optuna
7
12
  import torch
8
13
  from sklearn.exceptions import NotFittedError
9
- from sklearn.model_selection import train_test_split
10
14
  from snpio.analysis.genotype_encoder import GenotypeEncoder
11
15
  from snpio.utils.logging import LoggerManager
12
16
  from torch.optim.lr_scheduler import CosineAnnealingLR
13
17
 
14
18
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
15
19
  from pgsui.data_processing.containers import AutoencoderConfig
16
- from pgsui.data_processing.transformers import SimMissingTransformer
17
20
  from pgsui.impute.unsupervised.base import BaseNNImputer
18
21
  from pgsui.impute.unsupervised.callbacks import EarlyStopping
19
- from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
22
+ from pgsui.impute.unsupervised.loss_functions import FocalCELoss
20
23
  from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
21
24
  from pgsui.utils.logging_utils import configure_logger
22
25
  from pgsui.utils.pretty_metrics import PrettyMetrics
@@ -26,30 +29,72 @@ if TYPE_CHECKING:
26
29
  from snpio.read_input.genotype_data import GenotypeData
27
30
 
28
31
 
32
+ def _make_warmup_cosine_scheduler(
33
+ optimizer: torch.optim.Optimizer,
34
+ *,
35
+ max_epochs: int,
36
+ warmup_epochs: int,
37
+ start_factor: float = 0.1,
38
+ ) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
39
+ """Create a warmup->cosine LR scheduler.
40
+
41
+ Args:
42
+ optimizer: Optimizer to schedule.
43
+ max_epochs: Total number of epochs.
44
+ warmup_epochs: Number of warmup epochs.
45
+ start_factor: Starting LR factor for warmup.
46
+
47
+ Returns:
48
+ torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: LR scheduler (SequentialLR if warmup_epochs > 0 else CosineAnnealingLR).
49
+ """
50
+ warmup_epochs = int(max(0, warmup_epochs))
51
+
52
+ if warmup_epochs == 0:
53
+ return CosineAnnealingLR(optimizer, T_max=max_epochs)
54
+
55
+ warmup = torch.optim.lr_scheduler.LinearLR(
56
+ optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
57
+ )
58
+ cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
59
+
60
+ return torch.optim.lr_scheduler.SequentialLR(
61
+ optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
62
+ )
63
+
64
+
29
65
  def ensure_autoencoder_config(
30
66
  config: AutoencoderConfig | dict | str | None,
31
67
  ) -> AutoencoderConfig:
32
68
  """Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
33
69
 
34
- This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
70
+ Notes:
71
+ - Supports top-level preset, or io.preset inside dict/YAML.
72
+ - Does not mutate user-provided dict (deep-copies before processing).
73
+ - Flattens nested dicts into dot-keys and applies them as overrides.
35
74
 
36
75
  Args:
37
- config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
76
+ config: AutoencoderConfig instance, dict, YAML path, or None.
38
77
 
39
78
  Returns:
40
- AutoencoderConfig: Concrete configuration instance.
79
+ Concrete AutoencoderConfig.
41
80
  """
42
81
  if config is None:
43
82
  return AutoencoderConfig()
44
83
  if isinstance(config, AutoencoderConfig):
45
84
  return config
46
85
  if isinstance(config, str):
47
- # YAML path — top-level `preset` key is supported
48
86
  return load_yaml_to_dataclass(config, AutoencoderConfig)
49
87
  if isinstance(config, dict):
50
- # Flatten dict into dot-keys then overlay onto a fresh instance
88
+ cfg_in = copy.deepcopy(config)
51
89
  base = AutoencoderConfig()
52
90
 
91
+ preset = cfg_in.pop("preset", None)
92
+ if "io" in cfg_in and isinstance(cfg_in["io"], dict):
93
+ preset = preset or cfg_in["io"].pop("preset", None)
94
+
95
+ if preset:
96
+ base = AutoencoderConfig.from_preset(preset)
97
+
53
98
  def _flatten(prefix: str, d: dict, out: dict) -> dict:
54
99
  for k, v in d.items():
55
100
  kk = f"{prefix}.{k}" if prefix else k
@@ -59,26 +104,24 @@ def ensure_autoencoder_config(
59
104
  out[kk] = v
60
105
  return out
61
106
 
62
- # Lift any present preset first
63
- preset_name = config.pop("preset", None)
64
- if "io" in config and isinstance(config["io"], dict):
65
- preset_name = preset_name or config["io"].pop("preset", None)
66
-
67
- if preset_name:
68
- base = AutoencoderConfig.from_preset(preset_name)
69
-
70
- flat = _flatten("", config, {})
107
+ flat = _flatten("", cfg_in, {})
71
108
  return apply_dot_overrides(base, flat)
72
109
 
73
110
  raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
74
111
 
75
112
 
76
113
  class ImputeAutoencoder(BaseNNImputer):
77
- """Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
114
+ """Autoencoder imputer for 0/1/2 genotypes.
78
115
 
79
- This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
116
+ Trains a feedforward autoencoder on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. Missingness is simulated once on the full matrix, then train/val/test splits reuse those masks. It supports haploid and diploid data, focal-CE reconstruction loss (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
80
117
 
81
- The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
118
+ Notes:
119
+ - Simulates missingness once on the full 0/1/2 matrix, then splits indices on clean ground truth.
120
+ - Maintains clean targets and corrupted inputs per train/val/test, plus per-split masks.
121
+ - Haploid harmonization happens after the single simulation (no re-simulation).
122
+ - Training/validation loss is computed only where targets are known (~orig_mask_*).
123
+ - Evaluation is computed only on simulated-missing sites (sim_mask_*).
124
+ - ``transform()`` fills only originally missing sites and hard-errors if decoding yields "N".
82
125
  """
83
126
 
84
127
  def __init__(
@@ -87,8 +130,7 @@ class ImputeAutoencoder(BaseNNImputer):
87
130
  *,
88
131
  tree_parser: Optional["TreeParser"] = None,
89
132
  config: Optional[Union["AutoencoderConfig", dict, str]] = None,
90
- overrides: dict | None = None,
91
- simulate_missing: bool | None = None,
133
+ overrides: Optional[dict] = None,
92
134
  sim_strategy: (
93
135
  Literal[
94
136
  "random",
@@ -99,34 +141,29 @@ class ImputeAutoencoder(BaseNNImputer):
99
141
  ]
100
142
  | None
101
143
  ) = None,
102
- sim_prop: float | None = None,
103
- sim_kwargs: dict | None = None,
144
+ sim_prop: Optional[float] = None,
145
+ sim_kwargs: Optional[dict] = None,
104
146
  ) -> None:
105
147
  """Initialize the Autoencoder imputer with a unified config interface.
106
148
 
107
- This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
108
-
109
149
  Args:
110
- genotype_data ("GenotypeData"): Backing genotype data object.
111
- tree_parser (Optional["TreeParser"]): Optional SNPio phylogenetic tree parser for population-specific modes.
112
- config (Union["AutoencoderConfig", dict, str] | None): Structured configuration as dataclass, nested dict, YAML path, or None.
113
- overrides (dict | None): Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
114
- simulate_missing (bool | None): Whether to simulate missing data during evaluation. If None, uses config default.
115
- sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
116
- sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
117
- sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
150
+ genotype_data (GenotypeData): Backing genotype data object.
151
+ tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
152
+ config (Optional[Union[AutoencoderConfig, dict, str]]): AutoencoderConfig, nested dict, YAML path, or None.
153
+ overrides (Optional[dict]): Optional dot-key overrides with highest precedence.
154
+ sim_strategy (Literal["random", "random_weighted" "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Override sim strategy; if None, uses config default.
155
+ sim_prop (Optional[float]): Override simulated missing proportion; if None, uses config default. Default is None.
156
+ sim_kwargs (Optional[dict]): Override/extend simulated missing kwargs; if None, uses config default.
118
157
  """
119
158
  self.model_name = "ImputeAutoencoder"
120
159
  self.genotype_data = genotype_data
121
160
  self.tree_parser = tree_parser
122
161
 
123
- # Normalize config then apply highest-precedence overrides
124
162
  cfg = ensure_autoencoder_config(config)
125
163
  if overrides:
126
164
  cfg = apply_dot_overrides(cfg, overrides)
127
165
  self.cfg = cfg
128
166
 
129
- # Logger consistent with NLPCA
130
167
  logman = LoggerManager(
131
168
  __name__,
132
169
  prefix=self.cfg.io.prefix,
@@ -138,8 +175,8 @@ class ImputeAutoencoder(BaseNNImputer):
138
175
  verbose=self.cfg.io.verbose,
139
176
  debug=self.cfg.io.debug,
140
177
  )
178
+ self.logger.propagate = False
141
179
 
142
- # BaseNNImputer bootstrapping (device/dirs/logging handled here)
143
180
  super().__init__(
144
181
  model_name=self.model_name,
145
182
  genotype_data=self.genotype_data,
@@ -150,11 +187,9 @@ class ImputeAutoencoder(BaseNNImputer):
150
187
  )
151
188
 
152
189
  self.Model = AutoencoderModel
153
-
154
- # Model hook & encoder
155
190
  self.pgenc = GenotypeEncoder(genotype_data)
156
191
 
157
- # IO / global
192
+ # I/O and global
158
193
  self.seed = self.cfg.io.seed
159
194
  self.n_jobs = self.cfg.io.n_jobs
160
195
  self.prefix = self.cfg.io.prefix
@@ -163,186 +198,150 @@ class ImputeAutoencoder(BaseNNImputer):
163
198
  self.debug = self.cfg.io.debug
164
199
  self.rng = np.random.default_rng(self.seed)
165
200
 
166
- # Simulated-missing controls (config defaults with ctor overrides)
201
+ # Simulation controls (match VAE pattern)
167
202
  sim_cfg = getattr(self.cfg, "sim", None)
168
203
  sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
169
204
  if sim_kwargs:
170
205
  sim_cfg_kwargs.update(sim_kwargs)
171
- self.simulate_missing = (
172
- (
173
- sim_cfg.simulate_missing
174
- if simulate_missing is None
175
- else bool(simulate_missing)
176
- )
177
- if sim_cfg is not None
178
- else bool(simulate_missing)
179
- )
206
+
180
207
  if sim_cfg is None:
181
208
  default_strategy = "random"
182
- default_prop = 0.10
209
+ default_prop = 0.2
183
210
  else:
184
211
  default_strategy = sim_cfg.sim_strategy
185
212
  default_prop = sim_cfg.sim_prop
213
+
214
+ self.simulate_missing = True
186
215
  self.sim_strategy = sim_strategy or default_strategy
187
216
  self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
188
217
  self.sim_kwargs = sim_cfg_kwargs
189
218
 
190
219
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
191
- msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
220
+ msg = "tree_parser is required for nonrandom sim strategies."
192
221
  self.logger.error(msg)
193
222
  raise ValueError(msg)
194
223
 
195
- # Model hyperparams
224
+ # Model architecture
196
225
  self.latent_dim = int(self.cfg.model.latent_dim)
197
226
  self.dropout_rate = float(self.cfg.model.dropout_rate)
198
227
  self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
199
228
  self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
200
- self.layer_schedule: str = str(self.cfg.model.layer_schedule)
201
- self.activation = str(self.cfg.model.hidden_activation)
202
- self.gamma = float(self.cfg.model.gamma)
229
+ self.layer_schedule = str(self.cfg.model.layer_schedule)
230
+ self.activation = str(self.cfg.model.activation)
231
+
232
+ # Training / loss controls (align with VAE fields where present)
233
+ self.power = float(getattr(self.cfg.train, "weights_power", 1.0))
234
+ self.max_ratio = getattr(self.cfg.train, "weights_max_ratio", None)
235
+ self.normalize = bool(getattr(self.cfg.train, "weights_normalize", True))
236
+ self.inverse = bool(getattr(self.cfg.train, "weights_inverse", False))
203
237
 
204
- # Train hyperparams
205
238
  self.batch_size = int(self.cfg.train.batch_size)
206
239
  self.learning_rate = float(self.cfg.train.learning_rate)
207
- self.l1_penalty: float = float(self.cfg.train.l1_penalty)
240
+ self.l1_penalty = float(self.cfg.train.l1_penalty)
208
241
  self.early_stop_gen = int(self.cfg.train.early_stop_gen)
209
242
  self.min_epochs = int(self.cfg.train.min_epochs)
210
243
  self.epochs = int(self.cfg.train.max_epochs)
211
244
  self.validation_split = float(self.cfg.train.validation_split)
212
- self.beta = float(self.cfg.train.weights_beta)
213
- self.max_ratio = float(self.cfg.train.weights_max_ratio)
214
245
 
215
- # Tuning
216
- self.tune = bool(self.cfg.tune.enabled)
217
- self.tune_fast = bool(self.cfg.tune.fast)
218
- self.tune_batch_size = int(self.cfg.tune.batch_size)
219
- self.tune_epochs = int(self.cfg.tune.epochs)
220
- self.tune_eval_interval = int(self.cfg.tune.eval_interval)
221
- self.tune_metric: str = self.cfg.tune.metric
222
-
223
- if self.tune_metric is not None:
224
- self.tune_metric_: (
225
- Literal[
226
- "pr_macro",
227
- "f1",
228
- "accuracy",
229
- "precision",
230
- "recall",
231
- "roc_auc",
232
- "average_precision",
233
- ]
234
- | None
235
- ) = self.cfg.tune.metric
246
+ # Gamma can live in cfg.model or cfg.train depending on your dataclasses
247
+ gamma_raw = getattr(
248
+ self.cfg.train, "gamma", getattr(self.cfg.model, "gamma", 0.0)
249
+ )
250
+ if not isinstance(gamma_raw, (float, int)):
251
+ msg = f"Gamma must be float|int; got {type(gamma_raw)}."
252
+ self.logger.error(msg)
253
+ raise TypeError(msg)
254
+ self.gamma = float(gamma_raw)
255
+ self.gamma_schedule = bool(getattr(self.cfg.train, "gamma_schedule", True))
236
256
 
257
+ # Hyperparameter tuning
258
+ self.tune = bool(self.cfg.tune.enabled)
259
+ self.tune_metric = cast(
260
+ Literal[
261
+ "pr_macro",
262
+ "f1",
263
+ "accuracy",
264
+ "precision",
265
+ "recall",
266
+ "roc_auc",
267
+ "average_precision",
268
+ "mcc",
269
+ "jaccard",
270
+ ],
271
+ self.cfg.tune.metric or "f1",
272
+ )
237
273
  self.n_trials = int(self.cfg.tune.n_trials)
238
274
  self.tune_save_db = bool(self.cfg.tune.save_db)
239
275
  self.tune_resume = bool(self.cfg.tune.resume)
240
- self.tune_max_samples = int(self.cfg.tune.max_samples)
241
- self.tune_max_loci = int(self.cfg.tune.max_loci)
242
- self.tune_infer_epochs = int(
243
- getattr(self.cfg.tune, "infer_epochs", 0)
244
- ) # AE unused
245
276
  self.tune_patience = int(self.cfg.tune.patience)
246
277
 
247
- # Evaluate
248
- # AE does not optimize latents, so these are unused / fixed
249
- self.eval_latent_steps: int = 0
250
- self.eval_latent_lr: float = 0.0
251
- self.eval_latent_weight_decay: float = 0.0
252
-
253
- # Plotting (parity with NLPCA PlotConfig)
254
- self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
255
- self.cfg.plot.fmt
256
- )
278
+ # Plotting
279
+ self.plot_format = self.cfg.plot.fmt
257
280
  self.plot_dpi = int(self.cfg.plot.dpi)
258
281
  self.plot_fontsize = int(self.cfg.plot.fontsize)
259
282
  self.title_fontsize = int(self.cfg.plot.fontsize)
260
283
  self.despine = bool(self.cfg.plot.despine)
261
284
  self.show_plots = bool(self.cfg.plot.show)
262
285
 
263
- # Core derived at fit-time
264
- self.is_haploid: bool = False
265
- self.num_classes_: int | None = None
286
+ # Fit-time attributes
287
+ self.is_haploid_: bool = False
288
+ self.num_classes_: int = 3
266
289
  self.model_params: Dict[str, Any] = {}
267
- self.sim_mask_global_: np.ndarray | None = None
268
- self.sim_mask_train_: np.ndarray | None = None
269
- self.sim_mask_test_: np.ndarray | None = None
270
290
 
271
- def fit(self) -> "ImputeAutoencoder":
272
- """Fit the autoencoder on 0/1/2 encoded genotypes (missing -> -1).
291
+ self.sim_mask_train_: np.ndarray
292
+ self.sim_mask_val_: np.ndarray
293
+ self.sim_mask_test_: np.ndarray
273
294
 
274
- This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented internally as -1. (When simulated-missing loci are generated via ``SimMissingTransformer`` they are first marked with -9 but are immediately re-encoded as -1 prior to training.) The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
295
+ self.orig_mask_train_: np.ndarray
296
+ self.orig_mask_val_: np.ndarray
297
+ self.orig_mask_test_: np.ndarray
275
298
 
276
- Returns:
277
- ImputeAutoencoder: Fitted instance.
299
+ def fit(self) -> "ImputeAutoencoder":
300
+ """Fit the Autoencoder imputer model to the genotype data.
301
+
302
+ This method performs the following steps:
303
+ 1. Validates the presence of SNP data in the genotype data.
304
+ 2. Determines ploidy and sets up the number of classes accordingly.
305
+ 3. Cleans the ground truth genotype matrix and simulates missingness.
306
+ 4. Splits the data into training, validation, and test sets.
307
+ 5. Prepares one-hot encoded inputs for the model.
308
+ 6. Initializes plotting utilities and valid-class masks.
309
+ 7. Sets up data loaders for training and validation.
310
+ 8. Performs hyperparameter tuning if enabled, otherwise uses fixed hyperparameters.
311
+ 9. Builds and trains the Autoencoder model.
312
+ 10. Evaluates the trained model on the test set.
313
+ 11. Returns the fitted ImputeAutoencoder instance.
278
314
 
279
- Raises:
280
- NotFittedError: If training fails.
315
+ Returns:
316
+ ImputeAutoencoder: The fitted ImputeAutoencoder instance.
281
317
  """
282
318
  self.logger.info(f"Fitting {self.model_name} model...")
283
319
 
284
- # --- Data prep (mirror NLPCA) ---
285
- X012 = self._get_float_genotypes(copy=True)
286
- GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
287
- self.ground_truth_ = GT_full.astype(np.int64, copy=False)
288
-
289
- self.sim_mask_global_ = None
290
- cache_key = self._sim_mask_cache_key()
291
- if self.simulate_missing:
292
- cached_mask = (
293
- None if cache_key is None else self._sim_mask_cache.get(cache_key)
294
- )
295
- if cached_mask is not None:
296
- self.sim_mask_global_ = cached_mask.copy()
297
- else:
298
- tr = SimMissingTransformer(
299
- genotype_data=self.genotype_data,
300
- tree_parser=self.tree_parser,
301
- prop_missing=self.sim_prop,
302
- strategy=self.sim_strategy,
303
- missing_val=-9,
304
- mask_missing=True,
305
- verbose=self.verbose,
306
- **self.sim_kwargs,
307
- )
308
- tr.fit(X012.copy())
309
- self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
310
- if cache_key is not None:
311
- self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
312
-
313
- X_for_model = self.ground_truth_.copy()
314
- X_for_model[self.sim_mask_global_] = -1
315
- else:
316
- X_for_model = self.ground_truth_.copy()
317
-
318
320
  if self.genotype_data.snp_data is None:
319
- msg = "SNP data is required for Autoencoder imputer."
321
+ msg = f"SNP data is required for {self.model_name}."
320
322
  self.logger.error(msg)
321
- raise TypeError(msg)
323
+ raise AttributeError(msg)
322
324
 
323
- # Ploidy & classes
324
- self.is_haploid = bool(
325
- np.all(
326
- np.isin(
327
- self.genotype_data.snp_data,
328
- ["A", "C", "G", "T", "N", "-", ".", "?"],
329
- )
325
+ self.ploidy = self.cfg.io.ploidy
326
+ self.is_haploid_ = self.ploidy == 1
327
+
328
+ if self.ploidy > 2:
329
+ msg = (
330
+ f"{self.model_name} currently supports only haploid (1) or diploid (2) "
331
+ f"data; got ploidy={self.ploidy}."
330
332
  )
331
- )
332
- self.ploidy = 1 if self.is_haploid else 2
333
- self.num_classes_ = 2 if self.is_haploid else 3
334
- self.logger.info(
335
- f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
336
- f"using {self.num_classes_} classes."
337
- )
333
+ self.logger.error(msg)
334
+ raise ValueError(msg)
338
335
 
339
- if self.is_haploid:
340
- self.ground_truth_[self.ground_truth_ == 2] = 1
341
- X_for_model[X_for_model == 2] = 1
336
+ self.num_classes_ = 2 if self.is_haploid_ else 3
342
337
 
343
- n_samples, self.num_features_ = X_for_model.shape
338
+ # Clean 0/1/2 ground truth (missing=-1)
339
+ gt_full = self.pgenc.genotypes_012.copy()
340
+ gt_full[gt_full < 0] = -1
341
+ gt_full = np.nan_to_num(gt_full, nan=-1.0)
342
+ self.ground_truth_ = gt_full.astype(np.int8)
343
+ self.num_features_ = int(self.ground_truth_.shape[1])
344
344
 
345
- # Model params (decoder outputs L * K logits)
346
345
  self.model_params = {
347
346
  "n_features": self.num_features_,
348
347
  "num_classes": self.num_classes_,
@@ -351,66 +350,194 @@ class ImputeAutoencoder(BaseNNImputer):
351
350
  "activation": self.activation,
352
351
  }
353
352
 
354
- # Train/Val split
355
- indices = np.arange(n_samples)
356
- train_idx, val_idx = train_test_split(
357
- indices, test_size=self.validation_split, random_state=self.seed
353
+ # Simulate missingness ONCE on the full matrix
354
+ X_for_model_full, self.sim_mask_, self.orig_mask_ = self.sim_missing_transform(
355
+ self.ground_truth_
358
356
  )
359
- self.train_idx_, self.test_idx_ = train_idx, val_idx
360
- self.X_train_ = X_for_model[train_idx]
361
- self.X_val_ = X_for_model[val_idx]
362
- self.GT_train_full_ = self.ground_truth_[train_idx]
363
- self.GT_test_full_ = self.ground_truth_[val_idx]
364
-
365
- if self.sim_mask_global_ is not None:
366
- self.sim_mask_train_ = self.sim_mask_global_[train_idx]
367
- self.sim_mask_test_ = self.sim_mask_global_[val_idx]
368
- else:
369
- self.sim_mask_train_ = None
370
- self.sim_mask_test_ = None
371
357
 
372
- # Plotters/scorers (shared utilities)
358
+ # Split indices based on clean ground truth
359
+ self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
360
+ self.ground_truth_
361
+ )
362
+
363
+ # --- Clean targets per split ---
364
+ X_train_clean = self.ground_truth_[self.train_idx_].copy()
365
+ X_val_clean = self.ground_truth_[self.val_idx_].copy()
366
+ X_test_clean = self.ground_truth_[self.test_idx_].copy()
367
+
368
+ # --- Corrupted inputs per split (from the single simulation) ---
369
+ X_train_corrupted = X_for_model_full[self.train_idx_].copy()
370
+ X_val_corrupted = X_for_model_full[self.val_idx_].copy()
371
+ X_test_corrupted = X_for_model_full[self.test_idx_].copy()
372
+
373
+ # --- Masks per split ---
374
+ self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
375
+ self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
376
+ self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
377
+
378
+ self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
379
+ self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
380
+ self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
381
+
382
+ # Persist per-split matrices
383
+ self.X_train_clean_ = X_train_clean
384
+ self.X_val_clean_ = X_val_clean
385
+ self.X_test_clean_ = X_test_clean
386
+
387
+ self.X_train_corrupted_ = X_train_corrupted
388
+ self.X_val_corrupted_ = X_val_corrupted
389
+ self.X_test_corrupted_ = X_test_corrupted
390
+
391
+ # Haploid harmonization (do NOT resimulate; just recode values)
392
+ if self.is_haploid_:
393
+
394
+ def _haploidize(arr: np.ndarray) -> np.ndarray:
395
+ out = arr.copy()
396
+ miss = out < 0
397
+ out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
398
+ out[miss] = -1
399
+ return out
400
+
401
+ self.X_train_clean_ = _haploidize(self.X_train_clean_)
402
+ self.X_val_clean_ = _haploidize(self.X_val_clean_)
403
+ self.X_test_clean_ = _haploidize(self.X_test_clean_)
404
+
405
+ self.X_train_corrupted_ = _haploidize(self.X_train_corrupted_)
406
+ self.X_val_corrupted_ = _haploidize(self.X_val_corrupted_)
407
+ self.X_test_corrupted_ = _haploidize(self.X_test_corrupted_)
408
+
409
+ # Convention: X_* are corrupted inputs; y_* are clean targets
410
+ self.X_train_ = self.X_train_corrupted_
411
+ self.y_train_ = self.X_train_clean_
412
+
413
+ self.X_val_ = self.X_val_corrupted_
414
+ self.y_val_ = self.X_val_clean_
415
+
416
+ self.X_test_ = self.X_test_corrupted_
417
+ self.y_test_ = self.X_test_clean_
418
+
419
+ # One-hot for loaders/model input
420
+ X_train_ohe = self._one_hot_encode_012(
421
+ self.X_train_, num_classes=self.num_classes_
422
+ )
423
+ X_val_ohe = self._one_hot_encode_012(self.X_val_, num_classes=self.num_classes_)
424
+
425
+ # Plotters/scorers + valid-class mask repairs (copied from VAE flow)
373
426
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
427
+ self.valid_class_mask_ = self._build_valid_class_mask()
428
+
429
+ loci = getattr(self, "valid_class_mask_conflict_loci_", None)
430
+ if loci is not None and loci.size:
431
+ self._repair_ref_alt_from_iupac(loci)
432
+ self.valid_class_mask_ = self._build_valid_class_mask()
433
+
434
+ train_loader = self._get_data_loaders(
435
+ X_train_ohe.detach().cpu().numpy(),
436
+ self.y_train_,
437
+ ~self.orig_mask_train_,
438
+ self.batch_size,
439
+ shuffle=True,
440
+ )
441
+ val_loader = self._get_data_loaders(
442
+ X_val_ohe.detach().cpu().numpy(),
443
+ self.y_val_,
444
+ ~self.orig_mask_val_,
445
+ self.batch_size,
446
+ shuffle=False,
447
+ )
448
+
449
+ self.train_loader_ = train_loader
450
+ self.val_loader_ = val_loader
374
451
 
375
- # Tuning (optional; AE never needs latent refinement)
452
+ # Hyperparameter tuning or fixed run
376
453
  if self.tune:
377
- self.tune_hyperparameters()
454
+ self.tuned_params_ = self.tune_hyperparameters()
455
+ self.model_tuned_ = True
456
+ else:
457
+ self.model_tuned_ = False
458
+ self.class_weights_ = self._class_weights_from_zygosity(
459
+ self.y_train_,
460
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
461
+ inverse=self.inverse,
462
+ normalize=self.normalize,
463
+ max_ratio=self.max_ratio,
464
+ power=self.power,
465
+ )
466
+ self.tuned_params_ = {
467
+ "latent_dim": self.latent_dim,
468
+ "learning_rate": self.learning_rate,
469
+ "dropout_rate": self.dropout_rate,
470
+ "num_hidden_layers": self.num_hidden_layers,
471
+ "activation": self.activation,
472
+ "l1_penalty": self.l1_penalty,
473
+ "layer_scaling_factor": self.layer_scaling_factor,
474
+ "layer_schedule": self.layer_schedule,
475
+ "gamma": self.gamma,
476
+ "gamma_schedule": self.gamma_schedule,
477
+ "inverse": self.inverse,
478
+ "normalize": self.normalize,
479
+ "power": self.power,
480
+ }
481
+ self.tuned_params_["model_params"] = self.model_params
378
482
 
379
- # Best params (tuned or default)
380
- self.best_params_ = getattr(self, "best_params_", self._default_best_params())
483
+ if self.class_weights_ is not None:
484
+ self.logger.info(
485
+ f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
486
+ )
381
487
 
382
- # Class weights (device-aware)
383
- self.class_weights_ = self._normalize_class_weights(
384
- self._class_weights_from_zygosity(self.X_train_)
385
- )
488
+ # Always start clean
489
+ self.best_params_ = copy.deepcopy(self.tuned_params_)
386
490
 
387
- # DataLoader
388
- train_loader = self._get_data_loaders(self.X_train_)
491
+ # Final model params (compute hidden sizes using n_inputs=L*K, mirroring VAE)
492
+ input_dim = int(self.num_features_ * self.num_classes_)
493
+ model_params_final = {
494
+ "n_features": int(self.num_features_),
495
+ "num_classes": int(self.num_classes_),
496
+ "latent_dim": int(self.best_params_["latent_dim"]),
497
+ "dropout_rate": float(self.best_params_["dropout_rate"]),
498
+ "activation": str(self.best_params_["activation"]),
499
+ }
500
+ model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
501
+ n_inputs=input_dim,
502
+ n_outputs=int(self.num_classes_),
503
+ n_samples=len(self.train_idx_),
504
+ n_hidden=int(self.best_params_["num_hidden_layers"]),
505
+ latent_dim=int(self.best_params_["latent_dim"]),
506
+ alpha=float(self.best_params_["layer_scaling_factor"]),
507
+ schedule=str(self.best_params_["layer_schedule"]),
508
+ min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
509
+ )
510
+ self.best_params_["model_params"] = model_params_final
389
511
 
390
- # Build & train
391
- model = self.build_model(self.Model, self.best_params_)
512
+ # Build and train
513
+ model = self.build_model(self.Model, self.best_params_["model_params"])
392
514
  model.apply(self.initialize_weights)
393
515
 
516
+ if self.verbose or self.debug:
517
+ self.logger.info("Using model hyperparameters:")
518
+ pm = PrettyMetrics(
519
+ self.best_params_, precision=3, title="Model Hyperparameters"
520
+ )
521
+ pm.render()
522
+
523
+ lr_final = float(self.best_params_["learning_rate"])
524
+ l1_final = float(self.best_params_["l1_penalty"])
525
+ gamma_schedule = bool(
526
+ self.best_params_.get("gamma_schedule", self.gamma_schedule)
527
+ )
528
+
394
529
  loss, trained_model, history = self._train_and_validate_model(
395
530
  model=model,
396
- loader=train_loader,
397
- lr=self.learning_rate,
398
- l1_penalty=self.l1_penalty,
399
- return_history=True,
400
- class_weights=self.class_weights_,
401
- X_val=self.X_val_,
531
+ lr=lr_final,
532
+ l1_penalty=l1_final,
402
533
  params=self.best_params_,
403
- prune_metric=self.tune_metric,
404
- prune_warmup_epochs=5,
405
- eval_interval=1,
406
- eval_requires_latents=False,
407
- eval_latent_steps=0,
408
- eval_latent_lr=0.0,
409
- eval_latent_weight_decay=0.0,
534
+ trial=None,
535
+ class_weights=getattr(self, "class_weights_", None),
536
+ gamma_schedule=gamma_schedule,
410
537
  )
411
538
 
412
539
  if trained_model is None:
413
- msg = "Autoencoder training failed; no model was returned."
540
+ msg = f"{self.model_name} training failed."
414
541
  self.logger.error(msg)
415
542
  raise RuntimeError(msg)
416
543
 
@@ -419,217 +546,194 @@ class ImputeAutoencoder(BaseNNImputer):
419
546
  self.models_dir / f"final_model_{self.model_name}.pt",
420
547
  )
421
548
 
422
- hist: Dict[str, List[float] | Dict[str, List[float]] | None] | None = {
423
- "Train": history
424
- }
425
- self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
549
+ if history is None:
550
+ hist = {"Train": []}
551
+ elif isinstance(history, dict):
552
+ hist = dict(history)
553
+ else:
554
+ hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
555
+
556
+ self.best_loss_ = float(loss)
557
+ self.model_ = trained_model
558
+ self.history_ = hist
426
559
  self.is_fit_ = True
427
560
 
428
- # Evaluate on validation set (parity with NLPCA reporting)
429
- eval_mask = (
430
- self.sim_mask_test_
431
- if (self.simulate_missing and self.sim_mask_test_ is not None)
432
- else None
433
- )
561
+ # Evaluate on simulated-missing sites only
434
562
  self._evaluate_model(
435
- self.X_val_, self.model_, self.best_params_, eval_mask_override=eval_mask
563
+ self.model_,
564
+ X=self.X_test_,
565
+ y=self.y_test_,
566
+ eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
567
+ objective_mode=False,
436
568
  )
437
- self.plotter_.plot_history(self.history_)
569
+
570
+ if self.show_plots:
571
+ self.plotter_.plot_history(self.history_)
572
+
438
573
  self._save_best_params(self.best_params_)
439
574
 
575
+ if self.model_tuned_:
576
+ title = f"{self.model_name} Optimized Parameters"
577
+
578
+ if self.verbose or self.debug:
579
+ pm = PrettyMetrics(self.best_params_, precision=2, title=title)
580
+ pm.render()
581
+
582
+ # Save best parameters to a JSON file.
583
+ self._save_best_params(self.best_params_, objective_mode=True)
584
+
440
585
  return self
441
586
 
442
587
  def transform(self) -> np.ndarray:
443
- """Impute missing genotypes (0/1/2) and return IUPAC strings.
588
+ """Impute missing genotypes and return IUPAC strings.
444
589
 
445
- This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
590
+ This method performs the following steps:
591
+ 1. Validates that the model has been fitted.
592
+ 2. Uses the trained model to predict missing genotypes for the entire dataset.
593
+ 3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
594
+ 4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
595
+ 5. Checks for any remaining missing values or decoding issues, raising errors if found.
596
+ 6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
597
+ 7. Returns the imputed IUPAC genotype matrix.
446
598
 
447
599
  Returns:
448
- np.ndarray: IUPAC strings of shape (n_samples, n_loci).
600
+ np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
449
601
 
450
602
  Raises:
451
603
  NotFittedError: If called before fit().
604
+ RuntimeError: If any missing values remain or decoding yields "N".
605
+ RuntimeError: If loci contain 'N' after imputation due to missing REF/ALT metadata.
452
606
  """
453
607
  if not getattr(self, "is_fit_", False):
454
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
608
+ msg = "Model is not fitted. Call fit() before transform()."
609
+ self.logger.error(msg)
610
+ raise NotFittedError(msg)
455
611
 
456
- self.logger.info(f"Imputing entire dataset with {self.model_name}...")
612
+ self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
457
613
  X_to_impute = self.ground_truth_.copy()
458
614
 
459
- # Predict with masked inputs (no latent optimization)
460
- pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
615
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute)
461
616
 
462
- # Fill only missing
463
- missing_mask = X_to_impute == -1
617
+ missing_mask = X_to_impute < 0
464
618
  imputed_array = X_to_impute.copy()
465
619
  imputed_array[missing_mask] = pred_labels[missing_mask]
466
620
 
467
- # Decode to IUPAC & optionally plot
468
- imputed_genotypes = self.pgenc.decode_012(imputed_array)
621
+ if np.any(imputed_array < 0):
622
+ msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
623
+ self.logger.error(msg)
624
+ raise RuntimeError(msg)
625
+
626
+ decode_input = imputed_array
627
+ if self.is_haploid_:
628
+ decode_input = imputed_array.copy()
629
+ decode_input[decode_input == 1] = 2
630
+
631
+ imputed_genotypes = self.decode_012(decode_input)
632
+
633
+ bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
634
+ if bad_loci.size > 0:
635
+ msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
636
+ self.logger.error(msg)
637
+ self.logger.debug(
638
+ "All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
639
+ )
640
+ raise RuntimeError(msg)
641
+
469
642
  if self.show_plots:
470
- original_genotypes = self.pgenc.decode_012(X_to_impute)
643
+ original_input = X_to_impute
644
+ if self.is_haploid_:
645
+ original_input = X_to_impute.copy()
646
+ original_input[original_input == 1] = 2
647
+
648
+ original_genotypes = self.decode_012(original_input)
649
+
471
650
  plt.rcParams.update(self.plotter_.param_dict)
472
651
  self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
473
652
  self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
474
653
 
475
654
  return imputed_genotypes
476
655
 
477
- def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
478
- """Create DataLoader over indices + integer targets (-1 for missing).
479
-
480
- This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
481
-
482
- Args:
483
- y (np.ndarray): 0/1/2 matrix with -1 for missing.
484
-
485
- Returns:
486
- torch.utils.data.DataLoader: Shuffled DataLoader.
487
- """
488
- y_tensor = torch.from_numpy(y).long()
489
- indices = torch.arange(len(y), dtype=torch.long)
490
- dataset = torch.utils.data.TensorDataset(indices, y_tensor)
491
- pin_memory = self.device.type == "cuda"
492
- return torch.utils.data.DataLoader(
493
- dataset,
494
- batch_size=self.batch_size,
495
- shuffle=True,
496
- pin_memory=pin_memory,
497
- )
498
-
499
656
  def _train_and_validate_model(
500
657
  self,
501
658
  model: torch.nn.Module,
502
- loader: torch.utils.data.DataLoader,
659
+ *,
503
660
  lr: float,
504
661
  l1_penalty: float,
505
- trial: optuna.Trial | None = None,
506
- return_history: bool = False,
507
- class_weights: torch.Tensor | None = None,
508
- *,
509
- X_val: np.ndarray | None = None,
510
- params: dict | None = None,
511
- prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
512
- prune_warmup_epochs: int = 3,
513
- eval_interval: int = 1,
514
- # Evaluation parameters (AE ignores latent refinement knobs)
515
- eval_requires_latents: bool = False, # AE: always False
516
- eval_latent_steps: int = 0,
517
- eval_latent_lr: float = 0.0,
518
- eval_latent_weight_decay: float = 0.0,
519
- ) -> Tuple[float, torch.nn.Module | None, list | None]:
520
- """Wrap the AE training loop (no latent optimizer), with Optuna pruning.
521
-
522
- This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
662
+ trial: Optional[optuna.Trial] = None,
663
+ params: Optional[dict[str, Any]] = None,
664
+ class_weights: Optional[torch.Tensor] = None,
665
+ gamma_schedule: bool = False,
666
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
667
+ """Train and validate the model.
668
+
669
+ This method sets up the optimizer and learning rate scheduler, then executes the training loop with early stopping and optional hyperparameter tuning via Optuna. It returns the best validation loss, the best model, and the training history.
523
670
 
524
671
  Args:
525
672
  model (torch.nn.Module): Autoencoder model.
526
- loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
527
673
  lr (float): Learning rate.
528
- l1_penalty (float): L1 regularization coeff.
529
- trial (optuna.Trial | None): Optuna trial for pruning (optional).
530
- return_history (bool): If True, return train loss history.
531
- class_weights (torch.Tensor | None): Class weights tensor (on device).
532
- X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
533
- params (dict | None): Model params for evaluation.
534
- prune_metric (str): Metric for pruning reports.
535
- prune_warmup_epochs (int): Pruning warmup epochs.
536
- eval_interval (int): Eval frequency (epochs).
537
- eval_requires_latents (bool): Ignored for AE (no latent inference).
538
- eval_latent_steps (int): Unused for AE.
539
- eval_latent_lr (float): Unused for AE.
540
- eval_latent_weight_decay (float): Unused for AE.
674
+ l1_penalty (float): L1 regularization coefficient.
675
+ trial (Optional[optuna.Trial]): Optuna trial (optional).
676
+ params (Optional[dict[str, Any]]): Hyperparams dict (optional).
677
+ class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
678
+ gamma_schedule (bool): Whether to schedule gamma.
541
679
 
542
680
  Returns:
543
- Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
681
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, history.
544
682
  """
545
- if class_weights is None:
546
- msg = "Must provide class_weights."
547
- self.logger.error(msg)
548
- raise TypeError(msg)
683
+ max_epochs = int(self.epochs)
684
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
549
685
 
550
- # Epoch budget mirrors NLPCA config (tuning vs final)
551
- max_epochs = (
552
- self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
686
+ scheduler = _make_warmup_cosine_scheduler(
687
+ optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
553
688
  )
554
689
 
555
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
556
- scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
557
-
558
690
  best_loss, best_model, hist = self._execute_training_loop(
559
- loader=loader,
560
691
  optimizer=optimizer,
561
692
  scheduler=scheduler,
562
693
  model=model,
563
694
  l1_penalty=l1_penalty,
564
695
  trial=trial,
565
- return_history=return_history,
566
- class_weights=class_weights,
567
- X_val=X_val,
568
696
  params=params,
569
- prune_metric=prune_metric,
570
- prune_warmup_epochs=prune_warmup_epochs,
571
- eval_interval=eval_interval,
572
- eval_requires_latents=False, # AE: no latent inference
573
- eval_latent_steps=0,
574
- eval_latent_lr=0.0,
575
- eval_latent_weight_decay=0.0,
697
+ class_weights=class_weights,
698
+ gamma_schedule=gamma_schedule,
576
699
  )
577
- if return_history:
578
- return best_loss, best_model, hist
579
-
580
- return best_loss, best_model, None
700
+ return best_loss, best_model, hist
581
701
 
582
702
  def _execute_training_loop(
583
703
  self,
584
- loader: torch.utils.data.DataLoader,
704
+ *,
585
705
  optimizer: torch.optim.Optimizer,
586
- scheduler: CosineAnnealingLR,
706
+ scheduler: (
707
+ torch.optim.lr_scheduler.CosineAnnealingLR
708
+ | torch.optim.lr_scheduler.SequentialLR
709
+ ),
587
710
  model: torch.nn.Module,
588
711
  l1_penalty: float,
589
- trial: optuna.Trial | None,
590
- return_history: bool,
591
- class_weights: torch.Tensor,
592
- *,
593
- X_val: np.ndarray | None = None,
594
- params: dict | None = None,
595
- prune_metric: str = "f1",
596
- prune_warmup_epochs: int = 3,
597
- eval_interval: int = 1,
598
- # Evaluation parameters (AE ignores latent refinement knobs)
599
- eval_requires_latents: bool = False, # AE: False
600
- eval_latent_steps: int = 0,
601
- eval_latent_lr: float = 0.0,
602
- eval_latent_weight_decay: float = 0.0,
603
- ) -> Tuple[float, torch.nn.Module, list]:
604
- """Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
605
-
606
- This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
712
+ trial: Optional[optuna.Trial] = None,
713
+ params: Optional[dict[str, Any]] = None,
714
+ class_weights: Optional[torch.Tensor] = None,
715
+ gamma_schedule: bool = False,
716
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
717
+ """Train AE (masked focal CE) with EarlyStopping + Optuna pruning.
607
718
 
608
719
  Args:
609
- loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
610
- optimizer (torch.optim.Optimizer): Optimizer.
611
- scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
720
+ optimizer (torch.optim.Optimizer): Optimizer for training.
721
+ scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): LR scheduler.
612
722
  model (torch.nn.Module): Autoencoder model.
613
- l1_penalty (float): L1 regularization coeff.
614
- trial (optuna.Trial | None): Optuna trial for pruning (optional).
615
- return_history (bool): If True, return train loss history.
616
- class_weights (torch.Tensor): Class weights tensor (on device).
617
- X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
618
- params (dict | None): Model params for evaluation.
619
- prune_metric (str): Metric for pruning reports.
620
- prune_warmup_epochs (int): Pruning warmup epochs.
621
- eval_interval (int): Eval frequency (epochs).
622
- eval_requires_latents (bool): Ignored for AE (no latent inference).
623
- eval_latent_steps (int): Unused for AE.
624
- eval_latent_lr (float): Unused for AE.
625
- eval_latent_weight_decay (float): Unused for AE.
723
+ l1_penalty (float): L1 regularization coefficient.
724
+ trial (Optional[optuna.Trial]): Optuna trial (optional).
725
+ params (Optional[dict[str, Any]]): Hyperparams dict (optional).
726
+ class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
727
+ gamma_schedule (bool): Whether to schedule gamma.
626
728
 
627
729
  Returns:
628
- Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
730
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best loss, best model, and training history.
731
+
732
+ Notes:
733
+ - Computes loss only where targets are known (~orig_mask_*).
734
+ - Evaluates metrics only on simulated-missing sites (sim_mask_*).
629
735
  """
630
- best_loss = float("inf")
631
- best_model = None
632
- history: list[float] = []
736
+ history: dict[str, list[float]] = defaultdict(list)
633
737
 
634
738
  early_stopping = EarlyStopping(
635
739
  patience=self.early_stop_gen,
@@ -639,146 +743,157 @@ class ImputeAutoencoder(BaseNNImputer):
639
743
  debug=self.debug,
640
744
  )
641
745
 
642
- gamma_val = self.gamma
643
- if isinstance(gamma_val, (list, tuple)):
644
- if len(gamma_val) == 0:
645
- raise ValueError("gamma list is empty.")
646
- gamma_val = gamma_val[0]
647
-
648
- gamma_final = float(gamma_val)
649
- gamma_warm, gamma_ramp = 50, 100
650
-
651
- # Optional LR warmup
652
- warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
653
- base_lr = float(optimizer.param_groups[0]["lr"])
654
- min_lr = base_lr * 0.1
655
-
656
- max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
657
-
658
- for epoch in range(max_epochs):
659
- # focal γ schedule (for stable training)
660
- if epoch < gamma_warm:
661
- model.gamma = 0.0 # type: ignore
662
- elif epoch < gamma_warm + gamma_ramp:
663
- model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
746
+ gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
747
+ params, "gamma", default=self.gamma, max_epochs=self.epochs
748
+ )
749
+ gamma_target = float(gamma_target)
750
+
751
+ cw = class_weights
752
+ if cw is not None and cw.device != self.device:
753
+ cw = cw.to(self.device)
754
+
755
+ for epoch in range(int(self.epochs)):
756
+ if gamma_schedule:
757
+ gamma_current = self._update_anneal_schedule(
758
+ gamma_target,
759
+ warm=gamma_warm,
760
+ ramp=gamma_ramp,
761
+ epoch=epoch,
762
+ init_val=0.0,
763
+ )
764
+ gamma_val = float(gamma_current)
664
765
  else:
665
- model.gamma = gamma_final # type: ignore
766
+ gamma_val = gamma_target
666
767
 
667
- # LR warmup
668
- if epoch < warmup_epochs:
669
- scale = float(epoch + 1) / warmup_epochs
670
- for g in optimizer.param_groups:
671
- g["lr"] = min_lr + (base_lr - min_lr) * scale
768
+ ce_criterion = FocalCELoss(
769
+ alpha=cw, gamma=gamma_val, ignore_index=-1, reduction="mean"
770
+ )
672
771
 
673
772
  train_loss = self._train_step(
674
- loader=loader,
773
+ loader=self.train_loader_,
675
774
  optimizer=optimizer,
676
775
  model=model,
776
+ ce_criterion=ce_criterion,
677
777
  l1_penalty=l1_penalty,
678
- class_weights=class_weights,
679
778
  )
680
779
 
681
- # Abort or prune on non-finite epoch loss
682
780
  if not np.isfinite(train_loss):
683
781
  if trial is not None:
684
- raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
685
- # Soft reset suggestion: reduce LR and continue, or break
686
- self.logger.warning(
687
- "Non-finite epoch loss. Reducing LR by 10 percent and continuing."
688
- )
689
- for g in optimizer.param_groups:
690
- g["lr"] *= 0.9
691
- continue
782
+ msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
783
+ self.logger.warning(msg)
784
+ raise optuna.exceptions.TrialPruned(msg)
785
+ msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
786
+ self.logger.error(msg)
787
+ raise RuntimeError(msg)
788
+
789
+ val_loss = self._val_step(
790
+ loader=self.val_loader_,
791
+ model=model,
792
+ ce_criterion=ce_criterion,
793
+ l1_penalty=l1_penalty,
794
+ )
692
795
 
693
796
  scheduler.step()
694
- if return_history:
695
- history.append(train_loss)
797
+ history["Train"].append(float(train_loss))
798
+ history["Val"].append(float(val_loss))
696
799
 
697
- early_stopping(train_loss, model)
800
+ early_stopping(val_loss, model)
698
801
  if early_stopping.early_stop:
699
- self.logger.info(f"Early stopping at epoch {epoch + 1}.")
802
+ self.logger.debug(
803
+ f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
804
+ )
700
805
  break
701
806
 
702
- # Optuna report/prune on validation metric
703
- if (
704
- trial is not None
705
- and X_val is not None
706
- and ((epoch + 1) % eval_interval == 0)
707
- ):
708
- metric_key = prune_metric or getattr(self, "tune_metric", "f1")
709
- mask_override = None
710
- if (
711
- self.simulate_missing
712
- and getattr(self, "sim_mask_test_", None) is not None
713
- and getattr(self, "X_val_", None) is not None
714
- and X_val.shape == self.X_val_.shape
715
- ):
716
- mask_override = self.sim_mask_test_
717
- metric_val = self._eval_for_pruning(
807
+ if trial is not None:
808
+ metric_vals = self._evaluate_model(
718
809
  model=model,
719
- X_val=X_val,
720
- params=params or getattr(self, "best_params_", {}),
721
- metric=metric_key,
810
+ X=self.X_val_,
811
+ y=self.y_val_,
812
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
722
813
  objective_mode=True,
723
- do_latent_infer=False, # AE: False
724
- latent_steps=0,
725
- latent_lr=0.0,
726
- latent_weight_decay=0.0,
727
- latent_seed=self.seed, # type: ignore
728
- _latent_cache=None, # AE: not used
729
- _latent_cache_key=None,
730
- eval_mask_override=mask_override,
731
814
  )
732
- trial.report(metric_val, step=epoch + 1)
733
- if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
815
+ trial.report(metric_vals[self.tune_metric], step=epoch + 1)
816
+ if trial.should_prune():
734
817
  raise optuna.exceptions.TrialPruned(
735
- f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
818
+ f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
736
819
  )
737
820
 
738
- best_loss = early_stopping.best_score
739
- if early_stopping.best_model is not None:
740
- best_model = copy.deepcopy(early_stopping.best_model)
741
- else:
742
- best_model = copy.deepcopy(model)
743
- return best_loss, best_model, history
821
+ best_loss = float(early_stopping.best_score)
822
+ if early_stopping.best_state_dict is not None:
823
+ model.load_state_dict(early_stopping.best_state_dict)
824
+
825
+ return best_loss, model, dict(history)
744
826
 
745
827
  def _train_step(
746
828
  self,
747
829
  loader: torch.utils.data.DataLoader,
748
830
  optimizer: torch.optim.Optimizer,
749
831
  model: torch.nn.Module,
832
+ ce_criterion: torch.nn.Module,
833
+ *,
750
834
  l1_penalty: float,
751
- class_weights: torch.Tensor,
752
835
  ) -> float:
753
- """One epoch with stable focal CE and NaN/Inf guards."""
836
+ """Single epoch train step (masked focal CE + optional L1).
837
+
838
+ Args:
839
+ loader (torch.utils.data.DataLoader): Training data loader.
840
+ optimizer (torch.optim.Optimizer): Optimizer for training.
841
+ model (torch.nn.Module): Autoencoder model.
842
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
843
+ l1_penalty (float): L1 regularization coefficient.
844
+
845
+ Returns:
846
+ float: Average training loss over the epoch.
847
+
848
+ Notes:
849
+ Expects loader batches as (X_ohe, y_int, mask_bool) where:
850
+ - X_ohe: (B, L, C) float/compatible
851
+ - y_int: (B, L) int, with -1 for unknown targets
852
+ - mask_bool: (B, L) bool selecting which positions contribute to loss
853
+ """
754
854
  model.train()
755
855
  running = 0.0
756
856
  num_batches = 0
757
- l1_params = tuple(p for p in model.parameters() if p.requires_grad)
758
- if class_weights is not None and class_weights.device != self.device:
759
- class_weights = class_weights.to(self.device)
760
857
 
761
- # Use model.gamma if present, else self.gamma
762
- gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
763
- gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
764
- criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
858
+ nF_model = int(getattr(model, "n_features", self.num_features_))
859
+ nC_model = int(getattr(model, "num_classes", self.num_classes_))
860
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
765
861
 
766
- for _, y_batch in loader:
862
+ for X_batch, y_batch, m_batch in loader:
767
863
  optimizer.zero_grad(set_to_none=True)
768
- y_batch = y_batch.to(self.device, non_blocking=True)
864
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
865
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
866
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
769
867
 
770
- # Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
771
- x_ohe = self._one_hot_encode_012(y_batch) # (B, L, K)
772
- logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
773
- logits_flat = logits.view(-1, self.num_classes_)
774
- targets_flat = y_batch.view(-1).long()
868
+ if (
869
+ X_batch.dim() != 3
870
+ or X_batch.shape[1] != nF_model
871
+ or X_batch.shape[2] != nC_model
872
+ ):
873
+ msg = (
874
+ f"Train batch X shape mismatch: expected (B,{nF_model},{nC_model}), "
875
+ f"got {tuple(X_batch.shape)}."
876
+ )
877
+ self.logger.error(msg)
878
+ raise ValueError(msg)
775
879
 
776
- # Upfront guards on inputs
777
- if not torch.isfinite(logits_flat).all():
778
- # Skip this batch if model already produced non-finite
779
- continue
880
+ logits_flat = model(X_batch)
881
+ expected = (X_batch.shape[0], nF_model * nC_model)
882
+ if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
883
+ try:
884
+ logits_flat = logits_flat.view(*expected)
885
+ except Exception as e:
886
+ msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
887
+ self.logger.error(msg)
888
+ raise ValueError(msg) from e
780
889
 
781
- loss = criterion(logits_flat, targets_flat)
890
+ logits = logits_flat.view(-1, nF_model, nC_model)
891
+ logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
892
+
893
+ targets_masked = y_batch.view(-1)
894
+ targets_masked = targets_masked[m_batch.view(-1)]
895
+
896
+ loss = ce_criterion(logits_masked, targets_masked)
782
897
 
783
898
  if l1_penalty > 0:
784
899
  l1 = torch.zeros((), device=self.device)
@@ -786,194 +901,234 @@ class ImputeAutoencoder(BaseNNImputer):
786
901
  l1 = l1 + p.abs().sum()
787
902
  loss = loss + l1_penalty * l1
788
903
 
789
- # Final guard
790
904
  if not torch.isfinite(loss):
791
905
  continue
792
906
 
793
907
  loss.backward()
794
-
795
- # Clip to prevent exploding grads
796
908
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
797
-
798
- # If grads blew up to non-finite, skip update
799
- if any(
800
- (not torch.isfinite(p.grad).all())
801
- for p in model.parameters()
802
- if p.grad is not None
803
- ):
804
- optimizer.zero_grad(set_to_none=True)
805
- continue
806
-
807
909
  optimizer.step()
808
910
 
809
911
  running += float(loss.detach().item())
810
912
  num_batches += 1
811
913
 
812
- if num_batches == 0:
813
- return float("inf") # signal upstream that epoch had no usable batches
814
- return running / num_batches
914
+ return float("inf") if num_batches == 0 else running / num_batches
915
+
916
+ def _val_step(
917
+ self,
918
+ loader: torch.utils.data.DataLoader,
919
+ model: torch.nn.Module,
920
+ ce_criterion: torch.nn.Module,
921
+ *,
922
+ l1_penalty: float,
923
+ ) -> float:
924
+ """Validation step (masked focal CE + optional L1).
925
+
926
+ Args:
927
+ loader (torch.utils.data.DataLoader): Validation data loader.
928
+ model (torch.nn.Module): Autoencoder model.
929
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
930
+ l1_penalty (float): L1 regularization coefficient.
931
+
932
+ Returns:
933
+ float: Average validation loss over the epoch.
934
+ """
935
+ model.eval()
936
+ running = 0.0
937
+ num_batches = 0
938
+
939
+ nF_model = self.num_features_
940
+ nC_model = self.num_classes_
941
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
942
+
943
+ with torch.no_grad():
944
+ for X_batch, y_batch, m_batch in loader:
945
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
946
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
947
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
948
+
949
+ logits_flat = model(X_batch)
950
+ expected = (X_batch.shape[0], nF_model * nC_model)
951
+
952
+ if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
953
+ try:
954
+ logits_flat = logits_flat.view(*expected)
955
+ except Exception as e:
956
+ msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
957
+ self.logger.error(msg)
958
+ raise ValueError(msg) from e
959
+
960
+ logits = logits_flat.view(-1, nF_model, nC_model)
961
+ logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
962
+ targets_masked = y_batch.view(-1)[m_batch.view(-1)]
963
+
964
+ if targets_masked.numel() == 0:
965
+ continue
966
+
967
+ loss = ce_criterion(logits_masked, targets_masked)
968
+
969
+ if l1_penalty > 0:
970
+ l1 = torch.zeros((), device=self.device)
971
+ for p in l1_params:
972
+ l1 = l1 + p.abs().sum()
973
+ loss = loss + l1_penalty * l1
974
+
975
+ if not torch.isfinite(loss):
976
+ continue
977
+
978
+ running += float(loss.item())
979
+ num_batches += 1
980
+
981
+ return float("inf") if num_batches == 0 else running / num_batches
815
982
 
816
983
  def _predict(
817
984
  self,
818
985
  model: torch.nn.Module,
819
986
  X: np.ndarray | torch.Tensor,
987
+ *,
820
988
  return_proba: bool = False,
821
- ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
822
- """Predict 0/1/2 labels (and probabilities) from masked inputs.
823
-
824
- This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
989
+ ) -> tuple[np.ndarray, np.ndarray | None]:
990
+ """Predict categorical genotype labels from logits.
825
991
 
826
992
  Args:
827
993
  model (torch.nn.Module): Trained model.
828
- X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
829
- for missing.
830
- return_proba (bool): If True, return probabilities.
994
+ X (np.ndarray | torch.Tensor): 2D 0/1/2 matrix with -1 for missing, or 3D one-hot (B, L, K).
995
+ return_proba (bool): If True, return probabilities (B, L, K).
831
996
 
832
997
  Returns:
833
- Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
834
- and probabilities if requested.
998
+ tuple[np.ndarray, np.ndarray | None]: Predicted labels and optionally probabilities.
835
999
  """
836
1000
  if model is None:
837
1001
  msg = "Model is not trained. Call fit() before predict()."
838
1002
  self.logger.error(msg)
839
1003
  raise NotFittedError(msg)
840
1004
 
1005
+ nF = self.num_features_
1006
+ nC = self.num_classes_
1007
+
1008
+ if isinstance(X, torch.Tensor):
1009
+ X_tensor = X
1010
+ else:
1011
+ X_tensor = torch.from_numpy(X)
1012
+ X_tensor = X_tensor.float()
1013
+
1014
+ if X_tensor.device != self.device:
1015
+ X_tensor = X_tensor.to(self.device)
1016
+
1017
+ if X_tensor.dim() == 2:
1018
+ # 0/1/2 matrix -> one-hot for model input
1019
+ X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
1020
+ X_tensor = X_tensor.float()
1021
+ if X_tensor.device != self.device:
1022
+ X_tensor = X_tensor.to(self.device)
1023
+
1024
+ elif X_tensor.dim() != 3:
1025
+ msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
1026
+ self.logger.error(msg)
1027
+ raise ValueError(msg)
1028
+
1029
+ if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
1030
+ msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
1031
+ self.logger.error(msg)
1032
+ raise ValueError(msg)
1033
+
1034
+ X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
1035
+
841
1036
  model.eval()
842
1037
  with torch.no_grad():
843
- X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
844
- X_tensor = X_tensor.to(self.device).long()
845
- x_ohe = self._one_hot_encode_012(X_tensor)
846
- logits = model(x_ohe).view(-1, self.num_features_, self.num_classes_)
1038
+ logits_flat = model(X_tensor)
1039
+ logits = logits_flat.view(-1, nF, nC)
1040
+
847
1041
  probas = torch.softmax(logits, dim=-1)
848
1042
  labels = torch.argmax(probas, dim=-1)
849
1043
 
850
1044
  if return_proba:
851
1045
  return labels.cpu().numpy(), probas.cpu().numpy()
852
-
853
- return labels.cpu().numpy()
1046
+ return labels.cpu().numpy(), None
854
1047
 
855
1048
  def _evaluate_model(
856
1049
  self,
857
- X_val: np.ndarray,
858
1050
  model: torch.nn.Module,
859
- params: dict,
860
- objective_mode: bool = False,
861
- latent_vectors_val: Optional[np.ndarray] = None,
1051
+ X: np.ndarray,
1052
+ y: np.ndarray,
1053
+ eval_mask: np.ndarray,
862
1054
  *,
863
- eval_mask_override: np.ndarray | None = None,
1055
+ objective_mode: bool = False,
864
1056
  ) -> Dict[str, float]:
865
1057
  """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
866
1058
 
867
- This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
868
-
869
1059
  Args:
870
- X_val (np.ndarray): Validation set 0/1/2 matrix with -1
871
- for missing.
872
1060
  model (torch.nn.Module): Trained model.
873
- params (dict): Model parameters.
874
- objective_mode (bool): If True, suppress logging and reports.
875
- latent_vectors_val (Optional[np.ndarray]): Unused for AE.
876
- eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
1061
+ X (np.ndarray): 2D 0/1/2 matrix with -1 for missing.
1062
+ y (np.ndarray): 2D 0/1/2 ground truth matrix with -1 for missing.
1063
+ eval_mask (np.ndarray): 2D boolean mask selecting sites to evaluate.
1064
+ objective_mode (bool): If True, suppress detailed reports and plots.
877
1065
 
878
1066
  Returns:
879
1067
  Dict[str, float]: Dictionary of evaluation metrics.
880
1068
  """
881
- pred_labels, pred_probas = self._predict(
882
- model=model, X=X_val, return_proba=True
883
- )
1069
+ if model is None:
1070
+ msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
1071
+ self.logger.error(msg)
1072
+ raise NotFittedError(msg)
884
1073
 
885
- finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
886
-
887
- # FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
888
- if (
889
- hasattr(self, "X_val_")
890
- and getattr(self, "X_val_", None) is not None
891
- and X_val.shape[0] == self.X_val_.shape[0]
892
- ):
893
- GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
894
- elif (
895
- hasattr(self, "X_train_")
896
- and getattr(self, "X_train_", None) is not None
897
- and X_val.shape[0] == self.X_train_.shape[0]
898
- ):
899
- GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
900
- else:
901
- GT_ref = self.ground_truth_
902
-
903
- # FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
904
- # If the GT source has more columns than X_val, slice it to match.
905
- if GT_ref.shape[1] > X_val.shape[1]:
906
- GT_ref = GT_ref[:, : X_val.shape[1]]
907
-
908
- # Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
909
- if GT_ref.shape != X_val.shape:
910
- # If completely different, we can't use the ground truth object.
911
- # Fall back to X_val (this implies only observed values are scored)
912
- GT_ref = X_val
913
-
914
- if eval_mask_override is not None:
915
- # FIX 3: Allow override mask to be sliced if it's too wide
916
- if eval_mask_override.shape[0] != X_val.shape[0]:
917
- msg = (
918
- f"eval_mask_override rows {eval_mask_override.shape[0]} "
919
- f"does not match X_val rows {X_val.shape[0]}"
920
- )
921
- self.logger.error(msg)
922
- raise ValueError(msg)
1074
+ pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
923
1075
 
924
- if eval_mask_override.shape[1] > X_val.shape[1]:
925
- eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
926
- else:
927
- eval_mask = eval_mask_override.astype(bool)
928
- else:
929
- eval_mask = X_val != -1
1076
+ if pred_probas is None:
1077
+ msg = "Predicted probabilities are None in _evaluate_model()."
1078
+ self.logger.error(msg)
1079
+ raise ValueError(msg)
930
1080
 
931
- # Combine masks
932
- eval_mask = eval_mask & finite_mask & (GT_ref != -1)
1081
+ y_true_flat = y[eval_mask].astype(np.int8, copy=False)
1082
+ y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
1083
+ y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
933
1084
 
934
- y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
935
- y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
936
- y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
1085
+ valid = y_true_flat >= 0
1086
+ y_true_flat = y_true_flat[valid]
1087
+ y_pred_flat = y_pred_flat[valid]
1088
+ y_proba_flat = y_proba_flat[valid]
937
1089
 
938
1090
  if y_true_flat.size == 0:
939
- self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
940
1091
  return {self.tune_metric: 0.0}
941
1092
 
942
- # ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
1093
+ if y_proba_flat.ndim != 2:
1094
+ msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got {y_proba_flat.shape}."
1095
+ self.logger.error(msg)
1096
+ raise ValueError(msg)
1097
+
1098
+ K = int(y_proba_flat.shape[1])
1099
+ if self.is_haploid_:
1100
+ if K not in (2, 3):
1101
+ msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
1102
+ self.logger.error(msg)
1103
+ raise ValueError(msg)
1104
+ else:
1105
+ if K != 3:
1106
+ msg = f"Diploid evaluation expects 3 classes; got {K}."
1107
+ self.logger.error(msg)
1108
+ raise ValueError(msg)
1109
+
1110
+ if self.is_haploid_:
1111
+ y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
1112
+ y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
1113
+
1114
+ if K == 3:
1115
+ proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
1116
+ proba_2[:, 0] = y_proba_flat[:, 0]
1117
+ proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
1118
+ y_proba_flat = proba_2
1119
+
1120
+ labels_for_scoring = [0, 1]
1121
+ target_names = ["REF", "ALT"]
1122
+ else:
1123
+ labels_for_scoring = [0, 1, 2]
1124
+ target_names = ["REF", "HET", "ALT"]
1125
+
943
1126
  y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
944
1127
  row_sums = y_proba_flat.sum(axis=1, keepdims=True)
945
- row_sums[row_sums == 0] = 1.0
1128
+ row_sums[row_sums == 0.0] = 1.0
946
1129
  y_proba_flat = y_proba_flat / row_sums
947
1130
 
948
- labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
949
- target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
950
-
951
- if self.is_haploid:
952
- y_true_flat = y_true_flat.copy()
953
- y_pred_flat = y_pred_flat.copy()
954
- y_true_flat[y_true_flat == 2] = 1
955
- y_pred_flat[y_pred_flat == 2] = 1
956
- # collapse probs to 2-class
957
- proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
958
- proba_2[:, 0] = y_proba_flat[:, 0]
959
- proba_2[:, 1] = y_proba_flat[:, 2]
960
- y_proba_flat = proba_2
961
-
962
- y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
963
-
964
- tune_metric_tmp: Literal[
965
- "pr_macro",
966
- "roc_auc",
967
- "average_precision",
968
- "accuracy",
969
- "f1",
970
- "precision",
971
- "recall",
972
- ]
973
- if self.tune_metric_ is not None:
974
- tune_metric_tmp = self.tune_metric_
975
- else:
976
- tune_metric_tmp = "f1" # Default if not tuning
1131
+ y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
977
1132
 
978
1133
  metrics = self.scorers_.evaluate(
979
1134
  y_true_flat,
@@ -981,16 +1136,29 @@ class ImputeAutoencoder(BaseNNImputer):
981
1136
  y_true_ohe,
982
1137
  y_proba_flat,
983
1138
  objective_mode,
984
- tune_metric_tmp,
1139
+ cast(
1140
+ Literal[
1141
+ "pr_macro",
1142
+ "roc_auc",
1143
+ "accuracy",
1144
+ "f1",
1145
+ "average_precision",
1146
+ "precision",
1147
+ "recall",
1148
+ "mcc",
1149
+ "jaccard",
1150
+ ],
1151
+ self.tune_metric,
1152
+ ),
985
1153
  )
986
1154
 
987
1155
  if not objective_mode:
988
- pm = PrettyMetrics(
989
- metrics, precision=3, title=f"{self.model_name} Validation Metrics"
990
- )
991
- pm.render() # prints a command-line table
1156
+ if self.verbose or self.debug:
1157
+ pm = PrettyMetrics(
1158
+ metrics, precision=2, title=f"{self.model_name} Validation Metrics"
1159
+ )
1160
+ pm.render()
992
1161
 
993
- # Primary report (REF/HET/ALT or REF/ALT)
994
1162
  self._make_class_reports(
995
1163
  y_true=y_true_flat,
996
1164
  y_pred_proba=y_proba_flat,
@@ -999,18 +1167,15 @@ class ImputeAutoencoder(BaseNNImputer):
999
1167
  labels=target_names,
1000
1168
  )
1001
1169
 
1002
- # IUPAC decode & 10-base integer reports
1003
- # Now safe because GT_ref has been sliced to match X_val dimensions
1004
- y_true_dec = self.pgenc.decode_012(
1005
- GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1006
- )
1007
- X_pred = X_val.copy()
1008
- X_pred[eval_mask] = y_pred_flat
1170
+ y_true_matrix = np.array(y, copy=True)
1171
+ y_pred_matrix = np.array(pred_labels, copy=True)
1009
1172
 
1010
- # Use X_val.shape[1] (current features) not self.num_features_ (original features)
1011
- y_pred_dec = self.pgenc.decode_012(
1012
- X_pred.reshape(X_val.shape[0], X_val.shape[1])
1013
- )
1173
+ if self.is_haploid_:
1174
+ y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
1175
+ y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
1176
+
1177
+ y_true_dec = self.decode_012(y_true_matrix)
1178
+ y_pred_dec = self.decode_012(y_pred_matrix)
1014
1179
 
1015
1180
  encodings_dict = {
1016
1181
  "A": 0,
@@ -1049,237 +1214,177 @@ class ImputeAutoencoder(BaseNNImputer):
1049
1214
  return metrics
1050
1215
 
1051
1216
  def _objective(self, trial: optuna.Trial) -> float:
1052
- """Optuna objective for AE; mirrors NLPCA study driver without latents.
1053
-
1054
- This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
1217
+ """Optuna objective for AE (mirrors VAE flow, excluding KL-specific parts).
1055
1218
 
1056
1219
  Args:
1057
- trial (optuna.Trial): Optuna trial.
1220
+ trial (optuna.Trial): Optuna trial object.
1058
1221
 
1059
1222
  Returns:
1060
- float: Value of the tuning metric (maximize).
1223
+ float: Value of the tuning metric to optimize.
1061
1224
  """
1062
1225
  try:
1063
- # Sample hyperparameters (existing helper; unchanged signature)
1064
1226
  params = self._sample_hyperparameters(trial)
1065
1227
 
1066
- # Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
1067
- X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1068
- X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1069
-
1070
- class_weights = self._normalize_class_weights(
1071
- self._class_weights_from_zygosity(X_train)
1072
- )
1073
- train_loader = self._get_data_loaders(X_train)
1074
-
1075
1228
  model = self.build_model(self.Model, params["model_params"])
1076
1229
  model.apply(self.initialize_weights)
1077
1230
 
1078
- lr: float = float(params["lr"])
1079
- l1_penalty: float = float(params["l1_penalty"])
1231
+ lr = float(params["learning_rate"])
1232
+ l1_penalty = float(params["l1_penalty"])
1233
+
1234
+ class_weights = self._class_weights_from_zygosity(
1235
+ self.y_train_,
1236
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1237
+ inverse=params["inverse"],
1238
+ normalize=params["normalize"],
1239
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1240
+ power=params["power"],
1241
+ )
1080
1242
 
1081
- # Train + prune on metric
1082
- _, model, __ = self._train_and_validate_model(
1243
+ loss, model, _hist = self._train_and_validate_model(
1083
1244
  model=model,
1084
- loader=train_loader,
1085
1245
  lr=lr,
1086
1246
  l1_penalty=l1_penalty,
1247
+ params=params,
1087
1248
  trial=trial,
1088
- return_history=False,
1089
1249
  class_weights=class_weights,
1090
- X_val=X_val,
1091
- params=params,
1092
- prune_metric=self.tune_metric,
1093
- prune_warmup_epochs=5,
1094
- eval_interval=self.tune_eval_interval,
1095
- eval_requires_latents=False,
1096
- eval_latent_steps=0,
1097
- eval_latent_lr=0.0,
1098
- eval_latent_weight_decay=0.0,
1250
+ gamma_schedule=params["gamma_schedule"],
1099
1251
  )
1100
1252
 
1101
- eval_mask = (
1102
- self.sim_mask_test_
1103
- if (
1104
- self.simulate_missing
1105
- and getattr(self, "sim_mask_test_", None) is not None
1106
- )
1107
- else None
1108
- )
1253
+ if model is None or not np.isfinite(loss):
1254
+ msg = "Model training returned None or non-finite loss in tuning objective."
1255
+ self.logger.error(msg)
1256
+ raise RuntimeError(msg)
1109
1257
 
1110
- if model is not None:
1111
- metrics = self._evaluate_model(
1112
- X_val,
1113
- model,
1114
- params,
1115
- objective_mode=True,
1116
- eval_mask_override=eval_mask,
1117
- )
1118
- self._clear_resources(model, train_loader)
1119
- else:
1120
- raise TypeError("Model training failed; no model was returned.")
1258
+ metrics = self._evaluate_model(
1259
+ model=model,
1260
+ X=self.X_val_,
1261
+ y=self.y_val_,
1262
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
1263
+ objective_mode=True,
1264
+ )
1121
1265
 
1122
- return metrics[self.tune_metric]
1266
+ self._clear_resources(model)
1267
+ return float(metrics[self.tune_metric])
1123
1268
 
1124
1269
  except Exception as e:
1125
- # Keep sweeps moving if a trial fails
1126
- raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
1127
-
1128
- def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
1129
- """Sample AE hyperparameters and compute hidden sizes for model params.
1270
+ err_type = type(e).__name__
1271
+ self.logger.warning(
1272
+ f"Trial {trial.number} failed due to exception {err_type}: {e}"
1273
+ )
1274
+ self.logger.debug(traceback.format_exc())
1275
+ raise optuna.exceptions.TrialPruned(
1276
+ f"Trial {trial.number} failed due to an exception. {err_type}: {e}. "
1277
+ "Enable debug logging for full traceback."
1278
+ ) from e
1130
1279
 
1131
- This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
1280
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1281
+ """Sample AE hyperparameters; hidden sizes mirror VAE helper (excluding KL).
1132
1282
 
1133
1283
  Args:
1134
1284
  trial (optuna.Trial): Optuna trial object.
1135
1285
 
1136
1286
  Returns:
1137
- Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
1287
+ dict: Sampled hyperparameters.
1138
1288
  """
1139
1289
  params = {
1140
- "latent_dim": trial.suggest_int("latent_dim", 2, 64),
1141
- "lr": trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True),
1142
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.6),
1143
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 8),
1290
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
1291
+ "learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
1292
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
1293
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
1144
1294
  "activation": trial.suggest_categorical(
1145
- "activation", ["relu", "elu", "selu"]
1295
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
1146
1296
  ),
1147
- "l1_penalty": trial.suggest_float("l1_penalty", 1e-7, 1e-2, log=True),
1297
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1148
1298
  "layer_scaling_factor": trial.suggest_float(
1149
- "layer_scaling_factor", 2.0, 10.0
1299
+ "layer_scaling_factor", 2.0, 10.0, step=0.025
1150
1300
  ),
1151
1301
  "layer_schedule": trial.suggest_categorical(
1152
- "layer_schedule", ["pyramid", "constant", "linear"]
1302
+ "layer_schedule", ["pyramid", "linear"]
1303
+ ),
1304
+ "power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
1305
+ "normalize": trial.suggest_categorical("normalize", [True, False]),
1306
+ "inverse": trial.suggest_categorical("inverse", [True, False]),
1307
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0, step=0.1),
1308
+ "gamma_schedule": trial.suggest_categorical(
1309
+ "gamma_schedule", [True, False]
1153
1310
  ),
1154
1311
  }
1155
1312
 
1156
- nF: int = self.num_features_
1157
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1313
+ nF = int(self.num_features_)
1314
+ nC = int(self.num_classes_)
1158
1315
  input_dim = nF * nC
1316
+
1159
1317
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1160
1318
  n_inputs=input_dim,
1161
- n_outputs=input_dim,
1319
+ n_outputs=nC,
1162
1320
  n_samples=len(self.train_idx_),
1163
- n_hidden=params["num_hidden_layers"],
1164
- alpha=params["layer_scaling_factor"],
1165
- schedule=params["layer_schedule"],
1321
+ n_hidden=int(params["num_hidden_layers"]),
1322
+ latent_dim=int(params["latent_dim"]),
1323
+ alpha=float(params["layer_scaling_factor"]),
1324
+ schedule=str(params["layer_schedule"]),
1166
1325
  )
1167
1326
 
1168
- # Keep the latent_dim as the first element,
1169
- # then the interior hidden widths.
1170
- # If there are no interior widths (very small nets),
1171
- # this still leaves [latent_dim].
1172
- hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1173
-
1174
1327
  params["model_params"] = {
1175
- "n_features": int(self.num_features_),
1176
- "num_classes": (
1177
- int(self.num_classes_) if self.num_classes_ is not None else 3
1178
- ),
1328
+ "n_features": nF,
1329
+ "num_classes": nC,
1179
1330
  "latent_dim": int(params["latent_dim"]),
1180
1331
  "dropout_rate": float(params["dropout_rate"]),
1181
- "hidden_layer_sizes": hidden_only,
1332
+ "hidden_layer_sizes": hidden_layer_sizes,
1182
1333
  "activation": str(params["activation"]),
1183
1334
  }
1184
1335
  return params
1185
1336
 
1186
- def _set_best_params(
1187
- self, best_params: Dict[str, int | float | str | List[int]]
1188
- ) -> Dict[str, int | float | str | List[int]]:
1189
- """Adopt best params (ImputeNLPCA parity) and return model_params.
1190
-
1191
- This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
1337
+ def _set_best_params(self, params: dict) -> dict:
1338
+ """Update instance fields from tuned params and return model_params dict.
1192
1339
 
1193
1340
  Args:
1194
- best_params (Dict[str, int | float | str | List[int]]): Best hyperparameters from tuning.
1341
+ params (dict): Best hyperparameters from tuning.
1195
1342
 
1196
1343
  Returns:
1197
- Dict[str, int | float | str | List[int]]: Model parameters for building the model.
1344
+ dict: Model parameters for building the final model.
1198
1345
  """
1199
- bp = {}
1200
- for k, v in best_params.items():
1201
- if not isinstance(v, list):
1202
- if k in {"latent_dim", "num_hidden_layers"}:
1203
- bp[k] = int(v)
1204
- elif k in {
1205
- "dropout_rate",
1206
- "learning_rate",
1207
- "l1_penalty",
1208
- "layer_scaling_factor",
1209
- }:
1210
- bp[k] = float(v)
1211
- elif k in {"activation", "layer_schedule"}:
1212
- if k == "layer_schedule":
1213
- if v not in {"pyramid", "constant", "linear"}:
1214
- raise ValueError(f"Invalid layer_schedule: {v}")
1215
- bp[k] = v
1216
- else:
1217
- bp[k] = str(v)
1218
- else:
1219
- bp[k] = v # keep lists as-is
1220
-
1221
- self.latent_dim: int = bp["latent_dim"]
1222
- self.dropout_rate: float = bp["dropout_rate"]
1223
- self.learning_rate: float = bp["learning_rate"]
1224
- self.l1_penalty: float = bp["l1_penalty"]
1225
- self.activation: str = bp["activation"]
1226
- self.layer_scaling_factor: float = bp["layer_scaling_factor"]
1227
- self.layer_schedule: str = bp["layer_schedule"]
1228
-
1229
- nF: int = self.num_features_
1230
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1231
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
1232
- n_inputs=nF * nC,
1233
- n_outputs=nF * nC,
1234
- n_samples=len(self.train_idx_),
1235
- n_hidden=bp["num_hidden_layers"],
1236
- alpha=bp["layer_scaling_factor"],
1237
- schedule=bp["layer_schedule"],
1346
+ self.latent_dim = int(params["latent_dim"])
1347
+ self.dropout_rate = float(params["dropout_rate"])
1348
+ self.learning_rate = float(params["learning_rate"])
1349
+ self.l1_penalty = float(params["l1_penalty"])
1350
+ self.activation = str(params["activation"])
1351
+ self.layer_scaling_factor = float(params["layer_scaling_factor"])
1352
+ self.layer_schedule = str(params["layer_schedule"])
1353
+
1354
+ self.power = float(params["power"])
1355
+ self.normalize = bool(params["normalize"])
1356
+ self.inverse = bool(params["inverse"])
1357
+ self.gamma = float(params["gamma"])
1358
+ self.gamma_schedule = bool(params["gamma_schedule"])
1359
+
1360
+ self.class_weights_ = self._class_weights_from_zygosity(
1361
+ self.y_train_,
1362
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1363
+ inverse=self.inverse,
1364
+ normalize=self.normalize,
1365
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1366
+ power=self.power,
1238
1367
  )
1239
1368
 
1240
- # Keep the latent_dim as the first element,
1241
- # then the interior hidden widths.
1242
- # If there are no interior widths (very small nets),
1243
- # this still leaves [latent_dim].
1244
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1245
-
1246
- return {
1247
- "n_features": self.num_features_,
1248
- "latent_dim": self.latent_dim,
1249
- "hidden_layer_sizes": hidden_only,
1250
- "dropout_rate": self.dropout_rate,
1251
- "activation": self.activation,
1252
- "num_classes": nC,
1253
- }
1254
-
1255
- def _default_best_params(self) -> Dict[str, int | float | str | list]:
1256
- """Default model params when tuning is disabled.
1257
-
1258
- This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
1259
-
1260
- Returns:
1261
- Dict[str, int | float | str | list]: Default model parameters.
1262
- """
1263
- nF: int = self.num_features_
1264
- nC: int = int(self.num_classes_) if self.num_classes_ is not None else 3
1265
- ls = self.layer_schedule
1266
-
1267
- if ls not in {"pyramid", "constant", "linear"}:
1268
- raise ValueError(f"Invalid layer_schedule: {ls}")
1369
+ nF = int(self.num_features_)
1370
+ nC = int(self.num_classes_)
1371
+ input_dim = nF * nC
1269
1372
 
1270
1373
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1271
- n_inputs=nF * nC,
1272
- n_outputs=nF * nC,
1273
- n_samples=len(self.ground_truth_),
1274
- n_hidden=self.num_hidden_layers,
1275
- alpha=self.layer_scaling_factor,
1276
- schedule=ls,
1374
+ n_inputs=input_dim,
1375
+ n_outputs=nC,
1376
+ n_samples=len(self.train_idx_),
1377
+ n_hidden=int(params["num_hidden_layers"]),
1378
+ latent_dim=int(params["latent_dim"]),
1379
+ alpha=float(params["layer_scaling_factor"]),
1380
+ schedule=str(params["layer_schedule"]),
1277
1381
  )
1382
+
1278
1383
  return {
1279
- "n_features": self.num_features_,
1384
+ "n_features": nF,
1385
+ "num_classes": nC,
1280
1386
  "latent_dim": self.latent_dim,
1281
1387
  "hidden_layer_sizes": hidden_layer_sizes,
1282
1388
  "dropout_rate": self.dropout_rate,
1283
1389
  "activation": self.activation,
1284
- "num_classes": nC,
1285
1390
  }