pg-sui 1.6.16a3__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/METADATA +26 -30
  2. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/RECORD +29 -33
  3. pgsui/__init__.py +0 -8
  4. pgsui/_version.py +2 -2
  5. pgsui/cli.py +577 -125
  6. pgsui/data_processing/config.py +1 -2
  7. pgsui/data_processing/containers.py +203 -530
  8. pgsui/data_processing/transformers.py +44 -20
  9. pgsui/impute/deterministic/imputers/mode.py +475 -182
  10. pgsui/impute/deterministic/imputers/ref_allele.py +454 -147
  11. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +4 -3
  12. pgsui/impute/supervised/imputers/random_forest.py +3 -2
  13. pgsui/impute/unsupervised/base.py +1269 -534
  14. pgsui/impute/unsupervised/callbacks.py +28 -33
  15. pgsui/impute/unsupervised/imputers/autoencoder.py +870 -841
  16. pgsui/impute/unsupervised/imputers/vae.py +931 -787
  17. pgsui/impute/unsupervised/loss_functions.py +156 -202
  18. pgsui/impute/unsupervised/models/autoencoder_model.py +7 -49
  19. pgsui/impute/unsupervised/models/vae_model.py +40 -221
  20. pgsui/impute/unsupervised/nn_scorers.py +53 -13
  21. pgsui/utils/classification_viz.py +240 -97
  22. pgsui/utils/misc.py +201 -3
  23. pgsui/utils/plotting.py +73 -58
  24. pgsui/utils/pretty_metrics.py +2 -6
  25. pgsui/utils/scorers.py +39 -0
  26. pgsui/impute/unsupervised/imputers/nlpca.py +0 -1666
  27. pgsui/impute/unsupervised/imputers/ubp.py +0 -1660
  28. pgsui/impute/unsupervised/models/nlpca_model.py +0 -206
  29. pgsui/impute/unsupervised/models/ubp_model.py +0 -200
  30. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/WHEEL +0 -0
  31. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/entry_points.txt +0 -0
  32. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/licenses/LICENSE +0 -0
  33. {pg_sui-1.6.16a3.dist-info → pg_sui-1.7.0.dist-info}/top_level.txt +0 -0
@@ -1,1660 +0,0 @@
1
- import copy
2
- from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
3
-
4
- from fastapi import params
5
- import numpy as np
6
- import optuna
7
- import torch
8
- import torch.nn.functional as F
9
- from sklearn.decomposition import PCA
10
- from sklearn.exceptions import NotFittedError
11
- from sklearn.model_selection import train_test_split
12
- from snpio.analysis.genotype_encoder import GenotypeEncoder
13
- from snpio.utils.logging import LoggerManager
14
-
15
- from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
16
- from pgsui.data_processing.containers import UBPConfig
17
- from pgsui.data_processing.transformers import SimMissingTransformer
18
- from pgsui.impute.unsupervised.base import BaseNNImputer
19
- from pgsui.impute.unsupervised.callbacks import EarlyStopping
20
- from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
21
- from pgsui.impute.unsupervised.models.ubp_model import UBPModel
22
- from pgsui.utils.logging_utils import configure_logger
23
- from pgsui.utils.pretty_metrics import PrettyMetrics
24
-
25
- if TYPE_CHECKING:
26
- from snpio import TreeParser
27
- from snpio.read_input.genotype_data import GenotypeData
28
-
29
-
30
- def ensure_ubp_config(config: UBPConfig | dict | str | None) -> UBPConfig:
31
- """Return a concrete UBPConfig from dataclass, dict, YAML path, or None.
32
-
33
- This method normalizes the input configuration for the UBP imputer. It accepts a UBPConfig instance, a dictionary, a YAML file path, or None. If None is provided, it returns a default UBPConfig instance. If a YAML path is given, it loads the configuration from the file, supporting top-level presets. If a dictionary is provided, it flattens any nested structures and applies dot-key overrides to a base configuration, which can also be influenced by a preset if specified. The method ensures that the final output is a fully populated UBPConfig instance.
34
-
35
- Args:
36
- config: UBPConfig | dict | YAML path | None.
37
-
38
- Returns:
39
- UBPConfig: Normalized configuration instance.
40
- """
41
- if config is None:
42
- return UBPConfig()
43
- if isinstance(config, UBPConfig):
44
- return config
45
- if isinstance(config, str):
46
- # YAML path — support top-level `preset`
47
- return load_yaml_to_dataclass(config, UBPConfig)
48
- if isinstance(config, dict):
49
- base = UBPConfig()
50
-
51
- def _flatten(prefix: str, d: dict, out: dict) -> dict:
52
- for k, v in d.items():
53
- kk = f"{prefix}.{k}" if prefix else k
54
- if isinstance(v, dict):
55
- _flatten(kk, v, out)
56
- else:
57
- out[kk] = v
58
- return out
59
-
60
- preset_name = config.pop("preset", None)
61
- if "io" in config and isinstance(config["io"], dict):
62
- preset_name = preset_name or config["io"].pop("preset", None)
63
- if preset_name:
64
- base = UBPConfig.from_preset(preset_name)
65
-
66
- flat = _flatten("", config, {})
67
- return apply_dot_overrides(base, flat)
68
-
69
- raise TypeError("config must be a UBPConfig, dict, YAML path, or None.")
70
-
71
-
72
- class ImputeUBP(BaseNNImputer):
73
- """UBP imputer for 0/1/2 genotypes with a three-phase decoder schedule.
74
-
75
- This imputer follows the training recipe from Unsupervised Backpropagation:
76
-
77
- 1. Phase 1 (joint warm start): Learn latent codes and the shallow linear decoder together.
78
- 2. Phase 2 (deep decoder reset): Reinitialize the deeper decoder, freeze the latent codes, and train only the decoder parameters.
79
- 3. Phase 3 (joint fine-tune): Unfreeze everything and jointly refine latent codes plus the deep decoder before evaluation/reporting.
80
-
81
- References:
82
- - Gashler, Michael S., Smith, Michael R., Morris, R., and Martinez, T. (2016) Missing Value Imputation with Unsupervised Backpropagation. Computational Intelligence, 32: 196-215. doi: 10.1111/coin.12048.
83
- """
84
-
85
- def __init__(
86
- self,
87
- genotype_data: "GenotypeData",
88
- *,
89
- tree_parser: Optional["TreeParser"] = None,
90
- config: UBPConfig | dict | str | None = None,
91
- overrides: dict | None = None,
92
- simulate_missing: bool | None = None,
93
- sim_strategy: (
94
- Literal[
95
- "random",
96
- "random_weighted",
97
- "random_weighted_inv",
98
- "nonrandom",
99
- "nonrandom_weighted",
100
- ]
101
- | None
102
- ) = None,
103
- sim_prop: float | None = None,
104
- sim_kwargs: dict | None = None,
105
- ):
106
- """Initialize the UBP imputer via dataclass/dict/YAML config with overrides.
107
-
108
- This constructor allows for flexible initialization of the UBP imputer by accepting various forms of configuration input. It ensures that the configuration is properly normalized and any specified overrides are applied. The method also sets up logging and initializes various attributes related to the model, training, tuning, and evaluation based on the provided configuration.
109
-
110
- Args:
111
- genotype_data (GenotypeData): Backing genotype data object.
112
- tree_parser: "TreeParser" | None = None, Optional SNPio phylogenetic tree parser for nonrandom sim_strategy modes.
113
- config (UBPConfig | dict | str | None): UBP configuration.
114
- overrides (dict | None): Flat dot-key overrides applied after `config`.
115
- simulate_missing (bool | None): Whether to simulate missing data during training.
116
- sim_strategy (Literal[...] | None): Simulated missing strategy if simulating.
117
- sim_prop (float | None): Proportion of data to simulate as missing if simulating.
118
- sim_kwargs (dict | None): Additional kwargs for SimMissingTransformer.
119
- """
120
- self.model_name = "ImputeUBP"
121
- self.genotype_data = genotype_data
122
- self.tree_parser = tree_parser
123
-
124
- # ---- normalize config, then apply overrides ----
125
- cfg = ensure_ubp_config(config)
126
- if overrides:
127
- cfg = apply_dot_overrides(cfg, overrides)
128
- self.cfg = cfg
129
-
130
- # ---- logging ----
131
- logman = LoggerManager(
132
- __name__,
133
- prefix=self.cfg.io.prefix,
134
- debug=self.cfg.io.debug,
135
- verbose=self.cfg.io.verbose,
136
- )
137
- self.logger = configure_logger(
138
- logman.get_logger(),
139
- verbose=self.cfg.io.verbose,
140
- debug=self.cfg.io.debug,
141
- )
142
-
143
- # ---- Base init ----
144
- super().__init__(
145
- model_name=self.model_name,
146
- genotype_data=self.genotype_data,
147
- prefix=self.cfg.io.prefix,
148
- device=self.cfg.train.device,
149
- verbose=self.cfg.io.verbose,
150
- debug=self.cfg.io.debug,
151
- )
152
-
153
- # ---- model/meta ----
154
- self.Model = UBPModel
155
- self.pgenc = GenotypeEncoder(genotype_data)
156
-
157
- self.seed = self.cfg.io.seed
158
- self.n_jobs = self.cfg.io.n_jobs
159
- self.prefix = self.cfg.io.prefix
160
- self.scoring_averaging = self.cfg.io.scoring_averaging
161
- self.verbose = self.cfg.io.verbose
162
- self.debug = self.cfg.io.debug
163
- self.rng = np.random.default_rng(self.seed)
164
- self.pos_weights_: torch.Tensor | None = None
165
-
166
- # Simulated-missing controls (config defaults w/ overrides)
167
- sim_cfg = getattr(self.cfg, "sim", None)
168
- sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
169
- if sim_kwargs:
170
- sim_cfg_kwargs.update(sim_kwargs)
171
- if sim_cfg is None:
172
- default_sim_flag = bool(simulate_missing)
173
- default_strategy = "random"
174
- default_prop = 0.10
175
- else:
176
- default_sim_flag = sim_cfg.simulate_missing
177
- default_strategy = sim_cfg.sim_strategy
178
- default_prop = sim_cfg.sim_prop
179
- self.simulate_missing = (
180
- default_sim_flag if simulate_missing is None else bool(simulate_missing)
181
- )
182
- self.sim_strategy = sim_strategy or default_strategy
183
- self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
184
- self.sim_kwargs = sim_cfg_kwargs
185
-
186
- # ---- model hyperparams ----
187
- self.latent_dim = self.cfg.model.latent_dim
188
- self.dropout_rate = self.cfg.model.dropout_rate
189
- self.num_hidden_layers = self.cfg.model.num_hidden_layers
190
- self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
191
- self.layer_schedule = self.cfg.model.layer_schedule
192
- self.latent_init: Literal["pca", "random"] = self.cfg.model.latent_init
193
- self.activation = self.cfg.model.hidden_activation
194
- self.gamma = self.cfg.model.gamma
195
-
196
- # ---- training ----
197
- self.batch_size = self.cfg.train.batch_size
198
- self.learning_rate = self.cfg.train.learning_rate
199
- self.lr_input_factor = self.cfg.train.lr_input_factor
200
- self.l1_penalty = self.cfg.train.l1_penalty
201
- self.early_stop_gen = self.cfg.train.early_stop_gen
202
- self.min_epochs = self.cfg.train.min_epochs
203
- self.epochs = self.cfg.train.max_epochs
204
- self.validation_split = self.cfg.train.validation_split
205
- self.beta = self.cfg.train.weights_beta
206
- self.max_ratio = self.cfg.train.weights_max_ratio
207
-
208
- # ---- tuning ----
209
- self.tune = self.cfg.tune.enabled
210
- self.tune_fast = self.cfg.tune.fast
211
- self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
212
- self.tune_batch_size = self.cfg.tune.batch_size
213
- self.tune_epochs = self.cfg.tune.epochs
214
- self.tune_eval_interval = self.cfg.tune.eval_interval
215
- self.tune_metric: Literal[
216
- "pr_macro",
217
- "f1",
218
- "accuracy",
219
- "average_precision",
220
- "precision",
221
- "recall",
222
- "roc_auc",
223
- ] = self.cfg.tune.metric
224
- self.n_trials = self.cfg.tune.n_trials
225
- self.tune_save_db = self.cfg.tune.save_db
226
- self.tune_resume = self.cfg.tune.resume
227
- self.tune_max_samples = self.cfg.tune.max_samples
228
- self.tune_max_loci = self.cfg.tune.max_loci
229
- self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
230
- self.tune_patience = self.cfg.tune.patience
231
-
232
- # ---- evaluation ----
233
- self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
234
- self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
235
- self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
236
-
237
- # ---- plotting ----
238
- self.plot_format = self.cfg.plot.fmt
239
- self.plot_dpi = self.cfg.plot.dpi
240
- self.plot_fontsize = self.cfg.plot.fontsize
241
- self.title_fontsize = self.cfg.plot.fontsize
242
- self.despine = self.cfg.plot.despine
243
- self.show_plots = self.cfg.plot.show
244
-
245
- # ---- core runtime ----
246
- self.is_haploid = False
247
- self.num_classes_ = False
248
- self.model_params: Dict[str, Any] = {}
249
- self.sim_mask_global_: np.ndarray | None = None
250
- self.sim_mask_train_: np.ndarray | None = None
251
- self.sim_mask_test_: np.ndarray | None = None
252
-
253
- if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
254
- msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
255
- self.logger.error(msg)
256
- raise ValueError(msg)
257
-
258
- def fit(self) -> "ImputeUBP":
259
- """Fit the UBP decoder on 0/1/2 encodings (missing = -1) via three phases.
260
-
261
- 1. Phase 1 initializes latent vectors alongside the linear decoder.
262
- 2. Phase 2 resets and trains the deeper decoder while latents remain fixed.
263
- 3. Phase 3 jointly fine-tunes latents plus the deep decoder before evaluation.
264
-
265
- Returns:
266
- ImputeUBP: Fitted instance.
267
-
268
- Raises:
269
- NotFittedError: If training fails.
270
- """
271
- self.logger.info(f"Fitting {self.model_name} model...")
272
-
273
- # --- Use 0/1/2 with -1 for missing ---
274
- X012 = self._get_float_genotypes(copy=True)
275
- GT_full = np.nan_to_num(X012, nan=-1.0, copy=True)
276
- self.ground_truth_ = GT_full.astype(np.int64, copy=False)
277
-
278
- cache_key = self._sim_mask_cache_key()
279
- self.sim_mask_global_ = None
280
- if self.simulate_missing:
281
- cached_mask = (
282
- None if cache_key is None else self._sim_mask_cache.get(cache_key)
283
- )
284
- if cached_mask is not None:
285
- self.sim_mask_global_ = cached_mask.copy()
286
- else:
287
- tr = SimMissingTransformer(
288
- genotype_data=self.genotype_data,
289
- tree_parser=self.tree_parser,
290
- prop_missing=self.sim_prop,
291
- strategy=self.sim_strategy,
292
- missing_val=-9,
293
- mask_missing=True,
294
- verbose=self.verbose,
295
- **self.sim_kwargs,
296
- )
297
- tr.fit(X012.copy())
298
- self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
299
- if cache_key is not None:
300
- self._sim_mask_cache[cache_key] = self.sim_mask_global_.copy()
301
-
302
- X_for_model = self.ground_truth_.copy()
303
- if self.sim_mask_global_ is not None:
304
- X_for_model[self.sim_mask_global_] = -1
305
-
306
- # --- Determine ploidy (haploid vs diploid) and classes ---
307
- self.is_haploid = bool(
308
- np.all(
309
- np.isin(
310
- self.genotype_data.snp_data,
311
- ["A", "C", "G", "T", "N", "-", ".", "?"],
312
- )
313
- )
314
- )
315
- self.ploidy = 1 if self.is_haploid else 2
316
-
317
- if self.is_haploid:
318
- self.num_classes_ = 2
319
- self.ground_truth_[self.ground_truth_ == 2] = 1
320
- X_for_model[X_for_model == 2] = 1
321
- self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
322
- else:
323
- self.num_classes_ = 3
324
- self.logger.info(
325
- "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring."
326
- )
327
- # Model head always uses two channels; scoring uses num_classes_
328
- self.output_classes_ = 2
329
-
330
- n_samples, self.num_features_ = X_for_model.shape
331
-
332
- # --- model params (decoder: Z -> L * num_classes) ---
333
- self.model_params = {
334
- "n_features": self.num_features_,
335
- "num_classes": self.output_classes_,
336
- "latent_dim": self.latent_dim,
337
- "dropout_rate": self.dropout_rate,
338
- "activation": self.activation,
339
- # hidden_layer_sizes injected later
340
- }
341
-
342
- # --- split ---
343
- indices = np.arange(n_samples)
344
- train_idx, test_idx = train_test_split(
345
- indices, test_size=self.validation_split, random_state=self.seed
346
- )
347
- self.train_idx_, self.test_idx_ = train_idx, test_idx
348
- self.X_train_ = X_for_model[train_idx]
349
- self.X_test_ = X_for_model[test_idx]
350
- self.GT_train_full_ = self.ground_truth_[train_idx]
351
- self.GT_test_full_ = self.ground_truth_[test_idx]
352
-
353
- if self.sim_mask_global_ is not None:
354
- self.sim_mask_train_ = self.sim_mask_global_[train_idx]
355
- self.sim_mask_test_ = self.sim_mask_global_[test_idx]
356
- else:
357
- self.sim_mask_train_ = None
358
- self.sim_mask_test_ = None
359
-
360
- # pos weights for diploid multilabel path
361
- if not self.is_haploid:
362
- self.pos_weights_ = self._compute_pos_weights(self.X_train_)
363
- else:
364
- self.pos_weights_ = None
365
-
366
- # --- plotting/scorers & tuning ---
367
- self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
368
- if self.tune:
369
- self.tune_hyperparameters()
370
-
371
- # Fall back to default model params when none have been selected yet.
372
- if not getattr(self, "best_params_", None):
373
- self.best_params_ = self._set_best_params_default()
374
-
375
- # --- class weights for 0/1/2 ---
376
- self.class_weights_ = self._normalize_class_weights(
377
- self._class_weights_from_zygosity(self.X_train_)
378
- )
379
-
380
- # --- latent init & loader ---
381
- train_latent_vectors = self._create_latent_space(
382
- self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
383
- )
384
- train_loader = self._get_data_loaders(self.X_train_)
385
-
386
- # --- final training (three-phase under the hood) ---
387
- (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
388
- self._train_final_model(
389
- loader=train_loader,
390
- best_params=self.best_params_,
391
- initial_latent_vectors=train_latent_vectors,
392
- )
393
- )
394
-
395
- self.is_fit_ = True
396
- self.plotter_.plot_history(self.history_)
397
- eval_mask = (
398
- self.sim_mask_test_
399
- if (self.simulate_missing and self.sim_mask_test_ is not None)
400
- else None
401
- )
402
- self._evaluate_model(
403
- self.X_test_,
404
- self.model_,
405
- self.best_params_,
406
- eval_mask_override=eval_mask,
407
- )
408
- self._save_best_params(self.best_params_)
409
- return self
410
-
411
- def transform(self) -> np.ndarray:
412
- """Impute missing genotypes (0/1/2) and return IUPAC strings.
413
-
414
- This method first checks if the model has been fitted. It then imputes the entire dataset by optimizing latent vectors for the ground truth data and predicting the missing genotypes using the trained UBP model. The imputed genotypes are decoded to IUPAC format, and genotype distributions are plotted only when ``self.show_plots`` is enabled.
415
-
416
- Returns:
417
- np.ndarray: IUPAC single-character array (n_samples x L).
418
-
419
- Raises:
420
- NotFittedError: If called before fit().
421
- """
422
- if not getattr(self, "is_fit_", False):
423
- raise NotFittedError("Model is not fitted. Call fit() before transform().")
424
-
425
- self.logger.info(f"Imputing entire dataset with {self.model_name}...")
426
- X_to_impute = self.ground_truth_.copy()
427
-
428
- optimized_latents = self._optimize_latents_for_inference(
429
- X_to_impute, self.model_, self.best_params_
430
- )
431
-
432
- if not isinstance(optimized_latents, torch.nn.Parameter):
433
- optimized_latents = torch.nn.Parameter(
434
- optimized_latents, requires_grad=False
435
- )
436
-
437
- pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
438
-
439
- missing_mask = X_to_impute == -1
440
- imputed_array = X_to_impute.copy()
441
- imputed_array[missing_mask] = pred_labels[missing_mask]
442
-
443
- # Decode to IUPAC for return & optional plots
444
- imputed_genotypes = self.pgenc.decode_012(imputed_array)
445
- if self.show_plots:
446
- original_genotypes = self.pgenc.decode_012(X_to_impute)
447
- self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
448
- self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
449
- return imputed_genotypes
450
-
451
- def _train_step(
452
- self,
453
- loader: torch.utils.data.DataLoader,
454
- optimizer: torch.optim.Optimizer,
455
- latent_optimizer: torch.optim.Optimizer,
456
- model: torch.nn.Module,
457
- l1_penalty: float,
458
- latent_vectors: torch.nn.Parameter,
459
- class_weights: torch.Tensor,
460
- phase: int,
461
- ) -> Tuple[float, torch.nn.Parameter]:
462
- """One epoch with stable focal CE, grad clipping, and NaN guards.
463
-
464
- Returns:
465
- Tuple[float, torch.nn.Parameter]: Mean loss and updated latents.
466
- """
467
- model.train()
468
- running, used = 0.0, 0
469
-
470
- if not isinstance(latent_vectors, torch.nn.Parameter):
471
- latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
472
-
473
- gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
474
- gamma = max(0.0, min(gamma, 10.0))
475
- l1_params = tuple(p for p in model.parameters() if p.requires_grad)
476
- if class_weights is not None and class_weights.device != self.device:
477
- class_weights = class_weights.to(self.device)
478
-
479
- criterion = SafeFocalCELoss(gamma=gamma, weight=class_weights, ignore_index=-1)
480
- decoder: torch.Tensor | torch.nn.Module = (
481
- model.phase1_decoder if phase == 1 else model.phase23_decoder
482
- )
483
-
484
- if not isinstance(decoder, torch.nn.Module):
485
- msg = f"{self.model_name} Decoder is not a torch.nn.Module."
486
- self.logger.error(msg)
487
- raise TypeError(msg)
488
-
489
- for batch_indices, y_batch in loader:
490
- optimizer.zero_grad(set_to_none=True)
491
- latent_optimizer.zero_grad(set_to_none=True)
492
-
493
- batch_indices = batch_indices.to(latent_vectors.device, non_blocking=True)
494
- z = latent_vectors[batch_indices]
495
- y = y_batch.to(self.device, non_blocking=True).long()
496
-
497
- logits = decoder(z).view(
498
- len(batch_indices), self.num_features_, self.output_classes_
499
- )
500
-
501
- # Guard upstream explosions
502
- if not torch.isfinite(logits).all():
503
- continue
504
-
505
- if self.is_haploid:
506
- loss = criterion(logits.view(-1, self.output_classes_), y.view(-1))
507
- else:
508
- targets = self._multi_hot_targets(y)
509
- bce = F.binary_cross_entropy_with_logits(
510
- logits, targets, pos_weight=self.pos_weights_, reduction="none"
511
- )
512
- mask = (y != -1).unsqueeze(-1).float()
513
- loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
514
-
515
- if l1_penalty > 0:
516
- l1 = torch.zeros((), device=self.device)
517
- for p in l1_params:
518
- l1 = l1 + p.abs().sum()
519
- loss = loss + l1_penalty * l1
520
-
521
- if not torch.isfinite(loss):
522
- continue
523
-
524
- loss.backward()
525
-
526
- # Clip returns the Total Norm
527
- model_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
528
- latent_norm = torch.nn.utils.clip_grad_norm_([latent_vectors], 1.0)
529
-
530
- # Skip update on non-finite grads
531
- # Check norms instead of iterating all parameters
532
- if torch.isfinite(model_norm) and torch.isfinite(latent_norm):
533
- optimizer.step()
534
- if phase != 2:
535
- latent_optimizer.step()
536
- else:
537
- # Logic to handle bad grads (zero out, skip, etc)
538
- optimizer.zero_grad(set_to_none=True)
539
- latent_optimizer.zero_grad(set_to_none=True)
540
-
541
- running += float(loss.detach().item())
542
- used += 1
543
-
544
- return (running / used if used > 0 else float("inf")), latent_vectors
545
-
546
- def _predict(
547
- self,
548
- model: torch.nn.Module,
549
- latent_vectors: Optional[torch.nn.Parameter | torch.Tensor] = None,
550
- ) -> Tuple[np.ndarray, np.ndarray]:
551
- """Predict 0/1/2 labels & probabilities from latents via phase23 decoder. This method requires a trained model and latent vectors.
552
-
553
- Args:
554
- model (torch.nn.Module): Trained model.
555
- latent_vectors (torch.nn.Parameter | None): Latent vectors.
556
-
557
- Returns:
558
- Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
559
- """
560
- if model is None or latent_vectors is None:
561
- msg = "Model and latent vectors must be provided for prediction. Fit the model first."
562
- self.logger.error(msg)
563
- raise NotFittedError(msg)
564
-
565
- model.eval()
566
- nF = getattr(model, "n_features", self.num_features_)
567
- with torch.no_grad():
568
- decoder = model.phase23_decoder
569
-
570
- if not isinstance(decoder, torch.nn.Module):
571
- msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
572
- self.logger.error(msg)
573
- raise TypeError(msg)
574
-
575
- logits = decoder(latent_vectors.to(self.device)).view(
576
- len(latent_vectors), nF, self.output_classes_
577
- )
578
- if self.is_haploid:
579
- probas = torch.softmax(logits, dim=-1)
580
- labels = torch.argmax(probas, dim=-1)
581
- else:
582
- probas2 = torch.sigmoid(logits)
583
- p_ref = probas2[..., 0]
584
- p_alt = probas2[..., 1]
585
- p_het = p_ref * p_alt
586
- p_ref_only = p_ref * (1 - p_alt)
587
- p_alt_only = p_alt * (1 - p_ref)
588
- probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
589
- probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
590
- labels = torch.argmax(probas, dim=-1)
591
-
592
- return labels.cpu().numpy(), probas.cpu().numpy()
593
-
594
- def _evaluate_model(
595
- self,
596
- X_val: np.ndarray,
597
- model: torch.nn.Module,
598
- params: dict,
599
- objective_mode: bool = False,
600
- latent_vectors_val: torch.Tensor | None = None,
601
- *,
602
- eval_mask_override: np.ndarray | None = None,
603
- ) -> Dict[str, float]:
604
- """Evaluates the model on a validation set.
605
-
606
- This method evaluates the trained UBP model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
607
-
608
- Args:
609
- X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
610
- model (torch.nn.Module): Trained UBP model.
611
- params (dict): Model parameters.
612
- objective_mode (bool): If True, suppresses logging and reports only the metric.
613
- latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
614
- eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
615
-
616
- Returns:
617
- Dict[str, float]: Dictionary of evaluation metrics.
618
- """
619
- if latent_vectors_val is not None:
620
- test_latent_vectors = latent_vectors_val
621
- else:
622
- test_latent_vectors = self._optimize_latents_for_inference(
623
- X_val, model, params
624
- )
625
-
626
- pred_labels, pred_probas = self._predict(
627
- model=model, latent_vectors=test_latent_vectors
628
- )
629
-
630
- if eval_mask_override is not None:
631
- # Validate row counts to allow feature subsetting during tuning
632
- if eval_mask_override.shape[0] != X_val.shape[0]:
633
- msg = (
634
- f"eval_mask_override rows {eval_mask_override.shape[0]} "
635
- f"does not match X_val rows {X_val.shape[0]}"
636
- )
637
- self.logger.error(msg)
638
- raise ValueError(msg)
639
-
640
- # FIX: Slice mask columns if override is wider than current X_val (tune_fast)
641
- if eval_mask_override.shape[1] > X_val.shape[1]:
642
- eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
643
- else:
644
- eval_mask = eval_mask_override.astype(bool)
645
- else:
646
- # Default: score only observed entries
647
- eval_mask = X_val != -1
648
-
649
- # y_true should be drawn from the pre-mask ground truth
650
- # Map X_val back to the correct full ground truth slice
651
- # FIX: Check shape[0] (n_samples) only.
652
- if X_val.shape[0] == self.X_test_.shape[0]:
653
- GT_ref = self.GT_test_full_
654
- elif X_val.shape[0] == self.X_train_.shape[0]:
655
- GT_ref = self.GT_train_full_
656
- else:
657
- GT_ref = self.ground_truth_
658
-
659
- # FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
660
- if GT_ref.shape[1] > X_val.shape[1]:
661
- GT_ref = GT_ref[:, : X_val.shape[1]]
662
-
663
- # Fallback safeguard
664
- if GT_ref.shape != X_val.shape:
665
- GT_ref = X_val
666
-
667
- y_true_flat = GT_ref[eval_mask]
668
- pred_labels_flat = pred_labels[eval_mask]
669
- pred_probas_flat = pred_probas[eval_mask]
670
-
671
- if y_true_flat.size == 0:
672
- return {self.tune_metric: 0.0}
673
-
674
- # For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
675
- labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
676
- target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
677
-
678
- y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
679
-
680
- metrics = self.scorers_.evaluate(
681
- y_true_flat,
682
- pred_labels_flat,
683
- y_true_ohe,
684
- pred_probas_flat,
685
- objective_mode,
686
- self.tune_metric,
687
- )
688
-
689
- if not objective_mode:
690
- pm = PrettyMetrics(
691
- metrics, precision=3, title=f"{self.model_name} Validation Metrics"
692
- )
693
- pm.render() # prints a command-line table
694
-
695
- self._make_class_reports(
696
- y_true=y_true_flat,
697
- y_pred_proba=pred_probas_flat,
698
- y_pred=pred_labels_flat,
699
- metrics=metrics,
700
- labels=target_names,
701
- )
702
-
703
- # FIX: Use X_val dimensions for reshaping, not self.num_features_
704
- y_true_dec = self.pgenc.decode_012(
705
- GT_ref.reshape(X_val.shape[0], X_val.shape[1])
706
- )
707
-
708
- X_pred = X_val.copy()
709
- X_pred[eval_mask] = pred_labels_flat
710
-
711
- y_pred_dec = self.pgenc.decode_012(
712
- X_pred.reshape(X_val.shape[0], X_val.shape[1])
713
- )
714
-
715
- encodings_dict = {
716
- "A": 0,
717
- "C": 1,
718
- "G": 2,
719
- "T": 3,
720
- "W": 4,
721
- "R": 5,
722
- "M": 6,
723
- "K": 7,
724
- "Y": 8,
725
- "S": 9,
726
- "N": -1,
727
- }
728
-
729
- y_true_int = self.pgenc.convert_int_iupac(
730
- y_true_dec, encodings_dict=encodings_dict
731
- )
732
- y_pred_int = self.pgenc.convert_int_iupac(
733
- y_pred_dec, encodings_dict=encodings_dict
734
- )
735
-
736
- # For IUPAC report
737
- valid_true = y_true_int[eval_mask]
738
- valid_true = valid_true[valid_true >= 0] # drop -1 (N)
739
- iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
740
-
741
- # For numeric report
742
- if (
743
- np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
744
- or valid_true.size == 0
745
- ):
746
- if not objective_mode:
747
- self.logger.warning(
748
- "Skipped numeric confusion matrix: no y_true labels present."
749
- )
750
- else:
751
- self._make_class_reports(
752
- y_true=valid_true,
753
- y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
754
- metrics=metrics,
755
- y_pred_proba=None,
756
- labels=iupac_label_set,
757
- )
758
-
759
- return metrics
760
-
761
- def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
762
- """Create DataLoader over indices + 0/1/2 target matrix.
763
-
764
- This method creates a PyTorch DataLoader for the given genotype matrix, which contains 0/1/2 encodings with -1 for missing values. The DataLoader is constructed to yield batches of data during training, where each batch consists of indices and the corresponding genotype values. The genotype matrix is converted to a PyTorch tensor and moved to the appropriate device (CPU or GPU) before being wrapped in a TensorDataset. The DataLoader is configured to shuffle the data and use the specified batch size.
765
-
766
- Args:
767
- y (np.ndarray): (n_samples x L) int matrix with -1 missing.
768
-
769
- Returns:
770
- torch.utils.data.DataLoader: Shuffled mini-batches.
771
- """
772
- y_tensor = torch.from_numpy(y).long()
773
- indices = torch.arange(len(y), dtype=torch.long)
774
- dataset = torch.utils.data.TensorDataset(indices, y_tensor)
775
- pin_memory = self.device.type == "cuda"
776
- return torch.utils.data.DataLoader(
777
- dataset,
778
- batch_size=self.batch_size,
779
- shuffle=True,
780
- pin_memory=pin_memory,
781
- )
782
-
783
- def _objective(self, trial: optuna.Trial) -> float:
784
- """Optuna objective using the UBP training loop.
785
-
786
- This method defines the objective function for hyperparameter tuning using Optuna. It prepares the necessary artifacts for tuning, samples a set of hyperparameters for the current trial, and trains the UBP model using these hyperparameters. The model is evaluated on a validation set, and the specified tuning metric is returned as the objective value. If any exception occurs during the process, the trial is pruned.
787
- """
788
- try:
789
- self._prepare_tuning_artifacts()
790
- trial_params = self._sample_hyperparameters(trial)
791
- model_params = trial_params["model_params"]
792
-
793
- nfeat = self._tune_num_features
794
- if self.tune and self.tune_fast:
795
- model_params["n_features"] = nfeat
796
-
797
- X_train_trial = getattr(
798
- self, "X_train_", self.ground_truth_[self.train_idx_]
799
- )
800
- X_test_trial = getattr(self, "X_test_", self.ground_truth_[self.test_idx_])
801
-
802
- class_weights = self._normalize_class_weights(
803
- self._class_weights_from_zygosity(X_train_trial)
804
- )
805
- if not self.is_haploid:
806
- self.pos_weights_ = self._compute_pos_weights(X_train_trial)
807
- else:
808
- self.pos_weights_ = None
809
- train_loader = self._get_data_loaders(X_train_trial)
810
-
811
- train_latent_vectors = self._create_latent_space(
812
- model_params,
813
- len(X_train_trial),
814
- X_train_trial,
815
- trial_params["latent_init"],
816
- )
817
-
818
- model = self.build_model(self.Model, model_params)
819
- model.n_features = model_params["n_features"]
820
- model.apply(self.initialize_weights)
821
-
822
- _, model, __ = self._train_and_validate_model(
823
- model=model,
824
- loader=train_loader,
825
- lr=trial_params["lr"],
826
- l1_penalty=trial_params["l1_penalty"],
827
- trial=trial,
828
- return_history=False,
829
- latent_vectors=train_latent_vectors,
830
- lr_input_factor=trial_params["lr_input_factor"],
831
- class_weights=class_weights,
832
- X_val=X_test_trial,
833
- params=model_params,
834
- prune_metric=self.tune_metric,
835
- prune_warmup_epochs=10,
836
- eval_interval=self.tune_eval_interval,
837
- eval_requires_latents=True,
838
- eval_latent_steps=self.eval_latent_steps,
839
- eval_latent_lr=self.eval_latent_lr,
840
- eval_latent_weight_decay=self.eval_latent_weight_decay,
841
- )
842
-
843
- eval_mask = (
844
- self.sim_mask_test_
845
- if (
846
- self.simulate_missing
847
- and getattr(self, "sim_mask_test_", None) is not None
848
- )
849
- else None
850
- )
851
- metrics = self._evaluate_model(
852
- X_test_trial,
853
- model,
854
- model_params,
855
- objective_mode=True,
856
- eval_mask_override=eval_mask,
857
- )
858
- self._clear_resources(
859
- model, train_loader, latent_vectors=train_latent_vectors
860
- )
861
- return metrics[self.tune_metric]
862
- except Exception as e:
863
- raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
864
-
865
- def _sample_hyperparameters(self, trial: optuna.Trial) -> dict:
866
- """Sample UBP hyperparameters; compute hidden sizes for model_params.
867
-
868
- This method samples a set of hyperparameters for the UBP model using the provided Optuna trial object. It defines a search space for various hyperparameters, including latent dimension, learning rate, dropout rate, number of hidden layers, activation function, and others. After sampling the hyperparameters, it computes the sizes of the hidden layers based on the sampled values and constructs the model parameters dictionary. The method returns a dictionary containing all sampled hyperparameters along with the computed model parameters.
869
-
870
- Args:
871
- trial (optuna.Trial): Current trial.
872
-
873
- Returns:
874
- Dict[str, int | float | str | list]: Sampled hyperparameters.
875
- """
876
- params = {
877
- "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
878
- "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
879
- "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30, step=0.05),
880
- "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 6),
881
- "activation": trial.suggest_categorical(
882
- "activation", ["relu", "elu", "selu", "leaky_relu"]
883
- ),
884
- "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
885
- "lr_input_factor": trial.suggest_float(
886
- "lr_input_factor", 0.3, 3.0, log=True
887
- ),
888
- "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
889
- "layer_scaling_factor": trial.suggest_float(
890
- "layer_scaling_factor", 2.0, 4.0, step=0.5
891
- ),
892
- "layer_schedule": trial.suggest_categorical(
893
- "layer_schedule", ["pyramid", "linear"]
894
- ),
895
- "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
896
- }
897
-
898
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
899
- n_inputs=params["latent_dim"],
900
- n_outputs=self.num_features_ * self.output_classes_,
901
- n_samples=len(self.train_idx_),
902
- n_hidden=params["num_hidden_layers"],
903
- alpha=params["layer_scaling_factor"],
904
- schedule=params["layer_schedule"],
905
- )
906
- # Keep the latent_dim as the first element,
907
- # then the interior hidden widths.
908
- # If there are no interior widths (very small nets),
909
- # this still leaves [latent_dim].
910
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
911
-
912
- params["model_params"] = {
913
- "n_features": self.num_features_,
914
- "num_classes": self.output_classes_,
915
- "latent_dim": params["latent_dim"],
916
- "dropout_rate": params["dropout_rate"],
917
- "hidden_layer_sizes": hidden_only,
918
- "activation": params["activation"],
919
- }
920
-
921
- return params
922
-
923
- def _set_best_params(self, best_params: dict) -> dict:
924
- """Set best params onto instance; return model_params payload.
925
-
926
- This method sets the best hyperparameters found during tuning onto the instance attributes of the ImputeUBP class. It extracts the relevant hyperparameters from the provided dictionary and updates the corresponding instance variables. Additionally, it computes the sizes of the hidden layers based on the best hyperparameters and constructs the model parameters dictionary. The method returns a dictionary containing the model parameters that can be used to build the UBP model.
927
-
928
- Args:
929
- best_params (dict): Best hyperparameters.
930
-
931
- Returns:
932
- dict: model_params payload.
933
-
934
- Raises:
935
- ValueError: If best_params is missing required keys.
936
- """
937
- self.latent_dim = best_params["latent_dim"]
938
- self.dropout_rate = best_params["dropout_rate"]
939
- self.learning_rate = best_params["learning_rate"]
940
- self.gamma = best_params["gamma"]
941
- self.lr_input_factor = best_params["lr_input_factor"]
942
- self.l1_penalty = best_params["l1_penalty"]
943
- self.activation = best_params["activation"]
944
- self.latent_init = best_params["latent_init"]
945
-
946
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
947
- n_inputs=self.latent_dim,
948
- n_outputs=self.num_features_ * self.output_classes_,
949
- n_samples=len(self.train_idx_),
950
- n_hidden=best_params["num_hidden_layers"],
951
- alpha=best_params["layer_scaling_factor"],
952
- schedule=best_params["layer_schedule"],
953
- )
954
-
955
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
956
-
957
- return {
958
- "n_features": self.num_features_,
959
- "latent_dim": self.latent_dim,
960
- "hidden_layer_sizes": hidden_only,
961
- "dropout_rate": self.dropout_rate,
962
- "activation": self.activation,
963
- "gamma": self.gamma,
964
- "num_classes": self.output_classes_,
965
- }
966
-
967
- def _set_best_params_default(self) -> dict:
968
- """Default (no-tuning) model_params aligned with current attributes.
969
-
970
- This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
971
-
972
- Returns:
973
- dict: model_params payload.
974
- """
975
- hidden_layer_sizes = self._compute_hidden_layer_sizes(
976
- n_inputs=self.latent_dim,
977
- n_outputs=self.num_features_ * self.output_classes_,
978
- n_samples=len(self.ground_truth_),
979
- n_hidden=self.num_hidden_layers,
980
- alpha=self.layer_scaling_factor,
981
- schedule=self.layer_schedule,
982
- )
983
-
984
- hidden_only = [hidden_layer_sizes[0]] + hidden_layer_sizes[1:-1]
985
-
986
- return {
987
- "n_features": self.num_features_,
988
- "latent_dim": self.latent_dim,
989
- "hidden_layer_sizes": hidden_only,
990
- "dropout_rate": self.dropout_rate,
991
- "activation": self.activation,
992
- "gamma": self.gamma,
993
- "num_classes": self.output_classes_,
994
- }
995
-
996
- def _train_and_validate_model(
997
- self,
998
- model: torch.nn.Module,
999
- loader: torch.utils.data.DataLoader,
1000
- lr: float,
1001
- l1_penalty: float,
1002
- trial: optuna.Trial | None = None,
1003
- return_history: bool = False,
1004
- latent_vectors: torch.nn.Parameter | None = None,
1005
- lr_input_factor: float = 1.0,
1006
- class_weights: torch.Tensor | None = None,
1007
- *,
1008
- X_val: np.ndarray | None = None,
1009
- params: dict | None = None,
1010
- prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
1011
- prune_warmup_epochs: int = 10,
1012
- eval_interval: int = 1,
1013
- eval_requires_latents: bool = True, # UBP needs latent eval
1014
- eval_latent_steps: int = 50,
1015
- eval_latent_lr: float = 1e-2,
1016
- eval_latent_weight_decay: float = 0.0,
1017
- ) -> tuple:
1018
- """Train & validate UBP model with three-phase loop.
1019
-
1020
- This method trains and validates the UBP model using a three-phase training loop. It sets up the latent optimizer and invokes the training loop, which includes pre-training, fine-tuning, and joint training phases. The method ensures that the necessary latent vectors and class weights are provided before proceeding with training. It also incorporates new parameters for evaluation and pruning during training. The final best loss, best model, training history, and optimized latent vectors are returned.
1021
-
1022
- Args:
1023
- model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
1024
- loader (torch.utils.data.DataLoader): DataLoader for training data.
1025
- lr (float): Learning rate for decoder.
1026
- l1_penalty (float): L1 regularization weight.
1027
- trial (optuna.Trial | None): Current trial or None.
1028
- return_history (bool): If True, return loss history.
1029
- latent_vectors (torch.nn.Parameter | None): Trainable Z.
1030
- lr_input_factor (float): LR factor for latents.
1031
- class_weights (torch.Tensor | None): Class weights for 0/1/2.
1032
- X_val (np.ndarray | None): Validation set for pruning/eval.
1033
- params (dict | None): Model params for eval.
1034
- prune_metric (str | None): Metric to monitor for pruning.
1035
- prune_warmup_epochs (int): Epochs before pruning starts.
1036
- eval_interval (int): Epochs between evaluations.
1037
- eval_requires_latents (bool): If True, optimize latents for eval.
1038
- eval_latent_steps (int): Latent optimization steps for eval.
1039
- eval_latent_lr (float): Latent optimization LR for eval.
1040
- eval_latent_weight_decay (float): Latent optimization weight decay for eval.
1041
-
1042
- Returns:
1043
- Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
1044
-
1045
- Raises:
1046
- TypeError: If latent_vectors or class_weights are
1047
- not provided.
1048
- ValueError: If X_val is not provided for evaluation.
1049
- RuntimeError: If eval_latent_steps is not positive.
1050
- """
1051
- if latent_vectors is None or class_weights is None:
1052
- msg = "Must provide latent_vectors and class_weights."
1053
- self.logger.error(msg)
1054
- raise TypeError(msg)
1055
-
1056
- latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
1057
-
1058
- result = self._execute_training_loop(
1059
- loader=loader,
1060
- latent_optimizer=latent_optimizer,
1061
- lr=lr,
1062
- model=model,
1063
- l1_penalty=l1_penalty,
1064
- trial=trial,
1065
- return_history=return_history,
1066
- latent_vectors=latent_vectors,
1067
- class_weights=class_weights,
1068
- # NEW ↓↓↓
1069
- X_val=X_val,
1070
- params=params,
1071
- prune_metric=prune_metric,
1072
- prune_warmup_epochs=prune_warmup_epochs,
1073
- eval_interval=eval_interval,
1074
- eval_requires_latents=eval_requires_latents,
1075
- eval_latent_steps=eval_latent_steps,
1076
- eval_latent_lr=eval_latent_lr,
1077
- eval_latent_weight_decay=eval_latent_weight_decay,
1078
- )
1079
-
1080
- if return_history:
1081
- return result
1082
-
1083
- return result[0], result[1], result[3]
1084
-
1085
- def _train_final_model(
1086
- self,
1087
- loader: torch.utils.data.DataLoader,
1088
- best_params: dict,
1089
- initial_latent_vectors: torch.nn.Parameter,
1090
- ) -> tuple:
1091
- """Train final UBP model with best params; save weights to disk.
1092
-
1093
- This method trains the final UBP model using the best hyperparameters found during tuning. It builds the model with the specified parameters, initializes the weights, and invokes the training and validation process. The method saves the trained model's state dictionary to disk and returns the final loss, trained model, training history, and optimized latent vectors.
1094
-
1095
- Args:
1096
- loader (torch.utils.data.DataLoader): DataLoader for training data.
1097
- best_params (Dict[str, int | float | str | list]): Best hyperparameters.
1098
- initial_latent_vectors (torch.nn.Parameter): Initialized latent vectors.
1099
-
1100
- Returns:
1101
- Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (loss, model, {"Train": history}, latents).
1102
- """
1103
- self.logger.info(f"Training the final {self.model_name} model...")
1104
-
1105
- model = self.build_model(self.Model, best_params)
1106
- model.n_features = best_params["n_features"]
1107
- model.apply(self.initialize_weights)
1108
-
1109
- loss, trained_model, history, latent_vectors = self._train_and_validate_model(
1110
- model=model,
1111
- loader=loader,
1112
- lr=self.learning_rate,
1113
- l1_penalty=self.l1_penalty,
1114
- return_history=True,
1115
- latent_vectors=initial_latent_vectors,
1116
- lr_input_factor=self.lr_input_factor,
1117
- class_weights=self.class_weights_,
1118
- X_val=self.X_test_,
1119
- params=best_params,
1120
- prune_metric=self.tune_metric,
1121
- prune_warmup_epochs=10,
1122
- eval_interval=1,
1123
- eval_requires_latents=True,
1124
- eval_latent_steps=self.eval_latent_steps,
1125
- eval_latent_lr=self.eval_latent_lr,
1126
- eval_latent_weight_decay=self.eval_latent_weight_decay,
1127
- )
1128
-
1129
- if trained_model is None:
1130
- msg = "Final model training failed."
1131
- self.logger.error(msg)
1132
- raise RuntimeError(msg)
1133
-
1134
- fout = self.models_dir / "final_model.pt"
1135
- torch.save(trained_model.state_dict(), fout)
1136
- return loss, trained_model, {"Train": history}, latent_vectors
1137
-
1138
- def _execute_training_loop(
1139
- self,
1140
- loader: torch.utils.data.DataLoader,
1141
- latent_optimizer: torch.optim.Optimizer,
1142
- lr: float,
1143
- model: torch.nn.Module,
1144
- l1_penalty: float,
1145
- trial: optuna.Trial | None,
1146
- return_history: bool,
1147
- latent_vectors: torch.nn.Parameter,
1148
- class_weights: torch.Tensor,
1149
- *,
1150
- X_val: np.ndarray | None = None,
1151
- params: dict | None = None,
1152
- prune_metric: str | None = None,
1153
- prune_warmup_epochs: int = 10,
1154
- eval_interval: int = 1,
1155
- eval_requires_latents: bool = True,
1156
- eval_latent_steps: int = 50,
1157
- eval_latent_lr: float = 1e-2,
1158
- eval_latent_weight_decay: float = 0.0,
1159
- ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
1160
- """Three-phase UBP with numeric guards, LR warmup, and pruning.
1161
-
1162
- This method executes the three-phase training loop for the UBP model, incorporating numeric stability guards, learning rate warmup, and Optuna pruning. It iterates through three training phases: pre-training the phase 1 decoder, fine-tuning the phase 2 and 3 decoders, and joint training of all components. The method monitors training loss, applies early stopping, and evaluates the model on a validation set for pruning purposes. The final best loss, best model, training history, and optimized latent vectors are returned.
1163
-
1164
- Args:
1165
- loader (torch.utils.data.DataLoader): DataLoader for training data.
1166
- latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
1167
- lr (float): Learning rate for decoder.
1168
- model (torch.nn.Module): UBP model with phase1_decoder & phase23_decoder.
1169
- l1_penalty (float): L1 regularization weight.
1170
- trial (optuna.Trial | None): Current trial or None.
1171
- return_history (bool): If True, return loss history.
1172
- latent_vectors (torch.nn.Parameter): Trainable Z.
1173
- class_weights (torch.Tensor): Class weights for
1174
- 0/1/2.
1175
- X_val (np.ndarray | None): Validation set for pruning/eval.
1176
- params (dict | None): Model params for eval.
1177
- prune_metric (str | None): Metric to monitor for pruning.
1178
- prune_warmup_epochs (int): Epochs before pruning starts.
1179
- eval_interval (int): Epochs between evaluations.
1180
- eval_requires_latents (bool): If True, optimize latents for eval.
1181
- eval_latent_steps (int): Latent optimization steps for eval.
1182
- eval_latent_lr (float): Latent optimization LR for eval.
1183
- eval_latent_weight_decay (float): Latent optimization weight decay for eval.
1184
-
1185
- Returns:
1186
- Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: (best_loss, best_model, history, latents).
1187
-
1188
- Raises:
1189
- ValueError: If X_val is not provided for evaluation.
1190
- RuntimeError: If eval_latent_steps is not positive.
1191
- """
1192
- history: dict[str, list[float]] = {}
1193
- final_best_loss, final_best_model = float("inf"), None
1194
-
1195
- warm, ramp, gamma_final = 50, 100, torch.tensor(self.gamma, device=self.device)
1196
-
1197
- # Schema-aware latent cache for eval
1198
- _latent_cache: dict = {}
1199
- nF = getattr(model, "n_features", self.num_features_)
1200
- cache_key_root = f"{self.prefix}_ubp_val_latents_L{nF}_K{self.output_classes_}"
1201
-
1202
- E = int(self.epochs)
1203
- phase_epochs = {
1204
- 1: max(1, int(0.15 * E)),
1205
- 2: max(1, int(0.35 * E)),
1206
- 3: max(1, E - int(0.15 * E) - int(0.35 * E)),
1207
- }
1208
-
1209
- for phase in (1, 2, 3):
1210
- steps_this_phase = phase_epochs[phase]
1211
- warmup_epochs = getattr(self, "lr_warmup_epochs", 5) if phase == 1 else 0
1212
-
1213
- early_stopping = EarlyStopping(
1214
- patience=self.early_stop_gen,
1215
- min_epochs=self.min_epochs,
1216
- verbose=self.verbose,
1217
- prefix=self.prefix,
1218
- debug=self.debug,
1219
- )
1220
-
1221
- if phase == 2:
1222
- self._reset_weights(model)
1223
-
1224
- decoder: torch.Tensor | torch.nn.Module = (
1225
- model.phase1_decoder if phase == 1 else model.phase23_decoder
1226
- )
1227
-
1228
- if not isinstance(decoder, torch.nn.Module):
1229
- msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1230
- self.logger.error(msg)
1231
- raise TypeError(msg)
1232
-
1233
- decoder_params = decoder.parameters()
1234
- optimizer = torch.optim.AdamW(decoder_params, lr=lr, eps=1e-7)
1235
- scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
1236
- optimizer, T_max=steps_this_phase
1237
- )
1238
-
1239
- # Cache base LRs for warmup
1240
- dec_lr0 = optimizer.param_groups[0]["lr"]
1241
- lat_lr0 = latent_optimizer.param_groups[0]["lr"]
1242
- dec_lr_min, lat_lr_min = dec_lr0 * 0.1, lat_lr0 * 0.1
1243
-
1244
- phase_hist: list[float] = []
1245
- gamma_init = torch.tensor(0.0, device=self.device)
1246
-
1247
- for epoch in range(steps_this_phase):
1248
- # Focal gamma warm/ramp
1249
- if epoch < warm:
1250
- model.gamma = gamma_init.cpu().numpy().item()
1251
- elif epoch < warm + ramp:
1252
- model.gamma = gamma_final * ((epoch - warm) / ramp)
1253
- else:
1254
- model.gamma = gamma_final
1255
-
1256
- # Linear warmup for both optimizers
1257
- if warmup_epochs and epoch < warmup_epochs:
1258
- scale = float(epoch + 1) / warmup_epochs
1259
- for g in optimizer.param_groups:
1260
- g["lr"] = dec_lr_min + (dec_lr0 - dec_lr_min) * scale
1261
- for g in latent_optimizer.param_groups:
1262
- g["lr"] = lat_lr_min + (lat_lr0 - lat_lr_min) * scale
1263
-
1264
- train_loss, latent_vectors = self._train_step(
1265
- loader=loader,
1266
- optimizer=optimizer,
1267
- latent_optimizer=latent_optimizer,
1268
- model=model,
1269
- l1_penalty=l1_penalty,
1270
- latent_vectors=latent_vectors,
1271
- class_weights=class_weights,
1272
- phase=phase,
1273
- )
1274
-
1275
- if not np.isfinite(train_loss):
1276
- if trial:
1277
- raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
1278
- # reduce LRs and continue
1279
- for g in optimizer.param_groups:
1280
- g["lr"] *= 0.5
1281
- for g in latent_optimizer.param_groups:
1282
- g["lr"] *= 0.5
1283
- continue
1284
-
1285
- scheduler.step()
1286
- if return_history:
1287
- phase_hist.append(train_loss)
1288
-
1289
- early_stopping(train_loss, model)
1290
- if early_stopping.early_stop:
1291
- self.logger.info(
1292
- f"Early stopping at epoch {epoch + 1} (phase {phase})."
1293
- )
1294
- break
1295
-
1296
- # Validation + pruning
1297
- if (
1298
- trial is not None
1299
- and X_val is not None
1300
- and ((epoch + 1) % eval_interval == 0)
1301
- ):
1302
- metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1303
- zdim = self._first_linear_in_features(model)
1304
- schema_key = f"{cache_key_root}_z{zdim}"
1305
- mask_override = None
1306
- if (
1307
- self.simulate_missing
1308
- and getattr(self, "sim_mask_test_", None) is not None
1309
- and getattr(self, "X_test_", None) is not None
1310
- and X_val.shape == self.X_test_.shape
1311
- ):
1312
- mask_override = self.sim_mask_test_
1313
-
1314
- metric_val = self._eval_for_pruning(
1315
- model=model,
1316
- X_val=X_val,
1317
- params=params or getattr(self, "best_params_", {}),
1318
- metric=metric_key,
1319
- objective_mode=True,
1320
- do_latent_infer=eval_requires_latents,
1321
- latent_steps=eval_latent_steps,
1322
- latent_lr=eval_latent_lr,
1323
- latent_weight_decay=eval_latent_weight_decay,
1324
- latent_seed=self.seed, # type: ignore
1325
- _latent_cache=_latent_cache,
1326
- _latent_cache_key=schema_key,
1327
- eval_mask_override=mask_override,
1328
- )
1329
-
1330
- if phase == 3:
1331
- trial.report(metric_val, step=epoch + 1)
1332
- if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1333
- raise optuna.exceptions.TrialPruned(
1334
- f"Pruned at epoch {epoch + 1} (phase {phase}): {metric_key}={metric_val:.5f}"
1335
- )
1336
-
1337
- history[f"Phase {phase}"] = phase_hist
1338
- final_best_loss = early_stopping.best_score
1339
- if early_stopping.best_model is not None:
1340
- final_best_model = copy.deepcopy(early_stopping.best_model)
1341
- else:
1342
- final_best_model = copy.deepcopy(model)
1343
-
1344
- if final_best_model is None:
1345
- final_best_model = copy.deepcopy(model)
1346
-
1347
- return final_best_loss, final_best_model, history, latent_vectors
1348
-
1349
- def _optimize_latents_for_inference(
1350
- self,
1351
- X_new: np.ndarray,
1352
- model: torch.nn.Module,
1353
- params: dict,
1354
- inference_epochs: int = 200,
1355
- ) -> torch.Tensor:
1356
- """Optimize latents for new 0/1/2 data with guards.
1357
-
1358
- This method optimizes the latent vectors for new genotype data using the trained UBP model. It initializes the latent space based on the provided data and iteratively updates the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability guards to ensure that gradients and losses remain finite. The optimized latent vectors are returned as a PyTorch tensor.
1359
-
1360
- Args:
1361
- X_new (np.ndarray): New 0/1/2 data with -1 for missing.
1362
- model (torch.nn.Module): Trained UBP model.
1363
- params (dict): Model params.
1364
- inference_epochs (int): Number of optimization epochs.
1365
-
1366
- Returns:
1367
- torch.Tensor: Optimized latent vectors.
1368
- """
1369
- model.eval()
1370
- nF = getattr(model, "n_features", self.num_features_)
1371
-
1372
- if self.tune and self.tune_fast:
1373
- inference_epochs = min(
1374
- inference_epochs, getattr(self, "tune_infer_epochs", 20)
1375
- )
1376
-
1377
- X_new = X_new.astype(np.int64, copy=False)
1378
- X_new[X_new < 0] = -1
1379
- y = torch.from_numpy(X_new).long().to(self.device)
1380
-
1381
- z = self._create_latent_space(
1382
- params, len(X_new), X_new, self.latent_init
1383
- ).requires_grad_(True)
1384
- opt = torch.optim.AdamW(
1385
- [z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
1386
- )
1387
-
1388
- for _ in range(inference_epochs):
1389
- decoder = model.phase23_decoder
1390
-
1391
- if not isinstance(decoder, torch.nn.Module):
1392
- msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1393
- self.logger.error(msg)
1394
- raise TypeError(msg)
1395
-
1396
- opt.zero_grad(set_to_none=True)
1397
- logits = decoder(z).view(len(X_new), nF, self.output_classes_)
1398
-
1399
- if not torch.isfinite(logits).all():
1400
- break
1401
-
1402
- if self.is_haploid:
1403
- loss = F.cross_entropy(
1404
- logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
1405
- )
1406
- else:
1407
- targets = self._multi_hot_targets(y)
1408
- bce = F.binary_cross_entropy_with_logits(
1409
- logits, targets, pos_weight=self.pos_weights_, reduction="none"
1410
- )
1411
- mask = (y != -1).unsqueeze(-1).float()
1412
- loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1413
-
1414
- if not torch.isfinite(loss):
1415
- break
1416
-
1417
- loss.backward()
1418
-
1419
- torch.nn.utils.clip_grad_norm_([z], 1.0)
1420
-
1421
- if z.grad is None or not torch.isfinite(z.grad).all():
1422
- break
1423
-
1424
- opt.step()
1425
-
1426
- return z.detach()
1427
-
1428
- def _create_latent_space(
1429
- self,
1430
- params: dict,
1431
- n_samples: int,
1432
- X: np.ndarray,
1433
- latent_init: Literal["random", "pca"],
1434
- ) -> torch.nn.Parameter:
1435
- """Initialize latent space via random Xavier or PCA on 0/1/2 matrix.
1436
-
1437
- This method initializes the latent space for the UBP model using either random Xavier initialization or PCA-based initialization. The choice of initialization strategy is determined by the latent_init parameter. If PCA is selected, the method handles missing values by imputing them with column means before performing PCA. The resulting latent vectors are standardized and converted to a PyTorch parameter that can be optimized during training.
1438
-
1439
- Args:
1440
- params (dict): Contains 'latent_dim'.
1441
- n_samples (int): Number of samples.
1442
- X (np.ndarray): (n_samples x L) 0/1/2 with -1 missing.
1443
- latent_init (Literal["random","pca"]): Init strategy.
1444
-
1445
- Returns:
1446
- torch.nn.Parameter: Trainable latent matrix.
1447
- """
1448
- latent_dim = int(params["latent_dim"])
1449
-
1450
- if latent_init == "pca":
1451
- X_pca = X.astype(np.float32, copy=True)
1452
- # mark missing
1453
- X_pca[X_pca < 0] = np.nan
1454
-
1455
- # ---- SAFE column means without warnings ----
1456
- valid_counts = np.sum(~np.isnan(X_pca), axis=0)
1457
- col_sums = np.nansum(X_pca, axis=0)
1458
- col_means = np.divide(
1459
- col_sums,
1460
- valid_counts,
1461
- out=np.zeros_like(col_sums, dtype=np.float32),
1462
- where=valid_counts > 0,
1463
- )
1464
-
1465
- # impute NaNs with per-column means
1466
- # (all-NaN cols -> 0.0 by the divide above)
1467
- nan_r, nan_c = np.where(np.isnan(X_pca))
1468
- if nan_r.size:
1469
- X_pca[nan_r, nan_c] = col_means[nan_c]
1470
-
1471
- # center columns
1472
- X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
1473
-
1474
- # guard: degenerate / all-zero after centering ->
1475
- # fall back to random
1476
- if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
1477
- latents = torch.empty(n_samples, latent_dim, device=self.device)
1478
- torch.nn.init.xavier_uniform_(latents)
1479
- return torch.nn.Parameter(latents, requires_grad=True)
1480
-
1481
- # rank-aware component count, at least 1
1482
- try:
1483
- est_rank = np.linalg.matrix_rank(X_pca)
1484
- except Exception:
1485
- est_rank = min(n_samples, X_pca.shape[1])
1486
-
1487
- n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
1488
-
1489
- # use deterministic SVD to avoid power-iteration warnings
1490
- pca = PCA(
1491
- n_components=n_components,
1492
- svd_solver="randomized",
1493
- random_state=self.seed,
1494
- )
1495
- initial = pca.fit_transform(X_pca) # (n_samples, n_components)
1496
-
1497
- # pad if latent_dim > n_components
1498
- if n_components < latent_dim:
1499
- pad = self.rng.standard_normal(
1500
- size=(n_samples, latent_dim - n_components)
1501
- )
1502
- initial = np.hstack([initial, pad])
1503
-
1504
- # standardize latent dims
1505
- initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
1506
-
1507
- latents = torch.from_numpy(initial).float().to(self.device)
1508
- return torch.nn.Parameter(latents, requires_grad=True)
1509
-
1510
- else:
1511
- latents = torch.empty(n_samples, latent_dim, device=self.device)
1512
- torch.nn.init.xavier_uniform_(latents)
1513
- return torch.nn.Parameter(latents, requires_grad=True)
1514
-
1515
- def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
1516
- """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
1517
- if self.is_haploid:
1518
- raise RuntimeError("_multi_hot_targets called for haploid data.")
1519
- y = y.to(self.device)
1520
- out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
1521
- valid = y != -1
1522
- ref_mask = valid & (y != 2)
1523
- alt_mask = valid & (y != 0)
1524
- out[ref_mask, 0] = 1.0
1525
- out[alt_mask, 1] = 1.0
1526
- return out
1527
-
1528
- def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
1529
- """Balance REF/ALT channels for multilabel BCE."""
1530
- ref_pos = np.count_nonzero((X == 0) | (X == 1))
1531
- alt_pos = np.count_nonzero((X == 2) | (X == 1))
1532
- total_valid = np.count_nonzero(X != -1)
1533
- pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
1534
- neg_counts = np.maximum(total_valid - pos_counts, 1.0)
1535
- pos_counts = np.maximum(pos_counts, 1.0)
1536
- weights = neg_counts / pos_counts
1537
- return torch.tensor(weights, device=self.device, dtype=torch.float32)
1538
-
1539
- def _reset_weights(self, model: torch.nn.Module) -> None:
1540
- """Selectively resets only the weights of the phase 2/3 decoder.
1541
-
1542
- This method targets only the `phase23_decoder` attribute of the UBPModel, leaving the `phase1_decoder` and other potential model components untouched. This allows the model to be re-initialized for the second phase of training without affecting other parts.
1543
-
1544
- Args:
1545
- model (torch.nn.Module): The PyTorch model whose parameters are to be reset.
1546
- """
1547
- if hasattr(model, "phase23_decoder"):
1548
- decoder = model.phase23_decoder
1549
- if not isinstance(decoder, torch.nn.Module):
1550
- msg = f"{self.model_name} phase23_decoder is not a torch.nn.Module."
1551
- self.logger.error(msg)
1552
- raise TypeError(msg)
1553
- # Iterate through only the modules of the second decoder
1554
- for layer in decoder.modules():
1555
- if hasattr(layer, "reset_parameters") and isinstance(
1556
- layer.reset_parameters, torch.nn.Module
1557
- ):
1558
- layer.reset_parameters()
1559
- else:
1560
- self.logger.warning(
1561
- "Model does not have a 'phase23_decoder' attribute; skipping weight reset."
1562
- )
1563
-
1564
- def _latent_infer_for_eval(
1565
- self,
1566
- model: torch.nn.Module,
1567
- X_val: np.ndarray,
1568
- *,
1569
- steps: int,
1570
- lr: float,
1571
- weight_decay: float,
1572
- seed: int,
1573
- cache: dict | None,
1574
- cache_key: str | None,
1575
- ) -> None:
1576
- """Freeze network; refine validation latents only with guards.
1577
-
1578
- This method refines the latent vectors for the validation dataset using the trained UBP model. It freezes the model parameters to prevent updates during this phase and optimizes the latent vectors to minimize the cross-entropy loss between the model's predictions and the true genotype values. The optimization process includes numeric stability checks to ensure that gradients and losses remain finite. If a cache is provided, it stores the optimized latent vectors for future use.
1579
-
1580
- Args:
1581
- model (torch.nn.Module): Trained UBP model.
1582
- X_val (np.ndarray): Validation set 0/1/2 with -1 missing
1583
- steps (int): Number of optimization steps.
1584
- lr (float): Learning rate for latent optimization.
1585
- weight_decay (float): Weight decay for latent optimization.
1586
- seed (int): Random seed for reproducibility.
1587
- cache (dict | None): Optional cache for latent vectors.
1588
- cache_key (str | None): Key for storing/retrieving from cache.
1589
- """
1590
- if seed is None:
1591
- seed = np.random.randint(0, 999_999)
1592
-
1593
- torch.manual_seed(seed)
1594
- np.random.seed(seed)
1595
-
1596
- model.eval()
1597
- for p in model.parameters():
1598
- p.requires_grad_(False)
1599
-
1600
- nF = getattr(model, "n_features", self.num_features_)
1601
- X_val = X_val.astype(np.int64, copy=False)
1602
- X_val[X_val < 0] = -1
1603
- y = torch.from_numpy(X_val).long().to(self.device)
1604
-
1605
- zdim = self._first_linear_in_features(model)
1606
- schema_key = (
1607
- f"{self.prefix}_ubp_val_latents_z{zdim}_L{nF}_K{self.output_classes_}"
1608
- )
1609
-
1610
- if cache is not None and schema_key in cache:
1611
- z = cache[schema_key].detach().clone().requires_grad_(True)
1612
- else:
1613
- z = self._create_latent_space(
1614
- {"latent_dim": zdim}, X_val.shape[0], X_val, self.latent_init
1615
- ).requires_grad_(True)
1616
-
1617
- opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
1618
-
1619
- for _ in range(max(int(steps), 0)):
1620
- opt.zero_grad(set_to_none=True)
1621
-
1622
- decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1623
-
1624
- if not isinstance(decoder, torch.nn.Module):
1625
- msg = f"{self.model_name} Decoder is not a torch.nn.Module."
1626
- self.logger.error(msg)
1627
- raise TypeError(msg)
1628
-
1629
- logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
1630
- if not torch.isfinite(logits).all():
1631
- break
1632
- if self.is_haploid:
1633
- loss = F.cross_entropy(
1634
- logits.view(-1, self.output_classes_), y.view(-1), ignore_index=-1
1635
- )
1636
- else:
1637
- targets = self._multi_hot_targets(y)
1638
- bce = F.binary_cross_entropy_with_logits(
1639
- logits, targets, pos_weight=self.pos_weights_, reduction="none"
1640
- )
1641
- mask = (y != -1).unsqueeze(-1).float()
1642
- loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1643
-
1644
- if not torch.isfinite(loss):
1645
- break
1646
-
1647
- loss.backward()
1648
-
1649
- torch.nn.utils.clip_grad_norm_([z], 1.0)
1650
-
1651
- if z.grad is None or not torch.isfinite(z.grad).all():
1652
- break
1653
-
1654
- opt.step()
1655
-
1656
- if cache is not None:
1657
- cache[schema_key] = z.detach().clone()
1658
-
1659
- for p in model.parameters():
1660
- p.requires_grad_(True)