pg-sui 0.2.0__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +101 -79
  2. pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.0.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +909 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1424 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1118 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1228 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.0.dist-info/RECORD +0 -75
  83. pg_sui-0.2.0.dist-info/top_level.txt +0 -3
  84. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  85. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  88. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  89. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  90. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  93. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  94. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  95. pgsui/example_data/trees/test.iqtree +0 -376
  96. pgsui/example_data/trees/test.qmat +0 -5
  97. pgsui/example_data/trees/test.rate +0 -2033
  98. pgsui/example_data/trees/test.tre +0 -1
  99. pgsui/example_data/trees/test_n10.rate +0 -19
  100. pgsui/example_data/trees/test_n100.rate +0 -109
  101. pgsui/example_data/trees/test_n500.rate +0 -509
  102. pgsui/example_data/trees/test_siterates.txt +0 -2024
  103. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  104. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  105. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  106. pgsui/example_data/vcf_files/test.vcf +0 -244
  107. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  108. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  109. pgsui/impute/estimators.py +0 -1268
  110. pgsui/impute/impute.py +0 -1463
  111. pgsui/impute/simple_imputers.py +0 -1431
  112. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  113. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  114. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  115. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  116. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  117. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  118. pgsui/pg_sui.py +0 -261
  119. pgsui/utils/sequence_tools.py +0 -407
  120. simulation/sim_benchmarks.py +0 -333
  121. simulation/sim_treeparams.py +0 -475
  122. test/__init__.py +0 -0
  123. test/pg_sui_simtest.py +0 -215
  124. test/pg_sui_testing.py +0 -523
  125. test/test.py +0 -151
  126. test/test_pgsui.py +0 -374
  127. test/test_tkc.py +0 -185
@@ -0,0 +1,1554 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import optuna
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.exceptions import NotFittedError
13
+ from sklearn.model_selection import train_test_split
14
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
15
+ from snpio.utils.logging import LoggerManager
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR
17
+
18
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
19
+ from pgsui.data_processing.containers import NLPCAConfig
20
+ from pgsui.data_processing.transformers import SimMissingTransformer
21
+ from pgsui.impute.unsupervised.base import BaseNNImputer
22
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
23
+ from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
24
+ from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
25
+ from pgsui.utils.logging_utils import configure_logger
26
+ from pgsui.utils.pretty_metrics import PrettyMetrics
27
+
28
+ if TYPE_CHECKING:
29
+ from snpio import TreeParser
30
+ from snpio.read_input.genotype_data import GenotypeData
31
+
32
+
33
+ def ensure_nlpca_config(config: NLPCAConfig | dict | str | None) -> NLPCAConfig:
34
+ """Return a concrete NLPCAConfig from dataclass, dict, YAML path, or None.
35
+
36
+ Args:
37
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
38
+
39
+ Returns:
40
+ NLPCAConfig: Concrete configuration instance.
41
+ """
42
+ if config is None:
43
+ return NLPCAConfig()
44
+ if isinstance(config, NLPCAConfig):
45
+ return config
46
+ if isinstance(config, str):
47
+ # YAML path — top-level `preset` key is supported
48
+ return load_yaml_to_dataclass(config, NLPCAConfig)
49
+ if isinstance(config, dict):
50
+ # Flatten dict into dot-keys then overlay onto a fresh instance
51
+ base = NLPCAConfig()
52
+
53
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
54
+ for k, v in d.items():
55
+ kk = f"{prefix}.{k}" if prefix else k
56
+ if isinstance(v, dict):
57
+ _flatten(kk, v, out)
58
+ else:
59
+ out[kk] = v
60
+ return out
61
+
62
+ # Lift any present preset first
63
+ preset_name = config.pop("preset", None)
64
+ if "io" in config and isinstance(config["io"], dict):
65
+ preset_name = preset_name or config["io"].pop("preset", None)
66
+
67
+ if preset_name:
68
+ base = NLPCAConfig.from_preset(preset_name)
69
+
70
+ flat = _flatten("", config, {})
71
+ return apply_dot_overrides(base, flat)
72
+
73
+ raise TypeError("config must be an NLPCAConfig, dict, YAML path, or None.")
74
+
75
+
76
+ class ImputeNLPCA(BaseNNImputer):
77
+ """Imputes missing genotypes using a Non-linear Principal Component Analysis (NLPCA) model.
78
+
79
+ This class implements an imputer based on Non-linear Principal Component Analysis (NLPCA) using a neural network architecture. It is designed to handle genotype data encoded in 0/1/2 format, where 0 represents the reference allele, 1 represents the heterozygous genotype, and 2 represents the alternate allele. Missing genotypes should be represented as -9 or -1.
80
+
81
+ The NLPCA model consists of an encoder-decoder architecture that learns a low-dimensional latent representation of the genotype data. The model is trained using a focal loss function to address class imbalance, and it can incorporate L1 regularization to promote sparsity in the learned representations.
82
+
83
+ Notes:
84
+ - Supports both haploid and diploid genotype data.
85
+ - Configurable model architecture with options for latent dimension, dropout rate, number of hidden layers, and activation functions.
86
+ - Hyperparameter tuning using Optuna for optimal model performance.
87
+ - Evaluation metrics including accuracy, F1-score, precision, recall, and ROC-AUC.
88
+ - Visualization of training history and genotype distributions.
89
+ - Flexible configuration via dataclass, dictionary, or YAML file.
90
+
91
+ Example:
92
+ >>> from snpio import VCFReader
93
+ >>> from pgsui import ImputeNLPCA
94
+ >>> gdata = VCFReader("genotypes.vcf.gz")
95
+ >>> imputer = ImputeNLPCA(gdata, config="nlpca_config.yaml")
96
+ >>> imputer.fit()
97
+ >>> imputed_genotypes = imputer.transform()
98
+ >>> print(imputed_genotypes)
99
+ [['A' 'G' 'C' ...],
100
+ ['G' 'G' 'C' ...],
101
+ ...
102
+ ['T' 'C' 'A' ...],
103
+ ['C' 'C' 'C' ...]]
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ genotype_data: "GenotypeData",
109
+ *,
110
+ tree_parser: Optional["TreeParser"] = None,
111
+ config: NLPCAConfig | dict | str | None = None,
112
+ overrides: dict | None = None,
113
+ simulate_missing: bool = False,
114
+ sim_strategy: Literal[
115
+ "random",
116
+ "random_weighted",
117
+ "random_weighted_inv",
118
+ "nonrandom",
119
+ "nonrandom_weighted",
120
+ ] = "random",
121
+ sim_prop: float = 0.10,
122
+ sim_kwargs: dict | None = None,
123
+ ):
124
+ """Initializes the ImputeNLPCA imputer with genotype data and configuration.
125
+
126
+ This constructor sets up the ImputeNLPCA imputer by accepting genotype data and a configuration that can be provided in various formats. It initializes logging, device settings, and model parameters based on the provided configuration.
127
+
128
+ Args:
129
+ genotype_data (GenotypeData): Backing genotype data.
130
+ tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for population-specific modes.
131
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
132
+ overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
133
+ simulate_missing (bool): Whether to simulate missing data during training.
134
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"]): Strategy for simulating missing data.
135
+ sim_prop (float): Proportion of data to simulate as missing.
136
+ sim_kwargs (dict | None): Additional keyword arguments for missing data simulation.
137
+ """
138
+ self.model_name = "ImputeNLPCA"
139
+ self.genotype_data = genotype_data
140
+ self.tree_parser = tree_parser
141
+
142
+ # Normalize config first, then apply overrides (highest precedence)
143
+ cfg = ensure_nlpca_config(config)
144
+ if overrides:
145
+ cfg = apply_dot_overrides(cfg, overrides)
146
+
147
+ self.cfg = cfg
148
+
149
+ logman = LoggerManager(
150
+ __name__,
151
+ prefix=self.cfg.io.prefix,
152
+ debug=self.cfg.io.debug,
153
+ verbose=self.cfg.io.verbose,
154
+ )
155
+ self.logger = configure_logger(
156
+ logman.get_logger(),
157
+ verbose=self.cfg.io.verbose,
158
+ debug=self.cfg.io.debug,
159
+ )
160
+
161
+ # Initialize BaseNNImputer with device/dirs/logging from config
162
+ super().__init__(
163
+ model_name=self.model_name,
164
+ genotype_data=self.genotype_data,
165
+ prefix=self.cfg.io.prefix,
166
+ device=self.cfg.train.device,
167
+ verbose=self.cfg.io.verbose,
168
+ debug=self.cfg.io.debug,
169
+ )
170
+
171
+ self.Model = NLPCAModel
172
+ self.pgenc = GenotypeEncoder(genotype_data)
173
+ self.seed = self.cfg.io.seed
174
+ self.n_jobs = self.cfg.io.n_jobs
175
+ self.prefix = self.cfg.io.prefix
176
+ self.scoring_averaging = self.cfg.io.scoring_averaging
177
+ self.verbose = self.cfg.io.verbose
178
+ self.debug = self.cfg.io.debug
179
+
180
+ self.rng = np.random.default_rng(self.seed)
181
+
182
+ # Model/train hyperparams
183
+ self.latent_dim = self.cfg.model.latent_dim
184
+ self.dropout_rate = self.cfg.model.dropout_rate
185
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
186
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
187
+ self.layer_schedule = self.cfg.model.layer_schedule
188
+ self.latent_init: Literal["random", "pca"] = self.cfg.model.latent_init
189
+ self.activation = self.cfg.model.hidden_activation
190
+ self.gamma = self.cfg.model.gamma
191
+
192
+ self.batch_size = self.cfg.train.batch_size
193
+ self.learning_rate: float = self.cfg.train.learning_rate
194
+ self.lr_input_factor = self.cfg.train.lr_input_factor
195
+ self.l1_penalty = self.cfg.train.l1_penalty
196
+ self.early_stop_gen = self.cfg.train.early_stop_gen
197
+ self.min_epochs = self.cfg.train.min_epochs
198
+ self.epochs = self.cfg.train.max_epochs
199
+ self.validation_split = self.cfg.train.validation_split
200
+ self.beta = self.cfg.train.weights_beta
201
+ self.max_ratio = self.cfg.train.weights_max_ratio
202
+
203
+ # Tuning
204
+ self.tune = self.cfg.tune.enabled
205
+ self.tune_fast = self.cfg.tune.fast
206
+ self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
207
+ self.tune_batch_size = self.cfg.tune.batch_size
208
+ self.tune_epochs = self.cfg.tune.epochs
209
+ self.tune_eval_interval = self.cfg.tune.eval_interval
210
+ self.tune_metric: Literal[
211
+ "pr_macro",
212
+ "f1",
213
+ "accuracy",
214
+ "average_precision",
215
+ "precision",
216
+ "recall",
217
+ "roc_auc",
218
+ ] = self.cfg.tune.metric
219
+ self.n_trials = self.cfg.tune.n_trials
220
+ self.tune_save_db = self.cfg.tune.save_db
221
+ self.tune_resume = self.cfg.tune.resume
222
+ self.tune_max_samples = self.cfg.tune.max_samples
223
+ self.tune_max_loci = self.cfg.tune.max_loci
224
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
225
+ self.tune_patience = self.cfg.tune.patience
226
+
227
+ # Eval
228
+ self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
229
+ self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
230
+ self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
231
+
232
+ # Plotting (NOTE: PlotConfig has 'show', not 'show_plots')
233
+ self.plot_format = self.cfg.plot.fmt
234
+ self.plot_dpi = self.cfg.plot.dpi
235
+ self.plot_fontsize = self.cfg.plot.fontsize
236
+ self.title_fontsize = self.cfg.plot.fontsize
237
+ self.despine = self.cfg.plot.despine
238
+ self.show_plots = self.cfg.plot.show
239
+
240
+ # Core model config
241
+ self.is_haploid = False
242
+ self.num_classes_ = 3
243
+ self.model_params: Dict[str, Any] = {}
244
+
245
+ self.simulate_missing = simulate_missing
246
+ self.sim_strategy = sim_strategy
247
+ self.sim_prop = float(sim_prop)
248
+ self.sim_kwargs = sim_kwargs or {}
249
+
250
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
251
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
252
+ self.logger.error(msg)
253
+ raise ValueError(msg)
254
+
255
+ def fit(self) -> "ImputeNLPCA":
256
+ """Fits the NLPCA model to the 0/1/2 encoded genotype data.
257
+
258
+ This method prepares the data, splits it into training and validation sets, initializes the model, and trains it. If hyperparameter tuning is enabled, it will perform tuning before final training. After training, it evaluates the model on a test set and generates relevant plots.
259
+
260
+ Returns:
261
+ ImputeNLPCA: The fitted imputer instance.
262
+ """
263
+ self.logger.info(f"Fitting {self.model_name} model...")
264
+
265
+ # --- BASE MATRIX AND GROUND TRUTH ---
266
+ X012 = self.pgenc.genotypes_012.astype(np.float32)
267
+ X012[X012 < 0] = np.nan # NaN = original missing
268
+
269
+ # Keep an immutable ground-truth copy in 0/1/2 with -1 for original
270
+ # missing
271
+ GT_full = X012.copy()
272
+ GT_full[np.isnan(GT_full)] = -1
273
+ self.ground_truth_ = GT_full.astype(np.int64)
274
+
275
+ # --- OPTIONAL SIMULATED MISSING VIA SimMissingTransformer ---
276
+ self.sim_mask_global_ = None
277
+ if self.simulate_missing:
278
+ tr = SimMissingTransformer(
279
+ genotype_data=self.genotype_data,
280
+ tree_parser=self.tree_parser,
281
+ prop_missing=self.sim_prop,
282
+ strategy=self.sim_strategy,
283
+ missing_val=-9,
284
+ mask_missing=True,
285
+ verbose=self.verbose,
286
+ tol=None,
287
+ max_tries=None,
288
+ )
289
+ # NOTE: pass NaN-coded missing; transformer handles NaNs correctly
290
+ tr.fit(X012.copy())
291
+
292
+ # Store boolean mask of simulated positions only (excludes original-missing)
293
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
294
+
295
+ # Apply simulation to the model’s input copy: encode as -1 for loss
296
+ X_for_model = self.ground_truth_.copy()
297
+ X_for_model[self.sim_mask_global_] = -1
298
+ else:
299
+ X_for_model = self.ground_truth_.copy()
300
+
301
+ # --- Determine Ploidy and Number of Classes ---
302
+ self.is_haploid = bool(
303
+ np.all(
304
+ np.isin(
305
+ self.genotype_data.snp_data,
306
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
307
+ )
308
+ )
309
+ )
310
+
311
+ self.ploidy = 1 if self.is_haploid else 2
312
+
313
+ if self.is_haploid:
314
+ self.num_classes_ = 2
315
+
316
+ # Remap labels from {0, 2} to {0, 1}
317
+ self.ground_truth_[self.ground_truth_ == 2] = 1
318
+ X_for_model[X_for_model == 2] = 1 # <- add this line
319
+ self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
320
+ else:
321
+ self.num_classes_ = 3
322
+
323
+ self.logger.info(
324
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
325
+ )
326
+
327
+ n_samples, self.num_features_ = X_for_model.shape
328
+
329
+ self.model_params = {
330
+ "n_features": self.num_features_,
331
+ "latent_dim": self.latent_dim,
332
+ "dropout_rate": self.dropout_rate,
333
+ "activation": self.activation,
334
+ "gamma": self.gamma,
335
+ "num_classes": self.num_classes_,
336
+ }
337
+
338
+ # --- Train/Test Split ---
339
+ indices = np.arange(n_samples)
340
+ train_idx, test_idx = train_test_split(
341
+ indices, test_size=self.validation_split, random_state=self.seed
342
+ )
343
+ self.train_idx_, self.test_idx_ = train_idx, test_idx
344
+ # Subset matrices for training/eval
345
+ self.X_train_ = X_for_model[train_idx]
346
+ self.X_test_ = X_for_model[test_idx]
347
+ self.GT_train_full_ = self.ground_truth_[train_idx] # pre-mask truth
348
+ self.GT_test_full_ = self.ground_truth_[test_idx]
349
+
350
+ # Slice the simulation mask by split if present
351
+ if self.sim_mask_global_ is not None:
352
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
353
+ self.sim_mask_test_ = self.sim_mask_global_[test_idx]
354
+ else:
355
+ self.sim_mask_train_ = None
356
+ self.sim_mask_test_ = None
357
+
358
+ # Tuning, model setup, training (unchanged except DataLoader input)
359
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
360
+
361
+ if self.tune:
362
+ self.tune_hyperparameters()
363
+ self.best_params_ = getattr(self, "best_params_", self.model_params.copy())
364
+ else:
365
+ self.best_params_ = self._set_best_params_default()
366
+
367
+ # Class weights from 0/1/2 training data
368
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
369
+
370
+ if not self.latent_init in {"random", "pca"} and isinstance(
371
+ self.latent_init, str
372
+ ):
373
+ msg = (
374
+ f"Invalid latent_init '{self.latent_init}'; must be 'random' or 'pca'."
375
+ )
376
+ self.logger.error(msg)
377
+ raise ValueError(msg)
378
+
379
+ li: Literal["random", "pca"] = self.latent_init
380
+
381
+ # Latent vectors for training set
382
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
383
+ train_latent_vectors = self._create_latent_space(
384
+ self.best_params_, len(self.X_train_), self.X_train_, li
385
+ )
386
+ train_loader = self._get_data_loaders(self.X_train_)
387
+
388
+ # Train the final model
389
+ (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
390
+ self._train_final_model(
391
+ train_loader, self.best_params_, train_latent_vectors
392
+ )
393
+ )
394
+
395
+ self.is_fit_ = True
396
+ self.plotter_.plot_history(self.history_)
397
+
398
+ if self.sim_mask_test_ is not None:
399
+ # Evaluate exactly on simulated-missing sites
400
+ self.logger.info("Evaluating on simulated-missing positions only.")
401
+ self._evaluate_model(
402
+ self.X_test_,
403
+ self.model_,
404
+ self.best_params_,
405
+ eval_mask_override=self.sim_mask_test_,
406
+ )
407
+ else:
408
+ self._evaluate_model(self.X_test_, self.model_, self.best_params_)
409
+
410
+ self._save_best_params(self.best_params_)
411
+
412
+ return self
413
+
414
+ def transform(self) -> np.ndarray:
415
+ """Imputes missing genotypes using the trained model.
416
+
417
+ This method uses the trained NLPCA model to impute missing genotypes in the entire dataset. It optimizes latent vectors for all samples, predicts missing values, and fills them in. The imputed genotypes are returned in IUPAC string format.
418
+
419
+ Returns:
420
+ np.ndarray: Imputed genotypes in IUPAC string format.
421
+
422
+ Raises:
423
+ NotFittedError: If the model has not been fitted.
424
+ """
425
+ if not getattr(self, "is_fit_", False):
426
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
427
+
428
+ self.logger.info(f"Imputing entire dataset with {self.model_name}...")
429
+ X_to_impute = self.ground_truth_.copy()
430
+
431
+ # Optimize latents for the full dataset
432
+ optimized_latents = self._optimize_latents_for_inference(
433
+ X_to_impute, self.model_, self.best_params_
434
+ )
435
+
436
+ # Predict missing values
437
+ pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
438
+
439
+ # Fill in missing values
440
+ missing_mask = X_to_impute == -1
441
+ imputed_array = X_to_impute.copy()
442
+ imputed_array[missing_mask] = pred_labels[missing_mask]
443
+
444
+ # Decode back to IUPAC strings
445
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
446
+ if self.show_plots:
447
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
448
+ plt.rcParams.update(self.plotter_.param_dict) # Ensure consistent style
449
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
450
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
451
+
452
+ return imputed_genotypes
453
+
454
+ def _train_step(
455
+ self,
456
+ loader: torch.utils.data.DataLoader,
457
+ optimizer: torch.optim.Optimizer,
458
+ latent_optimizer: torch.optim.Optimizer,
459
+ model: torch.nn.Module,
460
+ l1_penalty: float,
461
+ latent_vectors: torch.nn.Parameter,
462
+ class_weights: torch.Tensor,
463
+ ) -> Tuple[float, torch.nn.Parameter]:
464
+ """One epoch with stable focal CE, latent+weight updates, and NaN guards.
465
+
466
+ Args:
467
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
468
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
469
+ latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
470
+ model (torch.nn.Module): NLPCA model.
471
+ l1_penalty (float): L1 regularization penalty.
472
+ latent_vectors (torch.nn.Parameter): Latent vectors for samples.
473
+ class_weights (torch.Tensor): Class weights for focal loss.
474
+
475
+ Returns:
476
+ Tuple[float, torch.nn.Parameter]: Average loss and updated latent vectors.
477
+
478
+ Notes:
479
+ - Implements focal cross-entropy loss with class weights.
480
+ - Applies L1 regularization on model weights.
481
+ - Includes guards against NaN/infinite values in logits, loss, and gradients.
482
+ """
483
+ model.train()
484
+ running = 0.0
485
+ used = 0
486
+
487
+ # Ensure latent vectors are trainable
488
+ if not isinstance(latent_vectors, torch.nn.Parameter):
489
+ latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
490
+
491
+ # Bound gamma to a sane range
492
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
493
+ gamma = max(0.0, min(gamma, 10.0))
494
+
495
+ # Normalize class weights to mean≈1 to keep loss scale stable
496
+ if class_weights is not None:
497
+ cw = class_weights.to(self.device)
498
+ cw = cw / cw.mean().clamp_min(1e-8)
499
+ else:
500
+ cw = None
501
+
502
+ nF = getattr(model, "n_features", self.num_features_)
503
+
504
+ criterion = SafeFocalCELoss(gamma=gamma, weight=cw, ignore_index=-1)
505
+
506
+ for batch_indices, y_batch in loader:
507
+ optimizer.zero_grad(set_to_none=True)
508
+ latent_optimizer.zero_grad(set_to_none=True)
509
+
510
+ # Targets
511
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
512
+
513
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
514
+
515
+ if not isinstance(decoder, torch.nn.Module):
516
+ msg = "Model decoder is not a valid torch.nn.Module."
517
+ self.logger.error(msg)
518
+ raise TypeError(msg)
519
+
520
+ # Forward
521
+ z = latent_vectors[batch_indices].to(self.device)
522
+ logits = decoder(z).view(len(batch_indices), nF, self.num_classes_)
523
+
524
+ # Guard upstream explosions
525
+ if not torch.isfinite(logits).all():
526
+ # Skip batch if model already produced non-finite values
527
+ continue
528
+
529
+ logits_flat = logits.view(-1, self.num_classes_)
530
+ targets_flat = y_batch.view(-1)
531
+
532
+ loss = criterion(logits_flat, targets_flat)
533
+
534
+ # L1 on model weights only (exclude latents)
535
+ if l1_penalty > 0:
536
+ l1 = torch.stack(
537
+ [p.abs().sum() for p in model.parameters() if p.requires_grad]
538
+ ).sum()
539
+ loss = loss + l1_penalty * l1
540
+
541
+ if not torch.isfinite(loss):
542
+ # Skip pathological batch
543
+ continue
544
+
545
+ loss.backward()
546
+
547
+ # Clip both parameter sets to keep grads finite
548
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
549
+ torch.nn.utils.clip_grad_norm_([latent_vectors], max_norm=1.0)
550
+
551
+ # If any grad is non-finite, skip updates
552
+ bad_grad = False
553
+ for p in model.parameters():
554
+ if p.grad is not None and not torch.isfinite(p.grad).all():
555
+ bad_grad = True
556
+ break
557
+ if (
558
+ not bad_grad
559
+ and latent_vectors.grad is not None
560
+ and not torch.isfinite(latent_vectors.grad).all()
561
+ ):
562
+ bad_grad = True
563
+ if bad_grad:
564
+ optimizer.zero_grad(set_to_none=True)
565
+ latent_optimizer.zero_grad(set_to_none=True)
566
+ continue
567
+
568
+ optimizer.step()
569
+ latent_optimizer.step()
570
+
571
+ running += float(loss.detach().item())
572
+ used += 1
573
+
574
+ if used == 0:
575
+ # Signal upstream that no safe batches were used
576
+ return float("inf"), latent_vectors
577
+
578
+ return running / used, latent_vectors
579
+
580
+ def _predict(
581
+ self, model: torch.nn.Module, latent_vectors: torch.Tensor | None = None
582
+ ) -> Tuple[np.ndarray, np.ndarray]:
583
+ """Generates 0/1/2 predictions from latent vectors.
584
+
585
+ This method uses the trained NLPCA model to generate predictions from the latent vectors by passing them through the decoder. It returns both the predicted labels and their associated probabilities.
586
+
587
+ Args:
588
+ model (torch.nn.Module): Trained NLPCA model.
589
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
590
+
591
+ Returns:
592
+ Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
593
+ """
594
+ if model is None or latent_vectors is None:
595
+ raise NotFittedError("Model or latent vectors not available.")
596
+
597
+ model.eval()
598
+
599
+ nF = getattr(model, "n_features", self.num_features_)
600
+
601
+ if not isinstance(model.phase23_decoder, torch.nn.Module):
602
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
603
+ self.logger.error(msg)
604
+ raise TypeError(msg)
605
+
606
+ with torch.no_grad():
607
+ logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
608
+ len(latent_vectors), nF, self.num_classes_
609
+ )
610
+ probas = torch.softmax(logits, dim=-1)
611
+ labels = torch.argmax(probas, dim=-1)
612
+
613
+ return labels.cpu().numpy(), probas.cpu().numpy()
614
+
615
+ def _evaluate_model(
616
+ self,
617
+ X_val: np.ndarray,
618
+ model: torch.nn.Module,
619
+ params: dict,
620
+ objective_mode: bool = False,
621
+ latent_vectors_val: torch.Tensor | None = None,
622
+ *,
623
+ eval_mask_override: np.ndarray | None = None,
624
+ ) -> Dict[str, float]:
625
+ """Evaluates the model on a validation set.
626
+
627
+ This method evaluates the trained NLPCA model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
628
+
629
+ Args:
630
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
631
+ model (torch.nn.Module): Trained NLPCA model.
632
+ params (dict): Model parameters.
633
+ objective_mode (bool): If True, suppresses logging and reports only the metric.
634
+ latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
635
+ eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
636
+
637
+ Returns:
638
+ Dict[str, float]: Dictionary of evaluation metrics.
639
+ """
640
+ if latent_vectors_val is not None:
641
+ test_latent_vectors = latent_vectors_val
642
+ else:
643
+ test_latent_vectors = self._optimize_latents_for_inference(
644
+ X_val, model, params
645
+ )
646
+
647
+ pred_labels, pred_probas = self._predict(
648
+ model=model, latent_vectors=test_latent_vectors
649
+ )
650
+
651
+ if eval_mask_override is not None:
652
+ # Validate row counts to allow feature subsetting during tuning
653
+ if eval_mask_override.shape[0] != X_val.shape[0]:
654
+ msg = (
655
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
656
+ f"does not match X_val rows {X_val.shape[0]}"
657
+ )
658
+ self.logger.error(msg)
659
+ raise ValueError(msg)
660
+
661
+ # Slice mask columns if override is wider than current X_val (tune_fast)
662
+ if eval_mask_override.shape[1] > X_val.shape[1]:
663
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
664
+ else:
665
+ eval_mask = eval_mask_override.astype(bool)
666
+ else:
667
+ # Default: score only observed entries
668
+ eval_mask = X_val != -1
669
+
670
+ # y_true should be drawn from the pre-mask ground truth
671
+ # Map X_val back to the correct full ground truth slice
672
+ # FIX: Check shape[0] (n_samples) only.
673
+ if X_val.shape[0] == self.X_test_.shape[0]:
674
+ GT_ref = self.GT_test_full_
675
+ elif X_val.shape[0] == self.X_train_.shape[0]:
676
+ GT_ref = self.GT_train_full_
677
+ else:
678
+ GT_ref = self.ground_truth_
679
+
680
+ # FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
681
+ if GT_ref.shape[1] > X_val.shape[1]:
682
+ GT_ref = GT_ref[:, : X_val.shape[1]]
683
+
684
+ # Fallback safeguard
685
+ if GT_ref.shape != X_val.shape:
686
+ GT_ref = X_val
687
+
688
+ y_true_flat = GT_ref[eval_mask]
689
+ pred_labels_flat = pred_labels[eval_mask]
690
+ pred_probas_flat = pred_probas[eval_mask]
691
+
692
+ if y_true_flat.size == 0:
693
+ return {self.tune_metric: 0.0}
694
+
695
+ # For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
696
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
697
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
698
+
699
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
700
+
701
+ metrics = self.scorers_.evaluate(
702
+ y_true_flat,
703
+ pred_labels_flat,
704
+ y_true_ohe,
705
+ pred_probas_flat,
706
+ objective_mode,
707
+ self.tune_metric,
708
+ )
709
+
710
+ if not objective_mode:
711
+ pm = PrettyMetrics(
712
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
713
+ )
714
+ pm.render() # prints a command-line table
715
+
716
+ self._make_class_reports(
717
+ y_true=y_true_flat,
718
+ y_pred_proba=pred_probas_flat,
719
+ y_pred=pred_labels_flat,
720
+ metrics=metrics,
721
+ labels=target_names,
722
+ )
723
+
724
+ # FIX: Use X_val dimensions for reshaping, not self.num_features_
725
+ y_true_dec = self.pgenc.decode_012(
726
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
727
+ )
728
+
729
+ X_pred = X_val.copy()
730
+ X_pred[eval_mask] = pred_labels_flat
731
+
732
+ y_pred_dec = self.pgenc.decode_012(
733
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
734
+ )
735
+
736
+ encodings_dict = {
737
+ "A": 0,
738
+ "C": 1,
739
+ "G": 2,
740
+ "T": 3,
741
+ "W": 4,
742
+ "R": 5,
743
+ "M": 6,
744
+ "K": 7,
745
+ "Y": 8,
746
+ "S": 9,
747
+ "N": -1,
748
+ }
749
+
750
+ y_true_int = self.pgenc.convert_int_iupac(
751
+ y_true_dec, encodings_dict=encodings_dict
752
+ )
753
+ y_pred_int = self.pgenc.convert_int_iupac(
754
+ y_pred_dec, encodings_dict=encodings_dict
755
+ )
756
+
757
+ # For IUPAC report
758
+ valid_true = y_true_int[eval_mask]
759
+ valid_true = valid_true[valid_true >= 0] # drop -1 (N)
760
+ iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
761
+
762
+ # For numeric report
763
+ if (
764
+ np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
765
+ or valid_true.size == 0
766
+ ):
767
+ if not objective_mode:
768
+ self.logger.warning(
769
+ "Skipped numeric confusion matrix: no y_true labels present."
770
+ )
771
+ else:
772
+ self._make_class_reports(
773
+ y_true=valid_true,
774
+ y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
775
+ metrics=metrics,
776
+ y_pred_proba=None,
777
+ labels=iupac_label_set,
778
+ )
779
+
780
+ return metrics
781
+
782
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
783
+ """Creates a PyTorch DataLoader for the 0/1/2 encoded data.
784
+
785
+ This method constructs a DataLoader from the provided genotype data, which is expected to be in 0/1/2 encoding with -1 for missing values. The DataLoader is used for batching and shuffling the data during model training. It converts the numpy array to a PyTorch tensor and creates a TensorDataset. The DataLoader is configured with the specified batch size and shuffling enabled.
786
+
787
+ Args:
788
+ y (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
789
+
790
+ Returns:
791
+ torch.utils.data.DataLoader: DataLoader for the dataset.
792
+ """
793
+ y_tensor = torch.from_numpy(y).long().to(self.device)
794
+ dataset = torch.utils.data.TensorDataset(
795
+ torch.arange(len(y), device=self.device), y_tensor.to(self.device)
796
+ )
797
+ return torch.utils.data.DataLoader(
798
+ dataset, batch_size=self.batch_size, shuffle=True
799
+ )
800
+
801
+ def _create_latent_space(
802
+ self,
803
+ params: dict,
804
+ n_samples: int,
805
+ X: np.ndarray,
806
+ latent_init: Literal["random", "pca"],
807
+ ) -> torch.nn.Parameter:
808
+ """Initializes the latent space for the NLPCA model.
809
+
810
+ This method initializes the latent space for the NLPCA model based on the specified initialization method. It supports two methods: 'random' initialization using Xavier uniform distribution, and 'pca' initialization which uses PCA to derive initial latent vectors from the data. The latent vectors are returned as a PyTorch Parameter, allowing them to be optimized during training.
811
+
812
+ Args:
813
+ params (dict): Model parameters including 'latent_dim'.
814
+ n_samples (int): Number of samples in the dataset.
815
+ X (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
816
+ latent_init (str): Method to initialize latent space ('random' or 'pca').
817
+
818
+ Returns:
819
+ torch.nn.Parameter: Initialized latent vectors as a PyTorch Parameter.
820
+ """
821
+ latent_dim = int(params["latent_dim"])
822
+
823
+ if latent_init == "pca":
824
+ X_pca = X.astype(np.float32, copy=True)
825
+ # mark missing
826
+ X_pca[X_pca < 0] = np.nan
827
+
828
+ # ---- SAFE column means without warnings ----
829
+ valid_counts = np.sum(~np.isnan(X_pca), axis=0)
830
+ col_sums = np.nansum(X_pca, axis=0)
831
+ col_means = np.divide(
832
+ col_sums,
833
+ valid_counts,
834
+ out=np.zeros_like(col_sums, dtype=np.float32),
835
+ where=valid_counts > 0,
836
+ )
837
+
838
+ # impute NaNs with per-column means
839
+ # (all-NaN cols -> 0.0 by the divide above)
840
+ nan_r, nan_c = np.where(np.isnan(X_pca))
841
+ if nan_r.size:
842
+ X_pca[nan_r, nan_c] = col_means[nan_c]
843
+
844
+ # center columns
845
+ X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
846
+
847
+ # guard: degenerate / all-zero after centering ->
848
+ # fall back to random
849
+ if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
850
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
851
+ torch.nn.init.xavier_uniform_(latents)
852
+ return torch.nn.Parameter(latents, requires_grad=True)
853
+
854
+ # rank-aware component count, at least 1
855
+ try:
856
+ est_rank = np.linalg.matrix_rank(X_pca)
857
+ except Exception:
858
+ est_rank = min(n_samples, X_pca.shape[1])
859
+
860
+ n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
861
+
862
+ # use deterministic SVD to avoid power-iteration warnings
863
+ pca = PCA(
864
+ n_components=n_components, svd_solver="full", random_state=self.seed
865
+ )
866
+ initial = pca.fit_transform(X_pca) # (n_samples, n_components)
867
+
868
+ # pad if latent_dim > n_components
869
+ if n_components < latent_dim:
870
+ pad = self.rng.standard_normal(
871
+ size=(n_samples, latent_dim - n_components)
872
+ )
873
+ initial = np.hstack([initial, pad])
874
+
875
+ # standardize latent dims
876
+ initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
877
+
878
+ latents = torch.from_numpy(initial).float().to(self.device)
879
+ return torch.nn.Parameter(latents, requires_grad=True)
880
+
881
+ # --- Random init path (unchanged) ---
882
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
883
+ torch.nn.init.xavier_uniform_(latents)
884
+ return torch.nn.Parameter(latents, requires_grad=True)
885
+
886
+ def _objective(self, trial: optuna.Trial) -> float:
887
+ """Objective function for hyperparameter tuning with Optuna.
888
+
889
+ This method defines the objective function used by Optuna for hyperparameter tuning of the NLPCA model. It samples a set of hyperparameters, prepares the training and validation data, initializes the model and latent vectors, and trains the model. After training, it evaluates the model on a validation set and returns the value of the specified tuning metric.
890
+
891
+ Args:
892
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
893
+
894
+ Returns:
895
+ float: The value of the tuning metric to be minimized or maximized.
896
+ """
897
+ self._prepare_tuning_artifacts()
898
+ trial_params = self._sample_hyperparameters(trial)
899
+ model_params = trial_params["model_params"]
900
+
901
+ nfeat = self._tune_num_features
902
+ if self.tune and self.tune_fast:
903
+ model_params["n_features"] = nfeat
904
+
905
+ lr = trial_params["lr"]
906
+ l1_penalty = trial_params["l1_penalty"]
907
+ lr_input_fac = trial_params["lr_input_factor"]
908
+
909
+ X_train_trial = self._tune_X_train
910
+ X_test_trial = self._tune_X_test
911
+ class_weights = self._tune_class_weights
912
+ train_loader = self._tune_loader
913
+
914
+ train_latents = self._create_latent_space(
915
+ model_params, len(X_train_trial), X_train_trial, trial_params["latent_init"]
916
+ )
917
+
918
+ model = self.build_model(self.Model, model_params)
919
+ model.n_features = model_params["n_features"]
920
+ model.apply(self.initialize_weights)
921
+
922
+ _, model, __ = self._train_and_validate_model(
923
+ model=model,
924
+ loader=train_loader,
925
+ lr=lr,
926
+ l1_penalty=l1_penalty,
927
+ trial=trial,
928
+ latent_vectors=train_latents,
929
+ lr_input_factor=lr_input_fac,
930
+ class_weights=class_weights,
931
+ X_val=X_test_trial,
932
+ params=model_params,
933
+ prune_metric=self.tune_metric,
934
+ prune_warmup_epochs=5,
935
+ eval_interval=self.tune_eval_interval,
936
+ eval_latent_steps=self.eval_latent_steps,
937
+ eval_latent_lr=self.eval_latent_lr,
938
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
939
+ )
940
+
941
+ # --- simulate-only eval mask for tuning ---
942
+ eval_mask = None
943
+ if (
944
+ self.simulate_missing
945
+ and getattr(self, "sim_mask_global_", None) is not None
946
+ ):
947
+ if hasattr(self, "_tune_test_idx") and self.sim_mask_global_ is not None:
948
+ eval_mask = self.sim_mask_global_[self._tune_test_idx]
949
+ elif getattr(self, "sim_mask_test_", None) is not None:
950
+ eval_mask = self.sim_mask_test_
951
+
952
+ metrics = self._evaluate_model(
953
+ X_test_trial,
954
+ model,
955
+ model_params,
956
+ objective_mode=True,
957
+ eval_mask_override=eval_mask,
958
+ )
959
+
960
+ self._clear_resources(model, train_loader, latent_vectors=train_latents)
961
+ return metrics[self.tune_metric]
962
+
963
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
964
+ """Samples hyperparameters for the simplified NLPCA model.
965
+
966
+ This method defines the hyperparameter search space for the NLPCA model and samples a set of hyperparameters using the provided Optuna trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
967
+
968
+ Args:
969
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
970
+
971
+ Returns:
972
+ Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
973
+ """
974
+ params = {
975
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
976
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
977
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.05),
978
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 16),
979
+ "activation": trial.suggest_categorical(
980
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
981
+ ),
982
+ "gamma": trial.suggest_float("gamma", 0.1, 5.0, step=0.1),
983
+ "lr_input_factor": trial.suggest_float(
984
+ "lr_input_factor", 0.1, 10.0, log=True
985
+ ),
986
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
987
+ "layer_scaling_factor": trial.suggest_float(
988
+ "layer_scaling_factor", 2.0, 10.0
989
+ ),
990
+ "layer_schedule": trial.suggest_categorical(
991
+ "layer_schedule", ["pyramid", "constant", "linear"]
992
+ ),
993
+ "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
994
+ }
995
+
996
+ use_n_features = (
997
+ self._tune_num_features
998
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
999
+ else self.num_features_
1000
+ )
1001
+ use_n_samples = (
1002
+ len(self._tune_train_idx)
1003
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_train_idx"))
1004
+ else len(self.train_idx_)
1005
+ )
1006
+
1007
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1008
+ n_inputs=params["latent_dim"],
1009
+ n_outputs=use_n_features * self.num_classes_,
1010
+ n_samples=use_n_samples,
1011
+ n_hidden=params["num_hidden_layers"],
1012
+ alpha=params["layer_scaling_factor"],
1013
+ schedule=params["layer_schedule"],
1014
+ )
1015
+
1016
+ params["model_params"] = {
1017
+ "n_features": use_n_features,
1018
+ "num_classes": self.num_classes_,
1019
+ "latent_dim": params["latent_dim"],
1020
+ "dropout_rate": params["dropout_rate"],
1021
+ "hidden_layer_sizes": hidden_layer_sizes,
1022
+ "activation": params["activation"],
1023
+ "gamma": params["gamma"],
1024
+ }
1025
+
1026
+ return params
1027
+
1028
+ def _set_best_params(self, best_params: dict) -> dict:
1029
+ """Sets the best hyperparameters found during tuning.
1030
+
1031
+ This method updates the model's attributes with the best hyperparameters obtained from tuning. It also computes the hidden layer sizes based on these parameters and prepares the final model parameters dictionary.
1032
+
1033
+ Args:
1034
+ best_params (dict): Best hyperparameters from tuning.
1035
+
1036
+ Returns:
1037
+ dict: Model parameters configured with the best hyperparameters.
1038
+ """
1039
+ self.latent_dim = best_params["latent_dim"]
1040
+ self.dropout_rate = best_params["dropout_rate"]
1041
+ self.learning_rate = best_params["learning_rate"]
1042
+ self.gamma = best_params["gamma"]
1043
+ self.lr_input_factor = best_params["lr_input_factor"]
1044
+ self.l1_penalty = best_params["l1_penalty"]
1045
+ self.activation = best_params["activation"]
1046
+
1047
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1048
+ n_inputs=self.latent_dim,
1049
+ n_outputs=self.num_features_ * self.num_classes_,
1050
+ n_samples=len(self.train_idx_),
1051
+ n_hidden=best_params["num_hidden_layers"],
1052
+ alpha=best_params["layer_scaling_factor"],
1053
+ schedule=best_params["layer_schedule"],
1054
+ )
1055
+
1056
+ return {
1057
+ "n_features": self.num_features_,
1058
+ "latent_dim": self.latent_dim,
1059
+ "hidden_layer_sizes": hidden_layer_sizes,
1060
+ "dropout_rate": self.dropout_rate,
1061
+ "activation": self.activation,
1062
+ "gamma": self.gamma,
1063
+ "num_classes": self.num_classes_,
1064
+ }
1065
+
1066
+ def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
1067
+ """Default (no-tuning) model_params aligned with current attributes.
1068
+
1069
+ This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
1070
+
1071
+ Returns:
1072
+ Dict[str, int | float | str | list]: model_params payload.
1073
+ """
1074
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1075
+ n_inputs=self.latent_dim,
1076
+ n_outputs=self.num_features_ * self.num_classes_,
1077
+ n_samples=len(self.ground_truth_),
1078
+ n_hidden=self.num_hidden_layers,
1079
+ alpha=self.layer_scaling_factor,
1080
+ schedule=self.layer_schedule,
1081
+ )
1082
+
1083
+ return {
1084
+ "n_features": self.num_features_,
1085
+ "latent_dim": self.latent_dim,
1086
+ "hidden_layer_sizes": hidden_layer_sizes,
1087
+ "dropout_rate": self.dropout_rate,
1088
+ "activation": self.activation,
1089
+ "gamma": self.gamma,
1090
+ "num_classes": self.num_classes_,
1091
+ }
1092
+
1093
+ def _train_and_validate_model(
1094
+ self,
1095
+ model: torch.nn.Module,
1096
+ loader: torch.utils.data.DataLoader,
1097
+ lr: float,
1098
+ l1_penalty: float,
1099
+ trial: optuna.Trial | None = None,
1100
+ return_history: bool = False,
1101
+ latent_vectors: torch.nn.Parameter | None = None,
1102
+ lr_input_factor: float = 1.0,
1103
+ class_weights: torch.Tensor | None = None,
1104
+ *,
1105
+ X_val: np.ndarray | None = None,
1106
+ params: dict | None = None,
1107
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
1108
+ prune_warmup_epochs: int = 3,
1109
+ eval_interval: int = 1,
1110
+ eval_latent_steps: int = 50,
1111
+ eval_latent_lr: float = 1e-2,
1112
+ eval_latent_weight_decay: float = 0.0,
1113
+ ) -> Tuple:
1114
+ """Trains and validates the NLPCA model.
1115
+
1116
+ This method trains the provided NLPCA model using the specified training data and hyperparameters. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method initializes optimizers for both the model parameters and latent vectors, sets up a learning rate scheduler, and executes the training loop. It can return the training history if requested.
1117
+
1118
+ Args:
1119
+ model (torch.nn.Module): The NLPCA model to be trained.
1120
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1121
+ lr (float): Learning rate for the model optimizer.
1122
+ l1_penalty (float): L1 regularization penalty.
1123
+ trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
1124
+ return_history (bool): Whether to return training history.
1125
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
1126
+ lr_input_factor (float): Learning rate factor for latent vectors.
1127
+ class_weights (torch.Tensor | None): Class weights for handling class imbalance.
1128
+ X_val (np.ndarray | None): Validation data for pruning.
1129
+ params (dict | None): Model parameters.
1130
+ prune_metric (str | None): Metric for pruning decisions.
1131
+ prune_warmup_epochs (int): Number of epochs before pruning starts.
1132
+ eval_interval (int): Interval (in epochs) for evaluation during training.
1133
+ eval_latent_steps (int): Steps for latent optimization during evaluation.
1134
+ eval_latent_lr (float): Learning rate for latent optimization during evaluation.
1135
+ eval_latent_weight_decay (float): Weight decay for latent optimization during evaluation.
1136
+
1137
+ Returns:
1138
+ Tuple[float, torch.nn.Module, Dict[str, float], torch.nn.Parameter] | Tuple[float, torch.nn.Module, torch.nn.Parameter]: Training loss, trained model, training history (if requested), and optimized latent vectors.
1139
+
1140
+ Raises:
1141
+ TypeError: If latent_vectors or class_weights are not provided.
1142
+ """
1143
+
1144
+ if latent_vectors is None or class_weights is None:
1145
+ msg = "latent_vectors and class_weights must be provided."
1146
+ self.logger.error(msg)
1147
+ raise TypeError("Must provide latent_vectors and class_weights.")
1148
+
1149
+ latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
1150
+
1151
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1152
+
1153
+ if not isinstance(decoder, torch.nn.Module):
1154
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1155
+ self.logger.error(msg)
1156
+ raise TypeError(msg)
1157
+
1158
+ optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
1159
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.epochs)
1160
+
1161
+ result = self._execute_training_loop(
1162
+ loader=loader,
1163
+ optimizer=optimizer,
1164
+ latent_optimizer=latent_optimizer,
1165
+ scheduler=scheduler,
1166
+ model=model,
1167
+ l1_penalty=l1_penalty,
1168
+ return_history=return_history,
1169
+ latent_vectors=latent_vectors,
1170
+ class_weights=class_weights,
1171
+ trial=trial,
1172
+ X_val=X_val,
1173
+ params=params,
1174
+ prune_metric=prune_metric,
1175
+ prune_warmup_epochs=prune_warmup_epochs,
1176
+ eval_interval=eval_interval,
1177
+ eval_latent_steps=eval_latent_steps,
1178
+ eval_latent_lr=eval_latent_lr,
1179
+ eval_latent_weight_decay=eval_latent_weight_decay,
1180
+ )
1181
+
1182
+ if return_history:
1183
+ return result
1184
+
1185
+ return result[0], result[1], result[3]
1186
+
1187
+ def _train_final_model(
1188
+ self,
1189
+ loader: torch.utils.data.DataLoader,
1190
+ best_params: dict,
1191
+ initial_latent_vectors: torch.nn.Parameter,
1192
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
1193
+ """Trains the final model using the best hyperparameters.
1194
+
1195
+ This method builds and trains the final NLPCA model using the best hyperparameters obtained from tuning. It initializes the model weights, trains the model on the entire training set, and saves the trained model to disk. It returns the final training loss, trained model, training history, and optimized latent vectors.
1196
+
1197
+ Args:
1198
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1199
+ best_params (dict): Best hyperparameters for the model.
1200
+ initial_latent_vectors (torch.nn.Parameter): Initial latent vectors for samples.
1201
+
1202
+ Returns:
1203
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: Final training loss, trained model, training history, and optimized latent vectors.
1204
+ Raises:
1205
+ RuntimeError: If model training fails.
1206
+ """
1207
+ self.logger.info(f"Training the final {self.model_name} model...")
1208
+
1209
+ model = self.build_model(self.Model, best_params)
1210
+ model.n_features = best_params["n_features"]
1211
+ model.apply(self.initialize_weights)
1212
+
1213
+ loss, trained_model, history, latent_vectors = self._train_and_validate_model(
1214
+ model=model,
1215
+ loader=loader,
1216
+ lr=self.learning_rate,
1217
+ l1_penalty=self.l1_penalty,
1218
+ return_history=True,
1219
+ latent_vectors=initial_latent_vectors,
1220
+ lr_input_factor=self.lr_input_factor,
1221
+ class_weights=self.class_weights_,
1222
+ X_val=self.X_test_,
1223
+ params=best_params,
1224
+ prune_metric=self.tune_metric,
1225
+ prune_warmup_epochs=5,
1226
+ eval_interval=1,
1227
+ eval_latent_steps=self.eval_latent_steps,
1228
+ eval_latent_lr=self.eval_latent_lr,
1229
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
1230
+ )
1231
+
1232
+ if trained_model is None:
1233
+ msg = "Final model training failed."
1234
+ self.logger.error(msg)
1235
+ raise RuntimeError(msg)
1236
+
1237
+ fn = self.models_dir / "final_model.pt"
1238
+ torch.save(trained_model.state_dict(), fn)
1239
+
1240
+ return (loss, trained_model, {"Train": history}, latent_vectors)
1241
+
1242
+ def _execute_training_loop(
1243
+ self,
1244
+ loader,
1245
+ optimizer,
1246
+ latent_optimizer,
1247
+ scheduler, # do not overwrite; honor caller's scheduler
1248
+ model,
1249
+ l1_penalty,
1250
+ return_history,
1251
+ latent_vectors,
1252
+ class_weights,
1253
+ *,
1254
+ trial: optuna.Trial | None = None,
1255
+ X_val: np.ndarray | None = None,
1256
+ params: dict | None = None,
1257
+ prune_metric: str | None = None,
1258
+ prune_warmup_epochs: int = 3,
1259
+ eval_interval: int = 1,
1260
+ eval_latent_steps: int = 50,
1261
+ eval_latent_lr: float = 1e-2,
1262
+ eval_latent_weight_decay: float = 0.0,
1263
+ ) -> Tuple[float, torch.nn.Module, list, torch.nn.Parameter]:
1264
+ """Train NLPCA with warmup, pruning, and early stopping."""
1265
+ best_model = None
1266
+ history: list[float] = []
1267
+
1268
+ early_stopping = EarlyStopping(
1269
+ patience=self.early_stop_gen,
1270
+ min_epochs=self.min_epochs,
1271
+ verbose=self.verbose,
1272
+ prefix=self.prefix,
1273
+ debug=self.debug,
1274
+ )
1275
+
1276
+ # Epoch budget
1277
+ max_epochs = (
1278
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
1279
+ )
1280
+
1281
+ # Optional LR warmup for both optimizers
1282
+ warmup_epochs = getattr(self, "lr_warmup_epochs", 5)
1283
+ model_lr0 = optimizer.param_groups[0]["lr"]
1284
+ latent_lr0 = latent_optimizer.param_groups[0]["lr"]
1285
+ model_lr_min = model_lr0 * 0.1
1286
+ latent_lr_min = latent_lr0 * 0.1
1287
+
1288
+ _latent_cache: dict = {}
1289
+ _latent_cache_key = f"{self.prefix}_{self.model_name}_val_latents"
1290
+
1291
+ for epoch in range(max_epochs):
1292
+ # Linear warmup LRs for first few epochs
1293
+ if epoch < warmup_epochs:
1294
+ scale = float(epoch + 1) / warmup_epochs
1295
+ for g in optimizer.param_groups:
1296
+ g["lr"] = model_lr_min + (model_lr0 - model_lr_min) * scale
1297
+ for g in latent_optimizer.param_groups:
1298
+ g["lr"] = latent_lr_min + (latent_lr0 - latent_lr_min) * scale
1299
+
1300
+ train_loss, latent_vectors = self._train_step(
1301
+ loader=loader,
1302
+ optimizer=optimizer,
1303
+ latent_optimizer=latent_optimizer,
1304
+ model=model,
1305
+ l1_penalty=l1_penalty,
1306
+ latent_vectors=latent_vectors,
1307
+ class_weights=class_weights,
1308
+ )
1309
+
1310
+ if not np.isfinite(train_loss):
1311
+ if trial:
1312
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
1313
+ # Reduce both LRs and continue
1314
+ for g in optimizer.param_groups:
1315
+ g["lr"] *= 0.5
1316
+ for g in latent_optimizer.param_groups:
1317
+ g["lr"] *= 0.5
1318
+ continue
1319
+
1320
+ if scheduler is not None:
1321
+ scheduler.step()
1322
+
1323
+ if return_history:
1324
+ history.append(train_loss)
1325
+
1326
+ # Optuna prune on validation metric
1327
+ if (
1328
+ trial is not None
1329
+ and X_val is not None
1330
+ and ((epoch + 1) % eval_interval == 0)
1331
+ ):
1332
+ seed = int(
1333
+ self.rng.integers(0, 1_000_000) if self.seed is None else self.seed
1334
+ )
1335
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1336
+ do_infer = int(eval_latent_steps) > 0
1337
+ metric_val = self._eval_for_pruning(
1338
+ model=model,
1339
+ X_val=X_val,
1340
+ params=params or getattr(self, "best_params_", {}),
1341
+ metric=metric_key,
1342
+ objective_mode=True,
1343
+ do_latent_infer=do_infer,
1344
+ latent_steps=eval_latent_steps,
1345
+ latent_lr=eval_latent_lr,
1346
+ latent_weight_decay=eval_latent_weight_decay,
1347
+ latent_seed=seed,
1348
+ _latent_cache=_latent_cache,
1349
+ _latent_cache_key=_latent_cache_key,
1350
+ eval_mask_override=(
1351
+ self.sim_mask_test_
1352
+ if (
1353
+ self.simulate_missing
1354
+ and getattr(self, "sim_mask_test_", None) is not None
1355
+ and X_val.shape == self.X_test_.shape
1356
+ )
1357
+ else (
1358
+ self.sim_mask_global_[self._tune_test_idx]
1359
+ if (
1360
+ self.simulate_missing
1361
+ and self.sim_mask_global_ is not None
1362
+ and hasattr(self, "_tune_test_idx")
1363
+ and X_val.shape[0] == len(self._tune_test_idx)
1364
+ )
1365
+ else None
1366
+ )
1367
+ ),
1368
+ )
1369
+
1370
+ trial.report(metric_val, step=epoch + 1)
1371
+
1372
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1373
+ raise optuna.exceptions.TrialPruned(
1374
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.4f}"
1375
+ )
1376
+
1377
+ early_stopping(train_loss, model)
1378
+ if early_stopping.early_stop:
1379
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
1380
+ break
1381
+
1382
+ best_loss = early_stopping.best_score
1383
+ best_model = copy.deepcopy(early_stopping.best_model)
1384
+ if best_model is None:
1385
+ best_model = copy.deepcopy(model)
1386
+ return best_loss, best_model, history, latent_vectors
1387
+
1388
+ def _optimize_latents_for_inference(
1389
+ self,
1390
+ X_new: np.ndarray,
1391
+ model: torch.nn.Module,
1392
+ params: dict,
1393
+ inference_epochs: int = 200,
1394
+ ) -> torch.Tensor:
1395
+ """Refine latents for new data with guards.
1396
+
1397
+ This method optimizes latent vectors for new data samples by refining them through gradient-based optimization. It initializes the latent space and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization.
1398
+
1399
+ Args:
1400
+ X_new (np.ndarray): New data in 0/1/2 encoding with -
1401
+ model (torch.nn.Module): Trained NLPCA model.
1402
+ params (dict): Model parameters.
1403
+ inference_epochs (int): Number of optimization epochs.
1404
+
1405
+ Returns:
1406
+ torch.Tensor: Optimized latent vectors for the new data.
1407
+
1408
+ """
1409
+ if self.tune and self.tune_fast:
1410
+ inference_epochs = min(
1411
+ inference_epochs, getattr(self, "tune_infer_epochs", 20)
1412
+ )
1413
+
1414
+ model.eval()
1415
+ nF = getattr(model, "n_features", self.num_features_)
1416
+
1417
+ z = self._create_latent_space(
1418
+ params, len(X_new), X_new, self.latent_init
1419
+ ).requires_grad_(True)
1420
+ opt = torch.optim.AdamW(
1421
+ [z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
1422
+ )
1423
+
1424
+ X_new = X_new.astype(np.int64, copy=False)
1425
+ X_new[X_new < 0] = -1
1426
+ y = torch.from_numpy(X_new).long().to(self.device)
1427
+
1428
+ for _ in range(inference_epochs):
1429
+ opt.zero_grad(set_to_none=True)
1430
+
1431
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1432
+
1433
+ if not isinstance(decoder, torch.nn.Module):
1434
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1435
+ self.logger.error(msg)
1436
+ raise TypeError(msg)
1437
+
1438
+ logits = decoder(z).view(len(X_new), nF, self.num_classes_)
1439
+
1440
+ if not torch.isfinite(logits).all():
1441
+ break
1442
+
1443
+ loss = F.cross_entropy(
1444
+ logits.view(-1, self.num_classes_),
1445
+ y.view(-1),
1446
+ ignore_index=-1,
1447
+ reduction="mean",
1448
+ )
1449
+ if not torch.isfinite(loss):
1450
+ break
1451
+
1452
+ loss.backward()
1453
+ torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
1454
+ if z.grad is None or not torch.isfinite(z.grad).all():
1455
+ break
1456
+ opt.step()
1457
+
1458
+ return z.detach()
1459
+
1460
+ def _latent_infer_for_eval(
1461
+ self,
1462
+ model: torch.nn.Module,
1463
+ X_val: np.ndarray,
1464
+ *,
1465
+ steps: int,
1466
+ lr: float,
1467
+ weight_decay: float,
1468
+ seed: int,
1469
+ cache: dict | None,
1470
+ cache_key: str | None,
1471
+ ) -> None:
1472
+ """Freeze weights; refine validation latents only (no leakage).
1473
+
1474
+ This method refines latent vectors for validation data by optimizing them while keeping the model weights frozen. It initializes the latent space, optionally using cached latent vectors, and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization and can store the optimized latent vectors in a cache.
1475
+
1476
+ Args:
1477
+ model (torch.nn.Module): Trained NLPCA model.
1478
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
1479
+ steps (int): Number of optimization steps.
1480
+ lr (float): Learning rate for latent optimization.
1481
+ weight_decay (float): Weight decay for latent optimization.
1482
+ seed (int): Random seed for reproducibility.
1483
+ cache (dict | None): Cache for storing optimized latent vectors.
1484
+ cache_key (str | None): Key for storing/retrieving from cache.
1485
+ """
1486
+ if seed is None:
1487
+ seed = np.random.randint(0, 999_999)
1488
+ torch.manual_seed(seed)
1489
+ np.random.seed(seed)
1490
+
1491
+ model.eval()
1492
+ nF = getattr(model, "n_features", self.num_features_)
1493
+
1494
+ for p in model.parameters():
1495
+ p.requires_grad_(False)
1496
+
1497
+ X_val = X_val.astype(np.int64, copy=False)
1498
+ X_val[X_val < 0] = -1
1499
+ y = torch.from_numpy(X_val).long().to(self.device)
1500
+
1501
+ latent_dim = self._first_linear_in_features(model)
1502
+ cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.num_classes_}"
1503
+
1504
+ if cache is not None and cache_key in cache:
1505
+ z = cache[cache_key].detach().clone().requires_grad_(True)
1506
+ else:
1507
+ z = self._create_latent_space(
1508
+ {"latent_dim": latent_dim},
1509
+ n_samples=X_val.shape[0],
1510
+ X=X_val,
1511
+ latent_init=self.latent_init,
1512
+ ).requires_grad_(True)
1513
+
1514
+ opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
1515
+
1516
+ for _ in range(max(int(steps), 0)):
1517
+ opt.zero_grad(set_to_none=True)
1518
+
1519
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1520
+
1521
+ if not isinstance(decoder, torch.nn.Module):
1522
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1523
+ self.logger.error(msg)
1524
+ raise TypeError(msg)
1525
+
1526
+ logits = decoder(z).view(X_val.shape[0], nF, self.num_classes_)
1527
+
1528
+ if not torch.isfinite(logits).all():
1529
+ break
1530
+
1531
+ loss = F.cross_entropy(
1532
+ logits.view(-1, self.num_classes_),
1533
+ y.view(-1),
1534
+ ignore_index=-1,
1535
+ reduction="mean",
1536
+ )
1537
+
1538
+ if not torch.isfinite(loss):
1539
+ break
1540
+
1541
+ loss.backward()
1542
+
1543
+ torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
1544
+
1545
+ if z.grad is None or not torch.isfinite(z.grad).all():
1546
+ break
1547
+
1548
+ opt.step()
1549
+
1550
+ if cache is not None:
1551
+ cache[cache_key] = z.detach().clone()
1552
+
1553
+ for p in model.parameters():
1554
+ p.requires_grad_(True)