pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,23 +1,25 @@
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
1
4
  import copy
2
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
5
+ import traceback
6
+ from collections import defaultdict
7
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, cast
3
8
 
4
9
  import matplotlib.pyplot as plt
5
10
  import numpy as np
6
11
  import optuna
7
12
  import torch
8
- import torch.nn.functional as F
9
13
  from sklearn.exceptions import NotFittedError
10
- from sklearn.model_selection import train_test_split
11
14
  from snpio.analysis.genotype_encoder import GenotypeEncoder
12
15
  from snpio.utils.logging import LoggerManager
13
16
  from torch.optim.lr_scheduler import CosineAnnealingLR
14
17
 
15
18
  from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
16
19
  from pgsui.data_processing.containers import AutoencoderConfig
17
- from pgsui.data_processing.transformers import SimMissingTransformer
18
20
  from pgsui.impute.unsupervised.base import BaseNNImputer
19
21
  from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
- from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
22
+ from pgsui.impute.unsupervised.loss_functions import FocalCELoss
21
23
  from pgsui.impute.unsupervised.models.autoencoder_model import AutoencoderModel
22
24
  from pgsui.utils.logging_utils import configure_logger
23
25
  from pgsui.utils.pretty_metrics import PrettyMetrics
@@ -27,30 +29,72 @@ if TYPE_CHECKING:
27
29
  from snpio.read_input.genotype_data import GenotypeData
28
30
 
29
31
 
32
+ def _make_warmup_cosine_scheduler(
33
+ optimizer: torch.optim.Optimizer,
34
+ *,
35
+ max_epochs: int,
36
+ warmup_epochs: int,
37
+ start_factor: float = 0.1,
38
+ ) -> torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR:
39
+ """Create a warmup->cosine LR scheduler.
40
+
41
+ Args:
42
+ optimizer: Optimizer to schedule.
43
+ max_epochs: Total number of epochs.
44
+ warmup_epochs: Number of warmup epochs.
45
+ start_factor: Starting LR factor for warmup.
46
+
47
+ Returns:
48
+ torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR: LR scheduler (SequentialLR if warmup_epochs > 0 else CosineAnnealingLR).
49
+ """
50
+ warmup_epochs = int(max(0, warmup_epochs))
51
+
52
+ if warmup_epochs == 0:
53
+ return CosineAnnealingLR(optimizer, T_max=max_epochs)
54
+
55
+ warmup = torch.optim.lr_scheduler.LinearLR(
56
+ optimizer, start_factor=float(start_factor), total_iters=warmup_epochs
57
+ )
58
+ cosine = CosineAnnealingLR(optimizer, T_max=max(1, max_epochs - warmup_epochs))
59
+
60
+ return torch.optim.lr_scheduler.SequentialLR(
61
+ optimizer, schedulers=[warmup, cosine], milestones=[warmup_epochs]
62
+ )
63
+
64
+
30
65
  def ensure_autoencoder_config(
31
66
  config: AutoencoderConfig | dict | str | None,
32
67
  ) -> AutoencoderConfig:
33
68
  """Return a concrete AutoencoderConfig from dataclass, dict, YAML path, or None.
34
69
 
35
- This method normalizes the configuration input for the Autoencoder imputer. It accepts a structured configuration in various formats, including a dataclass instance, a nested dictionary, a YAML file path, or None. The method processes the input accordingly and returns a concrete instance of AutoencoderConfig with all necessary fields populated.
70
+ Notes:
71
+ - Supports top-level preset, or io.preset inside dict/YAML.
72
+ - Does not mutate user-provided dict (deep-copies before processing).
73
+ - Flattens nested dicts into dot-keys and applies them as overrides.
36
74
 
37
75
  Args:
38
- config (AutoencoderConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
76
+ config: AutoencoderConfig instance, dict, YAML path, or None.
39
77
 
40
78
  Returns:
41
- AutoencoderConfig: Concrete configuration instance.
79
+ Concrete AutoencoderConfig.
42
80
  """
43
81
  if config is None:
44
82
  return AutoencoderConfig()
45
83
  if isinstance(config, AutoencoderConfig):
46
84
  return config
47
85
  if isinstance(config, str):
48
- # YAML path — top-level `preset` key is supported
49
86
  return load_yaml_to_dataclass(config, AutoencoderConfig)
50
87
  if isinstance(config, dict):
51
- # Flatten dict into dot-keys then overlay onto a fresh instance
88
+ cfg_in = copy.deepcopy(config)
52
89
  base = AutoencoderConfig()
53
90
 
91
+ preset = cfg_in.pop("preset", None)
92
+ if "io" in cfg_in and isinstance(cfg_in["io"], dict):
93
+ preset = preset or cfg_in["io"].pop("preset", None)
94
+
95
+ if preset:
96
+ base = AutoencoderConfig.from_preset(preset)
97
+
54
98
  def _flatten(prefix: str, d: dict, out: dict) -> dict:
55
99
  for k, v in d.items():
56
100
  kk = f"{prefix}.{k}" if prefix else k
@@ -60,26 +104,24 @@ def ensure_autoencoder_config(
60
104
  out[kk] = v
61
105
  return out
62
106
 
63
- # Lift any present preset first
64
- preset_name = config.pop("preset", None)
65
- if "io" in config and isinstance(config["io"], dict):
66
- preset_name = preset_name or config["io"].pop("preset", None)
67
-
68
- if preset_name:
69
- base = AutoencoderConfig.from_preset(preset_name)
70
-
71
- flat = _flatten("", config, {})
107
+ flat = _flatten("", cfg_in, {})
72
108
  return apply_dot_overrides(base, flat)
73
109
 
74
110
  raise TypeError("config must be an AutoencoderConfig, dict, YAML path, or None.")
75
111
 
76
112
 
77
113
  class ImputeAutoencoder(BaseNNImputer):
78
- """Impute missing genotypes with a standard Autoencoder on 0/1/2 encodings.
114
+ """Autoencoder imputer for 0/1/2 genotypes.
79
115
 
80
- This imputer uses a feedforward autoencoder architecture to learn compressed and reconstructive representations of genotype data encoded as 0 (homozygous reference), 1 (heterozygous), and 2 (homozygous alternate). Missing genotypes are represented as -1 during training and imputation.
116
+ Trains a feedforward autoencoder on a genotype matrix encoded as 0/1/2 with missing values represented by any negative integer. Missingness is simulated once on the full matrix, then train/val/test splits reuse those masks. It supports haploid and diploid data, focal-CE reconstruction loss (optional scheduling), and Optuna-based hyperparameter tuning. Output is returned as IUPAC strings via ``decode_012``.
81
117
 
82
- The model is trained to minimize a focal cross-entropy loss, which helps to address class imbalance by focusing more on hard-to-classify examples. The architecture includes configurable parameters such as the number of hidden layers, latent dimension size, dropout rate, and activation functions.
118
+ Notes:
119
+ - Simulates missingness once on the full 0/1/2 matrix, then splits indices on clean ground truth.
120
+ - Maintains clean targets and corrupted inputs per train/val/test, plus per-split masks.
121
+ - Haploid harmonization happens after the single simulation (no re-simulation).
122
+ - Training/validation loss is computed only where targets are known (~orig_mask_*).
123
+ - Evaluation is computed only on simulated-missing sites (sim_mask_*).
124
+ - ``transform()`` fills only originally missing sites and hard-errors if decoding yields "N".
83
125
  """
84
126
 
85
127
  def __init__(
@@ -88,8 +130,7 @@ class ImputeAutoencoder(BaseNNImputer):
88
130
  *,
89
131
  tree_parser: Optional["TreeParser"] = None,
90
132
  config: Optional[Union["AutoencoderConfig", dict, str]] = None,
91
- overrides: dict | None = None,
92
- simulate_missing: bool | None = None,
133
+ overrides: Optional[dict] = None,
93
134
  sim_strategy: (
94
135
  Literal[
95
136
  "random",
@@ -100,34 +141,29 @@ class ImputeAutoencoder(BaseNNImputer):
100
141
  ]
101
142
  | None
102
143
  ) = None,
103
- sim_prop: float | None = None,
104
- sim_kwargs: dict | None = None,
144
+ sim_prop: Optional[float] = None,
145
+ sim_kwargs: Optional[dict] = None,
105
146
  ) -> None:
106
147
  """Initialize the Autoencoder imputer with a unified config interface.
107
148
 
108
- This initializer sets up the Autoencoder imputer by processing the provided configuration, initializing logging, and preparing the model and data encoder. It supports configuration input as a dataclass, nested dictionary, YAML file path, or None, with optional dot-key overrides for fine-tuning specific parameters.
109
-
110
149
  Args:
111
- genotype_data ("GenotypeData"): Backing genotype data object.
112
- tree_parser (Optional["TreeParser"]): Optional SNPio phylogenetic tree parser for population-specific modes.
113
- config (Union["AutoencoderConfig", dict, str] | None): Structured configuration as dataclass, nested dict, YAML path, or None.
114
- overrides (dict | None): Optional dot-key overrides with highest precedence (e.g., {'model.latent_dim': 32}).
115
- simulate_missing (bool | None): Whether to simulate missing data during evaluation. If None, uses config default.
116
- sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
117
- sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
118
- sim_kwargs (dict | None): Additional keyword arguments for simulating missing data. If None, uses config default.
150
+ genotype_data (GenotypeData): Backing genotype data object.
151
+ tree_parser (Optional[TreeParser]): Optional SNPio tree parser for nonrandom simulated-missing modes.
152
+ config (Optional[Union[AutoencoderConfig, dict, str]]): AutoencoderConfig, nested dict, YAML path, or None.
153
+ overrides (Optional[dict]): Optional dot-key overrides with highest precedence.
154
+ sim_strategy (Literal["random", "random_weighted" "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Override sim strategy; if None, uses config default.
155
+ sim_prop (Optional[float]): Override simulated missing proportion; if None, uses config default. Default is None.
156
+ sim_kwargs (Optional[dict]): Override/extend simulated missing kwargs; if None, uses config default.
119
157
  """
120
158
  self.model_name = "ImputeAutoencoder"
121
159
  self.genotype_data = genotype_data
122
160
  self.tree_parser = tree_parser
123
161
 
124
- # Normalize config then apply highest-precedence overrides
125
162
  cfg = ensure_autoencoder_config(config)
126
163
  if overrides:
127
164
  cfg = apply_dot_overrides(cfg, overrides)
128
165
  self.cfg = cfg
129
166
 
130
- # Logger consistent with NLPCA
131
167
  logman = LoggerManager(
132
168
  __name__,
133
169
  prefix=self.cfg.io.prefix,
@@ -139,8 +175,8 @@ class ImputeAutoencoder(BaseNNImputer):
139
175
  verbose=self.cfg.io.verbose,
140
176
  debug=self.cfg.io.debug,
141
177
  )
178
+ self.logger.propagate = False
142
179
 
143
- # BaseNNImputer bootstrapping (device/dirs/logging handled here)
144
180
  super().__init__(
145
181
  model_name=self.model_name,
146
182
  genotype_data=self.genotype_data,
@@ -151,11 +187,9 @@ class ImputeAutoencoder(BaseNNImputer):
151
187
  )
152
188
 
153
189
  self.Model = AutoencoderModel
154
-
155
- # Model hook & encoder
156
190
  self.pgenc = GenotypeEncoder(genotype_data)
157
191
 
158
- # IO / global
192
+ # I/O and global
159
193
  self.seed = self.cfg.io.seed
160
194
  self.n_jobs = self.cfg.io.n_jobs
161
195
  self.prefix = self.cfg.io.prefix
@@ -163,264 +197,347 @@ class ImputeAutoencoder(BaseNNImputer):
163
197
  self.verbose = self.cfg.io.verbose
164
198
  self.debug = self.cfg.io.debug
165
199
  self.rng = np.random.default_rng(self.seed)
166
- self.pos_weights_: torch.Tensor | None = None
167
200
 
168
- # Simulated-missing controls (config defaults with ctor overrides)
201
+ # Simulation controls (match VAE pattern)
169
202
  sim_cfg = getattr(self.cfg, "sim", None)
170
203
  sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
171
204
  if sim_kwargs:
172
205
  sim_cfg_kwargs.update(sim_kwargs)
173
- self.simulate_missing = (
174
- (
175
- sim_cfg.simulate_missing
176
- if simulate_missing is None
177
- else bool(simulate_missing)
178
- )
179
- if sim_cfg is not None
180
- else bool(simulate_missing)
181
- )
206
+
182
207
  if sim_cfg is None:
183
208
  default_strategy = "random"
184
- default_prop = 0.10
209
+ default_prop = 0.2
185
210
  else:
186
211
  default_strategy = sim_cfg.sim_strategy
187
212
  default_prop = sim_cfg.sim_prop
213
+
214
+ self.simulate_missing = True
188
215
  self.sim_strategy = sim_strategy or default_strategy
189
216
  self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
190
217
  self.sim_kwargs = sim_cfg_kwargs
191
218
 
192
219
  if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
193
- msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
220
+ msg = "tree_parser is required for nonrandom sim strategies."
194
221
  self.logger.error(msg)
195
222
  raise ValueError(msg)
196
223
 
197
- # Model hyperparams
224
+ # Model architecture
198
225
  self.latent_dim = int(self.cfg.model.latent_dim)
199
226
  self.dropout_rate = float(self.cfg.model.dropout_rate)
200
227
  self.num_hidden_layers = int(self.cfg.model.num_hidden_layers)
201
228
  self.layer_scaling_factor = float(self.cfg.model.layer_scaling_factor)
202
- self.layer_schedule: str = str(self.cfg.model.layer_schedule)
203
- self.activation = str(self.cfg.model.hidden_activation)
204
- self.gamma = float(self.cfg.model.gamma)
229
+ self.layer_schedule = str(self.cfg.model.layer_schedule)
230
+ self.activation = str(self.cfg.model.activation)
231
+
232
+ # Training / loss controls (align with VAE fields where present)
233
+ self.power = float(getattr(self.cfg.train, "weights_power", 1.0))
234
+ self.max_ratio = getattr(self.cfg.train, "weights_max_ratio", None)
235
+ self.normalize = bool(getattr(self.cfg.train, "weights_normalize", True))
236
+ self.inverse = bool(getattr(self.cfg.train, "weights_inverse", False))
205
237
 
206
- # Train hyperparams
207
238
  self.batch_size = int(self.cfg.train.batch_size)
208
239
  self.learning_rate = float(self.cfg.train.learning_rate)
209
- self.l1_penalty: float = float(self.cfg.train.l1_penalty)
240
+ self.l1_penalty = float(self.cfg.train.l1_penalty)
210
241
  self.early_stop_gen = int(self.cfg.train.early_stop_gen)
211
242
  self.min_epochs = int(self.cfg.train.min_epochs)
212
243
  self.epochs = int(self.cfg.train.max_epochs)
213
244
  self.validation_split = float(self.cfg.train.validation_split)
214
- self.beta = float(self.cfg.train.weights_beta)
215
- self.max_ratio = float(self.cfg.train.weights_max_ratio)
216
245
 
217
- # Tuning
218
- self.tune = bool(self.cfg.tune.enabled)
219
- self.tune_fast = bool(self.cfg.tune.fast)
220
- self.tune_batch_size = int(self.cfg.tune.batch_size)
221
- self.tune_epochs = int(self.cfg.tune.epochs)
222
- self.tune_eval_interval = int(self.cfg.tune.eval_interval)
223
- self.tune_metric: str = self.cfg.tune.metric
224
-
225
- if self.tune_metric is not None:
226
- self.tune_metric_: (
227
- Literal[
228
- "pr_macro",
229
- "f1",
230
- "accuracy",
231
- "precision",
232
- "recall",
233
- "roc_auc",
234
- "average_precision",
235
- ]
236
- | None
237
- ) = self.cfg.tune.metric
246
+ # Gamma can live in cfg.model or cfg.train depending on your dataclasses
247
+ gamma_raw = getattr(
248
+ self.cfg.train, "gamma", getattr(self.cfg.model, "gamma", 0.0)
249
+ )
250
+ if not isinstance(gamma_raw, (float, int)):
251
+ msg = f"Gamma must be float|int; got {type(gamma_raw)}."
252
+ self.logger.error(msg)
253
+ raise TypeError(msg)
254
+ self.gamma = float(gamma_raw)
255
+ self.gamma_schedule = bool(getattr(self.cfg.train, "gamma_schedule", True))
238
256
 
257
+ # Hyperparameter tuning
258
+ self.tune = bool(self.cfg.tune.enabled)
259
+ self.tune_metric = cast(
260
+ Literal[
261
+ "pr_macro",
262
+ "f1",
263
+ "accuracy",
264
+ "precision",
265
+ "recall",
266
+ "roc_auc",
267
+ "average_precision",
268
+ "mcc",
269
+ "jaccard",
270
+ ],
271
+ self.cfg.tune.metric or "f1",
272
+ )
239
273
  self.n_trials = int(self.cfg.tune.n_trials)
240
274
  self.tune_save_db = bool(self.cfg.tune.save_db)
241
275
  self.tune_resume = bool(self.cfg.tune.resume)
242
- self.tune_max_samples = int(self.cfg.tune.max_samples)
243
- self.tune_max_loci = int(self.cfg.tune.max_loci)
244
- self.tune_infer_epochs = int(
245
- getattr(self.cfg.tune, "infer_epochs", 0)
246
- ) # AE unused
247
276
  self.tune_patience = int(self.cfg.tune.patience)
248
277
 
249
- # Evaluate
250
- # AE does not optimize latents, so these are unused / fixed
251
- self.eval_latent_steps: int = 0
252
- self.eval_latent_lr: float = 0.0
253
- self.eval_latent_weight_decay: float = 0.0
254
-
255
- # Plotting (parity with NLPCA PlotConfig)
256
- self.plot_format: Literal["pdf", "png", "jpg", "jpeg", "svg"] = (
257
- self.cfg.plot.fmt
258
- )
278
+ # Plotting
279
+ self.plot_format = self.cfg.plot.fmt
259
280
  self.plot_dpi = int(self.cfg.plot.dpi)
260
281
  self.plot_fontsize = int(self.cfg.plot.fontsize)
261
282
  self.title_fontsize = int(self.cfg.plot.fontsize)
262
283
  self.despine = bool(self.cfg.plot.despine)
263
284
  self.show_plots = bool(self.cfg.plot.show)
264
285
 
265
- # Core derived at fit-time
266
- self.is_haploid: bool = False
267
- self.num_classes_: int | None = None
286
+ # Fit-time attributes
287
+ self.is_haploid_: bool = False
288
+ self.num_classes_: int = 3
268
289
  self.model_params: Dict[str, Any] = {}
269
- self.sim_mask_global_: np.ndarray | None = None
270
- self.sim_mask_train_: np.ndarray | None = None
271
- self.sim_mask_test_: np.ndarray | None = None
272
290
 
273
- def fit(self) -> "ImputeAutoencoder":
274
- """Fit the autoencoder on 0/1/2 encoded genotypes (missing -> -1).
291
+ self.sim_mask_train_: np.ndarray
292
+ self.sim_mask_val_: np.ndarray
293
+ self.sim_mask_test_: np.ndarray
275
294
 
276
- This method trains the autoencoder model using the provided genotype data. It prepares the data by encoding genotypes as 0, 1, and 2, with missing values represented internally as -1. (When simulated-missing loci are generated via ``SimMissingTransformer`` they are first marked with -9 but are immediately re-encoded as -1 prior to training.) The method splits the data into training and validation sets, initializes the model and training parameters, and performs training with optional hyperparameter tuning. After training, it evaluates the model on the validation set and stores the fitted model and training history.
295
+ self.orig_mask_train_: np.ndarray
296
+ self.orig_mask_val_: np.ndarray
297
+ self.orig_mask_test_: np.ndarray
277
298
 
278
- Returns:
279
- ImputeAutoencoder: Fitted instance.
299
+ def fit(self) -> "ImputeAutoencoder":
300
+ """Fit the Autoencoder imputer model to the genotype data.
301
+
302
+ This method performs the following steps:
303
+ 1. Validates the presence of SNP data in the genotype data.
304
+ 2. Determines ploidy and sets up the number of classes accordingly.
305
+ 3. Cleans the ground truth genotype matrix and simulates missingness.
306
+ 4. Splits the data into training, validation, and test sets.
307
+ 5. Prepares one-hot encoded inputs for the model.
308
+ 6. Initializes plotting utilities and valid-class masks.
309
+ 7. Sets up data loaders for training and validation.
310
+ 8. Performs hyperparameter tuning if enabled, otherwise uses fixed hyperparameters.
311
+ 9. Builds and trains the Autoencoder model.
312
+ 10. Evaluates the trained model on the test set.
313
+ 11. Returns the fitted ImputeAutoencoder instance.
280
314
 
281
- Raises:
282
- NotFittedError: If training fails.
315
+ Returns:
316
+ ImputeAutoencoder: The fitted ImputeAutoencoder instance.
283
317
  """
284
318
  self.logger.info(f"Fitting {self.model_name} model...")
285
319
 
286
- # --- Data prep (mirror NLPCA) ---
287
- X012 = self._get_float_genotypes(copy=True)
288
- GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
289
- self.ground_truth_ = GT_full.astype(np.int64, copy=False)
290
-
291
- self.sim_mask_global_ = None
292
- cache_key = self._sim_mask_cache_key()
293
- if self.simulate_missing:
294
- cached_mask = (
295
- None if cache_key is None else self._sim_mask_cache.get(cache_key)
296
- )
297
- if cached_mask is not None:
298
- self.sim_mask_global_ = cached_mask.copy()
299
- else:
300
- tr = SimMissingTransformer(
301
- genotype_data=self.genotype_data,
302
- tree_parser=self.tree_parser,
303
- prop_missing=self.sim_prop,
304
- strategy=self.sim_strategy,
305
- missing_val=-9,
306
- mask_missing=True,
307
- verbose=self.verbose,
308
- **self.sim_kwargs,
309
- )
310
- tr.fit(X012.copy())
311
- self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
312
- if cache_key is not None:
313
- self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
314
-
315
- X_for_model = self.ground_truth_.copy()
316
- X_for_model[self.sim_mask_global_] = -1
317
- else:
318
- X_for_model = self.ground_truth_.copy()
319
-
320
320
  if self.genotype_data.snp_data is None:
321
- msg = "SNP data is required for Autoencoder imputer."
321
+ msg = f"SNP data is required for {self.model_name}."
322
322
  self.logger.error(msg)
323
- raise TypeError(msg)
323
+ raise AttributeError(msg)
324
324
 
325
- # Ploidy & classes
326
- self.is_haploid = bool(
327
- np.all(
328
- np.isin(
329
- self.genotype_data.snp_data,
330
- ["A", "C", "G", "T", "N", "-", ".", "?"],
331
- )
325
+ self.ploidy = self.cfg.io.ploidy
326
+ self.is_haploid_ = self.ploidy == 1
327
+
328
+ if self.ploidy > 2:
329
+ msg = (
330
+ f"{self.model_name} currently supports only haploid (1) or diploid (2) "
331
+ f"data; got ploidy={self.ploidy}."
332
332
  )
333
- )
334
- self.ploidy = 1 if self.is_haploid else 2
335
- # Scoring still uses 3 labels for diploid (REF/HET/ALT); model head uses 2 logits
336
- self.num_classes_ = 2 if self.is_haploid else 3
337
- self.output_classes_ = 2
338
- self.logger.info(
339
- f"Data is {'haploid' if self.is_haploid else 'diploid'}; "
340
- f"using {self.num_classes_} classes for scoring and {self.output_classes_} output channels."
341
- )
333
+ self.logger.error(msg)
334
+ raise ValueError(msg)
342
335
 
343
- if self.is_haploid:
344
- self.ground_truth_[self.ground_truth_ == 2] = 1
345
- X_for_model[X_for_model == 2] = 1
336
+ self.num_classes_ = 2 if self.is_haploid_ else 3
346
337
 
347
- n_samples, self.num_features_ = X_for_model.shape
338
+ # Clean 0/1/2 ground truth (missing=-1)
339
+ gt_full = self.pgenc.genotypes_012.copy()
340
+ gt_full[gt_full < 0] = -1
341
+ gt_full = np.nan_to_num(gt_full, nan=-1.0)
342
+ self.ground_truth_ = gt_full.astype(np.int8)
343
+ self.num_features_ = int(self.ground_truth_.shape[1])
348
344
 
349
- # Model params (decoder outputs L * K logits)
350
345
  self.model_params = {
351
346
  "n_features": self.num_features_,
352
- "num_classes": self.output_classes_,
347
+ "num_classes": self.num_classes_,
353
348
  "latent_dim": self.latent_dim,
354
349
  "dropout_rate": self.dropout_rate,
355
350
  "activation": self.activation,
356
351
  }
357
352
 
358
- # Train/Val split
359
- indices = np.arange(n_samples)
360
- train_idx, val_idx = train_test_split(
361
- indices, test_size=self.validation_split, random_state=self.seed
353
+ # Simulate missingness ONCE on the full matrix
354
+ X_for_model_full, self.sim_mask_, self.orig_mask_ = self.sim_missing_transform(
355
+ self.ground_truth_
362
356
  )
363
- self.train_idx_, self.test_idx_ = train_idx, val_idx
364
- self.X_train_ = X_for_model[train_idx]
365
- self.X_val_ = X_for_model[val_idx]
366
- self.GT_train_full_ = self.ground_truth_[train_idx]
367
- self.GT_test_full_ = self.ground_truth_[val_idx]
368
-
369
- if self.sim_mask_global_ is not None:
370
- self.sim_mask_train_ = self.sim_mask_global_[train_idx]
371
- self.sim_mask_test_ = self.sim_mask_global_[val_idx]
372
- else:
373
- self.sim_mask_train_ = None
374
- self.sim_mask_test_ = None
375
357
 
376
- # Pos weights for diploid multilabel path (must exist before tuning)
377
- if not self.is_haploid:
378
- self.pos_weights_ = self._compute_pos_weights(self.X_train_)
379
- else:
380
- self.pos_weights_ = None
358
+ # Split indices based on clean ground truth
359
+ self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(
360
+ self.ground_truth_
361
+ )
362
+
363
+ # --- Clean targets per split ---
364
+ X_train_clean = self.ground_truth_[self.train_idx_].copy()
365
+ X_val_clean = self.ground_truth_[self.val_idx_].copy()
366
+ X_test_clean = self.ground_truth_[self.test_idx_].copy()
367
+
368
+ # --- Corrupted inputs per split (from the single simulation) ---
369
+ X_train_corrupted = X_for_model_full[self.train_idx_].copy()
370
+ X_val_corrupted = X_for_model_full[self.val_idx_].copy()
371
+ X_test_corrupted = X_for_model_full[self.test_idx_].copy()
372
+
373
+ # --- Masks per split ---
374
+ self.sim_mask_train_ = self.sim_mask_[self.train_idx_].copy()
375
+ self.sim_mask_val_ = self.sim_mask_[self.val_idx_].copy()
376
+ self.sim_mask_test_ = self.sim_mask_[self.test_idx_].copy()
377
+
378
+ self.orig_mask_train_ = self.orig_mask_[self.train_idx_].copy()
379
+ self.orig_mask_val_ = self.orig_mask_[self.val_idx_].copy()
380
+ self.orig_mask_test_ = self.orig_mask_[self.test_idx_].copy()
381
+
382
+ # Persist per-split matrices
383
+ self.X_train_clean_ = X_train_clean
384
+ self.X_val_clean_ = X_val_clean
385
+ self.X_test_clean_ = X_test_clean
386
+
387
+ self.X_train_corrupted_ = X_train_corrupted
388
+ self.X_val_corrupted_ = X_val_corrupted
389
+ self.X_test_corrupted_ = X_test_corrupted
390
+
391
+ # Haploid harmonization (do NOT resimulate; just recode values)
392
+ if self.is_haploid_:
393
+
394
+ def _haploidize(arr: np.ndarray) -> np.ndarray:
395
+ out = arr.copy()
396
+ miss = out < 0
397
+ out = np.where(out > 0, 1, out).astype(np.int8, copy=False)
398
+ out[miss] = -1
399
+ return out
400
+
401
+ self.X_train_clean_ = _haploidize(self.X_train_clean_)
402
+ self.X_val_clean_ = _haploidize(self.X_val_clean_)
403
+ self.X_test_clean_ = _haploidize(self.X_test_clean_)
404
+
405
+ self.X_train_corrupted_ = _haploidize(self.X_train_corrupted_)
406
+ self.X_val_corrupted_ = _haploidize(self.X_val_corrupted_)
407
+ self.X_test_corrupted_ = _haploidize(self.X_test_corrupted_)
408
+
409
+ # Convention: X_* are corrupted inputs; y_* are clean targets
410
+ self.X_train_ = self.X_train_corrupted_
411
+ self.y_train_ = self.X_train_clean_
412
+
413
+ self.X_val_ = self.X_val_corrupted_
414
+ self.y_val_ = self.X_val_clean_
415
+
416
+ self.X_test_ = self.X_test_corrupted_
417
+ self.y_test_ = self.X_test_clean_
418
+
419
+ # One-hot for loaders/model input
420
+ X_train_ohe = self._one_hot_encode_012(
421
+ self.X_train_, num_classes=self.num_classes_
422
+ )
423
+ X_val_ohe = self._one_hot_encode_012(self.X_val_, num_classes=self.num_classes_)
381
424
 
382
- # Plotters/scorers (shared utilities)
425
+ # Plotters/scorers + valid-class mask repairs (copied from VAE flow)
383
426
  self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
427
+ self.valid_class_mask_ = self._build_valid_class_mask()
428
+
429
+ loci = getattr(self, "valid_class_mask_conflict_loci_", None)
430
+ if loci is not None and loci.size:
431
+ self._repair_ref_alt_from_iupac(loci)
432
+ self.valid_class_mask_ = self._build_valid_class_mask()
433
+
434
+ train_loader = self._get_data_loaders(
435
+ X_train_ohe.detach().cpu().numpy(),
436
+ self.y_train_,
437
+ ~self.orig_mask_train_,
438
+ self.batch_size,
439
+ shuffle=True,
440
+ )
441
+ val_loader = self._get_data_loaders(
442
+ X_val_ohe.detach().cpu().numpy(),
443
+ self.y_val_,
444
+ ~self.orig_mask_val_,
445
+ self.batch_size,
446
+ shuffle=False,
447
+ )
448
+
449
+ self.train_loader_ = train_loader
450
+ self.val_loader_ = val_loader
384
451
 
385
- # Tuning (optional; AE never needs latent refinement)
452
+ # Hyperparameter tuning or fixed run
386
453
  if self.tune:
387
- self.tune_hyperparameters()
454
+ self.tuned_params_ = self.tune_hyperparameters()
455
+ self.model_tuned_ = True
456
+ else:
457
+ self.model_tuned_ = False
458
+ self.class_weights_ = self._class_weights_from_zygosity(
459
+ self.y_train_,
460
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
461
+ inverse=self.inverse,
462
+ normalize=self.normalize,
463
+ max_ratio=self.max_ratio,
464
+ power=self.power,
465
+ )
466
+ self.tuned_params_ = {
467
+ "latent_dim": self.latent_dim,
468
+ "learning_rate": self.learning_rate,
469
+ "dropout_rate": self.dropout_rate,
470
+ "num_hidden_layers": self.num_hidden_layers,
471
+ "activation": self.activation,
472
+ "l1_penalty": self.l1_penalty,
473
+ "layer_scaling_factor": self.layer_scaling_factor,
474
+ "layer_schedule": self.layer_schedule,
475
+ "gamma": self.gamma,
476
+ "gamma_schedule": self.gamma_schedule,
477
+ "inverse": self.inverse,
478
+ "normalize": self.normalize,
479
+ "power": self.power,
480
+ }
481
+ self.tuned_params_["model_params"] = self.model_params
388
482
 
389
- # Best params (tuned or default)
390
- self.best_params_ = getattr(self, "best_params_", self._default_best_params())
483
+ if self.class_weights_ is not None:
484
+ self.logger.info(
485
+ f"class_weights={self.class_weights_.detach().cpu().numpy().tolist()}"
486
+ )
391
487
 
392
- # Class weights (device-aware)
393
- self.class_weights_ = self._normalize_class_weights(
394
- self._class_weights_from_zygosity(self.X_train_)
395
- )
488
+ # Always start clean
489
+ self.best_params_ = copy.deepcopy(self.tuned_params_)
396
490
 
397
- # DataLoader
398
- train_loader = self._get_data_loaders(self.X_train_)
491
+ # Final model params (compute hidden sizes using n_inputs=L*K, mirroring VAE)
492
+ input_dim = int(self.num_features_ * self.num_classes_)
493
+ model_params_final = {
494
+ "n_features": int(self.num_features_),
495
+ "num_classes": int(self.num_classes_),
496
+ "latent_dim": int(self.best_params_["latent_dim"]),
497
+ "dropout_rate": float(self.best_params_["dropout_rate"]),
498
+ "activation": str(self.best_params_["activation"]),
499
+ }
500
+ model_params_final["hidden_layer_sizes"] = self._compute_hidden_layer_sizes(
501
+ n_inputs=input_dim,
502
+ n_outputs=int(self.num_classes_),
503
+ n_samples=len(self.train_idx_),
504
+ n_hidden=int(self.best_params_["num_hidden_layers"]),
505
+ latent_dim=int(self.best_params_["latent_dim"]),
506
+ alpha=float(self.best_params_["layer_scaling_factor"]),
507
+ schedule=str(self.best_params_["layer_schedule"]),
508
+ min_size=max(16, 2 * int(self.best_params_["latent_dim"])),
509
+ )
510
+ self.best_params_["model_params"] = model_params_final
399
511
 
400
- # Build & train
401
- model = self.build_model(self.Model, self.best_params_)
512
+ # Build and train
513
+ model = self.build_model(self.Model, self.best_params_["model_params"])
402
514
  model.apply(self.initialize_weights)
403
515
 
516
+ if self.verbose or self.debug:
517
+ self.logger.info("Using model hyperparameters:")
518
+ pm = PrettyMetrics(
519
+ self.best_params_, precision=3, title="Model Hyperparameters"
520
+ )
521
+ pm.render()
522
+
523
+ lr_final = float(self.best_params_["learning_rate"])
524
+ l1_final = float(self.best_params_["l1_penalty"])
525
+ gamma_schedule = bool(
526
+ self.best_params_.get("gamma_schedule", self.gamma_schedule)
527
+ )
528
+
404
529
  loss, trained_model, history = self._train_and_validate_model(
405
530
  model=model,
406
- loader=train_loader,
407
- lr=self.learning_rate,
408
- l1_penalty=self.l1_penalty,
409
- return_history=True,
410
- class_weights=self.class_weights_,
411
- X_val=self.X_val_,
531
+ lr=lr_final,
532
+ l1_penalty=l1_final,
412
533
  params=self.best_params_,
413
- prune_metric=self.tune_metric,
414
- prune_warmup_epochs=10,
415
- eval_interval=1,
416
- eval_requires_latents=False,
417
- eval_latent_steps=0,
418
- eval_latent_lr=0.0,
419
- eval_latent_weight_decay=0.0,
534
+ trial=None,
535
+ class_weights=getattr(self, "class_weights_", None),
536
+ gamma_schedule=gamma_schedule,
420
537
  )
421
538
 
422
539
  if trained_model is None:
423
- msg = "Autoencoder training failed; no model was returned."
540
+ msg = f"{self.model_name} training failed."
424
541
  self.logger.error(msg)
425
542
  raise RuntimeError(msg)
426
543
 
@@ -429,217 +546,194 @@ class ImputeAutoencoder(BaseNNImputer):
429
546
  self.models_dir / f"final_model_{self.model_name}.pt",
430
547
  )
431
548
 
432
- hist: Dict[str, List[float] | Dict[str, List[float]] | None] | None = {
433
- "Train": history
434
- }
435
- self.best_loss_, self.model_, self.history_ = (loss, trained_model, hist)
549
+ if history is None:
550
+ hist = {"Train": []}
551
+ elif isinstance(history, dict):
552
+ hist = dict(history)
553
+ else:
554
+ hist = {"Train": list(history["Train"]), "Val": list(history["Val"])}
555
+
556
+ self.best_loss_ = float(loss)
557
+ self.model_ = trained_model
558
+ self.history_ = hist
436
559
  self.is_fit_ = True
437
560
 
438
- # Evaluate on validation set (parity with NLPCA reporting)
439
- eval_mask = (
440
- self.sim_mask_test_
441
- if (self.simulate_missing and self.sim_mask_test_ is not None)
442
- else None
443
- )
561
+ # Evaluate on simulated-missing sites only
444
562
  self._evaluate_model(
445
- self.X_val_, self.model_, self.best_params_, eval_mask_override=eval_mask
563
+ self.model_,
564
+ X=self.X_test_,
565
+ y=self.y_test_,
566
+ eval_mask=self.sim_mask_test_ & ~self.orig_mask_test_,
567
+ objective_mode=False,
446
568
  )
447
- self.plotter_.plot_history(self.history_)
569
+
570
+ if self.show_plots:
571
+ self.plotter_.plot_history(self.history_)
572
+
448
573
  self._save_best_params(self.best_params_)
449
574
 
575
+ if self.model_tuned_:
576
+ title = f"{self.model_name} Optimized Parameters"
577
+
578
+ if self.verbose or self.debug:
579
+ pm = PrettyMetrics(self.best_params_, precision=2, title=title)
580
+ pm.render()
581
+
582
+ # Save best parameters to a JSON file.
583
+ self._save_best_params(self.best_params_, objective_mode=True)
584
+
450
585
  return self
451
586
 
452
587
  def transform(self) -> np.ndarray:
453
- """Impute missing genotypes (0/1/2) and return IUPAC strings.
588
+ """Impute missing genotypes and return IUPAC strings.
454
589
 
455
- This method imputes missing genotypes in the dataset using the trained autoencoder model. It predicts the most likely genotype (0, 1, or 2) for each missing entry and fills in these values. The imputed genotypes are then decoded back to IUPAC string format for easier interpretation.
590
+ This method performs the following steps:
591
+ 1. Validates that the model has been fitted.
592
+ 2. Uses the trained model to predict missing genotypes for the entire dataset.
593
+ 3. Fills in the missing genotypes in the original dataset with the predicted values from the model.
594
+ 4. Decodes the imputed genotype matrix from 0/1/2 encoding to IUPAC strings.
595
+ 5. Checks for any remaining missing values or decoding issues, raising errors if found.
596
+ 6. Optionally generates and displays plots comparing the original and imputed genotype distributions.
597
+ 7. Returns the imputed IUPAC genotype matrix.
456
598
 
457
599
  Returns:
458
- np.ndarray: IUPAC strings of shape (n_samples, n_loci).
600
+ np.ndarray: IUPAC genotype matrix of shape (n_samples, n_loci).
459
601
 
460
602
  Raises:
461
603
  NotFittedError: If called before fit().
604
+ RuntimeError: If any missing values remain or decoding yields "N".
605
+ RuntimeError: If loci contain 'N' after imputation due to missing REF/ALT metadata.
462
606
  """
463
607
  if not getattr(self, "is_fit_", False):
464
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
608
+ msg = "Model is not fitted. Call fit() before transform()."
609
+ self.logger.error(msg)
610
+ raise NotFittedError(msg)
465
611
 
466
- self.logger.info(f"Imputing entire dataset with {self.model_name}...")
612
+ self.logger.info(f"Imputing entire dataset with {self.model_name} model...")
467
613
  X_to_impute = self.ground_truth_.copy()
468
614
 
469
- # Predict with masked inputs (no latent optimization)
470
- pred_labels, _ = self._predict(self.model_, X=X_to_impute, return_proba=True)
615
+ pred_labels, _ = self._predict(self.model_, X=X_to_impute)
471
616
 
472
- # Fill only missing
473
- missing_mask = X_to_impute == -1
617
+ missing_mask = X_to_impute < 0
474
618
  imputed_array = X_to_impute.copy()
475
619
  imputed_array[missing_mask] = pred_labels[missing_mask]
476
620
 
477
- # Decode to IUPAC & optionally plot
478
- imputed_genotypes = self.pgenc.decode_012(imputed_array)
621
+ if np.any(imputed_array < 0):
622
+ msg = f"[{self.model_name}] Some missing genotypes remain after imputation. This is unexpected."
623
+ self.logger.error(msg)
624
+ raise RuntimeError(msg)
625
+
626
+ decode_input = imputed_array
627
+ if self.is_haploid_:
628
+ decode_input = imputed_array.copy()
629
+ decode_input[decode_input == 1] = 2
630
+
631
+ imputed_genotypes = self.decode_012(decode_input)
632
+
633
+ bad_loci = np.where((imputed_genotypes == "N").any(axis=0))[0]
634
+ if bad_loci.size > 0:
635
+ msg = f"[{self.model_name}] {bad_loci.size} loci contain 'N' after imputation (e.g., first 10 indices: {bad_loci[:10].tolist()}). This occurs when REF/ALT metadata is missing and cannot be inferred from the source data (e.g., loci with 100 percent missing genotypes). Try filtering out these loci before imputation."
636
+ self.logger.error(msg)
637
+ self.logger.debug(
638
+ "All loci with 'N': " + ", ".join(map(str, bad_loci.tolist()))
639
+ )
640
+ raise RuntimeError(msg)
641
+
479
642
  if self.show_plots:
480
- original_genotypes = self.pgenc.decode_012(X_to_impute)
643
+ original_input = X_to_impute
644
+ if self.is_haploid_:
645
+ original_input = X_to_impute.copy()
646
+ original_input[original_input == 1] = 2
647
+
648
+ original_genotypes = self.decode_012(original_input)
649
+
481
650
  plt.rcParams.update(self.plotter_.param_dict)
482
651
  self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
483
652
  self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
484
653
 
485
654
  return imputed_genotypes
486
655
 
487
- def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
488
- """Create DataLoader over indices + integer targets (-1 for missing).
489
-
490
- This method creates a PyTorch DataLoader that yields batches of indices and their corresponding genotype targets encoded as integers (0, 1, 2) with -1 indicating missing values. The DataLoader is shuffled to ensure random sampling during training.
491
-
492
- Args:
493
- y (np.ndarray): 0/1/2 matrix with -1 for missing.
494
-
495
- Returns:
496
- torch.utils.data.DataLoader: Shuffled DataLoader.
497
- """
498
- y_tensor = torch.from_numpy(y).long()
499
- indices = torch.arange(len(y), dtype=torch.long)
500
- dataset = torch.utils.data.TensorDataset(indices, y_tensor)
501
- pin_memory = self.device.type == "cuda"
502
- return torch.utils.data.DataLoader(
503
- dataset,
504
- batch_size=self.batch_size,
505
- shuffle=True,
506
- pin_memory=pin_memory,
507
- )
508
-
509
656
  def _train_and_validate_model(
510
657
  self,
511
658
  model: torch.nn.Module,
512
- loader: torch.utils.data.DataLoader,
659
+ *,
513
660
  lr: float,
514
661
  l1_penalty: float,
515
- trial: optuna.Trial | None = None,
516
- return_history: bool = False,
517
- class_weights: torch.Tensor | None = None,
518
- *,
519
- X_val: np.ndarray | None = None,
520
- params: dict | None = None,
521
- prune_metric: str = "f1", # "f1" | "accuracy" | "pr_macro"
522
- prune_warmup_epochs: int = 10,
523
- eval_interval: int = 1,
524
- # Evaluation parameters (AE ignores latent refinement knobs)
525
- eval_requires_latents: bool = False, # AE: always False
526
- eval_latent_steps: int = 0,
527
- eval_latent_lr: float = 0.0,
528
- eval_latent_weight_decay: float = 0.0,
529
- ) -> Tuple[float, torch.nn.Module | None, list | None]:
530
- """Wrap the AE training loop (no latent optimizer), with Optuna pruning.
531
-
532
- This method orchestrates the training of the autoencoder model using the provided DataLoader. It sets up the optimizer and learning rate scheduler, and executes the training loop with support for early stopping and Optuna pruning based on validation performance. The method returns the best validation loss, the best model state, and optionally the training history.
662
+ trial: Optional[optuna.Trial] = None,
663
+ params: Optional[dict[str, Any]] = None,
664
+ class_weights: Optional[torch.Tensor] = None,
665
+ gamma_schedule: bool = False,
666
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
667
+ """Train and validate the model.
668
+
669
+ This method sets up the optimizer and learning rate scheduler, then executes the training loop with early stopping and optional hyperparameter tuning via Optuna. It returns the best validation loss, the best model, and the training history.
533
670
 
534
671
  Args:
535
672
  model (torch.nn.Module): Autoencoder model.
536
- loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
537
673
  lr (float): Learning rate.
538
- l1_penalty (float): L1 regularization coeff.
539
- trial (optuna.Trial | None): Optuna trial for pruning (optional).
540
- return_history (bool): If True, return train loss history.
541
- class_weights (torch.Tensor | None): Class weights tensor (on device).
542
- X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
543
- params (dict | None): Model params for evaluation.
544
- prune_metric (str): Metric for pruning reports.
545
- prune_warmup_epochs (int): Pruning warmup epochs.
546
- eval_interval (int): Eval frequency (epochs).
547
- eval_requires_latents (bool): Ignored for AE (no latent inference).
548
- eval_latent_steps (int): Unused for AE.
549
- eval_latent_lr (float): Unused for AE.
550
- eval_latent_weight_decay (float): Unused for AE.
674
+ l1_penalty (float): L1 regularization coefficient.
675
+ trial (Optional[optuna.Trial]): Optuna trial (optional).
676
+ params (Optional[dict[str, Any]]): Hyperparams dict (optional).
677
+ class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
678
+ gamma_schedule (bool): Whether to schedule gamma.
551
679
 
552
680
  Returns:
553
- Tuple[float, torch.nn.Module | None, list | None]: (best_loss, best_model, history or None).
681
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best validation loss, best model, history.
554
682
  """
555
- if class_weights is None:
556
- msg = "Must provide class_weights."
557
- self.logger.error(msg)
558
- raise TypeError(msg)
683
+ max_epochs = int(self.epochs)
684
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
559
685
 
560
- # Epoch budget mirrors NLPCA config (tuning vs final)
561
- max_epochs = (
562
- self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
686
+ scheduler = _make_warmup_cosine_scheduler(
687
+ optimizer, max_epochs=max_epochs, warmup_epochs=int(0.1 * max_epochs)
563
688
  )
564
689
 
565
- optimizer = torch.optim.Adam(model.parameters(), lr=lr)
566
- scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
567
-
568
690
  best_loss, best_model, hist = self._execute_training_loop(
569
- loader=loader,
570
691
  optimizer=optimizer,
571
692
  scheduler=scheduler,
572
693
  model=model,
573
694
  l1_penalty=l1_penalty,
574
695
  trial=trial,
575
- return_history=return_history,
576
- class_weights=class_weights,
577
- X_val=X_val,
578
696
  params=params,
579
- prune_metric=prune_metric,
580
- prune_warmup_epochs=prune_warmup_epochs,
581
- eval_interval=eval_interval,
582
- eval_requires_latents=False, # AE: no latent inference
583
- eval_latent_steps=0,
584
- eval_latent_lr=0.0,
585
- eval_latent_weight_decay=0.0,
697
+ class_weights=class_weights,
698
+ gamma_schedule=gamma_schedule,
586
699
  )
587
- if return_history:
588
- return best_loss, best_model, hist
589
-
590
- return best_loss, best_model, None
700
+ return best_loss, best_model, hist
591
701
 
592
702
  def _execute_training_loop(
593
703
  self,
594
- loader: torch.utils.data.DataLoader,
704
+ *,
595
705
  optimizer: torch.optim.Optimizer,
596
- scheduler: CosineAnnealingLR,
706
+ scheduler: (
707
+ torch.optim.lr_scheduler.CosineAnnealingLR
708
+ | torch.optim.lr_scheduler.SequentialLR
709
+ ),
597
710
  model: torch.nn.Module,
598
711
  l1_penalty: float,
599
- trial: optuna.Trial | None,
600
- return_history: bool,
601
- class_weights: torch.Tensor,
602
- *,
603
- X_val: np.ndarray | None = None,
604
- params: dict | None = None,
605
- prune_metric: str = "f1",
606
- prune_warmup_epochs: int = 10,
607
- eval_interval: int = 1,
608
- # Evaluation parameters (AE ignores latent refinement knobs)
609
- eval_requires_latents: bool = False, # AE: False
610
- eval_latent_steps: int = 0,
611
- eval_latent_lr: float = 0.0,
612
- eval_latent_weight_decay: float = 0.0,
613
- ) -> Tuple[float, torch.nn.Module, list]:
614
- """Train AE with focal CE (gamma warm/ramp) + early stopping & pruning.
615
-
616
- This method executes the training loop for the autoencoder model, performing one epoch at a time. It computes the focal cross-entropy loss while ignoring masked (missing) values and applies L1 regularization if specified. The method incorporates early stopping based on validation performance and supports Optuna pruning to terminate unpromising trials early. It returns the best validation loss, the best model state, and optionally the training history.
712
+ trial: Optional[optuna.Trial] = None,
713
+ params: Optional[dict[str, Any]] = None,
714
+ class_weights: Optional[torch.Tensor] = None,
715
+ gamma_schedule: bool = False,
716
+ ) -> tuple[float, torch.nn.Module, dict[str, list[float]]]:
717
+ """Train AE (masked focal CE) with EarlyStopping + Optuna pruning.
617
718
 
618
719
  Args:
619
- loader (torch.utils.data.DataLoader): Batches (indices, y_int) where y_int is 0/1/2; -1 for missing.
620
- optimizer (torch.optim.Optimizer): Optimizer.
621
- scheduler (torch.optim.lr_scheduler._LRScheduler): LR scheduler.
720
+ optimizer (torch.optim.Optimizer): Optimizer for training.
721
+ scheduler (torch.optim.lr_scheduler.CosineAnnealingLR | torch.optim.lr_scheduler.SequentialLR): LR scheduler.
622
722
  model (torch.nn.Module): Autoencoder model.
623
- l1_penalty (float): L1 regularization coeff.
624
- trial (optuna.Trial | None): Optuna trial for pruning (optional).
625
- return_history (bool): If True, return train loss history.
626
- class_weights (torch.Tensor): Class weights tensor (on device).
627
- X_val (np.ndarray | None): Validation matrix (0/1/2 with -1 for missing).
628
- params (dict | None): Model params for evaluation.
629
- prune_metric (str): Metric for pruning reports.
630
- prune_warmup_epochs (int): Pruning warmup epochs.
631
- eval_interval (int): Eval frequency (epochs).
632
- eval_requires_latents (bool): Ignored for AE (no latent inference).
633
- eval_latent_steps (int): Unused for AE.
634
- eval_latent_lr (float): Unused for AE.
635
- eval_latent_weight_decay (float): Unused for AE.
723
+ l1_penalty (float): L1 regularization coefficient.
724
+ trial (Optional[optuna.Trial]): Optuna trial (optional).
725
+ params (Optional[dict[str, Any]]): Hyperparams dict (optional).
726
+ class_weights (Optional[torch.Tensor]): Class weights for focal CE (optional).
727
+ gamma_schedule (bool): Whether to schedule gamma.
636
728
 
637
729
  Returns:
638
- Tuple[float, torch.nn.Module, list]: Best validation loss, best model, and training history.
730
+ tuple[float, torch.nn.Module, dict[str, list[float]]]: Best loss, best model, and training history.
731
+
732
+ Notes:
733
+ - Computes loss only where targets are known (~orig_mask_*).
734
+ - Evaluates metrics only on simulated-missing sites (sim_mask_*).
639
735
  """
640
- best_loss = float("inf")
641
- best_model = None
642
- history: list[float] = []
736
+ history: dict[str, list[float]] = defaultdict(list)
643
737
 
644
738
  early_stopping = EarlyStopping(
645
739
  patience=self.early_stop_gen,
@@ -649,157 +743,157 @@ class ImputeAutoencoder(BaseNNImputer):
649
743
  debug=self.debug,
650
744
  )
651
745
 
652
- gamma_val = self.gamma
653
- if isinstance(gamma_val, (list, tuple)):
654
- if len(gamma_val) == 0:
655
- raise ValueError("gamma list is empty.")
656
- gamma_val = gamma_val[0]
657
-
658
- gamma_final = float(gamma_val)
659
- gamma_warm, gamma_ramp = 50, 100
660
-
661
- # Optional LR warmup
662
- warmup_epochs = int(getattr(self, "lr_warmup_epochs", 5))
663
- base_lr = float(optimizer.param_groups[0]["lr"])
664
- min_lr = base_lr * 0.1
665
-
666
- max_epochs = int(getattr(scheduler, "T_max", getattr(self, "epochs", 100)))
667
-
668
- for epoch in range(max_epochs):
669
- # focal γ schedule (for stable training)
670
- if epoch < gamma_warm:
671
- model.gamma = 0.0 # type: ignore
672
- elif epoch < gamma_warm + gamma_ramp:
673
- model.gamma = gamma_final * ((epoch - gamma_warm) / gamma_ramp) # type: ignore
746
+ gamma_target, gamma_warm, gamma_ramp = self._anneal_config(
747
+ params, "gamma", default=self.gamma, max_epochs=self.epochs
748
+ )
749
+ gamma_target = float(gamma_target)
750
+
751
+ cw = class_weights
752
+ if cw is not None and cw.device != self.device:
753
+ cw = cw.to(self.device)
754
+
755
+ for epoch in range(int(self.epochs)):
756
+ if gamma_schedule:
757
+ gamma_current = self._update_anneal_schedule(
758
+ gamma_target,
759
+ warm=gamma_warm,
760
+ ramp=gamma_ramp,
761
+ epoch=epoch,
762
+ init_val=0.0,
763
+ )
764
+ gamma_val = float(gamma_current)
674
765
  else:
675
- model.gamma = gamma_final # type: ignore
766
+ gamma_val = gamma_target
676
767
 
677
- # LR warmup
678
- if epoch < warmup_epochs:
679
- scale = float(epoch + 1) / warmup_epochs
680
- for g in optimizer.param_groups:
681
- g["lr"] = min_lr + (base_lr - min_lr) * scale
768
+ ce_criterion = FocalCELoss(
769
+ alpha=cw, gamma=gamma_val, ignore_index=-1, reduction="mean"
770
+ )
682
771
 
683
772
  train_loss = self._train_step(
684
- loader=loader,
773
+ loader=self.train_loader_,
685
774
  optimizer=optimizer,
686
775
  model=model,
776
+ ce_criterion=ce_criterion,
687
777
  l1_penalty=l1_penalty,
688
- class_weights=class_weights,
689
778
  )
690
779
 
691
- # Abort or prune on non-finite epoch loss
692
780
  if not np.isfinite(train_loss):
693
781
  if trial is not None:
694
- raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
695
- # Soft reset suggestion: reduce LR and continue, or break
696
- self.logger.warning(
697
- "Non-finite epoch loss. Reducing LR by 10 percent and continuing."
698
- )
699
- for g in optimizer.param_groups:
700
- g["lr"] *= 0.9
701
- continue
782
+ msg = f"[{self.model_name}] Trial {trial.number} training loss non-finite."
783
+ self.logger.warning(msg)
784
+ raise optuna.exceptions.TrialPruned(msg)
785
+ msg = f"[{self.model_name}] Training loss is non-finite at epoch {epoch + 1}."
786
+ self.logger.error(msg)
787
+ raise RuntimeError(msg)
788
+
789
+ val_loss = self._val_step(
790
+ loader=self.val_loader_,
791
+ model=model,
792
+ ce_criterion=ce_criterion,
793
+ l1_penalty=l1_penalty,
794
+ )
702
795
 
703
796
  scheduler.step()
704
- if return_history:
705
- history.append(train_loss)
797
+ history["Train"].append(float(train_loss))
798
+ history["Val"].append(float(val_loss))
706
799
 
707
- early_stopping(train_loss, model)
800
+ early_stopping(val_loss, model)
708
801
  if early_stopping.early_stop:
709
- self.logger.info(f"Early stopping at epoch {epoch + 1}.")
802
+ self.logger.debug(
803
+ f"[{self.model_name}] Early stopping at epoch {epoch + 1}."
804
+ )
710
805
  break
711
806
 
712
- # Optuna report/prune on validation metric
713
- if (
714
- trial is not None
715
- and X_val is not None
716
- and ((epoch + 1) % eval_interval == 0)
717
- ):
718
- metric_key = prune_metric or getattr(self, "tune_metric", "f1")
719
- mask_override = None
720
- if (
721
- self.simulate_missing
722
- and getattr(self, "sim_mask_test_", None) is not None
723
- and getattr(self, "X_val_", None) is not None
724
- and X_val.shape == self.X_val_.shape
725
- ):
726
- mask_override = self.sim_mask_test_
727
- metric_val = self._eval_for_pruning(
807
+ if trial is not None:
808
+ metric_vals = self._evaluate_model(
728
809
  model=model,
729
- X_val=X_val,
730
- params=params or getattr(self, "best_params_", {}),
731
- metric=metric_key,
810
+ X=self.X_val_,
811
+ y=self.y_val_,
812
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
732
813
  objective_mode=True,
733
- do_latent_infer=False, # AE: False
734
- latent_steps=0,
735
- latent_lr=0.0,
736
- latent_weight_decay=0.0,
737
- latent_seed=self.seed, # type: ignore
738
- _latent_cache=None, # AE: not used
739
- _latent_cache_key=None,
740
- eval_mask_override=mask_override,
741
814
  )
742
- trial.report(metric_val, step=epoch + 1)
743
- if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
815
+ trial.report(metric_vals[self.tune_metric], step=epoch + 1)
816
+ if trial.should_prune():
744
817
  raise optuna.exceptions.TrialPruned(
745
- f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.5f}"
818
+ f"[{self.model_name}] Trial {trial.number} pruned at epoch {epoch + 1}."
746
819
  )
747
820
 
748
- best_loss = early_stopping.best_score
749
- if early_stopping.best_model is not None:
750
- best_model = copy.deepcopy(early_stopping.best_model)
751
- else:
752
- best_model = copy.deepcopy(model)
753
- return best_loss, best_model, history
821
+ best_loss = float(early_stopping.best_score)
822
+ if early_stopping.best_state_dict is not None:
823
+ model.load_state_dict(early_stopping.best_state_dict)
824
+
825
+ return best_loss, model, dict(history)
754
826
 
755
827
  def _train_step(
756
828
  self,
757
829
  loader: torch.utils.data.DataLoader,
758
830
  optimizer: torch.optim.Optimizer,
759
831
  model: torch.nn.Module,
832
+ ce_criterion: torch.nn.Module,
833
+ *,
760
834
  l1_penalty: float,
761
- class_weights: torch.Tensor,
762
835
  ) -> float:
763
- """One epoch with stable focal CE and NaN/Inf guards."""
836
+ """Single epoch train step (masked focal CE + optional L1).
837
+
838
+ Args:
839
+ loader (torch.utils.data.DataLoader): Training data loader.
840
+ optimizer (torch.optim.Optimizer): Optimizer for training.
841
+ model (torch.nn.Module): Autoencoder model.
842
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
843
+ l1_penalty (float): L1 regularization coefficient.
844
+
845
+ Returns:
846
+ float: Average training loss over the epoch.
847
+
848
+ Notes:
849
+ Expects loader batches as (X_ohe, y_int, mask_bool) where:
850
+ - X_ohe: (B, L, C) float/compatible
851
+ - y_int: (B, L) int, with -1 for unknown targets
852
+ - mask_bool: (B, L) bool selecting which positions contribute to loss
853
+ """
764
854
  model.train()
765
855
  running = 0.0
766
856
  num_batches = 0
857
+
858
+ nF_model = int(getattr(model, "n_features", self.num_features_))
859
+ nC_model = int(getattr(model, "num_classes", self.num_classes_))
767
860
  l1_params = tuple(p for p in model.parameters() if p.requires_grad)
768
- if class_weights is not None and class_weights.device != self.device:
769
- class_weights = class_weights.to(self.device)
770
-
771
- # Use model.gamma if present, else self.gamma
772
- gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
773
- gamma = float(torch.tensor(gamma).clamp(min=0.0, max=10.0)) # sane bound
774
- ce_criterion = SafeFocalCELoss(
775
- gamma=gamma, weight=class_weights, ignore_index=-1
776
- )
777
861
 
778
- for _, y_batch in loader:
862
+ for X_batch, y_batch, m_batch in loader:
779
863
  optimizer.zero_grad(set_to_none=True)
780
- y_batch = y_batch.to(self.device, non_blocking=True)
781
-
782
- # Inputs: one-hot with zeros for missing; Targets: long ints with -1 for missing
783
- if self.is_haploid:
784
- x_in = self._one_hot_encode_012(y_batch) # (B, L, 2)
785
- logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
786
- logits_flat = logits.view(-1, self.output_classes_)
787
- targets_flat = y_batch.view(-1).long()
788
- if not torch.isfinite(logits_flat).all():
789
- continue
790
- loss = ce_criterion(logits_flat, targets_flat)
791
- else:
792
- x_in = self._encode_multilabel_inputs(y_batch) # (B, L, 2)
793
- logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
794
- if not torch.isfinite(logits).all():
795
- continue
796
- pos_w = getattr(self, "pos_weights_", None)
797
- targets = self._multi_hot_targets(y_batch) # float, same shape
798
- bce = F.binary_cross_entropy_with_logits(
799
- logits, targets, pos_weight=pos_w, reduction="none"
864
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
865
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
866
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
867
+
868
+ if (
869
+ X_batch.dim() != 3
870
+ or X_batch.shape[1] != nF_model
871
+ or X_batch.shape[2] != nC_model
872
+ ):
873
+ msg = (
874
+ f"Train batch X shape mismatch: expected (B,{nF_model},{nC_model}), "
875
+ f"got {tuple(X_batch.shape)}."
800
876
  )
801
- mask = (y_batch != -1).unsqueeze(-1).float()
802
- loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
877
+ self.logger.error(msg)
878
+ raise ValueError(msg)
879
+
880
+ logits_flat = model(X_batch)
881
+ expected = (X_batch.shape[0], nF_model * nC_model)
882
+ if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
883
+ try:
884
+ logits_flat = logits_flat.view(*expected)
885
+ except Exception as e:
886
+ msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
887
+ self.logger.error(msg)
888
+ raise ValueError(msg) from e
889
+
890
+ logits = logits_flat.view(-1, nF_model, nC_model)
891
+ logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
892
+
893
+ targets_masked = y_batch.view(-1)
894
+ targets_masked = targets_masked[m_batch.view(-1)]
895
+
896
+ loss = ce_criterion(logits_masked, targets_masked)
803
897
 
804
898
  if l1_penalty > 0:
805
899
  l1 = torch.zeros((), device=self.device)
@@ -807,247 +901,234 @@ class ImputeAutoencoder(BaseNNImputer):
807
901
  l1 = l1 + p.abs().sum()
808
902
  loss = loss + l1_penalty * l1
809
903
 
810
- # Final guard
811
904
  if not torch.isfinite(loss):
812
905
  continue
813
906
 
814
907
  loss.backward()
815
-
816
- # Clip to prevent exploding grads
817
908
  torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
818
-
819
- # If grads blew up to non-finite, skip update
820
- if any(
821
- (not torch.isfinite(p.grad).all())
822
- for p in model.parameters()
823
- if p.grad is not None
824
- ):
825
- optimizer.zero_grad(set_to_none=True)
826
- continue
827
-
828
909
  optimizer.step()
829
910
 
830
911
  running += float(loss.detach().item())
831
912
  num_batches += 1
832
913
 
833
- if num_batches == 0:
834
- return float("inf") # signal upstream that epoch had no usable batches
835
- return running / num_batches
914
+ return float("inf") if num_batches == 0 else running / num_batches
915
+
916
+ def _val_step(
917
+ self,
918
+ loader: torch.utils.data.DataLoader,
919
+ model: torch.nn.Module,
920
+ ce_criterion: torch.nn.Module,
921
+ *,
922
+ l1_penalty: float,
923
+ ) -> float:
924
+ """Validation step (masked focal CE + optional L1).
925
+
926
+ Args:
927
+ loader (torch.utils.data.DataLoader): Validation data loader.
928
+ model (torch.nn.Module): Autoencoder model.
929
+ ce_criterion (torch.nn.Module): Cross-entropy loss function.
930
+ l1_penalty (float): L1 regularization coefficient.
931
+
932
+ Returns:
933
+ float: Average validation loss over the epoch.
934
+ """
935
+ model.eval()
936
+ running = 0.0
937
+ num_batches = 0
938
+
939
+ nF_model = self.num_features_
940
+ nC_model = self.num_classes_
941
+ l1_params = tuple(p for p in model.parameters() if p.requires_grad)
942
+
943
+ with torch.no_grad():
944
+ for X_batch, y_batch, m_batch in loader:
945
+ X_batch = X_batch.to(self.device, non_blocking=True).float()
946
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
947
+ m_batch = m_batch.to(self.device, non_blocking=True).bool()
948
+
949
+ logits_flat = model(X_batch)
950
+ expected = (X_batch.shape[0], nF_model * nC_model)
951
+
952
+ if logits_flat.dim() != 2 or tuple(logits_flat.shape) != expected:
953
+ try:
954
+ logits_flat = logits_flat.view(*expected)
955
+ except Exception as e:
956
+ msg = f"Model logits expected shape {expected}, got {tuple(logits_flat.shape)}."
957
+ self.logger.error(msg)
958
+ raise ValueError(msg) from e
959
+
960
+ logits = logits_flat.view(-1, nF_model, nC_model)
961
+ logits_masked = logits.view(-1, nC_model)[m_batch.view(-1)]
962
+ targets_masked = y_batch.view(-1)[m_batch.view(-1)]
963
+
964
+ if targets_masked.numel() == 0:
965
+ continue
966
+
967
+ loss = ce_criterion(logits_masked, targets_masked)
968
+
969
+ if l1_penalty > 0:
970
+ l1 = torch.zeros((), device=self.device)
971
+ for p in l1_params:
972
+ l1 = l1 + p.abs().sum()
973
+ loss = loss + l1_penalty * l1
974
+
975
+ if not torch.isfinite(loss):
976
+ continue
977
+
978
+ running += float(loss.item())
979
+ num_batches += 1
980
+
981
+ return float("inf") if num_batches == 0 else running / num_batches
836
982
 
837
983
  def _predict(
838
984
  self,
839
985
  model: torch.nn.Module,
840
986
  X: np.ndarray | torch.Tensor,
987
+ *,
841
988
  return_proba: bool = False,
842
- ) -> Tuple[np.ndarray, np.ndarray] | np.ndarray:
843
- """Predict 0/1/2 labels (and probabilities) from masked inputs.
844
-
845
- This method generates predictions from the trained autoencoder model for the provided input data. It processes the input data, performs a forward pass through the model, and computes the predicted genotype labels (0, 1, or 2) along with their associated probabilities if requested.
989
+ ) -> tuple[np.ndarray, np.ndarray | None]:
990
+ """Predict categorical genotype labels from logits.
846
991
 
847
992
  Args:
848
993
  model (torch.nn.Module): Trained model.
849
- X (np.ndarray | torch.Tensor): 0/1/2 matrix with -1
850
- for missing.
851
- return_proba (bool): If True, return probabilities.
994
+ X (np.ndarray | torch.Tensor): 2D 0/1/2 matrix with -1 for missing, or 3D one-hot (B, L, K).
995
+ return_proba (bool): If True, return probabilities (B, L, K).
852
996
 
853
997
  Returns:
854
- Tuple[np.ndarray, np.ndarray] | np.ndarray: Predicted labels,
855
- and probabilities if requested.
998
+ tuple[np.ndarray, np.ndarray | None]: Predicted labels and optionally probabilities.
856
999
  """
857
1000
  if model is None:
858
1001
  msg = "Model is not trained. Call fit() before predict()."
859
1002
  self.logger.error(msg)
860
1003
  raise NotFittedError(msg)
861
1004
 
1005
+ nF = self.num_features_
1006
+ nC = self.num_classes_
1007
+
1008
+ if isinstance(X, torch.Tensor):
1009
+ X_tensor = X
1010
+ else:
1011
+ X_tensor = torch.from_numpy(X)
1012
+ X_tensor = X_tensor.float()
1013
+
1014
+ if X_tensor.device != self.device:
1015
+ X_tensor = X_tensor.to(self.device)
1016
+
1017
+ if X_tensor.dim() == 2:
1018
+ # 0/1/2 matrix -> one-hot for model input
1019
+ X_tensor = self._one_hot_encode_012(X_tensor, num_classes=nC)
1020
+ X_tensor = X_tensor.float()
1021
+ if X_tensor.device != self.device:
1022
+ X_tensor = X_tensor.to(self.device)
1023
+
1024
+ elif X_tensor.dim() != 3:
1025
+ msg = f"_predict expects 2D 0/1/2 inputs or 3D one-hot inputs; got shape {tuple(X_tensor.shape)}."
1026
+ self.logger.error(msg)
1027
+ raise ValueError(msg)
1028
+
1029
+ if X_tensor.shape[1] != nF or X_tensor.shape[2] != nC:
1030
+ msg = f"_predict input shape mismatch: expected (B, {nF}, {nC}), got {tuple(X_tensor.shape)}."
1031
+ self.logger.error(msg)
1032
+ raise ValueError(msg)
1033
+
1034
+ X_tensor = X_tensor.reshape(X_tensor.shape[0], nF * nC)
1035
+
862
1036
  model.eval()
863
1037
  with torch.no_grad():
864
- X_tensor = torch.from_numpy(X) if isinstance(X, np.ndarray) else X
865
- X_tensor = X_tensor.to(self.device).long()
866
- if self.is_haploid:
867
- x_ohe = self._one_hot_encode_012(X_tensor)
868
- logits = model(x_ohe).view(-1, self.num_features_, self.output_classes_)
869
- probas = torch.softmax(logits, dim=-1)
870
- labels = torch.argmax(probas, dim=-1)
871
- else:
872
- x_in = self._encode_multilabel_inputs(X_tensor)
873
- logits = model(x_in).view(-1, self.num_features_, self.output_classes_)
874
- probas_2 = torch.sigmoid(logits)
875
- p_ref = probas_2[..., 0]
876
- p_alt = probas_2[..., 1]
877
- p_het = p_ref * p_alt
878
- p_ref_only = p_ref * (1 - p_alt)
879
- p_alt_only = p_alt * (1 - p_ref)
880
- stacked = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
881
- stacked = stacked / stacked.sum(dim=-1, keepdim=True).clamp_min(1e-8)
882
- probas = stacked
883
- labels = torch.argmax(stacked, dim=-1)
1038
+ logits_flat = model(X_tensor)
1039
+ logits = logits_flat.view(-1, nF, nC)
1040
+
1041
+ probas = torch.softmax(logits, dim=-1)
1042
+ labels = torch.argmax(probas, dim=-1)
884
1043
 
885
1044
  if return_proba:
886
1045
  return labels.cpu().numpy(), probas.cpu().numpy()
887
-
888
- return labels.cpu().numpy()
889
-
890
- def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
891
- """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
892
- if self.is_haploid:
893
- return self._one_hot_encode_012(y)
894
- y = y.to(self.device)
895
- shape = y.shape + (2,)
896
- out = torch.zeros(shape, device=self.device, dtype=torch.float32)
897
- valid = y != -1
898
- ref_mask = valid & (y != 2)
899
- alt_mask = valid & (y != 0)
900
- out[ref_mask, 0] = 1.0
901
- out[alt_mask, 1] = 1.0
902
- return out
903
-
904
- def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
905
- """Targets aligned with _encode_multilabel_inputs for diploid training."""
906
- if self.is_haploid:
907
- # One-hot CE path expects integer targets; handled upstream.
908
- raise RuntimeError("_multi_hot_targets called for haploid data.")
909
- y = y.to(self.device)
910
- out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
911
- valid = y != -1
912
- ref_mask = valid & (y != 2)
913
- alt_mask = valid & (y != 0)
914
- out[ref_mask, 0] = 1.0
915
- out[alt_mask, 1] = 1.0
916
- return out
917
-
918
- def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
919
- """Balance REF/ALT channels for multilabel BCE."""
920
- ref_pos = np.count_nonzero((X == 0) | (X == 1))
921
- alt_pos = np.count_nonzero((X == 2) | (X == 1))
922
- total_valid = np.count_nonzero(X != -1)
923
- pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
924
- neg_counts = np.maximum(total_valid - pos_counts, 1.0)
925
- pos_counts = np.maximum(pos_counts, 1.0)
926
- weights = neg_counts / pos_counts
927
- return torch.tensor(weights, device=self.device, dtype=torch.float32)
1046
+ return labels.cpu().numpy(), None
928
1047
 
929
1048
  def _evaluate_model(
930
1049
  self,
931
- X_val: np.ndarray,
932
1050
  model: torch.nn.Module,
933
- params: dict,
934
- objective_mode: bool = False,
935
- latent_vectors_val: Optional[np.ndarray] = None,
1051
+ X: np.ndarray,
1052
+ y: np.ndarray,
1053
+ eval_mask: np.ndarray,
936
1054
  *,
937
- eval_mask_override: np.ndarray | None = None,
1055
+ objective_mode: bool = False,
938
1056
  ) -> Dict[str, float]:
939
1057
  """Evaluate on 0/1/2; then IUPAC decoding and 10-base integer reports.
940
1058
 
941
- This method evaluates the trained autoencoder model on a validation set, computing various classification metrics based on the predicted and true genotypes. It handles both haploid and diploid data appropriately and generates detailed classification reports for both genotype and IUPAC/10-base integer encodings.
942
-
943
1059
  Args:
944
- X_val (np.ndarray): Validation set 0/1/2 matrix with -1
945
- for missing.
946
1060
  model (torch.nn.Module): Trained model.
947
- params (dict): Model parameters.
948
- objective_mode (bool): If True, suppress logging and reports.
949
- latent_vectors_val (Optional[np.ndarray]): Unused for AE.
950
- eval_mask_override (np.ndarray | None): Optional mask to override default evaluation mask.
1061
+ X (np.ndarray): 2D 0/1/2 matrix with -1 for missing.
1062
+ y (np.ndarray): 2D 0/1/2 ground truth matrix with -1 for missing.
1063
+ eval_mask (np.ndarray): 2D boolean mask selecting sites to evaluate.
1064
+ objective_mode (bool): If True, suppress detailed reports and plots.
951
1065
 
952
1066
  Returns:
953
1067
  Dict[str, float]: Dictionary of evaluation metrics.
954
1068
  """
955
- pred_labels, pred_probas = self._predict(
956
- model=model, X=X_val, return_proba=True
957
- )
1069
+ if model is None:
1070
+ msg = "Model passed to _evaluate_model() is not fitted. Call fit() before evaluation."
1071
+ self.logger.error(msg)
1072
+ raise NotFittedError(msg)
958
1073
 
959
- finite_mask = np.all(np.isfinite(pred_probas), axis=-1) # (N, L)
960
-
961
- # FIX 1: Check ROWS (shape[0]) only. X_val might be a feature subset.
962
- if (
963
- hasattr(self, "X_val_")
964
- and getattr(self, "X_val_", None) is not None
965
- and X_val.shape[0] == self.X_val_.shape[0]
966
- ):
967
- GT_ref = getattr(self, "GT_test_full_", self.ground_truth_)
968
- elif (
969
- hasattr(self, "X_train_")
970
- and getattr(self, "X_train_", None) is not None
971
- and X_val.shape[0] == self.X_train_.shape[0]
972
- ):
973
- GT_ref = getattr(self, "GT_train_full_", self.ground_truth_)
974
- else:
975
- GT_ref = self.ground_truth_
976
-
977
- # FIX 2: Handle Feature Mismatch (e.g., tune_fast feature subsetting)
978
- # If the GT source has more columns than X_val, slice it to match.
979
- if GT_ref.shape[1] > X_val.shape[1]:
980
- GT_ref = GT_ref[:, : X_val.shape[1]]
981
-
982
- # Fallback if rows mismatch (unlikely after Fix 1, but safe to keep)
983
- if GT_ref.shape != X_val.shape:
984
- # If completely different, we can't use the ground truth object.
985
- # Fall back to X_val (this implies only observed values are scored)
986
- GT_ref = X_val
987
-
988
- if eval_mask_override is not None:
989
- # FIX 3: Allow override mask to be sliced if it's too wide
990
- if eval_mask_override.shape[0] != X_val.shape[0]:
991
- msg = (
992
- f"eval_mask_override rows {eval_mask_override.shape[0]} "
993
- f"does not match X_val rows {X_val.shape[0]}"
994
- )
995
- self.logger.error(msg)
996
- raise ValueError(msg)
1074
+ pred_labels, pred_probas = self._predict(model=model, X=X, return_proba=True)
997
1075
 
998
- if eval_mask_override.shape[1] > X_val.shape[1]:
999
- eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
1000
- else:
1001
- eval_mask = eval_mask_override.astype(bool)
1002
- else:
1003
- eval_mask = X_val != -1
1076
+ if pred_probas is None:
1077
+ msg = "Predicted probabilities are None in _evaluate_model()."
1078
+ self.logger.error(msg)
1079
+ raise ValueError(msg)
1004
1080
 
1005
- # Combine masks
1006
- eval_mask = eval_mask & finite_mask & (GT_ref != -1)
1081
+ y_true_flat = y[eval_mask].astype(np.int8, copy=False)
1082
+ y_pred_flat = pred_labels[eval_mask].astype(np.int8, copy=False)
1083
+ y_proba_flat = pred_probas[eval_mask].astype(np.float32, copy=False)
1007
1084
 
1008
- y_true_flat = GT_ref[eval_mask].astype(np.int64, copy=False)
1009
- y_pred_flat = pred_labels[eval_mask].astype(np.int64, copy=False)
1010
- y_proba_flat = pred_probas[eval_mask].astype(np.float64, copy=False)
1085
+ valid = y_true_flat >= 0
1086
+ y_true_flat = y_true_flat[valid]
1087
+ y_pred_flat = y_pred_flat[valid]
1088
+ y_proba_flat = y_proba_flat[valid]
1011
1089
 
1012
1090
  if y_true_flat.size == 0:
1013
- self.tune_metric = "f1" if self.tune_metric is None else self.tune_metric
1014
1091
  return {self.tune_metric: 0.0}
1015
1092
 
1016
- # ensure valid probability simplex after masking (no NaNs/Infs, sums=1)
1093
+ if y_proba_flat.ndim != 2:
1094
+ msg = f"Expected y_proba_flat to be 2D (n_eval, n_classes); got {y_proba_flat.shape}."
1095
+ self.logger.error(msg)
1096
+ raise ValueError(msg)
1097
+
1098
+ K = int(y_proba_flat.shape[1])
1099
+ if self.is_haploid_:
1100
+ if K not in (2, 3):
1101
+ msg = f"Haploid evaluation expects 2 or 3 classes; got {K}."
1102
+ self.logger.error(msg)
1103
+ raise ValueError(msg)
1104
+ else:
1105
+ if K != 3:
1106
+ msg = f"Diploid evaluation expects 3 classes; got {K}."
1107
+ self.logger.error(msg)
1108
+ raise ValueError(msg)
1109
+
1110
+ if self.is_haploid_:
1111
+ y_true_flat = (y_true_flat > 0).astype(np.int8, copy=False)
1112
+ y_pred_flat = (y_pred_flat > 0).astype(np.int8, copy=False)
1113
+
1114
+ if K == 3:
1115
+ proba_2 = np.empty((y_proba_flat.shape[0], 2), dtype=y_proba_flat.dtype)
1116
+ proba_2[:, 0] = y_proba_flat[:, 0]
1117
+ proba_2[:, 1] = y_proba_flat[:, 1] + y_proba_flat[:, 2]
1118
+ y_proba_flat = proba_2
1119
+
1120
+ labels_for_scoring = [0, 1]
1121
+ target_names = ["REF", "ALT"]
1122
+ else:
1123
+ labels_for_scoring = [0, 1, 2]
1124
+ target_names = ["REF", "HET", "ALT"]
1125
+
1017
1126
  y_proba_flat = np.clip(y_proba_flat, 0.0, 1.0)
1018
1127
  row_sums = y_proba_flat.sum(axis=1, keepdims=True)
1019
- row_sums[row_sums == 0] = 1.0
1128
+ row_sums[row_sums == 0.0] = 1.0
1020
1129
  y_proba_flat = y_proba_flat / row_sums
1021
1130
 
1022
- labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
1023
- target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
1024
-
1025
- if self.is_haploid:
1026
- y_true_flat = y_true_flat.copy()
1027
- y_pred_flat = y_pred_flat.copy()
1028
- y_true_flat[y_true_flat == 2] = 1
1029
- y_pred_flat[y_pred_flat == 2] = 1
1030
- # collapse probs to 2-class
1031
- proba_2 = np.zeros((len(y_proba_flat), 2), dtype=y_proba_flat.dtype)
1032
- proba_2[:, 0] = y_proba_flat[:, 0]
1033
- proba_2[:, 1] = y_proba_flat[:, 2]
1034
- y_proba_flat = proba_2
1035
-
1036
- y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
1037
-
1038
- tune_metric_tmp: Literal[
1039
- "pr_macro",
1040
- "roc_auc",
1041
- "average_precision",
1042
- "accuracy",
1043
- "f1",
1044
- "precision",
1045
- "recall",
1046
- ]
1047
- if self.tune_metric_ is not None:
1048
- tune_metric_tmp = self.tune_metric_
1049
- else:
1050
- tune_metric_tmp = "f1" # Default if not tuning
1131
+ y_true_ohe = np.eye(len(labels_for_scoring), dtype=np.int8)[y_true_flat]
1051
1132
 
1052
1133
  metrics = self.scorers_.evaluate(
1053
1134
  y_true_flat,
@@ -1055,16 +1136,29 @@ class ImputeAutoencoder(BaseNNImputer):
1055
1136
  y_true_ohe,
1056
1137
  y_proba_flat,
1057
1138
  objective_mode,
1058
- tune_metric_tmp,
1139
+ cast(
1140
+ Literal[
1141
+ "pr_macro",
1142
+ "roc_auc",
1143
+ "accuracy",
1144
+ "f1",
1145
+ "average_precision",
1146
+ "precision",
1147
+ "recall",
1148
+ "mcc",
1149
+ "jaccard",
1150
+ ],
1151
+ self.tune_metric,
1152
+ ),
1059
1153
  )
1060
1154
 
1061
1155
  if not objective_mode:
1062
- pm = PrettyMetrics(
1063
- metrics, precision=3, title=f"{self.model_name} Validation Metrics"
1064
- )
1065
- pm.render() # prints a command-line table
1156
+ if self.verbose or self.debug:
1157
+ pm = PrettyMetrics(
1158
+ metrics, precision=2, title=f"{self.model_name} Validation Metrics"
1159
+ )
1160
+ pm.render()
1066
1161
 
1067
- # Primary report (REF/HET/ALT or REF/ALT)
1068
1162
  self._make_class_reports(
1069
1163
  y_true=y_true_flat,
1070
1164
  y_pred_proba=y_proba_flat,
@@ -1073,18 +1167,15 @@ class ImputeAutoencoder(BaseNNImputer):
1073
1167
  labels=target_names,
1074
1168
  )
1075
1169
 
1076
- # IUPAC decode & 10-base integer reports
1077
- # Now safe because GT_ref has been sliced to match X_val dimensions
1078
- y_true_dec = self.pgenc.decode_012(
1079
- GT_ref.reshape(X_val.shape[0], X_val.shape[1])
1080
- )
1081
- X_pred = X_val.copy()
1082
- X_pred[eval_mask] = y_pred_flat
1170
+ y_true_matrix = np.array(y, copy=True)
1171
+ y_pred_matrix = np.array(pred_labels, copy=True)
1083
1172
 
1084
- # Use X_val.shape[1] (current features) not self.num_features_ (original features)
1085
- y_pred_dec = self.pgenc.decode_012(
1086
- X_pred.reshape(X_val.shape[0], X_val.shape[1])
1087
- )
1173
+ if self.is_haploid_:
1174
+ y_true_matrix = np.where(y_true_matrix > 0, 2, y_true_matrix)
1175
+ y_pred_matrix = np.where(y_pred_matrix > 0, 2, y_pred_matrix)
1176
+
1177
+ y_true_dec = self.decode_012(y_true_matrix)
1178
+ y_pred_dec = self.decode_012(y_pred_matrix)
1088
1179
 
1089
1180
  encodings_dict = {
1090
1181
  "A": 0,
@@ -1123,239 +1214,177 @@ class ImputeAutoencoder(BaseNNImputer):
1123
1214
  return metrics
1124
1215
 
1125
1216
  def _objective(self, trial: optuna.Trial) -> float:
1126
- """Optuna objective for AE; mirrors NLPCA study driver without latents.
1127
-
1128
- This method defines the objective function for hyperparameter tuning using Optuna. It samples hyperparameters, prepares the training and validation data, builds and trains the autoencoder model, and evaluates its performance on the validation set. The method returns the value of the tuning metric to be maximized.
1217
+ """Optuna objective for AE (mirrors VAE flow, excluding KL-specific parts).
1129
1218
 
1130
1219
  Args:
1131
- trial (optuna.Trial): Optuna trial.
1220
+ trial (optuna.Trial): Optuna trial object.
1132
1221
 
1133
1222
  Returns:
1134
- float: Value of the tuning metric (maximize).
1223
+ float: Value of the tuning metric to optimize.
1135
1224
  """
1136
1225
  try:
1137
- # Sample hyperparameters (existing helper; unchanged signature)
1138
1226
  params = self._sample_hyperparameters(trial)
1139
1227
 
1140
- # Optionally sub-sample for fast tuning (same keys used by NLPCA if you adopt them)
1141
- X_train = getattr(self, "X_train_", self.ground_truth_[self.train_idx_])
1142
- X_val = getattr(self, "X_val_", self.ground_truth_[self.test_idx_])
1143
-
1144
- class_weights = self._normalize_class_weights(
1145
- self._class_weights_from_zygosity(X_train)
1146
- )
1147
- train_loader = self._get_data_loaders(X_train)
1148
-
1149
1228
  model = self.build_model(self.Model, params["model_params"])
1150
1229
  model.apply(self.initialize_weights)
1151
1230
 
1152
- lr: float = float(params["lr"])
1153
- l1_penalty: float = float(params["l1_penalty"])
1231
+ lr = float(params["learning_rate"])
1232
+ l1_penalty = float(params["l1_penalty"])
1233
+
1234
+ class_weights = self._class_weights_from_zygosity(
1235
+ self.y_train_,
1236
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1237
+ inverse=params["inverse"],
1238
+ normalize=params["normalize"],
1239
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1240
+ power=params["power"],
1241
+ )
1154
1242
 
1155
- # Train + prune on metric
1156
- _, model, __ = self._train_and_validate_model(
1243
+ loss, model, _hist = self._train_and_validate_model(
1157
1244
  model=model,
1158
- loader=train_loader,
1159
1245
  lr=lr,
1160
1246
  l1_penalty=l1_penalty,
1247
+ params=params,
1161
1248
  trial=trial,
1162
- return_history=False,
1163
1249
  class_weights=class_weights,
1164
- X_val=X_val,
1165
- params=params,
1166
- prune_metric=self.tune_metric,
1167
- prune_warmup_epochs=10,
1168
- eval_interval=self.tune_eval_interval,
1169
- eval_requires_latents=False,
1170
- eval_latent_steps=0,
1171
- eval_latent_lr=0.0,
1172
- eval_latent_weight_decay=0.0,
1250
+ gamma_schedule=params["gamma_schedule"],
1173
1251
  )
1174
1252
 
1175
- eval_mask = (
1176
- self.sim_mask_test_
1177
- if (
1178
- self.simulate_missing
1179
- and getattr(self, "sim_mask_test_", None) is not None
1180
- )
1181
- else None
1182
- )
1253
+ if model is None or not np.isfinite(loss):
1254
+ msg = "Model training returned None or non-finite loss in tuning objective."
1255
+ self.logger.error(msg)
1256
+ raise RuntimeError(msg)
1183
1257
 
1184
- if model is not None:
1185
- metrics = self._evaluate_model(
1186
- X_val,
1187
- model,
1188
- params,
1189
- objective_mode=True,
1190
- eval_mask_override=eval_mask,
1191
- )
1192
- self._clear_resources(model, train_loader)
1193
- else:
1194
- raise TypeError("Model training failed; no model was returned.")
1258
+ metrics = self._evaluate_model(
1259
+ model=model,
1260
+ X=self.X_val_,
1261
+ y=self.y_val_,
1262
+ eval_mask=self.sim_mask_val_ & ~self.orig_mask_val_,
1263
+ objective_mode=True,
1264
+ )
1195
1265
 
1196
- return metrics[self.tune_metric]
1266
+ self._clear_resources(model)
1267
+ return float(metrics[self.tune_metric])
1197
1268
 
1198
1269
  except Exception as e:
1199
- # Keep sweeps moving if a trial fails
1200
- raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
1201
-
1202
- def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
1203
- """Sample AE hyperparameters and compute hidden sizes for model params.
1270
+ err_type = type(e).__name__
1271
+ self.logger.warning(
1272
+ f"Trial {trial.number} failed due to exception {err_type}: {e}"
1273
+ )
1274
+ self.logger.debug(traceback.format_exc())
1275
+ raise optuna.exceptions.TrialPruned(
1276
+ f"Trial {trial.number} failed due to an exception. {err_type}: {e}. "
1277
+ "Enable debug logging for full traceback."
1278
+ ) from e
1204
1279
 
1205
- This method samples hyperparameters for the autoencoder model using Optuna's trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
1280
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
1281
+ """Sample AE hyperparameters; hidden sizes mirror VAE helper (excluding KL).
1206
1282
 
1207
1283
  Args:
1208
1284
  trial (optuna.Trial): Optuna trial object.
1209
1285
 
1210
1286
  Returns:
1211
- Dict[str, int | float | str | bool]: Sampled hyperparameters and model_params.
1287
+ dict: Sampled hyperparameters.
1212
1288
  """
1213
1289
  params = {
1214
- "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1215
- "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1216
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
1217
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
1290
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
1291
+ "learning_rate": trial.suggest_float("learning_rate", 3e-6, 1e-3, log=True),
1292
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.025),
1293
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 20),
1218
1294
  "activation": trial.suggest_categorical(
1219
1295
  "activation", ["relu", "elu", "selu", "leaky_relu"]
1220
1296
  ),
1221
1297
  "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1222
1298
  "layer_scaling_factor": trial.suggest_float(
1223
- "layer_scaling_factor", 2.0, 4.0, step=0.5
1299
+ "layer_scaling_factor", 2.0, 10.0, step=0.025
1224
1300
  ),
1225
1301
  "layer_schedule": trial.suggest_categorical(
1226
1302
  "layer_schedule", ["pyramid", "linear"]
1227
1303
  ),
1304
+ "power": trial.suggest_float("power", 0.1, 2.0, step=0.1),
1305
+ "normalize": trial.suggest_categorical("normalize", [True, False]),
1306
+ "inverse": trial.suggest_categorical("inverse", [True, False]),
1307
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0, step=0.1),
1308
+ "gamma_schedule": trial.suggest_categorical(
1309
+ "gamma_schedule", [True, False]
1310
+ ),
1228
1311
  }
1229
1312
 
1230
- nF: int = self.num_features_
1231
- nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1313
+ nF = int(self.num_features_)
1314
+ nC = int(self.num_classes_)
1232
1315
  input_dim = nF * nC
1316
+
1233
1317
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1234
1318
  n_inputs=input_dim,
1235
- n_outputs=input_dim,
1319
+ n_outputs=nC,
1236
1320
  n_samples=len(self.train_idx_),
1237
- n_hidden=params["num_hidden_layers"],
1238
- alpha=params["layer_scaling_factor"],
1239
- schedule=params["layer_schedule"],
1321
+ n_hidden=int(params["num_hidden_layers"]),
1322
+ latent_dim=int(params["latent_dim"]),
1323
+ alpha=float(params["layer_scaling_factor"]),
1324
+ schedule=str(params["layer_schedule"]),
1240
1325
  )
1241
1326
 
1242
- # Keep the latent_dim as the first element,
1243
- # then the interior hidden widths.
1244
- # If there are no interior widths (very small nets),
1245
- # this still leaves [latent_dim].
1246
- hidden_only: list[int] = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1247
-
1248
1327
  params["model_params"] = {
1249
- "n_features": int(self.num_features_),
1250
- "num_classes": int(
1251
- getattr(self, "output_classes_", self.num_classes_ or 3)
1252
- ),
1328
+ "n_features": nF,
1329
+ "num_classes": nC,
1253
1330
  "latent_dim": int(params["latent_dim"]),
1254
1331
  "dropout_rate": float(params["dropout_rate"]),
1255
- "hidden_layer_sizes": hidden_only,
1332
+ "hidden_layer_sizes": hidden_layer_sizes,
1256
1333
  "activation": str(params["activation"]),
1257
1334
  }
1258
1335
  return params
1259
1336
 
1260
- def _set_best_params(
1261
- self, best_params: Dict[str, int | float | str | List[int]]
1262
- ) -> Dict[str, int | float | str | List[int]]:
1263
- """Adopt best params (ImputeNLPCA parity) and return model_params.
1264
-
1265
- This method sets the best hyperparameters found during tuning and computes the hidden layer sizes for the autoencoder model. It prepares the final model parameters dictionary to be used for building the model.
1337
+ def _set_best_params(self, params: dict) -> dict:
1338
+ """Update instance fields from tuned params and return model_params dict.
1266
1339
 
1267
1340
  Args:
1268
- best_params (Dict[str, int | float | str | List[int]]): Best hyperparameters from tuning.
1341
+ params (dict): Best hyperparameters from tuning.
1269
1342
 
1270
1343
  Returns:
1271
- Dict[str, int | float | str | List[int]]: Model parameters for building the model.
1344
+ dict: Model parameters for building the final model.
1272
1345
  """
1273
- bp = {}
1274
- for k, v in best_params.items():
1275
- if not isinstance(v, list):
1276
- if k in {"latent_dim", "num_hidden_layers"}:
1277
- bp[k] = int(v)
1278
- elif k in {
1279
- "dropout_rate",
1280
- "learning_rate",
1281
- "l1_penalty",
1282
- "layer_scaling_factor",
1283
- }:
1284
- bp[k] = float(v)
1285
- elif k in {"activation", "layer_schedule"}:
1286
- if k == "layer_schedule":
1287
- if v not in {"pyramid", "constant", "linear"}:
1288
- raise ValueError(f"Invalid layer_schedule: {v}")
1289
- bp[k] = v
1290
- else:
1291
- bp[k] = str(v)
1292
- else:
1293
- bp[k] = v # keep lists as-is
1294
-
1295
- self.latent_dim: int = bp["latent_dim"]
1296
- self.dropout_rate: float = bp["dropout_rate"]
1297
- self.learning_rate: float = bp["learning_rate"]
1298
- self.l1_penalty: float = bp["l1_penalty"]
1299
- self.activation: str = bp["activation"]
1300
- self.layer_scaling_factor: float = bp["layer_scaling_factor"]
1301
- self.layer_schedule: str = bp["layer_schedule"]
1302
-
1303
- nF: int = self.num_features_
1304
- nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1305
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
1306
- n_inputs=nF * nC,
1307
- n_outputs=nF * nC,
1308
- n_samples=len(self.train_idx_),
1309
- n_hidden=bp["num_hidden_layers"],
1310
- alpha=bp["layer_scaling_factor"],
1311
- schedule=bp["layer_schedule"],
1346
+ self.latent_dim = int(params["latent_dim"])
1347
+ self.dropout_rate = float(params["dropout_rate"])
1348
+ self.learning_rate = float(params["learning_rate"])
1349
+ self.l1_penalty = float(params["l1_penalty"])
1350
+ self.activation = str(params["activation"])
1351
+ self.layer_scaling_factor = float(params["layer_scaling_factor"])
1352
+ self.layer_schedule = str(params["layer_schedule"])
1353
+
1354
+ self.power = float(params["power"])
1355
+ self.normalize = bool(params["normalize"])
1356
+ self.inverse = bool(params["inverse"])
1357
+ self.gamma = float(params["gamma"])
1358
+ self.gamma_schedule = bool(params["gamma_schedule"])
1359
+
1360
+ self.class_weights_ = self._class_weights_from_zygosity(
1361
+ self.y_train_,
1362
+ train_mask=self.sim_mask_train_ & ~self.orig_mask_train_,
1363
+ inverse=self.inverse,
1364
+ normalize=self.normalize,
1365
+ max_ratio=self.max_ratio if self.max_ratio is not None else 5.0,
1366
+ power=self.power,
1312
1367
  )
1313
1368
 
1314
- # Keep the latent_dim as the first element,
1315
- # then the interior hidden widths.
1316
- # If there are no interior widths (very small nets),
1317
- # this still leaves [latent_dim].
1318
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
1319
-
1320
- return {
1321
- "n_features": self.num_features_,
1322
- "latent_dim": self.latent_dim,
1323
- "hidden_layer_sizes": hidden_only,
1324
- "dropout_rate": self.dropout_rate,
1325
- "activation": self.activation,
1326
- "num_classes": nC,
1327
- }
1328
-
1329
- def _default_best_params(self) -> Dict[str, int | float | str | list]:
1330
- """Default model params when tuning is disabled.
1331
-
1332
- This method computes the default model parameters for the autoencoder when hyperparameter tuning is not performed. It calculates the hidden layer sizes based on the initial configuration.
1333
-
1334
- Returns:
1335
- Dict[str, int | float | str | list]: Default model parameters.
1336
- """
1337
- nF: int = self.num_features_
1338
- # Use the number of output channels passed to the model (2 for diploid multilabel)
1339
- # instead of the scoring classes (3) to keep layer shapes aligned.
1340
- nC: int = int(getattr(self, "output_classes_", self.num_classes_ or 3))
1341
- ls = self.layer_schedule
1342
-
1343
- if ls not in {"pyramid", "constant", "linear"}:
1344
- raise ValueError(f"Invalid layer_schedule: {ls}")
1369
+ nF = int(self.num_features_)
1370
+ nC = int(self.num_classes_)
1371
+ input_dim = nF * nC
1345
1372
 
1346
1373
  hidden_layer_sizes = self._compute_hidden_layer_sizes(
1347
- n_inputs=nF * nC,
1348
- n_outputs=nF * nC,
1349
- n_samples=len(self.ground_truth_),
1350
- n_hidden=self.num_hidden_layers,
1351
- alpha=self.layer_scaling_factor,
1352
- schedule=ls,
1374
+ n_inputs=input_dim,
1375
+ n_outputs=nC,
1376
+ n_samples=len(self.train_idx_),
1377
+ n_hidden=int(params["num_hidden_layers"]),
1378
+ latent_dim=int(params["latent_dim"]),
1379
+ alpha=float(params["layer_scaling_factor"]),
1380
+ schedule=str(params["layer_schedule"]),
1353
1381
  )
1382
+
1354
1383
  return {
1355
- "n_features": self.num_features_,
1384
+ "n_features": nF,
1385
+ "num_classes": nC,
1356
1386
  "latent_dim": self.latent_dim,
1357
1387
  "hidden_layer_sizes": hidden_layer_sizes,
1358
1388
  "dropout_rate": self.dropout_rate,
1359
1389
  "activation": self.activation,
1360
- "num_classes": nC,
1361
1390
  }