pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,1264 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Tuple
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import optuna
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from sklearn.decomposition import PCA
11
+ from sklearn.exceptions import NotFittedError
12
+ from sklearn.model_selection import train_test_split
13
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
14
+ from snpio.utils.logging import LoggerManager
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+
17
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
18
+ from pgsui.data_processing.containers import NLPCAConfig
19
+ from pgsui.impute.unsupervised.base import BaseNNImputer
20
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
21
+ from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
22
+
23
+ if TYPE_CHECKING:
24
+ from snpio.read_input.genotype_data import GenotypeData
25
+
26
+
27
+ def ensure_nlpca_config(config: NLPCAConfig | dict | str | None) -> NLPCAConfig:
28
+ """Return a concrete NLPCAConfig from dataclass, dict, YAML path, or None.
29
+
30
+ Args:
31
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
32
+
33
+ Returns:
34
+ NLPCAConfig: Concrete configuration instance.
35
+ """
36
+ if config is None:
37
+ return NLPCAConfig()
38
+ if isinstance(config, NLPCAConfig):
39
+ return config
40
+ if isinstance(config, str):
41
+ # YAML path — top-level `preset` key is supported
42
+ return load_yaml_to_dataclass(
43
+ config,
44
+ NLPCAConfig,
45
+ preset_builder=NLPCAConfig.from_preset,
46
+ )
47
+ if isinstance(config, dict):
48
+ # Flatten dict into dot-keys then overlay onto a fresh instance
49
+ base = NLPCAConfig()
50
+
51
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
52
+ for k, v in d.items():
53
+ kk = f"{prefix}.{k}" if prefix else k
54
+ if isinstance(v, dict):
55
+ _flatten(kk, v, out)
56
+ else:
57
+ out[kk] = v
58
+ return out
59
+
60
+ # Lift any present preset first
61
+ preset_name = config.pop("preset", None)
62
+ if "io" in config and isinstance(config["io"], dict):
63
+ preset_name = preset_name or config["io"].pop("preset", None)
64
+
65
+ if preset_name:
66
+ base = NLPCAConfig.from_preset(preset_name)
67
+
68
+ flat = _flatten("", config, {})
69
+ return apply_dot_overrides(base, flat)
70
+
71
+ raise TypeError("config must be an NLPCAConfig, dict, YAML path, or None.")
72
+
73
+
74
+ class ImputeNLPCA(BaseNNImputer):
75
+ """Imputes missing genotypes using a Non-linear Principal Component Analysis (NLPCA) model.
76
+
77
+ This class implements an imputer based on Non-linear Principal Component Analysis (NLPCA) using a neural network architecture. It is designed to handle genotype data encoded in 0/1/2 format, where 0 represents the reference allele, 1 represents the heterozygous genotype, and 2 represents the alternate allele. Missing genotypes should be represented as -9 or -1.
78
+
79
+ The NLPCA model consists of an encoder-decoder architecture that learns a low-dimensional latent representation of the genotype data. The model is trained using a focal loss function to address class imbalance, and it can incorporate L1 regularization to promote sparsity in the learned representations.
80
+
81
+ Notes:
82
+ - Supports both haploid and diploid genotype data.
83
+ - Configurable model architecture with options for latent dimension, dropout rate, number of hidden layers, and activation functions.
84
+ - Hyperparameter tuning using Optuna for optimal model performance.
85
+ - Evaluation metrics including accuracy, F1-score, precision, recall, and ROC-AUC.
86
+ - Visualization of training history and genotype distributions.
87
+ - Flexible configuration via dataclass, dictionary, or YAML file.
88
+
89
+ Example:
90
+ >>> from snpio import VCFReader
91
+ >>> from pgsui import ImputeNLPCA
92
+ >>> gdata = VCFReader("genotypes.vcf.gz")
93
+ >>> imputer = ImputeNLPCA(gdata, config="nlpca_config.yaml")
94
+ >>> imputer.fit()
95
+ >>> imputed_genotypes = imputer.transform()
96
+ >>> print(imputed_genotypes)
97
+ [['A' 'G' 'C' ...],
98
+ ['G' 'G' 'C' ...],
99
+ ...
100
+ ['T' 'C' 'A' ...],
101
+ ['C' 'C' 'C' ...]]
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ genotype_data: "GenotypeData",
107
+ *,
108
+ config: NLPCAConfig | dict | str | None = None,
109
+ overrides: dict | None = None,
110
+ ):
111
+ """Initializes the ImputeNLPCA imputer with genotype data and configuration.
112
+
113
+ This constructor sets up the ImputeNLPCA imputer by accepting genotype data and a configuration that can be provided in various formats. It initializes logging, device settings, and model parameters based on the provided configuration.
114
+
115
+ Args:
116
+ genotype_data (GenotypeData): Backing genotype data.
117
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
118
+ overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
119
+ """
120
+ self.model_name = "ImputeNLPCA"
121
+ self.genotype_data = genotype_data
122
+
123
+ # Normalize config first, then apply overrides (highest precedence)
124
+ cfg = ensure_nlpca_config(config)
125
+ if overrides:
126
+ cfg = apply_dot_overrides(cfg, overrides)
127
+
128
+ self.cfg = cfg
129
+
130
+ logman = LoggerManager(
131
+ __name__,
132
+ prefix=self.cfg.io.prefix,
133
+ debug=self.cfg.io.debug,
134
+ verbose=self.cfg.io.verbose,
135
+ )
136
+ self.logger = logman.get_logger()
137
+
138
+ # Initialize BaseNNImputer with device/dirs/logging from config
139
+ super().__init__(
140
+ prefix=self.cfg.io.prefix,
141
+ device=self.cfg.train.device,
142
+ verbose=self.cfg.io.verbose,
143
+ debug=self.cfg.io.debug,
144
+ )
145
+
146
+ self.Model = NLPCAModel
147
+ self.pgenc = GenotypeEncoder(genotype_data)
148
+ self.seed = self.cfg.io.seed
149
+ self.n_jobs = self.cfg.io.n_jobs
150
+ self.prefix = self.cfg.io.prefix
151
+ self.scoring_averaging = self.cfg.io.scoring_averaging
152
+ self.verbose = self.cfg.io.verbose
153
+ self.debug = self.cfg.io.debug
154
+
155
+ self.rng = np.random.default_rng(self.seed)
156
+
157
+ # Model/train hyperparams
158
+ self.latent_dim = self.cfg.model.latent_dim
159
+ self.dropout_rate = self.cfg.model.dropout_rate
160
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
161
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
162
+ self.layer_schedule = self.cfg.model.layer_schedule
163
+ self.latent_init = self.cfg.model.latent_init
164
+ self.activation = self.cfg.model.hidden_activation
165
+ self.gamma = self.cfg.model.gamma
166
+
167
+ self.batch_size = self.cfg.train.batch_size
168
+ self.learning_rate = self.cfg.train.learning_rate
169
+ self.lr_input_factor = self.cfg.train.lr_input_factor
170
+ self.l1_penalty = self.cfg.train.l1_penalty
171
+ self.early_stop_gen = self.cfg.train.early_stop_gen
172
+ self.min_epochs = self.cfg.train.min_epochs
173
+ self.epochs = self.cfg.train.max_epochs
174
+ self.validation_split = self.cfg.train.validation_split
175
+ self.beta = self.cfg.train.weights_beta
176
+ self.max_ratio = self.cfg.train.weights_max_ratio
177
+
178
+ # Tuning
179
+ self.tune = self.cfg.tune.enabled
180
+ self.tune_fast = self.cfg.tune.fast
181
+ self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
182
+ self.tune_batch_size = self.cfg.tune.batch_size
183
+ self.tune_epochs = self.cfg.tune.epochs
184
+ self.tune_eval_interval = self.cfg.tune.eval_interval
185
+ self.tune_metric = self.cfg.tune.metric
186
+ self.n_trials = self.cfg.tune.n_trials
187
+ self.tune_save_db = self.cfg.tune.save_db
188
+ self.tune_resume = self.cfg.tune.resume
189
+ self.tune_max_samples = self.cfg.tune.max_samples
190
+ self.tune_max_loci = self.cfg.tune.max_loci
191
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
192
+ self.tune_patience = self.cfg.tune.patience
193
+
194
+ # Eval
195
+ self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
196
+ self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
197
+ self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
198
+
199
+ # Plotting (note: PlotConfig has 'show', not 'show_plots')
200
+ self.plot_format = self.cfg.plot.fmt
201
+ self.plot_dpi = self.cfg.plot.dpi
202
+ self.plot_fontsize = self.cfg.plot.fontsize
203
+ self.title_fontsize = self.cfg.plot.fontsize
204
+ self.despine = self.cfg.plot.despine
205
+ self.show_plots = self.cfg.plot.show
206
+
207
+ # Core model config
208
+ self.is_haploid = None
209
+ self.num_classes_ = None
210
+ self.model_params: Dict[str, Any] = {}
211
+
212
+ def fit(self) -> "ImputeNLPCA":
213
+ """Fits the NLPCA model to the 0/1/2 encoded genotype data.
214
+
215
+ This method prepares the data, splits it into training and validation sets, initializes the model, and trains it. If hyperparameter tuning is enabled, it will perform tuning before final training. After training, it evaluates the model on a test set and generates relevant plots.
216
+
217
+ Returns:
218
+ ImputeNLPCA: The fitted imputer instance.
219
+ """
220
+ self.logger.info(f"Fitting {self.model_name} model...")
221
+
222
+ # --- DATA PREPARATION ---
223
+ X = self.pgenc.genotypes_012.astype(np.float32)
224
+ X[X < 0] = np.nan # Ensure missing are NaN
225
+ X[np.isnan(X)] = -1 # Use -1 for missing, required by loss function
226
+ self.ground_truth_ = X.astype(np.int64)
227
+
228
+ # --- Determine Ploidy and Number of Classes ---
229
+ self.is_haploid = np.all(
230
+ np.isin(
231
+ self.genotype_data.snp_data, ["A", "C", "G", "T", "N", "-", ".", "?"]
232
+ )
233
+ )
234
+
235
+ self.ploidy = 1 if self.is_haploid else 2
236
+
237
+ if self.is_haploid:
238
+ self.num_classes_ = 2
239
+
240
+ # Remap labels from {0, 2} to {0, 1}
241
+ self.ground_truth_[self.ground_truth_ == 2] = 1
242
+ self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
243
+ else:
244
+ self.num_classes_ = 3
245
+
246
+ self.logger.info(
247
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2)."
248
+ )
249
+
250
+ n_samples, self.num_features_ = X.shape
251
+
252
+ self.model_params = {
253
+ "n_features": self.num_features_,
254
+ "latent_dim": self.latent_dim,
255
+ "dropout_rate": self.dropout_rate,
256
+ "activation": self.activation,
257
+ "gamma": self.gamma,
258
+ "num_classes": self.num_classes_,
259
+ }
260
+
261
+ # --- Train/Test Split ---
262
+ indices = np.arange(n_samples)
263
+ train_idx, test_idx = train_test_split(
264
+ indices, test_size=self.validation_split, random_state=self.seed
265
+ )
266
+ self.train_idx_, self.test_idx_ = train_idx, test_idx
267
+ self.X_train_ = self.ground_truth_[train_idx]
268
+ self.X_test_ = self.ground_truth_[test_idx]
269
+
270
+ # --- Tuning & Model Setup ---
271
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
272
+
273
+ if self.tune:
274
+ self.tune_hyperparameters()
275
+ self.best_params_ = getattr(self, "best_params_", self.model_params.copy())
276
+ else:
277
+ self.best_params_ = self._set_best_params_default()
278
+
279
+ # Class weights from 0/1/2 training data
280
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
281
+
282
+ # Latent vectors for training set
283
+ train_latent_vectors = self._create_latent_space(
284
+ self.best_params_, len(self.X_train_), self.X_train_, self.latent_init
285
+ )
286
+
287
+ train_loader = self._get_data_loaders(self.X_train_)
288
+
289
+ # Train the final model
290
+ (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
291
+ self._train_final_model(
292
+ train_loader, self.best_params_, train_latent_vectors
293
+ )
294
+ )
295
+
296
+ self.is_fit_ = True
297
+ self.plotter_.plot_history(self.history_)
298
+ self._evaluate_model(self.X_test_, self.model_, self.best_params_)
299
+ self._save_best_params(self.best_params_)
300
+ return self
301
+
302
+ def transform(self) -> np.ndarray:
303
+ """Imputes missing genotypes using the trained model.
304
+
305
+ This method uses the trained NLPCA model to impute missing genotypes in the entire dataset. It optimizes latent vectors for all samples, predicts missing values, and fills them in. The imputed genotypes are returned in IUPAC string format.
306
+
307
+ Returns:
308
+ np.ndarray: Imputed genotypes in IUPAC string format.
309
+
310
+ Raises:
311
+ NotFittedError: If the model has not been fitted.
312
+ """
313
+ if not getattr(self, "is_fit_", False):
314
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
315
+
316
+ self.logger.info("Imputing entire dataset...")
317
+ X_to_impute = self.ground_truth_.copy()
318
+
319
+ # Optimize latents for the full dataset
320
+ optimized_latents = self._optimize_latents_for_inference(
321
+ X_to_impute, self.model_, self.best_params_
322
+ )
323
+
324
+ # Predict missing values
325
+ pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
326
+
327
+ # Fill in missing values
328
+ missing_mask = X_to_impute == -1
329
+ imputed_array = X_to_impute.copy()
330
+ imputed_array[missing_mask] = pred_labels[missing_mask]
331
+
332
+ # Decode back to IUPAC strings
333
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
334
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
335
+
336
+ # Plot distributions
337
+ plt.rcParams.update(self.plotter_.param_dict) # Ensure consistent style
338
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
339
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
340
+
341
+ return imputed_genotypes
342
+
343
+ def _train_step(
344
+ self,
345
+ loader: torch.utils.data.DataLoader,
346
+ optimizer: torch.optim.Optimizer,
347
+ latent_optimizer: torch.optim.Optimizer,
348
+ model: torch.nn.Module,
349
+ l1_penalty: float,
350
+ latent_vectors: torch.nn.Parameter,
351
+ class_weights: torch.Tensor,
352
+ ) -> Tuple[float, torch.nn.Parameter]:
353
+ """Performs one epoch of training.
354
+
355
+ This method executes a single training epoch for the NLPCA model. It processes batches of data, computes the focal loss while handling missing values, applies L1 regularization if specified, and updates both the model parameters and latent vectors using their respective optimizers.
356
+
357
+ Args:
358
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
359
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
360
+ latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
361
+ model (torch.nn.Module): The NLPCA model.
362
+ l1_penalty (float): L1 regularization penalty.
363
+ latent_vectors (torch.nn.Parameter): Latent vectors for samples.
364
+ class_weights (torch.Tensor): Class weights for handling class imbalance.
365
+
366
+ Returns:
367
+ Tuple[float, torch.nn.Parameter]: Average training loss and updated latent vectors.
368
+ """
369
+ model.train()
370
+ running_loss = 0.0
371
+
372
+ nF = getattr(model, "n_features", self.num_features_)
373
+
374
+ for batch_indices, y_batch in loader:
375
+ optimizer.zero_grad(set_to_none=True)
376
+ latent_optimizer.zero_grad(set_to_none=True)
377
+
378
+ logits = model.phase23_decoder(latent_vectors[batch_indices]).view(
379
+ len(batch_indices), nF, self.num_classes_
380
+ )
381
+
382
+ # --- Simplified Focal Loss on 0/1/2 Classes ---
383
+ logits_flat = logits.view(-1, self.num_classes_)
384
+ targets_flat = y_batch.view(-1)
385
+
386
+ ce_loss = F.cross_entropy(
387
+ logits_flat,
388
+ targets_flat,
389
+ weight=class_weights,
390
+ reduction="none",
391
+ ignore_index=-1,
392
+ )
393
+
394
+ pt = torch.exp(-ce_loss)
395
+ gamma = getattr(model, "gamma", self.gamma)
396
+ focal_loss = ((1 - pt) ** gamma) * ce_loss
397
+
398
+ valid_mask = targets_flat != -1
399
+ loss = focal_loss[valid_mask].mean() if valid_mask.any() else 0.0
400
+
401
+ if l1_penalty > 0:
402
+ loss += l1_penalty * sum(p.abs().sum() for p in model.parameters())
403
+
404
+ loss.backward()
405
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
406
+ torch.nn.utils.clip_grad_norm_([latent_vectors], max_norm=1.0)
407
+ optimizer.step()
408
+ latent_optimizer.step()
409
+
410
+ running_loss += loss.item()
411
+
412
+ return running_loss / len(loader), latent_vectors
413
+
414
+ def _predict(
415
+ self, model: torch.nn.Module, latent_vectors: torch.nn.Parameter | None = None
416
+ ) -> Tuple[np.ndarray, np.ndarray]:
417
+ """Generates 0/1/2 predictions from latent vectors.
418
+
419
+ This method uses the trained NLPCA model to generate predictions from the latent vectors by passing them through the decoder. It returns both the predicted labels and their associated probabilities.
420
+
421
+ Args:
422
+ model (torch.nn.Module): Trained NLPCA model.
423
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
424
+
425
+ Returns:
426
+ Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
427
+ """
428
+ if model is None or latent_vectors is None:
429
+ raise NotFittedError("Model or latent vectors not available.")
430
+
431
+ model.eval()
432
+
433
+ nF = getattr(model, "n_features", self.num_features_)
434
+
435
+ with torch.no_grad():
436
+ logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
437
+ len(latent_vectors), nF, self.num_classes_
438
+ )
439
+ probas = torch.softmax(logits, dim=-1)
440
+ labels = torch.argmax(probas, dim=-1)
441
+
442
+ return labels.cpu().numpy(), probas.cpu().numpy()
443
+
444
+ def _evaluate_model(
445
+ self,
446
+ X_val: np.ndarray,
447
+ model: torch.nn.Module,
448
+ params: dict,
449
+ objective_mode: bool = False,
450
+ latent_vectors_val: torch.Tensor | None = None,
451
+ ) -> Dict[str, float]:
452
+ """Evaluates the model on a validation set.
453
+
454
+ This method evaluates the trained NLPCA model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
455
+
456
+ Args:
457
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
458
+ model (torch.nn.Module): Trained NLPCA model.
459
+ params (dict): Model parameters.
460
+ objective_mode (bool): If True, suppresses logging and reports only the metric.
461
+ latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
462
+
463
+ Returns:
464
+ Dict[str, float]: Dictionary of evaluation metrics.
465
+ """
466
+ if latent_vectors_val is not None:
467
+ test_latent_vectors = latent_vectors_val
468
+ else:
469
+ test_latent_vectors = self._optimize_latents_for_inference(
470
+ X_val, model, params
471
+ )
472
+
473
+ # The rest of the function remains the same...
474
+ pred_labels, pred_probas = self._predict(
475
+ model=model, latent_vectors=test_latent_vectors
476
+ )
477
+
478
+ eval_mask = X_val != -1
479
+ y_true_flat = X_val[eval_mask]
480
+ pred_labels_flat = pred_labels[eval_mask]
481
+ pred_probas_flat = pred_probas[eval_mask]
482
+
483
+ if y_true_flat.size == 0:
484
+ return {self.tune_metric: 0.0}
485
+
486
+ # For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
487
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
488
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
489
+
490
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
491
+
492
+ metrics = self.scorers_.evaluate(
493
+ y_true_flat,
494
+ pred_labels_flat,
495
+ y_true_ohe,
496
+ pred_probas_flat,
497
+ objective_mode,
498
+ self.tune_metric,
499
+ )
500
+
501
+ if not objective_mode:
502
+ self.logger.info(f"Validation Metrics: {metrics}")
503
+
504
+ self._make_class_reports(
505
+ y_true=y_true_flat,
506
+ y_pred_proba=pred_probas_flat,
507
+ y_pred=pred_labels_flat,
508
+ metrics=metrics,
509
+ labels=target_names,
510
+ )
511
+
512
+ y_true_dec = self.pgenc.decode_012(X_val)
513
+ X_pred = X_val.copy()
514
+ X_pred[eval_mask] = pred_labels_flat
515
+
516
+ nF_eval = X_val.shape[1]
517
+ y_pred_dec = self.pgenc.decode_012(X_pred.reshape(X_val.shape[0], nF_eval))
518
+
519
+ encodings_dict = {
520
+ "A": 0,
521
+ "C": 1,
522
+ "G": 2,
523
+ "T": 3,
524
+ "W": 4,
525
+ "R": 5,
526
+ "M": 6,
527
+ "K": 7,
528
+ "Y": 8,
529
+ "S": 9,
530
+ "N": -1,
531
+ }
532
+
533
+ y_true_int = self.pgenc.convert_int_iupac(
534
+ y_true_dec, encodings_dict=encodings_dict
535
+ )
536
+ y_pred_int = self.pgenc.convert_int_iupac(
537
+ y_pred_dec, encodings_dict=encodings_dict
538
+ )
539
+
540
+ self._make_class_reports(
541
+ y_true=y_true_int[eval_mask],
542
+ y_pred=y_pred_int[eval_mask],
543
+ metrics=metrics,
544
+ y_pred_proba=None,
545
+ labels=["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"],
546
+ )
547
+
548
+ return metrics
549
+
550
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
551
+ """Creates a PyTorch DataLoader for the 0/1/2 encoded data.
552
+
553
+ This method constructs a DataLoader from the provided genotype data, which is expected to be in 0/1/2 encoding with -1 for missing values. The DataLoader is used for batching and shuffling the data during model training. It converts the numpy array to a PyTorch tensor and creates a TensorDataset. The DataLoader is configured with the specified batch size and shuffling enabled.
554
+
555
+ Args:
556
+ y (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
557
+
558
+ Returns:
559
+ torch.utils.data.DataLoader: DataLoader for the dataset.
560
+ """
561
+ y_tensor = torch.from_numpy(y).long().to(self.device)
562
+ dataset = torch.utils.data.TensorDataset(
563
+ torch.arange(len(y), device=self.device), y_tensor.to(self.device)
564
+ )
565
+ return torch.utils.data.DataLoader(
566
+ dataset, batch_size=self.batch_size, shuffle=True
567
+ )
568
+
569
+ def _create_latent_space(
570
+ self,
571
+ params: dict,
572
+ n_samples: int,
573
+ X: np.ndarray,
574
+ latent_init: Literal["random", "pca"],
575
+ ) -> torch.nn.Parameter:
576
+ """Initializes the latent space for the NLPCA model.
577
+
578
+ This method initializes the latent space for the NLPCA model based on the specified initialization method. It supports two methods: 'random' initialization using Xavier uniform distribution, and 'pca' initialization which uses PCA to derive initial latent vectors from the data. The latent vectors are returned as a PyTorch Parameter, allowing them to be optimized during training.
579
+
580
+ Args:
581
+ params (dict): Model parameters including 'latent_dim'.
582
+ n_samples (int): Number of samples in the dataset.
583
+ X (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
584
+ latent_init (str): Method to initialize latent space ('random' or 'pca').
585
+
586
+ Returns:
587
+ torch.nn.Parameter: Initialized latent vectors as a PyTorch Parameter.
588
+ """
589
+ latent_dim = int(params["latent_dim"])
590
+
591
+ if latent_init == "pca":
592
+ X_pca = X.astype(np.float32, copy=True)
593
+ # mark missing
594
+ X_pca[X_pca < 0] = np.nan
595
+
596
+ # ---- SAFE column means without warnings ----
597
+ valid_counts = np.sum(~np.isnan(X_pca), axis=0)
598
+ col_sums = np.nansum(X_pca, axis=0)
599
+ col_means = np.divide(
600
+ col_sums,
601
+ valid_counts,
602
+ out=np.zeros_like(col_sums, dtype=np.float32),
603
+ where=valid_counts > 0,
604
+ )
605
+
606
+ # impute NaNs with per-column means
607
+ # (all-NaN cols -> 0.0 by the divide above)
608
+ nan_r, nan_c = np.where(np.isnan(X_pca))
609
+ if nan_r.size:
610
+ X_pca[nan_r, nan_c] = col_means[nan_c]
611
+
612
+ # center columns
613
+ X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
614
+
615
+ # guard: degenerate / all-zero after centering ->
616
+ # fall back to random
617
+ if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
618
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
619
+ torch.nn.init.xavier_uniform_(latents)
620
+ return torch.nn.Parameter(latents, requires_grad=True)
621
+
622
+ # rank-aware component count, at least 1
623
+ try:
624
+ est_rank = np.linalg.matrix_rank(X_pca)
625
+ except Exception:
626
+ est_rank = min(n_samples, X_pca.shape[1])
627
+
628
+ n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
629
+
630
+ # use deterministic SVD to avoid power-iteration warnings
631
+ pca = PCA(
632
+ n_components=n_components, svd_solver="full", random_state=self.seed
633
+ )
634
+ initial = pca.fit_transform(X_pca) # (n_samples, n_components)
635
+
636
+ # pad if latent_dim > n_components
637
+ if n_components < latent_dim:
638
+ pad = self.rng.standard_normal(
639
+ size=(n_samples, latent_dim - n_components)
640
+ )
641
+ initial = np.hstack([initial, pad])
642
+
643
+ # standardize latent dims
644
+ initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
645
+
646
+ latents = torch.from_numpy(initial).float().to(self.device)
647
+ return torch.nn.Parameter(latents, requires_grad=True)
648
+
649
+ # --- Random init path (unchanged) ---
650
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
651
+ torch.nn.init.xavier_uniform_(latents)
652
+ return torch.nn.Parameter(latents, requires_grad=True)
653
+
654
+ def _objective(self, trial: optuna.Trial) -> float:
655
+ """Objective function for hyperparameter tuning with Optuna.
656
+
657
+ This method defines the objective function used by Optuna for hyperparameter tuning of the NLPCA model. It samples a set of hyperparameters, prepares the training and validation data, initializes the model and latent vectors, and trains the model. After training, it evaluates the model on a validation set and returns the value of the specified tuning metric.
658
+
659
+ Args:
660
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
661
+
662
+ Returns:
663
+ float: The value of the tuning metric to be minimized or maximized.
664
+ """
665
+ self._prepare_tuning_artifacts()
666
+ trial_params = self._sample_hyperparameters(trial)
667
+ model_params = dict(trial_params["model_params"])
668
+
669
+ if self.tune and self.tune_fast:
670
+ model_params["n_features"] = self._tune_num_features
671
+
672
+ lr = trial_params["lr"]
673
+ l1_penalty = trial_params["l1_penalty"]
674
+ lr_input_fac = trial_params["lr_input_factor"]
675
+
676
+ X_train_trial = self._tune_X_train
677
+ X_test_trial = self._tune_X_test
678
+ class_weights = self._tune_class_weights
679
+ train_loader = self._tune_loader
680
+
681
+ train_latents = self._create_latent_space(
682
+ model_params,
683
+ len(X_train_trial),
684
+ X_train_trial,
685
+ trial_params["latent_init"],
686
+ )
687
+
688
+ model = self.build_model(self.Model, model_params)
689
+ model.n_features = model_params["n_features"]
690
+ model.apply(self.initialize_weights)
691
+
692
+ # train; pass an explicit flag that we are in tuning + whether to fix latents
693
+ _, model, _ = self._train_and_validate_model(
694
+ model=model,
695
+ loader=train_loader,
696
+ lr=lr,
697
+ l1_penalty=l1_penalty,
698
+ trial=trial,
699
+ latent_vectors=train_latents,
700
+ lr_input_factor=lr_input_fac,
701
+ class_weights=class_weights,
702
+ X_val=X_test_trial,
703
+ params=model_params,
704
+ prune_metric=self.tune_metric,
705
+ prune_warmup_epochs=5,
706
+ eval_interval=self.tune_eval_interval,
707
+ eval_latent_steps=0,
708
+ eval_latent_lr=0.0,
709
+ eval_latent_weight_decay=0.0,
710
+ )
711
+
712
+ metrics = self._evaluate_model(
713
+ X_test_trial, model, model_params, objective_mode=True
714
+ )
715
+
716
+ self._clear_resources(model, train_loader, latent_vectors=train_latents)
717
+ return metrics[self.tune_metric]
718
+
719
+ def _sample_hyperparameters(
720
+ self, trial: optuna.Trial
721
+ ) -> Dict[str, int | float | str | list]:
722
+ """Samples hyperparameters for the simplified NLPCA model.
723
+
724
+ This method defines the hyperparameter search space for the NLPCA model and samples a set of hyperparameters using the provided Optuna trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
725
+
726
+ Args:
727
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
728
+
729
+ Returns:
730
+ Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
731
+ """
732
+ params = {
733
+ "latent_dim": trial.suggest_int("latent_dim", 2, 32),
734
+ "lr": trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
735
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.5, step=0.05),
736
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 16),
737
+ "activation": trial.suggest_categorical(
738
+ "activation", ["relu", "elu", "selu", "leaky_relu"]
739
+ ),
740
+ "gamma": trial.suggest_float("gamma", 0.1, 5.0, step=0.1),
741
+ "lr_input_factor": trial.suggest_float(
742
+ "lr_input_factor", 0.1, 10.0, log=True
743
+ ),
744
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
745
+ "layer_scaling_factor": trial.suggest_float(
746
+ "layer_scaling_factor", 2.0, 10.0
747
+ ),
748
+ "layer_schedule": trial.suggest_categorical(
749
+ "layer_schedule", ["pyramid", "constant", "linear"]
750
+ ),
751
+ "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
752
+ }
753
+
754
+ use_n_features = (
755
+ self._tune_num_features
756
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
757
+ else self.num_features_
758
+ )
759
+ use_n_samples = (
760
+ len(self._tune_train_idx)
761
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_train_idx"))
762
+ else len(self.train_idx_)
763
+ )
764
+
765
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
766
+ n_inputs=params["latent_dim"],
767
+ n_outputs=use_n_features * self.num_classes_,
768
+ n_samples=use_n_samples,
769
+ n_hidden=params["num_hidden_layers"],
770
+ alpha=params["layer_scaling_factor"],
771
+ schedule=params["layer_schedule"],
772
+ )
773
+
774
+ params["model_params"] = {
775
+ "n_features": use_n_features,
776
+ "num_classes": self.num_classes_,
777
+ "latent_dim": params["latent_dim"],
778
+ "dropout_rate": params["dropout_rate"],
779
+ "hidden_layer_sizes": hidden_layer_sizes,
780
+ "activation": params["activation"],
781
+ "gamma": params["gamma"],
782
+ }
783
+
784
+ return params
785
+
786
+ def _set_best_params(
787
+ self, best_params: Dict[str, int | float | str | list]
788
+ ) -> Dict[str, int | float | str | list]:
789
+ """Sets the best hyperparameters found during tuning.
790
+
791
+ This method updates the model's attributes with the best hyperparameters obtained from tuning. It also computes the hidden layer sizes based on these parameters and prepares the final model parameters dictionary.
792
+
793
+ Args:
794
+ best_params (dict): Best hyperparameters from tuning.
795
+
796
+ Returns:
797
+ Dict[str, int | float | str | list]: Model parameters configured with the best hyperparameters.
798
+ """
799
+ self.latent_dim = best_params["latent_dim"]
800
+ self.dropout_rate = best_params["dropout_rate"]
801
+ self.learning_rate = best_params["learning_rate"]
802
+ self.gamma = best_params["gamma"]
803
+ self.lr_input_factor = best_params["lr_input_factor"]
804
+ self.l1_penalty = best_params["l1_penalty"]
805
+ self.activation = best_params["activation"]
806
+
807
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
808
+ n_inputs=self.latent_dim,
809
+ n_outputs=self.num_features_ * self.num_classes_,
810
+ n_samples=len(self.train_idx_),
811
+ n_hidden=best_params["num_hidden_layers"],
812
+ alpha=best_params["layer_scaling_factor"],
813
+ schedule=best_params["layer_schedule"],
814
+ )
815
+
816
+ return {
817
+ "n_features": self.num_features_,
818
+ "latent_dim": self.latent_dim,
819
+ "hidden_layer_sizes": hidden_layer_sizes,
820
+ "dropout_rate": self.dropout_rate,
821
+ "activation": self.activation,
822
+ "gamma": self.gamma,
823
+ "num_classes": self.num_classes_,
824
+ }
825
+
826
+ def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
827
+ """Default (no-tuning) model_params aligned with current attributes.
828
+
829
+ This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
830
+
831
+ Returns:
832
+ Dict[str, int | float | str | list]: model_params payload.
833
+ """
834
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
835
+ n_inputs=self.latent_dim,
836
+ n_outputs=self.num_features_ * self.num_classes_,
837
+ n_samples=len(self.ground_truth_),
838
+ n_hidden=self.num_hidden_layers,
839
+ alpha=self.layer_scaling_factor,
840
+ schedule=self.layer_schedule,
841
+ )
842
+
843
+ return {
844
+ "n_features": self.num_features_,
845
+ "latent_dim": self.latent_dim,
846
+ "hidden_layer_sizes": hidden_layer_sizes,
847
+ "dropout_rate": self.dropout_rate,
848
+ "activation": self.activation,
849
+ "gamma": self.gamma,
850
+ "num_classes": self.num_classes_,
851
+ }
852
+
853
+ def _train_and_validate_model(
854
+ self,
855
+ model: torch.nn.Module,
856
+ loader: torch.utils.data.DataLoader,
857
+ lr: float,
858
+ l1_penalty: float,
859
+ trial: optuna.Trial | None = None,
860
+ return_history: bool = False,
861
+ latent_vectors: torch.nn.Parameter | None = None,
862
+ lr_input_factor: float = 1.0,
863
+ class_weights: torch.Tensor | None = None,
864
+ *,
865
+ X_val: np.ndarray | None = None,
866
+ params: dict | None = None,
867
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
868
+ prune_warmup_epochs: int = 3,
869
+ eval_interval: int = 1,
870
+ eval_latent_steps: int = 50,
871
+ eval_latent_lr: float = 1e-2,
872
+ eval_latent_weight_decay: float = 0.0,
873
+ ) -> Tuple:
874
+ """Trains and validates the NLPCA model.
875
+
876
+ This method trains the provided NLPCA model using the specified training data and hyperparameters. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method initializes optimizers for both the model parameters and latent vectors, sets up a learning rate scheduler, and executes the training loop. It can return the training history if requested.
877
+
878
+ Args:
879
+ model (torch.nn.Module): The NLPCA model to be trained.
880
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
881
+ lr (float): Learning rate for the model optimizer.
882
+ l1_penalty (float): L1 regularization penalty.
883
+ trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
884
+ return_history (bool): Whether to return training history.
885
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
886
+ lr_input_factor (float): Learning rate factor for latent vectors.
887
+ class_weights (torch.Tensor | None): Class weights for handling class imbalance.
888
+ X_val (np.ndarray | None): Validation data for pruning.
889
+ params (dict | None): Model parameters.
890
+ prune_metric (str | None): Metric for pruning decisions.
891
+ prune_warmup_epochs (int): Number of epochs before pruning starts.
892
+ eval_interval (int): Interval (in epochs) for evaluation during training.
893
+ eval_latent_steps (int): Steps for latent optimization during evaluation.
894
+ eval_latent_lr (float): Learning rate for latent optimization during evaluation.
895
+ eval_latent_weight_decay (float): Weight decay for latent optimization during evaluation.
896
+
897
+ Returns:
898
+ Tuple[float, torch.nn.Module, Dict[str, float], torch.nn.Parameter] | Tuple[float, torch.nn.Module, torch.nn.Parameter]: Training loss, trained model, training history (if requested), and optimized latent vectors.
899
+
900
+ Raises:
901
+ TypeError: If latent_vectors or class_weights are not provided.
902
+ """
903
+
904
+ if latent_vectors is None or class_weights is None:
905
+ msg = "latent_vectors and class_weights must be provided."
906
+ self.logger.error(msg)
907
+ raise TypeError("Must provide latent_vectors and class_weights.")
908
+
909
+ latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
910
+
911
+ optimizer = torch.optim.Adam(model.phase23_decoder.parameters(), lr=lr)
912
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.epochs)
913
+
914
+ result = self._execute_training_loop(
915
+ loader=loader,
916
+ optimizer=optimizer,
917
+ latent_optimizer=latent_optimizer,
918
+ scheduler=scheduler,
919
+ model=model,
920
+ l1_penalty=l1_penalty,
921
+ return_history=return_history,
922
+ latent_vectors=latent_vectors,
923
+ class_weights=class_weights,
924
+ trial=trial,
925
+ X_val=X_val,
926
+ params=params,
927
+ prune_metric=prune_metric,
928
+ prune_warmup_epochs=prune_warmup_epochs,
929
+ eval_interval=eval_interval,
930
+ eval_latent_steps=eval_latent_steps,
931
+ eval_latent_lr=eval_latent_lr,
932
+ eval_latent_weight_decay=eval_latent_weight_decay,
933
+ )
934
+
935
+ if return_history:
936
+ return result
937
+
938
+ return result[0], result[1], result[3]
939
+
940
+ def _train_final_model(
941
+ self,
942
+ loader: torch.utils.data.DataLoader,
943
+ best_params: dict,
944
+ initial_latent_vectors: torch.nn.Parameter,
945
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
946
+ """Trains the final model using the best hyperparameters.
947
+
948
+ This method builds and trains the final NLPCA model using the best hyperparameters obtained from tuning. It initializes the model weights, trains the model on the entire training set, and saves the trained model to disk. It returns the final training loss, trained model, training history, and optimized latent vectors.
949
+
950
+ Args:
951
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
952
+ best_params (dict): Best hyperparameters for the model.
953
+ initial_latent_vectors (torch.nn.Parameter): Initial latent vectors for samples.
954
+
955
+ Returns:
956
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: Final training loss, trained model, training history, and optimized latent vectors.
957
+ Raises:
958
+ RuntimeError: If model training fails.
959
+ """
960
+ self.logger.info(f"Training the final model...")
961
+
962
+ model = self.build_model(self.Model, best_params)
963
+ model.n_features = best_params["n_features"]
964
+ model.apply(self.initialize_weights)
965
+
966
+ loss, trained_model, history, latent_vectors = self._train_and_validate_model(
967
+ model=model,
968
+ loader=loader,
969
+ lr=self.learning_rate,
970
+ l1_penalty=self.l1_penalty,
971
+ return_history=True,
972
+ latent_vectors=initial_latent_vectors,
973
+ lr_input_factor=self.lr_input_factor,
974
+ class_weights=self.class_weights_,
975
+ X_val=self.X_test_,
976
+ params=best_params,
977
+ prune_metric=self.tune_metric,
978
+ prune_warmup_epochs=5,
979
+ eval_interval=1,
980
+ eval_latent_steps=50,
981
+ eval_latent_lr=self.learning_rate * self.lr_input_factor,
982
+ eval_latent_weight_decay=0.0,
983
+ )
984
+
985
+ if trained_model is None:
986
+ msg = "Final model training failed."
987
+ self.logger.error(msg)
988
+ raise RuntimeError(msg)
989
+
990
+ fn = self.models_dir / "final_model.pt"
991
+ torch.save(trained_model.state_dict(), fn)
992
+
993
+ return (loss, trained_model, {"Train": history}, latent_vectors)
994
+
995
+ def _execute_training_loop(
996
+ self,
997
+ loader,
998
+ optimizer,
999
+ latent_optimizer,
1000
+ scheduler,
1001
+ model,
1002
+ l1_penalty,
1003
+ return_history,
1004
+ latent_vectors,
1005
+ class_weights,
1006
+ *,
1007
+ trial: optuna.Trial | None = None,
1008
+ X_val: np.ndarray | None = None,
1009
+ params: dict | None = None,
1010
+ prune_metric: str | None = None,
1011
+ prune_warmup_epochs: int = 3,
1012
+ eval_interval: int = 1,
1013
+ eval_latent_steps: int = 50,
1014
+ eval_latent_lr: float = 1e-2,
1015
+ eval_latent_weight_decay: float = 0.0,
1016
+ ) -> Tuple[float, torch.nn.Module, list, torch.nn.Parameter]:
1017
+ """Executes the training loop with optional Optuna pruning.
1018
+
1019
+ This method runs the training loop for the NLPCA model, performing multiple epochs of training. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method tracks training history, applies early stopping, and returns the best model and training metrics.
1020
+
1021
+ Args:
1022
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1023
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
1024
+ latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
1025
+ scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
1026
+ model (torch.nn.Module): The NLPCA model.
1027
+ l1_penalty (float): L1 regularization penalty.
1028
+ return_history (bool): Whether to return training history.
1029
+ latent_vectors (torch.nn.Parameter): Latent vectors for samples.
1030
+ class_weights (torch.Tensor): Class weights for
1031
+ handling class imbalance.
1032
+ trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
1033
+ X_val (np.ndarray | None): Validation data for pruning.
1034
+ params (dict | None): Model parameters.
1035
+ prune_metric (str | None): Metric to monitor for pruning.
1036
+ prune_warmup_epochs (int): Epochs to wait before pruning.
1037
+ eval_interval (int): Epoch interval for evaluation.
1038
+ eval_latent_steps (int): Steps to refine latents during eval.
1039
+ eval_latent_lr (float): Learning rate for latent refinement during eval.
1040
+ eval_latent_weight_decay (float): Weight decay for latent refinement during eval.
1041
+
1042
+ Returns:
1043
+ Tuple[float, torch.nn.Module, list, torch.nn.Parameter]: Best loss, best model, training history, and optimized latent vectors.
1044
+
1045
+ Raises:
1046
+ optuna.exceptions.TrialPruned: If the trial is pruned based on validation performance.
1047
+ """
1048
+ best_model = None
1049
+ train_history = []
1050
+ early_stopping = EarlyStopping(
1051
+ patience=self.early_stop_gen,
1052
+ min_epochs=self.min_epochs,
1053
+ verbose=self.verbose,
1054
+ prefix=self.prefix,
1055
+ debug=self.debug,
1056
+ )
1057
+
1058
+ # compute the epoch budget used by the loop
1059
+ max_epochs = (
1060
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
1061
+ )
1062
+ scheduler = CosineAnnealingLR(optimizer, T_max=max_epochs)
1063
+
1064
+ # just above the for-epoch loop
1065
+ _latent_cache: dict = {}
1066
+ _latent_cache_key = f"{self.prefix}_{self.model_name}_val_latents"
1067
+
1068
+ for epoch in range(max_epochs):
1069
+ train_loss, latent_vectors = self._train_step(
1070
+ loader,
1071
+ optimizer,
1072
+ latent_optimizer,
1073
+ model,
1074
+ l1_penalty,
1075
+ latent_vectors,
1076
+ class_weights,
1077
+ )
1078
+ scheduler.step()
1079
+
1080
+ if np.isnan(train_loss) or np.isinf(train_loss):
1081
+ raise optuna.exceptions.TrialPruned("Loss is NaN or Inf.")
1082
+
1083
+ if return_history:
1084
+ train_history.append(train_loss)
1085
+
1086
+ if (
1087
+ trial is not None
1088
+ and X_val is not None
1089
+ and ((epoch + 1) % eval_interval == 0)
1090
+ ):
1091
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1092
+
1093
+ do_infer = (eval_latent_steps or 0) > 0
1094
+ metric_val = self._eval_for_pruning(
1095
+ model=model,
1096
+ X_val=X_val,
1097
+ params=params or getattr(self, "best_params_", {}),
1098
+ metric=metric_key,
1099
+ objective_mode=True,
1100
+ do_latent_infer=do_infer,
1101
+ latent_steps=eval_latent_steps,
1102
+ latent_lr=eval_latent_lr,
1103
+ latent_weight_decay=eval_latent_weight_decay,
1104
+ latent_seed=(self.seed if self.seed is not None else None),
1105
+ _latent_cache=_latent_cache,
1106
+ _latent_cache_key=_latent_cache_key,
1107
+ )
1108
+
1109
+ trial.report(metric_val, step=epoch + 1)
1110
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1111
+ raise optuna.exceptions.TrialPruned(
1112
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.3f}"
1113
+ )
1114
+
1115
+ early_stopping(train_loss, model)
1116
+ if early_stopping.early_stop:
1117
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
1118
+ break
1119
+
1120
+ best_loss = early_stopping.best_score
1121
+ best_model = model # reuse instance
1122
+ best_model.load_state_dict(early_stopping.best_model.state_dict())
1123
+ return best_loss, best_model, train_history, latent_vectors
1124
+
1125
+ def _optimize_latents_for_inference(
1126
+ self,
1127
+ X_new: np.ndarray,
1128
+ model: torch.nn.Module,
1129
+ params: dict,
1130
+ inference_epochs: int = 200,
1131
+ ) -> torch.Tensor:
1132
+ """Optimizes latent vectors for new, unseen data.
1133
+
1134
+ This method optimizes latent vectors for new data samples that were not part of the training set. It initializes latent vectors and performs gradient-based optimization to minimize the reconstruction loss using the trained NLPCA model. The optimized latent vectors are returned for further predictions.
1135
+
1136
+ Args:
1137
+ X_new (np.ndarray): New data in 0/1/2 encoding with -1 for missing values.
1138
+ model (torch.nn.Module): Trained NLPCA model.
1139
+ params (dict): Model parameters.
1140
+ inference_epochs (int): Number of epochs to optimize latent vectors.
1141
+
1142
+ Returns:
1143
+ torch.Tensor: Optimized latent vectors for the new data.
1144
+ """
1145
+ if self.tune and self.tune_fast:
1146
+ inference_epochs = min(
1147
+ inference_epochs, getattr(self, "tune_infer_epochs", 20)
1148
+ )
1149
+
1150
+ model.eval()
1151
+
1152
+ nF = getattr(model, "n_features", self.num_features_)
1153
+
1154
+ new_latent_vectors = self._create_latent_space(
1155
+ params, len(X_new), X_new, self.latent_init
1156
+ )
1157
+ latent_optimizer = torch.optim.Adam(
1158
+ [new_latent_vectors], lr=self.learning_rate * self.lr_input_factor
1159
+ )
1160
+ y_target = torch.from_numpy(X_new).long().to(self.device)
1161
+
1162
+ for _ in range(inference_epochs):
1163
+ latent_optimizer.zero_grad()
1164
+ logits = model.phase23_decoder(new_latent_vectors).view(
1165
+ len(X_new), nF, self.num_classes_
1166
+ )
1167
+ loss = F.cross_entropy(
1168
+ logits.view(-1, self.num_classes_), y_target.view(-1), ignore_index=-1
1169
+ )
1170
+ if torch.isnan(loss):
1171
+ self.logger.warning("Inference loss is NaN; stopping.")
1172
+ break
1173
+ loss.backward()
1174
+ latent_optimizer.step()
1175
+
1176
+ return new_latent_vectors.detach()
1177
+
1178
+ def _latent_infer_for_eval(
1179
+ self,
1180
+ model: torch.nn.Module,
1181
+ X_val: np.ndarray,
1182
+ *,
1183
+ steps: int,
1184
+ lr: float,
1185
+ weight_decay: float,
1186
+ seed: int,
1187
+ cache: dict | None,
1188
+ cache_key: str | None,
1189
+ ) -> None:
1190
+ """Freeze weights; refine validation latents only (no leakage).
1191
+
1192
+ This method refines latent vectors for validation data by optimizing them while keeping the model weights frozen. It uses gradient-based optimization to minimize the reconstruction loss on the validation set. The optimized latent vectors can be cached for future use.
1193
+
1194
+ Args:
1195
+ model (torch.nn.Module): Trained NLPCA model.
1196
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with - 1 for missing.
1197
+ steps (int): Number of optimization steps for latent refinement.
1198
+ lr (float): Learning rate for latent optimization.
1199
+ weight_decay (float): Weight decay for latent optimization.
1200
+ seed (int): Random seed for reproducibility.
1201
+ cache (dict | None): Cache for storing optimized latents.
1202
+ cache_key (str | None): Key for storing/retrieving latents in/from cache
1203
+
1204
+ Returns:
1205
+ None. Updates cache in place if provided.
1206
+ """
1207
+ if seed is None:
1208
+ seed = np.random.randint(0, 999999)
1209
+
1210
+ torch.manual_seed(seed)
1211
+ np.random.seed(seed)
1212
+
1213
+ model.eval()
1214
+
1215
+ nF = getattr(model, "n_features", self.num_features_)
1216
+
1217
+ for p in model.parameters():
1218
+ p.requires_grad_(False)
1219
+
1220
+ X_val = X_val.astype(np.int64, copy=False)
1221
+ X_val[X_val < 0] = -1
1222
+ y_target = torch.from_numpy(X_val).long().to(self.device)
1223
+
1224
+ # Get latent_dim from the *model actually being evaluated*
1225
+ latent_dim_model = self._first_linear_in_features(model)
1226
+
1227
+ # Make a cache key that is specific to this latent size (and feature schema)
1228
+ cache_key = (
1229
+ f"{self.prefix}_nlpca_val_latents_"
1230
+ f"z{latent_dim_model}_L{self.num_features_}_K{self.num_classes_}"
1231
+ )
1232
+
1233
+ # Warm-start from cache if available *and* shape-compatible
1234
+ if cache is not None and cache_key in cache:
1235
+ val_latents = cache[cache_key].detach().clone().requires_grad_(True)
1236
+ else:
1237
+ val_latents = self._create_latent_space(
1238
+ {"latent_dim": latent_dim_model}, # use model's latent size
1239
+ n_samples=X_val.shape[0],
1240
+ X=X_val,
1241
+ latent_init=self.latent_init,
1242
+ ).requires_grad_(True)
1243
+
1244
+ opt = torch.optim.AdamW([val_latents], lr=lr, weight_decay=weight_decay)
1245
+
1246
+ for _ in range(max(int(steps), 0)):
1247
+ opt.zero_grad(set_to_none=True)
1248
+ logits = model.phase23_decoder(val_latents).view(
1249
+ X_val.shape[0], nF, self.num_classes_
1250
+ )
1251
+ loss = F.cross_entropy(
1252
+ logits.view(-1, self.num_classes_),
1253
+ y_target.view(-1),
1254
+ ignore_index=-1,
1255
+ reduction="mean",
1256
+ )
1257
+ loss.backward()
1258
+ opt.step()
1259
+
1260
+ if cache is not None:
1261
+ cache[cache_key] = val_latents.detach().clone()
1262
+
1263
+ for p in model.parameters():
1264
+ p.requires_grad_(True)