pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
@@ -0,0 +1,1666 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import optuna
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.exceptions import NotFittedError
13
+ from sklearn.model_selection import train_test_split
14
+ from snpio.analysis.genotype_encoder import GenotypeEncoder
15
+ from snpio.utils.logging import LoggerManager
16
+ from torch.optim.lr_scheduler import CosineAnnealingLR
17
+
18
+ from pgsui.data_processing.config import apply_dot_overrides, load_yaml_to_dataclass
19
+ from pgsui.data_processing.containers import NLPCAConfig
20
+ from pgsui.data_processing.transformers import SimMissingTransformer
21
+ from pgsui.impute.unsupervised.base import BaseNNImputer
22
+ from pgsui.impute.unsupervised.callbacks import EarlyStopping
23
+ from pgsui.impute.unsupervised.loss_functions import SafeFocalCELoss
24
+ from pgsui.impute.unsupervised.models.nlpca_model import NLPCAModel
25
+ from pgsui.utils.logging_utils import configure_logger
26
+ from pgsui.utils.pretty_metrics import PrettyMetrics
27
+
28
+ if TYPE_CHECKING:
29
+ from snpio import TreeParser
30
+ from snpio.read_input.genotype_data import GenotypeData
31
+
32
+
33
+ def ensure_nlpca_config(config: NLPCAConfig | dict | str | None) -> NLPCAConfig:
34
+ """Return a concrete NLPCAConfig from dataclass, dict, YAML path, or None.
35
+
36
+ Args:
37
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
38
+
39
+ Returns:
40
+ NLPCAConfig: Concrete configuration instance.
41
+ """
42
+ if config is None:
43
+ return NLPCAConfig()
44
+ if isinstance(config, NLPCAConfig):
45
+ return config
46
+ if isinstance(config, str):
47
+ # YAML path — top-level `preset` key is supported
48
+ return load_yaml_to_dataclass(config, NLPCAConfig)
49
+ if isinstance(config, dict):
50
+ # Flatten dict into dot-keys then overlay onto a fresh instance
51
+ base = NLPCAConfig()
52
+
53
+ def _flatten(prefix: str, d: dict, out: dict) -> dict:
54
+ for k, v in d.items():
55
+ kk = f"{prefix}.{k}" if prefix else k
56
+ if isinstance(v, dict):
57
+ _flatten(kk, v, out)
58
+ else:
59
+ out[kk] = v
60
+ return out
61
+
62
+ # Lift any present preset first
63
+ preset_name = config.pop("preset", None)
64
+ if "io" in config and isinstance(config["io"], dict):
65
+ preset_name = preset_name or config["io"].pop("preset", None)
66
+
67
+ if preset_name:
68
+ base = NLPCAConfig.from_preset(preset_name)
69
+
70
+ flat = _flatten("", config, {})
71
+ return apply_dot_overrides(base, flat)
72
+
73
+ raise TypeError("config must be an NLPCAConfig, dict, YAML path, or None.")
74
+
75
+
76
+ class ImputeNLPCA(BaseNNImputer):
77
+ """Imputes missing genotypes using a Non-linear Principal Component Analysis (NLPCA) model.
78
+
79
+ This class implements an imputer based on Non-linear Principal Component Analysis (NLPCA) using a neural network architecture. It is designed to handle genotype data encoded in 0/1/2 format, where 0 represents the reference allele, 1 represents the heterozygous genotype, and 2 represents the alternate allele. Missing genotypes should be represented as -9 or -1.
80
+
81
+ The NLPCA model consists of an encoder-decoder architecture that learns a low-dimensional latent representation of the genotype data. The model is trained using a focal loss function to address class imbalance, and it can incorporate L1 regularization to promote sparsity in the learned representations.
82
+
83
+ Notes:
84
+ - Supports both haploid and diploid genotype data.
85
+ - Configurable model architecture with options for latent dimension, dropout rate, number of hidden layers, and activation functions.
86
+ - Hyperparameter tuning using Optuna for optimal model performance.
87
+ - Evaluation metrics including accuracy, F1-score, precision, recall, and ROC-AUC.
88
+ - Visualization of training history and genotype distributions.
89
+ - Flexible configuration via dataclass, dictionary, or YAML file.
90
+
91
+ Example:
92
+ >>> from snpio import VCFReader
93
+ >>> from pgsui import ImputeNLPCA
94
+ >>> gdata = VCFReader("genotypes.vcf.gz")
95
+ >>> imputer = ImputeNLPCA(gdata, config="nlpca_config.yaml")
96
+ >>> imputer.fit()
97
+ >>> imputed_genotypes = imputer.transform()
98
+ >>> print(imputed_genotypes)
99
+ [['A' 'G' 'C' ...],
100
+ ['G' 'G' 'C' ...],
101
+ ...
102
+ ['T' 'C' 'A' ...],
103
+ ['C' 'C' 'C' ...]]
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ genotype_data: "GenotypeData",
109
+ *,
110
+ tree_parser: Optional["TreeParser"] = None,
111
+ config: NLPCAConfig | dict | str | None = None,
112
+ overrides: dict | None = None,
113
+ simulate_missing: bool | None = None,
114
+ sim_strategy: (
115
+ Literal[
116
+ "random",
117
+ "random_weighted",
118
+ "random_weighted_inv",
119
+ "nonrandom",
120
+ "nonrandom_weighted",
121
+ ]
122
+ | None
123
+ ) = None,
124
+ sim_prop: float | None = None,
125
+ sim_kwargs: dict | None = None,
126
+ ):
127
+ """Initializes the ImputeNLPCA imputer with genotype data and configuration.
128
+
129
+ This constructor sets up the ImputeNLPCA imputer by accepting genotype data and a configuration that can be provided in various formats. It initializes logging, device settings, and model parameters based on the provided configuration.
130
+
131
+ Args:
132
+ genotype_data (GenotypeData): Backing genotype data.
133
+ tree_parser (TreeParser | None): Optional SNPio phylogenetic tree parser for population-specific modes.
134
+ config (NLPCAConfig | dict | str | None): Structured configuration as dataclass, nested dict, YAML path, or None.
135
+ overrides (dict | None): Dot-key overrides (e.g. {'model.latent_dim': 4}).
136
+ simulate_missing (bool | None): Whether to simulate missing data during training. If None, uses config defaults.
137
+ sim_strategy (Literal["random", "random_weighted", "random_weighted_inv", "nonrandom", "nonrandom_weighted"] | None): Strategy for simulating missing data. If None, uses config default.
138
+ sim_prop (float | None): Proportion of data to simulate as missing. If None, uses config default.
139
+ sim_kwargs (dict | None): Additional keyword arguments for missing data simulation (overrides config kwargs).
140
+ """
141
+ self.model_name = "ImputeNLPCA"
142
+ self.genotype_data = genotype_data
143
+ self.tree_parser = tree_parser
144
+
145
+ # Normalize config first, then apply overrides (highest precedence)
146
+ cfg = ensure_nlpca_config(config)
147
+
148
+ if overrides:
149
+ cfg = apply_dot_overrides(cfg, overrides)
150
+
151
+ self.cfg = cfg
152
+
153
+ logman = LoggerManager(
154
+ __name__,
155
+ prefix=self.cfg.io.prefix,
156
+ debug=self.cfg.io.debug,
157
+ verbose=self.cfg.io.verbose,
158
+ )
159
+ self.logger = configure_logger(
160
+ logman.get_logger(), verbose=self.cfg.io.verbose, debug=self.cfg.io.debug
161
+ )
162
+
163
+ # Initialize BaseNNImputer with device/dirs/logging from config
164
+ super().__init__(
165
+ model_name=self.model_name,
166
+ genotype_data=self.genotype_data,
167
+ prefix=self.cfg.io.prefix,
168
+ device=self.cfg.train.device,
169
+ verbose=self.cfg.io.verbose,
170
+ debug=self.cfg.io.debug,
171
+ )
172
+
173
+ self.Model = NLPCAModel
174
+ self.pgenc = GenotypeEncoder(genotype_data)
175
+ self.seed = self.cfg.io.seed
176
+ self.n_jobs = self.cfg.io.n_jobs
177
+ self.prefix = self.cfg.io.prefix
178
+ self.scoring_averaging = self.cfg.io.scoring_averaging
179
+ self.verbose = self.cfg.io.verbose
180
+ self.debug = self.cfg.io.debug
181
+
182
+ self.rng = np.random.default_rng(self.seed)
183
+ self.pos_weights_: torch.Tensor | None = None
184
+
185
+ # Model/train hyperparams
186
+ self.latent_dim = self.cfg.model.latent_dim
187
+ self.dropout_rate = self.cfg.model.dropout_rate
188
+ self.num_hidden_layers = self.cfg.model.num_hidden_layers
189
+ self.layer_scaling_factor = self.cfg.model.layer_scaling_factor
190
+ self.layer_schedule = self.cfg.model.layer_schedule
191
+ self.latent_init: Literal["random", "pca"] = self.cfg.model.latent_init
192
+ self.activation = self.cfg.model.hidden_activation
193
+ self.gamma = self.cfg.model.gamma
194
+
195
+ self.batch_size = self.cfg.train.batch_size
196
+ self.learning_rate: float = self.cfg.train.learning_rate
197
+ self.lr_input_factor = self.cfg.train.lr_input_factor
198
+ self.l1_penalty = self.cfg.train.l1_penalty
199
+ self.early_stop_gen = self.cfg.train.early_stop_gen
200
+ self.min_epochs = self.cfg.train.min_epochs
201
+ self.epochs = self.cfg.train.max_epochs
202
+ self.validation_split = self.cfg.train.validation_split
203
+ self.beta = self.cfg.train.weights_beta
204
+ self.max_ratio = self.cfg.train.weights_max_ratio
205
+
206
+ # Tuning
207
+ self.tune = self.cfg.tune.enabled
208
+ self.tune_fast = self.cfg.tune.fast
209
+ self.tune_proxy_metric_batch = self.cfg.tune.proxy_metric_batch
210
+ self.tune_batch_size = self.cfg.tune.batch_size
211
+ self.tune_epochs = self.cfg.tune.epochs
212
+ self.tune_eval_interval = self.cfg.tune.eval_interval
213
+ self.tune_metric: Literal[
214
+ "pr_macro",
215
+ "f1",
216
+ "accuracy",
217
+ "average_precision",
218
+ "precision",
219
+ "recall",
220
+ "roc_auc",
221
+ ] = self.cfg.tune.metric
222
+ self.n_trials = self.cfg.tune.n_trials
223
+ self.tune_save_db = self.cfg.tune.save_db
224
+ self.tune_resume = self.cfg.tune.resume
225
+ self.tune_max_samples = self.cfg.tune.max_samples
226
+ self.tune_max_loci = self.cfg.tune.max_loci
227
+ self.tune_infer_epochs = getattr(self.cfg.tune, "infer_epochs", 100)
228
+ self.tune_patience = self.cfg.tune.patience
229
+
230
+ # Eval
231
+ self.eval_latent_steps = self.cfg.evaluate.eval_latent_steps
232
+ self.eval_latent_lr = self.cfg.evaluate.eval_latent_lr
233
+ self.eval_latent_weight_decay = self.cfg.evaluate.eval_latent_weight_decay
234
+
235
+ # Plotting (NOTE: PlotConfig has 'show', not 'show_plots')
236
+ self.plot_format = self.cfg.plot.fmt
237
+ self.plot_dpi = self.cfg.plot.dpi
238
+ self.plot_fontsize = self.cfg.plot.fontsize
239
+ self.title_fontsize = self.cfg.plot.fontsize
240
+ self.despine = self.cfg.plot.despine
241
+ self.show_plots = self.cfg.plot.show
242
+
243
+ # Core model config
244
+ self.is_haploid = False
245
+ self.num_classes_ = 3
246
+ self.model_params: Dict[str, Any] = {}
247
+
248
+ sim_cfg = getattr(self.cfg, "sim", None)
249
+ sim_cfg_kwargs = copy.deepcopy(getattr(sim_cfg, "sim_kwargs", None) or {})
250
+
251
+ if sim_kwargs:
252
+ sim_cfg_kwargs.update(sim_kwargs)
253
+
254
+ if sim_cfg is None:
255
+ default_strategy = "random"
256
+ default_prop = 0.10
257
+ else:
258
+ default_strategy = sim_cfg.sim_strategy
259
+ default_prop = sim_cfg.sim_prop
260
+
261
+ self.simulate_missing = (
262
+ (
263
+ sim_cfg.simulate_missing
264
+ if simulate_missing is None
265
+ else bool(simulate_missing)
266
+ )
267
+ if sim_cfg is not None
268
+ else bool(simulate_missing)
269
+ )
270
+ self.sim_strategy = sim_strategy or default_strategy
271
+ self.sim_prop = float(sim_prop if sim_prop is not None else default_prop)
272
+ self.sim_kwargs = sim_cfg_kwargs
273
+
274
+ if self.tree_parser is None and self.sim_strategy.startswith("nonrandom"):
275
+ msg = "tree_parser is required for nonrandom and nonrandom_weighted simulated missing strategies."
276
+ self.logger.error(msg)
277
+ raise ValueError(msg)
278
+
279
+ def fit(self) -> "ImputeNLPCA":
280
+ """Fits the NLPCA model to the 0/1/2 encoded genotype data.
281
+
282
+ This method prepares the data, splits it into training and validation sets, initializes the model, and trains it. If hyperparameter tuning is enabled, it will perform tuning before final training. After training, it evaluates the model on a test set and generates relevant plots.
283
+
284
+ Returns:
285
+ ImputeNLPCA: The fitted imputer instance.
286
+ """
287
+ self.logger.info(f"Fitting {self.model_name} model...")
288
+
289
+ # --- BASE MATRIX AND GROUND TRUTH ---
290
+ X012 = self.pgenc.genotypes_012.astype(np.float32)
291
+ X012[X012 < 0] = np.nan # NaN = original missing
292
+
293
+ # Keep an immutable ground-truth copy in 0/1/2 with -1 for original
294
+ # missing
295
+ GT_full = X012.copy()
296
+ GT_full[np.isnan(GT_full)] = -1
297
+ self.ground_truth_ = GT_full.astype(np.int64)
298
+
299
+ # --- OPTIONAL SIMULATED MISSING VIA SimMissingTransformer ---
300
+ self.sim_mask_global_ = None
301
+ if self.simulate_missing:
302
+ tr = SimMissingTransformer(
303
+ genotype_data=self.genotype_data,
304
+ tree_parser=self.tree_parser,
305
+ prop_missing=self.sim_prop,
306
+ strategy=self.sim_strategy,
307
+ missing_val=-9,
308
+ mask_missing=True,
309
+ verbose=self.verbose,
310
+ tol=None,
311
+ max_tries=None,
312
+ )
313
+ # NOTE: pass NaN-coded missing; transformer handles NaNs correctly
314
+ tr.fit(X012.copy())
315
+
316
+ # Store boolean mask of simulated positions only (excludes original-missing)
317
+ self.sim_mask_global_ = tr.sim_missing_mask_.astype(bool)
318
+
319
+ # Apply simulation to the model’s input copy: encode as -1 for loss
320
+ X_for_model = self.ground_truth_.copy()
321
+ X_for_model[self.sim_mask_global_] = -1
322
+ else:
323
+ X_for_model = self.ground_truth_.copy()
324
+
325
+ # --- Determine Ploidy and Number of Classes ---
326
+ self.is_haploid = bool(
327
+ np.all(
328
+ np.isin(
329
+ self.genotype_data.snp_data,
330
+ ["A", "C", "G", "T", "N", "-", ".", "?"],
331
+ )
332
+ )
333
+ )
334
+
335
+ self.ploidy = 1 if self.is_haploid else 2
336
+
337
+ if self.is_haploid:
338
+ self.num_classes_ = 2
339
+
340
+ # Remap labels from {0, 2} to {0, 1}
341
+ self.ground_truth_[self.ground_truth_ == 2] = 1
342
+ X_for_model[X_for_model == 2] = 1 # <- add this line
343
+ self.logger.info("Haploid data detected. Using 2 classes (REF=0, ALT=1).")
344
+ else:
345
+ self.num_classes_ = 3
346
+ # Model head uses two channels; scoring uses num_classes_
347
+ self.output_classes_ = 2
348
+ self.logger.info(
349
+ "Diploid data detected. Using 3 classes (REF=0, HET=1, ALT=2) for scoring; 2 output channels with sigmoid for training."
350
+ )
351
+
352
+ n_samples, self.num_features_ = X_for_model.shape
353
+
354
+ self.model_params = {
355
+ "n_features": self.num_features_,
356
+ "latent_dim": self.latent_dim,
357
+ "dropout_rate": self.dropout_rate,
358
+ "activation": self.activation,
359
+ "gamma": self.gamma,
360
+ "num_classes": self.output_classes_,
361
+ }
362
+
363
+ # --- Train/Test Split ---
364
+ indices = np.arange(n_samples)
365
+ train_idx, test_idx = train_test_split(
366
+ indices, test_size=self.validation_split, random_state=self.seed
367
+ )
368
+ self.train_idx_, self.test_idx_ = train_idx, test_idx
369
+ # Subset matrices for training/eval
370
+ self.X_train_ = X_for_model[train_idx]
371
+ self.X_test_ = X_for_model[test_idx]
372
+ self.GT_train_full_ = self.ground_truth_[train_idx] # pre-mask truth
373
+ self.GT_test_full_ = self.ground_truth_[test_idx]
374
+
375
+ # Slice the simulation mask by split if present
376
+ if self.sim_mask_global_ is not None:
377
+ self.sim_mask_train_ = self.sim_mask_global_[train_idx]
378
+ self.sim_mask_test_ = self.sim_mask_global_[test_idx]
379
+ else:
380
+ self.sim_mask_train_ = None
381
+ self.sim_mask_test_ = None
382
+ # pos weights for multilabel diploid path
383
+ if not self.is_haploid:
384
+ self.pos_weights_ = self._compute_pos_weights(self.X_train_)
385
+ else:
386
+ self.pos_weights_ = None
387
+
388
+ # Tuning, model setup, training (unchanged except DataLoader input)
389
+ self.plotter_, self.scorers_ = self.initialize_plotting_and_scorers()
390
+
391
+ if self.tune:
392
+ self.tune_hyperparameters()
393
+ self.best_params_ = getattr(self, "best_params_", self.model_params.copy())
394
+ else:
395
+ self.best_params_ = self._set_best_params_default()
396
+
397
+ # Class weights from 0/1/2 training data
398
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
399
+
400
+ if not self.latent_init in {"random", "pca"} and isinstance(
401
+ self.latent_init, str
402
+ ):
403
+ msg = (
404
+ f"Invalid latent_init '{self.latent_init}'; must be 'random' or 'pca'."
405
+ )
406
+ self.logger.error(msg)
407
+ raise ValueError(msg)
408
+
409
+ li: Literal["random", "pca"] = self.latent_init
410
+
411
+ # Latent vectors for training set
412
+ self.class_weights_ = self._class_weights_from_zygosity(self.X_train_)
413
+ train_latent_vectors = self._create_latent_space(
414
+ self.best_params_, len(self.X_train_), self.X_train_, li
415
+ )
416
+ train_loader = self._get_data_loaders(self.X_train_)
417
+
418
+ # Train the final model
419
+ (self.best_loss_, self.model_, self.history_, self.train_latent_vectors_) = (
420
+ self._train_final_model(
421
+ train_loader, self.best_params_, train_latent_vectors
422
+ )
423
+ )
424
+
425
+ self.is_fit_ = True
426
+ self.plotter_.plot_history(self.history_)
427
+
428
+ if self.sim_mask_test_ is not None:
429
+ # Evaluate exactly on simulated-missing sites
430
+ self.logger.info("Evaluating on simulated-missing positions only.")
431
+ self._evaluate_model(
432
+ self.X_test_,
433
+ self.model_,
434
+ self.best_params_,
435
+ eval_mask_override=self.sim_mask_test_,
436
+ )
437
+ else:
438
+ self._evaluate_model(self.X_test_, self.model_, self.best_params_)
439
+
440
+ self._save_best_params(self.best_params_)
441
+
442
+ return self
443
+
444
+ def transform(self) -> np.ndarray:
445
+ """Imputes missing genotypes using the trained model.
446
+
447
+ This method uses the trained NLPCA model to impute missing genotypes in the entire dataset. It optimizes latent vectors for all samples, predicts missing values, and fills them in. The imputed genotypes are returned in IUPAC string format.
448
+
449
+ Returns:
450
+ np.ndarray: Imputed genotypes in IUPAC string format.
451
+
452
+ Raises:
453
+ NotFittedError: If the model has not been fitted.
454
+ """
455
+ if not getattr(self, "is_fit_", False):
456
+ raise NotFittedError("Model is not fitted. Call fit() before transform().")
457
+
458
+ self.logger.info(f"Imputing entire dataset with {self.model_name}...")
459
+ X_to_impute = self.ground_truth_.copy()
460
+
461
+ # Optimize latents for the full dataset
462
+ optimized_latents = self._optimize_latents_for_inference(
463
+ X_to_impute, self.model_, self.best_params_
464
+ )
465
+
466
+ # Predict missing values
467
+ pred_labels, _ = self._predict(self.model_, latent_vectors=optimized_latents)
468
+
469
+ # Fill in missing values
470
+ missing_mask = X_to_impute == -1
471
+ imputed_array = X_to_impute.copy()
472
+ imputed_array[missing_mask] = pred_labels[missing_mask]
473
+
474
+ # Decode back to IUPAC strings
475
+ imputed_genotypes = self.pgenc.decode_012(imputed_array)
476
+ if self.show_plots:
477
+ original_genotypes = self.pgenc.decode_012(X_to_impute)
478
+ plt.rcParams.update(self.plotter_.param_dict) # Ensure consistent style
479
+ self.plotter_.plot_gt_distribution(original_genotypes, is_imputed=False)
480
+ self.plotter_.plot_gt_distribution(imputed_genotypes, is_imputed=True)
481
+
482
+ return imputed_genotypes
483
+
484
+ def _train_step(
485
+ self,
486
+ loader: torch.utils.data.DataLoader,
487
+ optimizer: torch.optim.Optimizer,
488
+ latent_optimizer: torch.optim.Optimizer,
489
+ model: torch.nn.Module,
490
+ l1_penalty: float,
491
+ latent_vectors: torch.nn.Parameter,
492
+ class_weights: torch.Tensor,
493
+ ) -> Tuple[float, torch.nn.Parameter]:
494
+ """One epoch with stable focal CE, latent+weight updates, and NaN guards.
495
+
496
+ Args:
497
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
498
+ optimizer (torch.optim.Optimizer): Optimizer for model parameters.
499
+ latent_optimizer (torch.optim.Optimizer): Optimizer for latent vectors.
500
+ model (torch.nn.Module): NLPCA model.
501
+ l1_penalty (float): L1 regularization penalty.
502
+ latent_vectors (torch.nn.Parameter): Latent vectors for samples.
503
+ class_weights (torch.Tensor): Class weights for focal loss.
504
+
505
+ Returns:
506
+ Tuple[float, torch.nn.Parameter]: Average loss and updated latent vectors.
507
+
508
+ Notes:
509
+ - Implements focal cross-entropy loss with class weights.
510
+ - Applies L1 regularization on model weights.
511
+ - Includes guards against NaN/infinite values in logits, loss, and gradients.
512
+ """
513
+ model.train()
514
+ running = 0.0
515
+ used = 0
516
+
517
+ # Ensure latent vectors are trainable
518
+ if not isinstance(latent_vectors, torch.nn.Parameter):
519
+ latent_vectors = torch.nn.Parameter(latent_vectors, requires_grad=True)
520
+
521
+ # Bound gamma to a sane range
522
+ gamma = float(getattr(model, "gamma", getattr(self, "gamma", 0.0)))
523
+ gamma = max(0.0, min(gamma, 10.0))
524
+
525
+ # Normalize class weights to mean≈1 to keep loss scale stable
526
+ if class_weights is not None:
527
+ cw = class_weights.to(self.device)
528
+ cw = cw / cw.mean().clamp_min(1e-8)
529
+ else:
530
+ cw = None
531
+
532
+ nF = getattr(model, "n_features", self.num_features_)
533
+
534
+ criterion = SafeFocalCELoss(gamma=gamma, weight=cw, ignore_index=-1)
535
+
536
+ for batch_indices, y_batch in loader:
537
+ optimizer.zero_grad(set_to_none=True)
538
+ latent_optimizer.zero_grad(set_to_none=True)
539
+
540
+ # Targets
541
+ y_batch = y_batch.to(self.device, non_blocking=True).long()
542
+
543
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
544
+
545
+ if not isinstance(decoder, torch.nn.Module):
546
+ msg = "Model decoder is not a valid torch.nn.Module."
547
+ self.logger.error(msg)
548
+ raise TypeError(msg)
549
+
550
+ # Forward
551
+ z = latent_vectors[batch_indices].to(self.device)
552
+ logits = decoder(z).view(len(batch_indices), nF, self.output_classes_)
553
+
554
+ # Guard upstream explosions
555
+ if not torch.isfinite(logits).all():
556
+ # Skip batch if model already produced non-finite values
557
+ continue
558
+
559
+ if self.is_haploid:
560
+ logits_flat = logits.view(-1, self.output_classes_)
561
+ targets_flat = y_batch.view(-1)
562
+ loss = criterion(logits_flat, targets_flat)
563
+ else:
564
+ targets = self._multi_hot_targets(y_batch)
565
+ bce = F.binary_cross_entropy_with_logits(
566
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
567
+ )
568
+ mask = (y_batch != -1).unsqueeze(-1).float()
569
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
570
+
571
+ # L1 on model weights only (exclude latents)
572
+ if l1_penalty > 0:
573
+ l1 = torch.stack(
574
+ [p.abs().sum() for p in model.parameters() if p.requires_grad]
575
+ ).sum()
576
+ loss = loss + l1_penalty * l1
577
+
578
+ if not torch.isfinite(loss):
579
+ # Skip pathological batch
580
+ continue
581
+
582
+ loss.backward()
583
+
584
+ # Clip both parameter sets to keep grads finite
585
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
586
+ torch.nn.utils.clip_grad_norm_([latent_vectors], max_norm=1.0)
587
+
588
+ # If any grad is non-finite, skip updates
589
+ bad_grad = False
590
+ for p in model.parameters():
591
+ if p.grad is not None and not torch.isfinite(p.grad).all():
592
+ bad_grad = True
593
+ break
594
+ if (
595
+ not bad_grad
596
+ and latent_vectors.grad is not None
597
+ and not torch.isfinite(latent_vectors.grad).all()
598
+ ):
599
+ bad_grad = True
600
+ if bad_grad:
601
+ optimizer.zero_grad(set_to_none=True)
602
+ latent_optimizer.zero_grad(set_to_none=True)
603
+ continue
604
+
605
+ optimizer.step()
606
+ latent_optimizer.step()
607
+
608
+ running += float(loss.detach().item())
609
+ used += 1
610
+
611
+ if used == 0:
612
+ # Signal upstream that no safe batches were used
613
+ return float("inf"), latent_vectors
614
+
615
+ return running / used, latent_vectors
616
+
617
+ def _predict(
618
+ self, model: torch.nn.Module, latent_vectors: torch.Tensor | None = None
619
+ ) -> Tuple[np.ndarray, np.ndarray]:
620
+ """Generates 0/1/2 predictions from latent vectors.
621
+
622
+ This method uses the trained NLPCA model to generate predictions from the latent vectors by passing them through the decoder. It returns both the predicted labels and their associated probabilities.
623
+
624
+ Args:
625
+ model (torch.nn.Module): Trained NLPCA model.
626
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
627
+
628
+ Returns:
629
+ Tuple[np.ndarray, np.ndarray]: Predicted labels and probabilities.
630
+ """
631
+ if model is None or latent_vectors is None:
632
+ raise NotFittedError("Model or latent vectors not available.")
633
+
634
+ model.eval()
635
+
636
+ nF = getattr(model, "n_features", self.num_features_)
637
+
638
+ if not isinstance(model.phase23_decoder, torch.nn.Module):
639
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
640
+ self.logger.error(msg)
641
+ raise TypeError(msg)
642
+
643
+ with torch.no_grad():
644
+ logits = model.phase23_decoder(latent_vectors.to(self.device)).view(
645
+ len(latent_vectors), nF, self.output_classes_
646
+ )
647
+ if self.is_haploid:
648
+ probas = torch.softmax(logits, dim=-1)
649
+ labels = torch.argmax(probas, dim=-1)
650
+ else:
651
+ probas2 = torch.sigmoid(logits)
652
+ p_ref = probas2[..., 0]
653
+ p_alt = probas2[..., 1]
654
+ p_het = p_ref * p_alt
655
+ p_ref_only = p_ref * (1 - p_alt)
656
+ p_alt_only = p_alt * (1 - p_ref)
657
+ probas = torch.stack([p_ref_only, p_het, p_alt_only], dim=-1)
658
+ probas = probas / probas.sum(dim=-1, keepdim=True).clamp_min(1e-8)
659
+ labels = torch.argmax(probas, dim=-1)
660
+
661
+ return labels.cpu().numpy(), probas.cpu().numpy()
662
+
663
+ def _evaluate_model(
664
+ self,
665
+ X_val: np.ndarray,
666
+ model: torch.nn.Module,
667
+ params: dict,
668
+ objective_mode: bool = False,
669
+ latent_vectors_val: torch.Tensor | None = None,
670
+ *,
671
+ eval_mask_override: np.ndarray | None = None,
672
+ ) -> Dict[str, float]:
673
+ """Evaluates the model on a validation set.
674
+
675
+ This method evaluates the trained NLPCA model on a validation dataset by optimizing latent vectors for the validation samples, predicting genotypes, and computing various performance metrics. It can operate in an objective mode that suppresses logging for automated evaluations.
676
+
677
+ Args:
678
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
679
+ model (torch.nn.Module): Trained NLPCA model.
680
+ params (dict): Model parameters.
681
+ objective_mode (bool): If True, suppresses logging and reports only the metric.
682
+ latent_vectors_val (torch.Tensor | None): Pre-optimized latent vectors for validation data.
683
+ eval_mask_override (np.ndarray | None): Boolean mask to specify which entries to evaluate.
684
+
685
+ Returns:
686
+ Dict[str, float]: Dictionary of evaluation metrics.
687
+ """
688
+ if latent_vectors_val is not None:
689
+ test_latent_vectors = latent_vectors_val
690
+ else:
691
+ test_latent_vectors = self._optimize_latents_for_inference(
692
+ X_val, model, params
693
+ )
694
+
695
+ pred_labels, pred_probas = self._predict(
696
+ model=model, latent_vectors=test_latent_vectors
697
+ )
698
+
699
+ if eval_mask_override is not None:
700
+ # Validate row counts to allow feature subsetting during tuning
701
+ if eval_mask_override.shape[0] != X_val.shape[0]:
702
+ msg = (
703
+ f"eval_mask_override rows {eval_mask_override.shape[0]} "
704
+ f"does not match X_val rows {X_val.shape[0]}"
705
+ )
706
+ self.logger.error(msg)
707
+ raise ValueError(msg)
708
+
709
+ # Slice mask columns if override is wider than current X_val (tune_fast)
710
+ if eval_mask_override.shape[1] > X_val.shape[1]:
711
+ eval_mask = eval_mask_override[:, : X_val.shape[1]].astype(bool)
712
+ else:
713
+ eval_mask = eval_mask_override.astype(bool)
714
+ else:
715
+ # Default: score only observed entries
716
+ eval_mask = X_val != -1
717
+
718
+ # y_true should be drawn from the pre-mask ground truth
719
+ # Map X_val back to the correct full ground truth slice
720
+ # FIX: Check shape[0] (n_samples) only.
721
+ if X_val.shape[0] == self.X_test_.shape[0]:
722
+ GT_ref = self.GT_test_full_
723
+ elif X_val.shape[0] == self.X_train_.shape[0]:
724
+ GT_ref = self.GT_train_full_
725
+ else:
726
+ GT_ref = self.ground_truth_
727
+
728
+ # FIX: Slice Ground Truth columns if it is wider than X_val (tune_fast)
729
+ if GT_ref.shape[1] > X_val.shape[1]:
730
+ GT_ref = GT_ref[:, : X_val.shape[1]]
731
+
732
+ # Fallback safeguard
733
+ if GT_ref.shape != X_val.shape:
734
+ GT_ref = X_val
735
+
736
+ y_true_flat = GT_ref[eval_mask]
737
+ pred_labels_flat = pred_labels[eval_mask]
738
+ pred_probas_flat = pred_probas[eval_mask]
739
+
740
+ if y_true_flat.size == 0:
741
+ return {self.tune_metric: 0.0}
742
+
743
+ # For haploids, remap class 2 to 1 for scoring (e.g., f1-score)
744
+ labels_for_scoring = [0, 1] if self.is_haploid else [0, 1, 2]
745
+ target_names = ["REF", "ALT"] if self.is_haploid else ["REF", "HET", "ALT"]
746
+
747
+ y_true_ohe = np.eye(len(labels_for_scoring))[y_true_flat]
748
+
749
+ metrics = self.scorers_.evaluate(
750
+ y_true_flat,
751
+ pred_labels_flat,
752
+ y_true_ohe,
753
+ pred_probas_flat,
754
+ objective_mode,
755
+ self.tune_metric,
756
+ )
757
+
758
+ if not objective_mode:
759
+ pm = PrettyMetrics(
760
+ metrics, precision=3, title=f"{self.model_name} Validation Metrics"
761
+ )
762
+ pm.render() # prints a command-line table
763
+
764
+ self._make_class_reports(
765
+ y_true=y_true_flat,
766
+ y_pred_proba=pred_probas_flat,
767
+ y_pred=pred_labels_flat,
768
+ metrics=metrics,
769
+ labels=target_names,
770
+ )
771
+
772
+ # FIX: Use X_val dimensions for reshaping, not self.num_features_
773
+ y_true_dec = self.pgenc.decode_012(
774
+ GT_ref.reshape(X_val.shape[0], X_val.shape[1])
775
+ )
776
+
777
+ X_pred = X_val.copy()
778
+ X_pred[eval_mask] = pred_labels_flat
779
+
780
+ y_pred_dec = self.pgenc.decode_012(
781
+ X_pred.reshape(X_val.shape[0], X_val.shape[1])
782
+ )
783
+
784
+ encodings_dict = {
785
+ "A": 0,
786
+ "C": 1,
787
+ "G": 2,
788
+ "T": 3,
789
+ "W": 4,
790
+ "R": 5,
791
+ "M": 6,
792
+ "K": 7,
793
+ "Y": 8,
794
+ "S": 9,
795
+ "N": -1,
796
+ }
797
+
798
+ y_true_int = self.pgenc.convert_int_iupac(
799
+ y_true_dec, encodings_dict=encodings_dict
800
+ )
801
+ y_pred_int = self.pgenc.convert_int_iupac(
802
+ y_pred_dec, encodings_dict=encodings_dict
803
+ )
804
+
805
+ # For IUPAC report
806
+ valid_true = y_true_int[eval_mask]
807
+ valid_true = valid_true[valid_true >= 0] # drop -1 (N)
808
+ iupac_label_set = ["A", "C", "G", "T", "W", "R", "M", "K", "Y", "S"]
809
+
810
+ # For numeric report
811
+ if (
812
+ np.intersect1d(np.unique(y_true_flat), labels_for_scoring).size == 0
813
+ or valid_true.size == 0
814
+ ):
815
+ if not objective_mode:
816
+ self.logger.warning(
817
+ "Skipped numeric confusion matrix: no y_true labels present."
818
+ )
819
+ else:
820
+ self._make_class_reports(
821
+ y_true=valid_true,
822
+ y_pred=y_pred_int[eval_mask][y_true_int[eval_mask] >= 0],
823
+ metrics=metrics,
824
+ y_pred_proba=None,
825
+ labels=iupac_label_set,
826
+ )
827
+
828
+ return metrics
829
+
830
+ def _get_data_loaders(self, y: np.ndarray) -> torch.utils.data.DataLoader:
831
+ """Creates a PyTorch DataLoader for the 0/1/2 encoded data.
832
+
833
+ This method constructs a DataLoader from the provided genotype data, which is expected to be in 0/1/2 encoding with -1 for missing values. The DataLoader is used for batching and shuffling the data during model training. It converts the numpy array to a PyTorch tensor and creates a TensorDataset. The DataLoader is configured with the specified batch size and shuffling enabled.
834
+
835
+ Args:
836
+ y (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
837
+
838
+ Returns:
839
+ torch.utils.data.DataLoader: DataLoader for the dataset.
840
+ """
841
+ y_tensor = torch.from_numpy(y).long().to(self.device)
842
+ dataset = torch.utils.data.TensorDataset(
843
+ torch.arange(len(y), device=self.device), y_tensor.to(self.device)
844
+ )
845
+ return torch.utils.data.DataLoader(
846
+ dataset, batch_size=self.batch_size, shuffle=True
847
+ )
848
+
849
+ def _encode_multilabel_inputs(self, y: torch.Tensor) -> torch.Tensor:
850
+ """Two-channel multi-hot for diploid: REF-only, ALT-only; HET sets both."""
851
+ if self.is_haploid:
852
+ return self._one_hot_encode_012(y)
853
+ y = y.to(self.device)
854
+ shape = y.shape + (2,)
855
+ out = torch.zeros(shape, device=self.device, dtype=torch.float32)
856
+ valid = y != -1
857
+ ref_mask = valid & (y != 2)
858
+ alt_mask = valid & (y != 0)
859
+ out[ref_mask, 0] = 1.0
860
+ out[alt_mask, 1] = 1.0
861
+ return out
862
+
863
+ def _multi_hot_targets(self, y: torch.Tensor) -> torch.Tensor:
864
+ """Targets aligned with _encode_multilabel_inputs for diploid training."""
865
+ if self.is_haploid:
866
+ raise RuntimeError("_multi_hot_targets called for haploid data.")
867
+ y = y.to(self.device)
868
+ out = torch.zeros(y.shape + (2,), device=self.device, dtype=torch.float32)
869
+ valid = y != -1
870
+ ref_mask = valid & (y != 2)
871
+ alt_mask = valid & (y != 0)
872
+ out[ref_mask, 0] = 1.0
873
+ out[alt_mask, 1] = 1.0
874
+ return out
875
+
876
+ def _compute_pos_weights(self, X: np.ndarray) -> torch.Tensor:
877
+ """Balance REF/ALT channels for multilabel BCE."""
878
+ ref_pos = np.count_nonzero((X == 0) | (X == 1))
879
+ alt_pos = np.count_nonzero((X == 2) | (X == 1))
880
+ total_valid = np.count_nonzero(X != -1)
881
+ pos_counts = np.array([ref_pos, alt_pos], dtype=np.float32)
882
+ neg_counts = np.maximum(total_valid - pos_counts, 1.0)
883
+ pos_counts = np.maximum(pos_counts, 1.0)
884
+ weights = neg_counts / pos_counts
885
+ return torch.tensor(weights, device=self.device, dtype=torch.float32)
886
+
887
+ def _create_latent_space(
888
+ self,
889
+ params: dict,
890
+ n_samples: int,
891
+ X: np.ndarray,
892
+ latent_init: Literal["random", "pca"],
893
+ ) -> torch.nn.Parameter:
894
+ """Initializes the latent space for the NLPCA model.
895
+
896
+ This method initializes the latent space for the NLPCA model based on the specified initialization method. It supports two methods: 'random' initialization using Xavier uniform distribution, and 'pca' initialization which uses PCA to derive initial latent vectors from the data. The latent vectors are returned as a PyTorch Parameter, allowing them to be optimized during training.
897
+
898
+ Args:
899
+ params (dict): Model parameters including 'latent_dim'.
900
+ n_samples (int): Number of samples in the dataset.
901
+ X (np.ndarray): 0/1/2 encoded genotype data with -1 for missing.
902
+ latent_init (str): Method to initialize latent space ('random' or 'pca').
903
+
904
+ Returns:
905
+ torch.nn.Parameter: Initialized latent vectors as a PyTorch Parameter.
906
+ """
907
+ latent_dim = int(params["latent_dim"])
908
+
909
+ if latent_init == "pca":
910
+ X_pca = X.astype(np.float32, copy=True)
911
+ # mark missing
912
+ X_pca[X_pca < 0] = np.nan
913
+
914
+ # ---- SAFE column means without warnings ----
915
+ valid_counts = np.sum(~np.isnan(X_pca), axis=0)
916
+ col_sums = np.nansum(X_pca, axis=0)
917
+ col_means = np.divide(
918
+ col_sums,
919
+ valid_counts,
920
+ out=np.zeros_like(col_sums, dtype=np.float32),
921
+ where=valid_counts > 0,
922
+ )
923
+
924
+ # impute NaNs with per-column means
925
+ # (all-NaN cols -> 0.0 by the divide above)
926
+ nan_r, nan_c = np.where(np.isnan(X_pca))
927
+ if nan_r.size:
928
+ X_pca[nan_r, nan_c] = col_means[nan_c]
929
+
930
+ # center columns
931
+ X_pca = X_pca - X_pca.mean(axis=0, keepdims=True)
932
+
933
+ # guard: degenerate / all-zero after centering ->
934
+ # fall back to random
935
+ if (not np.isfinite(X_pca).all()) or np.allclose(X_pca, 0.0):
936
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
937
+ torch.nn.init.xavier_uniform_(latents)
938
+ return torch.nn.Parameter(latents, requires_grad=True)
939
+
940
+ # rank-aware component count, at least 1
941
+ try:
942
+ est_rank = np.linalg.matrix_rank(X_pca)
943
+ except Exception:
944
+ est_rank = min(n_samples, X_pca.shape[1])
945
+
946
+ n_components = max(1, min(latent_dim, est_rank, n_samples, X_pca.shape[1]))
947
+
948
+ # use deterministic SVD to avoid power-iteration warnings
949
+ pca = PCA(
950
+ n_components=n_components, svd_solver="full", random_state=self.seed
951
+ )
952
+ initial = pca.fit_transform(X_pca) # (n_samples, n_components)
953
+
954
+ # pad if latent_dim > n_components
955
+ if n_components < latent_dim:
956
+ pad = self.rng.standard_normal(
957
+ size=(n_samples, latent_dim - n_components)
958
+ )
959
+ initial = np.hstack([initial, pad])
960
+
961
+ # standardize latent dims
962
+ initial = (initial - initial.mean(axis=0)) / (initial.std(axis=0) + 1e-6)
963
+
964
+ latents = torch.from_numpy(initial).float().to(self.device)
965
+ return torch.nn.Parameter(latents, requires_grad=True)
966
+
967
+ # --- Random init path (unchanged) ---
968
+ latents = torch.empty(n_samples, latent_dim, device=self.device)
969
+ torch.nn.init.xavier_uniform_(latents)
970
+ return torch.nn.Parameter(latents, requires_grad=True)
971
+
972
+ def _objective(self, trial: optuna.Trial) -> float:
973
+ """Objective function for hyperparameter tuning with Optuna.
974
+
975
+ This method defines the objective function used by Optuna for hyperparameter tuning of the NLPCA model. It samples a set of hyperparameters, prepares the training and validation data, initializes the model and latent vectors, and trains the model. After training, it evaluates the model on a validation set and returns the value of the specified tuning metric.
976
+
977
+ Args:
978
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
979
+
980
+ Returns:
981
+ float: The value of the tuning metric to be minimized or maximized.
982
+ """
983
+ try:
984
+ self._prepare_tuning_artifacts()
985
+ trial_params = self._sample_hyperparameters(trial)
986
+ model_params = trial_params["model_params"]
987
+
988
+ nfeat = self._tune_num_features
989
+ if self.tune and self.tune_fast:
990
+ model_params["n_features"] = nfeat
991
+
992
+ lr = trial_params["lr"]
993
+ l1_penalty = trial_params["l1_penalty"]
994
+ lr_input_fac = trial_params["lr_input_factor"]
995
+
996
+ X_train_trial = self._tune_X_train
997
+ X_test_trial = self._tune_X_test
998
+ class_weights = self._tune_class_weights
999
+ train_loader = self._tune_loader
1000
+
1001
+ train_latents = self._create_latent_space(
1002
+ model_params,
1003
+ len(X_train_trial),
1004
+ X_train_trial,
1005
+ trial_params["latent_init"],
1006
+ )
1007
+
1008
+ model = self.build_model(self.Model, model_params)
1009
+ model.n_features = model_params["n_features"]
1010
+ model.apply(self.initialize_weights)
1011
+
1012
+ _, model, __ = self._train_and_validate_model(
1013
+ model=model,
1014
+ loader=train_loader,
1015
+ lr=lr,
1016
+ l1_penalty=l1_penalty,
1017
+ trial=trial,
1018
+ latent_vectors=train_latents,
1019
+ lr_input_factor=lr_input_fac,
1020
+ class_weights=class_weights,
1021
+ X_val=X_test_trial,
1022
+ params=model_params,
1023
+ prune_metric=self.tune_metric,
1024
+ prune_warmup_epochs=10,
1025
+ eval_interval=self.tune_eval_interval,
1026
+ eval_latent_steps=self.eval_latent_steps,
1027
+ eval_latent_lr=self.eval_latent_lr,
1028
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
1029
+ )
1030
+
1031
+ # --- simulate-only eval mask for tuning ---
1032
+ eval_mask = None
1033
+ if (
1034
+ self.simulate_missing
1035
+ and getattr(self, "sim_mask_global_", None) is not None
1036
+ ):
1037
+ if (
1038
+ hasattr(self, "_tune_test_idx")
1039
+ and self.sim_mask_global_ is not None
1040
+ ):
1041
+ eval_mask = self.sim_mask_global_[self._tune_test_idx]
1042
+ elif getattr(self, "sim_mask_test_", None) is not None:
1043
+ eval_mask = self.sim_mask_test_
1044
+
1045
+ metrics = self._evaluate_model(
1046
+ X_test_trial,
1047
+ model,
1048
+ model_params,
1049
+ objective_mode=True,
1050
+ eval_mask_override=eval_mask,
1051
+ )
1052
+
1053
+ self._clear_resources(model, train_loader, latent_vectors=train_latents)
1054
+ return metrics[self.tune_metric]
1055
+ except Exception as e:
1056
+ raise optuna.exceptions.TrialPruned(f"Trial failed with error: {e}")
1057
+
1058
+ def _sample_hyperparameters(self, trial: optuna.Trial) -> Dict[str, Any]:
1059
+ """Samples hyperparameters for the simplified NLPCA model.
1060
+
1061
+ This method defines the hyperparameter search space for the NLPCA model and samples a set of hyperparameters using the provided Optuna trial object. It computes the hidden layer sizes based on the sampled parameters and prepares the model parameters dictionary.
1062
+
1063
+ Args:
1064
+ trial (optuna.Trial): An Optuna trial object for hyperparameter suggestions.
1065
+
1066
+ Returns:
1067
+ Dict[str, int | float | str | list]: A dictionary of sampled hyperparameters.
1068
+ """
1069
+ params = {
1070
+ "latent_dim": trial.suggest_int("latent_dim", 4, 16, step=2),
1071
+ "lr": trial.suggest_float("learning_rate", 3e-4, 1e-3, log=True),
1072
+ "dropout_rate": trial.suggest_float("dropout_rate", 0.0, 0.30),
1073
+ "num_hidden_layers": trial.suggest_int("num_hidden_layers", 1, 4),
1074
+ "activation": trial.suggest_categorical(
1075
+ "activation", ["relu", "elu", "selu"]
1076
+ ),
1077
+ "gamma": trial.suggest_float("gamma", 0.5, 3.0, step=0.5),
1078
+ "lr_input_factor": trial.suggest_float(
1079
+ "lr_input_factor", 0.3, 3.0, log=True
1080
+ ),
1081
+ "l1_penalty": trial.suggest_float("l1_penalty", 1e-6, 1e-3, log=True),
1082
+ "layer_scaling_factor": trial.suggest_float(
1083
+ "layer_scaling_factor", 2.0, 4.0, step=0.5
1084
+ ),
1085
+ "layer_schedule": trial.suggest_categorical(
1086
+ "layer_schedule", ["pyramid", "linear"]
1087
+ ),
1088
+ "latent_init": trial.suggest_categorical("latent_init", ["random", "pca"]),
1089
+ }
1090
+
1091
+ use_n_features = (
1092
+ self._tune_num_features
1093
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_num_features"))
1094
+ else self.num_features_
1095
+ )
1096
+ use_n_samples = (
1097
+ len(self._tune_train_idx)
1098
+ if (self.tune and self.tune_fast and hasattr(self, "_tune_train_idx"))
1099
+ else len(self.train_idx_)
1100
+ )
1101
+
1102
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1103
+ n_inputs=params["latent_dim"],
1104
+ n_outputs=use_n_features * self.output_classes_,
1105
+ n_samples=use_n_samples,
1106
+ n_hidden=params["num_hidden_layers"],
1107
+ alpha=params["layer_scaling_factor"],
1108
+ schedule=params["layer_schedule"],
1109
+ )
1110
+
1111
+ params["model_params"] = {
1112
+ "n_features": use_n_features,
1113
+ "num_classes": self.output_classes_,
1114
+ "latent_dim": params["latent_dim"],
1115
+ "dropout_rate": params["dropout_rate"],
1116
+ "hidden_layer_sizes": hidden_layer_sizes,
1117
+ "activation": params["activation"],
1118
+ "gamma": params["gamma"],
1119
+ }
1120
+
1121
+ return params
1122
+
1123
+ def _set_best_params(self, best_params: dict) -> dict:
1124
+ """Sets the best hyperparameters found during tuning.
1125
+
1126
+ This method updates the model's attributes with the best hyperparameters obtained from tuning. It also computes the hidden layer sizes based on these parameters and prepares the final model parameters dictionary.
1127
+
1128
+ Args:
1129
+ best_params (dict): Best hyperparameters from tuning.
1130
+
1131
+ Returns:
1132
+ dict: Model parameters configured with the best hyperparameters.
1133
+ """
1134
+ self.latent_dim = best_params["latent_dim"]
1135
+ self.dropout_rate = best_params["dropout_rate"]
1136
+ self.learning_rate = best_params["learning_rate"]
1137
+ self.gamma = best_params["gamma"]
1138
+ self.lr_input_factor = best_params["lr_input_factor"]
1139
+ self.l1_penalty = best_params["l1_penalty"]
1140
+ self.activation = best_params["activation"]
1141
+ self.latent_init = best_params["latent_init"]
1142
+
1143
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1144
+ n_inputs=self.latent_dim,
1145
+ n_outputs=self.num_features_ * self.output_classes_,
1146
+ n_samples=len(self.train_idx_),
1147
+ n_hidden=best_params["num_hidden_layers"],
1148
+ alpha=best_params["layer_scaling_factor"],
1149
+ schedule=best_params["layer_schedule"],
1150
+ )
1151
+
1152
+ return {
1153
+ "n_features": self.num_features_,
1154
+ "latent_dim": self.latent_dim,
1155
+ "hidden_layer_sizes": hidden_layer_sizes,
1156
+ "dropout_rate": self.dropout_rate,
1157
+ "activation": self.activation,
1158
+ "gamma": self.gamma,
1159
+ "num_classes": self.output_classes_,
1160
+ }
1161
+
1162
+ def _set_best_params_default(self) -> Dict[str, int | float | str | list]:
1163
+ """Default (no-tuning) model_params aligned with current attributes.
1164
+
1165
+ This method constructs the model parameters dictionary using the current instance attributes of the ImputeUBP class. It computes the sizes of the hidden layers based on the instance's latent dimension, dropout rate, learning rate, and other relevant attributes. The method returns a dictionary containing the model parameters that can be used to build the UBP model when no hyperparameter tuning has been performed.
1166
+
1167
+ Returns:
1168
+ Dict[str, int | float | str | list]: model_params payload.
1169
+ """
1170
+ hidden_layer_sizes = self._compute_hidden_layer_sizes(
1171
+ n_inputs=self.latent_dim,
1172
+ n_outputs=self.num_features_ * self.output_classes_,
1173
+ n_samples=len(self.ground_truth_),
1174
+ n_hidden=self.num_hidden_layers,
1175
+ alpha=self.layer_scaling_factor,
1176
+ schedule=self.layer_schedule,
1177
+ )
1178
+
1179
+ return {
1180
+ "n_features": self.num_features_,
1181
+ "latent_dim": self.latent_dim,
1182
+ "hidden_layer_sizes": hidden_layer_sizes,
1183
+ "dropout_rate": self.dropout_rate,
1184
+ "activation": self.activation,
1185
+ "gamma": self.gamma,
1186
+ "num_classes": self.output_classes_,
1187
+ }
1188
+
1189
+ def _train_and_validate_model(
1190
+ self,
1191
+ model: torch.nn.Module,
1192
+ loader: torch.utils.data.DataLoader,
1193
+ lr: float,
1194
+ l1_penalty: float,
1195
+ trial: optuna.Trial | None = None,
1196
+ return_history: bool = False,
1197
+ latent_vectors: torch.nn.Parameter | None = None,
1198
+ lr_input_factor: float = 1.0,
1199
+ class_weights: torch.Tensor | None = None,
1200
+ *,
1201
+ X_val: np.ndarray | None = None,
1202
+ params: dict | None = None,
1203
+ prune_metric: str | None = None, # "f1" | "accuracy" | "pr_macro"
1204
+ prune_warmup_epochs: int = 10,
1205
+ eval_interval: int = 1,
1206
+ eval_latent_steps: int = 50,
1207
+ eval_latent_lr: float = 1e-2,
1208
+ eval_latent_weight_decay: float = 0.0,
1209
+ ) -> Tuple:
1210
+ """Trains and validates the NLPCA model.
1211
+
1212
+ This method trains the provided NLPCA model using the specified training data and hyperparameters. It supports optional integration with Optuna for hyperparameter tuning and pruning based on validation performance. The method initializes optimizers for both the model parameters and latent vectors, sets up a learning rate scheduler, and executes the training loop. It can return the training history if requested.
1213
+
1214
+ Args:
1215
+ model (torch.nn.Module): The NLPCA model to be trained.
1216
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1217
+ lr (float): Learning rate for the model optimizer.
1218
+ l1_penalty (float): L1 regularization penalty.
1219
+ trial (optuna.Trial | None): Optuna trial for hyperparameter tuning.
1220
+ return_history (bool): Whether to return training history.
1221
+ latent_vectors (torch.nn.Parameter | None): Latent vectors for samples.
1222
+ lr_input_factor (float): Learning rate factor for latent vectors.
1223
+ class_weights (torch.Tensor | None): Class weights for handling class imbalance.
1224
+ X_val (np.ndarray | None): Validation data for pruning.
1225
+ params (dict | None): Model parameters.
1226
+ prune_metric (str | None): Metric for pruning decisions.
1227
+ prune_warmup_epochs (int): Number of epochs before pruning starts.
1228
+ eval_interval (int): Interval (in epochs) for evaluation during training.
1229
+ eval_latent_steps (int): Steps for latent optimization during evaluation.
1230
+ eval_latent_lr (float): Learning rate for latent optimization during evaluation.
1231
+ eval_latent_weight_decay (float): Weight decay for latent optimization during evaluation.
1232
+
1233
+ Returns:
1234
+ Tuple[float, torch.nn.Module, Dict[str, float], torch.nn.Parameter] | Tuple[float, torch.nn.Module, torch.nn.Parameter]: Training loss, trained model, training history (if requested), and optimized latent vectors.
1235
+
1236
+ Raises:
1237
+ TypeError: If latent_vectors or class_weights are not provided.
1238
+ """
1239
+
1240
+ if latent_vectors is None or class_weights is None:
1241
+ msg = "latent_vectors and class_weights must be provided."
1242
+ self.logger.error(msg)
1243
+ raise TypeError("Must provide latent_vectors and class_weights.")
1244
+
1245
+ latent_optimizer = torch.optim.Adam([latent_vectors], lr=lr * lr_input_factor)
1246
+
1247
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1248
+
1249
+ if not isinstance(decoder, torch.nn.Module):
1250
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1251
+ self.logger.error(msg)
1252
+ raise TypeError(msg)
1253
+
1254
+ optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
1255
+ scheduler = CosineAnnealingLR(optimizer, T_max=self.epochs)
1256
+
1257
+ result = self._execute_training_loop(
1258
+ loader=loader,
1259
+ optimizer=optimizer,
1260
+ latent_optimizer=latent_optimizer,
1261
+ scheduler=scheduler,
1262
+ model=model,
1263
+ l1_penalty=l1_penalty,
1264
+ return_history=return_history,
1265
+ latent_vectors=latent_vectors,
1266
+ class_weights=class_weights,
1267
+ trial=trial,
1268
+ X_val=X_val,
1269
+ params=params,
1270
+ prune_metric=prune_metric,
1271
+ prune_warmup_epochs=prune_warmup_epochs,
1272
+ eval_interval=eval_interval,
1273
+ eval_latent_steps=eval_latent_steps,
1274
+ eval_latent_lr=eval_latent_lr,
1275
+ eval_latent_weight_decay=eval_latent_weight_decay,
1276
+ )
1277
+
1278
+ if return_history:
1279
+ return result
1280
+
1281
+ return result[0], result[1], result[3]
1282
+
1283
+ def _train_final_model(
1284
+ self,
1285
+ loader: torch.utils.data.DataLoader,
1286
+ best_params: dict,
1287
+ initial_latent_vectors: torch.nn.Parameter,
1288
+ ) -> Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]:
1289
+ """Trains the final model using the best hyperparameters.
1290
+
1291
+ This method builds and trains the final NLPCA model using the best hyperparameters obtained from tuning. It initializes the model weights, trains the model on the entire training set, and saves the trained model to disk. It returns the final training loss, trained model, training history, and optimized latent vectors.
1292
+
1293
+ Args:
1294
+ loader (torch.utils.data.DataLoader): DataLoader for training data.
1295
+ best_params (dict): Best hyperparameters for the model.
1296
+ initial_latent_vectors (torch.nn.Parameter): Initial latent vectors for samples.
1297
+
1298
+ Returns:
1299
+ Tuple[float, torch.nn.Module, dict, torch.nn.Parameter]: Final training loss, trained model, training history, and optimized latent vectors.
1300
+ Raises:
1301
+ RuntimeError: If model training fails.
1302
+ """
1303
+ self.logger.info(f"Training the final {self.model_name} model...")
1304
+
1305
+ model = self.build_model(self.Model, best_params)
1306
+ model.n_features = best_params["n_features"]
1307
+ model.apply(self.initialize_weights)
1308
+
1309
+ loss, trained_model, history, latent_vectors = self._train_and_validate_model(
1310
+ model=model,
1311
+ loader=loader,
1312
+ lr=self.learning_rate,
1313
+ l1_penalty=self.l1_penalty,
1314
+ return_history=True,
1315
+ latent_vectors=initial_latent_vectors,
1316
+ lr_input_factor=self.lr_input_factor,
1317
+ class_weights=self.class_weights_,
1318
+ X_val=self.X_test_,
1319
+ params=best_params,
1320
+ prune_metric=self.tune_metric,
1321
+ prune_warmup_epochs=10,
1322
+ eval_interval=1,
1323
+ eval_latent_steps=self.eval_latent_steps,
1324
+ eval_latent_lr=self.eval_latent_lr,
1325
+ eval_latent_weight_decay=self.eval_latent_weight_decay,
1326
+ )
1327
+
1328
+ if trained_model is None:
1329
+ msg = "Final model training failed."
1330
+ self.logger.error(msg)
1331
+ raise RuntimeError(msg)
1332
+
1333
+ fn = self.models_dir / "final_model.pt"
1334
+ torch.save(trained_model.state_dict(), fn)
1335
+
1336
+ return (loss, trained_model, {"Train": history}, latent_vectors)
1337
+
1338
+ def _execute_training_loop(
1339
+ self,
1340
+ loader,
1341
+ optimizer,
1342
+ latent_optimizer,
1343
+ scheduler, # do not overwrite; honor caller's scheduler
1344
+ model,
1345
+ l1_penalty,
1346
+ return_history,
1347
+ latent_vectors,
1348
+ class_weights,
1349
+ *,
1350
+ trial: optuna.Trial | None = None,
1351
+ X_val: np.ndarray | None = None,
1352
+ params: dict | None = None,
1353
+ prune_metric: str | None = None,
1354
+ prune_warmup_epochs: int = 10,
1355
+ eval_interval: int = 1,
1356
+ eval_latent_steps: int = 50,
1357
+ eval_latent_lr: float = 1e-2,
1358
+ eval_latent_weight_decay: float = 0.0,
1359
+ ) -> Tuple[float, torch.nn.Module, list, torch.nn.Parameter]:
1360
+ """Train NLPCA with warmup, pruning, and early stopping."""
1361
+ best_model = None
1362
+ history: list[float] = []
1363
+
1364
+ early_stopping = EarlyStopping(
1365
+ patience=self.early_stop_gen,
1366
+ min_epochs=self.min_epochs,
1367
+ verbose=self.verbose,
1368
+ prefix=self.prefix,
1369
+ debug=self.debug,
1370
+ )
1371
+
1372
+ # Epoch budget
1373
+ max_epochs = (
1374
+ self.tune_epochs if (trial is not None and self.tune_fast) else self.epochs
1375
+ )
1376
+
1377
+ # Optional LR warmup for both optimizers
1378
+ warmup_epochs = getattr(self, "lr_warmup_epochs", 5)
1379
+ model_lr0 = optimizer.param_groups[0]["lr"]
1380
+ latent_lr0 = latent_optimizer.param_groups[0]["lr"]
1381
+ model_lr_min = model_lr0 * 0.1
1382
+ latent_lr_min = latent_lr0 * 0.1
1383
+
1384
+ _latent_cache: dict = {}
1385
+ _latent_cache_key = f"{self.prefix}_{self.model_name}_val_latents"
1386
+
1387
+ for epoch in range(max_epochs):
1388
+ # Linear warmup LRs for first few epochs
1389
+ if epoch < warmup_epochs:
1390
+ scale = float(epoch + 1) / warmup_epochs
1391
+ for g in optimizer.param_groups:
1392
+ g["lr"] = model_lr_min + (model_lr0 - model_lr_min) * scale
1393
+ for g in latent_optimizer.param_groups:
1394
+ g["lr"] = latent_lr_min + (latent_lr0 - latent_lr_min) * scale
1395
+
1396
+ train_loss, latent_vectors = self._train_step(
1397
+ loader=loader,
1398
+ optimizer=optimizer,
1399
+ latent_optimizer=latent_optimizer,
1400
+ model=model,
1401
+ l1_penalty=l1_penalty,
1402
+ latent_vectors=latent_vectors,
1403
+ class_weights=class_weights,
1404
+ )
1405
+
1406
+ if not np.isfinite(train_loss):
1407
+ if trial:
1408
+ raise optuna.exceptions.TrialPruned("Epoch loss non-finite.")
1409
+ # Reduce both LRs and continue
1410
+ for g in optimizer.param_groups:
1411
+ g["lr"] *= 0.5
1412
+ for g in latent_optimizer.param_groups:
1413
+ g["lr"] *= 0.5
1414
+ continue
1415
+
1416
+ if scheduler is not None:
1417
+ scheduler.step()
1418
+
1419
+ if return_history:
1420
+ history.append(train_loss)
1421
+
1422
+ # Optuna prune on validation metric
1423
+ if (
1424
+ trial is not None
1425
+ and X_val is not None
1426
+ and ((epoch + 1) % eval_interval == 0)
1427
+ ):
1428
+ seed = int(
1429
+ self.rng.integers(0, 1_000_000) if self.seed is None else self.seed
1430
+ )
1431
+ metric_key = prune_metric or getattr(self, "tune_metric", "f1")
1432
+ do_infer = int(eval_latent_steps) > 0
1433
+ metric_val = self._eval_for_pruning(
1434
+ model=model,
1435
+ X_val=X_val,
1436
+ params=params or getattr(self, "best_params_", {}),
1437
+ metric=metric_key,
1438
+ objective_mode=True,
1439
+ do_latent_infer=do_infer,
1440
+ latent_steps=eval_latent_steps,
1441
+ latent_lr=eval_latent_lr,
1442
+ latent_weight_decay=eval_latent_weight_decay,
1443
+ latent_seed=seed,
1444
+ _latent_cache=_latent_cache,
1445
+ _latent_cache_key=_latent_cache_key,
1446
+ eval_mask_override=(
1447
+ self.sim_mask_test_
1448
+ if (
1449
+ self.simulate_missing
1450
+ and getattr(self, "sim_mask_test_", None) is not None
1451
+ and X_val.shape == self.X_test_.shape
1452
+ )
1453
+ else (
1454
+ self.sim_mask_global_[self._tune_test_idx]
1455
+ if (
1456
+ self.simulate_missing
1457
+ and self.sim_mask_global_ is not None
1458
+ and hasattr(self, "_tune_test_idx")
1459
+ and X_val.shape[0] == len(self._tune_test_idx)
1460
+ )
1461
+ else None
1462
+ )
1463
+ ),
1464
+ )
1465
+
1466
+ trial.report(metric_val, step=epoch + 1)
1467
+
1468
+ if (epoch + 1) >= prune_warmup_epochs and trial.should_prune():
1469
+ raise optuna.exceptions.TrialPruned(
1470
+ f"Pruned at epoch {epoch + 1}: {metric_key}={metric_val:.4f}"
1471
+ )
1472
+
1473
+ early_stopping(train_loss, model)
1474
+ if early_stopping.early_stop:
1475
+ self.logger.info(f"Early stopping at epoch {epoch + 1}.")
1476
+ break
1477
+
1478
+ best_loss = early_stopping.best_score
1479
+ best_model = copy.deepcopy(early_stopping.best_model)
1480
+ if best_model is None:
1481
+ best_model = copy.deepcopy(model)
1482
+ return best_loss, best_model, history, latent_vectors
1483
+
1484
+ def _optimize_latents_for_inference(
1485
+ self,
1486
+ X_new: np.ndarray,
1487
+ model: torch.nn.Module,
1488
+ params: dict,
1489
+ inference_epochs: int = 200,
1490
+ ) -> torch.Tensor:
1491
+ """Refine latents for new data with guards.
1492
+
1493
+ This method optimizes latent vectors for new data samples by refining them through gradient-based optimization. It initializes the latent space and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization.
1494
+
1495
+ Args:
1496
+ X_new (np.ndarray): New data in 0/1/2 encoding with -
1497
+ model (torch.nn.Module): Trained NLPCA model.
1498
+ params (dict): Model parameters.
1499
+ inference_epochs (int): Number of optimization epochs.
1500
+
1501
+ Returns:
1502
+ torch.Tensor: Optimized latent vectors for the new data.
1503
+
1504
+ """
1505
+ if self.tune and self.tune_fast:
1506
+ inference_epochs = min(
1507
+ inference_epochs, getattr(self, "tune_infer_epochs", 20)
1508
+ )
1509
+
1510
+ model.eval()
1511
+ nF = getattr(model, "n_features", self.num_features_)
1512
+
1513
+ z = self._create_latent_space(
1514
+ params, len(X_new), X_new, self.latent_init
1515
+ ).requires_grad_(True)
1516
+ opt = torch.optim.AdamW(
1517
+ [z], lr=self.learning_rate * self.lr_input_factor, eps=1e-7
1518
+ )
1519
+
1520
+ X_new = X_new.astype(np.int64, copy=False)
1521
+ X_new[X_new < 0] = -1
1522
+ y = torch.from_numpy(X_new).long().to(self.device)
1523
+
1524
+ for _ in range(inference_epochs):
1525
+ opt.zero_grad(set_to_none=True)
1526
+
1527
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1528
+
1529
+ if not isinstance(decoder, torch.nn.Module):
1530
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1531
+ self.logger.error(msg)
1532
+ raise TypeError(msg)
1533
+
1534
+ logits = decoder(z).view(len(X_new), nF, self.output_classes_)
1535
+
1536
+ if not torch.isfinite(logits).all():
1537
+ break
1538
+
1539
+ if self.is_haploid:
1540
+ loss = F.cross_entropy(
1541
+ logits.view(-1, self.output_classes_),
1542
+ y.view(-1),
1543
+ ignore_index=-1,
1544
+ reduction="mean",
1545
+ )
1546
+ else:
1547
+ targets = self._multi_hot_targets(y)
1548
+ bce = F.binary_cross_entropy_with_logits(
1549
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1550
+ )
1551
+ mask = (y != -1).unsqueeze(-1).float()
1552
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1553
+ if not torch.isfinite(loss):
1554
+ break
1555
+
1556
+ loss.backward()
1557
+ torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
1558
+ if z.grad is None or not torch.isfinite(z.grad).all():
1559
+ break
1560
+ opt.step()
1561
+
1562
+ return z.detach()
1563
+
1564
+ def _latent_infer_for_eval(
1565
+ self,
1566
+ model: torch.nn.Module,
1567
+ X_val: np.ndarray,
1568
+ *,
1569
+ steps: int,
1570
+ lr: float,
1571
+ weight_decay: float,
1572
+ seed: int,
1573
+ cache: dict | None,
1574
+ cache_key: str | None,
1575
+ ) -> None:
1576
+ """Freeze weights; refine validation latents only (no leakage).
1577
+
1578
+ This method refines latent vectors for validation data by optimizing them while keeping the model weights frozen. It initializes the latent space, optionally using cached latent vectors, and iteratively updates the latent vectors to minimize the reconstruction loss using cross-entropy. The method includes safeguards to handle non-finite values during optimization and can store the optimized latent vectors in a cache.
1579
+
1580
+ Args:
1581
+ model (torch.nn.Module): Trained NLPCA model.
1582
+ X_val (np.ndarray): Validation data in 0/1/2 encoding with -1 for missing.
1583
+ steps (int): Number of optimization steps.
1584
+ lr (float): Learning rate for latent optimization.
1585
+ weight_decay (float): Weight decay for latent optimization.
1586
+ seed (int): Random seed for reproducibility.
1587
+ cache (dict | None): Cache for storing optimized latent vectors.
1588
+ cache_key (str | None): Key for storing/retrieving from cache.
1589
+ """
1590
+ if seed is None:
1591
+ seed = np.random.randint(0, 999_999)
1592
+ torch.manual_seed(seed)
1593
+ np.random.seed(seed)
1594
+
1595
+ model.eval()
1596
+ nF = getattr(model, "n_features", self.num_features_)
1597
+
1598
+ for p in model.parameters():
1599
+ p.requires_grad_(False)
1600
+
1601
+ X_val = X_val.astype(np.int64, copy=False)
1602
+ X_val[X_val < 0] = -1
1603
+ y = torch.from_numpy(X_val).long().to(self.device)
1604
+
1605
+ latent_dim = self._first_linear_in_features(model)
1606
+ cache_key = f"{self.prefix}_nlpca_val_latents_z{latent_dim}_L{self.num_features_}_K{self.output_classes_}"
1607
+
1608
+ if cache is not None and cache_key in cache:
1609
+ z = cache[cache_key].detach().clone().requires_grad_(True)
1610
+ else:
1611
+ z = self._create_latent_space(
1612
+ {"latent_dim": latent_dim},
1613
+ n_samples=X_val.shape[0],
1614
+ X=X_val,
1615
+ latent_init=self.latent_init,
1616
+ ).requires_grad_(True)
1617
+
1618
+ opt = torch.optim.AdamW([z], lr=lr, weight_decay=weight_decay, eps=1e-7)
1619
+
1620
+ for _ in range(max(int(steps), 0)):
1621
+ opt.zero_grad(set_to_none=True)
1622
+
1623
+ decoder: torch.Tensor | torch.nn.Module = model.phase23_decoder
1624
+
1625
+ if not isinstance(decoder, torch.nn.Module):
1626
+ msg = f"{self.model_name} decoder is not a valid torch.nn.Module."
1627
+ self.logger.error(msg)
1628
+ raise TypeError(msg)
1629
+
1630
+ logits = decoder(z).view(X_val.shape[0], nF, self.output_classes_)
1631
+
1632
+ if not torch.isfinite(logits).all():
1633
+ break
1634
+
1635
+ if self.is_haploid:
1636
+ loss = F.cross_entropy(
1637
+ logits.view(-1, self.output_classes_),
1638
+ y.view(-1),
1639
+ ignore_index=-1,
1640
+ reduction="mean",
1641
+ )
1642
+ else:
1643
+ targets = self._multi_hot_targets(y)
1644
+ bce = F.binary_cross_entropy_with_logits(
1645
+ logits, targets, pos_weight=self.pos_weights_, reduction="none"
1646
+ )
1647
+ mask = (y != -1).unsqueeze(-1).float()
1648
+ loss = (bce * mask).sum() / mask.sum().clamp_min(1e-8)
1649
+
1650
+ if not torch.isfinite(loss):
1651
+ break
1652
+
1653
+ loss.backward()
1654
+
1655
+ torch.nn.utils.clip_grad_norm_([z], max_norm=1.0)
1656
+
1657
+ if z.grad is None or not torch.isfinite(z.grad).all():
1658
+ break
1659
+
1660
+ opt.step()
1661
+
1662
+ if cache is not None:
1663
+ cache[cache_key] = z.detach().clone()
1664
+
1665
+ for p in model.parameters():
1666
+ p.requires_grad_(True)